Coverage for src / tracekit / pipeline / base.py: 69%

68 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Base classes for trace transformations and pipeline stages. 

2 

3This module implements the foundational abstract base classes for creating 

4custom trace transformations compatible with the Pipeline architecture. 

5""" 

6 

7from __future__ import annotations 

8 

9from abc import ABC, abstractmethod 

10from typing import TYPE_CHECKING, Any 

11 

12if TYPE_CHECKING: 

13 from ..core.types import WaveformTrace 

14 

15 

16class TraceTransformer(ABC): 

17 """Abstract base class for trace transformations. 

18 

19 All pipeline stages and custom transformations must inherit from this class. 

20 Provides the fit/transform pattern similar to sklearn transformers. 

21 

22 The TraceTransformer enforces a consistent interface: 

23 - transform(trace) -> trace: Required transformation method 

24 - fit(trace) -> self: Optional learning/calibration method 

25 - fit_transform(trace) -> trace: Convenience method 

26 - get_params() / set_params(): Hyperparameter access 

27 - clone(): Create a copy of the transformer 

28 

29 Example: 

30 >>> class AmplitudeScaler(TraceTransformer): 

31 ... def __init__(self, scale_factor=1.0): 

32 ... self.scale_factor = scale_factor 

33 ... 

34 ... def transform(self, trace): 

35 ... scaled_data = trace.data * self.scale_factor 

36 ... return WaveformTrace( 

37 ... data=scaled_data, 

38 ... metadata=trace.metadata 

39 ... ) 

40 ... 

41 >>> scaler = AmplitudeScaler(scale_factor=2.0) 

42 >>> result = scaler.transform(trace) 

43 

44 References: 

45 API-004: TraceTransformer Base Class 

46 sklearn.base.BaseEstimator, TransformerMixin 

47 """ 

48 

49 @abstractmethod 

50 def transform(self, trace: WaveformTrace) -> WaveformTrace: 

51 """Transform a trace. 

52 

53 Args: 

54 trace: Input WaveformTrace to transform. 

55 

56 Returns: 

57 Transformed WaveformTrace. 

58 

59 Raises: 

60 NotImplementedError: If not implemented by subclass. 

61 """ 

62 raise NotImplementedError(f"{self.__class__.__name__} must implement transform() method") 

63 

64 def fit(self, trace: WaveformTrace) -> TraceTransformer: 

65 """Fit transformer to a reference trace (optional for stateful transformers). 

66 

67 This method is optional and should be overridden by stateful transformers 

68 that need to learn parameters from a reference trace (e.g., normalization 

69 statistics, adaptive filters). 

70 

71 Args: 

72 trace: Reference WaveformTrace to fit to. 

73 

74 Returns: 

75 Self for method chaining. 

76 

77 Example: 

78 >>> class AdaptiveNormalizer(TraceTransformer): 

79 ... def __init__(self): 

80 ... self.mean_ = None 

81 ... self.std_ = None 

82 ... 

83 ... def fit(self, trace): 

84 ... self.mean_ = trace.data.mean() 

85 ... self.std_ = trace.data.std() 

86 ... return self 

87 ... 

88 ... def transform(self, trace): 

89 ... normalized = (trace.data - self.mean_) / self.std_ 

90 ... return WaveformTrace( 

91 ... data=normalized, 

92 ... metadata=trace.metadata 

93 ... ) 

94 """ 

95 # Default implementation: no fitting required 

96 return self 

97 

98 def fit_transform(self, trace: WaveformTrace) -> WaveformTrace: 

99 """Fit to trace, then transform it. 

100 

101 Convenience method that calls fit() followed by transform(). 

102 

103 Args: 

104 trace: Input WaveformTrace to fit and transform. 

105 

106 Returns: 

107 Transformed WaveformTrace. 

108 

109 Example: 

110 >>> normalizer = AdaptiveNormalizer() 

111 >>> result = normalizer.fit_transform(reference_trace) 

112 """ 

113 return self.fit(trace).transform(trace) 

114 

115 def get_params(self, deep: bool = True) -> dict[str, Any]: 

116 """Get parameters for this transformer. 

117 

118 Args: 

119 deep: If True, will return parameters for nested objects. 

120 

121 Returns: 

122 Dictionary of parameter names mapped to their values. 

123 

124 Example: 

125 >>> scaler = AmplitudeScaler(scale_factor=2.0) 

126 >>> params = scaler.get_params() 

127 >>> print(params) 

128 {'scale_factor': 2.0} 

129 """ 

130 params = {} 

131 for key in dir(self): 

132 # Skip private/magic attributes and methods 

133 if key.startswith("_") or callable(getattr(self, key)): 

134 continue 

135 value = getattr(self, key) 

136 params[key] = value 

137 

138 # Handle nested transformers if deep=True 

139 if deep and hasattr(value, "get_params"): 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true

140 nested_params = value.get_params(deep=True) 

141 for nested_key, nested_value in nested_params.items(): 

142 params[f"{key}__{nested_key}"] = nested_value 

143 

144 return params 

145 

146 def set_params(self, **params: Any) -> TraceTransformer: 

147 """Set parameters for this transformer. 

148 

149 Args: 

150 **params: Parameter names and values to set. 

151 

152 Returns: 

153 Self for method chaining. 

154 

155 Raises: 

156 ValueError: If parameter name is invalid. 

157 

158 Example: 

159 >>> scaler = AmplitudeScaler(scale_factor=1.0) 

160 >>> scaler.set_params(scale_factor=3.0) 

161 >>> print(scaler.scale_factor) 

162 3.0 

