Coverage for src / tracekit / exploratory / error_recovery.py: 87%

243 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Error recovery and graceful degradation for signal analysis. 

2 

3This module provides error recovery mechanisms for handling corrupted, 

4noisy, or incomplete signal data. 

5 

6 

7Example: 

8 >>> from tracekit.exploratory.error_recovery import recover_corrupted_data 

9 >>> recovered, stats = recover_corrupted_data(trace) 

10 >>> print(f"Recovered {stats.recovered_samples} samples") 

11""" 

12 

13from __future__ import annotations 

14 

15import logging 

16from dataclasses import dataclass 

17from typing import TYPE_CHECKING, Any, TypeVar 

18 

19import numpy as np 

20 

21from tracekit.core.types import WaveformTrace 

22 

23logger = logging.getLogger(__name__) 

24 

25if TYPE_CHECKING: 

26 from collections.abc import Callable 

27 

28 from numpy.typing import NDArray 

29 

30T = TypeVar("T") 

31 

32 

33@dataclass 

34class RecoveryStats: 

35 """Statistics from data recovery. 

36 

37 Attributes: 

38 total_samples: Total samples in original data. 

39 corrupted_samples: Number of detected corrupted samples. 

40 recovered_samples: Number of successfully recovered samples. 

41 unrecoverable_samples: Number that could not be recovered. 

42 recovery_method: Method used for recovery. 

43 confidence: Confidence in recovered data. 

44 """ 

45 

46 total_samples: int 

47 corrupted_samples: int 

48 recovered_samples: int 

49 unrecoverable_samples: int 

50 recovery_method: str 

51 confidence: float 

52 

53 

54def recover_corrupted_data( 

55 trace: WaveformTrace, 

56 *, 

57 corruption_threshold: float = 3.0, 

58 recovery_method: str = "interpolate", 

59 max_gap_samples: int = 100, 

60) -> tuple[WaveformTrace, RecoveryStats]: 

61 """Recover corrupted data. 

62 

63 Detects and attempts to recover corrupted samples using 

64 interpolation or other techniques. 

65 

66 Args: 

67 trace: Trace with potentially corrupted data. 

68 corruption_threshold: Threshold for detecting corruption (in std devs). 

69 recovery_method: 'interpolate', 'median', or 'zero'. 

70 max_gap_samples: Maximum gap that can be recovered. 

71 

72 Returns: 

73 Tuple of (recovered_trace, recovery_stats). 

74 

75 Example: 

76 >>> recovered, stats = recover_corrupted_data(trace) 

77 >>> print(f"Recovered {stats.recovered_samples} samples") 

78 >>> print(f"Confidence: {stats.confidence:.1%}") 

79 

80 References: 

81 ERROR-001: Error Recovery from Corrupted Data 

