Coverage for src / invariant / executor.py: 92.41%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-03 19:45 +0000

1"""Executor: The runtime engine for executing DAGs.""" 

2 

3import inspect 

4from typing import TYPE_CHECKING, Any 

5 

6from invariant.cacheable import is_cacheable 

7from invariant.expressions import resolve_params 

8from invariant.graph import Graph, GraphResolver 

9from invariant.hashing import hash_manifest 

10from invariant.node import Node, SubGraphNode 

11 

12if TYPE_CHECKING: 

13 from invariant.registry import OpRegistry 

14 from invariant.store.base import ArtifactStore 

15 

16 

17class Executor: 

18 """Runtime engine for executing DAGs. 

19 

20 Manages the two-phase execution: 

21 - Phase 1: Context Resolution (Graph -> Manifest) 

22 - Phase 2: Action Execution (Manifest -> Artifact) 

23 """ 

24 

25 def __init__( 

26 self, 

27 registry: "OpRegistry", 

28 store: "ArtifactStore", 

29 resolver: "GraphResolver | None" = None, 

30 ) -> None: 

31 """Initialize Executor. 

32 

33 Args: 

34 registry: OpRegistry for looking up operations. 

35 store: ArtifactStore for caching artifacts. 

36 resolver: Optional GraphResolver. If None, creates one with registry. 

37 """ 

38 self.registry = registry 

39 self.store = store 

40 self.resolver = resolver or GraphResolver(registry) 

41 

42 def execute( 

43 self, graph: Graph, context: dict[str, Any] | None = None 

44 ) -> dict[str, Any]: 

45 """Execute a graph and return artifacts for each node. 

46 

47 Args: 

48 graph: Dictionary mapping node IDs to Node or SubGraphNode objects. 

49 context: Optional dictionary of external dependencies (values not in graph). 

50 These are injected as artifacts available to any node that declares 

51 them in deps. 

52 

53 Returns: 

54 Dictionary mapping node IDs to their resulting artifacts. 

55 

56 Raises: 

57 ValueError: If graph validation fails or execution errors occur. 

58 """ 

59 artifacts_by_node, _ = self._execute_graph(graph, context=context) 

60 return artifacts_by_node 

61 

62 def _execute_graph( 

63 self, 

64 graph: Graph, 

65 context: dict[str, Any] | None = None, 

66 uncacheable_context_keys: set[str] | None = None, 

67 ) -> tuple[dict[str, Any], set[str]]: 

68 """Execute a graph and return artifacts plus node IDs that bypassed cache.""" 

69 # Validate and sort graph (pass context for validation) 

70 context = context or {} 

71 uncacheable_context_keys = uncacheable_context_keys or set() 

72 sorted_nodes = self.resolver.resolve(graph, context_keys=set(context.keys())) 

73 

74 # Track artifacts by node ID 

75 artifacts_by_node: dict[str, Any] = {} 

76 uncacheable_nodes = set(uncacheable_context_keys) 

77 

78 # Inject context values into artifacts_by_node before execution 

79 # This makes external dependencies available to any node that declares them in deps 

80 for key, value in context.items(): 

81 # Context values must be cacheable 

82 if not is_cacheable(value): 

83 raise ValueError( 

84 f"Context value for '{key}' is not cacheable, got {type(value)}" 

85 ) 

86 # Store native types as-is (no wrapping) 

87 artifacts_by_node[key] = value 

88 

89 # Execute nodes in topological order 

90 for node_id in sorted_nodes: 

91 node = graph[node_id] 

92 depends_on_uncacheable = any( 

93 dep_id in uncacheable_nodes for dep_id in node.deps 

94 ) 

95 

96 if isinstance(node, SubGraphNode): 

97 # SubGraphNode: run internal graph with resolved params as context 

98 manifest = self._build_manifest(node, node_id, graph, artifacts_by_node) 

99 inner_uncacheable_context_keys = ( 

100 set(manifest.keys()) if depends_on_uncacheable else set() 

101 ) 

102 inner_results, inner_uncacheable_nodes = self._execute_graph( 

103 node.graph, 

104 context=manifest, 

105 uncacheable_context_keys=inner_uncacheable_context_keys, 

106 ) 

107 if node.output not in inner_results: 

108 raise ValueError( 

109 f"SubGraphNode '{node_id}' output '{node.output}' not in " 

110 f"internal results. Keys: {list(inner_results.keys())}." 

111 ) 

112 artifacts_by_node[node_id] = inner_results[node.output] 

113 if depends_on_uncacheable or node.output in inner_uncacheable_nodes: 

