#!/usr/bin/env python

"""
Utility for moving jobs from the regular queue into a reservation, with limits
"""

import os, sys
import numpy as np
import subprocess
import json
import time

from desispec.workflow.queue import get_jobs_in_queue
from desiutil.log import get_logger

def get_reservation_info(name):
    """
    Return dictionary of reservation info from "scontrol show res NAME"
    """
    log = get_logger()

    cmd = f"scontrol show res {name} --json"
    log.info(f'Getting reservation info with: {cmd}')
    resinfo = json.loads(subprocess.check_output(cmd.split()))

    if len(resinfo['reservations']) == 0:
        log.critical(f'reservation {name} not found')
        sys.exit(1)
    elif len(resinfo['reservations']) > 1:
        log.critical(f'"{cmd}" returned info for more than one reservation')
        sys.exit(2)

    resinfo = resinfo['reservations'][0]
    assert resinfo['name'] == name

    return resinfo

def use_reservation(name=None, resinfo=None, extra_nodes=0, dry_run=False):
    """
    Move jobs from regular queue into a reservation

    Options:
        name (str): name of the reservation
        resinfo (dict): dictionary of reservation info from get_reservation_info
        extra_nodes (int): over-fill the reservation by this many jobs
        dry_run (bool): if True, print scontrol commands but don't move jobs

    Must provide name or pre-cached resinfo=get_reservation_info(name).

    This will move eligible jobs from the regular queue into the
    requested reservation, up to the reservation size + extra_nodes.
    It auto-detects CPU vs. GPU reservations, and prioritizes what type
    of jobs are most important to move first (e.g. psfnight jobs because
    those block other jobs more than a single arc job does).

    It does not move jobs that are still waiting on dependencies so that
    they don't fill up a spot in the reservation without being able to run.
    """

    ## NOTE:
    ## "scontrol show res kibo26_cpu --json" returns a "partition" that
    ## seems to match the "PARTITION" of "squeue -u desi -o '%i,%P,%v,%j,%u,%t,%M,%D,%R'"
    ## for jobs that are in the regular queue, but jobs that are in the reservation have
    ## and squeue reported "PARTITION" of "resv", not the parition of the reservation...

    log = get_logger()

    if resinfo is None and name is None:
        msg = 'Must provide either name or resinfo'
        log.critical(msg)
        raise ValueError(msg)

    if resinfo is None:
        resinfo = get_reservation_info(name)

    if name is None:
        name = resinfo['name']

    ressize = resinfo['node_count']

    #- job types for CPU and GPU, in the order they should be considered for the reservation
    cpujobs = ['linkcal', 'ccdcalib', 'psfnight', 'arc']
    gpujobs = ['nightlyflat', 'flat', 'ztile', 'tilenight', 'zpix']

    #- which regular queue partition is eligible for this reservation?
    regular_partition = resinfo['partition']

    #- Determine CPU vs. GPU reservation
    #- NOTE: some NERSC Perlmutter-specific hardcoding here
    if resinfo['partition'].startswith('gpu'):
        restype = 'GPU'
        jobtypes = gpujobs
    elif resinfo['partition'].startswith('regular'):
        restype = 'CPU'
        jobtypes = cpujobs
    else:
        log.critical(f'Unrecognized reservation type partition={resinfo["partition"]}')
        sys.exit(3)

    #- Get currently running jobs and filter to those that are eligible for this reservation
    jobs = get_jobs_in_queue()

    #- Sort by name so that earlier nights are prioritized over later nights
    jobs.sort('NAME')

    #- Filter which jobs are in reservation vs. eligible to move into reservation
    jobs_in_reservation = jobs[jobs['RESERVATION'] == name]

    eligible_for_reservation = jobs['RESERVATION'] == 'null'
    eligible_for_reservation &= jobs['PARTITION'] == regular_partition

    #- Only move jobs that are currently eligible to run, not those waiting for dependencies
    #- so that we don't fill reservation with jobs that can't run
    eligible_for_reservation &= jobs['ST'] == 'PD'
    eligible_for_reservation &= jobs['NODELISTREASON'] != 'Dependency'
    jobs_eligible = jobs[eligible_for_reservation]

    #- if there are 20x more tilenight or ztile jobs eligible than flats,
    #- move them to second priority after fiberflatnight
    num_eligible_tilenight = np.sum(np.char.startswith(jobs_eligible['NAME'], 'tilenight'))
    num_eligible_ztile = np.sum(np.char.startswith(jobs_eligible['NAME'], 'ztile'))
    num_eligible_flat = np.sum(np.char.startswith(jobs_eligible['NAME'], 'flat'))

    if num_eligible_tilenight > 0 and num_eligible_flat > 0 and num_eligible_tilenight > 20*num_eligible_flat:
        log.info(f'{num_eligible_tilenight} tilenight jobs >> {num_eligible_flat} flat jobs; prioritizing tilenight')
        jobtypes.remove('tilenight')
        jobtypes.insert(1, 'tilenight')

    if num_eligible_ztile > 0 and num_eligible_flat > 0 and num_eligible_ztile > 20*num_eligible_flat:
        log.info(f'{num_eligible_ztile} ztile jobs >> {num_eligible_flat} flat jobs; prioritizing ztile')
        jobtypes.remove('ztile')
        jobtypes.insert(1, 'ztile')

    #- Counting jobs and nodes in and out of the reservation
    njobs_in_reservation = len(jobs_in_reservation)
    njobnodes_in_reservation = np.sum(jobs_in_reservation['NODES'])

    njobs_eligible = len(jobs_eligible)
    njobnodes_eligible = np.sum(jobs_eligible['NODES'])

    njobnodes_to_add = max(0, ressize + extra_nodes - njobnodes_in_reservation)
    njobnodes_to_add = min(njobnodes_to_add, njobnodes_eligible)

    log.info(f'At {time.asctime()}, {name} ({ressize} nodes) has {njobs_in_reservation} jobs using {njobnodes_in_reservation} nodes')
    log.info(f'{njobs_eligible} {restype} jobs using {njobnodes_eligible} nodes are eligible to be moved into the reservation')

    if njobs_eligible == 0:
        log.info('No available jobs to add')
        return

    if njobnodes_to_add == 0:
        log.info('Reservation full, no need to add more jobs at this time')
        return

    log.info(f'Adding jobs to use {njobnodes_to_add} additional nodes')
    jobnodes_added = 0
    jobids = list()
    for jobtype in jobtypes:
        ii = np.char.startswith(jobs_eligible['NAME'], jobtype)
        for row in jobs_eligible[ii]:
            jobname = row['NAME']
            log.info(f'Move {jobname} to {name}')
            jobnodes_added += row['NODES']
            jobids.append(row['JOBID'])

            if jobnodes_added >= njobnodes_to_add:
                break

        if jobnodes_added >= njobnodes_to_add:
            break

    if len(jobids) > 0:
        if dry_run:
            log.info('Dry run mode; will print what to do but not actually run the commands')
        else:
            log.info('Running scontrol commands')

        #- Update queue in batches of batch_size jobs
        batch_size = 10
        for i in range(0, len(jobids), batch_size):
            jobids_csv = ','.join([str(x) for x in jobids[i:i+batch_size]])
            cmd = f'scontrol update ReservationName={name} JobID={jobids_csv}'
            if dry_run:
                #- Purposefully print, not log, to make it easier to cut-and-paste
                print(cmd)
            else:
                log.info(cmd)
                try:
                    subprocess.run(cmd.split(), check=True)
                except subprocess.CalledProcessError as err:
                    log.error(str(err))
                    log.warning('Continuing anyway')

