Coverage for src / tracekit / analyzers / statistics / correlation.py: 97%

208 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Correlation analysis for signal data. 

2 

3This module provides autocorrelation, cross-correlation, and related 

4analysis functions for identifying signal relationships and periodicities. 

5 

6 

7Example: 

8 >>> from tracekit.analyzers.statistics.correlation import ( 

9 ... autocorrelation, cross_correlation, correlate_chunked 

10 ... ) 

11 >>> acf = autocorrelation(trace, max_lag=1000) 

12 >>> xcorr, lag, coef = cross_correlation(trace1, trace2) 

13 >>> # Memory-efficient correlation for large signals 

14 >>> result = correlate_chunked(large_signal1, large_signal2) 

15 

16References: 

17 Oppenheim, A. V. & Schafer, R. W. (2009). Discrete-Time Signal Processing 

18 IEEE 1241-2010: Standard for Terminology and Test Methods for ADCs 

19""" 

20 

21from __future__ import annotations 

22 

23from dataclasses import dataclass 

24from typing import TYPE_CHECKING, Any 

25 

26import numpy as np 

27 

28from tracekit.core.types import WaveformTrace 

29 

30if TYPE_CHECKING: 

31 from numpy.typing import NDArray 

32 

33 

34@dataclass 

35class CrossCorrelationResult: 

36 """Result of cross-correlation analysis. 

37 

38 Attributes: 

39 correlation: Full correlation array. 

40 lags: Lag values in samples. 

41 lag_times: Lag values in seconds. 

42 peak_lag: Lag at maximum correlation (samples). 

43 peak_lag_time: Lag at maximum correlation (seconds). 

44 peak_coefficient: Maximum correlation coefficient. 

45 sample_rate: Sample rate used for time conversion. 

46 """ 

47 

48 correlation: NDArray[np.float64] 

49 lags: NDArray[np.intp] 

50 lag_times: NDArray[np.float64] 

51 peak_lag: int 

52 peak_lag_time: float 

53 peak_coefficient: float 

54 sample_rate: float 

55 

56 

57def autocorrelation( 

58 trace: WaveformTrace | NDArray[np.floating[Any]], 

59 *, 

60 max_lag: int | None = None, 

61 normalized: bool = True, 

62 sample_rate: float | None = None, 

63) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

64 """Compute autocorrelation of a signal. 

65 

66 Measures self-similarity of a signal at different time lags. 

67 Useful for detecting periodicities and characteristic time scales. 

68 

69 Args: 

70 trace: Input trace or numpy array. 

71 max_lag: Maximum lag to compute (samples). If None, uses n // 2. 

72 normalized: If True, normalize to correlation coefficients [-1, 1]. 

73 sample_rate: Sample rate in Hz (for time axis). Required if trace is array. 

74 

75 Returns: 

76 Tuple of (lags_time, autocorrelation): 

77 - lags_time: Time values for each lag in seconds 

78 - autocorrelation: Normalized autocorrelation values 

79 

80 Raises: 

81 ValueError: If sample_rate is not provided when trace is array. 

82 

83 Example: 

84 >>> lag_times, acf = autocorrelation(trace, max_lag=1000) 

85 >>> # Find first zero crossing for decorrelation time 

86 >>> zero_idx = np.where(acf[1:] < 0)[0][0] 

87 >>> decorr_time = lag_times[zero_idx] 

88 

89 References: 

90 Box, G. E. P. & Jenkins, G. M. (1976). Time Series Analysis 

