#!python
"""radical-orbit-run — submit one task to the bridge task dispatcher and
block until it reaches a terminal state.

Invoked per Makeflow rule by ``radical-orbit-makeflow-prep``-rewritten
commands, but usable standalone.  Everything after the ``--`` separator
is the command to run on the remote endpoint.

Usage:
    radical-orbit-run (--pool NAME | --endpoint NAME) --run-id ID \\
        [--priority N] [--in FILE]... [--out FILE]... \\
        [--pools pools.json] -- COMMAND [ARG...]

Targeting (exactly one required):
    --pool NAME   route through a dispatcher-managed pilot pool
                  (supports --in/--out file staging)
    --endpoint NAME   run directly on that endpoint's rhapsody (transparent
                  proxy; no pool, no file staging)

Options:
    --run-id ID   stable id for this workflow invocation; combined with
                  the command + inputs/outputs to derive the task id
    --priority N  dispatch priority (higher first); default 0
    --in  FILE    stage this input file into the task's scratch dir
                  (repeatable; --pool only)
    --out FILE    stage this output file back after success
                  (repeatable; --pool only)
    --pools FILE  pools.json declaring the workflow's pools; forwarded
                  to the dispatcher on register_session

Exit code mirrors the task's exit code (0 on DONE without one, 1 on
FAILED/CANCELED without one).
"""

import argparse
import hashlib
import logging
import os
import signal
import sys
import threading
import time

from pathlib import Path

log = logging.getLogger('radical.orbit.run')

# Poll interval when falling back from SSE to HTTP polling
_POLL_INTERVAL_SEC = 2.0
# How long to wait for the dispatcher's SSE listener to connect
_SSE_CONNECT_TIMEOUT = 15.0
# Terminal task states (mirrors task_dispatcher_state.TASK_TERMINAL_STATES)
_TERMINAL_STATES = {'DONE', 'FAILED', 'CANCELED'}


# ---------------------------------------------------------------------------
# Arg parsing
# ---------------------------------------------------------------------------

def _split_argv(argv: list[str]) -> tuple[list[str], list[str]]:
    """Split argv at the first ``--`` separator.

    Returns (opts_before_sep, cmd_after_sep).  Exits non-zero with a
    clear message if the separator is missing.
    """
    try:
        sep = argv.index('--')
    except ValueError:
        print('radical-orbit-run: missing -- separator before the command',
              file=sys.stderr)
        sys.exit(2)
    opts = argv[1:sep]
    cmd  = argv[sep + 1:]
    if not cmd:
        print('radical-orbit-run: no command after --', file=sys.stderr)
        sys.exit(2)
    return opts, cmd


def _parse_opts(opts: list[str]) -> argparse.Namespace:
    ap = argparse.ArgumentParser(prog='radical-orbit-run',
                                  add_help=False)
    # --endpoint XOR --pool: a task targets either a pool (dispatcher
    # manages pilots) or an endpoint directly (transparent proxy to that
    # endpoint's rhapsody — no pool, no input/output staging).
    target = ap.add_mutually_exclusive_group(required=True)
    target.add_argument('--endpoint', default=None)
    target.add_argument('--pool', default=None)
    ap.add_argument('--run-id',   required=True, dest='run_id')
    ap.add_argument('--priority', type=int, default=0)
    ap.add_argument('--in',  dest='inputs',  action='append', default=[])
    ap.add_argument('--out', dest='outputs', action='append', default=[])
    ap.add_argument('--pools', dest='pools_file', default=None,
                    help='Path to a pools.json declaring the workflow\'s '
                         'pools.  Sent to the dispatcher on register_session; '
                         'idempotent across rules (matching pools attach to '
                         'existing state, new pools materialise).  Without '
                         'this flag the dispatcher uses its built-in '
                         '"default" pool.')
    ap.add_argument('-h', '--help', action='store_true')
    args = ap.parse_args(opts)
    if args.help:
        print(__doc__)
        sys.exit(0)
    # Flatten space-separated forms: "--in a b c" becomes one append;
    # argparse keeps each as a single element, so split post-hoc.
    args.inputs  = _flatten(args.inputs)
    args.outputs = _flatten(args.outputs)
    # Endpoint mode does not support file staging yet.
    if args.endpoint and (args.inputs or args.outputs):
        print('radical-orbit-run: --in/--out are not supported with --endpoint; '
              'use --pool for staging-enabled tasks',
              file=sys.stderr)
        sys.exit(2)
    return args


def _flatten(items: list[str]) -> list[str]:
    '''Allow both ``--in a --in b`` and ``--in "a b"`` forms.'''
    out: list[str] = []
    for item in items:
        out.extend(item.split())
    return out


