Coverage for dataclasses_struct / dataclass.py: 97%

252 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 22:31 +1200

1import dataclasses 

2import sys 

3from collections.abc import Generator, Iterator 

4from struct import Struct 

5from typing import ( 

6 Annotated, 

7 Any, 

8 Callable, 

9 ClassVar, 

10 Generic, 

11 Literal, 

12 Protocol, 

13 TypedDict, 

14 TypeVar, 

15 Union, 

16 get_args, 

17 get_origin, 

18 get_type_hints, 

19 overload, 

20) 

21 

22from ._typing import Buffer, TypeGuard, Unpack, dataclass_transform 

23from .field import Field, builtin_fields 

24from .types import CString, PadAfter, PadBefore 

25 

26if sys.version_info >= (3, 10): 

27 from dataclasses import KW_ONLY as _KW_ONLY_MARKER 

28else: 

29 # Placeholder for KW_ONLY on Python 3.9 

30 

31 class _KW_ONLY_MARKER_TYPE: 

32 pass 

33 

34 _KW_ONLY_MARKER = _KW_ONLY_MARKER_TYPE() 

35 

36 

37def _separate_padding_from_annotation_args(args) -> tuple[int, int, object]: 

38 pad_before = pad_after = 0 

39 extra_arg = None # should be Field or integer for bytes/list types 

40 for arg in args: 

41 if isinstance(arg, PadBefore): 

42 pad_before += arg.size 

43 elif isinstance(arg, PadAfter): 

44 pad_after += arg.size 

45 elif extra_arg is not None: 

46 raise TypeError(f"too many annotations: {arg}") 

47 else: 

48 extra_arg = arg 

49 

50 return pad_before, pad_after, extra_arg 

51 

52 

53def _format_str_with_padding(fmt: str, pad_before: int, pad_after: int) -> str: 

54 return "".join( 

55 ( 

56 (f"{pad_before}x" if pad_before else ""), 

57 fmt, 

58 (f"{pad_after}x" if pad_after else ""), 

59 ) 

60 ) 

61 

62 

63T = TypeVar("T") 

64 

65 

66_SIZE_BYTEORDER_MODE_CHAR: dict[tuple[str, str], str] = { 

67 ("native", "native"): "@", 

68 ("std", "native"): "=", 

69 ("std", "little"): "<", 

70 ("std", "big"): ">", 

71 ("std", "network"): "!", 

72} 

73_MODE_CHAR_SIZE_BYTEORDER: dict[str, tuple[str, str]] = { 

74 v: k for k, v in _SIZE_BYTEORDER_MODE_CHAR.items() 

75} 

76 

77 

78@dataclasses.dataclass 

79class _FieldInfo: 

80 name: str 

81 field: Field[Any] 

82 type_: type 

83 init: bool 

84 

85 

86class DataclassStructInternal(Generic[T]): 

87 struct: Struct 

88 cls: type[T] 

89 _fields: list[_FieldInfo] 

90 

91 @property 

92 def format(self) -> str: 

93 """ 

94 The format string used by the `struct` module to pack/unpack data. 

95 

96 See https://docs.python.org/3/library/struct.html#format-strings. 

97 """ 

98 return self.struct.format 

99 

100 @property 

101 def size(self) -> int: 

102 """Size of the packed representation in bytes.""" 

103 return self.struct.size 

104 

105 @property 

106 def mode(self) -> str: 

107 """ 

108 The `struct` mode character that determines size, alignment, and 

109 byteorder. 

110 

111 This is the first character of the `format` field. See 

112 https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment 

113 for more info. 

114 """ 

115 return self.format[0] 

116 

117 def __init__( 

118 self, 

119 fmt: str, 

120 cls: type, 

121 fields: list[_FieldInfo], 

122 ): 

123 self.struct = Struct(fmt) 

124 self.cls = cls 

125 self._fields = fields 

126 

127 def _flattened_attrs(self, outer_self: T) -> list[Any]: 

128 """ 

129 Returns a list of all attributes of `outer_self`, including those of 

130 any nested structs. 

131 """ 

