Coverage for src / tracekit / workflow / dag.py: 100%

133 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Directed Acyclic Graph (DAG) execution for complex workflows. 

2 

3This module provides DAG-based workflow execution with automatic dependency 

4resolution and parallel execution of independent tasks. 

5""" 

6 

7from __future__ import annotations 

8 

9from collections import defaultdict, deque 

10from concurrent.futures import ThreadPoolExecutor, as_completed 

11from dataclasses import dataclass, field 

12from typing import TYPE_CHECKING, Any 

13 

14from tracekit.core.exceptions import AnalysisError 

15 

16if TYPE_CHECKING: 

17 from collections.abc import Callable 

18 

19 

20@dataclass 

21class TaskNode: 

22 """Node in a workflow DAG. 

23 

24 Attributes: 

25 name: Unique task name. 

26 func: Callable function to execute. 

27 depends_on: List of task names this task depends on. 

28 result: Computed result (set after execution). 

29 completed: Whether task has been executed. 

30 

31 Example: 

32 >>> def compute_fft(state): 

33 ... return {'fft': np.fft.fft(state['trace'])} 

34 >>> node = TaskNode(name='fft', func=compute_fft, depends_on=['load']) 

35 

36 References: 

37 API-013: DAG Execution 

38 """ 

39 

40 name: str 

41 func: Callable[[dict[str, Any]], Any] 

42 depends_on: list[str] = field(default_factory=list) 

43 result: Any = None 

44 completed: bool = False 

45 

46 

47class WorkflowDAG: 

48 """Directed Acyclic Graph for workflow execution. 

49 

50 Manages task dependencies and executes tasks in topological order 

51 with automatic parallelization of independent tasks. 

52 

53 Example: 

54 >>> from tracekit.workflow.dag import WorkflowDAG 

55 >>> dag = WorkflowDAG() 

56 >>> dag.add_task('load', load_trace, depends_on=[]) 

57 >>> dag.add_task('fft', compute_fft, depends_on=['load']) 

58 >>> dag.add_task('rise_time', compute_rise_time, depends_on=['load']) 

59 >>> dag.add_task('enob', compute_enob, depends_on=['fft', 'rise_time']) 

60 >>> results = dag.execute() 

61 

62 References: 

63 API-013: DAG Execution for Complex Workflows 

64 """ 

65 

66 def __init__(self) -> None: 

67 """Initialize empty DAG.""" 

68 self.tasks: dict[str, TaskNode] = {} 

69 self._adjacency: dict[str, list[str]] = defaultdict(list) 

70 

71 def add_task( 

72 self, 

73 name: str, 

74 func: Callable[[dict[str, Any]], Any], 

75 depends_on: list[str] | None = None, 

76 ) -> None: 

77 """Add a task to the DAG. 

78 

79 Args: 

80 name: Unique name for the task. 

81 func: Function to execute. Should accept state dict and return result. 

82 depends_on: List of task names this task depends on. 

83 

84 Raises: 

85 AnalysisError: If task name already exists or creates a cycle. 

86 

87 Example: 

88 >>> dag.add_task('fft', compute_fft, depends_on=['load']) 

89 

90 References: 

91 API-013: DAG Execution 

92 """ 

93 if name in self.tasks: 

94 raise AnalysisError(f"Task '{name}' already exists in DAG") 

95 

96 depends_on = depends_on or [] 

97 

98 # Verify dependencies exist 

99 for dep in depends_on: 

100 if dep not in self.tasks: 

101 raise AnalysisError(f"Dependency '{dep}' not found for task '{name}'") 

102 

103 # Create task node 

104 task = TaskNode(name=name, func=func, depends_on=depends_on) 

105 self.tasks[name] = task 

106 

107 # Update adjacency list 

108 for dep in depends_on: 

109 self._adjacency[dep].append(name) 

110 

111 # Check for cycles 

112 if self._has_cycle(): 

113 # Rollback - remove task 

114 del self.tasks[name] 

115 for dep in depends_on: 

116 self._adjacency[dep].remove(name) 

117 raise AnalysisError(f"Adding task '{name}' would create a cycle in DAG") 

118 

119 def _has_cycle(self) -> bool: 

120 """Check if DAG contains a cycle. 

