#!/usr/bin/env python
import numpy as np
import ulysses
from scipy.special import zeta

__doc__="""

%prog -m 1DME --ordering 0 --loop -o evol.pdf  PARAMETERFILE

Example parameter file:

m     -1      # logarithm, in [ev]
M1    12      # logarithm, in [GeV]
M2    12.6    # logarithm, in [GeV]
M3    13      # logarithm, in [GeV]
delta 213     # [deg]
a21    81     # [deg]
a31   476     # [deg]
x1     90     # [deg]
x2     87     # [deg]
x3    180     # [deg]
y1   -120     # [deg]
y2      0     # [deg]
y3   -120     # [deg]
t12    33.63  # [deg]
t13     8.52  # [deg]
t23    49.58  # [deg]
"""

def _auto_ylim(ax, low_pct=2, margin=0.5, top_margin=0.2):
    """Set log-scale y-limits based on the data, ignoring machine-precision transients.

    Reference lines (axhline — constant y) are excluded from the lower-bound
    calculation but still kept visible via top_margin above their value.
    """
    import numpy as np
    data_log = []   # actual evolving curves
    ref_log  = []   # horizontal reference lines (axhline)
    for line in ax.get_lines():
        y = np.asarray(line.get_ydata(), dtype=float)
        y = y[np.isfinite(y) & (y > 0)]
        if not len(y):
            continue
        if np.ptp(y) == 0:          # constant y → reference line
            ref_log.append(np.log10(y[0]))
        else:
            data_log.append(np.log10(y))
    if not data_log:
        return
    all_data = np.concatenate(data_log)
    lower    = np.percentile(all_data, low_pct) - margin
    upper    = np.max(all_data)
    if ref_log:
        upper = max(upper, max(ref_log))
    ax.set_ylim([10**lower, 10**(upper + top_margin)])


def plotEvolution(LEPTO, f_out):
    import matplotlib.pyplot as plt
    import matplotlib.ticker as ticker

    # Viridis-inspired discrete palette
    _PALETTE = ['#440154', '#31688E', '#35B779', '#F8961E', '#B12A90',
                '#FDE725', '#21918C', '#90D743', '#443983', '#E76F51']

    def _style_ax(ax, xlabel_str=None):
        """Apply consistent tick/grid styling to an axis."""
        ax.set_yscale("log")
        ax.set_xscale("log")
        ax.xaxis.set_major_locator(ticker.LogLocator(base=10, numticks=10))
        ax.xaxis.set_minor_locator(ticker.LogLocator(base=10, subs=np.arange(2,10)*0.1, numticks=100))
        ax.yaxis.set_major_locator(ticker.LogLocator(base=10, numticks=15))
        ax.yaxis.set_minor_locator(ticker.LogLocator(base=10, subs=np.arange(2,10)*0.1, numticks=100))
        ax.xaxis.set_minor_formatter(ticker.NullFormatter())
        ax.yaxis.set_minor_formatter(ticker.NullFormatter())
        ax.tick_params(axis='both', which='major', direction='in',
                       top=True, right=True, labelsize=16, pad=6, length=7, width=1.0)
        ax.tick_params(axis='both', which='minor', direction='in',
                       top=True, right=True, length=4, width=0.7)
        ax.grid(which='major', linestyle='-',  linewidth=0.4, color='grey', alpha=0.2)
        ax.grid(which='minor', linestyle=':',  linewidth=0.3, color='grey', alpha=0.15)
        if xlabel_str is not None:
            ax.set_xlabel(xlabel_str, fontsize=20, labelpad=8)

    data     = LEPTO.evolData
    active   = LEPTO.flavourindices()
    labels   = LEPTO.flavourlabels()
    DMactive = LEPTO.extendedindices()
    DMlabels = LEPTO.extendedlabels()
    xlabel   = getattr(LEPTO, 'evolname', r'$z = M_1/T$')
    lw       = 2.0

    if DMactive:
        fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(9, 9))
    else:
        fig, ax1 = plt.subplots(1, 1, figsize=(9, 6))

    # ------------------------------------------------------------------ #
    #  Upper panel: standard flavour abundances + eta_B                  #
    # ------------------------------------------------------------------ #
    for i, (idx, lab) in enumerate(zip(active, labels)):
        color = _PALETTE[i % len(_PALETTE)]
        vals  = data[:, idx]
        ax1.plot(data[:,0], np.abs(vals), label=lab, linewidth=lw, color=color)
        neg   = np.where(vals < 0, np.abs(vals), np.nan)
        if not np.all(np.isnan(neg)):
            ax1.plot(data[:,0], neg, linestyle='--', color=color, linewidth=lw)

    ax1.axhline(6.1e-10, color='0.5', linewidth=1.0, linestyle=':', label=r"$\eta_B^{\rm obs}$")
    ax1.plot(data[:,0], np.abs(data[:,-1]), label=r"$|\eta_B|$", linewidth=lw, color='k')
    ax1.legend(loc='upper left', bbox_to_anchor=(1.02, 1), borderaxespad=0,
               frameon=True, framealpha=0.95, edgecolor='0.7', prop={'size': 14})
    ax1.set_ylabel(r"$\left|N_{B-L}\right|,\;\eta_B$", fontsize=20, labelpad=8)
    _style_ax(ax1, xlabel_str=(None if DMactive else xlabel))
    _auto_ylim(ax1)

    # ------------------------------------------------------------------ #
    #  Lower panel: extended (model-specific) quantities, if present     #
    # ------------------------------------------------------------------ #
    if DMactive:
        for i, (idx, lab) in enumerate(zip(DMactive, DMlabels)):
            color = _PALETTE[(len(active) + i) % len(_PALETTE)]
            vals  = data[:, idx]
            ax2.plot(data[:,0], np.abs(vals), label=lab, linewidth=lw, color=color)
            neg   = np.where(vals < 0, np.abs(vals), np.nan)
            if not np.all(np.isnan(neg)):
                ax2.plot(data[:,0], neg, linestyle='--', color=color, linewidth=lw)
        ax2.legend(loc='upper left', bbox_to_anchor=(1.02, 1), borderaxespad=0,
                   frameon=True, framealpha=0.95, edgecolor='0.7', prop={'size': 14})
        ax2.set_ylabel(r"$\left|N_{\rm ext}\right|$", fontsize=20, labelpad=8)
        _style_ax(ax2, xlabel_str=xlabel)
        _auto_ylim(ax2)

    plt.tight_layout()
    plt.savefig(f_out, dpi=300, bbox_inches='tight')
    plt.close(fig)




