Coverage for src/configuraptor/binary_config.py: 100%

102 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-18 15:00 +0200

1""" 

2Logic to do with parsing bytestrings as configuration (using struct). 

3""" 

4 

5import collections 

6import struct 

7import typing 

8from dataclasses import dataclass 

9from io import BytesIO 

10from pathlib import Path 

11 

12from typing_extensions import Self 

13 

14from . import loaders 

15from .abs import AbstractTypedConfig 

16from .helpers import is_custom_class 

17from .loaders.register import DUMPERS 

18 

19BINARY_TYPES = typing.Union[str, float, int, bool] 

20 

21 

22class BinaryConfig(AbstractTypedConfig): 

23 """ 

24 Inherit this class if you want your config or a section of it to be parsed using struct. 

25 """ 

26 

27 _fields: collections.OrderedDict[str, "_BinaryField"] 

28 

29 def __init__(self) -> None: 

30 """ 

31 Before filling the class with data, we store the fields (BinaryField) for later use. 

32 """ 

33 fields, elements = self._collect_fields() 

34 self._fields = collections.OrderedDict(zip(fields, elements)) 

35 super().__init__() 

36 

37 @classmethod 

38 def _collect_fields(cls) -> tuple[list[str], list["_BinaryField"]]: 

39 """ 

40 Get the class' field names and dataclass instances. 

41 """ 

42 elements: list[_BinaryField] = [] 

43 fields: list[str] = [] 

44 

45 for field, value in cls.__dict__.items(): 

46 if field.startswith("_"): 

47 continue 

48 if not isinstance(value, _BinaryField): 

49 # other data, skip 

50 continue 

51 

52 fields.append(field) 

53 elements.append(value) 

54 

55 return fields, elements 

56 

57 @classmethod 

58 def _parse(cls, data: bytes | dict[str, bytes]) -> dict[str, BINARY_TYPES]: 

59 """ 

60 Parse a bytestring or a dict of bytestrings (in the right order). 

61 """ 

62 from .core import load_into 

63 

64 # NOTE: annotations not used! 

65 fields, elements = cls._collect_fields() 

66 

67 if isinstance(data, dict): 

68 # create one long bytestring of data in the right order: 

69 data = b"".join(data[field] for field in fields) 

70 

71 unpacked = struct.unpack(" ".join(str(_) for _ in elements), data) 

72 final_data: dict[str, BINARY_TYPES] = {} 

73 

74 zipped: typing.Iterable[tuple[str, typing.Any, _BinaryField]] = zip(fields, unpacked, elements) 

75 for field, value, meta in zipped: 

76 if isinstance(value, bytes) and not issubclass(meta.klass, BinaryConfig): 

77 value = value.strip(b"\x00").decode() 

78 

79 if meta.special: 

80 # e.g. load from JSON 

81 value = meta.special(value) 

82 

83 # ensure it's the right class (e.g. bool): 

84 value = load_into(meta.klass, value) if is_custom_class(meta.klass) else meta.klass(value) 

85 

86 final_data[field] = value 

87 

88 return final_data 

89 

90 @classmethod 

91 def _parse_into(cls, data: bytes | dict[str, bytes]) -> Self: 

92 """ 

93 Create a new instance based on data. 

94 """ 

95 converted = cls._parse(data) 

96 inst = cls() 

97 inst.__dict__.update(**converted) 

98 return inst 

99 

100 def _pack(self) -> bytes: 

101 """ 

102 Pack an instance back into a bytestring. 

103 """ 

104 fmt = " ".join(str(_) for _ in self._fields.values()) 

105 

106 values = [self._fields[k].pack(v) for k, v in self.__dict__.items() if not k.startswith("_")] 

107 

108 return struct.pack(fmt, *values) 

109 

110 @classmethod 

111 def _format(cls) -> str: 

112 _, fields = cls._collect_fields() 

113 

114 return " ".join(str(_) for _ in fields) 

115 

116 @classmethod 

117 def _get_length(cls) -> int: 