82 """ 

83 data = trace.data.copy() 

84 n = len(data) 

85 

86 # Detect corrupted samples using statistical outlier detection 

87 # Filter out nan/inf for initial statistics calculation 

88 valid_mask = np.isfinite(data) 

89 valid_data = data[valid_mask] if np.any(valid_mask) else data 

90 

91 median = np.median(valid_data) if len(valid_data) > 0 else 0.0 

92 mad = np.median(np.abs(valid_data - median)) if len(valid_data) > 0 else 0.0 

93 

94 if mad < 1e-10: 

95 mad = np.std(valid_data) if len(valid_data) > 0 else 1.0 

96 

97 # Z-score based on MAD 

98 z_scores = np.abs(data - median) / (1.4826 * mad + 1e-10) 

99 

100 # Find corrupted samples 

101 corrupted_mask = z_scores > corruption_threshold 

102 

103 # Also detect NaN and Inf 

104 corrupted_mask |= np.isnan(data) 

105 corrupted_mask |= np.isinf(data) 

106 

107 corrupted_indices = np.where(corrupted_mask)[0] 

108 n_corrupted = len(corrupted_indices) 

109 

110 if n_corrupted == 0: 

111 return trace, RecoveryStats( 

112 total_samples=n, 

113 corrupted_samples=0, 

114 recovered_samples=0, 

115 unrecoverable_samples=0, 

116 recovery_method="none", 

117 confidence=1.0, 

118 ) 

119 

120 # Group corrupted samples into contiguous regions 

121 gaps = [] 

122 if len(corrupted_indices) > 0: 122 ↛ 137line 122 didn't jump to line 137 because the condition on line 122 was always true

123 gap_start = corrupted_indices[0] 

124 gap_end = corrupted_indices[0] 

125 

126 for idx in corrupted_indices[1:]: 

127 if idx == gap_end + 1: 

128 gap_end = idx 

129 else: 

130 gaps.append((gap_start, gap_end)) 

131 gap_start = idx 

132 gap_end = idx 

133 

134 gaps.append((gap_start, gap_end)) 

135 

136 # Attempt recovery 

137 recovered = 0 

138 unrecoverable = 0 

139 

140 for start, end in gaps: 

141 gap_length = end - start + 1 

142 

143 if gap_length > max_gap_samples: 

144 unrecoverable += gap_length 

145 continue 

146 

147 if recovery_method == "interpolate": 

148 # Linear interpolation from surrounding samples 

149 left_idx = max(0, start - 1) 

150 right_idx = min(n - 1, end + 1) 

151 

152 if left_idx < start and right_idx > end: 152 ↛ 162line 152 didn't jump to line 162 because the condition on line 152 was always true

153 # Can interpolate 

154 left_val = data[left_idx] 

155 right_val = data[right_idx] 

156 for i, idx in enumerate(range(start, end + 1)): 

157 t = (i + 1) / (gap_length + 1) 

158 data[idx] = left_val * (1 - t) + right_val * t 

159 recovered += gap_length 

160 else: 

161 # Edge case - use nearest valid value 

162 if left_idx >= start: 

163 data[start : end + 1] = data[right_idx] 

164 else: 

165 data[start : end + 1] = data[left_idx] 

166 recovered += gap_length 

167 

168 elif recovery_method == "median": 

169 # Replace with local median 

170 window_start = max(0, start - 50) 

171 window_end = min(n, end + 50) 

172 window_data = data[window_start:window_end] 

173 valid_data = window_data[~corrupted_mask[window_start:window_end]] 

174 

175 if len(valid_data) > 0: 175 ↛ 180line 175 didn't jump to line 180 because the condition on line 175 was always true

176 fill_value = np.median(valid_data) 

177 data[start : end + 1] = fill_value 

178 recovered += gap_length 

179 else: 

180 unrecoverable += gap_length 

181 

182 elif recovery_method == "zero": 182 ↛ 188line 182 didn't jump to line 188 because the condition on line 182 was always true

183 # Replace with zero 

184 data[start : end + 1] = 0 

185 recovered += gap_length 

186 

187 else: 

188 unrecoverable += gap_length 

189 

190 # Create recovered trace 

191 recovered_trace = WaveformTrace( 

192 data=data, 

193 metadata=trace.metadata, 

194 ) 

195 

196 # Calculate confidence 

197 recovery_ratio = recovered / max(n_corrupted, 1) 

198 gap_sizes = [end - start + 1 for start, end in gaps] 

199 avg_gap_size = np.mean(gap_sizes) if gap_sizes else 0 

200 confidence = recovery_ratio * (1 - avg_gap_size / max_gap_samples) 

201 

202 return recovered_trace, RecoveryStats( 

203 total_samples=n, 

204 corrupted_samples=n_corrupted, 

205 recovered_samples=recovered, 

206 unrecoverable_samples=unrecoverable, 

207 recovery_method=recovery_method, 

208 confidence=max(0.0, min(1.0, confidence)), 

209 ) 

210 

211 

212@dataclass 

213class GracefulDegradationResult: 

214 """Result of gracefully degraded analysis. 

215 

216 Attributes: 

217 result: Analysis result (may be partial). 

218 quality_level: 'full', 'degraded', or 'minimal'. 

219 available_features: Features that could be computed. 

220 missing_features: Features that failed. 

221 warnings: List of warnings about degradation. 

222 """ 

223 

224 result: dict[str, Any] 

225 quality_level: str 

226 available_features: list[str] 

227 missing_features: list[str] 

228 warnings: list[str] 

229 

230 

231def graceful_degradation( 

232 analysis_func: Callable[..., dict[str, Any]], 

233 trace: WaveformTrace, 

234 *, 

235 required_features: list[str] | None = None, 

236 optional_features: list[str] | None = None, 

237 **kwargs: Any, 

238) -> GracefulDegradationResult: 

239 """Execute analysis with graceful degradation. 

240 

241 Attempts to provide partial results when full analysis fails. 

242 

243 Args: 

244 analysis_func: Analysis function to call. 

245 trace: Trace to analyze. 

246 required_features: Features that must succeed. 

247 optional_features: Features that can fail. 

248 **kwargs: Additional arguments to analysis function. 

