#!/usr/bin/env python

"""Build stellar population synthesis (SPS) template files for FastSpecFit.

This script generates the ``ftemplates-{imf}-{version}.fits`` files that
FastSpecFit uses to model galaxy stellar continua, nebular emission, dust
emission, AGN torus emission, and UV iron emission. Running it requires
``python-fsps`` (and a compiled FSPS library) installed in the active
environment — see the Notes section for setup instructions.

Usage
-----
::

  time python bin/build-templates --version 2.0.0 \\
      --templatedir $DESI_ROOT/users/ioannis/fastspecfit/templates

  time python bin/build-templates --version 2.0.0 \\
      --templatedir $DESI_ROOT/users/ioannis/fastspecfit/templates \\
      --logmets=-1.,0.,0.3

Input Files
-----------
The script expects the following raw data files under
``<templatedir>/original/``:

- ``DL07_MW3.1_{00,10,20,30,40,50,60}.dat`` — Draine & Li (2007) dust
  emission grids (Table 3 of https://arxiv.org/pdf/astro-ph/0608003).
- ``fetemplates-1.0.fits`` — Vestergaard & Wilkes (2001) UV Fe emission
  template (``VW01ALT`` extension).
- ``Nenkova08_y010_torusg_n10_q2.0.dat`` — Nenkova et al. (2008) clumpy
  AGN torus model grid.

Stellar Population Model
------------------------
Stellar spectra are computed with ``python-fsps`` using:

- Chabrier (2003) IMF (default; Salpeter and Kroupa also supported via
  ``--imf``)
- MIST isochrones
- C3K_a stellar spectral library (R ~ 3000 at 2750-9100 A)
- Tabular SFH: 5 variable-width age bins from 30 Myr to 13.7 Gyr with
  constant star formation within each bin, following the Prospector
  ``adjust_continuity_agebins`` prescription
- ``zcontinuous=1`` (metallicity interpolated in log space)
- Two passes per metallicity grid point: one with nebular emission
  (``nebemlineinspec=True``) and one without, so the pure line
  contribution can be isolated as ``LINEFLUX = FLUX - FLUX_nolines``

Output FITS Extensions
----------------------
The output ``ftemplates-{imf}-{version}.fits`` file contains the
following extensions (in order):

``FLUX`` (Primary HDU, float64 [nwave x nmodel])
    Stellar continuum plus nebular emission in units of
    erg/s/cm2/Angstrom at 10 pc per solar mass formed.

``LINEFLUX`` (float64 [nwave x nmodel])
    Isolated nebular emission (``FLUX`` minus the continuum-only pass).

``DUSTFLUX`` (float64 [nwave])
    Draine & Li (2007) dust emission spectrum interpolated onto the
    ``WAVE`` grid and normalized to unit bolometric luminosity.
    Header keywords: ``QPAH``, ``UMIN``, ``GAMMA``.

``WAVE`` (float64 [nwave])
    Vacuum wavelength array in Angstroms. Sampled at constant
    log-lambda (``PIXKMS`` km/s pixels) within ``PIXKMS_BOUNDS``;
    native FSPS spacing outside that range.

``METADATA`` (table, nmodel rows)
    Per-template age (Gyr), Z/Zsun, stellar mass, and SFR.

``AGNFLUX`` (float64 [nagnwave])
    Nenkova et al. (2008) clumpy AGN torus spectrum normalized to unit
    bolometric luminosity. Header keyword: ``AGNTAU``.

``AGNWAVE`` (float64 [nagnwave])
    Vacuum wavelength array for the AGN torus spectrum in Angstroms.

``FEFLUX`` (float64 [nfewave])
    Vestergaard & Wilkes (2001) UV Fe emission template resampled onto
    a constant log-lambda grid at ``AGN_PIXKMS`` km/s per pixel and
    normalized to the median flux.

``FEWAVE`` (float64 [nfewave])
    Vacuum wavelength array for the Fe emission template in Angstroms.

``LINEFLUXES`` (float64 [nlines x nmodel])
    Per-template FSPS emission-line luminosities in
    erg/s/cm2/Angstrom at 10 pc per solar mass formed.

``LINEWAVES`` (float64 [nlines])
    Vacuum wavelengths of the FSPS emission lines in Angstroms.

Notes
-----
* To switch stellar libraries (e.g., MILES vs. C3K_a), edit
  ``$SPS_HOME/src/sps_vars.f90`` and reinstall ``python-fsps`` from
  scratch in a clean environment::

    micromamba create -y -n fsps python
    micromamba activate fsps
    cd $HOME/code
    git clone --recursive https://github.com/dfm/python-fsps.git
    export SPS_HOME=$HOME/code/python-fsps/src/fsps/libfsps
    cd python-fsps
    python -m pip install . --no-cache-dir --force-reinstall
    cd
    python -c "import fsps; sp = fsps.StellarPopulation(); print(sp.libraries)"

* The C3K_a library has resolution R(lambda/FWHM) = 3000, equivalently
  R(lambda/dlambda) = 7065 (42 km/s), over 2750-9100 A:
  https://github.com/cconroy20/fsps/tree/master/SPECTRA/C3K#readme

* Figure 3 of Leja et al. (2017) illustrates how the various free SPS
  parameters affect the resulting SED.

"""
import os, time
import numpy as np
import numpy.ma as ma
import fitsio
from scipy.ndimage import gaussian_filter1d
from scipy.interpolate import interpn

