Coverage for src / tracekit / comparison / golden.py: 99%

139 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Golden waveform comparison for TraceKit. 

2 

3This module provides golden reference waveform management and comparison 

4functions for pass/fail testing against known-good waveforms. 

5 

6 

7Example: 

8 >>> from tracekit.comparison import create_golden, compare_to_golden 

9 >>> golden = create_golden(reference_trace) 

10 >>> result = compare_to_golden(measured_trace, golden) 

11 

12References: 

13 IEEE 181-2011: Standard for Transitional Waveform Definitions 

14""" 

15 

16from __future__ import annotations 

17 

18import json 

19from dataclasses import dataclass, field 

20from datetime import datetime 

21from pathlib import Path 

22from typing import TYPE_CHECKING, Any, Literal 

23 

24import numpy as np 

25 

26from tracekit.core.exceptions import AnalysisError, LoaderError 

27 

28if TYPE_CHECKING: 

29 from numpy.typing import NDArray 

30 

31 from tracekit.core.types import WaveformTrace 

32 

33 

34@dataclass 

35class GoldenReference: 

36 """Golden reference waveform for comparison. 

37 

38 Contains a reference waveform with tolerance bounds for pass/fail 

39 testing of measured waveforms. 

40 

41 Attributes: 

42 data: Reference waveform data. 

43 sample_rate: Sample rate in Hz. 

44 upper_bound: Upper tolerance bound. 

45 lower_bound: Lower tolerance bound. 

46 tolerance: Tolerance used to create bounds. 

47 tolerance_type: How tolerance was applied. 

48 name: Reference name. 

49 description: Optional description. 

50 created: Creation timestamp. 

51 metadata: Additional metadata. 

52 """ 

53 

54 data: NDArray[np.float64] 

55 sample_rate: float 

56 upper_bound: NDArray[np.float64] 

57 lower_bound: NDArray[np.float64] 

58 tolerance: float 

59 tolerance_type: Literal["absolute", "percentage", "sigma"] = "absolute" 

60 name: str = "golden" 

61 description: str = "" 

62 created: datetime = field(default_factory=datetime.now) 

63 metadata: dict[str, Any] = field(default_factory=dict) 

64 

65 @property 

66 def num_samples(self) -> int: 

67 """Number of samples in the reference.""" 

68 return len(self.data) 

69 

70 @property 

71 def duration(self) -> float: 

72 """Duration in seconds.""" 

73 return self.num_samples / self.sample_rate 

74 

75 def to_dict(self) -> dict[str, Any]: 

76 """Convert to dictionary for serialization.""" 

77 return { 

78 "data": self.data.tolist(), 

79 "sample_rate": self.sample_rate, 

80 "upper_bound": self.upper_bound.tolist(), 

81 "lower_bound": self.lower_bound.tolist(), 

82 "tolerance": self.tolerance, 

83 "tolerance_type": self.tolerance_type, 

84 "name": self.name, 

85 "description": self.description, 

86 "created": self.created.isoformat(), 

87 "metadata": self.metadata, 

88 } 

89 

90 @classmethod 

91 def from_dict(cls, data: dict[str, Any]) -> GoldenReference: 

92 """Create from dictionary.""" 

93 return cls( 

94 data=np.array(data["data"], dtype=np.float64), 

95 sample_rate=data["sample_rate"], 

96 upper_bound=np.array(data["upper_bound"], dtype=np.float64), 

97 lower_bound=np.array(data["lower_bound"], dtype=np.float64), 

98 tolerance=data["tolerance"], 

99 tolerance_type=data.get("tolerance_type", "absolute"), 

100 name=data.get("name", "golden"), 

101 description=data.get("description", ""), 

102 created=datetime.fromisoformat(data["created"]) 

103 if "created" in data 

104 else datetime.now(), 

105 metadata=data.get("metadata", {}), 

106 ) 

107 

108 def save(self, path: str | Path) -> None: 

109 """Save golden reference to file. 

110 

111 Args: 

112 path: File path (JSON format). 

113 """ 

114 path = Path(path) 

115 with open(path, "w") as f: 

116 json.dump(self.to_dict(), f, indent=2) 

117 

118 @classmethod 

119 def load(cls, path: str | Path) -> GoldenReference: 

120 """Load golden reference from file. 

