Source code for locpix.scripts.img_seg.train_prep

#!/usr/bin/env python
"""Train prep module

Split the dataset into train and test
Save these splits and also perform k fold split
Save all results in metadata and
"""

import yaml
import os
import argparse
import json
import time
from sklearn.model_selection import KFold


[docs] def main(): parser = argparse.ArgumentParser(description="Train prep") parser.add_argument( "-i", "--project_directory", action="store", type=str, help="the location of the project directory", required=True, ) parser.add_argument( "-c", "--config", action="store", type=str, help="the location of the .yaml configuaration file\ for train prep", required=True, ) parser.add_argument( "-m", "--project_metadata", action="store_true", help="check the metadata for the specified project and" "seek confirmation!", ) args = parser.parse_args() project_folder = args.project_directory # load config with open(args.config, "r") as ymlfile: config = yaml.safe_load(ymlfile) # list items input_folder = os.path.join(project_folder, "annotate/annotated") try: os.listdir(input_folder) except FileNotFoundError: raise ValueError("There should be some files to open") # split train/test train_files = config["train_files"] test_files = config["test_files"] # check files if not set(train_files).isdisjoint(test_files): raise ValueError("Train files and test files shared files!!") if len(set(train_files)) != len(train_files): raise ValueError("Train files contains duplicates") if len(set(test_files)) != len(test_files): raise ValueError("Test files contains duplicates") # create train and test folder train_folder = os.path.join(project_folder, "train_files") test_folder = os.path.join(project_folder, "test_files") folders = [ train_folder, test_folder, ] for folder in folders: if os.path.exists(folder): raise ValueError(f"Cannot proceed as {folder} already exists") else: os.makedirs(folder) # split train into folds folds = 5 kf = KFold(n_splits=folds, shuffle=True) train_folds = [] val_folds = [] for (train_index, val_index) in kf.split(train_files): train_fold = [] val_fold = [] for index in train_index: train_fold.append(train_files[index]) for index in val_index: val_fold.append(train_files[index]) train_folds.append(train_fold) val_folds.append(val_fold) # check files for fold in range(folds): # check train val test files train_files = train_folds[fold] val_files = val_folds[fold] # check files if not set(train_files).isdisjoint(test_files): raise ValueError("Train files and test files shared files!!") if not set(train_files).isdisjoint(val_files): raise ValueError("Train files and val files shared files!!") if not set(val_files).isdisjoint(test_files): raise ValueError("Val files and test files shared files!!") if len(set(train_files)) != len(train_files): raise ValueError("Train files contains duplicates") if len(set(val_files)) != len(val_files): raise ValueError("Val files contains duplicates") if len(set(test_files)) != len(test_files): raise ValueError("Test files contains duplicates") print("Fold", fold) print("Train files") print(train_files) print("Val files") print(val_files) print("Test files") print(test_files) # load and add to metadata metadata_path = os.path.join(project_folder, "metadata.json") with open( metadata_path, ) as file: metadata = json.load(file) # check metadata if args.project_metadata: print("".join([f"{key} : {value} \n" for key, value in metadata.items()])) check = input("Are you happy with this? (YES)") if check != "YES": exit() # add train val test split metadata["train_files"] = train_files + val_files metadata["test_files"] = test_files metadata["train_folds"] = train_folds metadata["val_folds"] = val_folds # add time ran this script to metadata file = os.path.basename(__file__) if file not in metadata: metadata[file] = time.asctime(time.gmtime(time.time())) else: print("Overwriting metadata...") metadata[file] = time.asctime(time.gmtime(time.time())) with open(metadata_path, "w") as outfile: json.dump(metadata, outfile) # save yaml file yaml_save_loc = os.path.join(project_folder, "train_prep.yaml") with open(yaml_save_loc, "w") as outfile: yaml.dump(config, outfile)
if __name__ == "__main__": main()