Coverage for src / invariant / async_executor.py: 78.72%

141 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-08 09:24 +0000

1"""Async executor: scheduler-driven alternative to the safe sync Executor.""" 

2 

3import asyncio 

4from collections.abc import Iterable 

5from typing import Any 

6 

7from invariant.cacheable import is_cacheable 

8from invariant.executor import Executor 

9from invariant.graph import Graph 

10from invariant.hashing import hash_manifest 

11from invariant.node import Node, SubGraphNode, SwitchNode 

12from invariant.scheduler import InlineScheduler, InvocationRequest, InvocationScheduler 

13 

14 

15class AsyncExecutor(Executor): 

16 """Async runtime engine for executing DAGs. 

17 

18 The async executor preserves the synchronous executor's graph and cache 

19 semantics, while delegating operation placement to an InvocationScheduler. 

20 """ 

21 

22 def __init__( 

23 self, 

24 registry, 

25 store, 

26 scheduler: InvocationScheduler | None = None, 

27 ) -> None: 

28 """Initialize AsyncExecutor.""" 

29 super().__init__(registry, store) 

30 self.scheduler = scheduler or InlineScheduler() 

31 self._singleflight: dict[tuple[str, str], asyncio.Task[Any]] = {} 

32 self._singleflight_lock = asyncio.Lock() 

33 

34 async def __aenter__(self) -> "AsyncExecutor": 

35 """Enter an async context manager.""" 

36 return self 

37 

38 async def __aexit__(self, *exc_info: object) -> None: 

39 """Close scheduler resources on context exit.""" 

40 await self.aclose() 

41 

42 async def aclose(self) -> None: 

43 """Close scheduler resources when supported.""" 

44 close = getattr(self.scheduler, "aclose", None) 

45 if close is not None: 

46 await close() 

47 

48 async def execute( 

49 self, 

50 graph: Graph, 

51 outputs: Iterable[str], 

52 context: dict[str, Any] | None = None, 

53 ) -> dict[str, Any]: 

54 """Execute requested graph outputs and return their artifacts.""" 

55 output_ids = self._normalize_outputs(outputs, graph) 

56 artifacts_by_node, _ = await self._execute_requested_outputs_async( 

57 graph, 

58 output_ids, 

59 context=context, 

60 ) 

61 return {output: artifacts_by_node[output] for output in output_ids} 

62 

63 async def _execute_requested_outputs_async( 

64 self, 

65 graph: Graph, 

66 outputs: tuple[str, ...], 

67 context: dict[str, Any] | None = None, 

68 uncacheable_context_keys: set[str] | None = None, 

69 ) -> tuple[dict[str, Any], set[str]]: 

70 """Demand-execute active paths for requested graph outputs.""" 

71 context = context or {} 

72 uncacheable_nodes = set(uncacheable_context_keys or set()) 

73 artifacts_by_node: dict[str, Any] = {} 

74 node_tasks: dict[str, asyncio.Task[Any]] = {} 

75 waiting_for: dict[str, set[str]] = {} 

76 

77 def check_wait_cycle(waiter: str, target: str) -> None: 

78 stack = [target] 

79 seen: set[str] = set() 

80 while stack: 

81 current = stack.pop() 

82 if current == waiter: 

83 raise ValueError( 

84 "Graph contains cycles on active path: " 

85 f"{waiter} -> {target} -> {waiter}" 

86 ) 

87 if current in seen: 

88 continue 

89 seen.add(current) 

90 stack.extend(waiting_for.get(current, ())) 

91 

92 async def await_existing_task( 

93 node_id: str, 

94 task: asyncio.Task[Any], 

95 waiter: str | None, 

96 ) -> Any: 

97 if task.done() or waiter is None: 

98 return await task 

99 check_wait_cycle(waiter, node_id) 

100 waiting_for.setdefault(waiter, set()).add(node_id) 

101 try: 

102 return await task 

103 finally: 

104 waiters = waiting_for.get(waiter) 

105 if waiters is not None: 

106 waiters.discard(node_id) 

107 if not waiters: 

108 waiting_for.pop(waiter, None) 

109 

110 async def resolve_artifact( 

111 node_id: str, 

112 path: tuple[str, ...] = (), 

113 waiter: str | None = None, 

114 ) -> Any: 

115 if node_id in artifacts_by_node: 

116 return artifacts_by_node[node_id] 

117 

118 if node_id in path: 

119 cycle = " -> ".join([*path, node_id]) 

120 raise ValueError(f"Graph contains cycles on active path: {cycle}") 

121 

122 if node_id in graph: 

123 existing = node_tasks.get(node_id) 

124 if existing is not None: 

125 return await await_existing_task(node_id, existing, waiter) 

126 

127 task = asyncio.create_task(run_graph_node(node_id, (*path, node_id))) 

128 node_tasks[node_id] = task 

129 return await task 

130 

131 if node_id in context: 

132 value = context[node_id] 

133 if not is_cacheable(value): 

134 raise ValueError( 

135 f"Context value for '{node_id}' is not cacheable, " 

136 f"got {type(value)}" 

137 ) 

138 artifacts_by_node[node_id] = value 

139 return value 

140 

141 raise ValueError( 

142 f"Node depends on '{node_id}' but it doesn't exist in graph " 

143 f"or context. Available: graph={sorted(graph)}, " 

144 f"context={sorted(context)}" 

145 ) 

