Custom evaluation with fluke
¶
This tutorial will guide you through the steps required to implement a new evaluation that can be tested with fluke
.
Install fluke
(if not already done)¶
pip install fluke-fl
Weighted accuracy¶
In this tutorial, we will show how to implement a metric that is quite common in Personalized Federated Learning in fluke
! In particular, a common technique for evaluating the local model is through using a balanced test set, weighting the accuracy based on the number of samples of each class. Intuitively, the weighted accuracy takes into account the number of samples for each class, allowing to lower the penalty if an error occurs in classifying a less frequent class. This metric is taken from Tackling Data Heterogeneity in Federated Learning with Class Prototypes, Dai et al. and it is defined as follows:
where \(\alpha_i(\cdot)\) is a positive valued function. It is defined as the probability that the sample y is from class c in the \(i^{th}\) client. Notice that, for \(\alpha_i(\cdot) = 1\) we obtain the traditional accuracy. In this tutorial, we will interpret \(\alpha_i(\cdot)\) as the proportion of the local samples of the class \(y\) over all the sample of that client. Specifically, we calculate the aforementioned coefficient for client \(i\) and class \(y_j\) as \(\alpha_i(y_j) = \frac{Y^i_j}{Y^i}\), where \(Y^i_j\) is the number of samples of class \(y_j\) for client \(i\) (training set), and \(Y^i\) is the total number of examples of client \(i\).
In our case \(\mathcal{D}_{test}\) will be the dataset on the server, that is (usually) the original test set of the dataset.
Implementing the server-side logic¶
Notice that server.evaluation
is called in server.fit
with only two arguments (the evaluator and the eligible clients). As a consequence, if we want to modify the server.evaluate
to take into account the class weights, we should modify the server.fit
as well. However, this is too verbose. The most straightforward solution is to not modify server.fit
and the input arguments of server.evaluate
function, but modify the evaluator evaluator.evaluate
input arguments, adding the class weight.
1import numpy as np
2import torch
3
4from fluke.data import FastDataLoader # NOQA
5from fluke.evaluation import Evaluator # NOQA
6from fluke.client import Client # NOQA
7from fluke.server import Server # NOQA
8
9class MyServer(Server):
10
11 def evaluate(self,
12 evaluator: Evaluator,
13 test_set: FastDataLoader) -> dict[str, float]:
14 if test_set is not None:
15 return evaluator.evaluate(self.rounds + 1,
16 self.model,
17 test_set,
18 device=self.device,
19 weights=torch.ones(evaluator.n_classes))
20 return {}
21
Implementing the client-side logic¶
Following the same logic as the server, we modify the evaluator.evaluate
instead of the whole client.local_update
and the inputs of client.evaluate
.
1from torch.nn import Module
2from typing import Any
3
4from fluke.utils import OptimizerConfigurator # NOQA
5
6class MyClient(Client):
7
8 def __init__(self,
9 index: int,
10 train_set: FastDataLoader,
11 test_set: FastDataLoader,
12 optimizer_cfg: OptimizerConfigurator,
13 loss_fn: Module,
14 local_epochs: int = 3,
15 **kwargs: dict[str, Any]):
16 super().__init__(index,
17 train_set,
18 test_set,
19 optimizer_cfg,
20 loss_fn,
21 local_epochs,
22 **kwargs)
23 self.class_weights = torch.bincount(self.train_set.tensors[1]).float()
24 self.class_weights /= self.train_set.size
25
26
27 def evaluate(self,
28 evaluator: Evaluator,
29 test_set: FastDataLoader) -> dict[str, float]:
30 if self.model is not None:
31 return evaluator.evaluate(self._last_round,
32 self.model,
33 test_set,
34 device=self.device,
35 weights=self.class_weights)
36 return {}
Implementing the metric¶
In the following, we start from the classification
metric present in eval.py
and modify it, taking into account the weight for each class. As a sanity check, in the global evaluation accuracy
and weighted accuracy
will be the same.
1from typing import Iterable, Optional, Union
2import numpy as np
3import torch
4from torchmetrics import Accuracy
5
6from fluke.utils import clear_cuda_cache # NOQA
7
8
9class WeightedClassificationEval(Evaluator):
10
11 def __init__(self, eval_every: int, n_classes: int):
12 super().__init__(eval_every=eval_every)
13 self.n_classes: int = n_classes
14
15 def evaluate(self,
16 round: int,
17 model: torch.nn.Module,
18 eval_data_loader: Union[FastDataLoader,
19 Iterable[FastDataLoader]],
20 loss_fn: Optional[torch.nn.Module] = None,
21 device: torch.device = torch.device("cpu"),
22 weights: torch.tensor = None) -> dict:
23
24
25 if round % self.eval_every != 0:
26 return {}
27
28 if (model is None) or (eval_data_loader is None):
29 return {}
30
31 model.eval()
32 model.to(device)
33 task = "multiclass" # if self.n_classes >= 2 else "binary"
34 accs, losses = [], []
35 true_weights, pred_weights, mask = [], [], []
36 weight_accs = []
37 loss, cnt = 0, 0
38
39 if not isinstance(eval_data_loader, list):
40 eval_data_loader = [eval_data_loader]
41
42 for data_loader in eval_data_loader:
43 accuracy = Accuracy(task=task,
44 num_classes=self.n_classes,
45 top_k=1,
46 average="micro")
47 loss = 0
48 for X, y in data_loader:
49 X, y = X.to(device), y.to(device)
50 with torch.no_grad():
51 y_hat = model(X)
52 if loss_fn is not None:
53 loss += loss_fn(y_hat, y).item()
54 y_hat = torch.max(y_hat, dim=1)[1]
55 true_weights.append(weights[y])
56 pred_weights.append(weights[y_hat])
57 mask.append(torch.eq(y, y_hat))
58 accuracy.update(y_hat.cpu(), y.cpu())
59
60
61 true_weights = torch.cat(true_weights, dim=0)
62 pred_weights = torch.cat(pred_weights, dim=0)
63 mask = torch.cat(mask, dim=0)
64 pred_weights = pred_weights*mask
65 weight_accs.append(pred_weights.sum().item() / true_weights.sum().item())
66
67 cnt += len(data_loader)
68 accs.append(accuracy.compute().item())
69 losses.append(loss / cnt)
70
71 model.cpu()
72 clear_cuda_cache()
73
74 result = {
75 "accuracy": np.round(sum(accs) / len(accs), 5).item(),
76 "weighted_accuracy": np.round(sum(weight_accs) / len(weight_accs), 5).item(),
77 }
78 if loss_fn is not None:
79 result["loss"] = np.round(sum(losses) / len(losses), 5).item()
80
81 return result
82
83 def __str__(self) -> str:
84 return f"{self.__class__.__name__}(eval_every={self.eval_every}" + \
85 f", n_classes={self.n_classes})[accuracy, weight_acc]"
86
87 def __repr__(self) -> str:
88 return str(self)
Implementing the new metric¶
Now, we are ready to test our metric!
1from fluke.algorithms import CentralizedFL
2
3class MyFLAlgorithm(CentralizedFL):
4
5 def get_client_class(self) -> Client:
6 return MyClient
7
8 def get_server_class(self) -> Server:
9 return MyServer
Ready to test the new federated algorithm¶
The rest of the code is the similar to the First steps with fluke
API tutorial. We just replace ClassificationEval
with our custom evaluation.
1from fluke.data import DataSplitter
2from fluke.data.datasets import Datasets
3from fluke import DDict
4from fluke.utils.log import Log
5from fluke import FlukeENV
6
7env = FlukeENV()
8env.set_seed(42) # we set a seed for reproducibility
9env.set_device("cpu") # we use the CPU for this example
10# we set the evaluation configuration
11env.set_eval_cfg(DDict(pre_fit=True, post_fit=True))
12
13# we set the evaluator to be used by both the server and the clients
14env.set_evaluator(WeightedClassificationEval(eval_every=1, n_classes=10))
15
16dataset = Datasets.get("mnist", path="./data")
17splitter = DataSplitter(dataset=dataset, distribution="iid")
18
19client_hp = DDict(
20 batch_size=10,
21 local_epochs=5,
22 loss="CrossEntropyLoss",
23 optimizer=DDict(
24 lr=0.01,
25 momentum=0.9,
26 weight_decay=0.0001),
27 scheduler=DDict(
28 gamma=1,
29 step_size=1)
30)
31
32# we put together the hyperparameters for the algorithm
33hyperparams = DDict(client=client_hp,
34 server=DDict(weighted=True),
35 model="MNIST_2NN")
Here is where the new federated algorithm comes into play.
1algorithm = MyFLAlgorithm(n_clients=10, # 10 clients in the federation
2 data_splitter=splitter,
3 hyper_params=hyperparams)
4
5logger = Log()
6algorithm.set_callbacks(logger)
We only just need to run it!
1algorithm.run(n_rounds=100, eligible_perc=0.5)