Coverage for src / tracekit / comparison / compare.py: 98%

133 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Trace comparison functions for TraceKit. 

2 

3This module provides functions for comparing waveform traces including 

4difference calculation, correlation, and similarity scoring. 

5 

6 

7Example: 

8 >>> from tracekit.comparison import compare_traces, similarity_score 

9 >>> result = compare_traces(trace1, trace2) 

10 >>> score = similarity_score(trace1, trace2) 

11 

12References: 

13 IEEE 181-2011: Standard for Transitional Waveform Definitions 

14""" 

15 

16from __future__ import annotations 

17 

18import warnings 

19from dataclasses import dataclass 

20from typing import TYPE_CHECKING, Literal 

21 

22import numpy as np 

23from scipy import signal as sp_signal 

24from scipy import stats 

25 

26from tracekit.core.types import TraceMetadata, WaveformTrace 

27 

28if TYPE_CHECKING: 

29 from numpy.typing import NDArray 

30 

31 

32@dataclass 

33class ComparisonResult: 

34 """Result of a trace comparison operation. 

35 

36 Attributes: 

37 match: True if traces are considered matching. 

38 similarity: Similarity score (0.0 to 1.0). 

39 max_difference: Maximum absolute difference. 

40 rms_difference: RMS of the difference. 

41 correlation: Correlation coefficient. 

42 difference_trace: Difference waveform (optional). 

43 violations: Indices where difference exceeds threshold. 

44 statistics: Additional comparison statistics. 

45 """ 

46 

47 match: bool 

48 similarity: float 

49 max_difference: float 

50 rms_difference: float 

51 correlation: float 

52 difference_trace: WaveformTrace | None = None 

53 violations: NDArray[np.int64] | None = None 

54 statistics: dict | None = None # type: ignore[type-arg] 

55 

56 

57def difference( 

58 trace1: WaveformTrace, 

59 trace2: WaveformTrace, 

60 *, 

61 normalize: bool = False, 

62 channel_name: str | None = None, 

63) -> WaveformTrace: 

64 """Compute difference between two traces. 

65 

66 Calculates the element-wise difference (trace1 - trace2). Traces 

67 are aligned to the shorter length. 

68 

69 Args: 

70 trace1: First trace. 

71 trace2: Second trace. 

72 normalize: Normalize difference to percentage of reference range. 

73 channel_name: Name for the result trace. 

74 

75 Returns: 

76 WaveformTrace containing the difference. 

77 

78 Raises: 

79 ValueError: If input traces contain NaN or Inf values. 

80 

81 Example: 

82 >>> diff = difference(measured, reference) 

83 >>> max_error = np.max(np.abs(diff.data)) 

84 """ 

85 # Get data 

86 data1 = trace1.data.astype(np.float64) 

87 data2 = trace2.data.astype(np.float64) 

88 

89 # Check for NaN/Inf values 

90 if np.any(~np.isfinite(data1)) or np.any(~np.isfinite(data2)): 

91 raise ValueError("Input traces contain NaN or Inf values") 

92 

93 # Align lengths 

94 min_len = min(len(data1), len(data2)) 

95 data1 = data1[:min_len] 

96 data2 = data2[:min_len] 

97 

98 # Compute difference 

99 diff = data1 - data2 

100 

101 if normalize: 

102 # Normalize to percentage of reference range 

103 ref_range = np.ptp(data2) 

104 if ref_range > 0: 104 ↛ 107line 104 didn't jump to line 107 because the condition on line 104 was always true

105 diff = (diff / ref_range) * 100.0 

106 

107 new_metadata = TraceMetadata( 

108 sample_rate=trace1.metadata.sample_rate, 

109 vertical_scale=None, 

110 vertical_offset=None, 

111 acquisition_time=trace1.metadata.acquisition_time, 

112 trigger_info=trace1.metadata.trigger_info, 

113 source_file=trace1.metadata.source_file, 

114 channel_name=channel_name or "difference", 

115 ) 

116 

117 return WaveformTrace(data=diff, metadata=new_metadata) 

118 

119 

120def correlation( 

121 trace1: WaveformTrace, 

122 trace2: WaveformTrace, 

123 *, 

124 mode: Literal["full", "same", "valid"] = "same", 

125 normalize: bool = True, 

126) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

127 """Compute cross-correlation between two traces. 

128 

129 Calculates the cross-correlation of two waveforms, useful for 

130 finding time delays and pattern matching. 

131 

132 Args: 

133 trace1: First trace. 

134 trace2: Second trace. 

135 mode: Correlation mode: 

136 - "full": Full correlation (length N+M-1) 

137 - "same": Same length as longer input 

138 - "valid": Only overlapping region 

139 normalize: Normalize to correlation coefficient (-1 to 1). 

140 

141 Returns: 

142 Tuple of (lags, correlation_values). 

143 

144 Example: 

