"""This module implements clients for decentralized federated learning (DFL) algorithms."""
from random import choice
from typing import Generator, Literal
import numpy as np
from torch.nn import Module
from ... import FlukeENV # NOQA
from ...client import Client # NOQA
from ...comm import TimedMessage # NOQA
from ...config import OptimizerConfigurator # NOQA
from ...data import DataLoader, FastDataLoader # NOQA
from ...utils.model import safe_load_state_dict, aggregate_models # NOQA
__all__ = ["AbstractDFLClient", "GossipClient"]
[docs]
class AbstractDFLClient(Client):
"""Abstract client for decentralized federated learning (DFL).
Args:
index (int): The index of the client.
model (Module): The model to be trained.
neighbours (list[int]): The indices of the neighbouring clients.
train_set (FastDataLoader | DataLoader): The training dataset.
test_set (FastDataLoader | DataLoader): The testing dataset.
optimizer_cfg (OptimizerConfigurator): The optimizer configuration.
loss_fn (Module): The loss function.
local_epochs (int): Number of local training epochs. Defaults to 3.
fine_tuning_epochs (int): Number of fine-tuning epochs. Defaults to 0.
clipping (float): Gradient clipping value. Defaults to 0 (no clipping).
persistency (bool): Whether to persist the model across rounds. Defaults to True.
activation_rate (float): Probability of the client being active in each round.
Defaults to 1 (always active).
**kwargs: Additional keyword arguments passed to the parent class.
"""
def __init__(
self,
index: int,
model: Module,
neighbours: list[int],
train_set: FastDataLoader | DataLoader,
test_set: FastDataLoader | DataLoader,
optimizer_cfg: OptimizerConfigurator,
loss_fn: Module,
local_epochs: int = 3,
fine_tuning_epochs: int = 0,
clipping: float = 0,
persistency: bool = True,
activation_rate: float = 1,
**kwargs,
):
super().__init__(
index=index,
train_set=train_set,
test_set=test_set,
optimizer_cfg=optimizer_cfg,
loss_fn=loss_fn,
local_epochs=local_epochs,
fine_tuning_epochs=fine_tuning_epochs,
clipping=clipping,
persistency=persistency,
**kwargs,
)
self.hyper_params.update(activation_rate=activation_rate)
self.model = model
self.neighbours: list[int] = neighbours
self._num_updates: int = 0
self._active_history: dict[int, bool] = {}
[docs]
def is_active(self, iter: int) -> bool:
"""Check if the client is active in the current iteration.
Args:
iter (int): The current iteration number.
Returns:
bool: True if the client is active, False otherwise.
"""
if iter not in self._active_history:
self._active_history[iter] = np.random.rand() < self.hyper_params.activation_rate
return self._active_history[iter]
def send_model(self, *args, **kwargs) -> None:
raise NotImplementedError()
def receive_model(self) -> Generator:
raise NotImplementedError()
def local_update(self, round: int) -> None:
if self._num_updates == 0 or self.channel.has_messages(self.index, "model"):
super().local_update(round)
self._num_updates += 1
elif self._num_updates > 0:
self.send_model()
def finalize(self) -> None:
self._load_from_cache()
evaluator = FlukeENV().get_evaluator()
if FlukeENV().get_eval_cfg().pre_fit:
metrics = self.evaluate(evaluator, self.test_set)
if metrics:
self.notify(
"client_evaluation",
round=-1,
client_id=self.index,
phase="pre-fit",
evals=metrics,
)
if FlukeENV().get_eval_cfg().post_fit:
self.fit()
metrics = self.evaluate(evaluator, self.test_set)
if metrics:
self.notify(
"client_evaluation",
round=-1,
client_id=self.index,
phase="post-fit",
evals=metrics,
)
self._save_to_cache()
[docs]
class GossipClient(AbstractDFLClient):
"""A client for decentralized federated learning using gossip protocol.
In the gossip protocol, clients send their model to a randomly chosen neighbour. Upon
receiving models from neighbours, the client applies a specified policy to update its model.
Possible policies include:
- "random": Selects a random model from the received messages.
- "aggregate": Aggregates all received models using the average.
- "last": Uses the last received model based on the timestamp.
- "best": Selects the model with the highest accuracy on the local test set.
In case of ties, the last model processed in the order of receipt is chosen.
Args:
*args: Positional arguments passed to the parent class.
policy (str): The policy to apply when receiving models from neighbours. Must be one of
"random", "aggregate", "last", or "best". Defaults to "random".
**kwargs: Keyword arguments passed to the parent class.
Raises:
AssertionError: If the provided policy is not one of the allowed values.
"""
def __init__(
self, *args, policy: str = Literal["random", "aggregate", "last", "best"], **kwargs
):
super().__init__(*args, **kwargs)
assert policy in ["random", "aggregate", "last", "best"], f"Invalid policy {policy}."
if policy == "best":
assert (
self.test_set is not None
), "The 'best' policy requires a test set to evaluate model accuracy."
self.hyper_params.update(policy=policy)
if "eta" in kwargs:
self.hyper_params.update(eta=kwargs["eta"])
else:
self.hyper_params.update(eta=1.0)
def send_model(self) -> None:
recipient = choice(self.neighbours)
self.channel.send(
TimedMessage(
self.model, "model", self.index, inmemory=True, timestamp=self._num_updates + 1
),
recipient,
)
def _apply_policy(self, messages: list[TimedMessage]) -> tuple[Module, int]:
if self.hyper_params.policy == "random":
msg = choice(messages)
return msg.payload, msg.timestamp
elif self.hyper_params.policy == "aggregate":
# keeponly the last message from each sender
senders_set = set()
retained_messages = []
for msg in messages[::-1]:
if msg.sender not in senders_set:
senders_set.add(msg.sender)
retained_messages.append(msg)
return (
aggregate_models(
self.model,
(msg.payload for msg in retained_messages),
np.ones(len(retained_messages)) / len(retained_messages),
eta=self.hyper_params.eta,
inplace=False,
),
max(msg.timestamp for msg in retained_messages),
)
elif self.hyper_params.policy == "last":
last = 0
last_msg = None
for msg in messages:
if msg.timestamp >= last:
last = msg.timestamp
last_msg = msg
return last_msg.payload, last
elif self.hyper_params.policy == "best":
best_acc = -1
best_msg = None
evaluator = FlukeENV().get_evaluator()
for msg in messages:
acc = evaluator.evaluate(
msg.timestamp, msg.payload, self.test_set, device=self.device, loss_fn=None
)["accuracy"]
if acc >= best_acc:
best_acc = acc
best_msg = msg
return best_msg.payload, best_msg.timestamp
def receive_model(self) -> None:
if self._num_updates == 0:
return
messages = self.channel.receive_all(self.index, msg_type="model")
selected_model, updates = self._apply_policy(messages)
self._num_updates = updates
safe_load_state_dict(self.model, selected_model.state_dict())