#!/usr/bin/env python
import argparse
import copy
import warnings

import pandas as pd
from plotnine.exceptions import PlotnineWarning

from nanoCEM.cem_utils import read_fasta_to_dic,identify_file_path,build_out_path,extract_kmer_feature,calculate_MANOVA_result
from nanoCEM.plot import current_plot,plot_PCA,MANOVA_plot

warnings.filterwarnings("ignore", category=PlotnineWarning)

def init_parser():
    def add_public_argument(parser_input):
        parser_input.add_argument("--chrom", required=True, help="Gene or chromosome name(head of your fasta file)")
        parser_input.add_argument("--pos", required=True, type=int, help="site of your interest")
        parser_input.add_argument("--len", default=10, type=int, help="region around the position")
        parser_input.add_argument("--strand", default="+", help="Strand of your interest")
        parser_input.add_argument('-r', "--ref", required=True, help="fasta file")
        parser_input.add_argument('-t', "--cpu", default=4, type=int, help="num of process")
        parser_input.add_argument('--norm', action='store_true', help='Turn on the normalization mode')
        parser_input.add_argument('-s', "--subsample_ratio", default=1, type=float,
                                  help="Subsample ratio to select reads")
        parser_input.add_argument('--rna', action='store_true', help='Turn on the RNA mode')
        parser_input.add_argument('--kmer_size', choices=['1','3','5','7'],default='3', help="kmer size for PCA and MANOVA analysis")
        parser_input.add_argument('-o', "--output", default="nanoCEM_result", help="output_file")

    # Define the argument parser
    parser = argparse.ArgumentParser(
        description='A sample tool designed to visualize the features that distinguish between two groups of ONT data at the site level. It supports two re-squiggle pipeline(Tombo and f5c).')
    subparsers = parser.add_subparsers(dest='function')

    # tombo subparser
    parser_tombo = subparsers.add_parser('tombo', help='tackle tombo re-squiggle')
    parser_tombo.add_argument('--basecall_group', default="RawGenomeCorrected_000",
                              help='The attribute group to extract the training data from. e.g. RawGenomeCorrected_000')
    parser_tombo.add_argument('--basecall_subgroup', default='BaseCalled_template',
                              help='Basecall subgroup Nanoraw resquiggle into. Default is BaseCalled_template')
    parser_tombo.add_argument('-i', "--input_fast5", required=True,
                              help="input_fast5_file")
    parser_tombo.add_argument('-c', "--control_fast5",
                              help="control_fast5_file")
    # parser_tombo.add_argument('-b', "--bam", help="bam file to help index to speed up")


    add_public_argument(parser_tombo)

    # f5c subparser
    parser_f5c = subparsers.add_parser('f5c_ev', help='tackle f5c eventalign')
    parser_f5c.add_argument("-i", "--input", required=True,
                            help="blow5_path")
    parser_f5c.add_argument('-c', "--control",
                            help="control_blow5_path")
    parser_f5c.add_argument('--base_shift',choices=['auto','0','-1','-2','-3','-4','-5','-6','-7','-8'],default='auto', help='base shift option')
    parser_f5c.add_argument('--pore', choices=['r9', 'r10','rna004'], help='Select pore type', default='r9')
    add_public_argument(parser_f5c)

    parser_f5c = subparsers.add_parser('f5c_re', help='tackle f5c re-squiggle')
    parser_f5c.add_argument("-i", "--input", required=True,
                            help="blow5_path")
    parser_f5c.add_argument('-c', "--control",
                            help="control_blow5_path")
    parser_f5c.add_argument('--base_shift',choices=['auto','0','-1','-2','-3','-4','-5','-6','-7','-8'],default='auto', help='base shift option')
    parser_f5c.add_argument('--pore', choices=['r9', 'r10','rna004'], help='Select pore type', default='r9')
    add_public_argument(parser_f5c)

    parser_f5c = subparsers.add_parser('move_table', help='tackle move_table from basecaller')
    parser_f5c.add_argument("-i", "--input", required=True,
                            help="blow5_path")
    parser_f5c.add_argument('-c', "--control",
                            help="control_blow5_path")
    parser_f5c.add_argument('-m','--sig_move_offset', type=int, required=True, help='sig_move_offset for squigualiser reform')
    parser_f5c.add_argument('-k','--kmer_length', type=int, required=True, help='kmer_length for squigualiser reform')
    add_public_argument(parser_f5c)

    return parser