145 >>> lags, corr = correlation(trace1, trace2) 

146 >>> delay = lags[np.argmax(corr)] 

147 """ 

148 data1 = trace1.data.astype(np.float64) 

149 data2 = trace2.data.astype(np.float64) 

150 

151 if normalize: 

152 # Normalize inputs 

153 data1 = (data1 - np.mean(data1)) / (np.std(data1) + 1e-10) 

154 data2 = (data2 - np.mean(data2)) / (np.std(data2) + 1e-10) 

155 

156 # Compute cross-correlation 

157 corr = sp_signal.correlate(data1, data2, mode=mode) 

158 

159 if normalize: 

160 # Normalize by length for correlation coefficient 

161 corr = corr / len(data1) 

162 

163 # Compute lag axis in samples 

164 if mode == "full": 

165 lags = np.arange(-(len(data2) - 1), len(data1)) 

166 elif mode == "same": 

167 lags = np.arange(-len(data1) // 2, len(data1) - len(data1) // 2) 

168 else: # valid 

169 lags = np.arange(0, len(data1) - len(data2) + 1) 

170 

171 return lags.astype(np.float64), corr 

172 

173 

174def similarity_score( 

175 trace1: WaveformTrace, 

176 trace2: WaveformTrace, 

177 *, 

178 method: Literal["correlation", "rms", "mse", "cosine"] = "correlation", 

179 normalize_amplitude: bool = True, 

180 normalize_offset: bool = True, 

181) -> float: 

182 """Compute similarity score between two traces. 

183 

184 Returns a score from 0.0 (completely different) to 1.0 (identical). 

185 

186 Args: 

187 trace1: First trace. 

188 trace2: Second trace. 

189 method: Similarity metric: 

190 - "correlation": Pearson correlation coefficient (default) 

191 - "rms": 1 - normalized RMS difference 

192 - "mse": 1 - normalized mean squared error 

193 - "cosine": Cosine similarity 

194 normalize_amplitude: Normalize amplitude before comparison. 

195 normalize_offset: Remove DC offset before comparison. 

196 

197 Returns: 

198 Similarity score (0.0 to 1.0). 

199 

200 Raises: 

201 ValueError: If input traces contain NaN or Inf values. 

202 

203 Example: 

204 >>> score = similarity_score(measured, reference) 

205 >>> if score > 0.95: 

206 ... print("Traces match") 

207 """ 

208 # Get data 

209 data1 = trace1.data.astype(np.float64).copy() 

210 data2 = trace2.data.astype(np.float64).copy() 

211 

212 # Check for NaN/Inf values 

213 if np.any(~np.isfinite(data1)) or np.any(~np.isfinite(data2)): 

214 raise ValueError("Input traces contain NaN or Inf values") 

215 

216 # Align lengths 

217 min_len = min(len(data1), len(data2)) 

218 data1 = data1[:min_len] 

219 data2 = data2[:min_len] 

220 

221 # Normalize offset (remove DC) 

222 if normalize_offset: 

223 data1 = data1 - np.mean(data1) 

224 data2 = data2 - np.mean(data2) 

225 

226 # Normalize amplitude 

227 if normalize_amplitude: 

228 std1 = np.std(data1) 

229 std2 = np.std(data2) 

230 if std1 > 0: 

231 data1 = data1 / std1 

232 if std2 > 0: 

233 data2 = data2 / std2 

234 

235 if method == "correlation": 

236 # Pearson correlation coefficient 

237 # Handle constant inputs gracefully 

238 with warnings.catch_warnings(): 

239 warnings.filterwarnings("ignore", category=stats.ConstantInputWarning) 

240 try: 

241 r, _ = stats.pearsonr(data1, data2) 

242 # Handle NaN result (constant traces after normalization) 

243 if np.isnan(r): 

244 # If both traces are constant and identical, perfect match 

245 if np.allclose(data1, data2, equal_nan=False): 

246 r = 1.0 

247 else: 

248 r = 0.0 

249 except Exception: 

250 r = 0.0 

251 # Map from [-1, 1] to [0, 1] 

252 return float((r + 1) / 2) 

253 

254 elif method == "rms": 

255 # RMS-based similarity 

256 rms_diff = np.sqrt(np.mean((data1 - data2) ** 2)) 

257 rms_ref = np.sqrt(np.mean(data2**2)) + 1e-10 

258 return float(max(0, 1 - rms_diff / rms_ref)) 

259 

260 elif method == "mse": 

261 # MSE-based similarity 

262 mse = np.mean((data1 - data2) ** 2) 

263 var_ref = np.var(data2) + 1e-10 

264 return float(max(0, 1 - mse / var_ref)) 

265 

266 elif method == "cosine": 

267 # Cosine similarity 

268 dot = np.dot(data1, data2) 

269 norm1 = np.linalg.norm(data1) + 1e-10 

270 norm2 = np.linalg.norm(data2) + 1e-10 

271 cosine = dot / (norm1 * norm2) 

272 # Map from [-1, 1] to [0, 1] 

273 return float((cosine + 1) / 2) 

274 

275 else: 

276 raise ValueError(f"Unknown similarity method: {method}") 

277 

278 

279def compare_traces( 

280 trace1: WaveformTrace, 

281 trace2: WaveformTrace, 

282 *, 

283 tolerance: float | None = None, 

284 tolerance_pct: float | None = None, 

285 method: Literal["absolute", "relative", "statistical"] = "absolute", 

286 include_difference: bool = True, 

287) -> ComparisonResult: 

288 """Compare two traces and determine if they match. 

