Coverage for agentos/orchestration/graph.py: 31%
182 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Graph Orchestrator for NexusAgent.
4DAG-based workflow orchestration. Allows defining
5complex workflows as graphs with nodes and edges.
6"""
8from __future__ import annotations
10import asyncio
11import time
12import uuid
13from dataclasses import dataclass, field
14from enum import Enum
15from typing import Any, Callable, Optional
18class NodeStatus(str, Enum):
19 """Node execution status."""
20 PENDING = "pending"
21 RUNNING = "running"
22 COMPLETED = "completed"
23 FAILED = "failed"
24 SKIPPED = "skipped"
27@dataclass
28class GraphNode:
29 """
30 Node in execution graph.
32 Attributes:
33 id: Unique identifier
34 name: Node name
35 func: Node function
36 inputs: Input parameters
37 outputs: Output values
38 status: Execution status
39 duration: Execution duration
40 error: Error message (if failed)
41 metadata: Additional metadata
42 """
43 id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
44 name: str = ""
45 func: Callable[..., Any] = None
46 inputs: dict[str, Any] = field(default_factory=dict)
47 outputs: dict[str, Any] = field(default_factory=dict)
48 status: NodeStatus = NodeStatus.PENDING
49 duration: float = 0.0
50 error: Optional[str] = None
51 metadata: dict[str, Any] = field(default_factory=dict)
53 def to_dict(self) -> dict[str, Any]:
54 """Convert to dict."""
55 return {
56 "id": self.id,
57 "name": self.name,
58 "inputs": self.inputs,
59 "outputs": self.outputs,
60 "status": self.status.value,
61 "duration": self.duration,
62 "error": self.error,
63 "metadata": self.metadata,
64 }
67@dataclass
68class GraphEdge:
69 """
70 Edge in execution graph.
72 Attributes:
73 source: Source node ID
74 target: Target node ID
75 condition: Optional condition function
76 metadata: Additional metadata
77 """
78 source: str
79 target: str
80 condition: Optional[Callable[[dict[str, Any]], bool]] = None
81 metadata: dict[str, Any] = field(default_factory=dict)
83 def to_dict(self) -> dict[str, Any]:
84 """Convert to dict."""
85 return {
86 "source": self.source,
87 "target": self.target,
88 "metadata": self.metadata,
89 }
92@dataclass
93class GraphResult:
94 """
95 Result of graph execution.
97 Attributes:
98 id: Unique identifier
99 node_results: Node execution results
100 total_duration: Total execution duration
101 success: Whether execution succeeded
102 error: Error message (if failed)
103 """
104 id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
105 node_results: dict[str, dict[str, Any]] = field(default_factory=dict)
106 total_duration: float = 0.0
107 success: bool = True
108 error: Optional[str] = None
110 def to_dict(self) -> dict[str, Any]:
111 """Convert to dict."""
112 return {
113 "id": self.id,
114 "node_results": self.node_results,
115 "total_duration": self.total_duration,
116 "success": self.success,
117 "error": self.error,
118 }
121class GraphOrchestrator:
122 """
123 DAG-based workflow orchestrator.
125 Allows defining complex workflows as graphs:
126 - Nodes represent tasks
127 - Edges represent dependencies
128 - Conditions for branching
130 Usage:
131 orchestrator = GraphOrchestrator()
133 # Add nodes
134 orchestrator.add_node("step1", step1_func)
135 orchestrator.add_node("step2", step2_func)
137 # Add edges
138 orchestrator.add_edge("step1", "step2")
140 # Execute
141 result = await orchestrator.execute({"input": "data"})
142 """
144 def __init__(self):
145 """Initialize graph orchestrator."""
146 self._nodes: dict[str, GraphNode] = {}
147 self._edges: list[GraphEdge] = []
148 self._start_nodes: list[str] = []
149 self._end_nodes: list[str] = []
151 def add_node(
152 self,
153 name: str,
154 func: Callable[..., Any],
155 **metadata
156 ) -> GraphNode:
157 """
158 Add a node to the graph.
160 Args:
161 name: Node name
162 func: Node function
163 **metadata: Additional metadata
165 Returns:
166 Created GraphNode
167 """
168 node = GraphNode(
169 name=name,
170 func=func,
171 metadata=metadata,
172 )
173 self._nodes[name] = node
175 # If first node, mark as start
176 if len(self._nodes) == 1:
177 self._start_nodes.append(name)
179 return node
181 def remove_node(self, name: str) -> bool:
182 """
183 Remove a node from the graph.
185 Args:
186 name: Node name
188 Returns:
189 True if removed, False if not found
190 """
191 if name not in self._nodes:
192 return False
194 del self._nodes[name]
196 # Remove edges
197 self._edges = [
198 e for e in self._edges
199 if e.source != name and e.target != name
200 ]
202 # Update start/end nodes
203 if name in self._start_nodes:
204 self._start_nodes.remove(name)
205 if name in self._end_nodes:
206 self._end_nodes.remove(name)
208 return True
210 def add_edge(
211 self,
212 source: str,
213 target: str,
214 condition: Optional[Callable[[dict[str, Any]], bool]] = None,
215 **metadata
216 ) -> GraphEdge:
217 """
218 Add an edge to the graph.
220 Args:
221 source: Source node name
222 target: Target node name
223 condition: Optional condition function
224 **metadata: Additional metadata
226 Returns:
227 Created GraphEdge
228 """
229 if source not in self._nodes:
230 raise ValueError(f"Source node not found: {source}")
231 if target not in self._nodes:
232 raise ValueError(f"Target node not found: {target}")
234 edge = GraphEdge(
235 source=source,
236 target=target,
237 condition=condition,
238 metadata=metadata,
239 )
240 self._edges.append(edge)
242 # Update start/end nodes
243 if target in self._start_nodes:
244 self._start_nodes.remove(target)
245 if source in self._end_nodes:
246 self._end_nodes.remove(source)
248 if source not in [e.target for e in self._edges]:
249 if source not in self._start_nodes:
250 self._start_nodes.append(source)
252 if target not in [e.source for e in self._edges]:
253 if target not in self._end_nodes:
254 self._end_nodes.append(target)
256 return edge
258 def remove_edge(self, source: str, target: str) -> bool:
259 """
260 Remove an edge from the graph.
262 Args:
263 source: Source node name
264 target: Target node name
266 Returns:
267 True if removed, False if not found
268 """
269 for edge in self._edges:
270 if edge.source == source and edge.target == target:
271 self._edges.remove(edge)
272 return True
273 return False
275 def get_node(self, name: str) -> Optional[GraphNode]:
276 """
277 Get a node by name.
279 Args:
280 name: Node name
282 Returns:
283 GraphNode if found, None otherwise
284 """
285 return self._nodes.get(name)
287 def list_nodes(self) -> list[str]:
288 """
289 List all nodes.
291 Returns:
292 List of node names
293 """
294 return list(self._nodes.keys())
296 def list_edges(self) -> list[tuple[str, str]]:
297 """
298 List all edges.
300 Returns:
301 List of (source, target) tuples
302 """
303 return [(e.source, e.target) for e in self._edges]
305 async def execute(
306 self,
307 inputs: dict[str, Any],
308 **metadata
309 ) -> GraphResult:
310 """
311 Execute the graph.
313 Args:
314 inputs: Input parameters
315 **metadata: Additional metadata
317 Returns:
318 GraphResult
319 """
320 start_time = time.time()
321 result = GraphResult()
323 # Reset node status
324 for node in self._nodes.values():
325 node.status = NodeStatus.PENDING
326 node.outputs = {}
327 node.error = None
329 # Execute start nodes
330 try:
331 await self._execute_nodes(self._start_nodes, inputs, result, metadata)
333 # Execute remaining nodes in topological order
334 executed = set(self._start_nodes)
335 while len(executed) < len(self._nodes):
336 next_nodes = self._get_next_nodes(executed)
337 if not next_nodes:
338 break
339 await self._execute_nodes(next_nodes, inputs, result, metadata)
340 executed.update(next_nodes)
342 except Exception as e:
343 result.success = False
344 result.error = str(e)
346 result.total_duration = time.time() - start_time
348 return result
350 async def _execute_nodes(
351 self,
352 node_names: list[str],
353 inputs: dict[str, Any],
354 result: GraphResult,
355 metadata: dict[str, Any],
356 ) -> None:
357 """Execute multiple nodes."""
358 tasks = []
359 for name in node_names:
360 node = self._nodes.get(name)
361 if node:
362 tasks.append(self._execute_node(node, inputs, result, metadata))
364 if tasks:
365 await asyncio.gather(*tasks, return_exceptions=True)
367 async def _execute_node(
368 self,
369 node: GraphNode,
370 inputs: dict[str, Any],
371 result: GraphResult,
372 metadata: dict[str, Any],
373 ) -> None:
374 """Execute a single node."""
375 # Check conditions
376 for edge in self._edges:
377 if edge.target == node.name and edge.condition:
378 if not edge.condition(inputs):
379 node.status = NodeStatus.SKIPPED
380 result.node_results[node.name] = node.to_dict()
381 return
383 # Execute node
384 node.status = NodeStatus.RUNNING
385 start_time = time.time()
387 try:
388 if asyncio.iscoroutinefunction(node.func):
389 output = await node.func(**inputs, **metadata)
390 else:
391 output = await asyncio.get_event_loop().run_in_executor(
392 None, lambda: node.func(**inputs, **metadata)
393 )
395 node.outputs = output if isinstance(output, dict) else {"result": output}
396 node.status = NodeStatus.COMPLETED
397 node.duration = time.time() - start_time
399 except Exception as e:
400 node.status = NodeStatus.FAILED
401 node.error = str(e)
402 node.duration = time.time() - start_time
403 result.success = False
405 result.node_results[node.name] = node.to_dict()
407 # Update inputs for next nodes
408 inputs.update(node.outputs)
410 def _get_next_nodes(self, executed: set[str]) -> list[str]:
411 """Get next nodes to execute."""
412 next_nodes = []
414 for edge in self._edges:
415 if edge.source in executed and edge.target not in executed:
416 # Check if all dependencies are executed
417 deps = [e.source for e in self._edges if e.target == edge.target]
418 if all(d in executed for d in deps):
419 next_nodes.append(edge.target)
421 return next_nodes
423 def get_execution_order(self) -> list[str]:
424 """
425 Get topological execution order.
427 Returns:
428 List of node names in execution order
429 """
430 order = []
431 visited = set()
433 def visit(node_name: str):
434 if node_name in visited:
435 return
436 visited.add(node_name)
438 # Visit dependencies first
439 for edge in self._edges:
440 if edge.target == node_name:
441 visit(edge.source)
443 order.append(node_name)
445 for node_name in self._nodes.keys():
446 visit(node_name)
448 return order
450 def validate(self) -> bool:
451 """
452 Validate the graph.
454 Returns:
455 True if valid, False otherwise
456 """
457 # Check for cycles
458 try:
459 self.get_execution_order()
460 except Exception:
461 return False
463 # Check for disconnected nodes
464 if not self._start_nodes or not self._end_nodes:
465 return False
467 return True
469 def clear(self) -> None:
470 """Clear the graph."""
471 self._nodes.clear()
472 self._edges.clear()
473 self._start_nodes.clear()
474 self._end_nodes.clear()