121 

122 Returns: 

123 True if cycle detected. 

124 """ 

125 # Use DFS to detect cycles 

126 visited: set[str] = set() 

127 rec_stack: set[str] = set() 

128 

129 def dfs(node: str) -> bool: 

130 visited.add(node) 

131 rec_stack.add(node) 

132 

133 for neighbor in self._adjacency.get(node, []): 

134 if neighbor not in visited: 

135 if dfs(neighbor): 

136 return True 

137 elif neighbor in rec_stack: 

138 return True 

139 

140 rec_stack.remove(node) 

141 return False 

142 

143 return any(task_name not in visited and dfs(task_name) for task_name in self.tasks) 

144 

145 def _topological_sort(self) -> list[list[str]]: 

146 """Compute topological sort grouped by execution level. 

147 

148 Tasks at the same level can be executed in parallel. 

149 

150 Returns: 

151 List of levels, where each level is a list of task names. 

152 

153 Raises: 

154 AnalysisError: If DAG contains a cycle or unreachable tasks. 

155 

156 Example: 

157 >>> levels = dag._topological_sort() 

158 >>> # [[load], [fft, rise_time], [enob]] 

159 """ 

160 # Compute in-degree for each node 

161 in_degree = {name: len(task.depends_on) for name, task in self.tasks.items()} 

162 

163 # Find nodes with no dependencies (level 0) 

164 levels: list[list[str]] = [] 

165 queue = deque([name for name, degree in in_degree.items() if degree == 0]) 

166 

167 while queue: 

168 # All tasks at this level can run in parallel 

169 current_level = list(queue) 

170 levels.append(current_level) 

171 queue.clear() 

172 

173 # Process current level 

174 for task_name in current_level: 

175 # Reduce in-degree of dependent tasks 

176 for dependent in self._adjacency.get(task_name, []): 

177 in_degree[dependent] -= 1 

178 if in_degree[dependent] == 0: 

179 queue.append(dependent) 

180 

181 # Verify all tasks were included 

182 total_tasks = sum(len(level) for level in levels) 

183 if total_tasks != len(self.tasks): 

184 raise AnalysisError("DAG contains a cycle or unreachable tasks") 

185 

186 return levels 

187 

188 def execute( 

189 self, 

190 *, 

191 initial_state: dict[str, Any] | None = None, 

192 parallel: bool = True, 

193 max_workers: int | None = None, 

194 ) -> dict[str, Any]: 

195 """Execute the workflow DAG. 

196 

197 Tasks are executed in topological order with automatic parallelization 

198 of independent tasks at each level. 

199 

200 Args: 

201 initial_state: Initial state dictionary passed to first tasks. 

202 parallel: Enable parallel execution of independent tasks. 

203 max_workers: Maximum number of parallel workers. None uses CPU count. 

204 

205 Returns: 

206 Final state dictionary containing all task results. 

207 

208 Example: 

209 >>> results = dag.execute(initial_state={'trace': trace_data}) 

210 >>> print(results['enob']) 

211 

212 References: 

213 API-013: DAG Execution 

214 """ 

215 if not self.tasks: 

216 return initial_state or {} 

217 

218 state = initial_state or {} 

219 levels = self._topological_sort() 

220 

221 for level in levels: 

222 if parallel and len(level) > 1: 

223 # Execute level in parallel 

224 self._execute_level_parallel(level, state, max_workers) 

225 else: 

226 # Execute level sequentially 

227 self._execute_level_sequential(level, state) 

228 

229 return state 

230 

231 def _execute_level_sequential(self, level: list[str], state: dict[str, Any]) -> None: 

232 """Execute a level of tasks sequentially. 

233 

234 Args: 

235 level: List of task names to execute. 

236 state: Shared state dictionary. 

237 

238 Raises: 

239 AnalysisError: If task execution fails. 

240 """ 

241 for task_name in level: 

242 task = self.tasks[task_name] 

243 try: 

244 result = task.func(state) 

245 task.result = result 

246 task.completed = True 

247 

248 # Update state with result 

249 if isinstance(result, dict): 

250 state.update(result) 

251 else: 

252 state[task_name] = result 

253 

254 except Exception as e: 

255 raise AnalysisError(f"Task '{task_name}' failed: {e}") from e 

256 

257 def _execute_level_parallel( 

258 self, level: list[str], state: dict[str, Any], max_workers: int | None 

259 ) -> None: 

260 """Execute a level of tasks in parallel. 

