Coverage for src/configuraptor/binary_config.py: 100%
108 statements
« prev ^ index » next coverage.py v7.2.7, created at 2026-05-01 16:56 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2026-05-01 16:56 +0200
1"""
2Logic to do with parsing bytestrings as configuration (using struct).
3"""
5import collections
6import struct
7import typing
8from dataclasses import dataclass
9from io import BytesIO
10from pathlib import Path
11from typing import Self
13from . import loaders
14from .abs import AbstractTypedConfig
15from .helpers import is_custom_class
16from .loaders.register import DUMPERS
18BINARY_TYPES = typing.Union[str, float, int, bool]
21class BinaryConfig(AbstractTypedConfig):
22 """
23 Inherit this class if you want your config or a section of it to be parsed using struct.
24 """
26 _fields: collections.OrderedDict[str, "_BinaryField"]
28 def __init__(self) -> None:
29 """
30 Before filling the class with data, we store the fields (BinaryField) for later use.
31 """
32 fields, elements = self._collect_fields()
33 self._fields = collections.OrderedDict(zip(fields, elements))
34 super().__init__()
36 @classmethod
37 def _collect_fields(cls) -> tuple[list[str], list["_BinaryField"]]:
38 """
39 Get the class' field names and dataclass instances.
40 """
41 elements: list[_BinaryField] = []
42 fields: list[str] = []
44 for field, value in cls.__dict__.items():
45 if field.startswith("_"):
46 continue
47 if not isinstance(value, _BinaryField):
48 # other data, skip
49 continue
51 fields.append(field)
52 elements.append(value)
54 return fields, elements
56 @classmethod
57 def _parse(cls, data: bytes | dict[str, bytes]) -> dict[str, BINARY_TYPES]:
58 """
59 Parse a bytestring or a dict of bytestrings (in the right order).
60 """
61 from .core import load_into
63 # NOTE: annotations not used!
64 fields, elements = cls._collect_fields()
66 if isinstance(data, dict):
67 # create one long bytestring of data in the right order:
68 data = b"".join(data[field] for field in fields)
70 unpacked = struct.unpack(" ".join(str(_) for _ in elements), data)
71 final_data: dict[str, BINARY_TYPES] = {}
73 zipped: typing.Iterable[tuple[str, typing.Any, _BinaryField]] = zip(fields, unpacked, elements)
74 for field, value, meta in zipped:
75 if isinstance(value, bytes) and not issubclass(meta.klass, BinaryConfig):
76 value = value.strip(b"\x00").decode()
78 if meta.special:
79 # e.g. load from JSON
80 value = meta.special(value)
82 # ensure it's the right class (e.g. bool):
83 value = load_into(meta.klass, value) if is_custom_class(meta.klass) else meta.klass(value)
85 final_data[field] = value
87 return final_data
89 @classmethod
90 def _parse_into(cls, data: bytes | dict[str, bytes]) -> Self:
91 """
92 Create a new instance based on data.
93 """
94 converted = cls._parse(data)
95 inst = cls()
96 inst.__dict__.update(**converted)
97 return inst
99 def _pack(self) -> bytes:
100 """
101 Pack an instance back into a bytestring.
102 """
103 fmt = " ".join(str(_) for _ in self._fields.values())
105 values = [self._fields[k].pack(v) for k, v in self.__dict__.items() if not k.startswith("_")]
107 return struct.pack(fmt, *values)
109 @classmethod
110 def _format(cls) -> str:
111 _, fields = cls._collect_fields()
113 return " ".join(str(_) for _ in fields)
115 @classmethod
116 def _get_length(cls) -> int:
117 """
118 How many bytes do the fields of this class have?
119 """
120 fmt = cls._format()
122 return struct.calcsize(fmt)
124 def __setattr__(self, key: str, value: typing.Any) -> None:
125 """
126 When setting a new field for this config, update the _fields property to have the correct new type + size.
127 """
128 if not key.startswith("_") and isinstance(value, BinaryConfig):
129 field = self._fields[key]
130 field.klass = value.__class__
131 field.length = value.__class__._get_length()
133 return super().__setattr__(key, value)
136@dataclass(slots=True)
137class _BinaryField:
138 """
139 Class that stores info to parse the value from a bytestring.
141 Returned by BinaryField, but overwritten on instances with the actual value of type klass.
142 """
144 klass: typing.Type[typing.Any]
145 length: int
146 fmt: str
147 special: typing.Callable[[typing.Any], dict[str, typing.Any]] | None
148 packer: typing.Callable[[typing.Any], typing.Any] | None
150 def __str__(self) -> str:
151 return f"{self.length}{self.fmt}"
153 def pack(self, value: typing.Any) -> typing.Any:
154 if self.packer:
155 value = self.packer(value)
156 if isinstance(value, str):
157 return value.encode()
158 if isinstance(value, BinaryConfig):
159 return value._pack()
160 return value
163T = typing.TypeVar("T")
165# https://docs.python.org/3/library/struct.html
166# DEFAULT_LENGTHS = {
167# "x": 1,
168# "c": 1,
169# "b": 1,
170# "?": 1,
171# "h": 2,
172# "H": 2,
173# "i": 4,
174# "I": 4,
175# "l": 4,
176# "L": 4,
177# "q": 8,
178# "Q": 8,
179# "n": 8,
180# "N": 8,
181# "e": 2,
182# "f": 4,
183# "d": 8,
184# "s": 1,
185# "p": 1,
186# "P": 8,
187# }
189DEFAULT_FORMATS = {
190 str: "s",
191 int: "i",
192 float: "f",
193 bool: "?", # b
194}
197def BinaryField(klass: typing.Type[T], **kw: typing.Any) -> T:
198 """
199 Fields for BinaryConfig can not be annotated like a regular typed config, \
200 because more info is required (such as struct format/type and length).
202 This actually returns a _BinaryField but when using load/load_into, the value will be replaced with type 'klass'.
204 Args:
205 klass (type): the final type the value will have
206 format (str): either one of the formats of struct (e.g. 10s) or a loadable format (json, toml, yaml etc.)
207 length (int): how many bytes of data to store? (required for str, unless you enter it in format already)
209 Usage:
210 class MyConfig(BinaryConfig):
211 string = BinaryField(str, length=5) # string of 5 characters
212 integer = BinaryField(int)
213 complex = BinaryField(OtherClass, format='json', length=64)
214 # will extract 64 bytes of string and try to convert to the linked class
215 # (using regular typeconfig logic)
216 """
217 special = None
218 packer = None
220 if issubclass(klass, BinaryConfig):
221 fmt = "s" # temporarily group as one string
222 length = kw.get("length", klass._get_length())
223 else:
224 fmt = kw.get("format") or DEFAULT_FORMATS[klass]
225 if loader := loaders.get(fmt, None):
226 special = lambda data: loader( # noqa: E731
227 BytesIO(data if isinstance(data, bytes) else data.encode()), Path()
228 )
229 if _packer := DUMPERS.get(fmt, None):
230 packer = lambda data: _packer(data, with_top_level_key=False) # noqa: E731
231 length = kw["length"]
232 fmt = "s"
233 elif len(fmt) > 1:
234 # length in format: 10s
235 length, fmt = int(fmt[:-1]), fmt[-1]
236 else:
237 length = kw.get("length", 1)
239 field = _BinaryField(
240 klass,
241 fmt=fmt,
242 length=length,
243 special=special,
244 packer=packer,
245 )
247 return typing.cast(T, field)