if __name__=="__main__":

    import optparse, sys
    op = optparse.OptionParser(usage=__doc__)
    op.add_option("-o", "--output",    dest="OUTPUT",           default=None, type=str, help="Output file name for evolution plots/data --- if not provided, only calculate etaB (default: %default)")
    op.add_option("-v", "--debug",     dest="DEBUG",            default=False, action="store_true", help="Turn on some debug messages")
    op.add_option("-m", "--model",     dest="MODEL",            default="1DME", help="Selection of of model (default: %default)")
    op.add_option("--zrange",          dest="ZRANGE",           default="0.1,30,500", help="Ranges and steps of the evolution variable (default: %default)")
    op.add_option("--lambda",          dest="LAMBDA",           default="1e3", help="ARS parameter Lambda which controls the scale at which the fast modes in linear term are forced to enter quasi-static regime (default: %default)")
    op.add_option("--xrange",          dest="XRANGE",           default="1e-6,None,500", help="Ranges and steps of the ARS evolution variable x=T_ew/T  (default: 1e-6,min([1, 20*131.7/M1]),500)")
    op.add_option("--inv",             dest="INVORDERING",      default=False, action='store_true', help="Use inverted mass ordering (default: %default)")
    op.add_option("--loop",            dest="LOOP",             default=False, action='store_true', help="Use loop-corrected Yukawa (default: %default)")
    op.add_option("--zcut",            dest="ZCUT",             default="1.0",  help="Set cut value for stitching in ARS model")
    op.add_option("--extended",        dest="EXTENDED",         default=False, action="store_true", help="Allow model-specific parameters beyond the standard pnames (default: %default)")
    op.add_option("--initial",         dest="INITIAL",          default=0.0,   type=float,          help="Initial RHN abundance: 0 = vanishing, 1 = thermal (default: %default)")
    op.add_option("--ars-indirect",    dest="ARS_INDIRECT",     default=False, action="store_true", help="Add indirect contributions for the 3RHN ARS scenario (default: %default)")
    opts, args = op.parse_args()


    if len(args)==0:
        print("No parameter space configuration given, exiting.")
        sys.exit(1)

    # Disect the zrange string
    zmin, zmax, zsteps = opts.ZRANGE.split(",")
    zmin=float(zmin)
    zmax=float(zmax)
    zsteps=int(zsteps)

    assert(zmin<zmax)
    assert(zsteps>0)

    def try_float(v):
        try:
            return float(v)
        except Exception:
            return None
        
    xmin, xmax, xsteps = opts.XRANGE.split(",")
    
    xmin=try_float(xmin)
    xmax=try_float(xmax)
    xsteps=int(xsteps)

    if xmax is not None:
        assert(xmin<xmax)
    assert(xsteps>0)
    

    pfile, gdict = ulysses.tools.parseArgs(args)

    LEPTO = ulysses.selectModel(opts.MODEL,
                                zmin=zmin, zmax=zmax, zsteps=zsteps,
                                xmin=xmin, xmax=xmax, xsteps=xsteps,
                                Lambda=float(opts.LAMBDA),
                                ordering=int(opts.INVORDERING),
                                loop=opts.LOOP,
                                debug=opts.DEBUG,
                                zcut=float(opts.ZCUT),
                                extended_mode=opts.EXTENDED,
                                initial_abundance=opts.INITIAL,
                                use_hind=opts.ARS_INDIRECT,
                                **gdict
                                )


    # Read parameter card and very explicit checks on parameter names
    _, FIX,param = ulysses.readConfig(pfile)

    LEPTO.which_param = param

    _optional = LEPTO._OPTIONAL_PARAMS
    _required_pnames = [p for p in LEPTO.pnames if p not in _optional]

    if LEPTO.extended_mode:
        # Extended mode: required pnames must be present in FIX; extra keys are model-specific
        for p in _required_pnames:
            if not p in FIX:
                print("Required parameter {} not provided in input file {}, exiting".format(p, args[0]))
                sys.exit(1)
    else:
        _supplied = {p for p in FIX if p in LEPTO.pnames or p in _optional}
        _required_set = set(_required_pnames)
        if not _required_set.issubset(_supplied):
            for p in _required_pnames:
                if p not in FIX:
                    print("Parameter {} not provided in input file {}, exiting".format(p, args[0]))
            sys.exit(1)
        for p in FIX.keys():
            if p not in LEPTO.pnames and p not in _optional:
                print("Parameter {} in input file {} not recognised, exiting".format(p, args[0]))
                sys.exit(1)
    
    
    # Print banner with model and run card info
    ulysses.tools.print_banner(
        model_name=opts.MODEL,
        param_file=pfile,
        params=FIX,
        extra=["zrange: zmin={}, zmax={}, steps={}".format(zmin, zmax, zsteps)]
    )
                                                                              

    if opts.DEBUG:
        print(LEPTO)

    etaB = LEPTO(FIX)

    if opts.DEBUG:
        print(LEPTO.h)
        LEPTO.printParams()
        print(LEPTO.U)

    # Conversion constants (PDG)
    mp       = 1.672621898e-24  # proton mass [g]
    ngamma   = 410.7            # photon number density today [cm^-3]
    rhoc     = 1.87840e-29      # critical density h^2 [g cm^-3]
    ToYb     = 45 * zeta(3) / ((43/11) * np.pi**4)
    ToOmegab = mp * ngamma / rhoc

    print("{}{}\n{}{}\n{}{}".format(
        "eta_b".ljust(      14), etaB,
        "Y_b".ljust(        14), etaB * ToYb,
        "Omega_b h^2".ljust(14), etaB * ToOmegab))

    summary = LEPTO.extended_summary()
    if summary:
        print(summary)

    if opts.OUTPUT is not None:
        # TODO header for text outputs
        D=LEPTO.evolData
        if opts.OUTPUT.endswith(".txt"):
            np.savetxt(opts.OUTPUT, D)
        if opts.OUTPUT.endswith(".dat"):
            np.savetxt(opts.OUTPUT, D)
        elif opts.OUTPUT.endswith(".csv"):
            np.savetxt(opts.OUTPUT, D, delimiter=",")
        else:
            plotEvolution(LEPTO, opts.OUTPUT)
        if opts.DEBUG:
            print("Output written to {}".format(opts.OUTPUT))

