Source code for fluke.algorithms.fedrod

"""Implementation of the FedROD [FedROD22]_ algorithm.

References:
    .. [FedROD22] Hong-You Chen and Wei-Lun Chao. On Bridging Generic and Personalized Federated
       Learning for Image Classification. In ICLR (2022).
       URL: https://openreview.net/pdf?id=I1hQbx10Kxn
"""
import sys
from copy import deepcopy
from typing import Any, Literal

import numpy as np
import torch
from torch.nn import functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler

sys.path.append(".")
sys.path.append("..")

from ..client import Client  # NOQA
from ..data import FastDataLoader  # NOQA
from ..evaluation import Evaluator  # NOQA
from ..nets import EncoderHeadNet  # NOQA
from ..utils import OptimizerConfigurator, clear_cuda_cache  # NOQA
from ..utils.model import ModOpt  # NOQA
from . import CentralizedFL  # NOQA

__all__ = [
    "RODModel",
    "BalancedSoftmaxLoss",
    "FedRODClient",
    "FedROD"
]


[docs] class RODModel(torch.nn.Module): """Model that combines a global model and a local head. During the forward pass, the global model, formed by an encoder and a head, is used to extract the representation of the input (using the encoder). The representation is then passed to the local head and the global head. The output of the local head is added to the output of the global head and returned as the final output. Args: global_model (EncoderHeadNet): Global model. local_head (EncoderHeadNet): Local head. """ def __init__(self, global_model: EncoderHeadNet, local_head: EncoderHeadNet): super().__init__() self.local_head = local_head self.global_model = global_model def forward(self, x: torch.Tensor) -> torch.Tensor: rep = self.global_model.encoder(x) out_g = self.global_model.head(rep) out_p = self.local_head(rep.detach()) output = out_g.detach() + out_p return output
[docs] class BalancedSoftmaxLoss(torch.nn.Module): """Compute the Balanced Softmax Loss. Args: sample_per_class (torch.Tensor): Number of samples per class. """ def __init__(self, sample_per_class: torch.Tensor, reduction: Literal["mean", "sum"] = "mean"): super().__init__() self.sample_per_class = sample_per_class self.reduction = reduction def forward(self, y: torch.LongTensor, logits: torch.FloatTensor) -> torch.Tensor: spc = self.sample_per_class.type_as(logits) spc = spc.unsqueeze(0).expand(logits.shape[0], -1) logits = logits + spc.log() loss = F.cross_entropy(input=logits, target=y, reduction=self.reduction) return loss
[docs] class FedRODClient(Client): def __init__(self, index: int, train_set: FastDataLoader, test_set: FastDataLoader, optimizer_cfg: OptimizerConfigurator, loss_fn: torch.nn.Module, local_epochs: int, fine_tuning_epochs: int = 0, clipping: float = 0, **kwargs: dict[str, Any]): super().__init__(index=index, train_set=train_set, test_set=test_set, optimizer_cfg=optimizer_cfg, loss_fn=loss_fn, local_epochs=local_epochs, fine_tuning_epochs=fine_tuning_epochs, clipping=clipping, **kwargs) self.sample_per_class: torch.Tensor = torch.zeros(self.train_set.num_labels) uniq_val, uniq_count = np.unique(self.train_set.tensors[1], return_counts=True) for i, c in enumerate(uniq_val.tolist()): self.sample_per_class[c] = uniq_count[i] self._inner_modopt: ModOpt = ModOpt() @property def inner_model(self) -> ModOpt: return self._inner_modopt.model @inner_model.setter def inner_model(self, model: torch.nn.Module) -> None: self._inner_modopt.model = model @property def optimizer_head(self) -> Optimizer: return self._inner_modopt.optimizer @optimizer_head.setter def optimizer_head(self, optimizer: Optimizer) -> None: self._inner_modopt.optimizer = optimizer @property def scheduler_head(self) -> LRScheduler: return self._inner_modopt.scheduler @scheduler_head.setter def scheduler_head(self, scheduler: LRScheduler) -> None: self._inner_modopt.scheduler = scheduler def receive_model(self) -> None: super().receive_model() if self.inner_model is None: self.inner_model = deepcopy(self.model.head) def fit(self, override_local_epochs: int = 0) -> float: epochs: int = (override_local_epochs if override_local_epochs else self.hyper_params.local_epochs) self.model.train() self.inner_model.train() self.model.to(self.device) self.inner_model.to(self.device) if self.optimizer is None: self.optimizer, self.scheduler = self._optimizer_cfg(self.model) self.optimizer_head, self.scheduler_head = self._optimizer_cfg(self.inner_model) bsm_loss = BalancedSoftmaxLoss(self.sample_per_class) running_loss = 0.0 for _ in range(epochs): for _, (X, y) in enumerate(self.train_set): X, y = X.to(self.device), y.to(self.device) rep = self.model.encoder(X) out_g = self.model.head(rep) loss = bsm_loss(y, out_g) self.optimizer.zero_grad() loss.backward() self._clip_grads(self.model) self.optimizer.step() running_loss += loss.item() out_p = self.inner_model(rep.detach()) loss = self.hyper_params.loss_fn(out_g.detach() + out_p, y) self.optimizer_head.zero_grad() loss.backward() self._clip_grads(self.model) self.optimizer_head.step() self.scheduler.step() self.scheduler_head.step() running_loss /= (epochs * len(self.train_set)) self.model.cpu() self.inner_model.cpu() clear_cuda_cache() return running_loss def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str, float]: if test_set is not None and self.model is not None and self.inner_model is not None: return evaluator.evaluate(self._last_round, RODModel(self.model, self.inner_model), test_set, device=self.device) return {}
[docs] class FedROD(CentralizedFL): def get_client_class(self) -> Client: return FedRODClient