Coverage for src / tracekit / inference / protocol_dsl.py: 93%

417 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Protocol Definition Language parser and decoder generator. 

2 

3Requirements addressed: PSI-004 

4 

5This module provides a declarative DSL for defining custom protocol formats 

6that can be used to generate decoders and encoders automatically. 

7 

8Key capabilities: 

9- Parse YAML-based protocol definitions 

10- Support all common field types 

11- Conditional fields and length-prefixed data 

12- Generate efficient decoders and encoders 

13- Comprehensive error reporting 

14""" 

15 

16import ast 

17import operator 

18import struct 

19from collections.abc import Iterator 

20from dataclasses import dataclass, field 

21from pathlib import Path 

22from typing import Any, Literal 

23 

24import yaml 

25 

26 

27@dataclass 

28class FieldDefinition: 

29 """Protocol field definition. 

30 

31 : Field specification. 

32 

33 Attributes: 

34 name: Field name 

35 field_type: Field type (uint8, uint16, int32, float32, bytes, string, bitfield, array, struct) 

36 Also accessible as 'type' for compatibility. 

37 size: Field size (literal or reference to length field) 

38 offset: Field offset (optional, auto-calculated if not provided) 

39 endian: Byte order ('big' or 'little') 

40 condition: Conditional expression for optional fields 

41 enum: Enumeration mapping for integer fields 

42 validation: Validation rules 

43 default: Default value 

44 description: Human-readable description 

45 value: Expected value for constant fields 

46 size_ref: Reference to length field (alias for size when string) 

47 element: Element definition for array types (contains type and optionally fields for struct) 

48 count_field: Field name that contains the array count 

49 count: Fixed array count 

50 fields: List of nested field definitions for struct types 

51 """ 

52 

53 name: str 

54 field_type: str = ( 

55 "uint8" # uint8, uint16, int32, float32, bytes, string, bitfield, array, struct 

56 ) 

57 size: int | str | None = None # Can be literal or reference to length field 

58 offset: int | None = None 

59 endian: Literal["big", "little"] = "big" 

60 condition: str | None = None # Conditional field 

61 enum: dict[int, str] | dict[str, Any] | None = None 

62 validation: dict[str, Any] | None = None 

63 default: Any = None 

64 description: str = "" 

65 value: Any = None # Expected constant value 

66 size_ref: str | None = None # Alias for size reference 

67 # Array/struct specific fields 

68 element: dict[str, Any] | None = None # Element definition for arrays 

69 count_field: str | None = None # Field containing array count 

70 count: int | None = None # Fixed array count 

71 fields: list["FieldDefinition"] | None = None # Nested fields for struct type 

72 

73 def __post_init__(self) -> None: 

74 """Handle size_ref as alias for size.""" 

75 if self.size_ref is not None and self.size is None: 

76 self.size = self.size_ref 

77 

78 @property 

79 def type(self) -> str: 

80 """Alias for field_type for backward compatibility.""" 

81 return self.field_type 

82 

83 @type.setter 

84 def type(self, value: str) -> None: 

85 """Set field_type via type property.""" 

86 self.field_type = value 

87 

88 

89@dataclass 

90class ProtocolDefinition: 

91 """Complete protocol definition. 

92 

93 : Complete protocol specification. 

94 

95 Attributes: 

96 name: Protocol name 

97 version: Protocol version 

98 description: Protocol description 

99 settings: Global settings (endianness, etc.) 

100 framing: Framing/sync configuration 

101 fields: List of field definitions 

102 computed_fields: Computed/derived fields 

103 decoding: Decoding settings 

104 encoding: Encoding settings 

105 endian: Default endianness for all fields 

106 """ 

107 

108 name: str 

109 description: str = "" 

110 version: str = "1.0" 

111 endian: Literal["big", "little"] = "big" 

112 fields: list[FieldDefinition] = field(default_factory=list) 

113 settings: dict[str, Any] = field(default_factory=dict) 

114 framing: dict[str, Any] = field(default_factory=dict) 

115 computed_fields: list[dict[str, Any]] = field(default_factory=list) 

116 decoding: dict[str, Any] = field(default_factory=dict) 

117 encoding: dict[str, Any] = field(default_factory=dict) 

118 

119 @classmethod 

120 def from_yaml(cls, path: str | Path) -> "ProtocolDefinition": 

121 """Load protocol definition from YAML file. 

122 

123 : YAML parsing. 

124 

125 Args: 

126 path: Path to YAML file 

127 

128 Returns: 

129 ProtocolDefinition instance 

