Coverage for src / tracekit / core / config.py: 92%

146 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""TraceKit configuration loading and management. 

2 

3This module provides configuration loading from YAML files with support 

4for nested structures, user overrides, and schema validation. 

5 

6 

7Example: 

8 >>> from tracekit.core.config import load_config 

9 >>> config = load_config() 

10 >>> print(config["defaults"]["sample_rate"]) 

11 1000000.0 

12""" 

13 

14from __future__ import annotations 

15 

16import copy 

17from pathlib import Path 

18from typing import Any 

19 

20import numpy as np 

21 

22try: 

23 import yaml 

24 

25 YAML_AVAILABLE = True 

26except ImportError: 

27 YAML_AVAILABLE = False 

28 

29from tracekit.core.exceptions import ConfigurationError 

30 

31# Default configuration values 

32DEFAULT_CONFIG: dict[str, Any] = { 

33 "version": "1.0", 

34 "defaults": { 

35 "sample_rate": 1e6, # 1 MHz default 

36 "window_function": "hann", 

37 "fft_size": 1024, 

38 }, 

39 "loaders": { 

40 "auto_detect": True, 

41 "formats": ["wfm", "csv", "npz", "hdf5", "tdms", "vcd", "sr", "wav"], 

42 "tektronix": {"byte_order": "little"}, 

43 "csv": {"delimiter": ",", "skip_header": 0}, 

44 }, 

45 "measurements": { 

46 "rise_time": {"ref_levels": [0.1, 0.9]}, 

47 "fall_time": {"ref_levels": [0.9, 0.1]}, 

48 "frequency": {"min_periods": 3}, 

49 }, 

50 "spectral": { 

51 "default_window": "hann", 

52 "overlap": 0.5, 

53 "nfft": None, # Auto-determine from signal length 

54 }, 

55 "visualization": { 

56 "default_style": "seaborn", 

57 "figure_size": [10, 6], 

58 "dpi": 100, 

59 }, 

60 "export": { 

61 "csv": {"precision": 6}, 

62 "hdf5": {"compression": "gzip", "compression_opts": 4}, 

63 }, 

64} 

65 

66 

67def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: 

68 """Recursively merge two dictionaries. 

69 

70 Values from override take precedence. Nested dictionaries are 

71 merged recursively. 

72 

73 Args: 

74 base: Base dictionary. 

75 override: Dictionary with values to override. 

76 

77 Returns: 

78 Merged dictionary. 

79 """ 

80 result = base.copy() 

81 

82 for key, value in override.items(): 

83 if key in result and isinstance(result[key], dict) and isinstance(value, dict): 

84 result[key] = _deep_merge(result[key], value) 

85 else: 

86 result[key] = value 

87 

88 return result 

89 

90 

91def load_config( 

92 config_path: str | Path | None = None, 

93 *, 

94 use_defaults: bool = True, 

95) -> dict[str, Any]: 

96 """Load configuration from YAML file. 

97 

98 Loads configuration from the specified path and optionally merges 

99 with default values. If no path is specified, looks for configuration 

100 in standard locations. 

101 

102 Args: 

103 config_path: Path to YAML configuration file. If None, searches 

104 for config in standard locations. 

105 use_defaults: If True, merge loaded config with defaults. 

106 

107 Returns: 

108 Configuration dictionary. 

109 

110 Raises: 

111 ConfigurationError: If configuration file cannot be loaded or parsed. 

112 

113 Example: 

114 >>> config = load_config() 

115 >>> print(config["defaults"]["sample_rate"]) 

116 1000000.0 

117 

118 >>> config = load_config("~/.tracekit/config.yaml") 

119 >>> print(config["measurements"]["rise_time"]["ref_levels"]) 

120 [0.1, 0.9] 

