Coverage for src / tracekit / loaders / wav.py: 96%
73 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""WAV audio file loader.
3This module provides loading of WAV audio files using scipy.io.wavfile.
4WAV files are useful for audio signal analysis and can contain
5oscilloscope data recorded as audio.
8Example:
9 >>> from tracekit.loaders.wav import load_wav
10 >>> trace = load_wav("recording.wav")
11 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz")
12"""
14from __future__ import annotations
16from pathlib import Path
17from typing import TYPE_CHECKING
19import numpy as np
20from scipy.io import wavfile
22from tracekit.core.exceptions import FormatError, LoaderError
23from tracekit.core.types import TraceMetadata, WaveformTrace
25if TYPE_CHECKING:
26 from os import PathLike
29def load_wav(
30 path: str | PathLike[str],
31 *,
32 channel: int | str | None = None,
33 normalize: bool = True,
34) -> WaveformTrace:
35 """Load a WAV audio file.
37 Extracts audio samples and sample rate from WAV files. Supports
38 mono and stereo files, with automatic normalization to [-1, 1] range.
40 Args:
41 path: Path to the WAV file.
42 channel: Channel to load for stereo files. Can be:
43 - 0 or "left": Left channel
44 - 1 or "right": Right channel
45 - "mono" or "mix": Average of both channels
46 - None: First channel (left for stereo)
47 normalize: If True, normalize samples to [-1, 1] range.
48 Default is True.
50 Returns:
51 WaveformTrace containing the audio data and metadata.
53 Raises:
54 LoaderError: If the file cannot be loaded.
55 FormatError: If the file is not a valid WAV file.
57 Example:
58 >>> trace = load_wav("recording.wav")
59 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz")
60 >>> print(f"Duration: {trace.duration:.2f} seconds")
62 >>> # Load right channel of stereo file
63 >>> trace = load_wav("stereo.wav", channel="right")
65 References:
66 WAV file format: https://en.wikipedia.org/wiki/WAV
67 """
68 path = Path(path)
70 if not path.exists():
71 raise LoaderError(
72 "File not found",
73 file_path=str(path),
74 )
76 try:
77 sample_rate, data = wavfile.read(str(path))
78 except ValueError as e:
79 raise FormatError(
80 "Invalid WAV file format",
81 file_path=str(path),
82 expected="Valid WAV audio file",
83 ) from e
84 except Exception as e:
85 raise LoaderError(
86 "Failed to read WAV file",
87 file_path=str(path),
88 details=str(e),
89 ) from e
91 # Handle stereo/multichannel files
92 if data.ndim == 2:
93 n_channels = data.shape[1]
94 channel_names = (
95 ["left", "right"] if n_channels == 2 else [f"ch{i}" for i in range(n_channels)]
96 )
98 if channel is None:
99 # Default to first channel
100 audio_data = data[:, 0]
101 channel_name = channel_names[0]
102 elif isinstance(channel, int):
103 if channel < 0 or channel >= n_channels:
104 raise LoaderError(
105 f"Channel index {channel} out of range",
106 file_path=str(path),
107 details=f"Available channels: 0-{n_channels - 1}",
108 )
109 audio_data = data[:, channel]
110 channel_name = (
111 channel_names[channel] if channel < len(channel_names) else f"ch{channel}"
112 )
113 elif isinstance(channel, str): 113 ↛ 132line 113 didn't jump to line 132 because the condition on line 113 was always true
114 channel_lower = channel.lower()
115 if channel_lower in ("left", "l", "0"):
116 audio_data = data[:, 0]
117 channel_name = "left"
118 elif channel_lower in ("right", "r", "1") and n_channels >= 2:
119 audio_data = data[:, 1]
120 channel_name = "right"
121 elif channel_lower in ("mono", "mix", "avg"):
122 # Average all channels
123 audio_data = np.mean(data, axis=1)
124 channel_name = "mono"
125 else:
126 raise LoaderError(
127 f"Invalid channel specifier: '{channel}'",
128 file_path=str(path),
129 details="Use 'left', 'right', 'mono', or channel index",
130 )
131 else:
132 audio_data = data[:, 0] # type: ignore[unreachable]
133 channel_name = channel_names[0]
134 else:
135 # Mono file
136 if channel is not None and isinstance(channel, int) and channel != 0:
137 raise LoaderError(
138 f"Channel index {channel} out of range",
139 file_path=str(path),
140 details="File is mono (only channel 0 available)",
141 )
142 audio_data = data
143 channel_name = "mono"
145 # Convert to float64
146 audio_data = audio_data.astype(np.float64)
148 # Normalize based on original dtype
149 if normalize:
150 if data.dtype == np.int16:
151 audio_data = audio_data / 32768.0
152 elif data.dtype == np.int32:
153 audio_data = audio_data / 2147483648.0
154 elif data.dtype == np.uint8:
155 audio_data = (audio_data - 128.0) / 128.0
156 elif data.dtype in (np.float32, np.float64): 156 ↛ 164line 156 didn't jump to line 164 because the condition on line 156 was always true
157 # Already in float format, typically [-1, 1]
158 # Clip to ensure range
159 max_val = np.max(np.abs(audio_data))
160 if max_val > 1.0:
161 audio_data = audio_data / max_val
163 # Build metadata
164 metadata = TraceMetadata(
165 sample_rate=float(sample_rate),
166 source_file=str(path),
167 channel_name=channel_name,
168 trigger_info={
169 "original_dtype": str(data.dtype),
170 "n_channels": data.shape[1] if data.ndim == 2 else 1,
171 "normalized": normalize,
172 },
173 )
175 return WaveformTrace(data=audio_data, metadata=metadata)
178def get_wav_info(
179 path: str | PathLike[str],
180) -> dict: # type: ignore[type-arg]
181 """Get WAV file information without loading all data.
183 Args:
184 path: Path to the WAV file.
186 Returns:
187 Dictionary with file information:
188 - sample_rate: Sample rate in Hz
189 - n_channels: Number of channels
190 - n_samples: Number of samples per channel
191 - duration: Duration in seconds
192 - dtype: Sample data type
194 Raises:
195 LoaderError: If the file cannot be read.
197 Example:
198 >>> info = get_wav_info("recording.wav")
199 >>> print(f"Duration: {info['duration']:.2f}s")
200 >>> print(f"Channels: {info['n_channels']}")
201 """
202 path = Path(path)
204 if not path.exists():
205 raise LoaderError(
206 "File not found",
207 file_path=str(path),
208 )
210 try:
211 sample_rate, data = wavfile.read(str(path))
213 n_samples = data.shape[0]
214 n_channels = data.shape[1] if data.ndim == 2 else 1
215 duration = n_samples / sample_rate
217 return {
218 "sample_rate": sample_rate,
219 "n_channels": n_channels,
220 "n_samples": n_samples,
221 "duration": duration,
222 "dtype": str(data.dtype),
223 }
225 except Exception as e:
226 raise LoaderError(
227 "Failed to read WAV file info",
228 file_path=str(path),
229 details=str(e),
230 ) from e
233__all__ = ["get_wav_info", "load_wav"]