#!python
"""
profile-fastspecfit  --  Single-process profiling harness for FastSpecFit.

Mirrors the pytest conftest fixtures so that no separate test run is needed.
Phases timed independently:

  1. sc_data.initialize()         -- template load + FFT pre-cache + JIT warmup
  2. DESISpectra I/O              -- gather_metadata + read (or read_stacked)
  3. fastspec_one cold run(s)     -- first pass; JIT compilation may fire here
  4. fastspec_one warm run(s)     -- subsequent passes; steady-state throughput

Usage examples
--------------
# Basic timing breakdown (fastspec mode):
  profile-fastspecfit

# fastphot mode with 3 warm repeats, save cProfile output:
  profile-fastspecfit --mode fastphot --nwarm 3 --cprofile

# Use a pre-downloaded template file:
  profile-fastspecfit --templates /path/to/ftemplates-chabrier-2.0.0.fits

# Point py-spy at this script (no --cprofile needed):
  py-spy record -o profile.svg -- profile-fastspecfit --mode fastspec --nwarm 5

Environment variables
---------------------
FTEMPLATES_CACHE_DIR  Directory where template files are cached between runs.
                      If set, the template is not deleted after the run.
"""
import argparse
import cProfile
import io
import os
import pathlib
import pstats
import sys
import tempfile
import time
from importlib import resources
from urllib.request import urlretrieve


TEMPLATE_VERSION = '2.0.0'
TEMPLATE_FILENAME = f'ftemplates-chabrier-{TEMPLATE_VERSION}.fits'
TEMPLATE_URL = (
    f'https://data.desi.lbl.gov/public/external/templates/fastspecfit/'
    f'{TEMPLATE_VERSION}/{TEMPLATE_FILENAME}'
)


def parse_args(argv=None):
    p = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    p.add_argument('--mode', choices=['fastspec', 'fastphot', 'stackfit'],
                   default='fastspec',
                   help='Fitting mode to profile (default: fastspec).')
    p.add_argument('--fastphot', action='store_true',
                   help='Shorthand for --mode fastphot.')
    p.add_argument('--templates', default=None,
                   help='Path to template FITS file.  Overrides FTEMPLATES_CACHE_DIR lookup.')
    p.add_argument('--nwarm', type=int, default=2,
                   help='Number of warm fitting passes after the cold pass (default: 2).')
    p.add_argument('--nmonte', type=int, default=10,
                   help='Monte Carlo realizations per object (default: 10).')
    p.add_argument('--cprofile', action='store_true',
                   help='Run cProfile on the warm fitting loop and write <mode>.prof.')
    p.add_argument('--cprofile-top', type=int, default=30,
                   help='Number of top cProfile entries to print (default: 30).')
    p.add_argument('--outdir', default=None,
                   help='Directory for output FITS and .prof files.  Defaults to a temp dir.')
    args = p.parse_args(argv)
    if args.fastphot:
        args.mode = 'fastphot'
    return args


# ---------------------------------------------------------------------------
# Template management
# ---------------------------------------------------------------------------

def get_template_path(args):
    """Return a path to the template file, downloading if necessary."""
    if args.templates:
        return args.templates

    cache_dir = os.environ.get('FTEMPLATES_CACHE_DIR')
    if cache_dir:
        tdir = pathlib.Path(cache_dir) / TEMPLATE_VERSION
        tdir.mkdir(parents=True, exist_ok=True)
        tfile = tdir / TEMPLATE_FILENAME
    else:
        tdir = pathlib.Path(tempfile.mkdtemp())
        tfile = tdir / TEMPLATE_FILENAME

    if not tfile.exists():
        print(f'Downloading templates to {tfile} ...', flush=True)
        t0 = time.perf_counter()
        urlretrieve(TEMPLATE_URL, tfile)
        print(f'  done in {time.perf_counter()-t0:.1f}s', flush=True)
    else:
        print(f'Using cached templates: {tfile}', flush=True)

    return str(tfile)


