#!/usr/bin/env python
"""Generate load-balanced Slurm job scripts for mpi-fastspecfit.

This script queries the remaining work for a given specprod and stage by
calling fastspecfit.mpi.plan(), divides the outstanding healpix pixels into
N balanced jobs (by target count), and writes one .slurm file per job.

For samplefile runs (--samplefile), use etc/fastspecfit-sample.sh instead.
A samplefile job runs as a single mpi-fastspecfit invocation across as many
nodes as needed; MPI distributes the healpix files across ranks automatically.

Requirements
------------
The DESI software environment must be sourced before running this script so
that fastspecfit and desispec are importable and DESI_SPECTRO_REDUX is set.

Examples
--------
# Print remaining-work summary only (no files written):
generate-fast-slurm --specprod loa --njobs 4 --nodes 16 --mp 16 --plan

# Generate 4 balanced .slurm files in the current directory:
generate-fast-slurm --specprod loa --njobs 4 --nodes 16 --mp 16

# Generate and immediately submit:
generate-fast-slurm --specprod loa --njobs 4 --nodes 16 --mp 16 --submit

# Override defaults for a fastphot run:
generate-fast-slurm --specprod loa --stage fastphot --njobs 2 --nodes 32 --mp 16 \\
    --survey main --program dark,bright
"""

import os
import argparse
import numpy as np

# --- NERSC / DESI defaults (override via CLI) --------------------------------
_ACCOUNT   = 'desi'
_QOS       = 'regular'
_MAIL_USER = 'jmoustakas@siena.edu'
_OUTDIR    = os.path.expandvars('$PSCRATCH/fastspecfit/data')
_ENV_SCRIPT = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fastspecfit-env.sh')

# Maximum wall-clock hours allowed for the chosen QOS (used when --time is not given).
_MAX_HOURS = {'regular': 12, 'debug': 0.5, 'premium': 24}

# Empirical throughput: objects per minute per physical core (conservative).
_OBJ_PER_MIN_PER_CORE = 1.5

# ---------------------------------------------------------------------------

_SLURM_TEMPLATE = """\
#!/bin/bash -l
#SBATCH --account={account}
#SBATCH --qos={qos}
#SBATCH --constraint=cpu
#SBATCH --mail-user={mail_user}
#SBATCH --mail-type=ALL
#SBATCH --nodes={nodes}
#SBATCH --time={walltime}
#SBATCH --job-name={job_name}
#SBATCH --output={job_name}-%j.log

echo "Starting {job_name} (job {job_index}/{njobs}) at $(date)"
echo "Healpix pixels: {nhealpix} | Targets: {ntargets_total:,d}"

source {env_script}

N={nodes}
mp={mp}

ntasks=$(( 128 * N / mp ))
if [[ $mp -gt 1 ]]; then
    cpus_per_task=$(( mp * 2 ))
    cpu_bind="none"
else
    cpus_per_task=2
    cpu_bind="cores"
fi

HEALPIX="{healpix_csv}"

mpiscript=$(type -p mpi-fastspecfit)

# Warm up Numba cache on a single rank before the parallel run.
# Compiled artifacts go to $NUMBA_CACHE_DIR (shared Lustre), so all
# subsequent ranks load from cache rather than recompiling.
FIRST_HEALPIX=$(echo "$HEALPIX" | cut -d',' -f1)
warmup_args="--specprod={specprod} --coadd-type={coadd_type}"
warmup_args+=" --survey={survey} --program={program}"
warmup_args+=" --healpix=$FIRST_HEALPIX"
warmup_args+=" --mp=1 --ntargets=1 --nompi"
warmup_args+=" --outdir-data=/tmp/fastspecfit-warmup --overwrite"
{stage_flag_warmup}
echo "Warming up Numba cache on one rank..."
srun --nodes=1 --ntasks=1 --cpus-per-task=2 --cpu-bind=cores $mpiscript $warmup_args
echo "Warm-up complete."

args="--specprod={specprod}"
args+=" --coadd-type={coadd_type}"
args+=" --survey={survey}"
args+=" --program={program}"
args+=" --mp=$mp"
args+=" --outdir-data={outdir_data}"
args+=" --healpix=$HEALPIX"
{stage_flag}
srun_args="--nodes=$N --ntasks=$ntasks --cpus-per-task=$cpus_per_task --cpu-bind=$cpu_bind"

cmd="time srun $srun_args $mpiscript $args"
echo "$cmd"
$cmd

if [ $? -eq 0 ]; then
    echo "SUCCESS: {job_name} done at $(date)"
else
    echo "FAILED: {job_name} done at $(date)"
    exit 1
fi
"""


