#!/usr/bin/env python

import os, sys
import numpy as np
import numpy.lib.recfunctions as rfn
import fitsio

# ADM the columns that are to be written to the "QSO prob" files.
datamodel = np.array([], dtype=[
    ('RELEASE', '>i2'), ('BRICKID', '>i4'),
    ('BRICKNAME', 'S8'), ('BRICK_OBJID', '>i4'),
    ('RA', '>f8'), ('DEC', '>f8'),
    ('PHOTSYS', '<U1'),
    ('PQSO', '>f4'), ('PQSO_HIZ', '>f4')
    ])

from desiutil import depend
from desitarget import io
from desitarget.cuts import _prepare_optical_wise, isQSO_randomforest, _isonnorthphotsys
from desitarget.targets import main_cmx_or_sv

from time import time
start = time()

from desiutil.log import get_logger
log = get_logger()

from argparse import ArgumentParser
ap = ArgumentParser(description='Take a set of target or sweep files and write ')
ap.add_argument("src",
                help="Sweeps or target file or root directory with sweeps or target files")
ap.add_argument("dest",
                help="Output directory")

ns = ap.parse_args()

infiles = io.list_sweepfiles(ns.src)
extname = "SWEEP"
if len(infiles) == 0:
    infiles = io.list_targetfiles(ns.src)
    extname = "TARGETS"
if len(infiles) == 0:
    log.critical('no sweep or target files found')
    sys.exit(1)

# ADM Convert single file to list of files.
if isinstance(infiles, str):
    infiles = [infiles, ]

for fn in infiles:
    log.info("Working on file {}...t={:.1f}s".format(fn, time()-start))
    # ADM read the file.
    targs, hdr = fitsio.read(fn, header=True, ext=extname)

    # ADM the input targets may have had some column names changed.
    targs = rfn.rename_fields(targs, {'MORPHTYPE': 'TYPE'})

    # ADM Add DESI_TARGET to the data model for target files.
    dt = datamodel.dtype
    if extname == "TARGETS":
        tcols, _, _ = main_cmx_or_sv(targs)
        dt = datamodel.dtype.descr + [(tcols[0], '>i8')]
    # ADM will need to add PHOTSYS column to sweeps files.
    else:
        targs = io.add_photsys(targs)

    # ADM populate the data model.
    data = np.zeros(len(targs), dtype = dt)
    for col in data.dtype.names:
        if col in targs.dtype.names:
            data[col] = targs[col]

    # ADM run targets split by north/south, which have different RFs.
    norths = _isonnorthphotsys(targs["PHOTSYS"])

    for ii, south in zip([norths, ~norths], [False, True]):
        log.info("south is {}; running on {} targets...t={:.1f}s".format(
            south, np.sum(ii), time()-start))
        # ADM grab the necessary information for running the RF.
        photsys_north, photsys_south, obs_rflux, gflux, rflux, zflux,  \
            w1flux, w2flux, gfiberflux, rfiberflux, zfiberflux,        \
            objtype, release, gfluxivar, rfluxivar, zfluxivar,         \
            gnobs, rnobs, znobs, gfracflux, rfracflux, zfracflux,      \
            gfracmasked, rfracmasked, zfracmasked,                     \
            gfracin, rfracin, zfracin, gallmask, rallmask, zallmask,   \
            gsnr, rsnr, zsnr, w1snr, w2snr,                            \
            dchisq, deltaChi2, maskbits, refcat =                      \
            _prepare_optical_wise(targs[ii], mask=True)

        # ADM run the RFs.
        _, _, pqso, pqsohiz = isQSO_randomforest(
            primary=None, zflux=zflux, rflux=rflux, gflux=gflux,
            w1flux=w1flux, w2flux=w2flux, deltaChi2=deltaChi2,
            maskbits=maskbits, gnobs=gnobs, rnobs=rnobs, znobs=znobs,
            objtype=objtype, release=release, south=south,
            return_probs=True)

        # ADM add the RFs to the output data.
        data["PQSO"][ii] = pqso
        data["PQSO_HIZ"][ii] = pqsohiz

    # ADM write the file.
    depend.setdep(hdr, 'desitarget-pqso-git', io.gitversion())
    outfn = os.path.join(ns.dest,
                         os.path.basename(fn).replace(".fits", '-pqso.fits'))

    log.info("Writing to {}...t={:.1f}s".format(outfn, time()-start))
    io.write_with_units(outfn, data, extname="PQSO", header=hdr)

log.info("Done...t={:.1f}s".format(time()-start))
