#!python
#
# Script to run Prometheus on an image

from __future__ import print_function

import os
import time
import numpy as np
import prometheus
from prometheus import prometheus as pm,utils,models
from astropy.io import fits
from astropy.table import Table
from argparse import ArgumentParser
from dlnpyutils import utils as dln
from datetime import datetime
import subprocess
import traceback
try:
    import __builtin__ as builtins # Python 2
except ImportError:
    import builtins # Python 3

# Main command-line program
if __name__ == "__main__":
    parser = ArgumentParser(description='Run Prometheus on an image')
    parser.add_argument('files', type=str, nargs='+', help='Images FITS files or list.')
    parser.add_argument('-p','--psf', type=str, nargs=1, default='gaussian', help='PSF model type.')
    parser.add_argument('--iterdet', type=int, nargs=1, default=0, help='Number of detection iterations.')
    parser.add_argument('--ndetsigma', type=float, nargs=1, default=1.5, help='Number of sigma above background for detection.')
    parser.add_argument('--snrthresh', type=float, nargs=1, default=5, help='Minimum S/N for sources')
    parser.add_argument('--psfsubnei', action='store_true', help='Subtract PSF neighbors.')
    parser.add_argument('--psffitradius', type=str, nargs=1, default=None, help='Fitting radius when constructing PSF.  Default is PSF FWHM.')
    parser.add_argument('--fitradius', type=str, nargs=1, default=None, help='Fitting radius of stars in image.  Default is PSF FWHM.')
    parser.add_argument('--npsfpix', type=int, nargs=1, default=51, help='Size of PSF footprint.')
    parser.add_argument('--binned', action='store_true', help='Binned analytical model.')
    parser.add_argument('--lookup', action='store_true', help='Use an empirical lookup table.')
    parser.add_argument('--lorder', type=int, nargs=1, default=0, help='Order of empirical/lookup spatial variations.')
    parser.add_argument('--psftrim', type=str, nargs=1, default='', help='Trim PSF size to radius where psftrim fraction of flux is removed.')
    parser.add_argument('--norecenter', action='store_true', help='Do not fit the x/y centroid positions when PSF fitting')
    parser.add_argument('-r','--reject', action='store_true', help='Reject high RMS PSF stars.')
    parser.add_argument('--apcorr', action='store_true', help='Apply aperture correction.')               
    parser.add_argument('--outfile', type=str, nargs=1, default='', help='Output filename')
    parser.add_argument('-d','--outdir', type=str, nargs=1, default='', help='Output directory')        
    parser.add_argument('-l','--list', action='store_true', help='Input is a list of FITS files')
    parser.add_argument('-v','--verbose', type=int, nargs='?', const=1, default=0, help='Verbose output')
    parser.add_argument('-t','--timestamp', action='store_true', help='Add timestamp to Verbose output')    
    args = parser.parse_args()
    
    t0 = time.time()
    files = args.files
    psftype = dln.first_el(args.psf)
    iterdet = args.iterdet
    ndetsigma = dln.first_el(args.ndetsigma)
    snrthresh = dln.first_el(args.snrthresh)
    psfsubnei = args.psfsubnei
    psffitradius = args.psffitradius
    fitradius = args.fitradius    
    npsfpix = dln.first_el(args.npsfpix)
    binned = args.binned
    lookup = args.lookup
    lorder = dln.first_el(args.lorder)
    psftrim = dln.first_el(args.psftrim)
    if psftrim=='':
        psftrim = None
    else:
        psftrim = float(psftrim)
    norecenter = args.norecenter
    reject = args.reject
    apcorr = args.apcorr
    inpoutfile = dln.first_el(args.outfile)
    outdir = dln.first_el(args.outdir)
    if outdir == '':
        outdir = None
    else:
        if os.path.exists(outdir) is False:
            os.mkdir(outdir)
    inlist = dln.first_el(args.list)
    verbose = args.verbose
    timestamp = args.timestamp    


    # Check PSF type
    if psftype not in models._models.keys():
        raise ValueError('PSF type '+str(psftype)+' not supported.  Select '+', '.join(models._models.keys()))
    
    # Timestamp requested, set up logger
    if timestamp and verbose:
        logger = dln.basiclogger()
        logger.handlers[0].setFormatter(logging.Formatter("%(asctime)s [%(levelname)-5.5s]  %(message)s"))
        logger.handlers[0].setStream(sys.stdout)
        builtins.logger = logger   # make it available globally across all modules
    
    # Load files from a list
    if inlist is True:
        # Check that file exists
        if os.path.exists(files[0]) is False:
            raise ValueError(files[0]+' NOT FOUND')
        # Read in the list
        listfile = files[0]
        files = dln.readlines(listfile)
        # If the filenames are relative, add the list directory
        listdir = os.path.dirname(listfile)
        if listdir != '':
            fdir = [os.path.dirname(f) for f in files]
            rel, = np.where(np.char.array(fdir)=='')
            if len(rel)>0:
                for i in range(len(rel)):
                    files[rel[i]] = listdir+'/'+files[rel[i]]
    nfiles = len(files)
    
    if (verbose>0):
        if nfiles>1:
            print('--- Running Prometheus on %d images ---' % nfiles)
        else:
            print('--- Running Prometheus on %s ---' % files[0])
        
    # Loop over the files
    for i,f in enumerate(files):
        # Check that the file exists
        if os.path.exists(f) is False:
            print(f+' NOT FOUND')
            continue

        try:
            # Load the image
            image = prometheus.read(f)
    
            if (verbose>0):
                if (nfiles>1):
                    if (i>0): print('')
                    print('Image %3d:  %s  ' % (i+1,f))
                    
            # Run Prometheus
            out, model, sky, psf = pm.run(image,psfname=psftype,iterdet=iterdet,ndetsigma=ndetsigma,
                                          snrthresh=snrthresh,psfsubnei=psfsubnei,psffitradius=psffitradius,
                                          fitradius=fitradius,npsfpix=npsfpix,binned=binned,lookup=lookup,
                                          lorder=lorder,psftrim=psftrim,recenter=~norecenter,reject=reject,
                                          apcorr=apcorr,verbose=verbose)
                
            # Save the output
            if inpoutfile!='':
                outfile = inpoutfile
            else:
                fdir,base,ext = utils.splitfilename(f)
                outfile = base+'_prometheus.fits'
                if outdir is not None: outfile = outdir+'/'+outfile
                if (outdir is None) & (fdir != ''): outfile = fdir+'/'+outfile
            if verbose is True:
                print('Writing output to '+outfile)
            if os.path.exists(outfile): os.remove(outfile)
            hdulist = fits.HDUList()
            hdulist.append(fits.table_to_hdu(out))  # table
            hdulist[1].header['EXTNAME'] = 'SOURCE TABLE'
            hdulist[0].header['COMMENT']='Prometheus version '+str(prometheus.__version__)
            hdulist[0].header['COMMENT']='Date '+datetime.now().ctime()
            hdulist[0].header['COMMENT']='File '+f
            hdulist[0].header['COMMENT']='HDU#0 : Header Only'
            hdulist[0].header['COMMENT']='HDU#1 : Source catalog'
            hdulist[0].header['COMMENT']='HDU#2 : Model image'
            hdulist[0].header['COMMENT']='HDU#3 : Sky model image'
            hdulist[0].header['COMMENT']='HDU#4 : PSF model'
            hdulist.append(model.tohdu())  # model
            hdulist[2].header['EXTNAME'] = 'MODEL IMAGE'
            hdulist[2].header['COMMENT'] = 'Prometheus model image'
            hdulist.append(sky.tohdu())    # sky
            hdulist[3].header['EXTNAME'] = 'SKY MODEL IMAGE'            
            hdulist[3].header['COMMENT'] = 'Prometheus sky image'
            hdulist += psf.tohdu()    # psf, returns a list
            hdulist.writeto(outfile,overwrite=True)
            hdulist.close()

        except Exception as e:
            if verbose>0:
                traceback.print_exc()                
                print('Prometheus failed on '+f+' '+str(e))