130 """ 

131 with open(path) as f: 

132 config = yaml.safe_load(f) 

133 

134 return cls.from_dict(config) 

135 

136 @classmethod 

137 def from_dict(cls, config: dict[str, Any]) -> "ProtocolDefinition": 

138 """Create from dictionary. 

139 

140 : Configuration parsing. 

141 

142 Args: 

143 config: Configuration dictionary 

144 

145 Returns: 

146 ProtocolDefinition instance 

147 """ 

148 # Parse field definitions 

149 field_defs = [] 

150 default_endian = config.get("endian", "big") 

151 for field_dict in config.get("fields", []): 

152 field_def = cls._parse_field_definition(field_dict, default_endian) 

153 field_defs.append(field_def) 

154 

155 return cls( 

156 name=config.get("name", "unknown"), 

157 version=config.get("version", "1.0"), 

158 description=config.get("description", ""), 

159 endian=default_endian, 

160 settings=config.get("settings", {}), 

161 framing=config.get("framing", {}), 

162 fields=field_defs, 

163 computed_fields=config.get("computed_fields", []), 

164 decoding=config.get("decoding", {}), 

165 encoding=config.get("encoding", {}), 

166 ) 

167 

168 @classmethod 

169 def _parse_field_definition( 

170 cls, field_dict: dict[str, Any], default_endian: str 

171 ) -> "FieldDefinition": 

172 """Parse a single field definition from dictionary. 

173 

174 Args: 

175 field_dict: Field configuration dictionary 

176 default_endian: Default endianness 

177 

178 Returns: 

179 FieldDefinition instance 

180 """ 

181 # Support both 'type' and 'field_type' attribute names 

182 field_type = field_dict.get("type") or field_dict.get("field_type", "uint8") 

183 

184 # Parse nested fields for struct type 

185 nested_fields: list[FieldDefinition] | None = None 

186 if field_dict.get("fields"): 

187 nested_fields = [ 

188 cls._parse_field_definition(f, default_endian) for f in field_dict["fields"] 

189 ] 

190 

191 return FieldDefinition( 

192 name=field_dict["name"], 

193 field_type=field_type, 

194 size=field_dict.get("size"), 

195 offset=field_dict.get("offset"), 

196 endian=field_dict.get("endian", default_endian), 

197 condition=field_dict.get("condition"), 

198 enum=field_dict.get("enum"), 

199 validation=field_dict.get("validation"), 

200 default=field_dict.get("default"), 

201 description=field_dict.get("description", ""), 

202 value=field_dict.get("value"), 

203 size_ref=field_dict.get("size_ref"), 

204 element=field_dict.get("element"), 

205 count_field=field_dict.get("count_field"), 

206 count=field_dict.get("count"), 

207 fields=nested_fields, 

208 ) 

209 

210 

211@dataclass 

212class DecodedMessage: 

213 """A decoded protocol message. 

214 

215 : Decoded message representation. 

216 

217 This class behaves like a dictionary for field access, supporting 

218 operations like `"field_name" in message` and `message["field_name"]`. 

219 

220 Attributes: 

221 fields: Dictionary of field name -> value 

222 raw_data: Original binary data 

223 size: Message size in bytes 

224 valid: Whether message passed validation 

225 errors: List of validation errors 

226 """ 

227 

228 fields: dict[str, Any] 

229 raw_data: bytes 

230 size: int 

231 valid: bool 

232 errors: list[str] 

233 

234 def __contains__(self, key: str) -> bool: 

235 """Check if field exists in message.""" 

236 return key in self.fields 

237 

238 def __getitem__(self, key: str) -> Any: 

239 """Get field value by name.""" 

240 return self.fields[key] 

241 

242 def __iter__(self) -> Iterator[str]: 

243 """Iterate over field names.""" 

244 return iter(self.fields) 

245 

246 def keys(self) -> Any: 

247 """Return field names.""" 

248 return self.fields.keys() 

249 

250 def values(self) -> Any: 

251 """Return field values.""" 

252 return self.fields.values() 

253 

254 def items(self) -> Any: 

255 """Return field name-value pairs.""" 

256 return self.fields.items() 

257 

258 def get(self, key: str, default: Any = None) -> Any: 

259 """Get field value with default.""" 

260 return self.fields.get(key, default) 

261 

262 

263class _SafeConditionEvaluator(ast.NodeVisitor): 

264 """Safe evaluator for protocol field conditions. 

265 

266 Only allows: 

267 - Comparisons: ==, !=, <, <=, >, >= 

