Coverage for src / tracekit / config / protocol.py: 88%
297 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Protocol definition registry and loading.
3This module provides protocol definition management including registry,
4loading from YAML/JSON files, inheritance, hot reload support, version
5migration, and circular dependency detection.
6"""
8from __future__ import annotations
10import contextlib
11import logging
12import os
13import threading
14import time
15from dataclasses import dataclass, field
16from pathlib import Path
17from typing import TYPE_CHECKING, Any
19import yaml
21from tracekit.config.schema import validate_against_schema
22from tracekit.core.exceptions import ConfigurationError
24if TYPE_CHECKING:
25 from collections.abc import Callable
27logger = logging.getLogger(__name__)
30@dataclass
31class ProtocolDefinition:
32 """Protocol definition with metadata and configuration.
34 Attributes:
35 name: Protocol identifier (e.g., "uart", "spi")
36 version: Protocol version (semver)
37 description: Human-readable description
38 author: Protocol definition author
39 timing: Timing configuration (baud rates, data bits, etc.)
40 voltage_levels: Logic level configuration
41 state_machine: Protocol state machine definition
42 extends: Parent protocol name for inheritance
43 metadata: Additional custom metadata
44 source_file: Path to source file (for hot reload)
45 schema_version: Schema version for migration support
47 Example:
48 >>> protocol = ProtocolDefinition(
49 ... name="uart",
50 ... version="1.0.0",
51 ... timing={"baud_rates": [9600, 115200]}
52 ... )
53 """
55 name: str
56 version: str = "1.0.0"
57 description: str = ""
58 author: str = ""
59 timing: dict[str, Any] = field(default_factory=dict)
60 voltage_levels: dict[str, Any] = field(default_factory=dict)
61 state_machine: dict[str, Any] = field(default_factory=dict)
62 extends: str | None = None
63 metadata: dict[str, Any] = field(default_factory=dict)
64 source_file: str | None = None
65 schema_version: str = "1.0.0"
67 @property
68 def supports_digital(self) -> bool:
69 """Check if protocol supports digital signals."""
70 return True # Most protocols are digital
72 @property
73 def supports_analog(self) -> bool:
74 """Check if protocol requires analog threshold detection."""
75 return bool(self.voltage_levels)
77 @property
78 def sample_rate_min(self) -> float:
79 """Minimum sample rate required for decoding."""
80 # Estimate from baud rate (need 10x oversampling typically)
81 baud_rates = self.timing.get("baud_rates", [])
82 if baud_rates:
83 max_baud = max(baud_rates)
84 return float(max_baud * 10)
85 return 1e6 # Default 1 MHz
87 @property
88 def sample_rate_max(self) -> float | None:
89 """Maximum useful sample rate for decoding."""
90 return None # No upper limit typically
92 @property
93 def bit_widths(self) -> list[int]:
94 """Supported data bit widths."""
95 return self.timing.get("data_bits", [8]) # type: ignore[no-any-return]
98@dataclass
99class ProtocolCapabilities:
100 """Protocol capabilities for querying and filtering.
102 Attributes:
103 supports_digital: Whether protocol uses digital signals
104 supports_analog: Whether protocol needs analog thresholds
105 sample_rate_min: Minimum required sample rate (Hz)
106 sample_rate_max: Maximum useful sample rate (Hz)
107 bit_widths: Supported data widths
108 """
110 supports_digital: bool = True
111 supports_analog: bool = False
112 sample_rate_min: float = 1e6
113 sample_rate_max: float | None = None
114 bit_widths: list[int] = field(default_factory=lambda: [8])
117class ProtocolRegistry:
118 """Central registry of all protocol definitions.
120 Provides O(1) lookup by name, version queries, capability filtering,
121 and enumeration for UI integration.
123 Example:
124 >>> registry = ProtocolRegistry()
125 >>> uart = registry.get("uart")
126 >>> i2c = registry.get("i2c", version="2.1.0")
127 >>> all_protocols = registry.list()
128 >>> digital = registry.filter(supports_digital=True)
129 """
131 _instance: ProtocolRegistry | None = None
133 def __new__(cls) -> ProtocolRegistry:
134 """Ensure singleton instance."""
135 if cls._instance is None:
136 cls._instance = super().__new__(cls)
137 cls._instance._protocols: dict[str, dict[str, ProtocolDefinition]] = {} # type: ignore[misc, attr-defined]
138 cls._instance._default_versions: dict[str, str] = {} # type: ignore[misc, attr-defined]
139 cls._instance._watchers: list[Callable[[ProtocolDefinition], None]] = [] # type: ignore[misc, attr-defined]
140 return cls._instance
142 def register(
143 self,
144 protocol: ProtocolDefinition,
145 *,
146 set_default: bool = True,
147 overwrite: bool = False,
148 ) -> None:
149 """Register a protocol definition.
151 Args:
152 protocol: Protocol definition to register
153 set_default: If True, set as default version
154 overwrite: If True, allow overwriting existing registration
156 Raises:
157 ValueError: If protocol already registered and overwrite=False
159 Example:
160 >>> registry.register(uart_protocol)
161 """
162 if protocol.name not in self._protocols: # type: ignore[attr-defined]
163 self._protocols[protocol.name] = {} # type: ignore[attr-defined]
165 if protocol.version in self._protocols[protocol.name] and not overwrite: # type: ignore[attr-defined]
166 raise ValueError(f"Protocol '{protocol.name}' v{protocol.version} already registered")
168 self._protocols[protocol.name][protocol.version] = protocol # type: ignore[attr-defined]
170 if set_default:
171 self._default_versions[protocol.name] = protocol.version # type: ignore[attr-defined]
173 logger.debug(f"Registered protocol: {protocol.name} v{protocol.version}")
175 def get(self, name: str, version: str | None = None) -> ProtocolDefinition:
176 """Get protocol by name and optional version.
178 Args:
179 name: Protocol name
180 version: Specific version or None for default
182 Returns:
183 Protocol definition
185 Raises:
186 KeyError: If protocol not found
188 Example:
189 >>> uart = registry.get("uart")
190 >>> i2c = registry.get("i2c", version="2.1.0")
191 """
192 if name not in self._protocols: # type: ignore[attr-defined]
193 raise KeyError(
194 f"Protocol '{name}' not found. Available: {list(self._protocols.keys())}" # type: ignore[attr-defined]
195 )
197 if version is None:
198 version = self._default_versions.get(name) # type: ignore[attr-defined]
199 if version is None: 199 ↛ 201line 199 didn't jump to line 201 because the condition on line 199 was never true
200 # Get latest version
201 versions = sorted(self._protocols[name].keys()) # type: ignore[attr-defined]
202 version = versions[-1] if versions else None
204 if version is None or version not in self._protocols[name]: # type: ignore[attr-defined] 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true
205 raise KeyError(
206 f"Protocol '{name}' version '{version}' not found. "
207 f"Available versions: {list(self._protocols[name].keys())}" # type: ignore[attr-defined]
208 )
210 return self._protocols[name][version] # type: ignore[no-any-return, attr-defined]
212 def list(self) -> list[ProtocolDefinition]:
213 """List all available protocols (default versions).
215 Returns:
216 Sorted list of protocol definitions
218 Example:
219 >>> for proto in registry.list():
220 ... print(f"{proto.name} v{proto.version}: {proto.description}")
221 """
222 protocols = []
223 for name in sorted(self._protocols.keys()): # type: ignore[attr-defined]
224 version = self._default_versions.get(name) # type: ignore[attr-defined]
225 if version and version in self._protocols[name]: # type: ignore[attr-defined] 225 ↛ 227line 225 didn't jump to line 227 because the condition on line 225 was always true
226 protocols.append(self._protocols[name][version]) # type: ignore[attr-defined]
227 elif self._protocols[name]: # type: ignore[attr-defined]
228 # Get latest version
229 latest = sorted(self._protocols[name].keys())[-1] # type: ignore[attr-defined]
230 protocols.append(self._protocols[name][latest]) # type: ignore[attr-defined]
231 return protocols
233 def get_capabilities(self, name: str) -> ProtocolCapabilities:
234 """Query protocol capabilities.
236 Args:
237 name: Protocol name
239 Returns:
240 Protocol capabilities
242 Example:
243 >>> caps = registry.get_capabilities("uart")
244 >>> print(f"Sample rate: {caps.sample_rate_min}-{caps.sample_rate_max} Hz")
245 """
246 protocol = self.get(name)
247 return ProtocolCapabilities(
248 supports_digital=protocol.supports_digital,
249 supports_analog=protocol.supports_analog,
250 sample_rate_min=protocol.sample_rate_min,
251 sample_rate_max=protocol.sample_rate_max,
252 bit_widths=protocol.bit_widths,
253 )
255 def filter(
256 self,
257 supports_digital: bool | None = None,
258 supports_analog: bool | None = None,
259 sample_rate_min__gte: float | None = None,
260 sample_rate_max__lte: float | None = None,
261 ) -> list[ProtocolDefinition]: # type: ignore[valid-type]
262 """Filter protocols by capabilities.
264 Args:
265 supports_digital: Filter by digital support
266 supports_analog: Filter by analog support
267 sample_rate_min__gte: Minimum sample rate >= value
268 sample_rate_max__lte: Maximum sample rate <= value
270 Returns:
271 List of matching protocols
273 Example:
274 >>> digital = registry.filter(supports_digital=True)
275 >>> high_speed = registry.filter(sample_rate_min__gte=1_000_000)
276 """
277 results = []
278 for protocol in self.list():
279 match = True
281 if supports_digital is not None:
282 if protocol.supports_digital != supports_digital: 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true
283 match = False
285 if supports_analog is not None: 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true
286 if protocol.supports_analog != supports_analog:
287 match = False
289 if sample_rate_min__gte is not None:
290 if protocol.sample_rate_min < sample_rate_min__gte: 290 ↛ 291line 290 didn't jump to line 291 because the condition on line 290 was never true
291 match = False
293 if sample_rate_max__lte is not None and ( 293 ↛ 296line 293 didn't jump to line 296 because the condition on line 293 was never true
294 protocol.sample_rate_max and protocol.sample_rate_max > sample_rate_max__lte
295 ):
296 match = False
298 if match: 298 ↛ 278line 298 didn't jump to line 278 because the condition on line 298 was always true
299 results.append(protocol)
301 return results
303 def has_protocol(self, name: str, version: str | None = None) -> bool:
304 """Check if protocol is registered.
306 Args:
307 name: Protocol name
308 version: Specific version or None for any
310 Returns:
311 True if registered
312 """
313 if name not in self._protocols: # type: ignore[attr-defined]
314 return False
315 if version is None:
316 return True
317 return version in self._protocols[name] # type: ignore[attr-defined]
319 def list_versions(self, name: str) -> list[str]: # type: ignore[valid-type]
320 """List all versions of a protocol.
322 Args:
323 name: Protocol name
325 Returns:
326 List of version strings
327 """
328 if name not in self._protocols: # type: ignore[attr-defined] 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true
329 return []
330 return sorted(self._protocols[name].keys()) # type: ignore[attr-defined]
332 def on_change(self, callback: Callable[[ProtocolDefinition], None]) -> None:
333 """Register callback for protocol changes (hot reload support).
335 Args:
336 callback: Function to call when protocol is reloaded
338 Example:
339 >>> watcher = registry.on_change(lambda proto: print(f"Reloaded {proto.name}"))
340 """
341 self._watchers.append(callback) # type: ignore[attr-defined]
343 def _notify_change(self, protocol: ProtocolDefinition) -> None:
344 """Notify watchers of protocol change."""
345 for callback in self._watchers: # type: ignore[attr-defined]
346 try:
347 callback(protocol)
348 except Exception as e:
349 logger.warning(f"Protocol change callback failed: {e}")
352def load_protocol(path: str | Path, validate: bool = True) -> ProtocolDefinition:
353 """Load protocol definition from YAML or JSON file.
355 Args:
356 path: Path to protocol definition file
357 validate: If True, validate against schema
359 Returns:
360 Loaded protocol definition
362 Raises:
363 ConfigurationError: If file invalid or validation fails
365 Example:
366 >>> protocol = load_protocol("configs/uart.yaml")
367 >>> protocol = load_protocol("configs/i2c.json")
368 """
369 path = Path(path)
371 if not path.exists():
372 raise ConfigurationError(
373 f"Protocol definition file not found: {path.name}", details=f"File path: {path}"
374 )
376 try:
377 with open(path, encoding="utf-8") as f:
378 content = f.read()
379 if path.suffix in (".yaml", ".yml"):
380 data = yaml.safe_load(content)
381 else:
382 import json
384 data = json.loads(content)
386 except yaml.YAMLError as e:
387 raise ConfigurationError(
388 f"YAML parse error in {path.name}", details=f"File: {path}\nError: {e}"
389 ) from e
390 except Exception as e:
391 raise ConfigurationError(
392 f"Failed to load protocol file: {path.name}", details=f"File: {path}\nError: {e}"
393 ) from e
395 # Handle nested 'protocol' key
396 if "protocol" in data:
397 data = data["protocol"]
399 if validate:
400 try:
401 validate_against_schema(data, "protocol")
402 except Exception as e:
403 raise ConfigurationError(
404 f"Protocol validation failed for {path.name}",
405 details=f"File: {path}\nError: {e}",
406 ) from e
408 protocol = ProtocolDefinition(
409 name=data.get("name", path.stem),
410 version=data.get("version", "1.0.0"),
411 description=data.get("description", ""),
412 author=data.get("author", ""),
413 timing=data.get("timing", {}),
414 voltage_levels=data.get("voltage_levels", {}),
415 state_machine=data.get("state_machine", {}),
416 extends=data.get("extends"),
417 metadata=data.get("metadata", {}),
418 source_file=str(path),
419 )
421 logger.info(f"Loaded protocol: {protocol.name} v{protocol.version} from {path}")
422 return protocol
425def resolve_inheritance(
426 protocol: ProtocolDefinition,
427 registry: ProtocolRegistry,
428 *,
429 max_depth: int = 5,
430 deep_merge: bool = False,
431 _visited: set[str] | None = None,
432) -> ProtocolDefinition:
433 """Resolve protocol inheritance chain with circular detection.
435 Supports multi-level inheritance (up to 5 levels deep) with both
436 shallow and deep merge strategies for nested properties.
438 Args:
439 protocol: Protocol with potential inheritance
440 registry: Registry to look up parent protocols
441 max_depth: Maximum inheritance depth (default 5.)
442 deep_merge: If True, recursively merge nested dicts; else shallow merge
443 _visited: Set of visited protocols for cycle detection
445 Returns:
446 Protocol with inherited properties merged
448 Raises:
449 ConfigurationError: If circular inheritance or depth exceeded
451 Example:
452 >>> resolved = resolve_inheritance(spi_variant, registry)
453 >>> resolved_deep = resolve_inheritance(spi_variant, registry, deep_merge=True)
454 """
455 if _visited is None:
456 _visited = set()
458 if not protocol.extends:
459 return protocol
461 # Cycle detection using DFS with visited set
462 if protocol.name in _visited:
463 cycle_list = [*list(_visited), protocol.name]
464 cycle = " → ".join(cycle_list)
465 raise ConfigurationError(
466 f"Circular inheritance detected: {cycle}",
467 details=f"Protocol inheritance forms a cycle. Remove 'extends' from one of: {', '.join(cycle_list)}",
468 fix_hint=f"Break the cycle by removing the 'extends' field from {protocol.name}",
469 )
471 # Depth limit check
472 if len(_visited) >= max_depth:
473 chain = " → ".join([*list(_visited), protocol.name])
474 raise ConfigurationError(
475 f"Inheritance depth exceeded maximum of {max_depth}",
476 details=f"Current chain: {chain}",
477 fix_hint="Flatten the inheritance hierarchy or increase max_depth",
478 )
480 _visited.add(protocol.name)
482 # Get parent protocol
483 try:
484 parent = registry.get(protocol.extends)
485 except KeyError as e:
486 available = ", ".join(registry._protocols.keys()) # type: ignore[attr-defined]
487 raise ConfigurationError(
488 f"Parent protocol '{protocol.extends}' not found",
489 details=f"Protocol '{protocol.name}' extends missing parent. Available: {available}",
490 fix_hint=f"Add protocol '{protocol.extends}' to registry or fix 'extends' field",
491 ) from e
493 # Recursively resolve parent
494 resolved_parent = resolve_inheritance(
495 parent, registry, max_depth=max_depth, deep_merge=deep_merge, _visited=_visited
496 )
498 # Merge properties (child overrides parent)
499 if deep_merge:
500 merged_timing = _deep_merge_dicts(resolved_parent.timing, protocol.timing)
501 merged_voltage = _deep_merge_dicts(resolved_parent.voltage_levels, protocol.voltage_levels)
502 merged_state = _deep_merge_dicts(resolved_parent.state_machine, protocol.state_machine)
503 merged_metadata = _deep_merge_dicts(resolved_parent.metadata, protocol.metadata)
504 else:
505 # Shallow merge (default)
506 merged_timing = {**resolved_parent.timing, **protocol.timing}
507 merged_voltage = {**resolved_parent.voltage_levels, **protocol.voltage_levels}
508 merged_state = {**resolved_parent.state_machine, **protocol.state_machine}
509 merged_metadata = {**resolved_parent.metadata, **protocol.metadata}
511 return ProtocolDefinition(
512 name=protocol.name,
513 version=protocol.version,
514 description=protocol.description or resolved_parent.description,
515 author=protocol.author or resolved_parent.author,
516 timing=merged_timing,
517 voltage_levels=merged_voltage,
518 state_machine=merged_state,
519 extends=None, # Clear extends after resolution
520 metadata=merged_metadata,
521 source_file=protocol.source_file,
522 schema_version=protocol.schema_version,
523 )
526def _deep_merge_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
527 """Deep merge two dictionaries recursively.
529 Args:
530 base: Base dictionary
531 override: Override dictionary (takes precedence)
533 Returns:
534 Merged dictionary
536 Example:
537 >>> base = {"a": {"b": 1, "c": 2}}
538 >>> override = {"a": {"c": 3, "d": 4}}
539 >>> _deep_merge_dicts(base, override)
540 {'a': {'b': 1, 'c': 3, 'd': 4}}
541 """
542 result = base.copy()
543 for key, value in override.items():
544 if key in result and isinstance(result[key], dict) and isinstance(value, dict):
545 result[key] = _deep_merge_dicts(result[key], value)
546 else:
547 result[key] = value
548 return result
551class ProtocolWatcher:
552 """File watcher for hot-reloading protocol definitions.
554 Monitors a directory for protocol file changes and reloads
555 automatically with <2s latency using background thread polling.
557 Example:
558 >>> watcher = ProtocolWatcher("configs/")
559 >>> watcher.on_change(lambda proto: print(f"Reloaded {proto.name}"))
560 >>> watcher.start()
561 >>> # ... later ...
562 >>> watcher.stop()
563 """
565 def __init__(
566 self,
567 directory: str | Path,
568 *,
569 poll_interval: float = 1.0,
570 registry: ProtocolRegistry | None = None,
571 ):
572 """Initialize watcher for directory.
574 Args:
575 directory: Directory to watch for protocol files
576 poll_interval: Polling interval in seconds (default 1.0 for <2s latency)
577 registry: Registry to auto-register reloaded protocols
578 """
579 self.directory = Path(directory)
580 self.poll_interval = poll_interval
581 self.registry = registry
582 self._callbacks: list[Callable[[ProtocolDefinition], None]] = []
583 self._running = False
584 self._thread: threading.Thread | None = None
585 self._file_mtimes: dict[str, float] = {}
587 def on_change(self, callback: Callable[[ProtocolDefinition], None]) -> None:
588 """Register callback for protocol changes.
590 Args:
591 callback: Function to call with reloaded protocol
592 """
593 self._callbacks.append(callback)
595 def start(self) -> None:
596 """Start watching for file changes in background thread.
598 The watcher polls the directory every poll_interval seconds,
599 ensuring <2s latency for detecting changes.
600 """
601 if self._running: 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true
602 logger.warning("Protocol watcher already running")
603 return
605 self._running = True
606 self._scan_files()
608 # Start background polling thread
609 self._thread = threading.Thread(target=self._watch_loop, daemon=True)
610 self._thread.start()
612 logger.info(
613 f"Started watching protocols in {self.directory} (poll interval: {self.poll_interval}s)"
614 )
616 def stop(self) -> None:
617 """Stop watching for file changes."""
618 self._running = False
619 if self._thread and self._thread.is_alive(): 619 ↛ 621line 619 didn't jump to line 621 because the condition on line 619 was always true
620 self._thread.join(timeout=2.0)
621 logger.info("Stopped protocol watcher")
623 def _watch_loop(self) -> None:
624 """Background thread polling loop."""
625 while self._running:
626 try:
627 self.check_changes()
628 except Exception as e:
629 logger.error(f"Error in protocol watcher: {e}")
630 time.sleep(self.poll_interval)
632 def check_changes(self) -> list[ProtocolDefinition]:
633 """Check for changed files and reload.
635 Returns:
636 List of reloaded protocols
637 """
638 if not self._running: 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true
639 return []
641 reloaded = []
642 for file_path in self.directory.glob("**/*.yaml"):
643 if not file_path.is_file(): 643 ↛ 644line 643 didn't jump to line 644 because the condition on line 643 was never true
644 continue
646 try:
647 mtime = os.path.getmtime(file_path) # noqa: PTH204
648 except OSError:
649 continue
651 str_path = str(file_path)
653 if str_path in self._file_mtimes and mtime > self._file_mtimes[str_path]:
654 try:
655 protocol = load_protocol(file_path)
656 reloaded.append(protocol)
658 # Auto-register if registry provided
659 if self.registry:
660 self.registry.register(protocol, overwrite=True)
661 self.registry._notify_change(protocol)
663 self._notify(protocol)
664 logger.info(f"Hot-reloaded protocol: {protocol.name} from {file_path}")
665 except Exception as e:
666 logger.warning(f"Failed to reload {file_path}: {e}")
668 self._file_mtimes[str_path] = mtime
670 return reloaded
672 def _scan_files(self) -> None:
673 """Initial scan of directory."""
674 for file_path in self.directory.glob("**/*.yaml"):
675 if file_path.is_file(): 675 ↛ 674line 675 didn't jump to line 674 because the condition on line 675 was always true
676 with contextlib.suppress(OSError):
677 self._file_mtimes[str(file_path)] = os.path.getmtime(file_path) # noqa: PTH204
679 def _notify(self, protocol: ProtocolDefinition) -> None:
680 """Notify callbacks of protocol change."""
681 for callback in self._callbacks:
682 try:
683 callback(protocol)
684 except Exception as e:
685 logger.warning(f"Protocol change callback failed: {e}")
688# Global registry instance
689_registry: ProtocolRegistry | None = None
692def get_protocol_registry() -> ProtocolRegistry:
693 """Get the global protocol registry.
695 Returns:
696 Global ProtocolRegistry instance
697 """
698 global _registry
699 if _registry is None:
700 _registry = ProtocolRegistry()
701 _register_builtin_protocols(_registry)
702 return _registry
705def _register_builtin_protocols(registry: ProtocolRegistry) -> None:
706 """Register built-in protocol definitions."""
707 # UART
708 registry.register(
709 ProtocolDefinition(
710 name="uart",
711 version="1.0.0",
712 description="Universal Asynchronous Receiver/Transmitter",
713 timing={
714 "baud_rates": [
715 9600,
716 19200,
717 38400,
718 57600,
719 115200,
720 230400,
721 460800,
722 921600,
723 ],
724 "data_bits": [7, 8],
725 "stop_bits": [1, 1.5, 2],
726 "parity": ["none", "even", "odd", "mark", "space"],
727 },
728 voltage_levels={"logic_family": "TTL", "idle_state": "high"},
729 state_machine={
730 "states": ["IDLE", "START", "DATA", "PARITY", "STOP"],
731 "initial_state": "IDLE",
732 },
733 )
734 )
736 # SPI
737 registry.register(
738 ProtocolDefinition(
739 name="spi",
740 version="1.0.0",
741 description="Serial Peripheral Interface",
742 timing={
743 "data_bits": [8, 16, 32],
744 "clock_polarity": [0, 1],
745 "clock_phase": [0, 1],
746 },
747 state_machine={"states": ["IDLE", "ACTIVE"], "initial_state": "IDLE"},
748 )
749 )
751 # I2C
752 registry.register(
753 ProtocolDefinition(
754 name="i2c",
755 version="1.0.0",
756 description="Inter-Integrated Circuit",
757 timing={
758 "speed_modes": ["standard", "fast", "fast_plus", "high_speed"],
759 "data_bits": [8],
760 },
761 state_machine={
762 "states": ["IDLE", "START", "ADDRESS", "DATA", "ACK", "STOP"],
763 "initial_state": "IDLE",
764 },
765 )
766 )
768 # CAN
769 registry.register(
770 ProtocolDefinition(
771 name="can",
772 version="1.0.0",
773 description="Controller Area Network",
774 timing={"baud_rates": [125000, 250000, 500000, 1000000]},
775 state_machine={
776 "states": [
777 "IDLE",
778 "SOF",
779 "ARBITRATION",
780 "CONTROL",
781 "DATA",
782 "CRC",
783 "ACK",
784 "EOF",
785 ],
786 "initial_state": "IDLE",
787 },
788 )
789 )
792def migrate_protocol_schema(
793 protocol_data: dict[str, Any], from_version: str, to_version: str = "1.0.0"
794) -> dict[str, Any]:
795 """Migrate protocol definition between schema versions.
797 Args:
798 protocol_data: Protocol data dictionary
799 from_version: Source schema version
800 to_version: Target schema version (default current)
802 Returns:
803 Migrated protocol data
805 Raises:
806 ConfigurationError: If migration fails or unsupported version
808 Example:
809 >>> old_proto = {"name": "uart", "timing": {...}}
810 >>> new_proto = migrate_protocol_schema(old_proto, "0.9.0", "1.0.0")
811 """
812 if from_version == to_version:
813 return protocol_data
815 # Define migration paths
816 migrations = {
817 ("0.9.0", "1.0.0"): _migrate_0_9_to_1_0,
818 ("0.8.0", "0.9.0"): _migrate_0_8_to_0_9,
819 ("0.8.0", "1.0.0"): lambda d: _migrate_0_9_to_1_0(_migrate_0_8_to_0_9(d)),
820 }
822 migration_key = (from_version, to_version)
823 if migration_key not in migrations:
824 raise ConfigurationError(
825 f"No migration path from schema {from_version} to {to_version}",
826 details="Supported migrations: " + ", ".join(f"{k[0]}→{k[1]}" for k in migrations),
827 fix_hint="Manually update the protocol definition or use an intermediate version",
828 )
830 logger.info(f"Migrating protocol schema from {from_version} to {to_version}")
831 try:
832 migrated = migrations[migration_key](protocol_data.copy()) # type: ignore[no-untyped-call]
833 migrated["schema_version"] = to_version
834 return migrated
835 except Exception as e:
836 raise ConfigurationError(
837 f"Schema migration failed from {from_version} to {to_version}",
838 details=str(e),
839 fix_hint="Check migration logs and manually update protocol definition",
840 ) from e
843def _migrate_0_8_to_0_9(data: dict[str, Any]) -> dict[str, Any]:
844 """Migrate from schema 0.8.0 to 0.9.0."""
845 # Example migration: rename 'baudrate' to 'baud_rates' and convert to list
846 if "baudrate" in data.get("timing", {}): 846 ↛ 849line 846 didn't jump to line 849 because the condition on line 846 was always true
847 data.setdefault("timing", {})
848 data["timing"]["baud_rates"] = [data["timing"].pop("baudrate")]
849 return data
852def _migrate_0_9_to_1_0(data: dict[str, Any]) -> dict[str, Any]:
853 """Migrate from schema 0.9.0 to 1.0.0."""
854 # Example migration: add required fields with defaults
855 data.setdefault("version", "1.0.0")
856 data.setdefault("description", "")
857 data.setdefault("author", "")
859 # Convert old state format if needed
860 if "state" in data:
861 data["state_machine"] = data.pop("state")
863 return data
866__all__ = [
867 "ProtocolCapabilities",
868 "ProtocolDefinition",
869 "ProtocolRegistry",
870 "ProtocolWatcher",
871 "get_protocol_registry",
872 "load_protocol",
873 "migrate_protocol_schema",
874 "resolve_inheritance",
875]