Coverage for src / tracekit / inference / protocol_dsl.py: 93%
417 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Protocol Definition Language parser and decoder generator.
3Requirements addressed: PSI-004
5This module provides a declarative DSL for defining custom protocol formats
6that can be used to generate decoders and encoders automatically.
8Key capabilities:
9- Parse YAML-based protocol definitions
10- Support all common field types
11- Conditional fields and length-prefixed data
12- Generate efficient decoders and encoders
13- Comprehensive error reporting
14"""
16import ast
17import operator
18import struct
19from collections.abc import Iterator
20from dataclasses import dataclass, field
21from pathlib import Path
22from typing import Any, Literal
24import yaml
27@dataclass
28class FieldDefinition:
29 """Protocol field definition.
31 : Field specification.
33 Attributes:
34 name: Field name
35 field_type: Field type (uint8, uint16, int32, float32, bytes, string, bitfield, array, struct)
36 Also accessible as 'type' for compatibility.
37 size: Field size (literal or reference to length field)
38 offset: Field offset (optional, auto-calculated if not provided)
39 endian: Byte order ('big' or 'little')
40 condition: Conditional expression for optional fields
41 enum: Enumeration mapping for integer fields
42 validation: Validation rules
43 default: Default value
44 description: Human-readable description
45 value: Expected value for constant fields
46 size_ref: Reference to length field (alias for size when string)
47 element: Element definition for array types (contains type and optionally fields for struct)
48 count_field: Field name that contains the array count
49 count: Fixed array count
50 fields: List of nested field definitions for struct types
51 """
53 name: str
54 field_type: str = (
55 "uint8" # uint8, uint16, int32, float32, bytes, string, bitfield, array, struct
56 )
57 size: int | str | None = None # Can be literal or reference to length field
58 offset: int | None = None
59 endian: Literal["big", "little"] = "big"
60 condition: str | None = None # Conditional field
61 enum: dict[int, str] | dict[str, Any] | None = None
62 validation: dict[str, Any] | None = None
63 default: Any = None
64 description: str = ""
65 value: Any = None # Expected constant value
66 size_ref: str | None = None # Alias for size reference
67 # Array/struct specific fields
68 element: dict[str, Any] | None = None # Element definition for arrays
69 count_field: str | None = None # Field containing array count
70 count: int | None = None # Fixed array count
71 fields: list["FieldDefinition"] | None = None # Nested fields for struct type
73 def __post_init__(self) -> None:
74 """Handle size_ref as alias for size."""
75 if self.size_ref is not None and self.size is None:
76 self.size = self.size_ref
78 @property
79 def type(self) -> str:
80 """Alias for field_type for backward compatibility."""
81 return self.field_type
83 @type.setter
84 def type(self, value: str) -> None:
85 """Set field_type via type property."""
86 self.field_type = value
89@dataclass
90class ProtocolDefinition:
91 """Complete protocol definition.
93 : Complete protocol specification.
95 Attributes:
96 name: Protocol name
97 version: Protocol version
98 description: Protocol description
99 settings: Global settings (endianness, etc.)
100 framing: Framing/sync configuration
101 fields: List of field definitions
102 computed_fields: Computed/derived fields
103 decoding: Decoding settings
104 encoding: Encoding settings
105 endian: Default endianness for all fields
106 """
108 name: str
109 description: str = ""
110 version: str = "1.0"
111 endian: Literal["big", "little"] = "big"
112 fields: list[FieldDefinition] = field(default_factory=list)
113 settings: dict[str, Any] = field(default_factory=dict)
114 framing: dict[str, Any] = field(default_factory=dict)
115 computed_fields: list[dict[str, Any]] = field(default_factory=list)
116 decoding: dict[str, Any] = field(default_factory=dict)
117 encoding: dict[str, Any] = field(default_factory=dict)
119 @classmethod
120 def from_yaml(cls, path: str | Path) -> "ProtocolDefinition":
121 """Load protocol definition from YAML file.
123 : YAML parsing.
125 Args:
126 path: Path to YAML file
128 Returns:
129 ProtocolDefinition instance
130 """
131 with open(path) as f:
132 config = yaml.safe_load(f)
134 return cls.from_dict(config)
136 @classmethod
137 def from_dict(cls, config: dict[str, Any]) -> "ProtocolDefinition":
138 """Create from dictionary.
140 : Configuration parsing.
142 Args:
143 config: Configuration dictionary
145 Returns:
146 ProtocolDefinition instance
147 """
148 # Parse field definitions
149 field_defs = []
150 default_endian = config.get("endian", "big")
151 for field_dict in config.get("fields", []):
152 field_def = cls._parse_field_definition(field_dict, default_endian)
153 field_defs.append(field_def)
155 return cls(
156 name=config.get("name", "unknown"),
157 version=config.get("version", "1.0"),
158 description=config.get("description", ""),
159 endian=default_endian,
160 settings=config.get("settings", {}),
161 framing=config.get("framing", {}),
162 fields=field_defs,
163 computed_fields=config.get("computed_fields", []),
164 decoding=config.get("decoding", {}),
165 encoding=config.get("encoding", {}),
166 )
168 @classmethod
169 def _parse_field_definition(
170 cls, field_dict: dict[str, Any], default_endian: str
171 ) -> "FieldDefinition":
172 """Parse a single field definition from dictionary.
174 Args:
175 field_dict: Field configuration dictionary
176 default_endian: Default endianness
178 Returns:
179 FieldDefinition instance
180 """
181 # Support both 'type' and 'field_type' attribute names
182 field_type = field_dict.get("type") or field_dict.get("field_type", "uint8")
184 # Parse nested fields for struct type
185 nested_fields: list[FieldDefinition] | None = None
186 if field_dict.get("fields"):
187 nested_fields = [
188 cls._parse_field_definition(f, default_endian) for f in field_dict["fields"]
189 ]
191 return FieldDefinition(
192 name=field_dict["name"],
193 field_type=field_type,
194 size=field_dict.get("size"),
195 offset=field_dict.get("offset"),
196 endian=field_dict.get("endian", default_endian),
197 condition=field_dict.get("condition"),
198 enum=field_dict.get("enum"),
199 validation=field_dict.get("validation"),
200 default=field_dict.get("default"),
201 description=field_dict.get("description", ""),
202 value=field_dict.get("value"),
203 size_ref=field_dict.get("size_ref"),
204 element=field_dict.get("element"),
205 count_field=field_dict.get("count_field"),
206 count=field_dict.get("count"),
207 fields=nested_fields,
208 )
211@dataclass
212class DecodedMessage:
213 """A decoded protocol message.
215 : Decoded message representation.
217 This class behaves like a dictionary for field access, supporting
218 operations like `"field_name" in message` and `message["field_name"]`.
220 Attributes:
221 fields: Dictionary of field name -> value
222 raw_data: Original binary data
223 size: Message size in bytes
224 valid: Whether message passed validation
225 errors: List of validation errors
226 """
228 fields: dict[str, Any]
229 raw_data: bytes
230 size: int
231 valid: bool
232 errors: list[str]
234 def __contains__(self, key: str) -> bool:
235 """Check if field exists in message."""
236 return key in self.fields
238 def __getitem__(self, key: str) -> Any:
239 """Get field value by name."""
240 return self.fields[key]
242 def __iter__(self) -> Iterator[str]:
243 """Iterate over field names."""
244 return iter(self.fields)
246 def keys(self) -> Any:
247 """Return field names."""
248 return self.fields.keys()
250 def values(self) -> Any:
251 """Return field values."""
252 return self.fields.values()
254 def items(self) -> Any:
255 """Return field name-value pairs."""
256 return self.fields.items()
258 def get(self, key: str, default: Any = None) -> Any:
259 """Get field value with default."""
260 return self.fields.get(key, default)
263class _SafeConditionEvaluator(ast.NodeVisitor):
264 """Safe evaluator for protocol field conditions.
266 Only allows:
267 - Comparisons: ==, !=, <, <=, >, >=
268 - Logical operations: and, or, not
269 - Constants: numbers, strings, booleans
270 - Variable names from context
272 Security:
273 Uses AST parsing to safely evaluate conditions without eval().
274 """
276 def __init__(self, context: dict[str, Any]):
277 """Initialize with field context.
279 Args:
280 context: Dictionary of field names to values
281 """
282 self.context = context
283 self.compare_ops = {
284 ast.Eq: operator.eq,
285 ast.NotEq: operator.ne,
286 ast.Lt: operator.lt,
287 ast.LtE: operator.le,
288 ast.Gt: operator.gt,
289 ast.GtE: operator.ge,
290 }
292 def eval(self, expression: str) -> bool:
293 """Evaluate condition expression.
295 Args:
296 expression: Condition string
298 Returns:
299 Boolean result
300 """
301 try:
302 tree = ast.parse(expression, mode="eval")
303 result = self.visit(tree.body)
304 return bool(result)
305 except Exception:
306 # If evaluation fails, condition is false
307 return False
309 def visit_Compare(self, node: ast.Compare) -> Any:
310 """Visit comparison operation."""
311 left = self.visit(node.left)
312 for op, comparator in zip(node.ops, node.comparators, strict=True):
313 if type(op) not in self.compare_ops: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true
314 return False
315 right = self.visit(comparator)
316 if not self.compare_ops[type(op)](left, right):
317 return False
318 left = right
319 return True
321 def visit_BoolOp(self, node: ast.BoolOp) -> Any:
322 """Visit boolean operation (and, or)."""
323 if isinstance(node.op, ast.And):
324 return all(self.visit(value) for value in node.values)
325 elif isinstance(node.op, ast.Or):
326 return any(self.visit(value) for value in node.values)
327 return False
329 def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
330 """Visit unary operation (not)."""
331 if isinstance(node.op, ast.Not):
332 return not self.visit(node.operand)
333 return False
335 def visit_Name(self, node: ast.Name) -> Any:
336 """Visit variable name."""
337 return self.context.get(node.id)
339 def visit_Constant(self, node: ast.Constant) -> Any:
340 """Visit constant value."""
341 return node.value
343 def visit_Num(self, node: ast.Num) -> Any:
344 """Visit number (Python <3.8 compatibility)."""
345 return node.n
347 def visit_Str(self, node: ast.Str) -> Any:
348 """Visit string (Python <3.8 compatibility)."""
349 return node.s
351 def visit_NameConstant(self, node: ast.NameConstant) -> Any:
352 """Visit named constant (Python <3.8 compatibility)."""
353 return node.value
355 def generic_visit(self, node: ast.AST) -> Any:
356 """Disallow other node types."""
357 return False
360class ProtocolDecoder:
361 """Decode binary data using protocol definition.
363 : Protocol decoder with full field type support.
364 """
366 def __init__(self, definition: ProtocolDefinition):
367 """Initialize decoder with protocol definition.
369 Args:
370 definition: Protocol definition
371 """
372 self.definition = definition
373 self._endian_map: dict[str, str] = {"big": ">", "little": "<"}
375 @classmethod
376 def load(cls, path: str | Path) -> "ProtocolDecoder":
377 """Load decoder from YAML protocol definition.
379 : Load decoder from file.
381 Args:
382 path: Path to YAML file
384 Returns:
385 ProtocolDecoder instance
386 """
387 definition = ProtocolDefinition.from_yaml(path)
388 return cls(definition)
390 def decode(self, data: bytes, offset: int = 0) -> DecodedMessage:
391 """Decode single message from binary data.
393 : Complete decoding with validation.
395 Args:
396 data: Binary data
397 offset: Starting offset in data
399 Returns:
400 DecodedMessage instance
401 """
402 fields: dict[str, Any] = {}
403 errors: list[str] = []
404 current_offset = offset
405 valid = True
407 # Check minimum length
408 if len(data) - offset < 1:
409 return DecodedMessage(
410 fields={}, raw_data=data[offset:], size=0, valid=False, errors=["Insufficient data"]
411 )
413 # Decode each field
414 for field_def in self.definition.fields:
415 # Check condition
416 if field_def.condition:
417 if not self._evaluate_condition(field_def.condition, fields):
418 continue # Skip this field
420 try:
421 value, bytes_consumed = self._decode_field(data[current_offset:], field_def, fields)
423 # Validate
424 if field_def.validation:
425 validation_error = self._validate_field(value, field_def.validation)
426 if validation_error:
427 errors.append(f"{field_def.name}: {validation_error}")
428 valid = False
430 # Store value
431 fields[field_def.name] = value
432 current_offset += bytes_consumed
434 except Exception as e:
435 errors.append(f"{field_def.name}: {e!s}")
436 valid = False
437 break
439 total_size = current_offset - offset
441 return DecodedMessage(
442 fields=fields,
443 raw_data=data[offset:current_offset],
444 size=total_size,
445 valid=valid and len(errors) == 0,
446 errors=errors,
447 )
449 def decode_stream(self, data: bytes) -> list[DecodedMessage]:
450 """Decode multiple messages from data stream.
452 : Stream decoding with sync detection.
454 Args:
455 data: Binary data stream
457 Returns:
458 List of DecodedMessage instances
459 """
460 messages = []
461 offset = 0
463 # Check for sync pattern
464 sync_pattern = self.definition.framing.get("sync_pattern")
466 while offset < len(data):
467 # Find sync if configured
468 if sync_pattern:
469 sync_offset = self.find_sync(data, offset)
470 if sync_offset is None: 470 ↛ 471line 470 didn't jump to line 471 because the condition on line 470 was never true
471 break # No more sync patterns
472 offset = sync_offset
474 # Decode message
475 msg = self.decode(data, offset)
477 if msg.valid: 477 ↛ 482line 477 didn't jump to line 482 because the condition on line 477 was always true
478 messages.append(msg)
479 offset += msg.size
480 else:
481 # Try to recover by finding next sync
482 if sync_pattern:
483 offset += 1
484 else:
485 break # Can't recover without sync
487 return messages
489 def find_sync(self, data: bytes, start: int = 0) -> int | None:
490 """Find sync pattern in data.
492 : Sync pattern detection.
494 Args:
495 data: Binary data
496 start: Starting offset
498 Returns:
499 Offset of sync pattern or None
500 """
501 sync_pattern = self.definition.framing.get("sync_pattern")
502 if not sync_pattern:
503 return start # No sync pattern, start from beginning
505 # Convert sync pattern (hex string or bytes)
506 if isinstance(sync_pattern, str):
507 if sync_pattern.startswith("0x"):
508 # Hex string like "0xAA55"
509 sync_bytes = bytes.fromhex(sync_pattern[2:])
510 else:
511 sync_bytes = sync_pattern.encode()
512 else:
513 sync_bytes = bytes(sync_pattern)
515 # Search for pattern
516 idx = data.find(sync_bytes, start)
517 if idx == -1:
518 return None
519 return idx
521 def _decode_field(
522 self, data: bytes, field: FieldDefinition, context: dict[str, Any]
523 ) -> tuple[Any, int]:
524 """Decode single field.
526 : Field decoding for all types.
528 Args:
529 data: Binary data
530 field: Field definition
531 context: Previously decoded fields
533 Returns:
534 Tuple of (value, bytes_consumed)
536 Raises:
537 ValueError: If bitfield size is unsupported or field type is unknown
538 """
539 endian = self._endian_map.get(field.endian, ">")
540 field_type = field.field_type
542 # Integer types
543 if field_type in ["uint8", "int8", "uint16", "int16", "uint32", "int32", "uint64", "int64"]:
544 return self._decode_integer(data, field_type, endian)
546 # Float types
547 elif field_type in ["float32", "float64"]:
548 return self._decode_float(data, field_type, endian)
550 # Bytes
551 elif field_type == "bytes":
552 size = self._resolve_size(field.size, context, data)
553 if size > len(data):
554 size = len(data) # Use remaining data
555 return bytes(data[:size]), size
557 # String
558 elif field_type == "string":
559 size = self._resolve_size(field.size, context, data)
560 if size > len(data): 560 ↛ 561line 560 didn't jump to line 561 because the condition on line 560 was never true
561 size = len(data) # Use remaining data
562 string_bytes = data[:size]
563 # Try to decode as UTF-8, fall back to latin-1
564 try:
565 value = string_bytes.decode("utf-8").rstrip("\x00")
566 except UnicodeDecodeError:
567 value = string_bytes.decode("latin-1").rstrip("\x00")
568 return value, size
570 # Bitfield
571 elif field_type == "bitfield":
572 # Decode as uint and extract bits
573 field_size = field.size if isinstance(field.size, int) else 1
574 if field_size == 1:
575 bitfield_value = int(data[0])
576 elif field_size == 2:
577 bitfield_value = struct.unpack(f"{endian}H", data[:2])[0]
578 elif field_size == 4:
579 bitfield_value = struct.unpack(f"{endian}I", data[:4])[0]
580 else:
581 raise ValueError(f"Unsupported bitfield size: {field_size}")
583 # Return as-is, caller can extract specific bits
584 return bitfield_value, field_size
586 # Array
587 elif field_type == "array":
588 return self._decode_array(data, field, context)
590 # Struct (nested)
591 elif field_type == "struct":
592 return self._decode_struct(data, field, context)
594 else:
595 raise ValueError(f"Unknown field type: {field_type}")
597 def _decode_array(
598 self, data: bytes, field: FieldDefinition, context: dict[str, Any]
599 ) -> tuple[list[Any], int]:
600 """Decode array field.
602 : Array field decoding.
604 Args:
605 data: Binary data
606 field: Field definition with element spec
607 context: Previously decoded fields
609 Returns:
610 Tuple of (list of values, bytes_consumed)
612 Raises:
613 ValueError: If array field is missing element definition
614 """
615 elements = []
616 total_consumed = 0
618 # Determine element count
619 count = None
620 if field.count is not None:
621 count = field.count
622 elif field.count_field is not None and field.count_field in context:
623 count = int(context[field.count_field])
625 # Get element definition
626 element_def = field.element
627 if element_def is None:
628 raise ValueError(f"Array field '{field.name}' missing element definition")
630 element_type = element_def.get("type", "uint8")
631 element_endian = element_def.get("endian", field.endian)
633 # If no count, try to decode until data exhausted
634 idx = 0
635 while len(data) - total_consumed > 0:
636 if count is not None and idx >= count: 636 ↛ 637line 636 didn't jump to line 637 because the condition on line 636 was never true
637 break
639 # Create a temporary field definition for the element
640 if element_type == "struct":
641 # Nested struct in array
642 nested_fields = element_def.get("fields", [])
643 parsed_fields = [
644 ProtocolDefinition._parse_field_definition(f, element_endian)
645 for f in nested_fields
646 ]
647 elem_field = FieldDefinition(
648 name=f"{field.name}[{idx}]",
649 field_type="struct",
650 endian=element_endian,
651 fields=parsed_fields,
652 )
653 value, consumed = self._decode_struct(data[total_consumed:], elem_field, context)
654 else:
655 # Simple element type
656 elem_field = FieldDefinition(
657 name=f"{field.name}[{idx}]",
658 field_type=element_type,
659 endian=element_endian,
660 size=element_def.get("size"),
661 )
662 value, consumed = self._decode_field(data[total_consumed:], elem_field, context)
664 if consumed == 0: 664 ↛ 665line 664 didn't jump to line 665 because the condition on line 664 was never true
665 break # Prevent infinite loop
667 elements.append(value)
668 total_consumed += consumed
669 idx += 1
671 return elements, total_consumed
673 def _decode_struct(
674 self, data: bytes, field: FieldDefinition, context: dict[str, Any]
675 ) -> tuple[dict[str, Any], int]:
676 """Decode struct field.
678 : Nested struct field decoding.
680 Args:
681 data: Binary data
682 field: Field definition with nested fields
683 context: Previously decoded fields
685 Returns:
686 Tuple of (dict of field values, bytes_consumed)
688 Raises:
689 ValueError: If struct field is missing fields definition
690 """
691 struct_fields: dict[str, Any] = {}
692 total_consumed = 0
694 # Get nested field definitions
695 nested_fields = field.fields
696 if nested_fields is None:
697 raise ValueError(f"Struct field '{field.name}' missing fields definition")
699 # Decode each nested field
700 for nested_field in nested_fields:
701 if len(data) - total_consumed < 1: 701 ↛ 702line 701 didn't jump to line 702 because the condition on line 701 was never true
702 break # Not enough data
704 # Check condition if present
705 if nested_field.condition:
706 # Use combined context (parent context + struct fields decoded so far)
707 combined_context = {**context, **struct_fields}
708 if not self._evaluate_condition(nested_field.condition, combined_context):
709 continue
711 value, consumed = self._decode_field(
712 data[total_consumed:], nested_field, {**context, **struct_fields}
713 )
714 struct_fields[nested_field.name] = value
715 total_consumed += consumed
717 return struct_fields, total_consumed
719 def _decode_integer(self, data: bytes, type_name: str, endian: str) -> tuple[int, int]:
720 """Decode integer field.
722 Args:
723 data: Binary data
724 type_name: Type name (uint8, int16, etc.)
725 endian: Endian marker
727 Returns:
728 Tuple of (value, bytes_consumed)
730 Raises:
731 ValueError: If insufficient data for the integer type
732 """
733 format_map = {
734 "uint8": ("B", 1),
735 "int8": ("b", 1),
736 "uint16": ("H", 2),
737 "int16": ("h", 2),
738 "uint32": ("I", 4),
739 "int32": ("i", 4),
740 "uint64": ("Q", 8),
741 "int64": ("q", 8),
742 }
744 fmt_char, size = format_map[type_name]
746 if len(data) < size:
747 raise ValueError(f"Insufficient data for {type_name} (need {size}, have {len(data)})")
749 # uint8/int8 don't use endianness
750 if size == 1:
751 value = struct.unpack(fmt_char, data[:size])[0]
752 else:
753 value = struct.unpack(f"{endian}{fmt_char}", data[:size])[0]
755 return value, size
757 def _decode_float(self, data: bytes, type_name: str, endian: str) -> tuple[float, int]:
758 """Decode float field.
760 Args:
761 data: Binary data
762 type_name: Type name (float32 or float64)
763 endian: Endian marker
765 Returns:
766 Tuple of (value, bytes_consumed)
768 Raises:
769 ValueError: If insufficient data for the float type
770 """
771 if type_name == "float32":
772 size = 4
773 fmt = f"{endian}f"
774 else: # float64
775 size = 8
776 fmt = f"{endian}d"
778 if len(data) < size:
779 raise ValueError(f"Insufficient data for {type_name} (need {size}, have {len(data)})")
781 value = struct.unpack(fmt, data[:size])[0]
782 return value, size
784 def _resolve_size(
785 self, size_spec: int | str | None, context: dict[str, Any], data: bytes
786 ) -> int:
787 """Resolve field size (literal or reference).
789 Args:
790 size_spec: Size specification (int, field name, or 'remaining')
791 context: Decoded fields
792 data: Current data buffer (for 'remaining' size)
794 Returns:
795 Resolved size
797 Raises:
798 ValueError: If size field not found in context or size specification is invalid
799 """
800 if size_spec is None: 800 ↛ 802line 800 didn't jump to line 802 because the condition on line 800 was never true
801 # No size specified, return 0 (caller should handle)
802 return 0
803 elif isinstance(size_spec, int):
804 return size_spec
805 elif isinstance(size_spec, str): 805 ↛ 815line 805 didn't jump to line 815 because the condition on line 805 was always true
806 # Special case: 'remaining' means use all remaining data
807 if size_spec == "remaining":
808 return len(data)
809 # Reference to another field
810 if size_spec in context:
811 return int(context[size_spec])
812 else:
813 raise ValueError(f"Size field '{size_spec}' not found in context")
814 else:
815 raise ValueError(f"Invalid size specification: {size_spec}")
817 def _evaluate_condition(self, condition: str, context: dict[str, Any]) -> bool:
818 """Evaluate field condition against decoded context.
820 : Conditional field evaluation.
822 Args:
823 condition: Condition expression (e.g., "msg_type == 0x02")
824 context: Decoded fields
826 Returns:
827 True if condition is satisfied
829 Security:
830 Uses AST-based safe evaluation. Only comparisons and logical
831 operations are permitted.
832 """
833 evaluator = _SafeConditionEvaluator(context)
834 return evaluator.eval(condition)
836 def _validate_field(self, value: Any, validation: dict[str, Any]) -> str | None:
837 """Validate field value.
839 Args:
840 value: Field value
841 validation: Validation rules
843 Returns:
844 Error message or None if valid
845 """
846 # Min/max validation
847 if "min" in validation:
848 if value < validation["min"]:
849 return f"Value {value} below minimum {validation['min']}"
851 if "max" in validation:
852 if value > validation["max"]:
853 return f"Value {value} above maximum {validation['max']}"
855 # Value validation
856 if "value" in validation:
857 if value != validation["value"]:
858 return f"Expected {validation['value']}, got {value}"
860 return None
863class ProtocolEncoder:
864 """Encode data using protocol definition.
866 : Protocol encoder.
867 """
869 def __init__(self, definition: ProtocolDefinition):
870 """Initialize encoder.
872 Args:
873 definition: Protocol definition
874 """
875 self.definition = definition
876 self._endian_map: dict[str, str] = {"big": ">", "little": "<"}
878 def encode(self, fields: dict[str, Any]) -> bytes:
879 """Encode field values to binary message.
881 : Message encoding.
883 Args:
884 fields: Dictionary of field name -> value
886 Returns:
887 Encoded binary message
889 Raises:
890 ValueError: If required field is missing
891 """
892 result = bytearray()
894 for field_def in self.definition.fields:
895 # Check condition
896 if field_def.condition:
897 # Use safe evaluator instead of eval()
898 evaluator = _SafeConditionEvaluator(fields)
899 if not evaluator.eval(field_def.condition):
900 continue
902 # Get value
903 if field_def.name in fields:
904 value = fields[field_def.name]
905 elif field_def.default is not None:
906 value = field_def.default
907 else:
908 raise ValueError(f"Missing required field: {field_def.name}")
910 # Encode field
911 encoded = self._encode_field(value, field_def)
912 result.extend(encoded)
914 return bytes(result)
916 def _encode_field(self, value: Any, field: FieldDefinition) -> bytes:
917 """Encode single field value.
919 : Field encoding.
921 Args:
922 value: Field value
923 field: Field definition
925 Returns:
926 Encoded bytes
928 Raises:
929 ValueError: If bytes value is invalid or field type is unknown for encoding
930 """
931 endian = self._endian_map.get(field.endian, ">")
932 field_type = field.field_type
934 # Integer types
935 if field_type == "uint8":
936 return struct.pack("B", int(value))
937 elif field_type == "int8":
938 return struct.pack("b", int(value))
939 elif field_type == "uint16":
940 return struct.pack(f"{endian}H", int(value))
941 elif field_type == "int16":
942 return struct.pack(f"{endian}h", int(value))
943 elif field_type == "uint32":
944 return struct.pack(f"{endian}I", int(value))
945 elif field_type == "int32":
946 return struct.pack(f"{endian}i", int(value))
947 elif field_type == "uint64":
948 return struct.pack(f"{endian}Q", int(value))
949 elif field_type == "int64":
950 return struct.pack(f"{endian}q", int(value))
952 # Float types
953 elif field_type == "float32":
954 return struct.pack(f"{endian}f", float(value))
955 elif field_type == "float64":
956 return struct.pack(f"{endian}d", float(value))
958 # Bytes
959 elif field_type == "bytes":
960 if isinstance(value, bytes):
961 return value
962 elif isinstance(value, list | tuple):
963 return bytes(value)
964 else:
965 raise ValueError(f"Invalid bytes value: {value}")
967 # String
968 elif field_type == "string":
969 if isinstance(value, str):
970 return value.encode("utf-8")
971 else:
972 return bytes(value)
974 # Array
975 elif field_type == "array":
976 return self._encode_array(value, field)
978 # Struct
979 elif field_type == "struct":
980 return self._encode_struct(value, field)
982 else:
983 raise ValueError(f"Unknown field type for encoding: {field_type}")
985 def _encode_array(self, value: list[Any], field: FieldDefinition) -> bytes:
986 """Encode array field.
988 Args:
989 value: List of values
990 field: Field definition
992 Returns:
993 Encoded bytes
995 Raises:
996 ValueError: If array field is missing element definition
997 """
998 result = bytearray()
1000 element_def = field.element
1001 if element_def is None:
1002 raise ValueError(f"Array field '{field.name}' missing element definition")
1004 element_type = element_def.get("type", "uint8")
1005 element_endian = element_def.get("endian", field.endian)
1007 for i, elem in enumerate(value):
1008 if element_type == "struct":
1009 # Nested struct
1010 nested_fields = element_def.get("fields", [])
1011 parsed_fields = [
1012 ProtocolDefinition._parse_field_definition(f, element_endian)
1013 for f in nested_fields
1014 ]
1015 elem_field = FieldDefinition(
1016 name=f"{field.name}[{i}]",
1017 field_type="struct",
1018 endian=element_endian,
1019 fields=parsed_fields,
1020 )
1021 result.extend(self._encode_struct(elem, elem_field))
1022 else:
1023 elem_field = FieldDefinition(
1024 name=f"{field.name}[{i}]",
1025 field_type=element_type,
1026 endian=element_endian,
1027 )
1028 result.extend(self._encode_field(elem, elem_field))
1030 return bytes(result)
1032 def _encode_struct(self, value: dict[str, Any], field: FieldDefinition) -> bytes:
1033 """Encode struct field.
1035 Args:
1036 value: Dictionary of field values
1037 field: Field definition
1039 Returns:
1040 Encoded bytes
1042 Raises:
1043 ValueError: If struct field is missing fields definition
1044 """
1045 result = bytearray()
1047 nested_fields = field.fields
1048 if nested_fields is None:
1049 raise ValueError(f"Struct field '{field.name}' missing fields definition")
1051 for nested_field in nested_fields:
1052 if nested_field.name in value:
1053 result.extend(self._encode_field(value[nested_field.name], nested_field))
1054 elif nested_field.default is not None: 1054 ↛ 1051line 1054 didn't jump to line 1051 because the condition on line 1054 was always true
1055 result.extend(self._encode_field(nested_field.default, nested_field))
1057 return bytes(result)
1060def load_protocol(path: str | Path) -> ProtocolDefinition:
1061 """Load protocol definition from YAML.
1063 : Convenience function for loading protocols.
1065 Args:
1066 path: Path to YAML file
1068 Returns:
1069 ProtocolDefinition instance
1070 """
1071 return ProtocolDefinition.from_yaml(path)
1074def decode_message(data: bytes, protocol: str | ProtocolDefinition) -> DecodedMessage:
1075 """Decode message using protocol.
1077 : Convenience function for decoding.
1079 Args:
1080 data: Binary message data
1081 protocol: Protocol definition (path or instance)
1083 Returns:
1084 DecodedMessage instance
1085 """
1086 if isinstance(protocol, str):
1087 protocol_def = ProtocolDefinition.from_yaml(protocol)
1088 else:
1089 protocol_def = protocol
1091 decoder = ProtocolDecoder(protocol_def)
1092 return decoder.decode(data)