#!/usr/bin/env python
#
# See top-level LICENSE.rst file for Copyright information
#
# -*- coding: utf-8 -*-

"""
This script is used to do a visual validation of the tiles based on the automated QA results.
"""


import os,sys
import argparse
import glob
from getpass import getuser
import numpy as np
import multiprocessing
from astropy.table import Table
import fitsio
import subprocess
import datetime

from desiutil.log import get_logger
from desispec.io import specprod_root, findfile
from desispec.util import parse_int_args


def parse(options=None):
    parser = argparse.ArgumentParser(
                description="Tile Visual Inspection")
    parser.add_argument('--prod', type = str, default = None, required=False,
                        help = 'Path to input reduction, e.g. /global/cfs/cdirs/desi/spectro/redux/blanc/,  or simply prod version, like blanc, but requires env. variable DESI_SPECTRO_REDUX. Default is $DESI_SPECTRO_REDUX/$SPECPROD.')
    parser.add_argument('--qa-dir', type = str, default = None, required=False,
                        help = 'Path to qa directory, default is the same as the prod directory')
    parser.add_argument('-t','--tileids', type = str, default = None, required=False,
                        help = 'Comma, or colon separated list of nights to process. ex: 12,14 or 12:23')
    parser.add_argument('-n','--nights', type = str, default = None, required=False,
                        help = 'Comma, or colon separated list of nights to process. ex: 20210501,20210502 or 20210501:20210531')
    parser.add_argument('-i','--infile', type = str, default = 'tiles-specstatus.ecsv', required=False,
                        help = 'Specific input tile file, default is tiles-specstatus.ecsv')
    parser.add_argument('-o','--outfile', type = str, default = None, required=False,
                        help = 'Output tile file, default is same as input')
    parser.add_argument('-u','--user', type = str, default = getuser(), required=False,
                        help = 'your name or initials')
    parser.add_argument('--viewer', type = str, default = "eog", required=False,
                        help = 'image viewer (default is eog)')
    parser.add_argument('--qastatus', type=str, default='none', required=False,
                        help = ('only inspect tiles with this QA status '
                                '(e.g. none, unsure, good, bad, all).  '
                                'Default of none shows only new tiles.'))
    parser.add_argument('--survey', type = str, default = 'main', required=False,
                        help = 'look only at tiles from this survey')
    parser.add_argument('--program', type=str, nargs='+', required=False,
                        default=['dark', 'bright', 'dark1b', 'bright1b'],
                        help='look only at tiles from these programs')

    args = None
    if options is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(options)
    return args

# user interaction wrappers
def input2(message):
    '''Wrapper for input which will log the interaction.'''
    print(f'PROMPT: {message}')
    user_input = input('>>> ')
    if user_input == '':
        user_input = '(no entry)'
    print(f'USER ENTERED >>> {user_input}')
    return user_input

blank_responses = {'', '(no entry)'}
viewer=None

def yesnohelp(message,_pass=True, _skip=True, _help=True):
    '''Wrapper for user input which cleans up human-entered values and returns
    a string of exactly 'yes', 'no', or 'help'. Also can give user option to
    exit the program.

    Boolean args allow you to turn off options for the user.
    '''
    suffix_parts = ['y','n']
    if _pass: suffix_parts.append('unsure')
    if _help: suffix_parts.append('help')
    if _skip: suffix_parts.append('skip')
    suffix = '/'.join(suffix_parts)
    answer =  input2(f'{message} ({suffix})')
    if _help and answer.lower() in {'h', 'help', 'doc', 'man'}:
        return 'help'
    if answer.lower() in {'n', 'no', 'false'}:
        return 'no'
    if _pass and answer.lower() in {'u', 'unsure'}:
        return 'pass'
    if answer.lower() in {'y', 'yes', 'true'}:
        return 'yes'
    if _skip and answer.lower() in {'s', 'skip'}:
        return 'skip'
    return yesnohelp(message)


