Coverage for src / tracekit / math / interpolation.py: 99%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Interpolation and resampling operations for TraceKit. 

2 

3This module provides interpolation, resampling, and trace alignment 

4functions for waveform data. 

5 

6 

7Example: 

8 >>> from tracekit.math import resample, align_traces 

9 >>> resampled = resample(trace, new_sample_rate=1e6) 

10 >>> aligned = align_traces(trace1, trace2) 

11 

12References: 

13 IEEE 181-2011: Standard for Transitional Waveform Definitions 

14""" 

15 

16from __future__ import annotations 

17 

18import warnings 

19from typing import TYPE_CHECKING, Literal 

20 

21import numpy as np 

22from scipy import interpolate as sp_interp 

23from scipy import signal as sp_signal 

24 

25from tracekit.core.exceptions import InsufficientDataError 

26from tracekit.core.types import TraceMetadata, WaveformTrace 

27 

28if TYPE_CHECKING: 

29 from numpy.typing import NDArray 

30 

31 

32def interpolate( 

33 trace: WaveformTrace, 

34 new_time: NDArray[np.float64], 

35 *, 

36 method: Literal["linear", "cubic", "nearest", "zero"] = "linear", 

37 fill_value: float | tuple[float, float] = np.nan, 

38 channel_name: str | None = None, 

39) -> WaveformTrace: 

40 """Interpolate trace to new time points. 

41 

42 Interpolates the waveform data to a new set of time points using 

43 the specified interpolation method. 

44 

45 Args: 

46 trace: Input trace. 

47 new_time: New time points in seconds. 

48 method: Interpolation method: 

49 - "linear": Linear interpolation (default) 

50 - "cubic": Cubic spline interpolation 

51 - "nearest": Nearest neighbor 

52 - "zero": Zero-order hold (step function) 

53 fill_value: Value for points outside original range. 

54 Can be a single value or (below, above) tuple. 

55 channel_name: Name for the result trace (optional). 

56 

57 Returns: 

58 Interpolated WaveformTrace at new time points. 

59 

60 Raises: 

61 InsufficientDataError: If trace has insufficient samples. 

62 ValueError: If interpolation method is unknown. 

63 

64 Example: 

65 >>> new_time = np.linspace(0, 1e-3, 2000) 

66 >>> interpolated = interpolate(trace, new_time, method="cubic") 

67 """ 

68 if len(trace.data) < 2: 

69 raise InsufficientDataError( 

70 "Need at least 2 samples for interpolation", 

71 required=2, 

72 available=len(trace.data), 

73 analysis_type="interpolate", 

74 ) 

75 

76 original_time = trace.time_vector 

77 data = trace.data.astype(np.float64) 

78 

79 # Create interpolator 

80 if method == "linear": 

81 interp_func = sp_interp.interp1d( 

82 original_time, 

83 data, 

84 kind="linear", 

85 bounds_error=False, 

86 fill_value=fill_value, 

87 ) 

88 elif method == "cubic": 

89 interp_func = sp_interp.interp1d( 

90 original_time, 

91 data, 

92 kind="cubic", 

93 bounds_error=False, 

94 fill_value=fill_value, 

95 ) 

96 elif method == "nearest": 

97 interp_func = sp_interp.interp1d( 

98 original_time, 

99 data, 

100 kind="nearest", 

101 bounds_error=False, 

102 fill_value=fill_value, 

103 ) 

104 elif method == "zero": 

105 interp_func = sp_interp.interp1d( 

106 original_time, 

107 data, 

108 kind="zero", 

109 bounds_error=False, 

110 fill_value=fill_value, 

111 ) 

112 else: 

113 raise ValueError(f"Unknown interpolation method: {method}") 

114 

115 # Interpolate 

116 result_data = interp_func(new_time) 

117 

118 # Calculate new sample rate from time points 

119 if len(new_time) > 1: 

120 new_sample_rate = 1.0 / np.mean(np.diff(new_time)) 

121 else: 

122 new_sample_rate = trace.metadata.sample_rate 

123 

124 new_metadata = TraceMetadata( 

125 sample_rate=new_sample_rate, 

126 vertical_scale=trace.metadata.vertical_scale, 

127 vertical_offset=trace.metadata.vertical_offset, 

128 acquisition_time=trace.metadata.acquisition_time, 

129 trigger_info=trace.metadata.trigger_info, 

130 source_file=trace.metadata.source_file, 

131 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_interp", 

132 ) 

133 

134 return WaveformTrace(data=result_data.astype(np.float64), metadata=new_metadata) 

135 

136 

137def resample( 

138 trace: WaveformTrace, 

139 new_sample_rate: float | None = None, 

140 *, 

141 num_samples: int | None = None, 

142 method: Literal["fft", "polyphase", "interp"] = "fft", 

143 anti_alias: bool = True, 

144 channel_name: str | None = None, 

145) -> WaveformTrace: 

146 """Resample trace to new sample rate or number of samples. 