121 """ 

122 config: dict[str, Any] = {} 

123 

124 if use_defaults: 

125 config = copy.deepcopy(DEFAULT_CONFIG) 

126 

127 if config_path is None: 

128 # Search standard locations 

129 search_paths = [ 

130 Path.cwd() / "tracekit.yaml", 

131 Path.cwd() / ".tracekit.yaml", 

132 Path.home() / ".tracekit" / "config.yaml", 

133 Path.home() / ".config" / "tracekit" / "config.yaml", 

134 ] 

135 

136 for path in search_paths: 

137 if path.exists(): 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true

138 config_path = path 

139 break 

140 

141 if config_path is not None: 

142 config_path = Path(config_path).expanduser() 

143 

144 if not config_path.exists(): 

145 raise ConfigurationError( 

146 "Configuration file not found", 

147 config_key=str(config_path), 

148 fix_hint=f"Create configuration file at {config_path}", 

149 ) 

150 

151 if not YAML_AVAILABLE: 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true

152 raise ConfigurationError( 

153 "YAML support not available", 

154 details="PyYAML package is required for configuration loading", 

155 fix_hint="Install PyYAML: pip install pyyaml", 

156 ) 

157 

158 try: 

159 with open(config_path, encoding="utf-8") as f: 

160 user_config = yaml.safe_load(f) 

161 except yaml.YAMLError as e: 

162 raise ConfigurationError( 

163 "Failed to parse configuration file", 

164 config_key=str(config_path), 

165 details=str(e), 

166 ) from e 

167 except OSError as e: 

168 raise ConfigurationError( 

169 "Failed to read configuration file", 

170 config_key=str(config_path), 

171 details=str(e), 

172 ) from e 

173 

174 if user_config is not None: 174 ↛ 177line 174 didn't jump to line 177 because the condition on line 174 was always true

175 config = _deep_merge(config, user_config) if use_defaults else user_config 

176 

177 return config 

178 

179 

180def validate_config(config: dict[str, Any]) -> bool: 

181 """Validate configuration against schema. 

182 

183 Checks that required fields exist and have valid types. 

184 

185 Args: 

186 config: Configuration dictionary to validate. 

187 

188 Returns: 

189 True if configuration is valid. 

190 

191 Raises: 

192 ConfigurationError: If configuration is invalid. 

193 

194 Example: 

195 >>> config = load_config() 

196 >>> validate_config(config) 

197 True 

198 """ 

199 # Required top-level sections 

200 required_sections = ["defaults", "loaders"] 

201 

202 for section in required_sections: 

203 if section not in config: 

204 raise ConfigurationError( 

205 f"Missing required configuration section: {section}", 

206 config_key=section, 

207 ) 

208 

209 # Validate defaults section 

210 defaults = config.get("defaults", {}) 

211 if "sample_rate" in defaults: 

212 sample_rate = defaults["sample_rate"] 

213 if not isinstance(sample_rate, int | float) or sample_rate <= 0: 

214 raise ConfigurationError( 

215 "Invalid sample_rate in defaults", 

216 config_key="defaults.sample_rate", 

217 expected_type="positive number", 

218 actual_value=sample_rate, 

219 ) 

220 

221 # Validate loaders section 

222 loaders = config.get("loaders", {}) 

223 if "formats" in loaders: 

224 formats = loaders["formats"] 

225 if not isinstance(formats, list): 

226 raise ConfigurationError( 

227 "Invalid formats in loaders", 

228 config_key="loaders.formats", 

229 expected_type="list of strings", 

230 actual_value=type(formats).__name__, 

231 ) 

232 

233 # Validate measurements section 

234 measurements = config.get("measurements", {}) 

235 for name, settings in measurements.items(): 

236 if "ref_levels" in settings: 

237 ref_levels = settings["ref_levels"] 

238 if not isinstance(ref_levels, list) or len(ref_levels) != 2: 

239 raise ConfigurationError( 

240 f"Invalid ref_levels for {name}", 

241 config_key=f"measurements.{name}.ref_levels", 

242 expected_type="list of 2 numbers", 

243 actual_value=ref_levels, 

244 ) 

245 

246 return True 

247 

248 

249def get_config_value( 

250 config: dict[str, Any], 

251 key_path: str, 

252 default: Any = None, 

253) -> Any: 

254 """Get a configuration value by dot-separated path. 

255 

256 Args: 

257 config: Configuration dictionary. 

258 key_path: Dot-separated path to the value (e.g., "defaults.sample_rate"). 

259 default: Default value if key not found. 

260 

261 Returns: 

