#!/usr/bin/env python

from __future__ import print_function, division

import os
import fitsio
from fitsio import FITS
import numpy as np

from time import time
start = time()

#import warnings
#warnings.simplefilter('error')

from desiutil.log import get_logger
log = get_logger()

# ADM the default number of chunks to write in to save memory.
nchunks = 10

from argparse import ArgumentParser
ap = ArgumentParser(description='Concatenate multiple FITS files into one large file. Retains the header from the FIRST listed input file')
ap.add_argument("infiles", 
                help="SEMI-COLON separated list of input files, which may have to be enclosed by quotes (e.g. 'file1;file2;file3;file4')")
ap.add_argument("outfile", 
                help="Output file name")
ap.add_argument("targtype", choices=['skies', 'randoms', 'targets', 'gfas'],
                help="Type of target run with parallelization/multiprocessing code to gather")
ap.add_argument("--norandomize", action='store_true',
                help="Do NOT randomly shuffle the output file before writing it (using seed 626)")
ap.add_argument("--numchunks", type=int,
                help='number of chunks in which to write to save memory [defaults to {}]'.format(nchunks),
                default=nchunks)
ap.add_argument("--skip", action='store_true',
                help="Skip input files that don't exist without flagging an error")
ap.add_argument("--columns",
                help='Limit the output file to just this set of columns (comma-separated without spaces, e.g., "RA,DEC"). \
                Pass the string "pixweight" to limit to just columns needed for making the pixweight files',
                default=None)

ns = ap.parse_args()

# ADM the gfas are actually called 'gfa_targets'
tt = ns.targtype
if ns.targtype == 'gfas':
    tt = "GFA_TARGETS"
extname = tt.upper()

# ADM the columns to limit to, if requested.
if ns.columns is None:
    columns = None
elif ns.columns=="pixweight":
    obsval = ['PSFDEPTH', 'NOBS', 'GALDEPTH', 'PSFSIZE']
    columns = ['RA', 'DEC', 'EBV', 'MASKBITS', 'PSFDEPTH_W1', 'PSFDEPTH_W2'] + \
    ['{}_{}'.format(val, b) for val in obsval for b in 'GRZ']
else:
    columns = [ col for col in ns.columns.split(',') ]

# ADM convert passed csv strings to lists.
if ns.skip:
    fns = [ fn for fn in ns.infiles.split(';') if
            os.path.isfile(os.path.expandvars(fn))]
else:
    fns = [ fn for fn in ns.infiles.split(';') ]

# ADM read the header from the first file.
fx = fitsio.FITS(fns[0])
hdr = fx[1].read_header()

# ADM if combining files, combine the HEALPixel coverage. If the
# ADM HEALPixel coverage is the same in each file, increase the density.
hpx0 = fitsio.read_header(fns[0], extname)["FILEHPX"]
hpx = [hpx0]
if ns.targtype == 'randoms':
    dens = fitsio.read_header(fns[0], extname)["DENSITY"]
for fn in fns[1:]:
    hpxfn = fitsio.read_header(fn, extname)["FILEHPX"]
    if hpxfn != hpx0:
        hpx.append(hpxfn)
    elif ns.targtype == 'randoms':
        dens += fitsio.read_header(fn, extname)["DENSITY"]
hdr["FILEHPX"] = ",".join(hpx)
if ns.targtype == 'randoms':
    hdr["DENSITY"] = dens

# ADM retain a list of the units.
tunits = ["TUNIT{}".format(i) for i in range(1, hdr["TFIELDS"]+1)]
units = [hdr[tunit] if tunit in hdr.keys() else "" for tunit in tunits]

log.info('Begin writing {} to {}...t = {:.1f}s'
         .format(ns.targtype, ns.outfile, time()-start))

# ADM read the input files.
data = []
for fn in fns:
    log.info('Reading file {}...t = {:.1f}s'.format(fn, time()-start))
    fndata = fitsio.read(fn, columns=columns)
    data.append(fndata)
data = np.concatenate(data)
ndata = len(data)

indexes = np.arange(ndata)
# ADM shuffle to ensure randomness.
if not ns.norandomize:
    log.info("Read in {:.1e} objects. Shuffling indexes...t = {:.1f}s"
             .format(ndata, time()-start))
    # ADM the seed for the "final" shuffle."
    reseed = 626
    hdr["RESEED"] = reseed
    np.random.seed(reseed)
    np.random.shuffle(indexes)

#ADM write in chunks to save memory.
chunk = ndata//ns.numchunks
outy = fitsio.FITS(ns.outfile+".tmp", 'rw', clobber=True)
for i in range(ns.numchunks):
    log.info("Writing chunk {}/{} from index {} to {}...t = {:.1f}s"
             .format(i+1, ns.numchunks, i*chunk, (i+1)*chunk-1, time()-start))
    datachunk = data[indexes[i*chunk:(i+1)*chunk]]
    # ADM if this is the first chunk, write the data and header...
    if i == 0:
        outy.write(datachunk, extname=extname, header=hdr)
    # ADM ...otherwise just append to the existing file object.
    else:
        outy[-1].append(datachunk)
    # ADM append any remaining data.
datachunk = data[indexes[ns.numchunks*chunk:]]
if len(datachunk) > 0:
    log.info("Writing final partial chunk from index {} to {}...t = {:.1f}s"
             .format(ns.numchunks*chunk, len(data)-1, time()-start))
    outy[-1].append(datachunk)
outy.close()

os.rename(ns.outfile+'.tmp', ns.outfile)

log.info('Finished writing to {}...t = {:.1f}s'.format(ns.outfile, time()-start))
