Coverage for agentos/orchestration/graph_executor.py: 32%

184 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 13:55 +0800

1""" 

2Agent Graph — DAG-based multi-agent execution engine. 

3 

4Build complex agent pipelines as directed acyclic graphs where each node 

5is an agent invocation and edges define data flow dependencies. 

6 

7v1.16.1: Integrated MetricsCollector for per-node execution tracking. 

8""" 

9 

10from __future__ import annotations 

11 

12import time 

13from collections import deque 

14from dataclasses import dataclass, field 

15from enum import Enum 

16from typing import Any, Callable, Optional 

17 

18from agentos.tools.metrics import MetricsCollector 

19 

20 

21class GraphNodeState(Enum): 

22 """Execution state of a graph orchestrator node.""" 

23 PENDING = "pending" 

24 RUNNING = "running" 

25 COMPLETED = "completed" 

26 FAILED = "failed" 

27 SKIPPED = "skipped" 

28 

29 

30@dataclass 

31class GraphNode: 

32 """A single node in the agent execution graph.""" 

33 

34 name: str 

35 agent_type: str 

36 task_template: str 

37 """Template string with {input} or {node_name.output} placeholders.""" 

38 

39 depends_on: list[str] = field(default_factory=list) 

40 """Node names this node depends on.""" 

41 

42 timeout_seconds: float = 120.0 

43 retry_count: int = 0 

44 on_failure: str = "abort" 

45 """Action on failure: 'abort', 'skip', 'continue'.""" 

46 

47 state: GraphNodeState = GraphNodeState.PENDING 

48 output: Any = None 

49 error: Optional[str] = None 

50 latency_ms: float = 0.0 

51 

52 def resolve_task(self, node_outputs: dict[str, Any]) -> str: 

53 """Resolve template placeholders using outputs from completed nodes.""" 

54 task = self.task_template 

55 task = task.replace("{input}", str(node_outputs.get("__input__", ""))) 

56 for name, output in node_outputs.items(): 

57 placeholder = f"{{{name}.output}}" 

58 if placeholder in task: 

59 task = task.replace(placeholder, str(output)) 

60 return task 

61 

62 

63@dataclass 

64class GraphResult: 

65 """Result of graph execution.""" 

66 

67 node_results: dict[str, GraphNode] = field(default_factory=dict) 

68 execution_order: list[str] = field(default_factory=list) 

69 total_latency_ms: float = 0.0 

70 success: bool = True 

71 error: Optional[str] = None 

72 

73 

74class AgentGraph: 

75 """ 

76 DAG-based multi-agent execution engine. 

77 

78 Define execution graphs declaratively, resolve dependencies automatically, 

79 execute nodes in topological order with parallelism for independent nodes. 

80 

81 Example:: 

82 

83 graph = AgentGraph() 

84 graph.add_node(GraphNode( 

85 name="research", 

86 agent_type="researcher", 

87 task_template="Research: {input}" 

88 )) 

89 graph.add_node(GraphNode( 

90 name="summarize", 

91 agent_type="summarizer", 

92 task_template="Summarize: {research.output}", 

93 depends_on=["research"] 

94 )) 

95 result = graph.execute("quantum computing advances") 

96 """ 

97 

98 def __init__(self, executor: Optional[Callable[[str, str], Any]] = None, metrics: Optional[MetricsCollector] = None): 

99 """ 

100 Args: 

101 executor: Callable(agent_type, task) -> output. If not provided, 

102 subclasses must override _execute_node. 

103 metrics: Optional MetricsCollector for per-node execution tracking. 

104 """ 

105 self._nodes: dict[str, GraphNode] = {} 

106 self._executor = executor 

107 self._metrics = metrics 

108 

109 def add_node(self, node: GraphNode) -> None: 

110 """Add a node to the graph. Raises ValueError on duplicate name.""" 

111 if node.name in self._nodes: 