262 Configuration value or default. 

263 

264 Example: 

265 >>> config = load_config() 

266 >>> get_config_value(config, "defaults.sample_rate", 1e6) 

267 1000000.0 

268 >>> get_config_value(config, "unknown.key", "default") 

269 'default' 

270 """ 

271 keys = key_path.split(".") 

272 value = config 

273 

274 for key in keys: 

275 if isinstance(value, dict) and key in value: 

276 value = value[key] 

277 else: 

278 return default 

279 

280 return value 

281 

282 

283def save_config(config: dict[str, Any], config_path: str | Path) -> None: 

284 """Save configuration to YAML file. 

285 

286 Args: 

287 config: Configuration dictionary to save. 

288 config_path: Path to save configuration to. 

289 

290 Raises: 

291 ConfigurationError: If configuration cannot be saved. 

292 

293 Example: 

294 >>> config = load_config() 

295 >>> config["defaults"]["sample_rate"] = 2e6 

296 >>> save_config(config, "~/my_config.yaml") 

297 """ 

298 if not YAML_AVAILABLE: 298 ↛ 299line 298 didn't jump to line 299 because the condition on line 298 was never true

299 raise ConfigurationError( 

300 "YAML support not available", 

301 details="PyYAML package is required for configuration saving", 

302 fix_hint="Install PyYAML: pip install pyyaml", 

303 ) 

304 

305 config_path = Path(config_path).expanduser() 

306 

307 # Create parent directory if needed 

308 config_path.parent.mkdir(parents=True, exist_ok=True) 

309 

310 try: 

311 with open(config_path, "w", encoding="utf-8") as f: 

312 yaml.dump(config, f, default_flow_style=False, sort_keys=False) 

313 except OSError as e: 

314 raise ConfigurationError( 

315 "Failed to save configuration file", 

316 config_key=str(config_path), 

317 details=str(e), 

318 ) from e 

319 

320 

321class SmartDefaults: 

322 """Intelligent defaults configuration. 

323 

324 Provides smart default parameters that work for 80% of use cases, 

325 with automatic parameter selection based on signal characteristics. 

326 

327 All auto-selected parameters are logged with rationale for transparency. 

328 Users can override any parameter. 

329 

330 Example: 

331 >>> from tracekit.core.config import SmartDefaults 

332 >>> defaults = SmartDefaults() 

333 >>> # Get smart default for FFT size based on signal length 

334 >>> fft_size = defaults.get_fft_size(signal_length=10000) 

335 >>> print(fft_size) 

336 8192 

337 

338 References: 

339 Best practices from scipy.signal, matplotlib, and numpy 

340 """ 

341 

342 def __init__(self, verbose: bool = False): 

343 """Initialize SmartDefaults. 

344 

345 Args: 

346 verbose: If True, log parameter selection rationale. 

347 """ 

348 self.verbose = verbose 

349 self._log_buffer: list[str] = [] 

350 

351 def _log(self, message: str) -> None: 

352 """Log a parameter selection message. 

353 

354 Args: 

355 message: Log message. 

356 """ 

357 self._log_buffer.append(message) 

358 if self.verbose: 

359 print(f"[SmartDefaults] {message}") 

360 

361 def get_fft_size( 

362 self, 

363 signal_length: int, 

364 *, 

365 min_size: int = 256, 

366 max_size: int = 2**16, 

367 ) -> int: 

368 """Get smart default FFT size. 

369 

370 Args: 

371 signal_length: Length of signal in samples. 

372 min_size: Minimum FFT size. 

373 max_size: Maximum FFT size. 

374 

375 Returns: 

376 Recommended FFT size (power of 2). 

377 """ 

378 # Next power of 2 at or above signal length 

379 nfft = 2 ** int(np.ceil(np.log2(signal_length))) 

380 

381 # Clamp to reasonable range 

382 nfft = max(min_size, min(nfft, max_size)) 

383 

384 self._log( 

385 f"FFT size: {nfft} (signal_length={signal_length}, " 

386 f"next_power_of_2={2 ** int(np.ceil(np.log2(signal_length)))})" 

387 ) 

388 

389 return nfft # type: ignore[no-any-return] 

390 

391 def get_window_function( 

392 self, 

393 application: str = "general", 

394 *, 

395 dynamic_range_db: float = 60.0, 

396 ) -> str: 

397 """Get smart default window function. 