261 

262 Args: 

263 level: List of task names to execute in parallel. 

264 state: Shared state dictionary. 

265 max_workers: Maximum number of workers. 

266 

267 Raises: 

268 AnalysisError: If task execution fails. 

269 """ 

270 with ThreadPoolExecutor(max_workers=max_workers) as executor: 

271 # Submit all tasks 

272 future_to_task = {executor.submit(self.tasks[name].func, state): name for name in level} 

273 

274 # Collect results 

275 for future in as_completed(future_to_task): 

276 task_name = future_to_task[future] 

277 task = self.tasks[task_name] 

278 

279 try: 

280 result = future.result() 

281 task.result = result 

282 task.completed = True 

283 

284 # Update state with result 

285 if isinstance(result, dict): 

286 state.update(result) 

287 else: 

288 state[task_name] = result 

289 

290 except Exception as e: 

291 raise AnalysisError(f"Task '{task_name}' failed: {e}") from e 

292 

293 def get_result(self, task_name: str) -> Any: 

294 """Get result from a completed task. 

295 

296 Args: 

297 task_name: Name of the task. 

298 

299 Returns: 

300 Task result. 

301 

302 Raises: 

303 AnalysisError: If task doesn't exist or hasn't been executed. 

304 

305 Example: 

306 >>> fft_result = dag.get_result('fft') 

307 """ 

308 if task_name not in self.tasks: 

309 raise AnalysisError(f"Task '{task_name}' not found in DAG") 

310 

311 task = self.tasks[task_name] 

312 if not task.completed: 

313 raise AnalysisError(f"Task '{task_name}' has not been executed yet") 

314 

315 return task.result 

316 

317 def reset(self) -> None: 

318 """Reset all task completion states. 

319 

320 Allows re-execution of the DAG with different initial state. 

321 

322 Example: 

323 >>> dag.reset() 

324 >>> results = dag.execute(initial_state={'trace': new_trace}) 

325 """ 

326 for task in self.tasks.values(): 

327 task.completed = False 

328 task.result = None 

329 

330 def to_graphviz(self) -> str: 

331 """Generate Graphviz DOT representation of the DAG. 

332 

333 Returns: 

334 DOT format string for visualization. 

335 

336 Example: 

337 >>> dot = dag.to_graphviz() 

338 >>> with open('workflow.dot', 'w') as f: 

339 ... f.write(dot) 

340 >>> # Then: dot -Tpng workflow.dot -o workflow.png 

341 

342 References: 

343 API-013: DAG Execution 

344 """ 

345 lines = ["digraph WorkflowDAG {", " rankdir=LR;", " node [shape=box];", ""] 

346 

347 # Add nodes 

348 for task_name, task in self.tasks.items(): 

349 style = "filled,bold" if task.completed else "filled" 

350 color = "lightgreen" if task.completed else "lightblue" 

351 lines.append(f' "{task_name}" [style="{style}", fillcolor="{color}"];') 

352 

353 lines.append("") 

354 

355 # Add edges 

356 for task_name, task in self.tasks.items(): 

357 for dep in task.depends_on: 

358 lines.append(f' "{dep}" -> "{task_name}";') 

359 

360 lines.append("}") 

361 return "\n".join(lines) 

362 

363 def __repr__(self) -> str: 

364 """String representation of DAG.""" 

365 return f"WorkflowDAG(tasks={len(self.tasks)})" 

366 

367 def __str__(self) -> str: 

368 """Detailed string representation.""" 

369 lines = [f"WorkflowDAG with {len(self.tasks)} tasks:"] 

370 for task_name, task in self.tasks.items(): 

371 deps = ", ".join(task.depends_on) if task.depends_on else "none" 

372 status = "✓" if task.completed else "○" 

373 lines.append(f" {status} {task_name} (depends on: {deps})") 

374 return "\n".join(lines) 

375 

376 

377__all__ = ["TaskNode", "WorkflowDAG"]