268 - Logical operations: and, or, not 

269 - Constants: numbers, strings, booleans 

270 - Variable names from context 

271 

272 Security: 

273 Uses AST parsing to safely evaluate conditions without eval(). 

274 """ 

275 

276 def __init__(self, context: dict[str, Any]): 

277 """Initialize with field context. 

278 

279 Args: 

280 context: Dictionary of field names to values 

281 """ 

282 self.context = context 

283 self.compare_ops = { 

284 ast.Eq: operator.eq, 

285 ast.NotEq: operator.ne, 

286 ast.Lt: operator.lt, 

287 ast.LtE: operator.le, 

288 ast.Gt: operator.gt, 

289 ast.GtE: operator.ge, 

290 } 

291 

292 def eval(self, expression: str) -> bool: 

293 """Evaluate condition expression. 

294 

295 Args: 

296 expression: Condition string 

297 

298 Returns: 

299 Boolean result 

300 """ 

301 try: 

302 tree = ast.parse(expression, mode="eval") 

303 result = self.visit(tree.body) 

304 return bool(result) 

305 except Exception: 

306 # If evaluation fails, condition is false 

307 return False 

308 

309 def visit_Compare(self, node: ast.Compare) -> Any: 

310 """Visit comparison operation.""" 

311 left = self.visit(node.left) 

312 for op, comparator in zip(node.ops, node.comparators, strict=True): 

313 if type(op) not in self.compare_ops: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true

314 return False 

315 right = self.visit(comparator) 

316 if not self.compare_ops[type(op)](left, right): 

317 return False 

318 left = right 

319 return True 

320 

321 def visit_BoolOp(self, node: ast.BoolOp) -> Any: 

322 """Visit boolean operation (and, or).""" 

323 if isinstance(node.op, ast.And): 

324 return all(self.visit(value) for value in node.values) 

325 elif isinstance(node.op, ast.Or): 

326 return any(self.visit(value) for value in node.values) 

327 return False 

328 

329 def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: 

330 """Visit unary operation (not).""" 

331 if isinstance(node.op, ast.Not): 

332 return not self.visit(node.operand) 

333 return False 

334 

335 def visit_Name(self, node: ast.Name) -> Any: 

336 """Visit variable name.""" 

337 return self.context.get(node.id) 

338 

339 def visit_Constant(self, node: ast.Constant) -> Any: 

340 """Visit constant value.""" 

341 return node.value 

342 

343 def visit_Num(self, node: ast.Num) -> Any: 

344 """Visit number (Python <3.8 compatibility).""" 

345 return node.n 

346 

347 def visit_Str(self, node: ast.Str) -> Any: 

348 """Visit string (Python <3.8 compatibility).""" 

349 return node.s 

350 

351 def visit_NameConstant(self, node: ast.NameConstant) -> Any: 

352 """Visit named constant (Python <3.8 compatibility).""" 

353 return node.value 

354 

355 def generic_visit(self, node: ast.AST) -> Any: 

356 """Disallow other node types.""" 

357 return False 

358 

359 

360class ProtocolDecoder: 

361 """Decode binary data using protocol definition. 

362 

363 : Protocol decoder with full field type support. 

364 """ 

365 

366 def __init__(self, definition: ProtocolDefinition): 

367 """Initialize decoder with protocol definition. 

368 

369 Args: 

370 definition: Protocol definition 

371 """ 

372 self.definition = definition 

373 self._endian_map: dict[str, str] = {"big": ">", "little": "<"} 

374 

375 @classmethod 

376 def load(cls, path: str | Path) -> "ProtocolDecoder": 

377 """Load decoder from YAML protocol definition. 

378 

379 : Load decoder from file. 

380 

381 Args: 

382 path: Path to YAML file 

383 

384 Returns: 

385 ProtocolDecoder instance 

386 """ 

387 definition = ProtocolDefinition.from_yaml(path) 

388 return cls(definition) 

389 

390 def decode(self, data: bytes, offset: int = 0) -> DecodedMessage: 

391 """Decode single message from binary data. 

392 

393 : Complete decoding with validation. 

394 

395 Args: 

396 data: Binary data 

397 offset: Starting offset in data 

398 

399 Returns: 

400 DecodedMessage instance 

