Coverage for src / tracekit / workflow / dag.py: 100%
133 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Directed Acyclic Graph (DAG) execution for complex workflows.
3This module provides DAG-based workflow execution with automatic dependency
4resolution and parallel execution of independent tasks.
5"""
7from __future__ import annotations
9from collections import defaultdict, deque
10from concurrent.futures import ThreadPoolExecutor, as_completed
11from dataclasses import dataclass, field
12from typing import TYPE_CHECKING, Any
14from tracekit.core.exceptions import AnalysisError
16if TYPE_CHECKING:
17 from collections.abc import Callable
20@dataclass
21class TaskNode:
22 """Node in a workflow DAG.
24 Attributes:
25 name: Unique task name.
26 func: Callable function to execute.
27 depends_on: List of task names this task depends on.
28 result: Computed result (set after execution).
29 completed: Whether task has been executed.
31 Example:
32 >>> def compute_fft(state):
33 ... return {'fft': np.fft.fft(state['trace'])}
34 >>> node = TaskNode(name='fft', func=compute_fft, depends_on=['load'])
36 References:
37 API-013: DAG Execution
38 """
40 name: str
41 func: Callable[[dict[str, Any]], Any]
42 depends_on: list[str] = field(default_factory=list)
43 result: Any = None
44 completed: bool = False
47class WorkflowDAG:
48 """Directed Acyclic Graph for workflow execution.
50 Manages task dependencies and executes tasks in topological order
51 with automatic parallelization of independent tasks.
53 Example:
54 >>> from tracekit.workflow.dag import WorkflowDAG
55 >>> dag = WorkflowDAG()
56 >>> dag.add_task('load', load_trace, depends_on=[])
57 >>> dag.add_task('fft', compute_fft, depends_on=['load'])
58 >>> dag.add_task('rise_time', compute_rise_time, depends_on=['load'])
59 >>> dag.add_task('enob', compute_enob, depends_on=['fft', 'rise_time'])
60 >>> results = dag.execute()
62 References:
63 API-013: DAG Execution for Complex Workflows
64 """
66 def __init__(self) -> None:
67 """Initialize empty DAG."""
68 self.tasks: dict[str, TaskNode] = {}
69 self._adjacency: dict[str, list[str]] = defaultdict(list)
71 def add_task(
72 self,
73 name: str,
74 func: Callable[[dict[str, Any]], Any],
75 depends_on: list[str] | None = None,
76 ) -> None:
77 """Add a task to the DAG.
79 Args:
80 name: Unique name for the task.
81 func: Function to execute. Should accept state dict and return result.
82 depends_on: List of task names this task depends on.
84 Raises:
85 AnalysisError: If task name already exists or creates a cycle.
87 Example:
88 >>> dag.add_task('fft', compute_fft, depends_on=['load'])
90 References:
91 API-013: DAG Execution
92 """
93 if name in self.tasks:
94 raise AnalysisError(f"Task '{name}' already exists in DAG")
96 depends_on = depends_on or []
98 # Verify dependencies exist
99 for dep in depends_on:
100 if dep not in self.tasks:
101 raise AnalysisError(f"Dependency '{dep}' not found for task '{name}'")
103 # Create task node
104 task = TaskNode(name=name, func=func, depends_on=depends_on)
105 self.tasks[name] = task
107 # Update adjacency list
108 for dep in depends_on:
109 self._adjacency[dep].append(name)
111 # Check for cycles
112 if self._has_cycle():
113 # Rollback - remove task
114 del self.tasks[name]
115 for dep in depends_on:
116 self._adjacency[dep].remove(name)
117 raise AnalysisError(f"Adding task '{name}' would create a cycle in DAG")
119 def _has_cycle(self) -> bool:
120 """Check if DAG contains a cycle.
122 Returns:
123 True if cycle detected.
124 """
125 # Use DFS to detect cycles
126 visited: set[str] = set()
127 rec_stack: set[str] = set()
129 def dfs(node: str) -> bool:
130 visited.add(node)
131 rec_stack.add(node)
133 for neighbor in self._adjacency.get(node, []):
134 if neighbor not in visited:
135 if dfs(neighbor):
136 return True
137 elif neighbor in rec_stack:
138 return True
140 rec_stack.remove(node)
141 return False
143 return any(task_name not in visited and dfs(task_name) for task_name in self.tasks)
145 def _topological_sort(self) -> list[list[str]]:
146 """Compute topological sort grouped by execution level.
148 Tasks at the same level can be executed in parallel.
150 Returns:
151 List of levels, where each level is a list of task names.
153 Raises:
154 AnalysisError: If DAG contains a cycle or unreachable tasks.
156 Example:
157 >>> levels = dag._topological_sort()
158 >>> # [[load], [fft, rise_time], [enob]]
159 """
160 # Compute in-degree for each node
161 in_degree = {name: len(task.depends_on) for name, task in self.tasks.items()}
163 # Find nodes with no dependencies (level 0)
164 levels: list[list[str]] = []
165 queue = deque([name for name, degree in in_degree.items() if degree == 0])
167 while queue:
168 # All tasks at this level can run in parallel
169 current_level = list(queue)
170 levels.append(current_level)
171 queue.clear()
173 # Process current level
174 for task_name in current_level:
175 # Reduce in-degree of dependent tasks
176 for dependent in self._adjacency.get(task_name, []):
177 in_degree[dependent] -= 1
178 if in_degree[dependent] == 0:
179 queue.append(dependent)
181 # Verify all tasks were included
182 total_tasks = sum(len(level) for level in levels)
183 if total_tasks != len(self.tasks):
184 raise AnalysisError("DAG contains a cycle or unreachable tasks")
186 return levels
188 def execute(
189 self,
190 *,
191 initial_state: dict[str, Any] | None = None,
192 parallel: bool = True,
193 max_workers: int | None = None,
194 ) -> dict[str, Any]:
195 """Execute the workflow DAG.
197 Tasks are executed in topological order with automatic parallelization
198 of independent tasks at each level.
200 Args:
201 initial_state: Initial state dictionary passed to first tasks.
202 parallel: Enable parallel execution of independent tasks.
203 max_workers: Maximum number of parallel workers. None uses CPU count.
205 Returns:
206 Final state dictionary containing all task results.
208 Example:
209 >>> results = dag.execute(initial_state={'trace': trace_data})
210 >>> print(results['enob'])
212 References:
213 API-013: DAG Execution
214 """
215 if not self.tasks:
216 return initial_state or {}
218 state = initial_state or {}
219 levels = self._topological_sort()
221 for level in levels:
222 if parallel and len(level) > 1:
223 # Execute level in parallel
224 self._execute_level_parallel(level, state, max_workers)
225 else:
226 # Execute level sequentially
227 self._execute_level_sequential(level, state)
229 return state
231 def _execute_level_sequential(self, level: list[str], state: dict[str, Any]) -> None:
232 """Execute a level of tasks sequentially.
234 Args:
235 level: List of task names to execute.
236 state: Shared state dictionary.
238 Raises:
239 AnalysisError: If task execution fails.
240 """
241 for task_name in level:
242 task = self.tasks[task_name]
243 try:
244 result = task.func(state)
245 task.result = result
246 task.completed = True
248 # Update state with result
249 if isinstance(result, dict):
250 state.update(result)
251 else:
252 state[task_name] = result
254 except Exception as e:
255 raise AnalysisError(f"Task '{task_name}' failed: {e}") from e
257 def _execute_level_parallel(
258 self, level: list[str], state: dict[str, Any], max_workers: int | None
259 ) -> None:
260 """Execute a level of tasks in parallel.
262 Args:
263 level: List of task names to execute in parallel.
264 state: Shared state dictionary.
265 max_workers: Maximum number of workers.
267 Raises:
268 AnalysisError: If task execution fails.
269 """
270 with ThreadPoolExecutor(max_workers=max_workers) as executor:
271 # Submit all tasks
272 future_to_task = {executor.submit(self.tasks[name].func, state): name for name in level}
274 # Collect results
275 for future in as_completed(future_to_task):
276 task_name = future_to_task[future]
277 task = self.tasks[task_name]
279 try:
280 result = future.result()
281 task.result = result
282 task.completed = True
284 # Update state with result
285 if isinstance(result, dict):
286 state.update(result)
287 else:
288 state[task_name] = result
290 except Exception as e:
291 raise AnalysisError(f"Task '{task_name}' failed: {e}") from e
293 def get_result(self, task_name: str) -> Any:
294 """Get result from a completed task.
296 Args:
297 task_name: Name of the task.
299 Returns:
300 Task result.
302 Raises:
303 AnalysisError: If task doesn't exist or hasn't been executed.
305 Example:
306 >>> fft_result = dag.get_result('fft')
307 """
308 if task_name not in self.tasks:
309 raise AnalysisError(f"Task '{task_name}' not found in DAG")
311 task = self.tasks[task_name]
312 if not task.completed:
313 raise AnalysisError(f"Task '{task_name}' has not been executed yet")
315 return task.result
317 def reset(self) -> None:
318 """Reset all task completion states.
320 Allows re-execution of the DAG with different initial state.
322 Example:
323 >>> dag.reset()
324 >>> results = dag.execute(initial_state={'trace': new_trace})
325 """
326 for task in self.tasks.values():
327 task.completed = False
328 task.result = None
330 def to_graphviz(self) -> str:
331 """Generate Graphviz DOT representation of the DAG.
333 Returns:
334 DOT format string for visualization.
336 Example:
337 >>> dot = dag.to_graphviz()
338 >>> with open('workflow.dot', 'w') as f:
339 ... f.write(dot)
340 >>> # Then: dot -Tpng workflow.dot -o workflow.png
342 References:
343 API-013: DAG Execution
344 """
345 lines = ["digraph WorkflowDAG {", " rankdir=LR;", " node [shape=box];", ""]
347 # Add nodes
348 for task_name, task in self.tasks.items():
349 style = "filled,bold" if task.completed else "filled"
350 color = "lightgreen" if task.completed else "lightblue"
351 lines.append(f' "{task_name}" [style="{style}", fillcolor="{color}"];')
353 lines.append("")
355 # Add edges
356 for task_name, task in self.tasks.items():
357 for dep in task.depends_on:
358 lines.append(f' "{dep}" -> "{task_name}";')
360 lines.append("}")
361 return "\n".join(lines)
363 def __repr__(self) -> str:
364 """String representation of DAG."""
365 return f"WorkflowDAG(tasks={len(self.tasks)})"
367 def __str__(self) -> str:
368 """Detailed string representation."""
369 lines = [f"WorkflowDAG with {len(self.tasks)} tasks:"]
370 for task_name, task in self.tasks.items():
371 deps = ", ".join(task.depends_on) if task.depends_on else "none"
372 status = "✓" if task.completed else "○"
373 lines.append(f" {status} {task_name} (depends on: {deps})")
374 return "\n".join(lines)
377__all__ = ["TaskNode", "WorkflowDAG"]