fluke.server
¶
The module fluke.server
provides the base classes for the servers in fluke
.
Classes¶
class fluke.server.Server
Aggregate the models of the clients. |
|
Broadcast the global model to the clients. |
|
The channel to communicate with the clients. |
|
Evaluate the global federated model on the |
|
Finalize the federated learning process. |
|
Run the federated learning algorithm. |
|
Get the clients that will participate in the current round. |
|
Return whether the server can evaluate the model. |
|
Return whether the server owns a global model. |
|
Load the server's state from file. |
|
Retrieve the models of the clients. |
|
Save the server's state to file. |
|
Return the server's state as a dictionary. |
|
Get the weights of the clients for the aggregation. |
- class fluke.server.Server(model: torch.nn.Module, test_set: FastDataLoader, clients: Iterable[Client], weighted: bool = False, lr: float = 1.0, **kwargs: dict[str, Any])[source]¶
Basic Server for Federated Learning.
This class is the base class for all servers in
fluke
. It implements the basic functionalities of a federated learning server. The default behaviour of this server is based on the Federated Averaging algorithm. The server is responsible for coordinating the learning process, selecting the clients for each round, sending the global model to the clients, and aggregating the models received from the clients at the end of the round. The server also evaluates the model server-side (if the test data is provided).- hyper_params¶
The hyper-parameters of the server. The default hyper-parameters are:
weighted: A boolean indicating if the clients should be weighted by the number of samples when aggregating the models.
When a new server class inherits from this class, it must add all its hyper-parameters to this dictionary.
- Type:
- device¶
The device where the server runs.
- Type:
- model¶
The federated model to be trained.
- Type:
- clients¶
The clients that will participate in the federated learning process.
- Type:
Iterable[Client]
- test_set¶
The test data to evaluate the model. If None, the model will not be evaluated server-side.
- Type:
- Parameters:
model (torch.nn.Module) – The federated model to be trained.
test_set (FastDataLoader) – The test data to evaluate the model.
clients (Iterable[Client]) – The clients that will participate in the federated learning process.
weighted (bool) – A boolean indicating if the clients should be weighted by the number of samples when aggregating the models. Defaults to False.
- _get_client_weights(eligible: Iterable[Client]) list[float] [source]¶
Get the weights of the clients for the aggregation. The weights are calculated based on the number of samples of each client. If the hyperparameter
weighted
is True, the clients are weighted by their number of samples. Otherwise, all clients have the same weight.Caution
The computation of the weights do not adhere to the “best-practices” of
fluke
because the server should not have direct access to the number of samples of the clients. Thus, the computation of the weights should be done communicating with the clients through the channel but, for simplicity, we are not following this practice here. However, the communication overhead is negligible and does not affect the logged performance.
- aggregate(eligible: Iterable[Client], client_models: Iterable[torch.nn.Module]) None [source]¶
Aggregate the models of the clients. The aggregation is done by averaging the models of the clients. If the hyperparameter
weighted
isTrue
, the clients are weighted by their number of samples. The method directly updates the model of the server. Formally, let \(\theta\) be the model of the server, \(\theta_i\) the model of client \(i\), and \(w_i\) the weight of client \(i\) such that \(\sum_{i=1}^{N} w_i = 1\). The aggregation is done as follows [FedAVG]:\[\theta = \sum_{i=1}^{N} w_i \theta_i\]Note
In case of networks with batch normalization layers, the running statistics of the batch normalization layers are also aggregated. For all statistics but
num_batches_tracked
are aggregated the mean is computed, while for thenum_batches_tracked
parameter, the maximum between the clients’ values is taken.See also
- Parameters:
eligible (Iterable[Client]) – The clients that will participate in the aggregation.
client_models (Iterable[torch.nn.Module]) – The models of the clients.
References
[FedAVG]H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, “Communication-Efficient Learning of Deep Networks from Decentralized Data”. In AISTATS (2017).
- broadcast_model(eligible: Iterable[Client]) None [source]¶
Broadcast the global model to the clients.
- Parameters:
eligible (Iterable[Client]) – The clients that will receive the global model.
- property channel: Channel¶
The channel to communicate with the clients.
Important
Always use this channel to exchange data/information with the clients. The server should directly call the clients’ methods only to trigger specific actions.
- Returns:
The channel to communicate with the clients.
- Return type:
- evaluate(evaluator: Evaluator, test_set: FastDataLoader) dict[str, float] [source]¶
Evaluate the global federated model on the
test_set
. If the test set is not set, the method returns an empty dictionary.
- finalize() None [source]¶
Finalize the federated learning process. The finalize method is called at the end of the federated learning process. The client-side evaluation is only done if the client has participated in at least one round.
- fit(n_rounds: int = 10, eligible_perc: float = 0.1, finalize: bool = True, **kwargs: dict[str, Any]) None [source]¶
Run the federated learning algorithm. The default behaviour of this method is to run the Federated Averaging algorithm. The server selects a percentage of the clients to participate in each round, sends the global model to the clients, which compute the local updates and send them back to the server. The server aggregates the models of the clients and repeats the process for a number of rounds.
- Parameters:
n_rounds (int, optional) – The number of rounds to run. Defaults to 10.
eligible_perc (float, optional) – The percentage of clients that will be selected for each round. Defaults to 0.1.
finalize (bool, optional) – If True, the server will finalize the federated learning process. Defaults to True.
**kwargs – Additional keyword arguments.
- get_eligible_clients(eligible_perc: float) Iterable[Client] [source]¶
Get the clients that will participate in the current round. Clients are selected randomly based on the percentage of eligible clients.
- property has_model: bool¶
Return whether the server owns a global model.
- Returns:
True if the server owns a global model, False otherwise.
- Return type:
- property has_test: bool¶
Return whether the server can evaluate the model.
- Returns:
True if the server can evaluate the model, False otherwise.
- Return type:
- load(path: str) None [source]¶
Load the server’s state from file.
- Parameters:
path (str) – The path to load the server’s state.
- receive_client_models(eligible: Iterable[Client], state_dict: bool = True) Generator[torch.nn.Module, None, None] [source]¶
Retrieve the models of the clients. This method assumes that the clients have already sent their models to the server. The models are received through the channel in the same order as the clients in
eligible
.Caution
The method returns a generator of the models of the clients to avoid to clutter the memory with all the models. This means that this method is expected to be called only once per round. If the models are needed multiple times, the generator should be converted to a list, tuple, or any other iterable.
- Parameters:
- Returns:
The models of the clients.
- Return type:
Generator[torch.nn.Module]