Coverage for src / tracekit / comparison / visualization.py: 86%

177 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Visualization utilities for trace comparison. 

2 

3This module provides visualization functions for comparing traces including 

4overlay plots, difference plots, and comparison heat maps. 

5 

6 

7Example: 

8 >>> from tracekit.comparison.visualization import ( 

9 ... plot_overlay, 

10 ... plot_difference, 

11 ... plot_comparison_heatmap 

12 ... ) 

13 >>> fig = plot_overlay(trace1, trace2) 

14 >>> fig = plot_difference(trace1, trace2) 

15 

16References: 

17 - Tufte, E. R. (2001). The Visual Display of Quantitative Information 

18""" 

19 

20from __future__ import annotations 

21 

22from typing import TYPE_CHECKING, Any 

23 

24import matplotlib.pyplot as plt 

25import numpy as np 

26from matplotlib.gridspec import GridSpec 

27 

28if TYPE_CHECKING: 

29 from matplotlib.figure import Figure 

30 

31 from tracekit.comparison.compare import ComparisonResult 

32 from tracekit.core.types import WaveformTrace 

33 

34 

35def plot_overlay( 

36 trace1: WaveformTrace, 

37 trace2: WaveformTrace, 

38 *, 

39 labels: tuple[str, str] = ("Trace 1", "Trace 2"), 

40 title: str = "Trace Comparison - Overlay", 

41 highlight_differences: bool = True, 

42 difference_threshold: float | None = None, 

43 figsize: tuple[float, float] = (10, 6), 

44 **kwargs: Any, 

45) -> Figure: 

46 """Create overlay plot showing both traces. 

47 

48 : Overlay plot with difference highlighting. 

49 

50 Args: 

51 trace1: First trace 

52 trace2: Second trace 

53 labels: Labels for the two traces 

54 title: Plot title 

55 highlight_differences: Highlight regions where traces differ 

56 difference_threshold: Threshold for highlighting (default: auto) 

57 figsize: Figure size (width, height) 

58 **kwargs: Additional arguments passed to plot() 

59 

60 Returns: 

61 Matplotlib Figure object 

62 

63 Example: 

64 >>> from tracekit.comparison.visualization import plot_overlay 

65 >>> fig = plot_overlay(measured, reference, 

66 ... labels=("Measured", "Reference")) 

67 >>> plt.show() 

68 

69 References: 

70 CMP-003: Overlay plot with difference highlighting 

71 """ 

72 fig, ax = plt.subplots(figsize=figsize) 

73 

74 # Get data 

75 data1 = trace1.data.astype(np.float64) 

76 data2 = trace2.data.astype(np.float64) 

77 

78 # Align lengths 

79 min_len = min(len(data1), len(data2)) 

80 data1 = data1[:min_len] 

81 data2 = data2[:min_len] 

82 

83 # Create time axis 

84 if hasattr(trace1, "metadata") and trace1.metadata.sample_rate is not None: 84 ↛ 89line 84 didn't jump to line 89 because the condition on line 84 was always true

85 sample_rate = trace1.metadata.sample_rate 

86 time = np.arange(min_len) / sample_rate 

87 xlabel = "Time (s)" 

88 else: 

89 time = np.arange(min_len, dtype=np.float64) 

90 xlabel = "Sample" 

91 

92 # Plot traces 

93 ax.plot(time, data1, label=labels[0], alpha=0.7, linewidth=1.5, **kwargs) 

94 ax.plot(time, data2, label=labels[1], alpha=0.7, linewidth=1.5, **kwargs) 

95 

96 # Highlight differences 

97 if highlight_differences: 

98 diff = np.abs(data1 - data2) 

99 if difference_threshold is None: 

100 # Auto threshold: mean + 2*std of difference 

101 difference_threshold = float(np.mean(diff) + 2 * np.std(diff)) 

102 

103 # Find regions with significant difference 

104 diff_mask = diff > difference_threshold 

105 if np.any(diff_mask): 105 ↛ 107line 105 didn't jump to line 107 because the condition on line 105 was never true

106 # Highlight regions with vertical spans 

107 in_region = False 

108 start_idx = 0 

109 for i in range(len(diff_mask)): 

110 if diff_mask[i] and not in_region: 

111 start_idx = i 

112 in_region = True 

113 elif not diff_mask[i] and in_region: 

114 ax.axvspan( 

115 time[start_idx], 

116 time[i - 1], 

117 alpha=0.2, 

118 color="red", 

119 label="Difference" if start_idx == 0 else "", 

120 ) 

121 in_region = False 

122 # Handle last region 

123 if in_region: 

124 ax.axvspan(time[start_idx], time[-1], alpha=0.2, color="red") 

125 

126 ax.set_xlabel(xlabel) 

127 ax.set_ylabel("Amplitude") 

128 ax.set_title(title) 

129 ax.legend() 

130 ax.grid(True, alpha=0.3) 

131 plt.tight_layout() 

132 

133 return fig 

134 

135 

136def plot_difference( 

137 trace1: WaveformTrace, 

138 trace2: WaveformTrace, 

139 *, 

140 title: str = "Trace Comparison - Difference", 

141 normalize: bool = False, 

142 show_statistics: bool = True, 

143 figsize: tuple[float, float] = (10, 6), 

144 **kwargs: Any, 

145) -> Figure: 

146 """Create difference plot (trace1 - trace2). 

