#!python

import numpy as np
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from datetime import date
from matplotlib.colors import LogNorm

from contur.run.run_extract_xs_bf import *
from contur.run.arg_utils import *
import contur.config.version

args = get_args(sys.argv[1:],'xsbfscan')

if args["printVersion"]:
    print("Contur " + contur.config.version)
    sys.exit(0)

if not len(args["inputDir"]) > 0 or not os.path.isdir(args["inputDir"][0]):
    print(f'{args["inputDir"]} is not a valid input directory. Abort.')
    sys.exit(0)

x = []
y = []
results = {}

if args['xy'] is None:
    print("[ERROR] please choose which params to scan over from this list, in format '--xy var1,var2'")
    os.system("head -n1 {}/sampled_points.dat".format(args['inputDir'][0]))
    exit(1)

args['x'] = args['xy'].split(",")[0]
args['y'] = args['xy'].split(",")[1]

postFix = ""

if args['foldBRs']:
    postFix += "_BRs"

if args['foldBSMBRs']:
    postFix += "_BSMBRs"

if args['splitLeptons']:
    postFix += "_sL"

if args['mergeEWBosons']:
    postFix += "_V"

if args['splitIncomingPartons']:
    postFix += "_sP"

if args['splitAntiParticles']:
    postFix += "_spA"

if args['splitBQuarks']:
    postFix += "_spB"

if args['splitLightQuarks']:
    postFix += "_spLq"


postFixPlotsDir = postFix
if args['splitIntoPools']:
    postFixPlotsDir += "_pools"
stem = "process_plots_{}{}".format(str(date.today().strftime('%d%m%y')), postFixPlotsDir)

outdir = os.path.join(args['OUTPUTDIR'],stem)
os.system("mkdir -p {}".format(outdir))
counter = -1
max_xs = 0
min_xs = 99999
maxProcs = []
ldir= sorted(os.listdir(args['inputDir'][0]))
for f in ldir:

    counter += 1
    if counter % 10 == 0: print("Point {}/{}: {}".format(counter, len(ldir),f))
    f = os.path.join(args['inputDir'][0],f)
    flog="cache/{}".format(f)
    os.system("mkdir -p {}".format(flog))
    if "ANA" in f:
        continue
    if os.path.isdir(f):
        log_file = "{}/out{}.log".format(flog, postFix)
        if (os.path.isfile(log_file) and args['clearCache']):
            os.system("rm {}".format(log_file))
        if not (os.path.isfile(log_file) and (not args['ignoreCache'])):
            run_extract_xs_bf(args,input_dir=f,outfile=log_file)
          
        log = open(log_file)
        for logline in log.readlines():
            logline = logline.strip()
            if not "::" in logline:
                if ":" in logline:
                    continue
                if "totalXS" in logline:
                    continue
                if "Skip" in logline:
                    continue
                xs = float(logline.split(", ")[0].replace("fb", ""))
                proc = logline.split(", ")[-1]
                if xs > max_xs:
                    max_xs = xs
                if xs < min_xs and xs != 0:
                    min_xs = xs
                    if not proc in maxProcs:
                        maxProcs += [proc]
                if not proc in results.keys():
                    results[proc] = [0 for i in range(len(x)-1)]
                results[proc].append(xs)
            else:
                varVals = logline.split("::")[-1].split(",")
                for varVal in varVals:
                    varVal=varVal.strip()
                    var = str(varVal.split("=")[0])
                    val = float(varVal.split("=")[1])
                    if var==args['x']:
                        x += [val]
                    if var==args['y']:
                        y += [val]

        for k, v in results.items():
            while (len(v) < len(x)):
                v += [0]

xBins = sorted(list(set(x)))
yBins = sorted(list(set(y)))
print("xBins:"+ repr(len(xBins)) + " " + repr(xBins))
print("yBins:"+ repr(len(yBins)) + " " + repr(yBins))
print("max_xs: {} fb".format(max_xs))
print("min_xs: {} fb".format(min_xs))
vals = len(xBins)*len(yBins)
arr = np.zeros(vals)
arr.resize((len(yBins), len(xBins)))
counter = -1
plotcounter = 0
newresults = {}

