#!python

from __future__ import print_function, division

import os, sys
import numpy as np
import fitsio

from desitarget import io
from desitarget.io import desitarget_version, check_both_set, get_checksums
from desitarget.cuts import select_targets, qso_selection_options
from desitarget.brightmask import mask_targets
from desitarget.QA import _parse_tcnames
from desitarget.targets import decode_targetid
from desitarget.subpriority import override_subpriority

from time import time
start = time()

#import warnings
#warnings.simplefilter('error')

import multiprocessing
nproc = multiprocessing.cpu_count() // 2
# ADM don't confuse this with the ns.nside input that is parsed below!!!
nside = io.desitarget_nside()

from desiutil.log import get_logger
log = get_logger()

from argparse import ArgumentParser
ap = ArgumentParser(description='Generates DESI target bits from Legacy Surveys sweeps or tractor files')
ap.add_argument("sweepdir",
                help="Tractor/sweeps file or root directory with tractor/sweeps files")
ap.add_argument("dest",
                help="Output target selection directory (the file name is built on-the-fly from other inputs)")
ap.add_argument('-s2', "--sweepdir2",
                help='Additional Tractor/sweeps file or directory (useful for combining, e.g., DR8 into one file of targets)',
                default=None)
ap.add_argument('-m', "--mask",
                help="If sent then mask the targets, the name of the mask file should be supplied")
ap.add_argument('--qsoselection', choices=qso_selection_options, default='randomforest',
                help="QSO target selection method")
ap.add_argument("--gaiasub", action='store_true',
                help="Substitute Gaia EDR3 proper motions and parallaxes in place of the sweeps (DR2) values. Only valid for DR9 sweeps files.")
ap.add_argument("--numproc", type=int,
                help='number of concurrent processes to use [defaults to {}]'.format(nproc),
                default=nproc)
ap.add_argument('-t','--tcnames', default=None,
                help="Comma-separated names of target classes to run (e.g. QSO,LRG). Options are ELG, QSO, LRG, MWS, BGS, STD. Default is to run everything)")
ap.add_argument('--nside', type=int,
                help="Process targets in HEALPixels at this resolution (defaults to None). See also the 'healpixels' input flag",
                default=None)
ap.add_argument('--healpixels',
                help="HEALPixels corresponding to `nside` (e.g. '6,21,57'). Only process files that touch these pixels and return targets within these pixels",
                default=None)
ap.add_argument("--bundlefiles", type=int,
                help="(overrides all options but `sweepdir`) print slurm script to parallelize by sending (any) integer. This is an integer rather than boolean for historical reasons",
                default=None)
ap.add_argument('--radecbox',
                help="Only return targets in an RA/Dec box denoted by 'RAmin,RAmax,Decmin,Decmax' in degrees (e.g. '140,150,-10,-20')",
                default=None)
ap.add_argument('--radecrad',
                help="Only return targets in an RA/Dec circle/cap denoted by 'centerRA,centerDec,radius' in degrees (e.g. '140,150,0.5')",
                default=None)
ap.add_argument("--noresolve", action='store_true',
                help="Do NOT resolve into northern targets in northern regions and southern targets in southern regions")
ap.add_argument("--nomaskbits", action='store_true',
                help="Do NOT apply information in MASKBITS column to target classes")
ap.add_argument("--writeall", action='store_true',
                help="Default behavior is to split targets by bright/dark-time surveys. Send this to ALSO write a file of ALL targets")
ap.add_argument("-nos", "--nosecondary", action='store_true',
                help="Do NOT create TARGETID look-up files for secondary targets in $SCNDIR/outdata/desitargetversion/priminfo-drversion-desitargetversion/$dest.fits (where $dest is the basename of dest)")
ap.add_argument("--scnddir",
                help="Base directory of secondary target files (e.g. '/project/projectdirs/desi/target/secondary' at NERSC). "+
                "Defaults to SCND_DIR environment variable. Not needed if --nosecondary is sent.")
ap.add_argument("-nob", "--nobackup", action='store_true',
                help="Do NOT run the Gaia-only backup targets (which require the GAIA_DIR environment variable to be set).")
ap.add_argument("-noc", "--nochecksum", action='store_true',
                help='Do NOT add the list of input files and their checksums to the output target file as the second ("INFILES") extension')