401 """ 

402 fields: dict[str, Any] = {} 

403 errors: list[str] = [] 

404 current_offset = offset 

405 valid = True 

406 

407 # Check minimum length 

408 if len(data) - offset < 1: 

409 return DecodedMessage( 

410 fields={}, raw_data=data[offset:], size=0, valid=False, errors=["Insufficient data"] 

411 ) 

412 

413 # Decode each field 

414 for field_def in self.definition.fields: 

415 # Check condition 

416 if field_def.condition: 

417 if not self._evaluate_condition(field_def.condition, fields): 

418 continue # Skip this field 

419 

420 try: 

421 value, bytes_consumed = self._decode_field(data[current_offset:], field_def, fields) 

422 

423 # Validate 

424 if field_def.validation: 

425 validation_error = self._validate_field(value, field_def.validation) 

426 if validation_error: 

427 errors.append(f"{field_def.name}: {validation_error}") 

428 valid = False 

429 

430 # Store value 

431 fields[field_def.name] = value 

432 current_offset += bytes_consumed 

433 

434 except Exception as e: 

435 errors.append(f"{field_def.name}: {e!s}") 

436 valid = False 

437 break 

438 

439 total_size = current_offset - offset 

440 

441 return DecodedMessage( 

442 fields=fields, 

443 raw_data=data[offset:current_offset], 

444 size=total_size, 

445 valid=valid and len(errors) == 0, 

446 errors=errors, 

447 ) 

448 

449 def decode_stream(self, data: bytes) -> list[DecodedMessage]: 

450 """Decode multiple messages from data stream. 

451 

452 : Stream decoding with sync detection. 

453 

454 Args: 

455 data: Binary data stream 

456 

457 Returns: 

458 List of DecodedMessage instances 

459 """ 

460 messages = [] 

461 offset = 0 

462 

463 # Check for sync pattern 

464 sync_pattern = self.definition.framing.get("sync_pattern") 

465 

466 while offset < len(data): 

467 # Find sync if configured 

468 if sync_pattern: 

469 sync_offset = self.find_sync(data, offset) 

470 if sync_offset is None: 470 ↛ 471line 470 didn't jump to line 471 because the condition on line 470 was never true

471 break # No more sync patterns 

472 offset = sync_offset 

473 

474 # Decode message 

475 msg = self.decode(data, offset) 

476 

477 if msg.valid: 477 ↛ 482line 477 didn't jump to line 482 because the condition on line 477 was always true

478 messages.append(msg) 

479 offset += msg.size 

480 else: 

481 # Try to recover by finding next sync 

482 if sync_pattern: 

483 offset += 1 

484 else: 

485 break # Can't recover without sync 

486 

487 return messages 

488 

489 def find_sync(self, data: bytes, start: int = 0) -> int | None: 

490 """Find sync pattern in data. 

491 

492 : Sync pattern detection. 

493 

494 Args: 

495 data: Binary data 

496 start: Starting offset 

497 

498 Returns: 

499 Offset of sync pattern or None 

500 """ 

501 sync_pattern = self.definition.framing.get("sync_pattern") 

502 if not sync_pattern: 

503 return start # No sync pattern, start from beginning 

504 

505 # Convert sync pattern (hex string or bytes) 

506 if isinstance(sync_pattern, str): 

507 if sync_pattern.startswith("0x"): 

508 # Hex string like "0xAA55" 

509 sync_bytes = bytes.fromhex(sync_pattern[2:]) 

510 else: 

511 sync_bytes = sync_pattern.encode() 

512 else: 

513 sync_bytes = bytes(sync_pattern) 

514 

515 # Search for pattern 

516 idx = data.find(sync_bytes, start) 

517 if idx == -1: 

518 return None 

519 return idx 

520 

521 def _decode_field( 

522 self, data: bytes, field: FieldDefinition, context: dict[str, Any] 

523 ) -> tuple[Any, int]: 

524 """Decode single field. 

525 

526 : Field decoding for all types. 

527 

528 Args: 

529 data: Binary data 

530 field: Field definition 

531 context: Previously decoded fields 

532 

533 Returns: 

534 Tuple of (value, bytes_consumed) 

535 

536 Raises: 

537 ValueError: If bitfield size is unsupported or field type is unknown 

