Coverage for agentos/orchestration/task_decomposer.py: 28%

279 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:36 +0800

1""" 

2AgentOS v1.14.7 — Intelligent Task Decomposer 2.0. 

3 

4Replaces the simplistic 234-line single-prompt decomposer with: 

5- DAG cycle detection (Kahn's + DFS fallback) 

6- Dynamic re-planning on partial failure 

7- Task dependency validation 

8- Parallelism detection (independent sub-tasks) 

9- Confidence scoring per sub-task 

10- Observable decomposition trace 

11 

12Architecture: 

13 TaskInput → Decomposer.decompose() → TaskDAG 

14 Partial failure → Decomposer.replan(failed_node) → new_path 

15""" 

16 

17from __future__ import annotations 

18 

19import json 

20import logging 

21import uuid 

22from collections import deque 

23from dataclasses import dataclass, field 

24from enum import Enum 

25from typing import Any, Callable, Dict, List, Optional, Set, Tuple 

26 

27logger = logging.getLogger(__name__) 

28 

29 

30# ── Types ──────────────────────────────────── 

31 

32 

33class TaskNodeStatus(str, Enum): 

34 PENDING = "pending" 

35 RUNNING = "running" 

36 COMPLETED = "completed" 

37 FAILED = "failed" 

38 SKIPPED = "skipped" 

39 

40 

41class DecompositionStrategy(str, Enum): 

42 """分解策略。""" 

43 TOP_DOWN = "top_down" # 从目标逐层拆解 

44 BOTTOM_UP = "bottom_up" # 从子任务聚合 

45 RECURSIVE = "recursive" # 递归分解直到原子任务 

46 HEURISTIC = "heuristic" # 基于规则/模式匹配 

47 

48 

49@dataclass 

50class TaskEdge: 

51 """DAG 边:from_node 完成后才能执行 to_node。""" 

52 from_node: str 

53 to_node: str 

54 dependency_type: str = "hard" # hard / soft 

55 

56 

57@dataclass 

58class TaskNode: 

59 """DAG 节点:单个可执行单元。""" 

60 id: str = field(default_factory=lambda: f"tn-{uuid.uuid4().hex[:8]}") 

61 description: str = "" 

62 input_schema: dict = field(default_factory=dict) 

63 output_schema: dict = field(default_factory=dict) 

64 agent_type: str = "default" # 推荐执行 Agent 类型 

65 estimated_duration_s: float = 0.0 

66 confidence: float = 1.0 # 0~1,分解置信度 

67 retry_policy: str = "once" # once / retry_n / fallback 

68 max_retries: int = 1 

69 status: TaskNodeStatus = TaskNodeStatus.PENDING 

70 result: Any = None 

71 error: str = "" 

72 

73 

74@dataclass 

75class TaskDAG: 

76 """完整的任务 DAG。""" 

77 dag_id: str = field(default_factory=lambda: f"dag-{uuid.uuid4().hex[:8]}") 

78 root_task: str = "" # 原始任务描述 

79 nodes: Dict[str, TaskNode] = field(default_factory=dict) 

80 edges: List[TaskEdge] = field(default_factory=list) 

81 strategy: DecompositionStrategy = DecompositionStrategy.TOP_DOWN 

82 metadata: Dict[str, Any] = field(default_factory=dict) 

83 created_at: float = 0.0 

84 

85 def in_degree_map(self) -> Dict[str, int]: 

86 """计算每个节点的入度。""" 

87 indeg: Dict[str, int] = {nid: 0 for nid in self.nodes} 

88 for e in self.edges: 

89 indeg[e.to_node] = indeg.get(e.to_node, 0) + 1 

90 return indeg 

91 

92 def adjacency_map(self) -> Dict[str, List[str]]: 

93 """邻接表。""" 

94 adj: Dict[str, List[str]] = {nid: [] for nid in self.nodes} 

95 for e in self.edges: 

96 adj[e.from_node].append(e.to_node) 

97 return adj 

98 

99 def topological_order(self) -> List[str]: 

100 """Kahn 算法拓扑排序,遇循环抛 ValueError。""" 

101 indeg = self.in_degree_map() 

102 adj = self.adjacency_map() 

103 queue = deque([nid for nid, d in indeg.items() if d == 0]) 

104 order: List[str] = [] 

105 

