#!/usr/bin/env python3
import argparse
import ast
from pe_automator.workflows.setup_run import setup_run
from pe_automator.workflows.rescue_run import rescue_run
from pe_automator.workflows.monitor_runs import monitor_runs
from pe_automator.workflows.setup_env import setup_env as setup_env_workflow
from pe_automator.workflows.post_process import post_process
from pe_automator.workflows.dlogz_plot import dlogz_plot
from pe_automator.utils.version import is_version_newer
from pe_automator.workflows.run_webpage import run_webpage

# Command setup, monitor, rescue
def main():
    parser = argparse.ArgumentParser(description="PE Automator CLI")
    subparsers = parser.add_subparsers(dest='command', required=True)

    # Setup command
    # real data:       pe_automator setup GW150914 --run_label run1 --approximant IMRPhenomXPNR ...
    # gaussian noise:  pe_automator setup --mode gaussian_noise --inj_set inj_set1 --noise_seeds 1000 2000 --run_label run1 ...
    # real noise:      pe_automator setup --mode real_noise --inj_set inj_set1 --run_label run1 ...
    setup_parser = subparsers.add_parser('setup', help='Set up a PE run (real data, Gaussian-noise or real-noise injection)')
    setup_parser.add_argument('name', type=str,
                              help='Event name (real_data mode, e.g. GW150914) or injection set name (gaussian_noise/real_noise mode, e.g. inj_set1)')
    setup_parser.add_argument('--mode', type=str, default='real_data',
                              choices=['real_data', 'gaussian_noise', 'real_noise'],
                              help='Run mode: real_data (default), gaussian_noise, or real_noise')
    setup_parser.add_argument('--data_path', type=str, default='./data', help='Path to the data directory')
    setup_parser.add_argument('--run_label', type=str, required=True, default='run1', help='Label for the run')
    setup_parser.add_argument('--account', type=str, help='Slurm account name')
    setup_parser.add_argument('--partition', type=str, help='Slurm partition name')
    setup_parser.add_argument('--memory', type=int, default=300, help='Memory in GB for the job, default is 300')
    setup_parser.add_argument('--walltime', type=str, default='71:40:00', help='Walltime for the job, default is 71:40:00')
    setup_parser.add_argument('--cpu', type=int, help='Number of CPUs for the job, default is the number of CPUs per node in the allocation')
    setup_parser.add_argument('--qos', type=str, help='Slurm Quality of Service')
    setup_parser.add_argument('--approximant', type=str, required=True, help='Gravitational wave approximant to use')
    setup_parser.add_argument('--user', type=str, required=True, help='Username for the remote server')
    setup_parser.add_argument('--conda_env', type=str, required=True, help='Conda environment to use on the remote server')
    setup_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    setup_parser.add_argument('--allocation', type=str, required=True, help='Allocation name for the job, e.g., AECT-2025-2-0029')
    setup_parser.add_argument('--npoint', type=int, default=1000, help='npoint for bilby sampler, default is 1000')
    setup_parser.add_argument('--nact', type=int, default=50, help='nact for bilby sampler, default is 50')
    setup_parser.add_argument('--naccept', type=int, default=60, help='naccept for bilby sampler, default is 60')
    setup_parser.add_argument('--maxmcmc', type=int, default=20000, help='maxmcmc for bilby sampler, default is 20000')
    setup_parser.add_argument('--priors', type=str, help='Path to the priors file (optional)', default=None)
    setup_parser.add_argument('--mode_array', type=str, help='Mode array for the run (optional)', default=None)
    setup_parser.add_argument('--parallel_bilby', action='store_true', default=False, help='If set, use parallel bilby for the run')
    setup_parser.add_argument('--wf_min_f', type=float, default=None, help='Minimum frequency for the waveform generator')
    setup_parser.add_argument('--wf_ref_f', type=float, default=None, help='Reference frequency for the waveform generator')
    setup_parser.add_argument('--min_f', type=float, default=None, help='Minimum frequency for the analysis')
    setup_parser.add_argument('--dlogz', type=float, default=0.10, help='dlogz threshold to stop sampling, default is 0.10')
    setup_parser.add_argument('--dry_run', action='store_true', help='If set, data will not be pushed to the remote server')
    setup_parser.add_argument('--waveform_arguments_dict', type=ast.literal_eval, help='Waveform argument dictionary for the run (optional)', default=None)
    setup_parser.add_argument('--comment', type=str, default='', help='Comment for the run, e.g., will be added to the GitLab issue')
    # injection-specific options
    setup_parser.add_argument('--flow', type=float, default=20.0,
                              help='Low-frequency cutoff for noise/waveform generation (injection modes, default 20 Hz)')
    setup_parser.add_argument('--f_ref', type=float, default=20.0,
                              help='Reference frequency for waveform generation (injection modes, default 20 Hz)')
    setup_parser.add_argument('--fetch_buffer', type=int, default=16,
                              help='Extra seconds fetched around the injection window (real_noise mode, default 16)')
    setup_parser.add_argument('--force_regenerate', action='store_true', default=False,
                              help='Re-generate GWF files even if they already exist (injection modes)')
    
    group = setup_parser.add_mutually_exclusive_group()
    group.add_argument('--distance_marginalization', dest='distance_marginalization', action='store_true', help='If set, enable distance marginalization')
    group.add_argument('--no-distance_marginalization', dest='distance_marginalization', action='store_false', help='If set, disable distance marginalization')
    
    encrypt_group = setup_parser.add_mutually_exclusive_group()
    encrypt_group.add_argument('--encrypt', dest='encrypt', action='store_true', help='Enable encryption')
    encrypt_group.add_argument('--no-encrypt', dest='encrypt', action='store_false', help='Disable encryption')
    
    setup_parser.set_defaults(distance_marginalization=None, encrypt=None) # Set to encrypt=None so we can detect if the user didn't choose
    setup_parser.set_defaults(func=setup)

    # Monitor command
    monitor_parser = subparsers.add_parser('monitor', help='Monitor the PE Automator processes')
    monitor_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    monitor_parser.add_argument('--ssh_key', type=str, required=True, help='Path to the SSH key for remote access')
    monitor_parser.add_argument('--data_path', type=str, default='.', help='Path to the data directory')
    monitor_parser.set_defaults(func=monitor)

    # Setup environment command
    setup_env_parser = subparsers.add_parser('setup_env', help='Setup the PE Automator environment on remote servers')
    setup_env_parser.add_argument('env_name', type=str, help='Name of the conda environment to set up')
    setup_env_parser.add_argument('--source_env', type=str, help='Path to the packed source conda environment file')
    setup_env_parser.add_argument('--source_remote', type=str, help='Remote server to fetch the source environment from')
    setup_env_parser.add_argument('--data_dir', type=str, default='.', help='Path to the data directory where the environment will be set up')
    setup_env_parser.set_defaults(func=setup_env)

    # Rescue command
    rescue_parser = subparsers.add_parser('rescue', help='Rescue a PE Automator process')
    rescue_parser.add_argument('issue_number', type=int, help='GitLab issue number to rescue')
    rescue_parser.add_argument('--data_path', type=str, default='.', help='Path to the data directory')
    rescue_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    rescue_parser.add_argument('--walltime', type=str, default='', help='Walltime for the job, default will not change the original walltime')
    rescue_parser.set_defaults(func=rescue)

    # Post-process command
    post_process_parser = subparsers.add_parser('post_process', help='Post-process the results from the parameter estimation')
    post_process_parser.add_argument('--results_dir', type=str, default='./results', help='Directory containing the results of the parameter estimation')
    post_process_parser.add_argument('--data_dir', type=str, default='./data', help='Path to the data directory')
    post_process_parser.add_argument('--gwtc_dir', type=str, default='./gwtc', help='Directory containing the GWTC results')
    post_process_parser.add_argument('--output_dir', type=str, default='.', help='Directory to save the post-processed results')
    post_process_parser.set_defaults(func=post_process_cmd)

    # dlogz checker
    dlogz_parser = subparsers.add_parser('dlogz', help='Check dlogz progress')
    dlogz_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    dlogz_parser.add_argument('--ssh_key', type=str, required=True, help='Path to the SSH key for remote access')
    dlogz_parser.add_argument('--data_path', type=str, default='./data', help='Path to the data directory')
    dlogz_parser.add_argument('--save_file', type=str, default=None, help='Path to save the dlogz plot')
    dlogz_parser.set_defaults(func=dlogz_plot_cmd)

    # run webpage
    run_webpage_parser = subparsers.add_parser('run_webpage', help='Run the webpage generation')
    run_webpage_parser.add_argument('--data_folder', type=str, required=True, help='Path to the data folder')
    run_webpage_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    run_webpage_parser.add_argument('--output_html', type=str, default='waveforms.html', help='Path to the output HTML file')
    run_webpage_parser.set_defaults(func=run_webpage_cmd)

    args = parser.parse_args()

    if args.command == 'setup' and args.encrypt is None:    # Handle default encryption
        if args.mode in ['gaussian_noise', 'real_noise']:
            args.encrypt = False
        else:
            args.encrypt = True
    # Encrypt is currently not passed to the setup_run for injections, but this logic is here for future development to encrypt injection runs as well.
    
    if hasattr(args, 'func'):
        args.func(args)
    else:
        parser.print_help()

    # if no command is given, print help
    if len(vars(args)) == 0:
        parser.print_help()


