#!/usr/bin/env python

import os
import fitsio
from glob import glob
import numpy as np

from time import time
start = time()

from desiutil.log import get_logger
log = get_logger()

from argparse import ArgumentParser
ap = ArgumentParser(description='Create a FITS files for which each column is the average of columns in multiple FITS files. Retains the header from the FIRST input file encountered.')
ap.add_argument("infilestring",
                help="A string with a wildcard to identify multiple files. Passed to glob. Will likely have to be passed in quotation marks.")
ap.add_argument("outfile",
                help="Output file name")

ns = ap.parse_args()

fns = sorted(glob(ns.infilestring))

nfns = len(fns)
objs, hdr = fitsio.read(fns[0], header=True)
log.info('Done with file {}...t = {:.1f}s'.format(fns[0], time()-start))
for fn in fns[1:]:
    obj = fitsio.read(fn)
    # ADM sum across columns.
    for col in objs.dtype.names:
        objs[col] += obj[col]
    log.info('Done with file {}...t = {:.1f}s'.format(fn, time()-start))

# ADM divide each column by the number of files to return the mean.
for col in objs.dtype.names:
    objs[col] = objs[col] / nfns

# ADM retain a list of the units.
tunits = ["TUNIT{}".format(i) for i in range(1, hdr["TFIELDS"]+1)]
units = [hdr[tunit] if tunit in hdr.keys() else "" for tunit in tunits]

# ADM write atomically.
fitsio.write(ns.outfile+'.tmp', objs, units=units, extname=hdr["EXTNAME"],
             header=hdr, clobber=True)
os.rename(ns.outfile+'.tmp', ns.outfile)

log.info('Finished writing to {}...t = {:.1f}s'.format(ns.outfile, time()-start))
