Coverage for src / tracekit / analyzers / statistical / chunked_corr.py: 91%
119 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Chunked correlation for memory-bounded processing.
3This module implements memory-efficient cross-correlation using the
4overlap-save method for signals larger than memory.
7Example:
8 >>> from tracekit.analyzers.statistical.chunked_corr import correlate_chunked
9 >>> corr = correlate_chunked('signal1.bin', 'signal2.bin', chunk_size=1e6)
10 >>> print(f"Correlation shape: {corr.shape}")
12References:
13 Oppenheim, A.V. & Schafer, R.W. (2009). "Discrete-Time Signal Processing"
14 Chapter on overlap-save and overlap-add methods
15"""
17from __future__ import annotations
19from pathlib import Path
20from typing import TYPE_CHECKING, Literal
22import numpy as np
23from scipy import fft, signal
25if TYPE_CHECKING:
26 from collections.abc import Iterator
28 from numpy.typing import NDArray
29else:
30 NDArray = np.ndarray
33def correlate_chunked(
34 signal1_path: str | Path | NDArray[np.float64],
35 signal2_path: str | Path | NDArray[np.float64],
36 *,
37 chunk_size: int | float = 1e6,
38 mode: Literal["valid", "same", "full"] = "same",
39 method: Literal["fft", "direct"] = "fft",
40 dtype: str = "float32",
41) -> NDArray[np.float64]:
42 """Compute correlation for large signals using chunked processing.
45 Processes signals in chunks using overlap-save method to compute
46 correlation without loading entire signals into memory. Memory
47 bounded by chunk size.
49 Args:
50 signal1_path: Path to first signal file or array.
51 signal2_path: Path to second signal file or array.
52 chunk_size: Chunk size in samples.
53 mode: Correlation mode ('valid', 'same', 'full').
54 method: Correlation method ('fft' or 'direct').
55 dtype: Data type of input files ('float32' or 'float64').
57 Returns:
58 Correlation array.
60 Example:
61 >>> # Correlate two large signals
62 >>> corr = correlate_chunked(
63 ... 'signal1.bin',
64 ... 'signal2.bin',
65 ... chunk_size=1e6,
66 ... mode='same',
67 ... method='fft'
68 ... )
69 >>> print(f"Correlation peak: {np.max(np.abs(corr))}")
71 References:
72 MEM-008: Chunked Correlation
73 """
74 chunk_size = int(chunk_size)
76 # Handle array inputs
77 if isinstance(signal1_path, np.ndarray):
78 signal1 = signal1_path
79 signal2 = (
80 signal2_path
81 if isinstance(signal2_path, np.ndarray)
82 else _load_signal(signal2_path, dtype)
83 )
84 # If both are arrays, use direct correlation
85 result: NDArray[np.float64] = signal.correlate(
86 signal1, signal2, mode=mode, method=method
87 ).astype(np.float64)
88 return result
90 if isinstance(signal2_path, np.ndarray):
91 signal1 = _load_signal(signal1_path, dtype)
92 signal2 = signal2_path
93 result2: NDArray[np.float64] = signal.correlate(
94 signal1, signal2, mode=mode, method=method
95 ).astype(np.float64)
96 return result2
98 # Both are files - use chunked processing
99 if method == "fft":
100 return _correlate_chunked_fft(signal1_path, signal2_path, chunk_size, mode, dtype)
101 else:
102 # Direct method - less efficient for large signals
103 signal1 = _load_signal(signal1_path, dtype)
104 signal2 = _load_signal(signal2_path, dtype)
105 result3: NDArray[np.float64] = signal.correlate(
106 signal1, signal2, mode=mode, method="direct"
107 ).astype(np.float64)
108 return result3
111def _correlate_chunked_fft(
112 signal1_path: str | Path,
113 signal2_path: str | Path,
114 chunk_size: int,
115 mode: str,
116 dtype: str,
117) -> NDArray[np.float64]:
118 """FFT-based chunked correlation using overlap-save.
120 Args:
121 signal1_path: Path to first signal.
122 signal2_path: Path to second signal.
123 chunk_size: Chunk size in samples.
124 mode: Correlation mode.
125 dtype: Data type.
127 Returns:
128 Correlation array.
130 Raises:
131 ValueError: If signals have different lengths or mode is invalid.
132 """
133 # Determine dtype
134 np_dtype = np.float32 if dtype == "float32" else np.float64
135 bytes_per_sample = 4 if dtype == "float32" else 8
137 # Get signal lengths
138 path1 = Path(signal1_path)
139 path2 = Path(signal2_path)
141 len1 = path1.stat().st_size // bytes_per_sample
142 len2 = path2.stat().st_size // bytes_per_sample
144 if len1 != len2:
145 raise ValueError(
146 f"Signals must have same length for correlation. Got {len1} and {len2} samples."
147 )
149 n_samples = len1
151 # For correlation, we need to reverse one signal
152 # Load signal2 completely (assumed smaller than memory for kernel)
153 # In practice, signal2 should be the shorter signal
154 signal2 = _load_signal(signal2_path, dtype)
155 signal2_rev = signal2[::-1]
157 # Determine FFT size (next power of 2)
158 nfft = _next_power_of_2(chunk_size + len2 - 1)
160 # Pre-compute FFT of reversed signal2
161 signal2_fft = fft.rfft(signal2_rev, n=nfft)
163 # Overlap-save parameters
164 # For correlation, we need overlap equal to the kernel length minus 1
165 overlap = len2 - 1
167 # Determine output length based on mode
168 if mode == "full":
169 result_len = n_samples + len2 - 1
170 elif mode == "same":
171 result_len = n_samples
172 elif mode == "valid": 172 ↛ 175line 172 didn't jump to line 175 because the condition on line 172 was always true
173 result_len = max(0, n_samples - len2 + 1)
174 else:
175 raise ValueError(f"Unknown mode: {mode}")
177 result = np.zeros(result_len, dtype=np_dtype)
179 # Process signal1 in chunks using overlap-save method
180 with open(path1, "rb") as f:
181 chunk_idx = 0
182 input_offset = 0
184 while input_offset < n_samples:
185 # Calculate chunk boundaries with overlap
186 # First chunk starts at 0, subsequent chunks include overlap
187 if chunk_idx == 0:
188 chunk_start = 0
189 chunk_end = min(n_samples, chunk_size)
190 else:
191 # Include overlap from previous chunk
192 chunk_start = max(0, input_offset - overlap)
193 chunk_end = min(n_samples, input_offset + chunk_size)
195 # Ensure chunk_start doesn't exceed total_samples
196 if chunk_start >= n_samples: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true
197 break
199 chunk_len = chunk_end - chunk_start
200 if chunk_len <= 0: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true
201 break
203 # Read chunk from file
204 f.seek(chunk_start * bytes_per_sample)
205 chunk1 = np.fromfile(f, dtype=np_dtype, count=chunk_len)
207 if len(chunk1) == 0: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true
208 break
210 # Zero-pad to FFT size
211 chunk1_padded = np.zeros(nfft, dtype=np_dtype)
212 chunk1_padded[: len(chunk1)] = chunk1
214 # Compute FFT
215 chunk1_fft = fft.rfft(chunk1_padded)
217 # Multiply in frequency domain (correlation = conj(X) * Y, but signal2 is already reversed)
218 corr_fft = chunk1_fft * signal2_fft
220 # Inverse FFT
221 corr_chunk = fft.irfft(corr_fft, n=nfft)
223 # Extract valid portion (discard wrap-around artifacts from circular convolution)
224 # The first 'overlap' samples are corrupted by circular wraparound
225 if chunk_idx == 0:
226 # First chunk: keep all samples from overlap to end
227 valid_start = overlap
228 valid_length = chunk_size
229 else:
230 # Subsequent chunks: discard overlap region
231 valid_start = overlap
232 valid_length = min(chunk_size, len(chunk1) - overlap)
234 valid_corr = corr_chunk[valid_start : valid_start + valid_length]
236 # Calculate output position for this chunk
237 # For 'full' mode, output starts at 0
238 # For 'same' mode, output is centered
239 if mode == "full":
240 output_pos = input_offset
241 elif mode == "same":
242 # Center the correlation (shift by half kernel length)
243 output_pos = input_offset
244 else: # valid
245 output_pos = input_offset
247 # Copy valid correlation to output
248 copy_len = min(len(valid_corr), result_len - output_pos)
249 if copy_len > 0:
250 result[output_pos : output_pos + copy_len] = valid_corr[:copy_len]
252 # Move to next chunk
253 input_offset = chunk_end
254 chunk_idx += 1
256 # Adjust result for different modes
257 if mode == "same":
258 # Correlation in 'same' mode should be centered
259 # The current result is in 'full' mode, so we need to extract the center
260 if len(result) > n_samples: 260 ↛ 261line 260 didn't jump to line 261 because the condition on line 260 was never true
261 start_idx = (len(result) - n_samples) // 2
262 result = result[start_idx : start_idx + n_samples]
263 elif mode == "valid":
264 # For 'valid' mode, only keep the center portion where signals fully overlap
265 if result_len < len(result): 265 ↛ 266line 265 didn't jump to line 266 because the condition on line 265 was never true
266 start_idx = (len(result) - result_len) // 2
267 result = result[start_idx : start_idx + result_len]
269 return result.astype(np.float64)
272def autocorrelate_chunked(
273 signal_path: str | Path | NDArray[np.float64],
274 *,
275 chunk_size: int | float = 1e6,
276 mode: Literal["same", "full"] = "same",
277 normalize: bool = True,
278 dtype: str = "float32",
279) -> NDArray[np.float64]:
280 """Compute autocorrelation for large signal using chunked processing.
282 Args:
283 signal_path: Path to signal file or array.
284 chunk_size: Chunk size in samples.
285 mode: Correlation mode ('same' or 'full').
286 normalize: Normalize by signal variance.
287 dtype: Data type of input file.
289 Returns:
290 Autocorrelation array.
292 Example:
293 >>> autocorr = autocorrelate_chunked(
294 ... 'signal.bin',
295 ... chunk_size=1e6,
296 ... mode='same',
297 ... normalize=True
298 ... )
299 >>> print(f"Zero-lag correlation: {autocorr[len(autocorr)//2]:.3f}")
300 """
301 # Autocorrelation is correlation with itself
302 result = correlate_chunked(
303 signal_path, signal_path, chunk_size=chunk_size, mode=mode, dtype=dtype
304 )
306 if normalize:
307 # Normalize by variance (zero-lag value for 'full' mode)
308 if isinstance(signal_path, np.ndarray):
309 signal_data = signal_path
310 variance = np.var(signal_path)
311 else:
312 signal_data = _load_signal(signal_path, dtype)
313 variance = np.var(signal_data)
315 if variance > 0: 315 ↛ 318line 315 didn't jump to line 318 because the condition on line 315 was always true
316 result = result / (variance * len(signal_data))
318 return result
321def _load_signal(file_path: str | Path, dtype: str) -> NDArray[np.float64]:
322 """Load signal from binary file.
324 Args:
325 file_path: Path to signal file.
326 dtype: Data type ('float32' or 'float64').
328 Returns:
329 Signal array.
330 """
331 np_dtype = np.float32 if dtype == "float32" else np.float64
332 return np.fromfile(file_path, dtype=np_dtype).astype(np.float64)
335def _next_power_of_2(n: int) -> int:
336 """Return next power of 2 >= n.
338 Args:
339 n: Input value.
341 Returns:
342 Next power of 2.
343 """
344 if n <= 0:
345 return 1
346 return 1 << (n - 1).bit_length()
349def cross_correlate_chunked_generator(
350 signal1_path: str | Path,
351 signal2_path: str | Path,
352 *,
353 chunk_size: int | float = 1e6,
354 dtype: str = "float32",
355) -> Iterator[NDArray[np.float64]]:
356 """Generator version that yields correlation chunks.
358 Useful for streaming processing of very large correlations.
360 Args:
361 signal1_path: Path to first signal file.
362 signal2_path: Path to second signal file.
363 chunk_size: Chunk size in samples.
364 dtype: Data type of input files.
366 Yields:
367 Correlation chunks.
369 Note:
370 FUTURE ENHANCEMENT: True streaming correlation generator.
371 Currently computes full correlation then yields chunks. A true
372 streaming implementation would compute correlation incrementally.
373 The current implementation provides correct results; streaming
374 is a memory optimization opportunity.
376 Example:
377 >>> for corr_chunk in cross_correlate_chunked_generator('s1.bin', 's2.bin'):
378 ... # Process each chunk separately
379 ... print(f"Chunk max: {np.max(np.abs(corr_chunk))}")
380 """
381 # Future: Implement true streaming correlation generator
382 # For now, compute full correlation and yield in chunks
383 corr_full = correlate_chunked(signal1_path, signal2_path, chunk_size=chunk_size, dtype=dtype)
385 chunk_size = int(chunk_size)
386 for i in range(0, len(corr_full), chunk_size):
387 yield corr_full[i : i + chunk_size]
390__all__ = [
391 "autocorrelate_chunked",
392 "correlate_chunked",
393 "cross_correlate_chunked_generator",
394]