#!/usr/bin/env python
## -*- coding: utf-8 -*-

"""Script to test DQR scripts
"""

import gwdetchar


import numpy as np
import os, sys, shutil
import bilby
import numpy as np
import subprocess, argparse
import tempfile

from lalsimulation import SimInspiralTransformPrecessingWvf2PE

from pycbc.filter import matched_filter, match, overlap, sigma, sigmasq
from pycbc.detector import Detector

import h5py as h5

from gwosc.datasets import event_gps
from gwpy.timeseries import TimeSeries
from gwpy.frequencyseries import FrequencySeries
from pycbc.filter import resample_to_delta_t, highpass
from pycbc.psd import interpolate, inverse_spectrum_truncation

from gwpy.table import EventTable

print(sys.argv[:])

def convert_to_bilby_spins(inj_params, f_ref):
    
    args_list = bilby.gw.utils.convert_args_list_to_float(
        inj_params['inclination'],
        inj_params['spin_1x'],inj_params['spin_1y'],inj_params['spin_1z'],
        inj_params['spin_2x'],inj_params['spin_2y'],inj_params['spin_2z'],
        inj_params['mass_1'],inj_params['mass_2'], f_ref,
        inj_params['phase'])
    
    out_spins = SimInspiralTransformPrecessingWvf2PE(*args_list)
    
    inj_params['theta_jn'], inj_params['phi_jl'], inj_params['tilt1'], \
    inj_params['tilt2'], inj_params['phi12'], inj_params['chi_1'], \
    inj_params['chi_2'] = out_spins

    return inj_params

parser = argparse.ArgumentParser()
parser.add_argument('--exc-args', type=str,
                    help='')
parser.add_argument('--injection-file', type=os.path.abspath,
                    help='')
parser.add_argument('--injection-num', type=int,
                    help='')
parser.add_argument('--ifar-cut', type=float,
                    help='')
parser.add_argument('--psd-file', type=os.path.abspath,
                    help='')
parser.add_argument('--ifo', type=str,
                    help='')
parser.add_argument('--output-dir', type=os.path.abspath,
                    help='')
parser.add_argument('--no-injections', action='store_true',
                    help='')
args = parser.parse_args()
if not os.path.isdir(args.output_dir):
    os.mkdir(args.output_dir)


ifo = args.ifo

gps = event_gps('GW190814')
segment = (int(gps)-64-2, int(gps)+64+2)


ts = TimeSeries.fetch_open_data(ifo, *segment,verbose=True,cache=True).resample(2048).to_pycbc()
ts = ts.highpass_fir(15, 512)
data_raw = resample_to_delta_t(ts, 1.0/2048).crop(2, 2)

p = data_raw.psd(4)
p = interpolate(p, data_raw.delta_f)
p = inverse_spectrum_truncation(p, int(4 * data_raw.sample_rate), low_frequency_cutoff=5.0)
psd = p

ran_duration = 256
sample_rate = 2048

inj_num = 100
inj_dict = {}

inj_out_loc = args.injection_file
out_file = h5.File(inj_out_loc,'r')
for k in out_file['found'].keys():
    try:
        inj_dict[k] = out_file['found/'+k][:]
    except:
        inj_dict[k+'_time'] = out_file['found/'+k+'/time'][:]
for k in out_file['injections'].keys():
    inj_dict[k] = out_file['injections/'+k][:][inj_dict['injection_index']]
out_file.close()

inj_table = EventTable(
                        data  = [inj_dict[k] for k in inj_dict.keys()],
                        names = [k for k in inj_dict.keys()]
                      )

inj_table = inj_table.filter('ifar_exc > %.4f'%float(args.ifar_cut))

inj_num = args.injection_num
inj = inj_table[inj_num]

gps = inj['geocent_end_time'] + inj['geocent_end_time_ns']*1e-9

dist = inj['distance']
if args.no_injections:
    dist = 100000

