First steps with fluke
API¶
This tutorial will guide you through the first steps with the fluke
API. We will show how to quickly run an experiment using the API.
Installation via pip¶
If you haven’t installed the package yet, you can do so by running the following command:
pip install fluke-fl
Loading and splitting the dataset¶
First of all, we need to load the dataset. Let say we want to load the MNIST
dataset.
1from fluke.data.datasets import Datasets
2dataset = Datasets.get("mnist", path="./data")
dataset
is a DataContainer
which is a simple data structure containing the dataset as it is loaded from the files. The downloaded files are stored in the directory path
. If the dataset is already downloaded, path
can be set to the directory containing the files.
After loading the dataset, we need to prepare it for the distribution. For simplicity, let say that we use the test set provided by the dataset as the server-side test set (which will test the performance of the global model), and the training set as the client-side training set (which will be distributed to the clients). For now, we will use the default data distribution strategy, which is IID and client-side we do not have any test set.
1from fluke.data import DataSplitter
2splitter = DataSplitter(dataset=dataset,
3 distribution="iid")
A DataSplitter
is the class responsible for splitting the dataset into the server-side and client-side datasets.
Setting up the evaluator¶
The evaluator is the class responsible for evaluating the performance of both the global and local models.
It must be defined in the global setting of fluke
as follows.
1from fluke.evaluation import ClassificationEval
2from fluke import GlobalSettings
3
4evaluator = ClassificationEval(eval_every=1, n_classes=dataset.num_classes)
5GlobalSettings().set_evaluator(evaluator)
Here we are using an evaluator for the classification task (to now the only one suppoerted).
eval_every
is the number of communication rounds after which the models are evaluated.
Instantiate and configure the federated learning algorithm¶
Now, we are ready to instantiate our algorithm. We will go with the standard FedAvg but many others are available on fluke
.
Instantiating a federated learning algorithm requires to set a bunch of hyper-parameters. fluke
divides these parameters into two groups:
client-side: the hyper-parameters of the clients which include the type of optimizer (and scheduler), learning rate, the number of local epochs, etc..
server-side: hyper-parameters of the server, which are typically less than the clients’ hyper-parameters, e.g., whether the aggregation is weighted or not.
In the following code, we will set the hyper-parameters of the clients using a DDict
that is a convenient data structure defined in fluke
. A simple dictionary can be used as well.
1from fluke import DDict
2client_hp = DDict(
3 batch_size=10,
4 local_epochs=5,
5 loss="CrossEntropyLoss",
6 optimizer=DDict(
7 lr=0.01,
8 momentum=0.9,
9 weight_decay=0.0001),
10 scheduler=DDict(
11 gamma=1,
12 step_size=1)
13)
14
15# we put together the hyperparameters for the algorithm
16hyperparams = DDict(client=client_hp,
17 server=DDict(weighted=True),
18 model="MNIST_2NN")
As you may see, we need also to specify the model (i.e., the neural network) that will be used in the federated learning process. In this example, we will use the MNIST_2NN
model which is a simple multi-layer perceptron with two hidden layers. The model is defined in the nets
module of the fluke
package.
Finally, we are all set to create the federated learning algorithm. The FedAvg
class is the implementation of the Federated Averaging algorithm. The FedAvg
class requires the following parameters:
1from fluke.algorithms.fedavg import FedAVG
2algorithm = FedAVG(100, splitter, hyperparams)
Before running the algorithm, we need to make sure to log the results. fluke
is designed to allow different types of logging. For this reason, it implements the design pattern Observer
. To attach a logger to the algorithm, we need to create an instance of the logger and attach it to the algorithm.
1from fluke.utils.log import Log
2logger = Log()
3algorithm.set_callbacks(logger)
Log
is a simple logger that logs the results in the console, while keeping the history of the results in a dictionaries.
Ready to go!¶
Finally, we can run the algorithm. The run
method of the algorithm requires to specify the number of rounds and the fraction of clients that will participate in each round.
1algorithm.run(2, 0.5)