163 """ 

164 if not params: 164 ↛ 165line 164 didn't jump to line 165 because the condition on line 164 was never true

165 return self 

166 

167 valid_params = self.get_params(deep=False) 

168 

169 for key, value in params.items(): 

170 # Handle nested parameters (e.g., 'filter__cutoff') 

171 if "__" in key: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true

172 nested_obj, nested_key = key.split("__", 1) 

173 if nested_obj not in valid_params: 

174 raise ValueError( 

175 f"Invalid parameter {nested_obj} for transformer {self.__class__.__name__}" 

176 ) 

177 nested = getattr(self, nested_obj) 

178 if hasattr(nested, "set_params"): 

179 nested.set_params(**{nested_key: value}) 

180 else: 

181 raise ValueError(f"Parameter {nested_obj} does not support set_params") 

182 else: 

183 if key not in valid_params: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true

184 raise ValueError( 

185 f"Invalid parameter {key} for transformer " 

186 f"{self.__class__.__name__}. " 

187 f"Valid parameters: {list(valid_params.keys())}" 

188 ) 

189 setattr(self, key, value) 

190 

191 return self 

192 

193 def clone(self) -> TraceTransformer: 

194 """Create a copy of this transformer with the same parameters. 

195 

196 Returns: 

197 New instance of the transformer with same parameters. 

198 

199 Example: 

200 >>> scaler = AmplitudeScaler(scale_factor=2.0) 

201 >>> scaler_copy = scaler.clone() 

202 >>> scaler_copy.scale_factor 

203 2.0 

204 """ 

205 params = self.get_params(deep=False) 

206 return self.__class__(**params) 

207 

208 def __getstate__(self) -> dict[str, Any]: 

209 """Get state for pickling. 

210 

211 Returns: 

212 Dictionary containing transformer state. 

213 """ 

214 return self.__dict__.copy() 

215 

216 def __setstate__(self, state: dict[str, Any]) -> None: 

217 """Set state from unpickling. 

218 

219 Args: 

220 state: Dictionary containing transformer state. 

221 """ 

222 self.__dict__.update(state) 

223 

224 def get_intermediate_result(self, key: str) -> Any: 

225 """Get intermediate result from last transformation. 

226 

227 Some transformers cache intermediate results (e.g., FFT coefficients, 

228 filter states) that can be accessed after transformation. 

229 

230 Args: 

231 key: Name of intermediate result to retrieve. 

232 

233 Returns: 

234 Intermediate result value. 

235 

236 Raises: 

237 KeyError: If key not found or transformer doesn't support intermediates. 

238 

239 Example: 

240 >>> filter = LowPassFilter(cutoff=1e6) 

241 >>> result = filter.transform(trace) 

242 >>> transfer_func = filter.get_intermediate_result('transfer_function') 

243 

244 References: 

245 API-005: Intermediate Result Access 

246 """ 

247 # Check if transformer has _intermediates cache 

248 if not hasattr(self, "_intermediates"): 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true

249 raise KeyError(f"{self.__class__.__name__} does not cache intermediate results") 

250 

251 intermediates = self._intermediates 

252 if key not in intermediates: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true

253 available = list(intermediates.keys()) 

254 raise KeyError( 

255 f"Intermediate '{key}' not found in {self.__class__.__name__}. " 

256 f"Available: {available}" 

257 ) 

258 

259 return intermediates[key] 

260 

261 def has_intermediate_result(self, key: str) -> bool: 

262 """Check if intermediate result is available. 

263 

264 Args: 

265 key: Name of intermediate result. 

266 

267 Returns: 

268 True if intermediate result exists. 

269 

270 Example: 

271 >>> if filter.has_intermediate_result('impulse_response'): 

272 ... impulse = filter.get_intermediate_result('impulse_response') 

273 

274 References: 

275 API-005: Intermediate Result Access 

276 """ 

277 if not hasattr(self, "_intermediates"): 277 ↛ 278line 277 didn't jump to line 278 because the condition on line 277 was never true

278 return False 

279 return key in self._intermediates 

280 

281 def list_intermediate_results(self) -> list[str]: 

282 """List all available intermediate result keys. 

283 

284 Returns: 

285 List of intermediate result names, or empty list if none available. 

286 

287 Example: 

288 >>> print(filter.list_intermediate_results()) 

289 ['transfer_function', 'impulse_response', 'frequency_response'] 

290 

291 References: 

292 API-005: Intermediate Result Access 

293 """ 

294 if not hasattr(self, "_intermediates"): 

295 return [] 

296 return list(self._intermediates.keys()) 

297 

298 def _cache_intermediate(self, key: str, value: Any) -> None: 

299 """Cache an intermediate result for later access. 

300 

301 This is a protected method for subclasses to use when storing 

302 intermediate computation results. 

303 

304 Args: 

305 key: Name of intermediate result. 

306 value: Value to cache. 

307 

308 Example (in subclass): 

309 >>> def transform(self, trace): 

310 ... fft_coeffs = compute_fft(trace) 

311 ... self._cache_intermediate('fft_coeffs', fft_coeffs) 

312 ... return processed_trace 

313 

314 References: 

315 API-005: Intermediate Result Access 

316 """ 

317 if not hasattr(self, "_intermediates"): 

318 self._intermediates = {} 

319 self._intermediates[key] = value 

320 

321 def _clear_intermediates(self) -> None: 

322 """Clear all cached intermediate results. 

323 

324 Useful for freeing memory when intermediate results are no longer needed. 

325 

326 Example (in subclass): 

327 >>> def transform(self, trace): 

328 ... self._clear_intermediates() # Clear previous results 

329 ... # ... perform transformation ... 

330 

331 References: 

332 API-005: Intermediate Result Access 

333 """ 

334 if hasattr(self, "_intermediates"): 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true

335 self._intermediates.clear() 

336 

337 

338__all__ = ["TraceTransformer"]