#!/usr/bin/env python
# running HiSV from .hic format
import numpy as np
import math
import os, argparse, datetime, sys
import pandas as pd
from multiprocessing import Pool, Manager
from hisv.utils import *
import hicstraw
import prox_tv as ptv
import multiprocessing

def point2region(df, resolution, index_mapping, chrom1, chrom2, outfile):
    """
    Merge close intervals in the SV results based on a maximum gap of bins for Pos1 and Pos2.
    """
    pos1list = df['pos1'].tolist()
    pos2list = df['pos2'].tolist()
    result = group_position(pos1list, pos2list)
    with open(outfile, 'a') as out:
        for pos1_start, pos1_end, pos2_start, pos2_end in zip(result['pos1_start'], result['pos1_end'], result['pos2_start'], result['pos2_end']):
            out.write(chrom1 + '\t' + str(index_mapping['row_mapping'][pos1_start] * resolution) + '\t' + 
                      str(index_mapping['row_mapping'][pos1_end] * resolution + resolution) + '\t' + chrom2 + '\t' +  
                      str(index_mapping['col_mapping'][pos2_start] * resolution) + '\t' + str(index_mapping['col_mapping'][pos2_end] * resolution + resolution))

def local_saliency(matrix, win):
    """
    # calculate local saliency map
    :param matrix: current Hi-C contact matrix
    :param win: 'Local region window.
    :return: saliency map
    """
    d = int(win / 2)
    N, M = len(matrix), len(matrix[0])
    nm_matrix = np.full((N, M), 0.0)
    Dist_matrix = np.full((win, win), 0.0)

    for m in range(win):
        for n in range(win):
            Dist_matrix[m][n] = np.sqrt(np.square(m) + np.square(n))

    it = np.nditer(matrix, flags=['multi_index'])
    while not it.finished:
        idx = it.multi_index
        data = matrix[idx]
        if data != 0:
            cur_matrix = matrix[idx[0] - d:idx[0] + d, idx[1] - d:idx[1] + d]
            cur_raw = cur_matrix.shape[0]
            cur_col = cur_matrix.shape[1]
            if cur_col == cur_raw == win and np.sum(cur_matrix) != 0:
                nm_matrix[idx] = 1 - math.exp(-np.mean(abs(data - cur_matrix) / (1 + Dist_matrix)))
            else:
                nm_matrix[idx] = 0
        it.iternext()
    return nm_matrix


def Calculating_diagonal_data(matrix):
    """
    # normalization matrix by diagonal to remove distance bias
    Calculating each diagonal mean and std
    """
    N, M = len(matrix), len(matrix[0])
    Diagonal_mean = np.full(M, 0.0)
    Diagonal_std = np.full(M, 0.0)
    std = []
    for d in range(N):
        intermediate = []
        c = d
        r = 0
        while r < N - d:
            intermediate.append(matrix[r][c])
            r += 1
            c += 1
        intermediate = np.array(intermediate)
        Diagonal_mean[d] = (np.mean(intermediate))
        Diagonal_std[d] = (np.std(intermediate))
    return Diagonal_mean, Diagonal_std


def Distance_normalization(matrix):
    """
    # normalization matrix by diagonal to remove distance bias
    norm_data = (data - mean_data) / mean_std
    """
    Diagonal_mean, Diagonal_std = Calculating_diagonal_data(matrix)
    N, M = len(matrix), len(matrix[0])
    for d in range(N):
        c = d
        r = 0
        while r < N - d:
            if Diagonal_std[d] == 0:
                matrix[r][c] = 0
            else:
                if matrix[r][c] - Diagonal_mean[d] < 0:
                    matrix[r][c] = 0
                else:
                    matrix[r][c] = (matrix[r][c] - Diagonal_mean[d]) / Diagonal_std[d]
            r += 1
            c += 1
    return matrix


def calculate_bins_to_delete(gap_df):
    """
    get gap region index
    """
    bins_to_delete = set()
    for _, row in gap_df.iterrows():
        start_bin = row["start_bin"]
        end_bin = row["end_bin"]
        bins_to_delete.update(range(start_bin, end_bin + 1))
    return bins_to_delete


def extend_bins(bins_set, max_bins):
    """
    get zero region index
    """
    zero_win = 10
    extended_bins = set()
    for bin_idx in bins_set:
        start = max(0, bin_idx - zero_win)
        end = min(max_bins, bin_idx + zero_win + 1)
        extended_bins.update(range(start, end))
    return extended_bins


