Coverage for agentos/concurrent/parallel.py: 0%

178 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Async Parallel Execution Primitives — fan-out/fan-in for multi-agent tasks. 

3 

4Provides structured concurrency patterns for AgentOS: 

5- TaskGroup: structured async task grouping with collective result/error handling 

6- ParallelExecutor: fan-out/fan-in with timeout, cancellation, throttling 

7- parallel_gather: await multiple coroutines with timeout & partial results 

8- parallel_map: map a function over items concurrently with bounded parallelism 

9 

10Key features: 

11- Structured concurrency (all-or-nothing or partial results) 

12- Semaphore-based throttling (bounded parallelism) 

13- Per-task timeout + global timeout 

14- Result aggregation with success/failure tracking 

15- Graceful cancellation propagation 

16""" 

17 

18from __future__ import annotations 

19 

20import asyncio 

21import time 

22import uuid 

23from dataclasses import dataclass, field 

24from enum import Enum 

25from typing import Any, Callable, Coroutine, Optional, TypeVar, Union 

26 

27T = TypeVar("T") 

28R = TypeVar("R") 

29 

30 

31# ── Data Structures ────────────────────────────────────────────── 

32 

33class TaskStatus(str, Enum): 

34 """Individual task execution status.""" 

35 PENDING = "pending" 

36 RUNNING = "running" 

37 COMPLETED = "completed" 

38 FAILED = "failed" 

39 TIMEOUT = "timeout" 

40 CANCELLED = "cancelled" 

41 

42 

43@dataclass 

44class TaskResult: 

45 """Result of a single parallel task.""" 

46 task_id: str 

47 status: TaskStatus = TaskStatus.PENDING 

48 result: Any = None 

49 error: Optional[Exception] = None 

50 started_at: float = 0.0 

51 finished_at: float = 0.0 

52 duration_ms: float = 0.0 

53 retries: int = 0 

54 

55 

56@dataclass 

57class GatherResult: 

58 """Aggregated result from parallel_gather.""" 

59 results: list[TaskResult] = field(default_factory=list) 

60 total: int = 0 

61 completed: int = 0 

62 failed: int = 0 

63 timed_out: int = 0 

64 cancelled: int = 0 

65 total_duration_ms: float = 0.0 

66 all_succeeded: bool = False 

67 

68 @property 

69 def success_rate(self) -> float: 

70 return self.completed / self.total if self.total > 0 else 0.0 

71 

72 def get_results(self) -> list[Any]: 

73 """Extract successful result values.""" 

74 return [r.result for r in self.results if r.status == TaskStatus.COMPLETED] 

75 

76 def get_errors(self) -> list[tuple[str, Exception]]: 

77 """Extract (task_id, error) pairs for failed tasks.""" 

78 return [(r.task_id, r.error) for r in self.results if r.error] 

79 

80 

81# ── Semaphore-based Task Throttler ─────────────────────────────── 

82 

83class TaskThrottler: 

84 """ 

85 Bounded concurrency controller using asyncio.Semaphore. 

86 

87 Usage: 

88 throttler = TaskThrottler(max_concurrent=5) 

89 async with throttler: 

90 await do_work() 

91 """ 

92 

93 def __init__(self, max_concurrent: int = 10): 

94 if max_concurrent < 1: 

95 raise ValueError("max_concurrent must be >= 1") 

96 self.max_concurrent = max_concurrent 

97 self._semaphore = asyncio.Semaphore(max_concurrent) 

98 self._active = 0 

99 self._peak = 0 

100 

101 @property 

102 def active(self) -> int: 

103 return self._active 

104 

105 @property 

106 def peak(self) -> int: 

107 return self._peak 

108 

109 async def __aenter__(self): 

110 await self._semaphore.acquire() 

111 self._active += 1 

112 if self._active > self._peak: 

113 self._peak = self._active 

114 return self 

115 

116 async def __aexit__(self, *args): 

117 self._active -= 1 

118 self._semaphore.release() 

119 

120 

121# ── Parallel Executor ──────────────────────────────────────────── 

122 

123class ParallelExecutor: 

124 """ 

125 Fan-out/fan-in executor for running multiple coroutines concurrently. 

126 

127 Supports structured concurrency: either wait for all (fail-fast or tolerant), 

128 or collect partial results on timeout. 

129 

130 Usage: 

131 executor = ParallelExecutor(max_concurrent=8, timeout=30.0) 

132 result = await executor.gather([ 

133 agent1.run(task_a), 

134 agent2.run(task_b), 

135 agent3.run(task_c), 

136 ]) 

137 print(f"{result.completed}/{result.total} succeeded") 

