Source code for locpix.scripts.img_seg.membrane_performance

#!/usr/bin/env python
"""Module for evluating performance of membrane
segmentation

For the train set, it will use the probability map over
the image, and use this to assign probability of belonging
to membrane to each localisation.
Then it will use this to plot precision-recall curve,
and calculate the optimal threshold.

For the test set, it will use the probability maps to
plot a precision recall curve.
It will then use the threshold calculated from the train
set to threshold the test images

Now we have output membrane segmentations on the test set
which we can use to finally calculate the other
remaining performance metrics."""

import yaml
import os
import numpy as np
from locpix.preprocessing import datastruc

# from locpix.visualise.performance import plot_pr_curve , generate_binary_conf_matrix
import locpix.evaluate.metrics as metrics
from sklearn.metrics import precision_recall_curve, auc
import polars as pl
from datetime import datetime

# import matplotlib.pyplot as plt
# from locpix.visualise import vis_img
import argparse
import json
import time


[docs] def main(): parser = argparse.ArgumentParser( description="Membrane performance metrics on data." ) parser.add_argument( "-i", "--project_directory", action="store", type=str, help="the location of the project directory", required=True, ) parser.add_argument( "-c", "--config", action="store", type=str, help="the location of the .yaml configuaration file\ for preprocessing", required=True, ) parser.add_argument( "-m", "--project_metadata", action="store_true", help="check the metadata for the specified project and" "seek confirmation!", ) args = parser.parse_args() project_folder = args.project_directory # load config with open(args.config, "r") as ymlfile: config = yaml.safe_load(ymlfile) metadata_path = os.path.join(project_folder, "metadata.json") with open( metadata_path, ) as file: metadata = json.load(file) # check metadata if args.project_metadata: print("".join([f"{key} : {value} \n" for key, value in metadata.items()])) check = input("Are you happy with this? (YES)") if check != "YES": exit() # add time ran this script to metadata file = os.path.basename(__file__) if file not in metadata: metadata[file] = time.asctime(time.gmtime(time.time())) else: print("Overwriting metadata...") metadata[file] = time.asctime(time.gmtime(time.time())) # load in train and test files train_folds = metadata["train_folds"] val_folds = metadata["val_folds"] test_files = metadata["test_files"] with open(metadata_path, "w") as outfile: json.dump(metadata, outfile) for fold, (train_files, val_files) in enumerate(zip(train_folds, val_folds)): # check files if not set(train_files).isdisjoint(test_files): raise ValueError("Train files and test files shared files!!") if not set(train_files).isdisjoint(val_files): raise ValueError("Train files and test files shared files!!") if not set(val_files).isdisjoint(test_files): raise ValueError("Train files and test files shared files!!") if len(set(train_files)) != len(train_files): raise ValueError("Train files contains duplicates") if len(set(val_files)) != len(val_files): raise ValueError("Test files contains duplicates") if len(set(test_files)) != len(test_files): raise ValueError("Test files contains duplicates") # list items gt_file_path = os.path.join(project_folder, "annotate/annotated") try: files = os.listdir(gt_file_path) except FileNotFoundError: raise ValueError("There should be some files to open") # fig_train, ax_train = plt.subplots() # fig_test, ax_test = plt.subplots() # fig_val, ax_val = plt.subplots() # linestyles = ["dashdot", "-", "--", "dotted"] methods = ["classic", "cellpose_no_train", "cellpose_train", "ilastik"] # output_overlay_pr_curves = os.path.join( # project_folder, f"membrane_performance/overlaid_pr_curves/{fold}" # ) # if os.path.exists(output_overlay_pr_curves): # raise ValueError( # f"Cannot proceed as {output_overlay_pr_curves} already exists" # ) # else: # os.makedirs(output_overlay_pr_curves) # one date for all methods date = datetime.today().strftime("%H_%M_%d_%m_%Y") for index, method in enumerate(methods): print(f"{method} ...") # get folder names if method == "ilastik": seg_folder = os.path.join( project_folder, f"ilastik/output/membrane/prob_map/{fold}" ) elif method == "classic" or method == "cellpose_no_train": seg_folder = os.path.join( project_folder, f"{method}/membrane/prob_map/" ) elif method == "cellpose_train": seg_folder = os.path.join( project_folder, f"{method}/{fold}/membrane/prob_map/" ) output_df_folder_test = os.path.join( project_folder, f"membrane_performance/{method}/membrane/seg_dataframes/test/{fold}", ) output_df_folder_val = os.path.join( project_folder, f"membrane_performance/{method}/membrane/seg_dataframes/val/{fold}", ) # output_seg_imgs_test = os.path.join( # project_folder, # f"membrane_performance/{method}/membrane/seg_images/test/{fold}", # ) # output_seg_imgs_val = os.path.join( # project_folder, # f"membrane_performance/{method}/membrane/seg_images/val/{fold}", # ) # output_train_pr = os.path.join( # project_folder, # f"membrane_performance/{method}/membrane/train_pr/{fold}", # ) # output_test_pr = os.path.join( # project_folder, f"membrane_performance/{method}/membrane/test_pr/ # {fold}" # ) # output_val_pr = os.path.join( # project_folder, f"membrane_performance/{method}/membrane/val_pr/{fold}" # ) output_metrics = os.path.join( project_folder, f"membrane_performance/{method}/membrane/metrics/{fold}" ) # output_conf_matrix = os.path.join( # project_folder, # f"membrane_performance/{method}/membrane/conf_matrix/{fold}", # ) # create output folders folders = [ output_df_folder_test, output_df_folder_val, # output_seg_imgs_val, # output_seg_imgs_test, # output_train_pr, # output_val_pr, # output_test_pr, output_metrics, # output_conf_matrix, ] for folder in folders: # if output directory not present create it if os.path.exists(folder): raise ValueError(f"Cannot proceed as {folder} already exists") else: os.makedirs(folder) print("Train set...") prob_list = np.array([]) gt_list = np.array([]) for file in files: if file.removesuffix(".parquet") not in train_files: continue print("File ", file) # load df item = datastruc.item(None, None, None, None, None) item.load_from_parquet(os.path.join(gt_file_path, file)) # load prob map img_prob = np.load(os.path.join(seg_folder, item.name + ".npy")) # merge prob map and df merged_df = item.mask_pixel_2_coord(img_prob) merged_df = merged_df.rename({"pred_label": "prob"}) # get gt and predicted probability gt = merged_df.select(pl.col("gt_label")).to_numpy() prob = merged_df.select(pl.col("prob")).to_numpy() # append to aggregated data set prob_list = np.append(prob_list, prob) gt_list = np.append(gt_list, gt) # print("Sanity check... ") # print("gt", len(gt_list), gt_list) # calculate precision recall curve gt_list = gt_list.flatten() prob_list = prob_list.flatten().round(2) pr, rec, pr_threshold = precision_recall_curve( gt_list, prob_list, pos_label=1 ) baseline = len(gt_list[gt_list == 1]) / len(gt_list) # pr, recall saved for train save_loc = os.path.join(output_metrics, f"train_{date}.txt") lines = ["Overall results", "-----------"] lines.append(f"prcurve_pr: {list(pr)}") lines.append(f"prcurve_rec: {list(rec)}") lines.append(f"prcurve_baseline: {baseline}") with open(save_loc, "w") as f: f.writelines("\n".join(lines)) # plot pr curve # save_loc = os.path.join(output_train_pr, "_curve.pkl") # plot_pr_curve( # ax_train, # method.capitalize(), # linestyles[index], # "darkorange", # pr, # rec, # baseline, # # save_loc, # # pickle=True, # ) # calculate optimal threshold if config["maximise_choice"] == "recall": max_index = np.argmax(rec) elif config["maximise_choice"] == "f": # get f measure and find optimal fmeas = (2 * pr * rec) / (pr + rec) # remove nan from optimal threshold calculation nan_array = np.isnan(fmeas) fmeas = np.where(nan_array, 0, fmeas) max_index = np.argmax(fmeas) else: raise ValueError("Choice not supported") threshold = pr_threshold[max_index] print("Threshold for ", method, " :", threshold) print("Test set...") metadata = { "train_set": train_files, "test_set": test_files, "threshold": threshold, "val_set": val_files, } prob_list = np.array([]) gt_list = np.array([]) # sanity check all have same gt label map gt_label_map = None # threshold dataframe and save to parquet file with pred label for file in files: if file.removesuffix(".parquet") not in test_files: continue print("File ", file) item = datastruc.item(None, None, None, None, None) item.load_from_parquet(os.path.join(gt_file_path, file)) # load prob map img_prob = np.load(os.path.join(seg_folder, item.name + ".npy")) # merge prob map and df merged_df = item.mask_pixel_2_coord(img_prob) merged_df = merged_df.rename({"pred_label": "prob"}) # get gt and predicted probability gt = merged_df.select(pl.col("gt_label")).to_numpy() prob = merged_df.select(pl.col("prob")).to_numpy() save_df = merged_df.select( [ pl.all(), pl.when(pl.col("prob") > threshold) .then(1) .otherwise(0) .alias("pred_label"), ] ) # append to aggregated data set prob_list = np.append(prob_list, prob) gt_list = np.append(gt_list, gt) # assign save dataframe to item item.df = save_df # save dataframe with predicted label column item.save_to_parquet( output_df_folder_test, drop_zero_label=False, drop_pixel_col=True ) # sanity check all have same gt label map if gt_label_map is None: gt_label_map = item.gt_label_map else: assert item.gt_label_map == gt_label_map # print("Sanity check... ") # print("gt", len(gt_list), gt_list) # calculate precision recall curve gt_list = gt_list.flatten() prob_list = prob_list.flatten().round(2) pr, rec, pr_threshold = precision_recall_curve( gt_list, prob_list, pos_label=1 ) baseline = len(gt_list[gt_list == 1]) / len(gt_list) # plot pr curve # save_loc = os.path.join(output_test_pr, "_curve.pkl") # plot_pr_curve( # ax_test, # method.capitalize(), # linestyles[index], # "darkorange", # pr, # rec, # baseline, # # save_loc, # # pickle=True, # ) pr_auc = auc(rec, pr) add_metrics = { "pr_auc": pr_auc, "prcurve_pr": list(pr), "prcurve_rec": list(rec), "prcurve_baseline": baseline, } # metric calculations based on final prediction save_loc = os.path.join(output_metrics, f"test_{date}.txt") _ = metrics.aggregated_metrics( output_df_folder_test, save_loc, gt_label_map, add_metrics=add_metrics, metadata=metadata, ) # calculate confusion matrix # saveloc = os.path.join(output_conf_matrix, f"conf_matrix_test_{date}.png") # classes = [item.gt_label_map[0], item.gt_label_map[1]] # generate_binary_conf_matrix(tn, fp, fn, tp, classes, saveloc) # could just use aggregated metric function to plot the confusion matrix print("Val set...") prob_list = np.array([]) gt_list = np.array([]) # sanity check all have same gt label map gt_label_map = None # threshold dataframe and save to parquet file with pred label for file in files: if file.removesuffix(".parquet") not in val_files: continue print("File ", file) item = datastruc.item(None, None, None, None, None) item.load_from_parquet(os.path.join(gt_file_path, file)) # load prob map img_prob = np.load(os.path.join(seg_folder, item.name + ".npy")) # merge prob map and df merged_df = item.mask_pixel_2_coord(img_prob) merged_df = merged_df.rename({"pred_label": "prob"}) # get gt and predicted probability gt = merged_df.select(pl.col("gt_label")).to_numpy() prob = merged_df.select(pl.col("prob")).to_numpy() save_df = merged_df.select( [ pl.all(), pl.when(pl.col("prob") > threshold) .then(1) .otherwise(0) .alias("pred_label"), ] ) # append to aggregated data set prob_list = np.append(prob_list, prob) gt_list = np.append(gt_list, gt) # assign save dataframe to item item.df = save_df # save dataframe with predicted label column item.save_to_parquet( output_df_folder_val, drop_zero_label=False, drop_pixel_col=True ) # sanity check all have same gt label map if gt_label_map is None: gt_label_map = item.gt_label_map else: assert item.gt_label_map == gt_label_map # print("Sanity check... ") # print("gt", len(gt_list), gt_list) # calculate precision recall curve gt_list = gt_list.flatten() prob_list = prob_list.flatten().round(2) pr, rec, pr_threshold = precision_recall_curve( gt_list, prob_list, pos_label=1 ) baseline = len(gt_list[gt_list == 1]) / len(gt_list) # plot pr curve # save_loc = os.path.join(output_val_pr, "_curve.pkl") # plot_pr_curve( # ax_val, # method.capitalize(), # linestyles[index], # "darkorange", # pr, # rec, # baseline, # # save_loc, # # pickle=True, # ) pr_auc = auc(rec, pr) add_metrics = { "pr_auc": pr_auc, "prcurve_pr": list(pr), "prcurve_rec": list(rec), "prcurve_baseline": baseline, } # metric calculations based on final prediction save_loc = os.path.join(output_metrics, f"val_{date}.txt") _ = metrics.aggregated_metrics( output_df_folder_val, save_loc, gt_label_map, add_metrics=add_metrics, metadata=metadata, ) # calculate confusion matrix # saveloc = os.path.join(output_conf_matrix, f"conf_matrix_val_{date}.png") # classes = [item.gt_label_map[0], item.gt_label_map[1]] # generate_binary_conf_matrix(tn, fp, fn, tp, classes, saveloc) # fig_train.tight_layout() # fig_test.tight_layout() # fig_val.tight_layout() # get handles and labels # handles_train, labels_train = ax_train.get_legend_handles_labels() # handles_test, labels_test = ax_test.get_legend_handles_labels() # specify order of items in legend # order = [1,0,2] # add legend to plot # ax_train.legend([handles_train[idx] for idx in order], # [methods[idx] for idx in order]) # ax_test.legend([handles_test[idx] for idx in order], # [methods[idx] for idx in order]) # fig_train.savefig(os.path.join(output_overlay_pr_curves, # "_train.png"), dpi=600) # fig_test.savefig(os.path.join(output_overlay_pr_curves, # "_test.png"), dpi=600) # fig_val.savefig(os.path.join(output_overlay_pr_curves, # "_val.png"), dpi=600) # save yaml file yaml_save_loc = os.path.join(project_folder, "membrane_performance.yaml") with open(yaml_save_loc, "w") as outfile: yaml.dump(config, outfile)
if __name__ == "__main__": main()