132 attrs: list[Any] = [] 

133 for field in self._fields: 

134 attr = getattr(outer_self, field.name) 

135 self._flatten_attr(attrs, attr) 

136 return attrs 

137 

138 @staticmethod 

139 def _flatten_attr(attrs: list[Any], attr: object) -> None: 

140 if is_dataclass_struct(attr): 

141 attrs.extend(attr.__dataclass_struct__._flattened_attrs(attr)) 

142 elif isinstance(attr, list): 

143 for sub_attr in attr: 

144 DataclassStructInternal._flatten_attr(attrs, sub_attr) 

145 else: 

146 attrs.append(attr) 

147 

148 def _pack(self, obj: T) -> bytes: 

149 return self.struct.pack(*self._flattened_attrs(obj)) 

150 

151 def _arg_generator(self, args: Iterator) -> Generator: 

152 for field in self._fields: 

153 yield from DataclassStructInternal._generate_args_recursively( 

154 args, field.field, field.type_ 

155 ) 

156 

157 @staticmethod 

158 def _generate_args_recursively( 

159 args: Iterator, 

160 field: Field[Any], 

161 field_type: type, 

162 ) -> Generator: 

163 if is_dataclass_struct(field_type): 

164 yield field_type.__dataclass_struct__._init_from_args(args) 

165 elif isinstance(field, _FixedLengthArrayField): 

166 items: list = [] 

167 for _ in range(field.n): 

168 items.extend( 

169 DataclassStructInternal._generate_args_recursively( 

170 args, field.item_field, field.item_type 

171 ) 

172 ) 

173 yield items 

174 elif isinstance(field, CString): 

175 data: bytes = next(args) 

176 pos = data.find(0) 

177 yield data if pos < 0 else data[:pos] 

178 else: 

179 yield field_type(next(args)) 

180 

181 def _init_from_args(self, args: Iterator) -> T: 

182 """ 

183 Returns an instance of self.cls, consuming args 

184 """ 

185 kwargs = {} 

186 no_init_args = {} 

187 

188 for field, arg in zip(self._fields, self._arg_generator(args)): 

189 if field.init: 

190 kwargs[field.name] = arg 

191 else: 

192 no_init_args[field.name] = arg 

193 

194 obj = self.cls(**kwargs) 

195 for name, arg in no_init_args.items(): 

196 setattr(obj, name, arg) 

197 return obj 

198 

199 def _unpack(self, data: Buffer) -> T: 

200 return self._init_from_args(iter(self.struct.unpack(data))) 

201 

202 

203class DataclassStructProtocol(Protocol): 

204 __dataclass_struct__: ClassVar[DataclassStructInternal] 

205 """ 

206 Internal data used by the library for packing and unpacking structs. 

207 

208 See 

209 [`DataclassStructInternal`][dataclasses_struct.DataclassStructInternal]. 

210 """ 

211 

212 @classmethod 

213 def from_packed(cls: type[T], data: Buffer) -> T: 

214 """Return an instance of the class from its packed representation. 

215 

216 Args: 

217 data: The packed representation of the class as returned by 

218 [`pack`][dataclasses_struct.dataclass.DataclassStructProtocol.pack]. 

219 

220 Returns: 

221 An instance of the class unpacked from `data`. 

222 

223 Raises: 

224 struct.error: If `data` is the wrong length. 

225 """ 

226 ... 

227 

228 def pack(self) -> bytes: 

229 """Return the packed representation in `bytes` of the object. 

230 

231 Returns: 

232 The packed representation. Can be used to instantiate a new object 

233 with 

234 [`from_packed`][dataclasses_struct.dataclass.DataclassStructProtocol.from_packed]. 

235 

236 Raises: 

237 struct.error: If any of the fields are out of range or the wrong 

238 type. 

239 """ 

240 ... 

241 

242 

243@overload 

244def is_dataclass_struct( 

245 obj: type, 

246) -> TypeGuard[type[DataclassStructProtocol]]: ... 