#--------------------------------------------------------------------

def main():
    import argparse
    import datetime
    p = argparse.ArgumentParser()
    p.add_argument('-r', '--reservation', required=True, help="batch reservation name")
    p.add_argument('-n', '--extra-nodes', type=int, default=0,
                   help="Add jobs for this number of additional nodes beyond reservation size")
    p.add_argument('--sleep', type=int, help="Sleep this number of minutes between checks")
    p.add_argument('--until', type=str, help="Keep running until this YEAR-MM-DDThh:mm(:ss) time")
    p.add_argument('--dry-run', action="store_true", help="Print what to do, but don't actually move jobs to reservation")
    args = p.parse_args()

    log = get_logger()

    if args.until is not None:
        datetime_until = datetime.datetime.fromisoformat(args.until)
        log.info(f'Will keep checking until {args.until}')

        #- arg.until implies sleeping in a loop, so make sure that is set
        if args.sleep is None:
            args.sleep = 5
    else:
        datetime_until = None


    #- Cache reservation information so that we don't have to look it up every loop
    resinfo = get_reservation_info(args.reservation)

    #- Move jobs into the reservation; optionally sleep, repeat
    while (datetime_until is None) or datetime.datetime.now() < datetime_until:
        use_reservation(resinfo=resinfo, extra_nodes=args.extra_nodes, dry_run=args.dry_run)
        if args.sleep is None:
            break
        else:
            log.info(f'Sleeping {args.sleep} minutes before next check')
            time.sleep(60*args.sleep)

    log.info(f'Done checking at {time.asctime()}')

if __name__ == '__main__':
    main()
