#!/usr/bin/env python

"""
This module is for training the model. See Readme.md for more details about training your own model.

Examples:
    Run local:
    $ ExpRunner --config=XXX

    Predicting with new config setup:
    $ ExpRunner --train=False --test=True --lw --config=XXX
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import warnings
import os
import importlib
import argparse
import pickle as pkl
from pprint import pprint
import distutils.util
from os.path import join

import nibabel as nib
import numpy as np
import torch

from tractseg.libs import direction_merger
from tractseg.libs import exp_utils
from tractseg.libs import img_utils
from tractseg.libs import peak_utils
from tractseg.libs.system_config import SystemConfig as C
from tractseg.libs import trainer
from tractseg.data.data_loader_training import DataLoaderTraining as DataLoaderTraining2D
from tractseg.data.data_loader_training_3D import DataLoaderTraining as DataLoaderTraining3D
from tractseg.data.data_loader_inference import DataLoaderInference
from tractseg.data import dataset_specific_utils
from tractseg.models.base_model import BaseModel

warnings.simplefilter("ignore", UserWarning)  # hide scipy warnings
warnings.simplefilter("ignore", FutureWarning)  # hide h5py warnings
warnings.filterwarnings("ignore", message="numpy.dtype size changed")  # hide Cython benign warning
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")  # hide Cython benign warning


def main():
    parser = argparse.ArgumentParser(description="Train a network on your own data to segment white matter bundles.",
                                        epilog="Written by Jakob Wasserthal. Please reference 'Wasserthal et al. "
                                               "TractSeg - Fast and accurate white matter tract segmentation. "
                                               "https://doi.org/10.1016/j.neuroimage.2018.07.070)'")
    parser.add_argument("--config", metavar="name", help="Name of configuration to use")
    parser.add_argument("--train", metavar="True/False", help="Train network",
                        type=distutils.util.strtobool, default=True)
    parser.add_argument("--test", metavar="True/False", help="Test network",
                        type=distutils.util.strtobool, default=False)
    parser.add_argument("--seg", action="store_true", help="Create binary segmentation")
    parser.add_argument("--probs", action="store_true", help="Create probmap segmentation")
    parser.add_argument("--lw", action="store_true", help="Load weights of pretrained net")
    parser.add_argument("--only_val", action="store_true", help="only run validation")
    parser.add_argument("--en", metavar="name", help="Experiment name")
    parser.add_argument("--fold", metavar="N", help="Which fold to train when doing CrossValidation", type=int)
    parser.add_argument("--verbose", action="store_true", help="Show more intermediate output", default=True)
    args = parser.parse_args()


    Config = getattr(importlib.import_module("tractseg.experiments.base"), "Config")()
    if args.config:
        # Config.__dict__ does not work properly therefore use this approach
        Config = getattr(importlib.import_module("tractseg.experiments.custom." + args.config), "Config")()

    if args.en:
        Config.EXP_NAME = args.en

    Config.TRAIN = bool(args.train)
    Config.TEST = bool(args.test)
    if args.seg:
        Config.SEGMENT = True
    if args.probs:
        Config.GET_PROBS = True
    if args.lw:
        Config.LOAD_WEIGHTS = args.lw
    if args.fold:
        Config.CV_FOLD = args.fold
    if args.only_val:
        Config.ONLY_VAL = True
    Config.VERBOSE = args.verbose

    Config.MULTI_PARENT_PATH = join(C.EXP_PATH, Config.EXP_MULTI_NAME)
    Config.EXP_PATH = join(C.EXP_PATH, Config.EXP_MULTI_NAME, Config.EXP_NAME)
    Config.TRAIN_SUBJECTS, Config.VALIDATE_SUBJECTS, Config.TEST_SUBJECTS = dataset_specific_utils.get_cv_fold(Config.CV_FOLD,
                                                                                                  dataset=Config.DATASET)

    if Config.WEIGHTS_PATH == "":
        Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(Config.EXP_PATH, Config.LOAD_WEIGHTS)

    #Autoset input dimensions based on settings
    Config.INPUT_DIM = dataset_specific_utils.get_correct_input_dim(Config)
    Config = dataset_specific_utils.get_labels_filename(Config)

    if Config.EXPERIMENT_TYPE == "peak_regression":
        Config.NR_OF_CLASSES = 3*len(dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
    else:
        Config.NR_OF_CLASSES = len(dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

    if Config.TRAIN and not Config.ONLY_VAL:
        Config.EXP_PATH = exp_utils.create_experiment_folder(Config.EXP_NAME, Config.MULTI_PARENT_PATH, Config.TRAIN)

    if Config.DIM != "2D":
        Config.EPOCH_MULTIPLIER = 3

    if Config.VERBOSE:
        print("Hyperparameters:")
        exp_utils.print_Configs(Config)

    with open(join(Config.EXP_PATH, "Hyperparameters.txt"), "w") as f:
        Config_dict = {attr: getattr(Config, attr) for attr in dir(Config)
                       if not callable(getattr(Config, attr)) and not attr.startswith("__")}
        pprint(Config_dict, f)

    Config = exp_utils.get_correct_labels_type(Config)  # do after saving Hyperparameters as txt

    pkl.dump(Config, open(join(Config.EXP_PATH, "Hyperparameters.pkl"), "wb"))

    model = BaseModel(Config)
    if Config.DIM == "2D":
        data_loader = DataLoaderTraining2D(Config)
    else:
        data_loader = DataLoaderTraining3D(Config)

    if Config.TRAIN:
        print("Training...")
        trainer.train_model(Config, model, data_loader)

    # After Training
    if Config.TRAIN and not Config.ONLY_VAL:
        # Have to load other weights, because after training it has the weights of the last epoch
        print("Loading best epoch: {}".format(Config.BEST_EPOCH))
        Config.WEIGHTS_PATH = Config.EXP_PATH + "/best_weights_ep" + str(Config.BEST_EPOCH) + ".npz"
        Config.LOAD_WEIGHTS = True
        model.load_model(join(Config.EXP_PATH, Config.WEIGHTS_PATH))
        model_test = model
    else:
        # Weight_path already set to best model (when reading program parameters) -> will be loaded automatically
        model_test = model

    if Config.SEGMENT:
        exp_utils.make_dir(join(Config.EXP_PATH, "segmentations"))
        all_subjects = Config.VALIDATE_SUBJECTS
        for subject in all_subjects:
            print("Get_segmentation subject {}".format(subject))

            if Config.EXPERIMENT_TYPE == "peak_regression":
                data_loader = DataLoaderInference(Config, subject=subject)
                img_probs, img_y = trainer.predict_img(Config, model_test, data_loader,
                                                       probs=True)  # only x or y or z
                img_seg = peak_utils.peak_image_to_binary_mask(img_probs,
                                                              len_thr=0.4)  # thr: 0.4 slightly better than 0.2
            else:
                # returns probs not binary seg
                img_seg, img_y = direction_merger.get_seg_single_img_3_directions(Config, model, subject)
                img_seg = direction_merger.mean_fusion(Config.THRESHOLD, img_seg, probs=False)

            img = nib.Nifti1Image(img_seg.astype(np.uint8),
                                  dataset_specific_utils.get_dwi_affine(Config.DATASET, Config.RESOLUTION))
            nib.save(img, join(Config.EXP_PATH, "segmentations", subject + "_segmentation.nii.gz"))

    if Config.TEST:
        trainer.test_whole_subject(Config, model_test, Config.VALIDATE_SUBJECTS, "validate")

    if Config.GET_PROBS:
        exp_utils.make_dir(join(Config.EXP_PATH, "probmaps"))
        all_subjects = Config.VALIDATE_SUBJECTS
        for subject in all_subjects:
            print("Get_probs subject {}".format(subject))
            data_loader = DataLoaderInference(Config, subject=subject)

            if Config.EXPERIMENT_TYPE == "peak_regression":
                img_probs, img_y = trainer.predict_img(Config, model_test, data_loader, probs=True)
                img_probs = peak_utils.remove_small_peaks(img_probs, len_thr=0.4)
            else:
                img_probs, img_y = direction_merger.get_seg_single_img_3_directions(Config, model, subject=subject)
                img_probs = direction_merger.mean_fusion(Config.THRESHOLD, img_probs, probs=True)

            img = nib.Nifti1Image(img_probs, dataset_specific_utils.get_dwi_affine(Config.DATASET, Config.RESOLUTION))
            nib.save(img, join(Config.EXP_PATH, "probmaps", subject + "_peak.nii.gz"))

    eval_script_path = join(os.path.expanduser("~"), "dev/bsp/eval_nonHCP.py")
    if os.path.exists(eval_script_path) and not Config.ONLY_VAL and not Config.TEST:
        print("Evaluating on non HCP data...")
        # Free GPU memory  (2.5GB remaining otherwise)
        del model, model_test
        torch.cuda.empty_cache()
        os.system("python " + eval_script_path + " " + Config.EXP_NAME)

if __name__ == '__main__':
    main()
