Coverage for src / tracekit / analyzers / spectral / chunked_wavelet.py: 97%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Chunked wavelet transform for memory-bounded processing. 

2 

3This module implements memory-bounded wavelet transforms (CWT and DWT) 

4with segment processing and boundary handling. 

5 

6 

7Example: 

8 >>> from tracekit.analyzers.spectral.chunked_wavelet import cwt_chunked, dwt_chunked 

9 >>> coeffs = cwt_chunked('large_signal.bin', scales=[1, 2, 4, 8], wavelet='morl') 

10 >>> print(f"CWT coefficients shape: {coeffs.shape}") 

11 

12References: 

13 pywt (PyWavelets) for wavelet transforms 

14 Mallat, S. (1999). "A Wavelet Tour of Signal Processing" 

15""" 

16 

17from __future__ import annotations 

18 

19from pathlib import Path 

20from typing import TYPE_CHECKING, Any 

21 

22import numpy as np 

23 

24if TYPE_CHECKING: 

25 from collections.abc import Iterator, Sequence 

26 

27 from numpy.typing import NDArray 

28 

29 

30def cwt_chunked( 

31 file_path: str | Path, 

32 scales: Sequence[float], 

33 wavelet: str = "morl", 

34 *, 

35 chunk_size: int | float = 1e6, 

36 overlap_factor: float = 2.0, 

37 sample_rate: float = 1.0, 

38 dtype: str = "float32", 

39) -> tuple[NDArray[Any], NDArray[Any]]: 

40 """Compute continuous wavelet transform for large files. 

41 

42 

43 Processes signal in chunks with overlap to handle boundaries, 

44 computes CWT per chunk, and stitches results. 

45 

46 Args: 

47 file_path: Path to signal file (binary format). 

48 scales: Scales for CWT (wavelet dilations). 

49 wavelet: Wavelet name ('morl', 'mexh', 'cmor', etc.). 

50 chunk_size: Chunk size in samples. 

51 overlap_factor: Overlap factor for boundaries (e.g., 2.0 = 2x max scale). 

52 sample_rate: Sample rate in Hz. 

53 dtype: Data type of input file ('float32' or 'float64'). 

54 

55 Returns: 

56 Tuple of (coefficients, frequencies) where: 

57 - coefficients: CWT coefficients (scales x time). 

58 - frequencies: Corresponding frequencies for each scale. 

59 

60 Raises: 

61 ImportError: If pywt (PyWavelets) is not installed. 

62 ValueError: If file cannot be read or scales invalid. 

63 

64 Example: 

65 >>> scales = np.arange(1, 128) 

66 >>> coeffs, freqs = cwt_chunked( 

67 ... 'signal.bin', 

68 ... scales=scales, 

69 ... wavelet='morl', 

70 ... chunk_size=1e6, 

71 ... sample_rate=1e6 

72 ... ) 

73 >>> print(f"CWT shape: {coeffs.shape}") 

74 

75 References: 

76 MEM-007: Chunked Wavelet Transform 

