#!/usr/bin/env python


import sys
import argparse
import matplotlib.pyplot as plt
import numpy as np
#import fitsio
from desiutil.log import get_logger
from desispec.io import read_spectra, read_table
from desispec.interpolation import resample_flux
from astropy.table import Table
import redrock.templates

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                 description="Display spectra, looping over targets if targetid not set, and optionally show best fit from redrock"
)
parser.add_argument('-i','--infile', type = str, default = None, required = True, nargs="*",
                    help = 'path to spectra file(s)')
parser.add_argument('-t','--targetid', type = int, default = None, required = False,
                    help = 'plot specific targetid')
parser.add_argument('--rebin',type = int, default = None, required = False,
                    help = 'rebin')
parser.add_argument('--redrock',type = str, default = None, required = False,
                    help = 'redrock file')
parser.add_argument('--spectype',type = str, default = None, required = False,
                    help = 'spectype to select')
parser.add_argument('--ylim', type=float, default=None, required=False, nargs=2,
                    help = 'ymin ymax for plot')
parser.add_argument('--title', type=str, default=None, required=False,
                    help = 'plot title')
parser.add_argument('--rest-frame', action='store_true',
                    help = 'show rest-frame wavelength')
parser.add_argument('--errors', action='store_true',
                    help = 'show error bars')
parser.add_argument('--only-valid', action='store_true',
                    help = 'show error bars')
parser.add_argument('--zrange', type=float, default=None, required=False, nargs=2,
                    help = 'zmin zmax (if with redredock)')


args        = parser.parse_args()

log = get_logger()

if  args.redrock is not None :
    #- Load redrock templates
    templates = dict()
    for filename in redrock.templates.find_templates():
        tx = redrock.templates.Template(filename)
        templates[(tx.template_type, tx.sub_type)] = tx

if args.redrock is None and args.rest_frame :
    args.rest_frame = False
    print("cannot show rest-frame wavelength without a redrock file")

spectra = []
for filename in args.infile :
    spec=read_spectra(filename)
    if args.only_valid :
        if "FIBERSTATUS" in spec.fibermap.dtype.names :
            good=(spec.fibermap["FIBERSTATUS"]==0)
        elif "COADD_FIBERSTATUS" in spec.fibermap.dtype.names :
            good=(spec.fibermap["COADD_FIBERSTATUS"]==0)
        else :
            log.error("cannot find FIBERSTATUS nor COADD_FIBERSTATUS in fibermap")
        spec=spec[good]
    spectra.append(spec)

targetids=None
if ( targetids is None ) and ( args.targetid is not None ) :
    targetids=[args.targetid,]

redshifts=None
if  args.redrock is not None :
    try:
        redshifts=read_table(args.redrock, "REDSHIFTS")
    except KeyError:
        redshifts=read_table(args.redrock, "ZBEST")
    except OSError:
        redshifts=read_table(args.redrock, "ZBEST")

if ( targetids is None ) and ( redshifts is not None ) and ( args.spectype is not None ):
    selection = np.where((redshifts["SPECTYPE"]==args.spectype)&(redshifts["ZWARN"]==0))[0]
    targetids=np.unique(spectra[0].fibermap["TARGETID"][selection])

if targetids is None :
    targetids=np.unique(spectra[0].fibermap["TARGETID"])

if redshifts is not None and args.zrange is not None :
    selection = np.isin(spectra[0].fibermap["TARGETID"],targetids)&(redshifts["Z"]>=args.zrange[0])&(redshifts["Z"]<=args.zrange[1])
    targetids=np.unique(spectra[0].fibermap["TARGETID"][selection])


lines = {
    'Ha'      : 6562.8,
    'Hb'       : 4862.68,
    'Hg'       : 4340.464,
    'Hd'       : 4101.734,
    'OIII-b'       :  5006.843,
    'OIII-a'       : 4958.911,
    'MgII'    : 2799.49,
    'OII'         : 3728,
    'CIII'  : 1909.,
    'CIV'    : 1549.06,
    'SiIV'  : 1393.76018,
    'LYA'         : 1215.67,
    'LYB'         : 1025.72
}



