Source code for fluke.algorithms.fat

"""Implementation of the FAT [FAT20]_ algorithm.

References:
    .. [FAT20] Giulio Zizzo, Ambrish Rawat, Mathieu Sinn, Beat Buesser.
       FAT: Federated Adversarial Training. In SpicyFL@NeurIPS (2020).
       URL: https://arxiv.org/abs/2012.01791

"""
import sys
from typing import Any

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Module

sys.path.append(".")
sys.path.append("..")

from ..client import Client  # NOQA
from ..data import FastDataLoader  # NOQA
from ..utils import OptimizerConfigurator, clear_cuda_cache  # NOQA
from . import CentralizedFL  # NOQA

__all__ = [
    "FATClient",
    "FAT"
]


[docs] class FATClient(Client): def __init__(self, index: int, train_set: FastDataLoader, test_set: FastDataLoader, optimizer_cfg: OptimizerConfigurator, loss_fn: Module, # not used local_epochs: int, fine_tuning_epochs: int = 0, clipping: float = 0, eps: float = 0.1, alpha: float = 2.0 / 255, adv_iters: int = 10, k_prop: float = 1.0, **kwargs: dict[str, Any]): assert 0.0 <= k_prop <= 1.0, "k_prop must be in [0, 1]" self.sample_per_class = torch.zeros(train_set.num_labels) uniq_val, uniq_count = np.unique(train_set.tensors[1], return_counts=True) for i, c in enumerate(uniq_val.tolist()): self.sample_per_class[c] = uniq_count[i] super().__init__(index=index, train_set=train_set, test_set=test_set, loss_fn=loss_fn, optimizer_cfg=optimizer_cfg, local_epochs=local_epochs, fine_tuning_epochs=fine_tuning_epochs, clipping=clipping, **kwargs) self.hyper_params.update(eps=eps, alpha=alpha, adv_iters=adv_iters, k_prop=k_prop)
[docs] def generate_adversarial(self, model: Module, inputs: torch.Tensor, targets: torch.Tensor, eps: float = 0.1, alpha: float = 2.0 / 255, iters: int = 10) -> torch.Tensor: """ Generates adversarial examples using Projected Gradient Descent (PGD). Args: model (nn.Module): The neural network model. inputs (Tensor): Input samples. targets (Tensor): True labels for the inputs. alpha (float): Step size for each iteration. iters (int): Number of attack iterations. Returns: Tensor: Adversarial examples. """ adv_inputs = inputs.clone().detach().requires_grad_(True) for _ in range(iters): outputs = model(adv_inputs) loss = F.cross_entropy(outputs, targets) model.zero_grad() loss.backward() perturbation = alpha * adv_inputs.grad.sign() adv_inputs = adv_inputs + perturbation adv_inputs = torch.min(torch.max(adv_inputs, inputs - eps), inputs + eps) adv_inputs = torch.clamp(adv_inputs, 0, 1).detach().requires_grad_(True) return adv_inputs.requires_grad_(False)
def fit(self, override_local_epochs: int = 0) -> float: epochs: int = (override_local_epochs if override_local_epochs > 0 else self.hyper_params.local_epochs) self.model.train() self.model.to(self.device) if self.optimizer is None: self.optimizer, self.scheduler = self._optimizer_cfg(self.model) running_loss = 0.0 k = int(self.hyper_params.k_prop * self.train_set.batch_size) for _ in range(epochs): for _, (X, y) in enumerate(self.train_set): X, y = X.to(self.device), y.to(self.device) X_adv = self.generate_adversarial(self.model, X[:k], y[:k]) X = torch.cat((X, X_adv), dim=0) y = torch.cat((y, y[:k]), dim=0) self.optimizer.zero_grad() y_hat = self.model(X) loss = self.hyper_params.loss_fn(y_hat, y) loss.backward() self._clip_grads(self.model) self.optimizer.step() running_loss += loss.item() self.scheduler.step() running_loss /= (epochs * len(self.train_set)) self.model.cpu() clear_cuda_cache() return running_loss
[docs] class FAT(CentralizedFL): def get_client_class(self): return FATClient