112 raise ValueError(f"Duplicate node name: {node.name}") 

113 self._nodes[node.name] = node 

114 

115 def remove_node(self, name: str) -> None: 

116 """Remove a node and all edges referencing it.""" 

117 if name not in self._nodes: 

118 raise KeyError(f"Node not found: {name}") 

119 del self._nodes[name] 

120 for node in self._nodes.values(): 

121 node.depends_on = [d for d in node.depends_on if d != name] 

122 

123 def validate(self) -> list[str]: 

124 """ 

125 Validate graph integrity. 

126 

127 Returns: 

128 List of error messages (empty if valid). 

129 """ 

130 errors: list[str] = [] 

131 

132 for name, node in self._nodes.items(): 

133 for dep in node.depends_on: 

134 if dep not in self._nodes: 

135 errors.append(f"Node '{name}' depends on unknown node '{dep}'") 

136 if dep == name: 

137 errors.append(f"Node '{name}' cannot depend on itself") 

138 

139 # Check for cycles using topological sort 

140 if not errors: 

141 try: 

142 self._topological_order() 

143 except ValueError as e: 

144 errors.append(str(e)) 

145 

146 return errors 

147 

148 def _topological_order(self) -> list[str]: 

149 """Return nodes in topological order. Raises ValueError on cycle.""" 

150 in_degree: dict[str, int] = {name: 0 for name in self._nodes} 

151 adjacency: dict[str, list[str]] = {name: [] for name in self._nodes} 

152 

153 for name, node in self._nodes.items(): 

154 for dep in node.depends_on: 

155 adjacency[dep].append(name) 

156 in_degree[name] += 1 

157 

158 queue = deque([name for name, deg in in_degree.items() if deg == 0]) 

159 order: list[str] = [] 

160 

161 while queue: 

162 current = queue.popleft() 

163 order.append(current) 

164 for neighbor in adjacency[current]: 

165 in_degree[neighbor] -= 1 

166 if in_degree[neighbor] == 0: 

167 queue.append(neighbor) 

168 

169 if len(order) != len(self._nodes): 

170 remaining = set(self._nodes) - set(order) 

171 raise ValueError(f"Cycle detected involving nodes: {remaining}") 

172 

173 return order 

174 

175 def execute(self, input_data: str) -> GraphResult: 

176 """ 

177 Execute the graph with given input. 

178 

179 Args: 

180 input_data: Initial task input, accessible as {input} in templates. 

181 

182 Returns: 

183 GraphResult with per-node outputs and execution metadata. 

184 """ 

185 errors = self.validate() 

186 if errors: 

187 return GraphResult(success=False, error="; ".join(errors)) 

188 

189 t0 = time.perf_counter() 

190 node_outputs: dict[str, Any] = {"__input__": input_data} 

191 results: dict[str, GraphNode] = {} 

192 order: list[str] = [] 

193 

194 # Reset all nodes 

195 for node in self._nodes.values(): 

196 node.state = GraphNodeState.PENDING 

197 node.output = None 

198 node.error = None 

199 node.latency_ms = 0.0 

200 

201 try: 

202 topo = self._topological_order() 

203 except ValueError as e: 

204 return GraphResult(success=False, error=str(e)) 

205 

206 abort = False 

207 for name in topo: 

208 if abort: 

209 self._nodes[name].state = GraphNodeState.SKIPPED 

210 results[name] = self._nodes[name] 

211 continue 

212 

213 node = self._nodes[name] 

214 results[name] = node 

215 order.append(name) 

216 

217 # Check dependencies 

218 deps_failed = False 

219 for dep in node.depends_on: 

220 if results[dep].state == GraphNodeState.FAILED: 

221 deps_failed = True 

222 break 

223 

224 if deps_failed: 

225 node.state = GraphNodeState.SKIPPED 

226 continue 

227 

228 task = node.resolve_task(node_outputs) 

229 node_t0 = time.perf_counter() 

