#!/usr/bin/env python

import fitsio
import numpy as np
from desitarget.secondary import select_secondary, _get_scxdir
from desitarget.brightmask import is_in_bright_mask, get_recent_mask_dir
from desitarget import io
from desitarget.subpriority import override_subpriority
import os
from glob import glob
from time import time
time0 = time()

from desiutil.log import get_logger
log = get_logger()

from argparse import ArgumentParser
ap = ArgumentParser(description='Generate file of secondary-only targets from $SCND_DIR, write matches to primary targets back to $SCND_DIR')
ap.add_argument("priminfodir",
                help="Location of the directory that has previously matched primary and secondary targets to recover the unique primary TARGETIDs " +
                "(as made by, e.g., select_targets without the --nosecondary option set).")
ap.add_argument("dest",
                help="Output secondary-only targets directory (the file names are built on-the-fly from other inputs")
ap.add_argument("-s", "--separation", type=float, default=1.,
                help='Angular separation at which to match secondary targets to themselves. Defaults to [1] arcsec.')
ap.add_argument("--scnddir",
                help="Base directory of secondary target files (e.g. '/project/projectdirs/desi/target/secondary' at NERSC). " +
                "Defaults to SCND_DIR environment variable.")
ap.add_argument("--writeall",
                action='store_true',
                help="Default behavior is to split targets by bright/dark-time surveys. Set this to ALSO write a file of ALL targets")
ap.add_argument("--nomasking", action='store_true',
                help="Masking occurs by default. If this is set, do NOT use a bright star mask to mask the secondary targets")
ap.add_argument("--maskdir",
                help="Name of the specific directory (or file) containing the bright star mask (defaults to the most recent directory in $MASK_DIR)",
                default=None)
ap.add_argument("--dark-subpriorities", type=str,
                help='Optional file with dark TARGETID:SUBPRIORITY overrides')
ap.add_argument("--bright-subpriorities", type=str,
                help='Optional file with bright TARGETID:SUBPRIORITY overrides')

ns = ap.parse_args()
do_mask = not(ns.nomasking)

# ADM Sanity check that priminfodir exists.
if not os.path.exists(ns.priminfodir):
    msg = "{} doesn't exist".format(ns.priminfodir)
    log.critical(msg)
    raise ValueError(msg)

# ADM read survey type from the header of the first file in priminfodir.
fns = sorted(glob(os.path.join(ns.priminfodir, "*fits")))
hdr = fitsio.read_header(fns[0], 'SCND_TARG')
surv = hdr["SURVEY"].rstrip()

# ADM find the SCND_DIR environment variable, if it wasn't passed.
scxdir = _get_scxdir(ns.scnddir, survey=surv)
# ADM and augment the scxdir if this is an SV set of primary files.
if surv != 'main':
    scxdir = os.path.join(scxdir, surv)

scx = select_secondary(ns.priminfodir, sep=ns.separation, scxdir=scxdir, darkbright=(not ns.writeall))

# ADM add the primary directory and matching radius to the header
# ADM from the first primary file.
hdr['PRIMDIR'] = ns.priminfodir
hdr['SEP'] = float(ns.separation)

# ADM try to determine the Data Release from the priminfo file.
try:
    drint = ns.priminfodir.split('dr')[1].split('-')[0]
except (ValueError, IndexError, AttributeError):
    drint = "X"

# ADM if secondaries need masked, grab the mask directory and find which
# ADM targets are masked. Also add the masking information to the header.
mdcomp = None
if do_mask:
    # ADM grab the mask directory
    maskdir = get_recent_mask_dir(ns.maskdir)
    log.info("Masking secondaries using bright star masks in {}".format(maskdir))
    # ADM read in the masks across the entire sky.
    Mx = io.read_targets_in_quick(maskdir, shape='box',
                                  radecbox=[0.0, 360.0, -90.0, 90.0])
    log.info("Checking {} targets against {} masks".format(len(scx), len(Mx)))
    in_mask, _ = is_in_bright_mask(scx, Mx, inonly=True)
    # ADM the output in_mask is a list. We want the array it contains.
    in_mask = in_mask[0]
    log.info("{} targets are in bright star masks".format(np.sum(in_mask)))
    # ADM a compact version of the maskdir name.
    md = maskdir.split("/")
    mdcomp = "/".join(md[md.index("masks"):])
hdr["masked"] = do_mask
hdr["maskdir"] = mdcomp

# SB subpriority override logic assumes that io.write_targets will set
# subpriorities; if they have already been set, crash early so that we
# can fix the logic
if np.any(scx['SUBPRIORITY'] > 0.0):
    log.critical('SUBPRIORITY already set; fix override logic below')
    sys.exit(1)

# ADM write out the secondary targets, with bright-time
# ADM and dark-time targets written separately.
obscons = ["BRIGHT", "DARK"]

if ns.writeall:
    obscons.append(None)
for obscon in obscons:

    # SB apply optional subpriority override; modifies targets in-place
    scx['SUBPRIORITY'] = 0.0  # SB reset any previous overrides
    if obscon == 'BRIGHT' and ns.bright_subpriorities:
        subpriorities = fitsio.read(ns.bright_subpriorities, 'SUBPRIORITY')
        overrideidx = override_subpriority(scx, subpriorities)
        log.info(f'Overriding {len(overrideidx)} {obscon} subpriorities')
    elif obscon == 'DARK' and ns.dark_subpriorities:
        subpriorities = fitsio.read(ns.dark_subpriorities, 'SUBPRIORITY')
        overrideidx = override_subpriority(scx, subpriorities)
        log.info(f'Overriding {len(overrideidx)} {obscon} subpriorities')

    # ADM the dict() here is to make hdr immutable.
    if obscon is None:
        ntargs, outfile = io.write_secondary(ns.dest, scx, primhdr=dict(hdr),
                                             scxdir=scxdir, obscon=obscon,
                                             drint=drint)
    else:
        ntargs, outfile = io.write_secondary(ns.dest, scx[~in_mask],
                                             primhdr=dict(hdr), scxdir=scxdir,
                                             obscon=obscon, drint=drint)

    log.info('{} standalone secondary targets written to {}...t={:.1f}mins'
             .format(ntargs, outfile, (time()-time0)/60.))
