#!python
#
# See top-level LICENSE.rst file for Copyright information
#
# -*- coding: utf-8 -*-

"""
This script computes the fiber flat field correction from a DESI continuum lamp frame.
"""


import os,sys
import argparse
import glob
import numpy as np
import multiprocessing
from astropy.table import Table

from desiutil.log import get_logger
from desispec.io import specprod_root
from desispec.skymag import compute_skymag
from desispec.util import parse_int_args



def parse(options=None):
    parser = argparse.ArgumentParser(
                description="Calculate template S/N ratio for exposures")
    parser.add_argument('-o','--outfile', type=str, default=None, required=False,
                        help = 'Output summary file')
    parser.add_argument('--update', action = 'store_true',
                        help = 'Update pre-existing output summary file (replace or append)')
    parser.add_argument('--prod', type = str, default = None, required=False,
                        help = 'Path to input reduction, e.g. /global/cfs/cdirs/desi/spectro/redux/blanc/,  or simply prod version, like blanc, but requires env. variable DESI_SPECTRO_REDUX. Default is $DESI_SPECTRO_REDUX/$SPECPROD.')
    parser.add_argument('-e','--expids', type = str, default = None, required=False,
                        help = 'Comma separated list of exp ids to process')
    parser.add_argument('-n','--nights', type = str, default = None, required=False,
                        help = 'Comma, or colon separated list of nights to process. ex: 20210501,20210502 or 20210501:20210531')
    parser.add_argument('--nproc', type = int, default = 1,
                        help = 'Multiprocessing.')

    args = None
    if options is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(options)
    return args




def func(night,expid,specprod_dir) :
    """
    Wrapper function to compute_skymag for multiprocessing
    """
    log = get_logger()
    mags = compute_skymag(night,expid,specprod_dir)
    entry = {'NIGHT':night,'EXPID':expid,'SKY_MAG_G':mags[0],'SKY_MAG_R':mags[1],'SKY_MAG_Z':mags[2]}
    log.info(str(entry))
    return(entry)

def _func(arg) :
    """
    Wrapper function to compute_skymag for multiprocessing
    """
    return func(**arg)

def main():

    log = get_logger()

    args=parse()

    if args.prod is None:
        args.prod = specprod_root()
    elif args.prod.find("/")<0 :
        args.prod = specprod_root(args.prod)

    log.info('prod    = {}'.format(args.prod))
    log.info('outfile = {}'.format(args.outfile))


    if args.expids is not None:
        expids = [np.int(x) for x in args.expids.split(',')]
    else:
        expids = None

    if args.nights is None:
        dirnames = sorted(glob.glob('{}/exposures/*'.format(args.prod)))
        nights=[]
        for dirname in dirnames :
            try :
                night=int(os.path.basename(dirname))
                nights.append(night)
            except ValueError as e :
                log.warning("ignore {}".format(dirname))
    else :
        nights=parse_int_args(args.nights)

    log.info("nights = {}".format(nights))
    if expids is not None : log.info('expids = {}'.format(expids))

    summary_rows  = list()

    for count,night in enumerate(nights) :

        dirnames = sorted(glob.glob('{}/exposures/{}/*'.format(args.prod,night)))
        night_expids=[]
        for dirname in dirnames :
            try :
                expid=int(os.path.basename(dirname))
                night_expids.append(expid)
            except ValueError as e :
                log.warning("ignore {}".format(dirname))
        if expids is not None :
            night_expids = np.intersect1d(expids,night_expids)
            if night_expids.size == 0 :
                continue
        log.info("{} {}".format(night,night_expids))

        func_args = []
        for expid in night_expids :
            func_args.append({'night':night,'expid':expid,'specprod_dir':args.prod})

        if args.nproc == 1 :
            for func_arg in func_args :
                entry = func(**func_arg)
                if entry is not None :
                    summary_rows.append(entry)
        else :
            log.info("Multiprocessing with {} procs".format(args.nproc))
            pool = multiprocessing.Pool(args.nproc)
            results  =  pool.map(_func, func_args)
            for entry in results :
                if entry is not None :
                    summary_rows.append(entry)
            pool.close()
            pool.join()
    if len(summary_rows)>0 :
        colnames = list(summary_rows[0].keys())
        table = Table(rows=summary_rows, names=colnames)
        table.write(args.outfile,overwrite=True)
        log.info("wrote {}".format(args.outfile))
    else :
        log.error("no data")
if __name__ == '__main__':
    main()
