#!/usr/bin/env python

import numpy as np
import argparse
import matplotlib.pyplot as plt
import fitsio

from desispec.io import read_xytraceset
from desispec.util import parse_fibers

def u(wave,wavemin,wavemax) :
    return 2.*(wave-wavemin)/(wavemax-wavemin)-1.

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-i','--inpsf', type = str, default = None, required = True,
                    help = 'path to psf files')
parser.add_argument('--fibers', type=str, default = None, required = False,
                    help = 'defines from_to which fiber to work on. (ex: --fibers=50:60,4 means that only fibers 4, and fibers from 50 to 60 (excluded) will be plotted)')
parser.add_argument('--image', type=str, default = None, required = False,
                    help = 'overplot traces on image')
parser.add_argument('--mask', action='store_true', required = False,
                    help = 'Use image mask when plotting')
parser.add_argument('--lines', type=str, default = None, required = False,
                    help = 'coma separated list of lines')
parser.add_argument('--vmin', type=float, default = None, required = False,
                    help = 'min value for image display')
parser.add_argument('--vmax', type=float, default = None, required = False,
                    help = 'max value for image display')
parser.add_argument('--zscale', action='store_true', required = False,
                    help = 'Use IRAF/ds9-like zscale limits for image')
parser.add_argument('--other-psf', type= str, default = None, required = False,
                    help = 'other psf to compare with')
parser.add_argument('--wavelength', type= float, default = None, required = False,
                    help = 'compare traces for this wavelength (default is central wavelength), used only along with --other-psf option')

args = parser.parse_args()



lines=None
if args.lines :
    lines=list()
    for tmp in args.lines.split(",") :
        lines.append(float(tmp))
    print("lines=",lines)

tset = read_xytraceset(args.inpsf)

if args.fibers is not None :
    fibers = parse_fibers(args.fibers)
else :
    fibers = np.arange(tset.nspec)

#- shift fibers to [0,500) so that they can be interpreted as indices
fibers = fibers%500

nw=50
wave = np.linspace(tset.wavemin,tset.wavemax,nw)

print(wave)
print(fibers)

x=np.zeros((fibers.size,nw))
y=np.zeros((fibers.size,nw))




plt.figure("traces")

if args.image is not None :
    img=fitsio.read(args.image)

    if args.zscale:
        from astropy.visualization import ZScaleInterval
        vmin, vmax = ZScaleInterval().get_limits(img)
    else:
        vmax=1000
        for l in range(5) :
            vmax=np.median(img[img>vmax])
            vmin=0
        if args.vmin is not None :
            vmin=args.vmin
        if args.vmax is not None :
            vmax=args.vmax

    #- Apply mask -> NaN after vmin to not mess up medians.
    #- Use NaN instead of 0 so that masked pixels aren't colored
    #- like normal pixels
    if args.mask:
        mask = fitsio.read(args.image, 'MASK')
        # img *= (mask==0)
        img[mask!=0] = np.nan

    plt.imshow(img,origin="lower",vmin=vmin,vmax=vmax,aspect="auto")

for i,fiber in enumerate(fibers) :
    x = tset.x_vs_wave(fiber,wave)
    y = tset.y_vs_wave(fiber,wave)
    color=None
    if args.image is not None: color="lightgray"
    plt.plot(x,y,color=color)


if lines is not None :
    for line in lines :
        xl=np.zeros(fibers.size)
        yl=np.zeros(fibers.size)
        for i,fiber in enumerate(fibers) :
            xl[i] = tset.x_vs_wave(fiber,line)
            yl[i] = tset.y_vs_wave(fiber,line)
        color=None
        if args.image is not None: color="lightgray"
        plt.plot(xl,yl,color=color)

plt.xlabel("xccd")
plt.ylabel("yccd")
plt.tight_layout()



if args.other_psf is not None :

    if args.wavelength is not None :
        wave_to_compare = args.wavelength
    else :
        wave_to_compare = np.mean(wave)

    
    otset = read_xytraceset(args.other_psf)
    xx=np.zeros(len(fibers))
    yy=np.zeros(len(fibers))
    dx=np.zeros(len(fibers))
    dy=np.zeros(len(fibers))
    for i,fiber in enumerate(fibers) :
        x = tset.x_vs_wave(fiber,wave_to_compare)
        y = tset.y_vs_wave(fiber,wave_to_compare)
        ox = otset.x_vs_wave(fiber,wave_to_compare)
        oy = otset.y_vs_wave(fiber,wave_to_compare)
        dx[i]=np.mean(x-ox)
        dy[i]=np.mean(y-oy)
        xx[i]=np.mean(x)
        yy[i]=np.mean(y)

    plt.figure("delta-x")
    plt.subplot(111,title="lambda={:d}A <y>={:d}".format(int(wave_to_compare),int(np.mean(yy))))
    plt.plot(fibers,dx,"-",label="dx")
    plt.plot(fibers,dy,"-",label="dy")
    plt.xlabel("fiber")
    plt.ylabel("delta X")
    plt.legend()
    plt.tight_layout()

    plt.figure("delta-y")
    nw=50
    waves = np.linspace(tset.wavemin,tset.wavemax,nw)
    fiber_to_compare = int(np.median(fibers))
    print("fiber=",fiber_to_compare)
    yy=[]
    dyy=[]
    for i,wave in enumerate(waves) :
        x = tset.x_vs_wave(fiber_to_compare,wave)
        y = tset.y_vs_wave(fiber_to_compare,wave)
        ox = otset.x_vs_wave(fiber_to_compare,wave)
        oy = otset.y_vs_wave(fiber_to_compare,wave)
        yy.append(y)
        dyy.append(y-oy)
        plt.subplot(111,title="fiber={:d}".format(int(fiber_to_compare)))
    plt.plot(yy,dyy,"-",label="dy")
    plt.xlabel("Y")
    plt.ylabel("delta Y")
    plt.tight_layout()

plt.show()