538 """ 

539 endian = self._endian_map.get(field.endian, ">") 

540 field_type = field.field_type 

541 

542 # Integer types 

543 if field_type in ["uint8", "int8", "uint16", "int16", "uint32", "int32", "uint64", "int64"]: 

544 return self._decode_integer(data, field_type, endian) 

545 

546 # Float types 

547 elif field_type in ["float32", "float64"]: 

548 return self._decode_float(data, field_type, endian) 

549 

550 # Bytes 

551 elif field_type == "bytes": 

552 size = self._resolve_size(field.size, context, data) 

553 if size > len(data): 

554 size = len(data) # Use remaining data 

555 return bytes(data[:size]), size 

556 

557 # String 

558 elif field_type == "string": 

559 size = self._resolve_size(field.size, context, data) 

560 if size > len(data): 560 ↛ 561line 560 didn't jump to line 561 because the condition on line 560 was never true

561 size = len(data) # Use remaining data 

562 string_bytes = data[:size] 

563 # Try to decode as UTF-8, fall back to latin-1 

564 try: 

565 value = string_bytes.decode("utf-8").rstrip("\x00") 

566 except UnicodeDecodeError: 

567 value = string_bytes.decode("latin-1").rstrip("\x00") 

568 return value, size 

569 

570 # Bitfield 

571 elif field_type == "bitfield": 

572 # Decode as uint and extract bits 

573 field_size = field.size if isinstance(field.size, int) else 1 

574 if field_size == 1: 

575 bitfield_value = int(data[0]) 

576 elif field_size == 2: 

577 bitfield_value = struct.unpack(f"{endian}H", data[:2])[0] 

578 elif field_size == 4: 

579 bitfield_value = struct.unpack(f"{endian}I", data[:4])[0] 

580 else: 

581 raise ValueError(f"Unsupported bitfield size: {field_size}") 

582 

583 # Return as-is, caller can extract specific bits 

584 return bitfield_value, field_size 

585 

586 # Array 

587 elif field_type == "array": 

588 return self._decode_array(data, field, context) 

589 

590 # Struct (nested) 

591 elif field_type == "struct": 

592 return self._decode_struct(data, field, context) 

593 

594 else: 

595 raise ValueError(f"Unknown field type: {field_type}") 

596 

597 def _decode_array( 

598 self, data: bytes, field: FieldDefinition, context: dict[str, Any] 

599 ) -> tuple[list[Any], int]: 

600 """Decode array field. 

601 

602 : Array field decoding. 

603 

604 Args: 

605 data: Binary data 

606 field: Field definition with element spec 

607 context: Previously decoded fields 

608 

609 Returns: 

610 Tuple of (list of values, bytes_consumed) 

611 

612 Raises: 

613 ValueError: If array field is missing element definition 

614 """ 

615 elements = [] 

616 total_consumed = 0 

617 

618 # Determine element count 

619 count = None 

620 if field.count is not None: 

621 count = field.count 

622 elif field.count_field is not None and field.count_field in context: 

623 count = int(context[field.count_field]) 

624 

625 # Get element definition 

626 element_def = field.element 

627 if element_def is None: 

628 raise ValueError(f"Array field '{field.name}' missing element definition") 

629 

630 element_type = element_def.get("type", "uint8") 

631 element_endian = element_def.get("endian", field.endian) 

632 

633 # If no count, try to decode until data exhausted 

634 idx = 0 

635 while len(data) - total_consumed > 0: 

636 if count is not None and idx >= count: 636 ↛ 637line 636 didn't jump to line 637 because the condition on line 636 was never true

637 break 

638 

639 # Create a temporary field definition for the element 

640 if element_type == "struct": 

641 # Nested struct in array 

642 nested_fields = element_def.get("fields", []) 

643 parsed_fields = [ 

644 ProtocolDefinition._parse_field_definition(f, element_endian) 

645 for f in nested_fields 

646 ] 

647 elem_field = FieldDefinition( 

648 name=f"{field.name}[{idx}]", 

649 field_type="struct", 

650 endian=element_endian, 

651 fields=parsed_fields, 

652 ) 

653 value, consumed = self._decode_struct(data[total_consumed:], elem_field, context) 

654 else: 

655 # Simple element type 

656 elem_field = FieldDefinition( 

657 name=f"{field.name}[{idx}]", 

658 field_type=element_type, 

659 endian=element_endian, 

660 size=element_def.get("size"), 

661 ) 

662 value, consumed = self._decode_field(data[total_consumed:], elem_field, context) 

663 

664 if consumed == 0: 664 ↛ 665line 664 didn't jump to line 665 because the condition on line 664 was never true

665 break # Prevent infinite loop 

666 

667 elements.append(value) 

668 total_consumed += consumed 

669 idx += 1 

670 

671 return elements, total_consumed 

672 

673 def _decode_struct( 

674 self, data: bytes, field: FieldDefinition, context: dict[str, Any] 

675 ) -> tuple[dict[str, Any], int]: 

676 """Decode struct field. 

677 

678 : Nested struct field decoding. 

