Coverage for src / tracekit / streaming / chunked.py: 96%

178 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Streaming APIs for chunk-by-chunk processing of large files. 

2 

3This module implements memory-efficient streaming analysis for huge waveform 

4files that don't fit in memory. Uses generator-based chunk loading and 

5accumulator pattern for rolling statistics. 

6""" 

7 

8from __future__ import annotations 

9 

10from pathlib import Path 

11from typing import TYPE_CHECKING, Any, cast 

12 

13import numpy as np 

14from scipy import signal 

15 

16from ..core.types import WaveformTrace 

17 

18if TYPE_CHECKING: 

19 from collections.abc import Callable, Generator 

20 

21 from numpy.typing import NDArray 

22 

23 

24def load_trace_chunks( 

25 file_path: str | Path, 

26 chunk_size: int | float = 100e6, 

27 overlap: int = 0, 

28 loader: Callable[[str | Path], WaveformTrace] | None = None, 

29 progress_callback: Callable[[int, int], None] | None = None, 

30) -> Generator[WaveformTrace, None, None]: 

31 """Load large trace files chunk-by-chunk without loading into memory. 

32 

33 Yields chunks of the trace for memory-efficient processing. Supports 

34 overlap between chunks for windowed operations that need continuity. 

35 

36 Args: 

37 file_path: Path to trace file. 

38 chunk_size: Size of each chunk in samples (if int) or bytes (if float). 

39 Default 100e6 (100 MB). 

40 overlap: Number of samples to overlap between chunks. Useful for 

41 windowed operations like FFT. Default 0. 

42 loader: Optional custom loader function. If None, uses default loader. 

43 progress_callback: Optional callback(chunk_num, total_chunks) for 

44 progress reporting. 

45 

46 Yields: 

47 WaveformTrace chunks. 

48 

49 Raises: 

50 ValueError: If failed to load trace metadata. 

51 

52 Example: 

53 >>> # Stream 10 GB file in 100 MB chunks 

54 >>> for chunk in tk.load_trace_chunks('huge_trace.bin', chunk_size=100e6): 

55 ... mean = chunk.data.mean() 

56 ... std = chunk.data.std() 

57 ... print(f"Chunk stats: mean={mean:.3f}, std={std:.3f}") 

58 

59 Advanced Example: 

60 >>> # Process with overlap for FFT continuity 

61 >>> for chunk in tk.load_trace_chunks( 

62 ... 'large_trace.bin', 

63 ... chunk_size=50e6, 

64 ... overlap=8192 # Overlap for continuity 

65 ... ): 

66 ... fft_result = tk.fft(chunk, nfft=8192) 

67 ... # Process FFT result... 

68 

69 References: 

70 API-003: Streaming/Generator API for Large Files 

71 """ 

72 file_path = Path(file_path) 

73 

74 # Import loader here to avoid circular dependency 

75 from ..loaders import load 

76 

77 # Use provided loader or default 

78 load_func = loader if loader is not None else load 

79 

80 # Load full trace metadata to get total size 

81 # For memory-mapped files, this doesn't load data 

82 try: 

83 full_trace = load_func(file_path) 

84 except Exception as e: 

85 raise ValueError(f"Failed to load trace metadata: {e}") from e 

86 

87 total_samples = len(full_trace.data) # type: ignore[union-attr] 

88 chunk_samples = int(chunk_size) if chunk_size < 1e6 else int(chunk_size / 8) 

89 

90 # Calculate number of chunks 

91 num_chunks = (total_samples - overlap) // (chunk_samples - overlap) 

92 if (total_samples - overlap) % (chunk_samples - overlap) != 0: 

93 num_chunks += 1 

94 

95 # Yield chunks 

96 chunk_num = 0 

97 start_idx = 0 

98 

99 while start_idx < total_samples: 99 ↛ exitline 99 didn't return from function 'load_trace_chunks' because the condition on line 99 was always true

100 end_idx = min(start_idx + chunk_samples, total_samples) 

101 

102 # Extract chunk 

103 chunk_data = full_trace.data[start_idx:end_idx] # type: ignore[union-attr] 

104 

105 # Create chunk trace with same metadata 

106 # Cast needed for mypy: slicing a floating array returns a floating array 

107 chunk_trace = WaveformTrace( 

108 data=cast("NDArray[np.floating[Any]]", chunk_data), 

109 metadata=full_trace.metadata, 

110 ) 

111 

112 # Call progress callback if provided 

113 if progress_callback is not None: 

114 progress_callback(chunk_num, num_chunks) 

115 

116 yield chunk_trace 

117 

118 # Move to next chunk, accounting for overlap 

119 start_idx = end_idx - overlap 

120 chunk_num += 1 

121 

122 # Break if we've reached the end 

123 if end_idx >= total_samples: 

124 break 

125 

126 

127class StreamingAnalyzer: 

128 """Accumulator for streaming analysis of large files. 

