#!python

"""
Redrock for DESI - MPI entry point
"""
import os
import sys
import argparse
import socket
import numpy as np

from redrock.utils import nersc_login_node, getGPUCountMPI
from redrock.external import desi
from desispec.util import runcmd
from desispec.scripts import qsoqn, qsomgii, emlinefit

# MPI environment availability
have_mpi = None
if nersc_login_node():
    print ("wrap_rrdesi should not be run on a login node.")
    sys.exit(0)
else:
    have_mpi = True
    try:
        import mpi4py.MPI as MPI
    except ImportError:
        have_mpi = False
        print ("MPI not available - required to run wrap_rrdesi")
        sys.exit(0)

parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("-i", "--input", type=str, default=None,
        required=True, help="input ASCII file list")
parser.add_argument("--input-dir", type=str, default=None,
        required=False, help="input directory")
parser.add_argument("-o", "--output", type=str, default=None,
        required=True, help="output directory")
parser.add_argument("--gpu", action="store_true",
        required=False, help="use GPUs")
parser.add_argument("--overwrite", action="store_true",
        required=False, help="Overwrite existing output files")
parser.add_argument("--rrdetails", action="store_true",
        required=False, help="Write out the rrdetails files.")
parser.add_argument("--afterburners", action="store_true",
        required=False, help="Run all after-burners.")
parser.add_argument("--cpu-per-task", type=int, default=32,
        required=False, help="Maximum number of CPUs to use on each input file")
parser.add_argument("--gpuonly", action="store_true",
        required=False, help="Use ONLY GPUs")
#Gather args and any unrecognized args are to be passed to redrock
args, args_to_pass = parser.parse_known_args()
inputdir = None
outdir = args.output
if args.input_dir is not None:
    inputdir = args.input_dir
cpu_per_task = args.cpu_per_task
overwrite = args.overwrite
rrdetails = args.rrdetails
afterburners = args.afterburners

#- global communicator across all nodes
comm = MPI.COMM_WORLD
comm_rank = comm.rank

#print ("COMM", comm.size, comm.rank)
env = os.environ
if not 'SLURM_STEP_RESV_PORTS' in os.environ and comm.rank == 0:
    print ("WARNING: Detected that wrap_rrdesi is not being run with srun command.")
    print ("WARNING: Calling directly can lead to under-utilizing resources.")
    print ("Recommended syntax: srun -N nodes -n tasks -c 2 --gpu-bind=map_gpu:3,2,1,0  ./wrap_rrdesi [options]") 
    print ("\tEx: 8 tasks each with GPU support on 2 nodes:")
    print ("\t\tsrun -N 2 -n 8 -c 2 --gpu-bind=map_gpu:3,2,1,0  wrap_rrdesi ...")
    print ("\tEx: 64 tasks on 1 node and 4 GPUs - this will run on both GPU and non-GPU nodes at once:")
    print ("\t\tsrun -N 1 -n 64 -c 2 --gpu-bind=map_gpu:3,2,1,0  wrap_rrdesi ...")


#Get number of nodes
nhosts = os.getenv('SLURM_NNODES')
if nhosts is None:
    #env var not set, try hostnames
    hostnames = comm.gather(socket.gethostname(), root=0)
    if comm.rank == 0:
        nhosts = len(set(hostnames))
    nhosts = comm.bcast(nhosts, root=0)
else:
    nhosts = int(nhosts)

# GPU configuration
ngpu = 0
gpu_per_node = 0
if args.gpu:
    gpu_per_node = os.getenv('SLURM_GPUS_PER_NODE')
    if gpu_per_node is None:
        #Use utils.getGPUCountMPI which will look at /proc/driver/nvidia/gpus/
        gpu_per_node = getGPUCountMPI(comm)
    else:
        gpu_per_node = int(gpu_per_node)
    ngpu = gpu_per_node*nhosts

if ngpu > comm.size:
    if comm.rank == 0:
        print (f"WARNING: wrap_rrdesi was called with {ngpu} GPUs but only {comm.size} MPI ranks.")
        print (f"WARNING: Will only use {comm.size} GPUs.")
    ngpu = comm.size 