from astropy.io import fits
from astropy.table import Table
import matplotlib.pyplot as plt

from desispec.interpolation import resample_flux
from fastspecfit.util import C_LIGHT, trapz
from fastspecfit.templates import Templates


def _qa_dustem_templates(dustwave, mduste, qpah, umin, gamma, png):
    """QA plot comparing DL07 dust emission model against FSPS."""
    # check against FSPS using a simple (age-independent) Charlot & Fall
    # dust model

    import fsps
    import seaborn as sns

    def attenuation(tauv, wave, alpha=-0.7):
        return np.exp(-tauv * (wave / 5500.)**alpha)

    tauv = 3.

    sp = fsps.StellarPopulation(imf_type=1, dust_type=0,
                                dust_index=-0.7, sfh=0, # SSP parameters
                                zcontinuous=1)

    # intrinsic (dust-free) spectrum with no dust emission
    sp.params["dust2"] = 0
    sp.params["add_dust_emission"] = False
    wave, flux = sp.get_spectrum(tage=0.01, peraa=True)

    # build the attenuated spectrum + dust emission using FSPS; take this
    # spectrum as "truth"
    sp.params["dust2"] = tauv
    sp.params["add_dust_emission"] = True
    _, fsps_dflux = sp.get_spectrum(tage=0.01, peraa=True)

    # For the purposes of the QA, interpolate the dust emission spectrum
    # onto the FSPS wavelength grid.
    mduste = np.interp(wave, dustwave, mduste, left=0.)

    # now do the dust emission calculation ourselves
    atten = attenuation(tauv, wave)
    dflux = flux * atten

    # normalize the dust emission to the luminosity absorbed by dust, i.e.,
    # keeping Lbol constant
    lbold = trapz(dflux, x=wave) # attenuated energy
    lboln = trapz(flux, x=wave)  # intrinsic energy

    labs = lboln - lbold
    norm = trapz(mduste, x=wave) # should already be 1.0
    duste = mduste * labs / norm

    orig_duste = duste.copy()

    # handle dust self-absorption (algorithm taken from fsps.add_dust)
    iiter = 0
    tduste = 0.
    while (lboln - lbold) > 1e-2 or iiter < 5:
        oduste = duste.copy()
        duste *= atten # attenuate
        tduste += duste

        lbold = trapz(duste, x=wave)  # after self-absorption
        lboln = trapz(oduste, x=wave) # before self-absorption
        duste = mduste * (lboln - lbold) / norm
        #print(lboln - lbold)

        iiter += 1

    plotwave = wave / 1e4

    xlim = (0.1, 500.)
    I = (plotwave > xlim[0]) * (plotwave < xlim[1])

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True)

    ax1.plot(plotwave[I], flux[I], label='Unattenuated spectrum')
    ax1.plot(plotwave[I], (dflux + tduste)[I], alpha=0.7, label='With self-absorption')
    ax1.plot(plotwave[I], (dflux + orig_duste)[I], alpha=0.7, label='No self-absorption')
    ax1.plot(plotwave[I], fsps_dflux[I], alpha=0.7, label='FSPS')
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.legend(fontsize=10)

    ax2.plot(plotwave[I], (dflux + tduste)[I]/fsps_dflux[I]-1, alpha=0.7, label='No self-absorption')
    ax2.plot(plotwave[I], (dflux + orig_duste)[I]/fsps_dflux[I]-1, alpha=0.7, label='With self-absorption')
    ax2.set_xscale('log')
    ax2.legend(fontsize=10)
    fig.tight_layout()
    fig.savefig(png)