91 """ 

92 if isinstance(trace, WaveformTrace): 

93 data = trace.data 

94 fs = trace.metadata.sample_rate 

95 else: 

96 data = trace 

97 if sample_rate is None: 

98 raise ValueError("sample_rate required when trace is array") 

99 fs = sample_rate 

100 

101 n = len(data) 

102 

103 if max_lag is None: 

104 max_lag = n // 2 

105 

106 max_lag = min(max_lag, n - 1) 

107 

108 # Remove mean for proper correlation 

109 data_centered = data - np.mean(data) 

110 

111 # Compute autocorrelation via FFT (faster for large n) 

112 if n > 256: 

113 # Zero-pad for full correlation 

114 nfft = int(2 ** np.ceil(np.log2(2 * n))) 

115 fft_data = np.fft.rfft(data_centered, n=nfft) 

116 acf_full = np.fft.irfft(fft_data * np.conj(fft_data), n=nfft) 

117 acf = acf_full[: max_lag + 1] 

118 else: 

119 # Direct computation for small n 

120 acf = np.correlate(data_centered, data_centered, mode="full") 

121 acf = acf[n - 1 : n + max_lag] 

122 

123 # Normalize 

124 if normalized and acf[0] > 0: 

125 acf = acf / acf[0] 

126 

127 # Time axis 

128 lags = np.arange(max_lag + 1) 

129 lag_times = lags / fs 

130 

131 return lag_times, acf.astype(np.float64) 

132 

133 

134def cross_correlation( 

135 trace1: WaveformTrace | NDArray[np.floating[Any]], 

136 trace2: WaveformTrace | NDArray[np.floating[Any]], 

137 *, 

138 max_lag: int | None = None, 

139 normalized: bool = True, 

140 sample_rate: float | None = None, 

141) -> CrossCorrelationResult: 

142 """Compute cross-correlation between two signals. 

143 

144 Measures similarity between signals at different time lags. 

145 Useful for finding time delays, alignments, and relationships. 

146 

147 Args: 

148 trace1: First input trace or numpy array (reference). 

149 trace2: Second input trace or numpy array. 

150 max_lag: Maximum lag to compute (samples). If None, uses min(n1, n2) // 2. 

151 normalized: If True, normalize to correlation coefficients [-1, 1]. 

152 sample_rate: Sample rate in Hz. Required if traces are arrays. 

153 

154 Returns: 

155 CrossCorrelationResult with correlation data and optimal lag. 

156 

157 Raises: 

158 ValueError: If sample_rate is not provided when traces are arrays. 

159 

160 Example: 

161 >>> result = cross_correlation(trace1, trace2) 

162 >>> print(f"Optimal lag: {result.peak_lag_time * 1e6:.1f} us") 

163 >>> print(f"Correlation: {result.peak_coefficient:.3f}") 

164 

165 References: 

166 Oppenheim, A. V. & Schafer, R. W. (2009). Discrete-Time Signal Processing 

