#!/usr/bin/env python

"""
MPI equivalent of select_mock_targets
"""

#- Parse args first to enable --help on login nodes where MPI crashes
from __future__ import absolute_import, division, print_function
import argparse
import multiprocessing
nproc = multiprocessing.cpu_count() // 2

parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', default='input.yaml')
parser.add_argument('-O', '--output_dir', help='Path to write the outputs', type=str, default='./')
parser.add_argument('-s', '--seed', help='Seed for random number generation', type=int, default=None)
parser.add_argument('-n', '--nproc', type=int, help='number of concurrent processes to use [{}]'.format(nproc), default=nproc)
parser.add_argument('--survey', type=str, choices=['main', 'sv1'], help='Survey to simulate.', default='main')
parser.add_argument('--nside', help='Divide the DESI footprint into this healpix resolution', type=int, default=64)
parser.add_argument('--tiles', help='Path to file with tiles to cover', type=str)
parser.add_argument('--npix', help='Number of healpix per process', type=int, default=1)
parser.add_argument('--healpixels', help='Integer list of healpix pixels (corresponding to nside) to process.', type=int, nargs='*', default=None)
parser.add_argument('-v','--verbose', action='store_true', help='Enable verbose output.')
parser.add_argument('--no-spectra', action='store_true', help='Do not generate spectra.')
parser.add_argument('--no-check-env', action='store_true', help="Don't check NERSC environment variables")
parser.add_argument('--sort-pixels', action='store_true', help="Sort pixels by galactic latitude")

args = parser.parse_args()

#- Then initialize MPI ASAP before proceeding with other imports
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

import sys, os, time
import numpy as np

from astropy.table import Table
import desimodel.footprint
from desiutil.log import get_logger, DEBUG
from desitarget.mock.build import targets_truth
import desitarget.mock.io as mockio
from  desitarget.io import find_target_files
from desispec.parallel import stdouterr_redirected

if args.verbose:
    log = get_logger(DEBUG)
else:
    log = get_logger()

if args.tiles is not None and args.healpixels is not None:
    if rank == 0:
        log.error('use --tiles or --healpixels but not both')
    sys.exit(1)

#- NERSC environment check
if 'NERSC_HOST' in os.environ and not args.no_check_env:
    ok = True
    if os.getenv('OMP_NUM_THREADS') not in ('1', '2'):
        ok = False
        if rank == 0:
            log.error('You likely want $OMP_NUM_THREADS=1 at NERSC')
    if os.getenv('MPICH_GNI_FORK_MODE') != 'FULLCOPY':
        ok = False
        if rank == 0:
            log.error('You likely want $MPICH_GNI_FORK_MODE=FULLCOPY at NERSC')
    if os.getenv('KMP_AFFINITY') != 'disabled':
        ok = False
        if rank == 0:
            log.error('You likely want $KMP_AFFINITY=disabled at NERSC')
    if not ok:
        if rank == 0:
            log.error('Either fix env or rerun with --no-check-env; exiting...')
        sys.exit(1)