106 while queue: 

107 node = queue.popleft() 

108 order.append(node) 

109 for neighbor in adj.get(node, []): 

110 indeg[neighbor] -= 1 

111 if indeg[neighbor] == 0: 

112 queue.append(neighbor) 

113 

114 if len(order) != len(self.nodes): 

115 remaining = set(self.nodes) - set(order) 

116 raise ValueError( 

117 f"Cycle detected in DAG. {len(remaining)} nodes in cycle: " 

118 f"{list(remaining)[:5]}..." 

119 ) 

120 

121 return order 

122 

123 def detect_cycles(self) -> Set[str]: 

124 """检测并返回所有参与循环的节点 ID 集合。""" 

125 indeg = self.in_degree_map() 

126 adj = self.adjacency_map() 

127 queue = deque([nid for nid, d in indeg.items() if d == 0]) 

128 acyclic: Set[str] = set() 

129 

130 while queue: 

131 node = queue.popleft() 

132 acyclic.add(node) 

133 for neighbor in adj.get(node, []): 

134 indeg[neighbor] -= 1 

135 if indeg[neighbor] == 0: 

136 queue.append(neighbor) 

137 

138 return set(self.nodes) - acyclic 

139 

140 def parallel_groups(self) -> List[List[str]]: 

141 """按拓扑层级分组,同一组内可并行执行。""" 

142 order = self.topological_order() 

143 indeg = self.in_degree_map() 

144 adj = self.adjacency_map() 

145 

146 groups: List[List[str]] = [] 

147 remaining = set(order) 

148 

149 while remaining: 

150 # 所有入度为 0 的当前批次 

151 batch = sorted([n for n in remaining if indeg.get(n, 0) == 0]) 

152 if not batch: 

153 break 

154 groups.append(batch) 

155 for n in batch: 

156 remaining.discard(n) 

157 for neighbor in adj.get(n, []): 

158 indeg[neighbor] -= 1 

159 

160 return groups 

161 

162 

163@dataclass 

164class DecompositionTrace: 

165 """分解过程可观测性记录。""" 

166 iteration: int 

167 action: str # split / merge / refine / replan 

168 node_before: Optional[TaskNode] = None 

169 nodes_after: List[TaskNode] = field(default_factory=list) 

170 reason: str = "" 

171 

172 

173# ── Decomposer ─────────────────────────────── 

174 

175 

176class TaskDecomposer: 

177 """智能任务分解器。 

178 

179 核心能力: 

180 1. 将复杂任务拆解为可执行的 DAG 

181 2. 检测并拒绝循环依赖 

182 3. 在部分失败时动态重规划 

183 4. 输出可观测的分解轨迹 

184 

185 Usage: 

186 decomposer = TaskDecomposer() 

187 dag = decomposer.decompose("从 10GB 日志中提取异常并生成日报") 

188 order = dag.topological_order() 

189 for nid in order: 

190 execute(dag.nodes[nid]) 

191 # 部分节点失败后 

192 new_dag = decomposer.replan(dag, failed_nodes=["tn-xxx"]) 

193 """ 

194 

195 MAX_DEPTH = 8 # 最大递归深度 

196 MIN_NODE_DURATION = 1.0 # 最小节点估算时长(秒),低于此不再分解 

197 MAX_NODES = 50 # 最多节点数 

198 

199 def __init__( 

200 self, 

201 strategy: DecompositionStrategy = DecompositionStrategy.RECURSIVE, 

202 llm_call: Optional[Callable] = None, 

203 ): 

204 self._strategy = strategy 

205 self._llm_call = llm_call 

206 self._trace: List[DecompositionTrace] = [] 

207 self._iteration = 0 

208 

209 # ── Public API ───────────────────────── 

210 

211 def decompose( 

212 self, 

213 task: str, 

214 context: Optional[Dict[str, Any]] = None, 

215 ) -> TaskDAG: 

216 """将任务分解为 DAG。 

217 

218 Args: 

219 task: 任务描述 

220 context: 补充上下文(已训练的 Agent、可用工具等) 

221 

222 Returns: 

223 完整的 TaskDAG 

224 """ 

225 self._trace = [] 

226 self._iteration = 0 

227 

228 dag = TaskDAG( 

229 root_task=task, 

230 strategy=self._strategy, 

231 metadata=context or {}, 

232 created_at=__import__("time").time(), 

233 ) 