# ---------------------------------------------------------------------------
# Setup helpers (mirror conftest.py)
# ---------------------------------------------------------------------------

def get_test_paths():
    base = resources.files('fastspecfit').joinpath('test/data')
    return {
        'redux_dir':   base,
        'specproddir': base,
        'mapdir':      base,
        'fphotodir':   base,
        'redrockfile': base.joinpath('redrock-4-80613-thru20210324.fits'),
        'stackfile':   base.joinpath('stack-LRG.fits'),
    }


def build_init_args(paths, templates, mode):
    return {
        'emlines_file':      None,
        'fphotofile':        None,
        'fastphot':          (mode == 'fastphot'),
        'fitstack':          (mode == 'stackfit'),
        'ignore_photometry': (mode == 'stackfit'),
        'template_file':     templates,
        'template_version':  TEMPLATE_VERSION,
        'template_imf':      'chabrier',
        'log_verbose':       False,
    }


# ---------------------------------------------------------------------------
# Timing helper
# ---------------------------------------------------------------------------

class Timer:
    def __init__(self):
        self._marks = []

    def mark(self, label):
        self._marks.append((label, time.perf_counter()))

    def report(self):
        print('\n' + '='*62)
        print(f'  {"Phase":<40} {"Elapsed (s)":>10}')
        print('  ' + '-'*58)
        t_start = self._marks[0][1]
        t_prev  = t_start
        for label, t in self._marks[1:]:
            dt = t - t_prev
            print(f'  {label:<40} {dt:>10.3f}')
            t_prev = t
        total = self._marks[-1][1] - t_start
        print('  ' + '-'*58)
        print(f'  {"TOTAL":<40} {total:>10.3f}')
        print('='*62 + '\n')


# ---------------------------------------------------------------------------
# Phase runners
# ---------------------------------------------------------------------------

def phase_init(sc_data, init_args, timer):
    timer.mark('start')
    sc_data.initialize(**init_args)
    timer.mark('sc_data.initialize (templates + JIT warmup)')


def phase_io(paths, mode, timer):
    from fastspecfit.io import DESISpectra, get_output_dtype
    from fastspecfit.singlecopy import sc_data

    fastphot = (mode == 'fastphot')
    fitstack = (mode == 'stackfit')

    Spec = DESISpectra(
        phot=sc_data.photometry,
        cosmo=sc_data.cosmology,
        fphotodir=str(paths['fphotodir']),
        mapdir=str(paths['mapdir']),
        redux_dir=str(paths['redux_dir']),
    )

    if fitstack:
        data, meta = Spec.read_stacked(
            [str(paths['stackfile'])],
            firsttarget=0, ntargets=None,
        )
    else:
        Spec.gather_metadata(
            [str(paths['redrockfile'])],
            firsttarget=0,
            redrockfile_prefix='redrock-',
            specfile_prefix='coadd-',
            qnfile_prefix='qso_qn-',
            use_quasarnet=False,
            specprod_dir=str(paths['specproddir']),
        )
        data, meta = Spec.read(sc_data.photometry, fastphot=fastphot)

    timer.mark(f'DESISpectra I/O ({len(meta)} object(s))')
    return Spec, data, meta


def build_fitargs(Spec, data, meta, mode, nmonte):
    """Build the list of per-object argument dicts for fastspec_one."""
    import numpy as np
    from fastspecfit.io import get_output_dtype
    from fastspecfit.singlecopy import sc_data

    fastphot = (mode == 'fastphot')
    fitstack = (mode == 'stackfit')

    ncoeff   = sc_data.templates.ntemplates
    cameras  = None if fastphot else data[0]['cameras']

    fastfit_dtype, _ = get_output_dtype(
        Spec.specprod, phot=sc_data.photometry,
        linetable=sc_data.emlines.table, ncoeff=ncoeff,
        cameras=cameras, fastphot=fastphot, fitstack=fitstack,
    )
    specphot_dtype, _ = get_output_dtype(
        Spec.specprod, phot=sc_data.photometry,
        linetable=sc_data.emlines.table, ncoeff=ncoeff,
        cameras=cameras, fastphot=fastphot, fitstack=fitstack, specphot=True,
    )

    rng   = np.random.default_rng(seed=1)
    seeds = rng.integers(2**32, size=len(meta), dtype=np.int64)

    return [
        dict(
            iobj=i, data=data[i], meta=meta[i],
            fastfit_dtype=fastfit_dtype, specphot_dtype=specphot_dtype,
            broadlinefit=True, fastphot=fastphot, fitstack=fitstack,
            constrain_age=False, no_smooth_continuum=False,
            debug_plots=False, uncertainty_floor=0.01,
            minsnr_balmer_broad=2.5, nmonte=nmonte, seed=int(seeds[i]),
        )
        for i in range(len(meta))
    ]