147 

148 : Difference trace visualization. 

149 

150 Args: 

151 trace1: First trace 

152 trace2: Second trace 

153 title: Plot title 

154 normalize: Normalize difference to percentage 

155 show_statistics: Show statistics text box 

156 figsize: Figure size 

157 **kwargs: Additional arguments passed to plot() 

158 

159 Returns: 

160 Matplotlib Figure object 

161 

162 Example: 

163 >>> from tracekit.comparison.visualization import plot_difference 

164 >>> fig = plot_difference(measured, reference, normalize=True) 

165 >>> plt.show() 

166 

167 References: 

168 CMP-003: Comparison Visualization 

169 """ 

170 fig, ax = plt.subplots(figsize=figsize) 

171 

172 # Get data 

173 data1 = trace1.data.astype(np.float64) 

174 data2 = trace2.data.astype(np.float64) 

175 

176 # Align lengths 

177 min_len = min(len(data1), len(data2)) 

178 data1 = data1[:min_len] 

179 data2 = data2[:min_len] 

180 

181 # Compute difference 

182 diff = data1 - data2 

183 

184 if normalize: 

185 # Normalize to percentage of reference range 

186 ref_range = np.ptp(data2) 

187 if ref_range > 0: 187 ↛ 189line 187 didn't jump to line 189 because the condition on line 187 was always true

188 diff = (diff / ref_range) * 100.0 

189 ylabel = "Difference (%)" 

190 else: 

191 ylabel = "Difference" 

192 

193 # Create time axis 

194 if hasattr(trace1, "metadata") and trace1.metadata.sample_rate is not None: 

195 sample_rate = trace1.metadata.sample_rate 

196 time = np.arange(min_len) / sample_rate 

197 xlabel = "Time (s)" 

198 else: 

199 time = np.arange(min_len, dtype=np.float64) 

200 xlabel = "Sample" 

201 

202 # Plot difference 

203 ax.plot(time, diff, label="Difference", **kwargs) 

204 ax.axhline(y=0, color="k", linestyle="--", alpha=0.5, linewidth=1) 

205 

206 # Add statistics text box 

207 if show_statistics: 

208 max_diff = float(np.max(np.abs(diff))) 

209 rms_diff = float(np.sqrt(np.mean(diff**2))) 

210 mean_diff = float(np.mean(diff)) 

211 std_diff = float(np.std(diff)) 

212 

213 stats_text = ( 

214 f"Max: {max_diff:.3f}\nRMS: {rms_diff:.3f}\nMean: {mean_diff:.3f}\nStd: {std_diff:.3f}" 

215 ) 

216 

217 ax.text( 

218 0.02, 

219 0.98, 

220 stats_text, 

221 transform=ax.transAxes, 

222 verticalalignment="top", 

223 bbox={"boxstyle": "round", "facecolor": "wheat", "alpha": 0.8}, 

224 fontsize=9, 

225 family="monospace", 

226 ) 

227 

228 ax.set_xlabel(xlabel) 

229 ax.set_ylabel(ylabel) 

230 ax.set_title(title) 

231 ax.grid(True, alpha=0.3) 

232 plt.tight_layout() 

233 

234 return fig 

235 

236 

237def plot_comparison_heatmap( 

238 trace1: WaveformTrace, 

239 trace2: WaveformTrace, 

240 *, 

241 title: str = "Trace Comparison - Difference Heatmap", 

242 window_size: int = 100, 

243 figsize: tuple[float, float] = (10, 8), 

244 **kwargs: Any, 

245) -> Figure: 

246 """Create difference heatmap showing where changes occur. 

247 

248 : Difference heat map showing where changes occur. 

249 

250 Args: 

251 trace1: First trace 

252 trace2: Second trace 

253 title: Plot title 

254 window_size: Window size for heatmap bins 