147 

148 Resamples the waveform to a different sample rate using high-quality 

149 resampling algorithms. Applies anti-aliasing filter when downsampling. 

150 

151 Args: 

152 trace: Input trace. 

153 new_sample_rate: Target sample rate in Hz. Mutually exclusive 

154 with num_samples. 

155 num_samples: Target number of samples. Mutually exclusive with 

156 new_sample_rate. 

157 method: Resampling method: 

158 - "fft": FFT-based resampling (default, best quality) 

159 - "polyphase": Polyphase filter resampling (efficient) 

160 - "interp": Linear interpolation (fastest) 

161 anti_alias: Apply anti-aliasing filter before downsampling. 

162 channel_name: Name for the result trace (optional). 

163 

164 Returns: 

165 Resampled WaveformTrace. 

166 

167 Raises: 

168 ValueError: If neither or both rate/samples specified. 

169 InsufficientDataError: If trace has insufficient samples. 

170 

171 Example: 

172 >>> upsampled = resample(trace, new_sample_rate=2e9) 

173 >>> downsampled = resample(trace, num_samples=1000) 

174 

175 References: 

176 MEM-012 (downsampling for memory management) 

177 """ 

178 if (new_sample_rate is None) == (num_samples is None): 

179 raise ValueError("Specify exactly one of new_sample_rate or num_samples") 

180 

181 if len(trace.data) < 2: 

182 raise InsufficientDataError( 

183 "Need at least 2 samples for resampling", 

184 required=2, 

185 available=len(trace.data), 

186 analysis_type="resample", 

187 ) 

188 

189 data = trace.data.astype(np.float64) 

190 original_rate = trace.metadata.sample_rate 

191 original_samples = len(data) 

192 

193 # Calculate target parameters 

194 if new_sample_rate is not None: 

195 target_rate = new_sample_rate 

196 target_samples = round(original_samples * target_rate / original_rate) 

197 else: 

198 target_samples = num_samples # type: ignore[assignment] 

199 target_rate = original_rate * target_samples / original_samples 

200 

201 if target_samples < 1: 

202 raise ValueError("Target number of samples must be at least 1") 

203 

204 # REQ: API-019 - Validate Nyquist criterion when downsampling 

205 if target_rate < original_rate: 

206 # Estimate maximum frequency using FFT 

207 fft_data = np.fft.rfft(data) 

208 fft_freqs = np.fft.rfftfreq(len(data), 1 / original_rate) 

209 # Find frequency with 90% of max power as max frequency 

210 power = np.abs(fft_data) ** 2 

211 power_threshold = 0.01 * np.max(power) # 1% of max power 

212 significant_freqs = fft_freqs[power > power_threshold] 

213 if len(significant_freqs) > 0: 213 ↛ 227line 213 didn't jump to line 227 because the condition on line 213 was always true

214 max_frequency = np.max(significant_freqs) 

215 nyquist_required = 2 * max_frequency 

216 if target_rate < nyquist_required: 

217 warnings.warn( 

218 f"Downsampling to {target_rate:.2e} Hz violates Nyquist criterion. " 

219 f"Maximum signal frequency is ~{max_frequency:.2e} Hz, " 

220 f"requiring ≥{nyquist_required:.2e} Hz sample rate. " 

221 f"Aliasing may occur.", 

222 UserWarning, 

223 stacklevel=2, 

224 ) 

225 

226 # Check if downsampling and apply anti-alias filter 

227 if anti_alias and target_samples < original_samples: 

228 # Lowpass filter at Nyquist of new rate 

229 nyquist = target_rate / 2 

230 cutoff = nyquist / original_rate * 2 # Normalized frequency 

231 if cutoff < 1.0: 231 ↛ 237line 231 didn't jump to line 237 because the condition on line 231 was always true

232 # Design lowpass filter 

233 b, a = sp_signal.butter(8, min(cutoff * 0.9, 0.99), btype="low") 

234 data = sp_signal.filtfilt(b, a, data) 

235 

236 # Resample 

237 if method == "fft": 

238 result_data = sp_signal.resample(data, target_samples) 

239 elif method == "polyphase": 

240 # Find rational approximation for polyphase resampling 

241 from fractions import Fraction 

242 

243 ratio = Fraction(target_samples, original_samples).limit_denominator(1000) 

244 up, down = ratio.numerator, ratio.denominator 

245 result_data = sp_signal.resample_poly(data, up, down) 

246 # Trim to exact length 

247 result_data = result_data[:target_samples] 

248 elif method == "interp": 

249 # Simple interpolation 

250 old_time = np.arange(original_samples) / original_rate 

251 new_time = np.arange(target_samples) / target_rate 

252 result_data = np.interp(new_time, old_time, data) 

253 else: 

254 raise ValueError(f"Unknown resampling method: {method}") 

255 

256 new_metadata = TraceMetadata( 

257 sample_rate=target_rate, 

258 vertical_scale=trace.metadata.vertical_scale, 

259 vertical_offset=trace.metadata.vertical_offset, 

260 acquisition_time=trace.metadata.acquisition_time, 

261 trigger_info=trace.metadata.trigger_info, 

262 source_file=trace.metadata.source_file, 

263 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_resampled", 

264 ) 

265 

266 return WaveformTrace(data=result_data.astype(np.float64), metadata=new_metadata) 

267 

268 

269def align_traces( 

270 trace1: WaveformTrace, 

271 trace2: WaveformTrace, 

272 *, 

273 method: Literal["interpolate", "resample"] = "interpolate", 

274 reference: Literal["first", "second", "higher"] = "higher", 

275 channel_names: tuple[str | None, str | None] | None = None, 

276) -> tuple[WaveformTrace, WaveformTrace]: 

277 """Align two traces to have the same sample rate and length. 

