Source code for MEDfl.LearningManager.model

#!/usr/bin/env python3
# froked from https://github.com/pythonlessons/mltu/blob/main/mltu/torch/model.py

import typing
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score

from .utils import params


[docs]class Model: """ Model class for training and testing PyTorch neural networks. Attributes: model (torch.nn.Module): PyTorch neural network. optimizer (torch.optim.Optimizer): PyTorch optimizer. criterion (typing.Callable): Loss function. """
[docs] def __init__( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, criterion: typing.Callable, ) -> None: """ Initialize Model class with the specified model, optimizer, and criterion. Args: model (torch.nn.Module): PyTorch neural network. optimizer (torch.optim.Optimizer): PyTorch optimizer. criterion (typing.Callable): Loss function. """ self.model = model self.optimizer = optimizer self.criterion = criterion # Get device on which model is running self.validate()
[docs] def validate(self) -> None: """ Validate model and optimizer. """ if not isinstance(self.model, torch.nn.Module): raise TypeError("model argument must be a torch.nn.Module") if not isinstance(self.optimizer, torch.optim.Optimizer): raise TypeError( "optimizer argument must be a torch.optim.Optimizer" )
[docs] def get_parameters(self) -> List[np.ndarray]: """ Get the parameters of the model as a list of NumPy arrays. Returns: List[np.ndarray]: The parameters of the model as a list of NumPy arrays. """ return [ val.cpu().numpy() for _, val in self.model.state_dict().items() ]
[docs] def set_parameters(self, parameters: List[np.ndarray]) -> None: """ Set the parameters of the model from a list of NumPy arrays. Args: parameters (List[np.ndarray]): The parameters to be set. """ params_dict = zip(self.model.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) self.model.load_state_dict(state_dict, strict=True)
[docs] def train( self, train_loader, epoch, device, privacy_engine, diff_priv=False ) -> float: """ Train the model on the given train_loader for one epoch. Args: train_loader: The data loader for training data. epoch (int): The current epoch number. device: The device on which to perform the training. privacy_engine: The privacy engine used for differential privacy (if enabled). diff_priv (bool, optional): Whether differential privacy is used. Default is False. Returns: float: The value of epsilon used in differential privacy. """ self.model.train() epsilon = 0 losses = [] top1_acc = [] for i, (X_train, y_train) in enumerate(train_loader): self.optimizer.zero_grad() # compute output y_hat = torch.squeeze(self.model(X_train), 1) loss = self.criterion(y_hat, y_train) preds = np.argmax(y_hat.detach().cpu().numpy(), axis=0) labels = y_train.detach().cpu().numpy() # measure accuracy and record loss acc = (preds == labels).mean() losses.append(loss.item()) top1_acc.append(acc) loss.backward() self.optimizer.step() if diff_priv: epsilon = privacy_engine.get_epsilon(params["DELTA"]) if (i + 1) % 10 == 0: if diff_priv: epsilon = privacy_engine.get_epsilon(params["DELTA"]) print( f"\tTrain Epoch: {epoch} \t" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {np.mean(top1_acc) * 100:.6f} " f"(ε = {epsilon:.2f}, δ = {params['DELTA']})" ) else: print( f"\tTrain Epoch: {epoch} \t" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {np.mean(top1_acc) * 100:.6f}" ) return epsilon
[docs] def evaluate( self, val_loader, device=torch.device("cpu") ) -> Tuple[float, float]: """ Evaluate the model on the given validation data. Args: val_loader: The data loader for validation data. device: The device on which to perform the evaluation. Default is 'cpu'. Returns: Tuple[float, float]: The evaluation loss and accuracy. """ correct, total, loss, accuracy = 0, 0, 0.0, [] self.model.eval() with torch.no_grad(): for X_test, y_test in val_loader: y_hat = torch.squeeze(self.model(X_test), 1) accuracy.append(accuracy_score(y_test, y_hat.round())) loss += self.criterion(y_hat, y_test).item() total += y_test.size(0) correct += np.sum( y_hat.round().detach().numpy() == y_test.detach().numpy() ) loss /= len(val_loader.dataset) return loss, np.mean(accuracy)
[docs] @staticmethod def save_model(model , model_name:str): """ Saves a PyTorch model to a file. Args: model (torch.nn.Module): PyTorch model to be saved. model_name (str): Name of the model file. Raises: Exception: If there is an issue during the saving process. Returns: None """ try: torch.save(model, '../../notebooks/.ipynb_checkpoints/trainedModels/' + model_name + ".pth") except Exception as e: raise Exception(f"Error saving the model: {str(e)}")
[docs] @staticmethod def load_model(model_name:str): """ Loads a PyTorch model from a file. Args: model_name (str): Name of the model file to be loaded. Returns: torch.nn.Module: Loaded PyTorch model. """ loadedModel = torch.load('../../notebooks/.ipynb_checkpoints/trainedModels/'+model_name+".pth") return loadedModel