#!python
# -*- coding: utf-8 -*-
# Copyright(c) Ryuichiro Nakato <rnakato@iqb.u-tokyo.ac.jp>
# All rights reserved.

import argparse
import os
os.environ["MPLBACKEND"] = "Agg"
os.environ["MPLCONFIGDIR"] = "/tmp/mplconfig"
import sys
import pandas as pd
import subprocess
import numpy as np
import random
from pybedtools import BedTool
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

def calculate_ratio(border_bed, gene_bed, allgene_bed):
    gene_all = allgene_bed.intersect(border_bed, u=True)
    gene_deg = gene_bed.intersect(border_bed, u=True)

    gene_num = len(gene_all)
    deg_num = len(gene_deg)

    return deg_num / gene_num if gene_num != 0 else 0

def back_function(border, gene_bed, permutation_times, len_border, allgene_bed):
    dist_randomRatio = []
    for _ in range(permutation_times):
        randomBorder = border.sample(len_border)
        randomBorder_bed = BedTool.from_dataframe(randomBorder)
        randomRatio = calculate_ratio(randomBorder_bed, gene_bed, allgene_bed)
        dist_randomRatio.append(randomRatio)

    d = np.array(dist_randomRatio)
    return [np.quantile(d, quantile) for quantile in [0.25, 0.75, 0.05, 0.95, 0.025, 0.975]]

def plot_graph(df, outputname):
    plt.rcParams['font.size'] = '12'

    plt.plot(df["Distance"]/1000,df["Ratio"],"m")
    plt.fill_between(df["Distance"]/1000,df["low50"],df["high50"],color="grey",alpha=0.5)
    plt.fill_between(df["Distance"]/1000,df["low90"],df["high90"],color="grey",alpha=0.3)
    plt.fill_between(df["Distance"]/1000,df["low95"],df["high95"],color="grey",alpha=0.1)
    plt.xlabel("Distance from TAD boundary (kb)",fontsize=15)
    plt.ylabel("Fraction of DEGs",fontsize=15)

    q50 = mpatches.Patch(color='grey',alpha=0.5,label='50% quantile')
    q90 = mpatches.Patch(color='grey',alpha=0.3,label='90% quantile')
    q95 = mpatches.Patch(color='grey',alpha=0.1,label='95% quantile')
    plt.legend(handles=[q50, q90, q95],fontsize=12)

    plt.savefig(outputname)

def set_border(border, i):
    border_temp = border.copy()
    border_temp[1] = np.maximum(border_temp[1] - i, 0)
    border_temp[2] = border_temp[2] + i
    return border_temp

def permutation_test_ratio(border, allborder, gene_bed, allgene_bed, permutation_times, max_distance, distance_step):
    select_ratios = []
    random_ratios = []
    positions = []

    for i in range(0, max_distance +1, distance_step):
        print(f"Distance {i} bp")
        border_temp = set_border(border, i)
        allborder_temp = set_border(allborder, i)

        border_temp_bed = BedTool.from_dataframe(border_temp)
        select_ratios.append(calculate_ratio(border_temp_bed, gene_bed, allgene_bed))

        len_border = len(border)
        random_ratios.append(back_function(allborder_temp, gene_bed, permutation_times, len_border, allgene_bed))
        positions.append(i)

    random_ratios = np.array(random_ratios)    # convert list of tuples to numpy array
    random_ratios = random_ratios.transpose()  # transpose array to get separate arrays for each quantile

    return select_ratios, random_ratios, positions

def main():
    parser = argparse.ArgumentParser()
    tp = lambda x:list(map(str, x.split(':')))
    parser.add_argument("--border_test",  help="<TAD boundary to be tested (BED format)>", type=str, default=None)
    parser.add_argument("--border_control",  help="<TAD boundary as background (BED format)>", type=str, default=None)
    parser.add_argument("--gene_test",  help="<Genes to be tested (BED format)>", type=str, default=None)
    parser.add_argument("--gene_control",  help="<Genes as background (BED format)>", type=str, default=None)
    parser.add_argument("-o", "--output", help="Output name (*.pdf or *.png, default: output.pdf)", type=str, default="output.pdf")
    parser.add_argument("-n", help="Number of permutation (default: 1000)", type=int, default=1000)
    parser.add_argument("--maxdistance", help="Max distance (bp, default: 300000)", type=int, default=300000)
    parser.add_argument("--step", help="Step of distance (bp, default: 10000)", type=int, default=10000)

    args = parser.parse_args()
#    print(args)

    if args.border_test is None:
        print ("Error: specify --border_test.")
        parser.print_help()
        exit()
    if args.border_control is None:
        print ("Error: specify --border_control.")
        parser.print_help()
        exit()
    if args.gene_test is None:
        print ("Error: specify --gene_test.")
        parser.print_help()
        exit()
    if args.gene_control is None:
        print ("Error: specify --gene_control.")
        parser.print_help()
        exit()

    print ("   TAD boundary to be tested: " + args.border_test)
    print ("   TAD boundary as background: " + args.border_control)
    print ("   Genes to be tested: " + args.gene_test)
    print ("   Genes as background: " + args.gene_control)
    print ("   Permutation time: " + str(args.n))
    print ("   Max distance: " + str(args.maxdistance) + " bp")
    print ("   Step of distance: " + str(args.step) + " bp")
    print ("   Output file: " + args.output)

    border = pd.read_csv(args.border_test, sep="\t", header=None)
    allborder = pd.read_csv(args.border_control, sep="\t", header=None)
    gene_bed = BedTool(args.gene_test)
    allgene_bed = BedTool(args.gene_control)
    permutation_times = args.n
    outputname = args.output
    max_distance = args.maxdistance
    distance_step = args.step

    select_ratios, random_ratios, positions = permutation_test_ratio(border,
                                                                     allborder,
                                                                     gene_bed,
                                                                     allgene_bed,
                                                                     permutation_times,
                                                                     max_distance,
                                                                     distance_step)

    df = pd.DataFrame({
        'Distance': positions,
        'Ratio': select_ratios,
        'low50': random_ratios[0],
        'high50': random_ratios[1],
        'low90': random_ratios[2],
        'high90': random_ratios[3],
        'low95': random_ratios[4],
        'high95': random_ratios[5]
    })

    plot_graph(df, outputname)


if(__name__ == '__main__'):
    main()
