Coverage for src/configuraptor/binary_config.py: 100%

108 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2026-05-01 17:14 +0200

1""" 

2Logic to do with parsing bytestrings as configuration (using struct). 

3""" 

4 

5import collections 

6import struct 

7import typing 

8from dataclasses import dataclass 

9from io import BytesIO 

10from pathlib import Path 

11from typing import Self 

12 

13from . import loaders 

14from .abs import AbstractTypedConfig 

15from .helpers import is_custom_class 

16from .loaders.register import DUMPERS 

17 

18BINARY_TYPES = typing.Union[str, float, int, bool] 

19 

20 

21class BinaryConfig(AbstractTypedConfig): 

22 """ 

23 Inherit this class if you want your config or a section of it to be parsed using struct. 

24 """ 

25 

26 _fields: collections.OrderedDict[str, "_BinaryField"] 

27 

28 def __init__(self) -> None: 

29 """ 

30 Before filling the class with data, we store the fields (BinaryField) for later use. 

31 """ 

32 fields, elements = self._collect_fields() 

33 self._fields = collections.OrderedDict(zip(fields, elements)) 

34 super().__init__() 

35 

36 @classmethod 

37 def _collect_fields(cls) -> tuple[list[str], list["_BinaryField"]]: 

38 """ 

39 Get the class' field names and dataclass instances. 

40 """ 

41 elements: list[_BinaryField] = [] 

42 fields: list[str] = [] 

43 

44 for field, value in cls.__dict__.items(): 

45 if field.startswith("_"): 

46 continue 

47 if not isinstance(value, _BinaryField): 

48 # other data, skip 

49 continue 

50 

51 fields.append(field) 

52 elements.append(value) 

53 

54 return fields, elements 

55 

56 @classmethod 

57 def _parse(cls, data: bytes | dict[str, bytes]) -> dict[str, BINARY_TYPES]: 

58 """ 

59 Parse a bytestring or a dict of bytestrings (in the right order). 

60 """ 

61 from .core import load_into 

62 

63 # NOTE: annotations not used! 

64 fields, elements = cls._collect_fields() 

65 

66 if isinstance(data, dict): 

67 # create one long bytestring of data in the right order: 

68 data = b"".join(data[field] for field in fields) 

69 

70 unpacked = struct.unpack(" ".join(str(_) for _ in elements), data) 

71 final_data: dict[str, BINARY_TYPES] = {} 

72 

73 zipped: typing.Iterable[tuple[str, typing.Any, _BinaryField]] = zip(fields, unpacked, elements) 

74 for field, value, meta in zipped: 

75 if isinstance(value, bytes) and not issubclass(meta.klass, BinaryConfig): 

76 value = value.strip(b"\x00").decode() 

77 

78 if meta.special: 

79 # e.g. load from JSON 

80 value = meta.special(value) 

81 

82 # ensure it's the right class (e.g. bool): 

83 value = load_into(meta.klass, value) if is_custom_class(meta.klass) else meta.klass(value) 

84 

85 final_data[field] = value 

86 

87 return final_data 

88 

89 @classmethod 

90 def _parse_into(cls, data: bytes | dict[str, bytes]) -> Self: 

91 """ 

92 Create a new instance based on data. 

93 """ 

94 converted = cls._parse(data) 

95 inst = cls() 

96 inst.__dict__.update(**converted) 

97 return inst 

98 

99 def _pack(self) -> bytes: 

100 """ 

101 Pack an instance back into a bytestring. 

102 """ 

103 fmt = " ".join(str(_) for _ in self._fields.values()) 

104 

105 values = [self._fields[k].pack(v) for k, v in self.__dict__.items() if not k.startswith("_")] 

106 

107 return struct.pack(fmt, *values) 

108 

109 @classmethod 

110 def _format(cls) -> str: 

111 _, fields = cls._collect_fields() 

112 

113 return " ".join(str(_) for _ in fields) 

114 

115 @classmethod 

116 def _get_length(cls) -> int: 

117 """ 

118 How many bytes do the fields of this class have? 

119 """ 

120 fmt = cls._format() 

121 

122 return struct.calcsize(fmt) 

123 

124 def __setattr__(self, key: str, value: typing.Any) -> None: 