ap.add_argument("-check", "--checkbright", action='store_true',
                help='If passed, then log a warning about targets that could be too bright when writing output files')
ap.add_argument("--dark-subpriorities", "--dark_subpriorities", type=str,
                help='Optional file with dark TARGETID:SUBPRIORITY overrides')
ap.add_argument("--bright-subpriorities", "--bright_subpriorities", type=str,
                help='Optional file with bright TARGETID:SUBPRIORITY overrides')

ns = ap.parse_args()
# ADM build the list of command line arguments as
# ADM bundlefiles potentially needs to know about them.
extra = " --numproc {}".format(ns.numproc)
nsdict = vars(ns)
for nskey in ["tcnames", "noresolve", "nomaskbits", "writeall",
              "nosecondary", "nobackup", "nochecksum", "gaiasub", "scnddir",
              "dark_subpriorities", "bright_subpriorities"]:
    if isinstance(nsdict[nskey], bool):
        if nsdict[nskey]:
            extra += " --{}".format(nskey)
    elif nsdict[nskey] is not None:
        extra += " --{} {}".format(nskey, nsdict[nskey])

infiles = io.list_sweepfiles(ns.sweepdir)
if ns.sweepdir2 is not None:
    infiles2 = io.list_sweepfiles(ns.sweepdir2)
    infiles += infiles2
if len(infiles) == 0:
    infiles = io.list_tractorfiles(ns.sweepdir)
    if ns.sweepdir2 is not None:
        infiles2 = io.list_tractorfiles(ns.sweepdir2)
        infiles += infiles2
if len(infiles) == 0:
    log.critical('no sweep or tractor files found')
    sys.exit(1)

# ADM Only coded for objects with Gaia matches
# ADM (e.g. DR6 or above). Fail for earlier Data Releases.
# ADM Guard against a single file being passed.
fn = infiles
if ~isinstance(infiles, str):
    fn = infiles[0]
data = fitsio.read(fn, columns=["RELEASE","PMRA"], upper=True)
if np.any(data["RELEASE"] < 6000):
    log.critical('SV cuts only coded for DR6 or above')
    raise ValueError
if (np.max(data['PMRA']) == 0.) & np.any(data["RELEASE"] < 7000):
    d = "/project/projectdirs/desi/target/gaia_dr2_match_dr6"
    log.info("Zero objects have a proper motion.")
    log.critical(
        "Did you mean to send the Gaia-matched sweeps in, e.g., {}?"
        .format(d)
    )
    raise IOError

if ns.bundlefiles is None:
    log.info("running on {} processors".format(ns.numproc))
    # ADM formally writing pixelized files requires both the nside
    # ADM and the list of healpixels to be set.
    check_both_set(ns.healpixels, ns.nside)

# ADM parse the list of HEALPixels in which to run.
pixlist = ns.healpixels
if pixlist is not None:
    pixlist = [int(pix) for pix in pixlist.split(',')]

# ADM parse the list of RA/Dec regions in which to run.
inlists = [ns.radecbox, ns.radecrad]
for i, inlist in enumerate(inlists):
    if inlist is not None:
        inlists[i] = [float(num) for num in inlist.split(',')]

# ADM limit to specific bit names, if passed, otherwise run all targets.
tcnames = _parse_tcnames(tcstring=ns.tcnames, add_all=False)

targets, infn = select_targets(
    infiles, numproc=ns.numproc, qso_selection=ns.qsoselection,
    gaiasub=ns.gaiasub, nside=ns.nside, pixlist=pixlist, extra=extra,
    bundlefiles=ns.bundlefiles, radecbox=inlists[0], radecrad=inlists[1],
    tcnames=tcnames, survey='main', backup=not(ns.nobackup),
    resolvetargs=not(ns.noresolve), mask=not(ns.nomaskbits), return_infiles=True
)

