#!/usr/bin/env python3
"""
Gridded likelihood computation module
Rachel Gray
"""

import gwcosmo
import numpy as np
import json
import pickle
import bilby
import h5py
import os
import sys
from itertools import product
from tqdm import tqdm

from gwcosmo.utilities.posterior_utilities import str2bool
from gwcosmo.utilities.arguments import create_parser
from gwcosmo.utilities.injections_utilities import default_ifar_value
from gwcosmo.utilities.mass_prior_utilities import *
from gwcosmo.utilities.cosmology import *
from gwcosmo.utilities.check_boundary import *
from gwcosmo.injections import injections_at_detector
from gwcosmo.prior.priors import *

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
matplotlib.rcParams['font.family']= 'Times New Roman'
matplotlib.rcParams['font.sans-serif']= ['Bitstream Vera Sans']
matplotlib.rcParams['mathtext.fontset']= 'stixsans'

import seaborn as sns
sns.set_context('paper')
sns.set_style('ticks')
sns.set_palette('colorblind')


parser = create_parser("--method",
                       "--posterior_samples",
                       "--redshift_evolution",
                       "--reweight_posterior_samples",
                       "--sky_area",
                       "--min_pixels",
                       "--min_samps_in_pixel",
                       "--outputfile",
                       "--LOS_catalog",
                       "--parameter_dict",
                       "--plot",
                       "--injections_path",
                       "--mass_model",
                       "--snr",
                       "--ifar",
                       "--gravity_model",
                       "--sampler",
                       "--nwalkers",
                       "--npool",
                       "--walks",
                       "--nsteps",
                       "--nlive",
                       "--dlogz",
                       "--sampler_likelihood_check")

opts = parser.parse_args()

print(opts)

if opts.method is not None:
    method = str(opts.method)
else:
    parser.error("Missing method. Choose 'sampling' or 'gridded'.")

if opts.LOS_catalog is not None:
    LOS_catalog_path = str(opts.LOS_catalog)
else:
    parser.error('Missing LOS_catalog')

if opts.posterior_samples is not None:
    if str(opts.posterior_samples).endswith('.json'):
        with open(str(opts.posterior_samples)) as json_file:
            posterior_samples_dictionary = json.load(json_file)
    else:
        err_str = """Missing posterior samples. Expecting a json file with format:\n
        {
           "GW150914_095045":
              {
	         "posterior_file_path" : "/path/to/GW150914.h5",
	         "samples_field" : "C01:Mixed",
	         "skymap_path" : "/path/to/GW150914_skymap.fits"
              }
        }
        """
        parser.error(err_str)
else:
    parser.error("Missing posterior samples")


print("\nGW events parameters are:\n")
missing_files = {}
for evt_k, evt_v in posterior_samples_dictionary.items():
    print(f"--- {evt_k}:---")
    if "skymap_path" not in evt_v:
        raise ValueError(f"GW event {evt_k} has not skymap. Exiting.")
    if "posterior_file_path" not in evt_v:
        raise ValueError(f"GW event {evt_k} has not PE samples. Exiting.")
    for k, v in evt_v.items():
        if "_path" not in k: # check if it is a key associated with a file path
            print("           Paths check: skip key {}".format(k))
            continue
        if not (file_exists := os.path.exists(v)):
            missing_files.setdefault(evt_k, {}).update({k: v})
        add_txt = "DOES NOT EXIST!" if not file_exists else "FILE OK"
        print(f"           {k} -> {v}: {add_txt}")
    print("\n")

if missing_files:
    print("PE or skymaps files are missing:")
    for evt_k, evt_v in missing_files.items():
        print(f"--- {evt_k}:---")
        for k, v in evt_v.items():
            print(f"      missing {k} -> {v}")
    print("Exiting.")
    sys.exit()


if opts.parameter_dict is not None:
    with open(str(opts.parameter_dict)) as json_file:
        parameter_dict = json.load(json_file)
else:
    parser.error('Missing parameters file')


try :
    injdata = h5py.File(opts.injections_path,'r')
except OSError:
    injdata = h5py.File(opts.injections_path,'r', locking=False)

if 'ifar' in injdata.keys():
    ifar = np.array(injdata['ifar'])
else:
    ifar = default_ifar_value+0*np.array(injdata['m1d'])

if opts.snr == None and opts.ifar != None:
    opts.snr = 0
elif opts.snr != None and opts.ifar == None:
    opts.ifar = 0
elif opts.snr == None and opts.ifar == None:
    parser.error("You must specify threshold values for SNR and/or IFAR for the analysis.")
print("Analysis will be run with cuts on SNR: {} OR IFAR: {}".format(opts.snr,opts.ifar))
print("For software injections, the IFAR is known: their SNR should be set to -1. Fort semianalytical injections, the SNR is known and the IFAR should be set to -1.")