def build_dustem_templates(qpah=3.5, umin=1., gamma=0.01, png=None):
    """Build the DL07 dust emission templates.

    Parameters
    ----------
    qpah : :class:`float`, optional
        PAH fraction. Default is 3.5.
    umin : :class:`float`, optional
        Minimum radiation field strength. Default is 1.0.
    gamma : :class:`float`, optional
        Fraction of dust heated by PDRs. Default is 0.01.
    png : :class:`str` or None, optional
        If not ``None``, write a QA figure to this path.

    Returns
    -------
    dustwave : :class:`numpy.ndarray`
        Wavelength array in Angstroms.
    mduste : :class:`numpy.ndarray`
        Normalized dust emission spectrum.

    """
    # Umin and qpah parameter grids
    umin_grid = np.array([0.10, 0.15, 0.20, 0.30, 0.40, 0.50, 0.70, 0.80,
                          1.00, 1.20, 1.50, 2.00, 2.50, 3.00, 4.00, 5.00, 
                          7.00, 8.00, 12.0, 15.0, 20.0, 25.0])
    numin = len(umin_grid)

    # from Table 3 in https://arxiv.org/pdf/astro-ph/0608003
    d_qpah = {
        'MW3.1_00': 0.47,
        'MW3.1_10': 1.12,
        'MW3.1_20': 1.77,
        'MW3.1_30': 2.50,
        'MW3.1_40': 3.19,
        'MW3.1_50': 3.90,
        'MW3.1_60': 4.58,
        }
    models = list(d_qpah.keys())
    qpah_grid = np.array([d_qpah[model] for model in models])
    nqpah = len(qpah_grid)

    nwave = 1001 # fixed number of wavelengths
    dustem = np.zeros((nqpah, numin*2, nwave))

    for imodel, model in enumerate(models):
        dustfile = os.path.join(templatedir, 'original', f'DL07_{model}.dat')
        dustem1 = np.loadtxt(dustfile, skiprows=2)
        if imodel == 0:
            dustwave = dustem1[:, 0]
        dustem[imodel, :, :] = dustem1[:, 1:].T
    dustwave *= 1e4 # [mu --> Angstrom]

    # set up the linear interpolation objects over Umin and qpah
    #umin_interp = RegularGridInterpolator((qpah_grid, umin_grid), dustem[:, :22, :], method='linear')
    #umax_interp = RegularGridInterpolator((qpah_grid, umin_grid), dustem[:, 22:, :], method='linear')
    #dumin = umin_interp((qpah, umin))
    #dumax = umax_interp((qpah, umin))
    dumin = interpn((qpah_grid, umin_grid, dustwave), dustem[:, :22, :], (qpah, umin, dustwave))
    dumax = interpn((qpah_grid, umin_grid, dustwave), dustem[:, 22:, :], (qpah, umin, dustwave))

    # construct P(U)dU as a weighted average of dumin and dumax
    mduste = (1. - gamma) * dumin + gamma * dumax

    # convert to F_lambda (from F_nu) and normalize the spectrum to unity
    mduste /= dustwave**2
    norm = trapz(mduste, x=dustwave)
    if norm > 0.:
        mduste /= norm

    if png:
        _qa_dustem_templates(dustwave, mduste, qpah, umin, gamma, png)

    return dustwave, mduste


