#!python

from Bio import motifs
import math
import os
import sys
import argparse
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas
import glob
import argparse
from pathlib import Path
import pybio
import gzip
import scanRBP

clip = None
scanRBP_path = Path(__file__).parent.absolute()

help_0 = """Usage: scanRBP sequence output [options]
    * one sequence provided on the command line, generates output.png/pdf + output.tab

Usage: scanRBP filename.fasta [options]
    * one heatmap/matrix will be generated per sequence from the FASTA file
    * output name of the files will be sequence ids provided in the fasta file

Usage: scanRBP search search_term
    * returns list of proteins available matching search_term

Options:
     -annotate                Annotate each heatmap cell with the number
     -xlabels                 Display sequence (x-labels), default False
     -protein TARDBP.K562.00  Only analyze binding for the provided protein (default: analyze binding for all proteins in scanRBP database)
     -cumulative              Take only one protein (-protein) and plot binding across all sequences provided in the FASTA file
     -figsize "(10,20)"       Change matplotlib/seaborn figure size for the heatmap, example width=10, height=20
     -fontscale               Change font scale, default 0.2
     -heatmap title           Make heatmap (png+pdf) with title
     -output_folder folder    Store all results to the output folder (default: current folder)
     -nonzero                 All negative vector values are set to 0, not enabled by default
     -clip peaks.bed.gz       Use actual CLIP data (peaks from BED file), do not use PWMs
     -force                   If output files already exist, overwrite them, otherwise do not process
"""

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, add_help=False)
parser.add_argument('commands', help="command(s) to run", nargs='*')
parser.add_argument("-help", "-h", "--help", action="store_true")
parser.add_argument("-annotate", "--annotate", action="store_true", default=False)
parser.add_argument("-xlabels", "--xlabels", action="store_true", default=False)
parser.add_argument("-protein", "--protein", default=False)
parser.add_argument("-figsize", "--figsize", default=False)
parser.add_argument("-fontscale", "--fontscale", default=0.2, type=float)
parser.add_argument("-title", "--title", default=False)
parser.add_argument("-matrix", "--matrix", action="store_true", default=True)
parser.add_argument("-clip", "--clip", default=False)
parser.add_argument("-nonzero", "--nonzero", action="store_true", default=False)
parser.add_argument("-heatmap", "--heatmap", default=False)
parser.add_argument("-force", "--force", default=False)
parser.add_argument("-output_folder", "--output_folder", default=".")
parser.add_argument("-cumulative", "--cumulative", action="store_true", default=False)
parser.add_argument("-version", "--version", help="Print version", action="store_true")
parser.add_argument("-vlines", "--vlines", default="[]", type=str)

args = parser.parse_args()

print(f"scanRBP | v{scanRBP.version}, https://github.com/grexor/scanRBP")
print(f"scanRBP | config file: {scanRBP.config.config_fname()}")
print(f"scanRBP | data folder: {scanRBP.config.data_folder}")
print()

# find all proteins with search
def find_proteins(search):
    results, results_data = [], []
    for scan_id, data in scanRBP.database.proteins.items():
        protein_lower = data["protein"].lower()
        protein_desc_lower = data["description"].lower()
        scan_id_lower = scan_id.lower()
        search_lower = search.lower()
        if scan_id_lower.find(search_lower)!=-1 or protein_desc_lower.find(search_lower)!=-1 or protein_lower.find(search_lower)!=-1:
            results.append(scan_id)
            results_data.append(data)
    return results, results_data

if args.version:
    sys.exit(0)

if len(args.commands) not in [1,2] or args.help:
    print(help_0)
    sys.exit(0)

def scan(seq):
    heatmap_data = []
    heatmap_columns = []
    sum_all = 0
    for scan_id, pssm in scanRBP.pwm.pssm.items():
        scores = pssm.calculate(seq)
        scores = [x if x!=-math.inf else 0 for x in scores]
        scores = [x if x>=0 else 0 for x in scores]
        if len(scores)<len(seq):
            scores = scores + [0] * (len(seq)-len(scores))
        sum_all += sum(scores)
        heatmap_data.append(scores)
        heatmap_columns.append(scan_id)
    heatmap_rows = list(seq)
    data = pandas.DataFrame(heatmap_data, index=heatmap_columns, columns=heatmap_rows)
    return data