247 

248 

249@overload 

250def is_dataclass_struct(obj: object) -> TypeGuard[DataclassStructProtocol]: ... 

251 

252 

253def is_dataclass_struct( 

254 obj: Union[type, object], 

255) -> Union[ 

256 TypeGuard[DataclassStructProtocol], 

257 TypeGuard[type[DataclassStructProtocol]], 

258]: 

259 """Determine whether a type or object is a dataclass-struct. 

260 

261 Args: 

262 obj: A class or object. 

263 

264 Returns: 

265 `True` if obj is a class that has been decorated with 

266 [`dataclass_struct`][dataclasses_struct.dataclass_struct] or is an 

267 instance of one. 

268 """ 

269 return ( 

270 dataclasses.is_dataclass(obj) 

271 and hasattr(obj, "__dataclass_struct__") 

272 and isinstance(obj.__dataclass_struct__, DataclassStructInternal) 

273 ) 

274 

275 

276def get_struct_size(cls_or_obj: object) -> int: 

277 """Get the size of the packed representation of the struct in bytes. 

278 

279 Args: 

280 cls_or_obj: A class that has been decorated with 

281 [`dataclass_struct`][dataclasses_struct.dataclass_struct] or an 

282 instance of one. 

283 

284 Returns: 

285 The size of the packed representation in bytes. 

286 

287 Raises: 

288 TypeError: if `cls_or_obj` is not a dataclass-struct. 

289 """ 

290 if not is_dataclass_struct(cls_or_obj): 

291 raise TypeError(f"{cls_or_obj} is not a dataclass_struct") 

292 return cls_or_obj.__dataclass_struct__.size 

293 

294 

295class _BytesField(Field[bytes]): 

296 field_type = bytes 

297 

298 def __init__(self, n: object): 

299 if not isinstance(n, int) or n < 1: 

300 raise ValueError("bytes length must be positive non-zero int") 

301 

302 self.n = n 

303 

304 def format(self) -> str: 

305 return f"{self.n}s" 

306 

307 def validate_default(self, val: bytes) -> None: 

308 if len(val) > self.n: 

309 raise ValueError(f"bytes cannot be longer than {self.n} bytes") 

310 

311 def __repr__(self) -> str: 

312 return f"{super().__repr__()}({self.n})" 

313 

314 

315class _NestedField(Field): 

316 field_type: type[DataclassStructProtocol] 

317 

318 def __init__(self, cls: type[DataclassStructProtocol]): 

319 self.field_type = cls 

320 

321 def format(self) -> str: 

322 # Return the format without the byteorder specifier at the beginning 

323 return self.field_type.__dataclass_struct__.format[1:] 

324 

325 

326class _FixedLengthArrayField(Field[list]): 

327 field_type = list 

328 

329 def __init__(self, item_type_annotation: Any, mode: str, n: object): 

330 if not isinstance(n, int) or n < 1: 

331 raise ValueError( 

332 "fixed-length array length must be positive non-zero int" 

333 ) 

334 

335 self.item_field, self.item_type, self.pad_before, self.pad_after = ( 

336 _resolve_field(item_type_annotation, mode) 

337 ) 

338 self.n = n 

339 self.is_native = self.item_field.is_native 

340 self.is_std = self.item_field.is_std 

341 

342 def format(self) -> str: 

343 fmt = _format_str_with_padding( 

344 self.item_field.format(), 

345 self.pad_before, 

346 self.pad_after, 

347 ) 

348 return fmt * self.n 

349 

350 def __repr__(self) -> str: 

351 return f"{super().__repr__()}({self.item_field!r}, {self.n})" 

352 

353 def validate_default(self, val: list) -> None: 

354 n = len(val) 

355 if n != self.n: 

356 msg = f"fixed-length array must have length of {self.n}, got {n}" 

357 raise ValueError(msg) 

358 

359 for i in val: 

360 _validate_field_default(self.item_field, i) 

361 

362 

363def _validate_modes_match(mode: str, nested_mode: str) -> None: 