def build_fe_templates(version='1.0', png=None):
    """Build the Fe emission template.

    Parameters
    ----------
    version : :class:`str`, optional
        Template file version string. Default is ``'1.0'``.
    png : :class:`str` or None, optional
        If not ``None``, write a QA figure to this path.

    Returns
    -------
    newwave : :class:`numpy.ndarray`
        Wavelength array in Angstroms on a constant log-lambda grid.
    newflux : :class:`numpy.ndarray`
        Normalized Fe emission flux spectrum.

    """
    from scipy.ndimage import gaussian_filter1d

    def _qa_fe_templates(fewave, feflux, title):
        fig, ax = plt.subplots()
        ax.plot(fewave, feflux, color='k', lw=1.5, label='Original', alpha=0.5)
        for vdisp in np.logspace(np.log10(150.), np.log10(1e4), 5):
            sigma = vdisp / Templates.AGN_PIXKMS # [pixels]
            sfeflux = gaussian_filter1d(feflux, sigma=sigma)
            ax.plot(fewave, sfeflux, lw=2, label=r'$\sigma=$'+f'{vdisp:.0f} km/s', alpha=0.75)
        ax.set_xlabel(r'Wavelength ($\AA$)')
        ax.set_ylabel('Flux (arbitrary units)')
        #ax.set_xscale('log')
        #ax.set_yscale('log')
        ax.legend(loc='upper left', ncol=2)
        ax.set_title(title)
        fig.tight_layout()
        fig.savefig(png)

    templatefile = os.path.join(templatedir, 'original', f'fetemplates-{version}.fits')
    fe = fitsio.read(templatefile, 'VW01ALT')
    fewave = fe['WAVE'].astype('f8')
    feflux = fe['FLUX'].astype('f8')

    #ma.clump_masked(ma.array(feflux, mask=(fewave > 2750.) * (fewave < 2850.) * (feflux == 0.)))

    # rebin to constant log-lambda (i.e., constant velocity)
    dlogwave = Templates.AGN_PIXKMS / C_LIGHT / np.log(10) # pixel size [log-lambda]
    newwave = 10.**np.arange(np.log10(Templates.AGN_PIXKMS_BOUNDS[0]),
                             np.log10(Templates.AGN_PIXKMS_BOUNDS[1]),
                             dlogwave)
    newflux = resample_flux(newwave, fewave, feflux)

    #norm = trapz(newflux, x=newwave)
    norm = np.median(newflux)
    if norm > 0.:
        newflux /= norm

    if png:
        _qa_fe_templates(newwave, newflux, f'Vestergaard & Wilkes (2001) - v{version}')

    return newwave, newflux


def build_agn_templates(agntau=10., png=None):
    """Build the AGN (Nenkova+08 torus) template.

    Parameters
    ----------
    agntau : :class:`float`, optional
        AGN optical depth. Default is 10.0.
    png : :class:`str` or None, optional
        If not ``None``, write a QA figure to this path.

    Returns
    -------
    agnwave : :class:`numpy.ndarray`
        Wavelength array in Angstroms.
    agnflux : :class:`numpy.ndarray`
        Normalized AGN flux spectrum.

    """
    def _qa_agn_templates(agnwave, agnflux, agntau, agngrid=None, taugrid=None):
        plotwave = agnwave / 1e4
        #xlim = (0.1, 100.)
        xlim = (np.min(plotwave), 100.)
        I = (plotwave > xlim[0]) * (plotwave < xlim[1])

        fig, ax = plt.subplots()
        ax.plot(plotwave[I], agnflux[I], label=f'tau={agntau:.1f}', lw=2)
        if agngrid is not None:
            for ii in range(agngrid.shape[0]):
                ax.plot(plotwave[I], agngrid[ii, I] / trapz(agngrid[ii, :], x=agnwave),
                        label=f'{taugrid[ii]:.1f}', lw=0.5)
        ax.set_xlabel('Wavelength (micron)')
        ax.set_ylabel('Flux (arbitrary units)')
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_xlim(xlim)
        ax.legend()
        fig.tight_layout()
        fig.savefig(png)

    templatefile = os.path.join(templatedir, 'original', 'Nenkova08_y010_torusg_n10_q2.0.dat')
    data = np.loadtxt(templatefile, skiprows=4)
    agnwave = data[:, 0]    # [nwave, Angstrom]
    agngrid = data[:, 1:].T # [nagn, nwave]
    agngrid /= agnwave[np.newaxis, :]**2 # [F_nu --> F_lam]
    nwave = len(agnwave)

    tau_grid = np.array([5, 10, 20, 30, 40, 60, 80, 100, 150]).astype('f4')
    ntau = len(tau_grid)

    agnflux = interpn((tau_grid, agnwave), agngrid, (agntau, agnwave))
    #I = agnflux > 0.
    #agnflux = agnflux[I]
    #agnwave = agnwave[I]

    norm = trapz(agnflux, x=agnwave)
    if norm > 0.:
        agnflux /= norm

    if png:
        _qa_agn_templates(agnwave, agnflux, agntau, agngrid=agngrid, taugrid=tau_grid)

    return agnwave, agnflux


