Coverage for src / tracekit / workflows / multi_trace.py: 18%

189 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Multi-Trace Workflow Support. 

2 

3Provides workflows for processing and analyzing multiple traces together. 

4""" 

5 

6import concurrent.futures 

7from collections.abc import Iterator 

8from dataclasses import dataclass, field 

9from glob import glob as glob_func 

10from pathlib import Path 

11from typing import Any 

12 

13import numpy as np 

14 

15from tracekit.core.exceptions import TraceKitError 

16from tracekit.core.progress import create_progress_tracker 

17 

18 

19class AlignmentMethod: 

20 """Alignment method constants.""" 

21 

22 TRIGGER = "trigger" 

23 TIME_SYNC = "time" 

24 CROSS_CORRELATION = "correlation" 

25 MANUAL = "manual" 

26 

27 

28@dataclass 

29class TraceStatistics: 

30 """Statistics for a measurement across traces. 

31 

32 Attributes: 

33 mean: Mean value 

34 std: Standard deviation 

35 min: Minimum value 

36 max: Maximum value 

37 median: Median value 

38 count: Number of traces 

39 """ 

40 

41 mean: float 

42 std: float 

43 min: float 

44 max: float 

45 median: float 

46 count: int 

47 

48 

49@dataclass 

50class MultiTraceResults: 

51 """Results from multi-trace workflow. 

52 

53 Attributes: 

54 trace_ids: List of trace identifiers 

55 measurements: Dict mapping trace_id -> measurement results 

56 statistics: Dict mapping measurement_name -> TraceStatistics 

57 metadata: Additional workflow metadata 

58 """ 

59 

60 trace_ids: list[str] = field(default_factory=list) 

61 measurements: dict[str, dict[str, Any]] = field(default_factory=dict) 

62 statistics: dict[str, TraceStatistics] = field(default_factory=dict) 

63 metadata: dict[str, Any] = field(default_factory=dict) 

64 

65 

66class MultiTraceWorkflow: 

67 """Workflow manager for multi-trace processing. 

68 

69 Provides methods to load, align, process, and analyze multiple traces 

70 with memory-efficient streaming and optional parallelization. 

71 """ 

72 

73 def __init__( 

74 self, 

75 pattern: str | None = None, 

76 traces: list[Any] | None = None, 

77 lazy: bool = False, 

78 ): 

79 """Initialize multi-trace workflow. 

80 

81 Args: 

82 pattern: Glob pattern for trace files (e.g., "*.csv") 

83 traces: Pre-loaded trace objects 

84 lazy: If True, load traces on demand 

85 

86 Raises: 

87 TraceKitError: If neither pattern nor traces provided 

88 """ 

89 self.pattern = pattern 

90 self._traces = traces or [] 

91 self._lazy = lazy 

92 self._file_paths: list[Path] = [] 

93 self._aligned = False 

94 self._alignment_offset: dict[str, int] = {} 

95 self.results = MultiTraceResults() 

96 

97 # Discover files if pattern provided 

98 if pattern: 

99 self._discover_files() 

100 elif not traces: 

101 raise TraceKitError("Must provide either pattern or traces") 

102 

103 def _discover_files(self) -> None: 

104 """Discover trace files matching pattern.""" 

105 if not self.pattern: 

106 return 

107 

108 paths = glob_func(self.pattern) # noqa: PTH207 

109 if not paths: 

110 raise TraceKitError(f"No files match pattern: {self.pattern}") 

111 

112 self._file_paths = [Path(p) for p in sorted(paths)] 

113 self.results.trace_ids = [p.name for p in self._file_paths] 

114 

115 def _load_trace(self, path: Path) -> Any: 

116 """Load a single trace file. 

117 

118 Args: 

119 path: Path to trace file 

120 

121 Returns: 

122 Loaded trace object 

123 

124 Raises: 

125 TraceKitError: If trace cannot be loaded 

