#!python

import sys,os
import argparse
import astropy.io.fits as pyfits
from astropy.table import Table
import fitsio
import numpy as np
import glob
from scipy.ndimage.filters import median_filter
from scipy.signal import fftconvolve,correlate


import matplotlib.pyplot as plt

from desiutil.log import get_logger
from desispec.io import specprod_root
from desispec.io import findfile,read_fiberflat
from desispec.fiberflat_vs_humidity import fit_wave_of_dip

def parse(options=None):
    parser = argparse.ArgumentParser(
                description="Generates fiber flat-field templates as a function of humidity")
    parser.add_argument("-y", "--years", type = str, default = None, required = True,
                        help = "coma separated list of years of data to scan")
    parser.add_argument("--reference-camera", type = str, default = "b0", required = False,
                        help = "reference camera to index humidity based on measured wavelength shift")
    parser.add_argument("-p", "--prod", type = str, default = None, required = True,
                        help = "Path to input reduction, e.g. /global/cfs/cdirs/desi/spectro/redux/daily/,  or simply prod version, like daily, but requires env. variable DESI_SPECTRO_REDUX. Default is $DESI_SPECTRO_REDUX/$SPECPROD.")
    parser.add_argument("-o", "--outdir", type = str, default = ".", required = False,
                        help = "output directory")
    parser.add_argument("-s", "--spectrographs", type = str, default = "0,1,2,3,4,5,6,7,8,9", required = False,
                        help = "comma separated list of spectographs, default is '0,1,2,3,4,5,6,7,8,9'")
    parser.add_argument("--first-night", type = int, default = None, required = False,
                        help = "specify first night (in case of change in the hardware)")
    args = None
    if options is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(options)
    return args

def define_template_bins(input_table_filename) :
    """
    Sort nighly flats and define bins for the templates based on the fit of the wavelength shift in one spectrograph.
    Args:
       input_table_filename input table filename with required columns NIGHT and DIPWAVE.

    Returns:
      wavebins: wavelength (of the throughput dip) bins (1D array) , nights_in_bins: list of list of nights
    """

    log=get_logger()
    log.info("Define the template bins")

    # read input table
    # sort by wavelength
    # define bins
    input_table = Table.read(input_table_filename)
    ii = np.argsort(input_table["DIPWAVE"])

    wave=input_table["DIPWAVE"][ii]
    b=min(int(wave[0]),4300)
    dwave=2
    e=int(wave[0])+1
    nmin=1
    wavebins=[]
    wavebins.append(b)

    while(True) :
        while((np.sum((wave>=b)&(wave<e))<nmin)&(e<=wave[-1]+0.001)) :
            e += dwave
            continue
        wavebins.append(e)
        b=e+0
        if e>wave[-1] : break
    wavebins[-1]=max(wavebins[-1],4500)
    wavebins = np.array(wavebins)
    log.info(f"dip wavelenth bins = {wavebins}")
    nbins=wavebins.size-1

    nights_in_bins = []
    for b in range(nbins) :
        ii=np.where((input_table["DIPWAVE"]>=wavebins[b])&(input_table["DIPWAVE"]<=wavebins[b+1]))
        nights_in_bins.append(list(input_table["NIGHT"][ii]))
        log.info("bin {} : {}-{} nights {} {}".format(b,wavebins[b],wavebins[b+1],len(nights_in_bins[b]),list(nights_in_bins[b])))

    return wavebins, nights_in_bins

