Coverage for src / tracekit / loaders / numpy_loader.py: 91%

146 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""NumPy NPZ file loader. 

2 

3This module provides loading of waveform data from NumPy .npz archive files. 

4 

5 

6Example: 

7 >>> from tracekit.loaders.numpy_loader import load_npz 

8 >>> trace = load_npz("waveform.npz") 

9 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz") 

10""" 

11 

12from __future__ import annotations 

13 

14from pathlib import Path 

15from typing import TYPE_CHECKING, Any 

16 

17import numpy as np 

18from numpy.typing import NDArray 

19 

20from tracekit.core.exceptions import FormatError, LoaderError 

21from tracekit.core.types import TraceMetadata, WaveformTrace 

22 

23if TYPE_CHECKING: 

24 from os import PathLike 

25 

26 

27# Common array names for waveform data 

28DATA_ARRAY_NAMES = ["data", "waveform", "signal", "samples", "y", "voltage"] 

29 

30# Common metadata keys 

31SAMPLE_RATE_KEYS = ["sample_rate", "samplerate", "fs", "sampling_rate", "rate"] 

32VERTICAL_SCALE_KEYS = ["vertical_scale", "v_scale", "scale", "volts_per_div"] 

33VERTICAL_OFFSET_KEYS = ["vertical_offset", "v_offset", "offset"] 

34 

35 

36def load_npz( 

37 path: str | PathLike[str], 

38 *, 

39 channel: str | int | None = None, 

40 sample_rate: float | None = None, 

41 mmap: bool = False, 

42) -> WaveformTrace: 

43 """Load waveform data from a NumPy NPZ archive. 

44 

45 

46 Extracts waveform array and metadata from an NPZ file. The function 

47 looks for common array names like 'data', 'waveform', 'signal', etc. 

48 

49 Args: 

50 path: Path to the .npz file. 

51 channel: Specific array name or index to load. If None, auto-detects. 

52 sample_rate: Override sample rate (if not found in file metadata). 

53 mmap: If True, use memory mapping to avoid loading entire file into RAM. 

54 Data stays on disk until accessed. Useful for very large files. 

55 

56 Returns: 

57 WaveformTrace containing the waveform data and metadata. 

58 

59 Raises: 

60 LoaderError: If the file cannot be loaded. 

61 FormatError: If no valid waveform data is found. 

62 

63 Example: 

64 >>> trace = load_npz("waveform.npz") 

65 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz") 

66 

67 >>> # Load specific channel from multi-channel file 

68 >>> trace = load_npz("multi.npz", channel="ch1") 

69 

70 >>> # Memory-map large file to avoid loading all into RAM 

71 >>> trace = load_npz("large.npz", mmap=True) 

72 >>> # Access only what you need: trace.data[1000:2000] 

73 

74 Security Warning: 

75 NPZ files may contain pickled Python objects. Only load NPZ files from 

76 trusted sources. Loading malicious NPZ files could execute arbitrary 

77 code. For untrusted data, prefer formats like plain NumPy arrays (.npy), 

78 CSV, or HDF5. 

