Coverage for agentos/orchestration/task_decomposer.py: 28%
279 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
1"""
2AgentOS v1.14.7 — Intelligent Task Decomposer 2.0.
4Replaces the simplistic 234-line single-prompt decomposer with:
5- DAG cycle detection (Kahn's + DFS fallback)
6- Dynamic re-planning on partial failure
7- Task dependency validation
8- Parallelism detection (independent sub-tasks)
9- Confidence scoring per sub-task
10- Observable decomposition trace
12Architecture:
13 TaskInput → Decomposer.decompose() → TaskDAG
14 Partial failure → Decomposer.replan(failed_node) → new_path
15"""
17from __future__ import annotations
19import json
20import logging
21import uuid
22from collections import deque
23from dataclasses import dataclass, field
24from enum import Enum
25from typing import Any, Callable, Dict, List, Optional, Set, Tuple
27logger = logging.getLogger(__name__)
30# ── Types ────────────────────────────────────
33class TaskNodeStatus(str, Enum):
34 PENDING = "pending"
35 RUNNING = "running"
36 COMPLETED = "completed"
37 FAILED = "failed"
38 SKIPPED = "skipped"
41class DecompositionStrategy(str, Enum):
42 """分解策略。"""
43 TOP_DOWN = "top_down" # 从目标逐层拆解
44 BOTTOM_UP = "bottom_up" # 从子任务聚合
45 RECURSIVE = "recursive" # 递归分解直到原子任务
46 HEURISTIC = "heuristic" # 基于规则/模式匹配
49@dataclass
50class TaskEdge:
51 """DAG 边:from_node 完成后才能执行 to_node。"""
52 from_node: str
53 to_node: str
54 dependency_type: str = "hard" # hard / soft
57@dataclass
58class TaskNode:
59 """DAG 节点:单个可执行单元。"""
60 id: str = field(default_factory=lambda: f"tn-{uuid.uuid4().hex[:8]}")
61 description: str = ""
62 input_schema: dict = field(default_factory=dict)
63 output_schema: dict = field(default_factory=dict)
64 agent_type: str = "default" # 推荐执行 Agent 类型
65 estimated_duration_s: float = 0.0
66 confidence: float = 1.0 # 0~1,分解置信度
67 retry_policy: str = "once" # once / retry_n / fallback
68 max_retries: int = 1
69 status: TaskNodeStatus = TaskNodeStatus.PENDING
70 result: Any = None
71 error: str = ""
74@dataclass
75class TaskDAG:
76 """完整的任务 DAG。"""
77 dag_id: str = field(default_factory=lambda: f"dag-{uuid.uuid4().hex[:8]}")
78 root_task: str = "" # 原始任务描述
79 nodes: Dict[str, TaskNode] = field(default_factory=dict)
80 edges: List[TaskEdge] = field(default_factory=list)
81 strategy: DecompositionStrategy = DecompositionStrategy.TOP_DOWN
82 metadata: Dict[str, Any] = field(default_factory=dict)
83 created_at: float = 0.0
85 def in_degree_map(self) -> Dict[str, int]:
86 """计算每个节点的入度。"""
87 indeg: Dict[str, int] = {nid: 0 for nid in self.nodes}
88 for e in self.edges:
89 indeg[e.to_node] = indeg.get(e.to_node, 0) + 1
90 return indeg
92 def adjacency_map(self) -> Dict[str, List[str]]:
93 """邻接表。"""
94 adj: Dict[str, List[str]] = {nid: [] for nid in self.nodes}
95 for e in self.edges:
96 adj[e.from_node].append(e.to_node)
97 return adj
99 def topological_order(self) -> List[str]:
100 """Kahn 算法拓扑排序,遇循环抛 ValueError。"""
101 indeg = self.in_degree_map()
102 adj = self.adjacency_map()
103 queue = deque([nid for nid, d in indeg.items() if d == 0])
104 order: List[str] = []
106 while queue:
107 node = queue.popleft()
108 order.append(node)
109 for neighbor in adj.get(node, []):
110 indeg[neighbor] -= 1
111 if indeg[neighbor] == 0:
112 queue.append(neighbor)
114 if len(order) != len(self.nodes):
115 remaining = set(self.nodes) - set(order)
116 raise ValueError(
117 f"Cycle detected in DAG. {len(remaining)} nodes in cycle: "
118 f"{list(remaining)[:5]}..."
119 )
121 return order
123 def detect_cycles(self) -> Set[str]:
124 """检测并返回所有参与循环的节点 ID 集合。"""
125 indeg = self.in_degree_map()
126 adj = self.adjacency_map()
127 queue = deque([nid for nid, d in indeg.items() if d == 0])
128 acyclic: Set[str] = set()
130 while queue:
131 node = queue.popleft()
132 acyclic.add(node)
133 for neighbor in adj.get(node, []):
134 indeg[neighbor] -= 1
135 if indeg[neighbor] == 0:
136 queue.append(neighbor)
138 return set(self.nodes) - acyclic
140 def parallel_groups(self) -> List[List[str]]:
141 """按拓扑层级分组,同一组内可并行执行。"""
142 order = self.topological_order()
143 indeg = self.in_degree_map()
144 adj = self.adjacency_map()
146 groups: List[List[str]] = []
147 remaining = set(order)
149 while remaining:
150 # 所有入度为 0 的当前批次
151 batch = sorted([n for n in remaining if indeg.get(n, 0) == 0])
152 if not batch:
153 break
154 groups.append(batch)
155 for n in batch:
156 remaining.discard(n)
157 for neighbor in adj.get(n, []):
158 indeg[neighbor] -= 1
160 return groups
163@dataclass
164class DecompositionTrace:
165 """分解过程可观测性记录。"""
166 iteration: int
167 action: str # split / merge / refine / replan
168 node_before: Optional[TaskNode] = None
169 nodes_after: List[TaskNode] = field(default_factory=list)
170 reason: str = ""
173# ── Decomposer ───────────────────────────────
176class TaskDecomposer:
177 """智能任务分解器。
179 核心能力:
180 1. 将复杂任务拆解为可执行的 DAG
181 2. 检测并拒绝循环依赖
182 3. 在部分失败时动态重规划
183 4. 输出可观测的分解轨迹
185 Usage:
186 decomposer = TaskDecomposer()
187 dag = decomposer.decompose("从 10GB 日志中提取异常并生成日报")
188 order = dag.topological_order()
189 for nid in order:
190 execute(dag.nodes[nid])
191 # 部分节点失败后
192 new_dag = decomposer.replan(dag, failed_nodes=["tn-xxx"])
193 """
195 MAX_DEPTH = 8 # 最大递归深度
196 MIN_NODE_DURATION = 1.0 # 最小节点估算时长(秒),低于此不再分解
197 MAX_NODES = 50 # 最多节点数
199 def __init__(
200 self,
201 strategy: DecompositionStrategy = DecompositionStrategy.RECURSIVE,
202 llm_call: Optional[Callable] = None,
203 ):
204 self._strategy = strategy
205 self._llm_call = llm_call
206 self._trace: List[DecompositionTrace] = []
207 self._iteration = 0
209 # ── Public API ─────────────────────────
211 def decompose(
212 self,
213 task: str,
214 context: Optional[Dict[str, Any]] = None,
215 ) -> TaskDAG:
216 """将任务分解为 DAG。
218 Args:
219 task: 任务描述
220 context: 补充上下文(已训练的 Agent、可用工具等)
222 Returns:
223 完整的 TaskDAG
224 """
225 self._trace = []
226 self._iteration = 0
228 dag = TaskDAG(
229 root_task=task,
230 strategy=self._strategy,
231 metadata=context or {},
232 created_at=__import__("time").time(),
233 )
235 root = self._create_node(task, confidence=1.0)
236 dag.nodes[root.id] = root
238 # 递归分解
239 self._decompose_recursive(dag, root.id, depth=0)
241 # 验证无循环
242 cycles = dag.detect_cycles()
243 if cycles:
244 logger.warning(f"Decomposition produced cycle: {cycles}. Re-resolving.")
245 dag = self._break_cycles(dag, cycles)
247 return dag
249 def replan(
250 self,
251 dag: TaskDAG,
252 failed_nodes: List[str],
253 ) -> TaskDAG:
254 """在部分节点执行失败后动态重规划。
256 Args:
257 dag: 当前 DAG(含已执行的节点状态)
258 failed_nodes: 失败节点 ID 列表
260 Returns:
261 重规划后的新 DAG(仅影响失败节点及其下游)
262 """
263 self._iteration += 1
265 affected = self._collect_downstream(dag, failed_nodes)
267 new_dag = TaskDAG(
268 dag_id=f"{dag.dag_id}-replan-{self._iteration}",
269 root_task=dag.root_task,
270 metadata=dag.metadata,
271 created_at=__import__("time").time(),
272 )
274 # 保留不受影响的节点和边
275 for nid, node in dag.nodes.items():
276 if nid not in affected:
277 new_dag.nodes[nid] = node
279 for edge in dag.edges:
280 if edge.from_node not in affected and edge.to_node not in affected:
281 new_dag.edges.append(edge)
283 # 为每个失败节点构建替代路径
284 for nid in failed_nodes:
285 node = dag.nodes[nid]
286 original_desc = node.description
288 alt_node = self._create_node(
289 f"[Retry Plan] {original_desc}",
290 confidence=node.confidence * 0.8, # 降信心
291 )
292 alt_node.retry_policy = "retry_n"
293 alt_node.max_retries = node.max_retries + 1
294 new_dag.nodes[alt_node.id] = alt_node
296 # 重连边:失败节点上游 → 新节点,新节点 → 失败节点下游
297 incoming = [e for e in dag.edges if e.to_node == nid]
298 outgoing = [e for e in dag.edges if e.from_node == nid]
300 for e in incoming:
301 if e.from_node not in affected:
302 new_dag.edges.append(TaskEdge(from_node=e.from_node, to_node=alt_node.id))
303 for e in outgoing:
304 if e.to_node not in affected:
305 new_dag.edges.append(TaskEdge(from_node=alt_node.id, to_node=e.to_node))
307 self._trace.append(DecompositionTrace(
308 iteration=self._iteration,
309 action="replan",
310 node_before=node,
311 nodes_after=[alt_node],
312 reason=f"Node {nid} failed: {node.error or 'unknown'}",
313 ))
315 return new_dag
317 def get_trace(self) -> List[DecompositionTrace]:
318 """获取完整的分解轨迹(用于可观测性)。"""
319 return list(self._trace)
321 def validate_dag(self, dag: TaskDAG) -> Tuple[bool, str]:
322 """验证 DAG 的结构完整性。
324 Returns:
325 (is_valid, error_message)
326 """
327 # 空 DAG
328 if not dag.nodes:
329 return False, "DAG has no nodes"
331 # 循环检测
332 cycles = dag.detect_cycles()
333 if cycles:
334 return False, f"DAG contains cycles: {cycles}"
336 # 孤立节点检测
337 connected: Set[str] = set()
338 for e in dag.edges:
339 connected.add(e.from_node)
340 connected.add(e.to_node)
341 isolated = set(dag.nodes) - connected
342 if isolated and len(dag.nodes) > 1:
343 logger.warning(f"Isolated nodes: {isolated}")
345 # 拓扑可达性(至少存在一条从源到汇的路径)
346 try:
347 dag.topological_order()
348 except ValueError as e:
349 return False, str(e)
351 return True, "valid"
353 # ── Internal ────────────────────────────
355 def _create_node(self, description: str, confidence: float = 1.0) -> TaskNode:
356 return TaskNode(
357 description=description,
358 confidence=confidence,
359 estimated_duration_s=max(1.0, len(description.split()) * 0.5),
360 )
362 def _decompose_recursive(self, dag: TaskDAG, node_id: str, depth: int):
363 """递归分解节点直到达到原子粒度。"""
364 if depth >= self.MAX_DEPTH:
365 return
367 node = dag.nodes.get(node_id)
368 if not node:
369 return
371 # 判断是否继续分解
372 if self._is_atomic(node, depth):
373 return
375 self._iteration += 1
377 # 调用 LLM 或启发式规则生成子任务
378 sub_tasks = self._generate_sub_tasks(node, dag.metadata)
380 if not sub_tasks or len(sub_tasks) <= 1:
381 return # 无法继续分解
383 # 移除原节点,插入子节点和边
384 dag.nodes.pop(node_id)
385 for i, sub in enumerate(sub_tasks):
386 dag.nodes[sub.id] = sub
387 # 子任务按顺序或并行链接
388 if i > 0:
389 dag.edges.append(TaskEdge(
390 from_node=sub_tasks[i - 1].id,
391 to_node=sub.id,
392 ))
394 # 重连原节点的入边和出边
395 incoming = [e for e in dag.edges if e.to_node == node_id]
396 outgoing = [e for e in dag.edges if e.from_node == node_id]
398 # 移除旧边
399 dag.edges = [e for e in dag.edges if e.to_node != node_id and e.from_node != node_id]
401 if sub_tasks:
402 first = sub_tasks[0]
403 for e in incoming:
404 dag.edges.append(TaskEdge(from_node=e.from_node, to_node=first.id))
405 last = sub_tasks[-1]
406 for e in outgoing:
407 dag.edges.append(TaskEdge(from_node=last.id, to_node=e.to_node))
409 self._trace.append(DecompositionTrace(
410 iteration=self._iteration,
411 action="split",
412 node_before=node,
413 nodes_after=sub_tasks,
414 reason=f"Decomposed at depth {depth}",
415 ))
417 # 递归分解子任务
418 if len(dag.nodes) < self.MAX_NODES:
419 for sub in sub_tasks:
420 if sub.id in dag.nodes:
421 self._decompose_recursive(dag, sub.id, depth + 1)
423 def _is_atomic(self, node: TaskNode, depth: int) -> bool:
424 """判断节点是否已达到原子粒度,无需进一步分解。"""
425 # 规则 1:估算时长够短
426 if node.estimated_duration_s < self.MIN_NODE_DURATION:
427 return True
428 # 规则 2:节点数已接近上限
429 if depth > self.MAX_DEPTH - 1:
430 return True
431 # 规则 3:描述过于简单(单步骤)
432 if len(node.description.split()) < 5:
433 return True
434 return False
436 def _generate_sub_tasks(
437 self, node: TaskNode, context: Dict[str, Any]
438 ) -> List[TaskNode]:
439 """生成节点的子任务列表。
441 优先使用 LLM 调用,降级为启发式规则。
442 """
443 if self._llm_call:
444 return self._llm_generate(node, context)
445 return self._heuristic_generate(node)
447 def _llm_generate(self, node: TaskNode, context: Dict[str, Any]) -> List[TaskNode]:
448 """通过 LLM 调用生成子任务。"""
449 prompt = f"""Break down the following task into 2-5 subtasks.
451Task: {node.description}
452Context: {json.dumps(context, default=str) if context else 'None'}
454Output JSON array of subtasks, each with:
455- description: string
456- agent_type: string (default/planner/executor/analyst)
457- estimated_duration_s: float
459Only respond with the JSON array, no other text."""
460 try:
461 result = self._llm_call(prompt)
462 items = json.loads(result) if isinstance(result, str) else result
463 return [
464 TaskNode(
465 description=item["description"],
466 agent_type=item.get("agent_type", "default"),
467 estimated_duration_s=item.get("estimated_duration_s", 5.0),
468 confidence=0.7,
469 )
470 for item in items
471 ]
472 except Exception as e:
473 logger.warning(f"LLM decomposition failed: {e}, falling back to heuristic")
474 return self._heuristic_generate(node)
476 def _heuristic_generate(self, node: TaskNode) -> List[TaskNode]:
477 """启发式任务分解 — 基于关键词和模式匹配。"""
478 desc = node.description.lower()
479 subtasks: List[TaskNode] = []
481 # 模式 1:提取/收集 → 分析 → 生成
482 if any(kw in desc for kw in ("extract", "collect", "fetch", "retrieve", "提取", "收集")):
483 subtasks.append(self._create_node(f"Phase 1: Collect data for: {node.description[:60]}"))
484 subtasks.append(self._create_node(f"Phase 2: Analyze/process collected data"))
485 subtasks.append(self._create_node(f"Phase 3: Generate output/report for: {node.description[:60]}"))
487 # 模式 2:对比/比较
488 elif any(kw in desc for kw in ("compare", "vs", "对比", "比较", "versus")):
489 parts = desc.replace("compare ", "").replace("对比 ", "").split(" vs ")
490 if len(parts) < 2:
491 parts = desc.replace("compare ", "").split(" and ")
492 if len(parts) >= 2:
493 subtasks.append(self._create_node(f"Analyze: {parts[0].strip()}"))
494 subtasks.append(self._create_node(f"Analyze: {parts[1].strip()}"))
495 subtasks.append(self._create_node(f"Synthesize comparison results"))
497 # 模式 3:transform/convert/migrate
498 elif any(kw in desc for kw in ("transform", "convert", "migrate", "转换", "迁移")):
499 subtasks.append(self._create_node(f"Validate source data integrity"))
500 subtasks.append(self._create_node(f"Execute transformation: {node.description[:60]}"))
501 subtasks.append(self._create_node(f"Verify output correctness"))
503 # 模式 4:default — 按步骤拆
504 else:
505 subtasks.append(self._create_node(f"Plan: outline steps for '{node.description[:60]}'"))
506 subtasks.append(self._create_node(f"Execute: carry out '{node.description[:60]}'"))
507 subtasks.append(self._create_node(f"Validate: check results of '{node.description[:60]}'"))
509 for sub in subtasks:
510 sub.confidence = 0.6 # 启发式分解信心较低
511 return subtasks
513 def _collect_downstream(self, dag: TaskDAG, failed_nodes: List[str]) -> Set[str]:
514 """收集失败节点及所有下游节点。"""
515 adj = dag.adjacency_map()
516 affected: Set[str] = set()
518 queue = deque(failed_nodes)
519 while queue:
520 nid = queue.popleft()
521 if nid in affected:
522 continue
523 affected.add(nid)
524 for neighbor in adj.get(nid, []):
525 if neighbor not in affected:
526 queue.append(neighbor)
528 return affected
530 def _break_cycles(self, dag: TaskDAG, cycles: Set[str]) -> TaskDAG:
531 """打破循环 — 移除循环中置信度最低的边。"""
532 cycle_edges = [e for e in dag.edges if e.from_node in cycles and e.to_node in cycles]
533 if cycle_edges:
534 # 移除第一条循环边(可改进为最小置信度边)
535 dag.edges.remove(cycle_edges[0])
536 logger.info(f"Removed edge {cycle_edges[0].from_node}→{cycle_edges[0].to_node} to break cycle")
537 return dag
540# ── Quick Start ──────────────────────────────
543def create_decomposer(
544 strategy: DecompositionStrategy = DecompositionStrategy.RECURSIVE,
545 llm_call: Optional[Callable] = None,
546) -> TaskDecomposer:
547 return TaskDecomposer(strategy=strategy, llm_call=llm_call)