Coverage for src/configuraptor/binary_config.py: 100%
108 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-18 16:37 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-18 16:37 +0200
1"""
2Logic to do with parsing bytestrings as configuration (using struct).
3"""
5import collections
6import struct
7import typing
8from dataclasses import dataclass
9from io import BytesIO
10from pathlib import Path
12from typing_extensions import Self
14from . import loaders
15from .abs import AbstractTypedConfig
16from .helpers import is_custom_class
17from .loaders.register import DUMPERS
19BINARY_TYPES = typing.Union[str, float, int, bool]
22class BinaryConfig(AbstractTypedConfig):
23 """
24 Inherit this class if you want your config or a section of it to be parsed using struct.
25 """
27 _fields: collections.OrderedDict[str, "_BinaryField"]
29 def __init__(self) -> None:
30 """
31 Before filling the class with data, we store the fields (BinaryField) for later use.
32 """
33 fields, elements = self._collect_fields()
34 self._fields = collections.OrderedDict(zip(fields, elements))
35 super().__init__()
37 @classmethod
38 def _collect_fields(cls) -> tuple[list[str], list["_BinaryField"]]:
39 """
40 Get the class' field names and dataclass instances.
41 """
42 elements: list[_BinaryField] = []
43 fields: list[str] = []
45 for field, value in cls.__dict__.items():
46 if field.startswith("_"):
47 continue
48 if not isinstance(value, _BinaryField):
49 # other data, skip
50 continue
52 fields.append(field)
53 elements.append(value)
55 return fields, elements
57 @classmethod
58 def _parse(cls, data: bytes | dict[str, bytes]) -> dict[str, BINARY_TYPES]:
59 """
60 Parse a bytestring or a dict of bytestrings (in the right order).
61 """
62 from .core import load_into
64 # NOTE: annotations not used!
65 fields, elements = cls._collect_fields()
67 if isinstance(data, dict):
68 # create one long bytestring of data in the right order:
69 data = b"".join(data[field] for field in fields)
71 unpacked = struct.unpack(" ".join(str(_) for _ in elements), data)
72 final_data: dict[str, BINARY_TYPES] = {}
74 zipped: typing.Iterable[tuple[str, typing.Any, _BinaryField]] = zip(fields, unpacked, elements)
75 for field, value, meta in zipped:
76 if isinstance(value, bytes) and not issubclass(meta.klass, BinaryConfig):
77 value = value.strip(b"\x00").decode()
79 if meta.special:
80 # e.g. load from JSON
81 value = meta.special(value)
83 # ensure it's the right class (e.g. bool):
84 value = load_into(meta.klass, value) if is_custom_class(meta.klass) else meta.klass(value)
86 final_data[field] = value
88 return final_data
90 @classmethod
91 def _parse_into(cls, data: bytes | dict[str, bytes]) -> Self:
92 """
93 Create a new instance based on data.
94 """
95 converted = cls._parse(data)
96 inst = cls()
97 inst.__dict__.update(**converted)
98 return inst
100 def _pack(self) -> bytes:
101 """
102 Pack an instance back into a bytestring.
103 """
104 fmt = " ".join(str(_) for _ in self._fields.values())
106 values = [self._fields[k].pack(v) for k, v in self.__dict__.items() if not k.startswith("_")]
108 return struct.pack(fmt, *values)
110 @classmethod
111 def _format(cls) -> str:
112 _, fields = cls._collect_fields()
114 return " ".join(str(_) for _ in fields)
116 @classmethod
117 def _get_length(cls) -> int:
118 """
119 How many bytes do the fields of this class have?
120 """
121 fmt = cls._format()
123 return struct.calcsize(fmt)
125 def __setattr__(self, key: str, value: typing.Any) -> None:
126 """
127 When setting a new field for this config, update the _fields property to have the correct new type + size.
128 """
129 if not key.startswith("_") and isinstance(value, BinaryConfig):
130 field = self._fields[key]
131 field.klass = value.__class__
132 field.length = value.__class__._get_length()
134 return super().__setattr__(key, value)
137@dataclass
138class _BinaryField:
139 """
140 Class that stores info to parse the value from a bytestring.
142 Returned by BinaryField, but overwritten on instances with the actual value of type klass.
143 """
145 klass: typing.Type[typing.Any]
146 length: int
147 fmt: str
148 special: typing.Callable[[typing.Any], dict[str, typing.Any]] | None
149 packer: typing.Callable[[typing.Any], typing.Any] | None
151 def __str__(self) -> str:
152 return f"{self.length}{self.fmt}"
154 def pack(self, value: typing.Any) -> typing.Any:
155 if self.packer:
156 value = self.packer(value)
157 if isinstance(value, str):
158 return value.encode()
159 if isinstance(value, BinaryConfig):
160 return value._pack()
161 return value
164T = typing.TypeVar("T")
166# https://docs.python.org/3/library/struct.html
167# DEFAULT_LENGTHS = {
168# "x": 1,
169# "c": 1,
170# "b": 1,
171# "?": 1,
172# "h": 2,
173# "H": 2,
174# "i": 4,
175# "I": 4,
176# "l": 4,
177# "L": 4,
178# "q": 8,
179# "Q": 8,
180# "n": 8,
181# "N": 8,
182# "e": 2,
183# "f": 4,
184# "d": 8,
185# "s": 1,
186# "p": 1,
187# "P": 8,
188# }
190DEFAULT_FORMATS = {
191 str: "s",
192 int: "i",
193 float: "f",
194 bool: "?", # b
195}
198def BinaryField(klass: typing.Type[T], **kw: typing.Any) -> T:
199 """
200 Fields for BinaryConfig can not be annotated like a regular typed config, \
201 because more info is required (such as struct format/type and length).
203 This actually returns a _BinaryField but when using load/load_into, the value will be replaced with type 'klass'.
205 Args:
206 klass (type): the final type the value will have
207 format (str): either one of the formats of struct (e.g. 10s) or a loadable format (json, toml, yaml etc.)
208 length (int): how many bytes of data to store? (required for str, unless you enter it in format already)
210 Usage:
211 class MyConfig(BinaryConfig):
212 string = BinaryField(str, length=5) # string of 5 characters
213 integer = BinaryField(int)
214 complex = BinaryField(OtherClass, format='json', length=64)
215 # will extract 64 bytes of string and try to convert to the linked class
216 # (using regular typeconfig logic)
217 """
218 special = None
219 packer = None
221 if issubclass(klass, BinaryConfig):
222 fmt = "s" # temporarily group as one string
223 length = kw.get("length", klass._get_length())
224 else:
225 fmt = kw.get("format") or DEFAULT_FORMATS[klass]
226 if loader := loaders.get(fmt, None):
227 special = lambda data: loader( # noqa: E731
228 BytesIO(data if isinstance(data, bytes) else data.encode()), Path()
229 )
230 if _packer := DUMPERS.get(fmt, None):
231 packer = lambda data: _packer(data, with_top_level_key=False) # noqa: E731
232 length = kw["length"]
233 fmt = "s"
234 elif len(fmt) > 1:
235 # length in format: 10s
236 length, fmt = int(fmt[:-1]), fmt[-1]
237 else:
238 length = kw.get("length", 1)
240 field = _BinaryField(
241 klass,
242 fmt=fmt,
243 length=length,
244 special=special,
245 packer=packer,
246 )
248 return typing.cast(T, field)