#!/usr/bin/env python3
"""
Multi-parameter likelihood computation module for GW event with unique EM counterpart

Tathagata Ghosh
"""

import bilby
import gwcosmo
import h5py
import matplotlib
import numpy as np
import json
from itertools import product
from gwcosmo.utilities.posterior_utilities import str2bool
from gwcosmo.utilities.arguments import create_parser
from gwcosmo.utilities.mass_prior_utilities import *
from gwcosmo.utilities.host_galaxy_merger_relations import RedshiftEvolutionConstant, RedshiftEvolutionPowerLaw, RedshiftEvolutionMadau
import astropy.constants as const
from gwcosmo.injections import injections_at_detector
from gwcosmo.prior import priors
from gwcosmo.utilities import host_galaxy_merger_relations
from gwcosmo.utilities.arguments import create_parser
from gwcosmo.utilities.check_boundary import *
from gwcosmo.utilities.cosmology import *
from gwcosmo.utilities.posterior_utilities import str2bool
from gwcosmo.utilities.injections_utilities import default_ifar_value
from tqdm import tqdm

matplotlib.use("agg")
import matplotlib.pyplot as plt

matplotlib.rcParams["font.family"] = "Times New Roman"
matplotlib.rcParams["font.sans-serif"] = ["Bitstream Vera Sans"]
matplotlib.rcParams["mathtext.fontset"] = "stixsans"

import seaborn as sns

sns.set(context="paper", style="ticks", palette="colorblind")

parser = create_parser(
    "--method",
    "--posterior_samples",
    "--redshift_evolution",
    "--parameter_dict",
    "--injections_path",
    "--mass_model",
    "--snr",
    "--ifar",
    "--plot",
    "--outputfile",
    "--gravity_model",
    "--sampler",
    "--nwalkers",
    "--npool",
    "--walks",
    "--nsteps",
    "--nlive",
    "--dlogz",
    "--sampler_likelihood_check")

opts = parser.parse_args()

if not opts.posterior_samples:
    parser.error("Posterior samples file is missing")

if not opts.method:
    parser.error("Missing method. Choose 'sampling' or 'gridded'.")
method = opts.method

# loading GW data (posterior samples/skymap)
if opts.posterior_samples.endswith(".json"):
    with open(opts.posterior_samples) as json_file:
        posterior_samples_dictionary = json.load(json_file)
else:  # a json file is required
    err_str = """Missing posterior samples. Expecting a json file with format:\n
    {
       "GW170817":
	  {
		 "posterior_file_path" : "/path/to/GW170817.h5",
		 "samples_field" : "C01:Mixed",
		 "skymap_path" : "/path/to/GW170817_skymap.fits"
	  }
    }
    """
    parser.error(err_str)

mass_model = str(opts.mass_model)
mass_priors = mass_model_selector(mass_model, parser)
print(f'Using the {mass_model} mass model')

# parameter(s) file
if opts.parameter_dict:
    with open(str(opts.parameter_dict)) as json_file:
        parameter_dict = json.load(json_file)
    if opts.posterior_samples is None and opts.skymap is not None :
        keys = extract_parameters_from_instance(mass_priors)
        mass_parameters = list(keys)
        for key in mass_parameters :
            if key in parameter_dict.keys() :
                if isinstance(parameter_dict[key]['value'], list):
                    parser.error(f"Mass parameter({key}) cannot be inferred while using GW skymap.")
else:
    parser.error("Missing parameters file")

try:
    injdata = h5py.File(opts.injections_path, "r")
except OSError:
    injdata = h5py.File(opts.injections_path, "r", locking=False)

if "ifar" in injdata:
    ifar = np.array(injdata["ifar"])
else:
    ifar = default_ifar_value + 0 * np.array(injdata["m1d"])

if not opts.snr and not opts.ifar:
    parser.error("You must specify threshold values for SNR and/or IFAR for the analysis.")
opts.snr = opts.snr or 0.0
opts.ifar = opts.ifar or 0.0
print(f"Analysis will be run with cuts on SNR: {opts.snr} OR IFAR: {opts.ifar}")
print("For software injections, the IFAR is known: their SNR should be set to -1. For semianalytical injections, the SNR is know and the IFAR should be set to -1.")


injections = injections_at_detector(
    m1d=np.array(injdata["m1d"]),
    m2d=np.array(injdata["m2d"]),
    dl=np.array(injdata["dl"]),
    prior_vals=np.array(injdata["pini"]),
    snr_det=np.array(injdata["snr"]),
    snr_cut=opts.snr,
    ifar=ifar,
    ifar_cut=opts.ifar,
    ntotal=np.array(injdata["ntotal"]),
    Tobs=np.array(injdata["Tobs"]),
)
injdata.close()