278 

279 Adjusts two traces to be compatible for arithmetic operations by 

280 resampling to a common sample rate and time base. 

281 

282 Args: 

283 trace1: First trace. 

284 trace2: Second trace. 

285 method: Alignment method: 

286 - "interpolate": Interpolate to common time points 

287 - "resample": Resample to common rate 

288 reference: Which trace to use as reference: 

289 - "first": Use trace1's sample rate 

290 - "second": Use trace2's sample rate 

291 - "higher": Use the higher sample rate (default) 

292 channel_names: Optional names for the aligned traces. 

293 

294 Returns: 

295 Tuple of (aligned_trace1, aligned_trace2) with matching parameters. 

296 

297 Example: 

298 >>> aligned1, aligned2 = align_traces(trace1, trace2) 

299 >>> diff = subtract(aligned1, aligned2) 

300 """ 

301 rate1 = trace1.metadata.sample_rate 

302 rate2 = trace2.metadata.sample_rate 

303 

304 # Determine reference sample rate 

305 if reference == "first": 

306 target_rate = rate1 

307 elif reference == "second": 

308 target_rate = rate2 

309 else: # "higher" 

310 target_rate = max(rate1, rate2) 

311 

312 # Determine time span (use overlapping portion) 

313 t1_end = len(trace1.data) / rate1 

314 t2_end = len(trace2.data) / rate2 

315 common_end = min(t1_end, t2_end) 

316 

317 # Calculate number of samples 

318 num_samples = round(common_end * target_rate) 

319 

320 # Create common time vector 

321 common_time = np.arange(num_samples) / target_rate 

322 

323 name1 = channel_names[0] if channel_names else None 

324 name2 = channel_names[1] if channel_names else None 

325 

326 if method == "interpolate": 

327 # Interpolate both traces to common time points 

328 aligned1 = interpolate(trace1, common_time, channel_name=name1) 

329 aligned2 = interpolate(trace2, common_time, channel_name=name2) 

330 else: # "resample" 

331 # Resample both to common rate 

332 aligned1 = resample(trace1, num_samples=num_samples, channel_name=name1) 

333 aligned2 = resample(trace2, num_samples=num_samples, channel_name=name2) 

334 

335 return aligned1, aligned2 

336 

337 

338def downsample( 

339 trace: WaveformTrace, 

340 factor: int, 

341 *, 

342 anti_alias: bool = True, 

343 method: Literal["decimate", "average", "max", "min"] = "decimate", 

344 channel_name: str | None = None, 

345) -> WaveformTrace: 

346 """Downsample trace by an integer factor. 