def build_fsps_templates(models, logages, agebins=None, imf='chabrier',
                         include_nebular=True):
    """Build FSPS stellar population synthesis template spectra.

    For each entry in ``models``, sets up a tabular star formation history
    with constant SFR over the corresponding age bin, calls
    ``python-fsps`` to compute the spectrum, and resamples the result onto
    a hybrid wavelength grid that is constant log-lambda (velocity) within
    ``Templates.PIXKMS_BOUNDS`` and native FSPS spacing outside that range.
    Called twice by ``main`` — once with and once without nebular emission —
    so that the pure line contribution can be isolated.

    Parameters
    ----------
    models : :class:`numpy.ndarray`
        Structured array with fields ``logmet`` (log10 Z/Zsun, float32)
        and ``logage`` (log10 mean age in yr, float32), one entry per
        model spectrum.
    logages : :class:`numpy.ndarray`
        Unique log10(age/yr) values, one per age bin. Used to look up
        which bin each model belongs to.
    agebins : :class:`numpy.ndarray` or None, optional
        Age bin edges of shape (nages, 2) in Gyr. Each row gives the
        [start, end] of one constant-SFR interval.
    imf : :class:`str`, optional
        Initial mass function passed to ``python-fsps``. One of
        ``'chabrier'`` (default), ``'salpeter'``, or ``'kroupa'``.
    include_nebular : :class:`bool`, optional
        If ``True`` (default), include nebular line and continuum emission
        in the output spectra.

    Returns
    -------
    meta : :class:`astropy.table.Table`
        Per-model metadata with columns ``age`` (Gyr), ``zzsun``
        (Z/Zsun), ``mstar`` (surviving stellar mass fraction), and
        ``sfr`` (star formation rate at the time of observation).
    newwave : :class:`numpy.ndarray`
        Vacuum wavelength array in Angstroms on the hybrid grid.
    fluxes : :class:`numpy.ndarray`
        Model spectra of shape (nwave, nmodel) in units of
        erg/s/cm2/Angstrom at 10 pc per solar mass formed.
    linewaves : :class:`numpy.ndarray`
        Vacuum wavelengths of the FSPS nebular emission lines in
        Angstroms.
    linefluxes : :class:`numpy.ndarray`
        Per-model emission-line luminosities of shape (nlines, nmodel)
        in units of erg/s/cm2/Angstrom at 10 pc per solar mass formed.

    """
    nsed = len(models)

    meta = Table()
    meta['age'] = 10**models['logage']
    meta['zzsun'] = models['logmet']
    meta['mstar'] = np.zeros(nsed, 'f4')
    meta['sfr'] = np.zeros(nsed, 'f4')

    # https://dfm.io/python-fsps/current/stellarpop_api/
    imfdict = {'salpeter': 0, 'chabrier': 1, 'kroupa': 2}

    print('Instantiating the StellarPopulation object...', end='')
    t0 = time.time()
    # tabular SFH
    sp = fsps.StellarPopulation(
        compute_vega_mags=False,
        add_dust_emission=False, # note!
        add_neb_emission=True,
        nebemlineinspec=include_nebular,
        imf_type=imfdict[imf],
        #dust_type=0,
        #dust_index=-0.7,
        sfh=3,  # tabular SFH parameters
        zcontinuous=1,
    )
    print('...took {:.3f} sec'.format((time.time()-t0)))
    print(sp.libraries)

    if include_nebular:
        print('Creating {} model spectra with nebular emission...'.format(nsed), end='')
    else:
        print('Creating {} model spectra without nebular emission...'.format(nsed), end='')

    t0 = time.time()
    for imodel, model in enumerate(models):
        sp.params['logzsol'] = model['logmet']

        # lookback time of constant SFR
        agebin_indx = np.where(model['logage'] == np.float32(logages))[0]
        agebin = agebins[agebin_indx, :][0] # Gyr
        fspstime = agebin - agebin[0]       # Gyr
        tage = agebin[1] # time of observation [Gyr]
        #print(tage, model['logage'])

        dt = np.diff(agebin) * 1e9          # [yr]
        sfh = np.zeros_like(fspstime) + 1. / dt #/ 2 # [Msun/yr]

        # force the SFR to go to zero at the edge
        fspstime = np.hstack((fspstime, fspstime[-1]*1.01))
        sfh = np.hstack((sfh, 0.))

        sp.set_tabular_sfh(fspstime, sfh)
        #print(tage, sp.sfr)

        wave, flux = sp.get_spectrum(tage=tage, peraa=True) # tage in Gyr

        lodot = 3.828e33 # [erg/s]
        tenpc2 = (10. * 3.085678e18)**2 # [cm^2]

        flux = flux * lodot / (4. * np.pi * tenpc2) # [erg/s/cm2/A/Msun at 10pc]

        # Resample to constant log-lambda / velocity. In the IR (starting at ~1
        # micron), take every fourth sampling, to save space.
        if imodel == 0:
            lo = np.searchsorted(wave, Templates.PIXKMS_BOUNDS[0], 'left')
            hi = np.searchsorted(wave, Templates.PIXKMS_BOUNDS[1], 'left')

            dlogwave = Templates.PIXKMS / C_LIGHT / np.log(10) # pixel size [log-lambda]
            optwave = 10.**np.arange(np.log10(wave[lo]), np.log10(wave[hi]), dlogwave)
            newwave = np.hstack((wave[:lo], optwave, wave[hi+1:]))
            npix = len(newwave)

            fluxes = np.zeros((npix, nsed), dtype=np.float64)

            # emission lines
            linewaves = sp.emline_wavelengths
            linefluxes = np.zeros((len(sp.emline_wavelengths), nsed), dtype=np.float64)

        newflux = resample_flux(newwave, wave, flux)

        fluxes[:, imodel] = newflux
        linefluxes[:, imodel] = sp.emline_luminosity * lodot / (4.0 * np.pi * tenpc2)

        meta['mstar'][imodel] = sp.stellar_mass
        meta['sfr'][imodel] = sp.sfr
        #print(tage, sp.formed_mass, sp.stellar_mass)
        if sp.stellar_mass < 0:
            raise ValueError('Stellar mass is negative!')

        #plt.clf()
        #I = np.where((wave > 3500) * (wave < 5600))[0]
        #J = np.where((newwave > 3500) * (newwave < 3600))[0]
        #plt.plot(wave[I], flux[I])
        #plt.plot(newwave[J], fluxes[J, imodel])
        #plt.savefig('junk.png')

    print('...took {:.3f} min'.format((time.time()-t0)/60.))

    return meta, newwave, fluxes, linewaves, linefluxes


