Coverage for agentos/concurrency/batch.py: 55%
128 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AsyncBatchExecutor — Concurrent agent task dispatch with configurable
3parallelism, timeout, retry, and result aggregation.
5Designed for running multiple AgentOS tasks in parallel (e.g., batch
6evaluation, multi-model comparison, bulk processing).
7"""
9from __future__ import annotations
11import asyncio
12import logging
13import time
14from dataclasses import dataclass, field
15from enum import Enum
16from typing import Any, Awaitable, Callable, Dict, List, Optional
18logger = logging.getLogger(__name__)
21class TaskStatus(Enum):
23 """任务状态枚举。"""
25 PENDING = "pending"
26 RUNNING = "running"
27 SUCCESS = "success"
28 FAILED = "failed"
29 TIMEOUT = "timeout"
30 RETRYING = "retrying"
31 CANCELLED = "cancelled"
34class BatchStrategy(Enum):
35 """Execution strategy for batch tasks."""
37 PARALLEL = "parallel" # All tasks run concurrently (limited by max_concurrency)
38 SEQUENTIAL = "sequential" # One after another
39 SMART = "smart" # Dynamically adjust based on system load
42@dataclass
43class TaskSpec:
44 """Specification for a single task in a batch."""
46 task_id: str
47 coro_or_func: Callable[..., Awaitable[Any]]
48 args: tuple = ()
49 kwargs: Dict[str, Any] = field(default_factory=dict)
50 timeout: float = 60.0
51 max_retries: int = 0
52 metadata: Dict[str, Any] = field(default_factory=dict)
55@dataclass
56class TaskResult:
57 """Result of a single task execution."""
59 task_id: str
60 status: TaskStatus
61 result: Any = None
62 error: Optional[str] = None
63 duration_ms: float = 0.0
64 retries: int = 0
65 started_at: float = 0.0
66 finished_at: float = 0.0
68 @property
69 def success(self) -> bool:
70 return self.status == TaskStatus.SUCCESS
73@dataclass
74class BatchConfig:
75 """Configuration for AsyncBatchExecutor."""
77 max_concurrency: int = 5
78 default_timeout: float = 60.0
79 max_retries: int = 1
80 retry_delay: float = 1.0
81 strategy: BatchStrategy = BatchStrategy.PARALLEL
82 fail_fast: bool = False
83 collect_errors: bool = True
86@dataclass
87class BatchResult:
88 """Aggregated result of a batch execution."""
90 results: List[TaskResult] = field(default_factory=list)
91 total: int = 0
92 succeeded: int = 0
93 failed: int = 0
94 timed_out: int = 0
95 total_duration_ms: float = 0.0
96 started_at: float = 0.0
97 finished_at: float = 0.0
99 @property
100 def success_rate(self) -> float:
101 if self.total == 0:
102 return 0.0
103 return self.succeeded / self.total
105 @property
106 def all_success(self) -> bool:
107 return self.succeeded == self.total
109 def get_failed_ids(self) -> List[str]:
110 return [
111 r.task_id for r in self.results
112 if r.status in (TaskStatus.FAILED, TaskStatus.TIMEOUT)
113 ]
116class AsyncBatchExecutor:
117 """Concurrently dispatches multiple AgentOS tasks and aggregates results."""
119 def __init__(self, config: Optional[BatchConfig] = None):
120 self.config = config or BatchConfig()
121 self._semaphore: Optional[asyncio.Semaphore] = None
122 self._cancel_event: Optional[asyncio.Event] = None
124 async def execute(self, tasks: List[TaskSpec]) -> BatchResult:
125 """Execute a list of tasks and return aggregated results."""
126 if not tasks:
127 return BatchResult(total=0)
129 start = time.perf_counter()
130 self._semaphore = asyncio.Semaphore(self.config.max_concurrency)
131 self._cancel_event = asyncio.Event()
132 results: List[TaskResult] = []
134 if self.config.strategy == BatchStrategy.SEQUENTIAL:
135 for task in tasks:
136 result = await self._execute_one(task)
137 results.append(result)
138 if self.config.fail_fast and not result.success:
139 break
140 else:
141 # PARALLEL or SMART
142 tasks_coros = [self._execute_one(task) for task in tasks]
143 results = list(await asyncio.gather(*tasks_coros))
145 elapsed = (time.perf_counter() - start) * 1000
146 succeeded = sum(1 for r in results if r.status == TaskStatus.SUCCESS)
147 failed = sum(1 for r in results if r.status == TaskStatus.FAILED)
148 timed_out = sum(1 for r in results if r.status == TaskStatus.TIMEOUT)
150 return BatchResult(
151 results=results,
152 total=len(tasks),
153 succeeded=succeeded,
154 failed=failed,
155 timed_out=timed_out,
156 total_duration_ms=elapsed,
157 started_at=start,
158 finished_at=time.perf_counter(),
159 )
161 async def _execute_one(self, task: TaskSpec) -> TaskResult:
162 """Execute a single task with retry support."""
163 timeout = task.timeout or self.config.default_timeout
164 max_retries = task.max_retries or self.config.max_retries
165 retries = 0
167 assert self._semaphore is not None, "semaphore must be set before calling execute()"
168 async with self._semaphore:
169 while True:
170 started = time.perf_counter()
171 try:
172 coro = task.coro_or_func(*task.args, **task.kwargs)
173 result_value = await asyncio.wait_for(coro, timeout=timeout)
174 elapsed = (time.perf_counter() - started) * 1000
175 return TaskResult(
176 task_id=task.task_id,
177 status=TaskStatus.SUCCESS,
178 result=result_value,
179 duration_ms=elapsed,
180 retries=retries,
181 started_at=started,
182 finished_at=time.perf_counter(),
183 )
184 except asyncio.TimeoutError:
185 if retries < max_retries:
186 retries += 1
187 logger.warning(
188 f"Task '{task.task_id}' timed out (attempt {retries}/{max_retries}), retrying..."
189 )
190 await asyncio.sleep(self.config.retry_delay)
191 continue
192 elapsed = (time.perf_counter() - started) * 1000
193 return TaskResult(
194 task_id=task.task_id,
195 status=TaskStatus.TIMEOUT,
196 error=f"Timed out after {timeout}s (retries: {retries})",
197 duration_ms=elapsed,
198 retries=retries,
199 started_at=started,
200 finished_at=time.perf_counter(),
201 )
202 except Exception as e:
203 if retries < max_retries:
204 retries += 1
205 logger.warning(
206 f"Task '{task.task_id}' failed (attempt {retries}/{max_retries}): {e}"
207 )
208 await asyncio.sleep(self.config.retry_delay)
209 continue
210 elapsed = (time.perf_counter() - started) * 1000
211 return TaskResult(
212 task_id=task.task_id,
213 status=TaskStatus.FAILED,
214 error=str(e),
215 duration_ms=elapsed,
216 retries=retries,
217 started_at=started,
218 finished_at=time.perf_counter(),
219 )
221 def cancel_all(self) -> None:
222 """Cancel all pending tasks."""
223 if self._cancel_event:
224 self._cancel_event.set()