def setup(args):
    if not is_version_newer(f"{args.data_path}/project/project.json"):
        print("Your PE Automator version is not up to date. Please update to the latest version.")
        return

    if args.mode == 'real_data':
        print(f"Setting up real-data run for event '{args.name}'...")
        setup_run(
            mode="real_data",
            eventname=args.name,
            data_path=args.data_path,
            run_label=args.run_label,
            account=args.account,
            partition=args.partition,
            memory=args.memory,
            walltime=args.walltime,
            cpu=args.cpu,
            qos=args.qos,
            approximant=args.approximant,
            user=args.user,
            conda_env=args.conda_env,
            private_token=args.private_token,
            allocation=args.allocation,
            npoint=args.npoint,
            nact=args.nact,
            naccept=args.naccept,
            maxmcmc=args.maxmcmc,
            priors=args.priors,
            mode_array=args.mode_array,
            distance_marginalization=args.distance_marginalization,
            encrypt_label=args.encrypt,
            parallel_bilby=args.parallel_bilby,
            comment=args.comment,
            waveform_minimum_frequency=args.wf_min_f,
            waveform_reference_frequency=args.wf_ref_f,
            waveform_arguments_dict=args.waveform_arguments_dict,
            minimum_frequency=args.min_f,
            dry_run=args.dry_run,
            dlogz=args.dlogz,
        )

    elif args.mode == 'gaussian_noise':
        print(f"Setting up Gaussian-noise injection runs for inj_set '{args.name}'...")
        setup_run(
            mode="gaussian_noise",
            data_path=args.data_path,
            inj_set=args.name,
            run_label=args.run_label,
            approximant=args.approximant,
            flow=args.flow,
            f_ref=args.f_ref,
            account=args.account,
            partition=args.partition,
            qos=args.qos,
            memory=args.memory,
            walltime=args.walltime,
            cpu=args.cpu,
            user=args.user,
            conda_env=args.conda_env,
            private_token=args.private_token,
            allocation=args.allocation,
            npoint=args.npoint,
            nact=args.nact,
            naccept=args.naccept,
            maxmcmc=args.maxmcmc,
            priors=args.priors,
            mode_array=args.mode_array,
            distance_marginalization=args.distance_marginalization,
            parallel_bilby=args.parallel_bilby,
            waveform_minimum_frequency=args.wf_min_f,
            waveform_reference_frequency=args.wf_ref_f,
            minimum_frequency=args.min_f,
            dlogz=args.dlogz,
            comment=args.comment,
            force_regenerate=args.force_regenerate,
            dry_run=args.dry_run,
        )

    elif args.mode == 'real_noise':
        print(f"Setting up real-noise injection runs for inj_set '{args.name}'...")
        setup_run(
            mode="real_noise",
            data_path=args.data_path,
            inj_set=args.name,
            run_label=args.run_label,
            approximant=args.approximant,
            flow=args.flow,
            f_ref=args.f_ref,
            fetch_buffer=args.fetch_buffer,
            account=args.account,
            partition=args.partition,
            qos=args.qos,
            memory=args.memory,
            walltime=args.walltime,
            cpu=args.cpu,
            user=args.user,
            conda_env=args.conda_env,
            private_token=args.private_token,
            allocation=args.allocation,
            npoint=args.npoint,
            nact=args.nact,
            naccept=args.naccept,
            maxmcmc=args.maxmcmc,
            priors=args.priors,
            mode_array=args.mode_array,
            distance_marginalization=args.distance_marginalization,
            parallel_bilby=args.parallel_bilby,
            waveform_minimum_frequency=args.wf_min_f,
            waveform_reference_frequency=args.wf_ref_f,
            minimum_frequency=args.min_f,
            dlogz=args.dlogz,
            comment=args.comment,
            force_regenerate=args.force_regenerate,
            dry_run=args.dry_run,
        )