121 

122 Args: 

123 path: File path. 

124 

125 Returns: 

126 GoldenReference instance. 

127 

128 Raises: 

129 LoaderError: If golden reference file not found. 

130 """ 

131 path = Path(path) 

132 if not path.exists(): 

133 raise LoaderError( 

134 f"Golden reference file not found: {path}", 

135 file_path=str(path), 

136 ) 

137 

138 with open(path) as f: 

139 data = json.load(f) 

140 

141 return cls.from_dict(data) 

142 

143 

144@dataclass 

145class GoldenComparisonResult: 

146 """Result of golden waveform comparison. 

147 

148 Attributes: 

149 passed: True if measured waveform is within tolerance. 

150 num_violations: Number of samples outside tolerance. 

151 violation_rate: Fraction of samples outside tolerance. 

152 max_deviation: Maximum deviation from reference. 

153 rms_deviation: RMS deviation from reference. 

154 upper_violations: Indices exceeding upper bound. 

155 lower_violations: Indices below lower bound. 

156 margin: Minimum margin to tolerance bound. 

157 margin_percentage: Margin as percentage of tolerance. 

158 statistics: Additional comparison statistics. 

159 """ 

160 

161 passed: bool 

162 num_violations: int 

163 violation_rate: float 

164 max_deviation: float 

165 rms_deviation: float 

166 upper_violations: NDArray[np.int64] | None = None 

167 lower_violations: NDArray[np.int64] | None = None 

168 margin: float | None = None 

169 margin_percentage: float | None = None 

170 statistics: dict[str, Any] = field(default_factory=dict) 

171 

172 

173def create_golden( 

174 trace: WaveformTrace, 

175 *, 

176 tolerance: float | None = None, 

177 tolerance_pct: float | None = None, 

178 tolerance_sigma: float | None = None, 

179 name: str = "golden", 

180 description: str = "", 

181) -> GoldenReference: 

182 """Create a golden reference from a trace. 

183 

184 Creates a golden reference waveform with tolerance bounds for 

185 subsequent comparison testing. 

186 

187 Args: 

188 trace: Reference waveform trace. 

189 tolerance: Absolute tolerance value. 

190 tolerance_pct: Percentage tolerance (0-100). 

191 tolerance_sigma: Tolerance as multiple of standard deviation. 

192 name: Name for the reference. 

193 description: Optional description. 

194 

195 Returns: 

196 GoldenReference for comparison testing. 

197 

198 Example: 

199 >>> golden = create_golden(trace, tolerance_pct=5) # 5% tolerance 

200 >>> golden = create_golden(trace, tolerance=0.01) # 10mV tolerance 

201 """ 

202 data = trace.data.astype(np.float64) 

203 

204 # Determine tolerance and type 

205 if tolerance is not None: 

206 tol = tolerance 

207 tol_type: Literal["absolute", "percentage", "sigma"] = "absolute" 

208 elif tolerance_pct is not None: 

209 # Calculate absolute tolerance from percentage 

210 data_range = float(np.ptp(data)) 

211 tol = data_range * tolerance_pct / 100.0 

212 tol_type = "percentage" 

213 elif tolerance_sigma is not None: 

214 # Calculate tolerance from standard deviation 

215 tol = float(np.std(data)) * tolerance_sigma 

216 tol_type = "sigma" 

217 else: 

218 # Default: 1% of range 

219 data_range = float(np.ptp(data)) 

220 tol = data_range * 0.01 

221 tol_type = "percentage" 

222 

223 # Create bounds 

224 upper_bound = data + tol 

225 lower_bound = data - tol 

226 

227 return GoldenReference( 

228 data=data, 

229 sample_rate=trace.metadata.sample_rate, 

230 upper_bound=upper_bound, 

231 lower_bound=lower_bound, 

232 tolerance=tol, 

233 tolerance_type=tol_type, 

234 name=name, 

235 description=description, 

236 metadata={ 

237 "source_file": trace.metadata.source_file, 

238 "channel_name": trace.metadata.channel_name, 

239 }, 

240 ) 

241 

242 

243def tolerance_envelope( 

244 trace: WaveformTrace, 

245 *, 

246 absolute: float | None = None, 

247 percentage: float | None = None, 

248 sigma: float | None = None, 

249) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

250 """Create tolerance envelope around a trace. 

