# rye:signed:2026-04-10T04:15:32Z:28778371fb5a24c219148ad8a44793708113a6604e1a518dae6439d87bad2958:gUKj3e50qx4tHr5eL5YkmXtb-XC_Q8twIH_1HDR-Y_e5-usXhAOwJ5E1EMeV2LOoH-Ps1odaUKm8kUe2ZfVoCw:6ea18199041a1ea8
"""
safety_harness.py: Thread safety harness — limits, hooks, cancellation, permissions
"""

__version__ = "1.0.0"
__tool_type__ = "python"
__category__ = "rye/agent/threads"
__tool_description__ = "Thread safety harness — limits, hooks, cancellation, permissions"

import fnmatch
import logging
import re
from pathlib import Path
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)

from module_loader import load_module

_ANCHOR = Path(__file__).parent

condition_evaluator = load_module("loaders/condition_evaluator", anchor=_ANCHOR)
interpolation = load_module("loaders/interpolation", anchor=_ANCHOR)


def _is_suppressed(hook: Dict, suppress: List[str]) -> bool:
    """Check if a hook should be suppressed by directive context.

    Matches suppress values against:
      - The hook's `id` field (e.g. "system_tool_protocol")
      - The action's full `item_id` (e.g. "rye/agent/core/tool-protocol")

    Does NOT match on basename to avoid ambiguous clashes
    (e.g. "identity" matching both "rye/agent/core/identity"
    and "project/auth/identity").
    """
    hook_id = hook.get("id", "")
    action = hook.get("action", {})
    item_id = action.get("item_id", "")
    for s in suppress:
        if s == hook_id:
            return True
        if item_id and s == item_id:
            return True
    return False


