Coverage for agentos/tools/orchestrator.py: 46%
301 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 18:40 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 18:40 +0800
1"""
2AgentOS v1.1.7 — 工具链编排引擎(Checkpoint/恢复)。
3基因来源: Airflow DAG + LangChain Tool Composition
5支持:
6- 顺序链 (chain): 工具A → 工具B → 工具C
7- 并行分支 (parallel): A + B 同时 → C
8- 条件执行 (conditional): if X then A else B
9- 重试策略 (retry): 指数退避 / 固定间隔
10- 超时控制 (timeout): 单工具 / 全链
11- Checkpoint/恢复: 长时间DAG断点保存与续跑
12"""
14from __future__ import annotations
16import asyncio
17import json
18import time
19from dataclasses import dataclass, field
20from enum import Enum
21from typing import Any, Callable, Coroutine
24# ── Core Types ──────────────────────────────────
26class NodeState(str, Enum):
28 """DAG 节点状态。"""
30 PENDING = "pending"
31 RUNNING = "running"
32 SUCCESS = "success"
33 FAILED = "failed"
34 SKIPPED = "skipped"
35 TIMEOUT = "timeout"
38@dataclass
39class NodeResult:
40 """DAG 节点执行结果。"""
41 node_id: str
42 state: NodeState
43 output: Any = None
44 error: str | None = None
45 duration_ms: float = 0.0
46 retries: int = 0
49@dataclass
50class DAGResult:
51 """DAG 执行结果。"""
52 nodes: dict[str, NodeResult]
53 final_output: Any = None
54 total_duration_ms: float = 0.0
55 success: bool = False
56 error: str | None = None
59@dataclass
60class RetryPolicy:
61 """重试策略类。"""
62 max_retries: int = 3
63 base_delay: float = 1.0 # seconds
64 max_delay: float = 30.0 # seconds
65 backoff: str = "exponential" # exponential | fixed | linear
66 retry_on: tuple = (Exception,)
69# ── DAG Node Types ─────────────────────────────
71@dataclass
72class ToolNode:
73 """工具执行节点。"""
74 tool_name: str
75 tool_args: dict[str, Any] = field(default_factory=dict)
76 depends_on: list[str] = field(default_factory=list) # upstream node IDs
77 timeout: float = 60.0
78 retry: RetryPolicy | None = None
80 # Optional transform: map upstream outputs → tool_args
81 input_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None
84@dataclass
85class ConditionNode:
86 """条件分支节点。"""
87 condition: Callable[[dict[str, Any]], str] # returns target node_id
88 depends_on: list[str] = field(default_factory=list)
91@dataclass
92class ParallelGroup:
93 """并行执行组 — 所有节点同时执行。"""
94 node_ids: list[str]
95 depends_on: list[str] = field(default_factory=list)
96 max_concurrency: int = 5
99@dataclass
100class DAGSpec:
101 """DAG编排规格。"""
102 name: str
103 nodes: dict[str, ToolNode] = field(default_factory=dict)
104 parallels: list[ParallelGroup] = field(default_factory=list)
105 conditions: dict[str, ConditionNode] = field(default_factory=dict)
106 entry: list[str] = field(default_factory=list)
107 global_timeout: float = 300.0
110# ── Checkpoint Data (v1.1.7) ──────────────────
112@dataclass
113class CheckpointData:
114 """DAG执行快照,支持断点续跑。"""
116 dag_name: str
117 completed_nodes: dict[str, dict] = field(default_factory=dict)
118 pending_nodes: list[str] = field(default_factory=list)
119 timestamp: float = 0.0
120 version: str = "1.0"
122 def to_dict(self) -> dict:
123 return {
124 "dag_name": self.dag_name,
125 "completed_nodes": self.completed_nodes,
126 "pending_nodes": self.pending_nodes,
127 "timestamp": self.timestamp,
128 "version": self.version,
129 }
131 @classmethod
132 def from_dict(cls, data: dict) -> "CheckpointData":
133 return cls(
134 dag_name=data.get("dag_name", ""),
135 completed_nodes=data.get("completed_nodes", {}),
136 pending_nodes=data.get("pending_nodes", []),
137 timestamp=data.get("timestamp", 0.0),
138 version=data.get("version", "1.0"),
139 )
141 def to_json(self) -> str:
142 return json.dumps(self.to_dict(), ensure_ascii=False, indent=2)
144 @classmethod
145 def from_json(cls, json_str: str) -> "CheckpointData":
146 return cls.from_dict(json.loads(json_str))
149# ── Orchestrator Engine ────────────────────────
151class ToolOrchestrator:
152 """
153 工具链编排引擎 — DAG执行、并行调度、条件分支、Checkpoint恢复。
154 """
156 def __init__(self, tool_registry: Any):
157 self.registry = tool_registry
158 self._results: dict[str, NodeResult] = {}
159 self._aborted: bool = False
161 async def execute(self, dag: DAGSpec) -> DAGResult:
162 """执行完整DAG。若从checkpoint恢复,会跳过已完成节点。"""
163 was_restored = bool(self._results)
164 if not was_restored:
165 self._results = {}
166 self._aborted = False
167 start = time.time()
169 try:
170 await asyncio.wait_for(
171 self._execute_entry(dag),
172 timeout=dag.global_timeout,
173 )
174 except asyncio.TimeoutError:
175 return DAGResult(
176 nodes=self._results,
177 total_duration_ms=(time.time() - start) * 1000,
178 error=f"DAG timeout ({dag.global_timeout}s)",
179 )
181 duration_ms = (time.time() - start) * 1000
182 success = all(
183 r.state == NodeState.SUCCESS
184 for r in self._results.values()
185 if not self._is_conditional(dag, r.node_id)
186 )
188 # Final output = last entry node (or last successful)
189 last = None
190 for nid in reversed(dag.entry):
191 if nid in self._results and self._results[nid].state == NodeState.SUCCESS:
192 last = self._results[nid].output
193 break
195 return DAGResult(
196 nodes=self._results,
197 final_output=last,
198 total_duration_ms=duration_ms,
199 success=success,
200 )
202 async def _execute_entry(self, dag: DAGSpec):
203 """递归执行入口节点。"""
204 pending = set(dag.entry or dag.nodes.keys())
205 # v1.1.7: 跳过checkpoint恢复后已完成的节点,并发现其下游节点
206 completed = {nid for nid in pending if nid in self._results and self._results[nid].state == NodeState.SUCCESS}
207 pending -= completed
208 for nid in completed:
209 for other_nid, other_node in dag.nodes.items():
210 if nid in other_node.depends_on and other_nid not in pending and other_nid not in self._results:
211 pending.add(other_nid)
213 while pending:
214 # Find nodes ready to execute (all deps satisfied)
215 ready = []
216 for nid in list(pending):
217 node = dag.nodes.get(nid)
218 if not node:
219 continue
220 if self._deps_ready(node.depends_on):
221 ready.append(nid)
223 if not ready:
224 # Check for deadlocks
225 stuck = [nid for nid in pending if not self._can_proceed(dag, nid)]
226 if stuck:
227 for nid in stuck:
228 self._results[nid] = NodeResult(
229 node_id=nid, state=NodeState.FAILED,
230 error="Deadlock: dependencies not met",
231 )
232 break
233 await asyncio.sleep(0.01)
234 continue
236 # Execute ready nodes (parallel groups first)
237 parallel_nodes = []
238 sequential_nodes = []
239 for nid in ready:
240 if any(nid in pg.node_ids for pg in dag.parallels):
241 parallel_nodes.append(nid)
242 else:
243 sequential_nodes.append(nid)
245 # Run sequential nodes concurrently
246 tasks = []
247 for nid in sequential_nodes:
248 tasks.append(self._run_node(dag, nid))
249 if tasks:
250 await asyncio.gather(*tasks)
252 # Run parallel groups
253 for pg in dag.parallels:
254 group_ready = [nid for nid in pg.node_ids if nid in ready]
255 if group_ready:
256 sem = asyncio.Semaphore(pg.max_concurrency)
257 async def bounded(nid):
258 async with sem:
259 await self._run_node(dag, nid)
260 await asyncio.gather(*[bounded(nid) for nid in group_ready])
262 pending -= set(ready)
264 # v1.1.7: 发现已完成节点的下游节点(支持checkpoint恢复 & 多节点链)
265 for nid in list(ready):
266 for other_nid, other_node in dag.nodes.items():
267 if nid in other_node.depends_on and other_nid not in pending and other_nid not in self._results:
268 pending.add(other_nid)
270 # Process conditions
271 for cid, cond in dag.conditions.items():
272 if self._deps_ready(cond.depends_on):
273 upstream = {nid: self._results.get(nid) for nid in cond.depends_on}
274 target = cond.condition(upstream)
275 if target and target in dag.nodes:
276 pending.add(target)
278 async def _run_node(self, dag: DAGSpec, nid: str) -> NodeResult:
279 """执行单个节点(带重试)。"""
280 node = dag.nodes.get(nid)
281 if not node:
282 return NodeResult(nid, NodeState.FAILED, error=f"Unknown node: {nid}")
284 # Gather upstream outputs
285 upstream = {}
286 for dep in node.depends_on:
287 dep_result = self._results.get(dep)
288 if dep_result and dep_result.state == NodeState.SUCCESS:
289 upstream[dep] = dep_result.output
291 # Transform inputs if needed
292 args = dict(node.tool_args)
293 if node.input_transform and upstream:
294 try:
295 transformed = node.input_transform(upstream)
296 args.update(transformed)
297 except Exception as e:
298 result = NodeResult(nid, NodeState.FAILED, error=f"Input transform error: {e}")
299 self._results[nid] = result
300 return result
302 # Execute with retry
303 retry_policy = node.retry or RetryPolicy(max_retries=0)
304 last_error = None
306 for attempt in range(retry_policy.max_retries + 1):
307 try:
308 step_start = time.time()
309 output = await asyncio.wait_for(
310 self._execute_tool(node.tool_name, args, upstream),
311 timeout=node.timeout,
312 )
313 duration_ms = (time.time() - step_start) * 1000
314 result = NodeResult(
315 node_id=nid, state=NodeState.SUCCESS,
316 output=output, duration_ms=duration_ms,
317 retries=attempt,
318 )
319 self._results[nid] = result
320 return result
321 except asyncio.TimeoutError:
322 last_error = f"Tool timeout ({node.timeout}s)"
323 result = NodeResult(nid, NodeState.TIMEOUT, error=last_error, retries=attempt)
324 except Exception as e:
325 last_error = str(e)
326 if attempt < retry_policy.max_retries:
327 delay = self._calc_retry_delay(retry_policy, attempt)
328 await asyncio.sleep(delay)
330 result = NodeResult(nid, NodeState.FAILED, error=last_error, retries=retry_policy.max_retries)
331 self._results[nid] = result
332 return result
334 async def _execute_tool(
335 self,
336 tool_name: str,
337 args: dict,
338 upstream: dict[str, Any],
339 ) -> Any:
340 """执行具体工具。"""
341 tool = self.registry.get(tool_name)
342 if not tool:
343 raise ValueError(f"Tool not found: {tool_name}")
345 # Inject upstream results into args
346 full_args = {**args}
347 for dep_id, dep_output in upstream.items():
348 full_args[f"_{dep_id}"] = dep_output
349 full_args["_upstream"] = upstream
351 if asyncio.iscoroutinefunction(tool.execute):
352 return await tool.execute(**full_args)
353 else:
354 return tool.execute(**full_args)
356 def _deps_ready(self, deps: list[str]) -> bool:
357 """所有依赖是否成功完成。"""
358 for dep in deps:
359 r = self._results.get(dep)
360 if not r or r.state != NodeState.SUCCESS:
361 return False
362 return True
364 def _can_proceed(self, dag: DAGSpec, nid: str) -> bool:
365 """节点是否有可能继续执行(未永久失败)。"""
366 node = dag.nodes.get(nid)
367 if not node:
368 return False
369 for dep in node.depends_on:
370 r = self._results.get(dep)
371 if r and r.state in (NodeState.FAILED, NodeState.TIMEOUT):
372 return False
373 return True
375 def _is_conditional(self, dag: DAGSpec, nid: str) -> bool:
376 return nid in dag.conditions
378 def _calc_retry_delay(self, policy: RetryPolicy, attempt: int) -> float:
379 if policy.backoff == "fixed":
380 return policy.base_delay
381 elif policy.backoff == "linear":
382 return min(policy.base_delay * (attempt + 1), policy.max_delay)
383 else: # exponential
384 return min(policy.base_delay * (2 ** attempt), policy.max_delay)
386 @property
387 def results(self) -> dict[str, NodeResult]:
388 return dict(self._results)
390 # ── Checkpoint / Restore (v1.1.7) ──────────────
392 def checkpoint(self, dag: DAGSpec) -> CheckpointData:
393 """保存当前DAG执行进度为快照。"""
394 completed = {}
395 for nid, result in self._results.items():
396 completed[nid] = {
397 "node_id": result.node_id,
398 "state": result.state.value,
399 "output": result.output,
400 "error": result.error,
401 "duration_ms": result.duration_ms,
402 "retries": result.retries,
403 }
404 # 未完成的节点(在dag中但不在results里)
405 pending = [nid for nid in dag.nodes if nid not in self._results]
406 return CheckpointData(
407 dag_name=dag.name,
408 completed_nodes=completed,
409 pending_nodes=pending,
410 timestamp=time.time(),
411 )
413 def restore_from_checkpoint(self, dag: DAGSpec, cp: CheckpointData) -> dict[str, NodeResult]:
414 """从快照恢复已完成的节点状态,返回可继续执行的results基础。"""
415 restored = {}
416 for nid, data in cp.completed_nodes.items():
417 restored[nid] = NodeResult(
418 node_id=data["node_id"],
419 state=NodeState(data["state"]),
420 output=data.get("output"),
421 error=data.get("error"),
422 duration_ms=data.get("duration_ms", 0),
423 retries=data.get("retries", 0),
424 )
425 self._results = restored
426 return restored
428 async def execute_with_checkpoint(
429 self,
430 dag: DAGSpec,
431 checkpoint_callback: Callable[[CheckpointData], None] = None,
432 checkpoint_interval: float = 60.0,
433 ) -> dict[str, NodeResult]:
434 """
435 执行DAG并周期保存快照。超时或异常时保留已执行结果。
437 Args:
438 dag: DAG规格
439 checkpoint_callback: 快照回调,收到最新的CheckpointData
440 checkpoint_interval: 快照保存间隔(秒)
441 Returns:
442 最终执行结果
443 """
444 last_checkpoint_time = 0.0
445 try:
446 await self.execute(dag)
447 except (asyncio.TimeoutError, Exception) as e:
448 # 异常时保存当前状态
449 cp = self.checkpoint(dag)
450 if checkpoint_callback:
451 checkpoint_callback(cp)
452 raise
453 else:
454 # 最终完成快照
455 cp = self.checkpoint(dag)
456 if checkpoint_callback:
457 checkpoint_callback(cp)
458 return self._results
461# ── DAG Builder (Fluent API) ────────────────────
463class DAGBuilder:
464 """流式构建DAG。"""
466 def __init__(self, name: str = "unnamed"):
467 self.name = name
468 self._nodes: dict[str, ToolNode] = {}
469 self._parallels: list[ParallelGroup] = []
470 self._conditions: dict[str, ConditionNode] = {}
471 self._entry: list[str] = []
473 def node(
474 self,
475 node_id: str,
476 tool_name: str,
477 tool_args: dict | None = None,
478 depends_on: list[str] | None = None,
479 timeout: float = 60.0,
480 retry: RetryPolicy | None = None,
481 input_transform: Callable | None = None,
482 ) -> "DAGBuilder":
483 self._nodes[node_id] = ToolNode(
484 tool_name=tool_name,
485 tool_args=tool_args or {},
486 depends_on=depends_on or [],
487 timeout=timeout,
488 retry=retry,
489 input_transform=input_transform,
490 )
491 if not depends_on:
492 self._entry.append(node_id)
493 return self
495 def parallel(self, node_ids: list[str], depends_on: list[str] | None = None, max_concurrency: int = 5) -> "DAGBuilder":
496 self._parallels.append(ParallelGroup(
497 node_ids=node_ids,
498 depends_on=depends_on or [],
499 max_concurrency=max_concurrency,
500 ))
501 return self
503 def condition(self, cond_id: str, condition: Callable, depends_on: list[str]) -> "DAGBuilder":
504 self._conditions[cond_id] = ConditionNode(
505 condition=condition,
506 depends_on=depends_on,
507 )
508 return self
510 def build(self, global_timeout: float = 300.0) -> DAGSpec:
511 return DAGSpec(
512 name=self.name,
513 nodes=self._nodes,
514 parallels=self._parallels,
515 conditions=self._conditions,
516 entry=self._entry,
517 global_timeout=global_timeout,
518 )
521# ── Pre-built Chains ────────────────────────────
523def chain_builder(name: str, tool_names: list[str]) -> DAGSpec:
524 """构建简单顺序链。"""
525 builder = DAGBuilder(name)
526 prev = None
527 for i, tool_name in enumerate(tool_names):
528 nid = f"step_{i}"
529 deps = [f"step_{i - 1}"] if i > 0 else []
530 builder.node(nid, tool_name, depends_on=deps)
531 return builder.build()
534def parallel_then_merge(name: str, parallel_tools: list[str], merge_tool: str) -> DAGSpec:
535 """构建 并行→合并 模式。"""
536 builder = DAGBuilder(name)
537 pids = []
538 for i, tool_name in enumerate(parallel_tools):
539 nid = f"par_{i}"
540 builder.node(nid, tool_name)
541 pids.append(nid)
542 builder.parallel(pids)
543 builder.node("merge", merge_tool, depends_on=pids)
544 return builder.build()
547def if_then_else(name: str, check_tool: str, true_tool: str, false_tool: str) -> DAGSpec:
548 """构建 if-then-else 条件分支。"""
549 builder = DAGBuilder(name)
550 builder.node("check", check_tool)
551 builder.condition("cond", lambda up: "true_branch" if up.get("check", {}).get("output") else "false_branch", depends_on=["check"])
552 builder.node("true_branch", true_tool, depends_on=["check"])
553 builder.node("false_branch", false_tool, depends_on=["check"])
554 return builder.build()