Coverage for agentos/tools/orchestrator.py: 46%

301 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 18:40 +0800

1""" 

2AgentOS v1.1.7 — 工具链编排引擎(Checkpoint/恢复)。 

3基因来源: Airflow DAG + LangChain Tool Composition 

4 

5支持: 

6- 顺序链 (chain): 工具A → 工具B → 工具C 

7- 并行分支 (parallel): A + B 同时 → C 

8- 条件执行 (conditional): if X then A else B 

9- 重试策略 (retry): 指数退避 / 固定间隔 

10- 超时控制 (timeout): 单工具 / 全链 

11- Checkpoint/恢复: 长时间DAG断点保存与续跑 

12""" 

13 

14from __future__ import annotations 

15 

16import asyncio 

17import json 

18import time 

19from dataclasses import dataclass, field 

20from enum import Enum 

21from typing import Any, Callable, Coroutine 

22 

23 

24# ── Core Types ────────────────────────────────── 

25 

26class NodeState(str, Enum): 

27 

28 """DAG 节点状态。""" 

29 

30 PENDING = "pending" 

31 RUNNING = "running" 

32 SUCCESS = "success" 

33 FAILED = "failed" 

34 SKIPPED = "skipped" 

35 TIMEOUT = "timeout" 

36 

37 

38@dataclass 

39class NodeResult: 

40 """DAG 节点执行结果。""" 

41 node_id: str 

42 state: NodeState 

43 output: Any = None 

44 error: str | None = None 

45 duration_ms: float = 0.0 

46 retries: int = 0 

47 

48 

49@dataclass 

50class DAGResult: 

51 """DAG 执行结果。""" 

52 nodes: dict[str, NodeResult] 

53 final_output: Any = None 

54 total_duration_ms: float = 0.0 

55 success: bool = False 

56 error: str | None = None 

57 

58 

59@dataclass 

60class RetryPolicy: 

61 """重试策略类。""" 

62 max_retries: int = 3 

63 base_delay: float = 1.0 # seconds 

64 max_delay: float = 30.0 # seconds 

65 backoff: str = "exponential" # exponential | fixed | linear 

66 retry_on: tuple = (Exception,) 

67 

68 

69# ── DAG Node Types ───────────────────────────── 

70 

71@dataclass 

72class ToolNode: 

73 """工具执行节点。""" 

74 tool_name: str 

75 tool_args: dict[str, Any] = field(default_factory=dict) 

76 depends_on: list[str] = field(default_factory=list) # upstream node IDs 

77 timeout: float = 60.0 

78 retry: RetryPolicy | None = None 

79 

80 # Optional transform: map upstream outputs → tool_args 

81 input_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None 

82 

83 

84@dataclass 

85class ConditionNode: 

86 """条件分支节点。""" 

87 condition: Callable[[dict[str, Any]], str] # returns target node_id 

88 depends_on: list[str] = field(default_factory=list) 

89 

90 

91@dataclass 

92class ParallelGroup: 

93 """并行执行组 — 所有节点同时执行。""" 

94 node_ids: list[str] 

95 depends_on: list[str] = field(default_factory=list) 

96 max_concurrency: int = 5 

97 

98 

99@dataclass 

100class DAGSpec: 

101 """DAG编排规格。""" 

102 name: str 

103 nodes: dict[str, ToolNode] = field(default_factory=dict) 

104 parallels: list[ParallelGroup] = field(default_factory=list) 

105 conditions: dict[str, ConditionNode] = field(default_factory=dict) 

106 entry: list[str] = field(default_factory=list) 

107 global_timeout: float = 300.0 

108 

109 

110# ── Checkpoint Data (v1.1.7) ────────────────── 

111 

112@dataclass 

113class CheckpointData: 

114 """DAG执行快照,支持断点续跑。""" 

115 

116 dag_name: str 

117 completed_nodes: dict[str, dict] = field(default_factory=dict) 

118 pending_nodes: list[str] = field(default_factory=list) 

119 timestamp: float = 0.0 

120 version: str = "1.0" 

121 

122 def to_dict(self) -> dict: 

