#!/usr/bin/env python

import argparse
import numpy as np
import matplotlib.pyplot as plt
import fitsio
from scipy.signal import fftconvolve
import astropy.io.fits as pyfits
import scipy.ndimage

from scipy.optimize import minimize


nodes=np.array([0,5,10,20,30,40,50,100,150,200,300])
hw=nodes[-1]
x1d=np.linspace(-hw,hw,2*hw+1)
x2d=np.tile(x1d,(2*hw+1,1))
r=np.sqrt(x2d**2+(x2d.T)**2)

def convolve1d(spectrum,vals) :
    kern1d = np.interp(r,nodes,vals,right=0)
    return fftconvolve(spectrum,kern1d,mode="same")
    
def convolve(spectrum,nodes,vals) :
    #return convolve1d(spectrum,nodes,vals,right=0)
    kern2d = np.interp(r,nodes,vals,right=0)
    kern1d = np.sum(kern2d,axis=1)
    return fftconvolve(spectrum,kern1d,mode="same")
    

def func(params, nodes, input_spectrum, convolved_spectrum):
    model = convolve(input_spectrum,nodes,params)
    return np.sum((model[200:-200]-convolved_spectrum[200:-200])**2)
   
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Compute scattered light convolution kernel using preprocessed arc lamp image")
parser.add_argument('-i','--infile', type = str, default = None, required = True, 
                    help = 'path of preprocessed arc lamp image fits files')
parser.add_argument('--plot', action='store_true', 
                    help = 'plot')

args = parser.parse_args()

filename=args.infile
img=fitsio.read(filename)
head=fitsio.read_header(filename)
cam=head["CAMERA"].strip().lower()
prof=np.mean(img[:,1850:2000],axis=1)
mprof=scipy.ndimage.filters.median_filter(prof,50,mode='constant')
lines=prof-mprof



params=np.zeros(nodes.shape)
p = minimize(func, params, args=(nodes,lines,mprof))
params = p.x

# set to zero all values beyond the first negative
i=np.where(params<=0)[0]
if i.size>0 :
    params[i[0]:] *= 0.
# also last one
params[-1]=0.

line="params['{}']=np.array([".format(cam)
for p in params :
    line += "{:5.4f},".format(p/params[0])
line+="])"
print(line)


if args.plot :
    
    plt.figure()
    plt.plot(prof)
    model = convolve(lines,nodes,params)
    plt.plot(model,label="model")
    plt.plot(prof-model)
    plt.legend()

    plt.figure()
    plt.plot(nodes,params,"o-")
    
    plt.show()