def print_summary(new, orig):
    """Print a summary of the status of tile VI.

    Arguments are the new and original tile tables."""
    import datetime
    twoweeksago = datetime.date.today() + datetime.timedelta(days=-14)
    twoweeksagoint = (twoweeksago.year*10000 +
                      twoweeksago.month*100 +
                      twoweeksago.day)
    newtiles = (new['QA'] != 'none') & (orig['QA'] == 'none')
    oldtiles = ~newtiles & (new['LASTNIGHT'] >= twoweeksagoint)
    veryoldtiles = ~newtiles & ~oldtiles
    print('Tile status')
    rows = []
    classes = {'new': newtiles, 'old': oldtiles, 'very old': veryoldtiles}
    programs = {'dark': new['FAFLAVOR'] == 'maindark',
                'bright': new['FAFLAVOR'] == 'mainbright',
                'dark1b': new['FAFLAVOR'] == 'maindark1b',
                'bright1b': new['FAFLAVOR'] == 'mainbright1b',
                'other': ((new['SURVEY'] == 'main') &
                          (~np.isin(new['FAPRGRM'], ['dark', 'bright', 'dark1b', 'bright1b'])))}
    for class0 in classes:
        for program0 in programs:
            m = classes[class0] & programs[program0]
            rows.append((
                class0 + ' ' + program0,
                np.sum((new['QA'] == 'good') & m),
                np.sum((new['QA'] == 'unsure') & m),
                np.sum((new['QA'] == 'bad') & m),
            ))

    print('%20s %8s %8s %8s' % ('type', 'good', 'unsure', 'bad'))
    for row in rows:
        print('%20s %8d %8d %8d' % row)