679 

680 Args: 

681 data: Binary data 

682 field: Field definition with nested fields 

683 context: Previously decoded fields 

684 

685 Returns: 

686 Tuple of (dict of field values, bytes_consumed) 

687 

688 Raises: 

689 ValueError: If struct field is missing fields definition 

690 """ 

691 struct_fields: dict[str, Any] = {} 

692 total_consumed = 0 

693 

694 # Get nested field definitions 

695 nested_fields = field.fields 

696 if nested_fields is None: 

697 raise ValueError(f"Struct field '{field.name}' missing fields definition") 

698 

699 # Decode each nested field 

700 for nested_field in nested_fields: 

701 if len(data) - total_consumed < 1: 701 ↛ 702line 701 didn't jump to line 702 because the condition on line 701 was never true

702 break # Not enough data 

703 

704 # Check condition if present 

705 if nested_field.condition: 

706 # Use combined context (parent context + struct fields decoded so far) 

707 combined_context = {**context, **struct_fields} 

708 if not self._evaluate_condition(nested_field.condition, combined_context): 

709 continue 

710 

711 value, consumed = self._decode_field( 

712 data[total_consumed:], nested_field, {**context, **struct_fields} 

713 ) 

714 struct_fields[nested_field.name] = value 

715 total_consumed += consumed 

716 

717 return struct_fields, total_consumed 

718 

719 def _decode_integer(self, data: bytes, type_name: str, endian: str) -> tuple[int, int]: 

720 """Decode integer field. 

721 

722 Args: 

723 data: Binary data 

724 type_name: Type name (uint8, int16, etc.) 

725 endian: Endian marker 

726 

727 Returns: 

728 Tuple of (value, bytes_consumed) 

729 

730 Raises: 

731 ValueError: If insufficient data for the integer type 

732 """ 

733 format_map = { 

734 "uint8": ("B", 1), 

735 "int8": ("b", 1), 

736 "uint16": ("H", 2), 

737 "int16": ("h", 2), 

738 "uint32": ("I", 4), 

739 "int32": ("i", 4), 

740 "uint64": ("Q", 8), 

741 "int64": ("q", 8), 

742 } 

743 

744 fmt_char, size = format_map[type_name] 

745 

746 if len(data) < size: 

747 raise ValueError(f"Insufficient data for {type_name} (need {size}, have {len(data)})") 

748 

749 # uint8/int8 don't use endianness 

750 if size == 1: 

751 value = struct.unpack(fmt_char, data[:size])[0] 

752 else: 

753 value = struct.unpack(f"{endian}{fmt_char}", data[:size])[0] 

754 

755 return value, size 

756 

757 def _decode_float(self, data: bytes, type_name: str, endian: str) -> tuple[float, int]: 

758 """Decode float field. 

759 

760 Args: 

761 data: Binary data 

762 type_name: Type name (float32 or float64) 

763 endian: Endian marker 

764 

765 Returns: 

766 Tuple of (value, bytes_consumed) 

767 

768 Raises: 

769 ValueError: If insufficient data for the float type 

770 """ 

771 if type_name == "float32": 

772 size = 4 

773 fmt = f"{endian}f" 

774 else: # float64 

775 size = 8 

776 fmt = f"{endian}d" 

777 

778 if len(data) < size: 

779 raise ValueError(f"Insufficient data for {type_name} (need {size}, have {len(data)})") 

780 

781 value = struct.unpack(fmt, data[:size])[0] 

782 return value, size 

783 

784 def _resolve_size( 

785 self, size_spec: int | str | None, context: dict[str, Any], data: bytes 

786 ) -> int: 

787 """Resolve field size (literal or reference). 

788 

789 Args: 

790 size_spec: Size specification (int, field name, or 'remaining') 

791 context: Decoded fields 

792 data: Current data buffer (for 'remaining' size) 

793 

794 Returns: 

795 Resolved size 

796 

797 Raises: 

798 ValueError: If size field not found in context or size specification is invalid 

799 """ 

800 if size_spec is None: 800 ↛ 802line 800 didn't jump to line 802 because the condition on line 800 was never true

801 # No size specified, return 0 (caller should handle) 

802 return 0 

803 elif isinstance(size_spec, int): 

804 return size_spec 

805 elif isinstance(size_spec, str): 805 ↛ 815line 805 didn't jump to line 815 because the condition on line 805 was always true

806 # Special case: 'remaining' means use all remaining data 

807 if size_spec == "remaining": 

808 return len(data) 

809 # Reference to another field 

810 if size_spec in context: 

811 return int(context[size_spec]) 

812 else: 

813 raise ValueError(f"Size field '{size_spec}' not found in context") 

814 else: 

815 raise ValueError(f"Invalid size specification: {size_spec}") 

816 

817 def _evaluate_condition(self, condition: str, context: dict[str, Any]) -> bool: 

818 """Evaluate field condition against decoded context. 

