Coverage for src / tracekit / pipeline / pipeline.py: 100%

92 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Pipeline architecture for chaining trace transformations. 

2 

3This module implements sklearn-style pipeline composition for trace operations, 

4enabling declarative, reusable analysis workflows. 

5""" 

6 

7from __future__ import annotations 

8 

9from typing import TYPE_CHECKING, Any 

10 

11from .base import TraceTransformer 

12 

13if TYPE_CHECKING: 

14 from collections.abc import Sequence 

15 

16 from ..core.types import WaveformTrace 

17 

18 

19class Pipeline(TraceTransformer): 

20 """Chain multiple trace transformers into a single processing pipeline. 

21 

22 Pipeline applies transformers sequentially: each stage transforms the output 

23 of the previous stage. Supports the fit/transform pattern and can be 

24 serialized with pickle or joblib. 

25 

26 The pipeline is itself a TraceTransformer, so pipelines can be nested. 

27 

28 Attributes: 

29 steps: List of (name, transformer) tuples defining the pipeline stages. 

30 named_steps: Dictionary mapping step names to transformers. 

31 

32 Example: 

33 >>> import tracekit as tk 

34 >>> pipeline = tk.Pipeline([ 

35 ... ('lowpass', tk.LowPassFilter(cutoff=1e6)), 

36 ... ('resample', tk.Resample(rate=1e9)), 

37 ... ('normalize', tk.Normalize()) 

38 ... ]) 

39 >>> result = pipeline.transform(trace) 

40 

41 Advanced Example: 

42 >>> # Create analysis pipeline with fit/transform 

43 >>> pipeline = tk.Pipeline([ 

44 ... ('filter', tk.BandPassFilter(low=1e5, high=1e6)), 

45 ... ('normalize', tk.Normalize(method='zscore')), 

46 ... ('fft', tk.FFT(nfft=8192, window='hann')), 

47 ... ('extract', tk.ExtractMeasurement('thd')) 

48 ... ]) 

49 >>> # Fit on reference trace 

50 >>> pipeline.fit(reference_trace) 

51 >>> # Transform multiple traces 

52 >>> results = [pipeline.transform(t) for t in traces] 

53 >>> # Access intermediate results 

54 >>> filtered = pipeline.named_steps['filter'].transform(trace) 

55 >>> # Save for reuse 

56 >>> import joblib 

57 >>> joblib.dump(pipeline, 'analysis_pipeline.pkl') 

58 

59 References: 

60 API-001: sklearn-style Pipeline Architecture 

61 sklearn.pipeline.Pipeline 

62 https://scikit-learn.org/stable/modules/compose.html 

63 """ 

64 

65 def __init__(self, steps: Sequence[tuple[str, TraceTransformer]]) -> None: 

66 """Initialize pipeline with sequence of transformers. 

67 

68 Args: 

69 steps: Sequence of (name, transformer) tuples. Each transformer 

70 must be a TraceTransformer instance. 

71 

72 Raises: 

73 TypeError: If any step is not a TraceTransformer. 

74 ValueError: If step names are not unique or empty. 

75 """ 

76 if not steps: 

77 raise ValueError("Pipeline steps cannot be empty") 

78 

79 # Validate steps 

80 names = [] 

81 for name, transformer in steps: 

82 if not name: 

83 raise ValueError("Step name cannot be empty") 

84 if not isinstance(transformer, TraceTransformer): 

85 raise TypeError( 

86 f"All pipeline steps must be TraceTransformer instances. " 

87 f"Step '{name}' is {type(transformer).__name__}" 

88 ) 

89 names.append(name) 

90 

91 # Check for duplicate names 

92 if len(names) != len(set(names)): 

93 duplicates = [n for n in names if names.count(n) > 1] 

94 raise ValueError(f"Duplicate step names: {set(duplicates)}") 

95 

96 self.steps = list(steps) 

97 self.named_steps = dict(steps) 

98 self._intermediate_results: dict[str, WaveformTrace] = {} 

99 

100 def fit(self, trace: WaveformTrace) -> Pipeline: 

101 """Fit all transformers in the pipeline. 

102 

103 Fits each transformer sequentially on the output of the previous stage. 

104 This allows stateful transformers to learn parameters from the trace. 

105 

106 Args: 

107 trace: Reference WaveformTrace to fit to. 

108 

109 Returns: 

110 Self for method chaining. 

111 

112 Example: 

113 >>> pipeline = Pipeline([ 

114 ... ('normalize', AdaptiveNormalizer()), 

115 ... ('filter', AdaptiveFilter()) 

116 ... ]) 

