#!/usr/bin/env python


import sys,os
import argparse
import matplotlib.pyplot as plt
import numpy as np
import fitsio
from importlib import resources

from desispec.util import parse_fibers
from desispec.qproc.io import read_qframe
from desispec.io import read_fibermap,read_frame
from desispec.interpolation import resample_flux
from desispec.fluxcalibration import isStdStar

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-i','--infile', type = str, default = None, required = True, nargs="*",
                    help = 'path to one or several frame fits files')
parser.add_argument('--fibers', type=str, default = None, required = False,
                    help = 'defines from_to which fiber to work on. (ex: --fibers=50:60,4 means that only fibers 4, and fibers from 50 to 60 (excluded) will be plotted)')
parser.add_argument('--legend', action='store_true',help="show legend")
parser.add_argument('--log', action='store_true', default = False, required = False,
                    help = 'log scale')
parser.add_argument('--batch', action='store_true', default = False, required = False,
                    help = 'batch mode (to save figure and exit')
parser.add_argument('-o','--outfile', type = str, default = None, required = False,
                    help = 'save figure in this file')
parser.add_argument('--ascii-spectrum', type = str, default = None, required = False,
                    help = 'also plot this ascii spectrum (first column=wavelength, second column=flux)')
parser.add_argument('--sky-spectrum', action = 'store_true',
                    help = 'also plot the sky spectrum')
parser.add_argument('--rebin', type = int, default=None)
parser.add_argument('--focal-plane', action = 'store_true', help = 'show focal plane view with median flux per fiber')
parser.add_argument('--norm', action = 'store_true', help = 'normalize for color scale of focal plane view')
parser.add_argument('--vmin', type = float, default=None, help = 'min value for color scale of focal plane view')
parser.add_argument('--vmax', type = float, default=None, help = 'max value for color scale of focal plane view')
parser.add_argument('--radial', action = 'store_true', help = 'show radial focal plane view with median flux per fiber (requires option --focal-plane)')
parser.add_argument('--wmin', type = float, default=None, help = 'min of wavelength range for mean flux of focal plane view (only valid form Frame, not QFrame)')
parser.add_argument('--wmax', type = float, default=None, help = 'max of wavelength range for mean flux of focal plane view (only valid form Frame, not QFrame)')
parser.add_argument('--no-mask', action = 'store_true', help = 'ignore mask')
parser.add_argument('--std-stars', action = 'store_true', help = 'show only the std stars')
parser.add_argument('--sky-fibers', action = 'store_true', help = 'show only the sky fibers')


args   = parser.parse_args()

if not args.focal_plane :
    fig=plt.figure()
    subplot=plt.subplot(1,1,1)


if args.focal_plane :
    x=list()
    y=list()
    z=list()
    medflux=None

