#!/usr/bin/env python
import numpy as np
from hisv import normCove
import argparse, sys, os
import pandas as pd


def read_file(resultfile):
    '''
    read HiSV intra-chromosomal result
    :param HiSV result file:
    :return: {'chrom', 'pos1_start', 'pos1_end', 'pos2_start', 'pos2_end'}
    '''
    result = pd.DataFrame(columns=['chrom', 'pos1_start', 'pos1_end', 'pos2_start', 'pos2_end'])
    with open(resultfile) as f:
        for line in f:
            line_data = line.strip().split('\t')
            result = result.append(
                {'chrom': line_data[0], 'pos1_start': int(line_data[1]), 'pos1_end': int(line_data[2]),
                 'pos2_start': int(line_data[3]), 'pos2_end': int(line_data[4])}, ignore_index=True)

    return result


def call_slope(data):
    '''
    Calculate the slope of the sub-matrix
    '''
    order = 1
    index = [i for i in range(1, len(data) + 1)]
    coeffs = np.polyfit(index, list(data), order)
    slope = coeffs[-2]
    return slope


def cal_slope(matrix):
    data_1 = matrix[:, 0]
    data_2 = matrix[:, -1]
    slope1 = call_slope(data_1)
    slope2 = call_slope(data_2)
    return slope1, slope2


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Define the type of SV event recognized by each HiSV''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--hic_file', help='Hi-C contact matrix file.')
    parser.add_argument('--gc_file', help='GC contene file of the reference genome.', default=None)
    parser.add_argument('--map_file', help='Mappability file of the reference genome.', default=None)
    parser.add_argument('--RE_file', help='RE sites file of the reference genome.', default=None)
    parser.add_argument('--binsize', type=int, default=50000, help='Bin size of final Hi-C contact matrix.')
    parser.add_argument('--withoutRE', action='store_true', default=False)
    parser.add_argument('--HiSV_result', help='Result file from HiSV')
    parser.add_argument('--name', help='''Sample name''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    return args, commands


def run():

    args, commands = getargs()
    hicfile = args.hic_file
    gc_file = args.gc_file
    mapscore_file = args.map_file
    cutsites_file = args.RE_file
    result_file = args.HiSV_result
    name = args.name
    binSize = args.binsize
    withoutRE = args.withoutRE
    print(withoutRE)
    '''
    hicfile = "/mnt/d/HiSV_test_data/test_data/hic_data/Intra_matrix"
    name = "K562"
    binSize = 50000
    result_file = "/mnt/d/HiSV_test_data/test_data/result/HiSV_intra_SV_result.txt"
    gc_file = "/mnt/d/HiSV_test_data/normData/hg19_1kb_GC.bw"
    mapscore_file = "/mnt/d/HiSV_test_data/normData/hg19_mappability_100mer.1kb.bw"
    # cutsites_file = "/mnt/d/HiSV_test_data/normData/hg19.HindIII.npz"
    isRE = False
    '''

    if not os.path.exists(hicfile):
        print("hic file does not exist.")
        sys.exit(-1)
    if not os.path.exists(gc_file):
        print("gc file does not exist.")
        sys.exit(-1)
    if not os.path.exists(mapscore_file):
        print("Mappability file does not exist.")
        sys.exit(-1)

    if withoutRE == False:
        if not os.path.exists(cutsites_file):
            print("RE file does not exist.")
            sys.exit(-1)

    # read result file from HiSV
    result = read_file(result_file)
    chr_list = list(set(list(result['chrom'])))
    print("chromosome list:", chr_list)
    print("Normalization Coverage...")
    coverList = pd.DataFrame(columns=['chrom', 'coverage'])

    # Determine whether the hi-c matrix file is consistent with the HiSV result
    for i in range(len(chr_list)):
        chrID = chr_list[i]
        matrix_file = os.path.join(hicfile, name + '_%skb_%s_%s_matrix.txt' % (str(binSize // 1000),
                                                                               chrID, chrID))
        if not os.path.exists(matrix_file):
            print("matrix file does not exist.")
            sys.exit(-1)

    for i in range(len(chr_list)):
        chrID = chr_list[i]
        matrix_file = os.path.join(hicfile, name + '_%skb_%s_%s_matrix.txt' % (str(binSize // 1000),
                                                                               chrID, chrID))
        if withoutRE:
            normed_coverage = normCove.normalized_coverage_withoutRE(matrix_file, gc_file, mapscore_file, binSize, chrID)
        else:

            normed_coverage = normCove.normalized_coverage(matrix_file, gc_file, mapscore_file, cutsites_file,
                                                           binSize, chrID)
        coverList = coverList.append({'chrom': chrID, 'coverage': normed_coverage}, ignore_index=True)

    # Define the type for each SV event
    for i in result.index:
        row = result.loc[i]
        print("________________________________________")
        cur_chr = row['chrom']
        matrix_file = os.path.join(hicfile, name + '_%skb_%s_%s_matrix.txt' % (str(binSize // 1000),
                                                                               cur_chr, cur_chr))
        contact_matrix = np.loadtxt(matrix_file)

        for j in coverList.index:
            cover_row = coverList.loc[j]
            if cover_row['chrom'] == cur_chr:
                cur_cover = cover_row['coverage']
                all_mean = np.mean(cur_cover[cur_cover > 0])

                start = int(row['pos1_start'] / binSize)
                end = int(row['pos2_end'] / binSize)
                print('SV_mean', np.mean(cur_cover[start:end][cur_cover[start:end] > 0]))
                print('all_mean', all_mean)
                # print(cur_cover[start:end])
                if np.mean(cur_cover[start:end][cur_cover[start:end] > 0]) > all_mean * 1.1:
                    print("Num ", str(i + 1), 'duplication')
                elif np.mean(cur_cover[start:end][cur_cover[start:end] > 0]) < all_mean * 0.9:
                    print("Num ", str(i + 1), 'deletion')
                # calculate the slope
                else:
                    start_1 = int(row['pos1_start'] / binSize)
                    end_1 = int(row['pos1_end'] / binSize)
                    start_2 = int(row['pos2_start'] / binSize)
                    end_2 = int(row['pos2_end'] / binSize)
                    cur_matrix = contact_matrix[start_1:end_1, start_2:end_2]
                    slope1, slope2 = cal_slope(cur_matrix)
                    print(slope1, slope2)
                    if slope1 >= 0 and slope2 >= 0:
                        print("Num ", str(i + 1), 'unbalanced_translocation')
                    elif slope1 <= 0 and slope2 <= 0:
                        print("Num ", str(i + 1), 'unbalanced_translocation')
                    elif slope1 < 0 < slope2:
                        print("Num ", str(i + 1), 'inversion')
                    else:
                        print("Num ", str(i + 1), 'balanced_translocation')


if __name__ == '__main__':
    run()
