Coverage for tests\unit\json_serialize\serializable_dataclass\test_serializable_dataclass.py: 100%

204 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-09 01:48 -0600

1from __future__ import annotations 

2 

3import typing 

4from typing import Any 

5 

6import pytest 

7 

8from muutils.json_serialize import ( 

9 SerializableDataclass, 

10 serializable_dataclass, 

11 serializable_field, 

12) 

13 

14from muutils.json_serialize.serializable_dataclass import ( 

15 FieldIsNotInitOrSerializeWarning, 

16) 

17 

18# pylint: disable=missing-class-docstring, unused-variable 

19 

20 

21@serializable_dataclass 

22class BasicAutofields(SerializableDataclass): 

23 a: str 

24 b: int 

25 c: typing.List[int] 

26 

27 

28def test_basic_auto_fields(): 

29 data = dict(a="hello", b=42, c=[1, 2, 3]) 

30 instance = BasicAutofields(**data) 

31 data_with_format = data.copy() 

32 data_with_format["__format__"] = "BasicAutofields(SerializableDataclass)" 

33 assert instance.serialize() == data_with_format 

34 assert instance == instance 

35 assert instance.diff(instance) == {} 

36 

37 

38def test_basic_diff(): 

39 instance_1 = BasicAutofields(a="hello", b=42, c=[1, 2, 3]) 

40 instance_2 = BasicAutofields(a="goodbye", b=42, c=[1, 2, 3]) 

41 instance_3 = BasicAutofields(a="hello", b=-1, c=[1, 2, 3]) 

42 instance_4 = BasicAutofields(a="hello", b=-1, c=[42]) 

43 

44 assert instance_1.diff(instance_2) == {"a": {"self": "hello", "other": "goodbye"}} 

45 assert instance_1.diff(instance_3) == {"b": {"self": 42, "other": -1}} 

46 assert instance_1.diff(instance_4) == { 

47 "b": {"self": 42, "other": -1}, 

48 "c": {"self": [1, 2, 3], "other": [42]}, 

49 } 

50 assert instance_1.diff(instance_1) == {} 

51 assert instance_2.diff(instance_3) == { 

52 "a": {"self": "goodbye", "other": "hello"}, 

53 "b": {"self": 42, "other": -1}, 

54 } 

55 

56 

57@serializable_dataclass 

58class SimpleFields(SerializableDataclass): 

59 d: str 

60 e: int = 42 

61 f: typing.List[int] = serializable_field(default_factory=list) # noqa: F821 

62 

63 

64@serializable_dataclass 

65class FieldOptions(SerializableDataclass): 

66 a: str = serializable_field() 

67 b: str = serializable_field() 

68 c: str = serializable_field(init=False, serialize=False, repr=False, compare=False) 

69 d: str = serializable_field( 

70 serialization_fn=lambda x: x.upper(), loading_fn=lambda x: x["d"].lower() 

71 ) 

72 

73 

74@serializable_dataclass(properties_to_serialize=["full_name"]) 

75class WithProperty(SerializableDataclass): 

76 first_name: str 

77 last_name: str 

78 

79 @property 

80 def full_name(self) -> str: 

81 return f"{self.first_name} {self.last_name}" 

82 

83 

84class Child(FieldOptions, WithProperty): 

85 pass 

86 

87 

88@pytest.fixture 

89def simple_fields_instance(): 

90 return SimpleFields(d="hello", e=42, f=[1, 2, 3]) 

91 

92 

93@pytest.fixture 

94def field_options_instance(): 

95 return FieldOptions(a="hello", b="world", d="case") 

96 

97 

98@pytest.fixture 

99def with_property_instance(): 

100 return WithProperty(first_name="John", last_name="Doe") 

101 

102 

103def test_simple_fields_serialization(simple_fields_instance): 

104 serialized = simple_fields_instance.serialize() 

105 assert serialized == { 

106 "d": "hello", 

107 "e": 42, 

108 "f": [1, 2, 3], 

109 "__format__": "SimpleFields(SerializableDataclass)", 

110 } 

111 

112 

113def test_simple_fields_loading(simple_fields_instance): 

114 serialized = simple_fields_instance.serialize() 

115 

116 loaded = SimpleFields.load(serialized) 

117 

118 assert loaded == simple_fields_instance 

119 assert loaded.diff(simple_fields_instance) == {} 

120 assert simple_fields_instance.diff(loaded) == {} 

121 

122 

123def test_field_options_serialization(field_options_instance): 