def rescue(args):
    # Placeholder for rescue logic
    print(f"Rescuing run for issue number {args.issue_number}...")
    rescue_run(
        issue_number=args.issue_number,
        data_path=args.data_path,
        private_token=args.private_token,
        walltime=args.walltime
    )


def monitor(args):
    monitor_runs(
        private_token=args.private_token,
        ssh_key=args.ssh_key,
        data_path=args.data_path,
    )

def setup_env(args):
    # Placeholder for setup environment logic
    print(f"Setting up environment: {args.env_name}")
    setup_env_workflow(
        env_name=args.env_name,
        source_env=args.source_env,
        source_remote=args.source_remote,
        data_dir=args.data_dir
    )

def post_process_cmd(args):
    # Placeholder for post-processing logic
    print("Post-processing results...")
    post_process(
        results_dir=args.results_dir,
        data_dir=args.data_dir,
        gwtc_dir=args.gwtc_dir,
        output_dir=args.output_dir
    )

def dlogz_plot_cmd(args):
    print("Generating dlogz plot...")
    dlogz_plot(
        private_token=args.private_token,
        ssh_key=args.ssh_key,
        data_path=args.data_path,
        save_file=args.save_file
    )

def run_webpage_cmd(args):
    run_webpage(
        data_folder=args.data_folder,
        private_token=args.private_token,
        output_html=args.output_html
    )

if __name__ == "__main__":
    main()