118 """ 

119 How many bytes do the fields of this class have? 

120 """ 

121 fmt = cls._format() 

122 

123 return struct.calcsize(fmt) 

124 

125 

126@dataclass 

127class _BinaryField: 

128 """ 

129 Class that stores info to parse the value from a bytestring. 

130 

131 Returned by BinaryField, but overwritten on instances with the actual value of type klass. 

132 """ 

133 

134 klass: typing.Type[typing.Any] 

135 length: int 

136 fmt: str 

137 special: typing.Callable[[typing.Any], dict[str, typing.Any]] | None 

138 packer: typing.Callable[[typing.Any], typing.Any] | None 

139 

140 def __str__(self) -> str: 

141 return f"{self.length}{self.fmt}" 

142 

143 def pack(self, value: typing.Any) -> typing.Any: 

144 if self.packer: 

145 value = self.packer(value) 

146 if isinstance(value, str): 

147 return value.encode() 

148 if isinstance(value, BinaryConfig): 

149 return value._pack() 

150 return value 

151 

152 

153T = typing.TypeVar("T") 

154 

155# https://docs.python.org/3/library/struct.html 

156# DEFAULT_LENGTHS = { 

157# "x": 1, 

158# "c": 1, 

159# "b": 1, 

160# "?": 1, 

161# "h": 2, 

162# "H": 2, 

163# "i": 4, 

164# "I": 4, 

165# "l": 4, 

166# "L": 4, 

167# "q": 8, 

168# "Q": 8, 

169# "n": 8, 

170# "N": 8, 

171# "e": 2, 

172# "f": 4, 

173# "d": 8, 

174# "s": 1, 

175# "p": 1, 

176# "P": 8, 

177# } 

178 

179DEFAULT_FORMATS = { 

180 str: "s", 

181 int: "i", 

182 float: "f", 

183 bool: "?", # b 

184} 

185 

186 

187def BinaryField(klass: typing.Type[T], **kw: typing.Any) -> T: 

188 """ 

189 Fields for BinaryConfig can not be annotated like a regular typed config, \ 

190 because more info is required (such as struct format/type and length). 

191 

192 This actually returns a _BinaryField but when using load/load_into, the value will be replaced with type 'klass'. 

193 

194 Args: 

195 klass (type): the final type the value will have 

196 format (str): either one of the formats of struct (e.g. 10s) or a loadable format (json, toml, yaml etc.) 

197 length (int): how many bytes of data to store? (required for str, unless you enter it in format already) 

198 

199 Usage: 

200 class MyConfig(BinaryConfig): 

201 string = BinaryField(str, length=5) # string of 5 characters 

202 integer = BinaryField(int) 

203 complex = BinaryField(OtherClass, format='json', length=64) 

204 # will extract 64 bytes of string and try to convert to the linked class 

205 # (using regular typeconfig logic) 

206 """ 

207 special = None 

208 packer = None 

209 

210 if issubclass(klass, BinaryConfig): 

211 fmt = "s" # temporarily group as one string 

212 length = kw.get("length", klass._get_length()) 

213 else: 

214 fmt = kw.get("format") or DEFAULT_FORMATS[klass] 

215 if loader := loaders.get(fmt, None): 

216 special = lambda data: loader( # noqa: E731 

217 BytesIO(data if isinstance(data, bytes) else data.encode()), Path() 

218 ) 

219 if _packer := DUMPERS.get(fmt, None): 

220 packer = lambda data: _packer(data, with_top_level_key=False) # noqa: E731 

221 length = kw["length"] 

222 fmt = "s" 

223 elif len(fmt) > 1: 

224 # length in format: 10s 

225 length, fmt = int(fmt[:-1]), fmt[-1] 

226 else: 

227 length = kw.get("length", 1) 

228 

229 field = _BinaryField( 

230 klass, 

231 fmt=fmt, 

232 length=length, 

233 special=special, 

234 packer=packer, 

235 ) 

236 

237 return typing.cast(T, field)