def compute_humidity_table(camera,years,specprod_dir,fit_dip_wavelength=False,first_night=None) :
    """

    Look for nightly flatfield is a series of years for a camera and record the night and
    humdity values. Optionally computes the wavelength of the absorption feature.

    Args:
       camera: str, camera, like 'b0'
       year: list(int), years
       specprod_dir: str, full path to a production directory

    Optional:
       fit_dip_wavelength: fit the wavelength of the absorption feature.
       first_night: first night to consider

    Returns:
      astropy.Table with columns 'NIGHT', 'HUMIDITY' and optionally 'DIPWAVE'
    """

    log = get_logger()
    log.info(f"Computing humidity table for {camera}")

    spectrograph=int(camera[-1])

    vals=dict()
    vals["NIGHT"]=[]
    vals["HUMIDITY"]=[]
    if fit_dip_wavelength :
        vals["DIPWAVE"]=[]

    # loop on nights with fiberflat
    flat_filenames = []
    for year in years :
        for flat_filename in sorted(glob.glob(f"{specprod_dir}/calibnight/{year}*/fiberflatnight-{camera}-*.fits")) :
            flat_filenames.append(flat_filename)

        for flat_filename in flat_filenames :
            log.info(f"reading {flat_filename}")
            head=fitsio.read_header(flat_filename)
            night=int(head['NIGHT'])

            if first_night is not None and night < first_night :
                log.warning(f"skip {night} < {first_night}")
                continue

            if camera == "b1" and night == 20220413 :
                log.warning("skip 20220413 for b1 because we tested a different mirror")
                continue

            month=night//100
            et_filename=f"{specprod_dir}/exposure_tables/{month}/exposure_table_{night}.csv"
            if not os.path.isfile(et_filename) :
                log.error(f"missing {et_filename}")
                continue
            et=Table.read(et_filename)
            ii=(et['OBSTYPE']=='flat')
            expids=list(et['EXPID'][ii])

            humidity_vals=[]
            for expid in expids :
                desi_filename=findfile("raw",night=night,expid=expid)
                spectcons=fitsio.read(desi_filename,"SPECTCONS")
                ii=(spectcons["unit"]==spectrograph)
                if np.sum(ii)==0 :
                    log.warning(f"no spectro {spectrograph} in {desi_filename}")
                    continue
                if not "BHUMID" in spectcons.dtype.names :
                    log.warning(f"no BHUMID column in hdu SPECTCONS of {desi_filename}")
                    continue
                humidity=float(spectcons["BHUMID"][ii])
                if humidity<=0 or humidity>100 :
                    log.warning(f"unphysical humidity value = {humidity}")
                    continue
                humidity_vals.append(humidity)
            if len(humidity_vals)==0 :
                log.error(f"couldn't get humidity info for night {night}")
                continue

            humidity=np.mean(humidity_vals)
            vals["NIGHT"].append(night)
            vals["HUMIDITY"].append(humidity)

            if fit_dip_wavelength :
                ff=read_fiberflat(flat_filename)
                dipwave=fit_wave_of_dip(ff.wave,ff.fiberflat)
                vals["DIPWAVE"].append(dipwave)

    table=Table()
    for k in vals.keys() :
        table[k]=vals[k]

    return table