289 

290 Comprehensive comparison of two waveforms including difference 

291 analysis, correlation, and match determination. 

292 

293 Args: 

294 trace1: First trace (typically measured). 

295 trace2: Second trace (typically reference). 

296 tolerance: Absolute tolerance for matching. 

297 tolerance_pct: Percentage tolerance (0-100) relative to reference range. 

298 method: Comparison method: 

299 - "absolute": Compare absolute values 

300 - "relative": Compare relative to reference 

301 - "statistical": Use statistical tests 

302 include_difference: Include difference trace in result. 

303 

304 Returns: 

305 ComparisonResult with match status and statistics. 

306 

307 Raises: 

308 ValueError: If method is unknown. 

309 

310 Example: 

311 >>> result = compare_traces(measured, golden, tolerance=0.01) 

312 >>> if result.match: 

313 ... print(f"Match! Similarity: {result.similarity:.1%}") 

314 """ 

315 # Get data 

316 data1 = trace1.data.astype(np.float64) 

317 data2 = trace2.data.astype(np.float64) 

318 

319 # Align lengths 

320 min_len = min(len(data1), len(data2)) 

321 data1 = data1[:min_len] 

322 data2 = data2[:min_len] 

323 

324 # Compute difference 

325 diff = data1 - data2 

326 

327 # Compute statistics 

328 max_diff = float(np.max(np.abs(diff))) 

329 rms_diff = float(np.sqrt(np.mean(diff**2))) 

330 

331 # Compute correlation 

332 if len(data1) > 1: 

333 # Handle constant inputs (e.g., DC signals) gracefully 

334 with warnings.catch_warnings(): 

335 warnings.filterwarnings("ignore", category=stats.ConstantInputWarning) 

336 try: 

337 corr, _ = stats.pearsonr(data1, data2) 

338 except Exception: 

339 # Fallback for any correlation computation issues 

340 corr = 0.0 

341 else: 

342 corr = 1.0 if data1[0] == data2[0] else 0.0 

343 

344 # Compute similarity score 

345 sim_score = similarity_score(trace1, trace2) 

346 

347 # Determine tolerance 

348 if tolerance is None and tolerance_pct is not None: 

349 ref_range = float(np.ptp(data2)) 

350 tolerance = ref_range * tolerance_pct / 100.0 

351 elif tolerance is None: 

352 # Default: 1% of reference range 

353 ref_range = float(np.ptp(data2)) 

354 tolerance = ref_range * 0.01 

355 

356 # Find violations 

357 violations = np.where(np.abs(diff) > tolerance)[0] 

358 

359 # Determine match 

360 if method == "absolute": 

361 match = max_diff <= tolerance 

362 elif method == "relative": 

363 ref_range = float(np.ptp(data2)) + 1e-10 

364 relative_max = max_diff / ref_range 

365 match = relative_max <= (tolerance_pct or 1.0) / 100.0 

366 elif method == "statistical": 

367 # Use t-test for statistical matching 

368 _, p_value = stats.ttest_rel(data1, data2) 

369 match = p_value > 0.05 # No significant difference 

370 else: 

371 raise ValueError(f"Unknown method: {method}") 

372 

373 # Create difference trace if requested 

374 diff_trace = None 

375 if include_difference: 

376 diff_trace = difference(trace1, trace2, channel_name="comparison_diff") 

377 

378 # Compute additional statistics 

379 statistics = { 

380 "mean_difference": float(np.mean(diff)), 

381 "std_difference": float(np.std(diff)), 

382 "median_difference": float(np.median(diff)), 

383 "num_violations": len(violations), 

384 "violation_rate": len(violations) / min_len if min_len > 0 else 0, 

385 "p_value": float(stats.ttest_rel(data1, data2)[1]) if len(data1) > 1 else 1.0, 

386 } 

387 

388 return ComparisonResult( 

389 match=match, 

390 similarity=sim_score, 

391 max_difference=max_diff, 

392 rms_difference=rms_diff, 

393 correlation=float(corr), 

394 difference_trace=diff_trace, 

395 violations=violations if len(violations) > 0 else None, 

396 statistics=statistics, 

397 )