#!/usr/bin/env python

import os, sys

from desitarget.skyutilities.legacypipe.util import LegacySurveyData
from desitarget.skyfibers import select_skies, density_of_sky_fibers, get_brick_info
from desitarget import io
from desitarget.io import check_both_set
from desitarget.geomask import bundle_bricks
from desitarget.targets import resolve
from desitarget.brightmask import mask_targets, get_recent_mask_dir
from desitarget.subpriority import override_subpriority

import numpy as np
import healpy as hp
import fitsio
from glob import glob

#import warnings
#warnings.simplefilter('error')

import multiprocessing
nproc = multiprocessing.cpu_count() // 2
# ADM default HEALPix Nside used throughout desitarget.
# ADM don't confuse this with the ns.nside parallelization input that is parsed below!!!
nside = io.desitarget_nside()

from desiutil.log import get_logger
log = get_logger()

from argparse import ArgumentParser
ap = ArgumentParser(description='Generates possible DESI sky fiber locations in Legacy Survey bricks')
ap.add_argument("surveydir", 
                help="Base directory for a Legacy Surveys Data Release (e.g. '/global/project/projectdirs/cosmo/data/legacysurvey/dr6/' at NERSC")
ap.add_argument("dest",
                help="Output sky targets directory (the file name is built on-the-fly from other inputs)")
ap.add_argument('-s2', "--surveydir2",
                help='Additional Legacy Surveys directory (useful for combining, e.g., DR8 into one file of sky locations)',
                default=None)
ap.add_argument("--nskiespersqdeg", type=float,
                help="Number of sky locations to generate per sq. deg. (don't pass to read the default from desimodel.io with a 64x margin)",
                default=None)
ap.add_argument("--bands", 
                help='Bands in this Legacy Survey Data Release to consider when deriving sky location and aperture fluxes (e.g, "g,r")',
                default="g,r,z")
ap.add_argument("--apertures", 
                help='Aperture radii in arcseconds at which to derive flux measurements for each sky location (e.g, "0.75, 1.0"; defaults to "0.75")',
                default="0.75")
ap.add_argument('--nside', type=int,
                help='Process sky locations in parallel in bricks that have centers within HEALPixels at this resolution (defaults to None)',
                default=None)
ap.add_argument('--healpixels', 
                help='HEALPixels (corresponding to nside) to process (e.g. "5,7,11"). If not passed, run all bricks in the Data Release',
                default=None)
ap.add_argument("--writebricks", action="store_true",
                help="Write sky information for EACH brick (the per-brick file names look like surveydir/metrics/(brick).3s/skies-(brick)s.fits.gz")
ap.add_argument("--bundlebricks", type=int,
                help="(overrides all options but surveydir) Print a slurm script to parallelize, with about this many bricks per HEALPixel (e.g. 14000). " +
                "Send a small number (e.g. 1) to split by individual HEALPixels.", 
                default=None)
ap.add_argument("--brickspersec", type=float,
                help="estimate of bricks completed per second by the (parallelized) code. Used with `bundlebricks` to guess run times (defaults to 1.6)",
                default=1.6)
ap.add_argument("--numproc", type=int,
                help="number of concurrent processes to use [{}]".format(nproc),
                default=nproc)
ap.add_argument("--nomasking", action='store_true',
                help="Masking occurs by default. If this is set, do NOT use a bright star mask to mask the sky locations")
ap.add_argument("--maskdir",
                help="Name of the specific directory (or file) containing the bright star mask (defaults to the most recent directory in $MASK_DIR)",
                default=None)
ap.add_argument("--sky-subpriorities", "--sky_subpriorities", type=str,
                help='Optional file with sky TARGETID:SUBPRIORITY override')

ns = ap.parse_args()
do_mask = not(ns.nomasking)
# ADM build the list of command line arguments as
# ADM bundlefiles potentially needs to know about them.
extra = " --numproc {}".format(ns.numproc)
nsdict = vars(ns)
for nskey in ["nskiespersqdeg", "bands", "apertures", "nomasking", "maskdir",
              "sky_subpriorities"]:
    if isinstance(nsdict[nskey], bool):
        if nsdict[nskey]:
            extra += " --{}".format(nskey)
    elif nsdict[nskey] is not None:
        extra += " --{} {}".format(nskey, nsdict[nskey])

for indir in ns.surveydir, ns.surveydir2:
    if indir is not None:
        if not os.path.exists(indir):
            log.critical('Input directory ({}) does not exist'
                         .format(indir))
            sys.exit(1)
        if not os.path.exists(os.path.join(indir, "coadd")):
            log.critical('coadd subdirectory not in input directory ({})'
                         .format(indir))
            sys.exit(1)

# ADM convert passed csv strings to lists.
bands = [ band for band in ns.bands.split(',') ]
apertures = [ float(aperture) for aperture in ns.apertures.split(',') ]