79 """ 

80 path = Path(path) 

81 

82 if not path.exists(): 

83 raise LoaderError( 

84 "File not found", 

85 file_path=str(path), 

86 ) 

87 

88 try: 

89 npz = np.load(path, allow_pickle=True, mmap_mode="r" if mmap else None) 

90 except Exception as e: 

91 raise LoaderError( 

92 "Failed to load NPZ file", 

93 file_path=str(path), 

94 details=str(e), 

95 ) from e 

96 

97 try: 

98 # Find waveform data array 

99 data_array = _find_data_array(npz, channel) 

100 

101 if data_array is None: 

102 available = list(npz.keys()) 

103 raise FormatError( 

104 "No waveform data found in NPZ file", 

105 file_path=str(path), 

106 expected=f"Array named: {', '.join(DATA_ARRAY_NAMES)}", 

107 got=f"Arrays: {', '.join(available)}", 

108 ) 

109 

110 # Convert to float64 (keep mmap if enabled) 

111 if mmap and isinstance(data_array, np.memmap): 111 ↛ 113line 111 didn't jump to line 113 because the condition on line 111 was never true

112 # Keep as memmap, just ensure float64 dtype 

113 if data_array.dtype != np.float64: 

114 # For memmap, we need to copy to convert dtype 

115 try: 

116 data = data_array.astype(np.float64) 

117 except (ValueError, TypeError) as e: 

118 raise FormatError( 

119 "Data array is not numeric", 

120 file_path=str(path), 

121 expected="Numeric dtype (int, float)", 

122 got=f"{data_array.dtype}", 

123 ) from e 

124 else: 

125 data = data_array 

126 else: 

127 try: 

128 data = data_array.astype(np.float64) 

129 except (ValueError, TypeError) as e: 

130 raise FormatError( 

131 "Data array is not numeric", 

132 file_path=str(path), 

133 expected="Numeric dtype (int, float)", 

134 got=f"{data_array.dtype}", 

135 ) from e 

136 

137 # Extract metadata 

138 detected_sample_rate = _find_metadata_value(npz, SAMPLE_RATE_KEYS) 

139 detected_vertical_scale = _find_metadata_value(npz, VERTICAL_SCALE_KEYS) 

140 detected_vertical_offset = _find_metadata_value(npz, VERTICAL_OFFSET_KEYS) 

141 

142 # Use provided sample_rate if specified 

143 if sample_rate is not None: 

144 detected_sample_rate = sample_rate 

145 elif detected_sample_rate is None: 

146 detected_sample_rate = 1e6 # Default to 1 MSa/s 

147 

148 # Build metadata 

149 metadata = TraceMetadata( 

150 sample_rate=float(detected_sample_rate), 

151 vertical_scale=float(detected_vertical_scale) 

152 if detected_vertical_scale is not None 

153 else None, 

154 vertical_offset=float(detected_vertical_offset) 

155 if detected_vertical_offset is not None 

156 else None, 

157 source_file=str(path), 

158 channel_name=_get_channel_name(npz, channel), 

159 ) 

160 

161 return WaveformTrace(data=data, metadata=metadata) 

162 

163 finally: 

164 npz.close() 

165 

166 

167def _find_data_array( 

168 npz: np.lib.npyio.NpzFile, 

169 channel: str | int | None, 

170) -> NDArray[np.float64] | None: 

171 """Find the waveform data array in the NPZ file. 

172 

173 Args: 

174 npz: Loaded NPZ file. 

175 channel: Specific channel name or index. 

176 

177 Returns: 

178 Waveform data array or None if not found. 

179 """ 

180 keys = list(npz.keys()) 

181 

182 # If channel specified by name 

183 if isinstance(channel, str): 

184 if channel in keys: 

185 return npz[channel] 

186 # Try case-insensitive match 

187 channel_lower = channel.lower() 

188 for key in keys: 

189 if key.lower() == channel_lower: 

190 return npz[key] 

191 return None 

192 

193 # If channel specified by index 

194 if isinstance(channel, int): 

195 # Find numeric-suffixed arrays (ch1, ch2, etc.) 

196 channel_arrays = [k for k in keys if _is_channel_array(k)] 

197 if channel_arrays and channel < len(channel_arrays): 

198 return npz[sorted(channel_arrays)[channel]] 

199 # Or use nth array 

200 data_arrays = [k for k in keys if _is_data_array(k)] 

201 if data_arrays and channel < len(data_arrays): 

202 return npz[data_arrays[channel]] 

203 return None 

204 

205 # Auto-detect: look for common data array names 

206 for name in DATA_ARRAY_NAMES: 

207 if name in keys: 

208 return npz[name] 

209 # Try case-insensitive match 

210 name_lower = name.lower() 

211 for key in keys: 

212 if key.lower() == name_lower: 

213 return npz[key] 

214 

215 # Fall back to first 1D or 2D array 

216 for key in keys: 

217 arr = npz[key] 

218 if isinstance(arr, np.ndarray) and arr.ndim in (1, 2): 

219 # Skip metadata scalars 

220 if arr.size > 10: # Arbitrary threshold 

221 return arr.ravel() if arr.ndim == 2 else arr 

222 

223 return None 

224 

225 

226def _is_channel_array(name: str) -> bool: 

227 """Check if array name looks like a channel (ch1, channel1, etc.).""" 

228 name_lower = name.lower() 

229 return ( 

230 name_lower.startswith("ch") 

231 or name_lower.startswith("channel") 

232 or name_lower.startswith("analog") 

233 ) 

234 

235 

236def _is_data_array(name: str) -> bool: 

237 """Check if array name looks like waveform data.""" 

238 name_lower = name.lower() 

239 return any(data_name in name_lower for data_name in DATA_ARRAY_NAMES) 

240 

241 

242def _find_metadata_value( 

243 npz: np.lib.npyio.NpzFile, 

244 key_names: list[str], 

245) -> float | None: 

246 """Find a metadata value by trying multiple key names. 

