#!python

import numpy as np
import os
from collections import defaultdict
from scipy.ndimage import gaussian_filter
from sklearn.decomposition import PCA
import argparse, sys, os
from optparse import OptionParser

from itertools import groupby

"""
Recalculate SV breakpoint in high resolution matrix
Reference
https://github.com/XiaoTaoWang/NeoLoopFinder
"""


def get_SV_intra_info(SV_file, binSize):
    '''

    :param SV_file:HiSV intro-chrom result
    :return:{['chrom1'], ['bp1_start'], [bp1_end], ['chrom1'], ['bp2_start'], [bp2_end] }
    '''
    SV_line = open(SV_file)
    SV_info = defaultdict(list)
    for line in SV_line:
        line = line.strip().split('\t')
        # if line[0] in chrlist:
        SV_info['chrom1'].append(line[0])
        SV_info['bp1_start'].append(int(line[1])-binSize)
        SV_info['bp1_end'].append(int(line[2])+binSize)
        SV_info['chrom2'].append(line[0])
        SV_info['bp2_start'].append(int(line[3])-binSize)
        SV_info['bp2_end'].append(int(line[4])+binSize)
    return SV_info


def get_SV_inter_info(SV_file, binSize):
    '''

    :param SV_file: HiSV inter-chrom result
    :return: {['chrom1'], ['bp1_start'], [bp1_end], ['chrom1'], ['bp2_start'], [bp2_end] }
    '''
    SV_line = open(SV_file)
    SV_info = defaultdict(list)
    for line in SV_line:
        line = line.strip().split('\t')
        # if line[0] in chrlist:
        SV_info['chrom1'].append(line[0])
        SV_info['bp1_start'].append(int(line[1])-binSize)
        SV_info['bp1_end'].append(int(line[2])+binSize)
        SV_info['chrom2'].append(line[3])
        SV_info['bp2_start'].append(int(line[4])-binSize)
        SV_info['bp2_end'].append(int(line[5])+binSize)
    return SV_info


def group(data):
    # merge the same values
    pos = []
    for i in range(len(data)-1):
        if data[i] != data[i+1]:
            pos.append(i)
    return pos


def recalculate_breakpoint(matrix):
    '''

    :param SV_region matrix:
    :return: breakpoint position
    '''
    # row
    differ_value = []
    corr = gaussian_filter(np.corrcoef(matrix, rowvar=True), sigma=1)

    corr[np.isnan(corr)] = 0
    pca = PCA(n_components=3, whiten=True)
    pca1 = pca.fit_transform(corr)[:, 0]
    # loci_1 = np.where(pca1 >= 0)[0]

    loci_1 = np.sign(pca1)
    poslist = group(loci_1)
    if poslist:
        for cur_pos in poslist:
            differ_value.append(abs(pca1[cur_pos]-pca1[cur_pos+1]))
        index = differ_value.index(max(differ_value))
        up_row = poslist[index]

    else:
        up_row = 0

    # column
    differ_value = []
    corr = gaussian_filter(np.corrcoef(matrix, rowvar=False), sigma=1)
    corr[np.isnan(corr)] = 0
    pca = PCA(n_components=3, whiten=True)
    pca2 = pca.fit_transform(corr)[:, 0]

    loci_2 = np.sign(pca2)
    poslist = group(loci_2)
    if poslist:
        for cur_pos in poslist:
            differ_value.append(abs(pca2[cur_pos]-pca2[cur_pos+1]))
        index = differ_value.index(max(differ_value))
        up_col = poslist[index]
    else:
        up_col = 0
    return up_row, up_col


def get_SV_matrix(SV_region, matrix_file, binSize):
    '''

    :param SV_region: HiSV result by lower resolution Hi-C matrix
    :param matrix_file: high-resolution matrix
    :param binSize: the bin size of high-resolution matrix
    :return: SV_region matrix
    '''
    row_num = (SV_region['bp1_end'] - SV_region['bp1_start']) // binSize + 1
    col_num = (SV_region['bp2_end'] - SV_region['bp2_start']) // binSize + 1
    Init_matrix = np.full((row_num, col_num), 0)
    bed_list = open(matrix_file)
    bed_list.readline()
    for line in bed_list:
        line = line.strip().split('\t')
        pos1 = int(line[1])
        pos2 = int(line[3])
        if SV_region['bp1_start'] <= pos1 <= SV_region['bp1_end'] and SV_region['bp2_start'] <= pos2 <= SV_region['bp2_end']:

            cur_pos1 = (pos1 - SV_region['bp1_start']) // binSize
            cur_pos2 = (pos2 - SV_region['bp2_start']) // binSize
            Init_matrix[cur_pos1][cur_pos2] = int(line[4])
    return Init_matrix


