Coverage for smart_pipeline / solvers.py: 100%

181 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 13:46 +0200

1import logging 

2from collections import defaultdict, deque 

3from dataclasses import dataclass 

4from typing import List, Dict, Any, Set, Protocol, Optional 

5 

6from .models import Step 

7from .executor import StepExecutor 

8 

9# Initialize module-level logger 

10logger = logging.getLogger(__name__) 

11 

12class Solver(Protocol): 

13 """Interface for execution logic.""" 

14 def solve(self, steps: List[Step], inputs: Dict[str, Any]) -> Dict[str, Any]: 

15 ... 

16 

17class DAGSolver: 

18 """ 

19 Standard Topological Sort Solver.  

20 Ideal for linear workflows. 

21 """ 

22 def solve(self, steps: List[Step], inputs: Dict[str, Any]) -> Dict[str, Any]: 

23 logger.info("DAGSolver started.") 

24 execution_order = self._topological_sort(steps, set(inputs.keys())) 

25 logger.debug(f"Topological sort order: {[s.name for s in execution_order]}") 

26 

27 memory = inputs.copy() 

28 

29 for step in execution_order: 

30 StepExecutor.run_step(step, memory) 

31 

32 return memory 

33 

34 def _topological_sort(self, steps: List[Step], input_keys: Set[str]) -> List[Step]: 

35 producers_map = _map_producers(steps) 

36 adj_list, indegree = _build_dependency_graph(steps, input_keys, producers_map) 

37 

38 # Kahn's Algorithm 

39 queue = deque([s for s, deg in indegree.items() if deg == 0]) 

40 sorted_steps = [] 

41 

42 while queue: 

43 current = queue.popleft() 

44 sorted_steps.append(current) 

45 

46 for neighbor in adj_list[current]: 

47 indegree[neighbor] -= 1 

48 if indegree[neighbor] == 0: 

49 queue.append(neighbor) 

50 

51 if len(sorted_steps) != len(steps): 

52 logger.error("Cycle detected in DAGSolver.") 

53 raise ValueError("Cycle detected in pipeline. Use HybridSolver or IterativeSolver.") 

54 

55 return sorted_steps 

56 

57@dataclass 

58class IterativeSolver: 

59 """ 

60 Solves systems with feedback loops. 

61 """ 

62 max_iterations: int = 100 

63 tolerance: float = 1e-6 

64 target_var: Optional[str] = None 

65 execution_order: Optional[List[str]] = None 

66 

67 def solve(self, steps: List[Step], inputs: Dict[str, Any]) -> Dict[str, Any]: 

68 memory = inputs.copy() 

69 residuals = [] 

70 

71 run_sequence = self._determine_execution_order(steps) 

72 logger.info(f"IterativeSolver started. Sequence: {[s.name for s in run_sequence]}") 

73 

74 # Identify variables produced by these steps (for auto-convergence) 

75 produced_vars = set() 

76 for s in steps: 

77 produced_vars.update(s.resolve_output_names()) 

78 

79 for i in range(self.max_iterations): 

80 # Snapshot state for convergence check 

81 prev_state = {k: memory.get(k) for k in produced_vars if k in memory} 

82 

83 # Execute 

84 for step in run_sequence: 

85 StepExecutor.run_step(step, memory) 

86 

87 # Check Convergence 

88 diff = self._calculate_residual(prev_state, memory, produced_vars) 

89 residuals.append(diff) 

90 

91 # Only break if we actually calculated a numeric difference (not inf) 

92 if diff != float('inf') and diff < self.tolerance: 

93 logger.info(f"Converged at iteration {i+1} with residual {diff:.6e}") 

94 break 

95 

96 logger.debug(f"Iteration {i+1}: residual {diff:.6e}") 

97 else: 

98 logger.warning(f"Reached max_iterations ({self.max_iterations}) without converging. Last residual: {residuals[-1]:.6e}") 

99 

100 # Store residuals (append to potentially existing history from other cycles) 

101 memory.setdefault('residual_history', []).append(residuals) 

102 return memory 

103 

104 def _calculate_residual(self, prev_state: Dict, current_memory: Dict, produced_vars: Set[str]) -> float: 

105 """ 

106 Calculates the maximum change in variables.  

107 """ 

108 if self.target_var: 

109 p = prev_state.get(self.target_var) 

110 c = current_memory.get(self.target_var) 

111 return abs(c - p) if (isinstance(p, (int, float)) and isinstance(c, (int, float))) else float('inf') 

112 

113 max_diff = 0.0 

114 numeric_vars_found = False 

115 

116 for k in produced_vars: 

117 p = prev_state.get(k) 

118 c = current_memory.get(k) 

119 

120 # Strictly require both to be numeric 

121 if isinstance(p, (int, float)) and isinstance(c, (int, float)): 

122 diff = abs(c - p) 

123 max_diff = max(max_diff, diff) 

124 numeric_vars_found = True 

125 

126 if numeric_vars_found: 

127 return max_diff 

128 

129 # If no numeric variables updated, we can't judge convergence numerically. 

130 return float('inf') 

131 

132 def _determine_execution_order(self, steps: List[Step]) -> List[Step]: 

133 if not self.execution_order: 

134 return steps 

135 

136 step_map = {s.name: s for s in steps} 

137 return [step_map[name] for name in self.execution_order if name in step_map] 

138 

139 

140class HybridSolver: 