251 

252 Generates upper and lower bounds based on the specified tolerance. 

253 

254 Args: 

255 trace: Reference trace. 

256 absolute: Absolute tolerance value. 

257 percentage: Percentage tolerance (0-100). 

258 sigma: Tolerance as multiple of standard deviation. 

259 

260 Returns: 

261 Tuple of (upper_bound, lower_bound) arrays. 

262 

263 Raises: 

264 ValueError: If no tolerance type specified. 

265 

266 Example: 

267 >>> upper, lower = tolerance_envelope(trace, percentage=5) 

268 """ 

269 data = trace.data.astype(np.float64) 

270 

271 if absolute is not None: 

272 tol = absolute 

273 elif percentage is not None: 

274 data_range = float(np.ptp(data)) 

275 tol = data_range * percentage / 100.0 

276 elif sigma is not None: 

277 tol = float(np.std(data)) * sigma 

278 else: 

279 raise ValueError("Must specify absolute, percentage, or sigma tolerance") 

280 

281 return data + tol, data - tol 

282 

283 

284def compare_to_golden( 

285 trace: WaveformTrace, 

286 golden: GoldenReference, 

287 *, 

288 align: bool = True, 

289 interpolate: bool = True, 

290) -> GoldenComparisonResult: 

291 """Compare a trace to a golden reference. 

292 

293 Tests if the measured trace falls within the tolerance bounds 

294 of the golden reference. 

295 

296 Args: 

297 trace: Measured trace to compare. 

298 golden: Golden reference to compare against. 

299 align: Attempt to align traces by cross-correlation. 

300 interpolate: Interpolate if sample counts differ. 

301 

302 Returns: 

303 GoldenComparisonResult with pass/fail status. 

304 

305 Example: 

306 >>> result = compare_to_golden(measured, golden) 

307 >>> if result.passed: 

308 ... print("PASS") 

309 """ 

310 measured = trace.data.astype(np.float64) 

311 reference = golden.data.copy() 

312 upper = golden.upper_bound.copy() 

313 lower = golden.lower_bound.copy() 

314 

315 # Handle length mismatch 

316 if len(measured) != len(reference): 

317 if interpolate: 

318 # Interpolate measured to match reference length 

319 x_measured = np.linspace(0, 1, len(measured)) 

320 x_reference = np.linspace(0, 1, len(reference)) 

321 measured = np.interp(x_reference, x_measured, measured) 

322 else: 

323 # Truncate to shorter length 

324 min_len = min(len(measured), len(reference)) 

325 measured = measured[:min_len] 

326 reference = reference[:min_len] 

327 upper = upper[:min_len] 

328 lower = lower[:min_len] 

329 

330 # Optionally align by cross-correlation 

331 if align and len(measured) > 10: 

332 from scipy import signal as sp_signal 

333 

334 corr = sp_signal.correlate(measured, reference, mode="same") 

335 shift = len(measured) // 2 - np.argmax(corr) 

336 if abs(shift) < len(measured) // 4: # Only shift if reasonable 

337 measured = np.roll(measured, -shift) 

338 

339 # Find violations 

340 upper_viol = np.where(measured > upper)[0] 

341 lower_viol = np.where(measured < lower)[0] 

342 all_violations = np.union1d(upper_viol, lower_viol) 

343 

344 num_violations = len(all_violations) 

345 violation_rate = num_violations / len(measured) if len(measured) > 0 else 0.0 

346 

347 # Compute deviation statistics 

348 deviation = measured - reference 

349 max_deviation = float(np.max(np.abs(deviation))) 

350 rms_deviation = float(np.sqrt(np.mean(deviation**2))) 

351 

352 # Compute margin 

353 upper_margin = float(np.min(upper - measured)) 

354 lower_margin = float(np.min(measured - lower)) 

355 margin = min(upper_margin, lower_margin) 

356 

357 # Margin as percentage of tolerance 

358 margin_pct = (margin / golden.tolerance * 100) if golden.tolerance > 0 else None 

359 

360 # Additional statistics 

361 # Handle constant data (zero std) for correlation calculation 

362 measured_std = np.std(measured) 

363 reference_std = np.std(reference) 

364 if measured_std == 0 or reference_std == 0: 

365 # For constant data, correlation is undefined (NaN) or 1.0 if both are equal 

366 if np.allclose(measured, reference): 366 ↛ 369line 366 didn't jump to line 369 because the condition on line 366 was always true

367 correlation = 1.0 

368 else: 

369 correlation = float("nan") 

370 else: 

371 correlation = float(np.corrcoef(measured, reference)[0, 1]) 

372 

373 statistics = { 

374 "mean_deviation": float(np.mean(deviation)), 

375 "std_deviation": float(np.std(deviation)), 

376 "max_positive_deviation": float(np.max(deviation)), 

377 "max_negative_deviation": float(np.min(deviation)), 

378 "correlation": correlation, 

379 } 

380 

381 return GoldenComparisonResult( 

382 passed=num_violations == 0, 

383 num_violations=num_violations, 

384 violation_rate=violation_rate, 

385 max_deviation=max_deviation, 

386 rms_deviation=rms_deviation, 

387 upper_violations=upper_viol if len(upper_viol) > 0 else None, 

388 lower_violations=lower_viol if len(lower_viol) > 0 else None, 

389 margin=margin, 

390 margin_percentage=margin_pct, 

391 statistics=statistics, 

392 ) 

393 

394 

395def batch_compare_to_golden( 

396 traces: list[WaveformTrace], 

397 golden: GoldenReference, 

398 *, 

399 align: bool = True, 

400) -> list[GoldenComparisonResult]: 

401 """Compare multiple traces to a golden reference. 