def query_remaining(specprod, coadd_type, survey, program, fastphot, outdir_data, mp=1):
    """Return (redrockfiles, healpix_ids, ntargets) for files not yet processed."""
    from fastspecfit.mpi import plan

    _, redrockfiles, _, ntargets = plan(
        comm=None,
        specprod=specprod,
        coadd_type=coadd_type,
        survey=survey.split(','),
        program=program.split(','),
        outdir_data=outdir_data,
        mp=mp,
        fastphot=fastphot,
        overwrite=False,
    )

    if len(redrockfiles) == 0:
        return np.array([]), np.array([], dtype=int), np.array([], dtype=int)

    # Parse healpix from filenames: redrock-{survey}-{program}-{healpix}.fits[.gz]
    healpix_ids = np.array([
        int(os.path.basename(f).split('-')[-1].split('.')[0])
        for f in redrockfiles
    ])

    return np.asarray(redrockfiles), healpix_ids, ntargets


def balanced_partition(ntargets, njobs):
    """Partition indices into njobs groups with roughly equal total ntargets.

    Uses a longest-processing-time-first greedy algorithm, which achieves
    a load imbalance of at most 1 + 1/(njobs) times optimal.
    """
    order = np.argsort(ntargets)[::-1]
    totals = np.zeros(njobs, dtype=np.int64)
    groups = [[] for _ in range(njobs)]

    for idx in order:
        g = int(np.argmin(totals))
        groups[g].append(int(idx))
        totals[g] += ntargets[idx]

    return groups, totals


def estimate_walltime(ntargets_total, nodes, mp, qos, padding=1.5):
    """Rough wall time estimate based on empirical throughput."""
    n_cores = nodes * 128
    throughput = n_cores * _OBJ_PER_MIN_PER_CORE
    minutes = ntargets_total / throughput * padding
    hours = max(1, int(np.ceil(minutes / 60)))
    hours = min(hours, _MAX_HOURS.get(qos, 12))
    return f'{hours:02d}:00:00'


def generate_slurm(args, groups, group_totals, healpix_ids, ntargets):
    os.makedirs(args.outdir_scripts, exist_ok=True)

    prefix = 'fastphot' if args.stage == 'fastphot' else 'fastspec'
    njobs = sum(1 for g in groups if g)
    stage_flag = ''
    stage_flag_warmup = ''
    if args.stage == 'fastphot':
        stage_flag = 'args+=" --fastphot"'
        stage_flag_warmup = 'warmup_args+=" --fastphot"'
    elif args.stage == 'makeqa':
        stage_flag = 'args+=" --makeqa"'
        stage_flag_warmup = 'warmup_args+=" --makeqa"'

    env_script = os.path.abspath(args.env_script)
    outdir_data = os.path.expandvars(args.outdir_data)

    slurm_files = []
    job_index = 0
    for i, (group, total) in enumerate(zip(groups, group_totals)):
        if not group:
            continue
        job_index += 1

        group_healpix = healpix_ids[group]
        healpix_csv = ','.join(str(h) for h in sorted(group_healpix))

        walltime = args.time or estimate_walltime(int(total), args.nodes, args.mp, args.qos)
        job_name = f'{prefix}-{args.specprod}-job{job_index:02d}'

        content = _SLURM_TEMPLATE.format(
            account=args.account,
            qos=args.qos,
            mail_user=args.mail_user,
            nodes=args.nodes,
            walltime=walltime,
            job_name=job_name,
            job_index=job_index,
            njobs=njobs,
            nhealpix=len(group),
            ntargets_total=int(total),
            env_script=env_script,
            mp=args.mp,
            specprod=args.specprod,
            coadd_type=args.coadd_type,
            survey=args.survey,
            program=args.program,
            outdir_data=outdir_data,
            healpix_csv=healpix_csv,
            stage_flag=stage_flag,
            stage_flag_warmup=stage_flag_warmup,
        )

        slurm_file = os.path.join(args.outdir_scripts, f'{job_name}.slurm')
        with open(slurm_file, 'w') as fh:
            fh.write(content)
        print(f'  {slurm_file}  ({len(group)} healpix, {int(total):,d} targets, walltime={walltime})')
        slurm_files.append(slurm_file)

    return slurm_files