347 

348 Reduces the sample rate by keeping every Nth sample (decimate) 

349 or by aggregating N samples (average/max/min). 

350 

351 Args: 

352 trace: Input trace. 

353 factor: Downsampling factor (must be >= 1). 

354 anti_alias: Apply anti-aliasing filter before decimation. 

355 method: Downsampling method: 

356 - "decimate": Keep every Nth sample (default) 

357 - "average": Average every N samples 

358 - "max": Maximum of every N samples 

359 - "min": Minimum of every N samples 

360 channel_name: Name for the result trace (optional). 

361 

362 Returns: 

363 Downsampled WaveformTrace. 

364 

365 Raises: 

366 ValueError: If factor is less than 1 or method is unknown. 

367 

368 Example: 

369 >>> small = downsample(large_trace, factor=10) 

370 

371 References: 

372 MEM-012 (memory management) 

373 """ 

374 if factor < 1: 

375 raise ValueError(f"Factor must be >= 1, got {factor}") 

376 

377 if factor == 1: 

378 return trace # No change needed 

379 

380 data = trace.data.astype(np.float64) 

381 

382 if anti_alias and method == "decimate": 

383 # Apply anti-aliasing filter 

384 nyquist = 0.5 / factor 

385 b, a = sp_signal.butter(8, min(nyquist * 0.9, 0.99), btype="low") 

386 data = sp_signal.filtfilt(b, a, data) 

387 

388 # Truncate to multiple of factor 

389 n = len(data) // factor * factor 

390 data = data[:n] 

391 

392 if method == "decimate": 

393 result_data = data[::factor] 

394 elif method == "average": 

395 result_data = data.reshape(-1, factor).mean(axis=1) 

396 elif method == "max": 

397 result_data = data.reshape(-1, factor).max(axis=1) 

398 elif method == "min": 

399 result_data = data.reshape(-1, factor).min(axis=1) 

400 else: 

401 raise ValueError(f"Unknown method: {method}") 

402 

403 new_metadata = TraceMetadata( 

404 sample_rate=trace.metadata.sample_rate / factor, 

405 vertical_scale=trace.metadata.vertical_scale, 

406 vertical_offset=trace.metadata.vertical_offset, 

407 acquisition_time=trace.metadata.acquisition_time, 

408 trigger_info=trace.metadata.trigger_info, 

409 source_file=trace.metadata.source_file, 

410 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_ds{factor}", 

411 ) 

412 

413 return WaveformTrace(data=result_data, metadata=new_metadata)