#!python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import warnings
import argparse
import importlib
import os
from os.path import join
from tqdm import tqdm

from tractseg.libs.system_config import get_config_name
from tractseg.libs import exp_utils
from tractseg.libs import tracking
from tractseg.data import dataset_specific_utils

warnings.simplefilter("ignore", UserWarning)  # hide scipy warnings
warnings.simplefilter("ignore", FutureWarning)  # hide h5py warnings
warnings.filterwarnings("ignore", message="numpy.dtype size changed")  # hide Cython benign warning
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")  # hide Cython benign warning


def parse_bundles_string(bundles_string, classes):
    all_bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
    if bundles_string == "all":
        return all_bundles
    else:
        bundles = bundles_string.strip().split(",")
        for bundle in bundles:
            if bundle not in all_bundles:
                raise ValueError("Invalid bundle name: {}".format(bundle))
        return bundles

def main():
    parser = argparse.ArgumentParser(description="Tracking on the TOMs created by TractSeg.",
                                        epilog="Written by Jakob Wasserthal. Please reference 'Wasserthal et al. "
                                               "TractSeg - Fast and accurate white matter tract segmentation. "
                                               "https://doi.org/10.1016/j.neuroimage.2018.07.070'")

    parser.add_argument("-i", metavar="filepath", dest="input",
                        help="CSD peaks in MRtrix format (4D Nifti image with dimensions [x,y,z,9])", required=True)

    parser.add_argument("-o", metavar="directory", dest="output",
                        help="Output directory (default: tractseg_output)")

    parser.add_argument("--tracking_dir", metavar="folder_name",
                        help="Set name of folder which will be created to save the tracking output.",
                        default="auto")

    parser.add_argument("--algorithm", metavar="prob|det",
                        choices=["prob", "det"],
                        help="Choose tracking algorithm. 'prob' uses a custom probabilistic algorithm which is more"
                             "sensitive. 'det' uses Mrtrix deterministic tracking (FACT) which is more specific. "
                             "(default: prob)",
                        default="prob")

    parser.add_argument("--track_FODs", metavar="False|FACT|SD_STREAM|iFOD2",
                        choices=["False", "FACT", "SD_STREAM", "iFOD2"],
                        help="Running tracking on FODs (provided as argument to '-i') instead of tracking on TOMs. "
                             "Only works if you do not use '--no_filtering_by_endpoints'. Uses MRtrix tracking. If you "
                             "choose 'FACT' you have to pass a peak image instead of FODs to -i.",
                        default="False")

    parser.add_argument("--track_best_orig", action="store_true",
                        help="Select peak from the original peaks (provided as argument to -i) which is closest to "
                             "the TOM peak. Track on these peaks.",
                        default=False)

    parser.add_argument("--tracking_dilation", metavar="n", type=int,
                        help="Dilate the endpoint and bundle masks by the "
                             "respective number of voxels. (default: 0)",
                        default=0)  # Info: dilation of endpoint mask: dilation+1

    parser.add_argument("--nr_fibers", metavar="n", type=int, help="Number of fibers to create (default: 2000)",
                        default=2000)

    parser.add_argument("--tracking_format", metavar="tck|trk|trk_legacy", choices=["tck", "trk", "trk_legacy"],
                        help="Set output format of tracking. trk_legacy is not suppurted by newer nibabel "
                             "anymore. (default: tck)",
                        default="tck")

    parser.add_argument("--no_filtering_by_endpoints", action="store_true",
                        help="Run tracking on TOMs without filtering results by tract mask and endpoint masks."
                             "MRtrix FACT tracking will be used instead of the 'probabilistic' tracking on peaks.",
                        default=False)

    parser.add_argument("--bundles", metavar="A,B,C", dest="bundles_string",
                        help="Comma separated list (without spaces) of bundles you want to track. " +
                             "Leave empty if you want to track all.",
                        default="all")

    parser.add_argument("--nr_cpus", metavar="n", type=int,
                        help="Number of CPUs to use. -1 means all available CPUs (default: -1)",
                        default=-1)

    parser.add_argument("--test", metavar="0|1|2|3", choices=[0, 1, 2, 3], type=int,
                        help="Only needed for unittesting.",
                        default=0)

    parser.add_argument("--verbose", action="store_true", help="Show more intermediate output",
                        default=False)

    args = parser.parse_args()


    #Private parameters
    tracking_type = args.algorithm
    prob_tracking_software = "tractseg"  # tractseg|mrtrix  (if mrtrix then use iFOD2 on TOMs)
    use_best_original_peaks = args.track_best_orig  # Use best original peak instead of TOM peak for tracking (ProbDet tracking)
    use_as_prior = False  # Track on weighted average between best original peaks and TOM
    # use default naming scheme plus this postfix when looking for TractSeg output subfolders during tracking
    dir_postfix = ""    # e.g. "_BSThr"
    tract_definition = "TractQuerier+"  # TractQuerier+|xtract
    next_step_displacement_std = 0.15
    filter_tracking_by_endpoints = not args.no_filtering_by_endpoints
    input_path = args.input
    config_file = get_config_name("peaks", "tract_segmentation", dropout_sampling=False,
                                  tract_definition=tract_definition)
    Config = getattr(importlib.import_module("tractseg.experiments.pretrained_models." + config_file), "Config")()
    Config = exp_utils.get_correct_labels_type(Config)
    Config.PREDICT_IMG = input_path is not None
    if args.output:
        Config.PREDICT_IMG_OUTPUT = args.output
    elif Config.PREDICT_IMG:
        Config.PREDICT_IMG_OUTPUT = join(os.path.dirname(input_path), Config.TRACTSEG_DIR)

    tracking_on_FODs = args.track_FODs != "False"

    # Set tracking software
    if not tracking_on_FODs and tracking_type == "prob":
        tracking_software = prob_tracking_software
    else:
        tracking_software = "mrtrix"

    # Set tracking algorithm
    if tracking_on_FODs:
        tracking_algorithm = args.track_FODs
    elif tracking_type == "det":
        tracking_algorithm = "FACT"
    else:  # "prob"
        if tracking_software == "mrtrix":
            tracking_algorithm = "iFOD2"
        else:
            tracking_algorithm = "fixed_prob"

    if args.test == 1:
        bundles = dataset_specific_utils.get_bundle_names("test")[1:]
    elif args.test == 2:
        bundles = dataset_specific_utils.get_bundle_names("toy")[1:]
    elif args.test == 3:
        bundles = dataset_specific_utils.get_bundle_names("test_single")[1:]
    else:
        bundles = parse_bundles_string(args.bundles_string, Config.CLASSES)

    for bundle in tqdm(bundles):
        tracking.track(bundle, input_path, Config.PREDICT_IMG_OUTPUT,
                       tracking_on_FODs, tracking_software, tracking_algorithm,
                       use_best_original_peaks=use_best_original_peaks, use_as_prior=use_as_prior,
                       filter_by_endpoints=filter_tracking_by_endpoints,
                       tracking_folder=args.tracking_dir, dir_postfix=dir_postfix, dilation=args.tracking_dilation,
                       next_step_displacement_std=next_step_displacement_std,
                       output_format=args.tracking_format, nr_fibers=args.nr_fibers, nr_cpus=args.nr_cpus)


if __name__ == '__main__':
    main()
