Coverage for agentos/orchestration/graph.py: 31%

182 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Graph Orchestrator for NexusAgent. 

3 

4DAG-based workflow orchestration. Allows defining 

5complex workflows as graphs with nodes and edges. 

6""" 

7 

8from __future__ import annotations 

9 

10import asyncio 

11import time 

12import uuid 

13from dataclasses import dataclass, field 

14from enum import Enum 

15from typing import Any, Callable, Optional 

16 

17 

18class NodeStatus(str, Enum): 

19 """Node execution status.""" 

20 PENDING = "pending" 

21 RUNNING = "running" 

22 COMPLETED = "completed" 

23 FAILED = "failed" 

24 SKIPPED = "skipped" 

25 

26 

27@dataclass 

28class GraphNode: 

29 """ 

30 Node in execution graph. 

31 

32 Attributes: 

33 id: Unique identifier 

34 name: Node name 

35 func: Node function 

36 inputs: Input parameters 

37 outputs: Output values 

38 status: Execution status 

39 duration: Execution duration 

40 error: Error message (if failed) 

41 metadata: Additional metadata 

42 """ 

43 id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) 

44 name: str = "" 

45 func: Callable[..., Any] = None 

46 inputs: dict[str, Any] = field(default_factory=dict) 

47 outputs: dict[str, Any] = field(default_factory=dict) 

48 status: NodeStatus = NodeStatus.PENDING 

49 duration: float = 0.0 

50 error: Optional[str] = None 

51 metadata: dict[str, Any] = field(default_factory=dict) 

52 

53 def to_dict(self) -> dict[str, Any]: 

54 """Convert to dict.""" 

55 return { 

56 "id": self.id, 

57 "name": self.name, 

58 "inputs": self.inputs, 

59 "outputs": self.outputs, 

60 "status": self.status.value, 

61 "duration": self.duration, 

62 "error": self.error, 

63 "metadata": self.metadata, 

64 } 

65 

66 

67@dataclass 

68class GraphEdge: 

69 """ 

70 Edge in execution graph. 

71 

72 Attributes: 

73 source: Source node ID 

74 target: Target node ID 

75 condition: Optional condition function 

76 metadata: Additional metadata 

77 """ 

78 source: str 

79 target: str 

80 condition: Optional[Callable[[dict[str, Any]], bool]] = None 

81 metadata: dict[str, Any] = field(default_factory=dict) 

82 

83 def to_dict(self) -> dict[str, Any]: 

84 """Convert to dict.""" 

85 return { 

86 "source": self.source, 

87 "target": self.target, 

88 "metadata": self.metadata, 

89 } 

90 

91 

92@dataclass 

93class GraphResult: 

94 """ 

95 Result of graph execution. 

96 

97 Attributes: 

98 id: Unique identifier 

99 node_results: Node execution results 

100 total_duration: Total execution duration 

101 success: Whether execution succeeded 

102 error: Error message (if failed) 

103 """ 

104 id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) 

105 node_results: dict[str, dict[str, Any]] = field(default_factory=dict) 

106 total_duration: float = 0.0 

107 success: bool = True 

108 error: Optional[str] = None 

109 

110 def to_dict(self) -> dict[str, Any]: 

111 """Convert to dict.""" 

112 return { 

113 "id": self.id, 

114 "node_results": self.node_results, 

115 "total_duration": self.total_duration, 

116 "success": self.success, 

117 "error": self.error, 

118 } 

119 

120 

121class GraphOrchestrator: 

122 """ 

123 DAG-based workflow orchestrator. 

124 

125 Allows defining complex workflows as graphs: 

126 - Nodes represent tasks 

127 - Edges represent dependencies 

128 - Conditions for branching 

129 

130 Usage: 

131 orchestrator = GraphOrchestrator() 

132 

133 # Add nodes 

134 orchestrator.add_node("step1", step1_func) 

135 orchestrator.add_node("step2", step2_func) 

136 

137 # Add edges 

138 orchestrator.add_edge("step1", "step2") 

139 

140 # Execute 

141 result = await orchestrator.execute({"input": "data"}) 

142 """ 

143 

144 def __init__(self): 

145 """Initialize graph orchestrator.""" 

146 self._nodes: dict[str, GraphNode] = {} 

147 self._edges: list[GraphEdge] = [] 

148 self._start_nodes: list[str] = [] 

149 self._end_nodes: list[str] = [] 

150 

151 def add_node( 

152 self, 

153 name: str, 

154 func: Callable[..., Any], 

155 **metadata 

156 ) -> GraphNode: 

157 """ 