#Set GPU nodes
#We want the first gpu_per_node ranks of each host
ranks_per_host = comm.size // nhosts
use_gpu = (comm_rank % ranks_per_host) < gpu_per_node
ncpu_ranks = (comm.size - ngpu -1) // cpu_per_task + 1
#if comm.rank == 0:
#    print (f'{ngpu=}, {gpu_per_node=}, {nhosts=}')
#    print (f'{ranks_per_host=}, {use_gpu=}, {ncpu_ranks=}')
#    print (f'{comm.size=}, {comm_rank=}, {cpu_per_task=}')
if args.gpuonly:
    ncpu_ranks = 0

if comm.rank == 0 and not os.access(outdir, os.F_OK):
    try:
        os.mkdir(outdir)
    except Exception as ex:
        print (ex)
        print ("Error: could not make output directory "+outdir)
        sys.exit(0)

#- read and broadcast input files
inputfiles = None
if comm.rank == 0:
    with open(args.input, 'r') as f:
        inputfiles = f.readlines()
    for i in range(len(inputfiles)):
        inputfiles[i] = inputfiles[i].strip()
        if inputdir is not None:
            inputfiles[i] = inputdir+'/'+inputfiles[i]
inputfiles = comm.bcast(inputfiles, root=0)

#- split subcommunicators
#number of communicators
ncomm = ngpu + ncpu_ranks
if use_gpu:
    myhost = (comm.rank % ranks_per_host) + (comm.rank // ranks_per_host)*gpu_per_node
else:
    myhost = ngpu + (comm.rank - gpu_per_node*(comm.rank // ranks_per_host)) // cpu_per_task
subcomm = comm.Split(myhost)
#print (f'{comm.rank=}, {ncomm=}, {myhost=}, {subcomm.size=}')

if comm.rank == 0:
    print("Running "+str(len(inputfiles))+" input files on "+str(ngpu)+" GPUs and "+str(ncomm)+" total procs...")

#- each subcommunicator processes a subset of files
# In --gpuonly mode, CPU procs will not enter this block 
if myhost < ncomm:
    myfiles = np.array_split(inputfiles, ncomm)[myhost]
    nfiles = len(myfiles)
    #print (f'DEBUG: {myhost=} {ncomm=} {nfiles=} {myfiles=}, {comm.rank=}')
    for infile in myfiles:
        redrockfile = os.path.join(outdir, os.path.basename(infile).replace('coadd-', 'redrock-'))
        if os.path.isfile(redrockfile) and not overwrite:
            if subcomm.rank == 0:
                print(f'Warning: skipping existing Redrock file {redrockfile}.')
        else:
            if os.path.isfile(redrockfile) and subcomm.rank == 0:
                print(f'Warning: overwriting existing Redrock file {redrockfile}.')
                
            opts = ['-i', infile, '-o', redrockfile]
            if rrdetails:
                h5file = os.path.join(outdir, os.path.basename(infile).replace('coadd-', 'rrdetails-').replace('.fits', '.h5'))
                opts += ['-d', h5file]
                
            if args_to_pass is not None:
                opts.extend(args_to_pass)
            if use_gpu:
                opts.append('--gpu')
            print (f'Running rrdesi on {myhost=} {subcomm.rank=} with options {opts=}')
            desi.rrdesi(opts, comm=subcomm)

        # optionally run all the afterburners
        if afterburners:
            for prefix, afterburn, maincmd in zip(['qso_qn', 'qso_mgii', 'emline'],
                                                  ['qso_qn', 'qso_mgii', 'emlinefit'],
                                                  [qsoqn.main, qsomgii.main, emlinefit.main]):
                afterfile = os.path.join(outdir, os.path.basename(infile).replace('coadd-', f'{prefix}-'))
                if os.path.isfile(afterfile) and not overwrite:
                    if subcomm.rank == 0:
                        print(f'Warning: skipping existing afterburner file {afterfile}.')
                else:
                    if os.path.isfile(afterfile) and subcomm.rank == 0:
                        print(f'Warning: overwriting existing afterburner file {afterfile}.')
                        
                    cmd = f'desi_{afterburn}_afterburner --coadd {infile} --redrock {redrockfile} --output {afterfile}'
                    if 'qso_' in prefix:
                        cmd += ' --target_selection all --save_target all'
                    if prefix == 'qso_qn':
                        runcmd(maincmd, args=cmd.split()[1:], inputs=[infile, redrockfile], outputs=[afterfile], comm=subcomm)
                    else:
                        runcmd(maincmd, args=cmd.split()[1:], inputs=[infile, redrockfile], outputs=[afterfile])
