#!/usr/bin/env python

"""
Report status of batch queue
"""

import os, sys, time
import subprocess
from astropy.table import Table

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-u', '--user', default=os.environ['USER'],
            help='check queue for this user')
parser.add_argument('--currentjobs', help='save current jobs to this file')
parser.add_argument('--pastjobs', help='save past jobs from today to this file')
parser.add_argument('-S', '--start', help='track jobs starting after this timestamp')
parser.add_argument('-E', '--end', help='track jobs ending before this timestamp')
parser.add_argument('--debug', action='store_true',
        help='print slurm commands and start ipython prompt at end')
args = parser.parse_args()

#- Get past jobs
# TODO: these don't include CANCELLED jobs; figure out how to get those too
cmd = f'sacct -u {args.user} -X --format jobid,jobname%50,state,end --noheader'.split()
if args.start is not None:
    cmd.extend(['-S', args.start])
if args.end is not None:
    cmd.extend(['-E', args.end])

if args.debug:
    print(' '.join(cmd))

pastjobs = [line.decode().strip() for line in subprocess.check_output(cmd).split(b'\n')]

#- Get currently running/pending jobs
cmd = ['squeue', '-u', args.user, '--format', '%.8i %.9P %.50j %.10T %R',  '--noheader']

if args.debug:
    print(' '.join(cmd))

currentjobs = [line.decode().strip() for line in subprocess.check_output(cmd).split(b'\n')]

#- Count states
num_pending = 0
num_waiting = 0
num_running = 0
for line in currentjobs:
    if len(line) == 0:
        continue

    jobid, partition, jobname, state, reason = line.split()

    if partition in ('cron', 'workflow'):
        continue

    if state == 'PENDING':
        if reason == '(Dependency)':
            num_waiting += 1
        elif reason == '(Priority)':
            num_pending += 1
        else:
            print(f"??? {line}")
    elif state == 'RUNNING':
        num_running += 1

#- Count stats per hour
statecount_per_hour = dict()
for line in pastjobs:
    if len(line) == 0:
        continue

    jobid, jobname, state, endtime = line.split()

    #- Exclude scron jobs (nightwatch, svn and dashboard updates, etc)
    if jobname.startswith('scron') or state in ('PENDING', 'RUNNING'):
        continue

    hour = endtime[0:13]
    if hour not in statecount_per_hour:
        statecount_per_hour[hour] = dict(HOUR=hour, COMPLETED=0, FAILED=0, TIMEOUT=0)

    if state in statecount_per_hour[hour]:
        statecount_per_hour[hour][state] += 1

job_history = Table(data=list(statecount_per_hour.values()))

print(time.asctime())
print(f"{num_running} jobs running, {num_pending} pending, and {num_waiting} waiting on dependencies")

if len(job_history) > 0:
    print("Recent job completion history:")
    job_history.pprint_all()
else:
    print("No recent jobs")

if args.pastjobs is not None:
    rows = [x.split() for x in pastjobs if len(x)>0]
    jobtable = Table(rows=rows, names=['JOBID', 'JOBNAME', 'STATE', 'ENDTIME'])
    jobtable.write(args.pastjobs)
    print(f'Saved past job table to {args.pastjobs}')

if args.currentjobs is not None:
    rows = [x.split() for x in currentjobs if len(x)>0]
    jobtable = Table(rows=rows, names=['JOBID', 'PARTITION', 'JOBNAME', 'STATE', 'REASON'])
    jobtable.write(args.currentjobs)
    print(f'Saved current job table to {args.currentjobs}')

if args.debug:
    import IPython; IPython.embed()



