Coverage for src / tracekit / analyzers / spectral / chunked_fft.py: 96%

134 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Chunked FFT computation for very long signals. 

2 

3This module implements FFT computation for signals larger than memory 

4using overlapping segments with result aggregation. 

5 

6 

7Example: 

8 >>> from tracekit.analyzers.spectral.chunked_fft import fft_chunked 

9 >>> freqs, spectrum = fft_chunked('huge_signal.bin', segment_size=1e6, overlap_pct=50) 

10 >>> print(f"FFT shape: {spectrum.shape}") 

11 

12References: 

13 scipy.fft for FFT computation 

14 Welch's method for spectral averaging 

15""" 

16 

17from __future__ import annotations 

18 

19from pathlib import Path 

20from typing import TYPE_CHECKING, Any 

21 

22import numpy as np 

23from scipy import fft, signal 

24 

25from tracekit.core.memoize import memoize_analysis 

26 

27if TYPE_CHECKING: 

28 from collections.abc import Callable, Iterator 

29 

30 from numpy.typing import NDArray 

31 

32 

33@memoize_analysis(maxsize=16) 

34def fft_chunked( 

35 file_path: str | Path, 

36 segment_size: int | float, 

37 overlap_pct: float = 50.0, 

38 *, 

39 window: str | NDArray[np.float64] = "hann", 

40 nfft: int | None = None, 

41 detrend: str | bool = False, 

42 scaling: str = "density", 

43 average_method: str = "mean", 

44 sample_rate: float = 1.0, 

45 dtype: str = "float32", 

46 preserve_phase: bool = False, 

47) -> tuple[NDArray[np.float64], NDArray[np.float64] | NDArray[np.complex128]]: 

48 """Compute FFT for very long signals using overlapping segments. 

49 

50 

51 Processes signal in segments with overlap, computes FFT per segment, 

52 and aggregates using specified method. Handles window edge effects. 

53 

54 Args: 

55 file_path: Path to signal file (binary format). 

56 segment_size: Segment size in samples. 

57 overlap_pct: Overlap percentage between segments (0-100). 

58 window: Window function name or array. 

59 nfft: FFT length (default: segment_size, zero-padded if larger). 

60 detrend: Detrend type ('constant', 'linear', False). 

61 scaling: Scaling mode ('density' or 'spectrum'). 

62 average_method: Aggregation method ('mean', 'median', 'max'). 

63 sample_rate: Sample rate in Hz (for frequency axis). 

64 dtype: Data type of input file ('float32' or 'float64'). 

65 preserve_phase: If True, preserve phase information (complex output). 

66 

67 Returns: 

68 Tuple of (frequencies, spectrum) where: 

69 - frequencies: Frequency bins in Hz. 

70 - spectrum: Averaged FFT magnitude (or complex if preserve_phase=True). 

71 

72 Raises: 

73 ValueError: If overlap_pct not in [0, 100] or file cannot be read. 

74 

75 Example: 

76 >>> # Process 10 GB file with 1M sample segments, 50% overlap 

77 >>> freqs, spectrum = fft_chunked( 

78 ... 'huge_signal.bin', 

79 ... segment_size=1e6, 

80 ... overlap_pct=50, 

81 ... window='hann', 

82 ... sample_rate=1e9, 

83 ... dtype='float32' 

84 ... ) 

85 >>> print(f"Spectrum shape: {spectrum.shape}") 

86 

87 References: 

88 MEM-006: Chunked FFT for Very Long Signals 

