#!/usr/bin/env python
#
# See top-level LICENSE.rst file for Copyright information
#
# -*- coding: utf-8 -*-

"""
This script averages the flux calibration for a DESI spectrograph camera.
"""

from desispec.io.fluxcalibration import read_flux_calibration,write_average_flux_calibration,read_average_flux_calibration
from desispec.averagefluxcalibration import AverageFluxCalib
from desiutil.log import get_logger
from desispec.io.util import replace_prefix

import argparse
import os
import os.path
import numpy as np
import sys
import fitsio
import scipy.interpolate
from importlib import resources
from desimodel.fastfiberacceptance import FastFiberAcceptance
import desimodel.io
from glob import glob
import yaml
import matplotlib.pyplot as plt


def get_ffracflux_wave(seeing, ffracflux, wave, fac_wave_power):
    """
    Compute the fiber acceptance for a given seeing, and normalize it at ffracflux at 6500A

    Args:
        seeing : in the r-band from the GFA [arcsec]
        ffracflux : FIBER_FRACFLUX from GFA
        wave : wavelength for output [A]
        fac_wave_power : shape of the wavelength-dependence of seeing (wave ** fac_wave_power)
        fdiam_um : physical fiber diameter [um]
        fdiam_asec : fiber diameter [arcsec]

    Returns:
        Fiber acceptance for seeing, normalized to ffracflux at 6500A


    Comment:
        GFA now provide ffrac for fiber_diameter=1.52", so no need to correct for different diameters
        we do the normalization using fac_wave, and then to be interpolated to wave
    """
    #
    desi = desimodel.io.load_desiparams()
    fdiam_um = desi['fibers']['diameter_um'] # AR 107 um: physical fiber diameter
    fdiam_avg_asec = desi['fibers']['diameter_arcsec'] # AR 1.52 arcsec ; average fiber diameter
    #
    fac = FastFiberAcceptance("{}/data/throughput/galsim-fiber-acceptance.fits".format(os.getenv("DESIMODEL")))
    fac_wave = np.arange(3500, 10000) # AR to do the fac calculation, and then to be interpolated to wave
    seeing_wave = seeing * (fac_wave / 6500) ** args.fac_wave_power
    sigma = seeing_wave / 2.35  * (fdiam_um / fdiam_avg_asec)
    offset = np.zeros(sigma.shape)
    ffracflux_wave = fac.value("POINT", sigma, offset) # AR fac for seeing[i] and fdiam_avg_asec
    ffracflux_wave /= ffracflux_wave[np.abs(fac_wave-6500).argmin()] # AR normalizing to 1 at 6500A
    ffracflux_wave *= ffracflux # AR normalising to ffracflux[i] at 6500A
    return np.interp(wave, fac_wave, ffracflux_wave) # AR interpolating on wave



