Coverage for src / tracekit / analyzers / digital / correlation.py: 96%
213 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Multi-channel time correlation for synchronized analysis.
3This module provides tools for correlating and aligning multiple signal channels
4that may have timing offsets, different sample rates, or require trigger-based
5synchronization.
8Example:
9 >>> from tracekit.analyzers.digital.correlation import correlate_channels, align_by_trigger
10 >>> result = correlate_channels(channel_a, channel_b, sample_rate=1e9)
11 >>> print(f"Time offset: {result.offset_seconds:.9f} seconds")
12 >>> aligned = align_by_trigger(channels, trigger_channel='clk', edge='rising')
13 >>> print(f"Aligned channels: {aligned.channel_names}")
15References:
16 Oppenheim & Schafer: Discrete-Time Signal Processing (3rd Ed), Chapter 2
17 Press et al: Numerical Recipes (3rd Ed), Section 13.2 - Correlation
18"""
20from __future__ import annotations
22from dataclasses import dataclass
23from typing import TYPE_CHECKING, Literal
25import numpy as np
26from scipy import signal
28from tracekit.core.exceptions import InsufficientDataError, ValidationError
30if TYPE_CHECKING:
31 from numpy.typing import NDArray
34@dataclass
35class CorrelationResult:
36 """Result of cross-correlation analysis.
40 Attributes:
41 offset_samples: Time offset in samples (positive = channel_b leads).
42 offset_seconds: Time offset in seconds.
43 correlation_coefficient: Peak correlation value (-1.0 to 1.0).
44 confidence: Confidence score (0.0 to 1.0) based on peak sharpness.
45 quality: Quality classification.
46 """
48 offset_samples: int
49 offset_seconds: float
50 correlation_coefficient: float
51 confidence: float
52 quality: str # 'excellent', 'good', 'fair', 'poor'
55class CorrelatedChannels:
56 """Container for time-aligned multi-channel data.
60 Attributes:
61 channels: Dictionary mapping channel names to aligned data arrays.
62 sample_rate: Common sample rate for all channels.
63 offsets: Dictionary mapping channel names to their time offsets (samples).
64 """
66 def __init__(
67 self, channels: dict[str, NDArray[np.float64]], sample_rate: float, offsets: dict[str, int]
68 ):
69 """Initialize correlated channels container.
71 Args:
72 channels: Dictionary of channel name -> aligned data array.
73 sample_rate: Sample rate in Hz (same for all channels).
74 offsets: Dictionary of channel name -> offset in samples.
76 Raises:
77 ValidationError: If channels are empty or inconsistent.
78 """
79 if not channels:
80 raise ValidationError("At least one channel is required")
82 # Validate all channels have same length
83 lengths = {name: len(data) for name, data in channels.items()}
84 if len(set(lengths.values())) > 1:
85 raise ValidationError(f"Channel length mismatch: {lengths}")
87 if sample_rate <= 0:
88 raise ValidationError(f"Sample rate must be positive, got {sample_rate}")
90 self.channels = channels
91 self.sample_rate = float(sample_rate)
92 self.offsets = offsets
94 @property
95 def channel_names(self) -> list[str]:
96 """Get list of channel names."""
97 return list(self.channels.keys())
99 def get_channel(self, name: str) -> NDArray[np.float64]:
100 """Get aligned data for a specific channel.
102 Args:
103 name: Channel name.
105 Returns:
106 Aligned data array.
107 """
108 return self.channels[name]
110 def get_time_vector(self) -> NDArray[np.float64]:
111 """Get time vector for aligned data.
113 Returns:
114 Time array in seconds, starting from 0.
115 """
116 first_channel = next(iter(self.channels.values()))
117 n_samples = len(first_channel)
118 return np.arange(n_samples) / self.sample_rate
121class ChannelCorrelator:
122 """Correlate multiple signal channels in time.
126 This class provides methods for aligning channels using cross-correlation,
127 trigger edge detection, or resampling to a common sample rate.
128 """
130 def __init__(self, reference_channel: str | None = None):
131 """Initialize correlator.
133 Args:
134 reference_channel: Name of reference channel for multi-channel alignment.
135 If None, first channel will be used as reference.
136 """
137 self.reference_channel = reference_channel
139 def correlate(
140 self,
141 signal1: NDArray[np.float64],
142 signal2: NDArray[np.float64],
143 ) -> float:
144 """Compute correlation coefficient between two signals.
146 Simple correlation interface for test compatibility.
148 Args:
149 signal1: First signal array.
150 signal2: Second signal array.
152 Returns:
153 Correlation coefficient (-1.0 to 1.0).
155 Example:
156 >>> correlator = ChannelCorrelator()
157 >>> corr = correlator.correlate(signal1, signal2)
158 """
159 signal1 = np.asarray(signal1, dtype=np.float64)
160 signal2 = np.asarray(signal2, dtype=np.float64)
162 if len(signal1) != len(signal2):
163 # Use shorter length
164 min_len = min(len(signal1), len(signal2))
165 signal1 = signal1[:min_len]
166 signal2 = signal2[:min_len]
168 if len(signal1) < 2:
169 return 0.0
171 # Compute Pearson correlation coefficient
172 s1_centered = signal1 - np.mean(signal1)
173 s2_centered = signal2 - np.mean(signal2)
175 num = np.sum(s1_centered * s2_centered)
176 denom = np.sqrt(np.sum(s1_centered**2) * np.sum(s2_centered**2))
178 if denom == 0:
179 return 0.0
181 return float(num / denom)
183 def find_lag(
184 self,
185 signal1: NDArray[np.float64],
186 signal2: NDArray[np.float64],
187 ) -> int:
188 """Find the time lag between two signals using cross-correlation.
190 Args:
191 signal1: First signal array.
192 signal2: Second signal array.
194 Returns:
195 Lag in samples (positive = signal2 lags signal1).
197 Example:
198 >>> correlator = ChannelCorrelator()
199 >>> lag = correlator.find_lag(signal1, signal2)
200 """
201 signal1 = np.asarray(signal1, dtype=np.float64)
202 signal2 = np.asarray(signal2, dtype=np.float64)
204 if len(signal1) < 2 or len(signal2) < 2:
205 return 0
207 # Center signals
208 s1_centered = signal1 - np.mean(signal1)
209 s2_centered = signal2 - np.mean(signal2)
211 # Compute cross-correlation
212 correlation = np.correlate(s1_centered, s2_centered, mode="full")
214 # Find peak
215 peak_idx = np.argmax(np.abs(correlation))
217 # Convert to lag (relative to signal2)
218 lag = peak_idx - (len(signal2) - 1)
220 return int(lag)
222 def correlation_matrix(
223 self,
224 channels: list[NDArray[np.float64]],
225 ) -> NDArray[np.float64]:
226 """Compute pairwise correlation matrix for multiple channels.
228 Args:
229 channels: List of signal arrays.
231 Returns:
232 NxN correlation matrix where N is number of channels.
234 Example:
235 >>> correlator = ChannelCorrelator()
236 >>> matrix = correlator.correlation_matrix([ch1, ch2, ch3])
237 """
238 n = len(channels)
239 matrix = np.ones((n, n), dtype=np.float64)
241 for i in range(n):
242 for j in range(i + 1, n):
243 corr = self.correlate(channels[i], channels[j])
244 matrix[i, j] = corr
245 matrix[j, i] = corr
247 return matrix
249 def correlate_channels(
250 self,
251 channel_a: NDArray[np.float64],
252 channel_b: NDArray[np.float64],
253 sample_rate: float = 1.0,
254 ) -> CorrelationResult:
255 """Find time offset between two channels using cross-correlation.
259 Uses normalized cross-correlation to find the time offset that maximizes
260 alignment between two channels. Handles zero-mean normalization for
261 robustness against DC offsets.
263 Args:
264 channel_a: First channel data.
265 channel_b: Second channel data.
266 sample_rate: Sample rate in Hz (default 1.0 for sample-based results).
268 Returns:
269 CorrelationResult with offset and quality metrics.
271 Raises:
272 InsufficientDataError: If channels are too short.
273 ValidationError: If sample rate is invalid.
274 """
275 if len(channel_a) < 2 or len(channel_b) < 2:
276 raise InsufficientDataError("Channels must have at least 2 samples")
278 if sample_rate <= 0:
279 raise ValidationError(f"Sample rate must be positive, got {sample_rate}")
281 # Convert to zero-mean for better correlation
282 a_mean = np.mean(channel_a)
283 b_mean = np.mean(channel_b)
284 a_centered = channel_a - a_mean
285 b_centered = channel_b - b_mean
287 # Compute cross-correlation using scipy (more efficient than numpy)
288 correlation = signal.correlate(a_centered, b_centered, mode="full", method="auto")
290 # Normalize by signal energies for correlation coefficient
291 a_energy = np.sum(a_centered**2)
292 b_energy = np.sum(b_centered**2)
294 if a_energy == 0 or b_energy == 0:
295 # One or both signals are constant
296 return CorrelationResult(
297 offset_samples=0,
298 offset_seconds=0.0,
299 correlation_coefficient=0.0,
300 confidence=0.0,
301 quality="poor",
302 )
304 normalization = np.sqrt(a_energy * b_energy)
305 correlation_normalized = correlation / normalization
307 # Find peak correlation
308 peak_idx = np.argmax(np.abs(correlation_normalized))
309 peak_value = correlation_normalized[peak_idx]
311 # Convert peak index to offset (positive = channel_b leads)
312 offset_samples = peak_idx - (len(channel_b) - 1)
313 offset_seconds = offset_samples / sample_rate
315 # Estimate confidence from peak sharpness
316 # High confidence = sharp peak, low confidence = broad/weak peak
317 confidence = self._estimate_correlation_confidence(correlation_normalized, int(peak_idx))
319 # Classify quality
320 quality = self._classify_correlation_quality(abs(peak_value), confidence)
322 return CorrelationResult(
323 offset_samples=int(offset_samples),
324 offset_seconds=float(offset_seconds),
325 correlation_coefficient=float(peak_value),
326 confidence=float(confidence),
327 quality=quality,
328 )
330 def align_by_trigger(
331 self,
332 channels: dict[str, NDArray[np.float64]],
333 trigger_channel: str,
334 edge: Literal["rising", "falling"] = "rising",
335 threshold: float = 0.5,
336 ) -> CorrelatedChannels:
337 """Align channels using trigger edge from one channel.
341 Aligns all channels by detecting the first trigger edge in the specified
342 channel and trimming all channels to start from that point.
344 Args:
345 channels: Dictionary of channel name -> data array.
346 trigger_channel: Name of channel to use for trigger detection.
347 edge: Edge type to detect ('rising' or 'falling').
348 threshold: Trigger threshold (normalized 0-1 if float, or absolute value).
350 Returns:
351 CorrelatedChannels with aligned data.
353 Raises:
354 InsufficientDataError: If trigger channel is too short.
355 ValidationError: If trigger channel not found or no edge detected.
356 """
357 if trigger_channel not in channels:
358 raise ValidationError(f"Trigger channel '{trigger_channel}' not found")
360 trigger_data = channels[trigger_channel]
362 if len(trigger_data) < 2:
363 raise InsufficientDataError("Trigger channel too short")
365 # Normalize threshold if needed
366 if 0.0 <= threshold <= 1.0:
367 data_min = np.min(trigger_data)
368 data_max = np.max(trigger_data)
369 threshold_abs = float(data_min + threshold * (data_max - data_min))
370 else:
371 threshold_abs = float(threshold)
373 # Detect first edge
374 trigger_idx = self._find_first_edge(trigger_data, edge, threshold_abs)
376 if trigger_idx is None:
377 raise ValidationError(f"No {edge} edge found in trigger channel")
379 # Align all channels by trimming to trigger point
380 aligned_channels = {}
381 offsets = {}
383 for name, data in channels.items():
384 if trigger_idx < len(data): 384 ↛ 389line 384 didn't jump to line 389 because the condition on line 384 was always true
385 aligned_channels[name] = data[trigger_idx:]
386 offsets[name] = trigger_idx
387 else:
388 # Trigger point is beyond this channel's data
389 aligned_channels[name] = np.array([])
390 offsets[name] = len(data)
392 # Assume all channels have same sample rate (no rate given)
393 # Use default of 1.0 Hz for sample-based indexing
394 return CorrelatedChannels(aligned_channels, sample_rate=1.0, offsets=offsets)
396 def resample_to_common_rate(
397 self,
398 channels: dict[str, tuple[NDArray[np.float64], float]],
399 target_rate: float | None = None,
400 ) -> CorrelatedChannels:
401 """Resample all channels to common sample rate.
405 Resamples channels with different sample rates to a common rate using
406 polyphase resampling for high quality. Uses the highest sample rate
407 as target if not specified.
409 Args:
410 channels: Dictionary of channel name -> (data, sample_rate) tuples.
411 target_rate: Target sample rate in Hz. If None, uses highest rate.
413 Returns:
414 CorrelatedChannels with resampled data at common rate.
416 Raises:
417 ValidationError: If channels are empty or rates are invalid.
418 """
419 if not channels:
420 raise ValidationError("At least one channel is required")
422 # Determine target rate
423 if target_rate is None:
424 rates = [rate for _, rate in channels.values()]
425 target_rate = max(rates)
427 if target_rate <= 0:
428 raise ValidationError(f"Target rate must be positive, got {target_rate}")
430 resampled_channels = {}
431 offsets = {}
433 for name, (data, original_rate) in channels.items():
434 if original_rate <= 0: 434 ↛ 435line 434 didn't jump to line 435 because the condition on line 434 was never true
435 raise ValidationError(f"Invalid sample rate for '{name}': {original_rate}")
437 if len(data) < 2:
438 # Skip empty/trivial channels
439 resampled_channels[name] = data
440 offsets[name] = 0
441 continue
443 # Calculate resampling ratio
444 ratio = target_rate / original_rate
446 if abs(ratio - 1.0) < 1e-6:
447 # Already at target rate
448 resampled_channels[name] = data
449 else:
450 # Resample using polyphase method
451 num_samples = int(np.round(len(data) * ratio))
452 resampled_channels[name] = signal.resample(data, num_samples)
454 offsets[name] = 0
456 return CorrelatedChannels(resampled_channels, sample_rate=target_rate, offsets=offsets)
458 def auto_align(
459 self,
460 channels: dict[str, NDArray[np.float64]],
461 sample_rate: float,
462 method: Literal["correlation", "trigger", "edge"] = "correlation",
463 ) -> CorrelatedChannels:
464 """Auto-align channels using best-guess method.
468 Automatically aligns multiple channels using the specified method.
469 For correlation method, aligns all channels to the reference channel.
471 Args:
472 channels: Dictionary of channel name -> data array.
473 sample_rate: Sample rate in Hz (same for all channels).
474 method: Alignment method to use.
476 Returns:
477 CorrelatedChannels with aligned data.
479 Raises:
480 ValidationError: If method is invalid or alignment fails.
481 """
482 if not channels:
483 raise ValidationError("At least one channel is required")
485 if len(channels) < 2:
486 # Single channel, no alignment needed
487 return CorrelatedChannels(
488 channels=channels, sample_rate=sample_rate, offsets=dict.fromkeys(channels, 0)
489 )
491 # Determine reference channel
492 if self.reference_channel and self.reference_channel in channels:
493 ref_name = self.reference_channel
494 else:
495 ref_name = next(iter(channels))
497 ref_data = channels[ref_name]
499 if method == "correlation":
500 # Correlate all channels to reference
501 aligned_channels = {ref_name: ref_data}
502 offsets = {ref_name: 0}
504 for name, data in channels.items():
505 if name == ref_name:
506 continue
508 # Cross-correlate with reference
509 result = self.correlate_channels(ref_data, data, sample_rate)
511 # Apply offset to align
512 offset = -result.offset_samples # Negative because we want to shift data
514 if offset > 0:
515 # Trim start of data
516 aligned_channels[name] = data[offset:]
517 elif offset < 0: 517 ↛ 522line 517 didn't jump to line 522 because the condition on line 517 was always true
518 # Pad start of data
519 pad = np.zeros(-offset)
520 aligned_channels[name] = np.concatenate([pad, data])
521 else:
522 aligned_channels[name] = data
524 offsets[name] = offset
526 # Trim all to same length
527 min_len = min(len(d) for d in aligned_channels.values())
528 aligned_channels = {name: data[:min_len] for name, data in aligned_channels.items()}
530 return CorrelatedChannels(aligned_channels, sample_rate, offsets)
532 elif method in ("trigger", "edge"):
533 # Use first channel as trigger
534 return self.align_by_trigger(channels, ref_name, edge="rising")
536 else:
537 raise ValidationError(f"Unknown alignment method: {method}")
539 def _estimate_correlation_confidence(
540 self, correlation: NDArray[np.float64], peak_idx: int
541 ) -> float:
542 """Estimate confidence from correlation peak sharpness.
544 Args:
545 correlation: Normalized correlation array.
546 peak_idx: Index of peak correlation.
548 Returns:
549 Confidence score 0.0 to 1.0.
550 """
551 peak_value = abs(correlation[peak_idx])
553 # Calculate peak-to-sidelobe ratio
554 # Higher ratio = sharper peak = higher confidence
555 window_size = min(20, len(correlation) // 10)
556 start = max(0, peak_idx - window_size)
557 end = min(len(correlation), peak_idx + window_size + 1)
559 # Exclude peak itself
560 sidelobe_indices = np.concatenate(
561 [np.arange(start, peak_idx), np.arange(peak_idx + 1, end)]
562 )
564 if len(sidelobe_indices) > 0: 564 ↛ 573line 564 didn't jump to line 573 because the condition on line 564 was always true
565 max_sidelobe = np.max(np.abs(correlation[sidelobe_indices]))
566 if max_sidelobe > 0: 566 ↛ 571line 566 didn't jump to line 571 because the condition on line 566 was always true
567 ratio = peak_value / max_sidelobe
568 # Map ratio to confidence (empirically tuned)
569 confidence = min(1.0, ratio / 5.0)
570 else:
571 confidence = 1.0
572 else:
573 confidence = peak_value
575 return float(confidence)
577 def _classify_correlation_quality(self, correlation: float, confidence: float) -> str:
578 """Classify correlation quality.
580 Args:
581 correlation: Correlation coefficient (0.0 to 1.0).
582 confidence: Confidence score (0.0 to 1.0).
584 Returns:
585 str: Quality rating - 'excellent', 'good', 'fair', or 'poor'.
586 """
587 score = (correlation + confidence) / 2.0
589 if score >= 0.8:
590 return "excellent"
591 elif score >= 0.6:
592 return "good"
593 elif score >= 0.4:
594 return "fair"
595 else:
596 return "poor"
598 def _find_first_edge(
599 self, data: NDArray[np.float64], edge: str, threshold: float
600 ) -> int | None:
601 """Find first edge in data.
603 Args:
604 data: Signal data.
605 edge: Edge type ('rising' or 'falling').
606 threshold: Threshold value.
608 Returns:
609 Index of first edge, or None if not found.
610 """
611 if edge == "rising":
612 # Find first point where signal crosses above threshold
613 crossings = np.where((data[:-1] < threshold) & (data[1:] >= threshold))[0]
614 else: # falling
615 crossings = np.where((data[:-1] > threshold) & (data[1:] <= threshold))[0]
617 if len(crossings) > 0:
618 return int(crossings[0] + 1) # Return index after crossing
619 else:
620 return None
623# Convenience functions
626def correlate_channels(
627 channel_a: NDArray[np.float64], channel_b: NDArray[np.float64], sample_rate: float = 1.0
628) -> CorrelationResult:
629 """Find time offset between two channels.
633 Convenience function for correlating two channels without creating
634 a ChannelCorrelator instance.
636 Args:
637 channel_a: First channel data.
638 channel_b: Second channel data.
639 sample_rate: Sample rate in Hz (default 1.0 for sample-based results).
641 Returns:
642 CorrelationResult with offset and quality metrics.
644 Example:
645 >>> result = correlate_channels(ch1, ch2, sample_rate=1e9)
646 >>> print(f"Offset: {result.offset_seconds*1e9:.2f} ns")
647 """
648 correlator = ChannelCorrelator()
649 return correlator.correlate_channels(channel_a, channel_b, sample_rate)
652def align_by_trigger(
653 channels: dict[str, NDArray[np.float64]],
654 trigger_channel: str,
655 edge: Literal["rising", "falling"] = "rising",
656 threshold: float = 0.5,
657) -> CorrelatedChannels:
658 """Align channels using trigger edge.
662 Convenience function for trigger-based alignment without creating
663 a ChannelCorrelator instance.
665 Args:
666 channels: Dictionary of channel name -> data array.
667 trigger_channel: Name of channel to use for trigger detection.
668 edge: Edge type to detect ('rising' or 'falling').
669 threshold: Trigger threshold (0-1 normalized or absolute).
671 Returns:
672 CorrelatedChannels with aligned data.
674 Example:
675 >>> aligned = align_by_trigger(
676 ... {'clk': clk_data, 'data': data_signal},
677 ... trigger_channel='clk',
678 ... edge='rising'
679 ... )
680 """
681 correlator = ChannelCorrelator()
682 return correlator.align_by_trigger(channels, trigger_channel, edge, threshold)
685def resample_to_common_rate(
686 channels: dict[str, tuple[NDArray[np.float64], float]], target_rate: float | None = None
687) -> CorrelatedChannels:
688 """Resample all channels to common rate.
692 Convenience function for resampling channels without creating
693 a ChannelCorrelator instance.
695 Args:
696 channels: Dictionary of channel name -> (data, sample_rate) tuples.
697 target_rate: Target sample rate in Hz. If None, uses highest rate.
699 Returns:
700 CorrelatedChannels with resampled data at common rate.
702 Example:
703 >>> resampled = resample_to_common_rate({
704 ... 'ch1': (data1, 1e9),
705 ... 'ch2': (data2, 2e9)
706 ... })
707 >>> print(f"Common rate: {resampled.sample_rate} Hz")
708 """
709 correlator = ChannelCorrelator()
710 return correlator.resample_to_common_rate(channels, target_rate)
713__all__ = [
714 "ChannelCorrelator",
715 "CorrelatedChannels",
716 "CorrelationResult",
717 "align_by_trigger",
718 "correlate_channels",
719 "resample_to_common_rate",
720]