def filter_zero_region_intra(matrix):
    """
    remove zero region in intra- fcontact matrix
    reture removed matrix and remove index_mapping (same row and col)
    """
    bins = matrix.shape[0]
    # zero region
    zero_sum_rows = set(np.where(matrix.sum(axis=1) == 0)[0])
    bins_to_delete = extend_bins(zero_sum_rows, bins)
    # remove zero region
    all_bins = list(range(bins))
    bins_to_keep = [i for i in all_bins if i not in bins_to_delete]
    index_mapping = {new_idx: original_idx for new_idx, original_idx in enumerate(bins_to_keep)}
    matrix = matrix[np.ix_(bins_to_keep, bins_to_keep)]
    return matrix, index_mapping


def filter_gap_region_inter(matrix):
    """
    remove gap region and zero region in inter- fcontact matrix
    reture removed matrix, removed row- and col- index_mapping 
    """
    chr1_bins, chr2_bins = np.shape(matrix)
    # zero region for row or col
    zero_sum_rows_chr1 = set(np.where(matrix.sum(axis=1) == 0)[0])
    zero_sum_cols_chr2 = set(np.where(matrix.sum(axis=0) == 0)[0])
    chr1_bins_to_delete = extend_bins(zero_sum_rows_chr1, chr1_bins)
    chr2_bins_to_delete = extend_bins(zero_sum_cols_chr2, chr2_bins)
    # remove zero and gap bins
    chr1_bins_to_keep = sorted(set(range(chr1_bins)) - chr1_bins_to_delete)
    chr2_bins_to_keep = sorted(set(range(chr2_bins)) - chr2_bins_to_delete)
    index_mapping = {
    "row_mapping": {new_idx: original_idx for new_idx, original_idx in enumerate(chr1_bins_to_keep)},
    "col_mapping": {new_idx: original_idx for new_idx, original_idx in enumerate(chr2_bins_to_keep)}
    }
    matrix = matrix[np.ix_(chr1_bins_to_keep, chr2_bins_to_keep)]
    return matrix, index_mapping


def process_chrom_pair_hic(chrom_names, i, j, win, weight, cutoff):
    if i == j:
        print(f"Processing {chrom_names[i]} intra-chromosomal")
        cur_chrom = chrom_names[i]
        matrix = load_matrix_from_hic(hicfile, cur_chrom, cur_chrom, intra_resolution)
        if np.sum(matrix) != 0:
            # filter gap region
            matrix, index_mapping = filter_zero_region_intra(matrix)
            # call SV
            results = call_SV_intra_matrix(matrix, index_mapping, chrom_names[i], win, weight, cutoff)
            return results
        else:
            results = []
            return results
    else:
        print(f"Processing {chrom_names[i]} and {chrom_names[j]} inter-chromosomal")
        matrix = load_matrix_from_hic(hicfile, chrom_names[i], chrom_names[j], inter_resolution)
        # filter gap region
        matrix, index_mapping = filter_gap_region_inter(matrix)
        if np.sum(matrix) != 0:
            # call SV
            results = call_SV_inter_matrix(matrix, index_mapping, chrom_names[i], chrom_names[j], win, weight, cutoff)
            return results
        else:
            results = []
            return results

def call_SV_intra_matrix(matrix, index_mapping, chrom, win, weight, cutoff):
    """
    call SV from intra- contact matrix
    """
    results= []
    combine_result = []
    num = np.shape(matrix)[0]
    il = np.tril_indices(num)
    matrix[il] = 0
    # z-score for distance normalization
    Dist_norm_mat = Distance_normalization(matrix)
    # calculating local saliency
    local_sali_matrix = local_saliency(Dist_norm_mat, win)
    # Segmentation
    # seg_mat = denoise_tv_chambolle(local_sali_matrix, weight=weight)
    seg_mat = ptv.tv1_2d(local_sali_matrix, weight, n_threads=1, max_iters=0, method='dr')

    # Filter
    for m in range(num):
        for n in range(num):
            if seg_mat[m][n] > cutoff:
                print(chrom, index_mapping[m], index_mapping[n], seg_mat[m][n])
                results.append([chrom, m, chrom, n])
    # Combine
    if len(results) > 0:
        result_df = pd.DataFrame(list(results), columns=['chrom1', 'pos1', 'chrom2', 'pos2'])
        # combine result
        pos1list = result_df['pos1'].tolist()
        pos2list = result_df['pos2'].tolist()
        result = group_position(pos1list, pos2list)
        for pos1_start, pos1_end, pos2_start, pos2_end in zip(result['pos1_start'], result['pos1_end'], result['pos2_start'], result['pos2_end']):
            combine_result.append([chrom, index_mapping[pos1_start] * intra_resolution, index_mapping[pos1_end] * intra_resolution + intra_resolution,
                               chrom, index_mapping[pos2_start] * intra_resolution, index_mapping[pos2_end] * intra_resolution + intra_resolution])
    return combine_result
    


