#!/usr/bin/env python

"""
Flip Mrtrix peaks (https://mrtrix.readthedocs.io/en/latest/reference/commands/sh2peaks.html) along specific axis.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import nibabel as nib


def main():
    parser = argparse.ArgumentParser(description="Flip Mrtrix peaks along specific axis.",
                                        epilog="Written by Jakob Wasserthal. Please reference 'Wasserthal et al. "
                                               "TractSeg - Fast and accurate white matter tract segmentation. "
                                               "https://doi.org/10.1016/j.neuroimage.2018.07.070)'")
    parser.add_argument("-i", metavar="file_in", dest="file_in", help="input file (4D nifti image)", required=True)
    parser.add_argument("-o", metavar="file_out", dest="file_out", help="output file (4D nifti image)", required=True)
    parser.add_argument("-a", metavar="axis", dest="axis", help="axis (x|y|z)", choices=["x", "y", "z"], required=True)
    args = parser.parse_args()

    img = nib.load(args.file_in)
    data = img.get_fdata()

    nr_peaks = data.shape[3] // 3

    #We only flip data, not affine
    if args.axis == "x":
        for i in range(nr_peaks):
            data[ :, :, :, (i*3)+0] *= -1
    elif args.axis == "y":
        for i in range(nr_peaks):
            data[ :, :, :, (i*3)+1] *= -1
    elif args.axis == "z":
        for i in range(nr_peaks):
            data[ :, :, :, (i*3)+2] *= -1

    new_image = nib.Nifti1Image(data, img.affine)
    nib.save(new_image, args.file_out)


if __name__ == '__main__':
    main()