247 

248 Args: 

249 npz: Loaded NPZ file. 

250 key_names: List of possible key names to try. 

251 

252 Returns: 

253 Metadata value or None if not found. 

254 """ 

255 keys = list(npz.keys()) 

256 

257 for name in key_names: 

258 # Exact match 

259 if name in keys: 

260 value = npz[name] 

261 if np.isscalar(value): 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true

262 return float(value) # type: ignore[arg-type] 

263 elif isinstance(value, np.ndarray) and value.size == 1: 263 ↛ 267line 263 didn't jump to line 267 because the condition on line 263 was always true

264 return float(value.item()) # type: ignore[arg-type] 

265 

266 # Case-insensitive match 

267 name_lower = name.lower() 

268 for key in keys: 

269 if key.lower() == name_lower: 

270 value = npz[key] 

271 if np.isscalar(value): 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true

272 return float(value) # type: ignore[arg-type] 

273 elif isinstance(value, np.ndarray) and value.size == 1: 273 ↛ 268line 273 didn't jump to line 268 because the condition on line 273 was always true

274 return float(value.item()) # type: ignore[arg-type] 

275 

276 # Check for metadata dict 

277 if "metadata" in keys: 

278 metadata = npz["metadata"] 

279 if isinstance(metadata, np.ndarray): 279 ↛ 289line 279 didn't jump to line 289 because the condition on line 279 was always true

280 try: 

281 meta_dict = metadata.item() 

282 if isinstance(meta_dict, dict): 282 ↛ 289line 282 didn't jump to line 289 because the condition on line 282 was always true

283 for name in key_names: 283 ↛ 289line 283 didn't jump to line 289 because the loop on line 283 didn't complete

284 if name in meta_dict: 284 ↛ 283line 284 didn't jump to line 283 because the condition on line 284 was always true

285 return float(meta_dict[name]) 

286 except (ValueError, TypeError): 

287 pass 

288 

289 return None 

290 

291 

292def _get_channel_name( 

293 npz: np.lib.npyio.NpzFile, 

294 channel: str | int | None, 

295) -> str: 

296 """Get a channel name for the loaded data. 

297 

298 Args: 

299 npz: Loaded NPZ file. 

300 channel: Channel specification. 

301 

302 Returns: 

303 Channel name string. 

304 """ 

305 if isinstance(channel, str): 

306 return channel 

307 elif isinstance(channel, int): 

308 return f"CH{channel + 1}" 

309 

310 # Try to find channel name in metadata 

311 keys = list(npz.keys()) 

312 if "channel_name" in keys: 

313 value = npz["channel_name"] 

314 # NPZ values are always ndarrays 

315 return str(value.item()) 

316 

317 return "CH1" 

318 

319 

320def list_arrays(path: str | PathLike[str]) -> list[str]: 

321 """List all arrays in an NPZ file. 