injections = injections_at_detector(m1d=np.array(injdata['m1d']),
                                    m2d=np.array(injdata['m2d']),
                                    dl=np.array(injdata['dl']),
                                    prior_vals=np.array(injdata['pini']),
                                    snr_det=np.array(injdata['snr']),
                                    snr_cut=opts.snr,
                                    ifar=ifar,
                                    ifar_cut=opts.ifar,
                                    ntotal=np.array(injdata['ntotal']),
                                    Tobs=np.array(injdata['Tobs']))
injdata.close()

mass_model = str(opts.mass_model)
mass_priors = mass_model_selector(mass_model, parser)
print(f'Using the {mass_model} mass model')

gravity_model = str(opts.gravity_model)
if gravity_model == 'GR':
    cosmo = standard_cosmology()
elif gravity_model == 'Xi0_n':
    cosmo = Xi0_n_cosmology()
elif gravity_model == 'extra_dimension':
    cosmo = extra_dimension_cosmology()
elif gravity_model == 'cM':
    cosmo = cM_cosmology()
else:
    parser.error('Unrecognized gravity model')
print(f'Using the {gravity_model} gravity model')

#check_boundary(cosmo, parameter_dict, injections, mass_priors, gravity_model, mass_model)

check_bool = str2bool(opts.sampler_likelihood_check)
plot = str2bool(opts.plot)
reweight_posterior_samples = str2bool(opts.reweight_posterior_samples)
outputfile = str(opts.outputfile)
redshift_evolution = str(opts.redshift_evolution)

if redshift_evolution=='PowerLaw':
    ps_z = gwcosmo.utilities.host_galaxy_merger_relations.RedshiftEvolutionPowerLaw()
elif redshift_evolution=='Madau':
    ps_z = gwcosmo.utilities.host_galaxy_merger_relations.RedshiftEvolutionMadau()
elif redshift_evolution=='None':
    ps_z = gwcosmo.utilities.host_galaxy_merger_relations.RedshiftEvolutionConstant()
print(f'Assuming a {redshift_evolution} redshift evolution model')

min_pixels = int(opts.min_pixels)
sky_area = float(opts.sky_area)

me = gwcosmo.likelihood.dark_siren_likelihood.PixelatedGalaxyCatalogMultipleEventLikelihood(posterior_samples_dictionary, \
                                                                                            injections,
                                                                                            LOS_catalog_path,
                                                                                            ps_z,
                                                                                            cosmo,
                                                                                            mass_priors,
                                                                                            min_pixels=min_pixels,
                                                                                            min_samps_in_pixel=opts.min_samps_in_pixel,
                                                                                            sky_area=sky_area, \
                                                                                            network_snr_threshold=opts.snr,
                                                                                            ifar_cut=opts.ifar)

for key in parameter_dict:
    if key not in me.parameters.keys():
        print(f'WARNING: The parameter {key} from your parameter dictionary is not recognised by the likelihood module.')