364 if mode != nested_mode: 

365 size, byteorder = _MODE_CHAR_SIZE_BYTEORDER[nested_mode] 

366 exp_size, exp_byteorder = _MODE_CHAR_SIZE_BYTEORDER[mode] 

367 msg = ( 

368 "byteorder and size of nested dataclass-struct does not " 

369 f"match that of container (expected '{exp_size}' size and " 

370 f"'{exp_byteorder}' byteorder, got '{size}' size and " 

371 f"'{byteorder}' byteorder)" 

372 ) 

373 raise TypeError(msg) 

374 

375 

376def _resolve_field( 

377 annotation: Any, 

378 mode: str, 

379) -> tuple[Field[Any], type, int, int]: 

380 """ 

381 Returns 4-tuple of: 

382 * field 

383 * type 

384 * number of padding bytes before 

385 * number of padding bytes after 

386 

387 Valid type annotations are: 

388 

389 1. <bool | int | float | bytes> | Annotated[<bool | int | float | bytes>, <padding>] 

390 

391 Supported builtin types. 

392 

393 2. Annotated[<bool | int | float | bytes>, Field(...), <padding>] 

394 

395 (These are the types defined in dataclasses_struct.types e.g. U32, F32). 

396 

397 3. <dataclasses_struct class> | Annotated[<dataclasses_struct class>, <padding>] 

398 

399 Must have the same size and byteorder as the container. 

400 

401 4. Annotated[bytes, <n>, <padding>] 

402 

403 Where <n> is >0. 

404 

405 5. Annotated[list[<type>], <n>, <padding>] 

406 

407 Where <n> is >0 and <type> is one of the above. 

408 

409 <padding> is an optional mixture of PadBefore and PadAfter annotations, 

410 which may be repeated. E.g. 

411 

412 Annotated[int, PadBefore(5), PadAfter(2), PadBefore(3)] 

413 """ # noqa: E501 

414 

415 if get_origin(annotation) == Annotated: 

416 type_, *args = get_args(annotation) 

417 pad_before, pad_after, annotation_arg = ( 

418 _separate_padding_from_annotation_args(args) 

419 ) 

420 else: 

421 pad_before = pad_after = 0 

422 type_ = annotation 

423 annotation_arg = None 

424 

425 field: Field[Any] 

426 if annotation_arg is None: 

427 if get_origin(type_) is list: 

428 msg = ( 

429 "list types must be marked as a fixed-length using " 

430 "Annotated, ex: Annotated[list[int], 5]" 

431 ) 

432 raise TypeError(msg) 

433 

434 # Must be either a nested type or one of the supported builtins 

435 if is_dataclass_struct(type_): 

436 _validate_modes_match(mode, type_.__dataclass_struct__.mode) 

437 field = _NestedField(type_) 

438 else: 

439 opt_field = builtin_fields.get(type_) 

440 if opt_field is None: 

441 raise TypeError(f"type not supported: {annotation}") 

442 field = opt_field 

443 elif isinstance(annotation_arg, Field) and issubclass( 

444 type_, annotation_arg.field_type 

445 ): 

446 field = annotation_arg 

447 elif get_origin(type_) is list: 

448 item_annotations = get_args(type_) 

449 assert len(item_annotations) == 1 

450 field = _FixedLengthArrayField( 

451 item_annotations[0], mode, annotation_arg 

452 ) 

453 elif issubclass(type_, bytes): 

454 field = _BytesField(annotation_arg) 

455 else: 

456 raise TypeError(f"invalid field annotation: {annotation!r}") 

457 

458 return field, type_, pad_before, pad_after 

459 

460 

461def _get_default_from_dataclasses_field(field: dataclasses.Field) -> Any: 

462 if field.default is not dataclasses.MISSING: 

463 return field.default 

464 

465 if field.default_factory is not dataclasses.MISSING: 

466 return field.default_factory() 

467 

468 return dataclasses.MISSING 

469 

470 

471def _validate_field_default(field: Field[T], val: Any) -> None: 

