import datetime
from typing import List
import json
import pandas as pd
# File: create_query.py
from sqlalchemy import text
from torch.utils.data import DataLoader, TensorDataset
from Medfl.LearningManager.server import FlowerServer
from Medfl.LearningManager.utils import params, test
from scripts.base import my_eng
from Medfl.NetManager.net_helper import get_flpipeline_from_name
from Medfl.NetManager.net_manager_queries import (CREATE_FLPIPELINE_QUERY,
DELETE_FLPIPELINE_QUERY , CREATE_TEST_RESULTS_QUERY)
[docs]
def create_query(name, description, creation_date, result):
query = text(
f"INSERT INTO FLpipeline(name, description, creation_date, results) "
f"VALUES ('{name}', '{description}', '{creation_date}', '{result}')"
)
return query
[docs]
class FLpipeline:
"""
FLpipeline class for managing Federated Learning pipelines.
Attributes:
name (str): The name of the FLpipeline.
description (str): A description of the FLpipeline.
server (FlowerServer): The FlowerServer object associated with the FLpipeline.
Methods:
__init__(self, name: str, description: str, server: FlowerServer) -> None:
Initialize FLpipeline with the specified name, description, and server.
"""
def __init__(
self, name: str, description: str, server: FlowerServer
) -> None:
self.name = name
self.description = description
self.server = server
self.validate()
[docs]
def validate(self) -> None:
"""
Validate the name, description, and server attributes.
Raises:
TypeError: If the name is not a string, the description is not a string,
or the server is not a FlowerServer object.
"""
if not isinstance(self.name, str):
raise TypeError("name argument must be a string")
if not isinstance(self.description, str):
raise TypeError("description argument must be a string")
if not isinstance(self.server, FlowerServer):
raise TypeError("server argument must be a FlowerServer")
[docs]
def create(self, result: str) -> None:
"""
Create a new FLpipeline entry in the database with the given result.
Args:
result (str): The result string to store in the database.
"""
creation_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
query = CREATE_FLPIPELINE_QUERY.format(
name=self.name,
description=self.description,
creation_date=creation_date,
result=result,
)
my_eng.execute(text(query))
self.id = get_flpipeline_from_name(self.name)
try:
self.server.fed_dataset.update(
FLpipeId=self.id, FedId=self.server.fed_dataset.id
)
except:
pass
[docs]
def delete(self) -> None:
"""
Delete the FLpipeline entry from the database based on its name.
Note: This is a placeholder method and needs to be implemented based on your specific database setup.
"""
# Placeholder code for deleting the FLpipeline entry from the database based on the name.
# You need to implement the actual deletion based on your database setup.
my_eng.execute(DELETE_FLPIPELINE_QUERY.format(self.name))
[docs]
def test_by_node(self, node_name: str, test_frac=1) -> dict:
"""
Test the FLpipeline by node with the specified test_frac.
Args:
node_name (str): The name of the node to test.
test_frac (float, optional): The fraction of the test data to use. Default is 1.
Returns:
dict: A dictionary containing the node name and the classification report.
"""
idx = self.server.fed_dataset.test_nodes.index(node_name)
global_model, test_loader = (
self.server.global_model,
self.server.fed_dataset.testloaders[idx],
)
test_data = test_loader.dataset
test_data = TensorDataset(
test_data[: int(test_frac * len(test_data))][0],
test_data[: int(test_frac * len(test_data))][1],
)
test_loader = DataLoader(
test_data, batch_size=params["test_batch_size"]
)
classification_report = test(
model=global_model.model, test_loader=test_loader
)
return {
"node_name": node_name,
"classification_report": str(classification_report),
}
[docs]
def auto_test(self, test_frac=1) -> List[dict]:
"""
Automatically test the FLpipeline on all nodes with the specified test_frac.
Args:
test_frac (float, optional): The fraction of the test data to use. Default is 1.
Returns:
List[dict]: A list of dictionaries containing the node names and the classification reports.
"""
result = [
self.test_by_node(node, test_frac)
for node in self.server.fed_dataset.test_nodes
]
self.create("\n".join(str(res).replace("'", '"') for res in result))
# stockage des resultats des tests
for entry in result:
node_name = entry['node_name']
classification_report_str = entry['classification_report']
# Convert the 'classification_report' string to a dictionary
classification_report_dict = json.loads(classification_report_str.replace("'", "\""))
try:
# Insert record into the 'testResults' table
query = CREATE_TEST_RESULTS_QUERY.format(
pipelineId = self.id,
nodeName = node_name ,
confusion_matrix = json.dumps(classification_report_dict['confusion matrix']),
accuracy =classification_report_dict['Accuracy'] ,
sensivity = classification_report_dict['Sensitivity/Recall'] ,
ppv = classification_report_dict['PPV/Precision'] ,
npv= classification_report_dict['NPV'] ,
f1score= classification_report_dict['F1-score'] ,
fpr= classification_report_dict['False positive rate'] ,
tpr= classification_report_dict['True positive rate']
)
my_eng.execute(text(query))
except Exception as e:
# This block will catch any other exceptions
print(f"An unexpected error occurred: {e}")
return result