Source code for MEDfl.LearningManager.utils

#!/usr/bin/env python3

import pkg_resources
import torch
import yaml
from sklearn.metrics import *
from yaml.loader import SafeLoader

from scripts.base import *
import json


import pandas as pd
import numpy as np

yaml_path = pkg_resources.resource_filename(__name__, "params.yaml")
with open(yaml_path) as g:
    params = yaml.load(g, Loader=SafeLoader)

global_yaml_path = pkg_resources.resource_filename(__name__, "../../global_params.yaml")
with open(global_yaml_path) as g:
    global_params = yaml.load(g, Loader=SafeLoader)


<<<<<<< HEAD
[docs]def custom_classification_report(y_true, y_pred_prob): =======
[docs]def custom_classification_report(y_true, y_pred): >>>>>>> 58b1e52fd8ab5e97505682e684fd63c00521021e """ Compute custom classification report metrics including accuracy, sensitivity, specificity, precision, NPV, F1-score, false positive rate, and true positive rate. Args: y_true (array-like): True labels. y_pred (array-like): Predicted labels. Returns: dict: A dictionary containing custom classification report metrics. """ y_pred = (y_pred_prob).round() # Round absolute values of predicted probabilities to the nearest integer auc = roc_auc_score(y_true, y_pred_prob) # Calculate AUC tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() # Accuracy denominator_acc = tp + tn + fp + fn acc = (tp + tn) / denominator_acc if denominator_acc != 0 else 0.0 # Sensitivity/Recall denominator_sen = tp + fn sen = tp / denominator_sen if denominator_sen != 0 else 0.0 # Specificity denominator_sp = tn + fp sp = tn / denominator_sp if denominator_sp != 0 else 0.0 # PPV/Precision denominator_ppv = tp + fp ppv = tp / denominator_ppv if denominator_ppv != 0 else 0.0 # NPV denominator_npv = tn + fn npv = tn / denominator_npv if denominator_npv != 0 else 0.0 # F1 Score denominator_f1 = sen + ppv f1 = 2 * (sen * ppv) / denominator_f1 if denominator_f1 != 0 else 0.0 # False Positive Rate denominator_fpr = fp + tn fpr = fp / denominator_fpr if denominator_fpr != 0 else 0.0 # True Positive Rate denominator_tpr = tp + fn tpr = tp / denominator_tpr if denominator_tpr != 0 else 0.0 return { "confusion matrix": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}, "Accuracy": round(acc, 3), "Sensitivity/Recall": round(sen, 3), "Specificity": round(sp, 3), "PPV/Precision": round(ppv, 3), "NPV": round(npv, 3), "F1-score": round(f1, 3), "False positive rate": round(fpr, 3), "True positive rate": round(tpr, 3), "auc": auc }
[docs]def test(model, test_loader, device=torch.device("cpu")): """ Evaluate a model using a test loader and return a custom classification report. Args: model (torch.nn.Module): PyTorch model to evaluate. test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset. device (torch.device, optional): Device for model evaluation. Default is "cpu". Returns: dict: A dictionary containing custom classification report metrics. """ model.eval() with torch.no_grad(): X_test, y_test = test_loader.dataset[:][0].to(device), test_loader.dataset[:][1].to(device) y_hat_prob = torch.squeeze(model(X_test), 1).cpu() return custom_classification_report(y_test.cpu().numpy(), y_hat_prob.cpu().numpy())
column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
[docs]def empty_db(): """ Empty the database by deleting records from multiple tables and resetting auto-increment counters. Returns: None """ # my_eng.execute(text(f"DELETE FROM {'DataSets'}")) my_eng.execute(text(f"DELETE FROM {'Nodes'}")) my_eng.execute(text(f"DELETE FROM {'FedDatasets'}")) my_eng.execute(text(f"DELETE FROM {'Networks'}")) my_eng.execute(text(f"DELETE FROM {'FLsetup'}")) my_eng.execute(text(f"DELETE FROM {'FLpipeline'}")) my_eng.execute(text(f"ALTER TABLE {'Nodes'} AUTO_INCREMENT = 1")) my_eng.execute(text(f"ALTER TABLE {'Networks'} AUTO_INCREMENT = 1")) my_eng.execute(text(f"ALTER TABLE {'FedDatasets'} AUTO_INCREMENT = 1")) my_eng.execute(text(f"ALTER TABLE {'FLsetup'} AUTO_INCREMENT = 1")) my_eng.execute(text(f"ALTER TABLE {'FLpipeline'} AUTO_INCREMENT = 1")) my_eng.execute(text(f"DELETE FROM {'testResults'}")) my_eng.execute(text(f"DROP TABLE IF EXISTS {'MasterDataset'}")) my_eng.execute(text(f"DROP TABLE IF EXISTS {'DataSets'}"))
[docs]def get_pipeline_from_name(name): """ Get the pipeline ID from its name in the database. Args: name (str): Name of the pipeline. Returns: int: ID of the pipeline. """ NodeId = int( pd.read_sql( text(f"SELECT id FROM FLpipeline WHERE name = '{name}'"), my_eng ).iloc[0, 0] ) return NodeId
[docs]def get_pipeline_confusion_matrix(pipeline_id): """ Get the global confusion matrix for a pipeline based on test results. Args: pipeline_id (int): ID of the pipeline. Returns: dict: A dictionary representing the global confusion matrix. """ data = pd.read_sql( text(f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng ) # Convert the column of strings into a list of dictionaries representing confusion matrices confusion_matrices = [ json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix'] ] # Initialize variables for global confusion matrix global_TP = global_FP = global_FN = global_TN = 0 # Iterate through each dictionary and sum the corresponding values for each category for matrix in confusion_matrices: global_TP += matrix['TP'] global_FP += matrix['FP'] global_FN += matrix['FN'] global_TN += matrix['TN'] # Create a global confusion matrix as a dictionary global_confusion_matrix = { 'TP': global_TP, 'FP': global_FP, 'FN': global_FN, 'TN': global_TN } # Return the list of dictionaries representing confusion matrices return global_confusion_matrix
[docs]def get_node_confusion_matrix(pipeline_id , node_name): """ Get the confusion matrix for a specific node in a pipeline based on test results. Args: pipeline_id (int): ID of the pipeline. node_name (str): Name of the node. Returns: dict: A dictionary representing the confusion matrix for the specified node. """ data = pd.read_sql( text(f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}' AND nodename = '{node_name}'"), my_eng ) # Convert the column of strings into a list of dictionaries representing confusion matrices confusion_matrices = [ json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix'] ] # Return the list of dictionaries representing confusion matrices return confusion_matrices[0]
[docs]def get_pipeline_result(pipeline_id): """ Get the test results for a pipeline. Args: pipeline_id (int): ID of the pipeline. Returns: pandas.DataFrame: DataFrame containing test results for the specified pipeline. """ data = pd.read_sql( text(f"SELECT * FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng ) return data