249 

250 Returns: 

251 GracefulDegradationResult with partial or full results. 

252 

253 Example: 

254 >>> result = graceful_degradation(analyze_signal, trace) 

255 >>> print(f"Quality: {result.quality_level}") 

256 >>> print(f"Available: {result.available_features}") 

257 

258 References: 

259 ERROR-002: Graceful Degradation 

260 """ 

261 if required_features is None: 

262 required_features = [] 

263 if optional_features is None: 

264 optional_features = [] 

265 

266 result: dict[str, Any] = {} 

267 available = [] 

268 missing = [] 

269 warnings = [] 

270 

271 # Try full analysis first 

272 try: 

273 result = analysis_func(trace, **kwargs) 

274 available = list(result.keys()) 

275 quality_level = "full" 

276 

277 except Exception as e: 

278 logger.warning("Full analysis failed: %s", e, exc_info=True) 

279 warnings.append(f"Full analysis failed: {e!s}") 

280 

281 # Try reduced analysis 

282 for feature in required_features + optional_features: 

283 try: 

284 # Attempt to compute individual feature 

285 if hasattr(trace, feature): 

286 result[feature] = getattr(trace, feature) 

287 available.append(feature) 

288 else: 

289 missing.append(feature) 

290 except Exception as fe: 

291 logger.debug("Feature %s failed: %s", feature, fe, exc_info=True) 

292 missing.append(feature) 

293 if feature in required_features: 

294 warnings.append(f"Required feature {feature} failed: {fe!s}") 

295 

296 # Determine quality level 

297 if all(f in available for f in required_features): 297 ↛ 299line 297 didn't jump to line 299 because the condition on line 297 was always true

298 quality_level = "degraded" 

299 elif len(available) > 0: 

300 quality_level = "minimal" 

301 else: 

302 quality_level = "failed" 

303 warnings.append("Analysis completely failed") 

304 

305 return GracefulDegradationResult( 

306 result=result, 

307 quality_level=quality_level, 

308 available_features=available, 

309 missing_features=missing, 

310 warnings=warnings, 

311 ) 

312 

313 

314@dataclass 

315class PartialDecodeResult: 

316 """Result of partial protocol decode. 

317 

318 Attributes: 

319 complete_packets: Successfully decoded packets. 

320 partial_packets: Partially decoded packets. 

321 error_regions: Regions that could not be decoded. 

322 decode_rate: Percentage of signal successfully decoded. 

323 confidence: Confidence in decoded data. 

324 """ 

325 

326 complete_packets: list[dict[str, Any]] 

327 partial_packets: list[dict[str, Any]] 

328 error_regions: list[dict[str, Any]] 

329 decode_rate: float 

330 confidence: float 

331 

332 

333def partial_decode( 

334 trace: WaveformTrace, 

335 decode_func: Callable[[WaveformTrace], list[dict[str, Any]]], 

336 *, 

337 segment_size: int = 10000, 

338 min_valid_ratio: float = 0.5, 

339) -> PartialDecodeResult: 

340 """Decode protocol with partial result support. 

341 

342 Continues decoding after errors to capture as much data as possible. 

343 

344 Args: 

345 trace: Trace to decode. 

346 decode_func: Protocol decode function. 

347 segment_size: Size of segments to try independently. 

348 min_valid_ratio: Minimum valid ratio to accept segment. 

349 

350 Returns: 

351 PartialDecodeResult with all decoded data. 

352 

353 Example: 

354 >>> result = partial_decode(trace, uart_decode) 

355 >>> print(f"Decoded {len(result.complete_packets)} complete packets") 

356 >>> print(f"Decode rate: {result.decode_rate:.1%}") 

357 

358 References: 

359 ERROR-003: Partial Decode Support 