if __name__ == '__main__':
    # Parse the arguments
    parser = init_parser()
    args = parser.parse_args()
    # args check
    if args.function is None:
        raise Exception("Please choose function from f5c or tombo, or use -h to view the help document")
    identify_file_path(args.ref)
    fasta = read_fasta_to_dic(args.ref)
    args.pos = args.pos - 1
    length_gene = len(fasta[args.chrom])
    if args.pos + args.len + 1 >= length_gene or args.pos - args.len <= 0:
        raise Exception("The position requested is too close to the border (pos-len>0 and pos+len<length of fasta)")

    # build title
    base_list = fasta[args.chrom][args.pos - args.len:args.pos + args.len + 1].upper()
    title = args.chrom + ':' + str(args.pos - args.len + 1) + '-' + str(args.pos + args.len + 1) + ':' + args.strand
    if args.rna:
        args.rna =True
        base_list = base_list.replace("T", "U")
        print("Running RNA mode ...")
        title = 'RNA  ' + title
    else:
        print("Running DNA mode ...")
        title = 'DNA  ' + title

    results_path = args.output
    build_out_path(results_path)
    subsample_ratio = args.subsample_ratio
    single_mode = False

    if args.function == 'tombo':
        from nanoCEM.read_tombo_resquiggle import create_read_list_file, extract_group

        wt_file = create_read_list_file(args.input_fast5, results_path)
        df_wt, aligned_num_wt = extract_group(args, wt_file, subsample_ratio)
        df_wt['Group'] = 'Sample'
        try:
            ivt_file = create_read_list_file(args.control_fast5, results_path)
            df_ivt, aligned_num_ivt = extract_group(args, ivt_file, subsample_ratio)
            df_ivt['Group'] = 'Control'
        except:
            single_mode = True
        if args.control_fast5 is None:
            single_mode = True

    elif args.function == 'f5c_ev':
        from nanoCEM.read_f5c_eventalign import read_blow5
        if args.base_shift !='auto':
            args.base_shift = int(args.base_shift)
        df_wt, aligned_num_wt, nucleotide_type = read_blow5(args.input, args.pos,args.ref, args.len, args.chrom, args.strand, args.pore,
                                                            subsample_ratio, args.base_shift, args.norm, str(args.cpu),args.rna)
        df_wt['Group'] = 'Sample'
        if nucleotide_type == 'RNA' and not args.rna:
            raise RuntimeError("You need to add --rna to turn on the rna mode")
        try:
            df_ivt, aligned_num_ivt, _ = read_blow5(args.control, args.pos,args.ref, args.len, args.chrom, args.strand, args.pore,
                                                            subsample_ratio, args.base_shift, args.norm, str(args.cpu),args.rna)
            df_ivt['Group'] = 'Control'
            df = pd.concat([df_wt, df_ivt])
        except:
            single_mode =True
        if args.control is None:
            single_mode =True
    elif args.function == 'move_table':
        from nanoCEM.read_move_table import read_basecall_bam
        df_wt, aligned_num_wt = read_basecall_bam(args.input, args.pos, args.ref, args.len, args.chrom,
                                                            args.strand, args.sig_move_offset, args.kmer_length,
                                                            subsample_ratio, args.norm, str(args.cpu),
                                                            args.rna)
        df_wt['Group'] = 'Sample'
        try:
            df_ivt, aligned_num_ivt = read_basecall_bam(args.control, args.pos, args.ref, args.len, args.chrom, args.strand, args.sig_move_offset, args.kmer_length,
                                                    subsample_ratio, args.norm, str(args.cpu),
                                                    args.rna)
            df_ivt['Group'] = 'Control'


        except Exception as e:
            print(e)
            single_mode = True
        if args.control is None:
            single_mode = True
    elif args.function == 'f5c_re':
        from nanoCEM.read_f5c_resquiggle import read_blow5
        if args.base_shift !='auto':
            args.base_shift = int(args.base_shift)
        df_wt, aligned_num_wt, nucleotide_type = read_blow5(args.input, args.pos,args.ref, args.len, args.chrom, args.strand, args.pore,
                                                            subsample_ratio, args.base_shift, args.norm, str(args.cpu),args.rna)
        df_wt['Group'] = 'Sample'
        if nucleotide_type == 'RNA' and not args.rna:
            raise RuntimeError("You need to add --rna to turn on the rna mode")
        try:
            df_ivt, aligned_num_ivt, _ = read_blow5(args.control, args.pos,args.ref, args.len, args.chrom, args.strand, args.pore,
                                                            subsample_ratio, args.base_shift, args.norm, str(args.cpu),args.rna)
            df_ivt['Group'] = 'Control'

            df = pd.concat([df_wt, df_ivt])
        except:
            single_mode =True
        if args.control is None:
            single_mode =True
    else:
        raise  RuntimeError("Wrong option")
    if single_mode:
        df = df_wt
        df_wt['Group'] = 'Single'
        title = title + '   Sample:' + str(aligned_num_wt)
    else:
        df = pd.concat([df_wt, df_ivt])
        title = title + '   Sample:' + str(aligned_num_wt) + '  Control:' + str(aligned_num_ivt)
        category = pd.api.types.CategoricalDtype(categories=['Sample', "Control"], ordered=True)
        df['Group'] = df['Group'].astype(category)

    # save result table
    df_copy = copy.deepcopy(df)
    df_copy['Position'] = df_copy['Position'].astype(int) + 1
    df_copy.to_csv(results_path + '/current_feature.csv', index=None)
    print("Feature file saved  in " + results_path + '/current_feature.csv')
    # draw current feature
    current_plot(df, results_path, args.pos, base_list, title,filter=True)
    if not single_mode:
        if args.kmer_size is None:
            args.kmer_size = 3
        else:
            args.kmer_size = int(args.kmer_size)
        # draw PCA
        kmer_size = args.kmer_size
        # df_copy.loc[:, 'Position'] = df_copy['Position'].astype(int)
        feature_matrix, label = extract_kmer_feature(df_copy,kmer_size, args.pos + 1)
        plot_PCA(feature_matrix, label, results_path)
        # draw p_value
        new_df = calculate_MANOVA_result( args.pos+1,df_copy,args.len,500,args.len,kmer_size)
        MANOVA_plot(new_df, results_path,3)
    print('Finished')

