#!/usr/bin/env python

import os, sys
import numpy as np
import fitsio
from time import time
start = time()
import fitsio

from desitarget.randoms import finalize_randoms, add_default_mtl

from desiutil.log import get_logger
log = get_logger()

from argparse import ArgumentParser
ap = ArgumentParser(description='Split a random catalog into N smaller catalogs. Shuffle the random catalog first to ensure randomness.')
ap.add_argument("randomcat",
                help='A random catalog (e.g /project/projectdirs/desi/target/catalogs/randoms-dr4-0.20.0.fits). For an input catalog /X/X.fits N smaller catalogs will be written to /X/X-[0:N-1].fits')
ap.add_argument("-n", "--nchunks", type=int,
                help='Number of smaller catalogs to split the random catalog into. Defaults to [10].',
                default="10")
ap.add_argument("--addmtl", action='store_true',
                help="If passed, then add the columns needed for MTL to the random catalogs after they are split.")
ap.add_argument("--skip", action='store_true',
                help="Check if the input random catalog exists. If it doesn't, do absolutely nothing (useful for scripting across all possible HEALPixels).")

ns = ap.parse_args()

if not os.path.exists(ns.randomcat):
    if ns.skip:
        log.info('Input catalog does not exist: {}'.format(ns.randomcat))
        sys.exit(0)
    else:
        log.critical('Input catalog does not exist: {}'.format(ns.randomcat))
        sys.exit(1)

log.info("Read in randoms from {} and split into {} catalogs...t = {:.1f}s"
         .format(ns.randomcat, ns.nchunks, time()-start))

hdr = fitsio.read_header(ns.randomcat, "RANDOMS")
nrands = hdr["NAXIS2"]

# ADM read in the seed used to make the catalog.
seed = hdr["SEED"]

# ADM shuffle to ensure randomness.
log.info("Read in {:.1e} randoms. Shuffling indexes...t = {:.1f}s"
         .format(nrands, time()-start))
indexes = np.arange(nrands)
spltseed = 636
np.random.seed(spltseed)
hdr["SPLTSEED"] = spltseed
np.random.shuffle(indexes)

# ADM note whether we requested the MTL columns to be added.
if ns.addmtl:
    hdr["MTLSPLIT"] = True

# ADM write in chunks to save memory.
chunk = nrands//ns.nchunks
# ADM remember that the density has effectively gone down.
hdr["DENSITY"] //= ns.nchunks

# ADM write out smaller files one-by-one.
for i in range(ns.nchunks):
    # ADM read in randoms at the relevant indexes.
    rands = fitsio.read(ns.randomcat, rows=indexes[i*chunk:(i+1)*chunk])
    # ADM if requested, add the MTL-relevant columns.
    if ns.addmtl:
        rands = add_default_mtl(finalize_randoms(rands), seed=seed)
    # ADM re-randomize, as fitsio reads in order.
    write_indexes = np.arange(chunk)
    np.random.shuffle(write_indexes)
    # ADM open the file for writing.
    outfile = "{}-{}.fits".format(os.path.splitext(ns.randomcat)[0], i)
    log.info("Writing chunk {} from index {} to {}...t = {:.1f}s"
             .format(i, i*chunk, (i+1)*chunk, time()-start))
    fitsio.write(outfile, rands[write_indexes], extname='RANDOMS', header=hdr,
                 clobber=True)

print("Done...t = {:.1f}s".format(time()-start))
