#!/usr/bin/env python

from desitarget.skyfibers import supplement_skies, density_of_sky_fibers
from desitarget import io
from desitarget.geomask import bundle_bricks, shares_hp, pixarea2nside
from desitarget.brightmask import mask_targets, get_recent_mask_dir
from desitarget.subpriority import override_subpriority

import os, sys
import numpy as np
import healpy as hp
import fitsio
from glob import glob

#import warnings
#warnings.simplefilter('error')

import multiprocessing
nproc = multiprocessing.cpu_count() // 2
# ADM default HEALPix Nside used throughout desitarget.
# ADM don't confuse this with the ns.nside parallelization input that is parsed below!!!
nside = io.desitarget_nside()

from desiutil.log import get_logger
log = get_logger()

from argparse import ArgumentParser
ap = ArgumentParser(description='Generate supplemental sky locations using Gaia-G-band avoidance (for regions beyond the Legacy Surveys) while also NOT generating supplemental skies near an existing sky location')
ap.add_argument("skydir",
                help="Full path to the directory of skies to supplement. Skies in this directory must have the same HEALPixel nside, and number of sky locations per sq. deg. in their header (note that those quantities have to be passed anyway, as supplemental skies in some pixels are outside of the Legacy surveys footprint, so won't have a corresponding skies file)")
ap.add_argument("dest",
                help="Output supplemental sky targets directory (the file name is built on-the-fly from other inputs)")
ap.add_argument("--nskiespersqdeg", type=float,
                help="Number of sky locations to generate per sq. deg. (don't pass to read the default from desimodel.io with a 64x margin)",
                default=None)
ap.add_argument("--numproc", type=int,
                help="number of concurrent processes to use (defaults to [{}])".format(nproc),
                default=nproc)
ap.add_argument("--gaiadir", type=str,
                help="Pass to set the GAIA_DIR environment variable directly in the code (i.e. the input directory that stores Gaia files)",
                default=None)
ap.add_argument('--nside', type=int,
                help="Process supplemental skies in HEALPixels at this resolution (defaults to None). See also the 'healpixels' input flag",
                default=None)
ap.add_argument('--healpixels',
                help="HEALPixels corresponding to `nside` (e.g. '6,21,57'). Only process files that touch these pixels and return targets within these pixels."+
                " The first entry is used to set RELEASE for TARGETIDs, and must be < 1000 (to prevent confusion with DR1 and above)",
                default=None)
ap.add_argument("--bundlefiles", type=int,
                help="(overrides all options) print slurm script to parallelize by sending (any) integer. This is an integer rather than boolean for consistency with select_targets.",
                default=None)
ap.add_argument("--mindec", type=float,
                help="Minimum declination to include in output file (degrees; defaults to [-90])",
                default=-90.)
ap.add_argument("--mingalb", type=float,
                help="Closest latitude to Galactic plane to output for NON-LEGACY-SURVEYS targets (e.g. send 10 to limit to areas beyond -10o <= b < 10o; defaults to [0])",
                default=0.)
ap.add_argument("--radius", type=float,
                help="Radius at which to avoid (all) Gaia sources (arcseconds; defaults to [2])",
                default=2.)
ap.add_argument("--nomasking", action='store_true',
                help="Masking occurs by default. If this is set, do NOT use Tycho+Gaia+URAT bright star mask to mask the sky locations")
ap.add_argument("--maskdir",
                help="Name of the specific directory (or file) containing the bright star mask (defaults to the most recent directory in $MASK_DIR)",
                default=None)
ap.add_argument("--sky-subpriorities", "--sky_subpriorities", type=str,
                help='Optional file with sky TARGETID:SUBPRIORITY override')

ns = ap.parse_args()
do_mask = not(ns.nomasking)

# ADM check the input sky file is in the "official" format.
official = io.is_sky_dir_official(ns.skydir)
if not official:
    msg = "Files in input sky directory ({}) not in official format!!!".format(
        ns.skydir)
    msg += " The remedy is typically to run bin/repartition_skies"
    log.critical(msg)
    raise ValueError(msg)

# ADM if the GAIA directory was passed, set it...
gaiadir = ns.gaiadir
if gaiadir is None:
    # ADM ...otherwise retrieve it from the environment variable.
    from desitarget.gaiamatch import get_gaia_dir
    gaiadir = get_gaia_dir()

# ADM if needed, determine the minimum density of sky fibers to generate.
nskiespersqdeg = ns.nskiespersqdeg
if nskiespersqdeg is None:
    nskiespersqdeg = density_of_sky_fibers(margin=64)
