Coverage for src / invariant / async_executor.py: 78.72%
141 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-08 09:24 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-08 09:24 +0000
1"""Async executor: scheduler-driven alternative to the safe sync Executor."""
3import asyncio
4from collections.abc import Iterable
5from typing import Any
7from invariant.cacheable import is_cacheable
8from invariant.executor import Executor
9from invariant.graph import Graph
10from invariant.hashing import hash_manifest
11from invariant.node import Node, SubGraphNode, SwitchNode
12from invariant.scheduler import InlineScheduler, InvocationRequest, InvocationScheduler
15class AsyncExecutor(Executor):
16 """Async runtime engine for executing DAGs.
18 The async executor preserves the synchronous executor's graph and cache
19 semantics, while delegating operation placement to an InvocationScheduler.
20 """
22 def __init__(
23 self,
24 registry,
25 store,
26 scheduler: InvocationScheduler | None = None,
27 ) -> None:
28 """Initialize AsyncExecutor."""
29 super().__init__(registry, store)
30 self.scheduler = scheduler or InlineScheduler()
31 self._singleflight: dict[tuple[str, str], asyncio.Task[Any]] = {}
32 self._singleflight_lock = asyncio.Lock()
34 async def __aenter__(self) -> "AsyncExecutor":
35 """Enter an async context manager."""
36 return self
38 async def __aexit__(self, *exc_info: object) -> None:
39 """Close scheduler resources on context exit."""
40 await self.aclose()
42 async def aclose(self) -> None:
43 """Close scheduler resources when supported."""
44 close = getattr(self.scheduler, "aclose", None)
45 if close is not None:
46 await close()
48 async def execute(
49 self,
50 graph: Graph,
51 outputs: Iterable[str],
52 context: dict[str, Any] | None = None,
53 ) -> dict[str, Any]:
54 """Execute requested graph outputs and return their artifacts."""
55 output_ids = self._normalize_outputs(outputs, graph)
56 artifacts_by_node, _ = await self._execute_requested_outputs_async(
57 graph,
58 output_ids,
59 context=context,
60 )
61 return {output: artifacts_by_node[output] for output in output_ids}
63 async def _execute_requested_outputs_async(
64 self,
65 graph: Graph,
66 outputs: tuple[str, ...],
67 context: dict[str, Any] | None = None,
68 uncacheable_context_keys: set[str] | None = None,
69 ) -> tuple[dict[str, Any], set[str]]:
70 """Demand-execute active paths for requested graph outputs."""
71 context = context or {}
72 uncacheable_nodes = set(uncacheable_context_keys or set())
73 artifacts_by_node: dict[str, Any] = {}
74 node_tasks: dict[str, asyncio.Task[Any]] = {}
75 waiting_for: dict[str, set[str]] = {}
77 def check_wait_cycle(waiter: str, target: str) -> None:
78 stack = [target]
79 seen: set[str] = set()
80 while stack:
81 current = stack.pop()
82 if current == waiter:
83 raise ValueError(
84 "Graph contains cycles on active path: "
85 f"{waiter} -> {target} -> {waiter}"
86 )
87 if current in seen:
88 continue
89 seen.add(current)
90 stack.extend(waiting_for.get(current, ()))
92 async def await_existing_task(
93 node_id: str,
94 task: asyncio.Task[Any],
95 waiter: str | None,
96 ) -> Any:
97 if task.done() or waiter is None:
98 return await task
99 check_wait_cycle(waiter, node_id)
100 waiting_for.setdefault(waiter, set()).add(node_id)
101 try:
102 return await task
103 finally:
104 waiters = waiting_for.get(waiter)
105 if waiters is not None:
106 waiters.discard(node_id)
107 if not waiters:
108 waiting_for.pop(waiter, None)
110 async def resolve_artifact(
111 node_id: str,
112 path: tuple[str, ...] = (),
113 waiter: str | None = None,
114 ) -> Any:
115 if node_id in artifacts_by_node:
116 return artifacts_by_node[node_id]
118 if node_id in path:
119 cycle = " -> ".join([*path, node_id])
120 raise ValueError(f"Graph contains cycles on active path: {cycle}")
122 if node_id in graph:
123 existing = node_tasks.get(node_id)
124 if existing is not None:
125 return await await_existing_task(node_id, existing, waiter)
127 task = asyncio.create_task(run_graph_node(node_id, (*path, node_id)))
128 node_tasks[node_id] = task
129 return await task
131 if node_id in context:
132 value = context[node_id]
133 if not is_cacheable(value):
134 raise ValueError(
135 f"Context value for '{node_id}' is not cacheable, "
136 f"got {type(value)}"
137 )
138 artifacts_by_node[node_id] = value
139 return value
141 raise ValueError(
142 f"Node depends on '{node_id}' but it doesn't exist in graph "
143 f"or context. Available: graph={sorted(graph)}, "
144 f"context={sorted(context)}"
145 )
147 async def run_graph_node(node_id: str, path: tuple[str, ...]) -> Any:
148 node = graph[node_id]
149 await asyncio.gather(
150 *(
151 resolve_artifact(dep_id, path, waiter=node_id)
152 for dep_id in node.deps
153 )
154 )
156 depends_on_uncacheable = any(
157 dep_id in uncacheable_nodes for dep_id in node.deps
158 )
160 if isinstance(node, SwitchNode):
161 self._validate_switch_targets(node, node_id, graph)
162 target_id = self._select_switch_target(
163 node,
164 node_id,
165 artifacts_by_node,
166 )
167 if target_id not in graph:
168 raise ValueError(
169 f"SwitchNode '{node_id}' targets '{target_id}' "
170 "which doesn't exist in graph"
171 )
172 await resolve_artifact(target_id, path, waiter=node_id)
173 artifacts_by_node[node_id] = artifacts_by_node[target_id]
174 if depends_on_uncacheable or target_id in uncacheable_nodes:
175 uncacheable_nodes.add(node_id)
176 elif isinstance(node, SubGraphNode):
177 manifest = self._build_manifest(
178 node,
179 node_id,
180 graph,
181 artifacts_by_node,
182 )
183 inner_uncacheable_context_keys = (
184 set(manifest.keys()) if depends_on_uncacheable else set()
185 )
186 inner_results, inner_uncacheable_nodes = (
187 await self._execute_requested_outputs_async(
188 node.graph,
189 (node.output,),
190 context=manifest,
191 uncacheable_context_keys=inner_uncacheable_context_keys,
192 )
193 )
194 artifacts_by_node[node_id] = inner_results[node.output]
195 if depends_on_uncacheable or node.output in inner_uncacheable_nodes:
196 uncacheable_nodes.add(node_id)
197 else:
198 manifest = self._build_manifest(
199 node,
200 node_id,
201 graph,
202 artifacts_by_node,
203 )
204 artifact = await self._execute_node_async(
205 node,
206 node_id,
207 manifest,
208 depends_on_uncacheable,
209 uncacheable_nodes,
210 )
211 artifacts_by_node[node_id] = artifact
213 return artifacts_by_node[node_id]
215 try:
216 await asyncio.gather(*(resolve_artifact(output) for output in outputs))
217 except Exception:
218 for task in node_tasks.values():
219 if not task.done():
220 task.cancel()
221 await asyncio.gather(*node_tasks.values(), return_exceptions=True)
222 raise
224 return artifacts_by_node, uncacheable_nodes
226 async def _execute_node_async(
227 self,
228 node: Node,
229 node_id: str,
230 manifest: dict[str, Any],
231 depends_on_uncacheable: bool,
232 uncacheable_nodes: set[str],
233 ) -> Any:
234 """Execute one Node using scheduler-driven cache semantics."""
235 if not self.registry.has(node.op_name):
236 raise ValueError(
237 f"Node '{node_id}' references unregistered op '{node.op_name}'"
238 )
240 binding = self.registry.get_binding(node.op_name)
241 should_cache = node.cache and not depends_on_uncacheable
242 if not should_cache:
243 artifact = await self.scheduler.invoke(
244 InvocationRequest(
245 op_name=node.op_name,
246 op=binding.op,
247 manifest=manifest,
248 traits=binding.traits,
249 implementation_ref=binding.implementation_ref,
250 )
251 )
252 uncacheable_nodes.add(node_id)
253 return artifact
255 digest = hash_manifest(manifest)
256 key = (node.op_name, digest)
257 owner = False
259 async with self._singleflight_lock:
260 task = self._singleflight.get(key)
261 if task is None:
262 if self.store.exists(node.op_name, digest):
263 return self.store.get(node.op_name, digest)
264 task = asyncio.create_task(
265 self._invoke_and_store(binding, manifest, key)
266 )
267 self._singleflight[key] = task
268 owner = True
270 try:
271 return await task
272 finally:
273 if owner:
274 async with self._singleflight_lock:
275 if self._singleflight.get(key) is task:
276 self._singleflight.pop(key, None)
278 async def _invoke_and_store(self, binding, manifest, key: tuple[str, str]) -> Any:
279 """Invoke an op through the scheduler and store the resulting artifact."""
280 op_name, digest = key
281 artifact = await self.scheduler.invoke(
282 InvocationRequest(
283 op_name=op_name,
284 op=binding.op,
285 manifest=manifest,
286 traits=binding.traits,
287 implementation_ref=binding.implementation_ref,
288 cache_key=key,
289 )
290 )
291 self.store.put(op_name, digest, artifact)
292 return artifact
295__all__ = ["AsyncExecutor"]