Coverage for src / tracekit / analyzers / statistics / correlation.py: 97%
208 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Correlation analysis for signal data.
3This module provides autocorrelation, cross-correlation, and related
4analysis functions for identifying signal relationships and periodicities.
7Example:
8 >>> from tracekit.analyzers.statistics.correlation import (
9 ... autocorrelation, cross_correlation, correlate_chunked
10 ... )
11 >>> acf = autocorrelation(trace, max_lag=1000)
12 >>> xcorr, lag, coef = cross_correlation(trace1, trace2)
13 >>> # Memory-efficient correlation for large signals
14 >>> result = correlate_chunked(large_signal1, large_signal2)
16References:
17 Oppenheim, A. V. & Schafer, R. W. (2009). Discrete-Time Signal Processing
18 IEEE 1241-2010: Standard for Terminology and Test Methods for ADCs
19"""
21from __future__ import annotations
23from dataclasses import dataclass
24from typing import TYPE_CHECKING, Any
26import numpy as np
28from tracekit.core.types import WaveformTrace
30if TYPE_CHECKING:
31 from numpy.typing import NDArray
34@dataclass
35class CrossCorrelationResult:
36 """Result of cross-correlation analysis.
38 Attributes:
39 correlation: Full correlation array.
40 lags: Lag values in samples.
41 lag_times: Lag values in seconds.
42 peak_lag: Lag at maximum correlation (samples).
43 peak_lag_time: Lag at maximum correlation (seconds).
44 peak_coefficient: Maximum correlation coefficient.
45 sample_rate: Sample rate used for time conversion.
46 """
48 correlation: NDArray[np.float64]
49 lags: NDArray[np.intp]
50 lag_times: NDArray[np.float64]
51 peak_lag: int
52 peak_lag_time: float
53 peak_coefficient: float
54 sample_rate: float
57def autocorrelation(
58 trace: WaveformTrace | NDArray[np.floating[Any]],
59 *,
60 max_lag: int | None = None,
61 normalized: bool = True,
62 sample_rate: float | None = None,
63) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
64 """Compute autocorrelation of a signal.
66 Measures self-similarity of a signal at different time lags.
67 Useful for detecting periodicities and characteristic time scales.
69 Args:
70 trace: Input trace or numpy array.
71 max_lag: Maximum lag to compute (samples). If None, uses n // 2.
72 normalized: If True, normalize to correlation coefficients [-1, 1].
73 sample_rate: Sample rate in Hz (for time axis). Required if trace is array.
75 Returns:
76 Tuple of (lags_time, autocorrelation):
77 - lags_time: Time values for each lag in seconds
78 - autocorrelation: Normalized autocorrelation values
80 Raises:
81 ValueError: If sample_rate is not provided when trace is array.
83 Example:
84 >>> lag_times, acf = autocorrelation(trace, max_lag=1000)
85 >>> # Find first zero crossing for decorrelation time
86 >>> zero_idx = np.where(acf[1:] < 0)[0][0]
87 >>> decorr_time = lag_times[zero_idx]
89 References:
90 Box, G. E. P. & Jenkins, G. M. (1976). Time Series Analysis
91 """
92 if isinstance(trace, WaveformTrace):
93 data = trace.data
94 fs = trace.metadata.sample_rate
95 else:
96 data = trace
97 if sample_rate is None:
98 raise ValueError("sample_rate required when trace is array")
99 fs = sample_rate
101 n = len(data)
103 if max_lag is None:
104 max_lag = n // 2
106 max_lag = min(max_lag, n - 1)
108 # Remove mean for proper correlation
109 data_centered = data - np.mean(data)
111 # Compute autocorrelation via FFT (faster for large n)
112 if n > 256:
113 # Zero-pad for full correlation
114 nfft = int(2 ** np.ceil(np.log2(2 * n)))
115 fft_data = np.fft.rfft(data_centered, n=nfft)
116 acf_full = np.fft.irfft(fft_data * np.conj(fft_data), n=nfft)
117 acf = acf_full[: max_lag + 1]
118 else:
119 # Direct computation for small n
120 acf = np.correlate(data_centered, data_centered, mode="full")
121 acf = acf[n - 1 : n + max_lag]
123 # Normalize
124 if normalized and acf[0] > 0:
125 acf = acf / acf[0]
127 # Time axis
128 lags = np.arange(max_lag + 1)
129 lag_times = lags / fs
131 return lag_times, acf.astype(np.float64)
134def cross_correlation(
135 trace1: WaveformTrace | NDArray[np.floating[Any]],
136 trace2: WaveformTrace | NDArray[np.floating[Any]],
137 *,
138 max_lag: int | None = None,
139 normalized: bool = True,
140 sample_rate: float | None = None,
141) -> CrossCorrelationResult:
142 """Compute cross-correlation between two signals.
144 Measures similarity between signals at different time lags.
145 Useful for finding time delays, alignments, and relationships.
147 Args:
148 trace1: First input trace or numpy array (reference).
149 trace2: Second input trace or numpy array.
150 max_lag: Maximum lag to compute (samples). If None, uses min(n1, n2) // 2.
151 normalized: If True, normalize to correlation coefficients [-1, 1].
152 sample_rate: Sample rate in Hz. Required if traces are arrays.
154 Returns:
155 CrossCorrelationResult with correlation data and optimal lag.
157 Raises:
158 ValueError: If sample_rate is not provided when traces are arrays.
160 Example:
161 >>> result = cross_correlation(trace1, trace2)
162 >>> print(f"Optimal lag: {result.peak_lag_time * 1e6:.1f} us")
163 >>> print(f"Correlation: {result.peak_coefficient:.3f}")
165 References:
166 Oppenheim, A. V. & Schafer, R. W. (2009). Discrete-Time Signal Processing
167 """
168 if isinstance(trace1, WaveformTrace):
169 data1 = trace1.data
170 fs = trace1.metadata.sample_rate
171 else:
172 data1 = trace1
173 if sample_rate is None:
174 raise ValueError("sample_rate required when traces are arrays")
175 fs = sample_rate
177 if isinstance(trace2, WaveformTrace):
178 data2 = trace2.data
179 # Use trace2 sample rate if available and trace1 wasn't a WaveformTrace
180 if not isinstance(trace1, WaveformTrace): 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true
181 fs = trace2.metadata.sample_rate
182 else:
183 data2 = trace2
185 n1, n2 = len(data1), len(data2)
187 if max_lag is None:
188 max_lag = min(n1, n2) // 2
190 # Center the data
191 data1_centered = data1 - np.mean(data1)
192 data2_centered = data2 - np.mean(data2)
194 # Full cross-correlation
195 # Note: np.correlate(a, b) computes sum(a[n+k] * conj(b[k]))
196 # For cross-correlation where we want to detect b delayed relative to a,
197 # we need correlate(b, a) so positive lag means b is delayed
198 xcorr_full = np.correlate(data2_centered, data1_centered, mode="full")
200 # Extract relevant portion around zero lag
201 # Full correlation has length n1 + n2 - 1, with zero lag at index n1 - 1
202 # (since we swapped the order above)
203 zero_lag_idx = n1 - 1
204 start_idx = max(0, zero_lag_idx - max_lag)
205 end_idx = min(len(xcorr_full), zero_lag_idx + max_lag + 1)
206 xcorr = xcorr_full[start_idx:end_idx]
208 # Create lag array
209 lags = np.arange(start_idx - zero_lag_idx, end_idx - zero_lag_idx)
211 # Normalize
212 if normalized:
213 norm1 = np.sqrt(np.sum(data1_centered**2))
214 norm2 = np.sqrt(np.sum(data2_centered**2))
215 if norm1 > 0 and norm2 > 0:
216 xcorr = xcorr / (norm1 * norm2)
218 # Find peak
219 peak_local_idx = np.argmax(np.abs(xcorr))
220 peak_lag = int(lags[peak_local_idx])
221 peak_coefficient = float(xcorr[peak_local_idx])
223 # Time values
224 lag_times = lags / fs
225 peak_lag_time = peak_lag / fs
227 return CrossCorrelationResult(
228 correlation=xcorr.astype(np.float64),
229 lags=lags,
230 lag_times=lag_times.astype(np.float64),
231 peak_lag=peak_lag,
232 peak_lag_time=peak_lag_time,
233 peak_coefficient=peak_coefficient,
234 sample_rate=fs,
235 )
238def correlation_coefficient(
239 trace1: WaveformTrace | NDArray[np.floating[Any]],
240 trace2: WaveformTrace | NDArray[np.floating[Any]],
241) -> float:
242 """Compute Pearson correlation coefficient between two signals.
244 Simple measure of linear relationship between signals at zero lag.
246 Args:
247 trace1: First input trace or numpy array.
248 trace2: Second input trace or numpy array.
250 Returns:
251 Correlation coefficient in range [-1, 1].
253 Example:
254 >>> r = correlation_coefficient(trace1, trace2)
255 >>> print(f"Correlation: {r:.3f}")
256 """
257 data1 = trace1.data if isinstance(trace1, WaveformTrace) else trace1
259 data2 = trace2.data if isinstance(trace2, WaveformTrace) else trace2
261 # Ensure same length
262 n = min(len(data1), len(data2))
263 data1 = data1[:n]
264 data2 = data2[:n]
266 # Compute correlation
267 return float(np.corrcoef(data1, data2)[0, 1])
270def find_periodicity(
271 trace: WaveformTrace | NDArray[np.floating[Any]],
272 *,
273 min_period_samples: int = 2,
274 max_period_samples: int | None = None,
275 sample_rate: float | None = None,
276) -> dict[str, float | int | list[dict[str, int | float]]]:
277 """Find dominant periodicity in signal using autocorrelation.
279 Detects the primary periodic component by finding the first
280 significant peak in the autocorrelation function.
282 Args:
283 trace: Input trace or numpy array.
284 min_period_samples: Minimum period to consider (samples).
285 max_period_samples: Maximum period to consider (samples).
286 sample_rate: Sample rate in Hz (required for array input).
288 Returns:
289 Dictionary with periodicity analysis:
290 - period_samples: Period in samples
291 - period_time: Period in seconds
292 - frequency: Frequency in Hz
293 - strength: Autocorrelation at period (0-1)
294 - harmonics: List of detected harmonics
296 Raises:
297 ValueError: If sample_rate is not provided when trace is array.
299 Example:
300 >>> result = find_periodicity(trace)
301 >>> print(f"Period: {result['period_time']*1e6:.2f} us")
302 >>> print(f"Frequency: {result['frequency']/1e3:.1f} kHz")
303 """
304 if isinstance(trace, WaveformTrace):
305 data = trace.data
306 fs = trace.metadata.sample_rate
307 else:
308 data = trace
309 if sample_rate is None:
310 raise ValueError("sample_rate required when trace is array")
311 fs = sample_rate
313 n = len(data)
315 if max_period_samples is None:
316 max_period_samples = n // 2
318 # Compute autocorrelation
319 _lag_times, acf = autocorrelation(
320 trace,
321 max_lag=max_period_samples,
322 sample_rate=sample_rate if sample_rate else fs,
323 )
325 # Find peaks in autocorrelation (after lag 0)
326 # Look for local maxima
327 acf_search = acf[min_period_samples:]
329 if len(acf_search) < 3:
330 return {
331 "period_samples": np.nan,
332 "period_time": np.nan,
333 "frequency": np.nan,
334 "strength": np.nan,
335 "harmonics": [],
336 }
338 # Find local maxima
339 local_max = (acf_search[1:-1] > acf_search[:-2]) & (acf_search[1:-1] > acf_search[2:])
340 max_indices = np.where(local_max)[0] + 1 # +1 for offset from [1:-1]
342 if len(max_indices) == 0:
343 # No local maxima found, use global max
344 primary_idx = int(np.argmax(acf_search)) + min_period_samples
345 strength = float(acf[primary_idx])
346 else:
347 # Find strongest peak
348 peak_values = acf_search[max_indices]
349 best_peak_idx = int(np.argmax(peak_values))
350 primary_idx = int(max_indices[best_peak_idx]) + min_period_samples
351 strength = float(acf[primary_idx])
353 period_samples = int(primary_idx)
354 period_time = period_samples / fs
355 frequency = 1.0 / period_time if period_time > 0 else np.nan
357 # Find harmonics (peaks at multiples of period)
358 harmonics: list[dict[str, int | float]] = []
359 for h in range(2, 6): # Check up to 5th harmonic
360 harmonic_lag = h * period_samples
361 if harmonic_lag < len(acf):
362 # Look for peak near expected harmonic
363 search_range = max(1, period_samples // 4)
364 start = int(max(0, harmonic_lag - search_range))
365 end = int(min(len(acf), harmonic_lag + search_range))
366 local_max_idx = int(start + int(np.argmax(acf[start:end])))
367 harmonic_strength = float(acf[local_max_idx])
369 if harmonic_strength > 0.3: # Threshold for significant harmonic
370 harmonics.append(
371 {
372 "harmonic": h,
373 "lag_samples": local_max_idx,
374 "strength": harmonic_strength,
375 }
376 )
378 return {
379 "period_samples": period_samples,
380 "period_time": float(period_time),
381 "frequency": float(frequency),
382 "strength": strength,
383 "harmonics": harmonics,
384 }
387def coherence(
388 trace1: WaveformTrace | NDArray[np.floating[Any]],
389 trace2: WaveformTrace | NDArray[np.floating[Any]],
390 *,
391 nperseg: int | None = None,
392 sample_rate: float | None = None,
393) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
394 """Compute magnitude-squared coherence between two signals.
396 Measures frequency-domain correlation between signals.
397 Coherence of 1 indicates perfect linear relationship at that frequency.
399 Args:
400 trace1: First input trace or numpy array.
401 trace2: Second input trace or numpy array.
402 nperseg: Segment length for estimation. If None, auto-selected.
403 sample_rate: Sample rate in Hz (required for array input).
405 Returns:
406 Tuple of (frequencies, coherence):
407 - frequencies: Frequency values in Hz
408 - coherence: Magnitude-squared coherence [0, 1]
410 Raises:
411 ValueError: If sample_rate is not provided when traces are arrays.
413 Example:
414 >>> freq, coh = coherence(trace1, trace2)
415 >>> # Find frequencies with high coherence
416 >>> high_coh_freqs = freq[coh > 0.8]
417 """
418 from scipy import signal as sp_signal
420 if isinstance(trace1, WaveformTrace):
421 data1 = trace1.data
422 fs = trace1.metadata.sample_rate
423 else:
424 data1 = trace1
425 if sample_rate is None:
426 raise ValueError("sample_rate required when traces are arrays")
427 fs = sample_rate
429 data2 = trace2.data if isinstance(trace2, WaveformTrace) else trace2
431 # Ensure same length
432 n = min(len(data1), len(data2))
433 data1 = data1[:n]
434 data2 = data2[:n]
436 if nperseg is None:
437 nperseg = min(256, n // 4)
438 nperseg = max(nperseg, 16)
440 freq, coh = sp_signal.coherence(data1, data2, fs=fs, nperseg=nperseg, noverlap=nperseg // 2)
442 return freq, coh.astype(np.float64)
445def correlate_chunked(
446 signal1: NDArray[np.floating[Any]],
447 signal2: NDArray[np.floating[Any]],
448 *,
449 mode: str = "same",
450 chunk_size: int | None = None,
451) -> NDArray[np.float64]:
452 """Memory-efficient cross-correlation using overlap-save FFT method.
454 Computes cross-correlation for large signals that don't fit in memory
455 by processing in chunks using the overlap-save method with FFT.
457 Args:
458 signal1: First input signal array.
459 signal2: Second input signal array (kernel/template).
460 mode: Correlation mode - 'same', 'valid', or 'full' (default 'same').
461 chunk_size: Size of chunks for processing. If None, auto-selected.
463 Returns:
464 Cross-correlation result with same semantics as numpy.correlate.
466 Raises:
467 ValueError: If signals are empty or mode is invalid.
469 Example:
470 >>> import numpy as np
471 >>> # Large signals
472 >>> signal1 = np.random.randn(100_000_000)
473 >>> signal2 = np.random.randn(10_000)
474 >>> # Memory-efficient correlation
475 >>> result = correlate_chunked(signal1, signal2, mode='same')
476 >>> print(f"Result shape: {result.shape}")
478 Notes:
479 Uses overlap-save FFT-based convolution which is memory-efficient
480 and faster than direct correlation for large signals.
482 References:
483 MEM-008: Chunked Correlation
484 Oppenheim & Schafer (2009): Discrete-Time Signal Processing, Ch 8
485 """
486 if len(signal1) == 0 or len(signal2) == 0:
487 raise ValueError("Input signals cannot be empty")
489 if mode not in ("same", "valid", "full"):
490 raise ValueError(f"Invalid mode: {mode}. Must be 'same', 'valid', or 'full'")
492 n1 = len(signal1)
493 n2 = len(signal2)
495 # For correlation, we need to flip signal2
496 signal2_flipped = signal2[::-1].copy()
498 # Determine chunk size
499 if chunk_size is None:
500 # Auto-select: aim for ~100MB chunks
501 bytes_per_sample = 8 # float64
502 target_bytes = 100 * 1024 * 1024
503 chunk_size = min(target_bytes // bytes_per_sample, n1)
504 # Round to power of 2 for FFT efficiency
505 chunk_size = 2 ** int(np.log2(chunk_size))
507 # For small signals, use direct method
508 if n1 < chunk_size and n2 < chunk_size: 508 ↛ 510line 508 didn't jump to line 510 because the condition on line 508 was never true
509 # Cast mode to literal type for numpy.correlate
510 from typing import Literal, cast
512 mode_literal = cast("Literal['same', 'valid', 'full']", mode)
513 result = np.correlate(signal1, signal2, mode=mode_literal)
514 return result.astype(np.float64)
516 # Overlap-save parameters
517 # L = chunk size, M = filter length
518 L = chunk_size
519 M = n2
520 overlap = M - 1
522 # FFT size (power of 2, >= L + M - 1)
523 nfft = int(2 ** np.ceil(np.log2(L + M - 1)))
525 # Pre-compute FFT of flipped signal2 (kernel)
526 kernel_fft = np.fft.fft(signal2_flipped, n=nfft)
528 # Output length based on mode
529 if mode == "full":
530 output_len = n1 + n2 - 1
531 elif mode == "same":
532 output_len = n1
533 else: # valid
534 output_len = max(0, n1 - n2 + 1)
536 # Initialize output
537 output = np.zeros(output_len, dtype=np.float64)
539 # Process chunks with overlap-save
540 pos = 0 # Position in signal1
542 while pos < n1:
543 # Extract chunk with overlap from previous chunk
544 if pos == 0:
545 # First chunk: no overlap needed
546 chunk_start = 0
547 chunk = signal1[0 : min(L, n1)]
548 else:
549 # Subsequent chunks: include overlap
550 chunk_start = pos - overlap
551 chunk = signal1[chunk_start : min(chunk_start + L, n1)]
553 # Zero-pad chunk to FFT size
554 chunk_padded = np.zeros(nfft, dtype=np.float64)
555 chunk_padded[: len(chunk)] = chunk
557 # Perform FFT-based convolution
558 chunk_fft = np.fft.fft(chunk_padded)
559 conv_fft = chunk_fft * kernel_fft
560 conv_result = np.fft.ifft(conv_fft).real
562 # Extract valid portion (discard transient at start)
563 if pos == 0:
564 # First chunk
565 valid_start = 0
566 valid_end = min(L, len(conv_result))
567 else:
568 # Subsequent chunks: discard overlap region
569 valid_start = overlap
570 valid_end = min(len(chunk), len(conv_result))
572 valid_output = conv_result[valid_start:valid_end]
574 # Determine output range based on mode
575 if mode == "full":
576 # Full convolution includes all overlap
577 out_start = pos
578 out_end = min(out_start + len(valid_output), output_len)
579 elif mode == "same":
580 # Same mode: center-aligned
581 offset = (M - 1) // 2
582 out_start = max(0, pos - offset)
583 out_end = min(out_start + len(valid_output), output_len)
584 # Adjust valid_output if we're at boundaries
585 if pos == 0 and offset > 0:
586 valid_output = valid_output[offset:]
587 else: # valid
588 # Valid mode: only where signals fully overlap
589 offset = M - 1
590 if pos < offset:
591 # Skip this chunk, not in valid region yet
592 pos += L - overlap
593 continue
594 out_start = pos - offset
595 out_end = min(out_start + len(valid_output), output_len)
597 # Copy to output
598 copy_len = min(len(valid_output), out_end - out_start)
599 if copy_len > 0:
600 output[out_start : out_start + copy_len] = valid_output[:copy_len]
602 # Move to next chunk
603 pos += L - overlap
604 if pos <= chunk_start: 604 ↛ 606line 604 didn't jump to line 606 because the condition on line 604 was never true
605 # Prevent infinite loop
606 pos = chunk_start + L
608 return output
611__all__ = [
612 "CrossCorrelationResult",
613 "autocorrelation",
614 "coherence",
615 "correlate_chunked",
616 "correlation_coefficient",
617 "cross_correlation",
618 "find_periodicity",
619]