Coverage for agentos/orchestration/graph_executor.py: 33%
178 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Agent Graph — DAG-based multi-agent execution engine.
4Build complex agent pipelines as directed acyclic graphs where each node
5is an agent invocation and edges define data flow dependencies.
6"""
8from __future__ import annotations
10import time
11from collections import deque
12from dataclasses import dataclass, field
13from enum import Enum
14from typing import Any, Callable, Optional
17class GraphNodeState(Enum):
18 """Execution state of a graph orchestrator node."""
19 PENDING = "pending"
20 RUNNING = "running"
21 COMPLETED = "completed"
22 FAILED = "failed"
23 SKIPPED = "skipped"
26@dataclass
27class GraphNode:
28 """A single node in the agent execution graph."""
30 name: str
31 agent_type: str
32 task_template: str
33 """Template string with {input} or {node_name.output} placeholders."""
35 depends_on: list[str] = field(default_factory=list)
36 """Node names this node depends on."""
38 timeout_seconds: float = 120.0
39 retry_count: int = 0
40 on_failure: str = "abort"
41 """Action on failure: 'abort', 'skip', 'continue'."""
43 state: GraphNodeState = GraphNodeState.PENDING
44 output: Any = None
45 error: Optional[str] = None
46 latency_ms: float = 0.0
48 def resolve_task(self, node_outputs: dict[str, Any]) -> str:
49 """Resolve template placeholders using outputs from completed nodes."""
50 task = self.task_template
51 task = task.replace("{input}", str(node_outputs.get("__input__", "")))
52 for name, output in node_outputs.items():
53 placeholder = f"{{{name}.output}}"
54 if placeholder in task:
55 task = task.replace(placeholder, str(output))
56 return task
59@dataclass
60class GraphResult:
61 """Result of graph execution."""
63 node_results: dict[str, GraphNode] = field(default_factory=dict)
64 execution_order: list[str] = field(default_factory=list)
65 total_latency_ms: float = 0.0
66 success: bool = True
67 error: Optional[str] = None
70class AgentGraph:
71 """
72 DAG-based multi-agent execution engine.
74 Define execution graphs declaratively, resolve dependencies automatically,
75 execute nodes in topological order with parallelism for independent nodes.
77 Example::
79 graph = AgentGraph()
80 graph.add_node(GraphNode(
81 name="research",
82 agent_type="researcher",
83 task_template="Research: {input}"
84 ))
85 graph.add_node(GraphNode(
86 name="summarize",
87 agent_type="summarizer",
88 task_template="Summarize: {research.output}",
89 depends_on=["research"]
90 ))
91 result = graph.execute("quantum computing advances")
92 """
94 def __init__(self, executor: Optional[Callable[[str, str], Any]] = None):
95 """
96 Args:
97 executor: Callable(agent_type, task) -> output. If not provided,
98 subclasses must override _execute_node.
99 """
100 self._nodes: dict[str, GraphNode] = {}
101 self._executor = executor
103 def add_node(self, node: GraphNode) -> None:
104 """Add a node to the graph. Raises ValueError on duplicate name."""
105 if node.name in self._nodes:
106 raise ValueError(f"Duplicate node name: {node.name}")
107 self._nodes[node.name] = node
109 def remove_node(self, name: str) -> None:
110 """Remove a node and all edges referencing it."""
111 if name not in self._nodes:
112 raise KeyError(f"Node not found: {name}")
113 del self._nodes[name]
114 for node in self._nodes.values():
115 node.depends_on = [d for d in node.depends_on if d != name]
117 def validate(self) -> list[str]:
118 """
119 Validate graph integrity.
121 Returns:
122 List of error messages (empty if valid).
123 """
124 errors: list[str] = []
126 for name, node in self._nodes.items():
127 for dep in node.depends_on:
128 if dep not in self._nodes:
129 errors.append(f"Node '{name}' depends on unknown node '{dep}'")
130 if dep == name:
131 errors.append(f"Node '{name}' cannot depend on itself")
133 # Check for cycles using topological sort
134 if not errors:
135 try:
136 self._topological_order()
137 except ValueError as e:
138 errors.append(str(e))
140 return errors
142 def _topological_order(self) -> list[str]:
143 """Return nodes in topological order. Raises ValueError on cycle."""
144 in_degree: dict[str, int] = {name: 0 for name in self._nodes}
145 adjacency: dict[str, list[str]] = {name: [] for name in self._nodes}
147 for name, node in self._nodes.items():
148 for dep in node.depends_on:
149 adjacency[dep].append(name)
150 in_degree[name] += 1
152 queue = deque([name for name, deg in in_degree.items() if deg == 0])
153 order: list[str] = []
155 while queue:
156 current = queue.popleft()
157 order.append(current)
158 for neighbor in adjacency[current]:
159 in_degree[neighbor] -= 1
160 if in_degree[neighbor] == 0:
161 queue.append(neighbor)
163 if len(order) != len(self._nodes):
164 remaining = set(self._nodes) - set(order)
165 raise ValueError(f"Cycle detected involving nodes: {remaining}")
167 return order
169 def execute(self, input_data: str) -> GraphResult:
170 """
171 Execute the graph with given input.
173 Args:
174 input_data: Initial task input, accessible as {input} in templates.
176 Returns:
177 GraphResult with per-node outputs and execution metadata.
178 """
179 errors = self.validate()
180 if errors:
181 return GraphResult(success=False, error="; ".join(errors))
183 t0 = time.perf_counter()
184 node_outputs: dict[str, Any] = {"__input__": input_data}
185 results: dict[str, GraphNode] = {}
186 order: list[str] = []
188 # Reset all nodes
189 for node in self._nodes.values():
190 node.state = GraphNodeState.PENDING
191 node.output = None
192 node.error = None
193 node.latency_ms = 0.0
195 try:
196 topo = self._topological_order()
197 except ValueError as e:
198 return GraphResult(success=False, error=str(e))
200 abort = False
201 for name in topo:
202 if abort:
203 self._nodes[name].state = GraphNodeState.SKIPPED
204 results[name] = self._nodes[name]
205 continue
207 node = self._nodes[name]
208 results[name] = node
209 order.append(name)
211 # Check dependencies
212 deps_failed = False
213 for dep in node.depends_on:
214 if results[dep].state == GraphNodeState.FAILED:
215 deps_failed = True
216 break
218 if deps_failed:
219 node.state = GraphNodeState.SKIPPED
220 continue
222 task = node.resolve_task(node_outputs)
223 node_t0 = time.perf_counter()
225 try:
226 node.state = GraphNodeState.RUNNING
227 output = self._execute_node(node.agent_type, task)
228 node.output = output
229 node.state = GraphNodeState.COMPLETED
230 node_outputs[name] = output
231 except Exception as exc:
232 node.state = GraphNodeState.FAILED
233 node.error = f"{type(exc).__name__}: {exc}"
234 node_outputs[name] = None
235 if node.on_failure == "abort":
236 abort = True
238 node.latency_ms = (time.perf_counter() - node_t0) * 1000
240 success = all(
241 n.state in (GraphNodeState.COMPLETED, GraphNodeState.SKIPPED)
242 for n in results.values()
243 )
244 total_latency = (time.perf_counter() - t0) * 1000
246 return GraphResult(
247 node_results=results,
248 execution_order=order,
249 total_latency_ms=total_latency,
250 success=success,
251 )
253 def _execute_node(self, agent_type: str, task: str) -> Any:
254 """Execute a single node. Override or provide executor callback."""
255 if self._executor:
256 return self._executor(agent_type, task)
257 raise NotImplementedError(
258 "No executor provided. Pass executor to __init__ or override _execute_node."
259 )
261 def to_mermaid(self) -> str:
262 """Export graph as Mermaid flowchart."""
263 lines = ["graph TD"]
264 for name, node in self._nodes.items():
265 safe = name.replace("-", "_").replace(" ", "_")
266 lines.append(f" {safe}[\"{name}\\n({node.agent_type})\"]")
267 for name, node in self._nodes.items():
268 safe = name.replace("-", "_").replace(" ", "_")
269 for dep in node.depends_on:
270 safe_dep = dep.replace("-", "_").replace(" ", "_")
271 lines.append(f" {safe_dep} --> {safe}")
272 return "\n".join(lines)
274 @property
275 def node_count(self) -> int:
276 return len(self._nodes)
278 @property
279 def edge_count(self) -> int:
280 return sum(len(n.depends_on) for n in self._nodes.values())
283@dataclass
284class GraphRecipe:
285 """Declarative graph definition (YAML-friendly)."""
287 name: str
288 description: str = ""
289 nodes: list[dict[str, Any]] = field(default_factory=list)
290 """List of node dicts with keys: name, agent_type, task_template, depends_on, timeout_seconds, on_failure."""
292 @classmethod
293 def from_dict(cls, data: dict) -> "GraphRecipe":
294 return cls(
295 name=data.get("name", "unnamed"),
296 description=data.get("description", ""),
297 nodes=data.get("nodes", []),
298 )
300 def build(self, executor: Optional[Callable] = None) -> AgentGraph:
301 """Build an AgentGraph from this recipe."""
302 graph = AgentGraph(executor=executor)
303 for spec in self.nodes:
304 graph.add_node(GraphNode(
305 name=spec["name"],
306 agent_type=spec.get("agent_type", "default"),
307 task_template=spec["task_template"],
308 depends_on=spec.get("depends_on", []),
309 timeout_seconds=spec.get("timeout_seconds", 120.0),
310 retry_count=spec.get("retry_count", 0),
311 on_failure=spec.get("on_failure", "abort"),
312 ))
313 return graph