#!python
# -*- coding: utf-8 -*-
# Copyright(c) Ryuichiro Nakato <rnakato@iqb.u-tokyo.ac.jp>
# All rights reserved.

import os
os.environ["MPLBACKEND"] = "Agg"
os.environ["MPLCONFIGDIR"] = "/tmp/mplconfig"
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy.cluster.hierarchy import linkage, fcluster, leaves_list
from custardpy.HiCmodule import JuicerMatrix

def getbedGraph(file, label):
    d = pd.read_csv(file, sep="\t", header=None, index_col = [1])
    d = d.drop([0,2], axis=1)
    d.columns = [label]
    return d

def get_boundary(boundaryfile):
    boundary = pd.read_csv(boundaryfile, sep="\t", header=None)
    boundary.columns = ["chromosome", "start", "end"]
    return boundary

def get_samples(dirs, labels, chr, type, resolution, cooler=False):
    samples = pd.DataFrame()

    for i, dir in enumerate(dirs):
        if cooler:
            observed = f"{dir}/Matrix/{resolution}/balanced.{chr}.matrix.gz"
        else:
            observed = f"{dir}/Matrix/intrachromosomal/{resolution}/observed.{type}.{chr}.matrix.gz"

        is_hic_cool = observed.endswith('.hic') or observed.endswith('.cool') or observed.endswith('.mcool')
        if is_hic_cool or os.path.exists(observed):
            print(observed)
            m = JuicerMatrix("RPM", observed, resolution, chrom=chr)
            is_vals = m.getInsulationScore(distance=500000)
            positions = np.arange(len(is_vals)) * resolution
            d = pd.Series(is_vals, index=positions, name=labels[i])
        else:
            bedgraph = f"{dir}/InsulationScore/{type}/{resolution}/Insulationscore.{chr}.500k.{resolution}.bedGraph"
            print(bedgraph)
            d = getbedGraph(bedgraph, labels[i])

        samples = pd.concat([samples, d], axis=1)

    return samples
    

def plot_insulation_score(samples, boundary, chr, chrlen, labels, odir):
    figsize_y = int(chrlen/2000000)
    plt.figure(figsize=(figsize_y, 3))
    for sample in labels:
        plt.plot(samples.index, samples[sample], label=sample)

    for row in boundary.itertuples():
        plt.hlines([0], row[2], row[3], "blue")

    plt.xlim(0, chrlen)
    step = 1000000
    plt.xticks(np.arange(0, int(chrlen/step)*step, step))
    plt.legend(bbox_to_anchor=(1.01, 1.0), loc='upper left')
    plt.tight_layout()
    plt.savefig(odir + "Insulation_score." + chr + ".pdf")
    plt.close()

    plt.figure(figsize=(figsize_y, 3))
    for sample in labels[1:]:
        plt.plot(samples.index, samples[sample] - samples[labels[0]], label=sample)

    for row in boundary.itertuples():
        plt.hlines([-.3], row[2], row[3], "blue")

    plt.xlim(0, chrlen)
    plt.xticks(np.arange(0, int(chrlen/step)*step, step))
    plt.legend(bbox_to_anchor=(1.01, 1.0), loc='upper left')
    plt.tight_layout()
    plt.savefig(odir + "Insulation_score.diff." + chr + ".pdf")
    plt.close()

def annotate_boundary(d, diffsamples, labels):
    d['status'] = 'unknown'
    thre = 0.13
    thre_insu = -0.13
    num_half = int(diffsamples.shape[1] * 0.5)

    for posi in d.index:
        row = diffsamples.loc[posi,:]
        if sum(row <= thre_insu) >= num_half:
            d.loc[posi, 'status'] = 'Gain'
        elif sum(row > thre) >= num_half:
            d.loc[posi, 'status'] = 'Loss'
        elif d.loc[posi, labels[0]] >= 0.8:  # Insulation score of the control sample
            d.loc[posi, 'status'] = 'Non-boundary'
        else:
            d.loc[posi, 'status'] = 'Robust'

    return d


def get_averagedIS_for_boundaries(samples, df, boundary):
    all = pd.DataFrame(index=samples)
    for row in boundary.itertuples():
        s=row[2]
        e=row[3]
        a = df[(df.index>=s) & (df.index<e)].mean(axis=0)
        all = pd.concat([all, a], axis=1)

    all = all.T
    all.index = boundary.index
    all = pd.concat([boundary, all], axis=1)
    all.index = boundary["start"]
    all = all[all.loc[:,samples].mean(axis=1)>0.3]

    return all

def calculate_mean_within_boundaries(df, start, end):
    return df[(df.index >= start) & (df.index < end)].mean()

def get_averaged_insulation_score_for_boundaries(samples, df, boundary):
    results = []
    
    for row in boundary.itertuples():
        start = row[2]
        end = row[3]
        averaged = calculate_mean_within_boundaries(df, start, end)
        results.append(averaged)

    all_df = pd.concat(results, axis=1).T
    all_df.index = boundary.index
    all_df = pd.concat([boundary, all_df], axis=1)
    all_df.index = boundary["start"]
    all_df = all_df[all_df.loc[:,samples].mean(axis=1) > 0.3]

    return all_df