234 

235 root = self._create_node(task, confidence=1.0) 

236 dag.nodes[root.id] = root 

237 

238 # 递归分解 

239 self._decompose_recursive(dag, root.id, depth=0) 

240 

241 # 验证无循环 

242 cycles = dag.detect_cycles() 

243 if cycles: 

244 logger.warning(f"Decomposition produced cycle: {cycles}. Re-resolving.") 

245 dag = self._break_cycles(dag, cycles) 

246 

247 return dag 

248 

249 def replan( 

250 self, 

251 dag: TaskDAG, 

252 failed_nodes: List[str], 

253 ) -> TaskDAG: 

254 """在部分节点执行失败后动态重规划。 

255 

256 Args: 

257 dag: 当前 DAG(含已执行的节点状态) 

258 failed_nodes: 失败节点 ID 列表 

259 

260 Returns: 

261 重规划后的新 DAG(仅影响失败节点及其下游) 

262 """ 

263 self._iteration += 1 

264 

265 affected = self._collect_downstream(dag, failed_nodes) 

266 

267 new_dag = TaskDAG( 

268 dag_id=f"{dag.dag_id}-replan-{self._iteration}", 

269 root_task=dag.root_task, 

270 metadata=dag.metadata, 

271 created_at=__import__("time").time(), 

272 ) 

273 

274 # 保留不受影响的节点和边 

275 for nid, node in dag.nodes.items(): 

276 if nid not in affected: 

277 new_dag.nodes[nid] = node 

278 

279 for edge in dag.edges: 

280 if edge.from_node not in affected and edge.to_node not in affected: 

281 new_dag.edges.append(edge) 

282 

283 # 为每个失败节点构建替代路径 

284 for nid in failed_nodes: 

285 node = dag.nodes[nid] 

286 original_desc = node.description 

287 

288 alt_node = self._create_node( 

289 f"[Retry Plan] {original_desc}", 

290 confidence=node.confidence * 0.8, # 降信心 

291 ) 

292 alt_node.retry_policy = "retry_n" 

293 alt_node.max_retries = node.max_retries + 1 

294 new_dag.nodes[alt_node.id] = alt_node 

295 

296 # 重连边:失败节点上游 → 新节点,新节点 → 失败节点下游 

297 incoming = [e for e in dag.edges if e.to_node == nid] 

298 outgoing = [e for e in dag.edges if e.from_node == nid] 

299 

300 for e in incoming: 

301 if e.from_node not in affected: 

302 new_dag.edges.append(TaskEdge(from_node=e.from_node, to_node=alt_node.id)) 

303 for e in outgoing: 

304 if e.to_node not in affected: 

305 new_dag.edges.append(TaskEdge(from_node=alt_node.id, to_node=e.to_node)) 

306 

307 self._trace.append(DecompositionTrace( 

308 iteration=self._iteration, 

309 action="replan", 

310 node_before=node, 

311 nodes_after=[alt_node], 

312 reason=f"Node {nid} failed: {node.error or 'unknown'}", 

313 )) 

314 

315 return new_dag 

316 

317 def get_trace(self) -> List[DecompositionTrace]: 

318 """获取完整的分解轨迹(用于可观测性)。""" 

319 return list(self._trace) 

320 

321 def validate_dag(self, dag: TaskDAG) -> Tuple[bool, str]: 

322 """验证 DAG 的结构完整性。 

323 

324 Returns: 

325 (is_valid, error_message) 

326 """ 

327 # 空 DAG 

328 if not dag.nodes: 

329 return False, "DAG has no nodes" 

330 

331 # 循环检测 

332 cycles = dag.detect_cycles() 

333 if cycles: 

334 return False, f"DAG contains cycles: {cycles}" 

335 

336 # 孤立节点检测 

337 connected: Set[str] = set() 

338 for e in dag.edges: 

339 connected.add(e.from_node) 

340 connected.add(e.to_node) 

341 isolated = set(dag.nodes) - connected 

342 if isolated and len(dag.nodes) > 1: 

343 logger.warning(f"Isolated nodes: {isolated}") 

344 

345 # 拓扑可达性(至少存在一条从源到汇的路径) 

346 try: 

347 dag.topological_order() 

348 except ValueError as e: 

349 return False, str(e) 