129 

130 Processes traces chunk-by-chunk, accumulating statistics and measurements 

131 without loading entire file into memory. Supports streaming PSD estimation 

132 using Welch's method and other rolling statistics. 

133 

134 Example: 

135 >>> # Create streaming analyzer 

136 >>> analyzer = tk.StreamingAnalyzer() 

137 >>> # Process file in chunks 

138 >>> for chunk in tk.load_trace_chunks('large_trace.bin', chunk_size=50e6): 

139 ... analyzer.accumulate_psd(chunk, nperseg=4096, window='hann') 

140 >>> # Get aggregated result 

141 >>> psd_result = analyzer.get_psd() 

142 

143 Advanced Example: 

144 >>> # Compute multiple statistics in streaming fashion 

145 >>> analyzer = tk.StreamingAnalyzer() 

146 >>> for chunk in tk.load_trace_chunks('huge_file.bin'): 

147 ... analyzer.accumulate_statistics(chunk) 

148 ... analyzer.accumulate_psd(chunk, nperseg=8192) 

149 >>> stats = analyzer.get_statistics() 

150 >>> psd = analyzer.get_psd() 

151 >>> print(f"Mean: {stats['mean']:.3f}, PSD shape: {psd.shape}") 

152 

153 References: 

154 API-003: Streaming/Generator API for Large Files 

155 scipy.signal.welch for streaming PSD 

156 """ 

157 

158 def __init__(self) -> None: 

159 """Initialize streaming analyzer.""" 

160 # Statistics accumulators 

161 self._n_samples = 0 

162 self._sum = 0.0 

163 self._sum_sq = 0.0 

164 self._min = float("inf") 

165 self._max = float("-inf") 

166 

167 # PSD accumulators 

168 self._psd_sum: NDArray[np.float64] | None = None 

169 self._psd_freqs: NDArray[np.float64] | None = None 

170 self._psd_count = 0 

171 self._sample_rate: float | None = None 

172 

173 # Histogram accumulators 

174 self._hist_counts: NDArray[np.int64] | None = None 

175 self._hist_edges: NDArray[np.float64] | None = None 

176 

177 def accumulate_statistics(self, chunk: WaveformTrace) -> None: 

178 """Accumulate basic statistics from chunk. 

179 

180 Updates running mean, std, min, max using Welford's online algorithm. 

181 

182 Args: 

183 chunk: WaveformTrace chunk to process. 

184 

185 Example: 

186 >>> analyzer.accumulate_statistics(chunk) 

187 """ 

188 chunk_data = chunk.data 

189 self._n_samples += len(chunk_data) 

190 self._sum += float(chunk_data.sum()) 

191 self._sum_sq += float((chunk_data**2).sum()) 

192 self._min = min(self._min, float(chunk_data.min())) 

193 self._max = max(self._max, float(chunk_data.max())) 

194 

195 def accumulate_psd( 

196 self, 

197 chunk: WaveformTrace, 

198 nperseg: int = 4096, 

199 window: str = "hann", 

200 **welch_kwargs: Any, 

201 ) -> None: 

202 """Accumulate PSD estimate from chunk using Welch's method. 

203 

204 Computes PSD for chunk and accumulates with running average. 

205 

206 Args: 

207 chunk: WaveformTrace chunk to process. 

208 nperseg: Length of each segment for Welch's method. 

209 window: Window function name (default 'hann'). 

210 **welch_kwargs: Additional arguments for scipy.signal.welch. 

211 

212 Example: 

213 >>> analyzer.accumulate_psd(chunk, nperseg=4096, window='hann') 

214 

215 References: 

216 scipy.signal.welch 

217 https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.welch.html 

218 """ 

219 # Store sample rate from first chunk 

220 if self._sample_rate is None: 

221 self._sample_rate = chunk.metadata.sample_rate 

222 

223 # Compute PSD for this chunk using Welch's method 

224 freqs, psd = signal.welch( 

225 chunk.data, 

226 fs=chunk.metadata.sample_rate, 

227 nperseg=nperseg, 

228 window=window, 

229 **welch_kwargs, 

230 ) 

231 

232 # Initialize or accumulate 

233 if self._psd_sum is None: 

234 self._psd_sum = psd 

235 self._psd_freqs = freqs 

236 else: 

237 # Accumulate PSD estimates 

238 self._psd_sum += psd 

239 

240 self._psd_count += 1 

241 

242 def accumulate_histogram( 

243 self, 

244 chunk: WaveformTrace, 

245 bins: int | NDArray[np.float64] = 100, 

246 range: tuple[float, float] | None = None, 

247 ) -> None: 

248 """Accumulate histogram from chunk. 