# ---------------------------------------------------------------------------
# Task id
# ---------------------------------------------------------------------------

def compute_task_id(cmd: list[str], inputs: list[str], outputs: list[str],
                    run_id: str) -> str:
    '''Stable task id for this (cmd, inputs, outputs, run) tuple.

    See design doc §3.3 for the resubmit semantics that this
    stability enables.
    '''
    h = hashlib.sha1()
    for arg in cmd:
        h.update(arg.encode('utf-8', errors='replace'))
        h.update(b'\0')
    h.update(b'|inputs|')
    for f in sorted(inputs):
        h.update(f.encode('utf-8', errors='replace'))
        h.update(b'\0')
    h.update(b'|outputs|')
    for f in sorted(outputs):
        h.update(f.encode('utf-8', errors='replace'))
        h.update(b'\0')
    h.update(b'|run|')
    h.update(run_id.encode('utf-8', errors='replace'))
    return 't.' + h.hexdigest()[:16]


# ---------------------------------------------------------------------------
# Wrapper runtime
# ---------------------------------------------------------------------------

class Wrapper:
    '''Encapsulates one invocation's state so SIGTERM can call back in.'''

    def __init__(self, args: argparse.Namespace, cmd: list[str]) -> None:
        self.args      = args
        self.cmd       = cmd
        self.task_id   = compute_task_id(cmd, args.inputs,
                                          args.outputs, args.run_id)
        self._td_client: object | None = None
        self._bc:        object | None = None
        self._terminal_event = threading.Event()
        self._final_state: dict | None = None

    def _get_td(self):
        '''Lazy import + connect.  Done inside run() so argv errors
        fail before we bring up a full BridgeClient.

        The dispatcher is bridge-hosted, so we always target the
        ``'bridge'`` endpoint for plugin lookup; ``--endpoint`` (if given) is
        a per-task routing hint forwarded in the submit body.
        '''
        from radical.orbit import BridgeClient  # type: ignore[import-untyped]
        self._bc = BridgeClient()
        endpoint_c   = self._bc.get_endpoint_client('bridge')

        # Forward --pools to register_session if provided.  The body
        # must match the pools.json schema (top-level "pools" list).
        kwargs: dict = {}
        pools_file = getattr(self.args, 'pools_file', None)
        if pools_file:
            import json
            with open(pools_file) as f:
                body = json.load(f)
            pools = body.get('pools', body) if isinstance(body, dict) else body
            kwargs['pools'] = pools

        self._td_client = endpoint_c.get_plugin('task_dispatcher', **kwargs)
        return self._td_client

    def _close(self) -> None:
        if self._bc is not None:
            try:
                self._bc.close()
            except Exception:
                pass

    def _on_task_status(self, endpoint: str, plugin: str,
                         topic: str, data: dict) -> None:
        '''SSE callback — fire terminal event when our task_id hits a
        terminal state.  Runs in the BridgeClient listener thread.'''
        del endpoint, plugin, topic  # filtered at registration time
        if data.get('task_id') != self.task_id:
            return
        state = str(data.get('state', '')).upper()
        if state in _TERMINAL_STATES:
            self._final_state = data
            self._terminal_event.set()

    def _stage_inputs(self, td) -> str:
        '''Upload declared inputs and return the cwd path on the endpoint.

        If there are no inputs, uploads a zero-byte sentinel so the
        dispatcher creates the task's scratch dir and returns its path.
        '''
        cwd = None
        if not self.args.inputs:
            # Sentinel upload — named so it's obvious in the scratch dir.
            reply = td.stage_in(
                self.args.pool, self.task_id,
                '.radical-orbit-run.scratch-init', b'',
                overwrite=True)
            cwd = reply['cwd']
        else:
            for path in self.args.inputs:
                try:
                    with open(path, 'rb') as fh:
                        content = fh.read()
                except OSError as e:
                    print(f'radical-orbit-run: cannot read input {path}: {e}',
                          file=sys.stderr)
                    sys.exit(3)
                filename = os.path.basename(path)
                reply = td.stage_in(
                    self.args.pool, self.task_id, filename, content,
                    overwrite=True)
                cwd = reply['cwd']
        assert cwd is not None
        return cwd

    def _stage_outputs(self, td) -> None:
        '''Pull declared outputs into the local Makeflow workdir.

        Missing outputs are reported as warnings; Makeflow will mark the
        rule failed if its declared outputs are not present, which is
        the correct behaviour anyway.
        '''
        for local_path in self.args.outputs:
            name = os.path.basename(local_path)
            try:
                content = td.stage_out(self.task_id, name)
            except Exception as e:
                print(f'radical-orbit-run: stage_out failed for {name}: {e}',
                      file=sys.stderr)
                continue
            dest = Path(local_path)
            dest.parent.mkdir(parents=True, exist_ok=True)
            dest.write_bytes(content)

    def _wait_for_terminal(self) -> dict:
        '''Block on SSE event with a polling fallback.

        If the SSE listener doesn't connect within a grace window, fall
        back to periodic ``get_task`` polling.  Either way, returns the
        final task dict.
        '''
        if self._bc is None:
            raise RuntimeError('bridge client not initialised')
        connected = self._bc.wait_for_listener(timeout=_SSE_CONNECT_TIMEOUT)

        if connected:
            # Wait indefinitely for SSE — Makeflow will SIGTERM us if
            # the rule needs to be cancelled.
            self._terminal_event.wait()
            assert self._final_state is not None
            return self._final_state

        # Fallback: poll
        print('radical-orbit-run: SSE not connected; falling back to polling',
              file=sys.stderr)
        td = self._td_client
        assert td is not None
        while True:
            try:
                t = td.get_task(self.task_id)  # type: ignore[attr-defined]
            except Exception as e:
                print(f'radical-orbit-run: get_task failed: {e}',
                      file=sys.stderr)
                time.sleep(_POLL_INTERVAL_SEC)
                continue
            if str(t.get('state', '')).upper() in _TERMINAL_STATES:
                return t
            time.sleep(_POLL_INTERVAL_SEC)

    def _install_signal_handler(self) -> None:
        '''On SIGTERM/SIGINT, cancel the task and exit.'''
        def _on_signal(signum, _frame):
            print(f'radical-orbit-run: caught signal {signum}; '
                  f'cancelling task {self.task_id}', file=sys.stderr)
            try:
                if self._td_client is not None:
                    self._td_client.cancel_task(self.task_id)  # type: ignore[attr-defined]
            except Exception as e:
                print(f'radical-orbit-run: cancel_task failed: {e}',
                      file=sys.stderr)
            sys.exit(130 if signum == signal.SIGINT else 143)

        signal.signal(signal.SIGTERM, _on_signal)
        signal.signal(signal.SIGINT,  _on_signal)

    def run(self) -> int:
        self._install_signal_handler()

        td = self._get_td()

        # Register a persistent SSE callback BEFORE submitting so no
        # notification can slip through between submit and subscribe.
        # The dispatcher is always on the 'bridge' endpoint.
        if self._bc is not None:
            self._bc.register_callback(
                endpoint_id='bridge',
                plugin_name='task_dispatcher',
                topic='task_status',
                callback=self._on_task_status)

        if self.args.pool:
            cwd = self._stage_inputs(td)
        else:
            # Endpoint mode: no staging; the task's cwd is the directory
            # we were invoked from (makeflow rules run their child in
            # the workflow's workdir).
            cwd = os.getcwd()

        try:
            td.submit_task(
                pool     = self.args.pool,
                endpoint     = self.args.endpoint,
                task_id  = self.task_id,
                cmd      = self.cmd,
                cwd      = cwd,
                priority = self.args.priority,
                inputs   = [os.path.basename(p)
                            for p in self.args.inputs],
                outputs  = [os.path.basename(p)
                            for p in self.args.outputs],
            )
        except Exception as e:
            print(f'radical-orbit-run: submit failed: {e}', file=sys.stderr)
            return 4

        final = self._wait_for_terminal()
        state = str(final.get('state', '')).upper()
        rc    = final.get('exit_code')

        if state == 'DONE' and self.args.pool:
            self._stage_outputs(td)

        if rc is None:
            # FAILED/CANCELED without an exit code — pick a sentinel
            return 0 if state == 'DONE' else 1
        return int(rc)


def main(argv: list[str] | None = None) -> int:
    logging.basicConfig(
        level=(os.environ.get('RADICAL_ORBIT_LOG_LVL')
               or os.environ.get('RADICAL_LOG_LVL') or 'WARNING'),
        format='%(asctime)s radical-orbit-run %(levelname)s %(message)s',
        stream=sys.stderr)

    if argv is None:
        argv = sys.argv
    # --help before argument parsing so "radical-orbit-run --help" works.
    if any(a in ('-h', '--help') for a in argv[1:]):
        print(__doc__)
        return 0
    opts, cmd = _split_argv(argv)
    args      = _parse_opts(opts)

    wrapper = Wrapper(args, cmd)
    try:
        return wrapper.run()
    finally:
        wrapper._close()


if __name__ == '__main__':
    sys.exit(main())