123 return { 

124 "dag_name": self.dag_name, 

125 "completed_nodes": self.completed_nodes, 

126 "pending_nodes": self.pending_nodes, 

127 "timestamp": self.timestamp, 

128 "version": self.version, 

129 } 

130 

131 @classmethod 

132 def from_dict(cls, data: dict) -> "CheckpointData": 

133 return cls( 

134 dag_name=data.get("dag_name", ""), 

135 completed_nodes=data.get("completed_nodes", {}), 

136 pending_nodes=data.get("pending_nodes", []), 

137 timestamp=data.get("timestamp", 0.0), 

138 version=data.get("version", "1.0"), 

139 ) 

140 

141 def to_json(self) -> str: 

142 return json.dumps(self.to_dict(), ensure_ascii=False, indent=2) 

143 

144 @classmethod 

145 def from_json(cls, json_str: str) -> "CheckpointData": 

146 return cls.from_dict(json.loads(json_str)) 

147 

148 

149# ── Orchestrator Engine ──────────────────────── 

150 

151class ToolOrchestrator: 

152 """ 

153 工具链编排引擎 — DAG执行、并行调度、条件分支、Checkpoint恢复。 

154 """ 

155 

156 def __init__(self, tool_registry: Any): 

157 self.registry = tool_registry 

158 self._results: dict[str, NodeResult] = {} 

159 self._aborted: bool = False 

160 

161 async def execute(self, dag: DAGSpec) -> DAGResult: 

162 """执行完整DAG。若从checkpoint恢复,会跳过已完成节点。""" 

163 was_restored = bool(self._results) 

164 if not was_restored: 

165 self._results = {} 

166 self._aborted = False 

167 start = time.time() 

168 

169 try: 

170 await asyncio.wait_for( 

171 self._execute_entry(dag), 

172 timeout=dag.global_timeout, 

173 ) 

174 except asyncio.TimeoutError: 

175 return DAGResult( 

176 nodes=self._results, 

177 total_duration_ms=(time.time() - start) * 1000, 

178 error=f"DAG timeout ({dag.global_timeout}s)", 

179 ) 

180 

181 duration_ms = (time.time() - start) * 1000 

182 success = all( 

183 r.state == NodeState.SUCCESS 

184 for r in self._results.values() 

185 if not self._is_conditional(dag, r.node_id) 

186 ) 

187 

188 # Final output = last entry node (or last successful) 

189 last = None 

190 for nid in reversed(dag.entry): 

191 if nid in self._results and self._results[nid].state == NodeState.SUCCESS: 

192 last = self._results[nid].output 

193 break 

194 

195 return DAGResult( 

196 nodes=self._results, 

197 final_output=last, 

198 total_duration_ms=duration_ms, 

199 success=success, 

200 ) 

201 

202 async def _execute_entry(self, dag: DAGSpec): 

203 """递归执行入口节点。""" 

204 pending = set(dag.entry or dag.nodes.keys()) 

205 # v1.1.7: 跳过checkpoint恢复后已完成的节点,并发现其下游节点 

206 completed = {nid for nid in pending if nid in self._results and self._results[nid].state == NodeState.SUCCESS} 

207 pending -= completed 

208 for nid in completed: 

209 for other_nid, other_node in dag.nodes.items(): 

210 if nid in other_node.depends_on and other_nid not in pending and other_nid not in self._results: 

211 pending.add(other_nid) 

212 

213 while pending: 

214 # Find nodes ready to execute (all deps satisfied) 

215 ready = [] 

216 for nid in list(pending): 

217 node = dag.nodes.get(nid) 

218 if not node: 

219 continue 

220 if self._deps_ready(node.depends_on): 

221 ready.append(nid) 

222 

223 if not ready: 

224 # Check for deadlocks 

225 stuck = [nid for nid in pending if not self._can_proceed(dag, nid)] 

226 if stuck: 

227 for nid in stuck: 

228 self._results[nid] = NodeResult( 

229 node_id=nid, state=NodeState.FAILED, 

230 error="Deadlock: dependencies not met", 

231 ) 

232 break 

233 await asyncio.sleep(0.01) 

234 continue 

235 

236 # Execute ready nodes (parallel groups first) 

