Coverage for src / tracekit / loaders / csv_loader.py: 93%

199 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""CSV file loader for waveform data. 

2 

3This module provides loading of waveform data from CSV files with 

4automatic header detection and column mapping. 

5 

6 

7Example: 

8 >>> from tracekit.loaders.csv_loader import load_csv 

9 >>> trace = load_csv("oscilloscope_export.csv") 

10 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz") 

11""" 

12 

13from __future__ import annotations 

14 

15import csv 

16from io import StringIO 

17from pathlib import Path 

18from typing import TYPE_CHECKING, Any 

19 

20import numpy as np 

21 

22from tracekit.core.exceptions import FormatError, LoaderError 

23from tracekit.core.types import TraceMetadata, WaveformTrace 

24 

25if TYPE_CHECKING: 

26 from os import PathLike 

27 

28# Try to import pandas for better CSV handling 

29try: 

30 import pandas as pd 

31 

32 PANDAS_AVAILABLE = True 

33except ImportError: 

34 PANDAS_AVAILABLE = False 

35 

36 

37# Common column names for time data 

38TIME_COLUMN_NAMES = [ 

39 "time", 

40 "t", 

41 "time_s", 

42 "time_sec", 

43 "seconds", 

44 "timestamp", 

45 "x", 

46 "Time", 

47 "TIME", 

48] 

49 

50# Common column names for voltage data 

51VOLTAGE_COLUMN_NAMES = [ 

52 "voltage", 

53 "v", 

54 "volt", 

55 "volts", 

56 "amplitude", 

57 "signal", 

58 "y", 

59 "value", 

60 "data", 

61 "ch1", 

62 "ch2", 

63 "ch3", 

64 "ch4", 

65 "channel1", 

66 "channel2", 

67 "Voltage", 

68 "VOLTAGE", 

69] 

70 

71 

72def load_csv( 

73 path: str | PathLike[str], 

74 *, 

75 time_column: str | int | None = None, 

76 voltage_column: str | int | None = None, 

77 sample_rate: float | None = None, 

78 delimiter: str | None = None, 

79 skip_rows: int = 0, 

80 encoding: str = "utf-8", 

81 mmap: bool = False, 

82) -> WaveformTrace | Any: 

83 """Load waveform data from a CSV file. 

84 

85 Parses CSV files exported from oscilloscopes or other data sources. 

86 Automatically detects header rows and maps columns for time and 

87 voltage data. 

88 

89 Args: 

90 path: Path to the CSV file. 

91 time_column: Name or index of time column. If None, auto-detects. 

92 voltage_column: Name or index of voltage column. If None, auto-detects. 

93 sample_rate: Override sample rate. If None, computed from time column. 

94 delimiter: Column delimiter. If None, auto-detects. 

95 skip_rows: Number of rows to skip before header. 

96 encoding: File encoding (default: utf-8). 

97 mmap: If True, return memory-mapped trace for large files. 

98 

99 Returns: 

100 WaveformTrace containing the waveform data and metadata. 

101 If mmap=True, returns MmapWaveformTrace instead. 

102 

103 Raises: 

104 LoaderError: If the file cannot be loaded. 

105 

106 Example: 

107 >>> trace = load_csv("oscilloscope.csv") 

108 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz") 

109 

110 >>> # Specify columns explicitly 

111 >>> trace = load_csv("data.csv", time_column="Time", voltage_column="CH1") 

112 

113 >>> # Load as memory-mapped for large files 

114 >>> trace = load_csv("huge_capture.csv", mmap=True) 

