#!/usr/bin/env python3
import argparse
from pe_automator.workflows.setup_run import setup_run
from pe_automator.workflows.rescue_run import rescue_run
from pe_automator.workflows.monitor_runs import monitor_runs
from pe_automator.workflows.setup_env import setup_env as setup_env_workflow
from pe_automator.workflows.post_process import post_process
from pe_automator.workflows.dlogz_plot import dlogz_plot
from pe_automator.utils.version import is_version_newer
from pe_automator.workflows.run_webpage import run_webpage

# Command setup, monitor, rescue
def main():
    parser = argparse.ArgumentParser(description="PE Automator CLI")
    subparsers = parser.add_subparsers(dest='command', required=True)

    # Setup command setup eventname --data_path . --run_label run1 --account uib107 --partition gpp --qos gp_resa --approximant IMRPhenomXPNR --user resh000428 --hostname picasso.scbi.uma.es --conda_env parallel_bilby --private_token <token> --allocation AECT-2025-2-0029
    setup_parser = subparsers.add_parser('setup', help='Setup the PE Automator environment')
    setup_parser.add_argument('eventname', type=str, help='Name of the event to set up')
    setup_parser.add_argument('--data_path', type=str, default='./data', help='Path to the data directory')
    setup_parser.add_argument('--run_label', type=str, required=True, default='run1', help='Label for the run')
    setup_parser.add_argument('--account', type=str, help='Slurm account name')
    setup_parser.add_argument('--partition', type=str, help='Slurm partition name')
    setup_parser.add_argument('--memory', type=int, default=300, help='Memory in GB for the job, default is 300')
    setup_parser.add_argument('--walltime', type=str, default='71:40:00', help='Walltime for the job, default is 71:40:00')
    setup_parser.add_argument('--cpu', type=int, help='Number of CPUs for the job, default is the number of CPUs per node in the allocation')
    setup_parser.add_argument('--qos', type=str, help='Slurm Quality of Service')
    setup_parser.add_argument('--approximant', type=str, required=True, help='Gravitational wave approximant to use')
    setup_parser.add_argument('--user', type=str, required=True, help='Username for the remote server')
    setup_parser.add_argument('--conda_env', type=str, required=True, help='Conda environment to use on the remote server')
    setup_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    setup_parser.add_argument('--allocation', type=str, required=True, help='Allocation name for the job, e.g., AECT-2025-2-0029')
    setup_parser.add_argument('--npoint', type=int, default=1000, help='npoint for bilby sampler, default is 1000')
    setup_parser.add_argument('--nact', type=int, default=50, help='nact for bilby sampler, default is 50')
    setup_parser.add_argument('--naccept', type=int, default=60, help='naccept for bilby sampler, default is 60')
    setup_parser.add_argument('--maxmcmc', type=int, default=20000, help='maxmcmc for bilby sampler, default is 20000')
    setup_parser.add_argument('--priors', type=str, help='Path to the priors file (optional)', default=None)
    setup_parser.add_argument('--mode_array', type=str, help='Mode array for the run (optional)', default=None)
    setup_parser.add_argument('--parallel_bilby', action='store_true', default=False, help='If set, use parallel bilby for the run')
    setup_parser.add_argument('--wf_min_f', type=float, default=None, help='Minimum frequency for the waveform generator')
    setup_parser.add_argument('--wf_ref_f', type=float, default=None, help='Reference frequency for the waveform generator')
    setup_parser.add_argument('--min_f', type=float, default=None, help='Minimum frequency for the analysis')
    setup_parser.add_argument('--dry_run', action='store_true', help='If set, data will not be pushed to the remote server')
    setup_parser.add_argument('--comment', type=str, default='', help='Comment for the run, e.g., will be added to the GitLab issue')
    group = setup_parser.add_mutually_exclusive_group()
    group.add_argument('--distance_marginalization', dest='distance_marginalization', action='store_true', help='If set, enable distance marginalization')
    group.add_argument('--no-distance_marginalization', dest='distance_marginalization', action='store_false', help='If set, disable distance marginalization')
    setup_parser.set_defaults(distance_marginalization=None)
    setup_parser.set_defaults(func=setup)

    # Monitor command
    monitor_parser = subparsers.add_parser('monitor', help='Monitor the PE Automator processes')
    monitor_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    monitor_parser.add_argument('--ssh_key', type=str, required=True, help='Path to the SSH key for remote access')
    monitor_parser.add_argument('--data_path', type=str, default='.', help='Path to the data directory')
    monitor_parser.set_defaults(func=monitor)

    # Setup environment command
    setup_env_parser = subparsers.add_parser('setup_env', help='Setup the PE Automator environment on remote servers')
    setup_env_parser.add_argument('env_name', type=str, help='Name of the conda environment to set up')
    setup_env_parser.add_argument('--source_env', type=str, help='Path to the packed source conda environment file')
    setup_env_parser.add_argument('--source_remote', type=str, help='Remote server to fetch the source environment from')
    setup_env_parser.add_argument('--data_dir', type=str, default='.', help='Path to the data directory where the environment will be set up')
    setup_env_parser.set_defaults(func=setup_env)

    # Rescue command
    rescue_parser = subparsers.add_parser('rescue', help='Rescue a PE Automator process')
    rescue_parser.add_argument('issue_number', type=int, help='GitLab issue number to rescue')
    rescue_parser.add_argument('--data_path', type=str, default='.', help='Path to the data directory')
    rescue_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    rescue_parser.add_argument('--walltime', type=str, default='', help='Walltime for the job, default will not change the original walltime')
    rescue_parser.set_defaults(func=rescue)

    # Post-process command
    post_process_parser = subparsers.add_parser('post_process', help='Post-process the results from the parameter estimation')
    post_process_parser.add_argument('--results_dir', type=str, default='./results', help='Directory containing the results of the parameter estimation')
    post_process_parser.add_argument('--data_dir', type=str, default='./data', help='Path to the data directory')
    post_process_parser.add_argument('--gwtc_dir', type=str, default='./gwtc', help='Directory containing the GWTC results')
    post_process_parser.add_argument('--output_dir', type=str, default='.', help='Directory to save the post-processed results')
    post_process_parser.set_defaults(func=post_process_cmd)

    # dlogz checker
    dlogz_parser = subparsers.add_parser('dlogz', help='Check dlogz progress')
    dlogz_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    dlogz_parser.add_argument('--ssh_key', type=str, required=True, help='Path to the SSH key for remote access')
    dlogz_parser.add_argument('--data_path', type=str, default='./data', help='Path to the data directory')
    dlogz_parser.add_argument('--save_file', type=str, default=None, help='Path to save the dlogz plot')
    dlogz_parser.set_defaults(func=dlogz_plot_cmd)

    # run webpage
    run_webpage_parser = subparsers.add_parser('run_webpage', help='Run the webpage generation')
    run_webpage_parser.add_argument('--data_folder', type=str, required=True, help='Path to the data folder')
    run_webpage_parser.add_argument('--private_token', type=str, required=True, help='Private token for GitLab access')
    run_webpage_parser.add_argument('--output_html', type=str, default='waveforms.html', help='Path to the output HTML file')
    run_webpage_parser.set_defaults(func=run_webpage_cmd)

    args = parser.parse_args()
    
    if hasattr(args, 'func'):
        args.func(args)
    else:
        parser.print_help()

    # if no command is given, print help
    if len(vars(args)) == 0:
        parser.print_help()


