#!python

from __future__ import print_function, division

import os, sys
import numpy as np
from time import time
start = time()

from desitarget import io
from desitarget.io import check_both_set
from desitarget.randoms import pixweight, select_randoms
from glob import iglob
import fitsio

#import warnings
#warnings.simplefilter('error')

import multiprocessing
nproc = multiprocessing.cpu_count() // 2
# ADM default HEALPix Nside used throughout desitarget.
# ADM don't confuse this with the ns.nside parallelization input that is parsed below!!!
nside = io.desitarget_nside()

from desiutil.log import get_logger
log = get_logger()

from argparse import ArgumentParser
ap = ArgumentParser(description='Generate pixel-level random points and associated information from a Data Release of the Legacy Surveys')
ap.add_argument("surveydir", 
                help='Legacy Surveys Data Release root directory (e.g. /global/project/projectdirs/cosmo/data/legacysurvey/dr4/ for DR4 at NERSC)')
ap.add_argument("dest",
                help="Output directory for random catalog (the file name is built on-the-fly from other inputs)")
ap.add_argument("--density", type=int,
                help='Number of points per sq. deg. at which to Monte Carlo the imaging masks (defaults to 10,000)',
                default="10000")
ap.add_argument("--numproc", type=int,
                help='number of concurrent processes to use [{}]'.format(nproc),
                default=nproc)
ap.add_argument('--nside', type=int,
                help='Process random points in parallel in bricks that have centers within HEALPixels at this resolution (defaults to None)',
                default=None)
ap.add_argument('--healpixels',
                help='HEALPixels (corresponding to nside) to process (e.g. "6,21,57"). If not passed, run all bricks in the Data Release',
                default=None)
ap.add_argument("--bundlebricks", type=int,
                help="(overrides all options but surveydir) Print a slurm script to parallelize, with about this many bricks per HEALPixel (e.g. 10000)",
                default=None)
ap.add_argument("-n", "--nchunks", type=int,
                help='Number of smaller catalogs to split the random catalog into inside the `bundlebricks` slurm script. Defaults to [10].',
                default="10")
ap.add_argument("--brickspersec", type=float,
                help="estimate of bricks completed per second by the (parallelized) code. Used with `bundlebricks` to guess run times (defaults to 2.5)",
                default=2.5)
ap.add_argument("--dustdir",
                help="Directory of SFD dust maps (defaults to the equivalent of $DUST_DIR+'/maps')",
                default=None)
ap.add_argument("--aprad", type=float,
                help="Radius of aperture in arcsec in which to generate sky background flux levels (defaults to 0.75; the DESI fiber radius). If aprad < 1e-8 is passed, the code to produce these values is skipped, as a speed-up, and `APFLUX_` output values are set to zero.",
                default=0.75)
ap.add_argument("--seed", type=int,
                help="Random seed passed to desitarget.randoms.select_randoms()",
                default=1)
ap.add_argument("--addmtl", action='store_true',
                help="If passed, then add the columns needed for MTL to the random catalogs.")

ns = ap.parse_args()
# ADM build the list of command line arguments as
# ADM bundlebricks potentially needs to know about them.
extra = " --numproc {}".format(ns.numproc)
nsdict = vars(ns)
for nskey in "aprad", "density", "seed", "addmtl":
    if isinstance(nsdict[nskey], bool):
        if nsdict[nskey]:
            extra += " --{}".format(nskey)
    else:
        extra += " --{} {}".format(nskey, nsdict[nskey])

# ADM whether or not to add the mtl columns.
nomtl = not ns.addmtl

# ADM parse the list of HEALPixels in which to run.
pixlist = ns.healpixels
if pixlist is not None:
    pixlist = [ int(pixnum) for pixnum in pixlist.split(',') ]

if not os.path.exists(ns.surveydir):
    log.critical('Input directory does not exist: {}'.format(ns.surveydir))
    sys.exit(1)

if ns.bundlebricks is None:
    log.info('running on {} processors...t = {:.1f}s'.format(ns.numproc, time()-start))
    # ADM formally writing pixelized files requires both the nside
    # ADM and the list of healpixels to be set.
    check_both_set(ns.healpixels, ns.nside)

# ADM go looking for a maskbits file to steal the header for the
# ADM bit names. Try a couple of configurations (pre/post DR8).
hdr = None
if not 'dr5' in ns.surveydir and not 'dr6' in ns.surveydir:
    gen = iglob(os.path.join(ns.surveydir, "*", "coadd", "*", "*", "*maskbits*"))
    try:
        fn = next(gen)
        hdrall = fitsio.read_header(fn, 1)
    except StopIteration:
        gen = iglob(os.path.join(ns.surveydir, "coadd", "*", "*", "*maskbits*"))
        try:
            fn = next(gen)
            hdrall = fitsio.read_header(fn, 0)
        except StopIteration:
            log.critical(
                "No coadd or north/coadd south/coadd directories in {}?!".format(
                    ns.surveydir))
            sys.exit(1)
    # ADM retrieve the record dictionary for the entire header.
    rmhdr = vars(hdrall)
    # ADM write only the maskbits-relevant headers to a new header.
    hdr = fitsio.FITSHDR()
    for record in rmhdr['_record_map']:
        if 'BITNM' in record:
            hdr[record] = rmhdr['_record_map'][record]

randres, randnorth, randsouth = select_randoms(
    ns.surveydir, density=ns.density, numproc=ns.numproc, nside=ns.nside,
    pixlist=pixlist, aprad=ns.aprad, extra=extra, nchunks=ns.nchunks,
    bundlebricks=ns.bundlebricks, seed=ns.seed,
    brickspersec=ns.brickspersec, dustdir=ns.dustdir, nomtl=nomtl)

if ns.bundlebricks is None:
    # ADM extra header keywords for the output fits file.
    extra = {k: v for k, v in zip(["density", "aprad", "seed", "addmtl"],
                                  [ns.density, ns.aprad, ns.seed, ns.addmtl])}

    randoms = [randres, randnorth, randsouth]
    ress = [True, False, False]
    norths =  [None, True, False]

    for random, res, north in zip(randoms, ress, norths):
        nrands, outfile = io.write_randoms(
            ns.dest, random, indir=ns.surveydir, hdr=hdr, nside=nside,
            extra=extra, resolve=res, nsidefile=ns.nside, hpxlist=pixlist,
            north=north)
        log.info('wrote file of {} randoms to {}...t = {:.1f}s'
                 .format(nrands, outfile, time()-start))