350 

351 return True, "valid" 

352 

353 # ── Internal ──────────────────────────── 

354 

355 def _create_node(self, description: str, confidence: float = 1.0) -> TaskNode: 

356 return TaskNode( 

357 description=description, 

358 confidence=confidence, 

359 estimated_duration_s=max(1.0, len(description.split()) * 0.5), 

360 ) 

361 

362 def _decompose_recursive(self, dag: TaskDAG, node_id: str, depth: int): 

363 """递归分解节点直到达到原子粒度。""" 

364 if depth >= self.MAX_DEPTH: 

365 return 

366 

367 node = dag.nodes.get(node_id) 

368 if not node: 

369 return 

370 

371 # 判断是否继续分解 

372 if self._is_atomic(node, depth): 

373 return 

374 

375 self._iteration += 1 

376 

377 # 调用 LLM 或启发式规则生成子任务 

378 sub_tasks = self._generate_sub_tasks(node, dag.metadata) 

379 

380 if not sub_tasks or len(sub_tasks) <= 1: 

381 return # 无法继续分解 

382 

383 # 移除原节点,插入子节点和边 

384 dag.nodes.pop(node_id) 

385 for i, sub in enumerate(sub_tasks): 

386 dag.nodes[sub.id] = sub 

387 # 子任务按顺序或并行链接 

388 if i > 0: 

389 dag.edges.append(TaskEdge( 

390 from_node=sub_tasks[i - 1].id, 

391 to_node=sub.id, 

392 )) 

393 

394 # 重连原节点的入边和出边 

395 incoming = [e for e in dag.edges if e.to_node == node_id] 

396 outgoing = [e for e in dag.edges if e.from_node == node_id] 

397 

398 # 移除旧边 

399 dag.edges = [e for e in dag.edges if e.to_node != node_id and e.from_node != node_id] 

400 

401 if sub_tasks: 

402 first = sub_tasks[0] 

403 for e in incoming: 

404 dag.edges.append(TaskEdge(from_node=e.from_node, to_node=first.id)) 

405 last = sub_tasks[-1] 

406 for e in outgoing: 

407 dag.edges.append(TaskEdge(from_node=last.id, to_node=e.to_node)) 

408 

409 self._trace.append(DecompositionTrace( 

410 iteration=self._iteration, 

411 action="split", 

412 node_before=node, 

413 nodes_after=sub_tasks, 

414 reason=f"Decomposed at depth {depth}", 

415 )) 

416 

417 # 递归分解子任务 

418 if len(dag.nodes) < self.MAX_NODES: 

419 for sub in sub_tasks: 

420 if sub.id in dag.nodes: 

421 self._decompose_recursive(dag, sub.id, depth + 1) 

422 

423 def _is_atomic(self, node: TaskNode, depth: int) -> bool: 

424 """判断节点是否已达到原子粒度,无需进一步分解。""" 

425 # 规则 1:估算时长够短 

426 if node.estimated_duration_s < self.MIN_NODE_DURATION: 

427 return True 

428 # 规则 2:节点数已接近上限 

429 if depth > self.MAX_DEPTH - 1: 

430 return True 

431 # 规则 3:描述过于简单(单步骤) 

432 if len(node.description.split()) < 5: 

433 return True 

434 return False 

435 

436 def _generate_sub_tasks( 

437 self, node: TaskNode, context: Dict[str, Any] 

438 ) -> List[TaskNode]: 

439 """生成节点的子任务列表。 

440 

441 优先使用 LLM 调用,降级为启发式规则。 

442 """ 

443 if self._llm_call: 

444 return self._llm_generate(node, context) 

445 return self._heuristic_generate(node) 

446 

447 def _llm_generate(self, node: TaskNode, context: Dict[str, Any]) -> List[TaskNode]: 

448 """通过 LLM 调用生成子任务。""" 

449 prompt = f"""Break down the following task into 2-5 subtasks. 

450 

451Task: {node.description} 

452Context: {json.dumps(context, default=str) if context else 'None'} 

453 

454Output JSON array of subtasks, each with: 

455- description: string 

456- agent_type: string (default/planner/executor/analyst) 

457- estimated_duration_s: float 

458 

459Only respond with the JSON array, no other text.""" 

460 try: 

461 result = self._llm_call(prompt) 