pixlist = ns.healpixels
if pixlist is not None:
    pixlist = [ int(pixnum) for pixnum in pixlist.split(',') ]

# ADM if needed, determine the minimum density of sky fibers to generate.
nskiespersqdeg = ns.nskiespersqdeg
if nskiespersqdeg is None:
    nskiespersqdeg = density_of_sky_fibers(margin=64)
# ADM and log how many sky positions per brick we expect to be generated.
area = 0.25*0.25
nskiesfloat = area*nskiespersqdeg
nskies = (np.sqrt(nskiesfloat).astype('int16') + 1)**2
log.info('Generating {} sky positions in each brick'.format(nskies))

surveys = [LegacySurveyData(survey_dir=ns.surveydir)]
if ns.surveydir2 is not None:
    surveys.append(LegacySurveyData(survey_dir=ns.surveydir2))

# ADM if bundlebricks is set, grab the HEALPixel number for each brick.
if ns.bundlebricks is not None:
    drdirs = [survey.survey_dir for survey in surveys]
    brickdict = get_brick_info(drdirs, counts=True)
    bra, bdec, _, _, _, _, cnts = np.vstack(list(brickdict.values())).T
    theta, phi = np.radians(90-bdec), np.radians(bra)
    pixnum = hp.ang2pix(ns.nside, theta, phi, nest=True)
    # ADM pixnum only contains unique bricks, need to add duplicates.
    allpixnum = np.concatenate([np.zeros(cnt, dtype=int)+pix
                        for cnt, pix in zip(cnts.astype(int), pixnum)])
    bundle_bricks(
        allpixnum, ns.bundlebricks, ns.nside, prefix='skies', extra=extra,
        gather=False, surveydirs=drdirs, brickspersec=ns.brickspersec)
else:
    log.info("running on {} processors".format(ns.numproc))
    # ADM formally writing pixelized files requires both the nside
    # ADM and the list of healpixels to be set.
    check_both_set(ns.healpixels, ns.nside)
    # ADM run the main sky selection code over the passed surveys.
    skies = []
    for survey in surveys:
        skies.append(select_skies(
            survey, numproc=ns.numproc, nskiespersqdeg=nskiespersqdeg,
            bands=bands, apertures_arcsec=apertures,
            nside=ns.nside, pixlist=pixlist, writebricks=ns.writebricks)
        )

    # ADM redact empty output (where there were no bricks in a survey).
    skies = np.array(skies, dtype=object)
    ii = [sk is not None for sk in skies]
    skies = np.concatenate(skies[ii])

    # ADM resolve any duplicates between imaging data releases. 
    resolved = resolve(skies)
    # ADM recover any unique bricks that were on the other side of the
    # ADM resolve, in case of missing N/S bricks available in the S/N.
    missedbrx = list(set(skies["BRICKNAME"]) - set(resolved["BRICKNAME"]))
    if len(missedbrx) > 0:
        backin = np.any([skies["BRICKNAME"]==brx for brx in missedbrx], axis=0)
        skies = np.concatenate([resolved, skies[backin]])
        log.info('Added some post-resolve bricks back in: {}'.format(missedbrx))
    else:
        skies = resolved

    # ADM mask the sky locations using a bright star mask.
    mdcomp = None
    if do_mask:
        maskdir = get_recent_mask_dir(ns.maskdir)
        skies = mask_targets(skies, maskdir, nside=ns.nside, pixlist=pixlist)
        # ADM a compact version of the maskdir name.
        md = maskdir.split("/")
        mdcomp = "/".join(md[md.index("masks"):])

    # ADM extra header keywords for the output fits file.
    extra = {
        "masked": do_mask,
        "maskdir": mdcomp,
        "cmdline": ' '.join(sys.argv),
        }
    extradeps = {
        "sky-subpriorities-override": str(ns.sky_subpriorities),
        }

    # SB apply optional subpriority override; modifies skies in-place
    if ns.sky_subpriorities:
        subpriorities = fitsio.read(ns.sky_subpriorities, 'SUBPRIORITY')
        ii = override_subpriority(skies, subpriorities)
        if len(ii) > 0:
            log.info(f"Overriding {len(ii)} SUBPRIORITY from {ns.sky_subpriorities}")
        else:
            log.warning(f"No matching targets to override SUBPRIORITY from {ns.sky_subpriorities}")

    # ADM this correctly records the apertures in the output file header
    # ADM as well as adding HEALPixel information.
    nskies, outfile = io.write_skies(ns.dest, skies, indir=ns.surveydir,
                                     indir2=ns.surveydir2, nside=nside,
                                     apertures_arcsec=apertures, extra=extra,
                                     extradeps=extradeps,
                                     nskiespersqdeg=nskiespersqdeg,
                                     nsidefile=ns.nside, hpxlist=pixlist)

    log.info('{} skies written to {}'.format(nskies, outfile))
