#!python
"""
MPI wrapper for fastphot and fastspec.

"""
import os, time
import numpy as np
from astropy.table import Table

from fastspecfit.util import NMONTE_DEFAULT
from fastspecfit.logger import log


def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None, makeqa=False,
                    sample=None, input_redshifts=False, outdir_data='.', templates=None,
                    templateversion=None, fphotodir=None, fphotofile=None, emlinesfile=None,
                    mp_per_file=False):
    """Run fastspec, fastphot, or fastqa across MPI ranks.

    Parameters
    ----------
    args : :class:`argparse.Namespace`
        Parsed command-line arguments.
    comm : MPI communicator or None, optional
        MPI communicator; uses a single process when ``None``.
    fastphot : :class:`bool`, optional
        If ``True``, run ``fastphot`` instead of ``fastspec``.
    specprod_dir : :class:`str` or None, optional
        Override path to the spectroscopic reduction directory.
    makeqa : :class:`bool`, optional
        If ``True``, generate QA figures instead of fitting.
    sample : :class:`astropy.table.Table` or None, optional
        Optional target sample table.
    input_redshifts : :class:`bool`, optional
        If ``True``, pass input redshifts from ``sample`` to the fitter.
    outdir_data : :class:`str`, optional
        Base output data directory.
    templates : :class:`str` or None, optional
        Full path to template file; auto-detected when ``None``.
    templateversion : :class:`str` or None, optional
        Template version string; auto-detected when ``None``.
    fphotodir : :class:`str` or None, optional
        Top-level photometry directory override.
    fphotofile : :class:`str` or None, optional
        Photometry configuration file override.
    emlinesfile : :class:`str` or None, optional
        Emission-line parameter file override.
    mp_per_file : :class:`bool`, optional
        If ``True``, run multiple files in parallel (one file per worker).

    """
    import sys
    from desispec.parallel import stdouterr_redirected, weighted_partition
    from fastspecfit.mpi import plan, build_cmdargs, _perfile_init, _perfile_run
    from fastspecfit.util import MPPool
    from fastspecfit.singlecopy import sc_data
    from fastspecfit.fastspecfit import make_init_sc_args

    if comm:
        rank = comm.rank
        size = comm.size
    else:
        rank, size = 0, 1

    if rank == 0:
        t0 = time.time()
    _, all_redrockfiles, all_outfiles, all_ntargets = plan(
        comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
        coadd_type=args.coadd_type, survey=args.survey, program=args.program,
        healpix=args.healpix, tile=args.tile, night=args.night,
        makeqa=args.makeqa, mp=args.mp, fastphot=fastphot,
        outdir_data=outdir_data, overwrite=args.overwrite, sample=sample)
    if rank == 0:
        log.info(f'Planning took {time.time() - t0:.2f} sec')
        # If no work is left to do, let the other ranks know so they can return
        # politely.
        if len(all_redrockfiles) == 0:
            alldone = True
        else:
            alldone = False
    else:
        alldone = False

    if comm:
        alldone = comm.bcast(alldone, root=0)

    if alldone:
        return

    # Auto-detect per-file parallelism: if the average number of targets per
    # file is less than mp, file-level parallelism keeps workers busier than
    # object-level parallelism would.  An explicit --mp-per-file flag overrides.
    # Only rank 0 has a valid all_ntargets; broadcast the decision to others.
    if not mp_per_file and args.mp > 1:
        if rank == 0 and len(all_ntargets) > 0:
            avg = float(np.mean(all_ntargets))
            mp_per_file = avg < args.mp
            if mp_per_file:
                log.info(f'Auto-selected per-file parallelism (avg {avg:.1f} targets/file < mp={args.mp})')
        if comm:
            mp_per_file = comm.bcast(mp_per_file, root=0)

    if comm:
        if rank == 0:
            groups = weighted_partition(all_ntargets, size)
            for irank in range(1, size):
                log.debug(f'Rank {rank} sending work to rank {irank}')
                comm.send(all_redrockfiles[groups[irank]], dest=irank, tag=1)
                comm.send(all_outfiles[groups[irank]], dest=irank, tag=2)
                comm.send(all_ntargets[groups[irank]], dest=irank, tag=3)
            # rank 0 gets work, too
            redrockfiles = all_redrockfiles[groups[0]]
            outfiles = all_outfiles[groups[0]]
            ntargets = all_ntargets[groups[0]]
        else:
            log.debug(f'Rank {rank}: received work from rank 0')
            redrockfiles = comm.recv(source=0, tag=1)
            outfiles = comm.recv(source=0, tag=2)
            ntargets = comm.recv(source=0, tag=3)
    else:
        redrockfiles = all_redrockfiles
        outfiles = all_outfiles
        ntargets = all_ntargets

    init_sc_args = make_init_sc_args(args, fastphot=fastphot)

    shared_pool = None
    if not makeqa and not mp_per_file:
        sc_data.initialize(**init_sc_args)
        shared_pool = MPPool(args.mp, initializer=sc_data.initialize,
                             init_argdict=init_sc_args)

    def run_one_file(redrockfile, outfile, ntarget):
        if rank == 0:
            log.debug(f'Rank {rank} started at {time.asctime()}')

        if args.makeqa:
            from fastspecfit.qa import fastqa as fast
        else:
            if fastphot:
                from fastspecfit.fastspecfit import fastphot as fast
            else:
                from fastspecfit.fastspecfit import fastspec as fast

        cmd, cmdargs, logfile = build_cmdargs(args, redrockfile, outfile, sample=sample,
                                              fastphot=fastphot, input_redshifts=input_redshifts)

        if rank == 0:
            log.info(f'Rank {rank}: ntargets={ntarget}: {cmd} {cmdargs}')

        if args.dry_run:
            return

        try:
            t1 = time.time()
            outdir = os.path.dirname(logfile)
            if not os.path.isdir(outdir):
                os.makedirs(outdir, exist_ok=True)

            fast_kwargs = {'args': cmdargs.split(), 'comm': None}
            if not args.makeqa and shared_pool is not None:
                fast_kwargs['mp_pool'] = shared_pool
            if args.nolog:
                err = fast(**fast_kwargs)
            else:
                with stdouterr_redirected(to=logfile, overwrite=args.overwrite, comm=None):
                    err = fast(**fast_kwargs)

            log.info(f'Rank {rank} done in {(time.time() - t1)/60.:.2f} min')
            if err != 0:
                if not os.path.exists(outfile):
                    log.warning(f'Rank {rank} missing {outfile}')
                    raise IOError
        except:
            log.warning(f'Rank {rank} raised an exception')
            import traceback
            traceback.print_exc()
        return

    # loop on each file
    if mp_per_file:
        pool_init = dict(init_sc_args=init_sc_args, sample=sample, rank=rank,
                         fastphot=fastphot, input_redshifts=input_redshifts)
        perfile_pool = MPPool(args.mp, initializer=_perfile_init, init_argdict=pool_init)
        if args.mp <= 1:
            _perfile_init(**pool_init)
        task_args = [{'redrockfile': r, 'outfile': o, 'ntarget': n, 'args': args}
                     for r, o, n in zip(redrockfiles, outfiles, ntargets)]
        perfile_pool.starmap(_perfile_run, task_args)
        perfile_pool.close()
    else:
        for redrockfile, outfile, ntarget in zip(redrockfiles, outfiles, ntargets):
            run_one_file(redrockfile, outfile, ntarget)
        if shared_pool is not None:
            shared_pool.close()


    if rank == 0:
        log.debug(f'Rank {rank} is done')

    if comm:
        comm.barrier()

    if rank == 0 and not args.dry_run:
        for outfile in outfiles:
            if not os.path.exists(outfile):
                log.warning(f'Missing {outfile}')

        log.info(f'All done at {time.asctime()}')