injection_parameters = dict(
    mass_1=inj['mass1'], mass_2=inj['mass2'], 
    spin_1x=inj['spin1x'], spin_1y=inj['spin1y'], spin_1z=inj['spin1z'], 
    spin_2x=inj['spin2x'], spin_2y=inj['spin2y'], spin_2z=inj['spin2z'],
    luminosity_distance=dist, inclination=inj['inclination'], psi=inj['polarization'],
    phase=inj['coa_phase'], geocent_time=gps, ra=inj['ra'], dec=inj['dec'])

reference_frequency=50.
waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
                          reference_frequency=reference_frequency, minimum_frequency=10.)

waveform_generator = bilby.gw.WaveformGenerator(
    duration=ran_duration, sampling_frequency=sample_rate,
    frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
    parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
    waveform_arguments=waveform_arguments)



# =======

tmpdir = tempfile.mkdtemp()

bilby_psd = bilby.gw.detector.PowerSpectralDensity.from_power_spectral_density_array(
                        psd_array=psd,
                        frequency_array=psd.sample_frequencies)
np.random.seed(int(inj_num))
data_ran_f, _ = bilby_psd.get_noise_realisation(sample_rate, ran_duration)
data_ran_t = bilby.gw.detector.InterferometerStrainData()
data_ran_t.set_from_frequency_domain_strain(
                        frequency_domain_strain=data_ran_f,
                        sampling_frequency=sample_rate,
                        duration=ran_duration)

data = data_ran_t.to_gwpy_timeseries()
data.t0 = int(gps-ran_duration*0.75)

det = bilby.gw.detector.get_empty_interferometer(ifo)
det.strain_data.set_from_gwpy_timeseries(data)
det.inject_signal(parameters=convert_to_bilby_spins(injection_parameters, reference_frequency),
                              waveform_generator=waveform_generator)
data = det.strain_data.to_gwpy_timeseries()
channel_name = args.ifo+":DATA"
data.name = channel_name

data_loc = os.path.join(tmpdir,'L1-data.gwf') 
data.write(data_loc)

exc_args = args.exc_args
if os.path.isfile(exc_args):
    exc_args = list(np.loadtxt(exc_args, unpack=True, dtype=str)) 
else:
    exc_args = exc_args.split()

for i, a in enumerate(exc_args):
    for k in inj.keys():
        if '{%s}'%k in a:
            print('Replacing {%s} with'%k,inj[k])
            exc_args[i] = a.replace('{%s}'%k,str(inj[k]))
    if '{inj_num}' in a:
        print('Replacing {inj_num} with',inj_num)
        exc_args[i] = a.replace('{inj_num}',str(inj_num))
    if '{output_dir}' in a:
        print('Replacing {output_dir} with',args.output_dir)
        exc_args[i] = a.replace('{output_dir}',str(args.output_dir))
    if '{ifo}' in a:
        print('Replacing {ifo} with',args.ifo)
        exc_args[i] = a.replace('{ifo}',str(args.ifo))
    if '{source}' in a:
        print('Replacing {source} with',str(data_loc))
        exc_args[i] = a.replace('{source}',str(data_loc))
    if '{channel}' in a:
        print('Replacing {channel} with',channel_name)
        exc_args[i] = a.replace('{channel}',str(channel_name))
    if '{mass1}' in a:
        print('Replacing {mass1} with',inj['mass1'])
        exc_args[i] = a.replace('{mass1}',str(inj['mass1']))
    if '{mass2}' in a:
        print('Replacing {mass2} with',inj['mass2'])
        exc_args[i] = a.replace('{mass2}',str(inj['mass2']))
    if '{spin1z}' in a:
        print('Replacing {spin1z} with',inj['spin1z'])
        exc_args[i] = a.replace('{spin1z}',inj['spin1z'])
    if '{spin2z}' in a:
        print('Replacing {spin2z} with',inj['spin2z'])
        exc_args[i] = a.replace('{spin2z}',str(inj['spin2z']))

print(exc_args)


subprocess.check_output(exc_args)





              


