#!/usr/bin/env python

"""
Create the tiles table for a specprod
"""

import os, sys, glob
import multiprocessing
import argparse
import numpy as np
from astropy.table import Table
import fitsio

from desispec.io import findfile, specprod_root
from desispec.io.meta import faflavor2program
from desispec.io.util import get_tempfilename
from desispec.tsnr import tsnr2_to_efftime
from desiutil.log import get_logger

def get_tile_info(tileid, specprod=None):
    """
    Return row of tiles table for the given tileid

    Args:
        tileid (int): tileid to lookup

    Options:
        specprod (str): override $SPECPROD

    Returns dict with TILED, SURVEY, PROGRAM ... to stack into a tiles table
    """
    log = get_logger()
    log.info(f'Processing tile {tileid}')

    # we don't know the NIGHT, so glob for it
    fakenight = '99999999'
    tileqa_glob = findfile('tileqa', tile=tileid, groupname='cumulative', night=fakenight, readonly=True)
    tileqa_glob = tileqa_glob.replace(fakenight, '*')
    tileqa_file = sorted(glob.glob(tileqa_glob))[-1]

    ztiledir = os.path.dirname(tileqa_file)

    coadd_files = sorted(glob.glob(f'{ztiledir}/coadd-?-{tileid}-thru*.fits'))

    qahdr = fitsio.read_header(tileqa_file, 'FIBERQA')
    cohdr = fitsio.read_header(coadd_files[0], 'FIBERMAP')

    program  = faflavor2program(cohdr['FAFLAVOR']).lower()
    if 'GOALTYPE' not in qahdr:
        log.error(f'{tileqa_file} missing GOALTYPE header keyword; using {program}')
        qahdr['GOALTYPE'] = program
        

    fibermaps = list()
    fibermap_columns = ['COADD_NUMEXP', 'COADD_EXPTIME']
    scores = list()
    scores_columns = ['TSNR2_LRG', 'TSNR2_ELG', 'TSNR2_BGS', 'TSNR2_LYA']
    for coaddfile in coadd_files:
        with fitsio.FITS(coaddfile) as fx:
            fibermaps.append(fx['FIBERMAP'].read(columns=fibermap_columns))
            scores.append(fx['SCORES'].read(columns=scores_columns))

    nexp = np.max( np.concatenate([fm['COADD_NUMEXP'] for fm in fibermaps]) )
    exptime = np.max( np.concatenate([fm['COADD_EXPTIME'] for fm in fibermaps]) )

    tsnr2_lrg = np.concatenate([s['TSNR2_LRG'] for s in scores])
    tsnr2_elg = np.concatenate([s['TSNR2_ELG'] for s in scores])
    tsnr2_bgs = np.concatenate([s['TSNR2_BGS'] for s in scores])
    tsnr2_lya = np.concatenate([s['TSNR2_LYA'] for s in scores])

    lrg_efftime = tsnr2_to_efftime(np.median(tsnr2_lrg), 'LRG'), 
    elg_efftime = tsnr2_to_efftime(np.median(tsnr2_elg), 'ELG'), 
    bgs_efftime = tsnr2_to_efftime(np.median(tsnr2_bgs), 'BGS'), 
    lya_efftime = tsnr2_to_efftime(np.median(tsnr2_lya), 'LYA'), 

    #- TODO: efftime_spec definition
    if program == 'dark':
        efftime_spec = lrg_efftime
    else:
        efftime_spec = bgs_efftime

    info = dict(
            TILEID   = qahdr['TILEID'],
            SURVEY   = qahdr['SURVEY'].lower(),
            PROGRAM  = program,
            FAPRGRM  = qahdr['FAPRGRM'].lower(),
            FAFLAVOR = cohdr['FAFLAVOR'].lower(),
            NEXP     = nexp,
            EXPTIME  = exptime,
            TILERA   = qahdr['TILERA'],
            TILEDEC  = qahdr['TILEDEC'],
            EFFTIME_ETC  = np.float32(0.0), # TODO
            EFFTIME_SPEC = np.float32(efftime_spec),
            EFFTIME_GFA  = np.float32(0.0), # TODO
            GOALTIME     = np.float32(qahdr['GOALTIME']),
            OBSSTATUS    = 'obsend',        # TODO
            LRG_EFFTIME_DARK   = np.float32(tsnr2_to_efftime(np.median(tsnr2_lrg), 'LRG')),
            ELG_EFFTIME_DARK   = np.float32(tsnr2_to_efftime(np.median(tsnr2_elg), 'ELG')),
            BGS_EFFTIME_BRIGHT = np.float32(tsnr2_to_efftime(np.median(tsnr2_bgs), 'BGS')),
            LYA_EFFTIME_BRIGHT = np.float32(tsnr2_to_efftime(np.median(tsnr2_lya), 'LYA')),
            GOALTYPE  = qahdr['GOALTYPE'].lower(),
            MINTFRAC  = np.float32(qahdr['MINTFRAC']),
            LASTNIGHT = qahdr['LASTNITE'],
            )

    return info

def main():
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument('-o', '--outfile', required=True, help="output filename")
    p.add_argument('--specprod', required=False, help="specprod to use")
    p.add_argument('--nproc', type=int, default=32, help="number of processes to use")
    p.add_argument('--debug', action="store_true", help="start ipython at the end")
    p.epilog = 'Note: in addition to --outfile, also writes .csv or .fits equivalent'
    args = p.parse_args()

    log = get_logger()

    reduxdir = specprod_root(args.specprod)
    tiledirs = sorted(glob.glob(f'{reduxdir}/tiles/cumulative/*'))
    tileids = list()
    for path in tiledirs:
        try:
            tileids.append( int(os.path.basename(path)) )
        except ValueError:
            pass

    tileids = np.array(sorted(tileids))
    log.info(f'Creating tile table for {len(tileids)} tiles')

    # warm up tsnr2_to_efftime ensemble cache before multiprocessing
    # so that they are read only once, instead of once per process
    blat = tsnr2_to_efftime(10, 'elg')

    # get tile info in parallel
    with multiprocessing.Pool(args.nproc) as pool:
        tileinfo = pool.map(get_tile_info, tileids)

    results = Table(tileinfo)
    results.meta['EXTNAME'] = 'TILE_COMPLETENESS'

    tmpfile = get_tempfilename(args.outfile)
    results.write(tmpfile)
    os.rename(tmpfile, args.outfile)
    log.info(f'Wrote {args.outfile}')

    prefix, extension = os.path.splitext(args.outfile)
    if extension == '.csv':
        altoutfile = prefix+'.fits'
    elif extension == '.fits' or extension == '.fits.gz':
        altoutfile = prefix+'.csv'
    else:
        log.info('Unrecognized outfile extension %s; not writing alternate .csv/.fits file')
        altoutfile = None

    if altoutfile is not None:
        tmpfile = get_tempfilename(altoutfile)
        results.write(tmpfile)
        os.rename(tmpfile, altoutfile)
        log.info(f'Wrote {altoutfile}')

    if args.debug:
        import IPython; IPython.embed()

if __name__ == '__main__':
    main()
