#!/usr/bin/env python
"""UNET segmentation module
Take in items and trains UNET
"""
import yaml
import os
import argparse
import json
import time
from locpix.img_processing.models.unet import two_d_UNet
from locpix.scripts.img_seg import img_train_prep
from locpix.img_processing.training import train, loss
from locpix.img_processing.data_loading import dataset
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import numpy as np
from locpix.img_processing import watershed
from locpix.preprocessing import datastruc
import wandb
[docs]
def main():
# Load in config
parser = argparse.ArgumentParser(description="Train UNET.")
parser.add_argument(
"-i",
"--project_directory",
action="store",
type=str,
help="the location of the project directory",
required=True,
)
parser.add_argument(
"-c",
"--config",
action="store",
type=str,
help="the location of the .yaml configuaration file\
for preprocessing",
required=True,
)
parser.add_argument(
"-m",
"--project_metadata",
action="store_true",
help="check the metadata for the specified project and" "seek confirmation!",
)
args = parser.parse_args()
project_folder = args.project_directory
# load config
with open(args.config, "r") as ymlfile:
config = yaml.safe_load(ymlfile)
metadata_path = os.path.join(project_folder, "metadata.json")
with open(
metadata_path,
) as file:
metadata = json.load(file)
# check metadata
if args.project_metadata:
print("".join([f"{key} : {value} \n" for key, value in metadata.items()]))
check = input("Are you happy with this? (YES)")
if check != "YES":
exit()
# add time ran this script to metadata
file = os.path.basename(__file__)
if file not in metadata:
metadata[file] = time.asctime(time.gmtime(time.time()))
else:
print("Overwriting metadata...")
metadata[file] = time.asctime(time.gmtime(time.time()))
with open(metadata_path, "w") as outfile:
json.dump(metadata, outfile)
# for every fold
folds = len(metadata["train_folds"])
# make folder
time_o = time.gmtime(time.time())
time_o = (
f"time_{time_o[3]}_{time_o[4]}_{time_o[4]}"
f"_date_{time_o[2]}_{time_o[1]}_{time_o[0]}"
)
# model_save_path = os.path.join(project_folder, "unet_train", time_o)
# folders = [
# model_save_path,
# ]
# for folder in folders:
# if os.path.exists(folder):
# raise ValueError(f"Cannot proceed as {folder} already exists")
# else:
# os.makedirs(folder)
print("------ Training --------")
for fold in range(folds):
print(f"----- Fold {fold} -------")
# make folder
model_save_path = os.path.join(project_folder, f"unet/models/{fold}")
if os.path.exists(model_save_path):
raise ValueError(f"Cannot proceed as {model_save_path} already exists")
else:
os.makedirs(model_save_path)
# image train prep
img_train_prep.preprocess_train_files(
project_folder, config, metadata, fold, "unet"
)
# train model
train_folder = os.path.abspath(
os.path.join(project_folder, "train_files/unet/train")
)
val_folder = os.path.abspath(
os.path.join(project_folder, "train_files/unet/val")
)
lr = config["learning_rate"]
wd = config["weight_decay"]
epochs = config["epochs"]
batch_size = config["batch_size"]
num_workers = config["num_workers"]
tf = config["train_transforms"]
if config["use_gpu"] is True and torch.cuda.is_available():
device = torch.device("cuda")
pin_memory = False
elif config["use_gpu"] is False:
pin_memory = True
device = torch.device("cpu")
else:
raise ValueError("Specify cpu or gpu !")
print("Device: ", device)
# initialise model and optimiser
model = two_d_UNet(1, 1, bilinear=False)
optimiser = Adam(model.parameters(), lr=lr, weight_decay=wd)
# init transforms, dataset and dataloader
train_files = metadata["train_folds"][fold]
val_files = metadata["val_folds"][fold]
train_dataset = dataset.ImgDataset(
train_folder,
train_files,
tf,
train=True,
)
val_dataset = dataset.ImgDataset(
val_folder,
val_files,
tf,
train=False,
mean=train_dataset.mean,
std=train_dataset.std,
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=pin_memory,
num_workers=num_workers,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=pin_memory,
num_workers=num_workers,
)
# initialise loss function
if config["loss_fn"] == "bce":
print("Using BCE loss!")
loss_fn = torch.nn.BCEWithLogitsLoss()
elif config["loss_fn"] == "dice":
print("Using DICE loss")
loss_fn = loss.dice_loss()
else:
raise ValueError("Loss function must be specified in config")
# initialise wandb
# start a new wandb run to track this script
wandb.init(
# set the wandb project where this run will be logged
project=config["wandb_project"],
# track hyperparameters and run metadata
config={
"learning_rate": lr,
"dataset": config["wandb_dataset"],
"epochs": epochs,
"fold": fold,
},
)
# train loop
model = train.train_loop(
epochs,
model,
optimiser,
train_loader,
val_loader,
loss_fn,
device,
os.path.join(model_save_path, f"{time_o}_model.pt"),
)
# need to save mean and std from normalisation in model path
mean = torch.tensor(train_dataset.mean)
std = torch.tensor(train_dataset.std)
np.save(os.path.join(model_save_path, f"{time_o}_mean.npy"), mean)
np.save(os.path.join(model_save_path, f"{time_o}_std.npy"), std)
# clean up
img_train_prep.clean_up(project_folder, "unet")
# stop wandb
wandb.finish()
print("------ Outputting for evaluation -------- ")
# prep images
img_train_prep.preprocess_all_files(project_folder, config, metadata, "unet")
img_folder = os.path.join(project_folder, "train_files/unet/all")
# create dataset
train_files = metadata["train_files"]
test_files = metadata["test_files"]
all_files = train_files + test_files
tf = config["test_transforms"]
for fold in range(folds):
model_folder = os.path.join(project_folder, f"unet/models/{fold}")
mean_path = os.path.join(model_folder, f"{time_o}_mean.npy")
std_path = os.path.join(model_folder, f"{time_o}_std.npy")
img_dataset = dataset.ImgDataset(
img_folder,
all_files,
tf,
train=False,
mean=np.load(mean_path),
std=np.load(std_path),
)
# output folder
cell_seg_df_folder = os.path.join(
project_folder, f"unet/{fold}/cell/seg_dataframes"
)
cell_seg_img_folder = os.path.join(project_folder, f"unet/{fold}/cell/seg_img")
cell_memb_folder = os.path.join(
project_folder, f"unet/{fold}/membrane/prob_map"
)
folders = [
cell_memb_folder,
cell_seg_df_folder,
cell_seg_img_folder,
]
for folder in folders:
if os.path.exists(folder):
raise ValueError(f"Cannot proceed as {folder} already exists")
else:
os.makedirs(folder)
# run through model
for index in range(img_dataset.__len__()):
with torch.no_grad():
# get img
img, label = img_dataset.__getitem__(index)
# expand img dimensions so has batch of 1?
img_name = img_dataset.input_data[index]
img_name = os.path.basename(img_name)
img_name = os.path.splitext(img_name)[0]
img = torch.unsqueeze(img, 0)
# make sure model in eval mode
model.eval()
# note set to none is meant to have less memory footprint
optimiser.zero_grad(set_to_none=True)
# move data to device
img = img.to(device)
# label = label.to(device)
# forward pass - with autocasting
with torch.autocast(device_type="cuda"):
output = model(img)
output = torch.sigmoid(output)
# reduce batch and channel dimensions
output = torch.squeeze(output, 0)
output = torch.squeeze(output, 0)
# convert tensor to numpy
output = output.cpu().numpy()
# ---- segment cells ----
# get markers
markers_loc = os.path.join(project_folder, "markers")
markers_loc = os.path.join(markers_loc, img_name + ".npy")
try:
markers = np.load(markers_loc)
except FileNotFoundError:
raise ValueError(
"Couldn't open the file/No markers were found in relevant location"
)
# tested very small amount annd line below is better than doing
# watershed on grey_log_img
instance_mask = watershed.watershed_segment(
output, coords=markers
) # watershed on the grey image
# ---- save ----
# save membrane mask
save_loc = os.path.join(cell_memb_folder, img_name + ".npy")
np.save(save_loc, output)
# save instance mask to dataframe
item_path = os.path.join(
project_folder, "annotate/annotated", img_name + ".parquet"
)
item = datastruc.item(None, None, None, None, None)
item.load_from_parquet(item_path)
df = item.mask_pixel_2_coord(instance_mask)
item.df = df
item.save_to_parquet(
cell_seg_df_folder, drop_zero_label=False, drop_pixel_col=True
)
# save cell segmentation image (as .npy) - consider only one channel
save_loc = os.path.join(cell_seg_img_folder, item.name + ".npy")
np.save(save_loc, instance_mask)
# clean up
img_train_prep.clean_up_all(project_folder, "unet")
# save train yaml file
yaml_save_loc = os.path.join(project_folder, "unet.yaml")
with open(yaml_save_loc, "w") as outfile:
yaml.dump(config, outfile)
if __name__ == "__main__":
main()