Coverage for tests\unit\json_serialize\serializable_dataclass\test_sdc_defaults.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-09 01:48 -0600

1from __future__ import annotations 

2 

3 

4from muutils.json_serialize import ( 

5 JsonSerializer, 

6 SerializableDataclass, 

7 serializable_dataclass, 

8 serializable_field, 

9) 

10 

11# pylint: disable=missing-class-docstring 

12 

13 

14@serializable_dataclass 

15class Config(SerializableDataclass): 

16 name: str = serializable_field(default="default_name") 

17 save_model: bool = serializable_field(default=True) 

18 batch_size: int = serializable_field(default=64) 

19 

20 

21def test_sdc_empty(): 

22 instance = Config() 

23 serialized = instance.serialize() 

24 assert serialized == { 

25 "name": "default_name", 

26 "save_model": True, 

27 "batch_size": 64, 

28 "__format__": "Config(SerializableDataclass)", 

29 } 

30 recovered = Config.load(serialized) 

31 assert recovered == instance 

32 

33 

34def test_sdc_strip_format_jser(): 

35 instance = Config() 

36 jser: JsonSerializer = JsonSerializer(write_only_format=True) 

37 serialized = jser.json_serialize(instance) 

38 assert serialized == { 

39 "name": "default_name", 

40 "save_model": True, 

41 "batch_size": 64, 

42 "__write_format__": "Config(SerializableDataclass)", 

43 } 

44 recovered = Config.load(serialized) 

45 assert recovered == instance 

46 

47 

48TYPE_MAP: dict[str, type] = {x.__name__: x for x in [int, float, str, bool]} 

49 

50 

51@serializable_dataclass 

52class ComplicatedConfig(SerializableDataclass): 

53 name: str = serializable_field(default="default_name") 

54 save_model: bool = serializable_field(default=True) 

55 batch_size: int = serializable_field(default=64) 

56 

57 some_type: type = serializable_field( 

58 default=float, 

59 serialization_fn=lambda x: x.__name__, 

60 loading_fn=lambda data: TYPE_MAP[data["some_type"]], 

61 ) 

62 

63 

64def test_sdc_empty_complicated(): 

65 instance = ComplicatedConfig() 

66 serialized = instance.serialize() 

67 recovered = ComplicatedConfig.load(serialized) 

68 assert recovered == instance