398 

399 Args: 

400 application: Application type ('general', 'narrowband', 'transient'). 

401 dynamic_range_db: Required dynamic range in dB. 

402 

403 Returns: 

404 Window function name. 

405 """ 

406 if application == "transient": 

407 window = "boxcar" 

408 reason = "rectangular window for transient analysis" 

409 elif dynamic_range_db > 80: 

410 window = "blackman-harris" 

411 reason = f"high dynamic range ({dynamic_range_db} dB) requires Blackman-Harris" 

412 elif dynamic_range_db > 60: 

413 window = "blackman" 

414 reason = f"moderate-high dynamic range ({dynamic_range_db} dB) uses Blackman" 

415 elif application == "narrowband": 

416 window = "flattop" 

417 reason = "narrowband analysis uses flat-top for amplitude accuracy" 

418 else: 

419 window = "hann" 

420 reason = "general purpose uses Hann window" 

421 

422 self._log(f"Window function: {window} ({reason})") 

423 

424 return window 

425 

426 def get_overlap( 

427 self, 

428 method: str = "welch", 

429 window: str = "hann", 

430 ) -> float: 

431 """Get smart default overlap for windowed methods. 

432 

433 Args: 

434 method: Analysis method ('welch', 'bartlett', 'stft'). 

435 window: Window function name. 

436 

437 Returns: 

438 Overlap fraction (0-1). 

439 """ 

440 if method == "bartlett": 

441 overlap = 0.0 

442 reason = "Bartlett method uses no overlap" 

443 elif window in ["hann", "hamming", "blackman"]: 

444 overlap = 0.5 

445 reason = f"{window} window typically uses 50% overlap" 

446 elif window == "blackman-harris": 446 ↛ 450line 446 didn't jump to line 450 because the condition on line 446 was always true

447 overlap = 0.75 

448 reason = "Blackman-Harris uses 75% overlap for smoothness" 

449 else: 

450 overlap = 0.5 

451 reason = "default 50% overlap for general windows" 

452 

453 self._log(f"Overlap: {overlap * 100:.0f}% ({reason})") 

454 

455 return overlap 

456 

457 def get_reference_levels( 

458 self, 

459 measurement: str = "rise_time", 

460 ) -> tuple[float, float]: 

461 """Get smart default reference levels for timing measurements. 

462 

463 Args: 

464 measurement: Measurement type ('rise_time', 'fall_time', etc.). 

465 

466 Returns: 

467 Tuple of (low_level, high_level) as fractions (0-1). 

468 """ 

469 if measurement in ["rise_time", "slew_rate"]: 

470 levels = (0.1, 0.9) 

471 reason = "10%-90% is IEEE 181-2011 standard for rise time" 

472 elif measurement == "fall_time": 

473 levels = (0.9, 0.1) 

474 reason = "90%-10% for fall time per IEEE 181-2011" 

475 elif measurement in ["propagation_delay", "setup_time", "hold_time"]: 475 ↛ 479line 475 didn't jump to line 479 because the condition on line 475 was always true

476 levels = (0.5, 0.5) 

477 reason = "50% threshold for timing measurements" 

478 else: 

479 levels = (0.1, 0.9) 

480 reason = "default 10%-90% for general measurements" 

481 

482 self._log(f"Reference levels: {levels[0]:.0%}-{levels[1]:.0%} ({reason})") 

483 

484 return levels 

485 

486 def get_log_messages(self) -> list[str]: 

487 """Get all logged parameter selection messages. 

488 

489 Returns: 

490 List of log messages. 

491 """ 

492 return self._log_buffer.copy() 

493 

494 def clear_log(self) -> None: 

495 """Clear the log buffer.""" 

496 self._log_buffer.clear() 

497 

498 

499__all__ = [ 

500 "DEFAULT_CONFIG", 

501 "SmartDefaults", 

502 "get_config_value", 

503 "load_config", 

504 "save_config", 

505 "validate_config", 

506]