# merge into pools if requested
if args['splitIntoPools']:
    for k, v in results.items():
        k = k.split("\\rightarrow ")[-1]
        print(k)
        MET = 1 if "\\nu" in k else 0
        jets = k.count("q")+k.count("g")+k.count("b")-k.count("gamma")
        bJets = k.count("b")
        photons = k.count("gamma")
        leptons = k.count("l")+k.count("e")+k.count("mu")+k.count("tau")
        for part in k.split(" "):
            if part not in ["e","\\nu","mu","tau","l","\\gamma","b","t","q","g","b","W","Z","H",""]:
            # assume stable, non-interacting BSM
                print("Assuming that " + part + " is a stable BSM particle which shows up as MET")
                MET = 1
        # print k, "MET=%d, jets=%d, leptons=%d, photons=%d"%(MET, jets, leptons, photons)
        pool = ""
        if leptons:
            pool += "nLeptons=%d, " % leptons
        if jets >= 2:
            pool += "nJets>=2, "
        if jets == 1:
            pool += "nJets=1, "
        if bJets > 0:
            pool += "nBJets=%d, " % bJets
        if photons:
            pool += "photons=%d, " % photons
        if MET:
            pool += "MET, "
        pool=pool[:-2]
        print("this process %s gets this label: %s"%(k, pool))
        if not pool in newresults.keys():
            newresults[pool] = np.array(v)
        else:
            newresults[pool] += np.array(v)
        # print k, pool
    results = newresults


for k, v in sorted(results.items(), key=lambda x: max(x[1]), reverse=True):
    print("K==",k,"V==",v)
    counter += 1
    print("%d/%d doing %s (max = %.3f fb)" % (counter+1, len(results), k, max(v)))
    if max(v) < args['tolerance']:
        print("skipping due to xs< tolerance !", args['tolerance'])
        continue
    heatmap = np.copy(arr)
    thisMax = max(v)
    for ig in range(len(v)):
        ix = xBins.index(x[ig])
        iy = yBins.index(y[ig])
        try:
            heatmap[iy][ix] = v[ig]
        except Exception:
            print("Issues with point ", ix, iy, x[ig], y[ig],heatmap[iy][ix])
            exit(1)
        #heatmap[ix][iy] = v[ig]
    fig, ax1 = plt.subplots(1, 1)
    max_xs = 10e4
    min_xs = 10e-4
    k = k.replace("_ch_UFO", r"_{ch,UFO}")
    k = k.replace("_ch", r"_{ch}")
    k = k.replace("_UFO", r"_{UFO}")
    k = k.replace("0", r"^{0}")
    print("$%s$" % k)
    ax1.set_title("$%s$" % k)

    im = ax1.pcolormesh(xBins, yBins, heatmap,
                        norm=LogNorm(vmin=min_xs, vmax=max_xs), shading='auto')
    cs = ax1.contour(xBins, yBins, heatmap, levels=[
                     0.01, 0.1, 1, 10, 100, 1000, 10000], colors='w')
    ax1.clabel(cs, inline=1, fontsize=10, fmt="%.0e")

    ax1.set_xlabel(args['xy'].split(",")[0])
    ax1.set_ylabel(args['xy'].split(",")[1])
    #print(args['xy'].split(",")[0])
    #print(args['xy'].split(",")[1])
    k = k.replace("\\rightarrow", "to")
    k = k.replace(">=", "_geq_")
    k = k.replace(">", "_gt_")
    k = k.replace("=", "_eq_")
    k = k.replace("^", "")
    k = k.replace(",", "").replace("\\", "").replace(" ", "_")
    k = k.replace("{","").replace("}","")
    print(outdir+"/%s.pdf" % k)
    plt.savefig(outdir+"/%s.png" % k)
    plt.savefig(outdir+"/%s.pdf" % k)
    plotcounter += 1
    cbar = plt.colorbar(im, ax=ax1)
    ax1.remove()
    cbar.ax.set_ylabel("cross-section [fb]", rotation=-90, va="bottom")
    #plt.savefig(outdir+"/cbar.png", bbox_inches='tight')
    #plt.savefig(outdir+"/cbar.pdf", bbox_inches='tight')
    plt.close(fig)


print("made %d plots: %s" % (plotcounter, outdir))
