Coverage for agentos/orchestration/graph_executor.py: 33%

178 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Agent Graph — DAG-based multi-agent execution engine. 

3 

4Build complex agent pipelines as directed acyclic graphs where each node 

5is an agent invocation and edges define data flow dependencies. 

6""" 

7 

8from __future__ import annotations 

9 

10import time 

11from collections import deque 

12from dataclasses import dataclass, field 

13from enum import Enum 

14from typing import Any, Callable, Optional 

15 

16 

17class GraphNodeState(Enum): 

18 """Execution state of a graph orchestrator node.""" 

19 PENDING = "pending" 

20 RUNNING = "running" 

21 COMPLETED = "completed" 

22 FAILED = "failed" 

23 SKIPPED = "skipped" 

24 

25 

26@dataclass 

27class GraphNode: 

28 """A single node in the agent execution graph.""" 

29 

30 name: str 

31 agent_type: str 

32 task_template: str 

33 """Template string with {input} or {node_name.output} placeholders.""" 

34 

35 depends_on: list[str] = field(default_factory=list) 

36 """Node names this node depends on.""" 

37 

38 timeout_seconds: float = 120.0 

39 retry_count: int = 0 

40 on_failure: str = "abort" 

41 """Action on failure: 'abort', 'skip', 'continue'.""" 

42 

43 state: GraphNodeState = GraphNodeState.PENDING 

44 output: Any = None 

45 error: Optional[str] = None 

46 latency_ms: float = 0.0 

47 

48 def resolve_task(self, node_outputs: dict[str, Any]) -> str: 

49 """Resolve template placeholders using outputs from completed nodes.""" 

50 task = self.task_template 

51 task = task.replace("{input}", str(node_outputs.get("__input__", ""))) 

52 for name, output in node_outputs.items(): 

53 placeholder = f"{{{name}.output}}" 

54 if placeholder in task: 

55 task = task.replace(placeholder, str(output)) 

56 return task 

57 

58 

59@dataclass 

60class GraphResult: 

61 """Result of graph execution.""" 

62 

63 node_results: dict[str, GraphNode] = field(default_factory=dict) 

64 execution_order: list[str] = field(default_factory=list) 

65 total_latency_ms: float = 0.0 

66 success: bool = True 

67 error: Optional[str] = None 

68 

69 

70class AgentGraph: 

71 """ 

72 DAG-based multi-agent execution engine. 

73 

74 Define execution graphs declaratively, resolve dependencies automatically, 

75 execute nodes in topological order with parallelism for independent nodes. 

76 

77 Example:: 

78 

79 graph = AgentGraph() 

80 graph.add_node(GraphNode( 

81 name="research", 

82 agent_type="researcher", 

83 task_template="Research: {input}" 

84 )) 

85 graph.add_node(GraphNode( 

86 name="summarize", 

87 agent_type="summarizer", 

88 task_template="Summarize: {research.output}", 

89 depends_on=["research"] 

90 )) 

91 result = graph.execute("quantum computing advances") 

92 """ 

93 

94 def __init__(self, executor: Optional[Callable[[str, str], Any]] = None): 

95 """ 

96 Args: 

97 executor: Callable(agent_type, task) -> output. If not provided, 

98 subclasses must override _execute_node. 

99 """ 

100 self._nodes: dict[str, GraphNode] = {} 

101 self._executor = executor 

102 

103 def add_node(self, node: GraphNode) -> None: 

104 """Add a node to the graph. Raises ValueError on duplicate name.""" 

105 if node.name in self._nodes: 

106 raise ValueError(f"Duplicate node name: {node.name}") 

107 self._nodes[node.name] = node 

108 

109 def remove_node(self, name: str) -> None: 

110 """Remove a node and all edges referencing it.""" 

111 if name not in self._nodes: 

112 raise KeyError(f"Node not found: {name}") 

113 del self._nodes[name] 

114 for node in self._nodes.values(): 

115 node.depends_on = [d for d in node.depends_on if d != name] 

116 

117 def validate(self) -> list[str]: 

118 """ 

119 Validate graph integrity. 

120 

121 Returns: 

122 List of error messages (empty if valid). 

123 """ 

124 errors: list[str] = [] 

125 

126 for name, node in self._nodes.items(): 

127 for dep in node.depends_on: 

128 if dep not in self._nodes: 

129 errors.append(f"Node '{name}' depends on unknown node '{dep}'") 

130 if dep == name: 

131 errors.append(f"Node '{name}' cannot depend on itself") 

132 

133 # Check for cycles using topological sort 

134 if not errors: 

135 try: 

136 self._topological_order() 

137 except ValueError as e: 

138 errors.append(str(e)) 

139 

140 return errors 

141 

142 def _topological_order(self) -> list[str]: 

143 """Return nodes in topological order. Raises ValueError on cycle.""" 

144 in_degree: dict[str, int] = {name: 0 for name in self._nodes} 

145 adjacency: dict[str, list[str]] = {name: [] for name in self._nodes} 

146 

147 for name, node in self._nodes.items(): 

148 for dep in node.depends_on: 

149 adjacency[dep].append(name) 

150 in_degree[name] += 1 

151 

152 queue = deque([name for name, deg in in_degree.items() if deg == 0]) 

153 order: list[str] = [] 

154 

155 while queue: 

156 current = queue.popleft() 

157 order.append(current) 

158 for neighbor in adjacency[current]: 

159 in_degree[neighbor] -= 1 

160 if in_degree[neighbor] == 0: 

161 queue.append(neighbor) 

162 

163 if len(order) != len(self._nodes): 

164 remaining = set(self._nodes) - set(order) 

165 raise ValueError(f"Cycle detected involving nodes: {remaining}") 

166 

167 return order 

168 

169 def execute(self, input_data: str) -> GraphResult: 

170 """ 

