#!/usr/bin/env python
import numpy as np
import ulysses

__doc__="""

Scan of EtaB in two variables

%prog -m 1DME --ordering 0 --loop -o scan.pdf  PARAMETERFILE -x 100 -y 100

Example paramter file:

m     -3   -1 # logarithm, in [ev]
M1    11   14 # logarithm, in [GeV]
M2    12.6    # logarithm, in [GeV]
M3    13      # logarithm, in [GeV]
delta 213     # [deg]
a21    81     # [deg]
a31   476     # [deg]
x1     90     # [deg]
x2     87     # [deg]
x3    180     # [deg]
y1   -120     # [deg]
y2      0     # [deg]
y3   -120     # [deg]
t12    33.63  # [deg]
t13     8.52  # [deg]
t23    49.58  # [deg]
"""

def plotEtaB_2D(X, Y, data, f_out, pxname, pyname):

    import matplotlib.pyplot as plt
    import matplotlib.ticker as ticker

    from matplotlib import cm
    from matplotlib.colors import ListedColormap, LinearSegmentedColormap

    colors = ["#2364aa", "#3da5d9", "#73bfb8", "#fec601", "#ea7317"]
    cmap1 = LinearSegmentedColormap.from_list("mycmap", colors)

    Z = np.zeros((len(Y), len(X)))

    for i in range(len(X)):
        for j in range(len(Y)):
            Z[j, i] = 1e10 * data[i*len(Y) + j, 2]

    fig, ax = plt.subplots()

    CS1 = ax.contour(X, Y, Z, cmap=cmap1, levels=20)
    CS2 = ax.contourf(X, Y, Z, cmap=cmap1, levels=20, alpha=0.5)
    cbar = fig.colorbar(CS2)
    cbar.ax.set_ylabel("$\eta_B\cdot 10^{10}$")
    ax.set_xlabel(pxname)
    ax.set_ylabel(pyname)

    plt.savefig(f_out)