def getargs():

    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Determination of SV breakpoint''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--hic_file', help='High resolution hi-C contact matrix file.')
    parser.add_argument('--binsize', type=int, default=5000, help='Bin size of high-resolution Hi-C contact matrix.')
    parser.add_argument('--HiSV_result', help='Result file from HiSV')
    parser.add_argument('--order_bin', type=int, default=50000, help='''Bin size of Hi-C contact matrix for HiSV.''')
    parser.add_argument('--name', help='''Sample name''')
    parser.add_argument('--output', help='''Output file path.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    return args, commands


if __name__ == '__main__':

    args, commands = getargs()
    hicfile = args.hic_file
    result_file = args.HiSV_result
    name = args.name
    binSize = args.binsize
    order_bin = args.order_bin
    output = args.output
    if not os.path.exists(output):
        os.makedirs(output)
    if not os.path.exists(hicfile):
        print("high resolution Hi-C file does not exist.")
        sys.exit(-1)
    if not os.path.exists(result_file):
        print("HiSV result file does not exist.")
        sys.exit(-1)

    # hicfile = "/mnt/d/HiSV_test_data/result/Inter_matrix"
    # result_file = "/mnt/d/HiSV_test_data/test_data/result/HiSV_inter_SV_result.txt"
    # name = "K562"
    # binSize = 5000
    # output = "/mnt/d/HiSV_test_data/test_data/result"

    if "inter" in result_file:
        SV_info = get_SV_inter_info(result_file, order_bin)
        output_file = os.path.join(output, 'HiSV_inter_SV_breakpoint.txt')
    else:
        SV_info = get_SV_intra_info(result_file, order_bin)
        output_file = os.path.join(output, 'HiSV_intra_SV_breakpoint.txt')

    # Determine whether the high-resolution hi-c matrix file is consistent with the HiSV result
    for i in range(len(SV_info['chrom1'])):
        chrom1 = SV_info['chrom1'][i]
        chrom2 = SV_info['chrom2'][i]
        # matrix_file = hicfile + str(chrom1) + '_' + str(chrom2) + '.txt'
        matrix_file = os.path.join(hicfile, name + '_%skb_%s_%s_bed.txt' % (str(binSize // 1000),
                                                                            chrom1, chrom2))
        if not os.path.exists(matrix_file):
            print("high-resolution matrix file does not exist.")
            sys.exit(-1)
    # Re-determine SV breakpoint by high-resolution Hi-C matrix
    with open(output_file, 'a') as out:
        for i in range(len(SV_info['chrom1'])):
            chrom1 = SV_info['chrom1'][i]
            chrom2 = SV_info['chrom2'][i]
            # matrix_file = hicfile + str(chrom1) + '_' + str(chrom2) + '.txt'
            matrix_file = os.path.join(hicfile, name + '_%skb_%s_%s_bed.txt' % (str(binSize // 1000),
                                                                                   chrom1, chrom2))
            SV_ragion = {'bp1_start': SV_info['bp1_start'][i], 'bp1_end': SV_info['bp1_end'][i],
                         'bp2_start': SV_info['bp2_start'][i], 'bp2_end': SV_info['bp2_end'][i]}
            matrix = get_SV_matrix(SV_ragion, matrix_file, binSize)
            row_pos, column_pos = recalculate_breakpoint(matrix)
            start_pos = row_pos * binSize + SV_info['bp1_start'][i]
            end_pos = column_pos * binSize + SV_info['bp2_start'][i]
            print('num ', i+1, "start: ", start_pos, "end: ", end_pos)
            out.write(str(chrom1) + '\t' + str(start_pos) + '\t' + str(chrom2) + '\t' + str(end_pos) + '\n')