if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Compute the average calibration for a DESI spectrograph camera using precomputed flux calibration vectors.")

    parser.add_argument('--gfa_matched_coadd', type = str, default = None, required=True, help = 'full path to the GFA matched_coadd file, e.g. $DESI_ROOT/survey/GFA/offline_matched_coadd_ccds_SV1-thru_20210422.fits')
    parser.add_argument('-i','--infiles', type = str, default = None, required=True, help = 'path to ASCII file with full path to DESI frame calib fits files (1 per line)')
    parser.add_argument('-o','--outfile', type = str, default = None, required=True, help = 'output calibration file')
    parser.add_argument('--plot', action = 'store_true', help = 'plot the result')
    parser.add_argument('--no-airmass-term', action = 'store_true', help = 'do not try to estimate an airmass term')
    parser.add_argument('--no-seeing-term', action = 'store_true', help = 'do not try to estimate a seeing term')
    parser.add_argument('--corr_transp', action = 'store_true', help = 'correct for transparency')
    parser.add_argument('--seeing_key', type = str, default = 'RADPROF_FWHM_ASEC', choices = ['FWHM_ASEC', 'RADPROF_FWHM_ASEC'], help = 'key in GFA catalog for seeing (default=RADPROF_FWHM_ASEC)')
    parser.add_argument('--fac_wave_power', type = float, default = -0.25, help = 'wavelength dependence of fiber acceptance (default=-0.25)')
    parser.add_argument('--first_night', type = int, default = None, help = 'first night for which this calibration is usable (default=None)')
    parser.add_argument('--unflat', action='store_true', help="undo flatfield correction to get the variation of thru from spectro to spectro")

    args = parser.parse_args()
    log=get_logger()
    for kwargs in args._get_kwargs():
        log.info(kwargs)

    # read the data
    ###########################################################

    camera=None
    arm=None
    spec=None
    wave=None
    calibs=list()
    airmass=list()
    seeing=list()

    with_airmass = not args.no_airmass_term
    with_seeing  = not args.no_seeing_term

    # AR GFA information: airmass, seeing, transparency, fiber_fracflux
    # AR working only with airmass and fiber_fracflux
    # AR now the filename is provided as input argument
    # AR we restrict to matched_coadd, as exposures not in matched_coadd are a few
    # AR and likely problematic
    # AR as of early Mar. 2021: GFA now reports fiber_fracflux for a 1.52" diameter fiber
    filenames = np.loadtxt(args.infiles, dtype=np.str)
    log.info("{} files in {}".format(len(filenames), args.infiles))
    expids = np.array([fitsio.read_header(fn)["EXPID"] for fn in filenames])
    airmass = np.nan + np.zeros(len(filenames))
    seeing = np.nan + np.zeros(len(filenames))
    transparency = np.nan + np.zeros(len(filenames))
    ffracflux = np.nan + np.zeros(len(filenames))
    # AR matched_coadd
    # AR we discard the exposure if not present in matched_coadd
    unq_expids = np.unique(expids)
    gfa = fitsio.read(args.gfa_matched_coadd, ext=3)
    _, ii, gfa_ii = np.intersect1d(unq_expids, gfa["EXPID"], return_indices=True)
    for i, gfa_i in zip(ii, gfa_ii):
        jj = np.where(expids == unq_expids[i])[0]
        airmass[jj] = gfa["AIRMASS"][gfa_i]
        seeing[jj] = gfa[args.seeing_key][gfa_i]
        transparency[jj] = gfa["TRANSPARENCY"][gfa_i]
        ffracflux[jj] = gfa["FIBER_FRACFLUX"][gfa_i]
    # AR cutting the input file list on exposures having - meaningful - GFA measurements
    keep = (airmass >= 1.0) & (seeing >= 0) & (ffracflux >= 0)
    if args.corr_transp:
        keep &= transparency >= 0
    airmass, seeing, transparency, ffracflux = airmass[keep], seeing[keep], transparency[keep], ffracflux[keep]
    log.info("discarding the {} following exposures (no GFA): {}".format((~keep).sum(), ",".join(filenames[~keep])))
    filenames = filenames[keep]

    # AR storing indexes kept in the loop
    ii = []
    # AR looping on filenames
    for i in range(len(filenames)):
        filename = filenames[i]
        log.info("reading {}".format(filename))
        header=fitsio.read_header(filename)

        if camera is None :
            camera=header["camera"].strip().lower()
            arm=camera[0]
            spec=camera[-1]
        else :
            assert(arm==header["camera"].strip().lower()[0])

        exptime = float(header["exptime"])

        cal=read_flux_calibration(filename)
        if np.any(np.isnan(cal.calib)) :
            log.error("calib has nan")
            continue

        if args.unflat :
            ffilename=replace_prefix(filename, "fluxcalib", "fiberflatexp")
            log.info(f"reading {ffilename}")
            flat=fitsio.read(ffilename)
            cal.calib *= flat

        mcalib = np.median(cal.calib,axis=0)

        # undo the heliocentric/barycentric correction
        #- check for heliocentric correction
        if 'HELIOCOR' in header.keys() :
            heliocor=header['HELIOCOR']
            log.info("Undo the heliocentric correction scale factor {} of the calib".format(heliocor))
            # wavelength are in solar system frame
            # first divide the multiplicative factor to have wavelength in KPNO frame
            wave_in_kpno_system = cal.wave/heliocor

            # now we want the wave grid cal.wave to be in the KPNO frame,
            # we have to resample the calib vector
            mcalib = np.interp(cal.wave,wave_in_kpno_system,mcalib)

        if wave is None :
            wave=cal.wave
        else :
            assert(np.all(np.abs(wave-cal.wave)<0.0001))

        if exptime <= 2. : # arbitrary cutoff
            print("skip exptime=",exptime)
            continue

        # AR GFA now provide ffrac for fiber_diameter=1.52", so no need to correct for different diameters
        # AR we do the normalization using fac_wave
        ffracflux_wave = get_ffracflux_wave(seeing[i], ffracflux[i], wave, args.fac_wave_power)
        # AR normalizing
        norm = exptime * ffracflux_wave
        text = "applying calibs.append(mcalib / exptime / ffracflux_wave_interp"
        if args.corr_transp:
            norm *= transparency[i]
            text += " / transparency"
        text += ")"
        log.info(text)
        calibs.append(mcalib / norm)
        #
        ii += [i]

    # AR cutting airmass, seeing, ffracflux on kept exposures
    airmass, seeing, ffracflux = airmass[ii], seeing[ii], ffracflux[ii]

    # AR we multiply by the ffracflux_wave from median conditions
    # AR to have a curve including the fiber acceptance
    # AR in the initial spirit of this script
    # AR not sure one needs to normalize by transparency if args.corr_transp; not done here
    median_ffracflux_wave = get_ffracflux_wave(np.median(seeing), np.median(ffracflux), wave, args.fac_wave_power)
    log.info("multiplying by the ffracflux_wave from median conditions")
    calibs = [calib * median_ffracflux_wave for calib in calibs]
    #
    calibs=np.array(calibs)
    nexp=calibs.shape[0]

    # compute an average calibration vector
    ###########################################################

    scalefactor=np.ones(nexp)
    ncalibs=calibs.copy()

    for loop in range(2) :
        mcalib=np.median(calibs,axis=0)
        mcalib2=np.sum(mcalib**2)
        scalefactor=np.sum(mcalib*calibs,axis=-1)/mcalib2
        scalefactor /= np.mean(scalefactor)
        for e in range(nexp) :
            ncalibs[e] = calibs[e]/scalefactor[e]

    mcalib=np.median(ncalibs,axis=0)

    # interpolate over brigh sky lines
    skylines=np.array([5578.94140625,5656.60644531,6302.08642578,6365.50097656,6618.06054688,7247.02929688,7278.30273438,7318.26904297,7342.94482422,7371.50878906,8346.74902344,8401.57617188,8432.42675781,8467.62011719,8770.36230469,8780.78027344,8829.57421875,8838.796875,8888.29394531,8905.66699219,8922.04785156,8960.48730469,9326.390625,9378.52734375,9442.37890625,9479.48242188,9569.9609375,9722.59082031,9793.796875])
    skymask=np.zeros(wave.size)
    for line in skylines :
        skymask[np.abs(wave-line)<2]=1
    mcalib=np.interp(wave,wave[skymask==0],mcalib[skymask==0])

    # telluric mask
    ###########################################################
    telluricmask=np.zeros(wave.size)
    # mask telluric lines
    srch_filename = "data/arc_lines/telluric_lines.txt"
    if not resources.files('desispec').joinpath(srch_filename).is_file():
        log.error("Cannot find telluric mask file {:s}".format(srch_filename))
        raise Exception("Cannot find telluric mask file {:s}".format(srch_filename))
    telluric_mask_filename = resources.files('desispec').joinpath(srch_filename)
    telluric_features = np.loadtxt(telluric_mask_filename)
    log.debug("Masking telluric features from file %s"%telluric_mask_filename)
    for feature in telluric_features :
        telluricmask[(wave>=feature[0])&(wave<=feature[1])]=1


    # fit a smooth multiplicative component per exposure
    ###########################################################
    res=300. #A, resolution of spline
    tmpwave=np.linspace(wave[0]+res/2,wave[-1]-res/2,int((wave[-1]-wave[0])/res))
    # keep only wave knots outside of telluric mask
    twave=list()
    for w in tmpwave :
        ok=True
        for feature in telluric_features :
            ok &= ((w<feature[0])|(w>feature[1]))
        if ok :
            twave.append(w)
    twave=np.array(twave)

    ncalibs=calibs.copy()
    for e in range(nexp) :
        # define a spline
        dcal=calibs[e]/(mcalib+(mcalib==0))
        tck=scipy.interpolate.splrep(wave,dcal,w=(mcalib>0)*(skymask==0)*(telluricmask==0),task=-1,t=twave,k=3)
        correction = scipy.interpolate.splev(wave,tck)
        if np.any(np.isnan(correction)) :
            log.warning("correction has nan for exp={}".format(e))
        else :
            ncalibs[e][telluricmask==0]=(mcalib*correction)[telluricmask==0] # this does not have any low frequency noise any more, except on telluric features

    # fit mean + several multiplicative smooth components
    ###########################################################
    # treated as additive components in log:
    # because there is virtually no more noise,
    # we can now take the log to separate the components,
    # without the risk of introducing a bias

    # log of calib. vectors
    lcal = -2.5*np.log10(ncalibs*(ncalibs>0)+(ncalibs<=0))
    weight=(lcal!=0)

    # fit a model with several terms
    # mean , derivative with seeing, derivative with airmass
    npar=1
    if with_airmass: npar +=1
    if with_seeing: npar +=1


    B = np.zeros((npar,wave.size))
    A = np.zeros((npar,npar,wave.size))
    M = np.zeros((npar,wave.size))
    X = np.ones((npar,nexp))
    index=1
    if with_airmass:
        pivot_airmass = np.median(airmass)
        X[index] = airmass-pivot_airmass
        airmass_index = index
        index += 1
    if with_seeing:
        pivot_seeing  = np.median(seeing)
        X[index] = seeing-pivot_seeing
        seeing_index = index
        index += 1

    ME = np.zeros((npar,wave.size)) # error on model

    lmod   = X.T.dot(M)
    res    = lcal-lmod
    for i in range(npar) :
        B[i]   = np.sum(weight*res*X[i][:,None],axis=0)
        for j in range(i,npar) :
            A[i,j] = np.sum(weight*X[i][:,None]*X[j][:,None],axis=0)
            if j !=i : A[j,i] = A[i,j]

    # solve the linear system per wavelength
    for i in range(wave.size) :
        if A[0,0,i]>0 :
            Aii=np.linalg.inv(A[:,:,i])
            dMi=Aii.dot(B[:,i])
            M[:,i] += dMi
            rms = np.std(res[:,i]-M[:,i].dot(X))
            ME[:,i] = np.sqrt(np.diag(Aii))*rms # approximate uncertainty on model

    if with_seeing :
        # smooth the seeing term over telluric features, freeze it and refit the two
        # other terms on telluric features
        tck=scipy.interpolate.splrep(wave,M[seeing_index],task=-1,t=twave,k=3)
        M[seeing_index] = scipy.interpolate.splev(wave,tck)
        # this estimation of uncertainty is very approximative
        tck=scipy.interpolate.splrep(wave,ME[seeing_index],task=-1,t=twave,k=3)
        ME[seeing_index] = scipy.interpolate.splev(wave,tck)

    lmod   = X.T.dot(M)
    res    = lcal-lmod
    nparbis = npar
    if with_seeing : nparbis -= 1
    B = np.zeros((nparbis,wave.size))
    A = np.zeros((nparbis,nparbis,wave.size))
    for i in range(nparbis) :
        B[i]   = np.sum(weight*res*X[i][:,None],axis=0)
        for j in range(i,nparbis) :
            A[i,j] = np.sum(weight*X[i][:,None]*X[j][:,None],axis=0)
            if j !=i : A[j,i] = A[i,j]
    # solve the linear system per wavelength
    for i in range(wave.size) :
        if A[0,0,i]>0 :
            Aii=np.linalg.inv(A[:,:,i])
            dMi=Aii.dot(B[:,i])
            M[:nparbis,i] += dMi
            rms = np.std(res[:,i]-M[:nparbis,i].dot(X[:nparbis]))
            ME[:nparbis,i] = np.sqrt(np.diag(Aii))*rms # approximate uncertainty on model

    # interpolate over sky lines
    for c in range(npar) :
        M[c]=np.interp(wave,wave[skymask==0],M[c][skymask==0])


    if with_airmass :
        airmass_term=M[airmass_index]
        airmass_term_uncertainty=ME[airmass_index]
    else :
        airmass_term=np.zeros(M[0].size)
        airmass_term_uncertainty=np.zeros(M[0].size)
        pivot_airmass=1

    if with_seeing :
        seeing_term=M[seeing_index]
        seeing_term_uncertainty=ME[seeing_index]
    else :
        seeing_term=np.zeros(M[0].size)
        seeing_term_uncertainty=np.zeros(M[0].size)
        pivot_seeing=1

    fluxcal=AverageFluxCalib(wave=wave,
                             average_calib=10**(-0.4*M[0]),
                             atmospheric_extinction=airmass_term,
                             seeing_term=seeing_term,
                             pivot_airmass=pivot_airmass,
                             pivot_seeing=pivot_seeing,
                             atmospheric_extinction_uncertainty=airmass_term_uncertainty,
                             seeing_term_uncertainty=seeing_term_uncertainty,
                             median_seeing=np.median(seeing),
                             median_ffracflux=np.median(ffracflux),
                             fac_wave_power=args.fac_wave_power,
                             ffracflux_wave=median_ffracflux_wave,
                             first_night=args.first_night)

    # write the result ...

    write_average_flux_calibration(args.outfile,fluxcal)
    log.info("wrote {}".format(args.outfile))

    if args.plot :

        fluxcal = read_average_flux_calibration(args.outfile)

        plt.figure()

        a=plt.subplot(3,1,1)
        a.plot(fluxcal.wave,fluxcal.value(airmass=pivot_airmass,seeing=pivot_seeing))
        a.set_ylabel("Mean calibration")

        a=plt.subplot(3,1,2)
        a.errorbar(fluxcal.wave,fluxcal.atmospheric_extinction,fluxcal.atmospheric_extinction_uncertainty)
        a.set_ylabel(r"Atmospheric extinction")# $d LC / d airmass$, with $LC=-2.5 log_{10}(calib)$")

        a=plt.subplot(3,1,3)
        a.errorbar(fluxcal.wave,fluxcal.seeing_term,fluxcal.seeing_term_uncertainty)
        a.set_ylabel(r"Seeing term")# $d LC/d seeing$, with $LC=-2.5 log_{10}(calib)$ ")
        a.set_xlabel("Wavelength [A]")

        plt.show()