def main():
    global viewer

    log = get_logger()

    args=parse()

    if args.prod is None:
        args.prod = specprod_root()
    elif args.prod.find("/")<0 :
        args.prod = specprod_root(args.prod)
    if args.qa_dir is None :
        args.qa_dir = args.prod
    log.info('prod    = {}'.format(args.prod))
    log.info('qa dir  = {}'.format(args.qa_dir))

    if args.outfile is None :
        args.outfile = args.infile

    if args.tileids is not None:
        requested_tileids = parse_int_args(args.tileids)
    else:
        requested_tileids = None
    if args.nights is not None :
        requested_nights = parse_int_args(args.nights)
    else :
        requested_nights = None

    if not os.path.isfile(args.infile) :
        print("ERROR: No tiles table file {}".format(args.infile))
        print("Add or change --prod or --infile ... ?")
        sys.exit(12)
    tiles_table = Table.read(args.infile)
    tiles_table_orig = tiles_table.copy()

    if not "QA" in tiles_table.dtype.names :
        tiles_table["QA"]=np.array(np.repeat("none",len(tiles_table)),dtype='<U20')
    if not "USER" in tiles_table.dtype.names :
        tiles_table["USER"]=np.array(np.repeat("none",len(tiles_table)),dtype='<U20')
    if not "OVERRIDE" in tiles_table.dtype.names :
        tiles_table["OVERRIDE"]=np.zeros(len(tiles_table),dtype=int)
    tiles_table['QA'] = tiles_table['QA'].astype('<U20')

    log.info("Found {} tiles in {}".format(len(tiles_table),args.infile))
    selection = (tiles_table["ZDONE"]=="false")
    if requested_tileids is None:
        # if the user has specifically requested some tileids, show them
        # to them!  Only filter if requested_tileids is not None
        selection &= (tiles_table["EFFTIME_SPEC"]>=(tiles_table["GOALTIME"]*tiles_table["MINTFRAC"]))
        if args.survey is not None:
            selection &= (tiles_table["SURVEY"]==args.survey)
        if args.program is not None:
            selection &= np.isin([x.lower() for x in tiles_table['FAPRGRM']],
                                 [x.lower() for x in args.program])
        if args.qastatus != 'all':
            if np.sum(tiles_table['QA'] == args.qastatus) == 0:
                print('No tiles matching qastatus, possible options are ',
                      np.unique(tiles_table['QA']))
                return
            selection &= (tiles_table["QA"]==args.qastatus)

    log.info("Includes {} tiles with ZDONE=false but EFFTIME_SPEC>=GOALTIME*MINTFRAC".format(np.sum(selection)))

    tileids = tiles_table["TILEID"][selection]
    nights  = tiles_table["LASTNIGHT"][selection]

    ok = np.repeat(True,tileids.size)
    if requested_tileids is not None :
        ok &= np.isin(tileids,requested_tileids)
    if requested_nights is not None :
        ok &= np.isin(nights,requested_nights)
    tileids = tileids[ok]
    nights  = nights[ok]

    if len(tileids)==0 :
        print("no tile/night selected")
        sys.exit(12)
    log.info("Remains {} tiles after selection of tileid/night:".format(tileids.size))
    #for tileid,night in zip(tileids,nights) :
    #    log.info(" TILE={} NIGHT={}".format(tileid,night))

    viewer=None
    for tileid,night in zip(tileids,nights) :
        log.info("Inspecting TILE={} NIGHT={} ...".format(tileid,night))

        # find qa plot ...
        qaplot_filename = findfile("tileqapng",night=night,tile=tileid,specprod_dir=args.qa_dir,groupname="cumulative")
        if not os.path.isfile(qaplot_filename) :
            log.error("Missing {}, did you run desi_tile_qa?".format(qaplot_filename))
            continue

        software_answer = None
        # find qa fits ...
        qafits_filename = findfile("tileqa",night=night,tile=tileid,specprod_dir=args.qa_dir,groupname="cumulative")
        if os.path.isfile(qaplot_filename) :
            qahead = fitsio.read_header(qafits_filename,"FIBERQA")
            if "VALID" in qahead :
                software_answer = qahead["VALID"]
            else :
                log.warning("no VALID header keyword in {} HDU FIBERQA".format(qafits_filename))
        else :
            log.warning("no {}".format(qafits_filename))




        print("Showing {} ...".format(qaplot_filename))
        if software_answer is not None :
            print("The code thinks it's a valid tile")
        else :
            print("The code thinks it's NOT a valid tile")
        index=np.where(tiles_table["TILEID"]==tileid)[0][0]
        row = tiles_table[index]
        if row['QA'] != 'none':
            print(f'Note existing QA for this tile of {row["QA"]} by '
                  f'{row["USER"]} on {row["QANIGHT"]}!')

        if viewer is not None : viewer.kill()
        viewer = subprocess.Popen([args.viewer,qaplot_filename ],stdout=subprocess.PIPE,stderr=subprocess.PIPE)

        while True :
            res=yesnohelp("Is tile {} a valid tile?".format(tileid))
            if res=="help" :
                print("y: yes, this tile is valid.")
                print("n: no, it's bad and we should discard one or several exposures or fix the processing.")
                print("skip: skip this one; just show the image and move on.")
                print("unsure: mark this one as unsure.")
            elif res=="yes" :
                if row['QA'] != 'none':
                    print(f'Overwriting old QA of {row["QA"]} by {row["USER"]} on '
                          f'{row["QANIGHT"]}!')
                tiles_table["QA"][index]="good"
                tiles_table["USER"][index]=args.user
                tiles_table["QANIGHT"][index]=tiles_table["LASTNIGHT"][index]
                if software_answer is not None :
                    if software_answer is False :
                        tiles_table["OVERRIDE"][index]=1
                        print("Your have overridden the automated validation.")
                tiles_table.write(args.outfile,overwrite=True)
                break
            elif res=="no" :
                tiles_table["QA"][index]="bad"
                tiles_table["USER"][index]=args.user
                tiles_table["QANIGHT"][index]=tiles_table["LASTNIGHT"][index]
                if software_answer is not None :
                    if software_answer is True :
                        tiles_table["OVERRIDE"][index]=1
                        print("Your have overridden the automated validation.")
                tiles_table.write(args.outfile,overwrite=True)
                break
            elif res=="pass" :
                tiles_table["QA"][index]="unsure"
                tiles_table["USER"][index]=args.user
                tiles_table["QANIGHT"][index]=tiles_table["LASTNIGHT"][index]
                tiles_table.write(args.outfile,overwrite=True)
                break
            elif res=="skip" :
                break

        viewer.kill()
        viewer=None
    print_summary(tiles_table, tiles_table_orig)


if __name__ == '__main__':
    sys.exit(main())