if method == 'sampling':
    sampler = str(opts.sampler)
    nwalkers = int(opts.nwalkers)
    walks = int(opts.walks)
    nsteps = int(opts.nsteps)
    nlive = int(opts.nlive)
    dlogz = float(opts.dlogz)
    npool = int(opts.npool)

    priors = bilby.core.prior.PriorDict()

    for key in parameter_dict:
        if isinstance(parameter_dict[key]['value'],list):
            if "label" in parameter_dict[key]:
                label = parameter_dict[key]['label']
            else:
                label = key
            if parameter_dict[key]['prior'] == 'Uniform' or parameter_dict[key]['prior'] == 'uniform':
                priors[key] = bilby.core.prior.Uniform(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
            elif parameter_dict[key]['prior'] == 'Gaussian' or parameter_dict[key]['prior'] == 'gaussian':
                priors[key] = bilby.core.prior.Gaussian(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
                print(f'Gaussian prior assigned to parameter {key}')
            elif parameter_dict[key]['prior'] == 'Loguniform' or parameter_dict[key]['prior'] == 'loguniform':
                priors[key] = bilby.core.prior.LogUniform(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
                print(f'Lognnormal prior assigned to parameter {key}')
            else:
                raise ValueError(f"Unrecognised prior settings for parameter {key} (specify 'Uniform', 'Gaussian' or 'Loguniform'.)")
        else:
            priors[key] = float(parameter_dict[key]['value'])

    # Check prior constraints for the mass distribution, does nothing if not required
    priors = me.mass_priors.sampling_constraint_call(priors)

    sampling_method = "acceptance-walk" # can be "rwalk", "act-walk" (the default). "acceptance-walk" is faster (as of 20230717), see https://lscsoft.docs.ligo.org/bilby/dynesty-guide.html

    print("\nPriors used:")
    for k in priors.keys():
        print("\t{}: {}".format(k,priors[k]))
    print("\n")

    bilbyresult = bilby.run_sampler(
        likelihood=me,
        priors=priors,
        outdir=f"{opts.outputfile}",     # the directory to output the results to
        label=f"{opts.outputfile}",      # a label to apply to the output file
        plot=plot,         # by default this is True and a corner plot will be produced
        sampler=sampler,  # set the name of the sampler
        nlive=nlive,         # add in any keyword arguments used by that sampler
        dlogz=dlogz,
        nwalkers=nwalkers,         # add in any keyword arguments used by that sampler
        walks=walks,
        nsteps=nsteps,
        npool=npool,
        sample=sampling_method,
        allow_multi_valued_likelihood = check_bool
    )

elif method == 'gridded':

    parameter_grid = {}
    fixed_params = {}
    parameter_grid_indices = {}

    for key in parameter_dict:
        if isinstance(parameter_dict[key]['value'],list):
            if parameter_dict[key]['prior'] == 'Uniform' or parameter_dict[key]['prior'] == 'uniform':
                print(f"Setting a uniform prior on {key} in the range [{parameter_dict[key]['value'][0]}, {parameter_dict[key]['value'][1]}]")
            else:
                raise ValueError(f"Unrecognised prior settings for parameter {key} (specify 'Uniform').")
            if len(parameter_dict[key]['value'])==3:
                parameter_grid[key] = np.linspace(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],parameter_dict[key]['value'][2])
            elif len(parameter_dict[key]['value'])==2:
                parameter_grid[key] = np.linspace(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],10)
            else:
                raise ValueError(f"Incorrect formatting for parameter {key}.")
            parameter_grid_indices[key] = range(len(parameter_grid[key]))
        else:
            fixed_params[key] = float(parameter_dict[key]['value'])

    for key in fixed_params:
        me.parameters[key] = fixed_params[key]
        print(f'Setting parameter {key} to {fixed_params[key]}')

    shape = [len(x) for x in parameter_grid.values()]
    likelihood = np.zeros(shape)

    print(f'Estimated runtime for this set of parameters and events is {round(0.25*likelihood.size*len(posterior_samples_dictionary),1)} seconds, which is {round(0.25*likelihood.size*len(posterior_samples_dictionary)/60,1)} minutes or {round(0.25*likelihood.size*len(posterior_samples_dictionary)/(60*60),1)} hours.')

    names = list(parameter_grid.keys())
    values = []
    for items in product(*parameter_grid.values()):
        values.append(items)
    indices = []
    for items in product(*parameter_grid_indices.values()):
        indices.append(items)


    print(f'Computing the likelihood on a grid over {names}')

    # Constraint check, raises value error if priors not allowed by a constraint
    constraint_grid = np.zeros(shape)
    me.mass_priors.grid_constraint_call(constraint_grid, values, parameter_grid, fixed_params)
    for i, value in enumerate(tqdm(values)):
        n=0

        for name in names:
            me.parameters[name] = value[n]
            n += 1
        likelihood[indices[i]] = constraint_grid[indices[i]] if constraint_grid[indices[i]] == -np.inf else me.log_likelihood()



    # rescale values of the log-likelihood before exponentiating
    likelihood -= np.nanmax(likelihood[likelihood < np.inf])
    likelihood = np.exp(likelihood)

    param_values = [parameter_grid[k] for k in names]

    mylist = np.array([names,param_values,likelihood,opts,parameter_dict],dtype=object)
    np.savez(outputfile+'.npz',mylist)

    if plot:
        no_params = len(names)
        labels=[]
        for name in names:
            if "label" in parameter_dict[name]:
                labels.append(parameter_dict[name]["label"])
            else:
                labels.append(name)

        if no_params == 1:
            ind_lik_norm = likelihood/np.sum(likelihood)/(param_values[0][1]-param_values[0][0])
            #print('plot_check', likelihood, np.sum(likelihood),(param_values[0][1]-param_values[0][0] ))
            plt.figure(figsize=[4.2,4.2])
            plt.plot(param_values[0],ind_lik_norm)
            plt.xlabel(labels[0],fontsize=16)
            plt.xlim(param_values[0][0],param_values[0][-1])
            plt.ylim(0,1.1*np.max(ind_lik_norm))
            plt.ylabel(f'p({labels[0]})',fontsize=16)

        else:
            fig, ax = plt.subplots(no_params, no_params, figsize=[4*no_params,4*no_params],constrained_layout=True)
            for column in np.arange(0,no_params):
                for row in np.arange(0,no_params):
                    indices = list(range(no_params))
                    if column > row:
                        fig.delaxes(ax[row][column])

                    elif row == column:
                        indices.remove(row)
                        ind_lik = np.sum(likelihood,axis=tuple(indices))
                        ind_lik_norm = ind_lik/np.sum(ind_lik)/(param_values[row][1]-param_values[row][0])

                        ax[row,column].plot(param_values[row],ind_lik_norm)
                        ax[row,column].set_xlim(param_values[row][0],param_values[row][-1])
                        ax[row,column].set_ylim(0,1.1*np.max(ind_lik_norm))

                    else:
                        indices.remove(row)
                        indices.remove(column)
                        ax[row,column].contourf(param_values[column],param_values[row],np.sum(likelihood,axis=tuple(indices)).T,20)

                    if column == 0:
                        ax[row,column].set_ylabel(labels[row], fontsize=16)
                    if row == no_params-1:
                        ax[row,column].set_xlabel(labels[column],fontsize=16)

        plt.savefig(f'{outputfile}.png',dpi=100,bbox_inches='tight')
