Coverage for src / tracekit / exploratory / error_recovery.py: 87%
243 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Error recovery and graceful degradation for signal analysis.
3This module provides error recovery mechanisms for handling corrupted,
4noisy, or incomplete signal data.
7Example:
8 >>> from tracekit.exploratory.error_recovery import recover_corrupted_data
9 >>> recovered, stats = recover_corrupted_data(trace)
10 >>> print(f"Recovered {stats.recovered_samples} samples")
11"""
13from __future__ import annotations
15import logging
16from dataclasses import dataclass
17from typing import TYPE_CHECKING, Any, TypeVar
19import numpy as np
21from tracekit.core.types import WaveformTrace
23logger = logging.getLogger(__name__)
25if TYPE_CHECKING:
26 from collections.abc import Callable
28 from numpy.typing import NDArray
30T = TypeVar("T")
33@dataclass
34class RecoveryStats:
35 """Statistics from data recovery.
37 Attributes:
38 total_samples: Total samples in original data.
39 corrupted_samples: Number of detected corrupted samples.
40 recovered_samples: Number of successfully recovered samples.
41 unrecoverable_samples: Number that could not be recovered.
42 recovery_method: Method used for recovery.
43 confidence: Confidence in recovered data.
44 """
46 total_samples: int
47 corrupted_samples: int
48 recovered_samples: int
49 unrecoverable_samples: int
50 recovery_method: str
51 confidence: float
54def recover_corrupted_data(
55 trace: WaveformTrace,
56 *,
57 corruption_threshold: float = 3.0,
58 recovery_method: str = "interpolate",
59 max_gap_samples: int = 100,
60) -> tuple[WaveformTrace, RecoveryStats]:
61 """Recover corrupted data.
63 Detects and attempts to recover corrupted samples using
64 interpolation or other techniques.
66 Args:
67 trace: Trace with potentially corrupted data.
68 corruption_threshold: Threshold for detecting corruption (in std devs).
69 recovery_method: 'interpolate', 'median', or 'zero'.
70 max_gap_samples: Maximum gap that can be recovered.
72 Returns:
73 Tuple of (recovered_trace, recovery_stats).
75 Example:
76 >>> recovered, stats = recover_corrupted_data(trace)
77 >>> print(f"Recovered {stats.recovered_samples} samples")
78 >>> print(f"Confidence: {stats.confidence:.1%}")
80 References:
81 ERROR-001: Error Recovery from Corrupted Data
82 """
83 data = trace.data.copy()
84 n = len(data)
86 # Detect corrupted samples using statistical outlier detection
87 # Filter out nan/inf for initial statistics calculation
88 valid_mask = np.isfinite(data)
89 valid_data = data[valid_mask] if np.any(valid_mask) else data
91 median = np.median(valid_data) if len(valid_data) > 0 else 0.0
92 mad = np.median(np.abs(valid_data - median)) if len(valid_data) > 0 else 0.0
94 if mad < 1e-10:
95 mad = np.std(valid_data) if len(valid_data) > 0 else 1.0
97 # Z-score based on MAD
98 z_scores = np.abs(data - median) / (1.4826 * mad + 1e-10)
100 # Find corrupted samples
101 corrupted_mask = z_scores > corruption_threshold
103 # Also detect NaN and Inf
104 corrupted_mask |= np.isnan(data)
105 corrupted_mask |= np.isinf(data)
107 corrupted_indices = np.where(corrupted_mask)[0]
108 n_corrupted = len(corrupted_indices)
110 if n_corrupted == 0:
111 return trace, RecoveryStats(
112 total_samples=n,
113 corrupted_samples=0,
114 recovered_samples=0,
115 unrecoverable_samples=0,
116 recovery_method="none",
117 confidence=1.0,
118 )
120 # Group corrupted samples into contiguous regions
121 gaps = []
122 if len(corrupted_indices) > 0: 122 ↛ 137line 122 didn't jump to line 137 because the condition on line 122 was always true
123 gap_start = corrupted_indices[0]
124 gap_end = corrupted_indices[0]
126 for idx in corrupted_indices[1:]:
127 if idx == gap_end + 1:
128 gap_end = idx
129 else:
130 gaps.append((gap_start, gap_end))
131 gap_start = idx
132 gap_end = idx
134 gaps.append((gap_start, gap_end))
136 # Attempt recovery
137 recovered = 0
138 unrecoverable = 0
140 for start, end in gaps:
141 gap_length = end - start + 1
143 if gap_length > max_gap_samples:
144 unrecoverable += gap_length
145 continue
147 if recovery_method == "interpolate":
148 # Linear interpolation from surrounding samples
149 left_idx = max(0, start - 1)
150 right_idx = min(n - 1, end + 1)
152 if left_idx < start and right_idx > end: 152 ↛ 162line 152 didn't jump to line 162 because the condition on line 152 was always true
153 # Can interpolate
154 left_val = data[left_idx]
155 right_val = data[right_idx]
156 for i, idx in enumerate(range(start, end + 1)):
157 t = (i + 1) / (gap_length + 1)
158 data[idx] = left_val * (1 - t) + right_val * t
159 recovered += gap_length
160 else:
161 # Edge case - use nearest valid value
162 if left_idx >= start:
163 data[start : end + 1] = data[right_idx]
164 else:
165 data[start : end + 1] = data[left_idx]
166 recovered += gap_length
168 elif recovery_method == "median":
169 # Replace with local median
170 window_start = max(0, start - 50)
171 window_end = min(n, end + 50)
172 window_data = data[window_start:window_end]
173 valid_data = window_data[~corrupted_mask[window_start:window_end]]
175 if len(valid_data) > 0: 175 ↛ 180line 175 didn't jump to line 180 because the condition on line 175 was always true
176 fill_value = np.median(valid_data)
177 data[start : end + 1] = fill_value
178 recovered += gap_length
179 else:
180 unrecoverable += gap_length
182 elif recovery_method == "zero": 182 ↛ 188line 182 didn't jump to line 188 because the condition on line 182 was always true
183 # Replace with zero
184 data[start : end + 1] = 0
185 recovered += gap_length
187 else:
188 unrecoverable += gap_length
190 # Create recovered trace
191 recovered_trace = WaveformTrace(
192 data=data,
193 metadata=trace.metadata,
194 )
196 # Calculate confidence
197 recovery_ratio = recovered / max(n_corrupted, 1)
198 gap_sizes = [end - start + 1 for start, end in gaps]
199 avg_gap_size = np.mean(gap_sizes) if gap_sizes else 0
200 confidence = recovery_ratio * (1 - avg_gap_size / max_gap_samples)
202 return recovered_trace, RecoveryStats(
203 total_samples=n,
204 corrupted_samples=n_corrupted,
205 recovered_samples=recovered,
206 unrecoverable_samples=unrecoverable,
207 recovery_method=recovery_method,
208 confidence=max(0.0, min(1.0, confidence)),
209 )
212@dataclass
213class GracefulDegradationResult:
214 """Result of gracefully degraded analysis.
216 Attributes:
217 result: Analysis result (may be partial).
218 quality_level: 'full', 'degraded', or 'minimal'.
219 available_features: Features that could be computed.
220 missing_features: Features that failed.
221 warnings: List of warnings about degradation.
222 """
224 result: dict[str, Any]
225 quality_level: str
226 available_features: list[str]
227 missing_features: list[str]
228 warnings: list[str]
231def graceful_degradation(
232 analysis_func: Callable[..., dict[str, Any]],
233 trace: WaveformTrace,
234 *,
235 required_features: list[str] | None = None,
236 optional_features: list[str] | None = None,
237 **kwargs: Any,
238) -> GracefulDegradationResult:
239 """Execute analysis with graceful degradation.
241 Attempts to provide partial results when full analysis fails.
243 Args:
244 analysis_func: Analysis function to call.
245 trace: Trace to analyze.
246 required_features: Features that must succeed.
247 optional_features: Features that can fail.
248 **kwargs: Additional arguments to analysis function.
250 Returns:
251 GracefulDegradationResult with partial or full results.
253 Example:
254 >>> result = graceful_degradation(analyze_signal, trace)
255 >>> print(f"Quality: {result.quality_level}")
256 >>> print(f"Available: {result.available_features}")
258 References:
259 ERROR-002: Graceful Degradation
260 """
261 if required_features is None:
262 required_features = []
263 if optional_features is None:
264 optional_features = []
266 result: dict[str, Any] = {}
267 available = []
268 missing = []
269 warnings = []
271 # Try full analysis first
272 try:
273 result = analysis_func(trace, **kwargs)
274 available = list(result.keys())
275 quality_level = "full"
277 except Exception as e:
278 logger.warning("Full analysis failed: %s", e, exc_info=True)
279 warnings.append(f"Full analysis failed: {e!s}")
281 # Try reduced analysis
282 for feature in required_features + optional_features:
283 try:
284 # Attempt to compute individual feature
285 if hasattr(trace, feature):
286 result[feature] = getattr(trace, feature)
287 available.append(feature)
288 else:
289 missing.append(feature)
290 except Exception as fe:
291 logger.debug("Feature %s failed: %s", feature, fe, exc_info=True)
292 missing.append(feature)
293 if feature in required_features:
294 warnings.append(f"Required feature {feature} failed: {fe!s}")
296 # Determine quality level
297 if all(f in available for f in required_features): 297 ↛ 299line 297 didn't jump to line 299 because the condition on line 297 was always true
298 quality_level = "degraded"
299 elif len(available) > 0:
300 quality_level = "minimal"
301 else:
302 quality_level = "failed"
303 warnings.append("Analysis completely failed")
305 return GracefulDegradationResult(
306 result=result,
307 quality_level=quality_level,
308 available_features=available,
309 missing_features=missing,
310 warnings=warnings,
311 )
314@dataclass
315class PartialDecodeResult:
316 """Result of partial protocol decode.
318 Attributes:
319 complete_packets: Successfully decoded packets.
320 partial_packets: Partially decoded packets.
321 error_regions: Regions that could not be decoded.
322 decode_rate: Percentage of signal successfully decoded.
323 confidence: Confidence in decoded data.
324 """
326 complete_packets: list[dict[str, Any]]
327 partial_packets: list[dict[str, Any]]
328 error_regions: list[dict[str, Any]]
329 decode_rate: float
330 confidence: float
333def partial_decode(
334 trace: WaveformTrace,
335 decode_func: Callable[[WaveformTrace], list[dict[str, Any]]],
336 *,
337 segment_size: int = 10000,
338 min_valid_ratio: float = 0.5,
339) -> PartialDecodeResult:
340 """Decode protocol with partial result support.
342 Continues decoding after errors to capture as much data as possible.
344 Args:
345 trace: Trace to decode.
346 decode_func: Protocol decode function.
347 segment_size: Size of segments to try independently.
348 min_valid_ratio: Minimum valid ratio to accept segment.
350 Returns:
351 PartialDecodeResult with all decoded data.
353 Example:
354 >>> result = partial_decode(trace, uart_decode)
355 >>> print(f"Decoded {len(result.complete_packets)} complete packets")
356 >>> print(f"Decode rate: {result.decode_rate:.1%}")
358 References:
359 ERROR-003: Partial Decode Support
360 """
361 data = trace.data
362 n = len(data)
364 complete_packets: list[dict[str, Any]] = []
365 partial_packets: list[dict[str, Any]] = []
366 error_regions: list[dict[str, Any]] = []
368 total_samples = 0
369 decoded_samples = 0
371 # Try to decode entire trace first
372 try:
373 full_result = decode_func(trace)
374 if full_result:
375 complete_packets.extend(full_result)
376 decoded_samples = n
377 total_samples = n
378 except Exception as e:
379 logger.info("Full decode failed, falling back to segment decode: %s", e)
380 # Fall back to segment-by-segment decode
381 for start in range(0, n, segment_size):
382 end = min(start + segment_size, n)
383 segment_data = data[start:end]
385 # Create segment trace
386 segment_trace = WaveformTrace(
387 data=segment_data,
388 metadata=trace.metadata,
389 )
391 total_samples += len(segment_data)
393 try:
394 segment_result = decode_func(segment_trace)
396 if segment_result: 396 ↛ 381line 396 didn't jump to line 381 because the condition on line 396 was always true
397 # Adjust timestamps
398 for packet in segment_result:
399 if "timestamp" in packet:
400 packet["timestamp"] += start / trace.metadata.sample_rate
401 if "sample" in packet:
402 packet["sample"] += start
404 # Check if segment is valid
405 valid_ratio = len(segment_result) / max(len(segment_data) / 100, 1)
407 if valid_ratio >= min_valid_ratio: 407 ↛ 408line 407 didn't jump to line 408 because the condition on line 407 was never true
408 complete_packets.extend(segment_result)
409 decoded_samples += len(segment_data)
410 else:
411 partial_packets.extend(segment_result)
412 decoded_samples += len(segment_data) // 2
414 except Exception as e:
415 logger.debug("Segment decode failed at sample %d: %s", start, e)
416 error_regions.append(
417 {
418 "start_sample": start,
419 "end_sample": end,
420 "error": str(e),
421 }
422 )
424 # Calculate statistics
425 decode_rate = decoded_samples / max(total_samples, 1)
427 # Calculate confidence
428 error_ratio = len(error_regions) / max((n // segment_size), 1)
429 confidence = decode_rate * (1 - error_ratio)
431 return PartialDecodeResult(
432 complete_packets=complete_packets,
433 partial_packets=partial_packets,
434 error_regions=error_regions,
435 decode_rate=decode_rate,
436 confidence=confidence,
437 )
440@dataclass
441class ErrorContext:
442 """Preserved error context for debugging.
444 Attributes:
445 error_type: Type of error that occurred.
446 error_message: Error message.
447 location: Where in the signal the error occurred.
448 context_before: Signal context before error.
449 context_after: Signal context after error.
450 parameters: Parameters at time of error.
451 suggestions: Suggestions for fixing the error.
452 """
454 error_type: str
455 error_message: str
456 location: int | None
457 context_before: NDArray[np.float64] | None
458 context_after: NDArray[np.float64] | None
459 parameters: dict[str, Any]
460 suggestions: list[str]
462 @classmethod
463 def capture(
464 cls,
465 exception: Exception,
466 trace: WaveformTrace | None = None,
467 location: int | None = None,
468 context_samples: int = 100,
469 parameters: dict[str, Any] | None = None,
470 ) -> ErrorContext:
471 """Capture error context from exception.
473 Args:
474 exception: The exception that occurred.
475 trace: Signal trace (for context extraction).
476 location: Sample index where error occurred.
477 context_samples: Number of context samples to capture.
478 parameters: Analysis parameters at time of error.
480 Returns:
481 ErrorContext with all available information.
482 """
483 context_before = None
484 context_after = None
486 if trace is not None and location is not None:
487 data = trace.data
488 n = len(data)
490 if location >= 0 and location < n: 490 ↛ 497line 490 didn't jump to line 497 because the condition on line 490 was always true
491 start = max(0, location - context_samples)
492 end = min(n, location + context_samples)
493 context_before = data[start:location]
494 context_after = data[location:end]
496 # Generate suggestions based on error type
497 suggestions = []
498 error_str = str(exception)
500 if "insufficient" in error_str.lower(): 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true
501 suggestions.append("Try providing more data samples")
502 suggestions.append("Check if trace is complete")
504 if "threshold" in error_str.lower():
505 suggestions.append("Try adjusting threshold parameter")
506 suggestions.append("Check signal levels are as expected")
508 if "timeout" in error_str.lower(): 508 ↛ 509line 508 didn't jump to line 509 because the condition on line 508 was never true
509 suggestions.append("Increase timeout parameter")
510 suggestions.append("Process in smaller chunks")
512 if "memory" in error_str.lower():
513 suggestions.append("Use chunked processing")
514 suggestions.append("Reduce analysis window size")
516 if not suggestions:
517 suggestions.append("Check input data format")
518 suggestions.append("Verify analysis parameters")
520 return cls(
521 error_type=type(exception).__name__,
522 error_message=str(exception),
523 location=location,
524 context_before=context_before,
525 context_after=context_after,
526 parameters=parameters or {},
527 suggestions=suggestions,
528 )
531@dataclass
532class RetryResult:
533 """Result of retry with parameter adjustment.
535 Attributes:
536 success: True if retry succeeded.
537 result: Analysis result (if successful).
538 attempts: Number of attempts made.
539 final_parameters: Parameters that worked.
540 adjustments_made: List of adjustments made.
541 """
543 success: bool
544 result: Any
545 attempts: int
546 final_parameters: dict[str, Any]
547 adjustments_made: list[str]
550def retry_with_adjustment[T](
551 func: Callable[..., T],
552 trace: WaveformTrace,
553 initial_params: dict[str, Any],
554 *,
555 max_retries: int = 3,
556 adjustment_rules: dict[str, Callable[[Any, int], Any]] | None = None,
557) -> RetryResult:
558 """Retry analysis with automatic parameter adjustment.
560 Adjusts parameters and retries when analysis fails.
562 Args:
563 func: Analysis function to retry.
564 trace: Trace to analyze.
565 initial_params: Initial parameters.
566 max_retries: Maximum retry attempts.
567 adjustment_rules: Rules for adjusting parameters.
569 Returns:
570 RetryResult with outcome of retries.
572 Example:
573 >>> rules = {
574 ... 'threshold': lambda v, n: v * 0.9, # Reduce by 10% each retry
575 ... 'window_size': lambda v, n: v * 2, # Double each retry
576 ... }
577 >>> result = retry_with_adjustment(analyze, trace, params, adjustment_rules=rules)
578 >>> if result.success:
579 ... print(f"Succeeded after {result.attempts} attempts")
581 References:
582 ERROR-005: Automatic Retry with Parameter Adjustment
583 """
584 if adjustment_rules is None:
585 # Default adjustment rules
586 adjustment_rules = {
587 "threshold": lambda v, n: v * (0.9**n),
588 "tolerance": lambda v, n: v * (1.2**n),
589 "window_size": lambda v, n: int(v * (1.5**n)),
590 "min_samples": lambda v, n: max(1, int(v * (0.8**n))),
591 }
593 params = initial_params.copy()
594 adjustments_made = [] # type: ignore[var-annotated]
596 for attempt in range(max_retries + 1): 596 ↛ 623line 596 didn't jump to line 623 because the loop on line 596 didn't complete
597 try:
598 result = func(trace, **params)
599 return RetryResult(
600 success=True,
601 result=result,
602 attempts=attempt + 1,
603 final_parameters=params,
604 adjustments_made=adjustments_made,
605 )
607 except Exception as e:
608 logger.debug("Retry attempt %d failed: %s", attempt + 1, e)
609 if attempt >= max_retries:
610 logger.warning("Max retries (%d) reached, giving up: %s", max_retries, e)
611 break
613 # Adjust parameters for next attempt
614 for param_name, adjust_func in adjustment_rules.items():
615 if param_name in params:
616 old_val = params[param_name]
617 new_val = adjust_func(old_val, attempt + 1)
618 params[param_name] = new_val
619 adjustments_made.append(
620 f"Attempt {attempt + 1}: {param_name} {old_val} -> {new_val}"
621 )
623 return RetryResult(
624 success=False,
625 result=None,
626 attempts=max_retries + 1,
627 final_parameters=params,
628 adjustments_made=adjustments_made,
629 )
632__all__ = [
633 "ErrorContext",
634 "GracefulDegradationResult",
635 "PartialDecodeResult",
636 "RecoveryStats",
637 "RetryResult",
638 "graceful_degradation",
639 "partial_decode",
640 "recover_corrupted_data",
641 "retry_with_adjustment",
642]