Source code for fedsim.fl.fl_algorithm

from http import client
from tqdm import trange
import random
import math
import torch
from torch import nn
from torch.utils.data import DataLoader
from typing import Any, Dict, Hashable, Iterable, Mapping, Optional, Union
import logging
import inspect
from pprint import pformat

from .aggregators import SerialAggregator
from fedsim.utils import search_in_submodules


[docs]class FLAlgorithm(object): def __init__( self, data_manager, metric_logger, num_clients, sample_scheme, sample_rate, model_class, epochs, loss_fn, batch_size, test_batch_size, local_weight_decay, slr, clr, clr_decay, clr_decay_type, min_clr, clr_step_size, device, log_freq, *args, **kwargs, ): grandpa = inspect.getmro(self.__class__)[-2] cls = self.__class__ grandpa_args = set(inspect.signature(grandpa).parameters.keys()) self_args = set(inspect.signature(cls).parameters.keys()) added_args = self_args - grandpa_args added_args_dict = {key: getattr(self, key) for key in added_args} if len(added_args_dict) > 0: logging.info('algorithm arguments: ' + pformat(added_args_dict)) self._data_manager = data_manager self.num_clients = num_clients self.sample_scheme = sample_scheme self.sample_count = int(sample_rate * num_clients) if not 1 <= self.sample_count <= num_clients: raise Exception( 'invalid client sample size for {}% of {} clients'.format( sample_rate, num_clients)) if isinstance(model_class, str): self.model_class = search_in_submodules('fedsim.models', model_class) elif issubclass(model_class, nn.Module): self.model_class = model_class else: raise Exception("incompatiple model!") self.epochs = epochs if loss_fn == 'ce': self.loss_fn = nn.CrossEntropyLoss() elif loss_fn == 'mse': self.loss_fn = nn.MSELoss() else: raise NotImplementedError self.batch_size = batch_size self.test_batch_size = test_batch_size self.local_weight_decay = local_weight_decay self.slr = slr self.clr = clr self.clr_decay = clr_decay self.clr_decay_type = clr_decay_type self.min_clr = min_clr self.clr_step_size = clr_step_size self.metric_logger = metric_logger self.device = device self.log_freq = log_freq self._server_memory: Dict[Hashable, object] = dict() self._client_memory: Dict[int, Dict[object]] = { k: dict() for k in range(num_clients) } self.global_dataloaders = { key: DataLoader( dataset, batch_size=self.test_batch_size, pin_memory=True, ) for key, dataset in self._data_manager.get_global_dataset().items() } self.oracle_dataset = self._data_manager.get_oracle_dataset() self.rounds = 0 # for internal use only self._last_client_sampled: int = None
[docs] def write_server(self, key, obj): self._server_memory[key] = obj
[docs] def write_client(self, client_id, key, obj): self._client_memory[client_id][key] = obj
[docs] def read_server(self, key): return self._server_memory[key] if key in self._server_memory else None
[docs] def read_client(self, client_id, key): if client_id >= self.num_clients: raise Exception("invalid client id {} >=".format( id, self.num_clients)) if key in self._client_memory[client_id]: return self._client_memory[client_id][key] return None
def _sample_clients(self): if self.sample_scheme == 'uniform': clients = random.sample(range(self.num_clients), self.sample_count) elif self.sample_scheme == 'sequential': last_sampled = -1 if self._last_client_sampled is None else self._last_client_sampled clients = [ (i + 1) % self.num_clients for i in range(last_sampled, last_sampled + self.sample_count) ] self._last_client_sampled = clients[-1] else: raise NotImplementedError return clients def _send_to_client(self, client_id): return self.send_to_client(client_id=client_id) def _send_to_server(self, client_id): if self.clr_decay_type == 'step': decayed_clr = self.clr * \ (self.clr_decay ** (self.rounds // self.clr_step_size)) elif self.clr_decay_type == 'cosine': T_i = self.clr_step_size T_cur = self.rounds % T_i decayed_clr = self.min_clr + 0.5 * \ (self.clr - self.min_clr) * \ (1 + math.cos(math.pi * T_cur / T_i)) client_ctx = self.send_to_server( client_id, self._data_manager.get_local_dataset(client_id), self.epochs, self.loss_fn, self.batch_size, decayed_clr, self.local_weight_decay, self.device, ctx=self._send_to_client(client_id)) if not isinstance(client_ctx, dict): raise Exception('client should only return a dict!') return {**client_ctx, 'client_id': client_id} def _receive_from_client(self, client_msg, aggregator): client_id = client_msg.pop('client_id') return self.receive_from_client(client_id, client_msg, aggregator) def _optimize(self, aggregator): reports = self.optimize(aggregator) # purge aggregated results del aggregator return reports def _report(self, optimize_reports=None, deployment_points=None): self.report(self.global_dataloaders, self.metric_logger, self.device, optimize_reports, deployment_points) def _train(self, rounds): for self.rounds in trange(rounds): aggregator = SerialAggregator() for client_id in self._sample_clients(): client_msg = self._send_to_server(client_id) self._receive_from_client(client_msg, aggregator) opt_reports = self._optimize(aggregator) if self.rounds % self.log_freq == 0: deploy_poiont = self.deploy() self._report(opt_reports, deploy_poiont) # one last report if self.rounds % self.log_freq > 0: deploy_poiont = self.deploy() self._report(opt_reports, deploy_poiont) return 0
[docs] def train(self, rounds): return self._train(rounds=rounds)
[docs] def get_model_class(self): return self.model_class
# we do not do type hinting, however, the hints for avstract # methods are provided to help clarity for users
[docs] def send_to_client(self, client_id: int) -> Mapping[Hashable, Any]: """ returns context to send to the client corresponding to client_id. Do not send shared objects like server model if you made any before you deepcopy it. Args: client_id (int): id of the receiving client Raises: NotImplementedError: abstract class to be implemented by child Returns: Mapping[Hashable, Any]: the context to be sent in form of a Mapping """ raise NotImplementedError
[docs] def send_to_server( self, client_id: int, datasets: Dict[str, Iterable], epochs: int, loss_fn: nn.Module, batch_size: int, lr: float, weight_decay: float = 0, device: Union[int, str] = 'cuda', ctx: Optional[Dict[Hashable, Any]] = None, *args, **kwargs, ) -> Mapping[str, Any]: """ client operation on the recieved information. Args: client_id (int): id of the client datasets (Dict[str, Iterable]): this comes from Data Manager epochs (int): number of epochs to train loss_fn (nn.Module): either 'ce' (for cross-entropy) or 'mse' batch_size (int): training batch_size lr (float): client learning rate weight_decay (float, optional): weight decay for SGD. Defaults to 0. device (Union[int, str], optional): Defaults to 'cuda'. ctx (Optional[Dict[Hashable, Any]], optional): context reveived from server. Defaults to None. Raises: NotImplementedError: abstract class to be implemented by child Returns: Mapping[str, Any]: client context to be sent to the server """ raise NotImplementedError
[docs] def receive_from_client(self, client_id: int, client_msg: Mapping[Hashable, Any], aggregator: Any): """receive and aggregate info from selected clients Args: client_id (int): id of the sender (client) client_msg (Mapping[Hashable, Any]): client context that is sent aggregator (Any): aggregator instance to collect info Raises: NotImplementedError: abstract class to be implemented by child """ raise NotImplementedError
[docs] def optimize(self, aggregator: Any) -> Mapping[Hashable, Any]: """ optimize server mdoel(s) and return metrics to be reported Args: aggregator (Any): Aggregator instance Raises: NotImplementedError: abstract class to be implemented by child Returns: Mapping[Hashable, Any]: context to be reported """ raise NotImplementedError
[docs] def deploy(self) -> Optional[Mapping[Hashable, Any]]: """ return Mapping of name -> parameters_set to test the model Raises: NotImplementedError: abstract class to be implemented by child """ raise NotImplementedError
[docs] def report( self, dataloaders: Dict[str, Any], metric_logger: Any, device: str, optimize_reports: Mapping[Hashable, Any], deployment_points: Optional[Mapping[Hashable, torch.Tensor]] = None ) -> None: """test on global data and report info Args: dataloaders (Any): dict of data loaders to test the global model(s) metric_logger (Any): the logging object (e.g., SummaryWriter) device (str): 'cuda', 'cpu' or gpu number optimize_reports (Mapping[Hashable, Any]): dict returned by optimzier deployment_points (Mapping[Hashable, torch.Tensor], optional): output of deploy method Raises: NotImplementedError: abstract class to be implemented by child """ raise NotImplementedError