class SafetyHarness:
    """Manages thread limits, hooks, cancellation, and permission enforcement.

    NOT an execution engine — it checks limits, evaluates hook conditions,
    and enforces directive permissions on tool calls.

    Permissions are declared in directive XML as capability strings:
        rye.<primary>.<kind>.<item_id_dotted>
    Example: rye.execute.tool.rye.file-system.* allows executing any
    tool under rye/file-system/.

    Two hook dispatch methods:
      - run_hooks()         — for error/limit/after_step events. Returns control action or None.
      - run_hooks_context() — for thread_started/thread_continued. Returns concatenated context string.
    """

    def __init__(
        self,
        thread_id: str,
        limits: Dict,
        hooks: List[Dict],
        project_path: Path,
        directive_name: str = "",
        permissions: Optional[List[Dict]] = None,
        parent_capabilities: Optional[List[str]] = None,
    ):
        self.thread_id = thread_id
        self.limits = limits
        self.hooks = hooks
        self.project_path = project_path
        self.directive_name = directive_name
        self._cancelled = False
        self.available_tools: List[Dict] = []

        child_caps = []
        if permissions:
            child_caps = [
                p["content"].replace("/", ".") for p in permissions if p.get("tag") == "cap"
            ]

        parent_caps = [c.replace("/", ".") for c in parent_capabilities] if parent_capabilities else []

        if child_caps and parent_caps:
            # Attenuate: capabilities only narrow down the hierarchy.
            # A child cap is kept if a parent cap covers it (parent as pattern).
            # If the child is broader than the parent, narrow to parent's scope.
            attenuated = []
            for cc in child_caps:
                for pc in parent_caps:
                    if fnmatch.fnmatch(cc, pc):
                        # Parent covers child exactly, or child is narrower
                        attenuated.append(cc)
                        break
                    elif fnmatch.fnmatch(pc, cc):
                        # Child is broader — narrow to parent's scope
                        attenuated.append(pc)
                        break
            self._capabilities = attenuated
        elif child_caps:
            self._capabilities = child_caps
        elif parent_caps:
            self._capabilities = parent_caps
        else:
            self._capabilities = []

    def check_permission(self, primary: str, kind: str, item_id: str = "") -> Optional[Dict]:
        """Check if an action is permitted by directive capabilities.

        Returns None if allowed, or an error dict if denied.

        If no capabilities are declared, all actions are denied (fail-closed).
        Internal thread tools (rye/agent/threads/internal/*) are always allowed.

        Capability format depends on the primary action:
          execute/fetch/sign: rye.<primary>.<kind>.<item_id_dotted>

        Item IDs use / separators, capabilities use . separators with fnmatch wildcards.
        Example: capability "rye.execute.tool.rye.file-system.*"
                 matches item_id "rye/file-system/fs_write"
        """
        if item_id and item_id.startswith("rye/agent/threads/internal/"):
            return None

        if not self._capabilities:
            target = item_id or kind
            return {
                "error": f"Permission denied: no capabilities declared. "
                f"Cannot {primary} {kind} '{target}'",
                "denied_action": primary,
                "denied_kind": kind,
                "denied_item_id": item_id,
            }

        # Build the capability string to check
        if item_id:
            item_id_dotted = item_id.replace("/", ".")
            required = f"rye.{primary}.{kind}.{item_id_dotted}"
        else:
            # fetch query mode has no item_id — check rye.fetch.<kind>
            required = f"rye.{primary}.{kind}"

        for cap in self._capabilities:
            if fnmatch.fnmatch(required, cap):
                return None

        return {
            "error": f"Permission denied: '{required}' not covered by "
            f"capabilities {self._capabilities}",
            "denied_action": primary,
            "denied_kind": kind,
            "denied_item_id": item_id,
        }

    def check_limits(self, cost: Dict) -> Optional[Dict]:
        """Check all limits against current cost. Returns limit event or None."""
        checks = [
            ("turns", cost.get("turns", 0), self.limits.get("turns")),
            (
                "tokens",
                cost.get("input_tokens", 0) + cost.get("output_tokens", 0),
                self.limits.get("tokens"),
            ),
            ("spend", cost.get("spend", 0.0), self.limits.get("spend")),
            ("duration_seconds", cost.get("elapsed_seconds", 0), self.limits.get("duration_seconds")),
        ]
        for limit_code, current, maximum in checks:
            if maximum is not None and current >= maximum:
                return {
                    "limit_code": f"{limit_code}_exceeded",
                    "current_value": current,
                    "current_max": maximum,
                }
        return None

    async def run_hooks(
        self,
        event: str,
        context: Dict,
        dispatcher: Any,
    ) -> Optional[Dict]:
        """Evaluate hooks for error/limit/after_step events.

        Hook evaluation order: layer 1 (directive) → layer 2 (builtin) → layer 3 (infra).
        First hook action that returns a non-None result wins (for control flow).
        Infra hooks (layer 3) always run regardless.

        Returns:
            None = continue, Dict = terminating action (from control.py)
        """
        control_result = None
        for hook in self.hooks:
            if hook.get("event") != event:
                continue
            if not condition_evaluator.matches(context, hook.get("condition", {})):
                continue

            action = hook.get("action", {})
            interpolated = interpolation.interpolate_action(action, context)
            result = await dispatcher.dispatch(interpolated)

            if hook.get("layer") == 3:
                continue

            if result and control_result is None:
                data = result.get("data", result)
                if data is not None and data != {"success": True}:
                    control_result = data

        return control_result

    async def run_hooks_context(
        self,
        context: Dict,
        dispatcher: Any,
        event: str,
        suppress: Optional[List[str]] = None,
    ) -> Dict[str, str]:
        """Run context-injection hooks for a given event and collect context blocks.

        Unlike run_hooks(), this method:
        - Filters by the specified event (not hardcoded to thread_started)
        - Runs ALL matching hooks (no short-circuit)
        - Maps FetchTool/ExecuteTool results to XML-wrapped context blocks
        - Returns {"before": str, "after": str} keyed by hook position

        Hook position is controlled by the `position` field (default: "before"):
        - "before": injected before the directive prompt (identity, rules)
        - "after": injected after the directive prompt (supplementary knowledge)
        """
        before_blocks = []
        after_blocks = []
        before_raw = []
        after_raw = []
        for hook in self.hooks:
            if hook.get("event") != event:
                continue
            if suppress and _is_suppressed(hook, suppress):
                continue
            if not condition_evaluator.matches(context, hook.get("condition", {})):
                continue

            # Support both single action and multi-action hooks
            actions = hook.get("actions", [])
            if not actions:
                single = hook.get("action", {})
                if single:
                    actions = [single]

            for action in actions:
                interpolated = interpolation.interpolate_action(action, context)
                result = await dispatcher.dispatch(interpolated)

                if not result or result.get("status") != "success":
                    item_id = action.get("item_id", "unknown")
                    error = result.get("error", "unknown error") if result else "no result"
                    if result and result.get("error_type") == "integrity":
                        raise RuntimeError(
                            f"Context hook '{hook.get('id', '?')}' integrity failure "
                            f"loading {item_id}: {error}"
                        )
                    logger.warning(
                        "Context hook '%s' failed to load %s: %s",
                        hook.get("id", "?"), item_id, error,
                    )
                    continue

                data = result.get("data", {})
                content = (
                    data.get("content") or data.get("body") or data.get("raw", "")
                    # FetchTool (ID mode) returns content at top level (no data wrapper)
                    or result.get("content") or result.get("body", "")
                )
                if content:
                    # Interpolate ${...} placeholders in content with hook context
                    content = interpolation.interpolate(content, context)

                    item_id = action.get("item_id", "")
                    kind = action.get("item_type", "")
                    # Hook-level wrap control (default: true)
                    wrap = hook.get("wrap", True)

                    if wrap:
                        # Use the item's name from parsed metadata as the XML tag
                        tag = data.get("name", "") if isinstance(data, dict) else ""
                        if not tag:
                            tag = result.get("name", "") or item_id.rsplit("/", 1)[-1] or "context"
                        type_attr = f' type="{kind}"' if kind else ""
                        block = f'<{tag} id="{item_id}"{type_attr}>\n{content.strip()}\n</{tag}>'
                    else:
                        tag = data.get("name", "") if isinstance(data, dict) else item_id.rsplit("/", 1)[-1]
                        block = content.strip()

                    raw_entry = {"id": item_id, "content": content.strip(), "role": tag}
                    if hook.get("position", "before") == "after":
                        after_blocks.append(block)
                        after_raw.append(raw_entry)
                    else:
                        before_blocks.append(block)
                        before_raw.append(raw_entry)

        return {
            "before": "\n\n".join(before_blocks),
            "after": "\n\n".join(after_blocks),
            "before_raw": before_raw,
            "after_raw": after_raw,
        }

    def request_cancel(self):
        self._cancelled = True

    def is_cancelled(self) -> bool:
        return self._cancelled