77 """ 

78 try: 

79 import pywt 

80 except ImportError as e: 

81 raise ImportError( 

82 "pywt (PyWavelets) is required for wavelet transforms. " 

83 "Install with: pip install PyWavelets" 

84 ) from e 

85 

86 chunk_size = int(chunk_size) 

87 scales: NDArray[np.float64] = np.asarray(scales) # type: ignore[no-redef] 

88 

89 # Calculate boundary overlap (proportional to max scale) 

90 max_scale = np.max(scales) 

91 boundary_overlap = int(overlap_factor * max_scale) 

92 

93 # Determine dtype 

94 np_dtype = np.float32 if dtype == "float32" else np.float64 

95 bytes_per_sample = 4 if dtype == "float32" else 8 

96 

97 # Open file and get total size 

98 file_path = Path(file_path) 

99 file_size_bytes = file_path.stat().st_size 

100 total_samples = file_size_bytes // bytes_per_sample 

101 

102 # Process chunks 

103 coeffs_list: list[NDArray[Any]] = [] 

104 chunks = _generate_chunks(file_path, total_samples, chunk_size, boundary_overlap, np_dtype) 

105 

106 for chunk_data in chunks: 

107 # Compute CWT for this chunk 

108 coeffs_chunk, freqs = pywt.cwt( 

109 chunk_data, 

110 scales, 

111 wavelet, 

112 sampling_period=1 / sample_rate, 

113 ) 

114 

115 # Remove boundary overlap regions (except first/last chunk) 

116 if len(coeffs_list) > 0: 

117 # Remove left boundary 

118 trim_left = boundary_overlap 

119 coeffs_chunk = coeffs_chunk[:, trim_left:] 

120 

121 coeffs_list.append(coeffs_chunk) 

122 

123 # Concatenate all chunks 

124 if len(coeffs_list) == 0: 

125 raise ValueError(f"No chunks processed from {file_path}") 

126 

127 coefficients = np.concatenate(coeffs_list, axis=1) 

128 

129 return coefficients, freqs 

130 

131 

132def dwt_chunked( 

133 file_path: str | Path, 

134 wavelet: str = "db4", 

135 level: int | None = None, 

136 *, 

137 chunk_size: int | float = 1e6, 

138 mode: str = "symmetric", 

139 dtype: str = "float32", 

140) -> list[NDArray[Any]]: 

141 """Compute discrete wavelet transform for large files. 

142 

143 

144 Processes signal in chunks and computes multilevel DWT. 

145 Handles boundaries using specified extension mode. 

146 

147 Args: 

148 file_path: Path to signal file (binary format). 

149 wavelet: Wavelet name ('db4', 'haar', 'sym5', etc.). 

150 level: Decomposition level (None = maximum level). 

151 chunk_size: Chunk size in samples. 

152 mode: Signal extension mode for boundaries. 

153 dtype: Data type of input file ('float32' or 'float64'). 

154 

155 Returns: 

156 List of coefficient arrays [cA_n, cD_n, ..., cD_1] where: 

157 - cA_n: Approximation coefficients at level n. 

158 - cD_i: Detail coefficients at level i. 

159 

160 Raises: 

161 ImportError: If pywt is not installed. 

162 ValueError: If file cannot be read. 

163 

164 Example: 

165 >>> coeffs = dwt_chunked( 

166 ... 'signal.bin', 

167 ... wavelet='db4', 

168 ... level=5, 

169 ... chunk_size=1e6 

170 ... ) 

171 >>> print(f"Approximation shape: {coeffs[0].shape}") 

172 

173 References: 

174 MEM-007: Chunked Wavelet Transform 

175 Daubechies, I. (1992). "Ten Lectures on Wavelets" 

