#!/usr/bin/env python
#
# See top-level LICENSE file for Copyright information
#
# -*- coding: utf-8 -*-

"""
This script is opening a PSF file and dumping the psf values at +-1,0 0,+-1 and +-1,+-1 pixels offsets at several location in the ccds to determine the parameters used in the cosmic ray rejection algorithm.
"""

from specter.psf import GaussHermitePSF,SpotGridPSF
from desiutil.log import get_logger
import astropy.io.fits as pyfits
import argparse
import numpy as np
import pylab



def main() :
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('-i', '--inpsf', type = str, nargs="*", default = None, required=True,
                        help = 'path of DESI PSF fits file')
    parser.add_argument('--sharpest', action = "store_true",
                        help = 'chose the wavelength and fiber for which the PSF is the sharpest, otherwise use average')


    args = parser.parse_args()
    log = get_logger()
    
    all_psfparams = []
    
    for filename in args.inpsf :
        log.info("reading in %s"%filename)
        h=pyfits.open(filename)
        psftype=h[0].header["PSFTYPE"]
        h.close()
        if psftype=="GAUSS-HERMITE" :
            psf=GaussHermitePSF(filename)
        elif psftype=="SPOTGRID" :
            psf=SpotGridPSF(filename)
        else :
            print("error ... cannot read PSF in file",filename)
            sys.exit(12)
        
        # 4 axis and 2 pixels per axis
        psfparams=np.zeros((4))
        counts = np.zeros((4))
        
        offset=np.zeros((4,2)).astype(int)
        offset[0]=[0,1]
        offset[1]=[1,0]
        offset[2]=[1,1]
        offset[3]=[1,-1]
        


        n=0.
        fstep=10

        for f in range(psf.nspec//fstep) :
            for wave in np.linspace(psf.wmin+100,psf.wmax-100,20) :
                x, y, pix = psf.xypix(int((f+0.5)*fstep),wave)
                if pix.size == 0 :
                    continue
                imax= np.argmax(pix)
                n0=pix.shape[0]
                n1=pix.shape[1]
                i0max=imax//n1
                i1max=imax%n1
                pix /= pix[i0max,i1max]
                for a in range(4) :
                    for s in [-1,1] :
                        i0=int(i0max+s*offset[a,0])
                        i1=int(i1max+s*offset[a,1])
                        if i0<0 or i0>=n0 or i1<0 or i1>=n1 : continue                    
                        if pix[i0,i1]>0 : 
                            if args.sharpest :
                                psfparams[a] = min(psfparams[a],pix[i0,i1])
                            else :
                                psfparams[a] += pix[i0,i1]
                                counts[a] += 1.
        if not args.sharpest :
            for a in range(4) :
                psfparams[a] /= counts[a]
        
        print(filename)
        print("psfparams=[%f,%f,%f,%f]"%(psfparams[0],psfparams[1],psfparams[2],psfparams[3]))
        all_psfparams.append(psfparams)
    
    if len(args.inpsf)>1 :
        all_psfparams = np.array(all_psfparams)
        average_psfparams = np.mean(all_psfparams,axis=0)
        min_psfparams = np.min(all_psfparams,axis=0)
        print("")
        print("average : [%f,%f,%f,%f]"%(average_psfparams[0],average_psfparams[1],average_psfparams[2],average_psfparams[3]))
        print("min     : [%f,%f,%f,%f]"%(min_psfparams[0],min_psfparams[1],min_psfparams[2],min_psfparams[3]))
    
    log.info("done")
    
if __name__ == '__main__':
    main()
