#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import warnings
import argparse
import importlib
import os
from os.path import join
import sys
from importlib.metadata import version
import nibabel as nib

from tractseg.libs.system_config import get_config_name
from tractseg.libs import exp_utils
from tractseg.libs import img_utils
from tractseg.libs import preprocessing
from tractseg.libs import plot_utils
from tractseg.libs import peak_utils
from tractseg.python_api import run_tractseg
from tractseg.libs.utils import bcolors
from tractseg.libs.system_config import SystemConfig as C
from tractseg.data import dataset_specific_utils

warnings.simplefilter("ignore", UserWarning)  # hide scipy warnings
warnings.simplefilter("ignore", FutureWarning)  # hide h5py warnings
warnings.filterwarnings("ignore", message="numpy.dtype size changed")  # hide Cython benign warning
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")  # hide Cython benign warning


def main():
    parser = argparse.ArgumentParser(description="Segment white matter bundles in a Diffusion MRI image.",
                                        epilog="Written by Jakob Wasserthal. Please reference 'Wasserthal et al. "
                                               "TractSeg - Fast and accurate white matter tract segmentation'. "
                                               "https://doi.org/10.1016/j.neuroimage.2018.07.070'")

    parser.add_argument("-i", metavar="filepath", dest="input",
                        help="CSD peaks in MRtrix format (4D Nifti image with dimensions [x,y,z,9])", required=True)

    parser.add_argument("-o", metavar="directory", dest="output",
                        help="Output directory (default: directory of input file)")

    parser.add_argument("--single_output_file", action="store_true",
                        help="Output all bundles in one file (4D image)",
                        default=False)

    parser.add_argument("--csd_type", metavar="csd|csd_msmt|csd_msmt_5tt", choices=["csd", "csd_msmt", "csd_msmt_5tt"],
                        help="Which MRtrix constrained spherical deconvolution (CSD) is used for peak generation.\n"
                             "'csd' [DEFAULT]: Standard CSD. Very fast.\n"
                             "'csd_msmt': Multi-shell multi-tissue CSD DHollander algorithm. Medium fast. Needs "
                             "more than one b-value shell.\n"
                             "'csd_msmt_5tt': Multi-shell multi-tissue CSD 5TT. Slow on large images. Needs more "
                             "than one b-value shell."
                             "Needs a T1 image with all non-brain area removed (a file "
                             "'T1w_acpc_dc_restore_brain.nii.gz' must be in the input directory).",
                        default="csd")

    parser.add_argument("--output_type", metavar="tract_segmentation|endings_segmentation|TOM|dm_regression",
                        choices=["tract_segmentation", "endings_segmentation", "TOM", "dm_regression"],
                        help="TractSeg can segment not only bundles, but also the end regions of bundles. "
                             "Moreover it can create Tract Orientation Maps (TOM).\n"
                             "'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles).\n"
                             "'endings_segmentation': Segmentation of bundle end regions (72 bundles).\n"
                             "'TOM': Tract Orientation Maps (20 bundles).",
                        default="tract_segmentation")

    parser.add_argument("--bvals", metavar="filename",
                        help="bvals file. Default is '<name_of_input_file>.bvals' in same directory as input")

    parser.add_argument("--bvecs", metavar="filename",
                        help="bvecs file. Default is '<name_of_input_file>.bvecs' in same directory as input")

    parser.add_argument("--brain_mask", metavar="filename",
                        help="Manually define brain mask file. If not specified will look for file "
                             "nodif_brain_mask.nii.gz in same folder as input and if not found create one using "
                             "fsl bet. Brain mask only needed if using '--raw_diffusion_input'.")

    parser.add_argument("--raw_diffusion_input", action="store_true",
                        help="Provide a Diffusion nifti image as argument to -i. "
                             "Will calculate CSD and extract the mean peaks needed as input for TractSeg.",
                        default=False)

    parser.add_argument("--keep_intermediate_files", action="store_true",
                        help="Do not remove intermediate files like CSD output and peaks",
                        default=False)

    parser.add_argument("--preview", action="store_true", help="Save preview of some tracts as png. Requires VTK.",
                        default=False)

    parser.add_argument("--flip", action="store_true",
                        help="Flip output peaks of TOM along z axis to make compatible with MITK.",
                        default=False)

    parser.add_argument("--single_orientation", action="store_true",
                        help="Do not run model 3x along x/y/z orientation with subsequent mean fusion.",
                        default=False)

    parser.add_argument("--get_probabilities", action="store_true",
                        help="Output probability map instead of binary segmentation (without any postprocessing)",
                        default=False)

    parser.add_argument("--super_resolution", action="store_true",
                        help="Keep 1.25mm resolution of model instead of downsampling back to original resolution",
                        default=False)

    parser.add_argument("--uncertainty", action="store_true",
                        help="Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)",
                        default=False)

    parser.add_argument("--no_postprocess", action="store_true",
                        help="Deactivate simple postprocessing of segmentations (removal of small blobs)",
                        default=False)

    parser.add_argument("--preprocess", action="store_true",
                        help="Move input image to MNI space (rigid registration of FA). "
                             "(Does not work together with csd_type=csd_msmt_5tt)",
                        default=False)

    parser.add_argument("--nr_cpus", metavar="n", type=int,
                        help="Number of CPUs to use. -1 means all available CPUs (default: -1)",
                        default=-1)

    parser.add_argument('--tract_segmentation_output_dir', metavar="folder_name",
                        help="name of bundle segmentations output folder (default: bundle_segmentations)",
                        default="bundle_segmentations")

    parser.add_argument('--TOM_output_dir', metavar="folder_name",
                        help="name of TOM output folder (default: TOM)",
                        default="TOM")

    parser.add_argument('--exp_name', metavar="folder_name", help="name of experiment - ONLY FOR TESTING",
                        default=None)

    parser.add_argument('--tract_definition', metavar="TractQuerier+|xtract", choices=["TractQuerier+", "xtract"],
                        help="Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their"
                             "cortical start and end region. 'xtract' defines tracts mainly by ROIs in white matter. "
                             "Both have their advantages and disadvantages. 'TractQuerier+' referes to the dataset "
                             "described the TractSeg NeuroImage paper. "
                             "NOTE 1: 'xtract' only works for output type 'tractseg_segmentation' and "
                             "'dm_regression'.",
                        default="TractQuerier+")

    parser.add_argument("--rescale_dm", action="store_true",
                        help="Rescale density map to [0,100] range. Original values can be very small and therefore "
                             "inconvenient to work with.",
                        default=False)

    parser.add_argument("--tract_segmentations_path", metavar="path",
                        help="Path to tract segmentations. Only needed for TOM. If empty will look for default "
                             "TractSeg output.",
                        default=None)

    parser.add_argument("--test", action="store_true",
                        help="Only needed for unittesting.",
                        default=False)

    parser.add_argument("--verbose", action="store_true", help="Show more intermediate output",
                        default=False)

    parser.add_argument('--version', action='version', version=version("TractSeg"))

    args = parser.parse_args()


    ####################################### Set more parameters #######################################

    input_type = "peaks"  # peaks|T1
    threshold = 0.5          # specificity (for tract_segmentation and endings_segmentation)
    peak_threshold = 0.3     # specificity (for TOM)
    blob_size_thr = 25  # default: 50
    manual_exp_name = args.exp_name
    # inference_batch_size:
    #   if using 48 -> 30% faster runtime on CPU but needs 30GB RAM instead of 4.5GB
    #   if using 5 -> 12% faster runtime on CPU
    inference_batch_size = 1
    TOM_dilation = 1  # 1 also ok for HCP because in tracking again filtered by mask
    bedpostX_input = False
    postprocess = not args.no_postprocess
    bundle_specific_postprocessing = True
    dropout_sampling = args.uncertainty
    input_path = args.input
    single_orientation = args.single_orientation
    if args.output_type == "TOM":
        single_orientation = True


    ####################################### Setup configuration #######################################

    if os.path.basename(input_path) == "dyads1.nii.gz":
        print("BedpostX dyads detected. Will automatically combine dyads1+2[+3].")
        bedpostX_input = True

    if manual_exp_name is None:
        config_file = get_config_name(input_type, args.output_type, dropout_sampling=dropout_sampling,
                                      tract_definition=args.tract_definition)
        Config = getattr(importlib.import_module("tractseg.experiments.pretrained_models." +
                                                 config_file), "Config")()
    else:
        Config = exp_utils.load_config_from_txt(join(C.EXP_PATH,
                                                     exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"),
                                                     "Hyperparameters.txt"))

    Config = exp_utils.get_correct_labels_type(Config)
    Config.CSD_TYPE = args.csd_type
    Config.KEEP_INTERMEDIATE_FILES = args.keep_intermediate_files
    Config.VERBOSE = args.verbose
    Config.SINGLE_OUTPUT_FILE = args.single_output_file
    Config.FLIP_OUTPUT_PEAKS = args.flip
    Config.PREDICT_IMG = input_path is not None
    if args.output:
        Config.PREDICT_IMG_OUTPUT = args.output
    elif Config.PREDICT_IMG:
        Config.PREDICT_IMG_OUTPUT = join(os.path.dirname(input_path), Config.TRACTSEG_DIR)
    tensor_model = Config.NR_OF_GRADIENTS == 18 * Config.NR_SLICES

    bvals, bvecs = exp_utils.get_bvals_bvecs_path(args)
    exp_utils.make_dir(Config.PREDICT_IMG_OUTPUT)

    if args.tract_segmentations_path is not None:
        tract_segmentations_path = args.tract_segmentations_path
    else:
        if Config.EXPERIMENT_TYPE != "tract_segmentation" and args.preprocess:
            tract_segmentations_path = join(Config.PREDICT_IMG_OUTPUT, "bundle_segmentations_MNI")
        else:
            tract_segmentations_path = join(Config.PREDICT_IMG_OUTPUT, "bundle_segmentations")


    ####################################### Preprocessing #######################################

    if args.preprocess and args.output_type == "TOM":
        raise ValueError("The preprocess option does not working with output_type TOM.")

    if args.raw_diffusion_input:
        brain_mask = exp_utils.get_brain_mask_path(Config.PREDICT_IMG_OUTPUT, args.brain_mask, args.input)

        if brain_mask is None:
            brain_mask = preprocessing.create_brain_mask(input_path, Config.PREDICT_IMG_OUTPUT)

        if args.preprocess:
            if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "dm_regression":
                input_path, bvals, bvecs, brain_mask = preprocessing.move_to_MNI_space(input_path, bvals, bvecs, brain_mask,
                                                                                       Config.PREDICT_IMG_OUTPUT)
            else:
                if not os.path.exists(join(Config.PREDICT_IMG_OUTPUT / "FA_2_MNI.mat")):
                    raise FileNotFoundError("Could not find file " + join(Config.PREDICT_IMG_OUTPUT / "FA_2_MNI.mat") +
                                            ". Run with options `--output_type tract_segmentation --preprocess` first.")
                if not os.path.exists(join(Config.PREDICT_IMG_OUTPUT / "MNI_2_FA.mat")):
                    raise FileNotFoundError("Could not find file " + join(Config.PREDICT_IMG_OUTPUT / "MNI_2_FA.mat") +
                                            ". Run with options `--output_type tract_segmentation --preprocess` first.")

        preprocessing.create_fods(input_path, Config.PREDICT_IMG_OUTPUT, bvals, bvecs,
                                  brain_mask, Config.CSD_TYPE, nr_cpus=args.nr_cpus)

    if args.raw_diffusion_input:
        peak_path = join(Config.PREDICT_IMG_OUTPUT, "peaks.nii.gz")
        data_img = nib.load(peak_path)
    else:
        peak_path = input_path
        if bedpostX_input:
            data_img = peak_utils.load_bedpostX_dyads(peak_path, scale=True, tensor_model=tensor_model)
        else:
            data_img = nib.load(peak_path)
        data_img_shape = data_img.get_fdata().shape
        if Config.NR_OF_GRADIENTS != 1 and not (len(data_img_shape) == 4 and
                                                data_img_shape[3] == Config.NR_OF_GRADIENTS):
            print(bcolors.ERROR + "ERROR" + bcolors.ENDC + bcolors.BOLD +
                  ": Input image must be a peak image (nifti 4D image with dimensions [x,y,z,9]). " +
                  "If you input a Diffusion image add the option '--raw_diffusion_input'." + bcolors.ENDC)
            sys.exit()
        if Config.NR_OF_GRADIENTS == 1 and not len(data_img_shape) == 3:
            print(bcolors.ERROR + "ERROR" + bcolors.ENDC + bcolors.BOLD +
                  ": Input image must be a 3D image (nifti 3D image with dimensions [x,y,z]). " + bcolors.ENDC)
            sys.exit()

    if tensor_model:
        data_img = peak_utils.peaks_to_tensors_nifti(data_img)

    if input_type == "T1" or Config.NR_OF_GRADIENTS == 1:
        data_img = nib.Nifti1Image(data_img.get_fdata()[..., None], data_img.affine)  # add fourth dimension

    if args.super_resolution:
        data_img = img_utils.change_spacing_4D(data_img, new_spacing=1.25)

    data_affine = data_img.affine
    data = data_img.get_fdata()
    del data_img     # free memory

    # Make image have the same signs of the affine as MNI space
    data, flip_axis = img_utils.flip_axis_to_match_MNI_space(data, data_affine)

    if len(flip_axis) > 0:
        print("Reorienting data...")

    #Use Peaks + T1
    # # t1_data = nib.load("T1w_acpc_dc_restore_brain_DWIsize.nii.gz").get_fdata()[:,:,:,None]
    # t1_data = nib.load("T1w_acpc_dc_restore_brain.nii.gz").get_fdata()[:,:,:,None]
    # # needed if upsampling of peaks resulted in one pixel less (sometimes)
    # # t1_data = nib.load("T1w_acpc_dc_restore_brain.nii.gz").get_fdata()[1:,1:-1,1:,None]
    # data = np.concatenate((data, t1_data), axis=3)

    output_float = False
    if Config.EXPERIMENT_TYPE == "dm_regression" or \
       Config.EXPERIMENT_TYPE == "peak_regression" or \
       dropout_sampling or \
       args.get_probabilities:
        output_float = True

    ####################################### Process #######################################

    if Config.EXPERIMENT_TYPE == "peak_regression":
        parts = ["Part1", "Part2", "Part3", "Part4"]
        if manual_exp_name is not None and "PeaksPart1" in manual_exp_name:
            print("INFO: Only using Part1")
            parts = ["Part1"]
    else:
        parts = [Config.CLASSES]

    for part in parts:
        if part.startswith("Part"):
            Config.CLASSES = "All_" + part
            Config.NR_OF_CLASSES = 3 * len(dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

        seg = run_tractseg(data, args.output_type,
                           single_orientation=single_orientation,
                           dropout_sampling=dropout_sampling, threshold=threshold,
                           bundle_specific_postprocessing=bundle_specific_postprocessing,
                           get_probs=args.get_probabilities, peak_threshold=peak_threshold,
                           postprocess=postprocess, peak_regression_part=part,
                           input_type=input_type, blob_size_thr=blob_size_thr, nr_cpus=args.nr_cpus,
                           verbose=args.verbose, manual_exp_name=manual_exp_name,
                           inference_batch_size=inference_batch_size,
                           tract_definition=args.tract_definition, bedpostX_input=bedpostX_input,
                           tract_segmentations_path=tract_segmentations_path,
                           TOM_dilation=TOM_dilation,
                           unit_test=args.test)

        # Undo image flipping if it was applied previously
        for axis in flip_axis:
            seg = img_utils.flip_axis(seg, axis)

        ####################################### Save output #######################################

        if args.preview and Config.CLASSES not in ["All_Part2", "All_Part3", "All_Part4"]:
            print("Saving preview...")
            plot_utils.plot_tracts_matplotlib(Config.CLASSES, seg, data, Config.PREDICT_IMG_OUTPUT,
                                              threshold=Config.THRESHOLD, exp_type=Config.EXPERIMENT_TYPE)

        if Config.EXPERIMENT_TYPE == "dm_regression":
            seg[seg < Config.THRESHOLD] = 0
            if args.rescale_dm:
                seg = img_utils.scale_to_range(seg, range(0, 100))

        if Config.SINGLE_OUTPUT_FILE:
            img = nib.Nifti1Image(seg, data_affine)
            del seg
            if Config.EXPERIMENT_TYPE == "tract_segmentation" and dropout_sampling:
                output_subdir = "bundle_uncertainties"
                nib.save(img, join(Config.PREDICT_IMG_OUTPUT, output_subdir + ".nii.gz"))
            elif Config.EXPERIMENT_TYPE == "tract_segmentation":
                output_subdir = "bundle_segmentations"
                nib.save(img, join(Config.PREDICT_IMG_OUTPUT, output_subdir + ".nii.gz"))
            elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                output_subdir = "bundle_endings"
                nib.save(img, join(Config.PREDICT_IMG_OUTPUT, output_subdir + ".nii.gz"))
            elif Config.EXPERIMENT_TYPE == "peak_regression":
                output_subdir = "bundle_TOMs"
                nib.save(img, join(Config.PREDICT_IMG_OUTPUT, output_subdir + ".nii.gz"))
            elif Config.EXPERIMENT_TYPE == "dm_regression":
                output_subdir = "bundle_density_maps"
                nib.save(img, join(Config.PREDICT_IMG_OUTPUT, output_subdir + ".nii.gz"))
            del img  # Free memory (before we run tracking)
        else:
            if Config.EXPERIMENT_TYPE == "tract_segmentation" and dropout_sampling:
                output_subdir = "bundle_uncertainties"
                img_utils.save_multilabel_img_as_multiple_files(Config.CLASSES, seg, data_affine,
                                                                Config.PREDICT_IMG_OUTPUT,
                                                                name=output_subdir)
            elif Config.EXPERIMENT_TYPE == "tract_segmentation":
                output_subdir = args.tract_segmentation_output_dir
                img_utils.save_multilabel_img_as_multiple_files(Config.CLASSES, seg, data_affine,
                                                                Config.PREDICT_IMG_OUTPUT,
                                                                name=output_subdir)
            elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                output_subdir = "endings_segmentations"
                img_utils.save_multilabel_img_as_multiple_files_endings(Config.CLASSES, seg, data_affine,
                                                                        Config.PREDICT_IMG_OUTPUT,
                                                                        name=output_subdir)
            elif Config.EXPERIMENT_TYPE == "peak_regression":
                output_subdir = args.TOM_output_dir
                img_utils.save_multilabel_img_as_multiple_files_peaks(Config.FLIP_OUTPUT_PEAKS, Config.CLASSES, seg,
                                                                      data_affine, Config.PREDICT_IMG_OUTPUT,
                                                                      name=output_subdir)
            elif Config.EXPERIMENT_TYPE == "dm_regression":
                output_subdir = "dm_regression"
                img_utils.save_multilabel_img_as_multiple_files(Config.CLASSES, seg, data_affine,
                                                                Config.PREDICT_IMG_OUTPUT, name=output_subdir)
            del seg  # Free memory (before we run tracking)

    if Config.EXPERIMENT_TYPE == "peak_regression": Config.CLASSES = "All"

    if args.preprocess:
        if Config.SINGLE_OUTPUT_FILE:
            if Config.EXPERIMENT_TYPE == "peak_regression":
                raise ValueError("single_output_file not supported for TOMs")
            else:
                preprocessing.move_to_subject_space_single_file(Config.PREDICT_IMG_OUTPUT, Config.EXPERIMENT_TYPE,
                                                                output_subdir, output_float=output_float)
        else:
            bundles = dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]
            preprocessing.move_to_subject_space(Config.PREDICT_IMG_OUTPUT, bundles, Config.EXPERIMENT_TYPE,
                                                output_subdir, output_float=output_float)

    preprocessing.clean_up(Config.KEEP_INTERMEDIATE_FILES, Config.PREDICT_IMG_OUTPUT, Config.CSD_TYPE,
                           preprocessing_done=args.preprocess)


if __name__ == '__main__':
    main()