def annotate_boundary_chromosome(chr, chrlen, samples, boundary, labels):

    plot_insulation_score(samples, boundary, chr, chrlen, labels, odir)

    df_boundary = get_averaged_insulation_score_for_boundaries(labels, samples, boundary)

    diffsamples = pd.DataFrame()
    for sample in labels[1:]:
        diff = df_boundary[sample] - df_boundary[labels[0]]
        diffsamples = pd.concat([diffsamples, diff], axis=1)

    diffsamples.columns = labels[1:]

    df_boundary = annotate_boundary(df_boundary, diffsamples, labels)

    return df_boundary, diffsamples


def plot_heatmap_diffsamples(df_boundary_all, diff_boundary_all, num_clusters):
    plt.figure()
    sns.clustermap(diff_boundary_all.corr(), cmap="bwr")
    plt.savefig(odir + "heatmap.diff.correlation.pdf", bbox_inches="tight")
    plt.close()

    linkage_matrix = linkage(diff_boundary_all, 'ward')
    cluster_labels = fcluster(linkage_matrix, num_clusters, criterion='maxclust')
    diff_boundary_all['cluster_labels'] = cluster_labels
    diff_boundary_all = diff_boundary_all.sort_values('cluster_labels')

    linkage_matrix = linkage(diff_boundary_all.T, 'ward')
    reordered_cols = leaves_list(linkage_matrix)
    diff_boundary_all = diff_boundary_all.iloc[:, reordered_cols]

    plt.figure()
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 12])

    ax0 = plt.subplot(gs[0])
    ax0.matshow(diff_boundary_all['cluster_labels'].values.reshape(-1, 1), cmap='Accent', aspect='auto')
    ax0.set_xticks([]) 
    ax0.set_yticks([]) 
    ax0.set_title("Cluster") 
    unique_clusters = np.unique(cluster_labels)
    for cluster_id in unique_clusters:
        indices = np.where(diff_boundary_all['cluster_labels'] == cluster_id)[0]
        center_index = indices.min() + len(indices) / 2
        ax0.text(-1, center_index, f"{cluster_id}", va='center', ha='right', color='black')

    ax1 = plt.subplot(gs[1])
    sns.heatmap(diff_boundary_all.drop('cluster_labels', axis=1), cmap='bwr_r', yticklabels=False, cbar_kws={'label': 'Value'}, ax=ax1)
    ax1.set_title("Heatmap (differential insulation score)") 
    plt.savefig(odir + "heatmap.diff.clustered.pdf", bbox_inches="tight")
    plt.close()
    
    df_boundary_all['cluster_labels'] = cluster_labels
    df_boundary_all = df_boundary_all.sort_values('cluster_labels')

    return df_boundary_all

def annotate_boundary_genome(boundary, dirs, labels, resolution, genometable, cooler=False):
    df_boundary_all = pd.DataFrame()
    diff_boundary_all = pd.DataFrame()

    for row in genometable.itertuples():
        chr = row[0]
        chrlen = row[1]

        if chr == "chrY" or chr == "chrM":
            continue
        print (chr)

        boundary_chr = boundary[boundary["chromosome"]==chr]
        samples = get_samples(dirs, labels, chr, type, resolution, cooler=cooler)
        df, df_diff = annotate_boundary_chromosome(chr, chrlen, samples, boundary_chr, labels)
        df_boundary_all = pd.concat([df_boundary_all, df], axis=0)
        diff_boundary_all = pd.concat([diff_boundary_all, df_diff], axis=0)

    return df_boundary_all, diff_boundary_all

if(__name__ == '__main__'):
    parser = argparse.ArgumentParser()
    tp = lambda x:list(map(str, x.split(':')))
    parser.add_argument("input",  help="<Input directory>:<label>", type=tp, nargs='*')
    parser.add_argument("--type", help="normalize type (default: SCALE)", type=str, default="SCALE")
    parser.add_argument("--boundary", help="Boundary file (BED format)", type=str)
    parser.add_argument("--gt", help="Genome table", type=str)
    parser.add_argument("-r", "--resolution", help="resolution (default: 25000)", type=int, default=25000)
    parser.add_argument("--ncluster", help="number of cluster (default: 4)", type=int, default=4)
    parser.add_argument("--odir", help="Output directory (default: output_boundary_clustering)", type=str, default="output_boundary_clustering/")
    parser.add_argument("--cooler", help="Load matrix from cooler-format directory structure", action='store_true')
    args = parser.parse_args()

    dirs = []
    labels = []
    for input in args.input:
        dirs.append(input[0])
        if (len(input) >1):
            labels.append(input[1])
        else:
            labels.append("")

#    print(dirs)
#    print(labels)

    if len(dirs) == 0:
        print ("Error: specify input data (-i).")
        exit()

    if args.boundary is None:
        print ("Error: specify boundary file (--boundary).")
        exit()
    else:
        boundary = get_boundary(args.boundary)

    if args.gt is None:
        print ("Error: specify genome_table file (--gt).")
        exit()
    else:
        genometable = pd.read_csv(args.gt, delimiter='\t', index_col=[0], header=None)

    resolution = args.resolution
    type = args.type
    num_cluster = args.ncluster
    odir = args.odir + "/"
    os.makedirs(odir, exist_ok=True)

    df_boundary_all, diff_boundary_all = annotate_boundary_genome(boundary, dirs, labels, resolution, genometable, cooler=args.cooler)

    df_boundary_all = plot_heatmap_diffsamples(df_boundary_all, diff_boundary_all, num_cluster)

    df_boundary_all.to_csv(odir + "Annotated_boundaries.tsv", sep="\t", index=False)
    