126 """ 

127 # Determine loader based on extension 

128 ext = path.suffix.lower() 

129 

130 try: 

131 if ext == ".csv": 

132 from tracekit.loaders.csv import ( # type: ignore[import-not-found] 

133 load_csv, # type: ignore[import-not-found] 

134 ) 

135 

136 return load_csv(str(path)) 

137 elif ext == ".bin": 

138 from tracekit.loaders.binary import ( # type: ignore[import-not-found] 

139 load_binary, # type: ignore[import-not-found] 

140 ) 

141 

142 return load_binary(str(path)) 

143 elif ext in (".h5", ".hdf5"): 

144 from tracekit.loaders.hdf5 import ( # type: ignore[import-not-found] 

145 load_hdf5, # type: ignore[import-not-found] 

146 ) 

147 

148 return load_hdf5(str(path)) 

149 else: 

150 raise TraceKitError(f"Unsupported format: {ext}") 

151 

152 except ImportError as e: 

153 raise TraceKitError(f"Loader not available for {ext}: {e}") # noqa: B904 

154 

155 def _iter_traces(self, lazy: bool = False) -> Iterator[tuple[str, Any]]: 

156 """Iterate over traces. 

157 

158 Args: 

159 lazy: If True, load on demand; if False, load all first 

160 

161 Yields: 

162 Tuple of (trace_id, trace) 

163 """ 

164 # Use pre-loaded traces if available 

165 if self._traces: 

166 for i, trace in enumerate(self._traces): 

167 trace_id = ( 

168 self.results.trace_ids[i] if i < len(self.results.trace_ids) else f"trace_{i}" 

169 ) 

170 yield trace_id, trace 

171 return 

172 

173 # Load from files 

174 for path in self._file_paths: 

175 trace_id = path.name 

176 if lazy or self._lazy: 

177 # Load on demand 

178 trace = self._load_trace(path) 

179 else: 

180 # Would load all at once (not implemented here) 

181 trace = self._load_trace(path) 

182 yield trace_id, trace 

183 

184 def align( 

185 self, 

186 method: str = AlignmentMethod.TRIGGER, 

187 channel: int = 0, 

188 threshold: float | None = None, 

189 **kwargs: Any, 

190 ) -> None: 

191 """Align traces using specified method. 

192 

193 Args: 

194 method: Alignment method ('trigger', 'time', 'correlation', 'manual') 

195 channel: Channel to use for alignment (for multi-channel traces) 

196 threshold: Trigger threshold (for trigger alignment) 

197 **kwargs: Additional method-specific parameters 

198 

199 Raises: 

200 TraceKitError: If alignment fails 

201 """ 

202 if method == AlignmentMethod.TRIGGER: 

203 self._align_by_trigger(channel, threshold, **kwargs) 

204 elif method == AlignmentMethod.TIME_SYNC: 

205 self._align_by_time(**kwargs) 

206 elif method == AlignmentMethod.CROSS_CORRELATION: 

207 self._align_by_correlation(channel, **kwargs) 

208 elif method == AlignmentMethod.MANUAL: 

209 self._align_manual(**kwargs) 

210 else: 

211 raise TraceKitError(f"Unknown alignment method: {method}") 

212 

213 self._aligned = True 

214 

215 def _align_by_trigger( 

216 self, 

217 channel: int, 

218 threshold: float | None, 

219 **kwargs: Any, 

220 ) -> None: 

221 """Align traces by trigger point. 

222 

223 Args: 

224 channel: Channel index 

225 threshold: Trigger threshold 

226 **kwargs: Additional parameters 

227 """ 

228 # Find trigger point in each trace 

229 for trace_id, trace in self._iter_traces(lazy=True): 

230 # Find first crossing of threshold 

231 if hasattr(trace, "data"): 

232 data = trace.data 

233 if threshold is None: 

234 # Auto threshold: 50% of max 

235 threshold = 0.5 * (np.max(data) + np.min(data)) 

236 

237 # Find first rising edge 

238 above = data > threshold 

239 edges = np.diff(above.astype(int)) 

240 rising = np.where(edges > 0)[0] 

241 

242 if len(rising) > 0: 

243 self._alignment_offset[trace_id] = int(rising[0]) 

244 else: 

245 self._alignment_offset[trace_id] = 0 

246 else: 

247 self._alignment_offset[trace_id] = 0 

248 

249 def _align_by_time(self, **kwargs: Any) -> None: 

250 """Align traces by timestamp. 

