#!python

import sys,os
import argparse
import numpy as np
from astropy.table import Table
import matplotlib.pyplot as plt

from desispec.fibercrosstalk import read_crosstalk_parameters,eval_crosstalk

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description="""Fits the optical crosstalk obtained with desi_compute_fiber_crosstalk""")

parser.add_argument('-i','--infile', type = str, default = None, required=True,
                        help = 'path to a fits table with the output of desi_compute_fiber_crosstalk')
parser.add_argument('--camera', type = str, default = None, required=True,
                        help = 'b,r or z')
parser.add_argument('-o','--outfile', type = str, default = None, required=True,
                        help = 'output yaml file with fit parameters')
parser.add_argument('--only-norm', action='store_true',
                        help = 'only fit a normalization factor')
parser.add_argument('--batch', action='store_true',
                        help = 'do not show plots')

args = parser.parse_args()

filename=args.infile
table=Table.read(filename)

wave=table["WAVELENGTH"]

camera=args.camera.upper()

bfibers=(np.arange(20)+0.5)*25 # fiber number for bundle
dfiber = np.array([-2,-1,1,2])


if args.only_norm :
    print("only fit a normalization factor")
    params = read_crosstalk_parameters()
    file=open(args.outfile,"w")
    file.write("{}:\n".format(args.camera.upper()))
    print("{}:".format(args.camera.upper()))
    for df in dfiber :

        xtalk  = eval_crosstalk(args.camera,wave,bfibers,df,params,apply_scale=False)

        a=0.
        b=0.
        for bundle in range(20) :

            bfiber = bfibers[bundle]

            vals=table["CROSSTALK-B{:02d}-F{:+d}".format(bundle,df)]
            ivar=table["CROSSTALKIVAR-B{:02d}-F{:+d}".format(bundle,df)]

            a += np.sum(ivar*xtalk[bundle]**2)
            b += np.sum(ivar*xtalk[bundle]*vals)

        if a>0 :
            scale=b/a
        else :
            scale=0
        if scale<0 : scale=0.
        print(" F{:+d}:\n  SCALE: {:4.3f}".format(df,scale))
        file.write(" F{:+d}:\n  SCALE: {:4.3f}\n".format(df,scale))
    file.close()
    sys.exit(0)

cam=camera[0]

if cam=="B" :
    # model is
    # XTALK = P0*(W1-WAVE)/(W1-W0)+P1*(WAVE-W0)/(W1-W0)
    W0=3500
    W1=5500
    npar=2

if cam=="R" :
    # model is
    # XTALK = P0*(W1-WAVE)/(W1-W0)+P1*(WAVE-W0)/(W1-W0)
    W0=6000
    W1=7500
    npar=2

if cam=="Z" :
    # model is
    # XTALK = P0 + P1*(FIBER/250-1) + (P2 + P3*(FIBER/250-1)) * DW**WP + (P4 + P5*(FIBER/250-1)) * DW**WP2
    # with DW=(WAVE>W0)*((WAVE-W0)/(W1-W0))
    W0=8900
    W1=9800
    dwave=(wave-W0)/(W1-W0)
    WP=2.
    WP2=4.
    npar=6



params={}


for df in dfiber :

    A=np.zeros((npar,npar))
    B=np.zeros(npar)

    for bundle in range(20) :

        bfiber = bfibers[bundle]

        H=np.zeros((npar,wave.size))

        if cam=="B" or cam=="R" :
            H[0]=(W1-wave)/(W1-W0)
            H[1]=(wave-W0)/(W1-W0)

        if cam=="Z" :
            H[0]=np.ones(wave.size)
            H[1]=(bfiber/250.-1.)*np.ones(wave.size)
            H[2]=(wave>W0)*(np.abs(dwave))**WP
            H[3]=(bfiber/250.-1.)*(wave>W0)*(np.abs(dwave))**WP
            if npar==6 :
                H[4]=(wave>W0)*(np.abs(dwave))**WP2
                H[5]=(bfiber/250.-1.)*(wave>W0)*(np.abs(dwave))**WP2

        vals=table["CROSSTALK-B{:02d}-F{:+d}".format(bundle,df)]
        ivar=table["CROSSTALKIVAR-B{:02d}-F{:+d}".format(bundle,df)]

        for j in range(npar) :
            B[j] += np.sum(ivar*vals*H[j])
            for k in range(npar) :
                A[j,k] += np.sum(ivar*H[j]*H[k])

    Ai = np.linalg.inv(A)
    P  = Ai.dot(B)

    key="F{:+d}".format(df)
    params[key]={}

    for j in range(npar) :
        params[key][j]=P[j]




