New federated algorithm with fluke

This tutorial will guide you through the steps required to implement a new federated learning algorithm that can be tested with fluke.

Attention

This tutorial does not go into the details of the implementation, but it provides a quick overview of the steps required to implement a new federated learning algorithm. For a more in-depth guide on how to implement your own federated learning algorithm, please refer to this section.

Try this notebook: Open in Colab

Install fluke (if not already done)

pip install fluke-fl

Implementing the server-side logic

To keep it simple, we use a very easy and not particulary smart :) example of a new FL algorithm. Let’s say we want define a new federated algorithm with these two characteristics:

  • At each round, the server only selects two clients among the participants to be merged;

  • When selected, a client will perform the local train for a number of epochs that is randomly chosen between 1 and the maximum number of epochs that is a hyperparameter.

Let’s start with the server. Given the characteristics of the algorithm, the only thing the server does differently from the standard FedAvg server is to select only two clients to be merged. The rest of the logic is the same.

 1from typing import Iterable
 2from fluke.client import Client
 3from fluke.server import Server
 4import numpy as np
 5
 6class MyServer(Server):
 7
 8    # we override the aggregate method to implement our aggregation strategy
 9    def aggregate(self, eligible: Iterable[Client]) -> None:
10        
11        # eligible is a list of clients that participated in the last round
12        # here we randomly select only two of them
13        selected = np.random.choice(eligible, 2, replace=False)
14
15        # we call the parent class method to aggregate the selected clients
16        return super().aggregate(selected)

Easy! Most of the server’s behaviour is the same as in FedAvg that is already implemented in fluke.server.Server.

Implementing the client-side logic

Let’s implement the client-side logic now. Also in this case we can start from the FedAvg client that is already implemented in fluke.client.Client and modify it to fit our needs.

1class MyClient(Client):
2
3    # we override the fit method to implement our training "strategy"
4    def fit(self, override_local_epochs: int = 0) -> float:
5        # we can override the number of local epochs and call the parent class method
6        new_local_epochs = np.random.randint(1, self.hyper_params.local_epochs + 1)
7        return super().fit(new_local_epochs)

Implementing the new federated algorithm

Now, we only need to put everything together in a new class that inherits from fluke.algorithms.CentralizedFL specifying the server and client classes we just implemented.

1from fluke.algorithms import CentralizedFL
2
3class MyFLAlgorithm(CentralizedFL):
4
5    def get_client_class(self) -> Client:
6        return MyClient
7
8    def get_server_class(self) -> Server:
9        return MyServer

Everything is ready! Now we can test our new federated algorithm with fluke!

Ready to test the new federated algorithm

The rest of the code is the same as in the First steps with fluke API tutorial.

 1from fluke.data import DataSplitter
 2from fluke.data.datasets import Datasets
 3from fluke import DDict
 4from fluke.utils.log import Log
 5from fluke.evaluation import ClassificationEval
 6from fluke import GlobalSettings
 7
 8settings = GlobalSettings()
 9settings.set_seed(42) # we set a seed for reproducibility
10settings.set_device("cpu") # we use the CPU for this example
11
12dataset = Datasets.get("mnist", path="./data")
13
14# we set the evaluator to be used by both the server and the clients
15settings.set_evaluator(ClassificationEval(eval_every=1, n_classes=dataset.num_classes))
16
17splitter = DataSplitter(dataset=dataset,
18                        distribution="iid")
19
20client_hp = DDict(
21    batch_size=10,
22    local_epochs=5,
23    loss="CrossEntropyLoss",
24    optimizer=DDict(
25      lr=0.01,
26      momentum=0.9,
27      weight_decay=0.0001),
28    scheduler=DDict(
29      gamma=1,
30      step_size=1)
31)
32
33# we put together the hyperparameters for the algorithm
34hyperparams = DDict(client=client_hp,
35                    server=DDict(weighted=True),
36                    model="MNIST_2NN")

Here is where the new federated algorithm comes into play.

1algorithm = MyFLAlgorithm(n_clients=10, # 10 clients in the federation
2                          data_splitter=splitter,
3                          hyper_params=hyperparams)
4
5logger = Log()
6algorithm.set_callbacks(logger)

We only just need to run it!

1algorithm.run(n_rounds=10, eligible_perc=0.5)