Coverage for src / tracekit / loaders / wav.py: 96%

73 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""WAV audio file loader. 

2 

3This module provides loading of WAV audio files using scipy.io.wavfile. 

4WAV files are useful for audio signal analysis and can contain 

5oscilloscope data recorded as audio. 

6 

7 

8Example: 

9 >>> from tracekit.loaders.wav import load_wav 

10 >>> trace = load_wav("recording.wav") 

11 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz") 

12""" 

13 

14from __future__ import annotations 

15 

16from pathlib import Path 

17from typing import TYPE_CHECKING 

18 

19import numpy as np 

20from scipy.io import wavfile 

21 

22from tracekit.core.exceptions import FormatError, LoaderError 

23from tracekit.core.types import TraceMetadata, WaveformTrace 

24 

25if TYPE_CHECKING: 

26 from os import PathLike 

27 

28 

29def load_wav( 

30 path: str | PathLike[str], 

31 *, 

32 channel: int | str | None = None, 

33 normalize: bool = True, 

34) -> WaveformTrace: 

35 """Load a WAV audio file. 

36 

37 Extracts audio samples and sample rate from WAV files. Supports 

38 mono and stereo files, with automatic normalization to [-1, 1] range. 

39 

40 Args: 

41 path: Path to the WAV file. 

42 channel: Channel to load for stereo files. Can be: 

43 - 0 or "left": Left channel 

44 - 1 or "right": Right channel 

45 - "mono" or "mix": Average of both channels 

46 - None: First channel (left for stereo) 

47 normalize: If True, normalize samples to [-1, 1] range. 

48 Default is True. 

49 

50 Returns: 

51 WaveformTrace containing the audio data and metadata. 

52 

53 Raises: 

54 LoaderError: If the file cannot be loaded. 

55 FormatError: If the file is not a valid WAV file. 

56 

57 Example: 

58 >>> trace = load_wav("recording.wav") 

59 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz") 

60 >>> print(f"Duration: {trace.duration:.2f} seconds") 

61 

62 >>> # Load right channel of stereo file 

63 >>> trace = load_wav("stereo.wav", channel="right") 

64 

65 References: 

66 WAV file format: https://en.wikipedia.org/wiki/WAV 

