Custom evaluation with fluke

This tutorial will guide you through the steps required to implement a new evaluation that can be tested with fluke.

Install fluke (if not already done)

1!pip install fluke-fl

Weighted accuracy

In this tutorial, we will show how to implement a metric that is quite common in Personalized Federated Learning in fluke! In particular, a common technique for evaluating the local model is through using a balanced test set, weighting the accuracy based on the number of samples of each class. Intuitively, the weighted accuracy takes into account the number of samples for each class, allowing to lower the penalty if an error occurs in classifying a less frequent class. This metric is taken from Tackling Data Heterogeneity in Federated Learning with Class Prototypes, Dai et al. and it is defined as follows:

$$ acc_i = \frac{\sum_{x_j,y_j\in D_{test}}\alpha_i(y_j)\mathbb{1}(y_j = \hat{y}j)}{\sum{x_j,y_j\in D_{test}}\alpha_i(y_j)} $$

where $\alpha_i(\cdot)$ is a positive valued function. It is defined as the probability that the sample y is from class c in the $i^{th}$ client. Notice that, for $\alpha_i(\cdot) = 1$ we obtain the traditional accuracy. In this tutorial, we will interpret $\alpha_i(\cdot)$ as the proportion of the local samples of the class $y$ over all the sample of that client. Specifically, we calculate the aforementioned coefficient for client $i$ and class $y_j$ as $\alpha_i(y_j) = \frac{Y^i_j}{Y^i}$, where $Y^i_j$ is the number of samples of class $y_j$ for client $i$ (training set), and $Y^i$ is the total number of examples of client $i$.

In our case $D_{test}$ will be the dataset on the server, that is (usually) the original test set of the dataset.


Implementing the metric

In the following, we start from the classification metric present in eval.py and modify it, taking into account the weight for each class. As a sanity check, in the global evaluation accuracy and weighted accuracy will be the same.

 1import torch
 2from torchmetrics import Metric
 3
 4class WeightedAccuracy(Metric):
 5
 6    def __init__(self, num_classes: int, weights: torch.Tensor):
 7        super().__init__()
 8        self.num_classes = num_classes
 9        self.weights = weights
10        self.true_weights = []
11        self.pred_weights = []
12        self.mask = []
13
14    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
15        if preds.ndim == 2:
16            preds = torch.argmax(preds, dim=1)
17        self.true_weights.append(self.weights[target])
18        self.pred_weights.append(self.weights[preds])
19        self.mask.append(torch.eq(preds, target))
20
21    def compute(self) -> float:
22        true_weights = torch.cat(self.true_weights, dim=0)
23        pred_weights = torch.cat(self.pred_weights, dim=0)
24        mask = torch.cat(self.mask, dim=0)
25        pred_weights = pred_weights * mask
26        return pred_weights.sum() / true_weights.sum()

Implementing the server-side logic

Notice that server.evaluation is called in server.fit with only two arguments (the evaluator and the eligible clients). As a consequence, if we want to modify the server.evaluate to take into account the class weights, we should modify the server.fit as well. However, this is too verbose. The most straightforward solution is to not modify server.fit and the input arguments of server.evaluate function, but modify the evaluator evaluator.evaluate input arguments, adding the class weight.

 1import numpy as np
 2import torch
 3
 4from fluke.data import FastDataLoader  # NOQA
 5from fluke.evaluation import Evaluator  # NOQA
 6from fluke.client import Client  # NOQA
 7from fluke.server import Server  # NOQA
 8
 9class MyServer(Server):
10
11    def evaluate(self, 
12                 evaluator: Evaluator,
13                 test_set: FastDataLoader) -> dict[str, float]:
14        if test_set is not None:
15            return evaluator.evaluate(self.rounds + 1, 
16                                      self.model, 
17                                      test_set, 
18                                      device=self.device, 
19                                      additional_metrics={
20                                        "weighted_accuracy": WeightedAccuracy(
21                                        evaluator.n_classes, torch.ones(evaluator.n_classes)
22                                      )})
23        return {}
24    

Implementing the client-side logic

Following the same logic as the server, we modify the evaluator.evaluate instead of the whole client.local_update and the inputs of client.evaluate.

 1from torch.nn import Module
 2from typing import Any
 3
 4from fluke.config import OptimizerConfigurator  # NOQA
 5
 6class MyClient(Client):
 7
 8    def __init__(self,
 9                 index: int,
10                 train_set: FastDataLoader,
11                 test_set: FastDataLoader,
12                 optimizer_cfg: OptimizerConfigurator,
13                 loss_fn: Module,
14                 local_epochs: int = 3,
15                 **kwargs):
16        super().__init__(index,
17                 train_set,
18                 test_set,
19                 optimizer_cfg,
20                 loss_fn,
21                 local_epochs,
22                 **kwargs)
23        self.class_weights = torch.bincount(self.train_set.tensors[1]).float()
24        self.class_weights /= self.train_set.size
25
26        
27    def evaluate(self, 
28                 evaluator: Evaluator,
29                 test_set: FastDataLoader) -> dict[str, float]: 
30        if self.model is not None:
31            return evaluator.evaluate(self._last_round, 
32                                      self.model, 
33                                      test_set, 
34                                      device=self.device, 
35                                      additional_metrics={
36                                        "weighted_accuracy": WeightedAccuracy(
37                                        evaluator.n_classes, self.class_weights
38                                      )})
39        return {}

Testing the new metric

Now, we are ready to test our metric!

1from fluke.algorithms import CentralizedFL
2
3class MyFLAlgorithm(CentralizedFL):
4
5    def get_client_class(self) -> type[Client]:
6        return MyClient
7
8    def get_server_class(self) -> type[Server]:
9        return MyServer

Ready to test the new federated algorithm

The rest of the code is the similar to the First steps with fluke API tutorial. We just replace ClassificationEval with our custom evaluation.

 1from fluke.data import DataSplitter
 2from fluke.data.datasets import Datasets
 3from fluke.evaluation import ClassificationEval
 4from fluke import DDict
 5from fluke.utils.log import Log
 6from fluke import FlukeENV
 7
 8env = FlukeENV()
 9env.set_seed(42) # we set a seed for reproducibility
10env.set_eval_cfg(pre_fit=True, post_fit=True)
11env.set_evaluator(ClassificationEval(eval_every=1, n_classes=10))
12
13dataset = Datasets.get("mnist", path="./data")
14splitter = DataSplitter(dataset=dataset, distribution="dir",
15                        client_split=0.1, dist_args=DDict(beta=0.5))
16
17client_hp = DDict(
18    batch_size=10,
19    local_epochs=5,
20    loss="CrossEntropyLoss",
21    optimizer=DDict(
22      lr=0.01,
23      momentum=0.9,
24      weight_decay=0.0001),
25    scheduler=DDict(
26      gamma=1,
27      step_size=1)
28)
29
30# we put together the hyperparameters for the algorithm
31hyperparams = DDict(client=client_hp,
32                    server=DDict(weighted=True),
33                    model="MNIST_2NN")

Here is where the new federated algorithm comes into play.

1algorithm = MyFLAlgorithm(n_clients=10, # 10 clients in the federation
2                          data_splitter=splitter,
3                          hyper_params=hyperparams)
4
5logger = Log()
6algorithm.set_callbacks(logger)

We only just need to run it!

1algorithm.run(n_rounds=100, eligible_perc=0.5)