158 Add a node to the graph. 

159 

160 Args: 

161 name: Node name 

162 func: Node function 

163 **metadata: Additional metadata 

164 

165 Returns: 

166 Created GraphNode 

167 """ 

168 node = GraphNode( 

169 name=name, 

170 func=func, 

171 metadata=metadata, 

172 ) 

173 self._nodes[name] = node 

174 

175 # If first node, mark as start 

176 if len(self._nodes) == 1: 

177 self._start_nodes.append(name) 

178 

179 return node 

180 

181 def remove_node(self, name: str) -> bool: 

182 """ 

183 Remove a node from the graph. 

184 

185 Args: 

186 name: Node name 

187 

188 Returns: 

189 True if removed, False if not found 

190 """ 

191 if name not in self._nodes: 

192 return False 

193 

194 del self._nodes[name] 

195 

196 # Remove edges 

197 self._edges = [ 

198 e for e in self._edges 

199 if e.source != name and e.target != name 

200 ] 

201 

202 # Update start/end nodes 

203 if name in self._start_nodes: 

204 self._start_nodes.remove(name) 

205 if name in self._end_nodes: 

206 self._end_nodes.remove(name) 

207 

208 return True 

209 

210 def add_edge( 

211 self, 

212 source: str, 

213 target: str, 

214 condition: Optional[Callable[[dict[str, Any]], bool]] = None, 

215 **metadata 

216 ) -> GraphEdge: 

217 """ 

218 Add an edge to the graph. 

219 

220 Args: 

221 source: Source node name 

222 target: Target node name 

223 condition: Optional condition function 

224 **metadata: Additional metadata 

225 

226 Returns: 

227 Created GraphEdge 

228 """ 

229 if source not in self._nodes: 

230 raise ValueError(f"Source node not found: {source}") 

231 if target not in self._nodes: 

232 raise ValueError(f"Target node not found: {target}") 

233 

234 edge = GraphEdge( 

235 source=source, 

236 target=target, 

237 condition=condition, 

238 metadata=metadata, 

239 ) 

240 self._edges.append(edge) 

241 

242 # Update start/end nodes 

243 if target in self._start_nodes: 

244 self._start_nodes.remove(target) 

245 if source in self._end_nodes: 

246 self._end_nodes.remove(source) 

247 

248 if source not in [e.target for e in self._edges]: 

249 if source not in self._start_nodes: 

250 self._start_nodes.append(source) 

251 

252 if target not in [e.source for e in self._edges]: 

253 if target not in self._end_nodes: 

254 self._end_nodes.append(target) 

255 

256 return edge 

257 

258 def remove_edge(self, source: str, target: str) -> bool: 

259 """ 

260 Remove an edge from the graph. 

261 

262 Args: 

263 source: Source node name 

264 target: Target node name 

265 

266 Returns: 

267 True if removed, False if not found 

268 """ 

269 for edge in self._edges: 

270 if edge.source == source and edge.target == target: 

271 self._edges.remove(edge) 

272 return True 

273 return False 

274 

275 def get_node(self, name: str) -> Optional[GraphNode]: 

276 """ 

277 Get a node by name. 

278 

279 Args: 

280 name: Node name 

281 

282 Returns: 

283 GraphNode if found, None otherwise 

284 """ 

285 return self._nodes.get(name) 

286 

287 def list_nodes(self) -> list[str]: 

288 """ 

289 List all nodes. 

290 

291 Returns: 

292 List of node names 

293 """ 

294 return list(self._nodes.keys()) 

295 

296 def list_edges(self) -> list[tuple[str, str]]: 

297 """ 

298 List all edges. 

299 

300 Returns: 

301 List of (source, target) tuples 

302 """ 

303 return [(e.source, e.target) for e in self._edges] 

304 

305 async def execute( 

306 self, 

307 inputs: dict[str, Any], 

308 **metadata 

309 ) -> GraphResult: 

310 """ 

311 Execute the graph. 

312 

313 Args: 

314 inputs: Input parameters 

315 **metadata: Additional metadata 

316 

317 Returns: 

318 GraphResult 