322 

323 Args: 

324 path: Path to the NPZ file. 

325 

326 Returns: 

327 List of array names. 

328 

329 Raises: 

330 LoaderError: If file not found or cannot be read. 

331 

332 Example: 

333 >>> arrays = list_arrays("multi.npz") 

334 >>> print(arrays) 

335 ['ch1', 'ch2', 'sample_rate'] 

336 """ 

337 path = Path(path) 

338 if not path.exists(): 

339 raise LoaderError("File not found", file_path=str(path)) 

340 

341 try: 

342 with np.load(path, allow_pickle=True) as npz: 

343 return list(npz.keys()) 

344 except Exception as e: 

345 raise LoaderError( 

346 "Failed to read NPZ file", 

347 file_path=str(path), 

348 details=str(e), 

349 ) from e 

350 

351 

352def load_raw_binary( 

353 path: str | PathLike[str], 

354 *, 

355 dtype: str = "float32", 

356 sample_rate: float = 1e6, 

357 mmap: bool = False, 

358 offset: int = 0, 

359 count: int = -1, 

360) -> WaveformTrace: 

361 """Load waveform data from a raw binary file. 

362 

363 

364 Loads raw binary waveform data with optional memory mapping for 

365 files larger than available RAM. 

366 

367 Args: 

368 path: Path to the raw binary file. 

369 dtype: Data type of samples (float32, float64, int16, etc.). 

370 sample_rate: Sample rate in Hz. 

371 mmap: If True, use memory mapping to avoid loading entire file. 

372 offset: Number of elements to skip at start of file. 

373 count: Number of elements to read (-1 = all). 

374 

375 Returns: 

376 WaveformTrace containing the waveform data and metadata. 

377 

378 Raises: 

379 LoaderError: If the file cannot be loaded. 

380 

381 Example: 

382 >>> # Load entire file into memory 

383 >>> trace = load_raw_binary("signal.bin", dtype="float32", sample_rate=1e9) 

384 

385 >>> # Memory-map large file 

386 >>> trace = load_raw_binary("large.bin", dtype="float32", sample_rate=1e9, mmap=True) 

387 >>> # Access subset: trace.data[1000:2000] 

388 

389 >>> # Load only portion of file 

390 >>> trace = load_raw_binary("signal.bin", dtype="int16", offset=1000, count=10000) 

391 """ 

392 path = Path(path) 

393 

394 if not path.exists(): 

395 raise LoaderError("File not found", file_path=str(path)) 

396 

397 try: 

398 data: NDArray[np.float64] | np.memmap[Any, np.dtype[Any]] 

399 if mmap: 

400 # Memory-mapped array (stays on disk) 

401 data = np.memmap( 

402 path, 

403 dtype=dtype, 

404 mode="r", 

405 offset=offset * np.dtype(dtype).itemsize, 

406 shape=(count,) if count > 0 else None, 

407 ) 

408 # Convert to float64 if needed (may copy) 

409 if data.dtype != np.float64: 409 ↛ 420line 409 didn't jump to line 420 because the condition on line 409 was always true

410 # For large files, user should slice before converting 

411 # data = data.astype(np.float64) # This would load entire file! 

412 # Instead, keep original dtype and convert in WaveformTrace 

413 pass 

414 else: 

415 # Load into memory 

416 data_raw = np.fromfile(path, dtype=dtype, count=count, offset=offset) 

417 # Convert to float64 

418 data = data_raw.astype(np.float64) 

419 

420 metadata = TraceMetadata( 

421 sample_rate=sample_rate, 

422 source_file=str(path), 

423 channel_name="RAW", 

424 ) 

425 

426 return WaveformTrace(data=data, metadata=metadata) 

427 

428 except Exception as e: 

429 raise LoaderError( 

430 "Failed to load raw binary file", 

431 file_path=str(path), 

432 details=str(e), 

433 ) from e 

434 

435 

436__all__ = ["list_arrays", "load_npz", "load_raw_binary"]