Source code for fluke

"""
The ``fluke`` module is the entry module of the ``fluke`` framework. Here are defined generic
classes used by the other modules.
"""
from __future__ import annotations

import random
import re
import warnings
from typing import TYPE_CHECKING, Any, Iterable, Union

import numpy as np
import torch
from rich.console import Group
from rich.progress import Live, Progress

if TYPE_CHECKING:
    from .evaluation import Evaluator


__all__ = [
    'algorithms',
    'client',
    'comm',
    'data',
    'evaluation',
    'get',
    'nets',
    'run',
    'server',
    'utils',
    'DDict',
    'GlobalSettings',
    'ObserverSubject',
    'Singleton'
]


[docs] class Singleton(type): """This metaclass is used to create singleton classes. A singleton class is a class that can have only one instance. If the instance does not exist, it is created; otherwise, the existing instance is returned. Example: .. code-block:: python :linenos: class MyClass(metaclass=Singleton): pass a = MyClass() b = MyClass() print(a is b) # True """ _instances = {} def __call__(cls, *args, **kwargs: dict[str, Any]): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls]
[docs] class DDict(dict): """A dictionary that can be accessed with dot notation recursively. Example: .. code-block:: python :linenos: d = DDict(a=1, b=2, c={'d': 3, 'e': 4}) print(d.a) # 1 print(d.b) # 2 print(d.c.d) # 3 print(d.c.e) # 4 """ __getattr__ = dict.get __setattr__ = dict.__setitem__ def __init__(self, *args: dict, **kwargs: dict[str, Any]): self.update(*args, **kwargs)
[docs] def update(self, *args: dict, **kwargs: dict[str, Any]): """Update the ``DDict`` with the specified key-value pairs. Args: *args (dict): Dictionary with the key-value pairs. **kwargs: The key-value pairs. Example: .. code-block:: python :linenos: d = DDict(a=1) print(d) # {'a': 1} d.update(b=2, c=3) print(d) # {'a': 1, 'b': 2, 'c': 3} """ for arg in args: if isinstance(arg, dict): for k, v in arg.items(): if isinstance(v, dict): self[k] = DDict(**v) else: self[k] = v else: warnings.warn(f"Argument {arg} is not a dictionary and will be ignored.") for k, v in kwargs.items(): if isinstance(v, dict): self[k] = DDict(**v) else: self[k] = v
[docs] def exclude(self, *keys: str): """Create a new ``DDict`` excluding the specified keys. Args: *keys: The keys to be excluded. Returns: DDict: The new DDict. Example: .. code-block:: python :linenos: d = DDict(a=1, b=2, c=3) e = d.exclude('b', 'c') print(e) # {'a': 1} """ return DDict(**{k: v for k, v in self.items() if k not in keys})
[docs] class ObserverSubject(): """Subject class for the observer pattern. The subject is the class that is observed and thus it holds the observers. Example: .. code-block:: python :linenos: class MySubject(ObserverSubject): def __init__(self): super().__init__() self._data = 0 @property def data(self): return self._data @data.setter def data(self, value): self._data = value self.notify() class MyObserver: def __init__(self, subject): subject.attach(self) def update(self): print("Data changed.") subject = MySubject() observer = MyObserver(subject) subject.data = 1 # "Data changed." """ def __init__(self): self._observers: list[Any] = []
[docs] def attach(self, observer: Union[Any, Iterable[Any]]): """Attach one or more observers. Args: observer (Union[Any, Iterable[Any]]): The observer or a list of observers. """ if observer is None: return if not isinstance(observer, (list, tuple, set)): observer = [observer] for obs in observer: if obs not in self._observers: self._observers.append(obs)
[docs] def detach(self, observer: Any): """Detach an observer. Args: observer (Any): The observer to be detached. """ try: self._observers.remove(observer) except ValueError: pass
[docs] class GlobalSettings(metaclass=Singleton): """Global settings for ``fluke``. This class is a singleton that holds the global settings for ``fluke``. The settings include: - The device (``"cpu"``, ``"cuda[:N]"``, ``"auto"``, ``"mps"``); - The ``seed`` for reproducibility; - The evaluation configuration; - The saving settings; - The progress bars for the federated learning process, clients and the server; - The live renderer, which is used to render the progress bars. """ # general settings _device: str = 'cpu' _seed: int = 0 # saving settings _save_path: str = None _save_every: int = 0 _global_only: bool = False # evaluation settings _evaluator: Evaluator = None _eval_cfg: dict = { "pre_fit": False, "post_fit": False, "locals": False, "server": True } # progress bars _progress_FL: Progress = None _progress_clients: Progress = None _progress_server: Progress = None _live_renderer: Live = None def __init__(self): super().__init__() self._progress_FL: Progress = Progress(transient=True) self._progress_clients: Progress = Progress(transient=True) self._progress_server: Progress = Progress(transient=True) self._live_renderer: Live = Live(Group(self._progress_FL, self._progress_clients, self._progress_server))
[docs] def get_seed(self) -> int: """Get the seed. Returns: int: The seed. """ return self._seed
[docs] def get_eval_cfg(self) -> DDict: """Get the evaluation configuration. Returns: DDict: The evaluation configuration. """ return DDict(self._eval_cfg)
[docs] def set_eval_cfg(self, cfg: DDict) -> None: """Set the evaluation configuration. Args: cfg (DDict): The evaluation configuration. """ for key, value in cfg.items(): self._eval_cfg[key] = value
[docs] def get_evaluator(self) -> Evaluator: """Get the evaluator. Returns: Evaluator: The evaluator. """ return self._evaluator
[docs] def set_evaluator(self, evaluator: Evaluator) -> None: """Set the evaluator. Args: evaluator (Evaluator): The evaluator. """ self._evaluator = evaluator
[docs] def set_seed(self, seed: int) -> None: """Set seed for reproducibility. The seed is used to set the random seed for the following libraries: ``torch``, ``torch.cuda``, ``numpy``, ``random``. Args: seed (int): The seed. """ self._seed = seed torch.manual_seed(seed) gen = torch.Generator() gen.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True
[docs] def auto_device(self) -> torch.device: """Set device to ``cuda`` or ``mps`` if available, otherwise ``cpu``. Returns: torch.device: The device. """ if torch.cuda.is_available(): self._device = torch.device('cuda') elif torch.backends.mps.is_available(): self._device = torch.device('mps') else: self._device = torch.device('cpu') return self._device
[docs] def set_device(self, device: str) -> torch.device: """Set the device. The device can be ``cpu``, ``auto``, ``mps``, ``cuda`` or ``cuda:N``, where ``N`` is the GPU index. Args: device (str): The device as string. Returns: torch.device: The selected device as torch.device. """ assert device in ['cpu', 'auto', 'mps', 'cuda'] or re.match(r'^cuda:\d+$', device), \ f"Invalid device {device}." if device == "auto": return GlobalSettings().auto_device() if device.startswith('cuda') and ":" in device: idx = int(device.split(":")[1]) self._device = torch.device("cuda", idx) else: self._device = torch.device(device) return self._device
[docs] def get_device(self) -> torch.device: """Get the current device. Returns: torch.device: The device. """ return self._device
[docs] def get_progress_bar(self, progress_type: str) -> Progress: """Get the progress bar. The possible progress bar types are: - ``FL``: The progress bar for the federated learning process. - ``clients``: The progress bar for the clients. - ``server``: The progress bar for the server. Args: progress_type (str): The type of progress bar. Returns: Progress: The progress bar. Raises: ValueError: If the progress bar type is invalid. """ if progress_type == 'FL': return self._progress_FL elif progress_type == 'clients': return self._progress_clients elif progress_type == 'server': return self._progress_server else: raise ValueError(f'Invalid type of progress bar type {progress_type}.')
[docs] def get_live_renderer(self) -> Live: """Get the live renderer. The live renderer is used to render the progress bars. Returns: Live: The live renderer. """ return self._live_renderer
[docs] def get_save_options(self) -> tuple[str, int, bool]: """Get the save options. Returns: tuple[str, int, bool]: The save path, the save frequency and the global only flag. """ return self._save_path, self._save_every, self._global_only
[docs] def set_save_options(self, path: str | None = None, save_every: int | None = None, global_only: bool | None = None) -> None: """Set the save options. Args: path (str): The path to save the checkpoints. save_every (int): The frequency of saving the checkpoints. global_only (bool): If ``True``, only the global model is saved. """ if path is not None: self._save_path = path if save_every is not None: self._save_every = save_every if global_only is not None: self._global_only = global_only