Coverage for src / tracekit / analyzers / statistical / chunked_corr.py: 91%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Chunked correlation for memory-bounded processing. 

2 

3This module implements memory-efficient cross-correlation using the 

4overlap-save method for signals larger than memory. 

5 

6 

7Example: 

8 >>> from tracekit.analyzers.statistical.chunked_corr import correlate_chunked 

9 >>> corr = correlate_chunked('signal1.bin', 'signal2.bin', chunk_size=1e6) 

10 >>> print(f"Correlation shape: {corr.shape}") 

11 

12References: 

13 Oppenheim, A.V. & Schafer, R.W. (2009). "Discrete-Time Signal Processing" 

14 Chapter on overlap-save and overlap-add methods 

15""" 

16 

17from __future__ import annotations 

18 

19from pathlib import Path 

20from typing import TYPE_CHECKING, Literal 

21 

22import numpy as np 

23from scipy import fft, signal 

24 

25if TYPE_CHECKING: 

26 from collections.abc import Iterator 

27 

28 from numpy.typing import NDArray 

29else: 

30 NDArray = np.ndarray 

31 

32 

33def correlate_chunked( 

34 signal1_path: str | Path | NDArray[np.float64], 

35 signal2_path: str | Path | NDArray[np.float64], 

36 *, 

37 chunk_size: int | float = 1e6, 

38 mode: Literal["valid", "same", "full"] = "same", 

39 method: Literal["fft", "direct"] = "fft", 

40 dtype: str = "float32", 

41) -> NDArray[np.float64]: 

42 """Compute correlation for large signals using chunked processing. 

43 

44 

45 Processes signals in chunks using overlap-save method to compute 

46 correlation without loading entire signals into memory. Memory 

47 bounded by chunk size. 

48 

49 Args: 

50 signal1_path: Path to first signal file or array. 

51 signal2_path: Path to second signal file or array. 

52 chunk_size: Chunk size in samples. 

53 mode: Correlation mode ('valid', 'same', 'full'). 

54 method: Correlation method ('fft' or 'direct'). 

55 dtype: Data type of input files ('float32' or 'float64'). 

56 

57 Returns: 

58 Correlation array. 

59 

60 Example: 

61 >>> # Correlate two large signals 

62 >>> corr = correlate_chunked( 

63 ... 'signal1.bin', 

64 ... 'signal2.bin', 

65 ... chunk_size=1e6, 

66 ... mode='same', 

67 ... method='fft' 

68 ... ) 

69 >>> print(f"Correlation peak: {np.max(np.abs(corr))}") 

70 

71 References: 

72 MEM-008: Chunked Correlation 

73 """ 

74 chunk_size = int(chunk_size) 

75 

76 # Handle array inputs 

77 if isinstance(signal1_path, np.ndarray): 

78 signal1 = signal1_path 

79 signal2 = ( 

80 signal2_path 

81 if isinstance(signal2_path, np.ndarray) 

82 else _load_signal(signal2_path, dtype) 

83 ) 

84 # If both are arrays, use direct correlation 

85 result: NDArray[np.float64] = signal.correlate( 

86 signal1, signal2, mode=mode, method=method 

87 ).astype(np.float64) 

88 return result 

89 

90 if isinstance(signal2_path, np.ndarray): 

91 signal1 = _load_signal(signal1_path, dtype) 

92 signal2 = signal2_path 

93 result2: NDArray[np.float64] = signal.correlate( 

94 signal1, signal2, mode=mode, method=method 

95 ).astype(np.float64) 

96 return result2 

97 

98 # Both are files - use chunked processing 

99 if method == "fft": 

100 return _correlate_chunked_fft(signal1_path, signal2_path, chunk_size, mode, dtype) 

101 else: 

102 # Direct method - less efficient for large signals 

103 signal1 = _load_signal(signal1_path, dtype) 

104 signal2 = _load_signal(signal2_path, dtype) 

105 result3: NDArray[np.float64] = signal.correlate( 

106 signal1, signal2, mode=mode, method="direct" 

107 ).astype(np.float64) 

108 return result3 

109 

110 

111def _correlate_chunked_fft( 

112 signal1_path: str | Path, 

113 signal2_path: str | Path, 

114 chunk_size: int, 

115 mode: str, 

116 dtype: str, 

117) -> NDArray[np.float64]: 

118 """FFT-based chunked correlation using overlap-save. 

119 

120 Args: 

121 signal1_path: Path to first signal. 

122 signal2_path: Path to second signal. 

123 chunk_size: Chunk size in samples. 

124 mode: Correlation mode. 

125 dtype: Data type. 

126 

127 Returns: 

128 Correlation array. 

129 

130 Raises: 

131 ValueError: If signals have different lengths or mode is invalid. 