89 """ 

90 if not 0 <= overlap_pct < 100: 

91 raise ValueError( 

92 f"overlap_pct must be in [0, 100), got {overlap_pct}. Note: 100% overlap would create an infinite loop." 

93 ) 

94 

95 segment_size = int(segment_size) 

96 if nfft is None: 

97 nfft = segment_size 

98 

99 # Calculate overlap in samples 

100 noverlap = int(segment_size * overlap_pct / 100) 

101 

102 # Determine dtype 

103 np_dtype = np.float32 if dtype == "float32" else np.float64 

104 bytes_per_sample = 4 if dtype == "float32" else 8 

105 

106 # Open file and get total size 

107 file_path = Path(file_path) 

108 file_size_bytes = file_path.stat().st_size 

109 total_samples = file_size_bytes // bytes_per_sample 

110 

111 # Generate window 

112 if isinstance(window, str): 

113 window_arr = signal.get_window(window, segment_size) 

114 else: 

115 window_arr = np.asarray(window) 

116 

117 # Initialize accumulators 

118 fft_accum: list[NDArray[np.float64] | NDArray[np.complex128]] = [] 

119 

120 # Process segments 

121 for segment in _generate_segments(file_path, total_samples, segment_size, noverlap, np_dtype): 

122 # Apply detrending 

123 if detrend: 

124 segment = signal.detrend(segment, type=detrend) 

125 

126 # Apply window 

127 windowed = segment * window_arr[: len(segment)] 

128 

129 # Zero-pad if needed 

130 if len(windowed) < nfft: 

131 windowed = np.pad(windowed, (0, nfft - len(windowed)), mode="constant") 

132 

133 # Compute FFT 

134 fft_result = fft.rfft(windowed, n=nfft) 

135 

136 # Store result (magnitude or complex) 

137 if preserve_phase: 

138 fft_accum.append(fft_result) 

139 else: 

140 fft_accum.append(np.abs(fft_result)) 

141 

142 # Aggregate results 

143 if len(fft_accum) == 0: 

144 raise ValueError(f"No segments processed from {file_path}") 

145 

146 if average_method == "mean": 

147 spectrum = np.mean(fft_accum, axis=0) 

148 elif average_method == "median": 

149 spectrum = np.median(fft_accum, axis=0) 

150 elif average_method == "max": 

151 spectrum = np.max(fft_accum, axis=0) 

152 else: 

153 raise ValueError( 

154 f"Unknown average_method: {average_method}. Use 'mean', 'median', or 'max'." 

155 ) 

156 

157 # Apply scaling 

158 if scaling == "density" and not preserve_phase: 

159 # Convert to PSD-like scaling 

160 spectrum = spectrum**2 / (sample_rate * np.sum(window_arr**2)) 

161 elif scaling == "spectrum" and not preserve_phase: 

162 # RMS scaling 

163 spectrum = spectrum / len(window_arr) 

164 

165 # Frequency axis 

166 frequencies = fft.rfftfreq(nfft, d=1 / sample_rate) 

167 

168 return frequencies, spectrum 

169 

170 

171def _generate_segments( 

172 file_path: Path, 

173 total_samples: int, 

174 segment_size: int, 

175 noverlap: int, 

176 dtype: type, 

177) -> Iterator[NDArray[np.float64]]: 

178 """Generate overlapping segments from file. 

179 

180 Args: 

181 file_path: Path to binary file. 

182 total_samples: Total number of samples in file. 

183 segment_size: Samples per segment. 

184 noverlap: Overlap samples between segments. 

185 dtype: NumPy dtype for data. 

186 

187 Yields: 

188 Segment arrays. 

189 """ 

190 hop = segment_size - noverlap 

191 offset = 0 

192 

193 with open(file_path, "rb") as f: 

194 while offset < total_samples: 

195 # Read segment 

196 f.seek(offset * dtype().itemsize) 

197 segment_data: NDArray[np.float64] = np.fromfile(f, dtype=dtype, count=segment_size) 

198 

199 if len(segment_data) == 0: 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true

200 break 

201 

202 yield segment_data 

203 

204 offset += hop 

205 

206 

207def welch_psd_chunked( 

208 file_path: str | Path, 

209 segment_size: int | float = 256, 

210 overlap_pct: float = 50.0, 

211 *, 

212 window: str | NDArray[np.float64] = "hann", 

213 nfft: int | None = None, 

214 detrend: str | bool = "constant", 

215 scaling: str = "density", 

216 sample_rate: float = 1.0, 

217 dtype: str = "float32", 

218) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

219 """Compute Welch PSD estimate for very long signals. 

220 

221 Similar to fft_chunked but specifically implements Welch's method 

222 for power spectral density estimation. 

223 

224 

225 Args: 

226 file_path: Path to signal file. 

227 segment_size: Segment size for Welch's method. 

228 overlap_pct: Overlap percentage (typically 50%). 

229 window: Window function. 

230 nfft: FFT length. 

231 detrend: Detrend type. 

232 scaling: Scaling mode ('density' or 'spectrum'). 

233 sample_rate: Sample rate in Hz. 