360 """ 

361 data = trace.data 

362 n = len(data) 

363 

364 complete_packets: list[dict[str, Any]] = [] 

365 partial_packets: list[dict[str, Any]] = [] 

366 error_regions: list[dict[str, Any]] = [] 

367 

368 total_samples = 0 

369 decoded_samples = 0 

370 

371 # Try to decode entire trace first 

372 try: 

373 full_result = decode_func(trace) 

374 if full_result: 

375 complete_packets.extend(full_result) 

376 decoded_samples = n 

377 total_samples = n 

378 except Exception as e: 

379 logger.info("Full decode failed, falling back to segment decode: %s", e) 

380 # Fall back to segment-by-segment decode 

381 for start in range(0, n, segment_size): 

382 end = min(start + segment_size, n) 

383 segment_data = data[start:end] 

384 

385 # Create segment trace 

386 segment_trace = WaveformTrace( 

387 data=segment_data, 

388 metadata=trace.metadata, 

389 ) 

390 

391 total_samples += len(segment_data) 

392 

393 try: 

394 segment_result = decode_func(segment_trace) 

395 

396 if segment_result: 396 ↛ 381line 396 didn't jump to line 381 because the condition on line 396 was always true

397 # Adjust timestamps 

398 for packet in segment_result: 

399 if "timestamp" in packet: 

400 packet["timestamp"] += start / trace.metadata.sample_rate 

401 if "sample" in packet: 

402 packet["sample"] += start 

403 

404 # Check if segment is valid 

405 valid_ratio = len(segment_result) / max(len(segment_data) / 100, 1) 

406 

407 if valid_ratio >= min_valid_ratio: 407 ↛ 408line 407 didn't jump to line 408 because the condition on line 407 was never true

408 complete_packets.extend(segment_result) 

409 decoded_samples += len(segment_data) 

410 else: 

411 partial_packets.extend(segment_result) 

412 decoded_samples += len(segment_data) // 2 

413 

414 except Exception as e: 

415 logger.debug("Segment decode failed at sample %d: %s", start, e) 

416 error_regions.append( 

417 { 

418 "start_sample": start, 

419 "end_sample": end, 

420 "error": str(e), 

421 } 

422 ) 

423 

424 # Calculate statistics 

425 decode_rate = decoded_samples / max(total_samples, 1) 

426 

427 # Calculate confidence 

428 error_ratio = len(error_regions) / max((n // segment_size), 1) 

429 confidence = decode_rate * (1 - error_ratio) 

430 

431 return PartialDecodeResult( 

432 complete_packets=complete_packets, 

433 partial_packets=partial_packets, 

434 error_regions=error_regions, 

435 decode_rate=decode_rate, 

436 confidence=confidence, 

437 ) 

438 

439 

440@dataclass 

441class ErrorContext: 

442 """Preserved error context for debugging. 

443 

444 Attributes: 

445 error_type: Type of error that occurred. 

446 error_message: Error message. 

447 location: Where in the signal the error occurred. 

448 context_before: Signal context before error. 

449 context_after: Signal context after error. 

450 parameters: Parameters at time of error. 

451 suggestions: Suggestions for fixing the error. 

452 """ 

453 

454 error_type: str 

455 error_message: str 

456 location: int | None 

457 context_before: NDArray[np.float64] | None 

458 context_after: NDArray[np.float64] | None 

459 parameters: dict[str, Any] 

460 suggestions: list[str] 

461 

462 @classmethod 

463 def capture( 

464 cls, 

465 exception: Exception, 

466 trace: WaveformTrace | None = None, 

467 location: int | None = None, 

468 context_samples: int = 100, 

469 parameters: dict[str, Any] | None = None, 

470 ) -> ErrorContext: 

471 """Capture error context from exception. 

472 

473 Args: 

474 exception: The exception that occurred. 

475 trace: Signal trace (for context extraction). 

476 location: Sample index where error occurred. 

477 context_samples: Number of context samples to capture. 

478 parameters: Analysis parameters at time of error. 

479 

480 Returns: 

481 ErrorContext with all available information. 

482 """ 

483 context_before = None 

484 context_after = None 

485 

486 if trace is not None and location is not None: 

487 data = trace.data 

488 n = len(data) 

489 

490 if location >= 0 and location < n: 490 ↛ 497line 490 didn't jump to line 497 because the condition on line 490 was always true

491 start = max(0, location - context_samples) 

492 end = min(n, location + context_samples) 

493 context_before = data[start:location] 

494 context_after = data[location:end] 

495 

496 # Generate suggestions based on error type 

497 suggestions = [] 

498 error_str = str(exception) 

499 

500 if "insufficient" in error_str.lower(): 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true

501 suggestions.append("Try providing more data samples") 

502 suggestions.append("Check if trace is complete") 

503 

504 if "threshold" in error_str.lower(): 

505 suggestions.append("Try adjusting threshold parameter") 

506 suggestions.append("Check signal levels are as expected") 

507 

508 if "timeout" in error_str.lower(): 508 ↛ 509line 508 didn't jump to line 509 because the condition on line 508 was never true

509 suggestions.append("Increase timeout parameter") 

510 suggestions.append("Process in smaller chunks") 

511 

512 if "memory" in error_str.lower(): 

513 suggestions.append("Use chunked processing") 

514 suggestions.append("Reduce analysis window size") 

515 

516 if not suggestions: 

517 suggestions.append("Check input data format") 

518 suggestions.append("Verify analysis parameters") 

519 

520 return cls( 

521 error_type=type(exception).__name__, 

522 error_message=str(exception), 

523 location=location, 

524 context_before=context_before, 

525 context_after=context_after, 

526 parameters=parameters or {}, 

527 suggestions=suggestions, 

528 ) 

529 

530 

531@dataclass 

532class RetryResult: 

533 """Result of retry with parameter adjustment. 

