Using a custom model in fluke
¶
This tutorial will guide you through the steps required to use a custom federated neural network in fluke
.
Install fluke
(if not already done)¶
pip install fluke-fl
Define your neural network¶
For the purpose of this tutorial, we will define a very simple neural network for the MNIST dataset. The network will have two hidden layers with ReLU activation function.
1import torch
2from torch.functional import F
3
4class MyMLP(torch.nn.Module):
5
6 def __init__(self):
7 super(MyMLP, self).__init__()
8 self.fc1 = torch.nn.Linear(28*28, 100)
9 self.fc2 = torch.nn.Linear(100, 64)
10 self.fc3 = torch.nn.Linear(64, 10)
11
12 def forward(self, x: torch.Tensor) -> torch.Tensor:
13 x = x.view(-1, 784)
14 x = F.relu(self.fc1(x))
15 x = F.relu(self.fc2(x))
16 x = self.fc3(x)
17 return x
FedAvg with your custom model¶
You are almost ready to use your custom model in fluke
.
The only thing you need to do is to set an instance of your MyMLP
as model in the hyper-parameters of the algorithm.
There is also another possibility, that is to provide as model the fully qualified name of your model class.
This is useful because it allows to use a custom model with the fluke
command line interface.
To keep it simple, we are going to use FedAVG, but you can use any other algorithm available in fluke
or even implement your own.
1from fluke.data import DataSplitter
2from fluke.data.datasets import Datasets
3from fluke import DDict
4from fluke.utils.log import Log
5from fluke.algorithms.fedavg import FedAVG
6from fluke.evaluation import ClassificationEval
7from fluke import GlobalSettings
8
9settings = GlobalSettings()
10settings.set_seed(42) # we set a seed for reproducibility
11settings.set_device("cpu") # we use the CPU for this example
12
13dataset = Datasets.get("mnist", path="./data")
14
15# we set the evaluator to be used by both the server and the clients
16settings.set_evaluator(ClassificationEval(eval_every=1, n_classes=dataset.num_classes))
17
18splitter = DataSplitter(dataset=dataset,
19 distribution="iid")
20
21client_hp = DDict(
22 batch_size=10,
23 local_epochs=5,
24 loss="CrossEntropyLoss",
25 optimizer=DDict(
26 lr=0.01,
27 momentum=0.9,
28 weight_decay=0.0001),
29 scheduler=DDict(
30 gamma=1,
31 step_size=1)
32)
Here is where you must set the model in the hyper-parameters.
1hyperparams = DDict(client=client_hp,
2 server=DDict(weighted=True),
3 model=MyMLP()) # or model="__main__.MyMLP"
4 # or model="mymodule.MyMLP" if the model is in a module called mymodule
Finally, let’s initialize the algorithm and run it.
1algorithm = FedAVG(n_clients=10, # 10 clients in the federation
2 data_splitter=splitter,
3 hyper_params=hyperparams)
4
5logger = Log()
6algorithm.set_callbacks(logger)
1algorithm.run(n_rounds=10, eligible_perc=0.5)