def process(seq, output_fname, args, title=None, seq_id=None): # called for each sequence
    output_fname = output_fname.replace("+", "plus").replace("-", "minus")
    basename = os.path.basename(output_fname.replace(".png", "").replace(".pdf", ""))
    output_final = f"{args.output_folder}/{basename}.tab.gz"
    if os.path.exists(output_final) and not args.force:
        return True

    heatmap_data = []
    heatmap_columns = []
    sum_all = 0
    if args.clip: # read in clip data and compute scores using CLIP peaks from provided bedGraph file
        coords = seq_id.split('_')
        start = int(coords[-2])
        stop = int(coords[-1])
        strand = coords[-3][-1]
        chr = '_'.join(coords[:-2])[:-1]
        scores = clip.get_vector(chr, strand, start, stop)
        scores = [x if x!=-math.inf else 0 for x in scores]
        if args.nonzero:
            scores = [x if x>=0 else 0 for x in scores]
        if len(scores)<len(seq):
            scores = scores + [0] * (len(seq)-len(scores))
        sum_all += sum(scores)
        heatmap_data.append(scores)
        heatmap_columns.append(args.protein)
    else: # process sequences with PWMs
        for scan_id, pssm in scanRBP.pwm.pssm.items():
            if args.protein: # only process a specific protein?
                if scan_id not in args.protein:
                    continue
            scores = pssm.calculate(seq)
            scores = [x if x!=-math.inf else 0 for x in scores]
            if args.nonzero:
                scores = [x if x>=0 else 0 for x in scores]
            if len(scores)<len(seq):
                scores = scores + [0] * (len(seq)-len(scores))
            sum_all += sum(scores)
            heatmap_data.append(scores)
            heatmap_columns.append(scan_id)

    for index, data in enumerate(heatmap_data):
        assert(len(seq)==len(data))

    heatmap_rows = list(seq)
    data = pandas.DataFrame(heatmap_data, index=heatmap_columns, columns=heatmap_rows)
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)
    data.to_csv(output_final, sep="\t")

    if args.heatmap!=False:
        plt.figure()
        sns.set(font="Arial")
        sns.set(font_scale=args.fontscale)
        sns.set_style("dark")
        sns.set_style("ticks")

        rdgn = sns.diverging_palette(h_neg=130, h_pos=10, s=99, l=55, sep=3, as_cmap=True)
        optional_params = {"linewidths":0.0}
        optional_params["col_cluster"] = False
        optional_params["annot"] = args.annotate
        optional_params["center"] = 0
        optional_params["fmt"] = ".0f"
        optional_params["yticklabels"] = True
        optional_params["xticklabels"] = args.xlabels
        optional_params["cmap"] = rdgn
        optional_params["cbar_pos"] = None

        if args.figsize!=False:
            optional_params["figsize"] = eval(args.figsize)
        fig = sns.clustermap(data, **optional_params)
        fig.ax_col_dendrogram.set_visible(False)
        ax = fig.ax_heatmap
        #ax.vlines([int(len(seq)/2)], *ax.get_ylim(), linestyle='dashed', label='center', colors=["#bbbbbb"], linewidth=1)

        vlines = args.vlines
        if not vlines.startswith("["):
            vlines = "[" + vlines
        if not vlines.endswith("]"):
            vlines = vlines + "]"
        vlines = eval(vlines)
        for vline_x in vlines:
            if vline_x<0:
                vline_x = len(seq)+vline_x
            ax.vlines([vline_x], *ax.get_ylim(), linestyle='dashed', label='center', colors=["#bbbbbb"], linewidth=1)

        if title==None:
            plt.title(output_fname + ": " + args.heatmap)
        else:
            plt.title(title)
        plt.tight_layout() 
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
        plt.savefig(f"{args.output_folder}/{basename}.png", dpi=300)
        plt.savefig(f"{args.output_folder}/{basename}.pdf")
        plt.close()

