#!python

"""
This script computes the line-of-sight redshift prior for a given galaxy catalogue, accounting for catalogue incompleteness.

Rachel Gray
Freija Beirnaert
"""

from gwcosmo.prior.catalog import load_catalog_from_opts
import gwcosmo
import numpy as np
import healpy as hp
import h5py

from gwcosmo.utilities.posterior_utilities import str2bool
from gwcosmo.utilities.arguments import create_parser
from gwcosmo.maps import create_mth_map, create_norm_map
from gwcosmo.likelihood.skymap import ra_dec_from_ipix, ipix_from_ra_dec

import multiprocessing as mp
import threading

import os
import sys
import logging
handler_out = logging.StreamHandler(stream=sys.stdout)
handler_err = logging.StreamHandler(stream=sys.stderr)
handler_err.setLevel(logging.ERROR)
logging.basicConfig(handlers=[handler_out, handler_err], level = logging.INFO)
logger = logging.getLogger(__name__)

POISON_PILL = None

def handle_results(h5file: h5py.File, queue: mp.Queue, npix=1, denom=1., offset=1.):
    """
    Writes back the results and accumulates them if necessary. If npix differs from 1, the combined_pixels will be evaluated.
    """
    logger.info("Writer thread started.")
    combined_pixels = 0

    while True:
        result = queue.get()
        if result == POISON_PILL:
            logger.info("Writer thread being stopped.")
            if npix != 1 and not isinstance(combined_pixels, int):
                combined_pixels /= (denom * npix)
                combined_pixels += offset
                combined_pixels = np.log(combined_pixels)
                logger.info("Writing back combined pixels.")
                h5file.create_dataset("combined_pixels", (len(combined_pixels),), dtype='f', data=combined_pixels)
            return

        (p_of_z, z_array, pixel_index) = result

        if npix != 1:
            combined_pixels += p_of_z
        p_of_z /= denom
        p_of_z += offset
        p_of_z = np.log(p_of_z)
        logger.info(f"pixel {pixel_index}: writing back results.")
        h5file.create_dataset(
            f"{pixel_index}", (len(p_of_z),), dtype='f', data=p_of_z)

def LOS_mp_thread(in_queue, out_queue, catalog, cosmo, zprior, opts, band, nside, sp, luminosity_prior, luminosity_weights, apply_Kcorr=None, mth=None, galaxy_norm=None, zcut=None, zmax=None, catalog_is_complete=False, min_gals_for_threshold=10):
    """
    Handles the multi process threading for multiple pixels.
    """
    while True:
        pixel_index = in_queue.get()
        if pixel_index is None:
            break
        try:
            LOS_zprior = gwcosmo.prior.LOS_redshift_prior.LineOfSightRedshiftPrior(pixel_index, catalog, band, nside, sp, cosmo, zprior, luminosity_prior, luminosity_weights, Kcorr=apply_Kcorr, mth=mth, galaxy_norm=galaxy_norm, zcut=zcut, zmax=zmax, catalog_is_complete=False, min_gals_for_threshold=min_gals_for_threshold)
            (p_of_z, z_array) = LOS_zprior.create_redshift_prior()
            res = (p_of_z, z_array, LOS_zprior.pixel_index)
            out_queue.put(res)
        except Exception:
            logger.exception(f"During calculation of pixel {pixel_index} the following error occurred:")

def LOS_empty_catalog(h5file:h5py.File, catalog, cosmo, zprior, opts, band, nside, sp, luminosity_prior, luminosity_weights, apply_Kcorr=None, zcut=None, zmax=None, catalog_is_complete=False, offset=1.):
    """
    Calculates and stores the LOS redshift prior for an empty pixel, which is normalized to unity.
    Returns the required normalization factor, by which all other LOS redshift priors will be divided as well.
    """
    pixel_index = 0
    denom = 1.
    try: 
        LOS_zprior = gwcosmo.prior.LOS_redshift_prior.LineOfSightRedshiftPrior(pixel_index, catalog, band, nside, sp, cosmo, zprior, luminosity_prior, luminosity_weights, Kcorr=apply_Kcorr, mth="inf", galaxy_norm=None, zcut=zcut, zmax=zmax, catalog_is_complete=catalog_is_complete)
        (p_of_z, z_array) = LOS_zprior.create_redshift_prior()
        denom = np.trapz(p_of_z, z_array)
        p_of_z /= denom
        p_of_z += offset
        p_of_z = np.log(p_of_z)
        h5file.create_dataset("z_array", (len(z_array),), dtype='f', data=z_array)
        h5file.create_dataset("empty_catalogue", (len(p_of_z),), dtype='f', data=p_of_z)
    except Exception:
        logger.exception(f"During calculation for pixel {pixel_index} an error occurred:")
    return denom

