Server
class¶
This class is the core of the federated learning simulation in fluke
. When you start to extend it,
make sure to have a clear understanding of what are the data exchanged between the server and the clients and
how the learning process is orchestrated. This is crucial to avoid introducing bugs and to keep the code clean.
Overview¶
The Server
class is the one responsible for coordinating the federated learning process.
The learning process starts when the fit
method is called on the Server
object. During the
fit
method, the server will iterate over the number of rounds specified in the argument n_rounds
.
Each important server’s operation trigger a callback to the observers that have been registered to the server.
Finally, at the end of the fit
method, the server will finalize the federated learning process.
Server initialization¶
The Server
constructor is responsible for initializing the server. Usually, there is not much more to it than setting the server’s attributes.
However, there is an important notion that you should be aware of: all the server’s hyperparameters should be set in the hyper_params
attribute that is a DDict. This best practice ensure that the hyperparameters are easily accessible and stored in a single place.
Sequence of operations of a single round¶
The standard behaviour of the Server
class (as provided in the class Server)
follows the sequence of operations of a standard Federate Averaging algorithm. The main methods
of the Server
class involved in a single round are:
get_eligibile_clients
: this method is called at the beginning of each round to select the clients that will participate in the round. The selection is based on theeligible_perc
argument of thefit
method.broadcast_model
: this method is called at the beginning of each round to send the global model to the clients that will participate in the round.aggregate
: this method is called at the end of each round to aggregate the models of the clients that participated in the round.evaluate
: this method is called at the end of each round to evaluate the global model on the server-side test set (if any).
The following figure shows the sequence of operations of the Server
class during the fit
method.
Disclaimer
For brevity, many details have been omitted or simplified. However, the figure below shows the key methods and calls involved in a round.
For a complete description of the Server
class, please refer to the Server’s API documentation.
The sequence diagram above shows the sequence of operations of the Server
class during a single round.
It highlights the dependencies between the methods of the Server
class and the Client
class. Moreover,
it shows that the communication between the server and the clients is done through the Channel
.
The only direct call between the server and the client is the fit
method of the Client
class that is
called to trigger the beginning of the training process on the client side.
Finalization¶
At the end of the fit
method, the Server
class will finalize the federated learning process by
calling the finalize
method. Ideally, this method should be used to perform any final operation,
for example, to get the final evaluation of the global (and/or local) model(s), or to save the model(s).
It can also be used to trigger fine-tuning operations client-side as it happens in personalized federated learning.
In its standard implementation, the finalize
method will call the evaluate
method to get the final evaluation
of the global model on the server-side test set (if any) and it also performs the evaluation client-side
after the global model has been broadcasted for the last time.
Observer pattern¶
As mentioned above, the Server
class triggers callbacks to the observers that have been registered to the server.
The default notifications are:
_notify_start_round
: triggered at the beginning of each round. It callsServerObserver.start_round
on each observer;_notify_selected_clients
: triggered after the clients have been selected for the round. It callsServerObserver.selected_clients
on each observer;_notify_end_round
: triggered at the end a round. It callsServerObserver.end_round
on each observer;_notify_evaluation
: it should be triggered after an evaluation has been performed. It callsServerObserver.evaluation
on each observer;_notify_finalize
: triggered at the end of thefinalize
method. It callsServerObserver.finished
on each observer.
Hint
Refer to the API documentation of the ServerObserver inerface and the ObserverSubject intarface for more details.
Creating your Server
class¶
Creating a custom Server
class is straightforward. You need to create a class that inherits from the Server
class
and override the methods that you want to customize. As long as the federated protocol you are implementing follows the
standard Federated Averaging protocol, you can reuse the default implementation of the fit
method and override only the
methods that are specific to your federated protocol.
Let’s see an example of a custom Server
class that overrides the aggregate
method while keeping the default implementation
of the other methods.
Hint
Here we show a single example but you can check all the following algorithm implementations to see
other examples of custom Server.aggregate
:
The example follows the implementation of the FedExP
algorithm. We also report the standard implementation of the aggregate
method for comparison.
1@torch.no_grad()
2def aggregate(self, eligible: Iterable[Client]) -> None:
3 W = flatten_parameters(self.model)
4 clients_model = self.get_client_models(eligible, state_dict=False)
5 Wi = [flatten_parameters(client_model) for client_model in clients_model]
6 eta = self._compute_eta(W, Wi)
7
8 clients_sd = [client.model.state_dict() for client in eligible]
9 avg_model_sd = deepcopy(self.model.state_dict())
10 for key in self.model.state_dict().keys():
11 if key.endswith(STATE_DICT_KEYS_TO_IGNORE):
12 continue
13
14 if key.endswith("num_batches_tracked"):
15 mean_nbt = torch.mean(torch.Tensor([c[key] for c in clients_sd])).long()
16 avg_model_sd[key] = max(avg_model_sd[key], mean_nbt)
17 continue
18
19 avg_model_sd[key] = avg_model_sd[key] - eta * torch.mean(
20 torch.stack([avg_model_sd[key] - client_sd[key] for client_sd in clients_sd]),
21 dim=0)
22 self.model.load_state_dict(avg_model_sd)
23
24def _compute_eta(self, clients_diff: Iterable[dict], eps: float = 1e-4) -> float:
25 ...
1@torch.no_grad()
2def aggregate(self, eligible: Iterable[Client]) -> None:
3 avg_model_sd = OrderedDict()
4 clients_sd = self.get_client_models(eligible)
5 weights = self._get_client_weights(eligible)
6 for key in self.model.state_dict().keys():
7 if key.endswith(STATE_DICT_KEYS_TO_IGNORE):
8 avg_model_sd[key] = self.model.state_dict()[key].clone()
9 continue
10
11 if key.endswith("num_batches_tracked"):
12 mean_nbt = torch.mean(torch.Tensor([c[key] for c in clients_sd])).long()
13 avg_model_sd[key] = max(avg_model_sd[key], mean_nbt)
14 continue
15
16 for i, client_sd in enumerate(clients_sd):
17 if key not in avg_model_sd:
18 avg_model_sd[key] = weights[i] * client_sd[key]
19 else:
20 avg_model_sd[key] += weights[i] * client_sd[key]
21 self.model.load_state_dict(avg_model_sd)
Let’s start by summarizing the implementation of the FedAVG’s aggregate
method. The goal of this method
is to aggregate the models of the clients that participated in the round to update the global model.
The aggregation is done by computing the weighted average of the models of the clients. Thus, the method
first collects the models of the clients that participated in the round (self.get_client_models(eligible)
)
and then computes the weighted average (the for loop) using the weights of the clients (self._get_client_weights(eligible)
).
Finally, the global model is updated with the weighted average model (self.model.load_state_dict(avg_model_sd)
).
The custom implementation of the aggregate
method for the FedExP
algorithm follows a slightly different approach.
The main difference lies in the update rule of the global model that is based on the model differences rather than the models themselves.
For this reason, the method first computes the differences between the models of the clients and the global model.
Then, it computes the global learning rate eta
and the average model difference (eta, mu_diff = self._compute_eta(clients_diff)
) that is needed to update the global model. The rest of the method remains the same as the standard implementation.
Please, refer to the original paper of the FedExP
algorithm for more details on the update rule.
Tip
In general, when you extend the Server
class, you should start overriding the methods from the implementation provided in the Server
class and
then modify only those aspects that do not suit your federated protocol trying to preserve as much as possible the default implementation.
Similar considerations can be made for the other cases when the there is no need to override the fit
.
Attention
When overriding methods that require to notify the observers, make sure to call the corresponding
notification method of the ObserverSubject
interface. For example, if you override the finalize
method
you should call the _notify_finalize
method at the end of the method. For example, see the implementation
of FedBABU.
The fit
method¶
Sometimes you might also need to override the fit
method of the Server
class. This is the case when the federated protocol
you are implementing requires a different sequence of operations than the standard Federated Averaging protocol.
This is quite uncommon but it can happen. Currently, in fluke
, the only algorithms that overrides the fit
method are
FedHP and FedDyn. In both these cases, the protocol differs from
the standard Federated Averaging protocol only in the starting phase of the learning and hence the fit
method is overridden to
add such additional operations and then the super().fit()
is called to trigger the standard behaviour.
When overriding the fit
method, you should follow the following best practices:
Progress bars: track the progress of the learning process using progress bars. In
fluke
, this this is done using therich
library. Inrich
, progress bars and status indicators use a live display that is an instance of theLive
class. You can reuse theLive
instance offluke
from theGlobalSettings
using theget_live_renderer
method. In this live display, you can show the progress of the client-side and server-side learning already available in theGlobalSettings
usingget_progress_bar("clients")
andget_progress_bar("server")
. Then to update the progress bars and to get more information on how to use therich
library, please refer to the official documentation.The following is an excert of the
fit
method, showing how to initialize the progress bars:1with GlobalSettings().get_live_renderer(): 2 progress_fl = GlobalSettings().get_progress_bar("FL") 3 progress_client = GlobalSettings().get_progress_bar("clients") 4 client_x_round = int(self.n_clients * eligible_perc) 5 task_rounds = progress_fl.add_task("[red]FL Rounds", total=n_rounds*client_x_round) 6 task_local = progress_client.add_task("[green]Local Training", total=client_x_round) 7 ...
Communication: in
fluke
, generally, clients and server should not call each other methods directly. There are very few exceptions to this rule, for example, when the server needs to trigger the training on the client side or when the server asks to perform the evaluation. In all other cases, the communication between the server and the clients should be done through theChannel
class (see the Channel API reference). TheChannel
instance is available in theServer
class (_channel
private instance orchannel
property) and it must be used to send/receive messages. Messages must be encapsulated in a Message object. Using a channel allowsfluke
, through the logger (see Log) to keep track of the exchanged messages and so it will automatically compute the communication cost. The following is the implementation of thebroadcast_model
method that uses theChannel
to send the global model to the clients:1def broadcast_model(self, eligible: Iterable[Client]) -> None: 2 self.channel.broadcast(Message(self.model, "model", self), eligible)
Minimal changes principle: this principle universally applies to software development but it is particularly important when overriding the
fit
method because it represents the point where the whole simulation is orchestrated. Start by copying the standard implementation of thefit
method and then modify only the parts that are specific to your federated protocol. This will help you to keep the code clean and to avoid introducing nasty bugs.