249 

250 Args: 

251 chunk: WaveformTrace chunk to process. 

252 bins: Number of bins or bin edges. 

253 range: Range of histogram (min, max). 

254 

255 Example: 

256 >>> analyzer.accumulate_histogram(chunk, bins=100) 

257 """ 

258 counts, edges = np.histogram(chunk.data, bins=bins, range=range) 

259 

260 if self._hist_counts is None: 

261 self._hist_counts = counts.astype(np.int64) 

262 self._hist_edges = edges 

263 else: 

264 self._hist_counts += counts.astype(np.int64) 

265 

266 def get_statistics(self) -> dict[str, float]: 

267 """Get accumulated statistics. 

268 

269 Returns: 

270 Dictionary with mean, std, min, max, and sample count. 

271 

272 Raises: 

273 ValueError: If no data accumulated yet. 

274 

275 Example: 

276 >>> stats = analyzer.get_statistics() 

277 >>> print(f"Mean: {stats['mean']:.3f}, Std: {stats['std']:.3f}") 

278 """ 

279 if self._n_samples == 0: 

280 raise ValueError("No data accumulated yet") 

281 

282 mean = self._sum / self._n_samples 

283 variance = (self._sum_sq / self._n_samples) - (mean**2) 

284 std = np.sqrt(max(0, variance)) # Avoid negative due to numerical errors 

285 

286 return { 

287 "mean": mean, 

288 "std": std, 

289 "min": self._min, 

290 "max": self._max, 

291 "n_samples": self._n_samples, 

292 } 

293 

294 def get_psd(self) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

295 """Get accumulated PSD estimate. 

296 

297 Returns: 

298 Tuple of (frequencies, psd) where psd is averaged over all chunks. 

299 

300 Raises: 

301 ValueError: If no PSD data accumulated. 

302 

303 Example: 

304 >>> freqs, psd = analyzer.get_psd() 

305 >>> print(f"PSD shape: {psd.shape}") 

306 """ 

307 if self._psd_sum is None or self._psd_freqs is None: 

308 raise ValueError("No PSD data accumulated yet") 

309 

310 # Return averaged PSD 

311 psd_avg = self._psd_sum / self._psd_count 

312 return self._psd_freqs, psd_avg 

313 

314 def get_histogram(self) -> tuple[NDArray[np.int64], NDArray[np.float64]]: 

315 """Get accumulated histogram. 

316 

317 Returns: 

318 Tuple of (counts, edges). 

319 

320 Raises: 

321 ValueError: If no histogram data accumulated. 

322 

323 Example: 

324 >>> counts, edges = analyzer.get_histogram() 

325 """ 

326 if self._hist_counts is None or self._hist_edges is None: 

327 raise ValueError("No histogram data accumulated yet") 

328 

329 return self._hist_counts, self._hist_edges 

330 

331 def reset(self) -> None: 

332 """Reset all accumulators. 

333 

334 Example: 

335 >>> analyzer.reset() 

336 """ 

337 self._n_samples = 0 

338 self._sum = 0.0 

339 self._sum_sq = 0.0 

340 self._min = float("inf") 

341 self._max = float("-inf") 

342 self._psd_sum = None 

343 self._psd_freqs = None 

344 self._psd_count = 0 

345 self._sample_rate = None 

346 self._hist_counts = None 

347 self._hist_edges = None 

348 

349 

350def chunked_spectrogram( 

351 data: NDArray[np.float64], 

352 sample_rate: float, 

353 *, 

354 chunk_size: int = 10_000_000, 

355 overlap: int = 0, 

356 nperseg: int = 256, 

357 noverlap: int | None = None, 

358 window: str = "hann", 

359) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]: 

360 """Compute spectrogram for large signals using chunked processing. 

361 

362 

363 Processes large signals in overlapping chunks to compute spectrograms 

364 without loading entire signal into memory. Stitches STFT results from 

365 chunks with proper boundary handling. 

366 

367 Args: 

368 data: Input signal array (can be memory-mapped). 

369 sample_rate: Sample rate in Hz. 

370 chunk_size: Maximum samples per chunk (default 10M). 

371 overlap: Overlap samples between chunks for continuity (default 0). 

372 Should be at least 2*nperseg for proper STFT boundary handling. 