472 if not isinstance(val, field.field_type): 

473 msg = ( 

474 "invalid type for field: expected " 

475 f"{field.field_type} got {type(val)}" 

476 ) 

477 raise TypeError(msg) 

478 

479 field.validate_default(val) 

480 

481 

482def _validate_and_parse_field( 

483 cls: type, 

484 *, 

485 name: str, 

486 field_type: type, 

487 is_native: bool, 

488 validate_defaults: bool, 

489 mode: str, 

490 init: bool, 

491) -> tuple[str, _FieldInfo]: 

492 """Returns format string and info.""" 

493 field, type_, pad_before, pad_after = _resolve_field(field_type, mode) 

494 

495 if is_native: 

496 if not field.is_native: 

497 raise TypeError( 

498 f"field {field} only supported in standard size mode" 

499 ) 

500 elif not field.is_std: 

501 raise TypeError(f"field {field} only supported in native size mode") 

502 

503 init_field = init 

504 if hasattr(cls, name): 

505 val = getattr(cls, name) 

506 if isinstance(val, dataclasses.Field): 

507 if not val.init: 

508 init_field = False 

509 

510 if validate_defaults: 

511 val = _get_default_from_dataclasses_field(val) 

512 

513 if validate_defaults and val is not dataclasses.MISSING: 

514 _validate_field_default(field, val) 

515 

516 return ( 

517 _format_str_with_padding(field.format(), pad_before, pad_after), 

518 _FieldInfo(name, field, type_, init_field), 

519 ) 

520 

521 

522def _make_pack_method() -> Callable: 

523 func = """ 

524def pack(self) -> bytes: 

525 '''Pack to bytes using struct.pack.''' 

526 return self.__dataclass_struct__._pack(self) 

527""" 

528 

529 scope: dict[str, Any] = {} 

530 exec(func, {}, scope) 

531 return scope["pack"] 

532 

533 

534def _make_unpack_method(cls: type) -> classmethod: 

535 func = """ 

536def from_packed(cls, data: Buffer) -> cls_type: 

537 '''Unpack from bytes.''' 

538 return cls.__dataclass_struct__._unpack(data) 

539""" 

540 

541 scope: dict[str, Any] = {"cls_type": cls, "Buffer": Buffer} 

542 exec(func, {}, scope) 

543 return classmethod(scope["from_packed"]) 

544 

545 

546def _make_class( 

547 cls: type, 

548 mode: str, 

549 is_native: bool, 

550 validate_defaults: bool, 

551 dataclass_kwargs, 

552) -> type[DataclassStructProtocol]: 

553 cls_annotations = get_type_hints(cls, include_extras=True) 

554 struct_format = [mode] 

555 fields: list[_FieldInfo] = [] 

556 init = dataclass_kwargs.get("init", True) 

557 for name, field in cls_annotations.items(): 

558 if field is _KW_ONLY_MARKER: 

559 # KW_ONLY is handled by stdlib dataclass, nothing to do on our end. 

560 continue 

561 

562 fmt, field = _validate_and_parse_field( 

563 cls, 

564 name=name, 

565 field_type=field, 

566 is_native=is_native, 

567 validate_defaults=validate_defaults, 

568 mode=mode, 

569 init=init, 

570 ) 

571 struct_format.append(fmt) 

572 fields.append(field) 

573 

574 setattr( # noqa: B010 

575 cls, 

576 "__dataclass_struct__", 

577 DataclassStructInternal("".join(struct_format), cls, fields), 

578 ) 

579 setattr(cls, "pack", _make_pack_method()) # noqa: B010 

580 setattr(cls, "from_packed", _make_unpack_method(cls)) # noqa: B010 

581 

582 return dataclasses.dataclass(cls, **dataclass_kwargs) 

583 

584 

585class _DataclassKwargsPre310(TypedDict, total=False): 

586 init: bool 

587 repr: bool 

588 eq: bool 

589 order: bool 

590 unsafe_hash: bool 

591 frozen: bool 

592 

593 

594if sys.version_info >= (3, 10): 