def run_fitting_loop(fitargs):
    """Run fastspec_one for every object; return wall-clock seconds.

    Deep-copies fitargs before each call because one_spectrum() mutates the
    'data' dict in-place (deletes the '*0' raw-array keys after processing).
    """
    import copy
    from fastspecfit.fastspecfit import fastspec_one
    t0 = time.perf_counter()
    for fa in fitargs:
        fastspec_one(**copy.deepcopy(fa))
    return time.perf_counter() - t0


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main(argv=None):
    args = parse_args(argv)

    outdir = pathlib.Path(args.outdir) if args.outdir else pathlib.Path(tempfile.mkdtemp())
    outdir.mkdir(parents=True, exist_ok=True)

    templates = get_template_path(args)
    paths     = get_test_paths()
    mode      = args.mode

    print(f'\n--- FastSpecFit profiling harness  (mode={mode}) ---\n')

    from fastspecfit.singlecopy import sc_data
    init_args = build_init_args(paths, templates, mode)

    timer = Timer()

    # Phase 1: initialize (template load, FFT pre-cache, first JIT touch)
    phase_init(sc_data, init_args, timer)

    # Phase 2: I/O
    Spec, data, meta = phase_io(paths, mode, timer)
    fitargs = build_fitargs(Spec, data, meta, mode, args.nmonte)
    nobj    = len(fitargs)

    # Phase 3: cold run (JIT compilation fires here)
    print(f'Cold pass ({nobj} object(s)) ...', flush=True)
    cold_s = run_fitting_loop(fitargs)
    timer.mark(f'fastspec_one cold pass ({nobj} obj)  [{cold_s:.3f}s]')
    if nobj > 0:
        print(f'  cold: {cold_s:.3f}s total  |  {cold_s/nobj:.3f}s/obj', flush=True)

    # Phase 4: warm runs
    warm_times = []
    for i in range(args.nwarm):
        print(f'Warm pass {i+1}/{args.nwarm} ({nobj} object(s)) ...', flush=True)
        warm_times.append(run_fitting_loop(fitargs))
    if warm_times:
        avg_s = sum(warm_times) / len(warm_times)
        timer.mark(
            f'fastspec_one warm x{args.nwarm} ({nobj} obj)'
            f'  [avg {avg_s:.3f}s]'
        )
        print(f'  warm avg: {avg_s:.3f}s total  |  {avg_s/nobj:.3f}s/obj', flush=True)

    timer.report()

    # Optional: cProfile on one warm pass
    if args.cprofile:
        prof_path = outdir / f'{mode}.prof'
        print(f'Running cProfile (1 warm pass) -> {prof_path} ...', flush=True)

        pr = cProfile.Profile()
        pr.enable()
        run_fitting_loop(fitargs)
        pr.disable()

        pr.dump_stats(str(prof_path))

        # Print top-N to stdout
        buf = io.StringIO()
        ps  = pstats.Stats(pr, stream=buf).sort_stats('cumulative')
        ps.print_stats(args.cprofile_top)
        print(buf.getvalue())
        print(f'Full profile saved to: {prof_path}')
        print(f'  View with:  python -m snakeviz {prof_path}')
        print(f'  Or:         python -m pstats {prof_path}')

    return 0


if __name__ == '__main__':
    sys.exit(main())