if not (mass_model := opts.mass_model):
    parser.error("Missing mass model")
if not (mass_priors := getattr(priors, mass_model.replace("-", "_"), None)):
    parser.error("Unrecognized mass model")
print(f"Using the {mass_model} mass model")
mass_priors = mass_priors()

gravity_model = opts.gravity_model
if gravity_model == "GR":
    cosmo = standard_cosmology()
elif gravity_model == "Xi0_n":
    cosmo = Xi0_n_cosmology()
elif gravity_model == 'extra_dimension':
    cosmo = extra_dimension_cosmology()
elif gravity_model == 'cM':
    cosmo = cM_cosmology()
else:
    parser.error(f"Unrecognized '{gravity_model}' gravity model")
print(f"Using the {gravity_model} gravity model")

# redshift evolution model
redshift_evolution = opts.redshift_evolution
if redshift_evolution == "None":
    redshift_evolution = "Constant"
if not (
    ps_z := getattr(host_galaxy_merger_relations, f"RedshiftEvolution{redshift_evolution}", None)
):
    parser.error("Unrecognized redshift evolution model")
print(f"Assuming a {redshift_evolution} redshift evolution model")
ps_z = ps_z()

# check_boundary(cosmo, parameter_dict, injections, mass_priors, gravity_model, mass_model)

check_bool = str2bool(opts.sampler_likelihood_check)
plot = str2bool(opts.plot)
outputfile = str(opts.outputfile)

me = gwcosmo.likelihood.bright_siren_likelihood.MultipleEventLikelihoodEM(
    posterior_samples_dictionary,
    injections,
    ps_z,
    cosmo,
    mass_priors,
    network_snr_threshold=opts.snr,
    ifar_cut=opts.ifar,
)

for key in parameter_dict:
    if key not in me.parameters:
        print(
            f"WARNING: The parameter {key} from your parameter dictionary is not recognized by the likelihood module."
        )

