#!python


import astropy.io.fits as pyfits
import fitsio
from desispec.calibfinder import CalibFinder
from astropy.table import Table
import matplotlib.pyplot as plt
import numpy as np
import sys
import re
import argparse
import matplotlib.pyplot as plt

from desiutil.log import get_logger

def clipped_var(mx,x2) :

    nsig=5.
    ovar=0.001

    # start with NMAD , less sensitive to outliers
    var = ( 1.4826*np.median( np.abs(np.sqrt(x2))) )**2
    #print("ini",var,0)
    for loop in range(20) :
        ok=(x2<nsig**2*var)
        var=np.mean(x2[ok])
        #print(loop,var,x2.size-ok.size)
        if np.abs(var/ovar-1)<0.0001 :
            break
        ovar=var
    return np.mean(mx[ok]),var,ok.size,loop


def _parse_sec_keyword(value):
    '''
    parse keywords like BIASSECB='[7:56,51:4146]' into python slices

    python and FITS have almost opposite conventions,
      * FITS 1-indexed vs. python 0-indexed
      * FITS upperlimit-inclusive vs. python upperlimit-exclusive
      * FITS[x,y] vs. python[y,x]

    i.e. BIASSEC2='[7:56,51:4146]' -> (slice(50,4146), slice(6,56))
    '''
    m = re.search('\[(\d+):(\d+)\,(\d+):(\d+)\]', value)
    if m is None:
        m = re.search('\[(\d+):(\d+)\, (\d+):(\d+)\]', value)
        if m is None :
            raise ValueError('unable to parse {} as [a:b, c:d]'.format(value))

    xmin, xmax, ymin, ymax = tuple(map(int, m.groups()))

    return np.s_[ymin-1:ymax, xmin-1:xmax]


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Compute the CCD gains (in e/ADU) using a series of similar images. Images are paired according to their exposure time")
parser.add_argument('-i','--image', type = str, default = None, required = True, nargs = "*",
                    help = 'path of preprocessed image fits files')
parser.add_argument('--exptime-keyword', type = str, default = "EXPTIME",
                    help = 'change exposure time keyword for pairing images')
parser.add_argument('--max-mean-flux',type = float, default = 8000)
parser.add_argument('--min-pixel-flux',type = float, default = -1000)
parser.add_argument('--max-pixel-flux',type = float, default = 15000)
parser.add_argument('--bin-size',type = float, default = 100)
parser.add_argument('--npx',type = int, default = 2,
                    help = 'number of adjacent pixels in a row to add before computing variance')
parser.add_argument('--npy',type = int, default = 1,
                    help = 'number of adjacent pixels in a column) to add before computing variance')
parser.add_argument('--margin',type = int, default = 200,
                    help = 'remove margins around CCD')
parser.add_argument('--amplifiers',type = str, default=None, help="amplifiers being studied (default is all)")
parser.add_argument('--deg',type = int, default=3, help="max degree of polynomial fit")
parser.add_argument('--fix-rdnoise',action='store_true', help="do not refit readnoise")
parser.add_argument('--plot',action="store_true",help="show the fit")
parser.add_argument('--outfile',type = str, default=None, help="save PTC values in ASCII file")
parser.add_argument('--outgain',type = str, default=None, help="save gain values in simple ASCII table")


args = parser.parse_args()
log  = get_logger()

exposure_times = []
camera = None

# loop on preprocessed images
filenames=args.image
for filename in filenames :
    hdulist=pyfits.open(filename)
    this_camera=hdulist[0].header["CAMERA"].strip()
    if camera is None :
        camera = this_camera
    else :
        if this_camera != camera :
            log.error("Not the same camera for all images, I find {} and {}".format(camera,this_camera))
            sys.exit(1)

    exposure_times.append(hdulist[0].header[args.exptime_keyword])

#unique_exposure_times = np.unique(exposure_times)

#log.info("Exposure times = {}".format(unique_exposure_times))


unique_exposure_times = np.array([exposure_times[0]])
threshold=0.2 # sec
for exptime in exposure_times[1:] :
    diff=np.min(np.abs(unique_exposure_times-exptime))
    if diff>threshold : unique_exposure_times = np.append(unique_exposure_times,exptime)

log.info("Exposure times = {}".format(unique_exposure_times))

pairs = []
for exptime in unique_exposure_times :
    ii=np.where(np.abs(exposure_times-exptime)<threshold)[0]
    npairs=ii.size//2
    for p in range(npairs) :
        pairs.append( [ filenames[ii[p*2]],filenames[ii[p*2+1]] ] )
log.info("Pairs of exposures = {}".format(pairs))

if args.amplifiers is None :
    # read one header to get list of amplifiers
    head = fitsio.read_header(filenames[0])
    cfinder = CalibFinder([head])
    args.amplifiers=cfinder.value("AMPLIFIERS")
    print("amplifiers=",args.amplifiers)

