Coverage for src/configuraptor/binary_config.py: 100%

108 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-18 16:37 +0200

1""" 

2Logic to do with parsing bytestrings as configuration (using struct). 

3""" 

4 

5import collections 

6import struct 

7import typing 

8from dataclasses import dataclass 

9from io import BytesIO 

10from pathlib import Path 

11 

12from typing_extensions import Self 

13 

14from . import loaders 

15from .abs import AbstractTypedConfig 

16from .helpers import is_custom_class 

17from .loaders.register import DUMPERS 

18 

19BINARY_TYPES = typing.Union[str, float, int, bool] 

20 

21 

22class BinaryConfig(AbstractTypedConfig): 

23 """ 

24 Inherit this class if you want your config or a section of it to be parsed using struct. 

25 """ 

26 

27 _fields: collections.OrderedDict[str, "_BinaryField"] 

28 

29 def __init__(self) -> None: 

30 """ 

31 Before filling the class with data, we store the fields (BinaryField) for later use. 

32 """ 

33 fields, elements = self._collect_fields() 

34 self._fields = collections.OrderedDict(zip(fields, elements)) 

35 super().__init__() 

36 

37 @classmethod 

38 def _collect_fields(cls) -> tuple[list[str], list["_BinaryField"]]: 

39 """ 

40 Get the class' field names and dataclass instances. 

41 """ 

42 elements: list[_BinaryField] = [] 

43 fields: list[str] = [] 

44 

45 for field, value in cls.__dict__.items(): 

46 if field.startswith("_"): 

47 continue 

48 if not isinstance(value, _BinaryField): 

49 # other data, skip 

50 continue 

51 

52 fields.append(field) 

53 elements.append(value) 

54 

55 return fields, elements 

56 

57 @classmethod 

58 def _parse(cls, data: bytes | dict[str, bytes]) -> dict[str, BINARY_TYPES]: 

59 """ 

60 Parse a bytestring or a dict of bytestrings (in the right order). 

61 """ 

62 from .core import load_into 

63 

64 # NOTE: annotations not used! 

65 fields, elements = cls._collect_fields() 

66 

67 if isinstance(data, dict): 

68 # create one long bytestring of data in the right order: 

69 data = b"".join(data[field] for field in fields) 

70 

71 unpacked = struct.unpack(" ".join(str(_) for _ in elements), data) 

72 final_data: dict[str, BINARY_TYPES] = {} 

73 

74 zipped: typing.Iterable[tuple[str, typing.Any, _BinaryField]] = zip(fields, unpacked, elements) 

75 for field, value, meta in zipped: 

76 if isinstance(value, bytes) and not issubclass(meta.klass, BinaryConfig): 

77 value = value.strip(b"\x00").decode() 

78 

79 if meta.special: 

80 # e.g. load from JSON 

81 value = meta.special(value) 

82 

83 # ensure it's the right class (e.g. bool): 

84 value = load_into(meta.klass, value) if is_custom_class(meta.klass) else meta.klass(value) 

85 

86 final_data[field] = value 

87 

88 return final_data 

89 

90 @classmethod 

91 def _parse_into(cls, data: bytes | dict[str, bytes]) -> Self: 

92 """ 

93 Create a new instance based on data. 

94 """ 

95 converted = cls._parse(data) 

96 inst = cls() 

97 inst.__dict__.update(**converted) 

98 return inst 

99 

100 def _pack(self) -> bytes: 

101 """ 

102 Pack an instance back into a bytestring. 

103 """ 

104 fmt = " ".join(str(_) for _ in self._fields.values()) 

105 

106 values = [self._fields[k].pack(v) for k, v in self.__dict__.items() if not k.startswith("_")] 

107 

108 return struct.pack(fmt, *values) 

109 

110 @classmethod 

111 def _format(cls) -> str: 

112 _, fields = cls._collect_fields() 

113 

114 return " ".join(str(_) for _ in fields) 

115 

116 @classmethod 

117 def _get_length(cls) -> int: 

118 """ 

119 How many bytes do the fields of this class have? 

120 """ 

121 fmt = cls._format() 

122 

123 return struct.calcsize(fmt) 

124 

125 def __setattr__(self, key: str, value: typing.Any) -> None: 

