Coverage for src/configuraptor/binary_config.py: 100%

104 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-18 14:30 +0200

1""" 

2Logic to do with parsing bytestrings as configuration (using struct). 

3""" 

4 

5import collections 

6import struct 

7import typing 

8from dataclasses import dataclass 

9from io import BytesIO 

10from pathlib import Path 

11 

12from typing_extensions import Self 

13 

14from . import loaders 

15from .abs import AbstractTypedConfig 

16from .helpers import is_custom_class 

17from .loaders.register import DUMPERS 

18 

19BINARY_TYPES = typing.Union[str, float, int, bool] 

20 

21 

22class BinaryConfig(AbstractTypedConfig): 

23 """ 

24 Inherit this class if you want your config or a section of it to be parsed using struct. 

25 """ 

26 

27 _fields: collections.OrderedDict[str, "_BinaryField"] 

28 

29 def __init__(self) -> None: 

30 """ 

31 Before filling the class with data, we store the fields (BinaryField) for later use. 

32 """ 

33 fields, elements = self._collect_fields() 

34 self._fields = collections.OrderedDict(zip(fields, elements)) 

35 super().__init__() 

36 

37 @classmethod 

38 def _collect_fields(cls) -> tuple[list[str], list["_BinaryField"]]: 

39 """ 

40 Get the class' field names and dataclass instances. 

41 """ 

42 elements: list[_BinaryField] = [] 

43 fields: list[str] = [] 

44 

45 for field, value in cls.__dict__.items(): 

46 if field.startswith("_"): 

47 continue 

48 if not isinstance(value, _BinaryField): 

49 # other data, skip 

50 continue 

51 

52 fields.append(field) 

53 elements.append(value) 

54 

55 return fields, elements 

56 

57 @classmethod 

58 def _parse(cls, data: bytes | dict[str, bytes]) -> dict[str, BINARY_TYPES]: 

59 """ 

60 Parse a bytestring or a dict of bytestrings (in the right order). 

61 """ 

62 from .core import load_into 

63 

64 # NOTE: annotations not used! 

65 fields, elements = cls._collect_fields() 

66 

67 if isinstance(data, dict): 

68 # create one long bytestring of data in the right order: 

69 data = b"".join(data[field] for field in fields) 

70 

71 unpacked = struct.unpack(" ".join(str(_) for _ in elements), data) 

72 final_data: dict[str, BINARY_TYPES] = {} 

73 

74 print(unpacked) 

75 

76 zipped: typing.Iterable[tuple[str, typing.Any, _BinaryField]] = zip(fields, unpacked, elements) 

77 for field, value, meta in zipped: 

78 if isinstance(value, bytes) and not issubclass(meta.klass, BinaryConfig): 

79 value = value.strip(b"\x00").decode() 

80 

81 if meta.special: 

82 # e.g. load from JSON 

83 value = meta.special(value) 

84 

85 if is_custom_class(meta.klass): 

86 value = load_into(meta.klass, value) 

87 else: 

88 # ensure it's the right class (e.g. bool): 

89 value = meta.klass(value) 

90 

91 final_data[field] = value 

92 

93 return final_data 

94 

95 @classmethod 

96 def _parse_into(cls, data: bytes | dict[str, bytes]) -> Self: 

97 """ 

98 Create a new instance based on data. 

99 """ 

100 converted = cls._parse(data) 

101 inst = cls() 

102 inst.__dict__.update(**converted) 

103 return inst 

104 

105 def _pack(self) -> bytes: 

106 """ 

107 Pack an instance back into a bytestring. 

108 """ 

109 fmt = " ".join(str(_) for _ in self._fields.values()) 

110 

111 values = [self._fields[k].pack(v) for k, v in self.__dict__.items() if not k.startswith("_")] 

112 return struct.pack(fmt, *values) 

113 

114 @classmethod 

115 def _format(cls) -> str: 

116 _, fields = cls._collect_fields() 

117 

118 return " ".join(str(_) for _ in fields) 

119 

