fluke.server

The module fluke.server provides the base classes for the servers in fluke.

Classes

class fluke.server.Server

fit

Run the federated learning algorithm.

broadcast_model

Broadcast the global model to the clients.

get_eligible_clients

Get the clients that will participate in the current round.

get_client_models

Retrieve the models of the clients.

_get_client_weights

Get the weights of the clients for the aggregation.

aggregate

Aggregate the models of the clients.

finalize

Finalize the federated learning process.

class fluke.server.Server(model: torch.nn.Module, test_set: FastDataLoader, clients: Iterable[Client], weighted: bool = False, lr: float = 1.0)[source]

Basic Server for Federated Learning. This class is the base class for all servers in fluke. It implements the basic functionalities of a federated learning server. The default behaviour of this server is based on the Federated Averaging algorithm. The server is responsible for coordinating the learning process, selecting the clients for each round, sending the global model to the clients, and aggregating the models received from the clients at the end of the round. The server also evaluates the model server-side (if the test data is provided).

hyper_params

The hyper-parameters of the server. The default hyper-parameters are:

  • weighted: A boolean indicating if the clients should be weighted by the number of samples when aggregating the models.

When a new server class inherits from this class, it must add all its hyper-parameters to this dictionary.

Type:

DDict

device

The device where the server runs.

Type:

torch.device

model

The federated model to be trained.

Type:

torch.nn.Module

clients

The clients that will participate in the federated learning process.

Type:

Iterable[Client]

rounds

The number of rounds that have been executed.

Type:

int

test_set

The test data to evaluate the model. If None, the model will not be evaluated server-side.

Type:

FastDataLoader

evaluator

The evaluator to compute the evaluation metrics.

Type:

Evaluator

Parameters:
  • model (torch.nn.Module) – The federated model to be trained.

  • test_set (FastDataLoader) – The test data to evaluate the model.

  • clients (Iterable[Client]) – The clients that will participate in the federated learning process.

  • evaluator (Evaluator) – The evaluator to compute the evaluation metrics.

  • eval_every (int) – The number of rounds between evaluations. Defaults to 1.

  • weighted (bool) – A boolean indicating if the clients should be weighted by the number of samples when aggregating the models. Defaults to False.

_get_client_weights(eligible: Iterable[Client])[source]

Get the weights of the clients for the aggregation. The weights are calculated based on the number of samples of each client. If the hyperparameter weighted is True, the clients are weighted by their number of samples. Otherwise, all clients have the same weight.

Caution

The computation of the weights do not adhere to the “best-practices” of fluke because the server should not have direct access to the number of samples of the clients. Thus, the computation of the weights should be done communicating with the clients through the channel, but for simplicity, we are not following this practice here. However, the communication overhead is negligible and does not affect the logged performance.

Parameters:

eligible (Iterable[Client]) – The clients that will participate in the aggregation.

Returns:

The weights of the clients.

Return type:

list[float]

aggregate(eligible: Iterable[Client]) None[source]

Aggregate the models of the clients. The aggregation is done by averaging the models of the clients. If the hyperparameter weighted is True, the clients are weighted by their number of samples. The method directly updates the model of the server. Formally, let \(\theta\) be the model of the server, \(\theta_i\) the model of client \(i\), and \(w_i\) the weight of client \(i\) such that \(\sum_{i=1}^{N} w_i = 1\). The aggregation is done as follows [FedAVG]:

\[\begin{split}\\theta = \\sum_{i=1}^{N} w_i \\theta_i\end{split}\]

Note

In case of networks with batch normalization layers, the running statistics of the batch normalization layers are also aggregated. For all statistics but num_batches_tracked`are aggregated the mean is computed, while for the `num_batches_tracked parameter, the maximum between the server’s and the truncated mean of the clients’ is taken.

Parameters:

eligible (Iterable[Client]) – The clients that will participate in the aggregation.

References

[FedAVG]

H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, “Communication-Efficient Learning of Deep Networks from Decentralized Data”. In AISTATS (2017).

broadcast_model(eligible: Iterable[Client]) None[source]

Broadcast the global model to the clients.

Parameters:

eligible (Iterable[Client]) – The clients that will receive the global model.

finalize() None[source]

Finalize the federated learning process. The finalize method is called at the end of the federated learning process. The client-side evaluation is only done if the client has participated in at least one round.

fit(n_rounds: int = 10, eligible_perc: float = 0.1, finalize: bool = True, **kwargs: dict[str, Any]) None[source]

Run the federated learning algorithm. The default behaviour of this method is to run the Federated Averaging algorithm. The server selects a percentage of the clients to participate in each round, sends the global model to the clients, which compute the local updates and send them back to the server. The server aggregates the models of the clients and repeats the process for a number of rounds. During the process, the server evaluates the global model and the local model every eval_every rounds.

Parameters:
  • n_rounds (int, optional) – The number of rounds to run. Defaults to 10.

  • eligible_perc (float, optional) – The percentage of clients that will be selected for each round. Defaults to 0.1.

  • finalize (bool, optional) – If True, the server will finalize the federated learning process. Defaults to True.

  • **kwargs – Additional keyword arguments.

get_client_models(eligible: Iterable[Client], state_dict: bool = True) list[Any][source]

Retrieve the models of the clients. This method assumes that the clients have already sent their models to the server.

Parameters:
  • eligible (Iterable[Client]) – The clients that will participate in the aggregation.

  • state_dict (bool, optional) – If True, the method returns the state_dict of the models. Otherwise, it returns the models. Defaults to True.

Returns:

The models of the clients.

Return type:

list[torch.nn.Module]

get_eligible_clients(eligible_perc: float) Iterable[Client][source]

Get the clients that will participate in the current round.

Parameters:

eligible_perc (float) – The percentage of clients that will be selected.

Returns:

The clients that will participate in the current round.

Return type:

Iterable[Client]