#!python

import os, sys
import numpy as np
import fitsio
from time import time
start = time()
import fitsio

from desiutil.log import get_logger
log = get_logger()
from desitarget.io import write_in_chunks

from argparse import ArgumentParser
ap = ArgumentParser(
    description='Combine N random and "-outside-" catalogs into "-allsky-" catalogs. Make the data model ' +
    'resemble the random catalogs with zeros where the "-outside-" catalogs have no data. This can be easily '+  
    'batched in a slurm file, for example, one line in such a file might be: '+
    'srun -N 1 combine_randoms /global/cfs/cdirs/desi/target/catalogs/dr9/0.47.0/randoms/resolve/ -s 1-13 &')
ap.add_argument("randir",
                help='Full path to the DIRECTORY containing random files (e.g /global/cfs/cdirs/desi/target/catalogs/dr9/0.47.0/randoms/resolve).')
ap.add_argument("-s", "--seedstring",
                help='Part of the filename that corresponds to the unique seed(s) used to run the randoms. The (corresponding) catalogs should be called \
                randoms-{seedstring}.fits and randoms-outside-{seedstring}.fits')
ap.add_argument("-nc", "--nchunks", type=int,
                help='Number of chunks to write the file in to save memory [10].',
                default=10)

ns = ap.parse_args()

log.info(
    "Working on file {}...t={:.1f}s".format(ns.seedstring, time()-start)
)
# ADM determine the file names for the supp and randoms files.
rcfn = os.path.join(ns.randir, "randoms-{}.fits".format(ns.seedstring))
suppfn = os.path.join(ns.randir, "randoms-outside-{}.fits".format(ns.seedstring))
combfn = os.path.join(ns.randir, "randoms-allsky-{}.fits".format(ns.seedstring))

# ADM read in the random catalog and it's header.
tempins, hdr = fitsio.read(rcfn, header=True)
# ADM read in the supp file.
tempouts = fitsio.read(suppfn)
log.info(
    "Read in files {} inside, {} outside...t={:.1f}s".format(
        len(tempins), len(tempouts), time()-start)
)

# ADM add zeros for the missing columns in the random catalog
# ADM and make dtype uniform.
dt=[('RA', '>f8'), ('DEC', '>f8'), ('BRICKNAME', '<U8'), ('BRICKID', '>i4'),
    ('NOBS_G', '>i2'), ('NOBS_R', '>i2'), ('NOBS_Z', '>i2'),
    ('PSFDEPTH_G', '>f4'), ('PSFDEPTH_R', '>f4'), ('PSFDEPTH_Z', '>f4'),
    ('GALDEPTH_G', '>f4'), ('GALDEPTH_R', '>f4'), ('GALDEPTH_Z', '>f4'),
    ('PSFDEPTH_W1', '>f4'), ('PSFDEPTH_W2', '>f4'),
    ('PSFSIZE_G', '>f4'), ('PSFSIZE_R', '>f4'), ('PSFSIZE_Z', '>f4'),
    ('APFLUX_G', '>f4'), ('APFLUX_R', '>f4'), ('APFLUX_Z', '>f4'),
    ('APFLUX_IVAR_G', '>f4'), ('APFLUX_IVAR_R', '>f4'), ('APFLUX_IVAR_Z', '>f4'),
    ('MASKBITS', '>i2'), ('WISEMASK_W1', 'u1'), ('WISEMASK_W2', 'u1'),
    ('EBV', '>f4'), ('PHOTSYS', '<U1'), ('HPXPIXEL', '>i8')]

inside = np.zeros(len(tempins), dtype=dt)
outside = np.zeros(len(tempouts), dtype=dt)
# ADM these are different because we are adding zerod colums to the
# ADM outside files, but removing some columns from the inside files.
for col in tempouts.dtype.names:
    outside[col] = tempouts[col]
for col in inside.dtype.names:
    inside[col] = tempins[col]

log.info("Populated zeros...t={:.1f}s".format(time()-start))

# ADM combine the inside and outside randoms...
rands = np.concatenate([outside, inside])
# ADM ...shuffle them...
nrands = len(rands)
indexes = np.arange(nrands)
np.random.seed(646)
np.random.shuffle(indexes)

log.info("Writing out file...t={:.1f}s".format(time()-start))
# ADM ...and write them out.
nchunks = ns.nchunks
outy = fitsio.FITS(combfn, 'rw', clobber=True)
# ADM write the chunks one-by-one.
chunk = len(indexes)//nchunks
for i in range(nchunks):
    log.info("Writing chunk {}/{} from index {} to {}...t = {:.1f}s"
             .format(i+1, nchunks, i*chunk, (i+1)*chunk-1, time()-start))
    ichunk = indexes[i*chunk:(i+1)*chunk]
    # ADM if this is the first chunk, write the data and header...
    if i == 0:
        outy.write(rands[ichunk], extname='RANDOMS', header=hdr)
    # ADM ...otherwise just append to the existing file object.
    else:
        outy[-1].append(rands[ichunk])
# ADM append any remaining data.
ichunk = indexes[nchunks*chunk:]
log.info("Writing final partial chunk from index {} to {}...t = {:.1f}s"
         .format(nchunks*chunk, len(indexes)-1, time()-start))
outy[-1].append(rands[ichunk])
outy.close()

log.info("Done...t={:.1f}s".format(time()-start))