462 items = json.loads(result) if isinstance(result, str) else result 

463 return [ 

464 TaskNode( 

465 description=item["description"], 

466 agent_type=item.get("agent_type", "default"), 

467 estimated_duration_s=item.get("estimated_duration_s", 5.0), 

468 confidence=0.7, 

469 ) 

470 for item in items 

471 ] 

472 except Exception as e: 

473 logger.warning(f"LLM decomposition failed: {e}, falling back to heuristic") 

474 return self._heuristic_generate(node) 

475 

476 def _heuristic_generate(self, node: TaskNode) -> List[TaskNode]: 

477 """启发式任务分解 — 基于关键词和模式匹配。""" 

478 desc = node.description.lower() 

479 subtasks: List[TaskNode] = [] 

480 

481 # 模式 1:提取/收集 → 分析 → 生成 

482 if any(kw in desc for kw in ("extract", "collect", "fetch", "retrieve", "提取", "收集")): 

483 subtasks.append(self._create_node(f"Phase 1: Collect data for: {node.description[:60]}")) 

484 subtasks.append(self._create_node(f"Phase 2: Analyze/process collected data")) 

485 subtasks.append(self._create_node(f"Phase 3: Generate output/report for: {node.description[:60]}")) 

486 

487 # 模式 2:对比/比较 

488 elif any(kw in desc for kw in ("compare", "vs", "对比", "比较", "versus")): 

489 parts = desc.replace("compare ", "").replace("对比 ", "").split(" vs ") 

490 if len(parts) < 2: 

491 parts = desc.replace("compare ", "").split(" and ") 

492 if len(parts) >= 2: 

493 subtasks.append(self._create_node(f"Analyze: {parts[0].strip()}")) 

494 subtasks.append(self._create_node(f"Analyze: {parts[1].strip()}")) 

495 subtasks.append(self._create_node(f"Synthesize comparison results")) 

496 

497 # 模式 3:transform/convert/migrate 

498 elif any(kw in desc for kw in ("transform", "convert", "migrate", "转换", "迁移")): 

499 subtasks.append(self._create_node(f"Validate source data integrity")) 

500 subtasks.append(self._create_node(f"Execute transformation: {node.description[:60]}")) 

501 subtasks.append(self._create_node(f"Verify output correctness")) 

502 

503 # 模式 4:default — 按步骤拆 

504 else: 

505 subtasks.append(self._create_node(f"Plan: outline steps for '{node.description[:60]}'")) 

506 subtasks.append(self._create_node(f"Execute: carry out '{node.description[:60]}'")) 

507 subtasks.append(self._create_node(f"Validate: check results of '{node.description[:60]}'")) 

508 

509 for sub in subtasks: 

510 sub.confidence = 0.6 # 启发式分解信心较低 

511 return subtasks 

512 

513 def _collect_downstream(self, dag: TaskDAG, failed_nodes: List[str]) -> Set[str]: 

514 """收集失败节点及所有下游节点。""" 

515 adj = dag.adjacency_map() 

516 affected: Set[str] = set() 

517 

518 queue = deque(failed_nodes) 

519 while queue: 

520 nid = queue.popleft() 

521 if nid in affected: 

522 continue 

523 affected.add(nid) 

524 for neighbor in adj.get(nid, []): 

525 if neighbor not in affected: 

526 queue.append(neighbor) 

527 

528 return affected 

529 

530 def _break_cycles(self, dag: TaskDAG, cycles: Set[str]) -> TaskDAG: 

531 """打破循环 — 移除循环中置信度最低的边。""" 

532 cycle_edges = [e for e in dag.edges if e.from_node in cycles and e.to_node in cycles] 

533 if cycle_edges: 

534 # 移除第一条循环边(可改进为最小置信度边) 

535 dag.edges.remove(cycle_edges[0]) 

536 logger.info(f"Removed edge {cycle_edges[0].from_node}→{cycle_edges[0].to_node} to break cycle") 

537 return dag 

538 

539 

540# ── Quick Start ────────────────────────────── 

541 

542 

543def create_decomposer( 

544 strategy: DecompositionStrategy = DecompositionStrategy.RECURSIVE, 

545 llm_call: Optional[Callable] = None, 

546) -> TaskDecomposer: 

547 return TaskDecomposer(strategy=strategy, llm_call=llm_call)