251 

252 Args: 

253 **kwargs: Additional parameters 

254 """ 

255 # Align based on trace timestamps 

256 # Placeholder implementation 

257 for trace_id, _trace in self._iter_traces(lazy=True): 

258 self._alignment_offset[trace_id] = 0 

259 

260 def _align_by_correlation(self, channel: int, **kwargs: Any) -> None: 

261 """Align traces by cross-correlation. 

262 

263 Args: 

264 channel: Channel index 

265 **kwargs: Additional parameters 

266 """ 

267 # Use cross-correlation to find alignment 

268 # Placeholder implementation 

269 for trace_id, _trace in self._iter_traces(lazy=True): 

270 self._alignment_offset[trace_id] = 0 

271 

272 def _align_manual(self, **kwargs: Any) -> None: 

273 """Manual alignment with specified offsets. 

274 

275 Args: 

276 **kwargs: Must include 'offsets' dict mapping trace_id -> offset 

277 

278 Raises: 

279 TraceKitError: If offsets parameter not provided. 

280 """ 

281 offsets = kwargs.get("offsets", {}) 

282 if not offsets: 

283 raise TraceKitError("Manual alignment requires 'offsets' parameter") 

284 

285 self._alignment_offset.update(offsets) 

286 

287 def measure( 

288 self, *measurements: str, parallel: bool = False, max_workers: int | None = None 

289 ) -> None: 

290 """Measure properties across all traces. 

291 

292 Args: 

293 *measurements: Measurement names (rise_time, fall_time, etc.) 

294 parallel: If True, process traces in parallel 

295 max_workers: Maximum parallel workers (None = CPU count) 

296 

297 Raises: 

298 TraceKitError: If measurement fails 

299 """ 

300 if not measurements: 

301 raise TraceKitError("At least one measurement required") 

302 

303 if parallel: 

304 self._measure_parallel(measurements, max_workers) 

305 else: 

306 self._measure_sequential(measurements) 

307 

308 def _measure_sequential(self, measurements: tuple[str, ...]) -> None: 

309 """Measure sequentially.""" 

310 # Progress tracking 

311 progress = create_progress_tracker( # type: ignore[call-arg] 

312 total=len(self.results.trace_ids), 

313 description="Measuring traces", 

314 ) 

315 

316 for trace_id, trace in self._iter_traces(lazy=True): 

317 results = {} 

318 for meas_name in measurements: 

319 try: 

320 # Call measurement function 

321 # Placeholder - would call actual measurement 

322 results[meas_name] = self._perform_measurement(trace, meas_name) 

323 except Exception as e: 

324 results[meas_name] = None 

325 print(f"Warning: {meas_name} failed for {trace_id}: {e}") 

326 

327 self.results.measurements[trace_id] = results 

328 progress.update(1) 

329 

330 def _measure_parallel(self, measurements: tuple[str, ...], max_workers: int | None) -> None: 

331 """Measure in parallel.""" 

332 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 

333 futures = {} 

334 

335 for trace_id, trace in self._iter_traces(lazy=False): 

336 future = executor.submit(self._measure_trace, trace, measurements) 

337 futures[future] = trace_id 

338 

339 for future in concurrent.futures.as_completed(futures): 

340 trace_id = futures[future] 

341 try: 

342 results = future.result() 

343 self.results.measurements[trace_id] = results 

344 except Exception as e: 

345 print(f"Error measuring {trace_id}: {e}") 

346 

347 def _measure_trace(self, trace: Any, measurements: tuple[str, ...]) -> dict[str, Any]: 

348 """Measure a single trace. 

349 

350 Args: 

351 trace: Trace object 

352 measurements: Measurement names 

353 

354 Returns: 

355 Dict mapping measurement_name -> value 

356 """ 

357 results = {} 

358 for meas_name in measurements: 

359 try: 

360 results[meas_name] = self._perform_measurement(trace, meas_name) 

361 except Exception: 

362 results[meas_name] = None 

363 return results 

364 

365 def _perform_measurement(self, trace: Any, measurement: str) -> Any: 

366 """Perform a single measurement. 

