fluke.utils
¶
This module contains utility functions and classes used in fluke
.
Submodules
This submodule provides logging utilities. |
|
This submodule provides utilities for pytorch model manipulation. |
Classes included in fluke.utils
Fluke configuration class. |
|
This class is used to configure the optimizer and the learning rate scheduler. |
|
Client observer interface. |
|
Server observer interface. |
Functions included in fluke.utils
Clear the CUDA cache. |
|
Get a class from its name. |
|
Get a class from its fully qualified name. |
|
Get the fully qualified name of a class. |
|
Get a loss function from its name. |
|
Get a model from its name. |
|
|
Get an optimizer from its name. |
Get a learning rate scheduler from its name. |
|
Import a module from its name. |
|
Plot the distribution of classes for each client. |
Classes¶
class fluke.utils.Configuration
- class fluke.utils.Configuration(config_exp_path: str, config_alg_path: str)[source]¶
Fluke configuration class. This class is used to store the configuration of an experiment. The configuration must adhere to a specific structure. The configuration is validated when the class is instantiated.
- Parameters:
- Raises:
ValueError – If the configuration is not valid.
- property client: DDict¶
Get quick access to the client hyperparameters.
- Returns:
The client hyperparameters.
- Return type:
class fluke.utils.OptimizerConfigurator
- class fluke.utils.OptimizerConfigurator(optimizer_cfg: DDict | dict, scheduler_cfg: DDict | dict | None = None)[source]¶
This class is used to configure the optimizer and the learning rate scheduler.
- __call__(model: Module, filter_fun: callable | None = None, **override_kwargs)[source]¶
Creates the optimizer and the scheduler.
- Parameters:
model (Module) – The model whose parameters will be optimized.
filter_fun (callable) – This must be a function of the model and it must returns the set of parameters that the optimizer will consider.
override_kwargs (dict) – The optimizer’s keyword arguments to override the default ones.
- Returns:
The optimizer and the scheduler.
- Return type:
tuple[Optimizer, StepLR]
interface fluke.utils.ClientObserver
- class fluke.utils.ClientObserver[source]¶
Bases:
object
Client observer interface. This interface is used to observe the client during the federated learning process. For example, it can be used to log the performance of the local model, as it is done by the
Log
class.- client_evaluation(round: int, client_id: int, phase: Literal['pre-fit', 'post-fit'], evals: dict[str, float], **kwargs: dict[str, Any])[source]¶
This method is called when the client evaluates the local model. The evaluation can be done before (‘pre-fit’) and/or after (‘post-fit’) the local training process. The ‘pre-fit’ evlauation is usually the evaluation of the global model on the local test set, and the ‘post-fit’ evaluation is the evaluation of the just updated local model on the local test set.
- end_fit(round: int, client_id: int, model: Module, loss: float, **kwargs: dict[str, Any])[source]¶
This method is called when the client ends the local training process.
interface fluke.utils.ServerObserver
- class fluke.utils.ServerObserver[source]¶
Bases:
object
Server observer interface. This interface is used to observe the server during the federated learning process. For example, it can be used to log the performance of the global model and the communication costs, as it is done by the
Log
class.- end_round(round: int) None [source]¶
This method is called when a round ends.
- Parameters:
round (int) – The round number.
- finished(round: int) None [source]¶
This method is called when the federated learning process has ended.
- Parameters:
round (int) – The last round number.
- selected_clients(round: int, clients: Iterable) None [source]¶
This method is called when the clients have been selected for the current round.
- Parameters:
round (int) – The round number.
clients (Iterable) – The clients selected for the current round.
- server_evaluation(round: int, type: Literal['global', 'locals'], evals: dict[str, float] | dict[int, dict[str, float]], **kwargs: dict[str, Any]) None [source]¶
This method is called when the server evaluates the global or the local models on its test set.
- Parameters:
round (int) – The round number.
type (Literal['global', 'locals']) – The type of evaluation. If ‘global’, the evaluation is done on the global model. If ‘locals’, the evaluation is done on the local models of the clients on the test set of the server.
evals (dict[str, float] | dict[int, dict[str, float]]) – The evaluation metrics. In case of ‘global’ evaluation, it is a dictionary with the evaluation metrics. In case of ‘locals’ evaluation, it is a dictionary of dictionaries where the keys are the client IDs and the values are the evaluation metrics.
Functions¶
- fluke.utils.clear_cache(ipc: bool = False)[source]¶
Clear the CUDA cache. This function should be used to free the GPU memory after the training process has ended. It is usually used after the local training of the clients.
- Parameters:
ipc (bool, optional) – Whether to force collecting GPU memory after it has been released by CUDA IPC.
- fluke.utils.get_class_from_str(module_name: str, class_name: str) Any [source]¶
Get a class from its name. This function is used to get a class from its name and the name of the module where it is defined. It is used to dynamically import classes.
- fluke.utils.get_class_from_qualified_name(qualname: str) Any [source]¶
Get a class from its fully qualified name.
- Parameters:
qualname (str) – The fully qualified name of the class.
- Returns:
The class.
- Return type:
Any
- fluke.utils.get_full_classname(classtype: type) str [source]¶
Get the fully qualified name of a class.
- Parameters:
classtype (type) – The class.
- Returns:
The fully qualified name of the class.
- Return type:
Example
Let
A
be a class defined in the modulefluke.utils
1# This is the content of the file fluke/utils.py 2class A: 3 pass 4 5get_full_classname(A) # 'fluke.utils.A'
If the class is defined in the
__main__
module, then:1if __name__ == "__main__": 2 class B: 3 pass 4 5 get_full_classname(B) # '__main__.B'
- fluke.utils.get_loss(lname: str) Module [source]¶
Get a loss function from its name. The supported loss functions are the ones defined in the
torch.nn
module.- Parameters:
lname (str) – The name of the loss function.
- Returns:
The loss function.
- Return type:
Module
- fluke.utils.get_model(mname: str, **kwargs: dict[str, Any]) Module [source]¶
Get a model from its name. This function is used to get a torch model from its name and the name of the module where it is defined. It is used to dynamically import models. If
mname
is not a fully qualified name, the model is assumed to be defined in thefluke.nets
module.- Parameters:
mname (str) – The name of the model.
**kwargs – The keyword arguments to pass to the model’s constructor.
- Returns:
The model.
- Return type:
Module
- fluke.utils.get_scheduler(sname: str) type[LRScheduler] [source]¶
Get a learning rate scheduler from its name. This function is used to get a learning rate scheduler from its name. It is used to dynamically import learning rate schedulers. The supported schedulers are the ones defined in the
torch.optim.lr_scheduler
module.- Parameters:
sname (str) – The name of the scheduler.
- Returns:
The learning rate scheduler.
- Return type:
torch.nn.Module
- fluke.utils.import_module_from_str(name: str) Any [source]¶
Import a module from its name.
- Parameters:
name (str) – The name of the module.
- Returns:
The module.
- Return type:
Any
- fluke.utils.plot_distribution(clients: list[Client], train: bool = True, type: str = 'ball') None [source]¶
Plot the distribution of classes for each client. This function is used to plot the distribution of classes for each client. The plot can be a scatter plot, a heatmap, or a bar plot. The scatter plot (
type='ball'
) shows filled circles whose size is proportional to the number of examples of a class. The heatmap (type='mat'
) shows a matrix where the rows represent the classes and the columns represent the clients with a color intensity proportional to the number of examples of a class. The bar plot (type='bar'
) shows a stacked bar plot where the height of the bars is proportional to the number of examples of a class.Warning
If the number of clients is greater than 30, the type is automatically switched to
'bar'
for better visualization.- Parameters: