Coverage for agentos/orchestration/parallel.py: 31%

166 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Native Parallel Agent Scheduler — Multi-agent parallel execution with DAG dependency. 

3 

4Features: 

5 - Task DAG: define task dependencies, auto topological sort 

6 - Concurrency pool: limit max parallel agents with asyncio.Semaphore 

7 - Load balancing: round-robin or least-busy agent selection 

8 - Progress tracking: per-task status with callback hooks 

9 - Resource limits: per-agent memory/token/cpu budgets 

10 - Error isolation: one agent failure doesn't crash the others 

11 - Streaming results: async generator for real-time output 

12 

13Usage: 

14 executor = ParallelExecutor(max_concurrent=8) 

15 

16 # Define tasks as a DAG 

17 dag = { 

18 "research": {"agent": "researcher", "prompt": "Research topic X"}, 

19 "draft": {"agent": "writer", "prompt": "Draft based on research", 

20 "depends_on": ["research"]}, 

21 "review": {"agent": "reviewer", "prompt": "Review the draft", 

22 "depends_on": ["draft"]}, 

23 "translate": {"agent": "translator", "prompt": "Translate to Chinese", 

24 "depends_on": ["draft"]}, 

25 } 

26 

27 results = await executor.execute(dag) 

28""" 

29 

30from __future__ import annotations 

31 

32import asyncio 

33import time 

34import uuid 

35from collections import defaultdict 

36from dataclasses import dataclass, field 

37from enum import Enum 

38from typing import Any, AsyncIterator, Callable, Coroutine, Optional 

39 

40 

41# ── Task Models ── 

42 

43class TaskStatus(str, Enum): 

44 QUEUED = "queued" 

45 RUNNING = "running" 

46 DONE = "done" 

47 FAILED = "failed" 

48 SKIPPED = "skipped" # Dependency failed 

49 

50 

51@dataclass 

52class TaskResult: 

53 """Result of a single parallel task.""" 

54 task_id: str 

55 status: TaskStatus 

56 agent: str 

57 output: Any = None 

58 error: str = "" 

59 started_at: float = 0.0 

60 finished_at: float = 0.0 

61 retry_count: int = 0 

62 

63 @property 

64 def duration_ms(self) -> float: 

65 return (self.finished_at - self.started_at) * 1000 

66 

67 @property 

68 def ok(self) -> bool: 

69 return self.status == TaskStatus.DONE 

70 

71 

72@dataclass 

73class RunResult: 

74 """Aggregate result of a parallel execution run.""" 

75 run_id: str 

76 total: int 

77 done: int 

78 failed: int 

79 skipped: int 

80 total_duration_ms: float 

81 tasks: list[TaskResult] 

82 

83 @property 

84 def success_rate(self) -> float: 

85 return self.done / max(self.total, 1) 

86 

87 

88# ── Parallel Executor ── 

89 

90ParallelAgentFn = Callable[[str, str, dict], Coroutine[Any, Any, Any]] 

91""" async fn(agent_name: str, prompt: str, context: dict) -> Any """ 

92 

93 

94class ParallelExecutor: 

95 """Execute multiple agent tasks concurrently with DAG dependency resolution. 

96 

97 Args: 

98 max_concurrent: Maximum number of simultaneously running tasks (default 8). 

99 agent_fn: Async callable that executes a single agent task. 

100 Signature: async (agent_name, prompt, context) -> result 

101 max_retries: Per-task retry count on failure (default 1). 

102 timeout: Per-task timeout in seconds (default 300). 

103 """ 

104 

105 def __init__( 

106 self, 

107 max_concurrent: int = 8, 

108 agent_fn: Optional[ParallelAgentFn] = None, 

109 max_retries: int = 1, 

110 timeout: float = 300.0, 

111 ): 

112 self._max_concurrent = max_concurrent 

113 self._semaphore = asyncio.Semaphore(max_concurrent) 

114 self._agent_fn = agent_fn 

115 self._max_retries = max_retries 

116 self._timeout = timeout 

117 

118 self._progress_hooks: list[Callable[[TaskResult], Any]] = [] 

119 self._task_counter: dict[str, int] = defaultdict(int) 

120 

121 def on_progress(self, hook: Callable[[TaskResult], Any]) -> None: 

122 """Register a progress callback — called on each task completion.""" 

123 self._progress_hooks.append(hook) 

124 

125 # ── Execute ── 

126 

127 async def execute( 

128 self, 

129 tasks: dict[str, dict], 

130 context: dict = None, 

131 ) -> RunResult: 

132 """Execute a DAG of tasks in parallel. 

