Using a custom model in fluke

This tutorial will guide you through the steps required to use a custom federated neural network in fluke.

Install fluke (if not already done)

pip install fluke-fl

Define your neural network

For the purpose of this tutorial, we will define a very simple neural network for the MNIST dataset. The network will have two hidden layers with ReLU activation function.

 1import torch
 2from torch.functional import F
 3
 4class MyMLP(torch.nn.Module):
 5
 6    def __init__(self):
 7        super(MyMLP, self).__init__()
 8        self.fc1 = torch.nn.Linear(28*28, 100)
 9        self.fc2 = torch.nn.Linear(100, 64)
10        self.fc3 = torch.nn.Linear(64, 10)
11
12    def forward(self, x: torch.Tensor) -> torch.Tensor:
13        x = x.view(-1, 784)
14        x = F.relu(self.fc1(x))
15        x = F.relu(self.fc2(x))
16        x = self.fc3(x)
17        return x

FedAvg with your custom model

You are almost ready to use your custom model in fluke. The only thing you need to do is to set an instance of your MyMLP as model in the hyper-parameters of the algorithm.

There is also another possibility, that is to provide as model the fully qualified name of your model class. This is useful because it allows to use a custom model with the fluke command line interface.

To keep it simple, we are going to use FedAVG, but you can use any other algorithm available in fluke or even implement your own.

 1from fluke.data import DataSplitter
 2from fluke.data.datasets import Datasets
 3from fluke import DDict
 4from fluke.utils.log import Log
 5from fluke.algorithms.fedavg import FedAVG
 6from fluke.evaluation import ClassificationEval
 7from fluke import GlobalSettings
 8
 9settings = GlobalSettings()
10settings.set_seed(42) # we set a seed for reproducibility
11settings.set_device("cpu") # we use the CPU for this example
12
13dataset = Datasets.get("mnist", path="./data")
14
15# we set the evaluator to be used by both the server and the clients
16settings.set_evaluator(ClassificationEval(eval_every=1, n_classes=dataset.num_classes))
17
18splitter = DataSplitter(dataset=dataset,
19                        distribution="iid")
20
21client_hp = DDict(
22    batch_size=10,
23    local_epochs=5,
24    loss="CrossEntropyLoss",
25    optimizer=DDict(
26        lr=0.01,
27        momentum=0.9,
28        weight_decay=0.0001),
29    scheduler=DDict(
30        gamma=1,
31        step_size=1)
32)

Here is where you must set the model in the hyper-parameters.

1hyperparams = DDict(client=client_hp,
2                    server=DDict(weighted=True),
3                    model=MyMLP()) # or model="__main__.MyMLP"
4                                   # or model="mymodule.MyMLP" if the model is in a module called mymodule

Finally, let’s initialize the algorithm and run it.

1algorithm = FedAVG(n_clients=10, # 10 clients in the federation
2                   data_splitter=splitter,
3                   hyper_params=hyperparams)
4
5logger = Log()
6algorithm.set_callbacks(logger)
1algorithm.run(n_rounds=10, eligible_perc=0.5)