def main(imf='chabrier', logmets=[0.0], qpah=1.0, umin=1., gamma=0.01,
         agntau=10., test=False, version='1.0.0', templatedir=None):
    """Build all templates and write the output FITS file.

    Orchestrates calls to :func:`build_fsps_templates` (twice — with and
    without nebular emission), :func:`build_dustem_templates`,
    :func:`build_agn_templates`, and :func:`build_fe_templates`, then
    assembles the results into a single multi-extension FITS file at
    ``<templatedir>/<version>/ftemplates-{imf}-{version}.fits``.

    Parameters
    ----------
    imf : :class:`str`, optional
        Initial mass function. One of ``'chabrier'`` (default),
        ``'salpeter'``, or ``'kroupa'``.
    logmets : :class:`list` of :class:`float`, optional
        log10(Z/Zsun) metallicity grid values. Default is ``[0.0]``
        (solar metallicity only). The production templates use
        ``[-1.0, 0.0, 0.3]``.
    qpah : :class:`float`, optional
        Draine & Li (2007) PAH mass fraction parameter. Default is 1.0.
    umin : :class:`float`, optional
        Draine & Li (2007) minimum interstellar radiation field strength.
        Default is 1.0.
    gamma : :class:`float`, optional
        Draine & Li (2007) fraction of dust heated by PDRs (power-law
        radiation field component). Default is 0.01.
    agntau : :class:`float`, optional
        Nenkova et al. (2008) AGN torus optical depth. Default is 10.0.
    test : :class:`bool`, optional
        If ``True``, override ``logmets`` with ``[0.0]`` for a quick
        single-metallicity test run. Default is ``False``.
    version : :class:`str`, optional
        Version string embedded in the output filename and FITS header,
        e.g. ``'2.0.0'``. Default is ``'1.0.0'``.
    templatedir : :class:`str` or None, optional
        Root template directory. Raw input files are read from
        ``<templatedir>/original/`` and the output FITS file is written
        to ``<templatedir>/<version>/``.

    """
    # AGN + Fe template(s)
    fewave, feflux = build_fe_templates(version='1.0')#, png=os.path.join(templatedir, 'original', 'qa-fe.png'))
    agnwave, agnflux = build_agn_templates(agntau=agntau, png=os.path.join(templatedir, 'original', 'qa-agn.png'))

    # dust emission template(s)
    dustwave, dustflux = build_dustem_templates(qpah=qpah, umin=umin, gamma=gamma)#,
                                              #png=os.path.join(templatedir, 'original', 'qa-dustem.png'))

    print(f'<IR Dust>: q_PAH = {qpah}, U_min = {umin}, gamma = {gamma}')

    # Choose lookback time bins.

    # from prospect.templates.adjust_continuity_agebins
    nbins = 5
    tuniv = 13.7
    tbinmax = (tuniv * 0.85) * 1e9
    lim1, lim2 = 7.4772, 8.0
    agelims = np.array([0, lim1] + np.linspace(lim2, np.log10(tbinmax), nbins-2).tolist() + [np.log10(tuniv*1e9)]) # log10(yr)
    agelims = 10.**agelims / 1e9 # [Gyr]

    agebins = np.array([agelims[:-1], agelims[1:]]).T # [Gyr]
    logages = np.log10(1e9*np.sum(agebins, axis=1) / 2) # mean age [log10(yr)] in each bin
    print('   <Ages>: '+', '.join(['{:.4f}'.format(10.**logage/1e9) for logage in logages]) + ' Gyr')

    nages = len(logages)
    nmets = len(logmets)
    zsolar = 0.019
    print(' <Z/Zsun>: '+', '.join(['{:.2f}'.format(10.**logmet) for logmet in logmets]))

    # for testing
    if test:
        logmets = [0.]
        nmets = 1

    dims = (nages, nmets)

    models_dtype = np.dtype(
        [('logmet', np.float32),
         ('logage', np.float32)])

    # Let's be pedantic about the procedure so we don't mess up the indexing...
    models = np.zeros(dims, dtype=models_dtype)

    for iage, logage in enumerate(logages):
        for imet, logmet in enumerate(logmets):
            models[iage, imet]['logmet'] = logmet
            models[iage, imet]['logage'] = logage

    models = models.flatten()

    # Build models with and without line-emission.
    meta, wave, flux, linewaves, linefluxes = build_fsps_templates(
        models, logages, agebins=agebins, include_nebular=True, imf=imf)

    _, _, fluxnolines, _, _ = build_fsps_templates(
        models, logages, agebins=agebins, include_nebular=False, imf=imf)

    lineflux = flux - fluxnolines

    ## Convolve the line-free models (trimmed to the 1200-10000 A wavelength
    ## range) to the nominal velocity dispersion
    #I = np.where(wave < PIXKMS_WAVESPLIT)[0]
    ##I = np.where((wave > 1200) * (wave < PIXKMS_WAVESPLIT))[0]
    #vdispwave = wave[I]
    #vdispflux = gaussian_filter1d(fluxnolines[I, :], sigma=vdisp_nominal / PIXKMS_BLU, axis=1) # [npix,nmodel]

    # Interpolate the dust emission model(s) to the nominal wavelength array.
    dustflux = np.interp(wave, dustwave, dustflux, left=0.)

    #agnflux = 10.**np.interp(np.log10(wave), np.log10(agnwave), np.log10(agnflux+1e-30))-1e-30
    #I = wave <= agnwave[0] # do not extrapolate blueward
    #agnflux[I] = 0.

    # Write out.
    outdir = os.path.join(templatedir, version)
    if not os.path.isdir(outdir):
        os.makedirs(outdir, exist_ok=True)
    outfile = os.path.join(outdir, f'ftemplates-{imf}-{version}.fits')

    #isplit = np.argmin(np.abs(wave-PIXKMS_WAVESPLIT)) + 1

    hduflux1 = fits.PrimaryHDU(flux)
    hduflux1.header['EXTNAME'] = 'FLUX'
    hduflux1.header['VERSION'] = version
    hduflux1.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduflux2 = fits.ImageHDU(lineflux)
    hduflux2.header['EXTNAME'] = 'LINEFLUX'
    hduflux2.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    # dust emission and AGN spectra
    hduflux3 = fits.ImageHDU(dustflux)
    hduflux3.header['EXTNAME'] = 'DUSTFLUX'
    hduflux3.header['QPAH'] = (qpah, 'PAH fraction')
    hduflux3.header['UMIN'] = (umin, 'minimum radiation field strength')
    hduflux3.header['GAMMA'] = (gamma, 'gamma parameter')
    hduflux3.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduflux4 = fits.ImageHDU(agnflux)
    hduflux4.header['EXTNAME'] = 'AGNFLUX'
    hduflux4.header['AGNTAU'] = (agntau, 'AGN optical depth')
    hduflux4.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduflux5 = fits.ImageHDU(feflux)
    hduflux5.header['EXTNAME'] = 'FEFLUX'
    hduflux5.header['BUNIT'] = 'erg/(s cm2 Angstrom)'


    hduwave1 = fits.ImageHDU(wave)
    hduwave1.header['EXTNAME'] = 'WAVE'
    hduwave1.header['BUNIT'] = 'Angstrom'
    hduwave1.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')
    hduwave1.header['PIXWAVLO'] = (Templates.PIXKMS_BOUNDS[0], 'min(wave) where pixel size is PIXKMS [Angstrom]')
    hduwave1.header['PIXWAVHI'] = (Templates.PIXKMS_BOUNDS[1], 'max(wave) where pixel size is PIXKMS [Angstrom]')
    hduwave1.header['PIXKMS'] = (Templates.PIXKMS, 'pixel size blueward of PIXSZSPT [km/s]')

    hduwave2 = fits.ImageHDU(agnwave)
    hduwave2.header['EXTNAME'] = 'AGNWAVE'
    hduwave2.header['BUNIT'] = 'Angstrom'
    hduwave2.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')

    hduwave3 = fits.ImageHDU(fewave)
    hduwave3.header['EXTNAME'] = 'FEWAVE'
    hduwave3.header['BUNIT'] = 'Angstrom'
    hduwave3.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')
    hduwave3.header['PIXSZ'] = (Templates.AGN_PIXKMS, 'pixel size [km/s]')


    hdutable = fits.convenience.table_to_hdu(meta)
    hdutable.header['EXTNAME'] = 'METADATA'
    hdutable.header['imf'] = imf

    # emission lines
    hduflux6 = fits.ImageHDU(linefluxes)
    hduflux6.header['EXTNAME'] = 'LINEFLUXES'
    hduflux6.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduwave4 = fits.ImageHDU(linewaves)
    hduwave4.header['EXTNAME'] = 'LINEWAVES'
    hduwave4.header['BUNIT'] = 'Angstrom'
    hduwave4.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')

    # metadata table
    hx = fits.HDUList([hduflux1, hduflux2, hduflux3, hduwave1, hdutable,
                       hduflux4, hduwave2,
                       hduflux5, hduwave3,
                       hduflux6, hduwave4])

    print(f'Writing {len(models)} model spectra to {outfile}')
    hx.writeto(outfile, overwrite=True)


