Coverage for smart_pipeline / solvers.py: 100%
181 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 13:46 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 13:46 +0200
1import logging
2from collections import defaultdict, deque
3from dataclasses import dataclass
4from typing import List, Dict, Any, Set, Protocol, Optional
6from .models import Step
7from .executor import StepExecutor
9# Initialize module-level logger
10logger = logging.getLogger(__name__)
12class Solver(Protocol):
13 """Interface for execution logic."""
14 def solve(self, steps: List[Step], inputs: Dict[str, Any]) -> Dict[str, Any]:
15 ...
17class DAGSolver:
18 """
19 Standard Topological Sort Solver.
20 Ideal for linear workflows.
21 """
22 def solve(self, steps: List[Step], inputs: Dict[str, Any]) -> Dict[str, Any]:
23 logger.info("DAGSolver started.")
24 execution_order = self._topological_sort(steps, set(inputs.keys()))
25 logger.debug(f"Topological sort order: {[s.name for s in execution_order]}")
27 memory = inputs.copy()
29 for step in execution_order:
30 StepExecutor.run_step(step, memory)
32 return memory
34 def _topological_sort(self, steps: List[Step], input_keys: Set[str]) -> List[Step]:
35 producers_map = _map_producers(steps)
36 adj_list, indegree = _build_dependency_graph(steps, input_keys, producers_map)
38 # Kahn's Algorithm
39 queue = deque([s for s, deg in indegree.items() if deg == 0])
40 sorted_steps = []
42 while queue:
43 current = queue.popleft()
44 sorted_steps.append(current)
46 for neighbor in adj_list[current]:
47 indegree[neighbor] -= 1
48 if indegree[neighbor] == 0:
49 queue.append(neighbor)
51 if len(sorted_steps) != len(steps):
52 logger.error("Cycle detected in DAGSolver.")
53 raise ValueError("Cycle detected in pipeline. Use HybridSolver or IterativeSolver.")
55 return sorted_steps
57@dataclass
58class IterativeSolver:
59 """
60 Solves systems with feedback loops.
61 """
62 max_iterations: int = 100
63 tolerance: float = 1e-6
64 target_var: Optional[str] = None
65 execution_order: Optional[List[str]] = None
67 def solve(self, steps: List[Step], inputs: Dict[str, Any]) -> Dict[str, Any]:
68 memory = inputs.copy()
69 residuals = []
71 run_sequence = self._determine_execution_order(steps)
72 logger.info(f"IterativeSolver started. Sequence: {[s.name for s in run_sequence]}")
74 # Identify variables produced by these steps (for auto-convergence)
75 produced_vars = set()
76 for s in steps:
77 produced_vars.update(s.resolve_output_names())
79 for i in range(self.max_iterations):
80 # Snapshot state for convergence check
81 prev_state = {k: memory.get(k) for k in produced_vars if k in memory}
83 # Execute
84 for step in run_sequence:
85 StepExecutor.run_step(step, memory)
87 # Check Convergence
88 diff = self._calculate_residual(prev_state, memory, produced_vars)
89 residuals.append(diff)
91 # Only break if we actually calculated a numeric difference (not inf)
92 if diff != float('inf') and diff < self.tolerance:
93 logger.info(f"Converged at iteration {i+1} with residual {diff:.6e}")
94 break
96 logger.debug(f"Iteration {i+1}: residual {diff:.6e}")
97 else:
98 logger.warning(f"Reached max_iterations ({self.max_iterations}) without converging. Last residual: {residuals[-1]:.6e}")
100 # Store residuals (append to potentially existing history from other cycles)
101 memory.setdefault('residual_history', []).append(residuals)
102 return memory
104 def _calculate_residual(self, prev_state: Dict, current_memory: Dict, produced_vars: Set[str]) -> float:
105 """
106 Calculates the maximum change in variables.
107 """
108 if self.target_var:
109 p = prev_state.get(self.target_var)
110 c = current_memory.get(self.target_var)
111 return abs(c - p) if (isinstance(p, (int, float)) and isinstance(c, (int, float))) else float('inf')
113 max_diff = 0.0
114 numeric_vars_found = False
116 for k in produced_vars:
117 p = prev_state.get(k)
118 c = current_memory.get(k)
120 # Strictly require both to be numeric
121 if isinstance(p, (int, float)) and isinstance(c, (int, float)):
122 diff = abs(c - p)
123 max_diff = max(max_diff, diff)
124 numeric_vars_found = True
126 if numeric_vars_found:
127 return max_diff
129 # If no numeric variables updated, we can't judge convergence numerically.
130 return float('inf')
132 def _determine_execution_order(self, steps: List[Step]) -> List[Step]:
133 if not self.execution_order:
134 return steps
136 step_map = {s.name: s for s in steps}
137 return [step_map[name] for name in self.execution_order if name in step_map]
140class HybridSolver:
141 """
142 Advanced solver that automatically decomposes the pipeline into
143 Linear (DAG) and Iterative (Cyclic) components (Strongly Connected Components).
144 """
145 def __init__(self, max_iterations: int = 100, tolerance: float = 1e-6):
146 self.max_iterations = max_iterations
147 self.tolerance = tolerance
149 def solve(self, steps: List[Step], inputs: Dict[str, Any]) -> Dict[str, Any]:
150 logger.info("HybridSolver started.")
151 input_keys = set(inputs.keys())
152 producers_map = _map_producers(steps)
154 # 1. Build Adjacency Graph (Producer -> Consumer)
155 adj_list, _ = _build_dependency_graph(steps, input_keys, producers_map)
157 # 2. Find Strongly Connected Components (SCCs)
158 sccs = self._tarjan_scc(steps, adj_list)
159 logger.debug(f"Detected {len(sccs)} execution blocks (SCCs).")
161 # 3. Build Condensation Graph (DAG of SCCs)
162 scc_map = {step: i for i, cluster in enumerate(sccs) for step in cluster}
163 scc_adj = defaultdict(set)
164 scc_indegree = defaultdict(int)
166 for u in steps:
167 u_scc = scc_map[u]
168 for v in adj_list[u]:
169 v_scc = scc_map[v]
170 if u_scc != v_scc:
171 if v_scc not in scc_adj[u_scc]:
172 scc_adj[u_scc].add(v_scc)
173 scc_indegree[v_scc] += 1
175 # Ensure all SCCs have an entry
176 for i in range(len(sccs)):
177 if i not in scc_indegree:
178 scc_indegree[i] = 0
180 # 4. Topological Sort of SCCs
181 queue = deque([i for i, deg in scc_indegree.items() if deg == 0])
182 execution_plan = []
184 while queue:
185 current_scc_idx = queue.popleft()
186 execution_plan.append(sccs[current_scc_idx])
188 for neighbor_scc in scc_adj[current_scc_idx]:
189 scc_indegree[neighbor_scc] -= 1
190 if scc_indegree[neighbor_scc] == 0:
191 queue.append(neighbor_scc)
193 # 5. Execute
194 memory = inputs.copy()
196 for group in execution_plan:
197 # Case A: Linear
198 if len(group) == 1 and group[0] not in adj_list[group[0]]:
199 step = group[0]
200 StepExecutor.run_step(step, memory)
201 continue
203 # Case B: Cyclic
204 # Sort alphabetically to ensure deterministic execution order within the cycle
205 group_sorted = sorted(group, key=lambda s: s.name)
207 logger.info(f"Cyclic Block Detected: {[s.name for s in group_sorted]}")
208 sub_solver = IterativeSolver(
209 max_iterations=self.max_iterations,
210 tolerance=self.tolerance
211 )
213 cycle_results = sub_solver.solve(group_sorted, memory)
214 memory.update(cycle_results)
216 return memory
218 def _tarjan_scc(self, steps: List[Step], adj_list: Dict[Step, List[Step]]) -> List[List[Step]]:
219 index = 0
220 indices = {}
221 lowlinks = {}
222 stack = []
223 on_stack = set()
224 sccs = []
226 def strongconnect(v):
227 nonlocal index
228 indices[v] = index
229 lowlinks[v] = index
230 index += 1
231 stack.append(v)
232 on_stack.add(v)
234 for w in adj_list[v]:
235 if w not in indices:
236 strongconnect(w)
237 lowlinks[v] = min(lowlinks[v], lowlinks[w])
238 elif w in on_stack:
239 lowlinks[v] = min(lowlinks[v], indices[w])
241 if lowlinks[v] == indices[v]:
242 new_scc = []
243 while True:
244 w = stack.pop()
245 on_stack.remove(w)
246 new_scc.append(w)
247 if w == v:
248 break
249 sccs.append(new_scc)
251 for step in steps:
252 if step not in indices:
253 strongconnect(step)
255 return sccs
257# --- Helpers ---
259def _map_producers(steps: List[Step]) -> Dict[str, Step]:
260 mapping = {}
261 for step in steps:
262 for out in step.resolve_output_names():
263 mapping[out] = step
264 return mapping
266def _build_dependency_graph(steps: List[Step], input_keys: Set[str], producers_map: Dict[str, Step]):
267 adj_list = defaultdict(list)
268 indegree = defaultdict(int)
270 for s in steps:
271 indegree[s] = 0
273 for consumer in steps:
274 # --- FIX: Use .get_signature() to see through decorators ---
275 sig = consumer.get_signature()
276 for param in sig.parameters:
278 # PRIORITY FIX: Check if it's an internal producer FIRST.
279 if param in producers_map:
280 producer = producers_map[param]
281 adj_list[producer].append(consumer)
282 indegree[consumer] += 1
284 # Only if it's NOT produced internally do we check if it's satisfied by inputs.
285 elif param in input_keys:
286 continue
288 return adj_list, indegree