Source code for fluke.algorithms.fedbabu

"""Implementation of the Federated Averaging with Body Aggregation and Body Update [FedBABU22]_
algorithm.

References:
    .. [FedBABU22] Jaehoon Oh, Sangmook Kim, Se-Young Yun. FedBABU: Towards Enhanced Representation
       for Federated Image Classification. In ICLR (2022). URL: https://arxiv.org/abs/2106.06042
"""
import sys
from typing import Any, Iterable

from rich.progress import Progress
from torch.nn import Module

sys.path.append(".")
sys.path.append("..")

from .. import FlukeENV  # NOQA
from ..algorithms import PersonalizedFL  # NOQA
from ..client import Client  # NOQA
from ..comm import Message  # NOQA
from ..data import FastDataLoader  # NOQA
from ..nets import EncoderHeadNet  # NOQA
from ..server import Server  # NOQA
from ..utils import OptimizerConfigurator, clear_cuda_cache  # NOQA
from ..utils.model import safe_load_state_dict  # NOQA


[docs] class FedBABUClient(Client): def __init__(self, index: int, model: EncoderHeadNet, train_set: FastDataLoader, test_set: FastDataLoader, optimizer_cfg: OptimizerConfigurator, loss_fn: Module, local_epochs: int, mode: str, fine_tuning_epochs: int, clipping: float = 0, **kwargs: dict[str, Any]): assert mode in ["head", "body", "full"] super().__init__(index=index, train_set=train_set, test_set=test_set, optimizer_cfg=optimizer_cfg, loss_fn=loss_fn, local_epochs=local_epochs, fine_tuning_epochs=fine_tuning_epochs, clipping=clipping, **kwargs) self.hyper_params.update(mode=mode) self._load_from_cache() self.model = model self._save_to_cache() def send_model(self): self.channel.send(Message(self.model.encoder, "model", self, inmemory=True), self.server) def receive_model(self) -> None: msg = self.channel.receive(self, self.server, msg_type="model") safe_load_state_dict(self.model.encoder, msg.payload.state_dict()) # Deactivate gradient for param in self.model.head.parameters(): param.requires_grad = False
[docs] def fine_tune(self) -> None: """Fine-tune the personalized model.""" self._load_from_cache() if self.hyper_params.mode == "full": for param in self.model.parameters(): param.requires_grad = True elif self.hyper_params.mode == "body": for param in self.model.encoder.parameters(): param.requires_grad = True for param in self.model.head.parameters(): param.requires_grad = False else: # head for param in self.model.encoder.parameters(): param.requires_grad = False for param in self.model.head.parameters(): param.requires_grad = True self.model.train() self.model.to(self.device) self.pers_optimizer, self.pers_scheduler = self._optimizer_cfg(self.model) for _ in range(self.hyper_params.fine_tuning_epochs): loss = None for _, (X, y) in enumerate(self.train_set): X, y = X.to(self.device), y.to(self.device) self.pers_optimizer.zero_grad() y_hat = self.model(X) loss = self.hyper_params.loss_fn(y_hat, y) loss.backward() self.pers_optimizer.step() self.pers_scheduler.step() self.model.cpu() clear_cuda_cache() self._save_to_cache()
def finalize(self) -> None: self._load_from_cache() metrics = self.evaluate(FlukeENV().get_evaluator(), self.test_set) if metrics: self._notify_evaluation(-1, "post-fit", metrics) self._save_to_cache()
[docs] class FedBABUServer(Server): def __init__(self, model: Module, test_set: FastDataLoader, # not used clients: Iterable[Client], weighted: bool = False): super().__init__(model=model, test_set=None, clients=clients, weighted=weighted) def finalize(self) -> None: with Progress(transient=True) as progress: task = progress.add_task("[cyan]Client's fine tuning", total=len(self._participants)) clients_ft = [client for client in self.clients if client.index in self._participants] for client in clients_ft: client.fine_tune() client.finalize() progress.update(task, advance=1) self._notify_finalize()
[docs] class FedBABU(PersonalizedFL): def get_client_class(self) -> Client: return FedBABUClient def get_server_class(self) -> Server: return FedBABUServer