819 

820 : Conditional field evaluation. 

821 

822 Args: 

823 condition: Condition expression (e.g., "msg_type == 0x02") 

824 context: Decoded fields 

825 

826 Returns: 

827 True if condition is satisfied 

828 

829 Security: 

830 Uses AST-based safe evaluation. Only comparisons and logical 

831 operations are permitted. 

832 """ 

833 evaluator = _SafeConditionEvaluator(context) 

834 return evaluator.eval(condition) 

835 

836 def _validate_field(self, value: Any, validation: dict[str, Any]) -> str | None: 

837 """Validate field value. 

838 

839 Args: 

840 value: Field value 

841 validation: Validation rules 

842 

843 Returns: 

844 Error message or None if valid 

845 """ 

846 # Min/max validation 

847 if "min" in validation: 

848 if value < validation["min"]: 

849 return f"Value {value} below minimum {validation['min']}" 

850 

851 if "max" in validation: 

852 if value > validation["max"]: 

853 return f"Value {value} above maximum {validation['max']}" 

854 

855 # Value validation 

856 if "value" in validation: 

857 if value != validation["value"]: 

858 return f"Expected {validation['value']}, got {value}" 

859 

860 return None 

861 

862 

863class ProtocolEncoder: 

864 """Encode data using protocol definition. 

865 

866 : Protocol encoder. 

867 """ 

868 

869 def __init__(self, definition: ProtocolDefinition): 

870 """Initialize encoder. 

871 

872 Args: 

873 definition: Protocol definition 

874 """ 

875 self.definition = definition 

876 self._endian_map: dict[str, str] = {"big": ">", "little": "<"} 

877 

878 def encode(self, fields: dict[str, Any]) -> bytes: 

879 """Encode field values to binary message. 

880 

881 : Message encoding. 

882 

883 Args: 

884 fields: Dictionary of field name -> value 

885 

886 Returns: 

887 Encoded binary message 

888 

889 Raises: 

890 ValueError: If required field is missing 

891 """ 

892 result = bytearray() 

893 

894 for field_def in self.definition.fields: 

895 # Check condition 

896 if field_def.condition: 

897 # Use safe evaluator instead of eval() 

898 evaluator = _SafeConditionEvaluator(fields) 

899 if not evaluator.eval(field_def.condition): 

900 continue 

901 

902 # Get value 

903 if field_def.name in fields: 

904 value = fields[field_def.name] 

905 elif field_def.default is not None: 

906 value = field_def.default 

907 else: 

908 raise ValueError(f"Missing required field: {field_def.name}") 

909 

910 # Encode field 

911 encoded = self._encode_field(value, field_def) 

912 result.extend(encoded) 

913 

914 return bytes(result) 

915 

916 def _encode_field(self, value: Any, field: FieldDefinition) -> bytes: 

917 """Encode single field value. 

918 

919 : Field encoding. 

920 

921 Args: 

922 value: Field value 

923 field: Field definition 

924 

925 Returns: 

926 Encoded bytes 

927 

928 Raises: 

929 ValueError: If bytes value is invalid or field type is unknown for encoding 

930 """ 

931 endian = self._endian_map.get(field.endian, ">") 

932 field_type = field.field_type 

933 

934 # Integer types 

935 if field_type == "uint8": 

936 return struct.pack("B", int(value)) 

937 elif field_type == "int8": 

938 return struct.pack("b", int(value)) 

939 elif field_type == "uint16": 

940 return struct.pack(f"{endian}H", int(value)) 

941 elif field_type == "int16": 

942 return struct.pack(f"{endian}h", int(value)) 

943 elif field_type == "uint32": 

944 return struct.pack(f"{endian}I", int(value)) 

945 elif field_type == "int32": 

946 return struct.pack(f"{endian}i", int(value)) 

947 elif field_type == "uint64": 

948 return struct.pack(f"{endian}Q", int(value)) 

949 elif field_type == "int64": 

950 return struct.pack(f"{endian}q", int(value)) 

951 

952 # Float types 

953 elif field_type == "float32": 

954 return struct.pack(f"{endian}f", float(value)) 

955 elif field_type == "float64": 

956 return struct.pack(f"{endian}d", float(value)) 

957 

958 # Bytes 

959 elif field_type == "bytes": 

960 if isinstance(value, bytes): 

961 return value 

962 elif isinstance(value, list | tuple): 

963 return bytes(value) 

964 else: 

965 raise ValueError(f"Invalid bytes value: {value}") 

966 

967 # String 

968 elif field_type == "string": 

969 if isinstance(value, str): 

970 return value.encode("utf-8") 

971 else: 

972 return bytes(value) 

973 

974 # Array 

975 elif field_type == "array": 

976 return self._encode_array(value, field) 

977 

978 # Struct 

979 elif field_type == "struct": 

980 return self._encode_struct(value, field) 

981 

982 else: 

983 raise ValueError(f"Unknown field type for encoding: {field_type}") 

984 

985 def _encode_array(self, value: list[Any], field: FieldDefinition) -> bytes: 

986 """Encode array field. 