126 """ 

127 When setting a new field for this config, update the _fields property to have the correct new type + size. 

128 """ 

129 if not key.startswith("_") and isinstance(value, BinaryConfig): 

130 field = self._fields[key] 

131 field.klass = value.__class__ 

132 field.length = value.__class__._get_length() 

133 

134 return super().__setattr__(key, value) 

135 

136 

137@dataclass 

138class _BinaryField: 

139 """ 

140 Class that stores info to parse the value from a bytestring. 

141 

142 Returned by BinaryField, but overwritten on instances with the actual value of type klass. 

143 """ 

144 

145 klass: typing.Type[typing.Any] 

146 length: int 

147 fmt: str 

148 special: typing.Callable[[typing.Any], dict[str, typing.Any]] | None 

149 packer: typing.Callable[[typing.Any], typing.Any] | None 

150 

151 def __str__(self) -> str: 

152 return f"{self.length}{self.fmt}" 

153 

154 def pack(self, value: typing.Any) -> typing.Any: 

155 if self.packer: 

156 value = self.packer(value) 

157 if isinstance(value, str): 

158 return value.encode() 

159 if isinstance(value, BinaryConfig): 

160 return value._pack() 

161 return value 

162 

163 

164T = typing.TypeVar("T") 

165 

166# https://docs.python.org/3/library/struct.html 

167# DEFAULT_LENGTHS = { 

168# "x": 1, 

169# "c": 1, 

170# "b": 1, 

171# "?": 1, 

172# "h": 2, 

173# "H": 2, 

174# "i": 4, 

175# "I": 4, 

176# "l": 4, 

177# "L": 4, 

178# "q": 8, 

179# "Q": 8, 

180# "n": 8, 

181# "N": 8, 

182# "e": 2, 

183# "f": 4, 

184# "d": 8, 

185# "s": 1, 

186# "p": 1, 

187# "P": 8, 

188# } 

189 

190DEFAULT_FORMATS = { 

191 str: "s", 

192 int: "i", 

193 float: "f", 

194 bool: "?", # b 

195} 

196 

197 

198def BinaryField(klass: typing.Type[T], **kw: typing.Any) -> T: 

199 """ 

200 Fields for BinaryConfig can not be annotated like a regular typed config, \ 

201 because more info is required (such as struct format/type and length). 

202 

203 This actually returns a _BinaryField but when using load/load_into, the value will be replaced with type 'klass'. 

204 

205 Args: 

206 klass (type): the final type the value will have 

207 format (str): either one of the formats of struct (e.g. 10s) or a loadable format (json, toml, yaml etc.) 

208 length (int): how many bytes of data to store? (required for str, unless you enter it in format already) 

209 

210 Usage: 

211 class MyConfig(BinaryConfig): 

212 string = BinaryField(str, length=5) # string of 5 characters 

213 integer = BinaryField(int) 

214 complex = BinaryField(OtherClass, format='json', length=64) 

215 # will extract 64 bytes of string and try to convert to the linked class 

216 # (using regular typeconfig logic) 

217 """ 

218 special = None 

219 packer = None 

220 

221 if issubclass(klass, BinaryConfig): 

222 fmt = "s" # temporarily group as one string 

223 length = kw.get("length", klass._get_length()) 

224 else: 

225 fmt = kw.get("format") or DEFAULT_FORMATS[klass] 

226 if loader := loaders.get(fmt, None): 

227 special = lambda data: loader( # noqa: E731 

228 BytesIO(data if isinstance(data, bytes) else data.encode()), Path() 

229 ) 

230 if _packer := DUMPERS.get(fmt, None): 

231 packer = lambda data: _packer(data, with_top_level_key=False) # noqa: E731 

232 length = kw["length"] 

233 fmt = "s" 

234 elif len(fmt) > 1: 

235 # length in format: 10s 

236 length, fmt = int(fmt[:-1]), fmt[-1] 

237 else: 

238 length = kw.get("length", 1) 

239 

240 field = _BinaryField( 

241 klass, 

242 fmt=fmt, 

243 length=length, 

244 special=special, 

245 packer=packer, 

246 ) 

247 

248 return typing.cast(T, field)