125 """ 

126 When setting a new field for this config, update the _fields property to have the correct new type + size. 

127 """ 

128 if not key.startswith("_") and isinstance(value, BinaryConfig): 

129 field = self._fields[key] 

130 field.klass = value.__class__ 

131 field.length = value.__class__._get_length() 

132 

133 return super().__setattr__(key, value) 

134 

135 

136@dataclass(slots=True) 

137class _BinaryField: 

138 """ 

139 Class that stores info to parse the value from a bytestring. 

140 

141 Returned by BinaryField, but overwritten on instances with the actual value of type klass. 

142 """ 

143 

144 klass: typing.Type[typing.Any] 

145 length: int 

146 fmt: str 

147 special: typing.Callable[[typing.Any], dict[str, typing.Any]] | None 

148 packer: typing.Callable[[typing.Any], typing.Any] | None 

149 

150 def __str__(self) -> str: 

151 return f"{self.length}{self.fmt}" 

152 

153 def pack(self, value: typing.Any) -> typing.Any: 

154 if self.packer: 

155 value = self.packer(value) 

156 if isinstance(value, str): 

157 return value.encode() 

158 if isinstance(value, BinaryConfig): 

159 return value._pack() 

160 return value 

161 

162 

163T = typing.TypeVar("T") 

164 

165# https://docs.python.org/3/library/struct.html 

166# DEFAULT_LENGTHS = { 

167# "x": 1, 

168# "c": 1, 

169# "b": 1, 

170# "?": 1, 

171# "h": 2, 

172# "H": 2, 

173# "i": 4, 

174# "I": 4, 

175# "l": 4, 

176# "L": 4, 

177# "q": 8, 

178# "Q": 8, 

179# "n": 8, 

180# "N": 8, 

181# "e": 2, 

182# "f": 4, 

183# "d": 8, 

184# "s": 1, 

185# "p": 1, 

186# "P": 8, 

187# } 

188 

189DEFAULT_FORMATS = { 

190 str: "s", 

191 int: "i", 

192 float: "f", 

193 bool: "?", # b 

194} 

195 

196 

197def BinaryField(klass: typing.Type[T], **kw: typing.Any) -> T: 

198 """ 

199 Fields for BinaryConfig can not be annotated like a regular typed config, \ 

200 because more info is required (such as struct format/type and length). 

201 

202 This actually returns a _BinaryField but when using load/load_into, the value will be replaced with type 'klass'. 

203 

204 Args: 

205 klass (type): the final type the value will have 

206 format (str): either one of the formats of struct (e.g. 10s) or a loadable format (json, toml, yaml etc.) 

207 length (int): how many bytes of data to store? (required for str, unless you enter it in format already) 

208 

209 Usage: 

210 class MyConfig(BinaryConfig): 

211 string = BinaryField(str, length=5) # string of 5 characters 

212 integer = BinaryField(int) 

213 complex = BinaryField(OtherClass, format='json', length=64) 

214 # will extract 64 bytes of string and try to convert to the linked class 

215 # (using regular typeconfig logic) 

216 """ 

217 special = None 

218 packer = None 

219 

220 if issubclass(klass, BinaryConfig): 

221 fmt = "s" # temporarily group as one string 

222 length = kw.get("length", klass._get_length()) 

223 else: 

224 fmt = kw.get("format") or DEFAULT_FORMATS[klass] 

225 if loader := loaders.get(fmt, None): 

226 special = lambda data: loader( # noqa: E731 

227 BytesIO(data if isinstance(data, bytes) else data.encode()), Path() 

228 ) 

229 if _packer := DUMPERS.get(fmt, None): 

230 packer = lambda data: _packer(data, with_top_level_key=False) # noqa: E731 

231 length = kw["length"] 

232 fmt = "s" 

233 elif len(fmt) > 1: 

234 # length in format: 10s 

235 length, fmt = int(fmt[:-1]), fmt[-1] 

236 else: 

237 length = kw.get("length", 1) 

238 

239 field = _BinaryField( 

240 klass, 

241 fmt=fmt, 

242 length=length, 

243 special=special, 

244 packer=packer, 

245 ) 

246 

247 return typing.cast(T, field)