#!python
"""radical-orbit-makeflow-prep — preprocess a .makeflow for the task dispatcher.

Reads a user-authored Makeflow file and rewrites every rule's command
to route through ``radical-orbit-run``.  Directives ``ENDPOINT``, ``POOL``,
and ``PRIORITY`` are scoped like Makeflow's existing ``CATEGORY``:
an assignment applies to all subsequent rules until overridden.

Usage:
    radical-orbit-makeflow-prep INPUT.makeflow \\
        [--output OUTPUT.makeflow] \\
        [--default-endpoint NAME] [--default-pool NAME] \\
        [--default-priority N]

Input example (before prep):

    ENDPOINT = "perlmutter"
    POOL = "cpu_small"
    PRIORITY = 10

    out.dat: in.dat
        ./compute in.dat out.dat

Output (after prep):

    out.dat: in.dat
        radical-orbit-run --endpoint=perlmutter --pool=cpu_small \\
                         --priority=10 --run-id=<sha1> \\
                         --in in.dat --out out.dat \\
                         -- ./compute in.dat out.dat

``run_id = sha1(abs_path_of_input + mtime_ns)[:16]`` is baked into every
rewritten command, so the dispatcher can distinguish different
invocations of the same workflow file while retrying within one run
stays idempotent.

See plans/task_dispatcher_design.md §8.2.
"""

from __future__ import annotations

import argparse
import hashlib
import re
import shlex
import sys

from dataclasses import dataclass
from pathlib import Path


# Substitute Makeflow's ``$(VAR)`` references with the values
# captured from ``VAR=value`` assignments earlier in the file.
# Needed because we wrap the original rule command in
# ``sh -c '<cmd>'`` — the single quotes block Makeflow's own
# substitution, and we cannot rely on ${VAR} env expansion at the
# remote endpoint (the executor's process tree doesn't inherit
# Makeflow's exports).  Unknown vars are left alone so the user
# notices the typo rather than silently getting an empty string.
_MAKEFLOW_VAR = re.compile(r'\$\((\w+)\)')


def _expand_vars(cmd: str, vars: dict[str, str]) -> str:
    '''Replace Makeflow ``$(VAR)`` with the captured assignment value.

    Two passes so simple nesting like ``A=$(B)`` resolves; we cap at
    a small number of passes to avoid loops on circular refs.
    '''
    for _ in range(4):
        new = _MAKEFLOW_VAR.sub(
            lambda m: vars.get(m.group(1), m.group(0)), cmd)
        if new == cmd:
            return cmd
        cmd = new
    return cmd


# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------

@dataclass
class Rule:
    '''One Makeflow rule — header line + one or more command lines.'''
    outputs : list[str]
    inputs  : list[str]
    commands: list[str]
    header_line: str         # original "outputs : inputs" text

    # scope at the point of this rule
    endpoint    : str | None
    pool    : str | None
    priority: int | None
    pools   : str | None = None   # path to pools.json (workflow-level)


@dataclass
class PrepOptions:
    # A rule targets EITHER a specific endpoint (transparent proxy to that
    # endpoint's rhapsody) OR a pool (dispatcher-managed pilot fleet) —
    # never both.  At most one of default_endpoint / default_pool may be
    # set; the same XOR is enforced per scope inside the file
    # (ENDPOINT = "..." clears any active POOL scope and vice versa).
    default_endpoint    : str | None = None
    default_pool    : str | None = None
    default_priority: int        = 0
    # Absolute path to a pools.json declaring the workflow's pools.
    # When set, every generated rule gets --pools=<path>.  Without
    # this, the dispatcher uses the built-in 'default' pool — which
    # is rarely what a real workflow wants.
    pools_file      : str | None = None

    def __post_init__(self) -> None:
        if self.default_endpoint and self.default_pool:
            raise PrepError(
                '--default-endpoint and --default-pool are mutually exclusive')


# ---------------------------------------------------------------------------
# Parser + rewriter
# ---------------------------------------------------------------------------

# We only care about three directives; everything else passes through.
_SCOPED_DIRECTIVES = ('ENDPOINT', 'POOL', 'PRIORITY', 'POOLS')