if ns.bundlefiles is None:
    # ADM Set the list of infiles actually processed by select_targets() to
    # ADM None if we DON'T want to write their checksums to the output file.
    if ns.nochecksum:
        shatab = None
    else:
        shatab = get_checksums(infn, verbose=True)

    # ADM only run secondary functions if --nosecondary was not passed.
    scndout = None
    if not ns.nosecondary and len(targets) > 0:
        from desitarget.secondary import _get_scxdir, match_secondary
        # ADM read secondary target directory.
        scxdir = _get_scxdir(ns.scnddir)
        # ADM construct a label for the secondary file for TARGETID look-ups.
        try:
            drint = int(ns.sweepdir.split("dr")[1][0])
        except (ValueError, IndexError, AttributeError):
            drint = "X"
        scndoutdn = "priminfo-dr{}-{}".format(drint, desitarget_version)
        scndoutdn = os.path.join(scxdir,
                                 "outdata", desitarget_version, scndoutdn)
        if not os.path.exists(scndoutdn):
            log.info("making directory...{}".format(scndoutdn))
            os.makedirs(scndoutdn)
        if pixlist is not None:
            scndoutfn = io.find_target_files(ns.dest, dr=drint, flavor="targets",
                                             survey="main", hp=pixlist)
        else:
            scndoutfn = io.find_target_files(ns.dest, dr=drint, flavor="targets",
                                             survey="main", hp="X")
        # ADM construct the output directory for primary match info.
        scndoutfn = os.path.basename(scndoutfn)
        scndout = os.path.join(scndoutdn, scndoutfn)
        log.info("writing files of primary matches to...{}".format(scndout))
        targets = match_secondary(targets, scxdir, scndout, sep=1.,
                                  pix=pixlist, nside=ns.nside, swfiles=infn)

    if ns.mask:
        targets = mask_targets(targets, inmaskfile=ns.mask, nside=nside)

    # ADM extra header keywords for the output fits file.
    extra = {
        "tcnames": ns.tcnames,
        "gaiasub": ns.gaiasub,
        "cmdline": ' '.join(sys.argv),  #- just in case...
        }
    extradeps = {
        "bright-subpriorities-override": str(ns.bright_subpriorities),
        "dark-subpriorities-override": str(ns.dark_subpriorities),
        }

    # ADM differentiate the Gaia-only and Legacy Surveys targets.
    _, _, _, _, _, gaiadr = decode_targetid(targets["TARGETID"])
    isgaia = gaiadr > 0
    # ADM write out bright-time and dark-time targets separately,
    # ADM together with the Gaia-only back-up objects
    # ADM and the extension targets.
    # ADM can use DARK for the back-up/supp objects as they never
    # ADM need merged with another target.
    obscons = ["BRIGHT", "DARK", "BACKUP", "BRIGHT1B", "DARK1B"]
    iis = [~isgaia, ~isgaia, isgaia, ~isgaia, ~isgaia]
    supps = [False, False, True, False, False]
    if ns.writeall:
        obscons.append(None)
        iis.append(~isgaia)
        supps.append(False)

    # SB subpriority override logic assumes that io.write_targets will set
    # subpriorities; if they have already been set, crash early so that we
    # can fix the logic
    if np.any(targets['SUBPRIORITY'] > 0.0):
        log.critical('SUBPRIORITY already set; fix override logic below')
        sys.exit(1)

    for obscon, ii, supp in zip(obscons, iis, supps):

        # SB apply optional subpriority override; modifies targets in-place
        targets['SUBPRIORITY'] = 0.0  # SB reset any previous overrides
        if obscon == 'BRIGHT' and ns.bright_subpriorities:
            subpriorities = fitsio.read(ns.bright_subpriorities, 'SUBPRIORITY')
            overrideidx = override_subpriority(targets, subpriorities)
            log.info(f'Overriding {len(overrideidx)} {obscon} subpriorities')
        elif obscon == 'DARK' and ns.dark_subpriorities:
            subpriorities = fitsio.read(ns.dark_subpriorities, 'SUBPRIORITY')
            overrideidx = override_subpriority(targets, subpriorities)
            log.info(f'Overriding {len(overrideidx)} {obscon} subpriorities')

        ntargs, outfile = io.write_targets(
            ns.dest, targets[ii], resolve=not(ns.noresolve), nside=nside,
            maskbits=not(ns.nomaskbits), indir=ns.sweepdir, indir2=ns.sweepdir2,
            obscon=obscon, scndout=scndout, survey="main", nsidefile=ns.nside,
            hpxlist=pixlist, supp=supp, qso_selection=ns.qsoselection,
            extra=extra, extradeps=extradeps, nosec=ns.nosecondary,
            infiles=shatab, checkbright=ns.checkbright
        )
        log.info('{} targets written to {}...t={:.1f}s'.format(ntargs, outfile, time()-start))
