#!/usr/bin/env python3
import flwr as fl
from opacus import PrivacyEngine
from torch.utils.data import DataLoader
from .model import Model
from .utils import params
import torch
[docs]
class FlowerClient(fl.client.NumPyClient):
"""
FlowerClient class for creating MEDfl clients.
Attributes:
cid (str): Client ID.
local_model (Model): Local model of the federated learning network.
trainloader (DataLoader): DataLoader for training data.
valloader (DataLoader): DataLoader for validation data.
diff_priv (bool): Flag indicating whether to use differential privacy.
"""
[docs]
def __init__(self, cid: str, local_model: Model, trainloader: DataLoader, valloader: DataLoader, diff_priv: bool = params["diff_privacy"]):
"""
Initializes the FlowerClient instance.
Args:
cid (str): Client ID.
local_model (Model): Local model of the federated learning network.
trainloader (DataLoader): DataLoader for training data.
valloader (DataLoader): DataLoader for validation data.
diff_priv (bool): Flag indicating whether to use differential privacy.
"""
self.cid = cid
self.local_model = local_model
self.trainloader = trainloader
self.valloader = valloader
self.device = torch.device(f"cuda:{int(self.cid) % 4}" if torch.cuda.is_available() else "cpu")
self.local_model.model.to(self.device)
self.privacy_engine = PrivacyEngine(secure_mode=False)
self.diff_priv = diff_priv
self.epsilons = []
self.accuracies = []
self.losses = []
if self.diff_priv:
model, optimizer, self.trainloader = self.privacy_engine.make_private_with_epsilon(
module=self.local_model.model.train(),
optimizer=self.local_model.optimizer,
data_loader=self.trainloader,
epochs=params["train_epochs"],
target_epsilon=params["EPSILON"],
target_delta=params["DELTA"],
max_grad_norm=params["MAX_GRAD_NORM"],
)
setattr(self.local_model, "model", model)
setattr(self.local_model, "optimizer", optimizer)
self.validate()
[docs]
def validate(self):
"""Validates cid, local_model, trainloader, valloader."""
if not isinstance(self.cid, str):
raise TypeError("cid argument must be a string")
if not isinstance(self.local_model, Model):
raise TypeError("local_model argument must be a Medfl.LearningManager.model.Model")
if not isinstance(self.trainloader, DataLoader):
raise TypeError("trainloader argument must be a torch.utils.data.dataloader")
if not isinstance(self.valloader, DataLoader):
raise TypeError("valloader argument must be a torch.utils.data.dataloader")
if not isinstance(self.diff_priv, bool):
raise TypeError("diff_priv argument must be a bool")
[docs]
def get_parameters(self, config):
"""
Returns the current parameters of the local model.
Args:
config: Configuration information.
Returns:
Numpy array: Parameters of the local model.
"""
print(f"[Client {self.cid}] get_parameters")
return self.local_model.get_parameters()
[docs]
def fit(self, parameters, config):
"""
Fits the local model to the received parameters using federated learning.
Args:
parameters: Parameters received from the server.
config: Configuration information.
Returns:
Tuple: Parameters of the local model, number of training examples, and privacy information.
"""
print(f"[Client {self.cid}] fit, config: {config}")
self.local_model.set_parameters(parameters)
for _ in range(params["train_epochs"]):
epsilon = self.local_model.train(
self.trainloader,
epoch=_,
device=self.device,
privacy_engine=self.privacy_engine,
diff_priv=self.diff_priv,
)
self.epsilons.append(epsilon)
print(f"epsilon of client {self.cid} : eps = {epsilon}")
return (
self.local_model.get_parameters(),
len(self.trainloader),
{"epsilon": epsilon},
)
[docs]
def evaluate(self, parameters, config):
"""
Evaluates the local model on the validation data and returns the loss and accuracy.
Args:
parameters: Parameters received from the server.
config: Configuration information.
Returns:
Tuple: Loss, number of validation examples, and accuracy information.
"""
print(f"[Client {self.cid}] evaluate, config: {config}")
self.local_model.set_parameters(parameters)
loss, accuracy = self.local_model.evaluate(
self.valloader, device=self.device
)
self.losses.append(loss)
self.accuracies.append(accuracy)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}