#- Calculate which pixels cover these tiles and broadcast to all ranks
if rank == 0:
    if args.healpixels is None:
        if args.tiles is not None:
            if args.tiles.endswith('.ecsv'):
                tiles = Table.read(args.tiles, format='ascii.ecsv')
            else:
                tiles = Table.read(args.tiles)
            log.info('{} tiles'.format(len(tiles)))
        else:
            tiles = None
            log.info('Running on the full DESI footprint')
        pixels = desimodel.footprint.tiles2pix(args.nside, tiles)
    else:
        pixels = np.array(args.healpixels)
            
    keep = list()
    for i, pixnum in enumerate(pixels):
        truth_dark = find_target_files(args.output_dir, flavor="truth", obscon='dark',
                                        hp=pixnum, nside=args.nside, mock=True)
        truth_bright = find_target_files(args.output_dir, flavor="truth", obscon='bright',
                                        hp=pixnum, nside=args.nside, mock=True)
        if os.path.exists(truth_dark) or os.path.exists(truth_bright):
            continue
        keep.append(i)

    log.info('{}/{} pixels remaining to do'.format(len(keep), len(pixels)))
    pixels = pixels[keep]

    #- pre-create output directories
    for pixnum in pixels:
        outdir = os.path.dirname(mockio.findfile('blat', args.nside, pixnum,
                                                 basedir=args.output_dir))
        if not os.path.isdir(outdir):
            os.makedirs(outdir, exist_ok=True)  #- belt and suspenders

    #- Optionally sort pixels by -|galactic latitude| to do pixels with
    #- lowest star densities first.  When MPI size >> 1 this doesn't work
    #- as well because high ranks will still get low latitude pixels.
    if args.sort_pixels:
        log.info('Sorting pixels by galactic latitude (highest lat first)')
        import healpy as hp
        from astropy.coordinates import SkyCoord
        from astropy import units
        theta, phi = hp.pix2ang(args.nside, pixels, nest=True)
        ra, dec = np.degrees(phi), 90-np.degrees(theta)
        c = SkyCoord(ra*units.deg, dec*units.deg)
        l, b = c.galactic.l.value, c.galactic.b.value
        ii = np.argsort(-np.abs(b))
        pixels = pixels[ii]
else:
    pixels = None

comm.barrier()
pixels = comm.bcast(pixels, root=0)

#- Read config file and broadcast to all ranks
if rank == 0:
    import yaml
    with open(args.config, 'r') as pfile:
        params = yaml.safe_load(pfile)
else:
    params = None

params = comm.bcast(params, root=0)

#- Generate a different random seed for each pixel
np.random.seed(args.seed)
randseeds = np.random.randint(2**31, size=len(pixels))

#- Split the pixels into groups for each rank
iedges = np.linspace(0, len(pixels), size+1, dtype=int)
rankpix = pixels[iedges[rank]:iedges[rank+1]]
rankseeds = randseeds[iedges[rank]:iedges[rank+1]]
log.info('rank {} processes {} pixels {}'.format(rank, iedges[rank+1]-iedges[rank], rankpix))
sys.stdout.flush()
comm.barrier()

if len(rankpix) > 0:
    #- Process one pixel at a time to avoid blowing out memory, but structure
    #- it in a way that we could expand to multiple pixels per call if/when
    #- we use less memory.
    n = args.npix
    for i in range(0, len(rankpix), n):
        logfile = mockio.findfile('build', args.nside, rankpix[i], ext='log', basedir=args.output_dir)
        log.info('Logging pixels {} to {}'.format(rankpix[i:i+n], logfile))
        t0 = time.time()
        try:
            with stdouterr_redirected(to=logfile):
                log.info('Calling targets_truth() with the following options')
                log.info('  output_dir {}'.format(args.output_dir))
                log.info('  seed       {}'.format(rankseeds[i]))
                log.info('  nside      {}'.format(args.nside))
                log.info('  nproc      {}'.format(args.nproc))
                log.info('  verbose    {}'.format(args.verbose))
                log.info('  healpixels {}'.format(rankpix[i:i+n]))
                targets_truth(params, output_dir=args.output_dir, seed=rankseeds[i],
                              nside=args.nside, nproc=args.nproc, verbose=args.verbose,
                              healpixels=rankpix[i:i+n], no_spectra=args.no_spectra,
                              survey=args.survey)

            runtime = (time.time()-t0) / 60
            log.info('Pixels {} took {:.1f} minutes'.format(rankpix[i:i+n], runtime))
        except Exception as err:
            runtime = (time.time()-t0) / 60
            log.error('Pixels {} failed after {:.1f} minutes'.format(rankpix[i:i+n], runtime))
            import traceback
            msg = traceback.format_exc()
            log.error(msg)
            sys.stdout.flush()
            sys.stderr.flush()