595 

596 class DataclassKwargs(_DataclassKwargsPre310, total=False): 

597 match_args: bool 

598 kw_only: bool 

599else: 

600 

601 class DataclassKwargs(_DataclassKwargsPre310, total=False): 

602 pass 

603 

604 

605@overload 

606def dataclass_struct( 

607 *, 

608 size: Literal["native"] = "native", 

609 byteorder: Literal["native"] = "native", 

610 validate_defaults: bool = True, 

611 **dataclass_kwargs: Unpack[DataclassKwargs], 

612) -> Callable[[type], type]: ... 

613 

614 

615@overload 

616def dataclass_struct( 

617 *, 

618 size: Literal["std"], 

619 byteorder: Literal["native", "big", "little", "network"] = "native", 

620 validate_defaults: bool = True, 

621 **dataclass_kwargs: Unpack[DataclassKwargs], 

622) -> Callable[[type], type]: ... 

623 

624 

625@dataclass_transform() 

626def dataclass_struct( 

627 *, 

628 size: Literal["native", "std"] = "native", 

629 byteorder: Literal["native", "big", "little", "network"] = "native", 

630 validate_defaults: bool = True, 

631 **dataclass_kwargs: Unpack[DataclassKwargs], 

632) -> Callable[[type], type]: 

633 """Create a dataclass struct. 

634 

635 Should be used as a decorator on a class: 

636 

637 ```python 

638 import dataclasses_struct as dcs 

639 

640 @dcs.dataclass_struct() 

641 class A: 

642 data: dcs.Pointer 

643 size: dcs.UnsignedSize 

644 ``` 

645 

646 The allowed `size` and `byteorder` argument combinations are as as follows. 

647 

648 | `size` | `byteorder` | Notes | 

649 | ---------- | ----------- | ------------------------------------------------------------------ | 

650 | `"native"` | `"native"` | The default. Native alignment and padding. | 

651 | `"std"` | `"native"` | Standard integer sizes and system endianness, no alignment/padding. | 

652 | `"std"` | `"little"` | Standard integer sizes and little endian, no alignment/padding. | 

653 | `"std"` | `"big"` | Standard integer sizes and big endian, no alignment/padding. | 

654 | `"std"` | `"network"` | Equivalent to `byteorder="big"`. | 

655 

656 Args: 

657 size: The size mode. 

658 byteorder: The byte order of the generated struct. If `size="native"`, 

659 only `"native"` is allowed. 

660 validate_defaults: Whether to validate the default values of any 

661 fields. 

662 dataclass_kwargs: Any additional keyword arguments to pass to the 

663 [stdlib 

664 `dataclass`](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass) 

665 decorator. The `slots` and `weakref_slot` keyword arguments are not 

666 supported. 

667 

668 Raises: 

669 ValueError: If the `size` and `byteorder` args are invalid or if 

670 `validate_defaults=True` and any of the fields' default values are 

671 invalid for their type. 

672 TypeError: If any of the fields' type annotations are invalid or 

673 not supported. 

674 """ # noqa: E501 

675 is_native = size == "native" 

676 if is_native: 

677 if byteorder != "native": 

678 raise ValueError("'native' size requires 'native' byteorder") 

679 elif size != "std": 

680 raise ValueError(f"invalid size: {size}") 

681 if byteorder not in ("native", "big", "little", "network"): 

682 raise ValueError(f"invalid byteorder: {byteorder}") 

683 

684 for kwarg in ("slots", "weakref_slot"): 

685 if kwarg in dataclass_kwargs: 

686 msg = f"dataclass '{kwarg}' keyword argument is not supported" 

687 raise ValueError(msg) 

688 

689 def decorator(cls: type) -> type: 

690 return _make_class( 

691 cls, 

692 mode=_SIZE_BYTEORDER_MODE_CHAR[(size, byteorder)], 

693 is_native=is_native, 

694 validate_defaults=validate_defaults, 

695 dataclass_kwargs=dataclass_kwargs, 

696 ) 

697 

698 return decorator