234 dtype: Data type of input file. 

235 

236 Returns: 

237 Tuple of (frequencies, psd). 

238 

239 Example: 

240 >>> freqs, psd = welch_psd_chunked('signal.bin', segment_size=1024, sample_rate=1e6) 

241 >>> print(f"PSD shape: {psd.shape}") 

242 

243 References: 

244 MEM-005: Chunked Welch PSD 

245 Welch, P.D. (1967). "The use of fast Fourier transform for the 

246 estimation of power spectra" 

247 """ 

248 freqs, spectrum = fft_chunked( 

249 file_path, 

250 segment_size=segment_size, 

251 overlap_pct=overlap_pct, 

252 window=window, 

253 nfft=nfft, 

254 detrend=detrend, 

255 scaling=scaling, 

256 average_method="mean", 

257 sample_rate=sample_rate, 

258 dtype=dtype, 

259 preserve_phase=False, 

260 ) 

261 # preserve_phase=False guarantees float64 output, not complex128 

262 return freqs, spectrum # type: ignore[return-value] 

263 

264 

265def fft_chunked_parallel( 

266 file_path: str | Path, 

267 segment_size: int | float, 

268 overlap_pct: float = 50.0, 

269 *, 

270 n_workers: int = 4, 

271 **kwargs: Any, 

272) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

273 """Compute chunked FFT with parallel processing. 

274 

275 Similar to fft_chunked but uses multiple workers for parallel 

276 segment processing. Useful for very large files on multi-core systems. 

277 

278 Args: 

279 file_path: Path to signal file. 

280 segment_size: Segment size in samples. 

281 overlap_pct: Overlap percentage. 

282 n_workers: Number of parallel workers. 

283 **kwargs: Additional arguments passed to fft_chunked. 

284 

285 Returns: 

286 Tuple of (frequencies, spectrum). 

287 

288 Note: 

289 FUTURE ENHANCEMENT: Parallel processing with multiprocessing/joblib. 

290 Currently uses serial processing (n_workers parameter is reserved 

291 for future implementation). The serial fallback provides correct 

292 results; parallelization is an optimization opportunity. 

293 

294 Example: 

295 >>> freqs, spectrum = fft_chunked_parallel( 

296 ... 'signal.bin', 

297 ... segment_size=1e6, 

298 ... overlap_pct=50, 

299 ... n_workers=8 

300 ... ) 

301 """ 

302 # Future: Implement parallel processing with multiprocessing or joblib 

303 # For now, fall back to serial processing 

304 freqs, spectrum = fft_chunked(file_path, segment_size, overlap_pct, **kwargs) 

305 # kwargs may contain preserve_phase, handle both float64 and complex128 

306 return freqs, spectrum # type: ignore[return-value] 

307 

308 

309def streaming_fft( 

310 file_path: str | Path, 

311 segment_size: int | float, 

312 overlap_pct: float = 50.0, 

313 *, 

314 window: str | NDArray[np.float64] = "hann", 

315 nfft: int | None = None, 

316 detrend: str | bool = False, 

317 sample_rate: float = 1.0, 

318 dtype: str = "float32", 

319 progress_callback: Callable[[int, int], None] | None = None, 

320) -> Iterator[tuple[NDArray[np.float64], NDArray[np.float64]]]: 

321 """Stream FFT computation yielding frequency bins as computed. 

322 

323 Implements streaming/generator API for memory-efficient FFT computation 

324 on very large files. Yields frequency bins as they are computed, allowing 

325 downstream processing before all segments are complete. 

326 

327 

328 Args: 

329 file_path: Path to signal file (binary format). 

330 segment_size: Segment size in samples. 

331 overlap_pct: Overlap percentage between segments (0-100). 

332 window: Window function name or array. 

333 nfft: FFT length (default: segment_size). 

334 detrend: Detrend type ('constant', 'linear', False). 

335 sample_rate: Sample rate in Hz (for frequency axis). 

336 dtype: Data type of input file ('float32' or 'float64'). 

337 progress_callback: Optional callback(current, total) to report progress. 

338 

339 Yields: 

340 Tuple of (frequencies, fft_magnitude) for each segment. 

341 

342 Raises: 

343 ValueError: If overlap_pct not in valid range. 

344 

345 Example: 

346 >>> # Stream FFT results as computed 

