Coverage for src/configuraptor/binary_config.py: 100%
104 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-18 14:30 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-18 14:30 +0200
1"""
2Logic to do with parsing bytestrings as configuration (using struct).
3"""
5import collections
6import struct
7import typing
8from dataclasses import dataclass
9from io import BytesIO
10from pathlib import Path
12from typing_extensions import Self
14from . import loaders
15from .abs import AbstractTypedConfig
16from .helpers import is_custom_class
17from .loaders.register import DUMPERS
19BINARY_TYPES = typing.Union[str, float, int, bool]
22class BinaryConfig(AbstractTypedConfig):
23 """
24 Inherit this class if you want your config or a section of it to be parsed using struct.
25 """
27 _fields: collections.OrderedDict[str, "_BinaryField"]
29 def __init__(self) -> None:
30 """
31 Before filling the class with data, we store the fields (BinaryField) for later use.
32 """
33 fields, elements = self._collect_fields()
34 self._fields = collections.OrderedDict(zip(fields, elements))
35 super().__init__()
37 @classmethod
38 def _collect_fields(cls) -> tuple[list[str], list["_BinaryField"]]:
39 """
40 Get the class' field names and dataclass instances.
41 """
42 elements: list[_BinaryField] = []
43 fields: list[str] = []
45 for field, value in cls.__dict__.items():
46 if field.startswith("_"):
47 continue
48 if not isinstance(value, _BinaryField):
49 # other data, skip
50 continue
52 fields.append(field)
53 elements.append(value)
55 return fields, elements
57 @classmethod
58 def _parse(cls, data: bytes | dict[str, bytes]) -> dict[str, BINARY_TYPES]:
59 """
60 Parse a bytestring or a dict of bytestrings (in the right order).
61 """
62 from .core import load_into
64 # NOTE: annotations not used!
65 fields, elements = cls._collect_fields()
67 if isinstance(data, dict):
68 # create one long bytestring of data in the right order:
69 data = b"".join(data[field] for field in fields)
71 unpacked = struct.unpack(" ".join(str(_) for _ in elements), data)
72 final_data: dict[str, BINARY_TYPES] = {}
74 print(unpacked)
76 zipped: typing.Iterable[tuple[str, typing.Any, _BinaryField]] = zip(fields, unpacked, elements)
77 for field, value, meta in zipped:
78 if isinstance(value, bytes) and not issubclass(meta.klass, BinaryConfig):
79 value = value.strip(b"\x00").decode()
81 if meta.special:
82 # e.g. load from JSON
83 value = meta.special(value)
85 if is_custom_class(meta.klass):
86 value = load_into(meta.klass, value)
87 else:
88 # ensure it's the right class (e.g. bool):
89 value = meta.klass(value)
91 final_data[field] = value
93 return final_data
95 @classmethod
96 def _parse_into(cls, data: bytes | dict[str, bytes]) -> Self:
97 """
98 Create a new instance based on data.
99 """
100 converted = cls._parse(data)
101 inst = cls()
102 inst.__dict__.update(**converted)
103 return inst
105 def _pack(self) -> bytes:
106 """
107 Pack an instance back into a bytestring.
108 """
109 fmt = " ".join(str(_) for _ in self._fields.values())
111 values = [self._fields[k].pack(v) for k, v in self.__dict__.items() if not k.startswith("_")]
112 return struct.pack(fmt, *values)
114 @classmethod
115 def _format(cls) -> str:
116 _, fields = cls._collect_fields()
118 return " ".join(str(_) for _ in fields)
120 @classmethod
121 def _get_length(cls) -> int:
122 """
123 How many bytes do the fields of this class have?
124 """
125 fmt = cls._format()
127 return struct.calcsize(fmt)
130@dataclass
131class _BinaryField:
132 """
133 Class that stores info to parse the value from a bytestring.
135 Returned by BinaryField, but overwritten on instances with the actual value of type klass.
136 """
138 klass: typing.Type[typing.Any]
139 length: int
140 fmt: str
141 special: typing.Callable[[typing.Any], typing.Any] | None
142 packer: typing.Callable[[typing.Any], typing.Any] | None
144 def __str__(self) -> str:
145 return f"{self.length}{self.fmt}"
147 def pack(self, value: typing.Any) -> typing.Any:
148 if self.packer:
149 value = self.packer(value)
150 if isinstance(value, str):
151 return value.encode()
152 return value
155T = typing.TypeVar("T")
157# https://docs.python.org/3/library/struct.html
158# DEFAULT_LENGTHS = {
159# "x": 1,
160# "c": 1,
161# "b": 1,
162# "?": 1,
163# "h": 2,
164# "H": 2,
165# "i": 4,
166# "I": 4,
167# "l": 4,
168# "L": 4,
169# "q": 8,
170# "Q": 8,
171# "n": 8,
172# "N": 8,
173# "e": 2,
174# "f": 4,
175# "d": 8,
176# "s": 1,
177# "p": 1,
178# "P": 8,
179# }
181DEFAULT_FORMATS = {
182 str: "s",
183 int: "i",
184 float: "f",
185 bool: "?", # b
186}
189def BinaryField(klass: typing.Type[T], **kw: typing.Any) -> T:
190 """
191 Fields for BinaryConfig can not be annotated like a regular typed config, \
192 because more info is required (such as struct format/type and length).
194 This actually returns a _BinaryField but when using load/load_into, the value will be replaced with type 'klass'.
196 Args:
197 klass (type): the final type the value will have
198 format (str): either one of the formats of struct (e.g. 10s) or a loadable format (json, toml, yaml etc.)
199 length (int): how many bytes of data to store? (required for str, unless you enter it in format already)
201 Usage:
202 class MyConfig(BinaryConfig):
203 string = BinaryField(str, length=5) # string of 5 characters
204 integer = BinaryField(int)
205 complex = BinaryField(OtherClass, format='json', length=64)
206 # will extract 64 bytes of string and try to convert to the linked class
207 # (using regular typeconfig logic)
208 """
209 special = None
210 packer = None
212 if issubclass(klass, BinaryConfig):
213 from .core import load_into
215 fmt = "s" # temporarily group as one string
216 length = kw.get("length", klass._get_length())
217 else:
218 fmt = kw.get("format") or DEFAULT_FORMATS[klass]
219 if loader := loaders.get(fmt, None):
220 special = lambda data: loader(
221 BytesIO(data if isinstance(data, bytes) else data.encode()), Path()
222 ) # noqa: E731
223 if _packer := DUMPERS.get(fmt, None):
224 packer = lambda data: _packer(data, with_top_level_key=False) # noqa: E731
225 length = kw["length"]
226 fmt = "s"
227 elif len(fmt) > 1:
228 # length in format: 10s
229 length, fmt = int(fmt[:-1]), fmt[-1]
230 else:
231 length = kw.get("length", 1)
233 field = _BinaryField(
234 klass,
235 fmt=fmt,
236 length=length,
237 special=special,
238 packer=packer,
239 )
241 return typing.cast(typing.Type[T], field)