for tid in targetids :
    line="TARGETID={}".format(tid)

    model_flux=dict()
    if redshifts is not None :
        j=np.where(redshifts["TARGETID"]==tid)[0][0]
        line += " Z={} SPECTYPE={} ZWARN={}".format(redshifts["Z"][j],redshifts["SPECTYPE"][j],redshifts["ZWARN"][j])
        zval=redshifts["Z"][j]
        tx = templates[(redshifts['SPECTYPE'][j], redshifts['SUBTYPE'][j])]
        for band in spectra[0].bands:
            model_flux[band] = np.zeros(spectra[0].wave[band].shape)
            coeff = redshifts['COEFF'][j][0:tx.nbasis]
            model = tx.flux.T.dot(coeff).T
            mx = resample_flux(spectra[0].wave[band], tx.wave*(1+redshifts['Z'][j]), model)
            k=np.where(spectra[0].fibermap["TARGETID"]==tid)[0][0]
            model_flux[band] = spectra[0].R[band][k].dot(mx)

    fig=plt.figure(figsize=[10,6])
    ax = fig.add_subplot(111)
    print(line)
    something_to_show = False
    for spec in spectra :
        jj=np.where(spec.fibermap["TARGETID"]==tid)[0]
        if len(jj)==0 :
            log.warning("TARGETID {} not in spectra".format(tid))
            continue
        something_to_show = True
        if "FIBER" in spec.fibermap.dtype.names :
            fiber=spec.fibermap["FIBER"][jj[0]]
        else :
            log.warning("no FIBER column in spectra file FIBERMAP")
            fiber=-1
        wavescale=1.
        if args.rest_frame :
            wavescale = 1./(1+zval)

        for j in jj :
            for b in spec._bands :

                i=np.where(spec.ivar[b][j]*(spec.mask[b][j]==0)>1./100.**2)[0]
                if i.size<10 : continue
                if args.rebin is not None and args.rebin>0:
                    rwave=np.linspace(spec.wave[b][0],spec.wave[b][-1],spec.wave[b].size//args.rebin)
                    rflux,rivar = resample_flux(rwave,spec.wave[b],spec.flux[b][j],ivar=spec.ivar[b][j]*(spec.mask[b][j]==0))
                else:
                    rwave = spec.wave[b][i]
                    rflux = spec.flux[b][j, i]
                    rivar = spec.ivar[b][j, i]
                if args.errors:
                    plt.fill_between(wavescale*rwave, rflux-1./np.sqrt(rivar),
                                     rflux+1./np.sqrt(rivar), alpha=0.5)
                    line, = plt.plot(wavescale*rwave, rflux)
                    plt.plot(wavescale*rwave, 1/np.sqrt(rivar)-2,
                             color=line.get_color(), linestyle='--')
                else:
                    plt.plot(wavescale*rwave, rflux)

                c=np.polyfit(spec.wave[b][i],spec.flux[b][j,i],3)
                pol=np.poly1d(c)(spec.wave[b][i])

        print(spec.fibermap[jj])

    if not something_to_show :
        log.error("no data to show")
        sys.exit(1)
    if redshifts is not None :
        for band in spectra[0].bands:
            plt.plot(wavescale*spectra[0].wave[band],model_flux[band],"-",alpha=0.6)
            for elem in lines :
                line=(1+zval)*lines[elem]
                if line>spectra[0].wave[band][0] and line<spectra[0].wave[band][-1] :
                    plt.axvline(wavescale*line,color="red",linestyle="--",alpha=0.4)
                    y=np.interp(wavescale*line,wavescale*spectra[0].wave[band],model_flux[band])
                    plt.text(wavescale*(line+60),y*1.1,elem.split("-")[0],color="red")
    if args.rest_frame :
        plt.xlabel("rest-frame wavelength [A]")
    else :
        plt.xlabel("wavelength [A]")
    plt.grid()
    props = dict(boxstyle='round', facecolor='yellow', alpha=0.2)
    bla="TID = {}".format(tid)
    bla+="\nFIBER = {}".format(fiber)
    if redshifts is not None :
        bla+="\nZ  = {:4.3f}".format(zval)
    plt.text(0.9,0.9,bla,fontsize=12, bbox=props,transform = ax.transAxes,verticalalignment='top', horizontalalignment='right')
    if args.ylim is not None:
        plt.ylim(args.ylim[0], args.ylim[1])

    if args.title is not None:
        plt.title(args.title)

    plt.tight_layout()
    plt.show()


plt.show()
