Coverage for src/configuraptor/binary_config.py: 100%

91 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-18 12:33 +0200

1""" 

2Logic to do with parsing bytestrings as configuration (using struct). 

3""" 

4 

5import collections 

6import struct 

7import typing 

8from dataclasses import dataclass 

9from io import BytesIO 

10from pathlib import Path 

11 

12from typing_extensions import Self 

13 

14from . import loaders 

15from .abs import AbstractTypedConfig 

16from .helpers import is_custom_class 

17from .loaders.register import DUMPERS 

18 

19BINARY_TYPES = typing.Union[str, float, int, bool] 

20 

21 

22class BinaryConfig(AbstractTypedConfig): 

23 """ 

24 Inherit this class if you want your config or a section of it to be parsed using struct. 

25 """ 

26 

27 _fields: collections.OrderedDict[str, "_BinaryField"] 

28 

29 def __init__(self) -> None: 

30 """ 

31 Before filling the class with data, we store the fields (BinaryField) for later use. 

32 """ 

33 fields, elements = self._collect_fields() 

34 self._fields = collections.OrderedDict(zip(fields, elements)) 

35 super().__init__() 

36 

37 @classmethod 

38 def _collect_fields(cls) -> tuple[list[str], list["_BinaryField"]]: 

39 """ 

40 Get the class' field names and dataclass instances. 

41 """ 

42 elements: list[_BinaryField] = [] 

43 fields: list[str] = [] 

44 

45 for field, value in cls.__dict__.items(): 

46 if field.startswith("_"): 

47 continue 

48 if not isinstance(value, _BinaryField): 

49 # other data, skip 

50 continue 

51 

52 fields.append(field) 

53 elements.append(value) 

54 

55 return fields, elements 

56 

57 @classmethod 

58 def _parse(cls, data: bytes | dict[str, bytes]) -> dict[str, BINARY_TYPES]: 

59 """ 

60 Parse a bytestring or a dict of bytestrings (in the right order). 

61 """ 

62 from .core import load_into 

63 

64 # NOTE: annotations not used! 

65 fields, elements = cls._collect_fields() 

66 

67 if isinstance(data, dict): 

68 # create one long bytestring of data in the right order: 

69 data = b"".join(data[field] for field in fields) 

70 

71 unpacked = struct.unpack(" ".join(str(_) for _ in elements), data) 

72 final_data: dict[str, BINARY_TYPES] = {} 

73 zipped: typing.Iterable[tuple[str, typing.Any, _BinaryField]] = zip(fields, unpacked, elements) 

74 for field, value, meta in zipped: 

75 if isinstance(value, bytes): 

76 value = value.strip(b"\x00").decode() 

77 

78 if meta.special: 

79 value = meta.special(value) 

80 

81 if is_custom_class(meta.klass): 

82 value = load_into(meta.klass, value) 

83 else: 

84 # ensure it's the right class (e.g. bool): 

85 value = meta.klass(value) 

86 

87 final_data[field] = value 

88 

89 return final_data 

90 

91 @classmethod 

92 def _parse_into(cls, data: bytes | dict[str, bytes]) -> Self: 

93 """ 

94 Create a new instance based on data. 

95 """ 

96 converted = cls._parse(data) 

97 inst = cls() 

98 inst.__dict__.update(**converted) 

99 return inst 

100 

101 def _pack(self) -> bytes: 

102 """ 

103 Pack an instance back into a bytestring. 

104 """ 

105 fmt = " ".join(str(_) for _ in self._fields.values()) 

106 

107 values = [self._fields[k].pack(v) for k, v in self.__dict__.items() if not k.startswith("_")] 

108 return struct.pack(fmt, *values) 

109 

110 

111@dataclass 

112class _BinaryField: 

113 """ 

114 Class that stores info to parse the value from a bytestring. 

115 

116 Returned by BinaryField, but overwritten on instances with the actual value of type klass. 

117 """ 

118 

119 klass: typing.Type[typing.Any] 

120 length: int 

121 fmt: str 

122 special: typing.Callable[[typing.Any], typing.Any] | None 

123 packer: typing.Callable[[typing.Any], typing.Any] | None 

124 

125 def __str__(self) -> str: 

126 return f"{self.length}{self.fmt}" 

127 

128 def pack(self, value: typing.Any) -> typing.Any: 

129 if self.packer: 

130 value = self.packer(value) 

131 if isinstance(value, str): 

132 return value.encode() 

133 return value 

134 

135 

136T = typing.TypeVar("T") 

137 

138# https://docs.python.org/3/library/struct.html 

139# DEFAULT_LENGTHS = { 

140# "x": 1, 

141# "c": 1, 

142# "b": 1, 

143# "?": 1, 

144# "h": 2, 

145# "H": 2, 

146# "i": 4, 

147# "I": 4, 

148# "l": 4, 

149# "L": 4, 

150# "q": 8, 

151# "Q": 8, 

152# "n": 8, 

153# "N": 8, 

154# "e": 2, 

155# "f": 4, 

156# "d": 8, 

157# "s": 1, 

158# "p": 1, 

159# "P": 8, 

160# } 

161 

162DEFAULT_FORMATS = { 

163 str: "s", 

164 int: "i", 

165 float: "f", 

166 bool: "b", 

167} 

168 

169 

170def BinaryField(klass: typing.Type[T], **kw: typing.Any) -> typing.Type[T]: 

171 """ 

172 Fields for BinaryConfig can not be annotated like a regular typed config, \ 

173 because more info is required (such as struct format/type and length). 

174 

175 This actually returns a _BinaryField but when using load/load_into, the value will be replaced with type 'klass'. 

176 

177 Args: 

178 klass (type): the final type the value will have 

179 format (str): either one of the formats of struct (e.g. 10s) or a loadable format (json, toml, yaml etc.) 

180 length (int): how many bytes of data to store? (required for str, unless you enter it in format already) 

181 

182 Usage: 

183 class MyConfig(BinaryConfig): 

184 string = BinaryField(str, length=5) # string of 5 characters 

185 integer = BinaryField(int) 

186 complex = BinaryField(OtherClass, format='json', length=64) 

187 # will extract 64 bytes of string and try to convert to the linked class 

188 # (using regular typeconfig logic) 

189 """ 

190 fmt = kw.get("format") or DEFAULT_FORMATS[klass] 

191 special = None 

192 packer = None 

193 if loader := loaders.get(fmt, None): 

194 special = lambda data: loader(BytesIO(data if isinstance(data, bytes) else data.encode()), Path()) # noqa: E731 

195 if _packer := DUMPERS.get(fmt, None): 

196 packer = lambda data: _packer(data, with_top_level_key=False) # noqa: E731 

197 length = kw["length"] 

198 fmt = "s" 

199 elif len(fmt) > 1: 

200 # length in format: 10s 

201 length, fmt = int(fmt[:-1]), fmt[-1] 

202 else: 

203 length = kw.get("length", 1) 

204 

205 field = _BinaryField( 

206 klass, 

207 fmt=fmt, 

208 length=length, 

209 special=special, 

210 packer=packer, 

211 ) 

212 

213 return typing.cast(typing.Type[T], field)