r""" This file contains an implementation of the following paper:
Title: "Minimizing Client Drift in Federated Learning via Adaptive Bias Estimation"
Authors: Farshid Varno, Marzie Saghayi, Laya Rafiee, Sharut Gupta, Stan Matwin, Mohammad Havaei
Publication date: [Submitted on 27 Apr 2022 (v1), last revised 23 May 2022 (this version, v2)]
Link: https://arxiv.org/abs/2204.13170
"""
from torch.nn.utils import parameters_to_vector
from functools import partial
import torch
from . import fedavg
from ..utils import vector_to_parameters_like, default_closure
from ..aggregators import SerialAggregator
[docs]class AdaBest(fedavg.FedAvg):
def __init__(
self,
data_manager,
metric_logger,
num_clients,
sample_scheme,
sample_rate,
model_class,
epochs,
loss_fn,
batch_size=32,
test_batch_size=64,
local_weight_decay=0.,
slr=1.,
clr=0.1,
clr_decay=1.,
clr_decay_type='step',
min_clr=1e-12,
clr_step_size=1000,
device='cuda',
log_freq=10,
mu=0.02,
beta=0.98,
*args,
**kwargs,
):
self.mu = mu
self.beta = beta
# this is to avoid violations like reading oracle info and
# number of clients in FedDyn and SCAFFOLD
self.general_agg = SerialAggregator()
super(AdaBest, self).__init__(
data_manager,
metric_logger,
num_clients,
sample_scheme,
sample_rate,
model_class,
epochs,
loss_fn,
batch_size,
test_batch_size,
local_weight_decay,
slr,
clr,
clr_decay,
clr_decay_type,
min_clr,
clr_step_size,
device,
log_freq,
)
cloud_params = self.read_server('cloud_params')
self.write_server('avg_params', cloud_params.detach().clone())
for client_id in range(num_clients):
self.write_client(client_id, 'h', torch.zeros_like(cloud_params))
self.write_client(client_id, 'last_round', -1)
self.write_server('average_sample', 0)
[docs] def send_to_server(
self,
client_id,
datasets,
epochs,
loss_fn,
batch_size,
lr,
weight_decay=0,
device='cuda',
ctx=None,
*args,
**kwargs,
):
model = ctx['model']
params_init = parameters_to_vector(model.parameters()).detach().clone()
h = self.read_client(client_id, 'h')
mu_adaptive = self.mu / len(datasets['train']) *\
self.read_server('average_sample')
def transform_grads_fn(model):
grad_additive = -h
grad_additive_list = vector_to_parameters_like(
mu_adaptive * grad_additive, model.parameters())
for p, g_a in zip(model.parameters(), grad_additive_list):
p.grad += g_a
step_closure_ = partial(default_closure,
transform_grads=transform_grads_fn)
opt_res = super(AdaBest, self).send_to_server(
client_id,
datasets,
epochs,
loss_fn,
batch_size,
lr,
weight_decay,
device,
ctx,
step_closure=step_closure_,
*args,
**kwargs,
)
# update local h
pseudo_grads = (
params_init - \
parameters_to_vector(model.parameters()).detach().clone().data
)
t = self.rounds
new_h = 1 / (t - self.read_client(client_id, 'last_round')) * h +\
pseudo_grads
self.write_client(client_id, 'h', new_h)
self.write_client(client_id, 'last_round', self.rounds)
return opt_res
[docs] def receive_from_client(self, client_id, client_msg, aggregation_results):
weight = 1
self.agg(client_id, client_msg, aggregation_results, weight=weight)
self.general_agg.add('avg_m', client_msg['num_samples'] / self.epochs,
1)
self.write_server('average_sample', self.general_agg.get('avg_m'))
[docs] def optimize(self, aggregator):
if 'local_params' in aggregator:
param_avg = aggregator.pop('local_params')
optimizer = self.read_server('optimizer')
cloud_params = self.read_server('cloud_params')
h = self.beta * (self.read_server('avg_params') - param_avg)
new_params = param_avg - h
modified_pseudo_grads = cloud_params.data - new_params
# update cloud params
optimizer.zero_grad()
cloud_params.grad = modified_pseudo_grads
optimizer.step()
self.write_server('avg_params', param_avg.detach().clone())
return aggregator.pop_all()
[docs] def deploy(self):
return dict(
cloud=self.read_server('cloud_params'),
avg=self.read_server('avg_params'),
)