Coverage for src / invariant / executor.py: 94.59%

148 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-08 09:24 +0000

1"""Executor: The runtime engine for executing DAGs.""" 

2 

3from collections.abc import Iterable 

4from decimal import Decimal 

5from typing import TYPE_CHECKING, Any 

6 

7from invariant.cacheable import is_cacheable 

8from invariant.expressions import resolve_params 

9from invariant.graph import Graph 

10from invariant.hashing import hash_manifest 

11from invariant.invocation import invoke_op 

12from invariant.node import Node, SubGraphNode, SwitchNode 

13 

14if TYPE_CHECKING: 

15 from invariant.registry import OpRegistry 

16 from invariant.store.base import ArtifactStore 

17 

18 

19class Executor: 

20 """Runtime engine for executing DAGs. 

21 

22 Manages the two-phase execution: 

23 - Phase 1: Context Resolution (Graph -> Manifest) 

24 - Phase 2: Action Execution (Manifest -> Artifact) 

25 """ 

26 

27 def __init__( 

28 self, 

29 registry: "OpRegistry", 

30 store: "ArtifactStore", 

31 ) -> None: 

32 """Initialize Executor. 

33 

34 Args: 

35 registry: OpRegistry for looking up operations. 

36 store: ArtifactStore for caching artifacts. 

37 """ 

38 self.registry = registry 

39 self.store = store 

40 

41 def execute( 

42 self, 

43 graph: Graph, 

44 outputs: Iterable[str], 

45 context: dict[str, Any] | None = None, 

46 ) -> dict[str, Any]: 

47 """Execute requested graph outputs and return their artifacts. 

48 

49 Execution is demand-driven from the requested output roots. Unreachable 

50 vertices and inactive SwitchNode branches are not dependency-resolved, 

51 registry-validated, cache-checked, or executed. 

52 

53 Args: 

54 graph: Dictionary mapping node IDs to graph vertices. 

55 outputs: Iterable of graph node IDs to produce. 

56 context: Optional dictionary of external dependencies (values not in graph). 

57 These are injected as artifacts available to any node that declares 

58 them in deps. 

59 

60 Returns: 

61 Dictionary mapping requested output IDs to their resulting artifacts. 

62 

63 Raises: 

64 ValueError: If requested outputs or active paths are invalid. 

65 """ 

66 output_ids = self._normalize_outputs(outputs, graph) 

67 artifacts_by_node, _ = self._execute_requested_outputs( 

68 graph, 

69 output_ids, 

70 context=context, 

71 ) 

72 return {output: artifacts_by_node[output] for output in output_ids} 

73 

74 def _execute_requested_outputs( 

75 self, 

76 graph: Graph, 

77 outputs: tuple[str, ...], 

78 context: dict[str, Any] | None = None, 

79 uncacheable_context_keys: set[str] | None = None, 

80 ) -> tuple[dict[str, Any], set[str]]: 

81 """Demand-execute the active paths for requested graph outputs.""" 

82 context = context or {} 

83 uncacheable_nodes = set(uncacheable_context_keys or set()) 

84 artifacts_by_node: dict[str, Any] = {} 

85 visiting: list[str] = [] 

86 visiting_set: set[str] = set() 

87 

88 def resolve_artifact(node_id: str) -> Any: 

89 if node_id in artifacts_by_node: 

90 return artifacts_by_node[node_id] 

91 

92 if node_id in graph: 

93 if node_id in visiting_set: 

94 cycle = " -> ".join([*visiting, node_id]) 

95 raise ValueError(f"Graph contains cycles on active path: {cycle}") 

96 

97 visiting.append(node_id) 

98 visiting_set.add(node_id) 

99 try: 

100 node = graph[node_id] 

101 for dep_id in node.deps: 

102 resolve_artifact(dep_id) 

103 

104 depends_on_uncacheable = any( 

105 dep_id in uncacheable_nodes for dep_id in node.deps 

106 ) 

107 

108 if isinstance(node, SwitchNode): 

109 self._validate_switch_targets(node, node_id, graph) 

110 target_id = self._select_switch_target( 

111 node, 

112 node_id, 

113 artifacts_by_node, 

114 ) 

115 if target_id not in graph: 

116 raise ValueError( 

117 f"SwitchNode '{node_id}' targets '{target_id}' " 

118 "which doesn't exist in graph" 

119 ) 

120 resolve_artifact(target_id) 

121 artifacts_by_node[node_id] = artifacts_by_node[target_id] 

122 if depends_on_uncacheable or target_id in uncacheable_nodes: 

