Coverage for src / tracekit / math / interpolation.py: 99%
119 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Interpolation and resampling operations for TraceKit.
3This module provides interpolation, resampling, and trace alignment
4functions for waveform data.
7Example:
8 >>> from tracekit.math import resample, align_traces
9 >>> resampled = resample(trace, new_sample_rate=1e6)
10 >>> aligned = align_traces(trace1, trace2)
12References:
13 IEEE 181-2011: Standard for Transitional Waveform Definitions
14"""
16from __future__ import annotations
18import warnings
19from typing import TYPE_CHECKING, Literal
21import numpy as np
22from scipy import interpolate as sp_interp
23from scipy import signal as sp_signal
25from tracekit.core.exceptions import InsufficientDataError
26from tracekit.core.types import TraceMetadata, WaveformTrace
28if TYPE_CHECKING:
29 from numpy.typing import NDArray
32def interpolate(
33 trace: WaveformTrace,
34 new_time: NDArray[np.float64],
35 *,
36 method: Literal["linear", "cubic", "nearest", "zero"] = "linear",
37 fill_value: float | tuple[float, float] = np.nan,
38 channel_name: str | None = None,
39) -> WaveformTrace:
40 """Interpolate trace to new time points.
42 Interpolates the waveform data to a new set of time points using
43 the specified interpolation method.
45 Args:
46 trace: Input trace.
47 new_time: New time points in seconds.
48 method: Interpolation method:
49 - "linear": Linear interpolation (default)
50 - "cubic": Cubic spline interpolation
51 - "nearest": Nearest neighbor
52 - "zero": Zero-order hold (step function)
53 fill_value: Value for points outside original range.
54 Can be a single value or (below, above) tuple.
55 channel_name: Name for the result trace (optional).
57 Returns:
58 Interpolated WaveformTrace at new time points.
60 Raises:
61 InsufficientDataError: If trace has insufficient samples.
62 ValueError: If interpolation method is unknown.
64 Example:
65 >>> new_time = np.linspace(0, 1e-3, 2000)
66 >>> interpolated = interpolate(trace, new_time, method="cubic")
67 """
68 if len(trace.data) < 2:
69 raise InsufficientDataError(
70 "Need at least 2 samples for interpolation",
71 required=2,
72 available=len(trace.data),
73 analysis_type="interpolate",
74 )
76 original_time = trace.time_vector
77 data = trace.data.astype(np.float64)
79 # Create interpolator
80 if method == "linear":
81 interp_func = sp_interp.interp1d(
82 original_time,
83 data,
84 kind="linear",
85 bounds_error=False,
86 fill_value=fill_value,
87 )
88 elif method == "cubic":
89 interp_func = sp_interp.interp1d(
90 original_time,
91 data,
92 kind="cubic",
93 bounds_error=False,
94 fill_value=fill_value,
95 )
96 elif method == "nearest":
97 interp_func = sp_interp.interp1d(
98 original_time,
99 data,
100 kind="nearest",
101 bounds_error=False,
102 fill_value=fill_value,
103 )
104 elif method == "zero":
105 interp_func = sp_interp.interp1d(
106 original_time,
107 data,
108 kind="zero",
109 bounds_error=False,
110 fill_value=fill_value,
111 )
112 else:
113 raise ValueError(f"Unknown interpolation method: {method}")
115 # Interpolate
116 result_data = interp_func(new_time)
118 # Calculate new sample rate from time points
119 if len(new_time) > 1:
120 new_sample_rate = 1.0 / np.mean(np.diff(new_time))
121 else:
122 new_sample_rate = trace.metadata.sample_rate
124 new_metadata = TraceMetadata(
125 sample_rate=new_sample_rate,
126 vertical_scale=trace.metadata.vertical_scale,
127 vertical_offset=trace.metadata.vertical_offset,
128 acquisition_time=trace.metadata.acquisition_time,
129 trigger_info=trace.metadata.trigger_info,
130 source_file=trace.metadata.source_file,
131 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_interp",
132 )
134 return WaveformTrace(data=result_data.astype(np.float64), metadata=new_metadata)
137def resample(
138 trace: WaveformTrace,
139 new_sample_rate: float | None = None,
140 *,
141 num_samples: int | None = None,
142 method: Literal["fft", "polyphase", "interp"] = "fft",
143 anti_alias: bool = True,
144 channel_name: str | None = None,
145) -> WaveformTrace:
146 """Resample trace to new sample rate or number of samples.
148 Resamples the waveform to a different sample rate using high-quality
149 resampling algorithms. Applies anti-aliasing filter when downsampling.
151 Args:
152 trace: Input trace.
153 new_sample_rate: Target sample rate in Hz. Mutually exclusive
154 with num_samples.
155 num_samples: Target number of samples. Mutually exclusive with
156 new_sample_rate.
157 method: Resampling method:
158 - "fft": FFT-based resampling (default, best quality)
159 - "polyphase": Polyphase filter resampling (efficient)
160 - "interp": Linear interpolation (fastest)
161 anti_alias: Apply anti-aliasing filter before downsampling.
162 channel_name: Name for the result trace (optional).
164 Returns:
165 Resampled WaveformTrace.
167 Raises:
168 ValueError: If neither or both rate/samples specified.
169 InsufficientDataError: If trace has insufficient samples.
171 Example:
172 >>> upsampled = resample(trace, new_sample_rate=2e9)
173 >>> downsampled = resample(trace, num_samples=1000)
175 References:
176 MEM-012 (downsampling for memory management)
177 """
178 if (new_sample_rate is None) == (num_samples is None):
179 raise ValueError("Specify exactly one of new_sample_rate or num_samples")
181 if len(trace.data) < 2:
182 raise InsufficientDataError(
183 "Need at least 2 samples for resampling",
184 required=2,
185 available=len(trace.data),
186 analysis_type="resample",
187 )
189 data = trace.data.astype(np.float64)
190 original_rate = trace.metadata.sample_rate
191 original_samples = len(data)
193 # Calculate target parameters
194 if new_sample_rate is not None:
195 target_rate = new_sample_rate
196 target_samples = round(original_samples * target_rate / original_rate)
197 else:
198 target_samples = num_samples # type: ignore[assignment]
199 target_rate = original_rate * target_samples / original_samples
201 if target_samples < 1:
202 raise ValueError("Target number of samples must be at least 1")
204 # REQ: API-019 - Validate Nyquist criterion when downsampling
205 if target_rate < original_rate:
206 # Estimate maximum frequency using FFT
207 fft_data = np.fft.rfft(data)
208 fft_freqs = np.fft.rfftfreq(len(data), 1 / original_rate)
209 # Find frequency with 90% of max power as max frequency
210 power = np.abs(fft_data) ** 2
211 power_threshold = 0.01 * np.max(power) # 1% of max power
212 significant_freqs = fft_freqs[power > power_threshold]
213 if len(significant_freqs) > 0: 213 ↛ 227line 213 didn't jump to line 227 because the condition on line 213 was always true
214 max_frequency = np.max(significant_freqs)
215 nyquist_required = 2 * max_frequency
216 if target_rate < nyquist_required:
217 warnings.warn(
218 f"Downsampling to {target_rate:.2e} Hz violates Nyquist criterion. "
219 f"Maximum signal frequency is ~{max_frequency:.2e} Hz, "
220 f"requiring ≥{nyquist_required:.2e} Hz sample rate. "
221 f"Aliasing may occur.",
222 UserWarning,
223 stacklevel=2,
224 )
226 # Check if downsampling and apply anti-alias filter
227 if anti_alias and target_samples < original_samples:
228 # Lowpass filter at Nyquist of new rate
229 nyquist = target_rate / 2
230 cutoff = nyquist / original_rate * 2 # Normalized frequency
231 if cutoff < 1.0: 231 ↛ 237line 231 didn't jump to line 237 because the condition on line 231 was always true
232 # Design lowpass filter
233 b, a = sp_signal.butter(8, min(cutoff * 0.9, 0.99), btype="low")
234 data = sp_signal.filtfilt(b, a, data)
236 # Resample
237 if method == "fft":
238 result_data = sp_signal.resample(data, target_samples)
239 elif method == "polyphase":
240 # Find rational approximation for polyphase resampling
241 from fractions import Fraction
243 ratio = Fraction(target_samples, original_samples).limit_denominator(1000)
244 up, down = ratio.numerator, ratio.denominator
245 result_data = sp_signal.resample_poly(data, up, down)
246 # Trim to exact length
247 result_data = result_data[:target_samples]
248 elif method == "interp":
249 # Simple interpolation
250 old_time = np.arange(original_samples) / original_rate
251 new_time = np.arange(target_samples) / target_rate
252 result_data = np.interp(new_time, old_time, data)
253 else:
254 raise ValueError(f"Unknown resampling method: {method}")
256 new_metadata = TraceMetadata(
257 sample_rate=target_rate,
258 vertical_scale=trace.metadata.vertical_scale,
259 vertical_offset=trace.metadata.vertical_offset,
260 acquisition_time=trace.metadata.acquisition_time,
261 trigger_info=trace.metadata.trigger_info,
262 source_file=trace.metadata.source_file,
263 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_resampled",
264 )
266 return WaveformTrace(data=result_data.astype(np.float64), metadata=new_metadata)
269def align_traces(
270 trace1: WaveformTrace,
271 trace2: WaveformTrace,
272 *,
273 method: Literal["interpolate", "resample"] = "interpolate",
274 reference: Literal["first", "second", "higher"] = "higher",
275 channel_names: tuple[str | None, str | None] | None = None,
276) -> tuple[WaveformTrace, WaveformTrace]:
277 """Align two traces to have the same sample rate and length.
279 Adjusts two traces to be compatible for arithmetic operations by
280 resampling to a common sample rate and time base.
282 Args:
283 trace1: First trace.
284 trace2: Second trace.
285 method: Alignment method:
286 - "interpolate": Interpolate to common time points
287 - "resample": Resample to common rate
288 reference: Which trace to use as reference:
289 - "first": Use trace1's sample rate
290 - "second": Use trace2's sample rate
291 - "higher": Use the higher sample rate (default)
292 channel_names: Optional names for the aligned traces.
294 Returns:
295 Tuple of (aligned_trace1, aligned_trace2) with matching parameters.
297 Example:
298 >>> aligned1, aligned2 = align_traces(trace1, trace2)
299 >>> diff = subtract(aligned1, aligned2)
300 """
301 rate1 = trace1.metadata.sample_rate
302 rate2 = trace2.metadata.sample_rate
304 # Determine reference sample rate
305 if reference == "first":
306 target_rate = rate1
307 elif reference == "second":
308 target_rate = rate2
309 else: # "higher"
310 target_rate = max(rate1, rate2)
312 # Determine time span (use overlapping portion)
313 t1_end = len(trace1.data) / rate1
314 t2_end = len(trace2.data) / rate2
315 common_end = min(t1_end, t2_end)
317 # Calculate number of samples
318 num_samples = round(common_end * target_rate)
320 # Create common time vector
321 common_time = np.arange(num_samples) / target_rate
323 name1 = channel_names[0] if channel_names else None
324 name2 = channel_names[1] if channel_names else None
326 if method == "interpolate":
327 # Interpolate both traces to common time points
328 aligned1 = interpolate(trace1, common_time, channel_name=name1)
329 aligned2 = interpolate(trace2, common_time, channel_name=name2)
330 else: # "resample"
331 # Resample both to common rate
332 aligned1 = resample(trace1, num_samples=num_samples, channel_name=name1)
333 aligned2 = resample(trace2, num_samples=num_samples, channel_name=name2)
335 return aligned1, aligned2
338def downsample(
339 trace: WaveformTrace,
340 factor: int,
341 *,
342 anti_alias: bool = True,
343 method: Literal["decimate", "average", "max", "min"] = "decimate",
344 channel_name: str | None = None,
345) -> WaveformTrace:
346 """Downsample trace by an integer factor.
348 Reduces the sample rate by keeping every Nth sample (decimate)
349 or by aggregating N samples (average/max/min).
351 Args:
352 trace: Input trace.
353 factor: Downsampling factor (must be >= 1).
354 anti_alias: Apply anti-aliasing filter before decimation.
355 method: Downsampling method:
356 - "decimate": Keep every Nth sample (default)
357 - "average": Average every N samples
358 - "max": Maximum of every N samples
359 - "min": Minimum of every N samples
360 channel_name: Name for the result trace (optional).
362 Returns:
363 Downsampled WaveformTrace.
365 Raises:
366 ValueError: If factor is less than 1 or method is unknown.
368 Example:
369 >>> small = downsample(large_trace, factor=10)
371 References:
372 MEM-012 (memory management)
373 """
374 if factor < 1:
375 raise ValueError(f"Factor must be >= 1, got {factor}")
377 if factor == 1:
378 return trace # No change needed
380 data = trace.data.astype(np.float64)
382 if anti_alias and method == "decimate":
383 # Apply anti-aliasing filter
384 nyquist = 0.5 / factor
385 b, a = sp_signal.butter(8, min(nyquist * 0.9, 0.99), btype="low")
386 data = sp_signal.filtfilt(b, a, data)
388 # Truncate to multiple of factor
389 n = len(data) // factor * factor
390 data = data[:n]
392 if method == "decimate":
393 result_data = data[::factor]
394 elif method == "average":
395 result_data = data.reshape(-1, factor).mean(axis=1)
396 elif method == "max":
397 result_data = data.reshape(-1, factor).max(axis=1)
398 elif method == "min":
399 result_data = data.reshape(-1, factor).min(axis=1)
400 else:
401 raise ValueError(f"Unknown method: {method}")
403 new_metadata = TraceMetadata(
404 sample_rate=trace.metadata.sample_rate / factor,
405 vertical_scale=trace.metadata.vertical_scale,
406 vertical_offset=trace.metadata.vertical_offset,
407 acquisition_time=trace.metadata.acquisition_time,
408 trigger_info=trace.metadata.trigger_info,
409 source_file=trace.metadata.source_file,
410 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_ds{factor}",
411 )
413 return WaveformTrace(data=result_data, metadata=new_metadata)