146 

147 async def run_graph_node(node_id: str, path: tuple[str, ...]) -> Any: 

148 node = graph[node_id] 

149 await asyncio.gather( 

150 *( 

151 resolve_artifact(dep_id, path, waiter=node_id) 

152 for dep_id in node.deps 

153 ) 

154 ) 

155 

156 depends_on_uncacheable = any( 

157 dep_id in uncacheable_nodes for dep_id in node.deps 

158 ) 

159 

160 if isinstance(node, SwitchNode): 

161 self._validate_switch_targets(node, node_id, graph) 

162 target_id = self._select_switch_target( 

163 node, 

164 node_id, 

165 artifacts_by_node, 

166 ) 

167 if target_id not in graph: 

168 raise ValueError( 

169 f"SwitchNode '{node_id}' targets '{target_id}' " 

170 "which doesn't exist in graph" 

171 ) 

172 await resolve_artifact(target_id, path, waiter=node_id) 

173 artifacts_by_node[node_id] = artifacts_by_node[target_id] 

174 if depends_on_uncacheable or target_id in uncacheable_nodes: 

175 uncacheable_nodes.add(node_id) 

176 elif isinstance(node, SubGraphNode): 

177 manifest = self._build_manifest( 

178 node, 

179 node_id, 

180 graph, 

181 artifacts_by_node, 

182 ) 

183 inner_uncacheable_context_keys = ( 

184 set(manifest.keys()) if depends_on_uncacheable else set() 

185 ) 

186 inner_results, inner_uncacheable_nodes = ( 

187 await self._execute_requested_outputs_async( 

188 node.graph, 

189 (node.output,), 

190 context=manifest, 

191 uncacheable_context_keys=inner_uncacheable_context_keys, 

192 ) 

193 ) 

194 artifacts_by_node[node_id] = inner_results[node.output] 

195 if depends_on_uncacheable or node.output in inner_uncacheable_nodes: 

196 uncacheable_nodes.add(node_id) 

197 else: 

198 manifest = self._build_manifest( 

199 node, 

200 node_id, 

201 graph, 

202 artifacts_by_node, 

203 ) 

204 artifact = await self._execute_node_async( 

205 node, 

206 node_id, 

207 manifest, 

208 depends_on_uncacheable, 

209 uncacheable_nodes, 

210 ) 

211 artifacts_by_node[node_id] = artifact 

212 

213 return artifacts_by_node[node_id] 

214 

215 try: 

216 await asyncio.gather(*(resolve_artifact(output) for output in outputs)) 

217 except Exception: 

218 for task in node_tasks.values(): 

219 if not task.done(): 

220 task.cancel() 

221 await asyncio.gather(*node_tasks.values(), return_exceptions=True) 

222 raise 

223 

224 return artifacts_by_node, uncacheable_nodes 

225 

226 async def _execute_node_async( 

227 self, 

228 node: Node, 

229 node_id: str, 

230 manifest: dict[str, Any], 

231 depends_on_uncacheable: bool, 

232 uncacheable_nodes: set[str], 

233 ) -> Any: 

234 """Execute one Node using scheduler-driven cache semantics.""" 

235 if not self.registry.has(node.op_name): 

236 raise ValueError( 

237 f"Node '{node_id}' references unregistered op '{node.op_name}'" 

238 ) 

239 

240 binding = self.registry.get_binding(node.op_name) 

241 should_cache = node.cache and not depends_on_uncacheable 

242 if not should_cache: 

243 artifact = await self.scheduler.invoke( 

244 InvocationRequest( 

245 op_name=node.op_name, 

246 op=binding.op, 

247 manifest=manifest, 

248 traits=binding.traits, 

249 implementation_ref=binding.implementation_ref, 

250 ) 

251 ) 

252 uncacheable_nodes.add(node_id) 

253 return artifact 

254 

255 digest = hash_manifest(manifest) 

256 key = (node.op_name, digest) 

257 owner = False 

258 

259 async with self._singleflight_lock: 

260 task = self._singleflight.get(key) 

261 if task is None: 

262 if self.store.exists(node.op_name, digest): 

263 return self.store.get(node.op_name, digest) 

264 task = asyncio.create_task( 

265 self._invoke_and_store(binding, manifest, key) 

266 ) 

267 self._singleflight[key] = task 

268 owner = True 

269 

270 try: 

271 return await task 

272 finally: 

273 if owner: 

274 async with self._singleflight_lock: 

275 if self._singleflight.get(key) is task: 

276 self._singleflight.pop(key, None) 

277 

278 async def _invoke_and_store(self, binding, manifest, key: tuple[str, str]) -> Any: 

279 """Invoke an op through the scheduler and store the resulting artifact.""" 

280 op_name, digest = key 

281 artifact = await self.scheduler.invoke( 

282 InvocationRequest( 

283 op_name=op_name, 

284 op=binding.op, 

285 manifest=manifest, 

286 traits=binding.traits, 

287 implementation_ref=binding.implementation_ref, 

288 cache_key=key, 

289 ) 

290 ) 

291 self.store.put(op_name, digest, artifact) 

292 return artifact 

293 

294 

295__all__ = ["AsyncExecutor"]