if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--templatedir', required=True, help='Location of original and output templates')
    parser.add_argument('--version', required=True, help='Version number (e.g., 2.0.0)')

    # DL07 parameters
    # https://python-fsps.readthedocs.io/en/latest/stellarpop_api
    parser.add_argument('--qpah', type=float, default=3.5, help='DL07 parameter')
    parser.add_argument('--umin', type=float, default=1., help='DL07 parameter')
    parser.add_argument('--gamma', type=float, default=0.01, help='DL07 parameter')
    # Nenkova+08 parameters
    # https://python-fsps.readthedocs.io/en/latest/stellarpop_api
    parser.add_argument('--agntau', type=float, default=10., help='Nenkova+08 parameter')
    # FSPS parameters
    parser.add_argument('--imf', type=str, default='chabrier', choices=['chabrier', 'salpeter', 'kroupa'],
                        help='Initial mass function')
    parser.add_argument('--logmets', type=str, default='0.', help='Stellar metallicity values')
    parser.add_argument('--test', action='store_true', help='Generate a test set of SPS models.')
    args = parser.parse_args()

    templatedir = os.path.expandvars(args.templatedir)
    origdir = os.path.join(templatedir, 'original')
    if not os.path.isdir(origdir):
        raise IOError(f'Missing directory containing original templates {origdir}')

    logmets = np.array(args.logmets.split(',')).astype(float)
    main(imf=args.imf, logmets=logmets, qpah=args.qpah,
         umin=args.umin, gamma=args.gamma, agntau=args.agntau, test=args.test,
         version=args.version, templatedir=templatedir)
