#!python

# -*- coding: utf-8 -*-
# Copyright (C) Arianna I. Renzini 2024
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import distutils
import json
import os
from pathlib import Path

import bilby
import numpy as np
import tqdm
from bilby.core.prior import Interped
from bilby.core.utils import infer_args_from_function_except_n_args
from gwpopulation.models.mass import SinglePeakSmoothedMassDistribution
from gwpopulation.models.redshift import MadauDickinsonRedshift
from gwpopulation.utils import xp

from popstock.PopulationOmegaGW import PopulationOmegaGW

"""
***
"""

parser = argparse.ArgumentParser()
parser.add_argument('-ns', '--number_samples',help="number of samples.",action="store", type=int, default=None)
parser.add_argument('-nt', '--number_trials',help="number of population trials.",action="store", type=int, default=10000)
parser.add_argument('-wf', '--waveform_approximant',help="Wavefrom approximant. Default is IMRPhenomD.",action="store", type=str, default='IMRPhenomD')
parser.add_argument('-rd', '--run_directory',help="Run directory.",action="store", type=str, default='./')
parser.add_argument('-sm', '--samples',help="Samples to use.",action="store", type=str, default=None)
parser.add_argument('-sid', '--sample_index',help="Index of first sample to include in the cacluation. Samples will be considered in order after this sample. Only considered if both number of samples and sample set are also provided. Default is 0.", action="store", type=int, default=0)
parser.add_argument('-fr', '--frequencies', help="txt file with frequencies to use for the spectrum calculation.",action="store", type=str, default=None)
parser.add_argument('-mp', '--multiprocess', help="If True (default), pools the DEdF calculation.", action="store", type=str, default='True')
args = parser.parse_args()

if args.multiprocess:
    args.multiprocess = bool(distutils.util.strtobool(args.multiprocess))
else:
    args.multiprocess = True

N_trials=args.number_trials
wave_approx=args.waveform_approximant 
rundir=Path(args.run_directory)

"""
***
"""

mass_obj = SinglePeakSmoothedMassDistribution()
redshift_obj = MadauDickinsonRedshift(z_max=10)

models = {
        'mass_model' : mass_obj,
        'redshift_model' : redshift_obj,
        }
if args.frequencies is None:
    freqs = np.arange(10, 2000, 2.5)
else:
    try:
        freqs = np.loadtxt(args.frequencies)
    except ValueError:
        raise ValueError(f"{args.frequencies} is not a txt file.")

newpop = PopulationOmegaGW(models=models, frequency_array=freqs)

if args.samples is not None:
    with open(args.samples) as samples_file:
        samples_dict = json.load(samples_file)
    Lambda_0 = samples_dict['Lambda_0']
    samples_dict.pop('Lambda_0')
    if args.number_samples is None:
        args.number_samples = len(samples_dict['redshift'])
        tag=f'{wave_approx}_{args.number_samples}_samples_{N_trials}_trials'
    else:
        for key in samples_dict.keys():
            samples_dict[key] = samples_dict[key][args.sample_index:args.sample_index+args.number_samples]
        tag=f'{wave_approx}_{args.sample_index}_{args.number_samples}_samples_{N_trials}_trials'
    newpop.set_proposal_samples(proposal_samples = samples_dict)
    print(f'Using {args.number_samples} samples...')


else:
    Lambda_0 =  {'alpha': 2.5, 'beta': 1, 'delta_m': 3, 'lam': 0.04, 'mmax': 100, 'mmin': 4, 'mpp': 33, 'sigpp':5, 'gamma': 2.7, 'kappa': 3, 'z_peak': 1.9, 'rate': 15}
    newpop.draw_and_set_proposal_samples(Lambda_0, N_proposal_samples=args.number_samples)
    tag=f'{wave_approx}_{args.number_samples}_samples_{N_trials}_trials'

newpop.calculate_omega_gw(waveform_approximant=wave_approx, Lambda=Lambda_0, multiprocess=args.multiprocess)

if args.samples is not None:
    np.savez(f"{os.path.join(rundir, f'omegagw_0_{tag}.npz')}", omega_gw=newpop.omega_gw, freqs=newpop.frequency_array, Lambda_0=Lambda_0)

else:
    np.savez(f"{os.path.join(rundir, f'omegagw_0_{tag}.npz')}", omega_gw=newpop.omega_gw, freqs=newpop.frequency_array, fiducial_samples=newpop.proposal_samples, Lambda_0=Lambda_0, draw_dict=newpop.pdraws)

new_omegas={}

new_omegas['Lambdas'] = []
new_omegas['Neff'] = []
new_omegas['omega_gw'] = []

result = bilby.core.result.read_in_result(filename='../test_data/o1o2o3_mass_c_iid_mag_iid_tilt_powerlaw_redshift_result.json')
lambda_samples = result.posterior.sample(N_trials).to_dict('list')

print('Running trials...')
for idx in tqdm.tqdm(range(N_trials)):
    Lambda_new = {
            'alpha': lambda_samples['alpha'][idx], 
            'beta': lambda_samples['beta'][idx], 
            'delta_m': lambda_samples['delta_m'][idx],
            'lam': lambda_samples['lam'][idx], 
            'mmax': lambda_samples['mmax'][idx],
            'mmin': lambda_samples['mmin'][idx],
            'mpp': lambda_samples['mpp'][idx],
            'sigpp': lambda_samples['sigpp'][idx],
            'rate': lambda_samples['rate'][idx],
            'gamma': 2.7, # lambda_samples['lamb'][idx],
            'kappa': 5.6, #3.83,
            'z_peak': 1.9, #0.3*(0.5-np.random.rand())+1.9,
            }
    new_omegas['Lambdas'].append(Lambda_new)

    newpop.calculate_omega_gw(sampling_frequency=2048, Lambda=Lambda_new)
    new_omegas['Neff'].append(float( (xp.sum(newpop.weights)**2)/(xp.sum(newpop.weights**2)) ))
    new_omegas['omega_gw'].append(newpop.omega_gw.tolist())

new_omegas['freqs']=newpop.frequency_array.tolist()

omegas_dict = json.dumps(new_omegas)
f = open(f"{os.path.join(rundir, f'new_omegas_{tag}.json')}","w")
f.write(omegas_dict)
f.close()

print('Done!')

exit()
