#!/usr/bin/env python

import numpy as np
import sys
import argparse
import matplotlib.pyplot as plt

from desiutil.log import get_logger
from desispec.io import read_image
from desispec.io.xytraceset import read_xytraceset
from desispec.preproc import _parse_sec_keyword
from desispec.calibfinder import CalibFinder
#from desispec.darktrail import *

def max_profile_index_for_fit(y) :
    nh = y.size//2
    n1 = nh//2+np.argmax(y[nh//2:nh-3])
    return np.argmin(y[:n1])+1
    
def compute_fiber_cross_profile(image,tset,xyslice) :
    """
    Computes a cross dispersion profile averaged over fibers and CCD rows
    for pixels in a subsample of the image

    Args:
      image : desispec.Image class instance 
      xyslice : tuple of python slices (yy,xx) where yy is indexing CCD rows (firt index in 2D numpy array) and xx is the other
              
    Returns: 
      x: 1D numpy array with the pixel grid
      prof: 1D numpy array with the cross dispersion profile sampled on x
    
    """

    log = get_logger()
    
    log.info("compute_fiber_cross_profile")

    xmargin = 100
    ymargin = 50
    #dx = np.arange(-xmargin,xmargin+1)
    oversampling=1
    dx = np.linspace(-xmargin,xmargin,2*xmargin*oversampling+1)
    
    yy=xyslice[0]
    xx=xyslice[1]
    
    fprofs=[]
    fsignal=[]
    for fiber in range(tset.nspec) :
        xc = tset.x_vs_y(fiber,(yy.start+yy.stop)/2.)
        if xc < xx.start+xmargin : continue
        if xc > xx.stop-xmargin : continue
        log.debug("using fiber={} xc={}".format(fiber,xc))
        profs = []
        #signal = []
        for y in range(yy.start+ymargin,yy.stop-ymargin) :
            xc = tset.x_vs_y(fiber,y)
            xb = int(xc)-xmargin
            xe = xb + 2*xmargin + 1
            if xb<xx.start : continue
            if xe>xx.stop : continue
            xslice = np.arange(xb,xe)
            prof = np.interp(dx,xslice-xc,image.pix[y,xb:xe])
            #sprof = np.sum(prof[xmargin-3:xmargin+4])
            #prof /= sprof
            prof -= np.median(prof)
            profs.append(prof)
            #signal.append(sprof)
        #prof = np.median(np.vstack(profs),axis=0)
        prof = np.mean(np.vstack(profs),axis=0)
        sprof = np.sum(prof[xmargin-3:xmargin+4])
        prof /= sprof
        fprofs.append(prof)
        #fsignal.append(np.median(signal))
    #ii = np.argsort(fsignal)
    prof = np.median(np.vstack(fprofs),axis=0)
    
    return dx,prof
 
def compute_dark_trail(input_profile,width,amplitude=1,data_profile=None,begin=None,end=None) :
    """
    Computes the dark trail profile given an input cross-dispersion profile.
    The amplitude is fitted if the data_profile is given.
    
    Args:
      input_profile: 1D numpy array with the input cross dispersion profile
      width: the trail width parameter
      amplitude: the amplitude of the trail (overwritten if data_profile is not None)
    
    Optional argument:
      data_profile: 1D numpy array with the data cross dispersion profile. It is
      used to fit the trail amplitude.

    Returns:
      trail: 1D numpy array with the trail profile on the same grid as the input profile
      amplitude: the fitted amplitude (or the input amplitude if data profile = None)
    """
    log = get_logger()
    nn = int(input_profile.size)
    nh = nn//2
    trail = np.zeros(nn)
    for i in range(nn) :
        dxi = i-np.arange(i)
        trail[:i] -= np.abs(dxi/width)*np.exp(-dxi/width)*input_profile[i]
    if data_profile is not None :
        a = np.sum(trail[begin:end]**2)
        b = np.sum(trail[begin:end]*(data_profile[begin:end]-input_profile[begin:end]))
        amplitude = b/a
        #print("a={} b={} begin:end={}:{}".format(a,b,begin,end))
        #print("amplitude={}".format(amplitude))
        #print("trail={}".format(trail))
        #print("data={}".format(data_profile))
        #plt.figure("debug")
        #plt.plot(trail[begin:end])
        #plt.plot(data_profile[begin:end])
        #plt.plot(input_profile[begin:end])
        #plt.plot(input_profile[begin:end]+amplitude*trail[begin:end])
        #plt.show()
        
    trail *= amplitude
    return trail,amplitude

  
