Source code for fedsim.fl.utils

from functools import partial
from inspect import signature
import torch
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.convert_parameters import _check_param_device


[docs]def get_metric_scores(metric_fn_dict, y_true, y_pred): answer = {} if metric_fn_dict is None: return answer for name, fn in metric_fn_dict.items(): args = dict() if 'y_true' in signature(fn).parameters: args['y_true'] = y_true elif 'target' in signature(fn).parameters: args['target'] = y_true else: raise NotImplementedError if 'y_pred' in signature(fn).parameters: args['y_pred'] = y_pred elif 'input' in signature(fn).parameters: args['input'] = y_pred else: raise NotImplementedError answer[name] = fn(**args) return answer
[docs]def default_closure(x, y, model, loss_fn, optimizer, metric_fn_dict, max_grad_norm=1000, link_fn=partial(torch.argmax, dim=1), device='cpu', transform_grads=None, transform_y=None, **kwargs): if transform_y is not None: y = transform_y(y) y_true = y.tolist() x = x.to(device) y = y.reshape(-1).long() y = y.to(device) model.train() outputs = model(x) loss = loss_fn(outputs, y) if loss.isnan() or loss.isinf(): return loss # backpropagation loss.backward() if transform_grads is not None: transform_grads(model) # Clip gradients clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm) # optimize optimizer.step() optimizer.zero_grad() y_pred = link_fn(outputs).tolist() metrics = get_metric_scores(metric_fn_dict, y_true, y_pred) return loss, metrics
[docs]def vector_to_parameters_like(vec, parameters_like): r"""Convert one vector to new parameters like the ones provided Args: vec (Tensor): a single vector represents the parameters of a model. parameters (Iterable[Tensor]): an iterator of Tensors that are the parameters of a model. This is only used to get the sizes. New parametere are defined. """ # Ensure vec of type Tensor if not isinstance(vec, torch.Tensor): raise TypeError('expected torch.Tensor, but got: {}'.format( torch.typename(vec))) # Flag for the device where the parameter is located param_device = None # Pointer for slicing the vector for each parameter pointer = 0 new_params = [] for param in parameters_like: # Ensure the parameters are located in the same device param_device = _check_param_device(param, param_device) # The length of the parameter num_param = param.numel() # Slice the vector, reshape it, and replace the old data of the # parameter new_params.append(vec[pointer:pointer + num_param].view_as(param).data) # Increment the pointer pointer += num_param return new_params
[docs]class ModelReconstructor(torch.nn.Module): def __init__(self, feature_extractor, classifier, connection_fn=None) -> None: super(ModelReconstructor, self).__init__() self.feature_extractor = feature_extractor self.classifier = classifier self.connection_fn = connection_fn
[docs] def forward(self, input): features = self.feature_extractor(input) if self.connection_fn is not None: features = self.connection_fn(features) return self.classifier(features)