123 uncacheable_nodes.add(node_id) 

124 elif isinstance(node, SubGraphNode): 

125 manifest = self._build_manifest( 

126 node, 

127 node_id, 

128 graph, 

129 artifacts_by_node, 

130 ) 

131 inner_uncacheable_context_keys = ( 

132 set(manifest.keys()) if depends_on_uncacheable else set() 

133 ) 

134 inner_results, inner_uncacheable_nodes = ( 

135 self._execute_requested_outputs( 

136 node.graph, 

137 (node.output,), 

138 context=manifest, 

139 uncacheable_context_keys=inner_uncacheable_context_keys, 

140 ) 

141 ) 

142 artifacts_by_node[node_id] = inner_results[node.output] 

143 if ( 

144 depends_on_uncacheable 

145 or node.output in inner_uncacheable_nodes 

146 ): 

147 uncacheable_nodes.add(node_id) 

148 else: 

149 manifest = self._build_manifest( 

150 node, 

151 node_id, 

152 graph, 

153 artifacts_by_node, 

154 ) 

155 artifact = self._execute_node( 

156 node, 

157 node_id, 

158 manifest, 

159 depends_on_uncacheable, 

160 uncacheable_nodes, 

161 ) 

162 artifacts_by_node[node_id] = artifact 

163 finally: 

164 visiting.pop() 

165 visiting_set.remove(node_id) 

166 

167 return artifacts_by_node[node_id] 

168 

169 if node_id in context: 

170 value = context[node_id] 

171 if not is_cacheable(value): 

172 raise ValueError( 

173 f"Context value for '{node_id}' is not cacheable, " 

174 f"got {type(value)}" 

175 ) 

176 artifacts_by_node[node_id] = value 

177 return value 

178 

179 raise ValueError( 

180 f"Node depends on '{node_id}' but it doesn't exist in graph " 

181 f"or context. Available: graph={sorted(graph)}, " 

182 f"context={sorted(context)}" 

183 ) 

184 

185 for output in outputs: 

186 resolve_artifact(output) 

187 return artifacts_by_node, uncacheable_nodes 

188 

189 def _normalize_outputs( 

190 self, 

191 outputs: Iterable[str], 

192 graph: Graph, 

193 ) -> tuple[str, ...]: 

194 """Validate and freeze requested output IDs.""" 

195 if isinstance(outputs, (str, bytes)): 

196 raise ValueError( 

197 "outputs must be an iterable of node IDs, not str or bytes" 

198 ) 

199 

200 try: 

201 output_ids = tuple(outputs) 

202 except TypeError as exc: 

203 raise ValueError("outputs must be an iterable of node IDs") from exc 

204 

205 if not output_ids: 

206 raise ValueError("outputs must not be empty") 

207 

208 seen: set[str] = set() 

209 for output in output_ids: 

210 if not isinstance(output, str) or not output: 

211 raise ValueError("outputs must contain non-empty strings") 

212 if output in seen: 

213 raise ValueError(f"outputs contains duplicate node ID '{output}'") 

214 seen.add(output) 

215 if output not in graph: 

216 raise ValueError( 

217 f"Output node '{output}' is not in graph. " 

218 f"Available: {sorted(graph)}" 

219 ) 

220 

221 return output_ids 

222 

223 def _validate_switch_targets( 

224 self, 

225 node: SwitchNode, 

226 node_id: str, 

227 graph: Graph, 

228 ) -> None: 

229 """Validate branch targets for an active SwitchNode.""" 

230 targets = list(node.cases.values()) 

231 if node.default is not None: 

232 targets.append(node.default) 

233 for target in targets: 

234 if target not in graph: 

235 raise ValueError( 

236 f"SwitchNode '{node_id}' targets '{target}' which doesn't " 

237 "exist in graph" 

238 ) 

239 

240 def _build_manifest( 

241 self, 

242 node: Node | SubGraphNode, 

243 node_id: str, 

244 graph: Graph, 

245 artifacts_by_node: dict[str, Any], 

246 ) -> dict[str, Any]: 

247 """Build the input manifest for a node (Phase 1). 

248 

249 The manifest is built entirely from resolved params. Dependencies are NOT 

250 injected into the manifest directly - they are only available for ref()/cel() 

251 resolution within params. 

252 

253 Args: 

254 node: The node to build manifest for. 

255 node_id: The ID of the node. 

256 graph: The full graph (for reference). 

257 artifacts_by_node: Already computed artifacts for upstream nodes. 

258 

259 Returns: 

260 The manifest dictionary mapping parameter names to resolved values. 

261 """ 

