Coverage for agentos/orchestration/graph_executor.py: 32%
184 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 13:55 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 13:55 +0800
1"""
2Agent Graph — DAG-based multi-agent execution engine.
4Build complex agent pipelines as directed acyclic graphs where each node
5is an agent invocation and edges define data flow dependencies.
7v1.16.1: Integrated MetricsCollector for per-node execution tracking.
8"""
10from __future__ import annotations
12import time
13from collections import deque
14from dataclasses import dataclass, field
15from enum import Enum
16from typing import Any, Callable, Optional
18from agentos.tools.metrics import MetricsCollector
21class GraphNodeState(Enum):
22 """Execution state of a graph orchestrator node."""
23 PENDING = "pending"
24 RUNNING = "running"
25 COMPLETED = "completed"
26 FAILED = "failed"
27 SKIPPED = "skipped"
30@dataclass
31class GraphNode:
32 """A single node in the agent execution graph."""
34 name: str
35 agent_type: str
36 task_template: str
37 """Template string with {input} or {node_name.output} placeholders."""
39 depends_on: list[str] = field(default_factory=list)
40 """Node names this node depends on."""
42 timeout_seconds: float = 120.0
43 retry_count: int = 0
44 on_failure: str = "abort"
45 """Action on failure: 'abort', 'skip', 'continue'."""
47 state: GraphNodeState = GraphNodeState.PENDING
48 output: Any = None
49 error: Optional[str] = None
50 latency_ms: float = 0.0
52 def resolve_task(self, node_outputs: dict[str, Any]) -> str:
53 """Resolve template placeholders using outputs from completed nodes."""
54 task = self.task_template
55 task = task.replace("{input}", str(node_outputs.get("__input__", "")))
56 for name, output in node_outputs.items():
57 placeholder = f"{{{name}.output}}"
58 if placeholder in task:
59 task = task.replace(placeholder, str(output))
60 return task
63@dataclass
64class GraphResult:
65 """Result of graph execution."""
67 node_results: dict[str, GraphNode] = field(default_factory=dict)
68 execution_order: list[str] = field(default_factory=list)
69 total_latency_ms: float = 0.0
70 success: bool = True
71 error: Optional[str] = None
74class AgentGraph:
75 """
76 DAG-based multi-agent execution engine.
78 Define execution graphs declaratively, resolve dependencies automatically,
79 execute nodes in topological order with parallelism for independent nodes.
81 Example::
83 graph = AgentGraph()
84 graph.add_node(GraphNode(
85 name="research",
86 agent_type="researcher",
87 task_template="Research: {input}"
88 ))
89 graph.add_node(GraphNode(
90 name="summarize",
91 agent_type="summarizer",
92 task_template="Summarize: {research.output}",
93 depends_on=["research"]
94 ))
95 result = graph.execute("quantum computing advances")
96 """
98 def __init__(self, executor: Optional[Callable[[str, str], Any]] = None, metrics: Optional[MetricsCollector] = None):
99 """
100 Args:
101 executor: Callable(agent_type, task) -> output. If not provided,
102 subclasses must override _execute_node.
103 metrics: Optional MetricsCollector for per-node execution tracking.
104 """
105 self._nodes: dict[str, GraphNode] = {}
106 self._executor = executor
107 self._metrics = metrics
109 def add_node(self, node: GraphNode) -> None:
110 """Add a node to the graph. Raises ValueError on duplicate name."""
111 if node.name in self._nodes:
112 raise ValueError(f"Duplicate node name: {node.name}")
113 self._nodes[node.name] = node
115 def remove_node(self, name: str) -> None:
116 """Remove a node and all edges referencing it."""
117 if name not in self._nodes:
118 raise KeyError(f"Node not found: {name}")
119 del self._nodes[name]
120 for node in self._nodes.values():
121 node.depends_on = [d for d in node.depends_on if d != name]
123 def validate(self) -> list[str]:
124 """
125 Validate graph integrity.
127 Returns:
128 List of error messages (empty if valid).
129 """
130 errors: list[str] = []
132 for name, node in self._nodes.items():
133 for dep in node.depends_on:
134 if dep not in self._nodes:
135 errors.append(f"Node '{name}' depends on unknown node '{dep}'")
136 if dep == name:
137 errors.append(f"Node '{name}' cannot depend on itself")
139 # Check for cycles using topological sort
140 if not errors:
141 try:
142 self._topological_order()
143 except ValueError as e:
144 errors.append(str(e))
146 return errors
148 def _topological_order(self) -> list[str]:
149 """Return nodes in topological order. Raises ValueError on cycle."""
150 in_degree: dict[str, int] = {name: 0 for name in self._nodes}
151 adjacency: dict[str, list[str]] = {name: [] for name in self._nodes}
153 for name, node in self._nodes.items():
154 for dep in node.depends_on:
155 adjacency[dep].append(name)
156 in_degree[name] += 1
158 queue = deque([name for name, deg in in_degree.items() if deg == 0])
159 order: list[str] = []
161 while queue:
162 current = queue.popleft()
163 order.append(current)
164 for neighbor in adjacency[current]:
165 in_degree[neighbor] -= 1
166 if in_degree[neighbor] == 0:
167 queue.append(neighbor)
169 if len(order) != len(self._nodes):
170 remaining = set(self._nodes) - set(order)
171 raise ValueError(f"Cycle detected involving nodes: {remaining}")
173 return order
175 def execute(self, input_data: str) -> GraphResult:
176 """
177 Execute the graph with given input.
179 Args:
180 input_data: Initial task input, accessible as {input} in templates.
182 Returns:
183 GraphResult with per-node outputs and execution metadata.
184 """
185 errors = self.validate()
186 if errors:
187 return GraphResult(success=False, error="; ".join(errors))
189 t0 = time.perf_counter()
190 node_outputs: dict[str, Any] = {"__input__": input_data}
191 results: dict[str, GraphNode] = {}
192 order: list[str] = []
194 # Reset all nodes
195 for node in self._nodes.values():
196 node.state = GraphNodeState.PENDING
197 node.output = None
198 node.error = None
199 node.latency_ms = 0.0
201 try:
202 topo = self._topological_order()
203 except ValueError as e:
204 return GraphResult(success=False, error=str(e))
206 abort = False
207 for name in topo:
208 if abort:
209 self._nodes[name].state = GraphNodeState.SKIPPED
210 results[name] = self._nodes[name]
211 continue
213 node = self._nodes[name]
214 results[name] = node
215 order.append(name)
217 # Check dependencies
218 deps_failed = False
219 for dep in node.depends_on:
220 if results[dep].state == GraphNodeState.FAILED:
221 deps_failed = True
222 break
224 if deps_failed:
225 node.state = GraphNodeState.SKIPPED
226 continue
228 task = node.resolve_task(node_outputs)
229 node_t0 = time.perf_counter()
231 try:
232 node.state = GraphNodeState.RUNNING
233 output = self._execute_node(node.agent_type, task)
234 node.output = output
235 node.state = GraphNodeState.COMPLETED
236 node_outputs[name] = output
237 except Exception as exc:
238 node.state = GraphNodeState.FAILED
239 node.error = f"{type(exc).__name__}: {exc}"
240 node_outputs[name] = None
241 if node.on_failure == "abort":
242 abort = True
244 node.latency_ms = (time.perf_counter() - node_t0) * 1000
246 # Metrics: track per-node execution
247 if self._metrics is not None:
248 self._metrics.get_counter("graph_nodes_total").inc(node.agent_type)
249 self._metrics.get_counter(
250 f"graph_node_{node.state.value}"
251 ).inc(node.agent_type)
252 self._metrics.get_timer("graph_node_latency_ms").record(node.latency_ms)
254 success = all(
255 n.state in (GraphNodeState.COMPLETED, GraphNodeState.SKIPPED)
256 for n in results.values()
257 )
258 total_latency = (time.perf_counter() - t0) * 1000
260 return GraphResult(
261 node_results=results,
262 execution_order=order,
263 total_latency_ms=total_latency,
264 success=success,
265 )
267 def _execute_node(self, agent_type: str, task: str) -> Any:
268 """Execute a single node. Override or provide executor callback."""
269 if self._executor:
270 return self._executor(agent_type, task)
271 raise NotImplementedError(
272 "No executor provided. Pass executor to __init__ or override _execute_node."
273 )
275 def to_mermaid(self) -> str:
276 """Export graph as Mermaid flowchart."""
277 lines = ["graph TD"]
278 for name, node in self._nodes.items():
279 safe = name.replace("-", "_").replace(" ", "_")
280 lines.append(f" {safe}[\"{name}\\n({node.agent_type})\"]")
281 for name, node in self._nodes.items():
282 safe = name.replace("-", "_").replace(" ", "_")
283 for dep in node.depends_on:
284 safe_dep = dep.replace("-", "_").replace(" ", "_")
285 lines.append(f" {safe_dep} --> {safe}")
286 return "\n".join(lines)
288 @property
289 def node_count(self) -> int:
290 return len(self._nodes)
292 @property
293 def edge_count(self) -> int:
294 return sum(len(n.depends_on) for n in self._nodes.values())
297@dataclass
298class GraphRecipe:
299 """Declarative graph definition (YAML-friendly)."""
301 name: str
302 description: str = ""
303 nodes: list[dict[str, Any]] = field(default_factory=list)
304 """List of node dicts with keys: name, agent_type, task_template, depends_on, timeout_seconds, on_failure."""
306 @classmethod
307 def from_dict(cls, data: dict) -> "GraphRecipe":
308 return cls(
309 name=data.get("name", "unnamed"),
310 description=data.get("description", ""),
311 nodes=data.get("nodes", []),
312 )
314 def build(self, executor: Optional[Callable] = None) -> AgentGraph:
315 """Build an AgentGraph from this recipe."""
316 graph = AgentGraph(executor=executor)
317 for spec in self.nodes:
318 graph.add_node(GraphNode(
319 name=spec["name"],
320 agent_type=spec.get("agent_type", "default"),
321 task_template=spec["task_template"],
322 depends_on=spec.get("depends_on", []),
323 timeout_seconds=spec.get("timeout_seconds", 120.0),
324 retry_count=spec.get("retry_count", 0),
325 on_failure=spec.get("on_failure", "abort"),
326 ))
327 return graph