Coverage for src / invariant / executor.py: 92.41%
79 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-03 19:52 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-03 19:52 +0000
1"""Executor: The runtime engine for executing DAGs."""
3import inspect
4from typing import TYPE_CHECKING, Any
6from invariant.cacheable import is_cacheable
7from invariant.expressions import resolve_params
8from invariant.graph import Graph, GraphResolver
9from invariant.hashing import hash_manifest
10from invariant.node import Node, SubGraphNode
12if TYPE_CHECKING:
13 from invariant.registry import OpRegistry
14 from invariant.store.base import ArtifactStore
17class Executor:
18 """Runtime engine for executing DAGs.
20 Manages the two-phase execution:
21 - Phase 1: Context Resolution (Graph -> Manifest)
22 - Phase 2: Action Execution (Manifest -> Artifact)
23 """
25 def __init__(
26 self,
27 registry: "OpRegistry",
28 store: "ArtifactStore",
29 resolver: "GraphResolver | None" = None,
30 ) -> None:
31 """Initialize Executor.
33 Args:
34 registry: OpRegistry for looking up operations.
35 store: ArtifactStore for caching artifacts.
36 resolver: Optional GraphResolver. If None, creates one with registry.
37 """
38 self.registry = registry
39 self.store = store
40 self.resolver = resolver or GraphResolver(registry)
42 def execute(
43 self, graph: Graph, context: dict[str, Any] | None = None
44 ) -> dict[str, Any]:
45 """Execute a graph and return artifacts for each node.
47 Args:
48 graph: Dictionary mapping node IDs to Node or SubGraphNode objects.
49 context: Optional dictionary of external dependencies (values not in graph).
50 These are injected as artifacts available to any node that declares
51 them in deps.
53 Returns:
54 Dictionary mapping node IDs to their resulting artifacts.
56 Raises:
57 ValueError: If graph validation fails or execution errors occur.
58 """
59 artifacts_by_node, _ = self._execute_graph(graph, context=context)
60 return artifacts_by_node
62 def _execute_graph(
63 self,
64 graph: Graph,
65 context: dict[str, Any] | None = None,
66 uncacheable_context_keys: set[str] | None = None,
67 ) -> tuple[dict[str, Any], set[str]]:
68 """Execute a graph and return artifacts plus node IDs that bypassed cache."""
69 # Validate and sort graph (pass context for validation)
70 context = context or {}
71 uncacheable_context_keys = uncacheable_context_keys or set()
72 sorted_nodes = self.resolver.resolve(graph, context_keys=set(context.keys()))
74 # Track artifacts by node ID
75 artifacts_by_node: dict[str, Any] = {}
76 uncacheable_nodes = set(uncacheable_context_keys)
78 # Inject context values into artifacts_by_node before execution
79 # This makes external dependencies available to any node that declares them in deps
80 for key, value in context.items():
81 # Context values must be cacheable
82 if not is_cacheable(value):
83 raise ValueError(
84 f"Context value for '{key}' is not cacheable, got {type(value)}"
85 )
86 # Store native types as-is (no wrapping)
87 artifacts_by_node[key] = value
89 # Execute nodes in topological order
90 for node_id in sorted_nodes:
91 node = graph[node_id]
92 depends_on_uncacheable = any(
93 dep_id in uncacheable_nodes for dep_id in node.deps
94 )
96 if isinstance(node, SubGraphNode):
97 # SubGraphNode: run internal graph with resolved params as context
98 manifest = self._build_manifest(node, node_id, graph, artifacts_by_node)
99 inner_uncacheable_context_keys = (
100 set(manifest.keys()) if depends_on_uncacheable else set()
101 )
102 inner_results, inner_uncacheable_nodes = self._execute_graph(
103 node.graph,
104 context=manifest,
105 uncacheable_context_keys=inner_uncacheable_context_keys,
106 )
107 if node.output not in inner_results:
108 raise ValueError(
109 f"SubGraphNode '{node_id}' output '{node.output}' not in "
110 f"internal results. Keys: {list(inner_results.keys())}."
111 )
112 artifacts_by_node[node_id] = inner_results[node.output]
113 if depends_on_uncacheable or node.output in inner_uncacheable_nodes:
114 uncacheable_nodes.add(node_id)
115 else:
116 # Node: Phase 1 build manifest, Phase 2 cache lookup or execute op
117 manifest = self._build_manifest(node, node_id, graph, artifacts_by_node)
118 should_cache = node.cache and not depends_on_uncacheable
119 if not should_cache:
120 # Ephemeral node: always execute, never cache. This also applies
121 # to descendants of explicit cache=False nodes.
122 op = self.registry.get(node.op_name)
123 artifact = self._invoke_op(op, node.op_name, manifest)
124 uncacheable_nodes.add(node_id)
125 else:
126 digest = hash_manifest(manifest)
127 if self.store.exists(node.op_name, digest):
128 artifact = self.store.get(node.op_name, digest)
129 else:
130 op = self.registry.get(node.op_name)
131 artifact = self._invoke_op(op, node.op_name, manifest)
132 self.store.put(node.op_name, digest, artifact)
133 artifacts_by_node[node_id] = artifact
135 return artifacts_by_node, uncacheable_nodes
137 def _build_manifest(
138 self,
139 node: Node | SubGraphNode,
140 node_id: str,
141 graph: Graph,
142 artifacts_by_node: dict[str, Any],
143 ) -> dict[str, Any]:
144 """Build the input manifest for a node (Phase 1).
146 The manifest is built entirely from resolved params. Dependencies are NOT
147 injected into the manifest directly - they are only available for ref()/cel()
148 resolution within params.
150 Args:
151 node: The node to build manifest for.
152 node_id: The ID of the node.
153 graph: The full graph (for reference).
154 artifacts_by_node: Already computed artifacts for upstream nodes.
156 Returns:
157 The manifest dictionary mapping parameter names to resolved values.
158 """
159 # Collect dependency artifacts for ref()/cel() resolution
160 dependencies: dict[str, Any] = {}
161 for dep_id in node.deps:
162 if dep_id not in artifacts_by_node:
163 raise ValueError(
164 f"Node '{node_id}' depends on '{dep_id}' but artifact not found. "
165 f"This should not happen if graph is topologically sorted or "
166 f"if '{dep_id}' is provided in context."
167 )
168 dependencies[dep_id] = artifacts_by_node[dep_id]
170 # Manifest = resolved params only. No dependency injection.
171 # ref() and cel() markers in params are resolved using dependencies.
172 return resolve_params(node.params, dependencies)
174 def _invoke_op(self, op: Any, op_name: str, manifest: dict[str, Any]) -> Any:
175 """Invoke an operation with kwargs dispatch and return validation.
177 Args:
178 op: The callable operation to invoke.
179 op_name: The name of the operation (for error messages).
180 manifest: The manifest dictionary mapping parameter names to values.
182 Returns:
183 The operation result (native type or ICacheable domain type).
185 Raises:
186 ValueError: If required parameters are missing.
187 TypeError: If return value is not cacheable.
188 """
189 # Inspect function signature to map manifest keys to function parameters
190 sig = inspect.signature(op)
191 kwargs: dict[str, Any] = {}
193 # Map manifest keys to function parameters by name
194 for name, param in sig.parameters.items():
195 if name in manifest:
196 value = manifest[name]
197 kwargs[name] = value
198 elif param.default is not inspect.Parameter.empty:
199 # Parameter has a default value, skip it
200 pass
201 elif param.kind == inspect.Parameter.VAR_KEYWORD:
202 # Function accepts **kwargs, will handle below
203 pass
204 else:
205 # Required parameter missing
206 raise ValueError(f"Op '{op_name}': missing required parameter '{name}'")
208 # If function has **kwargs, pass remaining manifest keys
209 has_var_kwargs = any(
210 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
211 )
212 if has_var_kwargs:
213 for key, val in manifest.items():
214 if key not in kwargs:
215 kwargs[key] = val
217 # Invoke the operation
218 result = op(**kwargs)
220 # Validate return value is cacheable
221 if not is_cacheable(result):
222 raise TypeError(
223 f"Op '{op_name}' returned {type(result).__name__}, "
224 f"which is not a cacheable type"
225 )
227 # Return as-is (no wrapping needed)
228 return result