Coverage for src / tracekit / pipeline / base.py: 69%
68 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Base classes for trace transformations and pipeline stages.
3This module implements the foundational abstract base classes for creating
4custom trace transformations compatible with the Pipeline architecture.
5"""
7from __future__ import annotations
9from abc import ABC, abstractmethod
10from typing import TYPE_CHECKING, Any
12if TYPE_CHECKING:
13 from ..core.types import WaveformTrace
16class TraceTransformer(ABC):
17 """Abstract base class for trace transformations.
19 All pipeline stages and custom transformations must inherit from this class.
20 Provides the fit/transform pattern similar to sklearn transformers.
22 The TraceTransformer enforces a consistent interface:
23 - transform(trace) -> trace: Required transformation method
24 - fit(trace) -> self: Optional learning/calibration method
25 - fit_transform(trace) -> trace: Convenience method
26 - get_params() / set_params(): Hyperparameter access
27 - clone(): Create a copy of the transformer
29 Example:
30 >>> class AmplitudeScaler(TraceTransformer):
31 ... def __init__(self, scale_factor=1.0):
32 ... self.scale_factor = scale_factor
33 ...
34 ... def transform(self, trace):
35 ... scaled_data = trace.data * self.scale_factor
36 ... return WaveformTrace(
37 ... data=scaled_data,
38 ... metadata=trace.metadata
39 ... )
40 ...
41 >>> scaler = AmplitudeScaler(scale_factor=2.0)
42 >>> result = scaler.transform(trace)
44 References:
45 API-004: TraceTransformer Base Class
46 sklearn.base.BaseEstimator, TransformerMixin
47 """
49 @abstractmethod
50 def transform(self, trace: WaveformTrace) -> WaveformTrace:
51 """Transform a trace.
53 Args:
54 trace: Input WaveformTrace to transform.
56 Returns:
57 Transformed WaveformTrace.
59 Raises:
60 NotImplementedError: If not implemented by subclass.
61 """
62 raise NotImplementedError(f"{self.__class__.__name__} must implement transform() method")
64 def fit(self, trace: WaveformTrace) -> TraceTransformer:
65 """Fit transformer to a reference trace (optional for stateful transformers).
67 This method is optional and should be overridden by stateful transformers
68 that need to learn parameters from a reference trace (e.g., normalization
69 statistics, adaptive filters).
71 Args:
72 trace: Reference WaveformTrace to fit to.
74 Returns:
75 Self for method chaining.
77 Example:
78 >>> class AdaptiveNormalizer(TraceTransformer):
79 ... def __init__(self):
80 ... self.mean_ = None
81 ... self.std_ = None
82 ...
83 ... def fit(self, trace):
84 ... self.mean_ = trace.data.mean()
85 ... self.std_ = trace.data.std()
86 ... return self
87 ...
88 ... def transform(self, trace):
89 ... normalized = (trace.data - self.mean_) / self.std_
90 ... return WaveformTrace(
91 ... data=normalized,
92 ... metadata=trace.metadata
93 ... )
94 """
95 # Default implementation: no fitting required
96 return self
98 def fit_transform(self, trace: WaveformTrace) -> WaveformTrace:
99 """Fit to trace, then transform it.
101 Convenience method that calls fit() followed by transform().
103 Args:
104 trace: Input WaveformTrace to fit and transform.
106 Returns:
107 Transformed WaveformTrace.
109 Example:
110 >>> normalizer = AdaptiveNormalizer()
111 >>> result = normalizer.fit_transform(reference_trace)
112 """
113 return self.fit(trace).transform(trace)
115 def get_params(self, deep: bool = True) -> dict[str, Any]:
116 """Get parameters for this transformer.
118 Args:
119 deep: If True, will return parameters for nested objects.
121 Returns:
122 Dictionary of parameter names mapped to their values.
124 Example:
125 >>> scaler = AmplitudeScaler(scale_factor=2.0)
126 >>> params = scaler.get_params()
127 >>> print(params)
128 {'scale_factor': 2.0}
129 """
130 params = {}
131 for key in dir(self):
132 # Skip private/magic attributes and methods
133 if key.startswith("_") or callable(getattr(self, key)):
134 continue
135 value = getattr(self, key)
136 params[key] = value
138 # Handle nested transformers if deep=True
139 if deep and hasattr(value, "get_params"): 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true
140 nested_params = value.get_params(deep=True)
141 for nested_key, nested_value in nested_params.items():
142 params[f"{key}__{nested_key}"] = nested_value
144 return params
146 def set_params(self, **params: Any) -> TraceTransformer:
147 """Set parameters for this transformer.
149 Args:
150 **params: Parameter names and values to set.
152 Returns:
153 Self for method chaining.
155 Raises:
156 ValueError: If parameter name is invalid.
158 Example:
159 >>> scaler = AmplitudeScaler(scale_factor=1.0)
160 >>> scaler.set_params(scale_factor=3.0)
161 >>> print(scaler.scale_factor)
162 3.0
163 """
164 if not params: 164 ↛ 165line 164 didn't jump to line 165 because the condition on line 164 was never true
165 return self
167 valid_params = self.get_params(deep=False)
169 for key, value in params.items():
170 # Handle nested parameters (e.g., 'filter__cutoff')
171 if "__" in key: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true
172 nested_obj, nested_key = key.split("__", 1)
173 if nested_obj not in valid_params:
174 raise ValueError(
175 f"Invalid parameter {nested_obj} for transformer {self.__class__.__name__}"
176 )
177 nested = getattr(self, nested_obj)
178 if hasattr(nested, "set_params"):
179 nested.set_params(**{nested_key: value})
180 else:
181 raise ValueError(f"Parameter {nested_obj} does not support set_params")
182 else:
183 if key not in valid_params: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true
184 raise ValueError(
185 f"Invalid parameter {key} for transformer "
186 f"{self.__class__.__name__}. "
187 f"Valid parameters: {list(valid_params.keys())}"
188 )
189 setattr(self, key, value)
191 return self
193 def clone(self) -> TraceTransformer:
194 """Create a copy of this transformer with the same parameters.
196 Returns:
197 New instance of the transformer with same parameters.
199 Example:
200 >>> scaler = AmplitudeScaler(scale_factor=2.0)
201 >>> scaler_copy = scaler.clone()
202 >>> scaler_copy.scale_factor
203 2.0
204 """
205 params = self.get_params(deep=False)
206 return self.__class__(**params)
208 def __getstate__(self) -> dict[str, Any]:
209 """Get state for pickling.
211 Returns:
212 Dictionary containing transformer state.
213 """
214 return self.__dict__.copy()
216 def __setstate__(self, state: dict[str, Any]) -> None:
217 """Set state from unpickling.
219 Args:
220 state: Dictionary containing transformer state.
221 """
222 self.__dict__.update(state)
224 def get_intermediate_result(self, key: str) -> Any:
225 """Get intermediate result from last transformation.
227 Some transformers cache intermediate results (e.g., FFT coefficients,
228 filter states) that can be accessed after transformation.
230 Args:
231 key: Name of intermediate result to retrieve.
233 Returns:
234 Intermediate result value.
236 Raises:
237 KeyError: If key not found or transformer doesn't support intermediates.
239 Example:
240 >>> filter = LowPassFilter(cutoff=1e6)
241 >>> result = filter.transform(trace)
242 >>> transfer_func = filter.get_intermediate_result('transfer_function')
244 References:
245 API-005: Intermediate Result Access
246 """
247 # Check if transformer has _intermediates cache
248 if not hasattr(self, "_intermediates"): 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true
249 raise KeyError(f"{self.__class__.__name__} does not cache intermediate results")
251 intermediates = self._intermediates
252 if key not in intermediates: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true
253 available = list(intermediates.keys())
254 raise KeyError(
255 f"Intermediate '{key}' not found in {self.__class__.__name__}. "
256 f"Available: {available}"
257 )
259 return intermediates[key]
261 def has_intermediate_result(self, key: str) -> bool:
262 """Check if intermediate result is available.
264 Args:
265 key: Name of intermediate result.
267 Returns:
268 True if intermediate result exists.
270 Example:
271 >>> if filter.has_intermediate_result('impulse_response'):
272 ... impulse = filter.get_intermediate_result('impulse_response')
274 References:
275 API-005: Intermediate Result Access
276 """
277 if not hasattr(self, "_intermediates"): 277 ↛ 278line 277 didn't jump to line 278 because the condition on line 277 was never true
278 return False
279 return key in self._intermediates
281 def list_intermediate_results(self) -> list[str]:
282 """List all available intermediate result keys.
284 Returns:
285 List of intermediate result names, or empty list if none available.
287 Example:
288 >>> print(filter.list_intermediate_results())
289 ['transfer_function', 'impulse_response', 'frequency_response']
291 References:
292 API-005: Intermediate Result Access
293 """
294 if not hasattr(self, "_intermediates"):
295 return []
296 return list(self._intermediates.keys())
298 def _cache_intermediate(self, key: str, value: Any) -> None:
299 """Cache an intermediate result for later access.
301 This is a protected method for subclasses to use when storing
302 intermediate computation results.
304 Args:
305 key: Name of intermediate result.
306 value: Value to cache.
308 Example (in subclass):
309 >>> def transform(self, trace):
310 ... fft_coeffs = compute_fft(trace)
311 ... self._cache_intermediate('fft_coeffs', fft_coeffs)
312 ... return processed_trace
314 References:
315 API-005: Intermediate Result Access
316 """
317 if not hasattr(self, "_intermediates"):
318 self._intermediates = {}
319 self._intermediates[key] = value
321 def _clear_intermediates(self) -> None:
322 """Clear all cached intermediate results.
324 Useful for freeing memory when intermediate results are no longer needed.
326 Example (in subclass):
327 >>> def transform(self, trace):
328 ... self._clear_intermediates() # Clear previous results
329 ... # ... perform transformation ...
331 References:
332 API-005: Intermediate Result Access
333 """
334 if hasattr(self, "_intermediates"): 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true
335 self._intermediates.clear()
338__all__ = ["TraceTransformer"]