367 

368 Args: 

369 trace: Trace object 

370 measurement: Measurement name 

371 

372 Raises: 

373 TraceKitError: If measurement not available 

374 """ 

375 # Placeholder - would call actual measurement functions 

376 # from tracekit.analyzers.measurements 

377 raise TraceKitError( 

378 f"Measurement '{measurement}' not yet implemented in multi-trace workflow" 

379 ) 

380 

381 def aggregate(self) -> MultiTraceResults: 

382 """Compute aggregate statistics across traces. 

383 

384 Returns: 

385 Results with statistics 

386 

387 Raises: 

388 TraceKitError: If no measurements available 

389 """ 

390 if not self.results.measurements: 

391 raise TraceKitError("No measurements available. Call measure() first.") 

392 

393 # Compute statistics for each measurement type 

394 all_measurements = set() # type: ignore[var-annotated] 

395 for trace_results in self.results.measurements.values(): 

396 all_measurements.update(trace_results.keys()) 

397 

398 for meas_name in all_measurements: 

399 values = [] 

400 for trace_results in self.results.measurements.values(): 

401 val = trace_results.get(meas_name) 

402 if val is not None and not (isinstance(val, float) and np.isnan(val)): 

403 values.append(float(val)) 

404 

405 if values: 

406 self.results.statistics[meas_name] = TraceStatistics( 

407 mean=float(np.mean(values)), 

408 std=float(np.std(values)), 

409 min=float(np.min(values)), 

410 max=float(np.max(values)), 

411 median=float(np.median(values)), 

412 count=len(values), 

413 ) 

414 

415 return self.results 

416 

417 def export_report(self, filename: str, format: str = "pdf") -> None: 

418 """Export combined report. 

419 

420 Args: 

421 filename: Output filename 

422 format: Report format ('pdf', 'html', 'json') 

423 

424 Raises: 

425 TraceKitError: If export fails 

426 """ 

427 if format == "json": 

428 self._export_json(filename) 

429 elif format == "pdf": 

430 self._export_pdf(filename) 

431 elif format == "html": 

432 self._export_html(filename) 

433 else: 

434 raise TraceKitError(f"Unsupported report format: {format}") 

435 

436 def _export_json(self, filename: str) -> None: 

437 """Export results to JSON.""" 

438 import json 

439 

440 data = { 

441 "trace_ids": self.results.trace_ids, 

442 "measurements": self.results.measurements, 

443 "statistics": { 

444 name: { 

445 "mean": stats.mean, 

446 "std": stats.std, 

447 "min": stats.min, 

448 "max": stats.max, 

449 "median": stats.median, 

450 "count": stats.count, 

451 } 

452 for name, stats in self.results.statistics.items() 

453 }, 

454 "metadata": self.results.metadata, 

455 } 

456 

457 with open(filename, "w") as f: 

458 json.dump(data, f, indent=2) 

459 

460 def _export_pdf(self, filename: str) -> None: 

461 """Export results to PDF. 

462 

463 Args: 

464 filename: Output filename 

465 

466 Raises: 

467 TraceKitError: PDF export not yet implemented 

468 """ 

469 raise TraceKitError("PDF export not yet implemented") 

470 

471 def _export_html(self, filename: str) -> None: 

472 """Export results to HTML. 

473 

474 Args: 

475 filename: Output filename 

476 

477 Raises: 

478 TraceKitError: HTML export not yet implemented 

479 """ 

480 raise TraceKitError("HTML export not yet implemented") 

481 

482 

483def load_all(pattern: str, lazy: bool = True) -> list[Any]: 

484 """Load all traces matching pattern. 

485 

486 Args: 

487 pattern: Glob pattern 

488 lazy: If True, return lazy-loading proxy objects 

489 

490 Returns: 

491 List of trace objects 

492 

493 Raises: 

494 TraceKitError: If no traces found 

495 """ 

496 paths = glob_func(pattern) # noqa: PTH207 

497 if not paths: 

498 raise TraceKitError(f"No files match pattern: {pattern}") 

499 

500 # For now, just return file paths 

501 # Would implement lazy loading proxy 

502 return [Path(p) for p in paths]