347 >>> def on_progress(current, total): 

348 ... print(f"Progress: {current}/{total} ({current/total*100:.1f}%)") 

349 >>> 

350 >>> for frequencies, magnitude in streaming_fft( 

351 ... 'huge_signal.bin', 

352 ... segment_size=1e6, 

353 ... overlap_pct=50, 

354 ... progress_callback=on_progress 

355 ... ): 

356 ... # Process each segment immediately 

357 ... peak_freq = frequencies[magnitude.argmax()] 

358 ... print(f"Peak frequency: {peak_freq:.2e} Hz") 

359 

360 References: 

361 API-003: Streaming/Generator API for Large Files 

362 """ 

363 if not 0 <= overlap_pct < 100: 

364 raise ValueError( 

365 f"overlap_pct must be in [0, 100), got {overlap_pct}. Note: 100% overlap would create an infinite loop." 

366 ) 

367 

368 segment_size = int(segment_size) 

369 if nfft is None: 

370 nfft = segment_size 

371 

372 # Calculate overlap in samples 

373 noverlap = int(segment_size * overlap_pct / 100) 

374 

375 # Determine dtype 

376 np_dtype = np.float32 if dtype == "float32" else np.float64 

377 bytes_per_sample = 4 if dtype == "float32" else 8 

378 

379 # Open file and get total size 

380 file_path = Path(file_path) 

381 file_size_bytes = file_path.stat().st_size 

382 total_samples = file_size_bytes // bytes_per_sample 

383 

384 # Calculate total segments for progress reporting 

385 hop = segment_size - noverlap 

386 total_segments = max(1, (total_samples - segment_size) // hop + 1) 

387 

388 # Generate window 

389 if isinstance(window, str): 389 ↛ 392line 389 didn't jump to line 392 because the condition on line 389 was always true

390 window_arr = signal.get_window(window, segment_size) 

391 else: 

392 window_arr = np.asarray(window) 

393 

394 # Frequency axis (computed once) 

395 frequencies = fft.rfftfreq(nfft, d=1 / sample_rate) 

396 

397 # Process and yield segments 

398 segment_count = 0 

399 for segment in _generate_segments(file_path, total_samples, segment_size, noverlap, np_dtype): 

400 # Apply detrending 

401 if detrend: 401 ↛ 402line 401 didn't jump to line 402 because the condition on line 401 was never true

402 segment = signal.detrend(segment, type=detrend) 

403 

404 # Apply window 

405 windowed = segment * window_arr[: len(segment)] 

406 

407 # Zero-pad if needed 

408 if len(windowed) < nfft: 

409 windowed = np.pad(windowed, (0, nfft - len(windowed)), mode="constant") 

410 

411 # Compute FFT 

412 fft_result = fft.rfft(windowed, n=nfft) 

413 magnitude = np.abs(fft_result) 

414 

415 # Yield result immediately 

416 yield frequencies, magnitude 

417 

418 # Update progress 

419 segment_count += 1 # noqa: SIM113 

420 if progress_callback is not None: 

421 progress_callback(segment_count, total_segments) 

422 

423 

424class StreamingAnalyzer: 

425 """Accumulator for streaming analysis across chunks. 

426 

427 Enables processing of huge files chunk-by-chunk with accumulation 

428 of statistics, PSD estimates, and other aggregated measurements. 

429 

430 

431 Attributes: 

432 chunk_count: Number of chunks processed. 

433 accumulated_psd: Accumulated PSD estimate (if accumulate_psd called). 

434 accumulated_stats: Dictionary of accumulated statistics. 

435 

436 Example: 

437 >>> analyzer = StreamingAnalyzer() 

438 >>> for chunk in load_trace_chunks('large.bin', chunk_size=50e6): 

439 ... analyzer.accumulate_psd(chunk, nperseg=4096) 

440 ... analyzer.accumulate_stats(chunk) 

441 >>> psd = analyzer.get_psd() 

442 >>> stats = analyzer.get_stats() 

443 

444 References: 

445 API-003: Streaming/Generator API for Large Files 