319 """ 

320 start_time = time.time() 

321 result = GraphResult() 

322 

323 # Reset node status 

324 for node in self._nodes.values(): 

325 node.status = NodeStatus.PENDING 

326 node.outputs = {} 

327 node.error = None 

328 

329 # Execute start nodes 

330 try: 

331 await self._execute_nodes(self._start_nodes, inputs, result, metadata) 

332 

333 # Execute remaining nodes in topological order 

334 executed = set(self._start_nodes) 

335 while len(executed) < len(self._nodes): 

336 next_nodes = self._get_next_nodes(executed) 

337 if not next_nodes: 

338 break 

339 await self._execute_nodes(next_nodes, inputs, result, metadata) 

340 executed.update(next_nodes) 

341 

342 except Exception as e: 

343 result.success = False 

344 result.error = str(e) 

345 

346 result.total_duration = time.time() - start_time 

347 

348 return result 

349 

350 async def _execute_nodes( 

351 self, 

352 node_names: list[str], 

353 inputs: dict[str, Any], 

354 result: GraphResult, 

355 metadata: dict[str, Any], 

356 ) -> None: 

357 """Execute multiple nodes.""" 

358 tasks = [] 

359 for name in node_names: 

360 node = self._nodes.get(name) 

361 if node: 

362 tasks.append(self._execute_node(node, inputs, result, metadata)) 

363 

364 if tasks: 

365 await asyncio.gather(*tasks, return_exceptions=True) 

366 

367 async def _execute_node( 

368 self, 

369 node: GraphNode, 

370 inputs: dict[str, Any], 

371 result: GraphResult, 

372 metadata: dict[str, Any], 

373 ) -> None: 

374 """Execute a single node.""" 

375 # Check conditions 

376 for edge in self._edges: 

377 if edge.target == node.name and edge.condition: 

378 if not edge.condition(inputs): 

379 node.status = NodeStatus.SKIPPED 

380 result.node_results[node.name] = node.to_dict() 

381 return 

382 

383 # Execute node 

384 node.status = NodeStatus.RUNNING 

385 start_time = time.time() 

386 

387 try: 

388 if asyncio.iscoroutinefunction(node.func): 

389 output = await node.func(**inputs, **metadata) 

390 else: 

391 output = await asyncio.get_event_loop().run_in_executor( 

392 None, lambda: node.func(**inputs, **metadata) 

393 ) 

394 

395 node.outputs = output if isinstance(output, dict) else {"result": output} 

396 node.status = NodeStatus.COMPLETED 

397 node.duration = time.time() - start_time 

398 

399 except Exception as e: 

400 node.status = NodeStatus.FAILED 

401 node.error = str(e) 

402 node.duration = time.time() - start_time 

403 result.success = False 

404 

405 result.node_results[node.name] = node.to_dict() 

406 

407 # Update inputs for next nodes 

408 inputs.update(node.outputs) 

409 

410 def _get_next_nodes(self, executed: set[str]) -> list[str]: 

411 """Get next nodes to execute.""" 

412 next_nodes = [] 

413 

414 for edge in self._edges: 

415 if edge.source in executed and edge.target not in executed: 

416 # Check if all dependencies are executed 

417 deps = [e.source for e in self._edges if e.target == edge.target] 

418 if all(d in executed for d in deps): 

419 next_nodes.append(edge.target) 

420 

421 return next_nodes 

422 

423 def get_execution_order(self) -> list[str]: 

424 """ 

425 Get topological execution order. 

426 

427 Returns: 

428 List of node names in execution order 

429 """ 

430 order = [] 

431 visited = set() 

432 

433 def visit(node_name: str): 

434 if node_name in visited: 

435 return 

436 visited.add(node_name) 

437 

438 # Visit dependencies first 

439 for edge in self._edges: 

440 if edge.target == node_name: 

441 visit(edge.source) 

442 

443 order.append(node_name) 

444 

445 for node_name in self._nodes.keys(): 

446 visit(node_name) 

447 

448 return order 

449 

450 def validate(self) -> bool: 

451 """ 

452 Validate the graph. 

453 

454 Returns: 

455 True if valid, False otherwise 

456 """ 

457 # Check for cycles 

458 try: 

459 self.get_execution_order() 

460 except Exception: 

461 return False 

462 

463 # Check for disconnected nodes 

464 if not self._start_nodes or not self._end_nodes: 

465 return False 

466 

467 return True 

468 

469 def clear(self) -> None: 

470 """Clear the graph.""" 

471 self._nodes.clear() 

472 self._edges.clear() 

473 self._start_nodes.clear() 

474 self._end_nodes.clear()