67 """ 

68 path = Path(path) 

69 

70 if not path.exists(): 

71 raise LoaderError( 

72 "File not found", 

73 file_path=str(path), 

74 ) 

75 

76 try: 

77 sample_rate, data = wavfile.read(str(path)) 

78 except ValueError as e: 

79 raise FormatError( 

80 "Invalid WAV file format", 

81 file_path=str(path), 

82 expected="Valid WAV audio file", 

83 ) from e 

84 except Exception as e: 

85 raise LoaderError( 

86 "Failed to read WAV file", 

87 file_path=str(path), 

88 details=str(e), 

89 ) from e 

90 

91 # Handle stereo/multichannel files 

92 if data.ndim == 2: 

93 n_channels = data.shape[1] 

94 channel_names = ( 

95 ["left", "right"] if n_channels == 2 else [f"ch{i}" for i in range(n_channels)] 

96 ) 

97 

98 if channel is None: 

99 # Default to first channel 

100 audio_data = data[:, 0] 

101 channel_name = channel_names[0] 

102 elif isinstance(channel, int): 

103 if channel < 0 or channel >= n_channels: 

104 raise LoaderError( 

105 f"Channel index {channel} out of range", 

106 file_path=str(path), 

107 details=f"Available channels: 0-{n_channels - 1}", 

108 ) 

109 audio_data = data[:, channel] 

110 channel_name = ( 

111 channel_names[channel] if channel < len(channel_names) else f"ch{channel}" 

112 ) 

113 elif isinstance(channel, str): 113 ↛ 132line 113 didn't jump to line 132 because the condition on line 113 was always true

114 channel_lower = channel.lower() 

115 if channel_lower in ("left", "l", "0"): 

116 audio_data = data[:, 0] 

117 channel_name = "left" 

118 elif channel_lower in ("right", "r", "1") and n_channels >= 2: 

119 audio_data = data[:, 1] 

120 channel_name = "right" 

121 elif channel_lower in ("mono", "mix", "avg"): 

122 # Average all channels 

123 audio_data = np.mean(data, axis=1) 

124 channel_name = "mono" 

125 else: 

126 raise LoaderError( 

127 f"Invalid channel specifier: '{channel}'", 

128 file_path=str(path), 

129 details="Use 'left', 'right', 'mono', or channel index", 

130 ) 

131 else: 

132 audio_data = data[:, 0] # type: ignore[unreachable] 

133 channel_name = channel_names[0] 

134 else: 

135 # Mono file 

136 if channel is not None and isinstance(channel, int) and channel != 0: 

137 raise LoaderError( 

138 f"Channel index {channel} out of range", 

139 file_path=str(path), 

140 details="File is mono (only channel 0 available)", 

141 ) 

142 audio_data = data 

143 channel_name = "mono" 

144 

145 # Convert to float64 

146 audio_data = audio_data.astype(np.float64) 

147 

148 # Normalize based on original dtype 

149 if normalize: 

150 if data.dtype == np.int16: 

151 audio_data = audio_data / 32768.0 

152 elif data.dtype == np.int32: 

153 audio_data = audio_data / 2147483648.0 

154 elif data.dtype == np.uint8: 

155 audio_data = (audio_data - 128.0) / 128.0 

156 elif data.dtype in (np.float32, np.float64): 156 ↛ 164line 156 didn't jump to line 164 because the condition on line 156 was always true

157 # Already in float format, typically [-1, 1] 

158 # Clip to ensure range 

159 max_val = np.max(np.abs(audio_data)) 

160 if max_val > 1.0: 

161 audio_data = audio_data / max_val 

162 

163 # Build metadata 

164 metadata = TraceMetadata( 

165 sample_rate=float(sample_rate), 

166 source_file=str(path), 

167 channel_name=channel_name, 

168 trigger_info={ 

169 "original_dtype": str(data.dtype), 

170 "n_channels": data.shape[1] if data.ndim == 2 else 1, 

171 "normalized": normalize, 

172 }, 

173 ) 

174 

175 return WaveformTrace(data=audio_data, metadata=metadata) 

176 

177 

178def get_wav_info( 

179 path: str | PathLike[str], 

180) -> dict: # type: ignore[type-arg] 

181 """Get WAV file information without loading all data. 

182 

183 Args: 

184 path: Path to the WAV file. 

185 

186 Returns: 

187 Dictionary with file information: 

188 - sample_rate: Sample rate in Hz 

189 - n_channels: Number of channels 

190 - n_samples: Number of samples per channel 

191 - duration: Duration in seconds 

192 - dtype: Sample data type 

193 

194 Raises: 

195 LoaderError: If the file cannot be read. 

196 

197 Example: 

198 >>> info = get_wav_info("recording.wav") 

199 >>> print(f"Duration: {info['duration']:.2f}s") 

200 >>> print(f"Channels: {info['n_channels']}") 

201 """ 

202 path = Path(path) 

203 

204 if not path.exists(): 

205 raise LoaderError( 

206 "File not found", 

207 file_path=str(path), 

208 ) 

209 

210 try: 

211 sample_rate, data = wavfile.read(str(path)) 

212 

213 n_samples = data.shape[0] 

214 n_channels = data.shape[1] if data.ndim == 2 else 1 

215 duration = n_samples / sample_rate 

216 

217 return { 

218 "sample_rate": sample_rate, 

219 "n_channels": n_channels, 

220 "n_samples": n_samples, 

221 "duration": duration, 

222 "dtype": str(data.dtype), 

223 } 

224 

225 except Exception as e: 

226 raise LoaderError( 

227 "Failed to read WAV file info", 

228 file_path=str(path), 

229 details=str(e), 

230 ) from e 

231 

232 

233__all__ = ["get_wav_info", "load_wav"]