446 """ 

447 

448 def __init__(self) -> None: 

449 """Initialize streaming analyzer.""" 

450 self.chunk_count: int = 0 

451 self._psd_accumulator: list[NDArray[Any]] = [] 

452 self._psd_frequencies: NDArray[Any] | None = None 

453 self._psd_config: dict[str, Any] = {} 

454 self._stats_accumulator: dict[str, list[float]] = { 

455 "mean": [], 

456 "std": [], 

457 "min": [], 

458 "max": [], 

459 } 

460 

461 def accumulate_psd( 

462 self, 

463 chunk: NDArray[Any], 

464 nperseg: int = 256, 

465 window: str = "hann", 

466 sample_rate: float = 1.0, 

467 ) -> None: 

468 """Accumulate PSD estimate from chunk using Welch's method. 

469 

470 Args: 

471 chunk: Data chunk to process. 

472 nperseg: Length of each segment for Welch's method. 

473 window: Window function name. 

474 sample_rate: Sample rate in Hz. 

475 

476 Example: 

477 >>> analyzer.accumulate_psd(chunk, nperseg=4096, window='hann') 

478 """ 

479 # Compute Welch PSD for this chunk 

480 freqs, psd = signal.welch(chunk, fs=sample_rate, nperseg=nperseg, window=window) 

481 

482 # Store frequencies on first call 

483 if self._psd_frequencies is None: 

484 self._psd_frequencies = freqs 

485 self._psd_config = { 

486 "nperseg": nperseg, 

487 "window": window, 

488 "sample_rate": sample_rate, 

489 } 

490 

491 # Accumulate PSD 

492 self._psd_accumulator.append(psd) 

493 self.chunk_count += 1 

494 

495 def accumulate_stats(self, chunk: NDArray[np.float64]) -> None: 

496 """Accumulate basic statistics from chunk. 

497 

498 Args: 

499 chunk: Data chunk to process. 

500 

501 Example: 

502 >>> analyzer.accumulate_stats(chunk) 

503 """ 

504 self._stats_accumulator["mean"].append(float(np.mean(chunk))) 

505 self._stats_accumulator["std"].append(float(np.std(chunk))) 

506 self._stats_accumulator["min"].append(float(np.min(chunk))) 

507 self._stats_accumulator["max"].append(float(np.max(chunk))) 

508 

509 def get_psd(self) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

510 """Get aggregated PSD estimate. 

511 

512 Returns: 

513 Tuple of (frequencies, psd) with averaged PSD across chunks. 

514 

515 Raises: 

516 ValueError: If no PSD data accumulated. 

517 

518 Example: 

519 >>> freqs, psd = analyzer.get_psd() 

520 """ 

521 if not self._psd_accumulator: 

522 raise ValueError("No PSD data accumulated. Call accumulate_psd() first.") 

523 

524 if self._psd_frequencies is None: 524 ↛ 525line 524 didn't jump to line 525 because the condition on line 524 was never true

525 raise ValueError("PSD frequencies not initialized. Call accumulate_psd() first.") 

526 

527 # Average PSDs across all chunks 

528 avg_psd = np.mean(self._psd_accumulator, axis=0) 

529 return self._psd_frequencies, avg_psd 

530 

531 def get_stats(self) -> dict[str, float]: 

532 """Get aggregated statistics. 

533 

534 Returns: 

535 Dictionary with overall mean, std, min, max. 

536 

537 Example: 

538 >>> stats = analyzer.get_stats() 

539 >>> print(f"Overall mean: {stats['mean']:.3f}") 

540 """ 

541 if not self._stats_accumulator["mean"]: 

542 return {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0} 

543 

544 return { 

545 "mean": float(np.mean(self._stats_accumulator["mean"])), 

546 "std": float(np.mean(self._stats_accumulator["std"])), 

547 "min": float(np.min(self._stats_accumulator["min"])), 

548 "max": float(np.max(self._stats_accumulator["max"])), 

549 } 

550 

551 def reset(self) -> None: 

552 """Reset all accumulated data. 

553 

554 Example: 

555 >>> analyzer.reset() 

556 """ 

557 self.chunk_count = 0 

558 self._psd_accumulator.clear() 

559 self._psd_frequencies = None 

560 self._psd_config.clear() 

561 for key in self._stats_accumulator: 

562 self._stats_accumulator[key].clear() 

563 

564 

565__all__ = [ 

566 "StreamingAnalyzer", 

567 "fft_chunked", 

568 "fft_chunked_parallel", 

569 "streaming_fft", 

570 "welch_psd_chunked", 

571]