def main():
    log=get_logger()

    args = parse()

    args.years = [ int(v) for v in args.years.split(",") ]


    if args.prod is None:
        args.prod = specprod_root()
    elif args.prod.find("/")<0 :
        args.prod = specprod_root(args.prod)

    if not os.path.isdir(args.outdir):
        log.info("creating {}".format(args.outdir))
        os.makedirs(args.outdir, exist_ok=True)


    # reference table with columns NIGHT HUMIDITY DIPWAVE
    reference_humidity_table_filename=f"{args.outdir}/humidity_table_{args.reference_camera}.csv"
    if not os.path.isfile(reference_humidity_table_filename) :
        table = compute_humidity_table(args.reference_camera,args.years,args.prod,fit_dip_wavelength=True,first_night=args.first_night)
        table.write(reference_humidity_table_filename)
        log.info("wrote "+reference_humidity_table_filename)

    xbins , nights_in_bins = define_template_bins(reference_humidity_table_filename)
    nbins = xbins.size-1

    wave=None
    fiberflats_in_bin=None
    hdulist=None

    spectrographs = [ int(s) for s in args.spectrographs.split(',')]

    for spectro in spectrographs :
        cam=f"b{spectro}"

        humidity_table_filename=f"{args.outdir}/humidity_table_{cam}.csv"

        if not os.path.isfile(humidity_table_filename) :
            table = compute_humidity_table(cam,args.years,args.prod,fit_dip_wavelength=False,first_night=args.first_night)
            table.write(humidity_table_filename)
            log.info("wrote "+humidity_table_filename)

        humidity_table = Table.read(humidity_table_filename)
        ii = ~np.isnan(humidity_table["HUMIDITY"])
        humidity_table = humidity_table[ii]

        wave = None
        goodfibers= None

        # loop on bins
        for bin_index in range(nbins-1) :

            extname="HUM{:02d}".format(bin_index)
            ofilename=f"{args.outdir}/fiberflat-vs-humidity-{cam}-{extname}.fits"
            if os.path.isfile(ofilename) :
                log.info(f"skip existing {ofilename}")
                continue

            log.info("nights in bins = {}".format(list(nights_in_bins[bin_index])))
            selection = np.isin(humidity_table["NIGHT"],nights_in_bins[bin_index])
            if np.sum(selection)==0 :
                log.warning("bin={} nflats={} (skip)".format(bin_index,np.sum(selection)))
                continue

            nights_in_bin     = humidity_table["NIGHT"][selection]
            log.info("bin={} nflats={} nights={}".format(bin_index,np.sum(selection),list(nights_in_bin)))


            fiberflats_in_bin = []
            humidities_in_bin = []

            # loop on nights with fiberflat
            for night in nights_in_bin :
                flat_filename = f"{args.prod}/calibnight/{night}/fiberflatnight-{cam}-{night}.fits"

                if not os.path.isfile(flat_filename) :
                    log.warning(f"MISSING {flat_filename}")
                    continue
                log.info(f"Humidity bin #{bin_index} night={night} {flat_filename}")
                flat=read_fiberflat(flat_filename)

                if wave is None :
                    wave=flat.wave
                    # fixed list of good fibers
                    ivar=(flat.ivar)*(flat.mask==0)
                    hasnan=np.sum(np.isnan(flat.fiberflat),axis=1)
                    goodfibers=np.where((hasnan==0)&(np.sum(flat.ivar,axis=1)>0))[0]
                    medflat=np.median(flat.fiberflat,axis=1)
                    mmedflat=np.median(medflat[goodfibers])
                    goodfibers=np.where((hasnan==0)&(np.sum(flat.ivar,axis=1)>0)&(medflat>0.5*mmedflat))[0]
                    if len(goodfibers)<300 :
                        wave=None
                        continue

                mflat=np.median(flat.fiberflat[goodfibers],axis=0)
                flats=[]
                for fiber in goodfibers :
                    norm=np.sum(flat.ivar[fiber]*mflat**2)
                    if norm<=0 :
                        continue
                    s=np.median(flat.fiberflat[fiber]/mflat)
                    flats.append(flat.fiberflat[fiber]/s)
                mflat=np.median(np.array(flats),axis=0)
                flat.fiberflat /= (mflat+(mflat==0))

                fiberflats_in_bin.append(flat.fiberflat)
                humidities_in_bin.append(np.mean(humidity_table["HUMIDITY"][humidity_table["NIGHT"]==night]))

            if len(humidities_in_bin)<1 : continue

            log.info("now median in bin")
            hdulist=pyfits.HDUList([pyfits.PrimaryHDU(wave)])
            hdulist[0].header["EXTNAME"]="WAVELENGTH"
            hdulist[0].header["CAMERA"]=cam

            fiberflats_in_bin = np.array(fiberflats_in_bin)
            fiberflat_in_bin  = np.ones((fiberflats_in_bin.shape[1],fiberflats_in_bin.shape[2]))
            for fiber in range(fiberflat_in_bin.shape[0]) :
                # median value of flat for this fiber across wavelength for each night
                medval=np.median(fiberflats_in_bin[:,fiber,:],axis=-1)
                # and across nights, selecting the ones with the largest value
                ii=np.argsort(medval)[-5:]
                medmedval=np.median(medval[ii])
                if medmedval<0.01 :
                    log.warning(f"median flat={medmedval}, no good data for fiber {fiber} in humidity bin {bin_index}")
                    continue
                # keep only measurements where median flat is stable
                good=np.abs(medval/medmedval-1)<0.1
                if np.sum(good)<1 :
                    log.warning(f"no good data for fiber {fiber} with median flats={medval} in humidity bin {bin_index}")
                    continue
                # normalize correction per fiber by median flat
                norm = np.median(fiberflats_in_bin[good,fiber])
                fiberflats_in_bin[good,fiber] /= norm
                fiberflat_in_bin[fiber] = np.median(fiberflats_in_bin[good,fiber],axis=0)

            # filtering across fibers
            fiberflat_in_bin = median_filter(fiberflat_in_bin,[5,1])

            hdulist.append(pyfits.ImageHDU(fiberflat_in_bin,name=extname))
            hdulist[extname].header["MEDHUM"]=np.median(humidities_in_bin)
            hdulist[extname].header["MINHUM"]=np.min(humidities_in_bin)
            hdulist[extname].header["MAXHUM"]=np.max(humidities_in_bin)

            # now measure the DIP WAVE
            if cam == "b4" :
                dipwave = 0 # no dip !
            else :
                dipwave = fit_wave_of_dip(wave,fiberflat_in_bin)
            hdulist[extname].header["DIPWAVE"]=dipwave
            hdulist.writeto(ofilename,overwrite=True)
            log.info(f"wrote {ofilename}")


        hdulist=None

        # merging
        # loop on bins
        for bin_index in range(nbins-1) :
            extname="HUM{:02d}".format(bin_index)
            filename=f"{args.outdir}/fiberflat-vs-humidity-{cam}-{extname}.fits"
            if not os.path.isfile(filename): continue
            log.info("adding "+filename)
            h=pyfits.open(filename)
            data=h[extname].data
            # identify null data
            rms=np.std(data,axis=1)
            nzero=np.sum(rms<1e-6)
            if nzero > 0 :
                log.warning("{} nzero={}".format(filename,nzero))
                continue


            if hdulist==None :
                hdulist=h
                continue
            hdulist.append(h[extname])

        ofilename=f"{args.outdir}/fiberflat-vs-humidity-{cam}.fits"
        hdulist.writeto(ofilename,overwrite=True)
        log.info(f"wrote {ofilename}")

if __name__ == "__main__":
    main()