def prep_stream(lines: list[str], run_id: str,
                opts: PrepOptions) -> list[str]:
    '''Rewrite a Makeflow file line by line.

    Returns the output lines (with terminating newlines preserved).
    Raises :class:`PrepError` on unresolvable rule placement.
    '''
    out: list[str] = []

    scope_endpoint    : str | None = opts.default_endpoint
    scope_pool    : str | None = opts.default_pool
    scope_priority: int        = opts.default_priority
    # POOLS is workflow-level (not per-rule); we track it in the same
    # scope walk for simplicity but only one final value reaches the
    # rule emitter.
    scope_pools   : str | None = opts.pools_file
    # Captured ``VAR=value`` assignments to substitute at prep time
    # into wrapped rule commands (see _expand_vars).
    scope_vars    : dict[str, str] = {}

    i = 0
    n = len(lines)
    while i < n:
        line = lines[i]
        stripped = line.strip()

        # Blank / comment — pass through
        if not stripped or stripped.startswith('#'):
            out.append(line)
            i += 1
            continue

        # Directive assignment (ENDPOINT/POOL/PRIORITY/POOLS).  We match
        # case-insensitively and allow optional whitespace around '='.
        directive = _parse_scoped_directive(stripped)
        if directive is not None:
            name, value = directive
            if name == 'ENDPOINT':
                # ENDPOINT and POOL are mutually exclusive per scope —
                # setting ENDPOINT clears any active POOL.
                scope_endpoint = _expect_string(value, name, i + 1)
                scope_pool = None
            elif name == 'POOL':
                scope_pool = _expect_string(value, name, i + 1)
                scope_endpoint = None
            elif name == 'PRIORITY':
                scope_priority = _expect_int(value, name, i + 1)
            elif name == 'POOLS':
                scope_pools = _expect_string(value, name, i + 1)
            # Our directives are consumed — don't emit to output.
            i += 1
            continue

        # Generic ``NAME = "value"`` or ``NAME=value`` assignment —
        # capture for substitution but ALSO pass the line through so
        # makeflow itself sees it (other tools, LOCAL rules, etc.).
        generic = _parse_generic_assignment(stripped)
        if generic is not None:
            name, value = generic
            scope_vars[name] = value
            out.append(line)
            i += 1
            continue

        # Rule header: has ':' at the top of a line (not indented)
        if ':' in stripped and not line.startswith(('\t', ' ')):
            rule_header = line.rstrip('\n')
            outputs, inputs = _parse_rule_header(rule_header, i + 1)

            # Collect command lines (tab- or space-indented)
            commands = []
            i += 1
            while i < n and lines[i].startswith(('\t', '    ')):
                cmd = lines[i].lstrip('\t ').rstrip('\n')
                if cmd:
                    commands.append(cmd)
                i += 1

            rule = Rule(
                outputs = outputs, inputs = inputs,
                commands = commands, header_line = rule_header,
                endpoint = scope_endpoint, pool = scope_pool,
                priority = scope_priority,
                pools = scope_pools)

            if not commands:
                raise PrepError(
                    f'rule at line {i}: no command found (outputs={outputs!r})')

            _validate_placement(rule, i)

            for line_out in _emit_rule(rule, run_id, scope_vars):
                out.append(line_out)
            continue

        # Anything else — pass through verbatim
        out.append(line)
        i += 1

    return out


def _parse_scoped_directive(stripped: str) -> tuple[str, str] | None:
    '''Return (NAME, value) if *stripped* is an ``ENDPOINT``/``POOL``/
    ``PRIORITY`` assignment, else None.
    '''
    for name in _SCOPED_DIRECTIVES:
        if stripped.startswith(name):
            rest = stripped[len(name):].lstrip()
            if rest.startswith('='):
                return name, rest[1:].strip()
    return None


_GENERIC_ASSIGN = re.compile(r'^([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(.*)$')


def _parse_generic_assignment(stripped: str) -> tuple[str, str] | None:
    '''Return (NAME, value) for a plain ``NAME=value`` assignment.

    Strips one layer of surrounding double or single quotes from the
    value if present.  Matches Makeflow's lexical handling closely
    enough for substitution into commands.
    '''
    m = _GENERIC_ASSIGN.match(stripped)
    if not m:
        return None
    name, value = m.group(1), m.group(2).strip()
    if len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'"):
        value = value[1:-1]
    return name, value


def _expect_string(value: str, name: str, line_no: int) -> str:
    '''Parse ``"quoted"`` or bareword into a Python string.'''
    if len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'"):
        return value[1:-1]
    if not value:
        raise PrepError(f'line {line_no}: {name} is empty')
    return value