237 parallel_nodes = [] 

238 sequential_nodes = [] 

239 for nid in ready: 

240 if any(nid in pg.node_ids for pg in dag.parallels): 

241 parallel_nodes.append(nid) 

242 else: 

243 sequential_nodes.append(nid) 

244 

245 # Run sequential nodes concurrently 

246 tasks = [] 

247 for nid in sequential_nodes: 

248 tasks.append(self._run_node(dag, nid)) 

249 if tasks: 

250 await asyncio.gather(*tasks) 

251 

252 # Run parallel groups 

253 for pg in dag.parallels: 

254 group_ready = [nid for nid in pg.node_ids if nid in ready] 

255 if group_ready: 

256 sem = asyncio.Semaphore(pg.max_concurrency) 

257 async def bounded(nid): 

258 async with sem: 

259 await self._run_node(dag, nid) 

260 await asyncio.gather(*[bounded(nid) for nid in group_ready]) 

261 

262 pending -= set(ready) 

263 

264 # v1.1.7: 发现已完成节点的下游节点(支持checkpoint恢复 & 多节点链) 

265 for nid in list(ready): 

266 for other_nid, other_node in dag.nodes.items(): 

267 if nid in other_node.depends_on and other_nid not in pending and other_nid not in self._results: 

268 pending.add(other_nid) 

269 

270 # Process conditions 

271 for cid, cond in dag.conditions.items(): 

272 if self._deps_ready(cond.depends_on): 

273 upstream = {nid: self._results.get(nid) for nid in cond.depends_on} 

274 target = cond.condition(upstream) 

275 if target and target in dag.nodes: 

276 pending.add(target) 

277 

278 async def _run_node(self, dag: DAGSpec, nid: str) -> NodeResult: 

279 """执行单个节点(带重试)。""" 

280 node = dag.nodes.get(nid) 

281 if not node: 

282 return NodeResult(nid, NodeState.FAILED, error=f"Unknown node: {nid}") 

283 

284 # Gather upstream outputs 

285 upstream = {} 

286 for dep in node.depends_on: 

287 dep_result = self._results.get(dep) 

288 if dep_result and dep_result.state == NodeState.SUCCESS: 

289 upstream[dep] = dep_result.output 

290 

291 # Transform inputs if needed 

292 args = dict(node.tool_args) 

293 if node.input_transform and upstream: 

294 try: 

295 transformed = node.input_transform(upstream) 

296 args.update(transformed) 

297 except Exception as e: 

298 result = NodeResult(nid, NodeState.FAILED, error=f"Input transform error: {e}") 

299 self._results[nid] = result 

300 return result 

301 

302 # Execute with retry 

303 retry_policy = node.retry or RetryPolicy(max_retries=0) 

304 last_error = None 

305 

306 for attempt in range(retry_policy.max_retries + 1): 

307 try: 

308 step_start = time.time() 

309 output = await asyncio.wait_for( 

310 self._execute_tool(node.tool_name, args, upstream), 

311 timeout=node.timeout, 

312 ) 

313 duration_ms = (time.time() - step_start) * 1000 

314 result = NodeResult( 

315 node_id=nid, state=NodeState.SUCCESS, 

316 output=output, duration_ms=duration_ms, 

317 retries=attempt, 

318 ) 

319 self._results[nid] = result 

320 return result 

321 except asyncio.TimeoutError: 

322 last_error = f"Tool timeout ({node.timeout}s)" 

323 result = NodeResult(nid, NodeState.TIMEOUT, error=last_error, retries=attempt) 

324 except Exception as e: 

325 last_error = str(e) 

326 if attempt < retry_policy.max_retries: 

327 delay = self._calc_retry_delay(retry_policy, attempt) 

328 await asyncio.sleep(delay) 

329 

330 result = NodeResult(nid, NodeState.FAILED, error=last_error, retries=retry_policy.max_retries) 

331 self._results[nid] = result 

332 return result 

333 

334 async def _execute_tool( 

335 self, 

336 tool_name: str, 

337 args: dict, 

338 upstream: dict[str, Any], 

339 ) -> Any: 

340 """执行具体工具。""" 

