Coverage for src / invariant / executor.py: 94.59%
148 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-08 09:24 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-08 09:24 +0000
1"""Executor: The runtime engine for executing DAGs."""
3from collections.abc import Iterable
4from decimal import Decimal
5from typing import TYPE_CHECKING, Any
7from invariant.cacheable import is_cacheable
8from invariant.expressions import resolve_params
9from invariant.graph import Graph
10from invariant.hashing import hash_manifest
11from invariant.invocation import invoke_op
12from invariant.node import Node, SubGraphNode, SwitchNode
14if TYPE_CHECKING:
15 from invariant.registry import OpRegistry
16 from invariant.store.base import ArtifactStore
19class Executor:
20 """Runtime engine for executing DAGs.
22 Manages the two-phase execution:
23 - Phase 1: Context Resolution (Graph -> Manifest)
24 - Phase 2: Action Execution (Manifest -> Artifact)
25 """
27 def __init__(
28 self,
29 registry: "OpRegistry",
30 store: "ArtifactStore",
31 ) -> None:
32 """Initialize Executor.
34 Args:
35 registry: OpRegistry for looking up operations.
36 store: ArtifactStore for caching artifacts.
37 """
38 self.registry = registry
39 self.store = store
41 def execute(
42 self,
43 graph: Graph,
44 outputs: Iterable[str],
45 context: dict[str, Any] | None = None,
46 ) -> dict[str, Any]:
47 """Execute requested graph outputs and return their artifacts.
49 Execution is demand-driven from the requested output roots. Unreachable
50 vertices and inactive SwitchNode branches are not dependency-resolved,
51 registry-validated, cache-checked, or executed.
53 Args:
54 graph: Dictionary mapping node IDs to graph vertices.
55 outputs: Iterable of graph node IDs to produce.
56 context: Optional dictionary of external dependencies (values not in graph).
57 These are injected as artifacts available to any node that declares
58 them in deps.
60 Returns:
61 Dictionary mapping requested output IDs to their resulting artifacts.
63 Raises:
64 ValueError: If requested outputs or active paths are invalid.
65 """
66 output_ids = self._normalize_outputs(outputs, graph)
67 artifacts_by_node, _ = self._execute_requested_outputs(
68 graph,
69 output_ids,
70 context=context,
71 )
72 return {output: artifacts_by_node[output] for output in output_ids}
74 def _execute_requested_outputs(
75 self,
76 graph: Graph,
77 outputs: tuple[str, ...],
78 context: dict[str, Any] | None = None,
79 uncacheable_context_keys: set[str] | None = None,
80 ) -> tuple[dict[str, Any], set[str]]:
81 """Demand-execute the active paths for requested graph outputs."""
82 context = context or {}
83 uncacheable_nodes = set(uncacheable_context_keys or set())
84 artifacts_by_node: dict[str, Any] = {}
85 visiting: list[str] = []
86 visiting_set: set[str] = set()
88 def resolve_artifact(node_id: str) -> Any:
89 if node_id in artifacts_by_node:
90 return artifacts_by_node[node_id]
92 if node_id in graph:
93 if node_id in visiting_set:
94 cycle = " -> ".join([*visiting, node_id])
95 raise ValueError(f"Graph contains cycles on active path: {cycle}")
97 visiting.append(node_id)
98 visiting_set.add(node_id)
99 try:
100 node = graph[node_id]
101 for dep_id in node.deps:
102 resolve_artifact(dep_id)
104 depends_on_uncacheable = any(
105 dep_id in uncacheable_nodes for dep_id in node.deps
106 )
108 if isinstance(node, SwitchNode):
109 self._validate_switch_targets(node, node_id, graph)
110 target_id = self._select_switch_target(
111 node,
112 node_id,
113 artifacts_by_node,
114 )
115 if target_id not in graph:
116 raise ValueError(
117 f"SwitchNode '{node_id}' targets '{target_id}' "
118 "which doesn't exist in graph"
119 )
120 resolve_artifact(target_id)
121 artifacts_by_node[node_id] = artifacts_by_node[target_id]
122 if depends_on_uncacheable or target_id in uncacheable_nodes:
123 uncacheable_nodes.add(node_id)
124 elif isinstance(node, SubGraphNode):
125 manifest = self._build_manifest(
126 node,
127 node_id,
128 graph,
129 artifacts_by_node,
130 )
131 inner_uncacheable_context_keys = (
132 set(manifest.keys()) if depends_on_uncacheable else set()
133 )
134 inner_results, inner_uncacheable_nodes = (
135 self._execute_requested_outputs(
136 node.graph,
137 (node.output,),
138 context=manifest,
139 uncacheable_context_keys=inner_uncacheable_context_keys,
140 )
141 )
142 artifacts_by_node[node_id] = inner_results[node.output]
143 if (
144 depends_on_uncacheable
145 or node.output in inner_uncacheable_nodes
146 ):
147 uncacheable_nodes.add(node_id)
148 else:
149 manifest = self._build_manifest(
150 node,
151 node_id,
152 graph,
153 artifacts_by_node,
154 )
155 artifact = self._execute_node(
156 node,
157 node_id,
158 manifest,
159 depends_on_uncacheable,
160 uncacheable_nodes,
161 )
162 artifacts_by_node[node_id] = artifact
163 finally:
164 visiting.pop()
165 visiting_set.remove(node_id)
167 return artifacts_by_node[node_id]
169 if node_id in context:
170 value = context[node_id]
171 if not is_cacheable(value):
172 raise ValueError(
173 f"Context value for '{node_id}' is not cacheable, "
174 f"got {type(value)}"
175 )
176 artifacts_by_node[node_id] = value
177 return value
179 raise ValueError(
180 f"Node depends on '{node_id}' but it doesn't exist in graph "
181 f"or context. Available: graph={sorted(graph)}, "
182 f"context={sorted(context)}"
183 )
185 for output in outputs:
186 resolve_artifact(output)
187 return artifacts_by_node, uncacheable_nodes
189 def _normalize_outputs(
190 self,
191 outputs: Iterable[str],
192 graph: Graph,
193 ) -> tuple[str, ...]:
194 """Validate and freeze requested output IDs."""
195 if isinstance(outputs, (str, bytes)):
196 raise ValueError(
197 "outputs must be an iterable of node IDs, not str or bytes"
198 )
200 try:
201 output_ids = tuple(outputs)
202 except TypeError as exc:
203 raise ValueError("outputs must be an iterable of node IDs") from exc
205 if not output_ids:
206 raise ValueError("outputs must not be empty")
208 seen: set[str] = set()
209 for output in output_ids:
210 if not isinstance(output, str) or not output:
211 raise ValueError("outputs must contain non-empty strings")
212 if output in seen:
213 raise ValueError(f"outputs contains duplicate node ID '{output}'")
214 seen.add(output)
215 if output not in graph:
216 raise ValueError(
217 f"Output node '{output}' is not in graph. "
218 f"Available: {sorted(graph)}"
219 )
221 return output_ids
223 def _validate_switch_targets(
224 self,
225 node: SwitchNode,
226 node_id: str,
227 graph: Graph,
228 ) -> None:
229 """Validate branch targets for an active SwitchNode."""
230 targets = list(node.cases.values())
231 if node.default is not None:
232 targets.append(node.default)
233 for target in targets:
234 if target not in graph:
235 raise ValueError(
236 f"SwitchNode '{node_id}' targets '{target}' which doesn't "
237 "exist in graph"
238 )
240 def _build_manifest(
241 self,
242 node: Node | SubGraphNode,
243 node_id: str,
244 graph: Graph,
245 artifacts_by_node: dict[str, Any],
246 ) -> dict[str, Any]:
247 """Build the input manifest for a node (Phase 1).
249 The manifest is built entirely from resolved params. Dependencies are NOT
250 injected into the manifest directly - they are only available for ref()/cel()
251 resolution within params.
253 Args:
254 node: The node to build manifest for.
255 node_id: The ID of the node.
256 graph: The full graph (for reference).
257 artifacts_by_node: Already computed artifacts for upstream nodes.
259 Returns:
260 The manifest dictionary mapping parameter names to resolved values.
261 """
262 # Collect dependency artifacts for ref()/cel() resolution
263 dependencies: dict[str, Any] = {}
264 for dep_id in node.deps:
265 if dep_id not in artifacts_by_node:
266 raise ValueError(
267 f"Node '{node_id}' depends on '{dep_id}' but artifact not found. "
268 "This should not happen after active dependency resolution."
269 )
270 dependencies[dep_id] = artifacts_by_node[dep_id]
272 # Manifest = resolved params only. No dependency injection.
273 # ref() and cel() markers in params are resolved using dependencies.
274 return resolve_params(node.params, dependencies)
276 def _resolve_selector(
277 self,
278 node: SwitchNode,
279 node_id: str,
280 artifacts_by_node: dict[str, Any],
281 ) -> Any:
282 """Resolve a SwitchNode selector from its declared deps."""
283 dependencies: dict[str, Any] = {}
284 for dep_id in node.deps:
285 if dep_id not in artifacts_by_node:
286 raise ValueError(
287 f"SwitchNode '{node_id}' depends on '{dep_id}' but artifact "
288 "not found"
289 )
290 dependencies[dep_id] = artifacts_by_node[dep_id]
291 return resolve_params({"selector": node.selector}, dependencies)["selector"]
293 def _select_switch_target(
294 self,
295 node: SwitchNode,
296 node_id: str,
297 artifacts_by_node: dict[str, Any],
298 ) -> str:
299 """Resolve a SwitchNode selector and return the chosen target node ID."""
300 selector_value = self._resolve_selector(node, node_id, artifacts_by_node)
301 case_key = self._normalize_switch_key(selector_value, node_id)
302 if case_key in node.cases:
303 return node.cases[case_key]
304 if node.default is not None:
305 return node.default
306 available = ", ".join(sorted(node.cases))
307 raise ValueError(
308 f"SwitchNode '{node_id}' selector resolved to {case_key!r}, "
309 f"which has no case. Available cases: {available}"
310 )
312 def _normalize_switch_key(self, value: Any, node_id: str) -> str:
313 """Normalize selector results to switch case keys."""
314 if isinstance(value, str):
315 return value
316 if isinstance(value, bool):
317 return "true" if value else "false"
318 if value is None:
319 return "null"
320 if isinstance(value, int):
321 return str(value)
322 if isinstance(value, Decimal):
323 return str(value)
324 raise ValueError(
325 f"SwitchNode '{node_id}' selector returned unsupported "
326 f"{type(value).__name__}; expected str, bool, null, int, or Decimal"
327 )
329 def _execute_node(
330 self,
331 node: Node,
332 node_id: str,
333 manifest: dict[str, Any],
334 depends_on_uncacheable: bool,
335 uncacheable_nodes: set[str],
336 ) -> Any:
337 """Execute one Node using the existing cache semantics."""
338 if not self.registry.has(node.op_name):
339 raise ValueError(
340 f"Node '{node_id}' references unregistered op '{node.op_name}'"
341 )
343 should_cache = node.cache and not depends_on_uncacheable
344 if not should_cache:
345 op = self.registry.get(node.op_name)
346 artifact = self._invoke_op(op, node.op_name, manifest)
347 uncacheable_nodes.add(node_id)
348 return artifact
350 digest = hash_manifest(manifest)
351 if self.store.exists(node.op_name, digest):
352 return self.store.get(node.op_name, digest)
354 op = self.registry.get(node.op_name)
355 artifact = self._invoke_op(op, node.op_name, manifest)
356 self.store.put(node.op_name, digest, artifact)
357 return artifact
359 def _invoke_op(self, op: Any, op_name: str, manifest: dict[str, Any]) -> Any:
360 """Invoke an operation with kwargs dispatch and return validation.
362 Args:
363 op: The callable operation to invoke.
364 op_name: The name of the operation (for error messages).
365 manifest: The manifest dictionary mapping parameter names to values.
367 Returns:
368 The operation result (native type or ICacheable domain type).
370 Raises:
371 ValueError: If required parameters are missing.
372 TypeError: If return value is not cacheable.
373 """
374 return invoke_op(op, op_name, manifest)