def _expect_int(value: str, name: str, line_no: int) -> int:
    value = value.strip('"\'')
    try:
        return int(value)
    except ValueError as e:
        raise PrepError(
            f'line {line_no}: {name} must be an integer, got {value!r}'
        ) from e


def _parse_rule_header(line: str, line_no: int) -> tuple[list[str], list[str]]:
    '''Split ``outputs: inputs`` on the first ``:`` at the top level.

    Makeflow allows trailing backslash continuations, but we handle
    only single-line headers in v1; complain clearly if one crosses
    lines.
    '''
    if line.rstrip().endswith('\\'):
        raise PrepError(
            f'line {line_no}: multi-line rule headers not supported in v1')
    head, _, tail = line.partition(':')
    outputs = head.split()
    inputs  = tail.split()
    if not outputs:
        raise PrepError(f'line {line_no}: rule has no outputs')
    return outputs, inputs


def _validate_placement(rule: Rule, line_no: int) -> None:
    if rule.endpoint is None and rule.pool is None:
        raise PrepError(
            f'line {line_no}: rule "{" ".join(rule.outputs)}" has no '
            f'ENDPOINT or POOL assignment (set ENDPOINT = "..." or POOL = "..." '
            f'in the file, or pass --default-endpoint/--default-pool)')
    if rule.endpoint is not None and rule.pool is not None:
        # Shouldn't happen — directive handler clears the other when
        # one is set, and the CLI mutex prevents both defaults — but
        # guard against future regressions.
        raise PrepError(
            f'line {line_no}: rule "{" ".join(rule.outputs)}" has both '
            f'ENDPOINT and POOL set; they are mutually exclusive')


def _emit_rule(rule: Rule, run_id: str,
               vars: dict[str, str] | None = None) -> list[str]:
    '''Emit the rewritten rule: header + single wrapped command line.

    If the original rule had multiple command lines, they are joined
    with ``;`` — Makeflow treats each command line as its own shell
    invocation, but we need a single process so the wrapper can
    measure one exit code.  This matches Makeflow's own advice to use
    ``;`` for multi-statement rules.

    Makeflow's ``LOCAL`` keyword (``LOCAL <cmd>`` on the first command
    line) marks the rule as "always run on the submitter, never via
    a worker".  We honour it by NOT wrapping the command in
    radical-orbit-run; the LOCAL prefix is preserved so makeflow
    executes the rule on the submitter itself.
    '''
    stripped, is_local = _strip_local(rule.commands)
    original = ' ; '.join(stripped)
    if is_local:
        # makeflow handles its own $(VAR) substitution for LOCAL
        # rules — don't pre-expand here.
        return [rule.header_line + '\n',
                '\t' + 'LOCAL ' + original + '\n']

    wrapped  = _build_wrapper_cmd(rule, run_id, original, vars or {})
    return [rule.header_line + '\n', '\t' + wrapped + '\n']


def _strip_local(commands: list[str]) -> tuple[list[str], bool]:
    '''If the first command starts with ``LOCAL ``, strip that prefix
    and return (stripped_commands, True).  Otherwise return the
    commands unchanged and ``False``.
    '''
    if commands and commands[0].lstrip().startswith('LOCAL '):
        first = commands[0].lstrip()[len('LOCAL '):]
        return [first] + commands[1:], True
    return commands, False


def _build_wrapper_cmd(rule: Rule, run_id: str, original: str,
                        vars: dict[str, str]) -> str:
    '''Produce the ``radical-orbit-run ... -- sh -c '<cmd>'`` line.

    The original command is wrapped in ``sh -c`` so shell grammar
    (pipes, redirections, compound commands) is contained as a single
    argument to the wrapper rather than splitting at the outer shell
    that Makeflow uses to invoke the rule.
    '''
    parts = ['radical-orbit-run']
    # Exactly one of endpoint/pool is set (validated above).
    if rule.endpoint is not None:
        parts.append(f'--endpoint={_q(rule.endpoint)}')
    else:
        parts.append(f'--pool={_q(rule.pool)}')
    parts.append(f'--priority={rule.priority or 0}')
    parts.append(f'--run-id={_q(run_id)}')
    if rule.pools:
        parts.append(f'--pools={_q(rule.pools)}')
    # Endpoint mode doesn't (yet) support input/output staging, so skip
    # the --in/--out flags entirely — radical-orbit-run would reject
    # them at argparse time otherwise.
    if rule.endpoint is None:
        for f in rule.inputs:
            parts.append(f'--in {_q(f)}')
        for f in rule.outputs:
            parts.append(f'--out {_q(f)}')
    parts.append('--')
    parts.append('sh')
    parts.append('-c')
    parts.append(shlex.quote(_expand_vars(original, vars)))
    return ' '.join(parts)


