Coverage for src / tracekit / streaming / chunked.py: 96%
178 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Streaming APIs for chunk-by-chunk processing of large files.
3This module implements memory-efficient streaming analysis for huge waveform
4files that don't fit in memory. Uses generator-based chunk loading and
5accumulator pattern for rolling statistics.
6"""
8from __future__ import annotations
10from pathlib import Path
11from typing import TYPE_CHECKING, Any, cast
13import numpy as np
14from scipy import signal
16from ..core.types import WaveformTrace
18if TYPE_CHECKING:
19 from collections.abc import Callable, Generator
21 from numpy.typing import NDArray
24def load_trace_chunks(
25 file_path: str | Path,
26 chunk_size: int | float = 100e6,
27 overlap: int = 0,
28 loader: Callable[[str | Path], WaveformTrace] | None = None,
29 progress_callback: Callable[[int, int], None] | None = None,
30) -> Generator[WaveformTrace, None, None]:
31 """Load large trace files chunk-by-chunk without loading into memory.
33 Yields chunks of the trace for memory-efficient processing. Supports
34 overlap between chunks for windowed operations that need continuity.
36 Args:
37 file_path: Path to trace file.
38 chunk_size: Size of each chunk in samples (if int) or bytes (if float).
39 Default 100e6 (100 MB).
40 overlap: Number of samples to overlap between chunks. Useful for
41 windowed operations like FFT. Default 0.
42 loader: Optional custom loader function. If None, uses default loader.
43 progress_callback: Optional callback(chunk_num, total_chunks) for
44 progress reporting.
46 Yields:
47 WaveformTrace chunks.
49 Raises:
50 ValueError: If failed to load trace metadata.
52 Example:
53 >>> # Stream 10 GB file in 100 MB chunks
54 >>> for chunk in tk.load_trace_chunks('huge_trace.bin', chunk_size=100e6):
55 ... mean = chunk.data.mean()
56 ... std = chunk.data.std()
57 ... print(f"Chunk stats: mean={mean:.3f}, std={std:.3f}")
59 Advanced Example:
60 >>> # Process with overlap for FFT continuity
61 >>> for chunk in tk.load_trace_chunks(
62 ... 'large_trace.bin',
63 ... chunk_size=50e6,
64 ... overlap=8192 # Overlap for continuity
65 ... ):
66 ... fft_result = tk.fft(chunk, nfft=8192)
67 ... # Process FFT result...
69 References:
70 API-003: Streaming/Generator API for Large Files
71 """
72 file_path = Path(file_path)
74 # Import loader here to avoid circular dependency
75 from ..loaders import load
77 # Use provided loader or default
78 load_func = loader if loader is not None else load
80 # Load full trace metadata to get total size
81 # For memory-mapped files, this doesn't load data
82 try:
83 full_trace = load_func(file_path)
84 except Exception as e:
85 raise ValueError(f"Failed to load trace metadata: {e}") from e
87 total_samples = len(full_trace.data) # type: ignore[union-attr]
88 chunk_samples = int(chunk_size) if chunk_size < 1e6 else int(chunk_size / 8)
90 # Calculate number of chunks
91 num_chunks = (total_samples - overlap) // (chunk_samples - overlap)
92 if (total_samples - overlap) % (chunk_samples - overlap) != 0:
93 num_chunks += 1
95 # Yield chunks
96 chunk_num = 0
97 start_idx = 0
99 while start_idx < total_samples: 99 ↛ exitline 99 didn't return from function 'load_trace_chunks' because the condition on line 99 was always true
100 end_idx = min(start_idx + chunk_samples, total_samples)
102 # Extract chunk
103 chunk_data = full_trace.data[start_idx:end_idx] # type: ignore[union-attr]
105 # Create chunk trace with same metadata
106 # Cast needed for mypy: slicing a floating array returns a floating array
107 chunk_trace = WaveformTrace(
108 data=cast("NDArray[np.floating[Any]]", chunk_data),
109 metadata=full_trace.metadata,
110 )
112 # Call progress callback if provided
113 if progress_callback is not None:
114 progress_callback(chunk_num, num_chunks)
116 yield chunk_trace
118 # Move to next chunk, accounting for overlap
119 start_idx = end_idx - overlap
120 chunk_num += 1
122 # Break if we've reached the end
123 if end_idx >= total_samples:
124 break
127class StreamingAnalyzer:
128 """Accumulator for streaming analysis of large files.
130 Processes traces chunk-by-chunk, accumulating statistics and measurements
131 without loading entire file into memory. Supports streaming PSD estimation
132 using Welch's method and other rolling statistics.
134 Example:
135 >>> # Create streaming analyzer
136 >>> analyzer = tk.StreamingAnalyzer()
137 >>> # Process file in chunks
138 >>> for chunk in tk.load_trace_chunks('large_trace.bin', chunk_size=50e6):
139 ... analyzer.accumulate_psd(chunk, nperseg=4096, window='hann')
140 >>> # Get aggregated result
141 >>> psd_result = analyzer.get_psd()
143 Advanced Example:
144 >>> # Compute multiple statistics in streaming fashion
145 >>> analyzer = tk.StreamingAnalyzer()
146 >>> for chunk in tk.load_trace_chunks('huge_file.bin'):
147 ... analyzer.accumulate_statistics(chunk)
148 ... analyzer.accumulate_psd(chunk, nperseg=8192)
149 >>> stats = analyzer.get_statistics()
150 >>> psd = analyzer.get_psd()
151 >>> print(f"Mean: {stats['mean']:.3f}, PSD shape: {psd.shape}")
153 References:
154 API-003: Streaming/Generator API for Large Files
155 scipy.signal.welch for streaming PSD
156 """
158 def __init__(self) -> None:
159 """Initialize streaming analyzer."""
160 # Statistics accumulators
161 self._n_samples = 0
162 self._sum = 0.0
163 self._sum_sq = 0.0
164 self._min = float("inf")
165 self._max = float("-inf")
167 # PSD accumulators
168 self._psd_sum: NDArray[np.float64] | None = None
169 self._psd_freqs: NDArray[np.float64] | None = None
170 self._psd_count = 0
171 self._sample_rate: float | None = None
173 # Histogram accumulators
174 self._hist_counts: NDArray[np.int64] | None = None
175 self._hist_edges: NDArray[np.float64] | None = None
177 def accumulate_statistics(self, chunk: WaveformTrace) -> None:
178 """Accumulate basic statistics from chunk.
180 Updates running mean, std, min, max using Welford's online algorithm.
182 Args:
183 chunk: WaveformTrace chunk to process.
185 Example:
186 >>> analyzer.accumulate_statistics(chunk)
187 """
188 chunk_data = chunk.data
189 self._n_samples += len(chunk_data)
190 self._sum += float(chunk_data.sum())
191 self._sum_sq += float((chunk_data**2).sum())
192 self._min = min(self._min, float(chunk_data.min()))
193 self._max = max(self._max, float(chunk_data.max()))
195 def accumulate_psd(
196 self,
197 chunk: WaveformTrace,
198 nperseg: int = 4096,
199 window: str = "hann",
200 **welch_kwargs: Any,
201 ) -> None:
202 """Accumulate PSD estimate from chunk using Welch's method.
204 Computes PSD for chunk and accumulates with running average.
206 Args:
207 chunk: WaveformTrace chunk to process.
208 nperseg: Length of each segment for Welch's method.
209 window: Window function name (default 'hann').
210 **welch_kwargs: Additional arguments for scipy.signal.welch.
212 Example:
213 >>> analyzer.accumulate_psd(chunk, nperseg=4096, window='hann')
215 References:
216 scipy.signal.welch
217 https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.welch.html
218 """
219 # Store sample rate from first chunk
220 if self._sample_rate is None:
221 self._sample_rate = chunk.metadata.sample_rate
223 # Compute PSD for this chunk using Welch's method
224 freqs, psd = signal.welch(
225 chunk.data,
226 fs=chunk.metadata.sample_rate,
227 nperseg=nperseg,
228 window=window,
229 **welch_kwargs,
230 )
232 # Initialize or accumulate
233 if self._psd_sum is None:
234 self._psd_sum = psd
235 self._psd_freqs = freqs
236 else:
237 # Accumulate PSD estimates
238 self._psd_sum += psd
240 self._psd_count += 1
242 def accumulate_histogram(
243 self,
244 chunk: WaveformTrace,
245 bins: int | NDArray[np.float64] = 100,
246 range: tuple[float, float] | None = None,
247 ) -> None:
248 """Accumulate histogram from chunk.
250 Args:
251 chunk: WaveformTrace chunk to process.
252 bins: Number of bins or bin edges.
253 range: Range of histogram (min, max).
255 Example:
256 >>> analyzer.accumulate_histogram(chunk, bins=100)
257 """
258 counts, edges = np.histogram(chunk.data, bins=bins, range=range)
260 if self._hist_counts is None:
261 self._hist_counts = counts.astype(np.int64)
262 self._hist_edges = edges
263 else:
264 self._hist_counts += counts.astype(np.int64)
266 def get_statistics(self) -> dict[str, float]:
267 """Get accumulated statistics.
269 Returns:
270 Dictionary with mean, std, min, max, and sample count.
272 Raises:
273 ValueError: If no data accumulated yet.
275 Example:
276 >>> stats = analyzer.get_statistics()
277 >>> print(f"Mean: {stats['mean']:.3f}, Std: {stats['std']:.3f}")
278 """
279 if self._n_samples == 0:
280 raise ValueError("No data accumulated yet")
282 mean = self._sum / self._n_samples
283 variance = (self._sum_sq / self._n_samples) - (mean**2)
284 std = np.sqrt(max(0, variance)) # Avoid negative due to numerical errors
286 return {
287 "mean": mean,
288 "std": std,
289 "min": self._min,
290 "max": self._max,
291 "n_samples": self._n_samples,
292 }
294 def get_psd(self) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
295 """Get accumulated PSD estimate.
297 Returns:
298 Tuple of (frequencies, psd) where psd is averaged over all chunks.
300 Raises:
301 ValueError: If no PSD data accumulated.
303 Example:
304 >>> freqs, psd = analyzer.get_psd()
305 >>> print(f"PSD shape: {psd.shape}")
306 """
307 if self._psd_sum is None or self._psd_freqs is None:
308 raise ValueError("No PSD data accumulated yet")
310 # Return averaged PSD
311 psd_avg = self._psd_sum / self._psd_count
312 return self._psd_freqs, psd_avg
314 def get_histogram(self) -> tuple[NDArray[np.int64], NDArray[np.float64]]:
315 """Get accumulated histogram.
317 Returns:
318 Tuple of (counts, edges).
320 Raises:
321 ValueError: If no histogram data accumulated.
323 Example:
324 >>> counts, edges = analyzer.get_histogram()
325 """
326 if self._hist_counts is None or self._hist_edges is None:
327 raise ValueError("No histogram data accumulated yet")
329 return self._hist_counts, self._hist_edges
331 def reset(self) -> None:
332 """Reset all accumulators.
334 Example:
335 >>> analyzer.reset()
336 """
337 self._n_samples = 0
338 self._sum = 0.0
339 self._sum_sq = 0.0
340 self._min = float("inf")
341 self._max = float("-inf")
342 self._psd_sum = None
343 self._psd_freqs = None
344 self._psd_count = 0
345 self._sample_rate = None
346 self._hist_counts = None
347 self._hist_edges = None
350def chunked_spectrogram(
351 data: NDArray[np.float64],
352 sample_rate: float,
353 *,
354 chunk_size: int = 10_000_000,
355 overlap: int = 0,
356 nperseg: int = 256,
357 noverlap: int | None = None,
358 window: str = "hann",
359) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]:
360 """Compute spectrogram for large signals using chunked processing.
363 Processes large signals in overlapping chunks to compute spectrograms
364 without loading entire signal into memory. Stitches STFT results from
365 chunks with proper boundary handling.
367 Args:
368 data: Input signal array (can be memory-mapped).
369 sample_rate: Sample rate in Hz.
370 chunk_size: Maximum samples per chunk (default 10M).
371 overlap: Overlap samples between chunks for continuity (default 0).
372 Should be at least 2*nperseg for proper STFT boundary handling.
373 nperseg: Segment length for STFT (default 256).
374 noverlap: Overlap between STFT segments within chunk (default nperseg//2).
375 window: Window function name (default "hann").
377 Returns:
378 (times, frequencies, Sxx_db) - Time axis, frequency axis, and
379 spectrogram magnitude in dB as 2D array (frequencies x time).
381 Raises:
382 ValueError: If no valid chunks produced.
384 Example:
385 >>> # Memory-efficient spectrogram on 1 GB signal
386 >>> import numpy as np
387 >>> data = np.memmap('huge_trace.dat', dtype='float64', mode='r')
388 >>> t, f, Sxx = chunked_spectrogram(data, sample_rate=1e9, chunk_size=10_000_000)
389 >>> print(f"Spectrogram shape: {Sxx.shape}")
391 References:
392 MEM-004: Chunked Spectrogram requirement
393 scipy.signal.spectrogram
394 """
395 n = len(data)
397 # Handle empty input
398 if n == 0:
399 return np.array([]), np.array([]), np.array([]).reshape(0, 0)
401 if noverlap is None:
402 noverlap = nperseg // 2
404 # Auto-adjust overlap if not specified to ensure continuity
405 if overlap == 0:
406 overlap = 2 * nperseg
408 # If data fits in one chunk, use scipy directly
409 if n <= chunk_size:
410 freq, times, Sxx = signal.spectrogram(
411 data,
412 fs=sample_rate,
413 window=window,
414 nperseg=nperseg,
415 noverlap=noverlap,
416 scaling="spectrum",
417 )
418 # Convert to dB
419 Sxx = np.maximum(Sxx, 1e-20)
420 Sxx_db = 10 * np.log10(Sxx)
421 return times, freq, Sxx_db
423 # Process chunks
424 chunks_stft = []
425 chunks_times = []
426 chunk_start = 0
428 while chunk_start < n:
429 # Determine chunk boundaries with overlap
430 chunk_end = min(chunk_start + chunk_size, n)
432 # Extract chunk with overlap extension on both sides
433 extended_start = max(0, chunk_start - overlap)
434 extended_end = min(n, chunk_end + overlap)
436 chunk_data = data[extended_start:extended_end]
438 # Compute spectrogram for chunk
439 freq, times_chunk, Sxx_chunk = signal.spectrogram(
440 chunk_data,
441 fs=sample_rate,
442 window=window,
443 nperseg=nperseg,
444 noverlap=noverlap,
445 scaling="spectrum",
446 )
448 # Adjust time axis for chunk position
449 time_offset = extended_start / sample_rate
450 times_chunk_adjusted = times_chunk + time_offset
452 # Trim overlap regions to avoid duplication
453 valid_time_start = chunk_start / sample_rate
454 valid_time_end = chunk_end / sample_rate
456 valid_mask = (times_chunk_adjusted >= valid_time_start) & (
457 times_chunk_adjusted < valid_time_end
458 )
460 if np.any(valid_mask): 460 ↛ 468line 460 didn't jump to line 468 because the condition on line 460 was always true
461 Sxx_chunk = Sxx_chunk[:, valid_mask]
462 times_chunk_adjusted = times_chunk_adjusted[valid_mask]
464 chunks_stft.append(Sxx_chunk)
465 chunks_times.append(times_chunk_adjusted)
467 # Move to next chunk
468 chunk_start = chunk_end
470 # Concatenate all chunks
471 if len(chunks_stft) == 0: 471 ↛ 472line 471 didn't jump to line 472 because the condition on line 471 was never true
472 raise ValueError("No valid chunks produced")
474 Sxx = np.concatenate(chunks_stft, axis=1)
475 times = np.concatenate(chunks_times)
477 # Convert to dB
478 Sxx = np.maximum(Sxx, 1e-20)
479 Sxx_db = 10 * np.log10(Sxx)
481 return times, freq, Sxx_db
484def chunked_fft(
485 data: NDArray[np.float64],
486 sample_rate: float,
487 *,
488 chunk_size: int = 10_000_000,
489 overlap: float = 50.0,
490 window: str = "hann",
491 nfft: int | None = None,
492) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
493 """Compute FFT for very long signals using segmented averaging.
496 Divides signal into overlapping segments, computes FFT for each,
497 and averages magnitude spectra. This is memory-bounded by chunk_size
498 and provides variance reduction through averaging (similar to Welch's method).
500 Args:
501 data: Input signal array (can be memory-mapped).
502 sample_rate: Sample rate in Hz.
503 chunk_size: Size of each segment in samples (default 10M).
504 overlap: Percentage overlap between segments, 0-100 (default 50%).
505 window: Window function name (default "hann").
506 nfft: FFT length. If None, uses next power of 2 >= chunk_size.
508 Returns:
509 (frequencies, magnitude_db) - Frequency axis and averaged magnitude in dB.
511 Example:
512 >>> # Memory-efficient FFT on 1 GB signal with 50% overlap
513 >>> import numpy as np
514 >>> data = np.memmap('huge_trace.dat', dtype='float64', mode='r')
515 >>> freq, mag = chunked_fft(data, sample_rate=1e9, chunk_size=1_000_000)
516 >>> print(f"Frequency resolution: {freq[1] - freq[0]:.3f} Hz")
518 References:
519 MEM-006: Chunked FFT requirement
520 Welch's method for spectral estimation
521 """
522 from ..utils.windowing import get_window
524 n = len(data)
526 # Handle empty input
527 if n == 0:
528 return np.array([]), np.array([])
530 # If data fits in one chunk, compute single FFT
531 if n <= chunk_size:
532 if nfft is None: 532 ↛ 536line 532 didn't jump to line 536 because the condition on line 532 was always true
533 nfft = int(2 ** np.ceil(np.log2(n)))
535 # Apply window
536 w = get_window(window, n)
537 data_windowed = data * w
539 # Compute FFT
540 spectrum = np.fft.rfft(data_windowed, n=nfft)
542 # Frequency axis
543 freq = np.fft.rfftfreq(nfft, d=1.0 / sample_rate)
545 # Magnitude in dB (normalized by window gain)
546 window_gain = np.sum(w) / n
547 magnitude = np.abs(spectrum) / (n * window_gain)
548 magnitude = np.maximum(magnitude, 1e-20)
549 magnitude_db = 20 * np.log10(magnitude)
551 return freq, magnitude_db
553 # Calculate overlap
554 overlap_samples = int(chunk_size * overlap / 100.0)
555 hop = chunk_size - overlap_samples
557 # Determine number of segments
558 num_segments = max(1, (n - overlap_samples) // hop)
560 if nfft is None:
561 nfft = int(2 ** np.ceil(np.log2(chunk_size)))
563 # Prepare window
564 w = get_window(window, chunk_size)
565 window_gain = np.sum(w) / chunk_size
567 # Accumulate magnitude spectra
568 freq = np.fft.rfftfreq(nfft, d=1.0 / sample_rate)
569 magnitude_sum = np.zeros(len(freq))
571 for i in range(num_segments):
572 start = i * hop
573 end = min(start + chunk_size, n)
575 # Extract segment
576 if end - start < chunk_size: 576 ↛ 578line 576 didn't jump to line 578 because the condition on line 576 was never true
577 # Last segment: pad with zeros
578 segment = np.zeros(chunk_size)
579 segment[: end - start] = data[start:end]
580 else:
581 segment = data[start:end]
583 # Detrend (remove mean)
584 segment = segment - np.mean(segment)
586 # Window
587 segment_windowed = segment * w
589 # FFT
590 spectrum = np.fft.rfft(segment_windowed, n=nfft)
592 # Accumulate magnitude
593 magnitude = np.abs(spectrum) / (chunk_size * window_gain)
594 magnitude_sum += magnitude
596 # Average
597 magnitude_avg = magnitude_sum / num_segments
599 # Convert to dB
600 magnitude_avg = np.maximum(magnitude_avg, 1e-20)
601 magnitude_db = 20 * np.log10(magnitude_avg)
603 return freq, magnitude_db
606__all__ = [
607 "StreamingAnalyzer",
608 "chunked_fft",
609 "chunked_spectrogram",
610 "load_trace_chunks",
611]