124 serialized = field_options_instance.serialize() 

125 assert serialized == { 

126 "a": "hello", 

127 "b": "world", 

128 "d": "CASE", 

129 "__format__": "FieldOptions(SerializableDataclass)", 

130 } 

131 

132 

133def test_field_options_loading(field_options_instance): 

134 # ignore a `FieldIsNotInitOrSerializeWarning` 

135 serialized = field_options_instance.serialize() 

136 with pytest.warns(FieldIsNotInitOrSerializeWarning): 

137 loaded = FieldOptions.load(serialized) 

138 assert loaded == field_options_instance 

139 

140 

141def test_with_property_serialization(with_property_instance): 

142 serialized = with_property_instance.serialize() 

143 assert serialized == { 

144 "first_name": "John", 

145 "last_name": "Doe", 

146 "full_name": "John Doe", 

147 "__format__": "WithProperty(SerializableDataclass)", 

148 } 

149 

150 

151def test_with_property_loading(with_property_instance): 

152 serialized = with_property_instance.serialize() 

153 loaded = WithProperty.load(serialized) 

154 assert loaded == with_property_instance 

155 

156 

157@serializable_dataclass 

158class Address(SerializableDataclass): 

159 street: str 

160 city: str 

161 zip_code: str 

162 

163 

164@serializable_dataclass 

165class Person(SerializableDataclass): 

166 name: str 

167 age: int 

168 address: Address 

169 

170 

171@pytest.fixture 

172def address_instance(): 

173 return Address(street="123 Main St", city="New York", zip_code="10001") 

174 

175 

176@pytest.fixture 

177def person_instance(address_instance): 

178 return Person(name="John Doe", age=30, address=address_instance) 

179 

180 

181def test_nested_serialization(person_instance): 

182 serialized = person_instance.serialize() 

183 expected_ser = { 

184 "name": "John Doe", 

185 "age": 30, 

186 "address": { 

187 "street": "123 Main St", 

188 "city": "New York", 

189 "zip_code": "10001", 

190 "__format__": "Address(SerializableDataclass)", 

191 }, 

192 "__format__": "Person(SerializableDataclass)", 

193 } 

194 assert serialized == expected_ser 

195 

196 

197def test_nested_loading(person_instance): 

198 serialized = person_instance.serialize() 

199 loaded = Person.load(serialized) 

200 assert loaded == person_instance 

201 assert loaded.address == person_instance.address 

202 

203 

204def test_with_printing(): 

205 @serializable_dataclass(properties_to_serialize=["full_name"]) 

206 class MyClass(SerializableDataclass): 

207 name: str 

208 age: int = serializable_field( 

209 serialization_fn=lambda x: x + 1, loading_fn=lambda x: x["age"] - 1 

210 ) 

211 items: list = serializable_field(default_factory=list) 

212 

213 @property 

214 def full_name(self) -> str: 

215 return f"{self.name} Doe" 

216 

217 # Usage 

218 my_instance = MyClass(name="John", age=30, items=["apple", "banana"]) 

219 serialized_data = my_instance.serialize() 

220 print(serialized_data) 

221 

222 loaded_instance = MyClass.load(serialized_data) 

223 print(loaded_instance) 

224 

225 

226def test_simple_class_serialization(): 

227 @serializable_dataclass 

228 class SimpleClass(SerializableDataclass): 

229 a: int 

230 b: str 

231 

232 simple = SimpleClass(a=42, b="hello") 

233 serialized = simple.serialize() 

234 assert serialized == { 

235 "a": 42, 

236 "b": "hello", 

237 "__format__": "SimpleClass(SerializableDataclass)", 

238 } 

239 

240 loaded = SimpleClass.load(serialized) 

241 assert loaded == simple 

242 

243 

244def test_error_when_init_and_not_serialize(): 

245 with pytest.raises(ValueError): 

246 

247 @serializable_dataclass 

248 class SimpleClass(SerializableDataclass): 

249 a: int = serializable_field(init=True, serialize=False) 

250 

251 

252def test_person_serialization(): 

253 @serializable_dataclass(properties_to_serialize=["full_name"]) 

254 class FullPerson(SerializableDataclass): 

255 name: str = serializable_field() 

256 age: int = serializable_field(default=-1) 

257 items: typing.List[str] = serializable_field(default_factory=list) 

258 

259 @property 

260 def full_name(self) -> str: 

261 return f"{self.name} Doe" 

262 