ofile=None
if args.outfile is not None :
    ofile=open(args.outfile,"w")
    ofile.write("# mean var nval amplifier pair\n")
    ofile.write("## amplifiers = {} for {}\n".format(np.arange(len(args.amplifiers)),args.amplifiers))

# Generate a table to save the gain values
gain_tbl = Table()
gain_tbl['AMP'] = [amp for amp in args.amplifiers]
gain_tbl['GAIN'] = np.zeros(len(gain_tbl))
gain_tbl['ERRGAIN'] = np.zeros(len(gain_tbl))




for a,amp in enumerate(args.amplifiers) :

    ax=[] # mean fluxes
    ay=[] # variances for mean fluxes
    ardn=[] # variance of readnoise
    an=[] # number of data points in bin

    nbins=int(args.max_mean_flux/args.bin_size)

    for p,pair in enumerate(pairs) :
        log.info("pair #{} {} {}".format(p,pair[0],pair[1]))
        h1=pyfits.open(pair[0])
        h2=pyfits.open(pair[1])
        k="CCDSEC%s"%amp

        yy,xx=_parse_sec_keyword(h1[0].header[k])

        img1=h1[0].data[yy,xx]
        img2=h2[0].data[yy,xx]
        ok = (h1["MASK"].data[yy,xx]==0)*(h1["IVAR"].data[yy,xx]>0)*(h2["MASK"].data[yy,xx]==0)*(h2["IVAR"].data[yy,xx]>0)
        img1 *= ok
        img2 *= ok


        rdnoise1=h1[0].header["OBSRDN%s"%amp]
        rdnoise2=h1[0].header["OBSRDN%s"%amp]
        rdnoise=np.sqrt((rdnoise1**2+rdnoise2**2)/2.)
        #log.info("rdnoise : {},{} -> {}".format(rdnoise1,rdnoise2,rdnoise))


        if args.margin>0 :
            margin=args.margin

            if amp=="B" or amp=="D" :
                img1=img1[:,:-margin] # remove margin pix on margin
                img2=img2[:,:-margin] # remove margin pix on margin
            else :
                img1=img1[:,margin:] # remove margin pix on margin
                img2=img2[:,margin:] # remove margin pix on margin
            if amp=="A" or amp=="B" :
                img1=img1[margin:]
                img2=img2[margin:]
            else :
                img1=img1[:-margin]
                img2=img2[:-margin]

        img1[img1>args.max_pixel_flux]=1e40 # kill entry
        img2[img2>args.max_pixel_flux]=1e40 # kill entry

        if args.npx>1 : # rebinning lines
            r=args.npx
            n0=img1.shape[0]
            n1=(img1.shape[1]//r)*r
            img1=img1[:,:n1].reshape(n0,n1//r,r).sum(axis=2)
            img2=img2[:,:n1].reshape(n0,n1//r,r).sum(axis=2)
        if args.npy>1 : # rebinning columns
            r=args.npy
            n0=(img1.shape[0]//r)*r
            n1=img1.shape[1]
            img1=img1[:n0,:].reshape(n0//r,r,n1).sum(axis=1)
            img2=img2[:n0,:].reshape(n0//r,r,n1).sum(axis=1)
        npix=args.npx*args.npy

        if 0 : # simple scalar scaling
            ok=(img1>200)&(img1<args.max_pixel_flux)&(img2>200)&(img2<args.max_pixel_flux)
            scale=np.exp(np.median(np.log(img1[ok]/img2[ok]))) # better
            log.info("pair %d scale = %f"%(p,scale))
            img2 *= scale
        if 1 : # 1D scaling
            scale=np.zeros(img1.shape[0])
            for j in range(img1.shape[0]) :
                ok=(img1[j]>args.min_pixel_flux*npix)&(img1[j]<args.max_pixel_flux*npix)&(img2[j]>args.min_pixel_flux*npix)&(img2[j]<args.max_pixel_flux*npix)
                sum1=np.sum(img1[j][ok])
                sum2=np.sum(img2[j][ok])
                if sum2>100 :
                    scale[j]=sum1/sum2 # better
            for j in range(img1.shape[0]) :
                if np.abs(scale[j]-1)>0.2 :
                    img1[j]=1e40 # kill
                    img2[j]=1e40 # kill
                else :
                    img2[j] *= scale[j]
            if args.plot :
                plt.figure("ratio")
                j=np.arange(img1.shape[0])
                plt.plot(j[scale>0],scale[scale>0],label="pair %d"%p)


        ok=(img1>args.min_pixel_flux*npix)&(img1<args.max_pixel_flux*npix)&(img2>args.min_pixel_flux*npix)&(img2<args.max_pixel_flux*npix)
        xx=(img1[ok]+img2[ok])/2.
        yy=(img1[ok]-img2[ok])**2/2. # we use the difference of images so the var of a single flux = var(diff)/2.
        bins=(xx/args.bin_size).astype(int)
        ok=(bins>=0)&(bins<nbins)
        ubins=np.unique(bins[ok])

        x=[]
        y=[]
        n=[]

        for b in ubins :
            ok=np.where((xx>=args.bin_size*b)&(xx<args.bin_size*(b+1)))[0]
            if ok.size<400 : continue
            mean,var,ndata,nloop = clipped_var(xx[ok],yy[ok])
            log.debug("flux=%f var=%f n=%d nloop=%d"%(mean,var,ndata,nloop))
            x.append(mean)
            y.append(var)
            n.append(ndata)
            ax.append(mean)
            ay.append(var)
            an.append(ndata)
            ardn.append(npix*rdnoise**2)


        if ofile is not None :
            for i in range(len(x)) :
                ofile.write("%f %f %d %d %d\n"%(x[i],y[i],n[i],a,p))


    # fit the data for this amplifier

    ok=(np.array(ay)>0.1) # var
    x=np.array(ax)[ok]
    y=np.array(ay)[ok]
    y0=np.array(ardn)[ok]
    n=np.array(an)[ok]
    err=np.sqrt(2./n)*y
    xs=np.mean(x)

    w=1./(err**2+1.**2)

    def myfit(w,x,y,deg) :
        A=np.zeros((deg+1,deg+1))
        B=np.zeros((deg+1))
        xs=np.mean(x)
        for i in range(deg+1) :
            B[i] = np.sum(w*(y-y0)*(x/xs)**(i))
            for j in range(i+1) :
                A[i,j] = np.sum(w*(x/xs)**(i+j))
                if (j!=i) : A[j,i] = A[i,j]
        if args.fix_rdnoise : # contrain first term to 0 because this is readout noise
            A[0,0] += 1e12
        Ai=np.linalg.inv(A)
        coef=Ai.dot(B)
        return coef

    def mypol(x,coef) :
        res=np.zeros(x.shape)
        for i in range(coef.size) :
            res += coef[i]*(x/xs)**i
        return res

    coef=[]
    gain=[]
    for deg in range(1,args.deg+1) :
        coef.append(myfit(w,x,y,deg))
        gain.append(1./coef[deg-1][1]*xs)
        log.info("%s %s deg=%d gain = %4.3f e/ADU"%(camera,amp,deg,gain[deg-1]))

    mgain=np.mean(gain)
    errgain=np.max(gain)-np.min(gain)
    log.info("%s %s gain = $%4.3f \\pm %4.3f$ e/ADU"%(camera,amp,mgain,errgain))

    # Save
    gain_tbl['GAIN'][a] = mgain
    gain_tbl['ERRGAIN'][a] = errgain

    #chi2=np.sum((y-y0-mypol(x,coef))**2/err**2)
    #ndf=x.size-coef.size
    #gain=1./coef[1]*xs

    if args.plot :
        ms=5
        fig=plt.figure("{}-{}".format(camera,amp))
        plt.subplot(2,1,1)
        plt.errorbar(x,y,err,fmt="o",ms=ms,color="gray")
        tx=np.arange(np.min(x),np.max(x),100)
        my0=np.mean(y0)
        for d in range(1,args.deg+1) :
            plt.plot(tx,my0+mypol(tx,coef[d-1]),label="deg %d fit, gain=%4.3f"%(d,gain[d-1]))

        plt.ylabel("variance(flux)")
        plt.legend(loc="upper left",title="%s-%s"%(camera,amp))
        plt.subplot(2,1,2)
        plt.errorbar(x,y-y0-mypol(x,coef[0]),err,fmt="o",ms=ms,color="gray")
        lfit=mypol(tx,coef[0])
        for d in range(1,args.deg+1) :
            plt.plot(tx,mypol(tx,coef[d-1])-lfit,label="deg %d fit"%d)
        plt.plot(tx,0*tx,"-",color="b")
        if npix==1 :
            plt.xlabel("flux (single pixel)")
        else :
            plt.xlabel("flux (sum of %dx%d pixels)"%(args.npx,args.npy))
        plt.ylabel("variance(flux) - linear relation")
        fig.savefig("%s-%s.png"%(camera,amp))
        plt.figure("ratio")
        plt.legend(fontsize="small")
    # end of loop on amplifiers


if args.outgain is not None:
    gain_tbl.write(args.outgain, format='ascii.fixed_width', overwrite=True)


if ofile is not None :
    ofile.close()
    log.info("wrote {}".format(args.outfile))



print(camera)
for j in range(len(gain_tbl["AMP"])) :
    print("{}-{} gain = {:4.3f} +- {:4.3f}".format(camera,gain_tbl['AMP'][j],gain_tbl['GAIN'][j],gain_tbl['ERRGAIN'][j]))
print("")
print(camera)
for j in range(len(gain_tbl["AMP"])) :
    print("GAIN{}: {:4.3f}".format(gain_tbl['AMP'][j],gain_tbl['GAIN'][j]))

if args.plot :
    plt.show()


    #plt.show()
