Coverage for src / tracekit / comparison / golden.py: 99%
139 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Golden waveform comparison for TraceKit.
3This module provides golden reference waveform management and comparison
4functions for pass/fail testing against known-good waveforms.
7Example:
8 >>> from tracekit.comparison import create_golden, compare_to_golden
9 >>> golden = create_golden(reference_trace)
10 >>> result = compare_to_golden(measured_trace, golden)
12References:
13 IEEE 181-2011: Standard for Transitional Waveform Definitions
14"""
16from __future__ import annotations
18import json
19from dataclasses import dataclass, field
20from datetime import datetime
21from pathlib import Path
22from typing import TYPE_CHECKING, Any, Literal
24import numpy as np
26from tracekit.core.exceptions import AnalysisError, LoaderError
28if TYPE_CHECKING:
29 from numpy.typing import NDArray
31 from tracekit.core.types import WaveformTrace
34@dataclass
35class GoldenReference:
36 """Golden reference waveform for comparison.
38 Contains a reference waveform with tolerance bounds for pass/fail
39 testing of measured waveforms.
41 Attributes:
42 data: Reference waveform data.
43 sample_rate: Sample rate in Hz.
44 upper_bound: Upper tolerance bound.
45 lower_bound: Lower tolerance bound.
46 tolerance: Tolerance used to create bounds.
47 tolerance_type: How tolerance was applied.
48 name: Reference name.
49 description: Optional description.
50 created: Creation timestamp.
51 metadata: Additional metadata.
52 """
54 data: NDArray[np.float64]
55 sample_rate: float
56 upper_bound: NDArray[np.float64]
57 lower_bound: NDArray[np.float64]
58 tolerance: float
59 tolerance_type: Literal["absolute", "percentage", "sigma"] = "absolute"
60 name: str = "golden"
61 description: str = ""
62 created: datetime = field(default_factory=datetime.now)
63 metadata: dict[str, Any] = field(default_factory=dict)
65 @property
66 def num_samples(self) -> int:
67 """Number of samples in the reference."""
68 return len(self.data)
70 @property
71 def duration(self) -> float:
72 """Duration in seconds."""
73 return self.num_samples / self.sample_rate
75 def to_dict(self) -> dict[str, Any]:
76 """Convert to dictionary for serialization."""
77 return {
78 "data": self.data.tolist(),
79 "sample_rate": self.sample_rate,
80 "upper_bound": self.upper_bound.tolist(),
81 "lower_bound": self.lower_bound.tolist(),
82 "tolerance": self.tolerance,
83 "tolerance_type": self.tolerance_type,
84 "name": self.name,
85 "description": self.description,
86 "created": self.created.isoformat(),
87 "metadata": self.metadata,
88 }
90 @classmethod
91 def from_dict(cls, data: dict[str, Any]) -> GoldenReference:
92 """Create from dictionary."""
93 return cls(
94 data=np.array(data["data"], dtype=np.float64),
95 sample_rate=data["sample_rate"],
96 upper_bound=np.array(data["upper_bound"], dtype=np.float64),
97 lower_bound=np.array(data["lower_bound"], dtype=np.float64),
98 tolerance=data["tolerance"],
99 tolerance_type=data.get("tolerance_type", "absolute"),
100 name=data.get("name", "golden"),
101 description=data.get("description", ""),
102 created=datetime.fromisoformat(data["created"])
103 if "created" in data
104 else datetime.now(),
105 metadata=data.get("metadata", {}),
106 )
108 def save(self, path: str | Path) -> None:
109 """Save golden reference to file.
111 Args:
112 path: File path (JSON format).
113 """
114 path = Path(path)
115 with open(path, "w") as f:
116 json.dump(self.to_dict(), f, indent=2)
118 @classmethod
119 def load(cls, path: str | Path) -> GoldenReference:
120 """Load golden reference from file.
122 Args:
123 path: File path.
125 Returns:
126 GoldenReference instance.
128 Raises:
129 LoaderError: If golden reference file not found.
130 """
131 path = Path(path)
132 if not path.exists():
133 raise LoaderError(
134 f"Golden reference file not found: {path}",
135 file_path=str(path),
136 )
138 with open(path) as f:
139 data = json.load(f)
141 return cls.from_dict(data)
144@dataclass
145class GoldenComparisonResult:
146 """Result of golden waveform comparison.
148 Attributes:
149 passed: True if measured waveform is within tolerance.
150 num_violations: Number of samples outside tolerance.
151 violation_rate: Fraction of samples outside tolerance.
152 max_deviation: Maximum deviation from reference.
153 rms_deviation: RMS deviation from reference.
154 upper_violations: Indices exceeding upper bound.
155 lower_violations: Indices below lower bound.
156 margin: Minimum margin to tolerance bound.
157 margin_percentage: Margin as percentage of tolerance.
158 statistics: Additional comparison statistics.
159 """
161 passed: bool
162 num_violations: int
163 violation_rate: float
164 max_deviation: float
165 rms_deviation: float
166 upper_violations: NDArray[np.int64] | None = None
167 lower_violations: NDArray[np.int64] | None = None
168 margin: float | None = None
169 margin_percentage: float | None = None
170 statistics: dict[str, Any] = field(default_factory=dict)
173def create_golden(
174 trace: WaveformTrace,
175 *,
176 tolerance: float | None = None,
177 tolerance_pct: float | None = None,
178 tolerance_sigma: float | None = None,
179 name: str = "golden",
180 description: str = "",
181) -> GoldenReference:
182 """Create a golden reference from a trace.
184 Creates a golden reference waveform with tolerance bounds for
185 subsequent comparison testing.
187 Args:
188 trace: Reference waveform trace.
189 tolerance: Absolute tolerance value.
190 tolerance_pct: Percentage tolerance (0-100).
191 tolerance_sigma: Tolerance as multiple of standard deviation.
192 name: Name for the reference.
193 description: Optional description.
195 Returns:
196 GoldenReference for comparison testing.
198 Example:
199 >>> golden = create_golden(trace, tolerance_pct=5) # 5% tolerance
200 >>> golden = create_golden(trace, tolerance=0.01) # 10mV tolerance
201 """
202 data = trace.data.astype(np.float64)
204 # Determine tolerance and type
205 if tolerance is not None:
206 tol = tolerance
207 tol_type: Literal["absolute", "percentage", "sigma"] = "absolute"
208 elif tolerance_pct is not None:
209 # Calculate absolute tolerance from percentage
210 data_range = float(np.ptp(data))
211 tol = data_range * tolerance_pct / 100.0
212 tol_type = "percentage"
213 elif tolerance_sigma is not None:
214 # Calculate tolerance from standard deviation
215 tol = float(np.std(data)) * tolerance_sigma
216 tol_type = "sigma"
217 else:
218 # Default: 1% of range
219 data_range = float(np.ptp(data))
220 tol = data_range * 0.01
221 tol_type = "percentage"
223 # Create bounds
224 upper_bound = data + tol
225 lower_bound = data - tol
227 return GoldenReference(
228 data=data,
229 sample_rate=trace.metadata.sample_rate,
230 upper_bound=upper_bound,
231 lower_bound=lower_bound,
232 tolerance=tol,
233 tolerance_type=tol_type,
234 name=name,
235 description=description,
236 metadata={
237 "source_file": trace.metadata.source_file,
238 "channel_name": trace.metadata.channel_name,
239 },
240 )
243def tolerance_envelope(
244 trace: WaveformTrace,
245 *,
246 absolute: float | None = None,
247 percentage: float | None = None,
248 sigma: float | None = None,
249) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
250 """Create tolerance envelope around a trace.
252 Generates upper and lower bounds based on the specified tolerance.
254 Args:
255 trace: Reference trace.
256 absolute: Absolute tolerance value.
257 percentage: Percentage tolerance (0-100).
258 sigma: Tolerance as multiple of standard deviation.
260 Returns:
261 Tuple of (upper_bound, lower_bound) arrays.
263 Raises:
264 ValueError: If no tolerance type specified.
266 Example:
267 >>> upper, lower = tolerance_envelope(trace, percentage=5)
268 """
269 data = trace.data.astype(np.float64)
271 if absolute is not None:
272 tol = absolute
273 elif percentage is not None:
274 data_range = float(np.ptp(data))
275 tol = data_range * percentage / 100.0
276 elif sigma is not None:
277 tol = float(np.std(data)) * sigma
278 else:
279 raise ValueError("Must specify absolute, percentage, or sigma tolerance")
281 return data + tol, data - tol
284def compare_to_golden(
285 trace: WaveformTrace,
286 golden: GoldenReference,
287 *,
288 align: bool = True,
289 interpolate: bool = True,
290) -> GoldenComparisonResult:
291 """Compare a trace to a golden reference.
293 Tests if the measured trace falls within the tolerance bounds
294 of the golden reference.
296 Args:
297 trace: Measured trace to compare.
298 golden: Golden reference to compare against.
299 align: Attempt to align traces by cross-correlation.
300 interpolate: Interpolate if sample counts differ.
302 Returns:
303 GoldenComparisonResult with pass/fail status.
305 Example:
306 >>> result = compare_to_golden(measured, golden)
307 >>> if result.passed:
308 ... print("PASS")
309 """
310 measured = trace.data.astype(np.float64)
311 reference = golden.data.copy()
312 upper = golden.upper_bound.copy()
313 lower = golden.lower_bound.copy()
315 # Handle length mismatch
316 if len(measured) != len(reference):
317 if interpolate:
318 # Interpolate measured to match reference length
319 x_measured = np.linspace(0, 1, len(measured))
320 x_reference = np.linspace(0, 1, len(reference))
321 measured = np.interp(x_reference, x_measured, measured)
322 else:
323 # Truncate to shorter length
324 min_len = min(len(measured), len(reference))
325 measured = measured[:min_len]
326 reference = reference[:min_len]
327 upper = upper[:min_len]
328 lower = lower[:min_len]
330 # Optionally align by cross-correlation
331 if align and len(measured) > 10:
332 from scipy import signal as sp_signal
334 corr = sp_signal.correlate(measured, reference, mode="same")
335 shift = len(measured) // 2 - np.argmax(corr)
336 if abs(shift) < len(measured) // 4: # Only shift if reasonable
337 measured = np.roll(measured, -shift)
339 # Find violations
340 upper_viol = np.where(measured > upper)[0]
341 lower_viol = np.where(measured < lower)[0]
342 all_violations = np.union1d(upper_viol, lower_viol)
344 num_violations = len(all_violations)
345 violation_rate = num_violations / len(measured) if len(measured) > 0 else 0.0
347 # Compute deviation statistics
348 deviation = measured - reference
349 max_deviation = float(np.max(np.abs(deviation)))
350 rms_deviation = float(np.sqrt(np.mean(deviation**2)))
352 # Compute margin
353 upper_margin = float(np.min(upper - measured))
354 lower_margin = float(np.min(measured - lower))
355 margin = min(upper_margin, lower_margin)
357 # Margin as percentage of tolerance
358 margin_pct = (margin / golden.tolerance * 100) if golden.tolerance > 0 else None
360 # Additional statistics
361 # Handle constant data (zero std) for correlation calculation
362 measured_std = np.std(measured)
363 reference_std = np.std(reference)
364 if measured_std == 0 or reference_std == 0:
365 # For constant data, correlation is undefined (NaN) or 1.0 if both are equal
366 if np.allclose(measured, reference): 366 ↛ 369line 366 didn't jump to line 369 because the condition on line 366 was always true
367 correlation = 1.0
368 else:
369 correlation = float("nan")
370 else:
371 correlation = float(np.corrcoef(measured, reference)[0, 1])
373 statistics = {
374 "mean_deviation": float(np.mean(deviation)),
375 "std_deviation": float(np.std(deviation)),
376 "max_positive_deviation": float(np.max(deviation)),
377 "max_negative_deviation": float(np.min(deviation)),
378 "correlation": correlation,
379 }
381 return GoldenComparisonResult(
382 passed=num_violations == 0,
383 num_violations=num_violations,
384 violation_rate=violation_rate,
385 max_deviation=max_deviation,
386 rms_deviation=rms_deviation,
387 upper_violations=upper_viol if len(upper_viol) > 0 else None,
388 lower_violations=lower_viol if len(lower_viol) > 0 else None,
389 margin=margin,
390 margin_percentage=margin_pct,
391 statistics=statistics,
392 )
395def batch_compare_to_golden(
396 traces: list[WaveformTrace],
397 golden: GoldenReference,
398 *,
399 align: bool = True,
400) -> list[GoldenComparisonResult]:
401 """Compare multiple traces to a golden reference.
403 Tests a batch of measured traces against the same golden reference.
405 Args:
406 traces: List of traces to compare.
407 golden: Golden reference.
408 align: Attempt to align traces.
410 Returns:
411 List of comparison results.
413 Example:
414 >>> results = batch_compare_to_golden(traces, golden)
415 >>> pass_rate = sum(r.passed for r in results) / len(results)
416 """
417 return [compare_to_golden(trace, golden, align=align) for trace in traces]
420def golden_from_average(
421 traces: list[WaveformTrace],
422 *,
423 tolerance_sigma: float = 3.0,
424 name: str = "averaged_golden",
425) -> GoldenReference:
426 """Create golden reference from averaged traces.
428 Creates a golden reference from the average of multiple traces,
429 with tolerance based on the standard deviation.
431 Args:
432 traces: List of traces to average.
433 tolerance_sigma: Number of standard deviations for tolerance.
434 name: Name for the reference.
436 Returns:
437 GoldenReference based on averaged data.
439 Raises:
440 AnalysisError: If no traces provided for averaging.
442 Example:
443 >>> golden = golden_from_average(sample_traces, tolerance_sigma=3)
444 """
445 if not traces:
446 raise AnalysisError("No traces provided for averaging")
448 # Get common length
449 lengths = [len(t.data) for t in traces]
450 min_len = min(lengths)
452 # Stack and average
453 stacked = np.array([t.data[:min_len] for t in traces], dtype=np.float64)
454 avg_data = np.mean(stacked, axis=0)
455 std_data = np.std(stacked, axis=0)
457 # Create tolerance from standard deviation
458 tolerance = std_data * tolerance_sigma
460 # Use constant tolerance (max of varying tolerance)
461 max_tol = float(np.max(tolerance))
463 return GoldenReference(
464 data=avg_data,
465 sample_rate=traces[0].metadata.sample_rate,
466 upper_bound=avg_data + tolerance,
467 lower_bound=avg_data - tolerance,
468 tolerance=max_tol,
469 tolerance_type="sigma",
470 name=name,
471 description=f"Averaged from {len(traces)} traces, {tolerance_sigma} sigma tolerance",
472 metadata={
473 "num_traces_averaged": len(traces),
474 "tolerance_sigma": tolerance_sigma,
475 },
476 )