Custom evaluation with fluke

This tutorial will guide you through the steps required to implement a new evaluation that can be tested with fluke.

Try this notebook: Open in Colab

Install fluke (if not already done)

pip install fluke-fl

Weighted accuracy

In this tutorial, we will show how to implement a metric that is quite common in Personalized Federated Learning in fluke! In particular, a common technique for evaluating the local model is through using a balanced test set, weighting the accuracy based on the number of samples of each class. Intuitively, the weighted accuracy takes into account the number of samples for each class, allowing to lower the penalty if an error occurs in classifying a less frequent class. This metric is taken from Tackling Data Heterogeneity in Federated Learning with Class Prototypes, Dai et al. and it is defined as follows:

\[acc_i = \frac{\sum_{x_j,y_j\in \mathcal{D}_{test}}\alpha_i(y_j)\mathbb{1}(y_j = \hat{y}_j)}{\sum_{x_j,y_j\in \mathcal{D}_{test}}\alpha_i(y_j)}\]

where \(\alpha_i(\cdot)\) is a positive valued function. It is defined as the probability that the sample y is from class c in the \(i^{th}\) client. Notice that, for \(\alpha_i(\cdot) = 1\) we obtain the traditional accuracy. In this tutorial, we will interpret \(\alpha_i(\cdot)\) as the proportion of the local samples of the class \(y\) over all the sample of that client. Specifically, we calculate the aforementioned coefficient for client \(i\) and class \(y_j\) as \(\alpha_i(y_j) = \frac{Y^i_j}{Y^i}\), where \(Y^i_j\) is the number of samples of class \(y_j\) for client \(i\) (training set), and \(Y^i\) is the total number of examples of client \(i\).

In our case \(\mathcal{D}_{test}\) will be the dataset on the server, that is (usually) the original test set of the dataset.

Implementing the server-side logic

Notice that server.evaluation is called in server.fit with only two arguments (the evaluator and the eligible clients). As a consequence, if we want to modify the server.evaluate to take into account the class weights, we should modify the server.fit as well. However, this is too verbose. The most straightforward solution is to not modify server.fit and the input arguments of server.evaluate function, but modify the evaluator evaluator.evaluate input arguments, adding the class weight.

 1import numpy as np
 2import torch
 3
 4from fluke.data import FastDataLoader  # NOQA
 5from fluke.evaluation import Evaluator  # NOQA
 6from fluke.client import Client  # NOQA
 7from fluke.server import Server  # NOQA
 8
 9class MyServer(Server):
10
11    def evaluate(self, 
12                 evaluator: Evaluator,
13                 test_set: FastDataLoader) -> dict[str, float]:
14        if test_set is not None:
15            return evaluator.evaluate(self.rounds + 1, 
16                                      self.model, 
17                                      test_set, 
18                                      device=self.device, 
19                                      weights=torch.ones(evaluator.n_classes))
20        return {}
21    

Implementing the client-side logic

Following the same logic as the server, we modify the evaluator.evaluate instead of the whole client.local_update and the inputs of client.evaluate.

 1from torch.nn import Module
 2from typing import Any
 3
 4from fluke.utils import OptimizerConfigurator  # NOQA
 5
 6class MyClient(Client):
 7
 8    def __init__(self,
 9                 index: int,
10                 train_set: FastDataLoader,
11                 test_set: FastDataLoader,
12                 optimizer_cfg: OptimizerConfigurator,
13                 loss_fn: Module,
14                 local_epochs: int = 3,
15                 **kwargs: dict[str, Any]):
16        super().__init__(index,
17                 train_set,
18                 test_set,
19                 optimizer_cfg,
20                 loss_fn,
21                 local_epochs,
22                 **kwargs)
23        self.class_weights = torch.bincount(self.train_set.tensors[1]).float()
24        self.class_weights /= self.train_set.size
25
26        
27    def evaluate(self, 
28                 evaluator: Evaluator,
29                 test_set: FastDataLoader) -> dict[str, float]: 
30        if self.model is not None:
31            return evaluator.evaluate(self._last_round, 
32                                      self.model, 
33                                      test_set, 
34                                      device=self.device, 
35                                      weights=self.class_weights)
36        return {}

Implementing the metric

In the following, we start from the classification metric present in eval.py and modify it, taking into account the weight for each class. As a sanity check, in the global evaluation accuracy and weighted accuracy will be the same.

 1from typing import Iterable, Optional, Union
 2import numpy as np
 3import torch
 4from torchmetrics import Accuracy
 5
 6from fluke.utils import clear_cuda_cache  # NOQA
 7
 8
 9class WeightedClassificationEval(Evaluator):