138 """ 

139 

140 def __init__( 

141 self, 

142 max_concurrent: int = 8, 

143 timeout: float = 60.0, 

144 fail_fast: bool = False, 

145 ): 

146 self.max_concurrent = max_concurrent 

147 self.timeout = timeout 

148 self.fail_fast = fail_fast 

149 self.throttler = TaskThrottler(max_concurrent) 

150 

151 async def gather( 

152 self, 

153 coros: list[Coroutine], 

154 timeout: Optional[float] = None, 

155 return_partial: bool = True, 

156 ) -> GatherResult: 

157 """ 

158 Execute multiple coroutines in parallel with bounded concurrency. 

159 

160 Args: 

161 coros: List of coroutines to execute 

162 timeout: Global timeout (overrides executor default) 

163 return_partial: If True, return partial results on timeout; if False, raise 

164 

165 Returns: 

166 GatherResult with aggregated status 

167 """ 

168 effective_timeout = timeout if timeout is not None else self.timeout 

169 total_start = time.monotonic() 

170 

171 results: list[TaskResult] = [] 

172 tasks: dict[str, asyncio.Task] = {} 

173 

174 async def _run_one(coro: Coroutine, task_id: str) -> None: 

175 tr = TaskResult(task_id=task_id, status=TaskStatus.RUNNING, started_at=time.monotonic()) 

176 results.append(tr) 

177 

178 async with self.throttler: 

179 try: 

180 tr.result = await coro 

181 tr.status = TaskStatus.COMPLETED 

182 except asyncio.CancelledError: 

183 tr.status = TaskStatus.CANCELLED 

184 if self.fail_fast: 

185 raise 

186 except asyncio.TimeoutError: 

187 tr.status = TaskStatus.TIMEOUT 

188 tr.error = TimeoutError(f"Task {task_id} timed out") 

189 if self.fail_fast: 

190 raise 

191 except Exception as e: 

192 tr.status = TaskStatus.FAILED 

193 tr.error = e 

194 if self.fail_fast: 

195 raise 

196 finally: 

197 tr.finished_at = time.monotonic() 

198 tr.duration_ms = (tr.finished_at - tr.started_at) * 1000 

199 

200 # Launch all tasks 

201 for i, coro in enumerate(coros): 

202 task_id = uuid.uuid4().hex[:10] 

203 t = asyncio.create_task(_run_one(coro, task_id)) 

204 tasks[task_id] = t 

205 

206 # Wait with global timeout 

207 try: 

208 done, pending = await asyncio.wait( 

209 tasks.values(), 

210 timeout=effective_timeout, 

211 return_when=asyncio.ALL_COMPLETED if not self.fail_fast else asyncio.FIRST_EXCEPTION, 

212 ) 

213 

214 # Cancel remaining on fail-fast 

215 if pending and self.fail_fast: 

216 for t in pending: 

217 t.cancel() 

218 

219 # Handle timeout: cancel remaining if not return_partial 

220 if pending and not return_partial: 

221 for t in pending: 

222 t.cancel() 

223 raise TimeoutError(f"Gather timed out after {effective_timeout}s") 

224 

225 # Mark timed-out tasks 

226 for t in pending: 

227 t.cancel() 

228 for tr in results: 

229 if tr.status == TaskStatus.RUNNING: 

230 tr.status = TaskStatus.TIMEOUT 

231 tr.finished_at = time.monotonic() 

232 tr.duration_ms = (tr.finished_at - tr.started_at) * 1000 

233 

234 except Exception: 

235 for t in tasks.values(): 

236 if not t.done(): 

237 t.cancel() 

238 raise 

239 

240 # Build aggregate result 

241 total_duration = (time.monotonic() - total_start) * 1000 

242 completed = sum(1 for r in results if r.status == TaskStatus.COMPLETED) 

243 failed = sum(1 for r in results if r.status == TaskStatus.FAILED) 

244 timed_out = sum(1 for r in results if r.status == TaskStatus.TIMEOUT) 

245 cancelled = sum(1 for r in results if r.status == TaskStatus.CANCELLED) 

246 

247 return GatherResult( 

248 results=results, 

249 total=len(results), 

250 completed=completed, 

251 failed=failed, 

252 timed_out=timed_out, 

253 cancelled=cancelled, 

254 total_duration_ms=total_duration, 

255 all_succeeded=(completed == len(results)), 

256 ) 

257 

258 async def map( 

259 self, 

260 func: Callable[[T], Coroutine], 

261 items: list[T], 

262 timeout: Optional[float] = None, 

263 ) -> GatherResult: 

264 """ 

265 Map an async function over a list of items with bounded concurrency. 

266 

267 Args: 

268 func: Async function taking one item and returning a value 

269 items: List of input items 

270 timeout: Global timeout 

271 

272 Returns: 

273 GatherResult with results 

