fluke.server
¶
The module fluke.server
provides the base classes for the servers in fluke
.
Classes¶
class fluke.server.Server
Run the federated learning algorithm. |
|
Broadcast the global model to the clients. |
|
Get the clients that will participate in the current round. |
|
Retrieve the models of the clients. |
|
Get the weights of the clients for the aggregation. |
|
Aggregate the models of the clients. |
|
Finalize the federated learning process. |
- class fluke.server.Server(model: torch.nn.Module, test_set: FastDataLoader, clients: Iterable[Client], weighted: bool = False, lr: float = 1.0)[source]¶
Basic Server for Federated Learning. This class is the base class for all servers in
fluke
. It implements the basic functionalities of a federated learning server. The default behaviour of this server is based on the Federated Averaging algorithm. The server is responsible for coordinating the learning process, selecting the clients for each round, sending the global model to the clients, and aggregating the models received from the clients at the end of the round. The server also evaluates the model server-side (if the test data is provided).- hyper_params¶
The hyper-parameters of the server. The default hyper-parameters are:
weighted: A boolean indicating if the clients should be weighted by the number of samples when aggregating the models.
When a new server class inherits from this class, it must add all its hyper-parameters to this dictionary.
- Type:
- device¶
The device where the server runs.
- Type:
torch.device
- model¶
The federated model to be trained.
- Type:
torch.nn.Module
- clients¶
The clients that will participate in the federated learning process.
- Type:
Iterable[Client]
- test_set¶
The test data to evaluate the model. If None, the model will not be evaluated server-side.
- Type:
- Parameters:
model (torch.nn.Module) – The federated model to be trained.
test_set (FastDataLoader) – The test data to evaluate the model.
clients (Iterable[Client]) – The clients that will participate in the federated learning process.
evaluator (Evaluator) – The evaluator to compute the evaluation metrics.
eval_every (int) – The number of rounds between evaluations. Defaults to 1.
weighted (bool) – A boolean indicating if the clients should be weighted by the number of samples when aggregating the models. Defaults to False.
- _get_client_weights(eligible: Iterable[Client])[source]¶
Get the weights of the clients for the aggregation. The weights are calculated based on the number of samples of each client. If the hyperparameter
weighted
is True, the clients are weighted by their number of samples. Otherwise, all clients have the same weight.Caution
The computation of the weights do not adhere to the “best-practices” of
fluke
because the server should not have direct access to the number of samples of the clients. Thus, the computation of the weights should be done communicating with the clients through the channel, but for simplicity, we are not following this practice here. However, the communication overhead is negligible and does not affect the logged performance.
- aggregate(eligible: Iterable[Client]) None [source]¶
Aggregate the models of the clients. The aggregation is done by averaging the models of the clients. If the hyperparameter
weighted
is True, the clients are weighted by their number of samples. The method directly updates the model of the server. Formally, let \(\theta\) be the model of the server, \(\theta_i\) the model of client \(i\), and \(w_i\) the weight of client \(i\) such that \(\sum_{i=1}^{N} w_i = 1\). The aggregation is done as follows [FedAVG]:\[\begin{split}\\theta = \\sum_{i=1}^{N} w_i \\theta_i\end{split}\]Note
In case of networks with batch normalization layers, the running statistics of the batch normalization layers are also aggregated. For all statistics but num_batches_tracked`are aggregated the mean is computed, while for the `num_batches_tracked parameter, the maximum between the server’s and the truncated mean of the clients’ is taken.
- Parameters:
eligible (Iterable[Client]) – The clients that will participate in the aggregation.
References
[FedAVG]H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, “Communication-Efficient Learning of Deep Networks from Decentralized Data”. In AISTATS (2017).
- broadcast_model(eligible: Iterable[Client]) None [source]¶
Broadcast the global model to the clients.
- Parameters:
eligible (Iterable[Client]) – The clients that will receive the global model.
- finalize() None [source]¶
Finalize the federated learning process. The finalize method is called at the end of the federated learning process. The client-side evaluation is only done if the client has participated in at least one round.
- fit(n_rounds: int = 10, eligible_perc: float = 0.1, finalize: bool = True, **kwargs: dict[str, Any]) None [source]¶
Run the federated learning algorithm. The default behaviour of this method is to run the Federated Averaging algorithm. The server selects a percentage of the clients to participate in each round, sends the global model to the clients, which compute the local updates and send them back to the server. The server aggregates the models of the clients and repeats the process for a number of rounds. During the process, the server evaluates the global model and the local model every
eval_every
rounds.- Parameters:
n_rounds (int, optional) – The number of rounds to run. Defaults to 10.
eligible_perc (float, optional) – The percentage of clients that will be selected for each round. Defaults to 0.1.
finalize (bool, optional) – If True, the server will finalize the federated learning process. Defaults to True.
**kwargs – Additional keyword arguments.
- get_client_models(eligible: Iterable[Client], state_dict: bool = True) list[Any] [source]¶
Retrieve the models of the clients. This method assumes that the clients have already sent their models to the server.
- Parameters:
- Returns:
The models of the clients.
- Return type:
list[torch.nn.Module]