if __name__=="__main__":

    import optparse, os, sys
    op = optparse.OptionParser(usage=__doc__)
    op.add_option("-o", "--output",    dest="OUTPUT",      default="scan.pdf", type=str, help="Output file name for evolution plots/data (default: %default)")
    op.add_option("-v", "--debug",     dest="DEBUG",       default=False, action="store_true", help="Turn on some debug messages")
    op.add_option("-m", "--model",     dest="MODEL",       default="1DME", help="Selection of of model (default: %default)")
    op.add_option("-x", "--nx-scan",   dest="NSCANX",      default=30, type=int, help="Number of point to scan in first variable (default: %default)")
    op.add_option("-y", "--ny-scan",   dest="NSCANY",      default=30, type=int, help="Number of point to scan in second variable (default: %default)")
    op.add_option("--zrange",          dest="ZRANGE",      default="0.1,100,1000", help="Ranges and steps of the evolution variable (default: %default)")
    op.add_option("--lambda",          dest="LAMBDA",      default="1e3", help="ARS parameter Lambda which controls the scale at which the fast modes in linear term are forced to enter quasi-static regime (default: %default)")
    op.add_option("--xrange",          dest="XRANGE",      default="1e-6,None,500", help="Ranges and steps of the ARS evolution variable (default: 1e-6,min([1, 20*131.7/M1]),500)")
    op.add_option("--inv",             dest="INVORDERING", default=False, action='store_true', help="Use inverted mass ordering (default: %default)")
    op.add_option("--loop",            dest="LOOP",        default=False, action='store_true', help="Use loop-corrected Yukawa (default: %default)")
    op.add_option("--zcut",            dest="ZCUT",        default="1.0",  help="Set cut value for stitching in ARS model")
    op.add_option("--extended",        dest="EXTENDED",    default=False, action="store_true", help="Allow model-specific parameters beyond the standard pnames (default: %default)")
    op.add_option("--initial",         dest="INITIAL",     default=0.0,   type=float,          help="Initial RHN abundance: 0 = vanishing, 1 = thermal (default: %default)")
    op.add_option("--ars-indirect",    dest="ARS_INDIRECT", default=False, action="store_true", help="Add indirect contributions for the 3RHN ARS scenario (default: %default)")
    opts, args = op.parse_args()


    if len(args)==0:
        print("No parameter space configuration given, exiting.")
        sys.exit(1)

    # Make sure specified file actually exists
    if not os.path.exists(args[0]):
        print("Specified input file {} does not exist, exiting.".format(args[0]))
        sys.exit(1)

    # Disect the zrange string
    zmin, zmax, zsteps = opts.ZRANGE.split(",")
    zmin=float(zmin)
    zmax=float(zmax)
    zsteps=int(zsteps)

    assert(zmin<zmax)
    assert(zsteps>0)

    def try_float(v):
        try:
            return float(v)
        except Exception:
            return None

    xmin, xmax, xsteps = opts.XRANGE.split(",")
    xmin=try_float(xmin)
    xmax=try_float(xmax)
    xsteps=int(xsteps)

    if xmax is not None:
        assert(xmin<xmax)
    assert(xsteps>0)

    pfile, gdict = ulysses.tools.parseArgs(args)

    LEPTO = ulysses.selectModel(opts.MODEL,
            zmin=zmin, zmax=zmax, zsteps=zsteps,
            xmin=xmin, xmax=xmax, xsteps=xsteps,
            Lambda=float(opts.LAMBDA),
            ordering=int(opts.INVORDERING),
            loop=opts.LOOP,
            debug=opts.DEBUG,
            zcut=float(opts.ZCUT),
            extended_mode=opts.EXTENDED,
            initial_abundance=opts.INITIAL,
            use_hind=opts.ARS_INDIRECT,
            **gdict
            )

    # Read parameter card and very explicit checks on parameter names
    RNG, FIX, param = ulysses.readConfig(pfile)

    LEPTO.which_param = param

    assert(len(RNG)==2)
    pxscan = list(RNG.keys())[0]
    pxmin, pxmax = RNG[pxscan]
    pyscan = list(RNG.keys())[1]
    pymin, pymax = RNG[pyscan]

    _optional = LEPTO._OPTIONAL_PARAMS
    _required_pnames = [p for p in LEPTO.pnames if p not in _optional]

    if LEPTO.extended_mode:
        # Extended mode: required pnames must appear in FIX or as scan parameters;
        # extra keys may be extended (model-specific) parameters
        covered = set(FIX.keys()) | set(RNG.keys())
        for p in _required_pnames:
            if not p in covered:
                print("Required parameter {} not provided in input file {}, exiting".format(p, args[0]))
                sys.exit(1)
    else:
        covered = set(FIX.keys()) | set(RNG.keys())
        for p in _required_pnames:
            if p not in covered:
                print("Parameter {} not provided in input file {}, exiting".format(p, args[0]))
                sys.exit(1)
        for p in FIX.keys():
            if p not in LEPTO.pnames and p not in _optional:
                print("Parameter {} in input file {} not recognised, exiting".format(p, args[0]))
                sys.exit(1)
        if pxscan not in LEPTO.pnames:
            print("Scan-parameter {} in input file {} not recognised, exiting".format(pxscan, args[0]))
            sys.exit(1)
        if pyscan not in LEPTO.pnames:
            print("Scan-parameter {} in input file {} not recognised, exiting".format(pyscan, args[0]))
            sys.exit(1)

    ulysses.tools.print_banner(
        model_name=opts.MODEL,
        param_file=pfile,
        params={**FIX, **{pxscan: "{} to {}".format(pxmin, pxmax), pyscan: "{} to {}".format(pymin, pymax)}},
        extra=["scan: {}={} to {}, steps={}".format(pxscan, pxmin, pxmax, opts.NSCANX),
               "      {}={} to {}, steps={}".format(pyscan, pymin, pymax, opts.NSCANY),
               "zrange: zmin={}, zmax={}, steps={}".format(zmin, zmax, zsteps)]
    )

    if opts.DEBUG:
        print(LEPTO)

    PPX = np.linspace(pxmin, pxmax, opts.NSCANX)
    PPY = np.linspace(pymin, pymax, opts.NSCANY)
    EE  = []

    print("2D Scanning {} in [{},{}] for {} values and {} in [{},{}] for {} values".format(pxscan, pxmin, pxmax, opts.NSCANX,
                                                                                           pyscan, pymin, pymax, opts.NSCANY))

    from tqdm import tqdm

    for px in tqdm(PPX):
        for py in tqdm(PPY, leave = False):
            FIX[pxscan] = px
            FIX[pyscan] = py
            etaB = LEPTO(FIX)
            EE.append(etaB)

    DATA = np.empty((opts.NSCANX * opts.NSCANY, 3))

    for i in range(opts.NSCANX):
        for j in range(opts.NSCANY):
            DATA[i*opts.NSCANY + j, 0] = PPX[i]
            DATA[i*opts.NSCANY + j, 1] = PPY[j]
            DATA[i*opts.NSCANY + j, 2] = EE[i*opts.NSCANY + j]

    if opts.OUTPUT is not None:
        if opts.OUTPUT.endswith(".txt"):
            np.savetxt(opts.OUTPUT, DATA)
        elif opts.OUTPUT.endswith(".csv"):
            np.savetxt(opts.OUTPUT, DATA, delimiter=",")
        else:
            plotEtaB_2D(PPX, PPY, DATA, opts.OUTPUT, pxscan, pyscan)
        if opts.DEBUG:
            print("Output written to {}".format(opts.OUTPUT))