def _q(s) -> str:
    '''Shell-quote a value if it contains metacharacters, else leave bare.

    Makeflow ingests these as shell commands, so we play it safe with
    anything non-trivial but avoid gratuitous quoting for plain names.
    '''
    if s is None:
        return ''
    s = str(s)
    safe = set('-._/:0123456789abcdefghijklmnopqrstuvwxyz'
               'ABCDEFGHIJKLMNOPQRSTUVWXYZ')
    if s and all(c in safe for c in s):
        return s
    return shlex.quote(s)


# ---------------------------------------------------------------------------
# Run-id (stable per workflow-file invocation)
# ---------------------------------------------------------------------------

def compute_run_id(path: Path) -> str:
    '''``sha1(abs_path + mtime_ns)[:16]`` — see design doc §8.2.'''
    h = hashlib.sha1()
    h.update(str(path.resolve()).encode('utf-8'))
    h.update(b'\0')
    try:
        h.update(str(path.stat().st_mtime_ns).encode('ascii'))
    except OSError:
        pass
    return h.hexdigest()[:16]


# ---------------------------------------------------------------------------
# Errors
# ---------------------------------------------------------------------------

class PrepError(ValueError):
    '''Raised on unrecoverable parse / schema errors in the input.'''
    pass


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def main(argv: list[str] | None = None) -> int:
    ap = argparse.ArgumentParser(
        prog='radical-orbit-makeflow-prep',
        description='Rewrite a .makeflow for the task dispatcher.')
    ap.add_argument('input', type=Path,
                    help='input .makeflow file')
    ap.add_argument('--output', '-o', type=Path,
                    help='output path (default: <input>.prepped)')
    # A rule must end up targeting either an endpoint or a pool, never
    # both — these CLI defaults follow the same XOR.
    target_default = ap.add_mutually_exclusive_group()
    target_default.add_argument('--default-endpoint', dest='default_endpoint',
                    default=None,
                    help='Default --endpoint for each rule (transparent proxy '
                         'to that endpoint\'s rhapsody).  Mutually exclusive '
                         'with --default-pool.')
    target_default.add_argument('--default-pool', dest='default_pool',
                    default=None,
                    help='Default --pool for each rule (dispatcher-managed '
                         'pilot fleet).  Mutually exclusive with '
                         '--default-endpoint.')
    ap.add_argument('--default-priority', dest='default_priority',
                    type=int, default=0)
    ap.add_argument('--pools',            dest='pools_file',
                    help='Path to a pools.json declaring the workflow\'s '
                         'pools.  Forwarded to radical-orbit-run on every '
                         'generated rule (--pools=<path>).  Without this, '
                         'rules use the dispatcher\'s built-in "default" '
                         'pool — which usually isn\'t what you want.  '
                         'Can also be set via POOLS = "..." in the '
                         'makeflow file.')
    ap.add_argument('--run-id', dest='run_id',
                    help='override run_id (default: sha1(abs_path + mtime))')

    args = ap.parse_args(argv)

    if not args.input.is_file():
        print(f'radical-orbit-makeflow-prep: input not found: {args.input}',
              file=sys.stderr)
        return 2

    # Resolve pools path to absolute so the generated rules work
    # regardless of the makeflow client's cwd.
    pools_file = args.pools_file
    if pools_file:
        pools_file = str(Path(pools_file).resolve())

    opts = PrepOptions(
        default_endpoint     = args.default_endpoint,
        default_pool     = args.default_pool,
        default_priority = args.default_priority,
        pools_file       = pools_file,
    )
    run_id = args.run_id or compute_run_id(args.input)

    try:
        lines = args.input.read_text().splitlines(keepends=True)
        out_lines = prep_stream(lines, run_id, opts)
    except PrepError as e:
        print(f'radical-orbit-makeflow-prep: {e}', file=sys.stderr)
        return 3

    output = args.output or args.input.with_suffix(
        args.input.suffix + '.prepped')
    output.write_text(''.join(out_lines))
    print(f'wrote {output} (run_id={run_id})')
    return 0


if __name__ == '__main__':
    sys.exit(main())
