"""
The :mod:`fluke` module is the entry module of the :mod:`fluke` framework. Here are defined generic
classes used by the other modules.
"""
from __future__ import annotations
import random
import re
import shutil
import uuid
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Iterable, Union
import numpy as np
import torch
from diskcache import Cache
from rich.console import Group
from rich.progress import Live, Progress
if TYPE_CHECKING:
from .evaluation import Evaluator
__all__ = [
'algorithms',
'client',
'comm',
'data',
'evaluation',
'get',
'nets',
'run',
'server',
'utils',
'DDict',
'FlukeCache',
'FlukeENV',
'ObserverSubject',
'Singleton'
]
[docs]
class Singleton(type):
"""This metaclass is used to create singleton classes. A singleton class is a class that can
have only one instance. If the instance does not exist, it is created; otherwise, the existing
instance is returned.
Example:
.. code-block:: python
:linenos:
class MyClass(metaclass=Singleton):
pass
a = MyClass()
b = MyClass()
print(a is b) # True
"""
_instances = {}
def __call__(cls, *args, **kwargs: dict[str, Any]):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
def clear(cls):
cls._instances = {}
[docs]
class DDict(dict):
"""A dictionary that can be accessed with dot notation recursively.
Important:
The :class:`DDict` is a subclass of the built-in :class:`dict` class and it behaves like a
dictionary. However, the keys must be strings.
Example:
.. code-block:: python
:linenos:
d = DDict(a=1, b=2, c={'d': 3, 'e': 4})
print(d.a) # 1
print(d.b) # 2
print(d.c.d) # 3
print(d.c.e) # 4
"""
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
def __init__(self, *args: dict, **kwargs: dict[str, Any]):
self.update(*args, **kwargs)
[docs]
def update(self, *args: dict, **kwargs: dict[str, Any]):
"""Update the :class:`DDict` with the specified key-value pairs.
Args:
*args (dict): Dictionary with the key-value pairs.
**kwargs: The key-value pairs.
Example:
.. code-block:: python
:linenos:
d = DDict(a=1)
print(d) # {'a': 1}
d.update(b=2, c=3)
print(d) # {'a': 1, 'b': 2, 'c': 3}
"""
for arg in args:
if isinstance(arg, dict):
for k, v in arg.items():
if isinstance(v, dict):
self[k] = DDict(**v)
else:
self[k] = v
else:
warnings.warn(f"Argument {arg} is not a dictionary and will be ignored.")
for k, v in kwargs.items():
if isinstance(v, dict):
self[k] = DDict(**v)
else:
self[k] = v
[docs]
def exclude(self, *keys: str):
"""Create a new :class:`DDict` excluding the specified keys.
Args:
*keys: The keys to be excluded.
Returns:
DDict: The new DDict.
Example:
.. code-block:: python
:linenos:
d = DDict(a=1, b=2, c=3)
e = d.exclude('b', 'c')
print(e) # {'a': 1}
"""
return DDict(**{k: v for k, v in self.items() if k not in keys})
[docs]
def match(self, other: DDict, full: bool = True) -> bool:
"""Check if the two :class:`DDict` match.
Args:
other (DDict): The other :class:`DDict`.
full (bool): If ``True``, the two :class:`DDict` must match exactly. If ``False``, the
`other` :class:`DDict` must be a subset of the current :class:`DDict`.
Returns:
bool: Whether the two :class:`DDict` match.
"""
if full:
return self == other
return all(k in self and (self[k] == other[k] if not isinstance(self[k], DDict)
else self[k].match(other[k], False)) for k in other.keys())
[docs]
def diff(self, other: DDict) -> DDict:
"""Get the difference between two :class:`DDict`.
Args:
other (DDict): The other :class:`DDict`.
Returns:
DDict: The difference between the two :class:`DDict`.
Example:
.. code-block:: python
:linenos:
d = DDict(a=1, b=2, c=3)
e = DDict(a=1, b=3, c=4)
print(d.diff(e)) # {'b': 3, 'c': 4}
"""
diff = DDict()
for k, v in other.items():
if k in self:
if isinstance(self[k], DDict):
d = self[k].diff(v)
if d:
diff[k] = d
elif v != self[k]:
diff[k] = v
else:
diff[k] = v
return diff
def __getstate__(self) -> dict:
return self.__dict__
def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)
[docs]
class ObserverSubject():
"""Subject class for the observer pattern. The subject is the class that is observed and thus
it holds the observers.
Example:
.. code-block:: python
:linenos:
class MySubject(ObserverSubject):
def __init__(self):
super().__init__()
self._data = 0
@property
def data(self):
return self._data
@data.setter
def data(self, value):
self._data = value
self.notify()
class MyObserver:
def __init__(self, subject):
subject.attach(self)
def update(self):
print("Data changed.")
subject = MySubject()
observer = MyObserver(subject)
subject.data = 1 # "Data changed."
"""
def __init__(self):
self._observers: list[Any] = []
[docs]
def attach(self, observer: Union[Any, Iterable[Any]]):
"""Attach one or more observers.
Args:
observer (Union[Any, Iterable[Any]]): The observer or a list of observers.
"""
if observer is None:
return
if not isinstance(observer, (list, tuple, set)):
observer = [observer]
for obs in observer:
if obs not in self._observers:
self._observers.append(obs)
[docs]
def detach(self, observer: Any):
"""Detach an observer.
Args:
observer (Any): The observer to be detached.
"""
try:
self._observers.remove(observer)
except ValueError:
pass
[docs]
class FlukeENV(metaclass=Singleton):
"""Environment class for the :mod:`fluke` framework.
This class is a singleton and it contains environment settings that are used by the other
classes. The environment includes:
- The device (``"cpu"``, ``"cuda[:N]"``, ``"auto"``, ``"mps"``);
- The ``seed`` for reproducibility;
- If the models are stored in memory or on disk (when not in use);
- The evaluation configuration;
- The saving settings;
- The progress bars for the federated learning process, clients and the server;
- The live renderer, which is used to render the progress bars.
"""
# general settings
_device: torch.device = torch.device('cpu')
_seed: int = 0
_inmemory: bool = True
_cache: FlukeCache = None
# saving settings
_save_path: str = None
_save_every: int = 0
_global_only: bool = False
_temp_path: str = None
# evaluation settings
_evaluator: Evaluator = None
_eval_cfg: dict = {
"pre_fit": False,
"post_fit": False,
"locals": False,
"server": True
}
# progress bars
_rich_progress_FL: Progress = None
_rich_progress_clients: Progress = None
_rich_progress_server: Progress = None
_live_renderer: Live = None
def __init__(self):
super().__init__()
self._rich_progress_FL: Progress = Progress(transient=True)
self._rich_progress_clients: Progress = Progress(transient=True)
self._rich_progress_server: Progress = Progress(transient=True)
self._rich_live_renderer: Live = Live(Group(self._rich_progress_FL,
self._rich_progress_clients,
self._rich_progress_server))
[docs]
def get_seed(self) -> int:
"""Get the seed.
Returns:
int: The seed.
"""
return self._seed
[docs]
def get_eval_cfg(self) -> DDict:
"""Get the evaluation configuration.
Returns:
DDict: The evaluation configuration.
"""
return DDict(self._eval_cfg)
[docs]
def set_eval_cfg(self, cfg: DDict) -> None:
"""Set the evaluation configuration.
Args:
cfg (DDict): The evaluation configuration.
"""
for key, value in cfg.items():
self._eval_cfg[key] = value
[docs]
def get_evaluator(self) -> Evaluator:
"""Get the evaluator.
Returns:
Evaluator: The evaluator.
"""
return self._evaluator
[docs]
def set_evaluator(self, evaluator: Evaluator) -> None:
"""Set the evaluator.
Args:
evaluator (Evaluator): The evaluator.
"""
self._evaluator = evaluator
[docs]
def set_seed(self, seed: int) -> None:
"""Set seed for reproducibility. The seed is used to set the random seed for the following
libraries: ``torch``, ``torch.cuda``, ``numpy``, ``random``.
Args:
seed (int): The seed.
"""
self._seed = seed
torch.manual_seed(seed)
gen = torch.Generator()
gen.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
[docs]
def auto_device(self) -> torch.device:
"""Set device to ``cuda`` or ``mps`` if available, otherwise ``cpu``.
Returns:
torch.device: The device.
"""
if torch.cuda.is_available():
self._device = torch.device('cuda')
elif torch.backends.mps.is_available():
self._device = torch.device('mps')
else:
self._device = torch.device('cpu')
return self._device
[docs]
def set_device(self, device: str) -> torch.device:
"""Set the device. The device can be ``cpu``, ``auto``, ``mps``, ``cuda`` or ``cuda:N``,
where ``N`` is the GPU index.
Args:
device (str): The device as string.
Returns:
torch.device: The selected device as torch.device.
"""
assert device in ['cpu', 'auto', 'mps', 'cuda'] or re.match(r'^cuda:\d+$', device), \
f"Invalid device {device}."
if device == "auto":
return FlukeENV().auto_device()
if device.startswith('cuda') and ":" in device:
idx = int(device.split(":")[1])
self._device = torch.device("cuda", idx)
else:
self._device = torch.device(device)
return self._device
[docs]
def get_device(self) -> torch.device:
"""Get the current device.
Returns:
torch.device: The device.
"""
return self._device
[docs]
def get_progress_bar(self, progress_type: str) -> Progress:
"""Get the progress bar.
The possible progress bar types are:
- ``FL``: The progress bar for the federated learning process.
- ``clients``: The progress bar for the clients.
- ``server``: The progress bar for the server.
Args:
progress_type (str): The type of progress bar.
Returns:
Progress: The progress bar.
Raises:
ValueError: If the progress bar type is invalid.
"""
if progress_type == 'FL':
return self._rich_progress_FL
elif progress_type == 'clients':
return self._rich_progress_clients
elif progress_type == 'server':
return self._rich_progress_server
else:
raise ValueError(f'Invalid type of progress bar type {progress_type}.')
[docs]
def get_live_renderer(self) -> Live:
"""Get the live renderer. The live renderer is used to render the progress bars.
Returns:
Live: The live renderer.
"""
return self._rich_live_renderer
[docs]
def get_save_options(self) -> tuple[str, int, bool]:
"""Get the save options.
Returns:
tuple[str, int, bool]: The save path, the save frequency and the global only flag.
"""
return self._save_path, self._save_every, self._global_only
[docs]
def set_save_options(self,
path: str | None = None,
save_every: int | None = None,
global_only: bool | None = None) -> None:
"""Set the save options.
Args:
path (str): The path to save the checkpoints.
save_every (int): The frequency of saving the checkpoints.
global_only (bool): If ``True``, only the global model is saved.
"""
if path is not None:
self._save_path = path
if save_every is not None:
self._save_every = save_every
if global_only is not None:
self._global_only = global_only
[docs]
def set_inmemory(self, inmemory: bool) -> None:
"""Set if the data is stored in memory.
Args:
inmemory (bool): If ``True``, the data is stored in memory, otherwise it is stored on
disk.
"""
self._inmemory = inmemory
[docs]
def get_cache(self) -> FlukeCache:
"""Get the cache.
Returns:
Cache: The cache.
"""
return self._cache
[docs]
def open_cache(self, path: str) -> None:
"""Open the cache at the specified path if the ``inmemory`` flag is ``False``.
Note:
The full path to the cache is ``tmp/path`` where ``path`` is the specified path.
We suggest to use as path the UUID of the experiment.
Args:
path (str): The path to the cache.
"""
if not self._inmemory and self._cache is None:
self._cache = FlukeCache(path)
elif self._cache is not None:
warnings.warn("Cache already open.")
[docs]
def close_cache(self) -> None:
"""Close the cache."""
if self._cache is not None:
self._cache.close()
self._cache = None
[docs]
def is_inmemory(self) -> bool:
"""Check if the data is stored in memory.
Returns:
bool: If ``True``, the data is stored in memory, otherwise it is stored on disk.
"""
return self._inmemory
[docs]
def force_close(self) -> None:
"""Force close the progress bars and the live renderer."""
task_ids = [task.id for task in self._rich_progress_FL.tasks]
for tid in task_ids:
self._rich_progress_FL.remove_task(tid)
task_ids = [task.id for task in self._rich_progress_clients.tasks]
for tid in task_ids:
self._rich_progress_clients.remove_task(tid)
task_ids = [task.id for task in self._rich_progress_server.tasks]
for tid in task_ids:
self._rich_progress_server.remove_task(tid)
self._rich_live_renderer.refresh()
self._rich_live_renderer.stop()
def __getstate__(self) -> dict:
return {k: v for k, v in self.__dict__.items() if not k.startswith('_rich')}
def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)
[docs]
class FlukeCache():
"""A cache class that can store data on disk."""
class _ObjectRef():
"""A reference to an object in the cache.
The reference is a unique identifier that is used to store and retrieve the object from the
cache.
"""
def __init__(self):
self._id = str(uuid.uuid4().hex)
@property
def id(self) -> str:
"""Get the unique identifier of the reference.
Returns:
str: The unique identifier.
"""
return self._id
class _RefCounter():
"""A reference counter for an object in the cache."""
def __init__(self, value: Any, refs: int = 1):
self._value = value
self._refs = refs
self._id = FlukeCache._ObjectRef()
@property
def id(self) -> FlukeCache._ObjectRef:
"""Get the unique identifier of the reference.
Returns:
str: The unique identifier.
"""
return self._id
@property
def value(self) -> Any:
"""Get the value ppinted by the reference.
Returns:
Any: The value.
"""
return self._value
@property
def refs(self) -> int:
"""Get the number of references to the object in the cache.
Returns:
int: The number of references.
"""
return self._refs
def dec(self) -> Any:
"""Decrement the number of references to the object in the cache.
Returns:
FlukeCache._RefCounter: The reference counter.
"""
self._refs -= 1
return self
def inc(self) -> FlukeCache._RefCounter:
"""Increment the number of references to the object in the cache.
Returns:
FlukeCache._RefCounter: The reference counter.
"""
self._refs += 1
return self
def __init__(self, path: str, **kwargs):
if 'size_limit' not in kwargs:
kwargs['size_limit'] = 2**34
self._cache: Cache = Cache(f"tmp/{path}", **kwargs)
self._key2ref: dict[str, FlukeCache._ObjectRef] = {}
def __getitem__(self, key: str):
return self._cache[self._key2ref[key].id].value
@property
def cache_dir(self) -> str:
"""Get the cache directory.
Returns:
str: The cache directory.
"""
return self._cache.directory
[docs]
def get(self, key: str, default: Any = None):
"""Get the object identified by the key from the cache.
If the object is not in the cache, the default value is returned.
Note:
The object is still in the cache after this operation.
Args:
key (str): The key of the object.
default (Any, optional): The default value to return if the object is not in the cache.
Defaults to None.
Returns:
Any: The object in the cache or the default value.
"""
if key not in self._key2ref:
return default
obj = self._cache.get(self._key2ref[key].id, default=default)
if obj is not default:
return obj.value
[docs]
def push(self, key: str, value: Any) -> FlukeCache._ObjectRef:
"""Push an object to the cache.
Note:
If the object that is pushed is already a cache reference, then the referenced object is
already in the cache and its reference counter is incremented.
Args:
key (str): The key of the object.
value (Any): The object to store in the cache.
Returns:
FlukeCache._ObjectRef: The reference to the object in the cache.
"""
if isinstance(value, FlukeCache._ObjectRef):
assert value.id in self._cache, f"Reference {value.id} not in cache."
self._key2ref[key] = value
self._cache[value.id] = self._cache[value.id].inc()
return value
else:
ref = self._RefCounter(value)
self._key2ref[key] = ref.id
self._cache[ref.id.id] = ref
return ref.id
[docs]
def pop(self, key: str, copy: bool = True) -> Any:
"""Pop an object from the cache given its key.
If the key is not in the cache, ``None`` is returned.
Args:
key (str): The key of the object.
copy (bool, optional): If ``True``, a copy of the object is returned.
Defaults to ``True``.
Returns:
Any: The object in the cache or its copy.
"""
if key not in self._key2ref:
return None
ref = self._key2ref[key]
del self._key2ref[key]
self._cache[ref.id] = self._cache[ref.id].dec()
obj = self._cache[ref.id].value
if self._cache[ref.id].refs == 0:
self._cache.delete(ref.id)
return obj if not copy else deepcopy(obj)
[docs]
def delete(self, key: str) -> None:
"""Remove an object from the cache without returning it.
If the key is not in the cache, nothing happens.
Args:
key (str): The key of the object.
"""
if key in self._key2ref:
ref = self._key2ref[key]
del self._key2ref[key]
self._cache[ref.id] = self._cache[ref.id].dec()
if self._cache[ref.id].refs == 0:
self._cache.delete(ref.id)
[docs]
def close(self) -> None:
"""Close the cache."""
if self._cache is not None:
self._cache.clear()
self._cache.close()
try:
shutil.rmtree(self._cache.directory)
except OSError: # Windows wonkiness
pass
self._cache = None
self._key2ref = {}
@property
def occupied(self) -> int:
"""Get the number of different objects in the cache.
Returns:
int: The number of objects in the cache.
"""
return len(list(self._cache.iterkeys()))
[docs]
def cleanup(self) -> None:
"""Clean up the cache by removing the objects that are not referenced.
This operation should not be necessary if the cache is used correctly.
"""
keys = set([v.id for v in self._key2ref.values()])
for key in self._cache.iterkeys():
if key not in keys:
self._cache.pop(key)