114 uncacheable_nodes.add(node_id) 

115 else: 

116 # Node: Phase 1 build manifest, Phase 2 cache lookup or execute op 

117 manifest = self._build_manifest(node, node_id, graph, artifacts_by_node) 

118 should_cache = node.cache and not depends_on_uncacheable 

119 if not should_cache: 

120 # Ephemeral node: always execute, never cache. This also applies 

121 # to descendants of explicit cache=False nodes. 

122 op = self.registry.get(node.op_name) 

123 artifact = self._invoke_op(op, node.op_name, manifest) 

124 uncacheable_nodes.add(node_id) 

125 else: 

126 digest = hash_manifest(manifest) 

127 if self.store.exists(node.op_name, digest): 

128 artifact = self.store.get(node.op_name, digest) 

129 else: 

130 op = self.registry.get(node.op_name) 

131 artifact = self._invoke_op(op, node.op_name, manifest) 

132 self.store.put(node.op_name, digest, artifact) 

133 artifacts_by_node[node_id] = artifact 

134 

135 return artifacts_by_node, uncacheable_nodes 

136 

137 def _build_manifest( 

138 self, 

139 node: Node | SubGraphNode, 

140 node_id: str, 

141 graph: Graph, 

142 artifacts_by_node: dict[str, Any], 

143 ) -> dict[str, Any]: 

144 """Build the input manifest for a node (Phase 1). 

145 

146 The manifest is built entirely from resolved params. Dependencies are NOT 

147 injected into the manifest directly - they are only available for ref()/cel() 

148 resolution within params. 

149 

150 Args: 

151 node: The node to build manifest for. 

152 node_id: The ID of the node. 

153 graph: The full graph (for reference). 

154 artifacts_by_node: Already computed artifacts for upstream nodes. 

155 

156 Returns: 

157 The manifest dictionary mapping parameter names to resolved values. 

158 """ 

159 # Collect dependency artifacts for ref()/cel() resolution 

160 dependencies: dict[str, Any] = {} 

161 for dep_id in node.deps: 

162 if dep_id not in artifacts_by_node: 

163 raise ValueError( 

164 f"Node '{node_id}' depends on '{dep_id}' but artifact not found. " 

165 f"This should not happen if graph is topologically sorted or " 

166 f"if '{dep_id}' is provided in context." 

167 ) 

168 dependencies[dep_id] = artifacts_by_node[dep_id] 

169 

170 # Manifest = resolved params only. No dependency injection. 

171 # ref() and cel() markers in params are resolved using dependencies. 

172 return resolve_params(node.params, dependencies) 

173 

174 def _invoke_op(self, op: Any, op_name: str, manifest: dict[str, Any]) -> Any: 

175 """Invoke an operation with kwargs dispatch and return validation. 

176 

177 Args: 

178 op: The callable operation to invoke. 

179 op_name: The name of the operation (for error messages). 

180 manifest: The manifest dictionary mapping parameter names to values. 

181 

182 Returns: 

183 The operation result (native type or ICacheable domain type). 

184 

185 Raises: 

186 ValueError: If required parameters are missing. 

187 TypeError: If return value is not cacheable. 

188 """ 

189 # Inspect function signature to map manifest keys to function parameters 

190 sig = inspect.signature(op) 

191 kwargs: dict[str, Any] = {} 

192 

193 # Map manifest keys to function parameters by name 

194 for name, param in sig.parameters.items(): 

195 if name in manifest: 

196 value = manifest[name] 

197 kwargs[name] = value 

198 elif param.default is not inspect.Parameter.empty: 

199 # Parameter has a default value, skip it 

200 pass 

201 elif param.kind == inspect.Parameter.VAR_KEYWORD: 

202 # Function accepts **kwargs, will handle below 

203 pass 

204 else: 

205 # Required parameter missing 

206 raise ValueError(f"Op '{op_name}': missing required parameter '{name}'") 

207 

208 # If function has **kwargs, pass remaining manifest keys 

209 has_var_kwargs = any( 

210 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() 

211 ) 

212 if has_var_kwargs: 

213 for key, val in manifest.items(): 

214 if key not in kwargs: 

215 kwargs[key] = val 

216 

217 # Invoke the operation 

218 result = op(**kwargs) 

219 

220 # Validate return value is cacheable 

221 if not is_cacheable(result): 

222 raise TypeError( 

223 f"Op '{op_name}' returned {type(result).__name__}, " 

224 f"which is not a cacheable type" 

225 ) 

226 

227 # Return as-is (no wrapping needed) 

228 return result