Coverage for src / tracekit / comparison / compare.py: 98%
133 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Trace comparison functions for TraceKit.
3This module provides functions for comparing waveform traces including
4difference calculation, correlation, and similarity scoring.
7Example:
8 >>> from tracekit.comparison import compare_traces, similarity_score
9 >>> result = compare_traces(trace1, trace2)
10 >>> score = similarity_score(trace1, trace2)
12References:
13 IEEE 181-2011: Standard for Transitional Waveform Definitions
14"""
16from __future__ import annotations
18import warnings
19from dataclasses import dataclass
20from typing import TYPE_CHECKING, Literal
22import numpy as np
23from scipy import signal as sp_signal
24from scipy import stats
26from tracekit.core.types import TraceMetadata, WaveformTrace
28if TYPE_CHECKING:
29 from numpy.typing import NDArray
32@dataclass
33class ComparisonResult:
34 """Result of a trace comparison operation.
36 Attributes:
37 match: True if traces are considered matching.
38 similarity: Similarity score (0.0 to 1.0).
39 max_difference: Maximum absolute difference.
40 rms_difference: RMS of the difference.
41 correlation: Correlation coefficient.
42 difference_trace: Difference waveform (optional).
43 violations: Indices where difference exceeds threshold.
44 statistics: Additional comparison statistics.
45 """
47 match: bool
48 similarity: float
49 max_difference: float
50 rms_difference: float
51 correlation: float
52 difference_trace: WaveformTrace | None = None
53 violations: NDArray[np.int64] | None = None
54 statistics: dict | None = None # type: ignore[type-arg]
57def difference(
58 trace1: WaveformTrace,
59 trace2: WaveformTrace,
60 *,
61 normalize: bool = False,
62 channel_name: str | None = None,
63) -> WaveformTrace:
64 """Compute difference between two traces.
66 Calculates the element-wise difference (trace1 - trace2). Traces
67 are aligned to the shorter length.
69 Args:
70 trace1: First trace.
71 trace2: Second trace.
72 normalize: Normalize difference to percentage of reference range.
73 channel_name: Name for the result trace.
75 Returns:
76 WaveformTrace containing the difference.
78 Raises:
79 ValueError: If input traces contain NaN or Inf values.
81 Example:
82 >>> diff = difference(measured, reference)
83 >>> max_error = np.max(np.abs(diff.data))
84 """
85 # Get data
86 data1 = trace1.data.astype(np.float64)
87 data2 = trace2.data.astype(np.float64)
89 # Check for NaN/Inf values
90 if np.any(~np.isfinite(data1)) or np.any(~np.isfinite(data2)):
91 raise ValueError("Input traces contain NaN or Inf values")
93 # Align lengths
94 min_len = min(len(data1), len(data2))
95 data1 = data1[:min_len]
96 data2 = data2[:min_len]
98 # Compute difference
99 diff = data1 - data2
101 if normalize:
102 # Normalize to percentage of reference range
103 ref_range = np.ptp(data2)
104 if ref_range > 0: 104 ↛ 107line 104 didn't jump to line 107 because the condition on line 104 was always true
105 diff = (diff / ref_range) * 100.0
107 new_metadata = TraceMetadata(
108 sample_rate=trace1.metadata.sample_rate,
109 vertical_scale=None,
110 vertical_offset=None,
111 acquisition_time=trace1.metadata.acquisition_time,
112 trigger_info=trace1.metadata.trigger_info,
113 source_file=trace1.metadata.source_file,
114 channel_name=channel_name or "difference",
115 )
117 return WaveformTrace(data=diff, metadata=new_metadata)
120def correlation(
121 trace1: WaveformTrace,
122 trace2: WaveformTrace,
123 *,
124 mode: Literal["full", "same", "valid"] = "same",
125 normalize: bool = True,
126) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
127 """Compute cross-correlation between two traces.
129 Calculates the cross-correlation of two waveforms, useful for
130 finding time delays and pattern matching.
132 Args:
133 trace1: First trace.
134 trace2: Second trace.
135 mode: Correlation mode:
136 - "full": Full correlation (length N+M-1)
137 - "same": Same length as longer input
138 - "valid": Only overlapping region
139 normalize: Normalize to correlation coefficient (-1 to 1).
141 Returns:
142 Tuple of (lags, correlation_values).
144 Example:
145 >>> lags, corr = correlation(trace1, trace2)
146 >>> delay = lags[np.argmax(corr)]
147 """
148 data1 = trace1.data.astype(np.float64)
149 data2 = trace2.data.astype(np.float64)
151 if normalize:
152 # Normalize inputs
153 data1 = (data1 - np.mean(data1)) / (np.std(data1) + 1e-10)
154 data2 = (data2 - np.mean(data2)) / (np.std(data2) + 1e-10)
156 # Compute cross-correlation
157 corr = sp_signal.correlate(data1, data2, mode=mode)
159 if normalize:
160 # Normalize by length for correlation coefficient
161 corr = corr / len(data1)
163 # Compute lag axis in samples
164 if mode == "full":
165 lags = np.arange(-(len(data2) - 1), len(data1))
166 elif mode == "same":
167 lags = np.arange(-len(data1) // 2, len(data1) - len(data1) // 2)
168 else: # valid
169 lags = np.arange(0, len(data1) - len(data2) + 1)
171 return lags.astype(np.float64), corr
174def similarity_score(
175 trace1: WaveformTrace,
176 trace2: WaveformTrace,
177 *,
178 method: Literal["correlation", "rms", "mse", "cosine"] = "correlation",
179 normalize_amplitude: bool = True,
180 normalize_offset: bool = True,
181) -> float:
182 """Compute similarity score between two traces.
184 Returns a score from 0.0 (completely different) to 1.0 (identical).
186 Args:
187 trace1: First trace.
188 trace2: Second trace.
189 method: Similarity metric:
190 - "correlation": Pearson correlation coefficient (default)
191 - "rms": 1 - normalized RMS difference
192 - "mse": 1 - normalized mean squared error
193 - "cosine": Cosine similarity
194 normalize_amplitude: Normalize amplitude before comparison.
195 normalize_offset: Remove DC offset before comparison.
197 Returns:
198 Similarity score (0.0 to 1.0).
200 Raises:
201 ValueError: If input traces contain NaN or Inf values.
203 Example:
204 >>> score = similarity_score(measured, reference)
205 >>> if score > 0.95:
206 ... print("Traces match")
207 """
208 # Get data
209 data1 = trace1.data.astype(np.float64).copy()
210 data2 = trace2.data.astype(np.float64).copy()
212 # Check for NaN/Inf values
213 if np.any(~np.isfinite(data1)) or np.any(~np.isfinite(data2)):
214 raise ValueError("Input traces contain NaN or Inf values")
216 # Align lengths
217 min_len = min(len(data1), len(data2))
218 data1 = data1[:min_len]
219 data2 = data2[:min_len]
221 # Normalize offset (remove DC)
222 if normalize_offset:
223 data1 = data1 - np.mean(data1)
224 data2 = data2 - np.mean(data2)
226 # Normalize amplitude
227 if normalize_amplitude:
228 std1 = np.std(data1)
229 std2 = np.std(data2)
230 if std1 > 0:
231 data1 = data1 / std1
232 if std2 > 0:
233 data2 = data2 / std2
235 if method == "correlation":
236 # Pearson correlation coefficient
237 # Handle constant inputs gracefully
238 with warnings.catch_warnings():
239 warnings.filterwarnings("ignore", category=stats.ConstantInputWarning)
240 try:
241 r, _ = stats.pearsonr(data1, data2)
242 # Handle NaN result (constant traces after normalization)
243 if np.isnan(r):
244 # If both traces are constant and identical, perfect match
245 if np.allclose(data1, data2, equal_nan=False):
246 r = 1.0
247 else:
248 r = 0.0
249 except Exception:
250 r = 0.0
251 # Map from [-1, 1] to [0, 1]
252 return float((r + 1) / 2)
254 elif method == "rms":
255 # RMS-based similarity
256 rms_diff = np.sqrt(np.mean((data1 - data2) ** 2))
257 rms_ref = np.sqrt(np.mean(data2**2)) + 1e-10
258 return float(max(0, 1 - rms_diff / rms_ref))
260 elif method == "mse":
261 # MSE-based similarity
262 mse = np.mean((data1 - data2) ** 2)
263 var_ref = np.var(data2) + 1e-10
264 return float(max(0, 1 - mse / var_ref))
266 elif method == "cosine":
267 # Cosine similarity
268 dot = np.dot(data1, data2)
269 norm1 = np.linalg.norm(data1) + 1e-10
270 norm2 = np.linalg.norm(data2) + 1e-10
271 cosine = dot / (norm1 * norm2)
272 # Map from [-1, 1] to [0, 1]
273 return float((cosine + 1) / 2)
275 else:
276 raise ValueError(f"Unknown similarity method: {method}")
279def compare_traces(
280 trace1: WaveformTrace,
281 trace2: WaveformTrace,
282 *,
283 tolerance: float | None = None,
284 tolerance_pct: float | None = None,
285 method: Literal["absolute", "relative", "statistical"] = "absolute",
286 include_difference: bool = True,
287) -> ComparisonResult:
288 """Compare two traces and determine if they match.
290 Comprehensive comparison of two waveforms including difference
291 analysis, correlation, and match determination.
293 Args:
294 trace1: First trace (typically measured).
295 trace2: Second trace (typically reference).
296 tolerance: Absolute tolerance for matching.
297 tolerance_pct: Percentage tolerance (0-100) relative to reference range.
298 method: Comparison method:
299 - "absolute": Compare absolute values
300 - "relative": Compare relative to reference
301 - "statistical": Use statistical tests
302 include_difference: Include difference trace in result.
304 Returns:
305 ComparisonResult with match status and statistics.
307 Raises:
308 ValueError: If method is unknown.
310 Example:
311 >>> result = compare_traces(measured, golden, tolerance=0.01)
312 >>> if result.match:
313 ... print(f"Match! Similarity: {result.similarity:.1%}")
314 """
315 # Get data
316 data1 = trace1.data.astype(np.float64)
317 data2 = trace2.data.astype(np.float64)
319 # Align lengths
320 min_len = min(len(data1), len(data2))
321 data1 = data1[:min_len]
322 data2 = data2[:min_len]
324 # Compute difference
325 diff = data1 - data2
327 # Compute statistics
328 max_diff = float(np.max(np.abs(diff)))
329 rms_diff = float(np.sqrt(np.mean(diff**2)))
331 # Compute correlation
332 if len(data1) > 1:
333 # Handle constant inputs (e.g., DC signals) gracefully
334 with warnings.catch_warnings():
335 warnings.filterwarnings("ignore", category=stats.ConstantInputWarning)
336 try:
337 corr, _ = stats.pearsonr(data1, data2)
338 except Exception:
339 # Fallback for any correlation computation issues
340 corr = 0.0
341 else:
342 corr = 1.0 if data1[0] == data2[0] else 0.0
344 # Compute similarity score
345 sim_score = similarity_score(trace1, trace2)
347 # Determine tolerance
348 if tolerance is None and tolerance_pct is not None:
349 ref_range = float(np.ptp(data2))
350 tolerance = ref_range * tolerance_pct / 100.0
351 elif tolerance is None:
352 # Default: 1% of reference range
353 ref_range = float(np.ptp(data2))
354 tolerance = ref_range * 0.01
356 # Find violations
357 violations = np.where(np.abs(diff) > tolerance)[0]
359 # Determine match
360 if method == "absolute":
361 match = max_diff <= tolerance
362 elif method == "relative":
363 ref_range = float(np.ptp(data2)) + 1e-10
364 relative_max = max_diff / ref_range
365 match = relative_max <= (tolerance_pct or 1.0) / 100.0
366 elif method == "statistical":
367 # Use t-test for statistical matching
368 _, p_value = stats.ttest_rel(data1, data2)
369 match = p_value > 0.05 # No significant difference
370 else:
371 raise ValueError(f"Unknown method: {method}")
373 # Create difference trace if requested
374 diff_trace = None
375 if include_difference:
376 diff_trace = difference(trace1, trace2, channel_name="comparison_diff")
378 # Compute additional statistics
379 statistics = {
380 "mean_difference": float(np.mean(diff)),
381 "std_difference": float(np.std(diff)),
382 "median_difference": float(np.median(diff)),
383 "num_violations": len(violations),
384 "violation_rate": len(violations) / min_len if min_len > 0 else 0,
385 "p_value": float(stats.ttest_rel(data1, data2)[1]) if len(data1) > 1 else 1.0,
386 }
388 return ComparisonResult(
389 match=match,
390 similarity=sim_score,
391 max_difference=max_diff,
392 rms_difference=rms_diff,
393 correlation=float(corr),
394 difference_trace=diff_trace,
395 violations=violations if len(violations) > 0 else None,
396 statistics=statistics,
397 )