Coverage for agentos/concurrent/parallel.py: 0%
178 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Async Parallel Execution Primitives — fan-out/fan-in for multi-agent tasks.
4Provides structured concurrency patterns for AgentOS:
5- TaskGroup: structured async task grouping with collective result/error handling
6- ParallelExecutor: fan-out/fan-in with timeout, cancellation, throttling
7- parallel_gather: await multiple coroutines with timeout & partial results
8- parallel_map: map a function over items concurrently with bounded parallelism
10Key features:
11- Structured concurrency (all-or-nothing or partial results)
12- Semaphore-based throttling (bounded parallelism)
13- Per-task timeout + global timeout
14- Result aggregation with success/failure tracking
15- Graceful cancellation propagation
16"""
18from __future__ import annotations
20import asyncio
21import time
22import uuid
23from dataclasses import dataclass, field
24from enum import Enum
25from typing import Any, Callable, Coroutine, Optional, TypeVar, Union
27T = TypeVar("T")
28R = TypeVar("R")
31# ── Data Structures ──────────────────────────────────────────────
33class TaskStatus(str, Enum):
34 """Individual task execution status."""
35 PENDING = "pending"
36 RUNNING = "running"
37 COMPLETED = "completed"
38 FAILED = "failed"
39 TIMEOUT = "timeout"
40 CANCELLED = "cancelled"
43@dataclass
44class TaskResult:
45 """Result of a single parallel task."""
46 task_id: str
47 status: TaskStatus = TaskStatus.PENDING
48 result: Any = None
49 error: Optional[Exception] = None
50 started_at: float = 0.0
51 finished_at: float = 0.0
52 duration_ms: float = 0.0
53 retries: int = 0
56@dataclass
57class GatherResult:
58 """Aggregated result from parallel_gather."""
59 results: list[TaskResult] = field(default_factory=list)
60 total: int = 0
61 completed: int = 0
62 failed: int = 0
63 timed_out: int = 0
64 cancelled: int = 0
65 total_duration_ms: float = 0.0
66 all_succeeded: bool = False
68 @property
69 def success_rate(self) -> float:
70 return self.completed / self.total if self.total > 0 else 0.0
72 def get_results(self) -> list[Any]:
73 """Extract successful result values."""
74 return [r.result for r in self.results if r.status == TaskStatus.COMPLETED]
76 def get_errors(self) -> list[tuple[str, Exception]]:
77 """Extract (task_id, error) pairs for failed tasks."""
78 return [(r.task_id, r.error) for r in self.results if r.error]
81# ── Semaphore-based Task Throttler ───────────────────────────────
83class TaskThrottler:
84 """
85 Bounded concurrency controller using asyncio.Semaphore.
87 Usage:
88 throttler = TaskThrottler(max_concurrent=5)
89 async with throttler:
90 await do_work()
91 """
93 def __init__(self, max_concurrent: int = 10):
94 if max_concurrent < 1:
95 raise ValueError("max_concurrent must be >= 1")
96 self.max_concurrent = max_concurrent
97 self._semaphore = asyncio.Semaphore(max_concurrent)
98 self._active = 0
99 self._peak = 0
101 @property
102 def active(self) -> int:
103 return self._active
105 @property
106 def peak(self) -> int:
107 return self._peak
109 async def __aenter__(self):
110 await self._semaphore.acquire()
111 self._active += 1
112 if self._active > self._peak:
113 self._peak = self._active
114 return self
116 async def __aexit__(self, *args):
117 self._active -= 1
118 self._semaphore.release()
121# ── Parallel Executor ────────────────────────────────────────────
123class ParallelExecutor:
124 """
125 Fan-out/fan-in executor for running multiple coroutines concurrently.
127 Supports structured concurrency: either wait for all (fail-fast or tolerant),
128 or collect partial results on timeout.
130 Usage:
131 executor = ParallelExecutor(max_concurrent=8, timeout=30.0)
132 result = await executor.gather([
133 agent1.run(task_a),
134 agent2.run(task_b),
135 agent3.run(task_c),
136 ])
137 print(f"{result.completed}/{result.total} succeeded")
138 """
140 def __init__(
141 self,
142 max_concurrent: int = 8,
143 timeout: float = 60.0,
144 fail_fast: bool = False,
145 ):
146 self.max_concurrent = max_concurrent
147 self.timeout = timeout
148 self.fail_fast = fail_fast
149 self.throttler = TaskThrottler(max_concurrent)
151 async def gather(
152 self,
153 coros: list[Coroutine],
154 timeout: Optional[float] = None,
155 return_partial: bool = True,
156 ) -> GatherResult:
157 """
158 Execute multiple coroutines in parallel with bounded concurrency.
160 Args:
161 coros: List of coroutines to execute
162 timeout: Global timeout (overrides executor default)
163 return_partial: If True, return partial results on timeout; if False, raise
165 Returns:
166 GatherResult with aggregated status
167 """
168 effective_timeout = timeout if timeout is not None else self.timeout
169 total_start = time.monotonic()
171 results: list[TaskResult] = []
172 tasks: dict[str, asyncio.Task] = {}
174 async def _run_one(coro: Coroutine, task_id: str) -> None:
175 tr = TaskResult(task_id=task_id, status=TaskStatus.RUNNING, started_at=time.monotonic())
176 results.append(tr)
178 async with self.throttler:
179 try:
180 tr.result = await coro
181 tr.status = TaskStatus.COMPLETED
182 except asyncio.CancelledError:
183 tr.status = TaskStatus.CANCELLED
184 if self.fail_fast:
185 raise
186 except asyncio.TimeoutError:
187 tr.status = TaskStatus.TIMEOUT
188 tr.error = TimeoutError(f"Task {task_id} timed out")
189 if self.fail_fast:
190 raise
191 except Exception as e:
192 tr.status = TaskStatus.FAILED
193 tr.error = e
194 if self.fail_fast:
195 raise
196 finally:
197 tr.finished_at = time.monotonic()
198 tr.duration_ms = (tr.finished_at - tr.started_at) * 1000
200 # Launch all tasks
201 for i, coro in enumerate(coros):
202 task_id = uuid.uuid4().hex[:10]
203 t = asyncio.create_task(_run_one(coro, task_id))
204 tasks[task_id] = t
206 # Wait with global timeout
207 try:
208 done, pending = await asyncio.wait(
209 tasks.values(),
210 timeout=effective_timeout,
211 return_when=asyncio.ALL_COMPLETED if not self.fail_fast else asyncio.FIRST_EXCEPTION,
212 )
214 # Cancel remaining on fail-fast
215 if pending and self.fail_fast:
216 for t in pending:
217 t.cancel()
219 # Handle timeout: cancel remaining if not return_partial
220 if pending and not return_partial:
221 for t in pending:
222 t.cancel()
223 raise TimeoutError(f"Gather timed out after {effective_timeout}s")
225 # Mark timed-out tasks
226 for t in pending:
227 t.cancel()
228 for tr in results:
229 if tr.status == TaskStatus.RUNNING:
230 tr.status = TaskStatus.TIMEOUT
231 tr.finished_at = time.monotonic()
232 tr.duration_ms = (tr.finished_at - tr.started_at) * 1000
234 except Exception:
235 for t in tasks.values():
236 if not t.done():
237 t.cancel()
238 raise
240 # Build aggregate result
241 total_duration = (time.monotonic() - total_start) * 1000
242 completed = sum(1 for r in results if r.status == TaskStatus.COMPLETED)
243 failed = sum(1 for r in results if r.status == TaskStatus.FAILED)
244 timed_out = sum(1 for r in results if r.status == TaskStatus.TIMEOUT)
245 cancelled = sum(1 for r in results if r.status == TaskStatus.CANCELLED)
247 return GatherResult(
248 results=results,
249 total=len(results),
250 completed=completed,
251 failed=failed,
252 timed_out=timed_out,
253 cancelled=cancelled,
254 total_duration_ms=total_duration,
255 all_succeeded=(completed == len(results)),
256 )
258 async def map(
259 self,
260 func: Callable[[T], Coroutine],
261 items: list[T],
262 timeout: Optional[float] = None,
263 ) -> GatherResult:
264 """
265 Map an async function over a list of items with bounded concurrency.
267 Args:
268 func: Async function taking one item and returning a value
269 items: List of input items
270 timeout: Global timeout
272 Returns:
273 GatherResult with results
274 """
275 coros = [func(item) for item in items]
276 return await self.gather(coros, timeout=timeout)
279# ── Convenience Functions ────────────────────────────────────────
281async def parallel_gather(
282 *coros: Coroutine,
283 max_concurrent: int = 8,
284 timeout: float = 60.0,
285 return_partial: bool = True,
286) -> GatherResult:
287 """
288 Convenience function: await multiple coroutines in parallel.
290 Usage:
291 result = await parallel_gather(
292 fetch_url(url1),
293 fetch_url(url2),
294 fetch_url(url3),
295 max_concurrent=5, timeout=30.0,
296 )
297 for r in result.get_results():
298 print(r)
299 """
300 executor = ParallelExecutor(max_concurrent=max_concurrent, timeout=timeout)
301 return await executor.gather(list(coros), return_partial=return_partial)
304async def parallel_map(
305 func: Callable[[T], Coroutine],
306 items: list[T],
307 max_concurrent: int = 8,
308 timeout: float = 60.0,
309) -> GatherResult:
310 """
311 Convenience function: map async function over items concurrently.
313 Usage:
314 result = await parallel_map(process_document, documents, max_concurrent=4)
315 print(f"Processed {result.completed}/{result.total} docs")
316 """
317 executor = ParallelExecutor(max_concurrent=max_concurrent, timeout=timeout)
318 return await executor.map(func, items)
321# ── Fan-Out / Fan-In with Aggregation ────────────────────────────
323@dataclass
324class FanOutConfig:
325 """Configuration for fan-out pattern."""
326 max_concurrent: int = 8
327 timeout: float = 60.0
328 aggregation: str = "all" # "all" | "first" | "merge"
329 retry_failed: bool = False
330 max_retries: int = 2
333class FanOutExecutor:
334 """
335 Fan-out pattern: dispatch tasks to N workers, collect results.
337 Supports aggregation modes:
338 - "all": Wait for all, return list of results
339 - "first": Return first successful result (race)
340 - "merge": Run all, merge results with a merge function
341 """
343 def __init__(self, config: Optional[FanOutConfig] = None):
344 self.config = config or FanOutConfig()
345 self.executor = ParallelExecutor(
346 max_concurrent=self.config.max_concurrent,
347 timeout=self.config.timeout,
348 )
350 async def fan_out(
351 self,
352 worker_coros: list[Coroutine],
353 merge_fn: Optional[Callable[[list[Any]], Any]] = None,
354 ) -> Union[list[Any], Any, GatherResult]:
355 """
356 Fan out tasks to workers and collect results.
358 Args:
359 worker_coros: List of worker coroutines
360 merge_fn: Merge function for "merge" mode (list of results -> merged value)
362 Returns:
363 Depends on aggregation mode:
364 - "all": list of results
365 - "first": first successful result
366 - "merge": merged result via merge_fn
367 """
368 mode = self.config.aggregation
370 if mode == "first":
371 # Race: return first successful
372 gather_result = await self.executor.gather(
373 worker_coros,
374 return_partial=True,
375 )
376 successes = gather_result.get_results()
377 if successes:
378 return successes[0]
379 # All failed — raise first error
380 errors = gather_result.get_errors()
381 if errors:
382 raise errors[0][1]
383 raise RuntimeError("All workers failed with no result")
385 elif mode == "merge":
386 gather_result = await self.executor.gather(worker_coros)
387 if not merge_fn:
388 raise ValueError("merge_fn required for 'merge' mode")
389 return merge_fn(gather_result.get_results())
391 else: # "all"
392 gather_result = await self.executor.gather(worker_coros)
393 if self.config.retry_failed and gather_result.failed > 0:
394 # Retry failed tasks
395 retry_coros = []
396 for r in gather_result.results:
397 if r.status == TaskStatus.FAILED:
398 # Note: retry requires caller to provide a way to rebuild the coro
399 pass
400 return gather_result
403# ── Agent Loop Integration ────────────────────────────────────────
405def create_parallel_agent_gather(
406 max_concurrent: int = 8,
407 timeout: float = 60.0,
408) -> Callable:
409 """
410 Create a gather function for use in Agent tool definitions.
412 Usage:
413 agent_tools["parallel_gather"] = create_parallel_agent_gather(max_concurrent=5)
414 """
415 async def agent_gather(*coros: Coroutine) -> GatherResult:
416 return await parallel_gather(*coros, max_concurrent=max_concurrent, timeout=timeout)
418 return agent_gather