Coverage for agentos/concurrency/batch.py: 55%

128 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AsyncBatchExecutor — Concurrent agent task dispatch with configurable 

3parallelism, timeout, retry, and result aggregation. 

4 

5Designed for running multiple AgentOS tasks in parallel (e.g., batch 

6evaluation, multi-model comparison, bulk processing). 

7""" 

8 

9from __future__ import annotations 

10 

11import asyncio 

12import logging 

13import time 

14from dataclasses import dataclass, field 

15from enum import Enum 

16from typing import Any, Awaitable, Callable, Dict, List, Optional 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21class TaskStatus(Enum): 

22 

23 """任务状态枚举。""" 

24 

25 PENDING = "pending" 

26 RUNNING = "running" 

27 SUCCESS = "success" 

28 FAILED = "failed" 

29 TIMEOUT = "timeout" 

30 RETRYING = "retrying" 

31 CANCELLED = "cancelled" 

32 

33 

34class BatchStrategy(Enum): 

35 """Execution strategy for batch tasks.""" 

36 

37 PARALLEL = "parallel" # All tasks run concurrently (limited by max_concurrency) 

38 SEQUENTIAL = "sequential" # One after another 

39 SMART = "smart" # Dynamically adjust based on system load 

40 

41 

42@dataclass 

43class TaskSpec: 

44 """Specification for a single task in a batch.""" 

45 

46 task_id: str 

47 coro_or_func: Callable[..., Awaitable[Any]] 

48 args: tuple = () 

49 kwargs: Dict[str, Any] = field(default_factory=dict) 

50 timeout: float = 60.0 

51 max_retries: int = 0 

52 metadata: Dict[str, Any] = field(default_factory=dict) 

53 

54 

55@dataclass 

56class TaskResult: 

57 """Result of a single task execution.""" 

58 

59 task_id: str 

60 status: TaskStatus 

61 result: Any = None 

62 error: Optional[str] = None 

63 duration_ms: float = 0.0 

64 retries: int = 0 

65 started_at: float = 0.0 

66 finished_at: float = 0.0 

67 

68 @property 

69 def success(self) -> bool: 

70 return self.status == TaskStatus.SUCCESS 

71 

72 

73@dataclass 

74class BatchConfig: 

75 """Configuration for AsyncBatchExecutor.""" 

76 

77 max_concurrency: int = 5 

78 default_timeout: float = 60.0 

79 max_retries: int = 1 

80 retry_delay: float = 1.0 

81 strategy: BatchStrategy = BatchStrategy.PARALLEL 

82 fail_fast: bool = False 

83 collect_errors: bool = True 

84 

85 

86@dataclass 

87class BatchResult: 

88 """Aggregated result of a batch execution.""" 

89 

90 results: List[TaskResult] = field(default_factory=list) 

91 total: int = 0 

92 succeeded: int = 0 

93 failed: int = 0 

94 timed_out: int = 0 

95 total_duration_ms: float = 0.0 

96 started_at: float = 0.0 

97 finished_at: float = 0.0 

98 

99 @property 

100 def success_rate(self) -> float: 

101 if self.total == 0: 

102 return 0.0 

103 return self.succeeded / self.total 

104 

105 @property 

106 def all_success(self) -> bool: 

107 return self.succeeded == self.total 

108 

109 def get_failed_ids(self) -> List[str]: 

110 return [ 

111 r.task_id for r in self.results 

112 if r.status in (TaskStatus.FAILED, TaskStatus.TIMEOUT) 

113 ] 

114 

115 

116class AsyncBatchExecutor: 

117 """Concurrently dispatches multiple AgentOS tasks and aggregates results.""" 

118 

119 def __init__(self, config: Optional[BatchConfig] = None): 

120 self.config = config or BatchConfig() 

121 self._semaphore: Optional[asyncio.Semaphore] = None 

122 self._cancel_event: Optional[asyncio.Event] = None 

123 

124 async def execute(self, tasks: List[TaskSpec]) -> BatchResult: 

125 """Execute a list of tasks and return aggregated results.""" 

126 if not tasks: 

127 return BatchResult(total=0) 

128 

129 start = time.perf_counter() 

130 self._semaphore = asyncio.Semaphore(self.config.max_concurrency) 

131 self._cancel_event = asyncio.Event() 

132 results: List[TaskResult] = [] 

133 

134 if self.config.strategy == BatchStrategy.SEQUENTIAL: 

135 for task in tasks: 

136 result = await self._execute_one(task) 

137 results.append(result) 

138 if self.config.fail_fast and not result.success: 

139 break 

140 else: 

141 # PARALLEL or SMART 

142 tasks_coros = [self._execute_one(task) for task in tasks] 

143 results = list(await asyncio.gather(*tasks_coros)) 

144 

145 elapsed = (time.perf_counter() - start) * 1000 

146 succeeded = sum(1 for r in results if r.status == TaskStatus.SUCCESS) 

147 failed = sum(1 for r in results if r.status == TaskStatus.FAILED) 

148 timed_out = sum(1 for r in results if r.status == TaskStatus.TIMEOUT) 

149 

150 return BatchResult( 

151 results=results, 

152 total=len(tasks), 

153 succeeded=succeeded, 

154 failed=failed, 

155 timed_out=timed_out, 

156 total_duration_ms=elapsed, 

157 started_at=start, 

158 finished_at=time.perf_counter(), 

159 ) 

160 

161 async def _execute_one(self, task: TaskSpec) -> TaskResult: 

162 """Execute a single task with retry support.""" 

163 timeout = task.timeout or self.config.default_timeout 

164 max_retries = task.max_retries or self.config.max_retries 

165 retries = 0 

166 

167 assert self._semaphore is not None, "semaphore must be set before calling execute()" 

168 async with self._semaphore: 

169 while True: 

170 started = time.perf_counter() 

171 try: 

172 coro = task.coro_or_func(*task.args, **task.kwargs) 

173 result_value = await asyncio.wait_for(coro, timeout=timeout) 

174 elapsed = (time.perf_counter() - started) * 1000 

175 return TaskResult( 

176 task_id=task.task_id, 

177 status=TaskStatus.SUCCESS, 

178 result=result_value, 

179 duration_ms=elapsed, 

180 retries=retries, 

181 started_at=started, 

182 finished_at=time.perf_counter(), 

183 ) 

184 except asyncio.TimeoutError: 

185 if retries < max_retries: 

186 retries += 1 

187 logger.warning( 

188 f"Task '{task.task_id}' timed out (attempt {retries}/{max_retries}), retrying..." 

189 ) 

190 await asyncio.sleep(self.config.retry_delay) 

191 continue 

192 elapsed = (time.perf_counter() - started) * 1000 

193 return TaskResult( 

194 task_id=task.task_id, 

195 status=TaskStatus.TIMEOUT, 

196 error=f"Timed out after {timeout}s (retries: {retries})", 

197 duration_ms=elapsed, 

198 retries=retries, 

199 started_at=started, 

200 finished_at=time.perf_counter(), 

201 ) 

202 except Exception as e: 

203 if retries < max_retries: 

204 retries += 1 

205 logger.warning( 

206 f"Task '{task.task_id}' failed (attempt {retries}/{max_retries}): {e}" 

207 ) 

208 await asyncio.sleep(self.config.retry_delay) 

209 continue 

210 elapsed = (time.perf_counter() - started) * 1000 

211 return TaskResult( 

212 task_id=task.task_id, 

213 status=TaskStatus.FAILED, 

214 error=str(e), 

215 duration_ms=elapsed, 

216 retries=retries, 

217 started_at=started, 

218 finished_at=time.perf_counter(), 

219 ) 

220 

221 def cancel_all(self) -> None: 

222 """Cancel all pending tasks.""" 

223 if self._cancel_event: 

224 self._cancel_event.set()