115 """ 

116 path = Path(path) 

117 

118 if not path.exists(): 

119 raise LoaderError( 

120 "File not found", 

121 file_path=str(path), 

122 ) 

123 

124 if PANDAS_AVAILABLE: 124 ↛ 135line 124 didn't jump to line 135 because the condition on line 124 was always true

125 trace = _load_with_pandas( 

126 path, 

127 time_column=time_column, 

128 voltage_column=voltage_column, 

129 sample_rate=sample_rate, 

130 delimiter=delimiter, 

131 skip_rows=skip_rows, 

132 encoding=encoding, 

133 ) 

134 else: 

135 trace = _load_basic( 

136 path, 

137 time_column=time_column, 

138 voltage_column=voltage_column, 

139 sample_rate=sample_rate, 

140 delimiter=delimiter, 

141 skip_rows=skip_rows, 

142 encoding=encoding, 

143 ) 

144 

145 # Convert to memory-mapped if requested 

146 if mmap: 

147 import tempfile 

148 

149 from tracekit.loaders.mmap_loader import load_mmap 

150 

151 # Save data to temporary .npy file for memory mapping 

152 with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp: 

153 tmp_path = Path(tmp.name) 

154 

155 np.save(tmp_path, trace.data) 

156 

157 # Load as memory-mapped trace 

158 return load_mmap( 

159 tmp_path, 

160 sample_rate=trace.metadata.sample_rate, 

161 ) 

162 

163 return trace 

164 

165 

166def _load_with_pandas( 

167 path: Path, 

168 *, 

169 time_column: str | int | None, 

170 voltage_column: str | int | None, 

171 sample_rate: float | None, 

172 delimiter: str | None, 

173 skip_rows: int, 

174 encoding: str, 

175) -> WaveformTrace: 

176 """Load CSV using pandas for better parsing.""" 

177 try: 

178 # Auto-detect delimiter if not specified 

179 if delimiter is None: 

180 delimiter = _detect_delimiter(path, encoding) 

181 

182 # Read CSV with pandas 

183 df = pd.read_csv( 

184 path, 

185 delimiter=delimiter, 

186 skiprows=skip_rows, 

187 encoding=encoding, 

188 engine="python", # More flexible parsing 

189 ) 

190 

191 if df.empty: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true

192 raise FormatError( 

193 "CSV file is empty", 

194 file_path=str(path), 

195 ) 

196 

197 # Find time column 

198 time_data = None 

199 time_col_name = None 

200 

201 if time_column is not None: 

202 if isinstance(time_column, int): 

203 if time_column < len(df.columns): 203 ↛ 219line 203 didn't jump to line 219 because the condition on line 203 was always true

204 time_col_name = df.columns[time_column] 

205 time_data = df.iloc[:, time_column].values 

206 elif time_column in df.columns: 206 ↛ 219line 206 didn't jump to line 219 because the condition on line 206 was always true

207 time_col_name = time_column 

208 time_data = df[time_column].values 

209 else: 

210 # Auto-detect time column 

211 for col in df.columns: 

212 col_lower = col.lower().strip() 

213 if col_lower in [n.lower() for n in TIME_COLUMN_NAMES]: 

214 time_col_name = col 

215 time_data = df[col].values 

216 break 

217 

218 # Find voltage column 

219 voltage_data = None 

220 voltage_col_name = None 

221 

222 if voltage_column is not None: 

223 if isinstance(voltage_column, int): 

224 if voltage_column < len(df.columns): 

225 voltage_col_name = df.columns[voltage_column] 

226 voltage_data = df.iloc[:, voltage_column].values 

227 elif voltage_column in df.columns: 

228 voltage_col_name = voltage_column 

229 voltage_data = df[voltage_column].values 

230 else: 

231 # Auto-detect voltage column (first non-time numeric column) 

232 for col in df.columns: 

233 if col == time_col_name: 

234 continue 

235 col_lower = col.lower().strip() 

236 # Check if numeric 

237 if pd.api.types.is_numeric_dtype(df[col]): 

238 # Prefer columns with voltage-like names 

239 if col_lower in [n.lower() for n in VOLTAGE_COLUMN_NAMES]: 

240 voltage_col_name = col 

241 voltage_data = df[col].values 

242 break 

243 elif voltage_data is None: 

244 voltage_col_name = col 

245 voltage_data = df[col].values 

246 

247 if voltage_data is None: 

248 raise FormatError( 

249 "No voltage data found in CSV", 

250 file_path=str(path), 

251 expected="Numeric column for voltage data", 

252 got=f"Columns: {', '.join(df.columns)}", 

253 ) 

254 

255 # Convert to float64 

256 data = np.asarray(voltage_data, dtype=np.float64) 

257 

258 # Determine sample rate 

259 detected_sample_rate = sample_rate 

260 if detected_sample_rate is None and time_data is not None: 

261 time_data = np.asarray(time_data, dtype=np.float64) 

262 if len(time_data) > 1: 

263 # Calculate sample rate from time intervals 

264 dt = np.median(np.diff(time_data)) 

265 if dt > 0: 

266 detected_sample_rate = 1.0 / dt 

267 

268 if detected_sample_rate is None: 

269 detected_sample_rate = 1e6 # Default to 1 MSa/s 

270 

271 # Build metadata 

272 metadata = TraceMetadata( 

273 sample_rate=detected_sample_rate, 

274 source_file=str(path), 

275 channel_name=voltage_col_name or "CH1", 

276 ) 

277 

278 return WaveformTrace(data=data, metadata=metadata) 

279 

280 except pd.errors.ParserError as e: 

281 raise FormatError( 

282 "Failed to parse CSV file", 

283 file_path=str(path), 

284 details=str(e), 

285 ) from e 

286 except Exception as e: 

287 if isinstance(e, LoaderError | FormatError): 

288 raise 

289 raise LoaderError( 

290 "Failed to load CSV file", 

291 file_path=str(path), 

292 details=str(e), 

293 ) from e 

294 

295 

296def _load_basic( 

297 path: Path, 

298 *, 

299 time_column: str | int | None, 

300 voltage_column: str | int | None, 

301 sample_rate: float | None, 

302 delimiter: str | None, 

303 skip_rows: int, 

304 encoding: str, 

305) -> WaveformTrace: 

306 """Basic CSV loader without pandas.""" 

307 try: 

308 with open(path, encoding=encoding) as f: 

309 # Skip rows 

310 for _ in range(skip_rows): 

311 next(f) 

312 

313 content = f.read() 

314 

315 # Auto-detect delimiter 

316 if delimiter is None: 

317 delimiter = _detect_delimiter_from_content(content) 

318 

319 # Parse CSV 

320 reader = csv.reader(StringIO(content), delimiter=delimiter) 

321 rows = list(reader) 

322 

323 if not rows: 

324 raise FormatError("CSV file is empty", file_path=str(path)) 

325 

326 # Detect header 

327 header = None 

328 data_start = 0 

329 first_row = rows[0] 

330 

331 # Check if first row is a header (contains non-numeric values) 

332 is_header = False 

333 for cell in first_row: 

334 try: 

335 float(cell) 

336 except ValueError: 

337 if cell.strip(): # Non-empty, non-numeric 337 ↛ 333line 337 didn't jump to line 333 because the condition on line 337 was always true

338 is_header = True 

339 break 

340 

341 if is_header: 

342 header = [cell.strip() for cell in first_row] 

343 data_start = 1 

344 

345 # Determine column indices 

346 time_idx = None 

347 voltage_idx = None 

348 

349 if header: 

350 # Find columns by name 

351 if time_column is not None: 

352 if isinstance(time_column, int): 

353 time_idx = time_column 

354 elif time_column in header: 354 ↛ 363line 354 didn't jump to line 363 because the condition on line 354 was always true

355 time_idx = header.index(time_column) 

356 else: 

357 # Auto-detect 

358 for i, col in enumerate(header): 

359 if col.lower() in [n.lower() for n in TIME_COLUMN_NAMES]: 

360 time_idx = i 

361 break 

362 

363 if voltage_column is not None: 

364 if isinstance(voltage_column, int): 

365 voltage_idx = voltage_column 

366 elif voltage_column in header: 366 ↛ 391line 366 didn't jump to line 391 because the condition on line 366 was always true

367 voltage_idx = header.index(voltage_column) 

368 else: 

369 # Auto-detect (first column that's not time) 

370 for i, col in enumerate(header): 

371 if i == time_idx: 

372 continue 

373 if col.lower() in [n.lower() for n in VOLTAGE_COLUMN_NAMES]: 

374 voltage_idx = i 

375 break 

376 if voltage_idx is None: 

377 voltage_idx = 1 if time_idx == 0 else 0 

378 else: 

379 # No header - use indices 

380 if isinstance(time_column, int): 380 ↛ 383line 380 didn't jump to line 383 because the condition on line 380 was always true

381 time_idx = time_column 

382 else: 

383 time_idx = 0 # Assume first column is time 

384 

385 if isinstance(voltage_column, int): 385 ↛ 388line 385 didn't jump to line 388 because the condition on line 385 was always true

386 voltage_idx = voltage_column 

387 else: 

388 voltage_idx = 1 # Assume second column is voltage 

389 

390 # Extract data 

391 time_data = [] 

392 voltage_data = [] 

393 

394 for row in rows[data_start:]: 

395 if not row: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true

396 continue 

397 try: 

398 if voltage_idx is not None and voltage_idx < len(row): 398 ↛ 394line 398 didn't jump to line 394 because the condition on line 398 was always true

399 voltage_data.append(float(row[voltage_idx])) 

400 if time_idx is not None and time_idx < len(row): 

401 time_data.append(float(row[time_idx])) 

402 except (ValueError, IndexError): 

403 continue # Skip malformed rows 

404 

405 if not voltage_data: 

406 raise FormatError( 

407 "No valid voltage data found in CSV", 

408 file_path=str(path), 

409 ) 

410 

411 data = np.array(voltage_data, dtype=np.float64) 

412 

413 # Determine sample rate 

414 detected_sample_rate = sample_rate 

415 if detected_sample_rate is None and time_data: 

416 time_arr = np.array(time_data, dtype=np.float64) 

417 if len(time_arr) > 1: 417 ↛ 422line 417 didn't jump to line 422 because the condition on line 417 was always true

418 dt = np.median(np.diff(time_arr)) 

419 if dt > 0: 419 ↛ 422line 419 didn't jump to line 422 because the condition on line 419 was always true

420 detected_sample_rate = 1.0 / dt 

421 

422 if detected_sample_rate is None: 

423 detected_sample_rate = 1e6 

424 

425 # Channel name 

426 channel_name = "CH1" 

427 if header and voltage_idx is not None and voltage_idx < len(header): 

428 channel_name = header[voltage_idx] 

429 

430 metadata = TraceMetadata( 

431 sample_rate=detected_sample_rate, 

432 source_file=str(path), 

433 channel_name=channel_name, 

434 ) 

435 

436 return WaveformTrace(data=data, metadata=metadata) 

437 

438 except Exception as e: 

439 if isinstance(e, LoaderError | FormatError): 439 ↛ 441line 439 didn't jump to line 441 because the condition on line 439 was always true

440 raise 

441 raise LoaderError( 

442 "Failed to load CSV file", 

443 file_path=str(path), 

444 details=str(e), 

445 ) from e 

446 

447 

448def _detect_delimiter(path: Path, encoding: str) -> str: 

449 """Detect the delimiter used in a CSV file.""" 

450 try: 

451 with open(path, encoding=encoding) as f: 

452 sample = f.read(4096) 

453 return _detect_delimiter_from_content(sample) 

454 except Exception: 

455 return "," 

456 

457 

458def _detect_delimiter_from_content(content: str) -> str: 

459 """Detect delimiter from CSV content.""" 

460 # Try common delimiters and count occurrences 

461 delimiters = [",", "\t", ";", "|", " "] 

462 counts: dict[str, int] = {} 

463 

464 for delim in delimiters: 

465 counts[delim] = content.count(delim) 

466 

467 # Return the most common delimiter 

468 if counts: 468 ↛ 470line 468 didn't jump to line 470 because the condition on line 468 was always true

469 return max(counts, key=lambda d: counts[d]) 

470 return "," 

471 

472 

473__all__ = ["load_csv"]