Source code for fedsim.fl.aggregators
[docs]class SerialAggregator(object):
def __init__(self) -> None:
self._members = dict()
def _get_pair(self, value, weight):
if weight is None:
return value, None
else:
return value * weight, weight
[docs] def add(self, key, value, weight=0):
new_v, new_w = self._get_pair(value, weight)
sum_v, cur_w = self._members.get(key, (0, 0))
self._members[key] = (sum_v + new_v, cur_w + new_w)
[docs] def get(self, key):
if key not in self._members:
raise Exception("{} is not in the aggregator".format(key))
v, w = self._members[key]
if w is None or w == 0:
return v
return v / w
[docs] def get_sum(self, key):
if key not in self._members:
raise Exception("{} is not in the aggregator".format(key))
v, _ = self._members[key]
return v
[docs] def get_weight(self, key):
if key not in self._members:
raise Exception("{} is not in the aggregator".format(key))
_, w = self._members[key]
return w
[docs] def pop(self, key):
if key not in self._members:
raise Exception("{} is not in the aggregator".format(key))
v, w = self._members.pop(key)
if w is None or w == 0:
return v
return v / w
[docs] def items(self):
for key, _ in self._members.items():
yield key, self.get(key)
[docs] def pop_all(self):
return {key: self.pop(key) for key in list(self._members.keys())}
def __contains__(self, key):
return key in self._members