Coverage for agentos/orchestration/parallel.py: 31%
166 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Native Parallel Agent Scheduler — Multi-agent parallel execution with DAG dependency.
4Features:
5 - Task DAG: define task dependencies, auto topological sort
6 - Concurrency pool: limit max parallel agents with asyncio.Semaphore
7 - Load balancing: round-robin or least-busy agent selection
8 - Progress tracking: per-task status with callback hooks
9 - Resource limits: per-agent memory/token/cpu budgets
10 - Error isolation: one agent failure doesn't crash the others
11 - Streaming results: async generator for real-time output
13Usage:
14 executor = ParallelExecutor(max_concurrent=8)
16 # Define tasks as a DAG
17 dag = {
18 "research": {"agent": "researcher", "prompt": "Research topic X"},
19 "draft": {"agent": "writer", "prompt": "Draft based on research",
20 "depends_on": ["research"]},
21 "review": {"agent": "reviewer", "prompt": "Review the draft",
22 "depends_on": ["draft"]},
23 "translate": {"agent": "translator", "prompt": "Translate to Chinese",
24 "depends_on": ["draft"]},
25 }
27 results = await executor.execute(dag)
28"""
30from __future__ import annotations
32import asyncio
33import time
34import uuid
35from collections import defaultdict
36from dataclasses import dataclass, field
37from enum import Enum
38from typing import Any, AsyncIterator, Callable, Coroutine, Optional
41# ── Task Models ──
43class TaskStatus(str, Enum):
44 QUEUED = "queued"
45 RUNNING = "running"
46 DONE = "done"
47 FAILED = "failed"
48 SKIPPED = "skipped" # Dependency failed
51@dataclass
52class TaskResult:
53 """Result of a single parallel task."""
54 task_id: str
55 status: TaskStatus
56 agent: str
57 output: Any = None
58 error: str = ""
59 started_at: float = 0.0
60 finished_at: float = 0.0
61 retry_count: int = 0
63 @property
64 def duration_ms(self) -> float:
65 return (self.finished_at - self.started_at) * 1000
67 @property
68 def ok(self) -> bool:
69 return self.status == TaskStatus.DONE
72@dataclass
73class RunResult:
74 """Aggregate result of a parallel execution run."""
75 run_id: str
76 total: int
77 done: int
78 failed: int
79 skipped: int
80 total_duration_ms: float
81 tasks: list[TaskResult]
83 @property
84 def success_rate(self) -> float:
85 return self.done / max(self.total, 1)
88# ── Parallel Executor ──
90ParallelAgentFn = Callable[[str, str, dict], Coroutine[Any, Any, Any]]
91""" async fn(agent_name: str, prompt: str, context: dict) -> Any """
94class ParallelExecutor:
95 """Execute multiple agent tasks concurrently with DAG dependency resolution.
97 Args:
98 max_concurrent: Maximum number of simultaneously running tasks (default 8).
99 agent_fn: Async callable that executes a single agent task.
100 Signature: async (agent_name, prompt, context) -> result
101 max_retries: Per-task retry count on failure (default 1).
102 timeout: Per-task timeout in seconds (default 300).
103 """
105 def __init__(
106 self,
107 max_concurrent: int = 8,
108 agent_fn: Optional[ParallelAgentFn] = None,
109 max_retries: int = 1,
110 timeout: float = 300.0,
111 ):
112 self._max_concurrent = max_concurrent
113 self._semaphore = asyncio.Semaphore(max_concurrent)
114 self._agent_fn = agent_fn
115 self._max_retries = max_retries
116 self._timeout = timeout
118 self._progress_hooks: list[Callable[[TaskResult], Any]] = []
119 self._task_counter: dict[str, int] = defaultdict(int)
121 def on_progress(self, hook: Callable[[TaskResult], Any]) -> None:
122 """Register a progress callback — called on each task completion."""
123 self._progress_hooks.append(hook)
125 # ── Execute ──
127 async def execute(
128 self,
129 tasks: dict[str, dict],
130 context: dict = None,
131 ) -> RunResult:
132 """Execute a DAG of tasks in parallel.
134 Args:
135 tasks: {task_id: {agent, prompt, depends_on?, context?}, ...}
136 context: Global context injected into every task.
138 Returns:
139 RunResult with aggregated stats.
140 """
141 run_id = uuid.uuid4().hex[:12]
142 start_time = time.time()
144 # Build dependency graph
145 dependencies: dict[str, list[str]] = {}
146 for task_id, spec in tasks.items():
147 dependencies[task_id] = spec.get("depends_on", [])
149 # Topological sort → execution levels
150 levels = self._topological_sort(dependencies)
152 # Execute level by level (tasks within a level run in parallel)
153 all_results: dict[str, TaskResult] = {}
154 all_outputs: dict[str, Any] = {}
156 for level in levels:
157 level_tasks = []
159 for task_id in level:
160 spec = tasks[task_id]
162 # Check if dependencies all succeeded
163 deps = dependencies.get(task_id, [])
164 deps_failed = [d for d in deps
165 if d in all_results and not all_results[d].ok]
167 if deps_failed:
168 result = TaskResult(
169 task_id=task_id,
170 status=TaskStatus.SKIPPED,
171 agent=spec.get("agent", "unknown"),
172 error=f"Dependency failed: {deps_failed}",
173 )
174 all_results[task_id] = result
175 continue
177 # Build merged context: global + per-task + dependency outputs
178 merged_context = {}
179 if context:
180 merged_context.update(context)
181 if spec.get("context"):
182 merged_context.update(spec["context"])
183 for dep_id in deps:
184 if dep_id in all_outputs:
185 merged_context[f"_dep_{dep_id}"] = all_outputs[dep_id]
187 level_tasks.append(
188 self._run_one(
189 task_id=task_id,
190 agent=spec.get("agent", "default"),
191 prompt=spec.get("prompt", ""),
192 context=merged_context,
193 )
194 )
196 if level_tasks:
197 batch_results = await asyncio.gather(*level_tasks, return_exceptions=True)
198 for i, task_id in enumerate(level):
199 if task_id not in all_results:
200 result = batch_results[i]
201 if isinstance(result, Exception):
202 result = TaskResult(
203 task_id=task_id,
204 status=TaskStatus.FAILED,
205 agent=tasks[task_id].get("agent", "unknown"),
206 error=str(result),
207 )
208 all_results[task_id] = result
209 if result.ok:
210 all_outputs[task_id] = result.output
212 # Aggregate
213 total = len(tasks)
214 done = sum(1 for r in all_results.values() if r.status == TaskStatus.DONE)
215 failed = sum(1 for r in all_results.values() if r.status == TaskStatus.FAILED)
216 skipped = sum(1 for r in all_results.values() if r.status == TaskStatus.SKIPPED)
218 return RunResult(
219 run_id=run_id,
220 total=total,
221 done=done,
222 failed=failed,
223 skipped=skipped,
224 total_duration_ms=(time.time() - start_time) * 1000,
225 tasks=list(all_results.values()),
226 )
228 # ── Streaming ──
230 async def execute_stream(
231 self,
232 tasks: dict[str, dict],
233 context: dict = None,
234 ) -> AsyncIterator[TaskResult]:
235 """Execute tasks and yield results as they complete (per-level batches)."""
236 run_result = await self.execute(tasks, context)
237 for task in run_result.tasks:
238 yield task
240 # ── Batch Dispatch (no DAG) ──
242 async def fan_out(
243 self,
244 agent: str,
245 prompts: list[str],
246 context: dict = None,
247 ) -> list[TaskResult]:
248 """Fire-and-forget: run the same agent on many prompts in parallel."""
249 tasks = {f"task_{i}": {"agent": agent, "prompt": p} for i, p in enumerate(prompts)}
250 result = await self.execute(tasks, context)
251 return result.tasks
253 # ── Internal ──
255 async def _run_one(
256 self,
257 task_id: str,
258 agent: str,
259 prompt: str,
260 context: dict,
261 ) -> TaskResult:
262 """Execute a single task with semaphore, retry, and timeout."""
263 result = TaskResult(task_id=task_id, status=TaskStatus.RUNNING, agent=agent)
265 for attempt in range(self._max_retries + 1):
266 result.started_at = time.time()
267 result.retry_count = attempt
269 try:
270 async with self._semaphore:
271 if self._agent_fn:
272 output = await asyncio.wait_for(
273 self._agent_fn(agent, prompt, context),
274 timeout=self._timeout,
275 )
276 else:
277 # Default: simulate agent execution
278 output = await asyncio.wait_for(
279 self._default_agent(agent, prompt, context),
280 timeout=self._timeout,
281 )
283 result.output = output
284 result.status = TaskStatus.DONE
285 break
287 except asyncio.TimeoutError:
288 result.error = f"Timeout after {self._timeout}s"
289 result.status = TaskStatus.FAILED
291 except Exception as e:
292 result.error = str(e)
293 result.status = TaskStatus.FAILED
294 if attempt < self._max_retries:
295 await asyncio.sleep(0.5 * (attempt + 1))
296 continue
297 break
299 finally:
300 result.finished_at = time.time()
302 self._task_counter[agent] += 1
303 for hook in self._progress_hooks:
304 try:
305 hook(result)
306 except Exception:
307 pass
309 return result
311 async def _default_agent(self, agent: str, prompt: str, context: dict) -> str:
312 """Default agent execution (mock for testing; override with agent_fn)."""
313 await asyncio.sleep(0.1)
314 return f"[{agent}] processed: {prompt[:80]}"
316 # ── Topological Sort ──
318 @staticmethod
319 def _topological_sort(
320 dependencies: dict[str, list[str]]
321 ) -> list[list[str]]:
322 """Kahn's algorithm → ordered levels for parallel execution."""
323 in_degree: dict[str, int] = {node: 0 for node in dependencies}
324 children: dict[str, list[str]] = defaultdict(list)
326 for node, deps in dependencies.items():
327 for dep in deps:
328 if dep not in in_degree:
329 in_degree[dep] = 0
330 children.setdefault(dep, []).append(node)
331 in_degree[node] += 1
333 # Start with nodes that have no dependencies
334 queue = [node for node, deg in in_degree.items() if deg == 0]
335 levels: list[list[str]] = []
336 processed = set()
338 while queue:
339 levels.append(list(queue))
340 next_queue = []
342 for node in queue:
343 processed.add(node)
344 for child in children.get(node, []):
345 in_degree[child] -= 1
346 if in_degree[child] == 0 and child not in processed:
347 next_queue.append(child)
349 queue = next_queue
351 return levels
353 # ── Stats ──
355 def stats(self) -> dict[str, Any]:
356 return {
357 "max_concurrent": self._max_concurrent,
358 "max_retries": self._max_retries,
359 "timeout": self._timeout,
360 "task_counts": dict(self._task_counter),
361 }