10
11    def __init__(self, eval_every: int, n_classes: int):
12        super().__init__(eval_every=eval_every)
13        self.n_classes: int = n_classes
14
15    def evaluate(self,
16                 round: int,
17                 model: torch.nn.Module,
18                 eval_data_loader: Union[FastDataLoader,
19                                         Iterable[FastDataLoader]],
20                 loss_fn: Optional[torch.nn.Module] = None,
21                 device: torch.device = torch.device("cpu"),
22                 weights: torch.tensor = None) -> dict:
23        
24
25        if round % self.eval_every != 0:
26            return {}
27
28        if (model is None) or (eval_data_loader is None):
29            return {}
30
31        model.eval()
32        model.to(device)
33        task = "multiclass"  # if self.n_classes >= 2 else "binary"
34        accs, losses = [], []
35        true_weights, pred_weights, mask = [], [], []
36        weight_accs = []
37        loss, cnt = 0, 0
38
39        if not isinstance(eval_data_loader, list):
40            eval_data_loader = [eval_data_loader]
41
42        for data_loader in eval_data_loader:
43            accuracy = Accuracy(task=task,
44                                num_classes=self.n_classes,
45                                top_k=1,
46                                average="micro")
47            loss = 0
48            for X, y in data_loader:
49                X, y = X.to(device), y.to(device)
50                with torch.no_grad():
51                    y_hat = model(X)
52                    if loss_fn is not None:
53                        loss += loss_fn(y_hat, y).item()
54                    y_hat = torch.max(y_hat, dim=1)[1]
55                true_weights.append(weights[y])
56                pred_weights.append(weights[y_hat])
57                mask.append(torch.eq(y, y_hat))
58                accuracy.update(y_hat.cpu(), y.cpu())
59               
60            
61            true_weights = torch.cat(true_weights, dim=0)
62            pred_weights = torch.cat(pred_weights, dim=0)
63            mask = torch.cat(mask, dim=0)
64            pred_weights = pred_weights*mask
65            weight_accs.append(pred_weights.sum().item() / true_weights.sum().item())
66            
67            cnt += len(data_loader)
68            accs.append(accuracy.compute().item())
69            losses.append(loss / cnt)
70
71        model.cpu()
72        clear_cuda_cache()
73
74        result = {
75            "accuracy":  np.round(sum(accs) / len(accs), 5).item(),
76            "weighted_accuracy":  np.round(sum(weight_accs) / len(weight_accs), 5).item(),
77        }
78        if loss_fn is not None:
79            result["loss"] = np.round(sum(losses) / len(losses), 5).item()
80
81        return result
82
83    def __str__(self) -> str:
84        return f"{self.__class__.__name__}(eval_every={self.eval_every}" + \
85               f", n_classes={self.n_classes})[accuracy, weight_acc]"
86
87    def __repr__(self) -> str:
88        return str(self)

Implementing the new metric

Now, we are ready to test our metric!

1from fluke.algorithms import CentralizedFL
2
3class MyFLAlgorithm(CentralizedFL):
4
5    def get_client_class(self) -> Client:
6        return MyClient
7
8    def get_server_class(self) -> Server:
9        return MyServer

Ready to test the new federated algorithm

The rest of the code is the similar to the First steps with fluke API tutorial. We just replace ClassificationEval with our custom evaluation.

 1from fluke.data import DataSplitter
 2from fluke.data.datasets import Datasets
 3from fluke import DDict
 4from fluke.utils.log import Log
 5from fluke import FlukeENV
 6
 7env = FlukeENV()
 8env.set_seed(42) # we set a seed for reproducibility
 9env.set_device("cpu") # we use the CPU for this example
10# we set the evaluation configuration
11env.set_eval_cfg(DDict(pre_fit=True, post_fit=True)) 
12
13# we set the evaluator to be used by both the server and the clients
14env.set_evaluator(WeightedClassificationEval(eval_every=1, n_classes=10))
15
16dataset = Datasets.get("mnist", path="./data")
17splitter = DataSplitter(dataset=dataset, distribution="iid")
18
19client_hp = DDict(
20    batch_size=10,
21    local_epochs=5,
22    loss="CrossEntropyLoss",
23    optimizer=DDict(
24      lr=0.01,
25      momentum=0.9,
26      weight_decay=0.0001),
27    scheduler=DDict(
28      gamma=1,
29      step_size=1)
30)
31
32# we put together the hyperparameters for the algorithm
33hyperparams = DDict(client=client_hp,
34                    server=DDict(weighted=True),
35                    model="MNIST_2NN")

Here is where the new federated algorithm comes into play.

1algorithm = MyFLAlgorithm(n_clients=10, # 10 clients in the federation
2                          data_splitter=splitter,
3                          hyper_params=hyperparams)
4
5logger = Log()
6algorithm.set_callbacks(logger)

We only just need to run it!

1algorithm.run(n_rounds=100, eligible_perc=0.5)