Coverage for src / tracekit / workflows / multi_trace.py: 18%
189 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Multi-Trace Workflow Support.
3Provides workflows for processing and analyzing multiple traces together.
4"""
6import concurrent.futures
7from collections.abc import Iterator
8from dataclasses import dataclass, field
9from glob import glob as glob_func
10from pathlib import Path
11from typing import Any
13import numpy as np
15from tracekit.core.exceptions import TraceKitError
16from tracekit.core.progress import create_progress_tracker
19class AlignmentMethod:
20 """Alignment method constants."""
22 TRIGGER = "trigger"
23 TIME_SYNC = "time"
24 CROSS_CORRELATION = "correlation"
25 MANUAL = "manual"
28@dataclass
29class TraceStatistics:
30 """Statistics for a measurement across traces.
32 Attributes:
33 mean: Mean value
34 std: Standard deviation
35 min: Minimum value
36 max: Maximum value
37 median: Median value
38 count: Number of traces
39 """
41 mean: float
42 std: float
43 min: float
44 max: float
45 median: float
46 count: int
49@dataclass
50class MultiTraceResults:
51 """Results from multi-trace workflow.
53 Attributes:
54 trace_ids: List of trace identifiers
55 measurements: Dict mapping trace_id -> measurement results
56 statistics: Dict mapping measurement_name -> TraceStatistics
57 metadata: Additional workflow metadata
58 """
60 trace_ids: list[str] = field(default_factory=list)
61 measurements: dict[str, dict[str, Any]] = field(default_factory=dict)
62 statistics: dict[str, TraceStatistics] = field(default_factory=dict)
63 metadata: dict[str, Any] = field(default_factory=dict)
66class MultiTraceWorkflow:
67 """Workflow manager for multi-trace processing.
69 Provides methods to load, align, process, and analyze multiple traces
70 with memory-efficient streaming and optional parallelization.
71 """
73 def __init__(
74 self,
75 pattern: str | None = None,
76 traces: list[Any] | None = None,
77 lazy: bool = False,
78 ):
79 """Initialize multi-trace workflow.
81 Args:
82 pattern: Glob pattern for trace files (e.g., "*.csv")
83 traces: Pre-loaded trace objects
84 lazy: If True, load traces on demand
86 Raises:
87 TraceKitError: If neither pattern nor traces provided
88 """
89 self.pattern = pattern
90 self._traces = traces or []
91 self._lazy = lazy
92 self._file_paths: list[Path] = []
93 self._aligned = False
94 self._alignment_offset: dict[str, int] = {}
95 self.results = MultiTraceResults()
97 # Discover files if pattern provided
98 if pattern:
99 self._discover_files()
100 elif not traces:
101 raise TraceKitError("Must provide either pattern or traces")
103 def _discover_files(self) -> None:
104 """Discover trace files matching pattern."""
105 if not self.pattern:
106 return
108 paths = glob_func(self.pattern) # noqa: PTH207
109 if not paths:
110 raise TraceKitError(f"No files match pattern: {self.pattern}")
112 self._file_paths = [Path(p) for p in sorted(paths)]
113 self.results.trace_ids = [p.name for p in self._file_paths]
115 def _load_trace(self, path: Path) -> Any:
116 """Load a single trace file.
118 Args:
119 path: Path to trace file
121 Returns:
122 Loaded trace object
124 Raises:
125 TraceKitError: If trace cannot be loaded
126 """
127 # Determine loader based on extension
128 ext = path.suffix.lower()
130 try:
131 if ext == ".csv":
132 from tracekit.loaders.csv import ( # type: ignore[import-not-found]
133 load_csv, # type: ignore[import-not-found]
134 )
136 return load_csv(str(path))
137 elif ext == ".bin":
138 from tracekit.loaders.binary import ( # type: ignore[import-not-found]
139 load_binary, # type: ignore[import-not-found]
140 )
142 return load_binary(str(path))
143 elif ext in (".h5", ".hdf5"):
144 from tracekit.loaders.hdf5 import ( # type: ignore[import-not-found]
145 load_hdf5, # type: ignore[import-not-found]
146 )
148 return load_hdf5(str(path))
149 else:
150 raise TraceKitError(f"Unsupported format: {ext}")
152 except ImportError as e:
153 raise TraceKitError(f"Loader not available for {ext}: {e}") # noqa: B904
155 def _iter_traces(self, lazy: bool = False) -> Iterator[tuple[str, Any]]:
156 """Iterate over traces.
158 Args:
159 lazy: If True, load on demand; if False, load all first
161 Yields:
162 Tuple of (trace_id, trace)
163 """
164 # Use pre-loaded traces if available
165 if self._traces:
166 for i, trace in enumerate(self._traces):
167 trace_id = (
168 self.results.trace_ids[i] if i < len(self.results.trace_ids) else f"trace_{i}"
169 )
170 yield trace_id, trace
171 return
173 # Load from files
174 for path in self._file_paths:
175 trace_id = path.name
176 if lazy or self._lazy:
177 # Load on demand
178 trace = self._load_trace(path)
179 else:
180 # Would load all at once (not implemented here)
181 trace = self._load_trace(path)
182 yield trace_id, trace
184 def align(
185 self,
186 method: str = AlignmentMethod.TRIGGER,
187 channel: int = 0,
188 threshold: float | None = None,
189 **kwargs: Any,
190 ) -> None:
191 """Align traces using specified method.
193 Args:
194 method: Alignment method ('trigger', 'time', 'correlation', 'manual')
195 channel: Channel to use for alignment (for multi-channel traces)
196 threshold: Trigger threshold (for trigger alignment)
197 **kwargs: Additional method-specific parameters
199 Raises:
200 TraceKitError: If alignment fails
201 """
202 if method == AlignmentMethod.TRIGGER:
203 self._align_by_trigger(channel, threshold, **kwargs)
204 elif method == AlignmentMethod.TIME_SYNC:
205 self._align_by_time(**kwargs)
206 elif method == AlignmentMethod.CROSS_CORRELATION:
207 self._align_by_correlation(channel, **kwargs)
208 elif method == AlignmentMethod.MANUAL:
209 self._align_manual(**kwargs)
210 else:
211 raise TraceKitError(f"Unknown alignment method: {method}")
213 self._aligned = True
215 def _align_by_trigger(
216 self,
217 channel: int,
218 threshold: float | None,
219 **kwargs: Any,
220 ) -> None:
221 """Align traces by trigger point.
223 Args:
224 channel: Channel index
225 threshold: Trigger threshold
226 **kwargs: Additional parameters
227 """
228 # Find trigger point in each trace
229 for trace_id, trace in self._iter_traces(lazy=True):
230 # Find first crossing of threshold
231 if hasattr(trace, "data"):
232 data = trace.data
233 if threshold is None:
234 # Auto threshold: 50% of max
235 threshold = 0.5 * (np.max(data) + np.min(data))
237 # Find first rising edge
238 above = data > threshold
239 edges = np.diff(above.astype(int))
240 rising = np.where(edges > 0)[0]
242 if len(rising) > 0:
243 self._alignment_offset[trace_id] = int(rising[0])
244 else:
245 self._alignment_offset[trace_id] = 0
246 else:
247 self._alignment_offset[trace_id] = 0
249 def _align_by_time(self, **kwargs: Any) -> None:
250 """Align traces by timestamp.
252 Args:
253 **kwargs: Additional parameters
254 """
255 # Align based on trace timestamps
256 # Placeholder implementation
257 for trace_id, _trace in self._iter_traces(lazy=True):
258 self._alignment_offset[trace_id] = 0
260 def _align_by_correlation(self, channel: int, **kwargs: Any) -> None:
261 """Align traces by cross-correlation.
263 Args:
264 channel: Channel index
265 **kwargs: Additional parameters
266 """
267 # Use cross-correlation to find alignment
268 # Placeholder implementation
269 for trace_id, _trace in self._iter_traces(lazy=True):
270 self._alignment_offset[trace_id] = 0
272 def _align_manual(self, **kwargs: Any) -> None:
273 """Manual alignment with specified offsets.
275 Args:
276 **kwargs: Must include 'offsets' dict mapping trace_id -> offset
278 Raises:
279 TraceKitError: If offsets parameter not provided.
280 """
281 offsets = kwargs.get("offsets", {})
282 if not offsets:
283 raise TraceKitError("Manual alignment requires 'offsets' parameter")
285 self._alignment_offset.update(offsets)
287 def measure(
288 self, *measurements: str, parallel: bool = False, max_workers: int | None = None
289 ) -> None:
290 """Measure properties across all traces.
292 Args:
293 *measurements: Measurement names (rise_time, fall_time, etc.)
294 parallel: If True, process traces in parallel
295 max_workers: Maximum parallel workers (None = CPU count)
297 Raises:
298 TraceKitError: If measurement fails
299 """
300 if not measurements:
301 raise TraceKitError("At least one measurement required")
303 if parallel:
304 self._measure_parallel(measurements, max_workers)
305 else:
306 self._measure_sequential(measurements)
308 def _measure_sequential(self, measurements: tuple[str, ...]) -> None:
309 """Measure sequentially."""
310 # Progress tracking
311 progress = create_progress_tracker( # type: ignore[call-arg]
312 total=len(self.results.trace_ids),
313 description="Measuring traces",
314 )
316 for trace_id, trace in self._iter_traces(lazy=True):
317 results = {}
318 for meas_name in measurements:
319 try:
320 # Call measurement function
321 # Placeholder - would call actual measurement
322 results[meas_name] = self._perform_measurement(trace, meas_name)
323 except Exception as e:
324 results[meas_name] = None
325 print(f"Warning: {meas_name} failed for {trace_id}: {e}")
327 self.results.measurements[trace_id] = results
328 progress.update(1)
330 def _measure_parallel(self, measurements: tuple[str, ...], max_workers: int | None) -> None:
331 """Measure in parallel."""
332 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
333 futures = {}
335 for trace_id, trace in self._iter_traces(lazy=False):
336 future = executor.submit(self._measure_trace, trace, measurements)
337 futures[future] = trace_id
339 for future in concurrent.futures.as_completed(futures):
340 trace_id = futures[future]
341 try:
342 results = future.result()
343 self.results.measurements[trace_id] = results
344 except Exception as e:
345 print(f"Error measuring {trace_id}: {e}")
347 def _measure_trace(self, trace: Any, measurements: tuple[str, ...]) -> dict[str, Any]:
348 """Measure a single trace.
350 Args:
351 trace: Trace object
352 measurements: Measurement names
354 Returns:
355 Dict mapping measurement_name -> value
356 """
357 results = {}
358 for meas_name in measurements:
359 try:
360 results[meas_name] = self._perform_measurement(trace, meas_name)
361 except Exception:
362 results[meas_name] = None
363 return results
365 def _perform_measurement(self, trace: Any, measurement: str) -> Any:
366 """Perform a single measurement.
368 Args:
369 trace: Trace object
370 measurement: Measurement name
372 Raises:
373 TraceKitError: If measurement not available
374 """
375 # Placeholder - would call actual measurement functions
376 # from tracekit.analyzers.measurements
377 raise TraceKitError(
378 f"Measurement '{measurement}' not yet implemented in multi-trace workflow"
379 )
381 def aggregate(self) -> MultiTraceResults:
382 """Compute aggregate statistics across traces.
384 Returns:
385 Results with statistics
387 Raises:
388 TraceKitError: If no measurements available
389 """
390 if not self.results.measurements:
391 raise TraceKitError("No measurements available. Call measure() first.")
393 # Compute statistics for each measurement type
394 all_measurements = set() # type: ignore[var-annotated]
395 for trace_results in self.results.measurements.values():
396 all_measurements.update(trace_results.keys())
398 for meas_name in all_measurements:
399 values = []
400 for trace_results in self.results.measurements.values():
401 val = trace_results.get(meas_name)
402 if val is not None and not (isinstance(val, float) and np.isnan(val)):
403 values.append(float(val))
405 if values:
406 self.results.statistics[meas_name] = TraceStatistics(
407 mean=float(np.mean(values)),
408 std=float(np.std(values)),
409 min=float(np.min(values)),
410 max=float(np.max(values)),
411 median=float(np.median(values)),
412 count=len(values),
413 )
415 return self.results
417 def export_report(self, filename: str, format: str = "pdf") -> None:
418 """Export combined report.
420 Args:
421 filename: Output filename
422 format: Report format ('pdf', 'html', 'json')
424 Raises:
425 TraceKitError: If export fails
426 """
427 if format == "json":
428 self._export_json(filename)
429 elif format == "pdf":
430 self._export_pdf(filename)
431 elif format == "html":
432 self._export_html(filename)
433 else:
434 raise TraceKitError(f"Unsupported report format: {format}")
436 def _export_json(self, filename: str) -> None:
437 """Export results to JSON."""
438 import json
440 data = {
441 "trace_ids": self.results.trace_ids,
442 "measurements": self.results.measurements,
443 "statistics": {
444 name: {
445 "mean": stats.mean,
446 "std": stats.std,
447 "min": stats.min,
448 "max": stats.max,
449 "median": stats.median,
450 "count": stats.count,
451 }
452 for name, stats in self.results.statistics.items()
453 },
454 "metadata": self.results.metadata,
455 }
457 with open(filename, "w") as f:
458 json.dump(data, f, indent=2)
460 def _export_pdf(self, filename: str) -> None:
461 """Export results to PDF.
463 Args:
464 filename: Output filename
466 Raises:
467 TraceKitError: PDF export not yet implemented
468 """
469 raise TraceKitError("PDF export not yet implemented")
471 def _export_html(self, filename: str) -> None:
472 """Export results to HTML.
474 Args:
475 filename: Output filename
477 Raises:
478 TraceKitError: HTML export not yet implemented
479 """
480 raise TraceKitError("HTML export not yet implemented")
483def load_all(pattern: str, lazy: bool = True) -> list[Any]:
484 """Load all traces matching pattern.
486 Args:
487 pattern: Glob pattern
488 lazy: If True, return lazy-loading proxy objects
490 Returns:
491 List of trace objects
493 Raises:
494 TraceKitError: If no traces found
495 """
496 paths = glob_func(pattern) # noqa: PTH207
497 if not paths:
498 raise TraceKitError(f"No files match pattern: {pattern}")
500 # For now, just return file paths
501 # Would implement lazy loading proxy
502 return [Path(p) for p in paths]