def call_SV_inter_matrix(matrix, index_mapping, chrom1, chrom2, win, weight, cutoff):
    """
    call SV from inter- contact matrix
    """
    results = []
    combine_result = [] 
    row, col = np.shape(matrix)
    # calculating local saliency
    local_sali_matrix = local_saliency(matrix, win)
    # Segmentation
    # seg_mat = denoise_tv_chambolle(local_sali_matrix, weight=weight)
    seg_mat = ptv.tv1_2d(local_sali_matrix, weight, n_threads=1, max_iters=0, method='dr')

    # Filter
    for m in range(row):
        for n in range(col):
            if seg_mat[m][n] > cutoff:
                results.append([chrom1, m, chrom2, n])

    if len(results) > 0:
        result_df = pd.DataFrame(list(results), columns=['chrom1', 'pos1', 'chrom2', 'pos2'])
        # combine result
        pos1list = result_df['pos1'].tolist()
        pos2list = result_df['pos2'].tolist()
        result = group_position(pos1list, pos2list)
        for pos1_start, pos1_end, pos2_start, pos2_end in zip(result['pos1_start'], result['pos1_end'], result['pos2_start'], result['pos2_end']):
            combine_result.append([chrom1, index_mapping['row_mapping'][pos1_start] * inter_resolution, index_mapping['row_mapping'][pos1_end] * inter_resolution + inter_resolution,
                               chrom2, index_mapping['col_mapping'][pos2_start] * inter_resolution, index_mapping['col_mapping'][pos2_end] * inter_resolution + inter_resolution])
    return combine_result


def run_HiSV_hic(hic_file, intra_resolution, inter_resolution, result_file, win, weight, cut_off, ncores):
    # Call SVs for hic data based hic format
    hic = hicstraw.HiCFile(hic_file)
    hic_resolutions = hic.getResolutions()
    # error resolution parameter
    if intra_resolution not in hic_resolutions:
        print("The resolution parameter is inconsistent with the resolution of the HiC data.")
        return "Parameter Error."
    if inter_resolution not in hic_resolutions:
        print("The resolution parameter is inconsistent with the resolution of the HiC data.")
        return "Parameter Error."
    # get chrom list
    chromosomes = hic.getChromosomes()
    chrom_names = [chrom.name for chrom in chromosomes]
    chrom_names = [chrom for chrom in chrom_names if chrom not in ['All', 'Y', 'M', 'MT', 'ALL', 'chrY', 'chrM']]

    with multiprocessing.Pool(processes=ncores) as pool:
        tasks = [(chrom_names, i, j, win, weight, cut_off) for i in range(len(chrom_names)) for j in range(i, len(chrom_names))]
        results = pool.starmap(process_chrom_pair_hic, tasks)
    
    # filtered_results = [result for result in results if result]
    flattened_data = []
    for group in results:
        for item in group:
            flattened_data.append(item)
    result_df = pd.DataFrame(flattened_data, columns=['chrom1', 'start_pos1', 'start_pos2', 'chrom2', 'end_pos1', 'end_pos2'])
    result_df.to_csv(result_file, sep='\t', index=False)

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Identification of SVs based on Hi-C interaction matrix.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--hic_file', help='The hi-c file based on .hic format.')
    parser.add_argument('--intra_resolution', type=int, default=50000, help='Resolution of intra Hi-C contact matrix.')
    parser.add_argument('--inter_resolution', type=int, default=100000, help='Resolution of inter Hi-C contact matrix.')
    parser.add_argument('--window', type=int, default=10, help='''Local region window.''')
    parser.add_argument('--regularization', type=float, default=0.2, help='The regularization parameter')
    parser.add_argument('--cutoff', type=float, default=0.6, help='''Threshold for filtering SV segments''')
    parser.add_argument('--cores', type=int, help='The number of cores used for parallel computing')
    parser.add_argument('--output', help='''Output file path.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    return args, commands

if __name__ == '__main__':

    start_t = datetime.datetime.now()
    args, commands = getargs()

    result_file = args.output
    hicfile = args.hic_file
    intra_resolution = args.intra_resolution
    inter_resolution = args.inter_resolution
    win = args.window
    cutoff = args.cutoff
    reg = args.regularization
    n_cores = args.cores
    run_HiSV_hic(hic_file=hicfile, intra_resolution=intra_resolution, inter_resolution=inter_resolution, 
                 result_file=result_file, win=win, weight=reg, cut_off=cutoff, ncores=n_cores)

    end_t = datetime.datetime.now()
    elapsed_sec = (end_t - start_t).total_seconds()
    print("running time: " + "{:.2f}".format(elapsed_sec) + " seconds")
   