Coverage for src / tracekit / analyzers / spectral / chunked_wavelet.py: 97%
118 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Chunked wavelet transform for memory-bounded processing.
3This module implements memory-bounded wavelet transforms (CWT and DWT)
4with segment processing and boundary handling.
7Example:
8 >>> from tracekit.analyzers.spectral.chunked_wavelet import cwt_chunked, dwt_chunked
9 >>> coeffs = cwt_chunked('large_signal.bin', scales=[1, 2, 4, 8], wavelet='morl')
10 >>> print(f"CWT coefficients shape: {coeffs.shape}")
12References:
13 pywt (PyWavelets) for wavelet transforms
14 Mallat, S. (1999). "A Wavelet Tour of Signal Processing"
15"""
17from __future__ import annotations
19from pathlib import Path
20from typing import TYPE_CHECKING, Any
22import numpy as np
24if TYPE_CHECKING:
25 from collections.abc import Iterator, Sequence
27 from numpy.typing import NDArray
30def cwt_chunked(
31 file_path: str | Path,
32 scales: Sequence[float],
33 wavelet: str = "morl",
34 *,
35 chunk_size: int | float = 1e6,
36 overlap_factor: float = 2.0,
37 sample_rate: float = 1.0,
38 dtype: str = "float32",
39) -> tuple[NDArray[Any], NDArray[Any]]:
40 """Compute continuous wavelet transform for large files.
43 Processes signal in chunks with overlap to handle boundaries,
44 computes CWT per chunk, and stitches results.
46 Args:
47 file_path: Path to signal file (binary format).
48 scales: Scales for CWT (wavelet dilations).
49 wavelet: Wavelet name ('morl', 'mexh', 'cmor', etc.).
50 chunk_size: Chunk size in samples.
51 overlap_factor: Overlap factor for boundaries (e.g., 2.0 = 2x max scale).
52 sample_rate: Sample rate in Hz.
53 dtype: Data type of input file ('float32' or 'float64').
55 Returns:
56 Tuple of (coefficients, frequencies) where:
57 - coefficients: CWT coefficients (scales x time).
58 - frequencies: Corresponding frequencies for each scale.
60 Raises:
61 ImportError: If pywt (PyWavelets) is not installed.
62 ValueError: If file cannot be read or scales invalid.
64 Example:
65 >>> scales = np.arange(1, 128)
66 >>> coeffs, freqs = cwt_chunked(
67 ... 'signal.bin',
68 ... scales=scales,
69 ... wavelet='morl',
70 ... chunk_size=1e6,
71 ... sample_rate=1e6
72 ... )
73 >>> print(f"CWT shape: {coeffs.shape}")
75 References:
76 MEM-007: Chunked Wavelet Transform
77 """
78 try:
79 import pywt
80 except ImportError as e:
81 raise ImportError(
82 "pywt (PyWavelets) is required for wavelet transforms. "
83 "Install with: pip install PyWavelets"
84 ) from e
86 chunk_size = int(chunk_size)
87 scales: NDArray[np.float64] = np.asarray(scales) # type: ignore[no-redef]
89 # Calculate boundary overlap (proportional to max scale)
90 max_scale = np.max(scales)
91 boundary_overlap = int(overlap_factor * max_scale)
93 # Determine dtype
94 np_dtype = np.float32 if dtype == "float32" else np.float64
95 bytes_per_sample = 4 if dtype == "float32" else 8
97 # Open file and get total size
98 file_path = Path(file_path)
99 file_size_bytes = file_path.stat().st_size
100 total_samples = file_size_bytes // bytes_per_sample
102 # Process chunks
103 coeffs_list: list[NDArray[Any]] = []
104 chunks = _generate_chunks(file_path, total_samples, chunk_size, boundary_overlap, np_dtype)
106 for chunk_data in chunks:
107 # Compute CWT for this chunk
108 coeffs_chunk, freqs = pywt.cwt(
109 chunk_data,
110 scales,
111 wavelet,
112 sampling_period=1 / sample_rate,
113 )
115 # Remove boundary overlap regions (except first/last chunk)
116 if len(coeffs_list) > 0:
117 # Remove left boundary
118 trim_left = boundary_overlap
119 coeffs_chunk = coeffs_chunk[:, trim_left:]
121 coeffs_list.append(coeffs_chunk)
123 # Concatenate all chunks
124 if len(coeffs_list) == 0:
125 raise ValueError(f"No chunks processed from {file_path}")
127 coefficients = np.concatenate(coeffs_list, axis=1)
129 return coefficients, freqs
132def dwt_chunked(
133 file_path: str | Path,
134 wavelet: str = "db4",
135 level: int | None = None,
136 *,
137 chunk_size: int | float = 1e6,
138 mode: str = "symmetric",
139 dtype: str = "float32",
140) -> list[NDArray[Any]]:
141 """Compute discrete wavelet transform for large files.
144 Processes signal in chunks and computes multilevel DWT.
145 Handles boundaries using specified extension mode.
147 Args:
148 file_path: Path to signal file (binary format).
149 wavelet: Wavelet name ('db4', 'haar', 'sym5', etc.).
150 level: Decomposition level (None = maximum level).
151 chunk_size: Chunk size in samples.
152 mode: Signal extension mode for boundaries.
153 dtype: Data type of input file ('float32' or 'float64').
155 Returns:
156 List of coefficient arrays [cA_n, cD_n, ..., cD_1] where:
157 - cA_n: Approximation coefficients at level n.
158 - cD_i: Detail coefficients at level i.
160 Raises:
161 ImportError: If pywt is not installed.
162 ValueError: If file cannot be read.
164 Example:
165 >>> coeffs = dwt_chunked(
166 ... 'signal.bin',
167 ... wavelet='db4',
168 ... level=5,
169 ... chunk_size=1e6
170 ... )
171 >>> print(f"Approximation shape: {coeffs[0].shape}")
173 References:
174 MEM-007: Chunked Wavelet Transform
175 Daubechies, I. (1992). "Ten Lectures on Wavelets"
176 """
177 try:
178 import pywt
179 except ImportError as e:
180 raise ImportError(
181 "pywt (PyWavelets) is required for wavelet transforms. "
182 "Install with: pip install PyWavelets"
183 ) from e
185 chunk_size = int(chunk_size)
187 # Determine dtype
188 np_dtype = np.float32 if dtype == "float32" else np.float64
189 bytes_per_sample = 4 if dtype == "float32" else 8
191 # Get wavelet filter length for overlap calculation
192 wavelet_obj = pywt.Wavelet(wavelet)
193 filter_len = wavelet_obj.dec_len
194 boundary_overlap = filter_len * (2 ** (level or 1))
196 # Open file and get total size
197 file_path = Path(file_path)
198 file_size_bytes = file_path.stat().st_size
199 total_samples = file_size_bytes // bytes_per_sample
201 # Process chunks
202 coeffs_list: list[list[NDArray[Any]]] = []
203 chunk_starts: list[int] = [] # Track chunk start positions in original signal
204 chunks = _generate_chunks(file_path, total_samples, chunk_size, boundary_overlap, np_dtype)
206 current_offset = 0
207 for chunk_data in chunks:
208 # Compute DWT for this chunk
209 coeffs_chunk = pywt.wavedec(chunk_data, wavelet, mode=mode, level=level)
210 coeffs_list.append(coeffs_chunk)
211 chunk_starts.append(current_offset)
212 # Move offset by chunk_size (not including overlap)
213 current_offset = min(current_offset + chunk_size, total_samples)
215 # Merge coefficients from all chunks
216 if len(coeffs_list) == 0:
217 raise ValueError(f"No chunks processed from {file_path}")
219 # For DWT, we need to handle overlaps properly at each decomposition level
220 # At each level j, the downsampling factor is 2^j
221 # We trim the overlap region at each level based on the decomposition level
223 num_levels = len(coeffs_list[0])
224 merged_coeffs = []
226 for level_idx in range(num_levels):
227 # Calculate the effective overlap for this decomposition level
228 # Each level is downsampled by factor of 2 from previous level
229 downsample_factor = 2**level_idx
230 level_overlap = boundary_overlap // downsample_factor
232 if len(coeffs_list) == 1:
233 # Only one chunk - no overlap to handle
234 merged_coeffs.append(coeffs_list[0][level_idx])
235 else:
236 # Multiple chunks - need to handle overlaps
237 merged_level_coeffs = []
239 for chunk_idx, chunk_coeffs in enumerate(coeffs_list):
240 level_coeffs = chunk_coeffs[level_idx]
242 if chunk_idx == 0:
243 # First chunk - keep all except right overlap
244 if level_overlap > 0 and len(level_coeffs) > level_overlap:
245 merged_level_coeffs.append(level_coeffs[:-level_overlap])
246 else:
247 merged_level_coeffs.append(level_coeffs)
248 elif chunk_idx == len(coeffs_list) - 1:
249 # Last chunk - trim left overlap, keep rest
250 if level_overlap > 0 and len(level_coeffs) > level_overlap:
251 merged_level_coeffs.append(level_coeffs[level_overlap:])
252 else:
253 # If overlap is too large, keep small center portion
254 center_start = max(0, level_overlap // 2)
255 merged_level_coeffs.append(level_coeffs[center_start:])
256 else:
257 # Middle chunks - trim both sides
258 if level_overlap > 0 and len(level_coeffs) > 2 * level_overlap:
259 merged_level_coeffs.append(level_coeffs[level_overlap:-level_overlap])
260 else:
261 # If overlap is too large, keep small center portion
262 center_start = max(0, level_overlap // 2)
263 center_end = max(center_start + 1, len(level_coeffs) - level_overlap // 2)
264 merged_level_coeffs.append(level_coeffs[center_start:center_end])
266 # Concatenate the trimmed coefficients
267 if merged_level_coeffs: 267 ↛ 271line 267 didn't jump to line 271 because the condition on line 267 was always true
268 merged_coeffs.append(np.concatenate(merged_level_coeffs))
269 else:
270 # Fallback: concatenate all if trimming failed
271 level_coeffs_list = [chunk_coeffs[level_idx] for chunk_coeffs in coeffs_list]
272 merged_coeffs.append(np.concatenate(level_coeffs_list))
274 return merged_coeffs
277def _generate_chunks(
278 file_path: Path,
279 total_samples: int,
280 chunk_size: int,
281 boundary_overlap: int,
282 dtype: type,
283) -> Iterator[NDArray[Any]]:
284 """Generate overlapping chunks from file.
286 Args:
287 file_path: Path to binary file.
288 total_samples: Total number of samples in file.
289 chunk_size: Samples per chunk.
290 boundary_overlap: Overlap samples between chunks.
291 dtype: NumPy dtype for data.
293 Yields:
294 Chunk arrays with boundary overlap.
295 """
296 offset = 0
298 with open(file_path, "rb") as f:
299 while offset < total_samples:
300 # Calculate chunk boundaries
301 chunk_start = max(0, offset - boundary_overlap)
302 chunk_end = min(total_samples, offset + chunk_size)
303 chunk_len = chunk_end - chunk_start
305 # Seek and read
306 f.seek(chunk_start * dtype().itemsize)
307 chunk_data: NDArray[np.float64] = np.fromfile(f, dtype=dtype, count=chunk_len)
309 if len(chunk_data) == 0: 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true
310 break
312 yield chunk_data
314 # Advance offset
315 offset += chunk_size
318def cwt_chunked_generator(
319 file_path: str | Path,
320 scales: Sequence[float],
321 wavelet: str = "morl",
322 *,
323 chunk_size: int | float = 1e6,
324 **kwargs: Any, # type: ignore[name-defined]
325) -> Iterator[tuple[NDArray[Any], NDArray[Any]]]: # type: ignore[name-defined]
326 """Generator version that yields CWT chunks.
328 Yields CWT coefficients for each chunk, useful for streaming processing.
330 Args:
331 file_path: Path to signal file.
332 scales: Scales for CWT.
333 wavelet: Wavelet name.
334 chunk_size: Chunk size in samples.
335 **kwargs: Additional arguments.
337 Yields:
338 Tuples of (coefficients, frequencies) for each chunk.
340 Raises:
341 ImportError: If pywt (PyWavelets) is not installed.
343 Example:
344 >>> for coeffs_chunk, freqs in cwt_chunked_generator('file.bin', scales=[1, 2, 4]):
345 ... # Process each chunk separately
346 ... print(f"Chunk shape: {coeffs_chunk.shape}")
347 """
348 try:
349 import pywt
350 except ImportError as e:
351 raise ImportError(
352 "pywt (PyWavelets) is required for wavelet transforms. "
353 "Install with: pip install PyWavelets"
354 ) from e
356 chunk_size = int(chunk_size)
357 scales: NDArray[np.float64] = np.asarray(scales) # type: ignore[no-redef]
359 # Determine dtype
360 dtype = kwargs.get("dtype", "float32")
361 np_dtype = np.float32 if dtype == "float32" else np.float64
362 bytes_per_sample = 4 if dtype == "float32" else 8
364 # Open file and get total size
365 file_path = Path(file_path)
366 file_size_bytes = file_path.stat().st_size
367 total_samples = file_size_bytes // bytes_per_sample
369 # Calculate boundary overlap
370 max_scale = np.max(scales)
371 boundary_overlap = int(kwargs.get("overlap_factor", 2.0) * max_scale)
373 # Process chunks
374 sample_rate = kwargs.get("sample_rate", 1.0)
375 chunks = _generate_chunks(file_path, total_samples, chunk_size, boundary_overlap, np_dtype)
377 for chunk_data in chunks:
378 coeffs_chunk, freqs = pywt.cwt(
379 chunk_data,
380 scales,
381 wavelet,
382 sampling_period=1 / sample_rate,
383 )
384 yield coeffs_chunk, freqs
387__all__ = [
388 "cwt_chunked",
389 "cwt_chunked_generator",
390 "dwt_chunked",
391]