log.info('Generating sky positions at a density of {}'.format(nskiespersqdeg))

# ADM build the list of command line arguments as
# ADM bundlefiles potentially needs to know about them.
extra = " --numproc {}".format(ns.numproc)
extra += " --nskiespersqdeg {}".format(nskiespersqdeg)
extra += " --gaiadir {}".format(gaiadir)
nsdict = vars(ns)
for nskey in ["mindec", "mingalb", "radius", "nomasking", "maskdir",
              "sky_subpriorities"]:
    if isinstance(nsdict[nskey], bool):
        if nsdict[nskey]:
            extra += " --{}".format(nskey)
    elif nsdict[nskey] is not None:
        extra += " --{} {}".format(nskey, nsdict[nskey])

# ADM only proceed if we're not writing a slurm script.
if ns.bundlefiles is None:
    # ADM parse the list of HEALPixels in which to run.
    pixlist = ns.healpixels
    if pixlist is not None:
        pixlist = [int(pix) for pix in pixlist.split(',')]

    # ADM read the locations and the header for skies in this healpixel.
    skies, hdr = io.read_targets_in_hp(ns.skydir, ns.nside, pixlist,
                                       columns=["RA", "DEC"], header=True)

    # ADM check passed quantities for supp_skies are consistent with skies.
    msg = ""
    if hdr["NPERSDEG"] != nskiespersqdeg:
        msg += "num/sq. deg. differs: sky files ({}), supp_skies ({})".format(
            hdr["NPERSDEG"], nskiespersqdeg)
    if hdr["FILENSID"] != ns.nside:
        msg += " nside differs: sky files ({}), supp_skies ({})".format(
            hdr["FILENSID"], ns.nside)
    if len(msg) > 0:
        log.critical(msg)
        raise ValueError

    # ADM generate the supplemental sky locations.
    supp_skies = supplement_skies(nskiespersqdeg=nskiespersqdeg, numproc=ns.numproc,
                                  gaiadir=gaiadir, radius=ns.radius, nside=ns.nside,
                                  pixlist=pixlist, mindec=ns.mindec,
                                  mingalb=ns.mingalb)

    # ADM mask the supplemental sky locations using a bright star mask.
    if do_mask:
        maskdir = get_recent_mask_dir(ns.maskdir)
        supp_skies = mask_targets(supp_skies, maskdir, nside=ns.nside,
                                  pixlist=pixlist, bricks_are_hpx=True)

    # ADM a compact version of the maskdir name.
    md = maskdir.split("/")
    mdcomp = "/".join(md[md.index("masks"):])

    # ADM remove supplemental skies that share HEALPixels with skies.
    nside_resol = pixarea2nside(1./18000)//2
    ii, _ = shares_hp(nside_resol, supp_skies, skies)
    supp_skies = supp_skies[~ii]
    log.info("Removed {} supp skies that matched skies".format(np.sum(ii)))

    # ADM extra header keywords for the output fits file.
    extra = {k: v for k, v in zip(
        ["radius", "mindec", "mingalb", "masked", "maskdir"],
        [ns.radius, ns.mindec, ns.mingalb, do_mask, mdcomp])}

    extra = {
        "radius": ns.radius,
        "mindec": ns.mindec,
        "mingalb": ns.mingalb,
        "masked": do_mask,
        "maskdir": mdcomp,
        "cmdline": ' '.join(sys.argv),
        }
    extradeps = {
        "sky-subpriorities-override": str(ns.sky_subpriorities),
        }

        # SB apply optional subpriority override; modifies skies in-place
    if ns.sky_subpriorities:
        subpriorities = fitsio.read(ns.sky_subpriorities, 'SUBPRIORITY')
        ii = override_subpriority(supp_skies, subpriorities)
        if len(ii) > 0:
            log.info(f"Overriding {len(ii)} SUBPRIORITY from {ns.sky_subpriorities}")
        else:
            log.warning(f"No matching targets to override SUBPRIORITY from {ns.sky_subpriorities}")

    nskies, outfile = io.write_skies(ns.dest, supp_skies, supp=True, indir=gaiadir,
                                     nside=nside, nskiespersqdeg=nskiespersqdeg,
                                     extra=extra, extradeps=extradeps,
                                     nsidefile=ns.nside, hpxlist=pixlist)

    log.info('{} supplemental skies written to {}'.format(nskies, outfile))
else:
    # ADM if the bundlefiles option was sent, call the slurming code.
    bundle_bricks([0], ns.bundlefiles, ns.nside, gather=False,
                  prefix='supp-skies', surveydirs=[ns.skydir], extra=extra)