117 >>> pipeline.fit(reference_trace) 

118 """ 

119 current = trace 

120 for _name, transformer in self.steps: 

121 # Fit transformer to current trace 

122 transformer.fit(current) 

123 # Transform for next stage 

124 current = transformer.transform(current) 

125 return self 

126 

127 def transform(self, trace: WaveformTrace) -> WaveformTrace: 

128 """Transform trace through all pipeline stages. 

129 

130 Applies each transformer sequentially, passing the output of each 

131 stage to the next. Optionally caches intermediate results. 

132 

133 Args: 

134 trace: Input WaveformTrace to transform. 

135 

136 Returns: 

137 Transformed WaveformTrace after passing through all stages. 

138 

139 Example: 

140 >>> result = pipeline.transform(trace) 

141 """ 

142 current = trace 

143 self._intermediate_results.clear() 

144 

145 for name, transformer in self.steps: 

146 current = transformer.transform(current) 

147 # Cache intermediate result for introspection 

148 self._intermediate_results[name] = current 

149 

150 return current 

151 

152 def get_intermediate(self, step_name: str, key: str | None = None) -> Any: 

153 """Get intermediate result from a pipeline stage. 

154 

155 Retrieves the cached output from a specific pipeline stage after 

156 transform() has been called. Can also access internal intermediate 

157 results from transformers that cache them (e.g., FFT coefficients). 

158 

159 

160 Args: 

161 step_name: Name of the pipeline step. 

162 key: Optional key for transformer-internal intermediate result. 

163 If None, returns the trace output from that stage. 

164 

165 Returns: 

166 WaveformTrace output from that stage (if key=None), or 

167 specific intermediate result from the transformer. 

168 

169 Raises: 

170 KeyError: If step name not found or transform() not yet called. 

171 

172 Example: 

173 >>> pipeline = Pipeline([ 

174 ... ('filter', LowPassFilter(1e6)), 

175 ... ('fft', FFT(nfft=8192)), 

176 ... ('normalize', Normalize()) 

177 ... ]) 

178 >>> result = pipeline.transform(trace) 

179 >>> # Get trace output from filter stage 

180 >>> filtered = pipeline.get_intermediate('filter') 

181 >>> # Get FFT coefficients from FFT stage 

182 >>> fft_spectrum = pipeline.get_intermediate('fft', 'spectrum') 

183 >>> fft_frequencies = pipeline.get_intermediate('fft', 'frequencies') 

184 

185 References: 

186 API-005: Intermediate Result Access 

187 """ 

188 if step_name not in self._intermediate_results: 

189 if step_name not in self.named_steps: 

190 raise KeyError(f"Step '{step_name}' not found in pipeline") 

191 raise KeyError( 

192 f"No intermediate result for step '{step_name}'. Call transform() first." 

193 ) 

194 

195 # If no key specified, return the trace output from that stage 

196 if key is None: 

197 return self._intermediate_results[step_name] 

198 

199 # Otherwise, try to get internal intermediate from the transformer 

200 transformer = self.named_steps[step_name] 

201 return transformer.get_intermediate_result(key) 

202 

203 def has_intermediate(self, step_name: str, key: str | None = None) -> bool: 

204 """Check if intermediate result is available. 

205 

206 Args: 

207 step_name: Name of the pipeline step. 

208 key: Optional key for transformer-internal intermediate result. 

209 

210 Returns: 

211 True if intermediate result exists. 

212 

213 Example: 

214 >>> if pipeline.has_intermediate('fft', 'spectrum'): 

215 ... spectrum = pipeline.get_intermediate('fft', 'spectrum') 

216 

217 References: 

218 API-005: Intermediate Result Access 

219 """ 

220 if step_name not in self._intermediate_results: 

221 return False 

222 

223 if key is None: 

224 return True 

225 

226 transformer = self.named_steps[step_name] 

227 return transformer.has_intermediate_result(key) 

228 

229 def list_intermediates(self, step_name: str | None = None) -> list[str] | dict[str, list[str]]: 

230 """List available intermediate results. 

231 

232 Args: 

233 step_name: If specified, list intermediates for that step only. 

234 If None, return dict of all steps with their intermediates. 

235 

236 Returns: 

237 List of intermediate keys for a step, or dict mapping step names 

238 to their available intermediates. 

239 

240 Raises: 

241 KeyError: If step_name not found in pipeline. 

242 

243 Example: 