for filename in args.infile :

    print(filename)

    if args.fibers is not None :
        fibers = np.sort(parse_fibers(args.fibers))
    else :
        fibers = None

    # frame or qframe ...
    head=fitsio.read_header(filename,"WAVELENGTH")
    naxis=head["NAXIS"]
    if naxis==2 :
        frame = read_qframe(filename)
    else :
        frame = read_frame(filename)

    if args.std_stars :
        selection = isStdStar(frame.fibermap)
        fibers = np.sort(frame.fibermap["FIBER"][selection])
        print("showing {} std stars fibers: {}".format(fibers.size,list(fibers)))
    if args.sky_fibers :
        selection = frame.fibermap["OBJTYPE"]=="SKY"
        fibers = np.sort(frame.fibermap["FIBER"][selection])
        print("showing {} sky fibers fibers: {}".format(fibers.size,list(fibers)))

    if fibers is None :
        fibers = frame.fibermap["FIBER"]

    selection = np.isin(frame.fibermap["FIBER"],fibers)
    if args.focal_plane and frame.fibermap is not None :
        x.append(frame.fibermap["FIBERASSIGN_X"][selection])
        y.append(frame.fibermap["FIBERASSIGN_Y"][selection])

        if args.wmin is not None and len(frame.wave.shape) == 1 :
            b=np.where(frame.wave>=args.wmin)[0][0]
        else :
            b=200 # to avoid edge of camera with variable dichroic transmission
        if args.wmax is not None and len(frame.wave.shape) == 1 :
            e=np.where(frame.wave<=args.wmax)[0][-1]+1
        else :
            e=frame.wave.size-200 # to avoid edge of camera with variable dichroic transmission

        #vals = np.median(frame.flux[:,b:e],axis=1)[selection]*(frame.fibermap["FIBERSTATUS"]==0)[selection]
        vals = np.median(frame.flux[:,b:e],axis=1)[selection]
        z.append(vals)
        continue


    if frame.fibermap is not None :
        fibers_in_frame = np.sort(frame.fibermap["FIBER"])
    else :
        fibers_in_frame = np.arange(frame.flux.shape[0])


    if fibers is None :
        fibers = fibers_in_frame
    else :
        selection = np.isin(fibers_in_frame,fibers)


        if np.sum(selection)==0 :
            print("empty selection")
            print("fibers are in the range [{},{}] (included)".format(np.min(fibers_in_frame),np.max(fibers_in_frame)))
            print("for convenience, I add X*500 to the fiber numbers")
            fibers = fibers%500+500*(fibers_in_frame[0]//500)
            print("will show fibers =",list(fibers))
            selection = np.isin(fibers_in_frame,fibers)
            if np.sum(selection)==0 :
                print("still empty selection!")
                sys.exit(12)

        frame = frame[selection]
        if len(fibers) > np.sum(selection) :
            print("not all requested fibers are in frame")
            fibers = fibers[np.isin(fibers,fibers_in_frame)]
            print("will show only {}".format(fibers))

    if args.no_mask :
        frame.ivar += 0.1*(frame.ivar==0)*(frame.mask>0)
        frame.mask*=0

    for i,fiber in enumerate(fibers) :

        jj=np.where((frame.ivar[i]>0)&(frame.mask[i]==0))[0]
        if jj.size==0 :
            print("fiber {} has no valid flux value".format(fiber))
            continue

        if len(frame.wave.shape)>1 :
            wave=frame.wave[i]
        else :
            wave=frame.wave
        label="fiber {:03d}".format(fiber)
        if len(args.infile) > 1 :
            label += " {}".format(os.path.basename(filename))
        if args.rebin is None :
            plt.plot(wave[jj],frame.flux[i,jj],label=label)
        else :
            rwave=np.linspace(wave[0],wave[-1],wave.size//args.rebin)
            rflux,rivar=resample_flux(rwave,wave,frame.flux[i],frame.ivar[i]*(frame.mask[i]==0))
            jj=(rivar>0)
            if jj.size>0 :
                plt.plot(rwave[jj],rflux[jj],label=label)

if not args.focal_plane :

    if args.sky_spectrum :
        args.ascii_spectrum = resources.files('desispec').joinpath('data/spec-sky.dat')


    if args.ascii_spectrum :
        tmp = np.loadtxt(args.ascii_spectrum)
        spec_wave=tmp[:,0]
        spec_flux=tmp[:,1]
        spec_flux *= np.max(frame.flux)/np.max(spec_flux)
        kk = (spec_wave>np.min(wave))&(spec_wave<np.max(wave))
        if np.sum(kk)==0 :
            print("no intersecting wavelength in file {} (first column interpreted as wavelength)".format(args.ascii_spectrum))
        else :
            plt.plot(spec_wave[kk],spec_flux[kk],label=os.path.basename(args.ascii_spectrum))


    if args.legend :
        plt.legend(loc="upper left",fontsize="small")
    if args.log :
        plt.yscale("log")


    plt.xlabel("Wavelength (A)")
    plt.ylabel("Flux")
    plt.grid()
    plt.tight_layout()


if args.focal_plane :
    x=np.hstack(x)
    y=np.hstack(y)
    z=np.hstack(z)
    ii = (z!=0)
    if args.norm :
        z /= np.median(z[ii])
        ii = (z>0.5)
    fig = plt.figure("focal_plane",figsize=(6,5))
    mz=np.median(z[ii])
    rmsz=1.4*np.median(np.abs(z[ii]-mz))
    if args.vmin is not None:
        vmin = args.vmin
    else :
        vmin=mz-3*rmsz
    if args.vmax is not None:
        vmax = args.vmax
    else :
        vmax=mz+3*rmsz

    plt.scatter(x[ii],y[ii],c=z[ii],vmin=vmin,vmax=vmax,s=30)
    plt.axis("off")
    plt.colorbar()
    plt.tight_layout()

if args.outfile is not None :
    fig.savefig(args.outfile)
    print("wrote {}".format(args.outfile))

if args.focal_plane and args.radial :
    plt.figure("radial")
    r=np.sqrt(x**2+y**2)
    plt.plot(r[ii],z[ii],".",alpha=0.6)
    plt.ylim([vmin,vmax])

if not args.batch :
    plt.show()