132 """ 

133 # Determine dtype 

134 np_dtype = np.float32 if dtype == "float32" else np.float64 

135 bytes_per_sample = 4 if dtype == "float32" else 8 

136 

137 # Get signal lengths 

138 path1 = Path(signal1_path) 

139 path2 = Path(signal2_path) 

140 

141 len1 = path1.stat().st_size // bytes_per_sample 

142 len2 = path2.stat().st_size // bytes_per_sample 

143 

144 if len1 != len2: 

145 raise ValueError( 

146 f"Signals must have same length for correlation. Got {len1} and {len2} samples." 

147 ) 

148 

149 n_samples = len1 

150 

151 # For correlation, we need to reverse one signal 

152 # Load signal2 completely (assumed smaller than memory for kernel) 

153 # In practice, signal2 should be the shorter signal 

154 signal2 = _load_signal(signal2_path, dtype) 

155 signal2_rev = signal2[::-1] 

156 

157 # Determine FFT size (next power of 2) 

158 nfft = _next_power_of_2(chunk_size + len2 - 1) 

159 

160 # Pre-compute FFT of reversed signal2 

161 signal2_fft = fft.rfft(signal2_rev, n=nfft) 

162 

163 # Overlap-save parameters 

164 # For correlation, we need overlap equal to the kernel length minus 1 

165 overlap = len2 - 1 

166 

167 # Determine output length based on mode 

168 if mode == "full": 

169 result_len = n_samples + len2 - 1 

170 elif mode == "same": 

171 result_len = n_samples 

172 elif mode == "valid": 172 ↛ 175line 172 didn't jump to line 175 because the condition on line 172 was always true

173 result_len = max(0, n_samples - len2 + 1) 

174 else: 

175 raise ValueError(f"Unknown mode: {mode}") 

176 

177 result = np.zeros(result_len, dtype=np_dtype) 

178 

179 # Process signal1 in chunks using overlap-save method 

180 with open(path1, "rb") as f: 

181 chunk_idx = 0 

182 input_offset = 0 

183 

184 while input_offset < n_samples: 

185 # Calculate chunk boundaries with overlap 

186 # First chunk starts at 0, subsequent chunks include overlap 

187 if chunk_idx == 0: 

188 chunk_start = 0 

189 chunk_end = min(n_samples, chunk_size) 

190 else: 

191 # Include overlap from previous chunk 

192 chunk_start = max(0, input_offset - overlap) 

193 chunk_end = min(n_samples, input_offset + chunk_size) 

194 

195 # Ensure chunk_start doesn't exceed total_samples 

196 if chunk_start >= n_samples: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 break 

198 

199 chunk_len = chunk_end - chunk_start 

200 if chunk_len <= 0: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true

201 break 

202 

203 # Read chunk from file 

204 f.seek(chunk_start * bytes_per_sample) 

205 chunk1 = np.fromfile(f, dtype=np_dtype, count=chunk_len) 

206 

207 if len(chunk1) == 0: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 break 

209 

210 # Zero-pad to FFT size 

211 chunk1_padded = np.zeros(nfft, dtype=np_dtype) 

212 chunk1_padded[: len(chunk1)] = chunk1 

213 

214 # Compute FFT 

215 chunk1_fft = fft.rfft(chunk1_padded) 

216 

217 # Multiply in frequency domain (correlation = conj(X) * Y, but signal2 is already reversed) 

218 corr_fft = chunk1_fft * signal2_fft 

219 

220 # Inverse FFT 

221 corr_chunk = fft.irfft(corr_fft, n=nfft) 

222 

223 # Extract valid portion (discard wrap-around artifacts from circular convolution) 

224 # The first 'overlap' samples are corrupted by circular wraparound 

225 if chunk_idx == 0: 

226 # First chunk: keep all samples from overlap to end 

227 valid_start = overlap 

228 valid_length = chunk_size 

229 else: 

230 # Subsequent chunks: discard overlap region 

231 valid_start = overlap 

232 valid_length = min(chunk_size, len(chunk1) - overlap) 

233 

234 valid_corr = corr_chunk[valid_start : valid_start + valid_length] 

235 

236 # Calculate output position for this chunk 

237 # For 'full' mode, output starts at 0 

238 # For 'same' mode, output is centered 

239 if mode == "full": 

240 output_pos = input_offset 

241 elif mode == "same": 

242 # Center the correlation (shift by half kernel length) 

243 output_pos = input_offset 

244 else: # valid 

245 output_pos = input_offset 

246 

247 # Copy valid correlation to output 

248 copy_len = min(len(valid_corr), result_len - output_pos) 

249 if copy_len > 0: 

250 result[output_pos : output_pos + copy_len] = valid_corr[:copy_len] 

251 

252 # Move to next chunk 

253 input_offset = chunk_end 

254 chunk_idx += 1 

255 

256 # Adjust result for different modes 

257 if mode == "same": 

258 # Correlation in 'same' mode should be centered 

259 # The current result is in 'full' mode, so we need to extract the center 

260 if len(result) > n_samples: 260 ↛ 261line 260 didn't jump to line 261 because the condition on line 260 was never true

261 start_idx = (len(result) - n_samples) // 2 

262 result = result[start_idx : start_idx + n_samples] 

263 elif mode == "valid": 

264 # For 'valid' mode, only keep the center portion where signals fully overlap 

265 if result_len < len(result): 265 ↛ 266line 265 didn't jump to line 266 because the condition on line 265 was never true

266 start_idx = (len(result) - result_len) // 2 

267 result = result[start_idx : start_idx + result_len] 

268 

269 return result.astype(np.float64) 

270 

271 

272def autocorrelate_chunked( 

273 signal_path: str | Path | NDArray[np.float64], 

274 *, 

275 chunk_size: int | float = 1e6, 

276 mode: Literal["same", "full"] = "same", 

277 normalize: bool = True, 

278 dtype: str = "float32", 

279) -> NDArray[np.float64]: 

280 """Compute autocorrelation for large signal using chunked processing. 