244 >>> # List all intermediates 

245 >>> all_intermediates = pipeline.list_intermediates() 

246 >>> print(all_intermediates) 

247 {'filter': ['transfer_function', 'impulse_response'], 

248 'fft': ['spectrum', 'frequencies', 'power', 'phase']} 

249 >>> # List intermediates for specific step 

250 >>> fft_intermediates = pipeline.list_intermediates('fft') 

251 >>> print(fft_intermediates) 

252 ['spectrum', 'frequencies', 'power', 'phase'] 

253 

254 References: 

255 API-005: Intermediate Result Access 

256 """ 

257 if step_name is not None: 

258 if step_name not in self.named_steps: 

259 raise KeyError(f"Step '{step_name}' not found in pipeline") 

260 transformer = self.named_steps[step_name] 

261 return transformer.list_intermediate_results() 

262 

263 # Return all intermediates for all steps 

264 result = {} 

265 for name, transformer in self.steps: 

266 intermediates = transformer.list_intermediate_results() 

267 if intermediates: # Only include steps with intermediates 

268 result[name] = intermediates 

269 return result 

270 

271 def get_params(self, deep: bool = True) -> dict[str, Any]: 

272 """Get parameters for all transformers in the pipeline. 

273 

274 Args: 

275 deep: If True, returns parameters for all nested transformers. 

276 

277 Returns: 

278 Dictionary of parameters with step names as prefixes. 

279 

280 Example: 

281 >>> params = pipeline.get_params() 

282 >>> print(params['filter__cutoff']) 

283 1000000.0 

284 """ 

285 params: dict[str, Any] = {"steps": self.steps} 

286 

287 if deep: 

288 for name, transformer in self.steps: 

289 transformer_params = transformer.get_params(deep=True) 

290 for key, value in transformer_params.items(): 

291 params[f"{name}__{key}"] = value 

292 

293 return params 

294 

295 def set_params(self, **params: Any) -> Pipeline: 

296 """Set parameters for transformers in the pipeline. 

297 

298 Args: 

299 **params: Parameters to set, using step__param syntax. 

300 

301 Returns: 

302 Self for method chaining. 

303 

304 Raises: 

305 ValueError: If parameter format is invalid. 

306 

307 Example: 

308 >>> pipeline.set_params(filter__cutoff=2e6, normalize__method='peak') 

309 """ 

310 # Special case: setting steps directly 

311 if "steps" in params: 

312 self.steps = params["steps"] 

313 self.named_steps = dict(self.steps) 

314 return self 

315 

316 # Parse step__param syntax 

317 for param_name, value in params.items(): 

318 if "__" not in param_name: 

319 raise ValueError( 

320 f"Pipeline parameter must use 'step__param' syntax, got '{param_name}'" 

321 ) 

322 

323 step_name, param = param_name.split("__", 1) 

324 if step_name not in self.named_steps: 

325 raise ValueError( 

326 f"Step '{step_name}' not found in pipeline. " 

327 f"Available steps: {list(self.named_steps.keys())}" 

328 ) 

329 

330 self.named_steps[step_name].set_params(**{param: value}) 

331 

332 return self 

333 

334 def clone(self) -> Pipeline: 

335 """Create a copy of this pipeline. 

336 

337 Returns: 

338 New Pipeline instance with cloned transformers. 

339 

340 Example: 

341 >>> pipeline_copy = pipeline.clone() 

342 """ 

343 cloned_steps = [(name, transformer.clone()) for name, transformer in self.steps] 

344 return Pipeline(cloned_steps) 

345 

346 def __len__(self) -> int: 

347 """Return number of steps in the pipeline.""" 

348 return len(self.steps) 

349 

350 def __getitem__(self, index: int | str) -> TraceTransformer: 

351 """Get transformer by index or name. 

352 

353 Args: 

354 index: Integer index or string name. 

355 

356 Returns: 

357 TraceTransformer at that position. 

358 

359 Example: 

360 >>> first_step = pipeline[0] 

361 >>> filter_step = pipeline['filter'] 

362 """ 

363 if isinstance(index, str): 

364 return self.named_steps[index] 

365 return self.steps[index][1] 

366 

367 def __repr__(self) -> str: 

368 """String representation of the pipeline.""" 

369 step_strs = [ 

370 f"('{name}', {transformer.__class__.__name__})" for name, transformer in self.steps 

371 ] 

372 return "Pipeline([\n " + ",\n ".join(step_strs) + "\n])" 

373 

374 

375__all__ = ["Pipeline"]