171 Execute the graph with given input. 

172 

173 Args: 

174 input_data: Initial task input, accessible as {input} in templates. 

175 

176 Returns: 

177 GraphResult with per-node outputs and execution metadata. 

178 """ 

179 errors = self.validate() 

180 if errors: 

181 return GraphResult(success=False, error="; ".join(errors)) 

182 

183 t0 = time.perf_counter() 

184 node_outputs: dict[str, Any] = {"__input__": input_data} 

185 results: dict[str, GraphNode] = {} 

186 order: list[str] = [] 

187 

188 # Reset all nodes 

189 for node in self._nodes.values(): 

190 node.state = GraphNodeState.PENDING 

191 node.output = None 

192 node.error = None 

193 node.latency_ms = 0.0 

194 

195 try: 

196 topo = self._topological_order() 

197 except ValueError as e: 

198 return GraphResult(success=False, error=str(e)) 

199 

200 abort = False 

201 for name in topo: 

202 if abort: 

203 self._nodes[name].state = GraphNodeState.SKIPPED 

204 results[name] = self._nodes[name] 

205 continue 

206 

207 node = self._nodes[name] 

208 results[name] = node 

209 order.append(name) 

210 

211 # Check dependencies 

212 deps_failed = False 

213 for dep in node.depends_on: 

214 if results[dep].state == GraphNodeState.FAILED: 

215 deps_failed = True 

216 break 

217 

218 if deps_failed: 

219 node.state = GraphNodeState.SKIPPED 

220 continue 

221 

222 task = node.resolve_task(node_outputs) 

223 node_t0 = time.perf_counter() 

224 

225 try: 

226 node.state = GraphNodeState.RUNNING 

227 output = self._execute_node(node.agent_type, task) 

228 node.output = output 

229 node.state = GraphNodeState.COMPLETED 

230 node_outputs[name] = output 

231 except Exception as exc: 

232 node.state = GraphNodeState.FAILED 

233 node.error = f"{type(exc).__name__}: {exc}" 

234 node_outputs[name] = None 

235 if node.on_failure == "abort": 

236 abort = True 

237 

238 node.latency_ms = (time.perf_counter() - node_t0) * 1000 

239 

240 success = all( 

241 n.state in (GraphNodeState.COMPLETED, GraphNodeState.SKIPPED) 

242 for n in results.values() 

243 ) 

244 total_latency = (time.perf_counter() - t0) * 1000 

245 

246 return GraphResult( 

247 node_results=results, 

248 execution_order=order, 

249 total_latency_ms=total_latency, 

250 success=success, 

251 ) 

252 

253 def _execute_node(self, agent_type: str, task: str) -> Any: 

254 """Execute a single node. Override or provide executor callback.""" 

255 if self._executor: 

256 return self._executor(agent_type, task) 

257 raise NotImplementedError( 

258 "No executor provided. Pass executor to __init__ or override _execute_node." 

259 ) 

260 

261 def to_mermaid(self) -> str: 

262 """Export graph as Mermaid flowchart.""" 

263 lines = ["graph TD"] 

264 for name, node in self._nodes.items(): 

265 safe = name.replace("-", "_").replace(" ", "_") 

266 lines.append(f" {safe}[\"{name}\\n({node.agent_type})\"]") 

267 for name, node in self._nodes.items(): 

268 safe = name.replace("-", "_").replace(" ", "_") 

269 for dep in node.depends_on: 

270 safe_dep = dep.replace("-", "_").replace(" ", "_") 

271 lines.append(f" {safe_dep} --> {safe}") 

272 return "\n".join(lines) 

273 

274 @property 

275 def node_count(self) -> int: 

276 return len(self._nodes) 

277 

278 @property 

279 def edge_count(self) -> int: 

280 return sum(len(n.depends_on) for n in self._nodes.values()) 

281 

282 

283@dataclass 

284class GraphRecipe: 

285 """Declarative graph definition (YAML-friendly).""" 

286 

287 name: str 

288 description: str = "" 

289 nodes: list[dict[str, Any]] = field(default_factory=list) 

290 """List of node dicts with keys: name, agent_type, task_template, depends_on, timeout_seconds, on_failure.""" 

291 

292 @classmethod 

293 def from_dict(cls, data: dict) -> "GraphRecipe": 

294 return cls( 

295 name=data.get("name", "unnamed"), 

296 description=data.get("description", ""), 

297 nodes=data.get("nodes", []), 

298 ) 

299 

300 def build(self, executor: Optional[Callable] = None) -> AgentGraph: 

301 """Build an AgentGraph from this recipe.""" 

302 graph = AgentGraph(executor=executor) 

303 for spec in self.nodes: 

304 graph.add_node(GraphNode( 

305 name=spec["name"], 

306 agent_type=spec.get("agent_type", "default"), 

307 task_template=spec["task_template"], 

308 depends_on=spec.get("depends_on", []), 

309 timeout_seconds=spec.get("timeout_seconds", 120.0), 

310 retry_count=spec.get("retry_count", 0), 

311 on_failure=spec.get("on_failure", "abort"), 

312 )) 

313 return graph