New federated algorithm with fluke
¶
This tutorial will guide you through the steps required to implement a new federated learning algorithm that can be tested with fluke
.
Attention
This tutorial does not go into the details of the implementation, but it provides a quick overview of the steps required to implement a new federated learning algorithm. For a more in-depth guide on how to implement your own federated learning algorithm, please refer to this section.
Install fluke
(if not already done)¶
pip install fluke-fl
Implementing the server-side logic¶
To keep it simple, we use a very easy and not particulary smart :) example of a new FL algorithm. Let’s say we want define a new federated algorithm with these two characteristics:
At each round, the server only selects two clients among the participants to be merged;
When selected, a client will perform the local train for a number of epochs that is randomly chosen between 1 and the maximum number of epochs that is a hyperparameter.
Let’s start with the server. Given the characteristics of the algorithm, the only thing the server does differently from the standard FedAvg server is to select only two clients to be merged. The rest of the logic is the same.
1from typing import Iterable
2from fluke.client import Client
3from fluke.server import Server
4import numpy as np
5
6class MyServer(Server):
7
8 # we override the aggregate method to implement our aggregation strategy
9 def aggregate(self, eligible: Iterable[Client]) -> None:
10
11 # eligible is a list of clients that participated in the last round
12 # here we randomly select only two of them
13 selected = np.random.choice(eligible, 2, replace=False)
14
15 # we call the parent class method to aggregate the selected clients
16 return super().aggregate(selected)
Easy! Most of the server’s behaviour is the same as in FedAvg
that is already implemented in fluke.server.Server
.
Implementing the client-side logic¶
Let’s implement the client-side logic now. Also in this case we can start from the FedAvg
client that is already implemented in fluke.client.Client
and modify it to fit our needs.
1class MyClient(Client):
2
3 # we override the fit method to implement our training "strategy"
4 def fit(self, override_local_epochs: int = 0) -> float:
5 # we can override the number of local epochs and call the parent class method
6 new_local_epochs = np.random.randint(1, self.hyper_params.local_epochs + 1)
7 return super().fit(new_local_epochs)
Implementing the new federated algorithm¶
Now, we only need to put everything together in a new class that inherits from fluke.algorithms.CentralizedFL
specifying the server and client classes we just implemented.
1from fluke.algorithms import CentralizedFL
2
3class MyFLAlgorithm(CentralizedFL):
4
5 def get_client_class(self) -> Client:
6 return MyClient
7
8 def get_server_class(self) -> Server:
9 return MyServer
Everything is ready! Now we can test our new federated algorithm with fluke
!
Ready to test the new federated algorithm¶
The rest of the code is the same as in the First steps with fluke
API tutorial.
1from fluke.data import DataSplitter
2from fluke.data.datasets import Datasets
3from fluke import DDict
4from fluke.utils.log import Log
5from fluke.evaluation import ClassificationEval
6from fluke import GlobalSettings
7
8settings = GlobalSettings()
9settings.set_seed(42) # we set a seed for reproducibility
10settings.set_device("cpu") # we use the CPU for this example
11
12dataset = Datasets.get("mnist", path="./data")
13
14# we set the evaluator to be used by both the server and the clients
15settings.set_evaluator(ClassificationEval(eval_every=1, n_classes=dataset.num_classes))
16
17splitter = DataSplitter(dataset=dataset,
18 distribution="iid")
19
20client_hp = DDict(
21 batch_size=10,
22 local_epochs=5,
23 loss="CrossEntropyLoss",
24 optimizer=DDict(
25 lr=0.01,
26 momentum=0.9,
27 weight_decay=0.0001),
28 scheduler=DDict(
29 gamma=1,
30 step_size=1)
31)
32
33# we put together the hyperparameters for the algorithm
34hyperparams = DDict(client=client_hp,
35 server=DDict(weighted=True),
36 model="MNIST_2NN")
Here is where the new federated algorithm comes into play.
1algorithm = MyFLAlgorithm(n_clients=10, # 10 clients in the federation
2 data_splitter=splitter,
3 hyper_params=hyperparams)
4
5logger = Log()
6algorithm.set_callbacks(logger)
We only just need to run it!
1algorithm.run(n_rounds=10, eligible_perc=0.5)