167 """ 

168 if isinstance(trace1, WaveformTrace): 

169 data1 = trace1.data 

170 fs = trace1.metadata.sample_rate 

171 else: 

172 data1 = trace1 

173 if sample_rate is None: 

174 raise ValueError("sample_rate required when traces are arrays") 

175 fs = sample_rate 

176 

177 if isinstance(trace2, WaveformTrace): 

178 data2 = trace2.data 

179 # Use trace2 sample rate if available and trace1 wasn't a WaveformTrace 

180 if not isinstance(trace1, WaveformTrace): 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true

181 fs = trace2.metadata.sample_rate 

182 else: 

183 data2 = trace2 

184 

185 n1, n2 = len(data1), len(data2) 

186 

187 if max_lag is None: 

188 max_lag = min(n1, n2) // 2 

189 

190 # Center the data 

191 data1_centered = data1 - np.mean(data1) 

192 data2_centered = data2 - np.mean(data2) 

193 

194 # Full cross-correlation 

195 # Note: np.correlate(a, b) computes sum(a[n+k] * conj(b[k])) 

196 # For cross-correlation where we want to detect b delayed relative to a, 

197 # we need correlate(b, a) so positive lag means b is delayed 

198 xcorr_full = np.correlate(data2_centered, data1_centered, mode="full") 

199 

200 # Extract relevant portion around zero lag 

201 # Full correlation has length n1 + n2 - 1, with zero lag at index n1 - 1 

202 # (since we swapped the order above) 

203 zero_lag_idx = n1 - 1 

204 start_idx = max(0, zero_lag_idx - max_lag) 

205 end_idx = min(len(xcorr_full), zero_lag_idx + max_lag + 1) 

206 xcorr = xcorr_full[start_idx:end_idx] 

207 

208 # Create lag array 

209 lags = np.arange(start_idx - zero_lag_idx, end_idx - zero_lag_idx) 

210 

211 # Normalize 

212 if normalized: 

213 norm1 = np.sqrt(np.sum(data1_centered**2)) 

214 norm2 = np.sqrt(np.sum(data2_centered**2)) 

215 if norm1 > 0 and norm2 > 0: 

216 xcorr = xcorr / (norm1 * norm2) 

217 

218 # Find peak 

219 peak_local_idx = np.argmax(np.abs(xcorr)) 

220 peak_lag = int(lags[peak_local_idx]) 

221 peak_coefficient = float(xcorr[peak_local_idx]) 

222 

223 # Time values 

224 lag_times = lags / fs 

225 peak_lag_time = peak_lag / fs 

226 

227 return CrossCorrelationResult( 

228 correlation=xcorr.astype(np.float64), 

229 lags=lags, 

230 lag_times=lag_times.astype(np.float64), 

231 peak_lag=peak_lag, 

232 peak_lag_time=peak_lag_time, 

233 peak_coefficient=peak_coefficient, 

234 sample_rate=fs, 

235 ) 

236 

237 

238def correlation_coefficient( 

239 trace1: WaveformTrace | NDArray[np.floating[Any]], 

240 trace2: WaveformTrace | NDArray[np.floating[Any]], 

241) -> float: 

242 """Compute Pearson correlation coefficient between two signals. 

243 

244 Simple measure of linear relationship between signals at zero lag. 

245 

246 Args: 

247 trace1: First input trace or numpy array. 

248 trace2: Second input trace or numpy array. 

249 

250 Returns: 

251 Correlation coefficient in range [-1, 1]. 

252 

253 Example: 

254 >>> r = correlation_coefficient(trace1, trace2) 

255 >>> print(f"Correlation: {r:.3f}") 

256 """ 

257 data1 = trace1.data if isinstance(trace1, WaveformTrace) else trace1 

258 

259 data2 = trace2.data if isinstance(trace2, WaveformTrace) else trace2 

260 

261 # Ensure same length 

262 n = min(len(data1), len(data2)) 

263 data1 = data1[:n] 

264 data2 = data2[:n] 

265 

266 # Compute correlation 

267 return float(np.corrcoef(data1, data2)[0, 1]) 

268 

269 

270def find_periodicity( 

271 trace: WaveformTrace | NDArray[np.floating[Any]], 

272 *, 

273 min_period_samples: int = 2, 

274 max_period_samples: int | None = None, 

275 sample_rate: float | None = None, 

276) -> dict[str, float | int | list[dict[str, int | float]]]: 

277 """Find dominant periodicity in signal using autocorrelation. 

278 

279 Detects the primary periodic component by finding the first 

280 significant peak in the autocorrelation function. 

281 

282 Args: 

283 trace: Input trace or numpy array. 

284 min_period_samples: Minimum period to consider (samples). 

285 max_period_samples: Maximum period to consider (samples). 

286 sample_rate: Sample rate in Hz (required for array input). 

287 

288 Returns: 

289 Dictionary with periodicity analysis: 

290 - period_samples: Period in samples 

291 - period_time: Period in seconds 

292 - frequency: Frequency in Hz 

293 - strength: Autocorrelation at period (0-1) 

294 - harmonics: List of detected harmonics 

295 

296 Raises: 

297 ValueError: If sample_rate is not provided when trace is array. 

298 

299 Example: 

300 >>> result = find_periodicity(trace) 

301 >>> print(f"Period: {result['period_time']*1e6:.2f} us") 

302 >>> print(f"Frequency: {result['frequency']/1e3:.1f} kHz") 

