fluke.evaluation
¶
This module contains the definition of the evaluation classes used to perform the evaluation of the model client-side and server-side.
This class is the base class for all evaluators in |
|
Evaluate a PyTorch model for classification. |
interface fluke.evaluation.Evaluator
- class fluke.evaluation.Evaluator(eval_every: int = 1)[source]¶
Bases:
ABC
This class is the base class for all evaluators in
fluke
. An evaluator object should be used to perform the evaluation of a (federated) model.- Parameters:
eval_every (int) – The evaluation frequency expressed as the number of rounds between two evaluations. Defaults to 1, i.e., evaluate the model at each round.
- abstract evaluate(round: int, model: Module, eval_data_loader: FastDataLoader, loss_fn: Module | None, **kwargs: dict[str, Any]) dict [source]¶
Evaluate the model.
- Parameters:
round (int) – The current
model (Module) – The model to evaluate.
eval_data_loader (FastDataLoader) – The data loader to use for evaluation.
loss_fn (torch.nn.Module, optional) – The loss function to use for evaluation.
**kwargs – Additional keyword arguments.
class fluke.evaluation.ClassificationEval
- class fluke.evaluation.ClassificationEval(eval_every: int, n_classes: int)[source]¶
Bases:
Evaluator
Evaluate a PyTorch model for classification. The metrics computed are
accuracy
,precision
,recall
,f1
and the loss according to the provided loss functionloss_fn
when calling the methodevaluation
. Metrics are computed both in a micro and macro fashion.- evaluate(round: int, model: Module, eval_data_loader: FastDataLoader | Iterable[FastDataLoader], loss_fn: Module | None = None, device: device = device(type='cpu')) dict [source]¶
Evaluate the model. The metrics computed are
accuracy
,precision
,recall
,f1
and the loss according to the provided loss functionloss_fn
. Metrics are computed both in a micro and macro fashion.- Parameters:
round (int) – The current round.
model (torch.nn.Module) – The model to evaluate. If
None
, the method returns an empty dictionary.eval_data_loader (Union[FastDataLoader, Iterable[FastDataLoader]]) – The data loader(s) to use for evaluation. If
None
, the method returns an empty dictionary.loss_fn (torch.nn.Module, optional) – The loss function to use for evaluation.
device (torch.device, optional) – The device to use for evaluation. Defaults to “cpu”.
- Returns:
A dictionary containing the computed metrics.
- Return type: