Coverage for src / tracekit / loaders / csv_loader.py: 93%
199 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""CSV file loader for waveform data.
3This module provides loading of waveform data from CSV files with
4automatic header detection and column mapping.
7Example:
8 >>> from tracekit.loaders.csv_loader import load_csv
9 >>> trace = load_csv("oscilloscope_export.csv")
10 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz")
11"""
13from __future__ import annotations
15import csv
16from io import StringIO
17from pathlib import Path
18from typing import TYPE_CHECKING, Any
20import numpy as np
22from tracekit.core.exceptions import FormatError, LoaderError
23from tracekit.core.types import TraceMetadata, WaveformTrace
25if TYPE_CHECKING:
26 from os import PathLike
28# Try to import pandas for better CSV handling
29try:
30 import pandas as pd
32 PANDAS_AVAILABLE = True
33except ImportError:
34 PANDAS_AVAILABLE = False
37# Common column names for time data
38TIME_COLUMN_NAMES = [
39 "time",
40 "t",
41 "time_s",
42 "time_sec",
43 "seconds",
44 "timestamp",
45 "x",
46 "Time",
47 "TIME",
48]
50# Common column names for voltage data
51VOLTAGE_COLUMN_NAMES = [
52 "voltage",
53 "v",
54 "volt",
55 "volts",
56 "amplitude",
57 "signal",
58 "y",
59 "value",
60 "data",
61 "ch1",
62 "ch2",
63 "ch3",
64 "ch4",
65 "channel1",
66 "channel2",
67 "Voltage",
68 "VOLTAGE",
69]
72def load_csv(
73 path: str | PathLike[str],
74 *,
75 time_column: str | int | None = None,
76 voltage_column: str | int | None = None,
77 sample_rate: float | None = None,
78 delimiter: str | None = None,
79 skip_rows: int = 0,
80 encoding: str = "utf-8",
81 mmap: bool = False,
82) -> WaveformTrace | Any:
83 """Load waveform data from a CSV file.
85 Parses CSV files exported from oscilloscopes or other data sources.
86 Automatically detects header rows and maps columns for time and
87 voltage data.
89 Args:
90 path: Path to the CSV file.
91 time_column: Name or index of time column. If None, auto-detects.
92 voltage_column: Name or index of voltage column. If None, auto-detects.
93 sample_rate: Override sample rate. If None, computed from time column.
94 delimiter: Column delimiter. If None, auto-detects.
95 skip_rows: Number of rows to skip before header.
96 encoding: File encoding (default: utf-8).
97 mmap: If True, return memory-mapped trace for large files.
99 Returns:
100 WaveformTrace containing the waveform data and metadata.
101 If mmap=True, returns MmapWaveformTrace instead.
103 Raises:
104 LoaderError: If the file cannot be loaded.
106 Example:
107 >>> trace = load_csv("oscilloscope.csv")
108 >>> print(f"Sample rate: {trace.metadata.sample_rate} Hz")
110 >>> # Specify columns explicitly
111 >>> trace = load_csv("data.csv", time_column="Time", voltage_column="CH1")
113 >>> # Load as memory-mapped for large files
114 >>> trace = load_csv("huge_capture.csv", mmap=True)
115 """
116 path = Path(path)
118 if not path.exists():
119 raise LoaderError(
120 "File not found",
121 file_path=str(path),
122 )
124 if PANDAS_AVAILABLE: 124 ↛ 135line 124 didn't jump to line 135 because the condition on line 124 was always true
125 trace = _load_with_pandas(
126 path,
127 time_column=time_column,
128 voltage_column=voltage_column,
129 sample_rate=sample_rate,
130 delimiter=delimiter,
131 skip_rows=skip_rows,
132 encoding=encoding,
133 )
134 else:
135 trace = _load_basic(
136 path,
137 time_column=time_column,
138 voltage_column=voltage_column,
139 sample_rate=sample_rate,
140 delimiter=delimiter,
141 skip_rows=skip_rows,
142 encoding=encoding,
143 )
145 # Convert to memory-mapped if requested
146 if mmap:
147 import tempfile
149 from tracekit.loaders.mmap_loader import load_mmap
151 # Save data to temporary .npy file for memory mapping
152 with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
153 tmp_path = Path(tmp.name)
155 np.save(tmp_path, trace.data)
157 # Load as memory-mapped trace
158 return load_mmap(
159 tmp_path,
160 sample_rate=trace.metadata.sample_rate,
161 )
163 return trace
166def _load_with_pandas(
167 path: Path,
168 *,
169 time_column: str | int | None,
170 voltage_column: str | int | None,
171 sample_rate: float | None,
172 delimiter: str | None,
173 skip_rows: int,
174 encoding: str,
175) -> WaveformTrace:
176 """Load CSV using pandas for better parsing."""
177 try:
178 # Auto-detect delimiter if not specified
179 if delimiter is None:
180 delimiter = _detect_delimiter(path, encoding)
182 # Read CSV with pandas
183 df = pd.read_csv(
184 path,
185 delimiter=delimiter,
186 skiprows=skip_rows,
187 encoding=encoding,
188 engine="python", # More flexible parsing
189 )
191 if df.empty: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true
192 raise FormatError(
193 "CSV file is empty",
194 file_path=str(path),
195 )
197 # Find time column
198 time_data = None
199 time_col_name = None
201 if time_column is not None:
202 if isinstance(time_column, int):
203 if time_column < len(df.columns): 203 ↛ 219line 203 didn't jump to line 219 because the condition on line 203 was always true
204 time_col_name = df.columns[time_column]
205 time_data = df.iloc[:, time_column].values
206 elif time_column in df.columns: 206 ↛ 219line 206 didn't jump to line 219 because the condition on line 206 was always true
207 time_col_name = time_column
208 time_data = df[time_column].values
209 else:
210 # Auto-detect time column
211 for col in df.columns:
212 col_lower = col.lower().strip()
213 if col_lower in [n.lower() for n in TIME_COLUMN_NAMES]:
214 time_col_name = col
215 time_data = df[col].values
216 break
218 # Find voltage column
219 voltage_data = None
220 voltage_col_name = None
222 if voltage_column is not None:
223 if isinstance(voltage_column, int):
224 if voltage_column < len(df.columns):
225 voltage_col_name = df.columns[voltage_column]
226 voltage_data = df.iloc[:, voltage_column].values
227 elif voltage_column in df.columns:
228 voltage_col_name = voltage_column
229 voltage_data = df[voltage_column].values
230 else:
231 # Auto-detect voltage column (first non-time numeric column)
232 for col in df.columns:
233 if col == time_col_name:
234 continue
235 col_lower = col.lower().strip()
236 # Check if numeric
237 if pd.api.types.is_numeric_dtype(df[col]):
238 # Prefer columns with voltage-like names
239 if col_lower in [n.lower() for n in VOLTAGE_COLUMN_NAMES]:
240 voltage_col_name = col
241 voltage_data = df[col].values
242 break
243 elif voltage_data is None:
244 voltage_col_name = col
245 voltage_data = df[col].values
247 if voltage_data is None:
248 raise FormatError(
249 "No voltage data found in CSV",
250 file_path=str(path),
251 expected="Numeric column for voltage data",
252 got=f"Columns: {', '.join(df.columns)}",
253 )
255 # Convert to float64
256 data = np.asarray(voltage_data, dtype=np.float64)
258 # Determine sample rate
259 detected_sample_rate = sample_rate
260 if detected_sample_rate is None and time_data is not None:
261 time_data = np.asarray(time_data, dtype=np.float64)
262 if len(time_data) > 1:
263 # Calculate sample rate from time intervals
264 dt = np.median(np.diff(time_data))
265 if dt > 0:
266 detected_sample_rate = 1.0 / dt
268 if detected_sample_rate is None:
269 detected_sample_rate = 1e6 # Default to 1 MSa/s
271 # Build metadata
272 metadata = TraceMetadata(
273 sample_rate=detected_sample_rate,
274 source_file=str(path),
275 channel_name=voltage_col_name or "CH1",
276 )
278 return WaveformTrace(data=data, metadata=metadata)
280 except pd.errors.ParserError as e:
281 raise FormatError(
282 "Failed to parse CSV file",
283 file_path=str(path),
284 details=str(e),
285 ) from e
286 except Exception as e:
287 if isinstance(e, LoaderError | FormatError):
288 raise
289 raise LoaderError(
290 "Failed to load CSV file",
291 file_path=str(path),
292 details=str(e),
293 ) from e
296def _load_basic(
297 path: Path,
298 *,
299 time_column: str | int | None,
300 voltage_column: str | int | None,
301 sample_rate: float | None,
302 delimiter: str | None,
303 skip_rows: int,
304 encoding: str,
305) -> WaveformTrace:
306 """Basic CSV loader without pandas."""
307 try:
308 with open(path, encoding=encoding) as f:
309 # Skip rows
310 for _ in range(skip_rows):
311 next(f)
313 content = f.read()
315 # Auto-detect delimiter
316 if delimiter is None:
317 delimiter = _detect_delimiter_from_content(content)
319 # Parse CSV
320 reader = csv.reader(StringIO(content), delimiter=delimiter)
321 rows = list(reader)
323 if not rows:
324 raise FormatError("CSV file is empty", file_path=str(path))
326 # Detect header
327 header = None
328 data_start = 0
329 first_row = rows[0]
331 # Check if first row is a header (contains non-numeric values)
332 is_header = False
333 for cell in first_row:
334 try:
335 float(cell)
336 except ValueError:
337 if cell.strip(): # Non-empty, non-numeric 337 ↛ 333line 337 didn't jump to line 333 because the condition on line 337 was always true
338 is_header = True
339 break
341 if is_header:
342 header = [cell.strip() for cell in first_row]
343 data_start = 1
345 # Determine column indices
346 time_idx = None
347 voltage_idx = None
349 if header:
350 # Find columns by name
351 if time_column is not None:
352 if isinstance(time_column, int):
353 time_idx = time_column
354 elif time_column in header: 354 ↛ 363line 354 didn't jump to line 363 because the condition on line 354 was always true
355 time_idx = header.index(time_column)
356 else:
357 # Auto-detect
358 for i, col in enumerate(header):
359 if col.lower() in [n.lower() for n in TIME_COLUMN_NAMES]:
360 time_idx = i
361 break
363 if voltage_column is not None:
364 if isinstance(voltage_column, int):
365 voltage_idx = voltage_column
366 elif voltage_column in header: 366 ↛ 391line 366 didn't jump to line 391 because the condition on line 366 was always true
367 voltage_idx = header.index(voltage_column)
368 else:
369 # Auto-detect (first column that's not time)
370 for i, col in enumerate(header):
371 if i == time_idx:
372 continue
373 if col.lower() in [n.lower() for n in VOLTAGE_COLUMN_NAMES]:
374 voltage_idx = i
375 break
376 if voltage_idx is None:
377 voltage_idx = 1 if time_idx == 0 else 0
378 else:
379 # No header - use indices
380 if isinstance(time_column, int): 380 ↛ 383line 380 didn't jump to line 383 because the condition on line 380 was always true
381 time_idx = time_column
382 else:
383 time_idx = 0 # Assume first column is time
385 if isinstance(voltage_column, int): 385 ↛ 388line 385 didn't jump to line 388 because the condition on line 385 was always true
386 voltage_idx = voltage_column
387 else:
388 voltage_idx = 1 # Assume second column is voltage
390 # Extract data
391 time_data = []
392 voltage_data = []
394 for row in rows[data_start:]:
395 if not row: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true
396 continue
397 try:
398 if voltage_idx is not None and voltage_idx < len(row): 398 ↛ 394line 398 didn't jump to line 394 because the condition on line 398 was always true
399 voltage_data.append(float(row[voltage_idx]))
400 if time_idx is not None and time_idx < len(row):
401 time_data.append(float(row[time_idx]))
402 except (ValueError, IndexError):
403 continue # Skip malformed rows
405 if not voltage_data:
406 raise FormatError(
407 "No valid voltage data found in CSV",
408 file_path=str(path),
409 )
411 data = np.array(voltage_data, dtype=np.float64)
413 # Determine sample rate
414 detected_sample_rate = sample_rate
415 if detected_sample_rate is None and time_data:
416 time_arr = np.array(time_data, dtype=np.float64)
417 if len(time_arr) > 1: 417 ↛ 422line 417 didn't jump to line 422 because the condition on line 417 was always true
418 dt = np.median(np.diff(time_arr))
419 if dt > 0: 419 ↛ 422line 419 didn't jump to line 422 because the condition on line 419 was always true
420 detected_sample_rate = 1.0 / dt
422 if detected_sample_rate is None:
423 detected_sample_rate = 1e6
425 # Channel name
426 channel_name = "CH1"
427 if header and voltage_idx is not None and voltage_idx < len(header):
428 channel_name = header[voltage_idx]
430 metadata = TraceMetadata(
431 sample_rate=detected_sample_rate,
432 source_file=str(path),
433 channel_name=channel_name,
434 )
436 return WaveformTrace(data=data, metadata=metadata)
438 except Exception as e:
439 if isinstance(e, LoaderError | FormatError): 439 ↛ 441line 439 didn't jump to line 441 because the condition on line 439 was always true
440 raise
441 raise LoaderError(
442 "Failed to load CSV file",
443 file_path=str(path),
444 details=str(e),
445 ) from e
448def _detect_delimiter(path: Path, encoding: str) -> str:
449 """Detect the delimiter used in a CSV file."""
450 try:
451 with open(path, encoding=encoding) as f:
452 sample = f.read(4096)
453 return _detect_delimiter_from_content(sample)
454 except Exception:
455 return ","
458def _detect_delimiter_from_content(content: str) -> str:
459 """Detect delimiter from CSV content."""
460 # Try common delimiters and count occurrences
461 delimiters = [",", "\t", ";", "|", " "]
462 counts: dict[str, int] = {}
464 for delim in delimiters:
465 counts[delim] = content.count(delim)
467 # Return the most common delimiter
468 if counts: 468 ↛ 470line 468 didn't jump to line 470 because the condition on line 468 was always true
469 return max(counts, key=lambda d: counts[d])
470 return ","
473__all__ = ["load_csv"]