402 

403 Tests a batch of measured traces against the same golden reference. 

404 

405 Args: 

406 traces: List of traces to compare. 

407 golden: Golden reference. 

408 align: Attempt to align traces. 

409 

410 Returns: 

411 List of comparison results. 

412 

413 Example: 

414 >>> results = batch_compare_to_golden(traces, golden) 

415 >>> pass_rate = sum(r.passed for r in results) / len(results) 

416 """ 

417 return [compare_to_golden(trace, golden, align=align) for trace in traces] 

418 

419 

420def golden_from_average( 

421 traces: list[WaveformTrace], 

422 *, 

423 tolerance_sigma: float = 3.0, 

424 name: str = "averaged_golden", 

425) -> GoldenReference: 

426 """Create golden reference from averaged traces. 

427 

428 Creates a golden reference from the average of multiple traces, 

429 with tolerance based on the standard deviation. 

430 

431 Args: 

432 traces: List of traces to average. 

433 tolerance_sigma: Number of standard deviations for tolerance. 

434 name: Name for the reference. 

435 

436 Returns: 

437 GoldenReference based on averaged data. 

438 

439 Raises: 

440 AnalysisError: If no traces provided for averaging. 

441 

442 Example: 

443 >>> golden = golden_from_average(sample_traces, tolerance_sigma=3) 

444 """ 

445 if not traces: 

446 raise AnalysisError("No traces provided for averaging") 

447 

448 # Get common length 

449 lengths = [len(t.data) for t in traces] 

450 min_len = min(lengths) 

451 

452 # Stack and average 

453 stacked = np.array([t.data[:min_len] for t in traces], dtype=np.float64) 

454 avg_data = np.mean(stacked, axis=0) 

455 std_data = np.std(stacked, axis=0) 

456 

457 # Create tolerance from standard deviation 

458 tolerance = std_data * tolerance_sigma 

459 

460 # Use constant tolerance (max of varying tolerance) 

461 max_tol = float(np.max(tolerance)) 

462 

463 return GoldenReference( 

464 data=avg_data, 

465 sample_rate=traces[0].metadata.sample_rate, 

466 upper_bound=avg_data + tolerance, 

467 lower_bound=avg_data - tolerance, 

468 tolerance=max_tol, 

469 tolerance_type="sigma", 

470 name=name, 

471 description=f"Averaged from {len(traces)} traces, {tolerance_sigma} sigma tolerance", 

472 metadata={ 

473 "num_traces_averaged": len(traces), 

474 "tolerance_sigma": tolerance_sigma, 

475 }, 

476 )