534 

535 Attributes: 

536 success: True if retry succeeded. 

537 result: Analysis result (if successful). 

538 attempts: Number of attempts made. 

539 final_parameters: Parameters that worked. 

540 adjustments_made: List of adjustments made. 

541 """ 

542 

543 success: bool 

544 result: Any 

545 attempts: int 

546 final_parameters: dict[str, Any] 

547 adjustments_made: list[str] 

548 

549 

550def retry_with_adjustment[T]( 

551 func: Callable[..., T], 

552 trace: WaveformTrace, 

553 initial_params: dict[str, Any], 

554 *, 

555 max_retries: int = 3, 

556 adjustment_rules: dict[str, Callable[[Any, int], Any]] | None = None, 

557) -> RetryResult: 

558 """Retry analysis with automatic parameter adjustment. 

559 

560 Adjusts parameters and retries when analysis fails. 

561 

562 Args: 

563 func: Analysis function to retry. 

564 trace: Trace to analyze. 

565 initial_params: Initial parameters. 

566 max_retries: Maximum retry attempts. 

567 adjustment_rules: Rules for adjusting parameters. 

568 

569 Returns: 

570 RetryResult with outcome of retries. 

571 

572 Example: 

573 >>> rules = { 

574 ... 'threshold': lambda v, n: v * 0.9, # Reduce by 10% each retry 

575 ... 'window_size': lambda v, n: v * 2, # Double each retry 

576 ... } 

577 >>> result = retry_with_adjustment(analyze, trace, params, adjustment_rules=rules) 

578 >>> if result.success: 

579 ... print(f"Succeeded after {result.attempts} attempts") 

580 

581 References: 

582 ERROR-005: Automatic Retry with Parameter Adjustment 

583 """ 

584 if adjustment_rules is None: 

585 # Default adjustment rules 

586 adjustment_rules = { 

587 "threshold": lambda v, n: v * (0.9**n), 

588 "tolerance": lambda v, n: v * (1.2**n), 

589 "window_size": lambda v, n: int(v * (1.5**n)), 

590 "min_samples": lambda v, n: max(1, int(v * (0.8**n))), 

591 } 

592 

593 params = initial_params.copy() 

594 adjustments_made = [] # type: ignore[var-annotated] 

595 

596 for attempt in range(max_retries + 1): 596 ↛ 623line 596 didn't jump to line 623 because the loop on line 596 didn't complete

597 try: 

598 result = func(trace, **params) 

599 return RetryResult( 

600 success=True, 

601 result=result, 

602 attempts=attempt + 1, 

603 final_parameters=params, 

604 adjustments_made=adjustments_made, 

605 ) 

606 

607 except Exception as e: 

608 logger.debug("Retry attempt %d failed: %s", attempt + 1, e) 

609 if attempt >= max_retries: 

610 logger.warning("Max retries (%d) reached, giving up: %s", max_retries, e) 

611 break 

612 

613 # Adjust parameters for next attempt 

614 for param_name, adjust_func in adjustment_rules.items(): 

615 if param_name in params: 

616 old_val = params[param_name] 

617 new_val = adjust_func(old_val, attempt + 1) 

618 params[param_name] = new_val 

619 adjustments_made.append( 

620 f"Attempt {attempt + 1}: {param_name} {old_val} -> {new_val}" 

621 ) 

622 

623 return RetryResult( 

624 success=False, 

625 result=None, 

626 attempts=max_retries + 1, 

627 final_parameters=params, 

628 adjustments_made=adjustments_made, 

629 ) 

630 

631 

632__all__ = [ 

633 "ErrorContext", 

634 "GracefulDegradationResult", 

635 "PartialDecodeResult", 

636 "RecoveryStats", 

637 "RetryResult", 

638 "graceful_degradation", 

639 "partial_decode", 

640 "recover_corrupted_data", 

641 "retry_with_adjustment", 

642]