#!python
"""
This file is a wrapper file to run isONform in parallel (parallelization over batches).
The script was taken from isoncorrect (https://github.com/ksahlin/isoncorrect/blob/master/run_isoncorrect)
by Kristoffer Sahlin and changed by Alexander Petri to be usable with the isONform code base.

"""


from __future__ import print_function
import argparse
import tempfile
from ast import Param
from time import time
from pathlib import Path
import signal
from multiprocessing import Pool
import multiprocessing as mp

import subprocess
import sys
import os
from sys import stdout
import shutil
import errno

from modules import batch_merging_parallel
from modules import help_functions
from modules import Parallelization_side_functions

def restructure_isoncorrect_output(directory):
    file_list= list(filter(lambda f: os.path.isfile(os.path.join(directory, f)), os.listdir(directory) ))
    #print("file_list:", file_list)
    if len(file_list) == 0:
        print("isONcorrect structure")
        #iterate over all subdirectories in directory
        for subdirectory in os.listdir(directory):
            subdirectory_path = os.path.join(directory, subdirectory)
            # Check if it's a directory
            if os.path.isdir(subdirectory_path):
                # Define the source file path
                source_file = os.path.join(subdirectory_path, 'corrected_reads.fastq')
                if source_file:
                    # Define the target file path in the main folder with new name
                    target_file = os.path.join(directory, f"{subdirectory}.fastq")
                    # Move and rename if the file exists
                    if os.path.exists(source_file):
                        shutil.move(source_file, target_file)
                        print(f"Moved and renamed {source_file} to {target_file}")
                    else:
                        print(f"File {source_file} does not exist.")
        # The following line removes all folders in the input folder
        Parallelization_side_functions.remove_folders(directory)
    else:
        print("isONclust structure")

def wccount(filename):
    out = subprocess.Popen(['wc', '-l', filename],
                           stdout=subprocess.PIPE,
                           stderr=subprocess.STDOUT
                           ).communicate()[0]
    # print(int(out.split()[0]))
    return int(out.split()[0])


def isONform(data):
    isONform_location, read_fastq_file, outfolder, batch_id, isONform_algorithm_params,cl_id = data[0], data[1], data[
        2], data[3], data[4], data[5]
    help_functions.mkdir_p(outfolder)
    #print("OUT",outfolder)
    #print("Algoparams",isONform_algorithm_params)
    isONform_exec = os.path.join(isONform_location, "main")
    isONform_error_file = os.path.join(outfolder, "stderr.txt")
    with open(isONform_error_file, "w") as error_file:
        print('Running isONform batch_id:{0}.{1}...'.format(cl_id,batch_id), end=' ')
        stdout.flush()
        isONform_out_file = open(os.path.join(outfolder, "stdout{0}_{1}.txt".format(cl_id,batch_id)), "w")
        subprocess.check_call(
                ["python", isONform_exec, "--fastq", read_fastq_file, "--outfolder", outfolder,
                 "--exact_instance_limit", str(isONform_algorithm_params["exact_instance_limit"]),
                 #"--max_seqs", str(isONform_algorithm_params["max_seqs"]),
                 "--k", str(isONform_algorithm_params["k"]), "--w", str(isONform_algorithm_params["w"]),
                 "--xmin", str(isONform_algorithm_params["xmin"]), "--xmax",
                 str(isONform_algorithm_params["xmax"]),"--delta_len", str(isONform_algorithm_params["delta_len"]),
                 "--exact", "--parallel", "True",  "--delta_iso_len_3", str(isONform_algorithm_params["delta_iso_len_3"]), "--delta_iso_len_5", str(isONform_algorithm_params["delta_iso_len_5"])#, "--slow"
                 #"--T", str(isONform_algorithm_params["T"])
                 ], stderr=error_file, stdout=isONform_out_file)

        print('Done with batch_id:{0}.{1}'.format(cl_id,batch_id))
        stdout.flush()
    error_file.close()
    isONform_out_file.close()
    return batch_id