230 

231 try: 

232 node.state = GraphNodeState.RUNNING 

233 output = self._execute_node(node.agent_type, task) 

234 node.output = output 

235 node.state = GraphNodeState.COMPLETED 

236 node_outputs[name] = output 

237 except Exception as exc: 

238 node.state = GraphNodeState.FAILED 

239 node.error = f"{type(exc).__name__}: {exc}" 

240 node_outputs[name] = None 

241 if node.on_failure == "abort": 

242 abort = True 

243 

244 node.latency_ms = (time.perf_counter() - node_t0) * 1000 

245 

246 # Metrics: track per-node execution 

247 if self._metrics is not None: 

248 self._metrics.get_counter("graph_nodes_total").inc(node.agent_type) 

249 self._metrics.get_counter( 

250 f"graph_node_{node.state.value}" 

251 ).inc(node.agent_type) 

252 self._metrics.get_timer("graph_node_latency_ms").record(node.latency_ms) 

253 

254 success = all( 

255 n.state in (GraphNodeState.COMPLETED, GraphNodeState.SKIPPED) 

256 for n in results.values() 

257 ) 

258 total_latency = (time.perf_counter() - t0) * 1000 

259 

260 return GraphResult( 

261 node_results=results, 

262 execution_order=order, 

263 total_latency_ms=total_latency, 

264 success=success, 

265 ) 

266 

267 def _execute_node(self, agent_type: str, task: str) -> Any: 

268 """Execute a single node. Override or provide executor callback.""" 

269 if self._executor: 

270 return self._executor(agent_type, task) 

271 raise NotImplementedError( 

272 "No executor provided. Pass executor to __init__ or override _execute_node." 

273 ) 

274 

275 def to_mermaid(self) -> str: 

276 """Export graph as Mermaid flowchart.""" 

277 lines = ["graph TD"] 

278 for name, node in self._nodes.items(): 

279 safe = name.replace("-", "_").replace(" ", "_") 

280 lines.append(f" {safe}[\"{name}\\n({node.agent_type})\"]") 

281 for name, node in self._nodes.items(): 

282 safe = name.replace("-", "_").replace(" ", "_") 

283 for dep in node.depends_on: 

284 safe_dep = dep.replace("-", "_").replace(" ", "_") 

285 lines.append(f" {safe_dep} --> {safe}") 

286 return "\n".join(lines) 

287 

288 @property 

289 def node_count(self) -> int: 

290 return len(self._nodes) 

291 

292 @property 

293 def edge_count(self) -> int: 

294 return sum(len(n.depends_on) for n in self._nodes.values()) 

295 

296 

297@dataclass 

298class GraphRecipe: 

299 """Declarative graph definition (YAML-friendly).""" 

300 

301 name: str 

302 description: str = "" 

303 nodes: list[dict[str, Any]] = field(default_factory=list) 

304 """List of node dicts with keys: name, agent_type, task_template, depends_on, timeout_seconds, on_failure.""" 

305 

306 @classmethod 

307 def from_dict(cls, data: dict) -> "GraphRecipe": 

308 return cls( 

309 name=data.get("name", "unnamed"), 

310 description=data.get("description", ""), 

311 nodes=data.get("nodes", []), 

312 ) 

313 

314 def build(self, executor: Optional[Callable] = None) -> AgentGraph: 

315 """Build an AgentGraph from this recipe.""" 

316 graph = AgentGraph(executor=executor) 

317 for spec in self.nodes: 

318 graph.add_node(GraphNode( 

319 name=spec["name"], 

320 agent_type=spec.get("agent_type", "default"), 

321 task_template=spec["task_template"], 

322 depends_on=spec.get("depends_on", []), 

323 timeout_seconds=spec.get("timeout_seconds", 120.0), 

324 retry_count=spec.get("retry_count", 0), 

325 on_failure=spec.get("on_failure", "abort"), 

326 )) 

327 return graph