176 """ 

177 try: 

178 import pywt 

179 except ImportError as e: 

180 raise ImportError( 

181 "pywt (PyWavelets) is required for wavelet transforms. " 

182 "Install with: pip install PyWavelets" 

183 ) from e 

184 

185 chunk_size = int(chunk_size) 

186 

187 # Determine dtype 

188 np_dtype = np.float32 if dtype == "float32" else np.float64 

189 bytes_per_sample = 4 if dtype == "float32" else 8 

190 

191 # Get wavelet filter length for overlap calculation 

192 wavelet_obj = pywt.Wavelet(wavelet) 

193 filter_len = wavelet_obj.dec_len 

194 boundary_overlap = filter_len * (2 ** (level or 1)) 

195 

196 # Open file and get total size 

197 file_path = Path(file_path) 

198 file_size_bytes = file_path.stat().st_size 

199 total_samples = file_size_bytes // bytes_per_sample 

200 

201 # Process chunks 

202 coeffs_list: list[list[NDArray[Any]]] = [] 

203 chunk_starts: list[int] = [] # Track chunk start positions in original signal 

204 chunks = _generate_chunks(file_path, total_samples, chunk_size, boundary_overlap, np_dtype) 

205 

206 current_offset = 0 

207 for chunk_data in chunks: 

208 # Compute DWT for this chunk 

209 coeffs_chunk = pywt.wavedec(chunk_data, wavelet, mode=mode, level=level) 

210 coeffs_list.append(coeffs_chunk) 

211 chunk_starts.append(current_offset) 

212 # Move offset by chunk_size (not including overlap) 

213 current_offset = min(current_offset + chunk_size, total_samples) 

214 

215 # Merge coefficients from all chunks 

216 if len(coeffs_list) == 0: 

217 raise ValueError(f"No chunks processed from {file_path}") 

218 

219 # For DWT, we need to handle overlaps properly at each decomposition level 

220 # At each level j, the downsampling factor is 2^j 

221 # We trim the overlap region at each level based on the decomposition level 

222 

223 num_levels = len(coeffs_list[0]) 

224 merged_coeffs = [] 

225 

226 for level_idx in range(num_levels): 

227 # Calculate the effective overlap for this decomposition level 

228 # Each level is downsampled by factor of 2 from previous level 

229 downsample_factor = 2**level_idx 

230 level_overlap = boundary_overlap // downsample_factor 

231 

232 if len(coeffs_list) == 1: 

233 # Only one chunk - no overlap to handle 

234 merged_coeffs.append(coeffs_list[0][level_idx]) 

235 else: 

236 # Multiple chunks - need to handle overlaps 

237 merged_level_coeffs = [] 

238 

239 for chunk_idx, chunk_coeffs in enumerate(coeffs_list): 

240 level_coeffs = chunk_coeffs[level_idx] 

241 

242 if chunk_idx == 0: 

243 # First chunk - keep all except right overlap 

244 if level_overlap > 0 and len(level_coeffs) > level_overlap: 

245 merged_level_coeffs.append(level_coeffs[:-level_overlap]) 

246 else: 

247 merged_level_coeffs.append(level_coeffs) 

248 elif chunk_idx == len(coeffs_list) - 1: 

249 # Last chunk - trim left overlap, keep rest 

250 if level_overlap > 0 and len(level_coeffs) > level_overlap: 

251 merged_level_coeffs.append(level_coeffs[level_overlap:]) 

252 else: 

253 # If overlap is too large, keep small center portion 

254 center_start = max(0, level_overlap // 2) 

255 merged_level_coeffs.append(level_coeffs[center_start:]) 

256 else: 

257 # Middle chunks - trim both sides 

258 if level_overlap > 0 and len(level_coeffs) > 2 * level_overlap: 

259 merged_level_coeffs.append(level_coeffs[level_overlap:-level_overlap]) 

260 else: 

261 # If overlap is too large, keep small center portion 

262 center_start = max(0, level_overlap // 2) 

263 center_end = max(center_start + 1, len(level_coeffs) - level_overlap // 2) 

264 merged_level_coeffs.append(level_coeffs[center_start:center_end]) 

265 

266 # Concatenate the trimmed coefficients 

267 if merged_level_coeffs: 267 ↛ 271line 267 didn't jump to line 271 because the condition on line 267 was always true

268 merged_coeffs.append(np.concatenate(merged_level_coeffs)) 

269 else: 

270 # Fallback: concatenate all if trimming failed 

271 level_coeffs_list = [chunk_coeffs[level_idx] for chunk_coeffs in coeffs_list] 

272 merged_coeffs.append(np.concatenate(level_coeffs_list)) 

273 

274 return merged_coeffs 

275 

276 

277def _generate_chunks( 

278 file_path: Path, 

279 total_samples: int, 

280 chunk_size: int, 

281 boundary_overlap: int, 

282 dtype: type, 

283) -> Iterator[NDArray[Any]]: 

284 """Generate overlapping chunks from file. 

285 

286 Args: 