def main():
    """MPI entry point for fastspec, fastphot, and fastqa."""
    import argparse
    from fastspecfit.mpi import plan
    from fastspecfit.templates import VDISP_NOMINAL, VDISP_BOUNDS

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--coadd-type', type=str, default='healpix', choices=['healpix', 'cumulative', 'pernight', 'perexp'],
                        help='Specify which type of spectra/zbest files to process.')
    parser.add_argument('--specprod', type=str, default='loa', help='Spectroscopic production to process.')

    parser.add_argument('--healpix', type=str, default=None, help='Comma-separated list of healpixels to process.')
    parser.add_argument('--survey', type=str, default='main,special,cmx,sv1,sv2,sv3', help='Survey to process.')
    parser.add_argument('--program', type=str, default='bright,dark,other,backup', help='Program to process.') # backup not supported
    parser.add_argument('--tile', default=None, type=str, nargs='*', help='Tile(s) to process.')
    parser.add_argument('--night', default=None, type=str, nargs='*', help='Night(s) to process (ignored if coadd-type is cumulative).')

    parser.add_argument('--samplefile', default=None, type=str, help='Full path to sample (FITS) file with {survey,program,healpix,targetid}.')
    parser.add_argument('--input-redshifts', action='store_true', help='Only used with --samplefile; if set, fit with redshift "Z" values.')
    parser.add_argument('--input-seeds', type=str, default=None, help='Comma-separated list of input random-number seeds corresponding to the (required) --targetids input.')
    parser.add_argument('--nmonte', type=int, default=NMONTE_DEFAULT, help='Number of Monte Carlo realizations.')
    parser.add_argument('--vdisp-nominal', type=float, default=VDISP_NOMINAL, help='Nominal (default) velocity dispersion in km/s.')
    parser.add_argument('--vdisp-bounds', type=float, default=VDISP_BOUNDS, nargs=2, help='Nominal (default) velocity dispersion in km/s.')
    parser.add_argument('--seed', type=int, default=1, help='Random seed for Monte Carlo reproducibility; ignored if --input-seeds is passed.')

    parser.add_argument('--mp', type=int, default=1, help='Number of multiprocessing processes per MPI rank or node.')
    parser.add_argument('--mp-per-file', default=False, action='store_true', help='Run multiprocessing on one core per spectra file with multiple simultaneous files.')
    parser.add_argument('-n', '--ntargets', type=int, help='Number of targets to process in each file.')
    parser.add_argument('--firsttarget', type=int, default=0, help='Index of first object to to process in each file, zero-indexed.') 
    parser.add_argument('--targetids', type=str, default=None, help='Comma-separated list of TARGETIDs to process.')

    parser.add_argument('--fastphot', action='store_true', help='Fit the broadband photometry.')

    parser.add_argument('--templateversion', type=str, default=None, help='Template version number.')
    parser.add_argument('--templates', type=str, default=None, help='Optional full path and filename to the templates.')
    parser.add_argument('--fphotodir', type=str, default=None, help='Top-level location of the source photometry.')
    parser.add_argument('--fphotofile', type=str, default=None, help='Photometric information file.')
    parser.add_argument('--emlinesfile', type=str, default=None, help='Emission line parameter file.')

    parser.add_argument('--merge', action='store_true', help='Merge all individual catalogs (for a given survey and program) into one large file.')
    parser.add_argument('--mergedir', type=str, help='Output directory for merged catalogs.')
    parser.add_argument('--merge-suffix', type=str, help='Filename suffix for merged catalogs.')
    parser.add_argument('--mergeall', action='store_true', help='Merge all the individual merged catalogs into a single merged catalog.')
    parser.add_argument('--mergeall-main', action='store_true', help='Merge all the main catalogs.')
    parser.add_argument('--mergeall-sv', action='store_true', help='Merge all the SV catalogs.')
    parser.add_argument('--mergeall-special', action='store_true', help='Merge all the special catalogs.')
    parser.add_argument('--makeqa', action='store_true', help='Build QA in parallel.')

    parser.add_argument('--ignore-quasarnet', default=False, action='store_true', help='Do not use QuasarNet to improve QSO redshifts.')
    parser.add_argument('--ignore-photometry', default=False, action='store_true', help='Ignore the broadband photometry during model fitting.')
    parser.add_argument('--no-smooth-continuum', default=False, action='store_true', help='Do not fit the smooth continuum.')

    parser.add_argument('--verbose', action='store_true', help='More verbose output.')
    parser.add_argument('--overwrite', action='store_true', help='Overwrite any existing output files.')
    parser.add_argument('--plan', action='store_true', help='Plan how many nodes to use and how to distribute the targets.')
    parser.add_argument('--profile', action='store_true', help='Write out profiling / timing files..')
    parser.add_argument('--nompi', action='store_true', help='Do not use MPI parallelism.')
    parser.add_argument('--nolog', action='store_true', help='Do not write to the log file.')
    parser.add_argument('--dry-run', action='store_true', help='Generate but do not run commands.')

    parser.add_argument('--outdir-data', default='$PSCRATCH/fastspecfit/data', type=str, help='Base output data directory.')

    args = parser.parse_args()

    specprod_dir = None
    outdir_data = os.path.expanduser(os.path.expandvars(args.outdir_data))

    # Must be set before MPI.Init(): forking after MPI initialization is unsafe.
    # https://docs.nersc.gov/development/languages/python/parallel-python/#use-the-spawn-start-method
    if args.mp > 1 and 'NERSC_HOST' in os.environ:
        import multiprocessing
        try:
            multiprocessing.set_start_method('spawn')
        except RuntimeError:
            pass  # already set

    if args.nompi:
        comm = None
    else:
        try:
            from mpi4py import MPI
            # needed when profiling; no effect otherwise
            # https://docs.linaroforge.com/24.0.6/html/forge/map/python_profiling/profile_python_script.html
            #MPI.Init_thread(MPI.THREAD_SINGLE)
            comm = MPI.COMM_WORLD
        except ImportError:
            comm = None

    if comm:
        rank = comm.rank
    else:
        rank = 0

    # If an input samplefile is provided, read and broadcast it.
    if args.samplefile is None:
        sample = None
    else:
        if args.coadd_type != 'healpix':
            errmsg = 'Input --samplefile is only currently compatible with --coadd-type="healpix"'
            log.critical(errmsg)
            raise NotImplementedError(errmsg)

        if rank == 0:
            import fitsio
            if not os.path.isfile(args.samplefile):
                log.warning(f'{args.samplefile} does not exist.')
                return
            try:
                sample = Table(fitsio.read(args.samplefile, columns=['SURVEY', 'PROGRAM', 'HEALPIX', 'TARGETID']))
                log.info(f'Read {len(sample)} rows from {args.samplefile}')
            except:
                if args.input_redshifts:
                    errmsg = f'Sample file {args.samplefile} with --input-redshifts missing required columns ' + \
                        '{SURVEY,PROGRAM,HEALPIX,TARGETID,Z}'
                else:
                    errmsg = f'Sample file {args.samplefile} missing required columns ' + \
                        '{SURVEY,PROGRAM,HEALPIX,TARGETID}'
                log.critical(errmsg)
                raise ValueError(errmsg)
        else:
            sample = Table()

        if comm:
            sample = comm.bcast(sample, root=0)

    # Parse some of the inputs.
    if args.samplefile is None and args.coadd_type == 'healpix':
        args.survey = args.survey.split(',')
        args.program = args.program.split(',')
        if args.healpix is not None:
            args.healpix = args.healpix.split(',')

    if args.mergeall_main or args.mergeall_sv or args.mergeall_special:
        args.mergeall = True

    if args.merge or args.mergeall:
        from fastspecfit.mpi import merge_fastspecfit

        # convenience code to make the super-merge catalogs, e.g., fastspec-iron-{main,special,sv}.fits
        if args.fastphot:
            fastprefix = 'fastphot'
        else:
            fastprefix = 'fastspec'

        if args.mergeall_main or args.mergeall_sv or args.mergeall_special:
            from glob import glob
            if args.mergedir is None:
                mergedir = os.path.join(outdir_data, args.specprod, 'catalogs')
            else:
                mergedir = args.mergedir

            if args.mergeall_main:
                args.merge_suffix = f'{args.specprod}-main'
                fastfiles_to_merge = sorted(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main-*.fits')))
            elif args.mergeall_special:
                args.merge_suffix = f'{args.specprod}-special'
                fastfiles_to_merge = sorted(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special-*.fits')))
            elif args.mergeall_sv:
                args.merge_suffix = f'{args.specprod}-sv'
                fastfiles_to_merge = sorted(list(set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-sv.fits')))))
            else:
                fastfiles_to_merge = None
        else:
            fastfiles_to_merge = None

        if args.samplefile is not None:
            merge_fastspecfit(specprod=args.specprod, specprod_dir=specprod_dir, coadd_type='healpix',
                              sample=sample, merge_suffix=args.merge_suffix,
                              outdir_data=outdir_data, fastfiles_to_merge=fastfiles_to_merge,
                              outsuffix=args.merge_suffix, mergedir=args.mergedir, overwrite=args.overwrite,
                              fastphot=args.fastphot, supermerge=args.mergeall, mp=args.mp)
        else:
            merge_fastspecfit(specprod=args.specprod, specprod_dir=specprod_dir, coadd_type=args.coadd_type,
                              survey=args.survey, program=args.program, healpix=args.healpix,
                              tile=args.tile, night=args.night, outdir_data=outdir_data,
                              fastfiles_to_merge=fastfiles_to_merge, outsuffix=args.merge_suffix,
                              mergedir=args.mergedir, overwrite=args.overwrite,
                              fastphot=args.fastphot, supermerge=args.mergeall, mp=args.mp,
                              nside_main=1)
        return


    if args.plan:
        plan(comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
             coadd_type=args.coadd_type, survey=args.survey, program=args.program,
             healpix=args.healpix, tile=args.tile, night=args.night,
             makeqa=args.makeqa, mp=args.mp, fastphot=args.fastphot,
             outdir_data=outdir_data, overwrite=args.overwrite,
             sample=sample)
    else:
        if args.profile:
            import cProfile
            import pstats

            profiler = cProfile.Profile()
            profiler.enable()

        run_fastspecfit(args, comm=comm, fastphot=args.fastphot, specprod_dir=specprod_dir,
                        makeqa=args.makeqa, outdir_data=outdir_data, sample=sample,
                        input_redshifts=args.input_redshifts, templates=args.templates,
                        templateversion=args.templateversion, fphotodir=args.fphotodir,
                        fphotofile=args.fphotofile, emlinesfile=args.emlinesfile, mp_per_file=args.mp_per_file)

        if args.profile:
            profiler.disable()

            outfile = os.path.join(outdir_data, f'profile_rank{rank}.prof')
            log.info(f'Writing {outfile}')
            profiler.dump_stats(outfile)
            with open(os.path.join(outdir_data, f'profile_rank{rank}.txt'), 'w') as F:
                stats = pstats.Stats(profiler, stream=F)
                stats.strip_dirs()
                stats.sort_stats('cumtime')
                stats.print_stats()
                stats.print_callees()

    if comm:
        MPI.Finalize()


if __name__ == '__main__':
    main()