303 """ 

304 if isinstance(trace, WaveformTrace): 

305 data = trace.data 

306 fs = trace.metadata.sample_rate 

307 else: 

308 data = trace 

309 if sample_rate is None: 

310 raise ValueError("sample_rate required when trace is array") 

311 fs = sample_rate 

312 

313 n = len(data) 

314 

315 if max_period_samples is None: 

316 max_period_samples = n // 2 

317 

318 # Compute autocorrelation 

319 _lag_times, acf = autocorrelation( 

320 trace, 

321 max_lag=max_period_samples, 

322 sample_rate=sample_rate if sample_rate else fs, 

323 ) 

324 

325 # Find peaks in autocorrelation (after lag 0) 

326 # Look for local maxima 

327 acf_search = acf[min_period_samples:] 

328 

329 if len(acf_search) < 3: 

330 return { 

331 "period_samples": np.nan, 

332 "period_time": np.nan, 

333 "frequency": np.nan, 

334 "strength": np.nan, 

335 "harmonics": [], 

336 } 

337 

338 # Find local maxima 

339 local_max = (acf_search[1:-1] > acf_search[:-2]) & (acf_search[1:-1] > acf_search[2:]) 

340 max_indices = np.where(local_max)[0] + 1 # +1 for offset from [1:-1] 

341 

342 if len(max_indices) == 0: 

343 # No local maxima found, use global max 

344 primary_idx = int(np.argmax(acf_search)) + min_period_samples 

345 strength = float(acf[primary_idx]) 

346 else: 

347 # Find strongest peak 

348 peak_values = acf_search[max_indices] 

349 best_peak_idx = int(np.argmax(peak_values)) 

350 primary_idx = int(max_indices[best_peak_idx]) + min_period_samples 

351 strength = float(acf[primary_idx]) 

352 

353 period_samples = int(primary_idx) 

354 period_time = period_samples / fs 

355 frequency = 1.0 / period_time if period_time > 0 else np.nan 

356 

357 # Find harmonics (peaks at multiples of period) 

358 harmonics: list[dict[str, int | float]] = [] 

359 for h in range(2, 6): # Check up to 5th harmonic 

360 harmonic_lag = h * period_samples 

361 if harmonic_lag < len(acf): 

362 # Look for peak near expected harmonic 

363 search_range = max(1, period_samples // 4) 

364 start = int(max(0, harmonic_lag - search_range)) 

365 end = int(min(len(acf), harmonic_lag + search_range)) 

366 local_max_idx = int(start + int(np.argmax(acf[start:end]))) 

367 harmonic_strength = float(acf[local_max_idx]) 

368 

369 if harmonic_strength > 0.3: # Threshold for significant harmonic 

370 harmonics.append( 

371 { 

372 "harmonic": h, 

373 "lag_samples": local_max_idx, 

374 "strength": harmonic_strength, 

375 } 

376 ) 

377 

378 return { 

379 "period_samples": period_samples, 

380 "period_time": float(period_time), 

381 "frequency": float(frequency), 

382 "strength": strength, 

383 "harmonics": harmonics, 

384 } 

385 

386 

387def coherence( 

388 trace1: WaveformTrace | NDArray[np.floating[Any]], 

389 trace2: WaveformTrace | NDArray[np.floating[Any]], 

390 *, 

391 nperseg: int | None = None, 

392 sample_rate: float | None = None, 

393) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

394 """Compute magnitude-squared coherence between two signals. 

395 

396 Measures frequency-domain correlation between signals. 

397 Coherence of 1 indicates perfect linear relationship at that frequency. 

398 

399 Args: 

400 trace1: First input trace or numpy array. 

401 trace2: Second input trace or numpy array. 

402 nperseg: Segment length for estimation. If None, auto-selected. 

403 sample_rate: Sample rate in Hz (required for array input). 

404 

405 Returns: 

406 Tuple of (frequencies, coherence): 

407 - frequencies: Frequency values in Hz 

408 - coherence: Magnitude-squared coherence [0, 1] 

409 

410 Raises: 

411 ValueError: If sample_rate is not provided when traces are arrays. 

412 