255 figsize: Figure size 

256 **kwargs: Additional arguments passed to imshow() 

257 

258 Returns: 

259 Matplotlib Figure object 

260 

261 Example: 

262 >>> from tracekit.comparison.visualization import plot_comparison_heatmap 

263 >>> fig = plot_comparison_heatmap(trace1, trace2, window_size=50) 

264 >>> plt.show() 

265 

266 References: 

267 CMP-003: Difference heat map showing where changes occur 

268 """ 

269 fig = plt.figure(figsize=figsize) 

270 gs = GridSpec(2, 1, height_ratios=[3, 1], hspace=0.3) 

271 ax_heat = fig.add_subplot(gs[0]) 

272 ax_trace = fig.add_subplot(gs[1], sharex=ax_heat) 

273 

274 # Get data 

275 data1 = trace1.data.astype(np.float64) 

276 data2 = trace2.data.astype(np.float64) 

277 

278 # Align lengths 

279 min_len = min(len(data1), len(data2)) 

280 data1 = data1[:min_len] 

281 data2 = data2[:min_len] 

282 

283 # Compute difference 

284 diff = np.abs(data1 - data2) 

285 

286 # Create windowed heatmap 

287 n_windows = min_len // window_size 

288 if n_windows == 0: 

289 n_windows = 1 

290 window_size = min_len 

291 

292 heatmap_data = np.zeros((10, n_windows)) 

293 for i in range(n_windows): 

294 start = i * window_size 

295 end = min(start + window_size, min_len) 

296 window_diff = diff[start:end] 

297 

298 # Bin into 10 levels based on amplitude 

299 window_data1 = data1[start:end] 

300 window_data2 = data2[start:end] 

301 y_min = min(np.min(window_data1), np.min(window_data2)) 

302 y_max = max(np.max(window_data1), np.max(window_data2)) 

303 

304 if y_max - y_min > 0: 304 ↛ 293line 304 didn't jump to line 293 because the condition on line 304 was always true

305 bins = np.linspace(y_min, y_max, 11) 

306 for sample_idx in range(len(window_diff)): 

307 y_val = window_data1[sample_idx] # window_data1 is already sliced 

308 bin_idx = np.digitize(y_val, bins) - 1 

309 bin_idx = max(0, min(9, bin_idx)) 

310 heatmap_data[bin_idx, i] += window_diff[sample_idx] 

311 

312 # Normalize heatmap 

313 heatmap_data = heatmap_data / window_size 

314 

315 # Plot heatmap 

316 im = ax_heat.imshow( 

317 heatmap_data, 

318 aspect="auto", 

319 cmap="hot", 

320 origin="lower", 

321 interpolation="nearest", 

322 **kwargs, 

323 ) 

324 plt.colorbar(im, ax=ax_heat, label="Average Difference") 

325 

326 ax_heat.set_ylabel("Amplitude Bin") 

327 ax_heat.set_title(title) 

328 

329 # Plot difference trace below 

330 if hasattr(trace1, "metadata") and trace1.metadata.sample_rate is not None: 330 ↛ 335line 330 didn't jump to line 335 because the condition on line 330 was always true

331 sample_rate = trace1.metadata.sample_rate 

332 time = np.arange(min_len) / sample_rate 

333 xlabel = "Time (s)" 

334 else: 

335 time = np.arange(min_len, dtype=np.float64) 

336 xlabel = "Sample" 

337 

338 ax_trace.plot(time, diff, linewidth=0.5) 

339 ax_trace.set_xlabel(xlabel) 

340 ax_trace.set_ylabel("Difference") 

341 ax_trace.grid(True, alpha=0.3) 

342 

343 plt.tight_layout() 

344 return fig 

345 

346 

347def plot_comparison_summary( 

348 result: ComparisonResult, 

349 *, 

350 title: str = "Trace Comparison Summary", 

351 figsize: tuple[float, float] = (12, 8), 

352) -> Figure: 

353 """Create comprehensive comparison summary figure. 

354 

355 : Summary table of key differences. 

356 

357 Args: 

358 result: ComparisonResult from compare_traces() 

359 title: Plot title 

360 figsize: Figure size 

361 

362 Returns: 

363 Matplotlib Figure object 

364 

365 Example: 

366 >>> from tracekit.comparison import compare_traces 

367 >>> from tracekit.comparison.visualization import plot_comparison_summary 

368 >>> result = compare_traces(trace1, trace2) 

369 >>> fig = plot_comparison_summary(result) 

370 >>> plt.show() 

371 

372 References: 

373 CMP-003: Summary table of key differences 