133 

134 Args: 

135 tasks: {task_id: {agent, prompt, depends_on?, context?}, ...} 

136 context: Global context injected into every task. 

137 

138 Returns: 

139 RunResult with aggregated stats. 

140 """ 

141 run_id = uuid.uuid4().hex[:12] 

142 start_time = time.time() 

143 

144 # Build dependency graph 

145 dependencies: dict[str, list[str]] = {} 

146 for task_id, spec in tasks.items(): 

147 dependencies[task_id] = spec.get("depends_on", []) 

148 

149 # Topological sort → execution levels 

150 levels = self._topological_sort(dependencies) 

151 

152 # Execute level by level (tasks within a level run in parallel) 

153 all_results: dict[str, TaskResult] = {} 

154 all_outputs: dict[str, Any] = {} 

155 

156 for level in levels: 

157 level_tasks = [] 

158 

159 for task_id in level: 

160 spec = tasks[task_id] 

161 

162 # Check if dependencies all succeeded 

163 deps = dependencies.get(task_id, []) 

164 deps_failed = [d for d in deps 

165 if d in all_results and not all_results[d].ok] 

166 

167 if deps_failed: 

168 result = TaskResult( 

169 task_id=task_id, 

170 status=TaskStatus.SKIPPED, 

171 agent=spec.get("agent", "unknown"), 

172 error=f"Dependency failed: {deps_failed}", 

173 ) 

174 all_results[task_id] = result 

175 continue 

176 

177 # Build merged context: global + per-task + dependency outputs 

178 merged_context = {} 

179 if context: 

180 merged_context.update(context) 

181 if spec.get("context"): 

182 merged_context.update(spec["context"]) 

183 for dep_id in deps: 

184 if dep_id in all_outputs: 

185 merged_context[f"_dep_{dep_id}"] = all_outputs[dep_id] 

186 

187 level_tasks.append( 

188 self._run_one( 

189 task_id=task_id, 

190 agent=spec.get("agent", "default"), 

191 prompt=spec.get("prompt", ""), 

192 context=merged_context, 

193 ) 

194 ) 

195 

196 if level_tasks: 

197 batch_results = await asyncio.gather(*level_tasks, return_exceptions=True) 

198 for i, task_id in enumerate(level): 

199 if task_id not in all_results: 

200 result = batch_results[i] 

201 if isinstance(result, Exception): 

202 result = TaskResult( 

203 task_id=task_id, 

204 status=TaskStatus.FAILED, 

205 agent=tasks[task_id].get("agent", "unknown"), 

206 error=str(result), 

207 ) 

208 all_results[task_id] = result 

209 if result.ok: 

210 all_outputs[task_id] = result.output 

211 

212 # Aggregate 

213 total = len(tasks) 

214 done = sum(1 for r in all_results.values() if r.status == TaskStatus.DONE) 

215 failed = sum(1 for r in all_results.values() if r.status == TaskStatus.FAILED) 

216 skipped = sum(1 for r in all_results.values() if r.status == TaskStatus.SKIPPED) 

217 

218 return RunResult( 

219 run_id=run_id, 

220 total=total, 

221 done=done, 

222 failed=failed, 

223 skipped=skipped, 

224 total_duration_ms=(time.time() - start_time) * 1000, 

225 tasks=list(all_results.values()), 

226 ) 

227 

228 # ── Streaming ── 

229 

230 async def execute_stream( 

231 self, 

232 tasks: dict[str, dict], 

233 context: dict = None, 

234 ) -> AsyncIterator[TaskResult]: 

235 """Execute tasks and yield results as they complete (per-level batches).""" 

236 run_result = await self.execute(tasks, context) 

237 for task in run_result.tasks: 

238 yield task 

239 

240 # ── Batch Dispatch (no DAG) ── 

241 

242 async def fan_out( 

243 self, 

244 agent: str, 

245 prompts: list[str], 

246 context: dict = None, 

247 ) -> list[TaskResult]: 

248 """Fire-and-forget: run the same agent on many prompts in parallel.""" 

249 tasks = {f"task_{i}": {"agent": agent, "prompt": p} for i, p in enumerate(prompts)} 

250 result = await self.execute(tasks, context) 

251 return result.tasks 

252 

253 # ── Internal ── 

254 

255 async def _run_one( 

256 self, 

257 task_id: str, 

258 agent: str, 

259 prompt: str, 

260 context: dict, 

261 ) -> TaskResult: 

262 """Execute a single task with semaphore, retry, and timeout.""" 

263 result = TaskResult(task_id=task_id, status=TaskStatus.RUNNING, agent=agent) 

264 

265 for attempt in range(self._max_retries + 1): 

266 result.started_at = time.time() 

267 result.retry_count = attempt 

268 

269 try: 

270 async with self._semaphore: 

271 if self._agent_fn: 

272 output = await asyncio.wait_for( 

273 self._agent_fn(agent, prompt, context), 

274 timeout=self._timeout, 

275 ) 

276 else: 

277 # Default: simulate agent execution 

278 output = await asyncio.wait_for( 

279 self._default_agent(agent, prompt, context), 

280 timeout=self._timeout, 

281 ) 

282 

283 result.output = output 

284 result.status = TaskStatus.DONE 

285 break 

286 

287 except asyncio.TimeoutError: 

288 result.error = f"Timeout after {self._timeout}s" 

289 result.status = TaskStatus.FAILED 

290 

291 except Exception as e: 

292 result.error = str(e) 

293 result.status = TaskStatus.FAILED 

294 if attempt < self._max_retries: 

295 await asyncio.sleep(0.5 * (attempt + 1)) 

296 continue 

297 break 

298 

299 finally: 

300 result.finished_at = time.time() 

301 

302 self._task_counter[agent] += 1 

303 for hook in self._progress_hooks: 

304 try: 

305 hook(result) 

306 except Exception: 

307 pass 

308 

309 return result 

310 

311 async def _default_agent(self, agent: str, prompt: str, context: dict) -> str: 

312 """Default agent execution (mock for testing; override with agent_fn).""" 

313 await asyncio.sleep(0.1) 

314 return f"[{agent}] processed: {prompt[:80]}" 

315 

316 # ── Topological Sort ── 

317 

318 @staticmethod 

319 def _topological_sort( 

320 dependencies: dict[str, list[str]] 

321 ) -> list[list[str]]: 

322 """Kahn's algorithm → ordered levels for parallel execution.""" 

323 in_degree: dict[str, int] = {node: 0 for node in dependencies} 

324 children: dict[str, list[str]] = defaultdict(list) 

325 

326 for node, deps in dependencies.items(): 

327 for dep in deps: 

328 if dep not in in_degree: 

329 in_degree[dep] = 0 

330 children.setdefault(dep, []).append(node) 

331 in_degree[node] += 1 

332 

333 # Start with nodes that have no dependencies 

334 queue = [node for node, deg in in_degree.items() if deg == 0] 

335 levels: list[list[str]] = [] 

336 processed = set() 

337 

338 while queue: 

339 levels.append(list(queue)) 

340 next_queue = [] 

341 

342 for node in queue: 

343 processed.add(node) 

344 for child in children.get(node, []): 

345 in_degree[child] -= 1 

346 if in_degree[child] == 0 and child not in processed: 

347 next_queue.append(child) 

348 

349 queue = next_queue 

350 

351 return levels 

352 

353 # ── Stats ── 

354 

355 def stats(self) -> dict[str, Any]: 

356 return { 

357 "max_concurrent": self._max_concurrent, 

358 "max_retries": self._max_retries, 

359 "timeout": self._timeout, 

360 "task_counts": dict(self._task_counter), 

361 }