def compute_prof_model(x,prof,left) :
    nn=prof.size
    nh=nn//2
    y=prof.copy()
    if not left :
        # flip the profile so we always deal with same orientation
        y = y[::-1]
    # use unaffected right side of profile for pedestal subtraction
    y -= np.median(y[-nh//2:]) 
    x2=np.interp(0.5*y[nh],y[nh:][::-1],x[nh:][::-1])
    x1=np.interp(0.5*y[nh],y[:nh],x[:nh])
    log.info("half max at {:4.3f} {:4.3f} pix".format(x1,x2))
    scale=x2/(-x1)
    #scale=1
    my = y.copy()
    my[:nh] = np.interp(-x[:nh]/scale,x[nh+1:],y[nh+1:])
    a = np.sum(my[nh-5:nh]*y[nh-5:nh])/np.sum(my[nh-5:nh]**2)
    #a=1
    log.info("width scale = {:4.3f}; amplitude scale = {:4.3f}".format(scale,a))
    my[:nh]*=a
    if not left :
        # flip back
        my = my[::-1]
    
    return my

def fit_dark_trail(x,prof,left,begin,end) :
    """
    Fit the width and amplitude parameters of a dark trail given a cross-dispersion profile.
    It uses the unaffect side of the profile as a model for the other side assuming
    a symetric profile. 
    
    Args:
      x: 1D numpy array giving the pixel grid of the cross-dispersion profile.
      prof: 1D numpy array representing the cross dispersion profile.
      left: if true , the trail is to the left of the profile, as is the case for amplifiers B and D.

    Returns:
      amplitude: the fitted amplitude of the trail profile
      width: the fitted width of the trail profile
    """
    log = get_logger()
    
    log.info("Fitting dark trail function")
    
    model = compute_prof_model(x,prof,left)
    
    nn=prof.size
    nh=nn//2
    if not left :
        # flip the profile so we always deal with same orientation
        prof = prof[::-1]
        model = model[::-1]
    
    

    for loop in range(2) :
        if loop == 0 : 
            width_array=np.linspace(1,200,201)
        else :
            width_array=np.linspace(max(1.,width-2),width+3,50)
        chi2=np.zeros(width_array.size)
        for iw,width in enumerate(width_array) :
            trail,amp = compute_dark_trail(model,width=width,data_profile=prof,begin=begin,end=end)
            chi2[iw]=np.sum(((prof-(model+trail))[begin:end])**2)
            log.debug("width={} amp={}, chi2={}".format(width,amp,chi2[iw]))
            
        width=width_array[np.argmin(chi2)]
        log.info("loop {}/2 best fit width = {:3.1f} pixels".format(loop+1,width))

   

    trail,amplitude = compute_dark_trail(model,width=width,data_profile=prof,begin=begin,end=end)
    log.info("amplitude={:6.5f} width={:3.2f} pix".format(amplitude,width))
    
    #import matplotlib.pyplot as plt
    #plt.figure()
    #plt.plot(x,prof,label="data")
    #plt.plot(x,model,label="model without trail")
    #plt.plot(x,model+trail,label="model with trail")
    #plt.plot(x,trail,c="r",label="trail")
    #plt.grid()
    #plt.legend()
    #plt.show()
    
    return amplitude,width
    
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Measure the readout negative  'dark' trails found in some CCD images and fit it with a model of the form a*(x/width)*exp(-x/width). The required inputs are a preprocessed continuum image (obtained by running desi_preproc with the option --nodarktrail) and a PSF file where the coordinates of the spectral traces are stored. The output are simply the amplitude and width of the fitted dark trails.")
parser.add_argument('-i','--infile', type = str, default = None, required = True,
                    help = 'path to a preprocessed continuum fits image')
parser.add_argument('-p','--psf', type = str, default = None, required = False,
                    help = 'path to input psf fits file to get the trace coordinates (default is psf in $DESI_CCD_CALIBRATION_DATA)')
parser.add_argument('--plot', action = 'store_true',
                    help = 'plot figures')
parser.add_argument('-a','--amplifier', type = str, required = True,
                    help = 'A B C or D')
parser.add_argument('--width', type = float, required = False, default = None, help = "fixed width")
parser.add_argument('--range', type = str, required = False, default = None, help = "range of pixel offset for fit, like 5:30 (positive)")

args = parser.parse_args()
log  = get_logger()


image = read_image(args.infile)

cfinder = None
if args.psf is None :
    if cfinder is None :
        cfinder = CalibFinder([image.meta])
    args.psf = cfinder.findfile("PSF")

log.info(" Using PSF {}".format(args.psf))
tset    = read_xytraceset(args.psf)
xyslice = _parse_sec_keyword(image.meta["CCDSEC"+args.amplifier.upper()])

x,prof  = compute_fiber_cross_profile(image,tset,xyslice)

left = ( args.amplifier == "B" or args.amplifier == "D" ) # tail is to the left
nn=prof.size
nh=nn//2    
if not left : prof = prof[::-1]
prof -= np.median(prof[-nh:]) 
model=compute_prof_model(x,prof,left)

begin=5
end=30
if args.range is not None :
    v=args.range.split(":")
    if len(v)!=2 :
        log.error("cannot interpret range '{}' , expect 'begin:end'",format(args.range))
    begin=int(v[0])
    end=int(v[1])
flipped_begin=max(0,-end+prof.size//2)
flipped_end=max(0,-begin+prof.size//2)
begin=flipped_begin
end=flipped_end
log.info("range of indices is {}:{}".format(begin,end))

if args.width :
    width =args.width
    trail,amplitude = compute_dark_trail(model,width=width,data_profile=prof,begin=begin,end=end)
else :
    amplitude,width = fit_dark_trail(x,prof,left,begin=begin,end=end)
    trail,amp = compute_dark_trail(model,width=width,data_profile=prof,begin=begin,end=end)

log.info("Fitted tail amplitude={:6.5f} width={:3.2f} pix".format(amplitude,width))



if args.plot :
    plt.figure("cross-dispersion-profile")
    plt.plot(x,prof,label="data profile")

    trail,amp = compute_dark_trail(model,width=width,data_profile=prof,begin=begin,end=end)
    if not left :
        trail = trail[::-1]
        model = model[::-1]
    
    plt.figure("cross-dispersion-profile")
    plt.plot(x,model,":",label="model (symmetrized data)")
    plt.plot(x,model+trail,"--",label="model + dark trail")
    plt.plot(x[begin:end],(model+trail)[begin:end],":",label="model + dark trail,fit range")
    plt.xlabel("pixels from trace center")
    plt.ylabel("cross dispersion profile")
    plt.grid()
    plt.legend()

    plt.figure("only-trail-profile")
    plt.plot(x[begin:end],(prof-model)[begin:end],"-",label="data-mirror(data)")
    plt.plot(x,trail,":")
    plt.plot(x[begin:end],trail[begin:end],label="trail model")
    
    plt.show()


    
