#!/usr/bin/env python
import argparse
import warnings

import pandas as pd
from plotnine.exceptions import PlotnineWarning

from nanoCEM.cem_utils import identify_file_path, build_out_path, run_samtools
from nanoCEM.plot import alignment_plot

warnings.filterwarnings("ignore", category=PlotnineWarning)


def count_mis(row, feature_matrix):
    for item in feature_matrix:
        row[item] = row[4].count(item)
    return row


def init_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--chrom", help="Gene or chromosome name(head of your fasta file)")
    parser.add_argument("--pos", type=int, help="site of your interest")
    parser.add_argument("--len", default=10, type=int, help="region around the position")
    parser.add_argument('-t',"--cpu", default=4, type=int, help="region around the position")
    parser.add_argument("--strand", default="+", help="Strand of your interest")
    parser.add_argument('--rna', action='store_true', help='Turn on the RNA mode')
    parser.add_argument('-r', "--ref", help="fasta file")
    parser.add_argument('-i', "--input_fastq",
                        help="input_fastq_file")
    parser.add_argument('-c', "--control_fastq",
                        help="control_fastq_file")
    # parser_tombo.add_argument('-b', "--bam", help="bam file to help index to speed up")
    parser.add_argument('-o', "--output", default="nanoCEM_result", help="output_file")

    return parser


if __name__ == '__main__':
    # Parse the arguments
    parser = init_parser()
    args = parser.parse_args()
    input_fastq = args.input_fastq
    control_fastq = args.control_fastq
    ref = args.ref
    cpu = str(args.cpu)
    # check input file
    identify_file_path(input_fastq)
    identify_file_path(ref)
    # prepare title
    location = args.chrom + ':' + str(args.pos - args.len) + '-' + str(args.pos + args.len)
    title = location + ':' + args.strand
    # check output path
    results_path = args.output
    build_out_path(results_path)

    print("Running samtools ...")
    sample_feature = run_samtools(input_fastq, location, ref, results_path, 'Sample',cpu)
    if sample_feature.shape[0]<args.len*2+1:
        raise Exception('Invalid position selection: exceeds reference length or contains positions with no coverage.')
    pos_list = list(sample_feature[1].values)
    base_list = list(sample_feature[2].values)
    try:
        identify_file_path(control_fastq)
        control_feature = run_samtools(control_fastq, location, ref, results_path, 'Control',cpu)
        final_feature = pd.concat([sample_feature, control_feature], axis=0)
    except Exception:
        print("control files do not exist, will turn to the single mode")
        final_feature = sample_feature
        final_feature['Group'] = 'Single'

    if args.rna:
        base_list = (''.join(base_list)).replace("T", "U")
        base_list = list(base_list)
        title = 'RNA  ' + title
    else:
        title = 'DNA  ' + title

    print("Calculating alignment feature...")

    if args.strand == '+':
        feature_matrix = ['.', 'A', 'T', 'C', 'G']
    else:
        feature_matrix = [",", 'a', 't', 'c', 'g']


    final_feature = final_feature.apply(lambda x: count_mis(x, feature_matrix), axis=1)
    final_feature['Match'] = final_feature[feature_matrix[0]]
    final_feature.drop(feature_matrix[0], inplace=True, axis=1)
    final_feature.columns = ['Chrom', 'Position', 'Base', 'Coverage', 'Align string', 'Q string', 'Group', 'A', 'T',
                             'C', 'G', 'Match']
    final_feature.to_csv(results_path + '/alignment_feature.csv', index=False)
    print("Alignment feature file saved as "+results_path + '/alignment_feature.csv')

    final_feature = pd.melt(final_feature, id_vars=['Position', 'Group'], value_vars=['A', 'T', 'C', 'G', 'Match'])

    print("Start to plot ...")
    alignment_plot(final_feature, pos_list, base_list, title, args.pos, results_path)
    print('Finished')
