#!python

"""Validate the redshift catalogs by making sure every spectrum exists.

desi_validate_zcatalog_spectra --reduxdir /global/cfs/cdirs/desi/spectro/redux/fuji --mp 32 --validate-zpix
desi_validate_zcatalog_spectra --reduxdir /global/cfs/cdirs/desi/spectro/redux/fuji --mp 32 --validate-ztile

"""
import os, sys, argparse, pdb
from glob import glob
import numpy as np
import fitsio
import multiprocessing
from astropy.table import Table
from desispec.io import specprod_root

from desiutil.log import get_logger, DEBUG
log = get_logger(DEBUG)

desi_root = os.environ.get('DESI_ROOT')

def _one_healpix(args):
    return one_healpix(*args)

def _one_tile(args):
    return one_tile(*args)

def one_healpix(healpix, survey, program, reduxdir):
    """Validate healpix tile worth of spectra."""

    filedir = os.path.join(reduxdir, 'healpix', survey, program, str(healpix//100), str(healpix))
    for prefix in ['coadd', 'redrock']:
        specfile = os.path.join(filedir, '{}-{}-{}-{}.fits'.format(prefix, survey, program, healpix))
        if not os.path.isfile(specfile):
            raise IOError('Missing {}'.format(specfile))

def one_tile(tileid, zcat, coadd_type, reduxdir):
    """Validate one tile worth of spectra."""
    if coadd_type == 'cumulative':
        for night in set(zcat['LASTNIGHT']):
            I = night == zcat['LASTNIGHT']
            for petal in set(zcat['PETAL_LOC'][I]):
                J = petal == zcat['PETAL_LOC'][I]
                filedir = os.path.join(reduxdir, 'tiles', coadd_type, str(tileid), str(night))
                for prefix in ['coadd', 'redrock']:
                    specfile = os.path.join(filedir, '{}-{}-{}-thru{}.fits'.format(prefix, petal, tileid, night))
                    if not os.path.isfile(specfile):
                        raise IOError('Missing {}'.format(specfile))
    elif coadd_type == 'pernight':
        for night in set(zcat['NIGHT']):
            I = night == zcat['NIGHT']
            for petal in set(zcat['PETAL_LOC'][I]):
                J = petal == zcat['PETAL_LOC'][I]
                filedir = os.path.join(reduxdir, 'tiles', coadd_type, str(tileid), str(night))
                for prefix in ['coadd', 'redrock']:
                    specfile = os.path.join(filedir, '{}-{}-{}-{}.fits'.format(prefix, petal, tileid, night))
                    if not os.path.isfile(specfile):
                        raise IOError('Missing {}'.format(specfile))
    elif coadd_type == 'perexp':
        for expid in set(zcat['EXPID']):
            I = expid == zcat['EXPID']
            filedir = os.path.join(reduxdir, 'tiles', coadd_type, str(tileid), '{:08}'.format(expid))
            for petal in set(zcat['PETAL_LOC'][I]):
                J = petal == zcat['PETAL_LOC'][I]
                for prefix in ['coadd', 'redrock']:
                    specfile = os.path.join(filedir, '{}-{}-{}-exp{}.fits'.format(prefix, petal, tileid, '{:08}'.format(expid)))
                    if not os.path.isfile(specfile):
                        raise IOError('Missing {}'.format(specfile))
    else:
        for spgrpval in set(zcat['SPGRPVAL']):
            filedir = os.path.join(reduxdir, 'tiles', coadd_type, str(tileid), str(spgrpval))
            I = spgrpval == zcat['SPGRPVAL']
            for petal in set(zcat['PETAL_LOC'][I]):
                J = petal == zcat['PETAL_LOC'][I]
                for prefix in ['coadd', 'redrock']:
                    specfile = glob(os.path.join(filedir, '{}-{}-{}-{}'.format(prefix, petal, tileid, '*.fits')))
                    try:
                        assert(len(specfile) == 1)
                    except:
                        raise IOError('Problematic {}'.format(specfile))
            
def main():

    p = argparse.ArgumentParser()
    p.add_argument('--reduxdir', type=str, help='spectro redux base dir overrides $DESI_SPECTRO_REDUX/$SPECPROD')
    p.add_argument('--mp', type=int, default=1, help='number of multiprocessing cores')
    p.add_argument('--validate-zpix', action='store_true', help='Validate the zpix- catalogs.')
    p.add_argument('--validate-ztile', action='store_true', help='Validate the ztile- catalogs.')
    
    args = p.parse_args()
    log = get_logger()

    if args.reduxdir is None:
        args.reduxdir = specprod_root()

    if args.validate_zpix:
        zprefix = 'zpix'
        columns = ['TARGETID', 'HEALPIX', 'TARGET_RA', 'TARGET_DEC']
        
        zcatfiles = sorted(set(glob(os.path.join(args.reduxdir, 'zcatalog', '{}-*.fits'.format(zprefix)))))
        for zcatfile in zcatfiles:
            log.info('Working on redshift catalog {}'.format(zcatfile))            
            zcat = Table(fitsio.read(zcatfile, 'ZCATALOG', columns=columns))
            tileids = np.unique(fitsio.read(zcatfile, 'EXP_FIBERMAP', columns='TILEID'))

            hdr = fitsio.read_header(zcatfile, ext='ZCATALOG')
            survey, program = hdr['SURVEY'], hdr['PROGRAM']
            
            # multiprocess over healpixels
            mpargs = [[healpix, survey, program, args.reduxdir] for healpix in set(zcat['HEALPIX'])]
                    
            if args.mp > 1:
                with multiprocessing.Pool(args.mp) as P:
                    P.map(_one_healpix, mpargs)
            else:
                [one_healpix(*mparg) for mparg in mpargs]
    
    if args.validate_ztile:
        zprefix = 'ztile'
        basecolumns = ['TARGETID', 'TILEID', 'PETAL_LOC', 'TARGET_RA', 'TARGET_DEC']
        
        for coadd_type in ['cumulative', 'perexp', 'pernight', '1x_depth', '4x_depth', 'lowspeed']:
            log.info('Working on coadd type {}'.format(coadd_type))
            zcatfiles = sorted(set(glob(os.path.join(args.reduxdir, 'zcatalog', '{}-*{}.fits'.format(zprefix, coadd_type)))))
                                                                                       
            for zcatfile in zcatfiles:
                log.info('Working on redshift catalog {}'.format(zcatfile))
                if coadd_type == 'cumulative':
                    columns = basecolumns + ['LASTNIGHT']
                elif coadd_type == 'perexp':
                    columns = basecolumns + ['NIGHT', 'EXPID']
                elif coadd_type == 'pernight':
                    columns = basecolumns + ['NIGHT']
                elif coadd_type == '1x_depth' or coadd_type == '4x_depth' or coadd_type == 'lowspeed':
                    columns = basecolumns + ['SPGRPVAL']
                    
                zcat = Table(fitsio.read(zcatfile, 'ZCATALOG', columns=columns))
                    
                # multiprocess over tiles
                mpargs = [[tileid, zcat[tileid == zcat['TILEID']], coadd_type, args.reduxdir]
                          for tileid in set(zcat['TILEID'])]
                        
                if args.mp > 1:
                    with multiprocessing.Pool(args.mp) as P:
                        P.map(_one_tile, mpargs)
                else:
                    [one_tile(*mparg) for mparg in mpargs]
    
if __name__ == '__main__':
    main()