274 """ 

275 coros = [func(item) for item in items] 

276 return await self.gather(coros, timeout=timeout) 

277 

278 

279# ── Convenience Functions ──────────────────────────────────────── 

280 

281async def parallel_gather( 

282 *coros: Coroutine, 

283 max_concurrent: int = 8, 

284 timeout: float = 60.0, 

285 return_partial: bool = True, 

286) -> GatherResult: 

287 """ 

288 Convenience function: await multiple coroutines in parallel. 

289 

290 Usage: 

291 result = await parallel_gather( 

292 fetch_url(url1), 

293 fetch_url(url2), 

294 fetch_url(url3), 

295 max_concurrent=5, timeout=30.0, 

296 ) 

297 for r in result.get_results(): 

298 print(r) 

299 """ 

300 executor = ParallelExecutor(max_concurrent=max_concurrent, timeout=timeout) 

301 return await executor.gather(list(coros), return_partial=return_partial) 

302 

303 

304async def parallel_map( 

305 func: Callable[[T], Coroutine], 

306 items: list[T], 

307 max_concurrent: int = 8, 

308 timeout: float = 60.0, 

309) -> GatherResult: 

310 """ 

311 Convenience function: map async function over items concurrently. 

312 

313 Usage: 

314 result = await parallel_map(process_document, documents, max_concurrent=4) 

315 print(f"Processed {result.completed}/{result.total} docs") 

316 """ 

317 executor = ParallelExecutor(max_concurrent=max_concurrent, timeout=timeout) 

318 return await executor.map(func, items) 

319 

320 

321# ── Fan-Out / Fan-In with Aggregation ──────────────────────────── 

322 

323@dataclass 

324class FanOutConfig: 

325 """Configuration for fan-out pattern.""" 

326 max_concurrent: int = 8 

327 timeout: float = 60.0 

328 aggregation: str = "all" # "all" | "first" | "merge" 

329 retry_failed: bool = False 

330 max_retries: int = 2 

331 

332 

333class FanOutExecutor: 

334 """ 

335 Fan-out pattern: dispatch tasks to N workers, collect results. 

336 

337 Supports aggregation modes: 

338 - "all": Wait for all, return list of results 

339 - "first": Return first successful result (race) 

340 - "merge": Run all, merge results with a merge function 

341 """ 

342 

343 def __init__(self, config: Optional[FanOutConfig] = None): 

344 self.config = config or FanOutConfig() 

345 self.executor = ParallelExecutor( 

346 max_concurrent=self.config.max_concurrent, 

347 timeout=self.config.timeout, 

348 ) 

349 

350 async def fan_out( 

351 self, 

352 worker_coros: list[Coroutine], 

353 merge_fn: Optional[Callable[[list[Any]], Any]] = None, 

354 ) -> Union[list[Any], Any, GatherResult]: 

355 """ 

356 Fan out tasks to workers and collect results. 

357 

358 Args: 

359 worker_coros: List of worker coroutines 

360 merge_fn: Merge function for "merge" mode (list of results -> merged value) 

361 

362 Returns: 

363 Depends on aggregation mode: 

364 - "all": list of results 

365 - "first": first successful result 

366 - "merge": merged result via merge_fn 

367 """ 

368 mode = self.config.aggregation 

369 

370 if mode == "first": 

371 # Race: return first successful 

372 gather_result = await self.executor.gather( 

373 worker_coros, 

374 return_partial=True, 

375 ) 

376 successes = gather_result.get_results() 

377 if successes: 

378 return successes[0] 

379 # All failed — raise first error 

380 errors = gather_result.get_errors() 

381 if errors: 

382 raise errors[0][1] 

383 raise RuntimeError("All workers failed with no result") 

384 

385 elif mode == "merge": 

386 gather_result = await self.executor.gather(worker_coros) 

387 if not merge_fn: 

388 raise ValueError("merge_fn required for 'merge' mode") 

389 return merge_fn(gather_result.get_results()) 

390 

391 else: # "all" 

392 gather_result = await self.executor.gather(worker_coros) 

393 if self.config.retry_failed and gather_result.failed > 0: 

394 # Retry failed tasks 

395 retry_coros = [] 

396 for r in gather_result.results: 

397 if r.status == TaskStatus.FAILED: 

398 # Note: retry requires caller to provide a way to rebuild the coro 

399 pass 

400 return gather_result 

401 

402 

403# ── Agent Loop Integration ──────────────────────────────────────── 

404 

405def create_parallel_agent_gather( 

406 max_concurrent: int = 8, 

407 timeout: float = 60.0, 

408) -> Callable: 

409 """ 

410 Create a gather function for use in Agent tool definitions. 

411 

412 Usage: 

413 agent_tools["parallel_gather"] = create_parallel_agent_gather(max_concurrent=5) 

414 """ 

415 async def agent_gather(*coros: Coroutine) -> GatherResult: 

416 return await parallel_gather(*coros, max_concurrent=max_concurrent, timeout=timeout) 

417 

418 return agent_gather