def process_cumulative(seqs, protein, output_fname, args, title=None):
    heatmap_data = []
    heatmap_columns = []
    sum_all = 0
    pssm = scanRBP.pwm.pssm[protein]
    seq_len = None
    for (seq_id, seq) in seqs:
        if seq_len==None:
            seq_len = len(seq)
        else:
            if len(seq)!=seq_len:
                print("scanRBP | please provide a FASTA file with sequences of equal length")
                sys.exit(1)
        if args.clip: # read in clip data and compute scores using CLIP peaks from provided bedGraph file
            print(f"scanRBP | loading CLIP {args.clip}")
            coords = seq_id.split('_')
            start = int(coords[-2])
            stop = int(coords[-1])
            strand = coords[-3][-1]
            chr = '_'.join(coords[:-2])[:-1]
            scores = clip.get_vector(chr, strand, start, stop)
        else: # compute scores using PWM for specified protein
            scores = pssm.calculate(seq)

        scores = [x if x!=-math.inf else 0 for x in scores]
        if args.nonzero:
            scores = [x if x>=0 else 0 for x in scores]
        if len(scores)<len(seq):
            scores = scores + [0] * (len(seq)-len(scores))
        sum_all += sum(scores)
        heatmap_data.append(scores)
        heatmap_columns.append(seq_id)

    for index, data in enumerate(heatmap_data):
        assert(len(seq)==len(data))

    heatmap_rows = list(seq)
    output_fname = output_fname.replace("+", "plus").replace("-", "minus")
    basename = os.path.basename(output_fname.replace(".png", "").replace(".pdf", ""))
    data = pandas.DataFrame(heatmap_data, index=heatmap_columns, columns=heatmap_rows)
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)
    data.to_csv(f"{args.output_folder}/pwm_{basename}.tab.gz", sep="\t")

    if args.heatmap!=False:
        plt.figure()
        sns.set(font="Arial")
        sns.set(font_scale=args.fontscale)
        sns.set_style("dark")
        sns.set_style("ticks")

        rdgn = sns.diverging_palette(h_neg=130, h_pos=10, s=99, l=55, sep=3, as_cmap=True)
        optional_params = {"linewidths": 0.0}
        optional_params["col_cluster"] = False
        optional_params["annot"] = args.annotate
        optional_params["center"] = 0
        optional_params["fmt"] = ".0f"
        optional_params["yticklabels"] = True
        optional_params["xticklabels"] = args.xlabels
        optional_params["cmap"] = rdgn
        optional_params["cbar_pos"] = None

        if args.figsize!=False:
            optional_params["figsize"] = eval(args.figsize)
        fig = sns.clustermap(data, **optional_params)
        fig.ax_col_dendrogram.set_visible(False)
        ax = fig.ax_heatmap
        #ax.vlines([int(seq_len/2)], *ax.get_ylim(), linestyle='dashed', label='center', colors=["#bbbbbb"], linewidth=1)

        vlines = args.vlines
        if not vlines.startswith("["):
            vlines = "[" + vlines
        if not vlines.endswith("]"):
            vlines = vlines + "]"
        vlines = eval(vlines)
        for vline_x in vlines:
            if vline_x<0:
                vline_x = len(seq)+vline_x
            ax.vlines([vline_x], *ax.get_ylim(), linestyle='dashed', label='center', colors=["#bbbbbb"], linewidth=1)

        if title==None:
            plt.title(output_fname + ": " + args.heatmap)
        else:
            plt.title(title)
        plt.tight_layout() 
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
        plt.savefig(f"{args.output_folder}/{basename}.png", dpi=300)
        #plt.savefig(f"{args.output_folder}/{basename}.pdf")
        plt.close()

if len(args.commands)>0:

    if args.protein:
        args.protein, _ = find_proteins(args.protein)

    if args.commands[0]=="config":
        if len(args.commands)>1:
            scanRBP.config.init(args.commands[1])
        sys.exit(0)

    if args.commands[0]=="search":
        if len(args.commands)>1:
            proteins, proteins_data = find_proteins(args.commands[1])
            if len(proteins)>0:
                print("[scanRBP] Found proteins in the scanRBP database:")
                table = []
                for rec in proteins_data:
                    table.append([rec['scan_id'], rec['protein'], rec['tissue'], rec['description'], rec['source']])
                df = pandas.DataFrame(table, columns = ['scan_id', 'protein', 'tissue', 'description', 'source'])
                print(df.to_string(index=False))
            else:
                print(f"[scanRBP] Found no proteins in the scanRBP database with provided seach '{args.commands[1]}'")
        else:
            print("[scanRBP] please provide search term (protein name)")
        sys.exit(0)

    if args.commands[0].lower().endswith(".fasta") or args.commands[0].lower().endswith(".fa"):
        import pybio
        basename = ".".join(args.commands[0].split(".")[:-1])
        data = pybio.data.Fasta(args.commands[0])
        if args.protein==False:
            print(f"scanRBP | processing {args.commands[0]}")
        else:
            print(f"scanRBP | processing {args.commands[0]}, protein(s) {args.protein}")
        seqs = []
        if args.clip:
            clip = pybio.data.Bedgraph(args.clip)
        while data.read():
            seq_id = data.id.lstrip(" ").split(" ")[0]
            output_id = f"pwm_{seq_id}"
            seq = data.sequence
            if not args.cumulative:
                process(seq, output_id, args, title=id, seq_id=seq_id)
            seqs.append((data.id.lstrip(" "), data.sequence))
        if args.cumulative:
            if args.protein==False:
                print("scanRBP | please provide a specific protein (or list of proteins) with '-protein name1,name2...' to compute cumulative binding across all sequences in the FASTA file for each of the provided proteins")
                sys.exit(1)
            for protein in args.protein:
                process_cumulative(seqs, protein, f"{basename}_{protein}", args, title=f"{basename}_{protein}")
    elif args.commands[0]=="bed":
        make_bedgraph()
    else:
        if len(args.commands)>=2:
            process(args.commands[0], args.commands[1], args)
        else:
            print(help_0)