262 # Collect dependency artifacts for ref()/cel() resolution 

263 dependencies: dict[str, Any] = {} 

264 for dep_id in node.deps: 

265 if dep_id not in artifacts_by_node: 

266 raise ValueError( 

267 f"Node '{node_id}' depends on '{dep_id}' but artifact not found. " 

268 "This should not happen after active dependency resolution." 

269 ) 

270 dependencies[dep_id] = artifacts_by_node[dep_id] 

271 

272 # Manifest = resolved params only. No dependency injection. 

273 # ref() and cel() markers in params are resolved using dependencies. 

274 return resolve_params(node.params, dependencies) 

275 

276 def _resolve_selector( 

277 self, 

278 node: SwitchNode, 

279 node_id: str, 

280 artifacts_by_node: dict[str, Any], 

281 ) -> Any: 

282 """Resolve a SwitchNode selector from its declared deps.""" 

283 dependencies: dict[str, Any] = {} 

284 for dep_id in node.deps: 

285 if dep_id not in artifacts_by_node: 

286 raise ValueError( 

287 f"SwitchNode '{node_id}' depends on '{dep_id}' but artifact " 

288 "not found" 

289 ) 

290 dependencies[dep_id] = artifacts_by_node[dep_id] 

291 return resolve_params({"selector": node.selector}, dependencies)["selector"] 

292 

293 def _select_switch_target( 

294 self, 

295 node: SwitchNode, 

296 node_id: str, 

297 artifacts_by_node: dict[str, Any], 

298 ) -> str: 

299 """Resolve a SwitchNode selector and return the chosen target node ID.""" 

300 selector_value = self._resolve_selector(node, node_id, artifacts_by_node) 

301 case_key = self._normalize_switch_key(selector_value, node_id) 

302 if case_key in node.cases: 

303 return node.cases[case_key] 

304 if node.default is not None: 

305 return node.default 

306 available = ", ".join(sorted(node.cases)) 

307 raise ValueError( 

308 f"SwitchNode '{node_id}' selector resolved to {case_key!r}, " 

309 f"which has no case. Available cases: {available}" 

310 ) 

311 

312 def _normalize_switch_key(self, value: Any, node_id: str) -> str: 

313 """Normalize selector results to switch case keys.""" 

314 if isinstance(value, str): 

315 return value 

316 if isinstance(value, bool): 

317 return "true" if value else "false" 

318 if value is None: 

319 return "null" 

320 if isinstance(value, int): 

321 return str(value) 

322 if isinstance(value, Decimal): 

323 return str(value) 

324 raise ValueError( 

325 f"SwitchNode '{node_id}' selector returned unsupported " 

326 f"{type(value).__name__}; expected str, bool, null, int, or Decimal" 

327 ) 

328 

329 def _execute_node( 

330 self, 

331 node: Node, 

332 node_id: str, 

333 manifest: dict[str, Any], 

334 depends_on_uncacheable: bool, 

335 uncacheable_nodes: set[str], 

336 ) -> Any: 

337 """Execute one Node using the existing cache semantics.""" 

338 if not self.registry.has(node.op_name): 

339 raise ValueError( 

340 f"Node '{node_id}' references unregistered op '{node.op_name}'" 

341 ) 

342 

343 should_cache = node.cache and not depends_on_uncacheable 

344 if not should_cache: 

345 op = self.registry.get(node.op_name) 

346 artifact = self._invoke_op(op, node.op_name, manifest) 

347 uncacheable_nodes.add(node_id) 

348 return artifact 

349 

350 digest = hash_manifest(manifest) 

351 if self.store.exists(node.op_name, digest): 

352 return self.store.get(node.op_name, digest) 

353 

354 op = self.registry.get(node.op_name) 

355 artifact = self._invoke_op(op, node.op_name, manifest) 

356 self.store.put(node.op_name, digest, artifact) 

357 return artifact 

358 

359 def _invoke_op(self, op: Any, op_name: str, manifest: dict[str, Any]) -> Any: 

360 """Invoke an operation with kwargs dispatch and return validation. 

361 

362 Args: 

363 op: The callable operation to invoke. 

364 op_name: The name of the operation (for error messages). 

365 manifest: The manifest dictionary mapping parameter names to values. 

366 

367 Returns: 

368 The operation result (native type or ICacheable domain type). 

369 

370 Raises: 

371 ValueError: If required parameters are missing. 

372 TypeError: If return value is not cacheable. 

373 """ 

374 return invoke_op(op, op_name, manifest)