287 file_path: Path to binary file. 

288 total_samples: Total number of samples in file. 

289 chunk_size: Samples per chunk. 

290 boundary_overlap: Overlap samples between chunks. 

291 dtype: NumPy dtype for data. 

292 

293 Yields: 

294 Chunk arrays with boundary overlap. 

295 """ 

296 offset = 0 

297 

298 with open(file_path, "rb") as f: 

299 while offset < total_samples: 

300 # Calculate chunk boundaries 

301 chunk_start = max(0, offset - boundary_overlap) 

302 chunk_end = min(total_samples, offset + chunk_size) 

303 chunk_len = chunk_end - chunk_start 

304 

305 # Seek and read 

306 f.seek(chunk_start * dtype().itemsize) 

307 chunk_data: NDArray[np.float64] = np.fromfile(f, dtype=dtype, count=chunk_len) 

308 

309 if len(chunk_data) == 0: 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true

310 break 

311 

312 yield chunk_data 

313 

314 # Advance offset 

315 offset += chunk_size 

316 

317 

318def cwt_chunked_generator( 

319 file_path: str | Path, 

320 scales: Sequence[float], 

321 wavelet: str = "morl", 

322 *, 

323 chunk_size: int | float = 1e6, 

324 **kwargs: Any, # type: ignore[name-defined] 

325) -> Iterator[tuple[NDArray[Any], NDArray[Any]]]: # type: ignore[name-defined] 

326 """Generator version that yields CWT chunks. 

327 

328 Yields CWT coefficients for each chunk, useful for streaming processing. 

329 

330 Args: 

331 file_path: Path to signal file. 

332 scales: Scales for CWT. 

333 wavelet: Wavelet name. 

334 chunk_size: Chunk size in samples. 

335 **kwargs: Additional arguments. 

336 

337 Yields: 

338 Tuples of (coefficients, frequencies) for each chunk. 

339 

340 Raises: 

341 ImportError: If pywt (PyWavelets) is not installed. 

342 

343 Example: 

344 >>> for coeffs_chunk, freqs in cwt_chunked_generator('file.bin', scales=[1, 2, 4]): 

345 ... # Process each chunk separately 

346 ... print(f"Chunk shape: {coeffs_chunk.shape}") 

347 """ 

348 try: 

349 import pywt 

350 except ImportError as e: 

351 raise ImportError( 

352 "pywt (PyWavelets) is required for wavelet transforms. " 

353 "Install with: pip install PyWavelets" 

354 ) from e 

355 

356 chunk_size = int(chunk_size) 

357 scales: NDArray[np.float64] = np.asarray(scales) # type: ignore[no-redef] 

358 

359 # Determine dtype 

360 dtype = kwargs.get("dtype", "float32") 

361 np_dtype = np.float32 if dtype == "float32" else np.float64 

362 bytes_per_sample = 4 if dtype == "float32" else 8 

363 

364 # Open file and get total size 

365 file_path = Path(file_path) 

366 file_size_bytes = file_path.stat().st_size 

367 total_samples = file_size_bytes // bytes_per_sample 

368 

369 # Calculate boundary overlap 

370 max_scale = np.max(scales) 

371 boundary_overlap = int(kwargs.get("overlap_factor", 2.0) * max_scale) 

372 

373 # Process chunks 

374 sample_rate = kwargs.get("sample_rate", 1.0) 

375 chunks = _generate_chunks(file_path, total_samples, chunk_size, boundary_overlap, np_dtype) 

376 

377 for chunk_data in chunks: 

378 coeffs_chunk, freqs = pywt.cwt( 

379 chunk_data, 

380 scales, 

381 wavelet, 

382 sampling_period=1 / sample_rate, 

383 ) 

384 yield coeffs_chunk, freqs 

385 

386 

387__all__ = [ 

388 "cwt_chunked", 

389 "cwt_chunked_generator", 

390 "dwt_chunked", 

391]