373 nperseg: Segment length for STFT (default 256). 

374 noverlap: Overlap between STFT segments within chunk (default nperseg//2). 

375 window: Window function name (default "hann"). 

376 

377 Returns: 

378 (times, frequencies, Sxx_db) - Time axis, frequency axis, and 

379 spectrogram magnitude in dB as 2D array (frequencies x time). 

380 

381 Raises: 

382 ValueError: If no valid chunks produced. 

383 

384 Example: 

385 >>> # Memory-efficient spectrogram on 1 GB signal 

386 >>> import numpy as np 

387 >>> data = np.memmap('huge_trace.dat', dtype='float64', mode='r') 

388 >>> t, f, Sxx = chunked_spectrogram(data, sample_rate=1e9, chunk_size=10_000_000) 

389 >>> print(f"Spectrogram shape: {Sxx.shape}") 

390 

391 References: 

392 MEM-004: Chunked Spectrogram requirement 

393 scipy.signal.spectrogram 

394 """ 

395 n = len(data) 

396 

397 # Handle empty input 

398 if n == 0: 

399 return np.array([]), np.array([]), np.array([]).reshape(0, 0) 

400 

401 if noverlap is None: 

402 noverlap = nperseg // 2 

403 

404 # Auto-adjust overlap if not specified to ensure continuity 

405 if overlap == 0: 

406 overlap = 2 * nperseg 

407 

408 # If data fits in one chunk, use scipy directly 

409 if n <= chunk_size: 

410 freq, times, Sxx = signal.spectrogram( 

411 data, 

412 fs=sample_rate, 

413 window=window, 

414 nperseg=nperseg, 

415 noverlap=noverlap, 

416 scaling="spectrum", 

417 ) 

418 # Convert to dB 

419 Sxx = np.maximum(Sxx, 1e-20) 

420 Sxx_db = 10 * np.log10(Sxx) 

421 return times, freq, Sxx_db 

422 

423 # Process chunks 

424 chunks_stft = [] 

425 chunks_times = [] 

426 chunk_start = 0 

427 

428 while chunk_start < n: 

429 # Determine chunk boundaries with overlap 

430 chunk_end = min(chunk_start + chunk_size, n) 

431 

432 # Extract chunk with overlap extension on both sides 

433 extended_start = max(0, chunk_start - overlap) 

434 extended_end = min(n, chunk_end + overlap) 

435 

436 chunk_data = data[extended_start:extended_end] 

437 

438 # Compute spectrogram for chunk 

439 freq, times_chunk, Sxx_chunk = signal.spectrogram( 

440 chunk_data, 

441 fs=sample_rate, 

442 window=window, 

443 nperseg=nperseg, 

444 noverlap=noverlap, 

445 scaling="spectrum", 

446 ) 

447 

448 # Adjust time axis for chunk position 

449 time_offset = extended_start / sample_rate 

450 times_chunk_adjusted = times_chunk + time_offset 

451 

452 # Trim overlap regions to avoid duplication 

453 valid_time_start = chunk_start / sample_rate 

454 valid_time_end = chunk_end / sample_rate 

455 

456 valid_mask = (times_chunk_adjusted >= valid_time_start) & ( 

457 times_chunk_adjusted < valid_time_end 

458 ) 

459 

460 if np.any(valid_mask): 460 ↛ 468line 460 didn't jump to line 468 because the condition on line 460 was always true

461 Sxx_chunk = Sxx_chunk[:, valid_mask] 

462 times_chunk_adjusted = times_chunk_adjusted[valid_mask] 

463 

464 chunks_stft.append(Sxx_chunk) 

465 chunks_times.append(times_chunk_adjusted) 

466 

467 # Move to next chunk 

468 chunk_start = chunk_end 

469 

470 # Concatenate all chunks 

471 if len(chunks_stft) == 0: 471 ↛ 472line 471 didn't jump to line 472 because the condition on line 471 was never true

472 raise ValueError("No valid chunks produced") 

473 

474 Sxx = np.concatenate(chunks_stft, axis=1) 

475 times = np.concatenate(chunks_times) 

476 

477 # Convert to dB 

478 Sxx = np.maximum(Sxx, 1e-20) 

479 Sxx_db = 10 * np.log10(Sxx) 

480 

481 return times, freq, Sxx_db 

482 

483 

484def chunked_fft( 

485 data: NDArray[np.float64], 

486 sample_rate: float, 

487 *, 

488 chunk_size: int = 10_000_000, 

489 overlap: float = 50.0, 

490 window: str = "hann", 

491 nfft: int | None = None, 

492) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

493 """Compute FFT for very long signals using segmented averaging. 

494 

495 

496 Divides signal into overlapping segments, computes FFT for each, 

497 and averages magnitude spectra. This is memory-bounded by chunk_size 

498 and provides variance reduction through averaging (similar to Welch's method). 

499 

500 Args: 

501 data: Input signal array (can be memory-mapped). 

502 sample_rate: Sample rate in Hz. 

503 chunk_size: Size of each segment in samples (default 10M). 

504 overlap: Percentage overlap between segments, 0-100 (default 50%). 

505 window: Window function name (default "hann"). 

506 nfft: FFT length. If None, uses next power of 2 >= chunk_size. 

507 

508 Returns: 

509 (frequencies, magnitude_db) - Frequency axis and averaged magnitude in dB. 

510 

511 Example: 

512 >>> # Memory-efficient FFT on 1 GB signal with 50% overlap 

513 >>> import numpy as np 

514 >>> data = np.memmap('huge_trace.dat', dtype='float64', mode='r') 

515 >>> freq, mag = chunked_fft(data, sample_rate=1e9, chunk_size=1_000_000) 

516 >>> print(f"Frequency resolution: {freq[1] - freq[0]:.3f} Hz") 

517 

518 References: 

519 MEM-006: Chunked FFT requirement 

520 Welch's method for spectral estimation 

521 """ 

522 from ..utils.windowing import get_window 

523 

524 n = len(data) 

525 

526 # Handle empty input 

527 if n == 0: 

528 return np.array([]), np.array([]) 

529 

530 # If data fits in one chunk, compute single FFT 

531 if n <= chunk_size: 

532 if nfft is None: 532 ↛ 536line 532 didn't jump to line 536 because the condition on line 532 was always true

533 nfft = int(2 ** np.ceil(np.log2(n))) 

534 

535 # Apply window 

536 w = get_window(window, n) 

537 data_windowed = data * w 

538 

539 # Compute FFT 

540 spectrum = np.fft.rfft(data_windowed, n=nfft) 

541 

542 # Frequency axis 

543 freq = np.fft.rfftfreq(nfft, d=1.0 / sample_rate) 

544 

545 # Magnitude in dB (normalized by window gain) 

546 window_gain = np.sum(w) / n 

547 magnitude = np.abs(spectrum) / (n * window_gain) 

548 magnitude = np.maximum(magnitude, 1e-20) 

549 magnitude_db = 20 * np.log10(magnitude) 

550 

551 return freq, magnitude_db 

552 

553 # Calculate overlap 

554 overlap_samples = int(chunk_size * overlap / 100.0) 

555 hop = chunk_size - overlap_samples 

556 

557 # Determine number of segments 

558 num_segments = max(1, (n - overlap_samples) // hop) 

559 

560 if nfft is None: 

561 nfft = int(2 ** np.ceil(np.log2(chunk_size))) 

562 

563 # Prepare window 

564 w = get_window(window, chunk_size) 

565 window_gain = np.sum(w) / chunk_size 

566 

567 # Accumulate magnitude spectra 

568 freq = np.fft.rfftfreq(nfft, d=1.0 / sample_rate) 

569 magnitude_sum = np.zeros(len(freq)) 

570 

571 for i in range(num_segments): 

572 start = i * hop 

573 end = min(start + chunk_size, n) 

574 

575 # Extract segment 

576 if end - start < chunk_size: 576 ↛ 578line 576 didn't jump to line 578 because the condition on line 576 was never true

577 # Last segment: pad with zeros 

578 segment = np.zeros(chunk_size) 

579 segment[: end - start] = data[start:end] 

580 else: 

581 segment = data[start:end] 

582 

583 # Detrend (remove mean) 

584 segment = segment - np.mean(segment) 

585 

586 # Window 

587 segment_windowed = segment * w 

588 

589 # FFT 

590 spectrum = np.fft.rfft(segment_windowed, n=nfft) 

591 

592 # Accumulate magnitude 

593 magnitude = np.abs(spectrum) / (chunk_size * window_gain) 

594 magnitude_sum += magnitude 

595 

596 # Average 

597 magnitude_avg = magnitude_sum / num_segments 

598 

599 # Convert to dB 

600 magnitude_avg = np.maximum(magnitude_avg, 1e-20) 

601 magnitude_db = 20 * np.log10(magnitude_avg) 

602 

603 return freq, magnitude_db 

604 

605 

606__all__ = [ 

607 "StreamingAnalyzer", 

608 "chunked_fft", 

609 "chunked_spectrogram", 

610 "load_trace_chunks", 

611]