#!python
"""
Gridded likelihood computation module
Rachel Gray
"""

import gwcosmo
import json
import os
import sys
import hashlib
import re

from gwcosmo.utilities.posterior_utilities import str2bool
from gwcosmo.utilities.arguments import create_parser
from gwcosmo.likelihood.posterior_samples import load_posterior_samples

parser = create_parser("--posterior_samples","--outputfile","--check_PE_h5")

opts = parser.parse_args()

print(opts)

if opts.check_PE_h5 is not None:
    evt_name = re.findall('GW[1-3][0-9][0-1][0-9][0-3][0-9]',opts.check_PE_h5)
    if len(evt_name) == 0: # no match for event name of format GWXXXXXX, take file's basename
        evt_key = os.path.basename(opts.check_PE_h5)
    else:
        evt_key = evt_name[0]
        if len(evt_name)>1:
            print("Multiple decoding of GW event name. Setting to first occurence: {} -> {}".format(evt_name,evt_key))

    posterior_samples_dictionary = {}
    posterior_samples_dictionary[evt_key] = {}
    posterior_samples_dictionary[evt_key]["posterior_file_path"] = opts.check_PE_h5

elif opts.posterior_samples is not None:

    if str(opts.posterior_samples).endswith('.json'):
        with open(str(opts.posterior_samples)) as json_file:
            posterior_samples_dictionary = json.load(json_file)
    else:
        err_str = """Missing posterior samples. Expecting a json file with format:\n
        {
           "GW150914_095045":
              {
	         "posterior_file_path" : "/path/to/GW150914.h5",
	         "samples_field" : "C01:Mixed",
	         "skymap_path" : "/path/to/GW150914_skymap.fits"
              }
        }
        """
        parser.error(err_str)
else:
    parser.error("Missing posterior samples")


print("\nGW events parameters are:\n")
missing_files = {}
for evt_k, evt_v in posterior_samples_dictionary.items():    
    print(f"--- {evt_k}:---")
    if "posterior_file_path" not in evt_v:
        raise ValueError(f"GW event {evt_k} has not PE samples. Exiting.")
    ndict = {}
    for k, v in evt_v.items():
        if "_path" not in k: # check if it is a key associated with a file path
            print("Paths check: skip key {}".format(k))
            continue
        if not (file_exists := os.path.exists(v)):
            missing_files.setdefault(evt_k, {}).update({k: v})
        add_txt = "DOES NOT EXIST!" if not file_exists else "FILE OK"
        print(f"           {k} -> {v}: {add_txt}")
        ndict['md5_'+k] = hashlib.md5(open(posterior_samples_dictionary[evt_k][k],'rb').read()).hexdigest()
    print("\n")
    evt_v.update(ndict)

if missing_files:
    print("PE or skymaps files are missing:")
    for evt_k, evt_v in missing_files.items():
        print(f"--- {evt_k}:---")
        for k, v in evt_v.items():
            print(f"      missing {k} -> {v}")
    print("Exiting.")
    sys.exit()

for key, value in posterior_samples_dictionary.items():
    save_key = None
    save_key_2 = None
    if "samples_field" in posterior_samples_dictionary[key].keys():
        posterior_samples_dictionary[key].pop("samples_field")
    if "PEprior_file_path" in posterior_samples_dictionary[key].keys():
        save_key = posterior_samples_dictionary[key]["PEprior_file_path"]
        # remove key PEprior if it exists, before calling load_posterior_samples, as we want the h5 contents
        posterior_samples_dictionary[key].pop("PEprior_file_path")
    if "PEprior_kind" in posterior_samples_dictionary[key].keys():
        save_key_2 = posterior_samples_dictionary[key]["PEprior_kind"]
        posterior_samples_dictionary[key].pop("PEprior_kind")
    try:
        samples = load_posterior_samples(posterior_samples_dictionary[key],choose_default_waveform_for_analysis=False)
    except ValueError as ve:
        print(ve)
    if save_key: posterior_samples_dictionary[key]["PEprior_file_path"] = save_key
    if save_key_2: posterior_samples_dictionary[key]["PEprior_kind"] = save_key_2
    
print("\nPosterior sample file contents:\n")
for k in posterior_samples_dictionary.keys():
    print("{}:".format(k))
    thelen = len(str(k))
    blank = " "
    tabs = blank.ljust(thelen)
    for l in posterior_samples_dictionary[k].keys():
        print("{} -> {}: {}".format(tabs,l,posterior_samples_dictionary[k][l]))

if opts.outputfile == parser.get_default('outputfile'):
    jsonfile = "./data.json"
else:
    jsonfile = opts.outputfile
    if jsonfile[-5:] != '.json': jsonfile += '.json'

with open(jsonfile,'a',encoding='utf-8') as f:
    json.dump(posterior_samples_dictionary,f,ensure_ascii=False,indent=4,sort_keys=True)

print("GW event(s) details written in file {}".format(jsonfile))
