#!/usr/bin/env python

"""Find exposures with a significant fraction of bad positioners.

time desi_find_badpos --reduxdir /global/cfs/cdirs/desi/spectro/redux/everest --mp 32 -o badpos-everest.fits
time desi_find_badpos --reduxdir /global/cfs/cdirs/desi/spectro/redux/daily --mp 32 -o badpos-daily-20220123.fits

To cross-reference against the existing exposure tables, do:

desi_find_badpos -o badpos-everest.fits -c
desi_find_badpos -o badpos-daily-20220123.fits -c

"""
import os, sys, argparse, pdb
import numpy as np
import fitsio
import multiprocessing
from astropy.table import Table
from desiutil.log import get_logger
from desispec.io import specprod_root, iterfiles

def _badpos_in_fibermap(args):
    return badpos_in_fibermap(*args)

def badpos_in_fibermap(fibermapfile):

    fm = fitsio.read(fibermapfile, 'FIBERMAP', columns=['PETAL_LOC', 'DELTA_X', 'DELTA_Y'])
    night, expid = np.array(fibermapfile.split('/')[-3:-1]).astype(int) # fragile...

    inan = np.logical_or(np.isnan(fm['DELTA_X']), np.isnan(fm['DELTA_Y']))
    ibig = fm['DELTA_X']**2 + fm['DELTA_Y']**2 > 0.030**2
    ibad = np.logical_or(inan, ibig)

    # get the per-petal fraction of bad positioners
    petals = set(fm['PETAL_LOC'])
    rows = list()
    for petal in petals:
        npos = np.sum(fm['PETAL_LOC'] == petal) # should always be 500
        fbad = np.sum((fm['PETAL_LOC'] == petal) * ibad) / npos
        rows.append((night, expid, petal, fbad))
        
    out = Table(rows=rows, names=('NIGHT', 'EXPID', 'PETAL_LOC', 'FBAD'),
                dtype=(np.int32, np.int32, np.int16, np.float32))

    return out

def main():

    p = argparse.ArgumentParser()
    p.add_argument('--reduxdir', type=str, help='spectro redux base dir overrides $DESI_SPECTRO_REDUX/$SPECPROD')
    p.add_argument('--mp', type=int, default=1, help='number of multiprocessing cores')
    p.add_argument('--thresh', type=float, default=0.99, help='Threshold for identifying bad exposures.')
    p.add_argument('-o', '--outfile', type=str, required=True, help='output FITS file')
    p.add_argument('-c', '--check-exposure-tables', action='store_true', help='Cross-reference against existing exposure tables.')
    p.add_argument('--overwrite', type=bool, help='Overwrite an existing file.')
    
    args = p.parse_args()
    log = get_logger()

    if os.path.isfile(args.outfile) and not args.overwrite and not args.check_exposure_tables:
        log.info('Output file {} exists; use --overwrite'.format(args.outfile))
        return

    if not os.path.isfile(args.outfile) or args.overwrite:       
        if args.reduxdir is None:
            args.reduxdir = specprod_root()
    
        fibermapfiles = sorted(iterfiles(f'{args.reduxdir}/preproc', 'fibermap-'))
        #fibermapfiles = sorted(iterfiles(f'{args.reduxdir}/preproc/20210614', 'fibermap-'))
        #fibermapfiles = ['/global/cfs/cdirs/desi/spectro/redux/everest/preproc/20210614/00094626/fibermap-00094626.fits'] # testing
        
        n = len(fibermapfiles)
        if n == 0:
            log.error(f'No fibermaps found in {args.reduxdir}/preproc')
            sys.exit(1)
        else:
            log.info(f'Processing {n} fibermaps from {args.reduxdir}/preproc')
    
        if args.mp > 1:
            mpargs = [[fibermapfile] for fibermapfile in fibermapfiles if 'old' not in fibermapfile and 'stash' not in fibermapfile]
            with multiprocessing.Pool(args.mp) as P:
                out = P.map(_badpos_in_fibermap, mpargs)
        else:
            out = [badpos_in_fibermap(fibermapfile) for fibermapfile in fibermapfiles if 'old' not in fibermapfile and 'stash' not in fibermapfile]
    
        # stack and write out
        out = Table(np.hstack(out))        
        out.meta['EXTNAME'] = 'FBADPOS'
        out.write(args.outfile, overwrite=True)

    # check / cross-reference against the exposure tables
    if args.check_exposure_tables:

        log.info('Reading {}'.format(args.outfile))
        allpos = Table.read(args.outfile)

        pos = allpos[allpos['FBAD'] > args.thresh]
        if len(pos) == 0:
            log.info('No bad petals found with threshold = {:.4f}'.format(args.thresh))
        else:
            log.info('Flagging {} bad petals with threshold = {:.4f}'.format(len(pos), args.thresh))

        # Loop on each unique night
        for night in set(pos['NIGHT']):
          #log.info('Working on night {}'.format(night))
          strnight = str(night)
          expfile = os.path.join(os.environ.get('DESI_ROOT'), 'spectro', 'redux', 'exptabs', 'exposure_tables',
                                 strnight[:6], 'exposure_table_{}.csv'.format(strnight))
          nightpos = pos[pos['NIGHT'] == night]
          
          # If the table exists, read it and cross-reference against the identified petals
          if os.path.exists(expfile):
              exps = Table.read(expfile)
              # Let's be conservative and loop one at a time.
              for expid in set(nightpos['EXPID']):
                I = exps['EXPID'] == expid
                if np.sum(I) == 0: # not accounted for
                  log.info(expid, night)
          else:
            log.info('Need new exposure table with the following content:')
            log.info(expfile)
            for onepos in nightpos:
                print('{} {} {}'.format(onepos['NIGHT'], onepos['EXPID'], onepos['PETAL_LOC']))
            print()

if __name__ == '__main__':
    main()
