Coverage for dataclasses_struct / dataclass.py: 97%
252 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 22:31 +1200
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 22:31 +1200
1import dataclasses
2import sys
3from collections.abc import Generator, Iterator
4from struct import Struct
5from typing import (
6 Annotated,
7 Any,
8 Callable,
9 ClassVar,
10 Generic,
11 Literal,
12 Protocol,
13 TypedDict,
14 TypeVar,
15 Union,
16 get_args,
17 get_origin,
18 get_type_hints,
19 overload,
20)
22from ._typing import Buffer, TypeGuard, Unpack, dataclass_transform
23from .field import Field, builtin_fields
24from .types import CString, PadAfter, PadBefore
26if sys.version_info >= (3, 10):
27 from dataclasses import KW_ONLY as _KW_ONLY_MARKER
28else:
29 # Placeholder for KW_ONLY on Python 3.9
31 class _KW_ONLY_MARKER_TYPE:
32 pass
34 _KW_ONLY_MARKER = _KW_ONLY_MARKER_TYPE()
37def _separate_padding_from_annotation_args(args) -> tuple[int, int, object]:
38 pad_before = pad_after = 0
39 extra_arg = None # should be Field or integer for bytes/list types
40 for arg in args:
41 if isinstance(arg, PadBefore):
42 pad_before += arg.size
43 elif isinstance(arg, PadAfter):
44 pad_after += arg.size
45 elif extra_arg is not None:
46 raise TypeError(f"too many annotations: {arg}")
47 else:
48 extra_arg = arg
50 return pad_before, pad_after, extra_arg
53def _format_str_with_padding(fmt: str, pad_before: int, pad_after: int) -> str:
54 return "".join(
55 (
56 (f"{pad_before}x" if pad_before else ""),
57 fmt,
58 (f"{pad_after}x" if pad_after else ""),
59 )
60 )
63T = TypeVar("T")
66_SIZE_BYTEORDER_MODE_CHAR: dict[tuple[str, str], str] = {
67 ("native", "native"): "@",
68 ("std", "native"): "=",
69 ("std", "little"): "<",
70 ("std", "big"): ">",
71 ("std", "network"): "!",
72}
73_MODE_CHAR_SIZE_BYTEORDER: dict[str, tuple[str, str]] = {
74 v: k for k, v in _SIZE_BYTEORDER_MODE_CHAR.items()
75}
78@dataclasses.dataclass
79class _FieldInfo:
80 name: str
81 field: Field[Any]
82 type_: type
83 init: bool
86class DataclassStructInternal(Generic[T]):
87 struct: Struct
88 cls: type[T]
89 _fields: list[_FieldInfo]
91 @property
92 def format(self) -> str:
93 """
94 The format string used by the `struct` module to pack/unpack data.
96 See https://docs.python.org/3/library/struct.html#format-strings.
97 """
98 return self.struct.format
100 @property
101 def size(self) -> int:
102 """Size of the packed representation in bytes."""
103 return self.struct.size
105 @property
106 def mode(self) -> str:
107 """
108 The `struct` mode character that determines size, alignment, and
109 byteorder.
111 This is the first character of the `format` field. See
112 https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment
113 for more info.
114 """
115 return self.format[0]
117 def __init__(
118 self,
119 fmt: str,
120 cls: type,
121 fields: list[_FieldInfo],
122 ):
123 self.struct = Struct(fmt)
124 self.cls = cls
125 self._fields = fields
127 def _flattened_attrs(self, outer_self: T) -> list[Any]:
128 """
129 Returns a list of all attributes of `outer_self`, including those of
130 any nested structs.
131 """
132 attrs: list[Any] = []
133 for field in self._fields:
134 attr = getattr(outer_self, field.name)
135 self._flatten_attr(attrs, attr)
136 return attrs
138 @staticmethod
139 def _flatten_attr(attrs: list[Any], attr: object) -> None:
140 if is_dataclass_struct(attr):
141 attrs.extend(attr.__dataclass_struct__._flattened_attrs(attr))
142 elif isinstance(attr, list):
143 for sub_attr in attr:
144 DataclassStructInternal._flatten_attr(attrs, sub_attr)
145 else:
146 attrs.append(attr)
148 def _pack(self, obj: T) -> bytes:
149 return self.struct.pack(*self._flattened_attrs(obj))
151 def _arg_generator(self, args: Iterator) -> Generator:
152 for field in self._fields:
153 yield from DataclassStructInternal._generate_args_recursively(
154 args, field.field, field.type_
155 )
157 @staticmethod
158 def _generate_args_recursively(
159 args: Iterator,
160 field: Field[Any],
161 field_type: type,
162 ) -> Generator:
163 if is_dataclass_struct(field_type):
164 yield field_type.__dataclass_struct__._init_from_args(args)
165 elif isinstance(field, _FixedLengthArrayField):
166 items: list = []
167 for _ in range(field.n):
168 items.extend(
169 DataclassStructInternal._generate_args_recursively(
170 args, field.item_field, field.item_type
171 )
172 )
173 yield items
174 elif isinstance(field, CString):
175 data: bytes = next(args)
176 pos = data.find(0)
177 yield data if pos < 0 else data[:pos]
178 else:
179 yield field_type(next(args))
181 def _init_from_args(self, args: Iterator) -> T:
182 """
183 Returns an instance of self.cls, consuming args
184 """
185 kwargs = {}
186 no_init_args = {}
188 for field, arg in zip(self._fields, self._arg_generator(args)):
189 if field.init:
190 kwargs[field.name] = arg
191 else:
192 no_init_args[field.name] = arg
194 obj = self.cls(**kwargs)
195 for name, arg in no_init_args.items():
196 setattr(obj, name, arg)
197 return obj
199 def _unpack(self, data: Buffer) -> T:
200 return self._init_from_args(iter(self.struct.unpack(data)))
203class DataclassStructProtocol(Protocol):
204 __dataclass_struct__: ClassVar[DataclassStructInternal]
205 """
206 Internal data used by the library for packing and unpacking structs.
208 See
209 [`DataclassStructInternal`][dataclasses_struct.DataclassStructInternal].
210 """
212 @classmethod
213 def from_packed(cls: type[T], data: Buffer) -> T:
214 """Return an instance of the class from its packed representation.
216 Args:
217 data: The packed representation of the class as returned by
218 [`pack`][dataclasses_struct.dataclass.DataclassStructProtocol.pack].
220 Returns:
221 An instance of the class unpacked from `data`.
223 Raises:
224 struct.error: If `data` is the wrong length.
225 """
226 ...
228 def pack(self) -> bytes:
229 """Return the packed representation in `bytes` of the object.
231 Returns:
232 The packed representation. Can be used to instantiate a new object
233 with
234 [`from_packed`][dataclasses_struct.dataclass.DataclassStructProtocol.from_packed].
236 Raises:
237 struct.error: If any of the fields are out of range or the wrong
238 type.
239 """
240 ...
243@overload
244def is_dataclass_struct(
245 obj: type,
246) -> TypeGuard[type[DataclassStructProtocol]]: ...
249@overload
250def is_dataclass_struct(obj: object) -> TypeGuard[DataclassStructProtocol]: ...
253def is_dataclass_struct(
254 obj: Union[type, object],
255) -> Union[
256 TypeGuard[DataclassStructProtocol],
257 TypeGuard[type[DataclassStructProtocol]],
258]:
259 """Determine whether a type or object is a dataclass-struct.
261 Args:
262 obj: A class or object.
264 Returns:
265 `True` if obj is a class that has been decorated with
266 [`dataclass_struct`][dataclasses_struct.dataclass_struct] or is an
267 instance of one.
268 """
269 return (
270 dataclasses.is_dataclass(obj)
271 and hasattr(obj, "__dataclass_struct__")
272 and isinstance(obj.__dataclass_struct__, DataclassStructInternal)
273 )
276def get_struct_size(cls_or_obj: object) -> int:
277 """Get the size of the packed representation of the struct in bytes.
279 Args:
280 cls_or_obj: A class that has been decorated with
281 [`dataclass_struct`][dataclasses_struct.dataclass_struct] or an
282 instance of one.
284 Returns:
285 The size of the packed representation in bytes.
287 Raises:
288 TypeError: if `cls_or_obj` is not a dataclass-struct.
289 """
290 if not is_dataclass_struct(cls_or_obj):
291 raise TypeError(f"{cls_or_obj} is not a dataclass_struct")
292 return cls_or_obj.__dataclass_struct__.size
295class _BytesField(Field[bytes]):
296 field_type = bytes
298 def __init__(self, n: object):
299 if not isinstance(n, int) or n < 1:
300 raise ValueError("bytes length must be positive non-zero int")
302 self.n = n
304 def format(self) -> str:
305 return f"{self.n}s"
307 def validate_default(self, val: bytes) -> None:
308 if len(val) > self.n:
309 raise ValueError(f"bytes cannot be longer than {self.n} bytes")
311 def __repr__(self) -> str:
312 return f"{super().__repr__()}({self.n})"
315class _NestedField(Field):
316 field_type: type[DataclassStructProtocol]
318 def __init__(self, cls: type[DataclassStructProtocol]):
319 self.field_type = cls
321 def format(self) -> str:
322 # Return the format without the byteorder specifier at the beginning
323 return self.field_type.__dataclass_struct__.format[1:]
326class _FixedLengthArrayField(Field[list]):
327 field_type = list
329 def __init__(self, item_type_annotation: Any, mode: str, n: object):
330 if not isinstance(n, int) or n < 1:
331 raise ValueError(
332 "fixed-length array length must be positive non-zero int"
333 )
335 self.item_field, self.item_type, self.pad_before, self.pad_after = (
336 _resolve_field(item_type_annotation, mode)
337 )
338 self.n = n
339 self.is_native = self.item_field.is_native
340 self.is_std = self.item_field.is_std
342 def format(self) -> str:
343 fmt = _format_str_with_padding(
344 self.item_field.format(),
345 self.pad_before,
346 self.pad_after,
347 )
348 return fmt * self.n
350 def __repr__(self) -> str:
351 return f"{super().__repr__()}({self.item_field!r}, {self.n})"
353 def validate_default(self, val: list) -> None:
354 n = len(val)
355 if n != self.n:
356 msg = f"fixed-length array must have length of {self.n}, got {n}"
357 raise ValueError(msg)
359 for i in val:
360 _validate_field_default(self.item_field, i)
363def _validate_modes_match(mode: str, nested_mode: str) -> None:
364 if mode != nested_mode:
365 size, byteorder = _MODE_CHAR_SIZE_BYTEORDER[nested_mode]
366 exp_size, exp_byteorder = _MODE_CHAR_SIZE_BYTEORDER[mode]
367 msg = (
368 "byteorder and size of nested dataclass-struct does not "
369 f"match that of container (expected '{exp_size}' size and "
370 f"'{exp_byteorder}' byteorder, got '{size}' size and "
371 f"'{byteorder}' byteorder)"
372 )
373 raise TypeError(msg)
376def _resolve_field(
377 annotation: Any,
378 mode: str,
379) -> tuple[Field[Any], type, int, int]:
380 """
381 Returns 4-tuple of:
382 * field
383 * type
384 * number of padding bytes before
385 * number of padding bytes after
387 Valid type annotations are:
389 1. <bool | int | float | bytes> | Annotated[<bool | int | float | bytes>, <padding>]
391 Supported builtin types.
393 2. Annotated[<bool | int | float | bytes>, Field(...), <padding>]
395 (These are the types defined in dataclasses_struct.types e.g. U32, F32).
397 3. <dataclasses_struct class> | Annotated[<dataclasses_struct class>, <padding>]
399 Must have the same size and byteorder as the container.
401 4. Annotated[bytes, <n>, <padding>]
403 Where <n> is >0.
405 5. Annotated[list[<type>], <n>, <padding>]
407 Where <n> is >0 and <type> is one of the above.
409 <padding> is an optional mixture of PadBefore and PadAfter annotations,
410 which may be repeated. E.g.
412 Annotated[int, PadBefore(5), PadAfter(2), PadBefore(3)]
413 """ # noqa: E501
415 if get_origin(annotation) == Annotated:
416 type_, *args = get_args(annotation)
417 pad_before, pad_after, annotation_arg = (
418 _separate_padding_from_annotation_args(args)
419 )
420 else:
421 pad_before = pad_after = 0
422 type_ = annotation
423 annotation_arg = None
425 field: Field[Any]
426 if annotation_arg is None:
427 if get_origin(type_) is list:
428 msg = (
429 "list types must be marked as a fixed-length using "
430 "Annotated, ex: Annotated[list[int], 5]"
431 )
432 raise TypeError(msg)
434 # Must be either a nested type or one of the supported builtins
435 if is_dataclass_struct(type_):
436 _validate_modes_match(mode, type_.__dataclass_struct__.mode)
437 field = _NestedField(type_)
438 else:
439 opt_field = builtin_fields.get(type_)
440 if opt_field is None:
441 raise TypeError(f"type not supported: {annotation}")
442 field = opt_field
443 elif isinstance(annotation_arg, Field) and issubclass(
444 type_, annotation_arg.field_type
445 ):
446 field = annotation_arg
447 elif get_origin(type_) is list:
448 item_annotations = get_args(type_)
449 assert len(item_annotations) == 1
450 field = _FixedLengthArrayField(
451 item_annotations[0], mode, annotation_arg
452 )
453 elif issubclass(type_, bytes):
454 field = _BytesField(annotation_arg)
455 else:
456 raise TypeError(f"invalid field annotation: {annotation!r}")
458 return field, type_, pad_before, pad_after
461def _get_default_from_dataclasses_field(field: dataclasses.Field) -> Any:
462 if field.default is not dataclasses.MISSING:
463 return field.default
465 if field.default_factory is not dataclasses.MISSING:
466 return field.default_factory()
468 return dataclasses.MISSING
471def _validate_field_default(field: Field[T], val: Any) -> None:
472 if not isinstance(val, field.field_type):
473 msg = (
474 "invalid type for field: expected "
475 f"{field.field_type} got {type(val)}"
476 )
477 raise TypeError(msg)
479 field.validate_default(val)
482def _validate_and_parse_field(
483 cls: type,
484 *,
485 name: str,
486 field_type: type,
487 is_native: bool,
488 validate_defaults: bool,
489 mode: str,
490 init: bool,
491) -> tuple[str, _FieldInfo]:
492 """Returns format string and info."""
493 field, type_, pad_before, pad_after = _resolve_field(field_type, mode)
495 if is_native:
496 if not field.is_native:
497 raise TypeError(
498 f"field {field} only supported in standard size mode"
499 )
500 elif not field.is_std:
501 raise TypeError(f"field {field} only supported in native size mode")
503 init_field = init
504 if hasattr(cls, name):
505 val = getattr(cls, name)
506 if isinstance(val, dataclasses.Field):
507 if not val.init:
508 init_field = False
510 if validate_defaults:
511 val = _get_default_from_dataclasses_field(val)
513 if validate_defaults and val is not dataclasses.MISSING:
514 _validate_field_default(field, val)
516 return (
517 _format_str_with_padding(field.format(), pad_before, pad_after),
518 _FieldInfo(name, field, type_, init_field),
519 )
522def _make_pack_method() -> Callable:
523 func = """
524def pack(self) -> bytes:
525 '''Pack to bytes using struct.pack.'''
526 return self.__dataclass_struct__._pack(self)
527"""
529 scope: dict[str, Any] = {}
530 exec(func, {}, scope)
531 return scope["pack"]
534def _make_unpack_method(cls: type) -> classmethod:
535 func = """
536def from_packed(cls, data: Buffer) -> cls_type:
537 '''Unpack from bytes.'''
538 return cls.__dataclass_struct__._unpack(data)
539"""
541 scope: dict[str, Any] = {"cls_type": cls, "Buffer": Buffer}
542 exec(func, {}, scope)
543 return classmethod(scope["from_packed"])
546def _make_class(
547 cls: type,
548 mode: str,
549 is_native: bool,
550 validate_defaults: bool,
551 dataclass_kwargs,
552) -> type[DataclassStructProtocol]:
553 cls_annotations = get_type_hints(cls, include_extras=True)
554 struct_format = [mode]
555 fields: list[_FieldInfo] = []
556 init = dataclass_kwargs.get("init", True)
557 for name, field in cls_annotations.items():
558 if field is _KW_ONLY_MARKER:
559 # KW_ONLY is handled by stdlib dataclass, nothing to do on our end.
560 continue
562 fmt, field = _validate_and_parse_field(
563 cls,
564 name=name,
565 field_type=field,
566 is_native=is_native,
567 validate_defaults=validate_defaults,
568 mode=mode,
569 init=init,
570 )
571 struct_format.append(fmt)
572 fields.append(field)
574 setattr( # noqa: B010
575 cls,
576 "__dataclass_struct__",
577 DataclassStructInternal("".join(struct_format), cls, fields),
578 )
579 setattr(cls, "pack", _make_pack_method()) # noqa: B010
580 setattr(cls, "from_packed", _make_unpack_method(cls)) # noqa: B010
582 return dataclasses.dataclass(cls, **dataclass_kwargs)
585class _DataclassKwargsPre310(TypedDict, total=False):
586 init: bool
587 repr: bool
588 eq: bool
589 order: bool
590 unsafe_hash: bool
591 frozen: bool
594if sys.version_info >= (3, 10):
596 class DataclassKwargs(_DataclassKwargsPre310, total=False):
597 match_args: bool
598 kw_only: bool
599else:
601 class DataclassKwargs(_DataclassKwargsPre310, total=False):
602 pass
605@overload
606def dataclass_struct(
607 *,
608 size: Literal["native"] = "native",
609 byteorder: Literal["native"] = "native",
610 validate_defaults: bool = True,
611 **dataclass_kwargs: Unpack[DataclassKwargs],
612) -> Callable[[type], type]: ...
615@overload
616def dataclass_struct(
617 *,
618 size: Literal["std"],
619 byteorder: Literal["native", "big", "little", "network"] = "native",
620 validate_defaults: bool = True,
621 **dataclass_kwargs: Unpack[DataclassKwargs],
622) -> Callable[[type], type]: ...
625@dataclass_transform()
626def dataclass_struct(
627 *,
628 size: Literal["native", "std"] = "native",
629 byteorder: Literal["native", "big", "little", "network"] = "native",
630 validate_defaults: bool = True,
631 **dataclass_kwargs: Unpack[DataclassKwargs],
632) -> Callable[[type], type]:
633 """Create a dataclass struct.
635 Should be used as a decorator on a class:
637 ```python
638 import dataclasses_struct as dcs
640 @dcs.dataclass_struct()
641 class A:
642 data: dcs.Pointer
643 size: dcs.UnsignedSize
644 ```
646 The allowed `size` and `byteorder` argument combinations are as as follows.
648 | `size` | `byteorder` | Notes |
649 | ---------- | ----------- | ------------------------------------------------------------------ |
650 | `"native"` | `"native"` | The default. Native alignment and padding. |
651 | `"std"` | `"native"` | Standard integer sizes and system endianness, no alignment/padding. |
652 | `"std"` | `"little"` | Standard integer sizes and little endian, no alignment/padding. |
653 | `"std"` | `"big"` | Standard integer sizes and big endian, no alignment/padding. |
654 | `"std"` | `"network"` | Equivalent to `byteorder="big"`. |
656 Args:
657 size: The size mode.
658 byteorder: The byte order of the generated struct. If `size="native"`,
659 only `"native"` is allowed.
660 validate_defaults: Whether to validate the default values of any
661 fields.
662 dataclass_kwargs: Any additional keyword arguments to pass to the
663 [stdlib
664 `dataclass`](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass)
665 decorator. The `slots` and `weakref_slot` keyword arguments are not
666 supported.
668 Raises:
669 ValueError: If the `size` and `byteorder` args are invalid or if
670 `validate_defaults=True` and any of the fields' default values are
671 invalid for their type.
672 TypeError: If any of the fields' type annotations are invalid or
673 not supported.
674 """ # noqa: E501
675 is_native = size == "native"
676 if is_native:
677 if byteorder != "native":
678 raise ValueError("'native' size requires 'native' byteorder")
679 elif size != "std":
680 raise ValueError(f"invalid size: {size}")
681 if byteorder not in ("native", "big", "little", "network"):
682 raise ValueError(f"invalid byteorder: {byteorder}")
684 for kwarg in ("slots", "weakref_slot"):
685 if kwarg in dataclass_kwargs:
686 msg = f"dataclass '{kwarg}' keyword argument is not supported"
687 raise ValueError(msg)
689 def decorator(cls: type) -> type:
690 return _make_class(
691 cls,
692 mode=_SIZE_BYTEORDER_MODE_CHAR[(size, byteorder)],
693 is_native=is_native,
694 validate_defaults=validate_defaults,
695 dataclass_kwargs=dataclass_kwargs,
696 )
698 return decorator