#splits files containing more than max_seqs reads into smaller files, that can be parallelized upon
def splitfile(indir, tmp_outdir, fname, chunksize,cl_id,ext):
    # from https://stackoverflow.com/a/27641636/2060202
    # fpath, fname = os.path.split(infilepath)
    #cl_id, ext = fname.rsplit('.',1)
    infilepath = os.path.join(indir, fname)
    #infilepath=indir
    # print(fpath, cl_id, ext)
    #print("now at splitfile")
    #print(indir, tmp_outdir, cl_id, ext)

    i = 0
    written = False
    with open(infilepath) as infile:
        while True:
            outfilepath = os.path.join(tmp_outdir, '{0}_{1}.{2}'.format(cl_id, i, ext) ) #"{}_{}.{}".format(foutpath, fname, i, ext)
            #print(outfilepath)
            with open(outfilepath, 'w') as outfile:
                for line in (infile.readline() for _ in range(chunksize)):
                    outfile.write(line)
                written = bool(line)
            # print(os.stat(outfilepath).st_size == 0)
            if os.stat(outfilepath).st_size == 0: # Corner case: Original file is even multiple of max_seqs, hence the last file becomes empty. Remove this
                os.remove(outfilepath)
            if not written:
                break
            i += 1


def symlink_force(target, link_name):
    #print("Symlink",link_name)
    try:
        os.symlink(target, link_name)
    except OSError as e:
        if e.errno == errno.EEXIST:
            if not os.path.exists(os.readlink(link_name)):
                print('path %s is a broken symlink' % link_name)
                os.remove(link_name)
                symlink_force(target,link_name)
        else:
            raise e

#splits clusters up so that we get smaller batches
def split_cluster_in_batches_corrected(indir, outdir, tmp_work_dir, max_seqs):
    # create a modified indir
    tmp_work_dir = os.path.join(tmp_work_dir, 'split_in_batches')
    # print(indir)
    help_functions.mkdir_p(tmp_work_dir)
    pat = Path(indir)
    #collect all fastq files located in this directory or any subdirectories
    file_list = list(pat.rglob('*.fastq'))
    #print("FLIST",file_list)
    #iterate over the fastq_files
    for filepath in file_list:
        #print("FPATH",filepath)
        old_fastq_file = str(filepath.resolve())
        path_split = old_fastq_file.split("/")
        folder = path_split[-2]
        #print(folder)
        fastq_file = path_split[-1]
        #we do not want to look at the analysis fastq file
        if not folder == "Analysis":
            cl_id=path_split[-2]
            #print("CLID",cl_id)

            #if we have more lines than max_seqs
            new_indir=os.path.join(indir,folder)
            #print(new_indir)

            num_lines = sum(1 for line in open(os.path.join(new_indir, fastq_file)))
            #print("Number Lines", fastq_file, num_lines)

            # determine whether the file is larger than max_seqs
            larger_than_max_seqs = num_lines > 4 * max_seqs
            if larger_than_max_seqs:
                #print("Splitting",filepath)
                ext = fastq_file.rsplit('.', 1)[1]
                splitfile(new_indir, tmp_work_dir, fastq_file, 4 * max_seqs,cl_id,ext)  # is fastq file
            else:
                ext = fastq_file.rsplit('.', 1)[1]
                #print(fastq_file, "symlinking instead")
                symlink_force(filepath, os.path.join(tmp_work_dir, '{0}_{1}.{2}'.format(cl_id, 0, ext)))
    return tmp_work_dir

def split_cluster_in_batches_clust(indir, outdir, tmp_work_dir, max_seqs):
    # create a modified indir
    tmp_work_dir = os.path.join(tmp_work_dir, 'split_in_batches')
    # print(indir)
    #os.mkdir(tmp_work_dir)
    help_functions.mkdir_p(tmp_work_dir)
    #print(tmp_work_dir)
    #print("clust")

    pat = Path(indir)
    file_list = list(pat.rglob('*.fastq'))
    # add split files to this indir
    for file_ in file_list:
    #for file_ in sorted(os.listdir(indir), key=lambda x: int(x.split('.')[0])):
        #fastq_path = os.fsdecode(file_)
        old_fastq_file = str(file_.resolve())
        fastq_file = old_fastq_file.split("/")[-1]
        #print("FASTQ",fastq_file)

        num_lines = sum(1 for line in open(os.path.join(indir, fastq_file)))

        # determine whether the file is larger than max_seqs
        larger_than_max_seqs = num_lines > 4 * max_seqs
        if larger_than_max_seqs:
            cl_id, ext = fastq_file.rsplit('.', 1)
            splitfile(indir, tmp_work_dir, fastq_file, 4 * max_seqs, cl_id, ext) # is fastq file
        else:
            cl_id, ext = fastq_file.rsplit('.', 1)
            print("SYMLINK",os.path.join(tmp_work_dir, '{0}_{1}.{2}'.format(cl_id, 0, ext)))
            symlink_force(file_, os.path.join(tmp_work_dir, '{0}_{1}.{2}'.format(cl_id, 0, ext)))
    return tmp_work_dir