120 @classmethod 

121 def _get_length(cls) -> int: 

122 """ 

123 How many bytes do the fields of this class have? 

124 """ 

125 fmt = cls._format() 

126 

127 return struct.calcsize(fmt) 

128 

129 

130@dataclass 

131class _BinaryField: 

132 """ 

133 Class that stores info to parse the value from a bytestring. 

134 

135 Returned by BinaryField, but overwritten on instances with the actual value of type klass. 

136 """ 

137 

138 klass: typing.Type[typing.Any] 

139 length: int 

140 fmt: str 

141 special: typing.Callable[[typing.Any], typing.Any] | None 

142 packer: typing.Callable[[typing.Any], typing.Any] | None 

143 

144 def __str__(self) -> str: 

145 return f"{self.length}{self.fmt}" 

146 

147 def pack(self, value: typing.Any) -> typing.Any: 

148 if self.packer: 

149 value = self.packer(value) 

150 if isinstance(value, str): 

151 return value.encode() 

152 return value 

153 

154 

155T = typing.TypeVar("T") 

156 

157# https://docs.python.org/3/library/struct.html 

158# DEFAULT_LENGTHS = { 

159# "x": 1, 

160# "c": 1, 

161# "b": 1, 

162# "?": 1, 

163# "h": 2, 

164# "H": 2, 

165# "i": 4, 

166# "I": 4, 

167# "l": 4, 

168# "L": 4, 

169# "q": 8, 

170# "Q": 8, 

171# "n": 8, 

172# "N": 8, 

173# "e": 2, 

174# "f": 4, 

175# "d": 8, 

176# "s": 1, 

177# "p": 1, 

178# "P": 8, 

179# } 

180 

181DEFAULT_FORMATS = { 

182 str: "s", 

183 int: "i", 

184 float: "f", 

185 bool: "?", # b 

186} 

187 

188 

189def BinaryField(klass: typing.Type[T], **kw: typing.Any) -> T: 

190 """ 

191 Fields for BinaryConfig can not be annotated like a regular typed config, \ 

192 because more info is required (such as struct format/type and length). 

193 

194 This actually returns a _BinaryField but when using load/load_into, the value will be replaced with type 'klass'. 

195 

196 Args: 

197 klass (type): the final type the value will have 

198 format (str): either one of the formats of struct (e.g. 10s) or a loadable format (json, toml, yaml etc.) 

199 length (int): how many bytes of data to store? (required for str, unless you enter it in format already) 

200 

201 Usage: 

202 class MyConfig(BinaryConfig): 

203 string = BinaryField(str, length=5) # string of 5 characters 

204 integer = BinaryField(int) 

205 complex = BinaryField(OtherClass, format='json', length=64) 

206 # will extract 64 bytes of string and try to convert to the linked class 

207 # (using regular typeconfig logic) 

208 """ 

209 special = None 

210 packer = None 

211 

212 if issubclass(klass, BinaryConfig): 

213 from .core import load_into 

214 

215 fmt = "s" # temporarily group as one string 

216 length = kw.get("length", klass._get_length()) 

217 else: 

218 fmt = kw.get("format") or DEFAULT_FORMATS[klass] 

219 if loader := loaders.get(fmt, None): 

220 special = lambda data: loader( 

221 BytesIO(data if isinstance(data, bytes) else data.encode()), Path() 

222 ) # noqa: E731 

223 if _packer := DUMPERS.get(fmt, None): 

224 packer = lambda data: _packer(data, with_top_level_key=False) # noqa: E731 

225 length = kw["length"] 

226 fmt = "s" 

227 elif len(fmt) > 1: 

228 # length in format: 10s 

229 length, fmt = int(fmt[:-1]), fmt[-1] 

230 else: 

231 length = kw.get("length", 1) 

232 

233 field = _BinaryField( 

234 klass, 

235 fmt=fmt, 

236 length=length, 

237 special=special, 

238 packer=packer, 

239 ) 

240 

241 return typing.cast(typing.Type[T], field)