#!python

import argparse
import numpy as np

from desispec.io import read_xytraceset,write_xytraceset
from desispec.xytraceset import XYTraceSet
from desispec.util import parse_fibers

def extrapolate(x,xo,yo) :
    y=np.interp(x,xo,yo)
    ii=x<xo[0]
    if np.sum(ii)>0:
        y[ii] = yo[0] + (x[ii]-xo[0])*(yo[1]-yo[0])/(xo[1]-xo[0])
    ii=x>xo[-1]
    if np.sum(ii)>0:
        y[ii] = yo[-1] + (x[ii]-xo[-1])*(yo[-2]-yo[-1])/(xo[-2]-xo[-1])
    return y


parser = argparse.ArgumentParser(description="Insert 'manually' a fiber in a psf boot file. This is to get the right count of 500 fiber traces for book keeping despite the fact some of them are broken")

parser.add_argument('-i','--infile', type = str, default = None, required=True,
                    help = 'input psf fits file')
parser.add_argument('-o','--outfile', type = str, default = None, required=True,
                    help = 'output psf fits file')
parser.add_argument('--fibers', type = str, default = None, required=True,
                    help = 'fiber indices (i, or i:j or i,j,k) (for i, a fiber will be inserted between index i-1 and i in the file)')

args = parser.parse_args()

tset=read_xytraceset(args.infile)
fibers=parse_fibers(args.fibers)


coefs={}

coefs["xcoef"] = tset.x_vs_wave_traceset._coeff.copy()
coefs["ycoef"] = tset.y_vs_wave_traceset._coeff.copy()

if tset.xsig_vs_wave_traceset is not None :
    coefs["xsigcoef"] = tset.xsig_vs_wave_traceset._coeff.copy()
else :
    coefs["xsigcoef"] = None

if tset.ysig_vs_wave_traceset is not None :
    coefs["ysigcoef"] = tset.ysig_vs_wave_traceset._coeff.copy()
else :
    coefs["ysigcoef"] = None

wave=(tset.wavemin+tset.wavemax)/2.
print("wave={:4.1f}A".format(wave))

for fiber in fibers :

    copy_from_left = (fiber%25>12) # the fiber is on the right side of the bundle
    
    if copy_from_left :
        print("Adding fiber {}, copying from the left".format(fiber))
    else :
        print("Adding fiber {}, copying from the right".format(fiber))
    
    for k in coefs :
        
        coef=coefs[k]
        if coef is None : continue
        
        new_coef = np.zeros((coef.shape[0]+1,coef.shape[1]))
        new_coef[:fiber] = coef[:fiber]
        new_coef[fiber+1:] = coef[fiber:]

        other_fibers = np.arange(coef.shape[0])
        other_fibers[fiber:] += 1

        if copy_from_left :
            selection = np.where(other_fibers<fiber)[0][-3:]
        else :
            selection = np.where(other_fibers>fiber)[0][:3]
        
        for c in range(coef.shape[1]) :
            #print(fiber,other_fibers[selection],new_coef[selection,c])
            new_coef[fiber,c] = extrapolate(np.array([fiber,]),other_fibers[selection],new_coef[other_fibers[selection],c])[0]
            if k=="xcoef" and c == 0 :
                print("x({})={} -> x({})={}".format(other_fibers[selection],new_coef[other_fibers[selection],c],fiber,new_coef[fiber,c]))
        
        print(new_coef.shape)
        coefs[k] = new_coef


ntset = XYTraceSet(coefs["xcoef"],coefs["ycoef"],tset.wavemin,tset.wavemax,
                   npix_y=tset.npix_y,xsigcoef=coefs["xsigcoef"],ysigcoef=coefs["ysigcoef"],
                   meta = tset.meta)


print("Old trace set with {} fibers:".format(tset.nspec))
for fiber in fibers :
    for f in [fiber-1,fiber]:
        if f>=0 and f<tset.nspec :
            print("fiber #{:d} x={:4.1f} y={:4.1f}".format(f,tset.x_vs_wave(f,wave),tset.y_vs_wave(f,wave)))
print("New trace set with {} fibers:".format(ntset.nspec))
for fiber in fibers :
    for f in [fiber-1,fiber,fiber+1] :
        if f>=0 and f<tset.nspec :
            print("fiber #{:d} x={:4.1f} y={:4.1f}".format(f,ntset.x_vs_wave(f,wave),ntset.y_vs_wave(f,wave)))


write_xytraceset(args.outfile,ntset)
print("wrote",args.outfile)



 