141 """ 

142 Advanced solver that automatically decomposes the pipeline into  

143 Linear (DAG) and Iterative (Cyclic) components (Strongly Connected Components). 

144 """ 

145 def __init__(self, max_iterations: int = 100, tolerance: float = 1e-6): 

146 self.max_iterations = max_iterations 

147 self.tolerance = tolerance 

148 

149 def solve(self, steps: List[Step], inputs: Dict[str, Any]) -> Dict[str, Any]: 

150 logger.info("HybridSolver started.") 

151 input_keys = set(inputs.keys()) 

152 producers_map = _map_producers(steps) 

153 

154 # 1. Build Adjacency Graph (Producer -> Consumer) 

155 adj_list, _ = _build_dependency_graph(steps, input_keys, producers_map) 

156 

157 # 2. Find Strongly Connected Components (SCCs) 

158 sccs = self._tarjan_scc(steps, adj_list) 

159 logger.debug(f"Detected {len(sccs)} execution blocks (SCCs).") 

160 

161 # 3. Build Condensation Graph (DAG of SCCs) 

162 scc_map = {step: i for i, cluster in enumerate(sccs) for step in cluster} 

163 scc_adj = defaultdict(set) 

164 scc_indegree = defaultdict(int) 

165 

166 for u in steps: 

167 u_scc = scc_map[u] 

168 for v in adj_list[u]: 

169 v_scc = scc_map[v] 

170 if u_scc != v_scc: 

171 if v_scc not in scc_adj[u_scc]: 

172 scc_adj[u_scc].add(v_scc) 

173 scc_indegree[v_scc] += 1 

174 

175 # Ensure all SCCs have an entry 

176 for i in range(len(sccs)): 

177 if i not in scc_indegree: 

178 scc_indegree[i] = 0 

179 

180 # 4. Topological Sort of SCCs 

181 queue = deque([i for i, deg in scc_indegree.items() if deg == 0]) 

182 execution_plan = [] 

183 

184 while queue: 

185 current_scc_idx = queue.popleft() 

186 execution_plan.append(sccs[current_scc_idx]) 

187 

188 for neighbor_scc in scc_adj[current_scc_idx]: 

189 scc_indegree[neighbor_scc] -= 1 

190 if scc_indegree[neighbor_scc] == 0: 

191 queue.append(neighbor_scc) 

192 

193 # 5. Execute 

194 memory = inputs.copy() 

195 

196 for group in execution_plan: 

197 # Case A: Linear 

198 if len(group) == 1 and group[0] not in adj_list[group[0]]: 

199 step = group[0] 

200 StepExecutor.run_step(step, memory) 

201 continue 

202 

203 # Case B: Cyclic 

204 # Sort alphabetically to ensure deterministic execution order within the cycle 

205 group_sorted = sorted(group, key=lambda s: s.name) 

206 

207 logger.info(f"Cyclic Block Detected: {[s.name for s in group_sorted]}") 

208 sub_solver = IterativeSolver( 

209 max_iterations=self.max_iterations, 

210 tolerance=self.tolerance 

211 ) 

212 

213 cycle_results = sub_solver.solve(group_sorted, memory) 

214 memory.update(cycle_results) 

215 

216 return memory 

217 

218 def _tarjan_scc(self, steps: List[Step], adj_list: Dict[Step, List[Step]]) -> List[List[Step]]: 

219 index = 0 

220 indices = {} 

221 lowlinks = {} 

222 stack = [] 

223 on_stack = set() 

224 sccs = [] 

225 

226 def strongconnect(v): 

227 nonlocal index 

228 indices[v] = index 

229 lowlinks[v] = index 

230 index += 1 

231 stack.append(v) 

232 on_stack.add(v) 

233 

234 for w in adj_list[v]: 

235 if w not in indices: 

236 strongconnect(w) 

237 lowlinks[v] = min(lowlinks[v], lowlinks[w]) 

238 elif w in on_stack: 

239 lowlinks[v] = min(lowlinks[v], indices[w]) 

240 

241 if lowlinks[v] == indices[v]: 

242 new_scc = [] 

243 while True: 

244 w = stack.pop() 

245 on_stack.remove(w) 

246 new_scc.append(w) 

247 if w == v: 

248 break 

249 sccs.append(new_scc) 

250 

251 for step in steps: 

252 if step not in indices: 

253 strongconnect(step) 

254 

255 return sccs 

256 

257# --- Helpers --- 

258 

259def _map_producers(steps: List[Step]) -> Dict[str, Step]: 

260 mapping = {} 

261 for step in steps: 

262 for out in step.resolve_output_names(): 

263 mapping[out] = step 

264 return mapping 

265 

266def _build_dependency_graph(steps: List[Step], input_keys: Set[str], producers_map: Dict[str, Step]): 

267 adj_list = defaultdict(list) 

268 indegree = defaultdict(int) 

269 

270 for s in steps: 

271 indegree[s] = 0 

272 

273 for consumer in steps: 

274 # --- FIX: Use .get_signature() to see through decorators --- 

275 sig = consumer.get_signature() 

276 for param in sig.parameters: 

277 

278 # PRIORITY FIX: Check if it's an internal producer FIRST. 

279 if param in producers_map: 

280 producer = producers_map[param] 

281 adj_list[producer].append(consumer) 

282 indegree[consumer] += 1 

283 

284 # Only if it's NOT produced internally do we check if it's satisfied by inputs. 

285 elif param in input_keys: 

286 continue 

287 

288 return adj_list, indegree