if method == "sampling":

    sampler = str(opts.sampler)
    nwalkers = int(opts.nwalkers)
    walks = int(opts.walks)
    nsteps = int(opts.nsteps)
    nlive = int(opts.nlive)
    dlogz = float(opts.dlogz)
    npool = int(opts.npool)

    priors =  bilby.core.prior.PriorDict()

    for key in parameter_dict:
        if isinstance(parameter_dict[key]["value"], list):
            if "label" in parameter_dict[key]:
                label = parameter_dict[key]["label"]
            else:
                label = key
            if parameter_dict[key]['prior'] == 'Uniform' or parameter_dict[key]['prior'] == 'uniform':
                priors[key] = bilby.core.prior.Uniform(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
            elif parameter_dict[key]['prior'] == 'Gaussian' or parameter_dict[key]['prior'] == 'gaussian':
                priors[key] = bilby.core.prior.Gaussian(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
                print(f'Gaussian prior assigned to parameter {key}')
            elif parameter_dict[key]['prior'] == 'Loguniform' or parameter_dict[key]['prior'] == 'loguniform':
                priors[key] = bilby.core.prior.LogUniform(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
                print(f'Lognnormal prior assigned to parameter {key}')
            else:
                raise ValueError(f"Unrecognised prior settings for parameter {key} (specify 'Uniform', 'Gaussian' or 'Loguniform'.)")
        else:
            priors[key] = float(parameter_dict[key]["value"])

    # Check prior constraints for the mass distribution, does nothing if not required
    priors = me.mass_priors.sampling_constraint_call(priors)

    sampling_method = "acceptance-walk" # can be "rwalk", "act-walk" (the default). "acceptance-walk" is faster (as of 20230717), see         https://lscsoft.docs.ligo.org/bilby/dynesty-guide.html

    bilbyresult = bilby.run_sampler(
        likelihood=me,
        priors=priors,
        outdir=f"{opts.outputfile}",  # the directory to output the results to
        label=f"{opts.outputfile}",  # a label to apply to the output file
        plot=plot,  # by default this is True and a corner plot will be produced
        sampler=sampler,  # set the name of the sampler
        nlive=nlive,  # add in any keyword arguments used by that sampler
        dlogz=dlogz,
        nwalkers=nwalkers,  # add in any keyword arguments used by that sampler
        walks=walks,
        nsteps=nsteps,
        npool=npool,
        sample=sampling_method,
        allow_multi_valued_likelihood = check_bool,
        # injection_parameters={"m": m, "c": c},  # (optional) true values for adding to plots
    )


elif method == "gridded":

    parameter_grid = {}
    fixed_params = {}
    parameter_grid_indices = {}

    for key in parameter_dict:
        if isinstance(parameter_dict[key]["value"], list):
            if (
                parameter_dict[key]["prior"] == "Uniform"
                or parameter_dict[key]["prior"] == "uniform"
            ):
                print(
                    f"Setting a uniform prior on {key} in the range [{parameter_dict[key]['value'][0]}, {parameter_dict[key]['value'][1]}]"
                )
            else:
                raise ValueError(
                    f"Unrecognised prior settings for parameter {key} (specify 'Uniform')."
                )
            if len(parameter_dict[key]["value"]) == 3:
                parameter_grid[key] = np.linspace(
                    parameter_dict[key]["value"][0],
                    parameter_dict[key]["value"][1],
                    parameter_dict[key]["value"][2],
                )
            elif len(parameter_dict[key]["value"]) == 2:
                parameter_grid[key] = np.linspace(
                    parameter_dict[key]["value"][0], parameter_dict[key]["value"][1], 10
                )
            else:
                raise ValueError(f"Incorrect formatting for parameter {key}.")
            parameter_grid_indices[key] = range(len(parameter_grid[key]))
        else:
            fixed_params[key] = float(parameter_dict[key]["value"])

    for key in fixed_params:
        me.parameters[key] = fixed_params[key]
        print(f"Setting parameter {key} to {fixed_params[key]}")

    shape = [len(x) for x in parameter_grid.values()]
    likelihood = np.zeros(shape)

    names = list(parameter_grid.keys())
    values = []
    for items in product(*parameter_grid.values()):
        values.append(items)
    indices = []
    for items in product(*parameter_grid_indices.values()):
        indices.append(items)

    print(f'Computing the likelihood on a grid over {names}')

    # Constraint check, raises value error if priors not allowed by a constraint
    constraint_grid = np.zeros(shape)
    me.mass_priors.grid_constraint_call(constraint_grid, values, parameter_grid, fixed_params)
    for i, value in enumerate(tqdm(values)):
        n = 0
        for name in names:
            me.parameters[name] = value[n]
            n += 1
        likelihood[indices[i]] = constraint_grid[indices[i]] if constraint_grid[indices[i]] == -np.inf else me.log_likelihood()

    # rescale values of the log-likelihood before exponentiating
    likelihood -= np.nanmax(likelihood)  # np.nanmin(likelihood[likelihood != -np.inf])
    likelihood = np.exp(likelihood)

    param_values = [parameter_grid[k] for k in names]

    mylist = np.array([names, param_values, likelihood, opts, parameter_dict], dtype=object)
    np.savez(outputfile + ".npz", mylist)

    if plot:
        no_params = len(names)
        labels = []
        for name in names:
            if "label" in parameter_dict[name]:
                labels.append(parameter_dict[name]["label"])
            else:
                labels.append(name)

        if no_params == 1:
            ind_lik_norm = (
                likelihood / np.sum(likelihood) / (param_values[0][1] - param_values[0][0])
            )
            plt.figure(figsize=[4.2, 4.2])
            plt.plot(param_values[0], ind_lik_norm)
            plt.xlabel(labels[0], fontsize=16)
            plt.xlim(param_values[0][0], param_values[0][-1])
            plt.ylim(0, 1.1 * np.max(ind_lik_norm))
            plt.ylabel(f"p({labels[0]})", fontsize=16)

        else:
            fig, ax = plt.subplots(
                no_params,
                no_params,
                figsize=[4 * no_params, 4 * no_params],
                constrained_layout=True,
            )
            for column in np.arange(0, no_params):
                for row in np.arange(0, no_params):
                    indices = list(range(no_params))
                    if column > row:
                        fig.delaxes(ax[row][column])

                    elif row == column:
                        indices.remove(row)
                        ind_lik = np.sum(likelihood, axis=tuple(indices))
                        ind_lik_norm = (
                            ind_lik
                            / np.sum(ind_lik)
                            / (param_values[row][1] - param_values[row][0])
                        )

                        ax[row, column].plot(param_values[row], ind_lik_norm)
                        ax[row, column].set_xlim(param_values[row][0], param_values[row][-1])
                        ax[row, column].set_ylim(0, 1.1 * np.max(ind_lik_norm))

                    else:
                        indices.remove(row)
                        indices.remove(column)
                        ax[row, column].contourf(
                            param_values[column],
                            param_values[row],
                            np.sum(likelihood, axis=tuple(indices)).T,
                            20,
                        )
                    if column == 0:
                        ax[row, column].set_ylabel(labels[row], fontsize=16)
                    if row == no_params - 1:
                        ax[row, column].set_xlabel(labels[column], fontsize=16)

        plt.savefig(f"{outputfile}.png", dpi=100, bbox_inches="tight")
