Coverage for src / tracekit / loaders / numpy_loader.py: 91%
146 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""NumPy NPZ file loader.
3This module provides loading of waveform data from NumPy .npz archive files.
6Example:
7 >>> from tracekit.loaders.numpy_loader import load_npz
8 >>> trace = load_npz("waveform.npz")
9 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz")
10"""
12from __future__ import annotations
14from pathlib import Path
15from typing import TYPE_CHECKING, Any
17import numpy as np
18from numpy.typing import NDArray
20from tracekit.core.exceptions import FormatError, LoaderError
21from tracekit.core.types import TraceMetadata, WaveformTrace
23if TYPE_CHECKING:
24 from os import PathLike
27# Common array names for waveform data
28DATA_ARRAY_NAMES = ["data", "waveform", "signal", "samples", "y", "voltage"]
30# Common metadata keys
31SAMPLE_RATE_KEYS = ["sample_rate", "samplerate", "fs", "sampling_rate", "rate"]
32VERTICAL_SCALE_KEYS = ["vertical_scale", "v_scale", "scale", "volts_per_div"]
33VERTICAL_OFFSET_KEYS = ["vertical_offset", "v_offset", "offset"]
36def load_npz(
37 path: str | PathLike[str],
38 *,
39 channel: str | int | None = None,
40 sample_rate: float | None = None,
41 mmap: bool = False,
42) -> WaveformTrace:
43 """Load waveform data from a NumPy NPZ archive.
46 Extracts waveform array and metadata from an NPZ file. The function
47 looks for common array names like 'data', 'waveform', 'signal', etc.
49 Args:
50 path: Path to the .npz file.
51 channel: Specific array name or index to load. If None, auto-detects.
52 sample_rate: Override sample rate (if not found in file metadata).
53 mmap: If True, use memory mapping to avoid loading entire file into RAM.
54 Data stays on disk until accessed. Useful for very large files.
56 Returns:
57 WaveformTrace containing the waveform data and metadata.
59 Raises:
60 LoaderError: If the file cannot be loaded.
61 FormatError: If no valid waveform data is found.
63 Example:
64 >>> trace = load_npz("waveform.npz")
65 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz")
67 >>> # Load specific channel from multi-channel file
68 >>> trace = load_npz("multi.npz", channel="ch1")
70 >>> # Memory-map large file to avoid loading all into RAM
71 >>> trace = load_npz("large.npz", mmap=True)
72 >>> # Access only what you need: trace.data[1000:2000]
74 Security Warning:
75 NPZ files may contain pickled Python objects. Only load NPZ files from
76 trusted sources. Loading malicious NPZ files could execute arbitrary
77 code. For untrusted data, prefer formats like plain NumPy arrays (.npy),
78 CSV, or HDF5.
79 """
80 path = Path(path)
82 if not path.exists():
83 raise LoaderError(
84 "File not found",
85 file_path=str(path),
86 )
88 try:
89 npz = np.load(path, allow_pickle=True, mmap_mode="r" if mmap else None)
90 except Exception as e:
91 raise LoaderError(
92 "Failed to load NPZ file",
93 file_path=str(path),
94 details=str(e),
95 ) from e
97 try:
98 # Find waveform data array
99 data_array = _find_data_array(npz, channel)
101 if data_array is None:
102 available = list(npz.keys())
103 raise FormatError(
104 "No waveform data found in NPZ file",
105 file_path=str(path),
106 expected=f"Array named: {', '.join(DATA_ARRAY_NAMES)}",
107 got=f"Arrays: {', '.join(available)}",
108 )
110 # Convert to float64 (keep mmap if enabled)
111 if mmap and isinstance(data_array, np.memmap): 111 ↛ 113line 111 didn't jump to line 113 because the condition on line 111 was never true
112 # Keep as memmap, just ensure float64 dtype
113 if data_array.dtype != np.float64:
114 # For memmap, we need to copy to convert dtype
115 try:
116 data = data_array.astype(np.float64)
117 except (ValueError, TypeError) as e:
118 raise FormatError(
119 "Data array is not numeric",
120 file_path=str(path),
121 expected="Numeric dtype (int, float)",
122 got=f"{data_array.dtype}",
123 ) from e
124 else:
125 data = data_array
126 else:
127 try:
128 data = data_array.astype(np.float64)
129 except (ValueError, TypeError) as e:
130 raise FormatError(
131 "Data array is not numeric",
132 file_path=str(path),
133 expected="Numeric dtype (int, float)",
134 got=f"{data_array.dtype}",
135 ) from e
137 # Extract metadata
138 detected_sample_rate = _find_metadata_value(npz, SAMPLE_RATE_KEYS)
139 detected_vertical_scale = _find_metadata_value(npz, VERTICAL_SCALE_KEYS)
140 detected_vertical_offset = _find_metadata_value(npz, VERTICAL_OFFSET_KEYS)
142 # Use provided sample_rate if specified
143 if sample_rate is not None:
144 detected_sample_rate = sample_rate
145 elif detected_sample_rate is None:
146 detected_sample_rate = 1e6 # Default to 1 MSa/s
148 # Build metadata
149 metadata = TraceMetadata(
150 sample_rate=float(detected_sample_rate),
151 vertical_scale=float(detected_vertical_scale)
152 if detected_vertical_scale is not None
153 else None,
154 vertical_offset=float(detected_vertical_offset)
155 if detected_vertical_offset is not None
156 else None,
157 source_file=str(path),
158 channel_name=_get_channel_name(npz, channel),
159 )
161 return WaveformTrace(data=data, metadata=metadata)
163 finally:
164 npz.close()
167def _find_data_array(
168 npz: np.lib.npyio.NpzFile,
169 channel: str | int | None,
170) -> NDArray[np.float64] | None:
171 """Find the waveform data array in the NPZ file.
173 Args:
174 npz: Loaded NPZ file.
175 channel: Specific channel name or index.
177 Returns:
178 Waveform data array or None if not found.
179 """
180 keys = list(npz.keys())
182 # If channel specified by name
183 if isinstance(channel, str):
184 if channel in keys:
185 return npz[channel]
186 # Try case-insensitive match
187 channel_lower = channel.lower()
188 for key in keys:
189 if key.lower() == channel_lower:
190 return npz[key]
191 return None
193 # If channel specified by index
194 if isinstance(channel, int):
195 # Find numeric-suffixed arrays (ch1, ch2, etc.)
196 channel_arrays = [k for k in keys if _is_channel_array(k)]
197 if channel_arrays and channel < len(channel_arrays):
198 return npz[sorted(channel_arrays)[channel]]
199 # Or use nth array
200 data_arrays = [k for k in keys if _is_data_array(k)]
201 if data_arrays and channel < len(data_arrays):
202 return npz[data_arrays[channel]]
203 return None
205 # Auto-detect: look for common data array names
206 for name in DATA_ARRAY_NAMES:
207 if name in keys:
208 return npz[name]
209 # Try case-insensitive match
210 name_lower = name.lower()
211 for key in keys:
212 if key.lower() == name_lower:
213 return npz[key]
215 # Fall back to first 1D or 2D array
216 for key in keys:
217 arr = npz[key]
218 if isinstance(arr, np.ndarray) and arr.ndim in (1, 2):
219 # Skip metadata scalars
220 if arr.size > 10: # Arbitrary threshold
221 return arr.ravel() if arr.ndim == 2 else arr
223 return None
226def _is_channel_array(name: str) -> bool:
227 """Check if array name looks like a channel (ch1, channel1, etc.)."""
228 name_lower = name.lower()
229 return (
230 name_lower.startswith("ch")
231 or name_lower.startswith("channel")
232 or name_lower.startswith("analog")
233 )
236def _is_data_array(name: str) -> bool:
237 """Check if array name looks like waveform data."""
238 name_lower = name.lower()
239 return any(data_name in name_lower for data_name in DATA_ARRAY_NAMES)
242def _find_metadata_value(
243 npz: np.lib.npyio.NpzFile,
244 key_names: list[str],
245) -> float | None:
246 """Find a metadata value by trying multiple key names.
248 Args:
249 npz: Loaded NPZ file.
250 key_names: List of possible key names to try.
252 Returns:
253 Metadata value or None if not found.
254 """
255 keys = list(npz.keys())
257 for name in key_names:
258 # Exact match
259 if name in keys:
260 value = npz[name]
261 if np.isscalar(value): 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true
262 return float(value) # type: ignore[arg-type]
263 elif isinstance(value, np.ndarray) and value.size == 1: 263 ↛ 267line 263 didn't jump to line 267 because the condition on line 263 was always true
264 return float(value.item()) # type: ignore[arg-type]
266 # Case-insensitive match
267 name_lower = name.lower()
268 for key in keys:
269 if key.lower() == name_lower:
270 value = npz[key]
271 if np.isscalar(value): 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true
272 return float(value) # type: ignore[arg-type]
273 elif isinstance(value, np.ndarray) and value.size == 1: 273 ↛ 268line 273 didn't jump to line 268 because the condition on line 273 was always true
274 return float(value.item()) # type: ignore[arg-type]
276 # Check for metadata dict
277 if "metadata" in keys:
278 metadata = npz["metadata"]
279 if isinstance(metadata, np.ndarray): 279 ↛ 289line 279 didn't jump to line 289 because the condition on line 279 was always true
280 try:
281 meta_dict = metadata.item()
282 if isinstance(meta_dict, dict): 282 ↛ 289line 282 didn't jump to line 289 because the condition on line 282 was always true
283 for name in key_names: 283 ↛ 289line 283 didn't jump to line 289 because the loop on line 283 didn't complete
284 if name in meta_dict: 284 ↛ 283line 284 didn't jump to line 283 because the condition on line 284 was always true
285 return float(meta_dict[name])
286 except (ValueError, TypeError):
287 pass
289 return None
292def _get_channel_name(
293 npz: np.lib.npyio.NpzFile,
294 channel: str | int | None,
295) -> str:
296 """Get a channel name for the loaded data.
298 Args:
299 npz: Loaded NPZ file.
300 channel: Channel specification.
302 Returns:
303 Channel name string.
304 """
305 if isinstance(channel, str):
306 return channel
307 elif isinstance(channel, int):
308 return f"CH{channel + 1}"
310 # Try to find channel name in metadata
311 keys = list(npz.keys())
312 if "channel_name" in keys:
313 value = npz["channel_name"]
314 # NPZ values are always ndarrays
315 return str(value.item())
317 return "CH1"
320def list_arrays(path: str | PathLike[str]) -> list[str]:
321 """List all arrays in an NPZ file.
323 Args:
324 path: Path to the NPZ file.
326 Returns:
327 List of array names.
329 Raises:
330 LoaderError: If file not found or cannot be read.
332 Example:
333 >>> arrays = list_arrays("multi.npz")
334 >>> print(arrays)
335 ['ch1', 'ch2', 'sample_rate']
336 """
337 path = Path(path)
338 if not path.exists():
339 raise LoaderError("File not found", file_path=str(path))
341 try:
342 with np.load(path, allow_pickle=True) as npz:
343 return list(npz.keys())
344 except Exception as e:
345 raise LoaderError(
346 "Failed to read NPZ file",
347 file_path=str(path),
348 details=str(e),
349 ) from e
352def load_raw_binary(
353 path: str | PathLike[str],
354 *,
355 dtype: str = "float32",
356 sample_rate: float = 1e6,
357 mmap: bool = False,
358 offset: int = 0,
359 count: int = -1,
360) -> WaveformTrace:
361 """Load waveform data from a raw binary file.
364 Loads raw binary waveform data with optional memory mapping for
365 files larger than available RAM.
367 Args:
368 path: Path to the raw binary file.
369 dtype: Data type of samples (float32, float64, int16, etc.).
370 sample_rate: Sample rate in Hz.
371 mmap: If True, use memory mapping to avoid loading entire file.
372 offset: Number of elements to skip at start of file.
373 count: Number of elements to read (-1 = all).
375 Returns:
376 WaveformTrace containing the waveform data and metadata.
378 Raises:
379 LoaderError: If the file cannot be loaded.
381 Example:
382 >>> # Load entire file into memory
383 >>> trace = load_raw_binary("signal.bin", dtype="float32", sample_rate=1e9)
385 >>> # Memory-map large file
386 >>> trace = load_raw_binary("large.bin", dtype="float32", sample_rate=1e9, mmap=True)
387 >>> # Access subset: trace.data[1000:2000]
389 >>> # Load only portion of file
390 >>> trace = load_raw_binary("signal.bin", dtype="int16", offset=1000, count=10000)
391 """
392 path = Path(path)
394 if not path.exists():
395 raise LoaderError("File not found", file_path=str(path))
397 try:
398 data: NDArray[np.float64] | np.memmap[Any, np.dtype[Any]]
399 if mmap:
400 # Memory-mapped array (stays on disk)
401 data = np.memmap(
402 path,
403 dtype=dtype,
404 mode="r",
405 offset=offset * np.dtype(dtype).itemsize,
406 shape=(count,) if count > 0 else None,
407 )
408 # Convert to float64 if needed (may copy)
409 if data.dtype != np.float64: 409 ↛ 420line 409 didn't jump to line 420 because the condition on line 409 was always true
410 # For large files, user should slice before converting
411 # data = data.astype(np.float64) # This would load entire file!
412 # Instead, keep original dtype and convert in WaveformTrace
413 pass
414 else:
415 # Load into memory
416 data_raw = np.fromfile(path, dtype=dtype, count=count, offset=offset)
417 # Convert to float64
418 data = data_raw.astype(np.float64)
420 metadata = TraceMetadata(
421 sample_rate=sample_rate,
422 source_file=str(path),
423 channel_name="RAW",
424 )
426 return WaveformTrace(data=data, metadata=metadata)
428 except Exception as e:
429 raise LoaderError(
430 "Failed to load raw binary file",
431 file_path=str(path),
432 details=str(e),
433 ) from e
436__all__ = ["list_arrays", "load_npz", "load_raw_binary"]