def setup(args):
    if not is_version_newer(f"{args.data_path}/project/project.json"):
        print("Your PE Automator version is not up to date. Please update to the latest version.")
        return

    # Placeholder for setup logic
    print("Setting up the PE Automator environment...")
    setup_run(
        eventname=args.eventname,
        data_path=args.data_path,
        run_label=args.run_label,
        account=args.account,
        partition=args.partition,
        memory=args.memory,
        walltime=args.walltime,
        cpu=args.cpu,
        qos=args.qos,
        approximant=args.approximant,
        user=args.user,
        conda_env=args.conda_env,
        private_token=args.private_token,
        allocation=args.allocation,
        npoint=args.npoint,
        nact=args.nact,
        naccept=args.naccept,
        maxmcmc=args.maxmcmc,
        priors=args.priors,
        mode_array=args.mode_array,
        distance_marginalization=args.distance_marginalization,
        parallel_bilby=args.parallel_bilby,
        comment=args.comment,
        waveform_minimum_frequency=args.wf_min_f,
        waveform_reference_frequency=args.wf_ref_f,
        minimum_frequency=args.min_f,
        dry_run=args.dry_run
    )


def rescue(args):
    # Placeholder for rescue logic
    print(f"Rescuing run for issue number {args.issue_number}...")
    rescue_run(
        issue_number=args.issue_number,
        data_path=args.data_path,
        private_token=args.private_token,
        walltime=args.walltime
    )


def monitor(args):
    monitor_runs(
        private_token=args.private_token,
        ssh_key=args.ssh_key,
        data_path=args.data_path,
    )

def setup_env(args):
    # Placeholder for setup environment logic
    print(f"Setting up environment: {args.env_name}")
    setup_env_workflow(
        env_name=args.env_name,
        source_env=args.source_env,
        source_remote=args.source_remote,
        data_dir=args.data_dir
    )

def post_process_cmd(args):
    # Placeholder for post-processing logic
    print("Post-processing results...")
    post_process(
        results_dir=args.results_dir,
        data_dir=args.data_dir,
        gwtc_dir=args.gwtc_dir,
        output_dir=args.output_dir
    )

def dlogz_plot_cmd(args):
    print("Generating dlogz plot...")
    dlogz_plot(
        private_token=args.private_token,
        ssh_key=args.ssh_key,
        data_path=args.data_path,
        save_file=args.save_file
    )

def run_webpage_cmd(args):
    run_webpage(
        data_folder=args.data_folder,
        private_token=args.private_token,
        output_html=args.output_html
    )

if __name__ == "__main__":
    main()