"""Implementation of the [FedOpt21]_ algorithm.
References:
.. [FedOpt21] Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush,
Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. Adaptive Federated Optimization.
In ICLR (2021). URL: https://openreview.net/pdf?id=LkFG3lB13U5
"""
import sys
from collections import OrderedDict
from typing import Iterable
import torch
from torch.nn import Module
sys.path.append(".")
sys.path.append("..")
from ..algorithms import CentralizedFL # NOQA
from ..client import Client # NOQA
from ..data import FastDataLoader # NOQA
from ..server import Server # NOQA
from ..utils.model import STATE_DICT_KEYS_TO_IGNORE # NOQA
__all__ = [
"FedOptServer",
"FedOpt"
]
[docs]
class FedOptServer(Server):
def __init__(self,
model: Module,
test_set: FastDataLoader,
clients: Iterable[Client],
mode: str = "adam",
lr: float = 0.001,
beta1: float = 0.9,
beta2: float = 0.999,
tau: float = 0.0001,
weighted: bool = True):
super().__init__(model=model, test_set=test_set, clients=clients, weighted=weighted)
assert mode in {"adam", "yogi", "adagrad"}, \
"'mode' must be one of {'adam', 'yogi', 'adagrad'}"
assert 0 <= beta1 < 1, "beta1 must be in [0, 1)"
assert 0 <= beta2 < 1, "beta2 must be in [0, 1)"
self.hyper_params.update(
mode=mode,
lr=lr,
beta1=beta1,
beta2=beta2,
tau=tau
)
self._init_moments()
def _init_moments(self):
self.m = OrderedDict()
self.v = OrderedDict()
for key in self.model.state_dict().keys():
if "num_batches_tracked" not in key:
self.m[key] = torch.zeros_like(self.model.state_dict()[key])
# This guarantees that the second moment is >= 0 and <= tau^2
self.v[key] = torch.zeros_like(self.model.state_dict()[key])
# * self.hyper_params.tau ** 2
@torch.no_grad()
def aggregate(self, eligible: Iterable[Client], client_models: Iterable[Module]) -> None:
avg_model_sd = OrderedDict()
clients_sd = [c.state_dict() for c in client_models]
del client_models
for key in self.model.state_dict().keys():
if key.endswith(STATE_DICT_KEYS_TO_IGNORE):
avg_model_sd[key] = self.model.state_dict()[key].clone()
continue
if key.endswith("num_batches_tracked"):
mean_nbt = torch.mean(torch.Tensor([c[key] for c in clients_sd])).long()
avg_model_sd[key] = max(avg_model_sd[key], mean_nbt)
continue
den, diff = 0, 0
for i, client_sd in enumerate(clients_sd):
weight = 1 if not self.hyper_params.weighted else eligible[i].n_examples
diff += weight * (client_sd[key] - self.model.state_dict()[key])
den += weight
diff /= den
self.m[key] = self.hyper_params.beta1 * \
self.m[key] + (1 - self.hyper_params.beta1) * diff
diff_2 = diff ** 2
if self.hyper_params.mode == "adam":
self.v[key] = self.hyper_params.beta2 * self.v[key] + \
(1 - self.hyper_params.beta2) * diff_2
elif self.hyper_params.mode == "yogi":
self.v[key] -= (1 - self.hyper_params.beta2) * \
diff_2 * torch.sign(self.v[key] - diff_2)
elif self.hyper_params.mode == "adagrad":
self.v[key] += diff_2
update = self.hyper_params.lr * self.m[key] / \
(torch.sqrt(self.v[key]) + self.hyper_params.tau)
avg_model_sd[key] = self.model.state_dict()[key] + update
self.model.load_state_dict(avg_model_sd)
[docs]
class FedOpt(CentralizedFL):
def get_server_class(self) -> Server:
return FedOptServer