#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from os.path import join
import argparse

import nibabel as nib
import numpy as np
from tqdm import tqdm

from tractseg.libs import tractometry
from tractseg.data import dataset_specific_utils


def main():
    parser = argparse.ArgumentParser(description="Evaluate image (e.g. FA) along fiber bundles.",
                                     epilog="Written by Jakob Wasserthal. Please reference 'Wasserthal et al. "
                                            "TractSeg - Fast and accurate white matter tract segmentation. "
                                            "https://doi.org/10.1016/j.neuroimage.2018.07.070)'")
    parser.add_argument("-i", metavar="tracking_dir", dest="tracking_dir",
                        help="Folder containing the TractSeg tractograms (normally '.../tractseg_output/TOM_tracking')",
                        required=True)

    parser.add_argument("-o", metavar="csv_output", dest="csv_file_out",
                        help="CSV output file containing the results", required=True)

    parser.add_argument("-e", metavar="endings_dir", dest="endings_dir",
                        help="Folder containing the TractSeg bundle endings segmentations "
                             "(normally '.../tractseg_output/endings_segmentations'). "
                             "Needed to ensure that all fibers are starting from the same side.", required=True)

    parser.add_argument("-s", metavar="scalar_img", dest="scalar_img",
                        help="Scalar image (e.g. FA) or peak image (MRtrix peaks) if using '--peak_length'",
                        required=True)

    parser.add_argument("--nr_points", metavar="n", dest="nr_points",
                        help="Number of points along streamline to evaluate (default: 100)",
                        default="100")

    parser.add_argument("--peak_length", action="store_true",
                        help="Instead of taking values of scalar image along streamlines (e.g. FA) take the length "
                             "of the peak pointing in the same direction as the respective bundle. Gives better "
                             "results in areas of crossing fibers than FA.",
                        default=False)

    parser.add_argument("--TOM", metavar="TOM_dir", dest="TOM_dir",
                        help="Folder containing Tract Orientation Maps (TOMs) (normally '.../tractseg_output/TOM'). "
                             "Needed if using the '--peak_length' option.")

    parser.add_argument("--tracking_format", metavar="tck|trk|trk_legacy", choices=["tck", "trk", "trk_legacy"],
                        help="Set output format of tracking. trk_legacy is not suppurted by newer nibabel "
                             "anymore. (default: tck)",
                        default="tck")

    parser.add_argument("--test", metavar="1|2|3", choices=[0, 1, 2, 3], type=int,
                        help="Only needed for unittesting.",
                        default=0)

    args = parser.parse_args()

    NR_POINTS = int(args.nr_points)
    # Dilation >0 important because otherwise some streamlines do not start/end in beginnings region and then
    # correct reorientation/flipping of streamlines does not work anymore
    DILATION = 2
    scalar_image = nib.load(args.scalar_img)

    if args.test == 1:
        bundles = dataset_specific_utils.get_bundle_names("test")[1:]
    elif args.test == 2:
        bundles = dataset_specific_utils.get_bundle_names("toy")[1:]
        DILATION = 0
    elif args.test == 3:
        bundles = dataset_specific_utils.get_bundle_names("test_single")[1:]
    else:
        bundles = dataset_specific_utils.get_bundle_names("All_tractometry")[1:]

    results = []
    for bundle in tqdm(bundles):
        if args.peak_length:
            predicted_peaks = nib.load(join(args.TOM_dir, bundle + ".nii.gz")).get_fdata()
        else:
            predicted_peaks = None
        beginnings = nib.load(join(args.endings_dir, bundle + "_b.nii.gz"))

        file_ending = "trk" if args.tracking_format == "trk_legacy" else args.tracking_format
        trk_path = join(args.tracking_dir, bundle + "." + file_ending)

        if not os.path.exists(trk_path):
            print("WARNING: No tracking found for bundle {}. Returning zeros.".format(bundle))
            mean = np.zeros(NR_POINTS)
            std = np.zeros(NR_POINTS)
        else:
            if args.tracking_format == "trk_legacy":
                from nibabel import trackvis
                streams, hdr = trackvis.read(trk_path)
                streamlines = [s[0] for s in streams]
            else:
                sl_file = nib.streamlines.load(trk_path)
                streamlines = sl_file.streamlines

            if len(streamlines) >= 5 or args.test == 2:
                mean, std = tractometry.evaluate_along_streamlines(np.nan_to_num(scalar_image.get_fdata()), streamlines,
                                                               beginnings.get_fdata(), NR_POINTS, dilate=DILATION,
                                                               predicted_peaks=predicted_peaks, affine=scalar_image.affine)
            else:
                print("WARNING: bundle {} contains less than 5 streamlines. Saving value 0 for this bundle.".
                      format(bundle))
                mean = np.zeros(NR_POINTS)
                std = np.zeros(NR_POINTS)

        # Remove first and last segment as those tend to be more noisy
        mean = mean[1:-1]
        std = std[1:-1]

        results.append(mean)

    bundle_string = ""
    for bundle in bundles:
        bundle_string += bundle + ";"
    bundle_string = bundle_string[:-1]

    np.savetxt(args.csv_file_out, np.array(results).transpose(), delimiter=";", header=bundle_string, comments="")

    # Notes on reproducibility
    # - map_coordinates, QuickBundles and cKDTree are deterministic for the same input streamlines
    # - Variance in final Tractometry results when running 2 probabilistic trackings (10k fibers and 100 points):
    #   - Max difference ~0.01, but on average difference around 0.004
    #   - When plotting the differences: very minor
    #   - Will get averaged out in group analysis


if __name__ == '__main__':
    main()