def main():
    parser = argparse.ArgumentParser(
        description='Generate balanced Slurm scripts for mpi-fastspecfit.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Work selection
    parser.add_argument('--specprod', required=True, help='Spectroscopic production name.')
    parser.add_argument('--stage', default='fastspec', choices=['fastspec', 'fastphot', 'makeqa'])
    parser.add_argument('--coadd-type', default='healpix',
                        choices=['healpix', 'cumulative', 'pernight', 'perexp'])
    parser.add_argument('--survey', default='main', help='Comma-separated survey names.')
    parser.add_argument('--program', default='dark,bright', help='Comma-separated program names.')

    # Job sizing
    parser.add_argument('--njobs', type=int, default=1, help='Number of Slurm jobs to generate.')
    parser.add_argument('--nodes', type=int, default=32, help='Perlmutter CPU nodes per job.')
    parser.add_argument('--mp', type=int, default=16, help='Multiprocessing workers per MPI rank.')
    parser.add_argument('--time', default=None,
                        help='Wall time per job as HH:MM:SS; auto-estimated when omitted.')

    # Slurm metadata
    parser.add_argument('--account', default=_ACCOUNT)
    parser.add_argument('--qos', default=_QOS)
    parser.add_argument('--mail-user', default=_MAIL_USER)

    # Paths
    parser.add_argument('--outdir-data', default=_OUTDIR, help='Base output data directory.')
    parser.add_argument('--outdir-scripts', default='.', help='Directory for generated .slurm files.')
    parser.add_argument('--env-script', default=_ENV_SCRIPT,
                        help='Path to fastspecfit-env.sh (absolute or relative to CWD).')

    # Actions
    parser.add_argument('--plan', action='store_true',
                        help='Print work summary and partition plan; do not write files.')
    parser.add_argument('--submit', action='store_true',
                        help='Submit each generated .slurm with sbatch.')

    args = parser.parse_args()

    fastphot = args.stage == 'fastphot'
    outdir_data = os.path.expandvars(args.outdir_data)

    print(f'Querying remaining work: specprod={args.specprod}, stage={args.stage}, '
          f'survey={args.survey}, program={args.program} ...')
    _, healpix_ids, ntargets = query_remaining(
        specprod=args.specprod, coadd_type=args.coadd_type,
        survey=args.survey, program=args.program,
        fastphot=fastphot, outdir_data=outdir_data)

    n_files = len(healpix_ids)
    n_targets = int(np.sum(ntargets)) if n_files > 0 else 0
    print(f'Remaining: {n_files:,d} healpix files, {n_targets:,d} targets')

    if n_files == 0:
        print('Nothing to do.')
        return

    njobs = min(args.njobs, n_files)
    if njobs < args.njobs:
        print(f'Warning: reduced njobs from {args.njobs} to {njobs} (fewer files than jobs).')

    groups, group_totals = balanced_partition(ntargets, njobs)

    # Print balance summary
    nonempty = [(g, t) for g, t in zip(groups, group_totals) if g]
    totals = [t for _, t in nonempty]
    print(f'\nPartition across {njobs} job(s)  '
          f'[imbalance: {max(totals)/max(min(totals), 1):.2f}x]:')
    for i, (g, t) in enumerate(nonempty, 1):
        est = args.time or estimate_walltime(int(t), args.nodes, args.mp, args.qos)
        print(f'  Job {i:02d}: {len(g):5,d} healpix  {int(t):9,d} targets  est. walltime {est}')

    if args.plan:
        return

    print(f'\nWriting {njobs} .slurm file(s) to {args.outdir_scripts}/ ...')
    slurm_files = generate_slurm(args, groups, group_totals, healpix_ids, ntargets)

    if args.submit:
        import subprocess
        print('\nSubmitting ...')
        for f in slurm_files:
            result = subprocess.run(['sbatch', f], capture_output=True, text=True)
            if result.returncode == 0:
                print(f'  {f}  →  {result.stdout.strip()}')
            else:
                print(f'  FAILED {f}: {result.stderr.strip()}')


if __name__ == '__main__':
    main()