987 

988 Args: 

989 value: List of values 

990 field: Field definition 

991 

992 Returns: 

993 Encoded bytes 

994 

995 Raises: 

996 ValueError: If array field is missing element definition 

997 """ 

998 result = bytearray() 

999 

1000 element_def = field.element 

1001 if element_def is None: 

1002 raise ValueError(f"Array field '{field.name}' missing element definition") 

1003 

1004 element_type = element_def.get("type", "uint8") 

1005 element_endian = element_def.get("endian", field.endian) 

1006 

1007 for i, elem in enumerate(value): 

1008 if element_type == "struct": 

1009 # Nested struct 

1010 nested_fields = element_def.get("fields", []) 

1011 parsed_fields = [ 

1012 ProtocolDefinition._parse_field_definition(f, element_endian) 

1013 for f in nested_fields 

1014 ] 

1015 elem_field = FieldDefinition( 

1016 name=f"{field.name}[{i}]", 

1017 field_type="struct", 

1018 endian=element_endian, 

1019 fields=parsed_fields, 

1020 ) 

1021 result.extend(self._encode_struct(elem, elem_field)) 

1022 else: 

1023 elem_field = FieldDefinition( 

1024 name=f"{field.name}[{i}]", 

1025 field_type=element_type, 

1026 endian=element_endian, 

1027 ) 

1028 result.extend(self._encode_field(elem, elem_field)) 

1029 

1030 return bytes(result) 

1031 

1032 def _encode_struct(self, value: dict[str, Any], field: FieldDefinition) -> bytes: 

1033 """Encode struct field. 

1034 

1035 Args: 

1036 value: Dictionary of field values 

1037 field: Field definition 

1038 

1039 Returns: 

1040 Encoded bytes 

1041 

1042 Raises: 

1043 ValueError: If struct field is missing fields definition 

1044 """ 

1045 result = bytearray() 

1046 

1047 nested_fields = field.fields 

1048 if nested_fields is None: 

1049 raise ValueError(f"Struct field '{field.name}' missing fields definition") 

1050 

1051 for nested_field in nested_fields: 

1052 if nested_field.name in value: 

1053 result.extend(self._encode_field(value[nested_field.name], nested_field)) 

1054 elif nested_field.default is not None: 1054 ↛ 1051line 1054 didn't jump to line 1051 because the condition on line 1054 was always true

1055 result.extend(self._encode_field(nested_field.default, nested_field)) 

1056 

1057 return bytes(result) 

1058 

1059 

1060def load_protocol(path: str | Path) -> ProtocolDefinition: 

1061 """Load protocol definition from YAML. 

1062 

1063 : Convenience function for loading protocols. 

1064 

1065 Args: 

1066 path: Path to YAML file 

1067 

1068 Returns: 

1069 ProtocolDefinition instance 

1070 """ 

1071 return ProtocolDefinition.from_yaml(path) 

1072 

1073 

1074def decode_message(data: bytes, protocol: str | ProtocolDefinition) -> DecodedMessage: 

1075 """Decode message using protocol. 

1076 

1077 : Convenience function for decoding. 

1078 

1079 Args: 

1080 data: Binary message data 

1081 protocol: Protocol definition (path or instance) 

1082 

1083 Returns: 

1084 DecodedMessage instance 

1085 """ 

1086 if isinstance(protocol, str): 

1087 protocol_def = ProtocolDefinition.from_yaml(protocol) 

1088 else: 

1089 protocol_def = protocol 

1090 

1091 decoder = ProtocolDecoder(protocol_def) 

1092 return decoder.decode(data)