file=open(args.outfile,"w")
file.write("{}:\n".format(camera))
if cam=="B" or cam=="R" :
    file.write("# XTALK = P0*(W1-WAVE)/(W1-W0)+P1*(WAVE-W0)/(W1-W0)\n")
    file.write(" W0: {}\n".format(W0))
    file.write(" W1: {}\n".format(W1))
    for k1 in params :
        file.write(" {}:\n".format(k1))
        for k2 in params[k1] :
            file.write("  P{}: {}\n".format(k2,params[k1][k2]))

if cam=="Z" :
    file.write("# fiber cross talk model\n")
    if npar==4 :
        file.write("# XTALK = P0 + P1*(FIBER/250-1) + (P2 + P3*(FIBER/250-1)) * DW**WP\n")
    else :
        file.write("# XTALK = P0 + P1*(FIBER/250-1) + (P2 + P3*(FIBER/250-1)) * DW**WP + (P4 + P5*(FIBER/250-1)) * DW**WP2\n")
    file.write("# with DW=(WAVE>W0)*((WAVE-W0)/(W1-W0))\n")
    file.write(" W0: {}\n".format(W0))
    file.write(" W1: {}\n".format(W1))
    file.write(" WP: {}\n".format(WP))
    if npar==6: file.write(" WP2: {}\n".format(WP2))
    for k1 in params :
        file.write(" {}:\n".format(k1))
        for k2 in params[k1] :
            file.write("  P{}: {}\n".format(k2,params[k1][k2]))
file.close()
print("wrote",args.outfile)


for k1 in params :

    print(" {}:".format(k1))
    for k2 in params[k1] :
        print("  P{}: {}\n".format(k2,params[k1][k2]))

    df=int(k1.replace("F",""))
    print("df=",df)

    title=k1
    plt.figure(title)
    plt.subplot(111,title=title)

    colors = plt.cm.jet(np.linspace(0,1,20))

    P=params[k1]
    wvals=[]
    werrs=[]
    wmodels=[]
    for bundle in range(20) :
        bfiber = bfibers[bundle]
        if cam=="B" or cam=="R" :
            ct=P[0]*(W1-wave)/(W1-W0)+P[1]*(wave-W0)/(W1-W0)
        if cam=="Z":
            DW=(wave>W0)*(np.abs(dwave))
            ct=P[0]+P[1]*(bfiber/250.-1.)+(P[2]+P[3]*(bfiber/250.-1.))*DW**WP
            if npar==6 :
                ct+=(P[4]+P[5]*(bfiber/250.-1.))*DW**WP2
        plt.plot(wave,ct,color=colors[bundle])

        vals=table["CROSSTALK-B{:02d}-F{:+d}".format(bundle,df)]
        wvals.append(vals)
        wmodels.append(ct)

        ivar=table["CROSSTALKIVAR-B{:02d}-F{:+d}".format(bundle,df)]
        errs=(ivar>0)/np.sqrt(ivar*(ivar>0)+(ivar<=0))
        werrs.append(errs)
        ok=(ivar>0)&(errs<0.1)
        plt.errorbar(wave[ok],vals[ok],errs[ok],fmt=".",color=colors[bundle])

    plt.grid()
    plt.xlabel("wavelenth")
    plt.ylabel("fiber crosstalk")
    plt.figure(title+"-bis")
    wvals=np.array(wvals)
    werrs=np.array(werrs)
    wmodels=np.array(wmodels)

    colors = plt.cm.jet(np.linspace(0,1,wave.size))
    for i,w in enumerate(wave) :
        ok=(werrs[:,i]>0)&(werrs[:,i]<0.1)
        plt.errorbar(bfibers[ok],wvals[ok,i],werrs[ok,i],fmt=".",color=colors[i])
        plt.plot(bfibers,wmodels[:,i],"-",color=colors[i])
    plt.grid()
    plt.xlabel("fiber")
    plt.ylabel("fiber crosstalk")

if not args.batch :
    plt.show()