413 Example: 

414 >>> freq, coh = coherence(trace1, trace2) 

415 >>> # Find frequencies with high coherence 

416 >>> high_coh_freqs = freq[coh > 0.8] 

417 """ 

418 from scipy import signal as sp_signal 

419 

420 if isinstance(trace1, WaveformTrace): 

421 data1 = trace1.data 

422 fs = trace1.metadata.sample_rate 

423 else: 

424 data1 = trace1 

425 if sample_rate is None: 

426 raise ValueError("sample_rate required when traces are arrays") 

427 fs = sample_rate 

428 

429 data2 = trace2.data if isinstance(trace2, WaveformTrace) else trace2 

430 

431 # Ensure same length 

432 n = min(len(data1), len(data2)) 

433 data1 = data1[:n] 

434 data2 = data2[:n] 

435 

436 if nperseg is None: 

437 nperseg = min(256, n // 4) 

438 nperseg = max(nperseg, 16) 

439 

440 freq, coh = sp_signal.coherence(data1, data2, fs=fs, nperseg=nperseg, noverlap=nperseg // 2) 

441 

442 return freq, coh.astype(np.float64) 

443 

444 

445def correlate_chunked( 

446 signal1: NDArray[np.floating[Any]], 

447 signal2: NDArray[np.floating[Any]], 

448 *, 

449 mode: str = "same", 

450 chunk_size: int | None = None, 

451) -> NDArray[np.float64]: 

452 """Memory-efficient cross-correlation using overlap-save FFT method. 

453 

454 Computes cross-correlation for large signals that don't fit in memory 

455 by processing in chunks using the overlap-save method with FFT. 

456 

457 Args: 

458 signal1: First input signal array. 

459 signal2: Second input signal array (kernel/template). 

460 mode: Correlation mode - 'same', 'valid', or 'full' (default 'same'). 

461 chunk_size: Size of chunks for processing. If None, auto-selected. 

462 

463 Returns: 

464 Cross-correlation result with same semantics as numpy.correlate. 

465 

466 Raises: 

467 ValueError: If signals are empty or mode is invalid. 

468 

469 Example: 

470 >>> import numpy as np 

471 >>> # Large signals 

472 >>> signal1 = np.random.randn(100_000_000) 

473 >>> signal2 = np.random.randn(10_000) 

474 >>> # Memory-efficient correlation 

475 >>> result = correlate_chunked(signal1, signal2, mode='same') 

476 >>> print(f"Result shape: {result.shape}") 

477 

478 Notes: 

479 Uses overlap-save FFT-based convolution which is memory-efficient 

480 and faster than direct correlation for large signals. 

481 

482 References: 

483 MEM-008: Chunked Correlation 

484 Oppenheim & Schafer (2009): Discrete-Time Signal Processing, Ch 8 