281 

282 Args: 

283 signal_path: Path to signal file or array. 

284 chunk_size: Chunk size in samples. 

285 mode: Correlation mode ('same' or 'full'). 

286 normalize: Normalize by signal variance. 

287 dtype: Data type of input file. 

288 

289 Returns: 

290 Autocorrelation array. 

291 

292 Example: 

293 >>> autocorr = autocorrelate_chunked( 

294 ... 'signal.bin', 

295 ... chunk_size=1e6, 

296 ... mode='same', 

297 ... normalize=True 

298 ... ) 

299 >>> print(f"Zero-lag correlation: {autocorr[len(autocorr)//2]:.3f}") 

300 """ 

301 # Autocorrelation is correlation with itself 

302 result = correlate_chunked( 

303 signal_path, signal_path, chunk_size=chunk_size, mode=mode, dtype=dtype 

304 ) 

305 

306 if normalize: 

307 # Normalize by variance (zero-lag value for 'full' mode) 

308 if isinstance(signal_path, np.ndarray): 

309 signal_data = signal_path 

310 variance = np.var(signal_path) 

311 else: 

312 signal_data = _load_signal(signal_path, dtype) 

313 variance = np.var(signal_data) 

314 

315 if variance > 0: 315 ↛ 318line 315 didn't jump to line 318 because the condition on line 315 was always true

316 result = result / (variance * len(signal_data)) 

317 

318 return result 

319 

320 

321def _load_signal(file_path: str | Path, dtype: str) -> NDArray[np.float64]: 

322 """Load signal from binary file. 

323 

324 Args: 

325 file_path: Path to signal file. 

326 dtype: Data type ('float32' or 'float64'). 

327 

328 Returns: 

329 Signal array. 

330 """ 

331 np_dtype = np.float32 if dtype == "float32" else np.float64 

332 return np.fromfile(file_path, dtype=np_dtype).astype(np.float64) 

333 

334 

335def _next_power_of_2(n: int) -> int: 

336 """Return next power of 2 >= n. 

337 

338 Args: 

339 n: Input value. 

340 

341 Returns: 

342 Next power of 2. 

343 """ 

344 if n <= 0: 

345 return 1 

346 return 1 << (n - 1).bit_length() 

347 

348 

349def cross_correlate_chunked_generator( 

350 signal1_path: str | Path, 

351 signal2_path: str | Path, 

352 *, 

353 chunk_size: int | float = 1e6, 

354 dtype: str = "float32", 

355) -> Iterator[NDArray[np.float64]]: 

356 """Generator version that yields correlation chunks. 

357 

358 Useful for streaming processing of very large correlations. 

359 

360 Args: 

361 signal1_path: Path to first signal file. 

362 signal2_path: Path to second signal file. 

363 chunk_size: Chunk size in samples. 

364 dtype: Data type of input files. 

365 

366 Yields: 

367 Correlation chunks. 

368 

369 Note: 

370 FUTURE ENHANCEMENT: True streaming correlation generator. 

371 Currently computes full correlation then yields chunks. A true 

372 streaming implementation would compute correlation incrementally. 

373 The current implementation provides correct results; streaming 

374 is a memory optimization opportunity. 

375 

376 Example: 

377 >>> for corr_chunk in cross_correlate_chunked_generator('s1.bin', 's2.bin'): 

378 ... # Process each chunk separately 

379 ... print(f"Chunk max: {np.max(np.abs(corr_chunk))}") 

380 """ 

381 # Future: Implement true streaming correlation generator 

382 # For now, compute full correlation and yield in chunks 

383 corr_full = correlate_chunked(signal1_path, signal2_path, chunk_size=chunk_size, dtype=dtype) 

384 

385 chunk_size = int(chunk_size) 

386 for i in range(0, len(corr_full), chunk_size): 

387 yield corr_full[i : i + chunk_size] 

388 

389 

390__all__ = [ 

391 "autocorrelate_chunked", 

392 "correlate_chunked", 

393 "cross_correlate_chunked_generator", 

394]