def main():
    parser = create_parser("--Kcorrections", "--zmax", "--galaxy_weighting", "--assume_complete_catalog", "--zcut", "--mth", "--schech_alpha", "--schech_Mstar", "--schech_Mmin", "--schech_Mmax", "--H0", "--Omega_m", "--w0", "--wa", "--nside", "--coarse_nside", "--maps_path", "--catalog", "--catalog_band", "--min_gals_for_threshold", "--pixel_index", "--num_threads", "--offset")
    opts = parser.parse_args()
    logger.info(opts)

    zmax = float(opts.zmax)
    apply_Kcorr = str2bool(opts.Kcorrections)
    band = str(opts.catalog_band)
    nside = int(opts.nside)

    cosmo = gwcosmo.utilities.cosmology.standard_cosmology(float(opts.H0), float(opts.Omega_m), float(opts.w0), float(opts.wa))

    #############################################################
    ################## LUMINOSITY FUNCTIONS #####################
    #############################################################

    logger.info(f"Using the galaxy catalogue {band} band magnitudes.")

    if opts.schech_alpha is not None:
        schech_alpha = float(opts.schech_alpha)
    else:
        schech_alpha = None
    if opts.schech_Mstar is not None:
        schech_Mstar = float(opts.schech_Mstar)
    else:
        schech_Mstar = None
    if opts.schech_Mmin is not None:
        schech_Mmin = float(opts.schech_Mmin)
    else:
        schech_Mmin = None
    if opts.schech_Mmax is not None:
        schech_Mmax = float(opts.schech_Mmax)
    else:
        schech_Mmax = None

    sp = gwcosmo.utilities.schechter_params.SchechterParams(band, schech_alpha, schech_Mstar, schech_Mmin, schech_Mmax)
    logger.info("Schechter function with parameters: alpha={}, Mstar={}, Mmin={}, Mmax={}, ".format(sp.alpha, sp.Mstar, sp.Mmin, sp.Mmax))
    
    luminosity_prior = gwcosmo.utilities.luminosity_function.SchechterMagFunction(Mstar_obs=sp.Mstar, alpha=sp.alpha)

    galaxy_weighting = str2bool(opts.galaxy_weighting)
    if galaxy_weighting:
        luminosity_weights = gwcosmo.utilities.host_galaxy_merger_relations.LuminosityWeighting()
        logger.info("Assuming galaxies are luminosity weighted.")
        weights = "luminosity_weighted"
    else:
        luminosity_weights = gwcosmo.utilities.host_galaxy_merger_relations.UniformWeighting()
        logger.info("Assuming galaxies are uniformly weighted.")
        weights = "uniformly_weighted"

    #############################################################
    ########################## MAPS #############################
    #############################################################

    coarse_nside = opts.coarse_nside  
    if opts.pixel_index is None:
        coarse_pixel_index = None
    else:
        coarse_ra, coarse_dec = ra_dec_from_ipix(nside, opts.pixel_index, True)
        coarse_pixel_index = ipix_from_ra_dec(coarse_nside, coarse_ra, coarse_dec, True)

    # Set redshift and colour limits based on whether Kcorrections are applied
    if apply_Kcorr == True:
        if opts.zcut is None:
            if band == 'W1':
                # Polynomial k corrections out to z=1
                zcut = 1.0
                logger.info(f"Automatically applying a redshift cut at z=1 as K-corrections are not valid above this.")
            else:
                # color-based k corrections valid to z=0.5
                zcut = 0.5
                logger.info(f"Automatically applying a redshift cut at z=0.5 as K-corrections are not valid above this.")
            logger.info(f"Override this by specifying a value for zcut. But do this at your own peril!")
        else:
            zcut = float(opts.zcut)
            if band == 'W1' and zcut > 1.0:
                logger.warn(f"Your requested zcut {zcut} is greater than the valid range (1.0) for W1-band k corrections")
            elif zcut > 0.5:
                logger.warn(f"Your requested zcut {zcut} is greater than the valid range (0.5) for k corrections")
            else:
                # zcut is < valid for k-corr, do nothing
                pass
    else:
        if opts.zcut is None:
            zcut = zmax

    
    if opts.maps_path == None:
        #maps_path = os.path.abspath(create_norm_map.__file__).strip(r'create_norm_map.py').strip(r'/')
        maps_path = os.path.abspath(create_norm_map.__file__).rstrip(r'/create_norm_map.py')
    else:
        maps_path = (opts.maps_path).rstrip("/")

    mth = opts.mth
    if opts.mth is None:
        # Try to use all sky map if it exists
        mth_map_path = f"{maps_path}/mth_map_{opts.catalog}_band{band}_nside{coarse_nside}_pixel_indexNone_mingals{opts.min_gals_for_threshold}.fits"
        if not os.path.exists(mth_map_path):
            # Try to use map for coarse_pixel_index if it exists
            mth_map_path = f"{maps_path}/mth_map_{opts.catalog}_band{band}_nside{coarse_nside}_pixel_index{coarse_pixel_index}_mingals{opts.min_gals_for_threshold}.fits"
            if not os.path.exists(mth_map_path):
                # Create map for coarse_pixel_index, which can be None
                create_mth_map.create_mth_map(mth_map_path, opts.catalog, band, coarse_nside, opts.min_gals_for_threshold, coarse_pixel_index)
        mth = mth_map_path
    logger.info(f"mth: {mth}")
    
    # Try to use all sky map if it exists
    norm_map_path = f"{maps_path}/norm_map_{opts.catalog}_band{band}_nside{coarse_nside}_pixel_indexNone_zmax{str(zmax).replace('.', ',')}_zcut{str(zcut).replace('.', ',')}_Mmax{str(sp.Mmax).replace('.', ',')}_Kcorr{apply_Kcorr}.fits"
    if not os.path.exists(norm_map_path):
        # Try tu use map for coarse_pixel_index if it exists
        norm_map_path = f"{maps_path}/norm_map_{opts.catalog}_band{band}_nside{coarse_nside}_pixel_index{coarse_pixel_index}_zmax{str(zmax).replace('.', ',')}_zcut{str(zcut).replace('.', ',')}_Mmax{str(sp.Mmax).replace('.', ',')}_Kcorr{apply_Kcorr}.fits"
        if not os.path.exists(norm_map_path):
            # Create map for coarse_pixel_index, which can be None
            create_norm_map.create_norm_map(norm_map_path, opts.catalog, band, coarse_nside, coarse_pixel_index, zmax, zcut, mth, cosmo, sp.Mmax, apply_Kcorr)
    galaxy_norm = norm_map_path
    logger.info(f"norm map path: {norm_map_path}")

    #############################################################
    ##################### MAIN FUNCTIONS ########################
    #############################################################

    logger.info(f"npixels_tot =  {hp.nside2npix(opts.nside)}")
    if opts.pixel_index is None:
        npix = hp.nside2npix(opts.nside)
        pixel_indices = np.arange(0, npix)
    else:
        npix = 1
        pixel_indices = np.array([int(opts.pixel_index)])

    offset = opts.offset

    f1 = h5py.File(
        f"{opts.catalog}_LOS_redshift_prior_{band}_band_{weights}_nside_{opts.nside}_pixel_index_{opts.pixel_index}.hdf5", "w")

    # Precompute redshift prior
    zprior = cosmo.p_z

    catalog = load_catalog_from_opts(opts)

    # Compute z_array and z_prior for empty_catalogue
    logger.info("Creating and writing z_array and empty catalogue z_prior")
    denom = LOS_empty_catalog(f1, catalog, cosmo, zprior, opts, band, nside, sp, luminosity_prior, luminosity_weights, apply_Kcorr, zcut, zmax, opts.assume_complete_catalog)

    # Create queues for multiprocessing message passing
    mgr = mp.Manager()
    task_queue = mgr.Queue()
    results_queue = mgr.Queue()

    logger.info("Starting worker threads")
    logger.info(npix)
    n_threads = min(npix, opts.num_threads)
    logger.info("Starting writer thread.")
    writer_thread = threading.Thread(target=handle_results, args=(f1, results_queue), kwargs={"npix": npix, "denom": denom, "offset": offset})
    writer_thread.start()
    with mp.Pool(n_threads) as p:
        # Create the mp threads
        logger.info(f"Launching {n_threads} worker threads")
        args=[(task_queue, results_queue, catalog, cosmo, zprior, opts, band, nside, sp, luminosity_prior, luminosity_weights, apply_Kcorr, mth, galaxy_norm, zcut, zmax, opts.assume_complete_catalog, opts.min_gals_for_threshold) for _ in range(n_threads)]

        p.starmap_async(LOS_mp_thread,args)
        for pixel_index in pixel_indices: task_queue.put(pixel_index)

        logger.info(f"Awaiting end of all redshift_prior calculations.")
        for _ in range(opts.num_threads): task_queue.put(POISON_PILL)

        p.close()
        p.join()

    logger.info("Stopping writer thread")
    results_queue.put(POISON_PILL)

    # Wait untill everything is written back
    writer_thread.join()
    logger.info(f"Writer thread stopped.")

    opts.zcut = zcut
    opts.schech_alpha = sp.alpha 
    opts.schech_Mstar = sp.Mstar
    opts.schech_Mmin = sp.Mmin
    opts.schech_Mmax = sp.Mmax
    opts.mth = mth
    opts_string = np.string_(vars(opts))
    f1.attrs["opts"] = opts_string
    logger.info(f"opts_string: {opts_string}")

    #############################################################
    ###################### CHECK OUTPUT #########################
    #############################################################
    logger.info(f"Checking output file")
    
    keys = f1.keys()
    npix_out = len([key for key in keys if key.isdigit()])
    if npix_out != npix:
        logger.warn(f"Number of pixels in output file is {npix_out} which doesn't correspond to expected {npix}")
    arr_names = ["z_array", "empty_catalogue"]
    if npix != 1:
        arr_names += ["combined_pixels"]
    arr_names += list(pixel_indices)
    for name in arr_names:
        try:
            arr = np.array(list(f1[str(name)]))
            if np.isnan(arr).any():
                logger.warn(f"{name} contains nan values")
            if np.isinf(arr).any():
                logger.warn(f"{name} contains inf values")
        except Warning:
            logger.warn(f"Output file doesn't contain {name}")
    f1.close()

if __name__ == "__main__":
    main()