485 """ 

486 if len(signal1) == 0 or len(signal2) == 0: 

487 raise ValueError("Input signals cannot be empty") 

488 

489 if mode not in ("same", "valid", "full"): 

490 raise ValueError(f"Invalid mode: {mode}. Must be 'same', 'valid', or 'full'") 

491 

492 n1 = len(signal1) 

493 n2 = len(signal2) 

494 

495 # For correlation, we need to flip signal2 

496 signal2_flipped = signal2[::-1].copy() 

497 

498 # Determine chunk size 

499 if chunk_size is None: 

500 # Auto-select: aim for ~100MB chunks 

501 bytes_per_sample = 8 # float64 

502 target_bytes = 100 * 1024 * 1024 

503 chunk_size = min(target_bytes // bytes_per_sample, n1) 

504 # Round to power of 2 for FFT efficiency 

505 chunk_size = 2 ** int(np.log2(chunk_size)) 

506 

507 # For small signals, use direct method 

508 if n1 < chunk_size and n2 < chunk_size: 508 ↛ 510line 508 didn't jump to line 510 because the condition on line 508 was never true

509 # Cast mode to literal type for numpy.correlate 

510 from typing import Literal, cast 

511 

512 mode_literal = cast("Literal['same', 'valid', 'full']", mode) 

513 result = np.correlate(signal1, signal2, mode=mode_literal) 

514 return result.astype(np.float64) 

515 

516 # Overlap-save parameters 

517 # L = chunk size, M = filter length 

518 L = chunk_size 

519 M = n2 

520 overlap = M - 1 

521 

522 # FFT size (power of 2, >= L + M - 1) 

523 nfft = int(2 ** np.ceil(np.log2(L + M - 1))) 

524 

525 # Pre-compute FFT of flipped signal2 (kernel) 

526 kernel_fft = np.fft.fft(signal2_flipped, n=nfft) 

527 

528 # Output length based on mode 

529 if mode == "full": 

530 output_len = n1 + n2 - 1 

531 elif mode == "same": 

532 output_len = n1 

533 else: # valid 

534 output_len = max(0, n1 - n2 + 1) 

535 

536 # Initialize output 

537 output = np.zeros(output_len, dtype=np.float64) 

538 

539 # Process chunks with overlap-save 

540 pos = 0 # Position in signal1 

541 

542 while pos < n1: 

543 # Extract chunk with overlap from previous chunk 

544 if pos == 0: 

545 # First chunk: no overlap needed 

546 chunk_start = 0 

547 chunk = signal1[0 : min(L, n1)] 

548 else: 

549 # Subsequent chunks: include overlap 

550 chunk_start = pos - overlap 

551 chunk = signal1[chunk_start : min(chunk_start + L, n1)] 

552 

553 # Zero-pad chunk to FFT size 

554 chunk_padded = np.zeros(nfft, dtype=np.float64) 

555 chunk_padded[: len(chunk)] = chunk 

556 

557 # Perform FFT-based convolution 

558 chunk_fft = np.fft.fft(chunk_padded) 

559 conv_fft = chunk_fft * kernel_fft 

560 conv_result = np.fft.ifft(conv_fft).real 

561 

562 # Extract valid portion (discard transient at start) 

563 if pos == 0: 

564 # First chunk 

565 valid_start = 0 

566 valid_end = min(L, len(conv_result)) 

567 else: 

568 # Subsequent chunks: discard overlap region 

569 valid_start = overlap 

570 valid_end = min(len(chunk), len(conv_result)) 

571 

572 valid_output = conv_result[valid_start:valid_end] 

573 

574 # Determine output range based on mode 

575 if mode == "full": 

576 # Full convolution includes all overlap 

577 out_start = pos 

578 out_end = min(out_start + len(valid_output), output_len) 

579 elif mode == "same": 

580 # Same mode: center-aligned 

581 offset = (M - 1) // 2 

582 out_start = max(0, pos - offset) 

583 out_end = min(out_start + len(valid_output), output_len) 

584 # Adjust valid_output if we're at boundaries 

585 if pos == 0 and offset > 0: 

586 valid_output = valid_output[offset:] 

587 else: # valid 

588 # Valid mode: only where signals fully overlap 

589 offset = M - 1 

590 if pos < offset: 

591 # Skip this chunk, not in valid region yet 

592 pos += L - overlap 

593 continue 

594 out_start = pos - offset 

595 out_end = min(out_start + len(valid_output), output_len) 

596 

597 # Copy to output 

598 copy_len = min(len(valid_output), out_end - out_start) 

599 if copy_len > 0: 

600 output[out_start : out_start + copy_len] = valid_output[:copy_len] 

601 

602 # Move to next chunk 

603 pos += L - overlap 

604 if pos <= chunk_start: 604 ↛ 606line 604 didn't jump to line 606 because the condition on line 604 was never true

605 # Prevent infinite loop 

606 pos = chunk_start + L 

607 

608 return output 

609 

610 

611__all__ = [ 

612 "CrossCorrelationResult", 

613 "autocorrelation", 

614 "coherence", 

615 "correlate_chunked", 

616 "correlation_coefficient", 

617 "cross_correlation", 

618 "find_periodicity", 

619]