341 tool = self.registry.get(tool_name) 

342 if not tool: 

343 raise ValueError(f"Tool not found: {tool_name}") 

344 

345 # Inject upstream results into args 

346 full_args = {**args} 

347 for dep_id, dep_output in upstream.items(): 

348 full_args[f"_{dep_id}"] = dep_output 

349 full_args["_upstream"] = upstream 

350 

351 if asyncio.iscoroutinefunction(tool.execute): 

352 return await tool.execute(**full_args) 

353 else: 

354 return tool.execute(**full_args) 

355 

356 def _deps_ready(self, deps: list[str]) -> bool: 

357 """所有依赖是否成功完成。""" 

358 for dep in deps: 

359 r = self._results.get(dep) 

360 if not r or r.state != NodeState.SUCCESS: 

361 return False 

362 return True 

363 

364 def _can_proceed(self, dag: DAGSpec, nid: str) -> bool: 

365 """节点是否有可能继续执行(未永久失败)。""" 

366 node = dag.nodes.get(nid) 

367 if not node: 

368 return False 

369 for dep in node.depends_on: 

370 r = self._results.get(dep) 

371 if r and r.state in (NodeState.FAILED, NodeState.TIMEOUT): 

372 return False 

373 return True 

374 

375 def _is_conditional(self, dag: DAGSpec, nid: str) -> bool: 

376 return nid in dag.conditions 

377 

378 def _calc_retry_delay(self, policy: RetryPolicy, attempt: int) -> float: 

379 if policy.backoff == "fixed": 

380 return policy.base_delay 

381 elif policy.backoff == "linear": 

382 return min(policy.base_delay * (attempt + 1), policy.max_delay) 

383 else: # exponential 

384 return min(policy.base_delay * (2 ** attempt), policy.max_delay) 

385 

386 @property 

387 def results(self) -> dict[str, NodeResult]: 

388 return dict(self._results) 

389 

390 # ── Checkpoint / Restore (v1.1.7) ────────────── 

391 

392 def checkpoint(self, dag: DAGSpec) -> CheckpointData: 

393 """保存当前DAG执行进度为快照。""" 

394 completed = {} 

395 for nid, result in self._results.items(): 

396 completed[nid] = { 

397 "node_id": result.node_id, 

398 "state": result.state.value, 

399 "output": result.output, 

400 "error": result.error, 

401 "duration_ms": result.duration_ms, 

402 "retries": result.retries, 

403 } 

404 # 未完成的节点(在dag中但不在results里) 

405 pending = [nid for nid in dag.nodes if nid not in self._results] 

406 return CheckpointData( 

407 dag_name=dag.name, 

408 completed_nodes=completed, 

409 pending_nodes=pending, 

410 timestamp=time.time(), 

411 ) 

412 

413 def restore_from_checkpoint(self, dag: DAGSpec, cp: CheckpointData) -> dict[str, NodeResult]: 

414 """从快照恢复已完成的节点状态,返回可继续执行的results基础。""" 

415 restored = {} 

416 for nid, data in cp.completed_nodes.items(): 

417 restored[nid] = NodeResult( 

418 node_id=data["node_id"], 

419 state=NodeState(data["state"]), 

420 output=data.get("output"), 

421 error=data.get("error"), 

422 duration_ms=data.get("duration_ms", 0), 

423 retries=data.get("retries", 0), 

424 ) 

425 self._results = restored 

426 return restored 

427 

428 async def execute_with_checkpoint( 

429 self, 

430 dag: DAGSpec, 

431 checkpoint_callback: Callable[[CheckpointData], None] = None, 

432 checkpoint_interval: float = 60.0, 

433 ) -> dict[str, NodeResult]: 

434 """ 

435 执行DAG并周期保存快照。超时或异常时保留已执行结果。 

436 

437 Args: 

438 dag: DAG规格 

439 checkpoint_callback: 快照回调,收到最新的CheckpointData 

440 checkpoint_interval: 快照保存间隔(秒) 

441 Returns: 

442 最终执行结果 

443 """ 

444 last_checkpoint_time = 0.0 

445 try: 

446 await self.execute(dag) 

447 except (asyncio.TimeoutError, Exception) as e: 