374 """ 

375 fig = plt.figure(figsize=figsize) 

376 gs = GridSpec(3, 2, hspace=0.4, wspace=0.3) 

377 

378 # Statistics table 

379 ax_stats = fig.add_subplot(gs[0, :]) 

380 ax_stats.axis("off") 

381 

382 stats_data = [ 

383 ["Match Status", "PASS ✓" if result.match else "FAIL ✗"], 

384 ["Similarity Score", f"{result.similarity:.4f}"], 

385 ["Correlation", f"{result.correlation:.4f}"], 

386 ["Max Difference", f"{result.max_difference:.6f}"], 

387 ["RMS Difference", f"{result.rms_difference:.6f}"], 

388 ] 

389 

390 if result.statistics: 390 ↛ 399line 390 didn't jump to line 399 because the condition on line 390 was always true

391 stats_data.extend( 

392 [ 

393 ["Mean Difference", f"{result.statistics['mean_difference']:.6f}"], 

394 ["Violations", f"{result.statistics['num_violations']}"], 

395 ["Violation Rate", f"{result.statistics['violation_rate'] * 100:.2f}%"], 

396 ] 

397 ) 

398 

399 table = ax_stats.table( 

400 cellText=stats_data, 

401 colLabels=["Metric", "Value"], 

402 cellLoc="left", 

403 loc="center", 

404 bbox=[0, 0, 1, 1], # type: ignore[arg-type] 

405 ) 

406 table.auto_set_font_size(False) 

407 table.set_fontsize(10) 

408 table.scale(1, 2) 

409 

410 # Color code match status 

411 if result.match: 

412 table[(1, 1)].set_facecolor("#90EE90") # Light green 

413 else: 

414 table[(1, 1)].set_facecolor("#FFB6C1") # Light red 

415 

416 ax_stats.set_title(title, fontsize=14, fontweight="bold", pad=20) 

417 

418 # Overlay plot 

419 if result.difference_trace is not None: 419 ↛ 433line 419 didn't jump to line 433 because the condition on line 419 was always true

420 # Plot difference trace 

421 ax_overlay = fig.add_subplot(gs[1, :]) 

422 n_samples = len(result.difference_trace.data) 

423 time = np.arange(n_samples) 

424 ax_overlay.plot(time, result.difference_trace.data, label="Difference") 

425 ax_overlay.axhline(y=0, color="k", linestyle="--", alpha=0.5) 

426 ax_overlay.set_xlabel("Sample") 

427 ax_overlay.set_ylabel("Difference") 

428 ax_overlay.set_title("Difference Trace") 

429 ax_overlay.legend() 

430 ax_overlay.grid(True, alpha=0.3) 

431 

432 # Histogram of differences 

433 if result.difference_trace is not None: 433 ↛ 445line 433 didn't jump to line 445 because the condition on line 433 was always true

434 ax_hist = fig.add_subplot(gs[2, 0]) 

435 diff_data = result.difference_trace.data 

436 ax_hist.hist(diff_data, bins=50, edgecolor="black", alpha=0.7) 

437 ax_hist.axvline(x=0, color="r", linestyle="--", linewidth=2, label="Zero difference") 

438 ax_hist.set_xlabel("Difference") 

439 ax_hist.set_ylabel("Count") 

440 ax_hist.set_title("Difference Distribution") 

441 ax_hist.legend() 

442 ax_hist.grid(True, alpha=0.3) 

443 

444 # Violation locations 

445 ax_viol = fig.add_subplot(gs[2, 1]) 

446 if result.violations is not None and len(result.violations) > 0: 

447 ax_viol.scatter( 

448 result.violations, 

449 np.ones_like(result.violations), 

450 marker="|", 

451 s=100, 

452 color="red", 

453 alpha=0.5, 

454 ) 

455 ax_viol.set_xlim(0, len(result.difference_trace.data) if result.difference_trace else 1000) 

456 ax_viol.set_ylim(0.5, 1.5) 

457 ax_viol.set_xlabel("Sample Index") 

458 ax_viol.set_title(f"Violation Locations ({len(result.violations)} total)") 

459 ax_viol.set_yticks([]) 

460 else: 

461 ax_viol.text( 

462 0.5, 

463 0.5, 

464 "No Violations", 

465 ha="center", 

466 va="center", 

467 fontsize=14, 

468 color="green", 

469 ) 

470 ax_viol.axis("off") 

471 

472 plt.tight_layout() 

473 return fig 

474 

475 

476__all__ = [ 

477 "plot_comparison_heatmap", 

478 "plot_comparison_summary", 

479 "plot_difference", 

480 "plot_overlay", 

481]