Coverage for src/configuraptor/binary_config.py: 100%
102 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-18 15:00 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-18 15:00 +0200
1"""
2Logic to do with parsing bytestrings as configuration (using struct).
3"""
5import collections
6import struct
7import typing
8from dataclasses import dataclass
9from io import BytesIO
10from pathlib import Path
12from typing_extensions import Self
14from . import loaders
15from .abs import AbstractTypedConfig
16from .helpers import is_custom_class
17from .loaders.register import DUMPERS
19BINARY_TYPES = typing.Union[str, float, int, bool]
22class BinaryConfig(AbstractTypedConfig):
23 """
24 Inherit this class if you want your config or a section of it to be parsed using struct.
25 """
27 _fields: collections.OrderedDict[str, "_BinaryField"]
29 def __init__(self) -> None:
30 """
31 Before filling the class with data, we store the fields (BinaryField) for later use.
32 """
33 fields, elements = self._collect_fields()
34 self._fields = collections.OrderedDict(zip(fields, elements))
35 super().__init__()
37 @classmethod
38 def _collect_fields(cls) -> tuple[list[str], list["_BinaryField"]]:
39 """
40 Get the class' field names and dataclass instances.
41 """
42 elements: list[_BinaryField] = []
43 fields: list[str] = []
45 for field, value in cls.__dict__.items():
46 if field.startswith("_"):
47 continue
48 if not isinstance(value, _BinaryField):
49 # other data, skip
50 continue
52 fields.append(field)
53 elements.append(value)
55 return fields, elements
57 @classmethod
58 def _parse(cls, data: bytes | dict[str, bytes]) -> dict[str, BINARY_TYPES]:
59 """
60 Parse a bytestring or a dict of bytestrings (in the right order).
61 """
62 from .core import load_into
64 # NOTE: annotations not used!
65 fields, elements = cls._collect_fields()
67 if isinstance(data, dict):
68 # create one long bytestring of data in the right order:
69 data = b"".join(data[field] for field in fields)
71 unpacked = struct.unpack(" ".join(str(_) for _ in elements), data)
72 final_data: dict[str, BINARY_TYPES] = {}
74 zipped: typing.Iterable[tuple[str, typing.Any, _BinaryField]] = zip(fields, unpacked, elements)
75 for field, value, meta in zipped:
76 if isinstance(value, bytes) and not issubclass(meta.klass, BinaryConfig):
77 value = value.strip(b"\x00").decode()
79 if meta.special:
80 # e.g. load from JSON
81 value = meta.special(value)
83 # ensure it's the right class (e.g. bool):
84 value = load_into(meta.klass, value) if is_custom_class(meta.klass) else meta.klass(value)
86 final_data[field] = value
88 return final_data
90 @classmethod
91 def _parse_into(cls, data: bytes | dict[str, bytes]) -> Self:
92 """
93 Create a new instance based on data.
94 """
95 converted = cls._parse(data)
96 inst = cls()
97 inst.__dict__.update(**converted)
98 return inst
100 def _pack(self) -> bytes:
101 """
102 Pack an instance back into a bytestring.
103 """
104 fmt = " ".join(str(_) for _ in self._fields.values())
106 values = [self._fields[k].pack(v) for k, v in self.__dict__.items() if not k.startswith("_")]
108 return struct.pack(fmt, *values)
110 @classmethod
111 def _format(cls) -> str:
112 _, fields = cls._collect_fields()
114 return " ".join(str(_) for _ in fields)
116 @classmethod
117 def _get_length(cls) -> int:
118 """
119 How many bytes do the fields of this class have?
120 """
121 fmt = cls._format()
123 return struct.calcsize(fmt)
126@dataclass
127class _BinaryField:
128 """
129 Class that stores info to parse the value from a bytestring.
131 Returned by BinaryField, but overwritten on instances with the actual value of type klass.
132 """
134 klass: typing.Type[typing.Any]
135 length: int
136 fmt: str
137 special: typing.Callable[[typing.Any], dict[str, typing.Any]] | None
138 packer: typing.Callable[[typing.Any], typing.Any] | None
140 def __str__(self) -> str:
141 return f"{self.length}{self.fmt}"
143 def pack(self, value: typing.Any) -> typing.Any:
144 if self.packer:
145 value = self.packer(value)
146 if isinstance(value, str):
147 return value.encode()
148 if isinstance(value, BinaryConfig):
149 return value._pack()
150 return value
153T = typing.TypeVar("T")
155# https://docs.python.org/3/library/struct.html
156# DEFAULT_LENGTHS = {
157# "x": 1,
158# "c": 1,
159# "b": 1,
160# "?": 1,
161# "h": 2,
162# "H": 2,
163# "i": 4,
164# "I": 4,
165# "l": 4,
166# "L": 4,
167# "q": 8,
168# "Q": 8,
169# "n": 8,
170# "N": 8,
171# "e": 2,
172# "f": 4,
173# "d": 8,
174# "s": 1,
175# "p": 1,
176# "P": 8,
177# }
179DEFAULT_FORMATS = {
180 str: "s",
181 int: "i",
182 float: "f",
183 bool: "?", # b
184}
187def BinaryField(klass: typing.Type[T], **kw: typing.Any) -> T:
188 """
189 Fields for BinaryConfig can not be annotated like a regular typed config, \
190 because more info is required (such as struct format/type and length).
192 This actually returns a _BinaryField but when using load/load_into, the value will be replaced with type 'klass'.
194 Args:
195 klass (type): the final type the value will have
196 format (str): either one of the formats of struct (e.g. 10s) or a loadable format (json, toml, yaml etc.)
197 length (int): how many bytes of data to store? (required for str, unless you enter it in format already)
199 Usage:
200 class MyConfig(BinaryConfig):
201 string = BinaryField(str, length=5) # string of 5 characters
202 integer = BinaryField(int)
203 complex = BinaryField(OtherClass, format='json', length=64)
204 # will extract 64 bytes of string and try to convert to the linked class
205 # (using regular typeconfig logic)
206 """
207 special = None
208 packer = None
210 if issubclass(klass, BinaryConfig):
211 fmt = "s" # temporarily group as one string
212 length = kw.get("length", klass._get_length())
213 else:
214 fmt = kw.get("format") or DEFAULT_FORMATS[klass]
215 if loader := loaders.get(fmt, None):
216 special = lambda data: loader( # noqa: E731
217 BytesIO(data if isinstance(data, bytes) else data.encode()), Path()
218 )
219 if _packer := DUMPERS.get(fmt, None):
220 packer = lambda data: _packer(data, with_top_level_key=False) # noqa: E731
221 length = kw["length"]
222 fmt = "s"
223 elif len(fmt) > 1:
224 # length in format: 10s
225 length, fmt = int(fmt[:-1]), fmt[-1]
226 else:
227 length = kw.get("length", 1)
229 field = _BinaryField(
230 klass,
231 fmt=fmt,
232 length=length,
233 special=special,
234 packer=packer,
235 )
237 return typing.cast(T, field)