Coverage for src / tracekit / pipeline / pipeline.py: 100%
92 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Pipeline architecture for chaining trace transformations.
3This module implements sklearn-style pipeline composition for trace operations,
4enabling declarative, reusable analysis workflows.
5"""
7from __future__ import annotations
9from typing import TYPE_CHECKING, Any
11from .base import TraceTransformer
13if TYPE_CHECKING:
14 from collections.abc import Sequence
16 from ..core.types import WaveformTrace
19class Pipeline(TraceTransformer):
20 """Chain multiple trace transformers into a single processing pipeline.
22 Pipeline applies transformers sequentially: each stage transforms the output
23 of the previous stage. Supports the fit/transform pattern and can be
24 serialized with pickle or joblib.
26 The pipeline is itself a TraceTransformer, so pipelines can be nested.
28 Attributes:
29 steps: List of (name, transformer) tuples defining the pipeline stages.
30 named_steps: Dictionary mapping step names to transformers.
32 Example:
33 >>> import tracekit as tk
34 >>> pipeline = tk.Pipeline([
35 ... ('lowpass', tk.LowPassFilter(cutoff=1e6)),
36 ... ('resample', tk.Resample(rate=1e9)),
37 ... ('normalize', tk.Normalize())
38 ... ])
39 >>> result = pipeline.transform(trace)
41 Advanced Example:
42 >>> # Create analysis pipeline with fit/transform
43 >>> pipeline = tk.Pipeline([
44 ... ('filter', tk.BandPassFilter(low=1e5, high=1e6)),
45 ... ('normalize', tk.Normalize(method='zscore')),
46 ... ('fft', tk.FFT(nfft=8192, window='hann')),
47 ... ('extract', tk.ExtractMeasurement('thd'))
48 ... ])
49 >>> # Fit on reference trace
50 >>> pipeline.fit(reference_trace)
51 >>> # Transform multiple traces
52 >>> results = [pipeline.transform(t) for t in traces]
53 >>> # Access intermediate results
54 >>> filtered = pipeline.named_steps['filter'].transform(trace)
55 >>> # Save for reuse
56 >>> import joblib
57 >>> joblib.dump(pipeline, 'analysis_pipeline.pkl')
59 References:
60 API-001: sklearn-style Pipeline Architecture
61 sklearn.pipeline.Pipeline
62 https://scikit-learn.org/stable/modules/compose.html
63 """
65 def __init__(self, steps: Sequence[tuple[str, TraceTransformer]]) -> None:
66 """Initialize pipeline with sequence of transformers.
68 Args:
69 steps: Sequence of (name, transformer) tuples. Each transformer
70 must be a TraceTransformer instance.
72 Raises:
73 TypeError: If any step is not a TraceTransformer.
74 ValueError: If step names are not unique or empty.
75 """
76 if not steps:
77 raise ValueError("Pipeline steps cannot be empty")
79 # Validate steps
80 names = []
81 for name, transformer in steps:
82 if not name:
83 raise ValueError("Step name cannot be empty")
84 if not isinstance(transformer, TraceTransformer):
85 raise TypeError(
86 f"All pipeline steps must be TraceTransformer instances. "
87 f"Step '{name}' is {type(transformer).__name__}"
88 )
89 names.append(name)
91 # Check for duplicate names
92 if len(names) != len(set(names)):
93 duplicates = [n for n in names if names.count(n) > 1]
94 raise ValueError(f"Duplicate step names: {set(duplicates)}")
96 self.steps = list(steps)
97 self.named_steps = dict(steps)
98 self._intermediate_results: dict[str, WaveformTrace] = {}
100 def fit(self, trace: WaveformTrace) -> Pipeline:
101 """Fit all transformers in the pipeline.
103 Fits each transformer sequentially on the output of the previous stage.
104 This allows stateful transformers to learn parameters from the trace.
106 Args:
107 trace: Reference WaveformTrace to fit to.
109 Returns:
110 Self for method chaining.
112 Example:
113 >>> pipeline = Pipeline([
114 ... ('normalize', AdaptiveNormalizer()),
115 ... ('filter', AdaptiveFilter())
116 ... ])
117 >>> pipeline.fit(reference_trace)
118 """
119 current = trace
120 for _name, transformer in self.steps:
121 # Fit transformer to current trace
122 transformer.fit(current)
123 # Transform for next stage
124 current = transformer.transform(current)
125 return self
127 def transform(self, trace: WaveformTrace) -> WaveformTrace:
128 """Transform trace through all pipeline stages.
130 Applies each transformer sequentially, passing the output of each
131 stage to the next. Optionally caches intermediate results.
133 Args:
134 trace: Input WaveformTrace to transform.
136 Returns:
137 Transformed WaveformTrace after passing through all stages.
139 Example:
140 >>> result = pipeline.transform(trace)
141 """
142 current = trace
143 self._intermediate_results.clear()
145 for name, transformer in self.steps:
146 current = transformer.transform(current)
147 # Cache intermediate result for introspection
148 self._intermediate_results[name] = current
150 return current
152 def get_intermediate(self, step_name: str, key: str | None = None) -> Any:
153 """Get intermediate result from a pipeline stage.
155 Retrieves the cached output from a specific pipeline stage after
156 transform() has been called. Can also access internal intermediate
157 results from transformers that cache them (e.g., FFT coefficients).
160 Args:
161 step_name: Name of the pipeline step.
162 key: Optional key for transformer-internal intermediate result.
163 If None, returns the trace output from that stage.
165 Returns:
166 WaveformTrace output from that stage (if key=None), or
167 specific intermediate result from the transformer.
169 Raises:
170 KeyError: If step name not found or transform() not yet called.
172 Example:
173 >>> pipeline = Pipeline([
174 ... ('filter', LowPassFilter(1e6)),
175 ... ('fft', FFT(nfft=8192)),
176 ... ('normalize', Normalize())
177 ... ])
178 >>> result = pipeline.transform(trace)
179 >>> # Get trace output from filter stage
180 >>> filtered = pipeline.get_intermediate('filter')
181 >>> # Get FFT coefficients from FFT stage
182 >>> fft_spectrum = pipeline.get_intermediate('fft', 'spectrum')
183 >>> fft_frequencies = pipeline.get_intermediate('fft', 'frequencies')
185 References:
186 API-005: Intermediate Result Access
187 """
188 if step_name not in self._intermediate_results:
189 if step_name not in self.named_steps:
190 raise KeyError(f"Step '{step_name}' not found in pipeline")
191 raise KeyError(
192 f"No intermediate result for step '{step_name}'. Call transform() first."
193 )
195 # If no key specified, return the trace output from that stage
196 if key is None:
197 return self._intermediate_results[step_name]
199 # Otherwise, try to get internal intermediate from the transformer
200 transformer = self.named_steps[step_name]
201 return transformer.get_intermediate_result(key)
203 def has_intermediate(self, step_name: str, key: str | None = None) -> bool:
204 """Check if intermediate result is available.
206 Args:
207 step_name: Name of the pipeline step.
208 key: Optional key for transformer-internal intermediate result.
210 Returns:
211 True if intermediate result exists.
213 Example:
214 >>> if pipeline.has_intermediate('fft', 'spectrum'):
215 ... spectrum = pipeline.get_intermediate('fft', 'spectrum')
217 References:
218 API-005: Intermediate Result Access
219 """
220 if step_name not in self._intermediate_results:
221 return False
223 if key is None:
224 return True
226 transformer = self.named_steps[step_name]
227 return transformer.has_intermediate_result(key)
229 def list_intermediates(self, step_name: str | None = None) -> list[str] | dict[str, list[str]]:
230 """List available intermediate results.
232 Args:
233 step_name: If specified, list intermediates for that step only.
234 If None, return dict of all steps with their intermediates.
236 Returns:
237 List of intermediate keys for a step, or dict mapping step names
238 to their available intermediates.
240 Raises:
241 KeyError: If step_name not found in pipeline.
243 Example:
244 >>> # List all intermediates
245 >>> all_intermediates = pipeline.list_intermediates()
246 >>> print(all_intermediates)
247 {'filter': ['transfer_function', 'impulse_response'],
248 'fft': ['spectrum', 'frequencies', 'power', 'phase']}
249 >>> # List intermediates for specific step
250 >>> fft_intermediates = pipeline.list_intermediates('fft')
251 >>> print(fft_intermediates)
252 ['spectrum', 'frequencies', 'power', 'phase']
254 References:
255 API-005: Intermediate Result Access
256 """
257 if step_name is not None:
258 if step_name not in self.named_steps:
259 raise KeyError(f"Step '{step_name}' not found in pipeline")
260 transformer = self.named_steps[step_name]
261 return transformer.list_intermediate_results()
263 # Return all intermediates for all steps
264 result = {}
265 for name, transformer in self.steps:
266 intermediates = transformer.list_intermediate_results()
267 if intermediates: # Only include steps with intermediates
268 result[name] = intermediates
269 return result
271 def get_params(self, deep: bool = True) -> dict[str, Any]:
272 """Get parameters for all transformers in the pipeline.
274 Args:
275 deep: If True, returns parameters for all nested transformers.
277 Returns:
278 Dictionary of parameters with step names as prefixes.
280 Example:
281 >>> params = pipeline.get_params()
282 >>> print(params['filter__cutoff'])
283 1000000.0
284 """
285 params: dict[str, Any] = {"steps": self.steps}
287 if deep:
288 for name, transformer in self.steps:
289 transformer_params = transformer.get_params(deep=True)
290 for key, value in transformer_params.items():
291 params[f"{name}__{key}"] = value
293 return params
295 def set_params(self, **params: Any) -> Pipeline:
296 """Set parameters for transformers in the pipeline.
298 Args:
299 **params: Parameters to set, using step__param syntax.
301 Returns:
302 Self for method chaining.
304 Raises:
305 ValueError: If parameter format is invalid.
307 Example:
308 >>> pipeline.set_params(filter__cutoff=2e6, normalize__method='peak')
309 """
310 # Special case: setting steps directly
311 if "steps" in params:
312 self.steps = params["steps"]
313 self.named_steps = dict(self.steps)
314 return self
316 # Parse step__param syntax
317 for param_name, value in params.items():
318 if "__" not in param_name:
319 raise ValueError(
320 f"Pipeline parameter must use 'step__param' syntax, got '{param_name}'"
321 )
323 step_name, param = param_name.split("__", 1)
324 if step_name not in self.named_steps:
325 raise ValueError(
326 f"Step '{step_name}' not found in pipeline. "
327 f"Available steps: {list(self.named_steps.keys())}"
328 )
330 self.named_steps[step_name].set_params(**{param: value})
332 return self
334 def clone(self) -> Pipeline:
335 """Create a copy of this pipeline.
337 Returns:
338 New Pipeline instance with cloned transformers.
340 Example:
341 >>> pipeline_copy = pipeline.clone()
342 """
343 cloned_steps = [(name, transformer.clone()) for name, transformer in self.steps]
344 return Pipeline(cloned_steps)
346 def __len__(self) -> int:
347 """Return number of steps in the pipeline."""
348 return len(self.steps)
350 def __getitem__(self, index: int | str) -> TraceTransformer:
351 """Get transformer by index or name.
353 Args:
354 index: Integer index or string name.
356 Returns:
357 TraceTransformer at that position.
359 Example:
360 >>> first_step = pipeline[0]
361 >>> filter_step = pipeline['filter']
362 """
363 if isinstance(index, str):
364 return self.named_steps[index]
365 return self.steps[index][1]
367 def __repr__(self) -> str:
368 """String representation of the pipeline."""
369 step_strs = [
370 f"('{name}', {transformer.__class__.__name__})" for name, transformer in self.steps
371 ]
372 return "Pipeline([\n " + ",\n ".join(step_strs) + "\n])"
375__all__ = ["Pipeline"]