#!python

import sys,os
import argparse
import matplotlib.pyplot as plt
import numpy as np
import astropy.io.fits as pyfits
from desispec.io import read_xytraceset
import fitsio

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                 description="PCA of sky corrections")
parser.add_argument('-i','--infile', type = str, default = None, required = True, nargs="*",
                    help = 'path to skycorr fits files')
parser.add_argument('-o','--outfile', type = str, default = None, required = True,
                    help = 'output fits file with PCA coefficients')
parser.add_argument('--ncomp', type = int, default = 3, required = False,
                    help = 'number of PCA components to save (in addition to the mean)')
parser.add_argument('--min-exptime', type = float, default = 700, required = False,
                    help = 'minimum exposure time (to ensure a good precision on sky lines)')
args = parser.parse_args()


output_hdulist=pyfits.HDUList()
first=True




wave=None
head=None
night_min=None
night_max=None

for what in [ "DWAVE" , "DLSF" ] :
    deltas=[]
    print("processing",what)
    print("reading {} files".format(len(args.infile)))
    for index,filename in enumerate(args.infile) :

        print("reading {}/{} {}".format(index+1,len(args.infile),filename))

        frame_filename=filename.replace("skycorr-","frame-")
        if not os.path.isfile(frame_filename) :
            print("warning: cannot open frame file {}".format(frame_filename))
            continue

        head=fitsio.read_header(frame_filename)
        exptime=head["EXPTIME"]
        if exptime < args.min_exptime :
            print("ignore {} with exptime={} < {}".format(filename,exptime,args.min_exptime))
            continue

        night=head["NIGHT"]
        if night_min is None :
            night_min = night
            night_max = night
        else :
            night_min = min(night_min,night)
            night_max = max(night_max,night)

        h=pyfits.open(filename)
        delta=h[what].data
        rms=np.std(delta)
        if rms==0 or rms>1.:
            print("no valid data for {}".format(filename))
            continue
        print("{} rms {} = {}".format(filename,what,rms))

        mdelta=np.max(np.abs(delta))
        if what=="DWAVE" :
            threshold=0.3
        elif what=="DLSF" :
            threshold=2.
        else:
            print(f"ERROR: unknown threshold for {what}; please update code")
            sys.exit(1)

        if mdelta>threshold :
            print(f"skip {filename} with max {what} = {mdelta} > {threshold} A")
            continue
        threshold=0.
        deltas.append(delta)

        if wave is None :
            wave=h["WAVELENGTH"].data

    print("convert to array")
    deltas=np.array(deltas)
    print("number of measurements=",deltas.shape[0])

    print("subtract mean")
    mdeltas=np.mean(deltas,axis=0)
    deltas[:] -= mdeltas

    print("fill scalar product matrix")
    nmeas=len(deltas)
    mat=np.zeros((nmeas,nmeas))
    for i in range(nmeas) :
        print("{}/{}".format(i+1,nmeas))
        for j in range(i,nmeas) :
            mat[i,j] = np.sum(deltas[i]*deltas[j])
            if j != i :
                mat[j,i] = mat[i,j]

    print("eigen decomposition")
    eigenvals,v = np.linalg.eig(mat)
    eigenvals=np.real(eigenvals)
    v=np.real(v)
    eigenvals[eigenvals<0]=0.
    ii=np.argsort(eigenvals)[::-1]
    print("eigen vals")
    print(eigenvals[ii])

    # testing
    #print(mat.dot(v[:,5]),"=?",eigenvals[5]*v[:,5])


    if first :
        output_hdulist.append(pyfits.PrimaryHDU(mdeltas))
        output_hdulist[0].header["EXTNAME"]=what+"_MEAN"
        first=False
    else :
        output_hdulist.append(pyfits.ImageHDU(mdeltas,name=what+"_MEAN"))

    ncomp=args.ncomp
    for e,i in enumerate(ii[:ncomp]) :
        vv=v[:,i] # eigen vector i with eigen value of rank e+1
        tmp=np.zeros(mdeltas.shape)
        for j in range(v.shape[0]) :
            tmp += vv[j]*deltas[j]
        output_hdulist.append(pyfits.ImageHDU(tmp,name=what+"_EIG{}".format(e+1)))

    output_hdulist.append(pyfits.ImageHDU(eigenvals[ii][:40],name=what+"_EIGENVALS"))

output_hdulist.append(pyfits.ImageHDU(wave,name="WAVELENGTH"))
if head is not None :
    for k in ["CAMERA","DETECTOR"] :
        output_hdulist[0].header[k]=head[k]
if night_min is not None :
    output_hdulist[0].header["NIGHTMIN"]=night_min
    output_hdulist[0].header["NIGHTMAX"]=night_max
output_hdulist.writeto(args.outfile,overwrite=True)
print("wrote",args.outfile)