263 person = FullPerson(name="John", items=["apple", "banana"]) 

264 serialized = person.serialize() 

265 expected_ser = { 

266 "name": "John", 

267 "age": -1, 

268 "items": ["apple", "banana"], 

269 "full_name": "John Doe", 

270 "__format__": "FullPerson(SerializableDataclass)", 

271 } 

272 assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" 

273 

274 loaded = FullPerson.load(serialized) 

275 

276 assert loaded == person 

277 

278 

279def test_custom_serialization(): 

280 @serializable_dataclass 

281 class CustomSerialization(SerializableDataclass): 

282 data: Any = serializable_field( 

283 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["data"] // 2 

284 ) 

285 

286 custom = CustomSerialization(data=5) 

287 serialized = custom.serialize() 

288 assert serialized == { 

289 "data": 10, 

290 "__format__": "CustomSerialization(SerializableDataclass)", 

291 } 

292 

293 loaded = CustomSerialization.load(serialized) 

294 assert loaded == custom 

295 

296 

297@serializable_dataclass 

298class Nested_with_Container(SerializableDataclass): 

299 val_int: int 

300 val_str: str 

301 val_list: typing.List[BasicAutofields] = serializable_field( 

302 default_factory=list, 

303 serialization_fn=lambda x: [y.serialize() for y in x], 

304 loading_fn=lambda x: [BasicAutofields.load(y) for y in x["val_list"]], 

305 ) 

306 

307 

308def test_nested_with_container(): 

309 instance = Nested_with_Container( 

310 val_int=42, 

311 val_str="hello", 

312 val_list=[ 

313 BasicAutofields(a="a", b=1, c=[1, 2, 3]), 

314 BasicAutofields(a="b", b=2, c=[4, 5, 6]), 

315 ], 

316 ) 

317 

318 serialized = instance.serialize() 

319 expected_ser = { 

320 "val_int": 42, 

321 "val_str": "hello", 

322 "val_list": [ 

323 { 

324 "a": "a", 

325 "b": 1, 

326 "c": [1, 2, 3], 

327 "__format__": "BasicAutofields(SerializableDataclass)", 

328 }, 

329 { 

330 "a": "b", 

331 "b": 2, 

332 "c": [4, 5, 6], 

333 "__format__": "BasicAutofields(SerializableDataclass)", 

334 }, 

335 ], 

336 "__format__": "Nested_with_Container(SerializableDataclass)", 

337 } 

338 

339 assert serialized == expected_ser 

340 

341 loaded = Nested_with_Container.load(serialized) 

342 

343 assert loaded == instance 

344 

345 

346class Custom_class_with_serialization: 

347 """custom class which doesnt inherit but does serialize""" 

348 

349 def __init__(self, a: int, b: str): 

350 self.a: int = a 

351 self.b: str = b 

352 

353 def serialize(self): 

354 return {"a": self.a, "b": self.b} 

355 

356 @classmethod 

357 def load(cls, data): 

358 return cls(data["a"], data["b"]) 

359 

360 def __eq__(self, other): 

361 return (self.a == other.a) and (self.b == other.b) 

362 

363 

364@serializable_dataclass 

365class nested_custom(SerializableDataclass): 

366 value: float 

367 data1: Custom_class_with_serialization 

368 

369 

370def test_nested_custom(recwarn): # this will send some warnings but whatever 

371 instance = nested_custom( 

372 value=42.0, data1=Custom_class_with_serialization(1, "hello") 

373 ) 

374 serialized = instance.serialize() 

375 expected_ser = { 

376 "value": 42.0, 

377 "data1": {"a": 1, "b": "hello"}, 

378 "__format__": "nested_custom(SerializableDataclass)", 

379 } 

380 assert serialized == expected_ser 

381 loaded = nested_custom.load(serialized) 

382 assert loaded == instance 

383 

384 

385def test_deserialize_fn(): 

386 @serializable_dataclass 

387 class DeserializeFn(SerializableDataclass): 

388 data: int = serializable_field( 

389 serialization_fn=lambda x: str(x), 

390 deserialize_fn=lambda x: int(x), 

391 ) 

392 

393 instance = DeserializeFn(data=5) 

394 serialized = instance.serialize() 

395 assert serialized == { 

396 "data": "5", 

397 "__format__": "DeserializeFn(SerializableDataclass)", 

398 } 

399 

400 loaded = DeserializeFn.load(serialized) 

401 assert loaded == instance 

402 assert loaded.data == 5