#!/usr/bin/env python
#
# See top-level LICENSE.rst file for Copyright information
#
# -*- coding: utf-8 -*-

"""Simulate DTS by copying data from a staging area.
"""

from __future__ import absolute_import, division, print_function

import sys
import os
import time
import shutil

from collections import OrderedDict

import argparse
import re
import warnings

import fitsio

from desispec.util import sprun

import desispec.io as io

from desiutil.log import get_logger

import desispec.pipeline as pipe


def pack_args(args):
    optlist = [
        "nersc",
        "nersc_queue",
        "nersc_queue_redshifts",
        "nersc_maxtime",
        "nersc_maxnodes",
        "nersc_maxnodes_small",
        "nersc_maxnodes_redshifts",
        "nersc_shifter",
        "mpi_procs",
        "mpi_run",
        "procs_per_node"
    ]
    varg = vars(args)
    opts = list()
    for k, v in varg.items():
        if k in optlist:
            if v is not None:
                opts.append("--{}".format(k))
                if not isinstance(v, bool):
                    opts.append(v)
    return opts


def main():
    parser = argparse.ArgumentParser(description="Copy data from a staging "
        "area into DESI_SPECTRO_DATA")

    parser.add_argument("--staging", required=False, default=".",
        help="Staging directory containing night subdirs.")

    parser.add_argument("--exptime_arc", required=False, type=int, default=15,
        help="Minutes for arc exposures.")

    parser.add_argument("--exptime_flat", required=False, type=int, default=5,
        help="Minutes for flat exposures.")

    parser.add_argument("--exptime_science", required=False, type=int,
        default=15, help="Minutes for science exposures.")

    parser.add_argument("--night_break", required=False, default=0, type=int,
        help="Minutes to pause in between nights.")

    parser.add_argument("--nersc", required=False, default=None,
        help="write a script for this NERSC system (edison | cori-haswell "
        "| cori-knl)")

    parser.add_argument("--nersc_queue", required=False, default="regular",
        help="write a script for this NERSC queue (debug | regular)")

    parser.add_argument("--nersc_queue_redshifts", required=False,
        default=None, help="Use this NERSC queue for redshifts. "
        "Defaults to same as --nersc_queue.")

    parser.add_argument("--nersc_maxtime", required=False,
        default=None, help="Then maximum run time (in minutes) for a single "
        " job.  If the list of tasks cannot be run in this time, multiple "
        " job scripts will be written.  Default is the maximum time for "
        " the specified queue.")

    parser.add_argument("--nersc_maxnodes", required=False,
        default=None, help="The maximum number of nodes to use.  Default "
        " is the maximum nodes for the specified queue.")

    parser.add_argument("--nersc_maxnodes_small", required=False,
        default=None, help="The maximum number of nodes to use for 'small' "
        "steps like the per-night psf and fiberflat.  Default is to use the"
        " same value as --nersc_maxnodes.")

    parser.add_argument("--nersc_maxnodes_redshifts", required=False,
        default=None, help="The maximum number of nodes to use for "
        " redshifts.  Default is to use --nersc_maxnodes.")

    parser.add_argument("--nersc_shifter", required=False, default=None,
        help="The shifter image to use for NERSC jobs")

    parser.add_argument("--mpi_procs", required=False, default=None,
        help="The number of MPI processes to use for non-NERSC shell "
        "scripts (default 1)")

    parser.add_argument("--mpi_run", required=False,
        default=None, help="The command to launch MPI programs "
        "for non-NERSC shell scripts (default do not use MPI)")

    parser.add_argument("--procs_per_node", required=False,
        default=None, help="The number of processes to use per node.  If not "
        "specified it uses a default value for each machine.")

    parser.add_argument("--start_expid", required=False, default=None,
        help="Start processing at this exposure ID.")

    args = parser.parse_args()

    nightopts = pack_args(args)

    flavor_times = {
        "arc" : args.exptime_arc,
        "flat" : args.exptime_flat,
        "science" : args.exptime_science
    }

    log = get_logger()

    startexp = 0
    if args.start_expid is not None:
        startexp = int(args.start_expid)

    # data locations

    stagedir = os.path.abspath(args.staging)
    rawdir = os.path.abspath(io.rawdata_root())

    log.info("Starting fake DTS at {}".format(time.asctime()))
    log.info("  Will copy data from staging location:")
    log.info("    {}".format(stagedir))
    log.info("  To DESI_SPECTRO_DATA at:")
    log.info("    {}".format(rawdir))
    sys.stdout.flush()

    nights = io.get_nights(strip_path=True, specprod_dir=stagedir,
                           sub_folder="")

    for nt in nights:
        # exposures for this night
        expids = io.get_exposures(nt, raw=True, rawdata_dir=stagedir)

        # Look up flavors of each exposure
        lastarc = None
        lastflat = None
        lastscience = None
        ntexp = OrderedDict()
        for ex in expids:
            fibermap = io.get_raw_files("fibermap", nt, ex,
                                        rawdata_dir=stagedir)
            fmdata, header = fitsio.read(fibermap, header=True)
            flavor = header["FLAVOR"].strip().lower()
            if flavor not in ["arc", "flat", "science"] :
                log.error("Unknown flavor '{}' for file '{}'".format(flavor, fibermap))
            else:
                ntexp[ex] = flavor
                if flavor == "arc":
                    lastarc = ex
                if flavor == "flat":
                    lastflat = ex
                if flavor == "science":
                    lastscience = ex

        # Go through exposures in order
        for ex, flavor in ntexp.items():
            if int(ex) < startexp:
                log.info("Skipping exposure {} on night {} which is before --start_expid".format(ex, nt))
                continue
            exptime = flavor_times[flavor]
            expsec = 60 * exptime
            log.info("Acquiring exposure {} on night {} for {} minutes...".format(ex, nt, exptime))
            sys.stdout.flush()
            time.sleep(expsec)
            fmsrc = io.get_raw_files("fibermap", nt, ex, rawdata_dir=stagedir)
            rawsrc = io.get_raw_files("raw", nt, ex, rawdata_dir=stagedir)

            targetdir = os.path.join(rawdir, nt)
            if not os.path.isdir(targetdir):
                os.makedirs(targetdir)

            targetexpdir = os.path.join(targetdir, "{:08d}".format(int(ex)))
            if not os.path.isdir(targetexpdir):
                os.makedirs(targetexpdir)

            fmtarg = os.path.join(targetexpdir, os.path.basename(fmsrc))
            rawtarg = os.path.join(targetexpdir, os.path.basename(rawsrc))

            if os.path.exists(fmtarg):
                os.remove(fmtarg)
            if os.path.exists(rawtarg):
                os.remove(rawtarg)

            shutil.copy2(fmsrc, fmtarg)
            shutil.copy2(rawsrc, rawtarg)

            log.info("  Finished exposure {} on night {}".format(ex, nt))
            sys.stdout.flush()

            # Trigger nightly processing
            com = ["desi_night", "update", "--night", "{}".format(nt),
                   "--expid", "{}".format(ex)]
            com.extend(nightopts)
            log.info("  Running {}".format(" ".join(com)))
            ret = sprun(com)
            if ret != 0:
                sys.exit(ret)

            # If we are at the last arc or flat, trigger nightly products.
            if ex == lastarc:
                com = ["desi_night", "arcs", "--night", "{}".format(nt)]
                com.extend(nightopts)
                log.info("  Running {}".format(" ".join(com)))
                ret = sprun(com)
                if ret != 0:
                    sys.exit(ret)

            if ex == lastflat:
                com = ["desi_night", "flats", "--night", "{}".format(nt)]
                com.extend(nightopts)
                log.info("  Running {}".format(" ".join(com)))
                ret = sprun(com)
                if ret != 0:
                    sys.exit(ret)

        # Trigger redshifts at the end of the night (for now)
        if int(lastscience) >= startexp:
            com = ["desi_night", "redshifts", "--night", "{}".format(nt)]
            com.extend(nightopts)
            log.info("  Running {}".format(" ".join(com)))
            ret = sprun(com)
            if ret != 0:
                sys.exit(ret)

            # Pause until the next night
            incr = 5
            nloop = args.night_break // 5
            for ir in range(nloop):
                log.info("End of night {}, will resume in {} minutes...".format(nt, args.night_break - (ir*incr)))
                sys.stdout.flush()
                time.sleep(60*incr)


if __name__ == '__main__':
    main()
