#!/usr/bin/env python


import sys,string
import numpy as np
import astropy.io.fits as pyfits
import argparse
from desispec.preproc import  _parse_sec_keyword
import matplotlib.pyplot as plt
from desiutil.log import get_logger

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Compute the electronic cross-talk coefficient among the amplifiers of a CCD image",
                                 epilog='''
                                 Input is a preprocessed arc lamp image with a sufficient number of bright lines
                                 to evaluate unambiguously the cross-talk.
                                 ''')
parser.add_argument('-i','--image', type = str, default = None, required = True, nargs = "*",
                    help = 'path of preprocessed image fits files')
parser.add_argument('-t','--threshold', type = float, default = 200, required = False,
                    help = 'flux threshold to detect cross-talk in other amps')
parser.add_argument('-m','--maxflux',type = float, default = 50000, required = False,
                    help = 'max flux to avoid saturation issues')
parser.add_argument('--plot',action="store_true",help="show the fit")


args        = parser.parse_args()
log = get_logger()


coeffs={}

for filename in args.image :
    log.info("analysing %s"%filename)
    image_file  = pyfits.open(filename)
    flux = image_file[0].data
    ivar = image_file["IVAR"].data*(image_file["MASK"].data==0)
    flux *= (ivar>0)

    # work with gradient of image along wavelength
    # to erase sensitivity of estimator to background level
    nflux = np.zeros(flux.shape)
    nflux[1:-1] = flux[1:-1]-(flux[2:]+flux[:-2])/2.
    flux = nflux
    

    header = image_file[0].header

    amplifiers=["A","B","C","D"]
    # with python axis 1 = fits axis 0 = x corresponding to column number (as in ds9 display)
    # and  python axis 0 = fits axis 1 = y corresponding to row number (as in ds9 display)
    # the ccd looks like :
    # C D
    # A B
    # for cross talk, we need a symmetric 4x4 flip_matrix
    # of coordinates ABCD giving flip of both axis
    # when computing crosstalk of
    #    A   B   C   D
    #
    # A  AA  AB  AC  AD
    # B  BA  BB  BC  BD
    # C  CA  CB  CC  CD
    # D  DA  DB  DC  BB
    # orientation_matrix_defines change of orientation
    #

    fip_axis_0= np.array([[1,1,-1,-1],
                          [1,1,-1,-1],
                          [-1,-1,1,1],
                          [-1,-1,1,1]])

    fip_axis_1= np.array([[1,-1,1,-1],
                          [-1,1,-1,1],
                          [1,-1,1,-1],
                          [-1,1,-1,1]])


    for a1 in range(4) :
        amp1=amplifiers[a1]
        ccdsec1=header["CCDSEC%s"%amp1]
        ii1 =  _parse_sec_keyword(ccdsec1)

        a1flux=flux[ii1]
        mask=(flux[ii1]>args.threshold)&(flux[ii1]<args.maxflux)
        if a1flux[mask].size<10 :
            log.warning("not enough pix above threshold in %s"%amp1)
            continue

        for a2 in range(4) :
            if a1 == a2 :
                continue
            amp2=amplifiers[a2]
            ccdsec2=header["CCDSEC%s"%amp2]
            ii2 =  _parse_sec_keyword(ccdsec2)
            a2flux=flux[ii2]
            a2ivar=ivar[ii2]

            if fip_axis_0[a1,a2]==-1 :
                a2flux=a2flux[::-1]
                a2ivar=a2ivar[::-1]
            if fip_axis_1[a1,a2]==-1 :
               a2flux=a2flux[:,::-1]
               a2ivar=a2ivar[:,::-1]

            mask12 = mask&(a2ivar>0)


            npix=np.sum(mask12)

            f1=a1flux[mask].ravel()
            f2=a2flux[mask].ravel()


            nbins=int(npix/50.)
            fbins=args.threshold*np.exp(np.linspace(0.,np.log(args.maxflux/args.threshold),nbins))

            mf1=np.zeros(fbins.size-1)
            mf2=np.zeros(fbins.size-1)
            ef2=np.zeros(fbins.size-1)
            for i in range(fbins.size-1) :
                ok=(f1>=fbins[i])&(f1<fbins[i+1])
                if np.sum(ok)<5 :
                    continue
                if1=f1[ok]
                if2=f2[ok]
                imf2=np.median(if2) # robust estimate in bin
                imf1=np.median(if1) # also a median to avoid biases
                irms2=1.4826*np.median(np.abs(if2-imf2)) # robust estimate of rms in bins
                npt2=np.sum(np.abs(if2-imf2)<3*irms2) # number of points within 3 sigma
                if npt2<3 :
                    continue
                err2=1.4*irms2/np.sqrt(npt2) # error estimation of median
                mf1[i]=imf1
                mf2[i]=imf2
                ef2[i]=err2
            i=np.where((mf1>0)&(ef2>0))[0]
            if i.size<2 :
                log.warning("cannot measure cross-talk for %s%s"%(amp1,amp2))
                continue

            mf1=mf1[i]
            mf2=mf2[i]
            ef2=ef2[i]

            w=1./ef2**2
            
            #A = np.zeros((2,2))
            #B = np.zeros(2)
            #A[0,0] = np.sum(w)
            #A[0,1] = A[1,0] = np.sum(w*mf1)
            #A[1,1] = np.sum(w*mf1**2)
            #B[0]   = np.sum(w*mf2)
            #B[1]   = np.sum(w*mf1*mf2)
            #Ai = np.linalg.inv(A)
            #X  = Ai.dot(B)
             
            # with gradients, do not expect cst term
            A = np.sum(w*mf1**2)
            B = np.sum(w*mf1*mf2)
            cross_talk = B/A
            cross_err  = 1./np.sqrt(A)
            
            pair = "%s%s"%(amp1,amp2)
            log.info("%s = %f +- %f sig=%f"%(pair,cross_talk,cross_err,np.abs(cross_talk/cross_err)))

            if not pair in coeffs :
                coeffs[pair]=[]
            coeffs[pair].append(cross_talk)

            if args.plot :
                plt.figure("c%s%s"%(amp1,amp2))
                plt.errorbar(mf1,mf2,ef2,color="b",fmt="o")
                plt.plot(mf1,mf1*cross_talk,"-",color="r")
                plt.xlabel("flux in amp %s"%amp1)
                plt.ylabel("flux in amp %s"%amp2)

print("")
print("average values")
print("==========================================")

pairs = list(coeffs.keys())
pairs = np.sort(pairs)
cross_talk={}
cross_talk_err={}
cross_talk_rms={}

for pair in pairs :
    if len(coeffs[pair])>1 :
        cross_talk[pair]=np.median(coeffs[pair])
        cross_talk_rms[pair]=np.std(coeffs[pair])
        cross_talk_err[pair]=np.sqrt(np.pi/2.)*cross_talk_rms[pair]/np.sqrt(len(coeffs[pair]))
        snr=cross_talk[pair]/cross_talk_err[pair]
        log.info("%s = %g +- %g sig=%f ( rms of images  = %g )"%(pair,cross_talk[pair],cross_talk_err[pair],np.abs(snr),cross_talk_rms[pair]))
    else :
        cross_talk[pair]=coeffs[pair][0]
        cross_talk_rms[pair]=0
        cross_talk_err[pair]=0
        log.info("%s = %g"%(pair,coeffs[pair][0]))
for pair in pairs :
    print("CROSSTALK%s: %g"%(pair,cross_talk[pair]))
if args.plot :
    plt.show()