#PYTHONHASHSEED = 0
def main(args):
    globstart = time()
    directory = args.fastq_folder
    #Please note that this alters the output of isONcorrect i.e. the structure of corrected in the isONpipeline
    restructure_isoncorrect_output(directory)


    #print("MERGE?", args.merge_sub_isoforms_3, args.merge_sub_isoforms_5)

      # os.fsencode(args.fastq_folder)
    write_low_abundance = False
    #print(directory)
    #print("ARGS",args)
    isONform_location = os.path.dirname(os.path.realpath(__file__))
    if args.split_wrt_batches:
        if args.tmpdir:
            tmp_work_dir = args.tmpdir
            #curr_work_dir = os.getcwd()
            #os.chdir(tmp_work_dir)
            #try:
                # Create the directory
                #os.makedirs(tmp_work_dir)
            #    print("Directory created successfully.")
            #except FileExistsError:
            #    print("Directory already exists.")
            Parallelization_side_functions.mkdir_p(tmp_work_dir)
            #os.chdir(curr_work_dir)
        else:
            tmp_work_dir = tempfile.mkdtemp()
        print("Temporary workdirectory:", tmp_work_dir)
        split_tmp_directory = split_cluster_in_batches_clust(directory, args.outfolder, tmp_work_dir,
                                                                     args.max_seqs)
        split_directory = os.fsencode(split_tmp_directory)
    else:

        split_directory = os.fsencode(directory)
        print("SplitDIR", split_directory)
    instances = []
    for file_ in os.listdir(split_directory):
            #print(file_)
            read_fastq_file = os.fsdecode(file_)
            if read_fastq_file.endswith(".fastq"):
                #print("True")
                tmp_id = read_fastq_file.split(".")[0]
                snd_tmp_id = tmp_id.split("_")
                cl_id = snd_tmp_id[0]
                batch_id = snd_tmp_id[1] if len(snd_tmp_id) > 1 else 0
                outfolder = os.path.join(args.outfolder, cl_id)
                #print(batch_id,cl_id)
                #print(outfolder)
                fastq_file_path = os.path.join(os.fsdecode(split_directory), read_fastq_file)
                #print(fastq_file_path)
                compute = True
                if args.keep_old:
                    candidate_corrected_file = os.path.join(outfolder, "isoforms.fastq")
                    if os.path.isfile(candidate_corrected_file):
                        if wccount(candidate_corrected_file) == wccount(fastq_file_path):
                            #print("already computed cluster and complete file", batch_id)
                            compute = False

                if compute:
                    #print("computing")
                    isONform_algorithm_params = {"set_w_dynamically": args.set_w_dynamically,
                                                    "exact_instance_limit": args.exact_instance_limit,
                                                    "delta_len": args.delta_len,"--exact": True,
                                                    "k": args.k, "w": args.w, "xmin": args.xmin, "xmax": args.xmax,
                                                     "max_seqs": args.max_seqs, "parallel": True, "--slow": True, "delta_iso_len_3": args.delta_iso_len_3,
                                                 "delta_iso_len_5": args.delta_iso_len_5}
                    instances.append(
                        (isONform_location, fastq_file_path, outfolder, batch_id, isONform_algorithm_params, cl_id))
            else:
                continue
    instances.sort(key=lambda x: (int(x[5]), int(x[3])))  # sorting on cluster_id and then batch_id numerically
    print("Printing instances")
    for t in instances:
        print(t)
    original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
    signal.signal(signal.SIGINT, original_sigint_handler)
    try:
        mp.set_start_method('spawn')
    except RuntimeError:
        pass
    #mp.set_start_method('spawn')
    print(mp.get_context())
    print("Environment set:", mp.get_context())
    print("Using {0} cores.".format(args.nr_cores))
    start_multi = time()
    pool = Pool(processes=int(args.nr_cores))
    try:
        start = time()
        for x in pool.imap_unordered(isONform, instances):
            print("{} (Time elapsed: {}s)".format(x, int(time() - start)))
    except KeyboardInterrupt:
        print("Caught KeyboardInterrupt, terminating workers")
        pool.terminate()
        sys.exit()
    else:
        pool.close()
    pool.join()
    print("Time elapsed multiprocessing:", time() - start_multi)

    print("Merging...")
    file_handling = time()
    if args.write_fastq:
        write_fastq = True
    else:
        write_fastq = False
    batch_merging_parallel.join_back_via_batch_merging(args.outfolder, args.delta, args.delta_len, args.delta_iso_len_3, args.delta_iso_len_5, args.max_seqs_to_spoa, args.iso_abundance, write_fastq, write_low_abundance)
    Parallelization_side_functions.generate_full_output(args.outfolder, write_fastq, write_low_abundance)
    Parallelization_side_functions.remove_folders(args.outfolder)
    if args.split_wrt_batches:
        print("Removed the split directory")
        shutil.rmtree(split_directory)
    print("Joined back batched files in:", time() - file_handling)
    print("Finished full algo after :", time() - globstart)
    return


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="De novo reconstruction of long-read transcriptome reads",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--version', action='version', version='%(prog)s 0.3.9')
    parser.add_argument('--fastq_folder', type=str, default=False,
                        help='Path to input fastq folder with reads in clusters')
    parser.add_argument('--t', dest="nr_cores", type=int, default=8, help='Number of cores allocated for clustering')
    parser.add_argument('--k', type=int, default=20, help='Kmer size')
    parser.add_argument('--w', type=int, default=31, help='Window size')
    parser.add_argument('--xmin', type=int, default=18, help='Lower interval length')
    parser.add_argument('--xmax', type=int, default=80, help='Upper interval length')
    parser.add_argument('--exact_instance_limit', type=int, default=50,
                        help='Do exact correction for clusters under this threshold')
    parser.add_argument('--keep_old', action="store_true",
                        help='Do not recompute previous results if corrected_reads.fq is found and has the smae number of reads as input file (i.e., is complete).')
    parser.add_argument('--set_w_dynamically', action="store_true",
                        help='Set w = k + max(2*k, floor(cluster_size/1000)).')
    parser.add_argument('--max_seqs', type=int, default=1000,
                        help='Maximum number of seqs to correct at a time (in case of large clusters).')
    parser.add_argument('--split_wrt_batches', action="store_true",
                        help='Process reads per batch (of max_seqs sequences) instead of per cluster. Significantly decrease runtime when few very large clusters are less than the number of cores used.')
    #parser.add_argument('--clustered', action="store_true",
    #                    help='Indicates whether we use the output of isONclust (i.e. we have uncorrected data)')
    parser.add_argument('--outfolder', type=str, default=None, help='Outfolder with all corrected reads.')
    parser.add_argument('--delta_len', type=int, default=5,
                        help='Maximum length difference between two reads intervals for which they would still be merged')
    parser.add_argument('--delta',type=float,default=0.1, help='diversity rate used to compare sequences')
    parser.add_argument('--max_seqs_to_spoa', type=int, default=200, help='Maximum number of seqs to spoa')
    parser.add_argument('--verbose', action="store_true", help='Print various developer stats.')
    parser.add_argument('--iso_abundance', type=int, default=5,
                        help='Cutoff parameter: abundance of reads that have to support an isoform to show in results')
    parser.add_argument('--delta_iso_len_3', type=int, default=30,
                        help='Cutoff parameter: maximum length difference at 3prime end, for which subisoforms are still merged into longer isoforms')
    parser.add_argument('--delta_iso_len_5', type=int, default=50,
                        help='Cutoff parameter: maximum length difference at 5prime end, for which subisoforms are still merged into longer isoforms')
    parser.add_argument('--tmpdir', type=str,default=None, help='OPTIONAL PARAMETER: Absolute path to custom folder in which to store temporary files. If tmpdir is not specified, isONform will attempt to write the temporary files into the tmp folder on your system. It is advised to only use this parameter if the symlinking does not work on your system.')
    parser.add_argument('--write_fastq', action="store_true", help=' Indicates that we want to ouptut the final output (transcriptome) as fastq file (New standard: fasta)')

    args = parser.parse_args()
    print(len(sys.argv))
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit()


    if args.outfolder and not os.path.exists(args.outfolder):
        os.makedirs(args.outfolder)

    main(args)