448 # 异常时保存当前状态 

449 cp = self.checkpoint(dag) 

450 if checkpoint_callback: 

451 checkpoint_callback(cp) 

452 raise 

453 else: 

454 # 最终完成快照 

455 cp = self.checkpoint(dag) 

456 if checkpoint_callback: 

457 checkpoint_callback(cp) 

458 return self._results 

459 

460 

461# ── DAG Builder (Fluent API) ──────────────────── 

462 

463class DAGBuilder: 

464 """流式构建DAG。""" 

465 

466 def __init__(self, name: str = "unnamed"): 

467 self.name = name 

468 self._nodes: dict[str, ToolNode] = {} 

469 self._parallels: list[ParallelGroup] = [] 

470 self._conditions: dict[str, ConditionNode] = {} 

471 self._entry: list[str] = [] 

472 

473 def node( 

474 self, 

475 node_id: str, 

476 tool_name: str, 

477 tool_args: dict | None = None, 

478 depends_on: list[str] | None = None, 

479 timeout: float = 60.0, 

480 retry: RetryPolicy | None = None, 

481 input_transform: Callable | None = None, 

482 ) -> "DAGBuilder": 

483 self._nodes[node_id] = ToolNode( 

484 tool_name=tool_name, 

485 tool_args=tool_args or {}, 

486 depends_on=depends_on or [], 

487 timeout=timeout, 

488 retry=retry, 

489 input_transform=input_transform, 

490 ) 

491 if not depends_on: 

492 self._entry.append(node_id) 

493 return self 

494 

495 def parallel(self, node_ids: list[str], depends_on: list[str] | None = None, max_concurrency: int = 5) -> "DAGBuilder": 

496 self._parallels.append(ParallelGroup( 

497 node_ids=node_ids, 

498 depends_on=depends_on or [], 

499 max_concurrency=max_concurrency, 

500 )) 

501 return self 

502 

503 def condition(self, cond_id: str, condition: Callable, depends_on: list[str]) -> "DAGBuilder": 

504 self._conditions[cond_id] = ConditionNode( 

505 condition=condition, 

506 depends_on=depends_on, 

507 ) 

508 return self 

509 

510 def build(self, global_timeout: float = 300.0) -> DAGSpec: 

511 return DAGSpec( 

512 name=self.name, 

513 nodes=self._nodes, 

514 parallels=self._parallels, 

515 conditions=self._conditions, 

516 entry=self._entry, 

517 global_timeout=global_timeout, 

518 ) 

519 

520 

521# ── Pre-built Chains ──────────────────────────── 

522 

523def chain_builder(name: str, tool_names: list[str]) -> DAGSpec: 

524 """构建简单顺序链。""" 

525 builder = DAGBuilder(name) 

526 prev = None 

527 for i, tool_name in enumerate(tool_names): 

528 nid = f"step_{i}" 

529 deps = [f"step_{i - 1}"] if i > 0 else [] 

530 builder.node(nid, tool_name, depends_on=deps) 

531 return builder.build() 

532 

533 

534def parallel_then_merge(name: str, parallel_tools: list[str], merge_tool: str) -> DAGSpec: 

535 """构建 并行→合并 模式。""" 

536 builder = DAGBuilder(name) 

537 pids = [] 

538 for i, tool_name in enumerate(parallel_tools): 

539 nid = f"par_{i}" 

540 builder.node(nid, tool_name) 

541 pids.append(nid) 

542 builder.parallel(pids) 

543 builder.node("merge", merge_tool, depends_on=pids) 

544 return builder.build() 

545 

546 

547def if_then_else(name: str, check_tool: str, true_tool: str, false_tool: str) -> DAGSpec: 

548 """构建 if-then-else 条件分支。""" 

549 builder = DAGBuilder(name) 

550 builder.node("check", check_tool) 

551 builder.condition("cond", lambda up: "true_branch" if up.get("check", {}).get("output") else "false_branch", depends_on=["check"]) 

552 builder.node("true_branch", true_tool, depends_on=["check"]) 

553 builder.node("false_branch", false_tool, depends_on=["check"]) 

554 return builder.build()