Coverage for tests / unit / no_torch / test_zanj_serializable_dataclass.py: 100%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-21 22:18 -0700

1from __future__ import annotations 

2 

3import json 

4import sys 

5import typing 

6from pathlib import Path 

7 

8import numpy as np 

9import pandas as pd # type: ignore[import] 

10from muutils.json_serialize import ( 

11 SerializableDataclass, 

12 serializable_dataclass, 

13 serializable_field, 

14) 

15 

16from zanj import ZANJ 

17 

18np.random.seed(0) 

19 

20TEST_DATA_PATH: Path = Path("tests/junk_data") 

21 

22SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10)) 

23 

24 

25@serializable_dataclass 

26class BasicZanj(SerializableDataclass): 

27 a: str 

28 q: int = 42 

29 c: typing.List[int] = serializable_field(default_factory=list) 

30 

31 

32def test_Basic(): 

33 instance = BasicZanj("hello", 42, [1, 2, 3]) 

34 

35 z = ZANJ() 

36 path = TEST_DATA_PATH / "test_BasicZanj.zanj" 

37 z.save(instance, path) 

38 recovered = z.read(path) 

39 assert instance == recovered 

40 

41 

42@serializable_dataclass 

43class Nested(SerializableDataclass): 

44 name: str 

45 basic: BasicZanj 

46 val: float 

47 

48 

49def test_Nested(): 

50 instance = Nested("hello", BasicZanj("hello", 42, [1, 2, 3]), 3.14) 

51 

52 z = ZANJ() 

53 path = TEST_DATA_PATH / "test_Nested.zanj" 

54 z.save(instance, path) 

55 recovered = z.read(path) 

56 assert instance == recovered 

57 

58 

59@serializable_dataclass 

60class Nested_with_container(SerializableDataclass): 

61 name: str 

62 basic: BasicZanj 

63 val: float 

64 container: typing.List[Nested] = serializable_field(default_factory=list) 

65 

66 

67def test_Nested_with_container(): 

68 instance = Nested_with_container( 

69 "hello", 

70 basic=BasicZanj("hello", 42, [1, 2, 3]), 

71 val=3.14, 

72 container=[ 

73 Nested("n1", BasicZanj("n1_b", 123, [4, 5, 7]), 2.71), 

74 Nested("n2", BasicZanj("n2_b", 456, [7, 8, 9]), 6.28), 

75 ], 

76 ) 

77 

78 z = ZANJ() 

79 path = TEST_DATA_PATH / "test_Nested_with_container.zanj" 

80 z.save(instance, path) 

81 recovered = z.read(path) 

82 assert instance == recovered 

83 

84 

85@serializable_dataclass 

86class sdc_with_np_array(SerializableDataclass): 

87 name: str 

88 arr1: np.ndarray 

89 arr2: np.ndarray 

90 

91 

92def test_sdc_with_np_array_small(): 

93 instance = sdc_with_np_array("small arrays", np.random.rand(10), np.random.rand(20)) 

94 

95 z = ZANJ() 

96 path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj" 

97 z.save(instance, path) 

98 recovered = z.read(path) 

99 assert instance == recovered 

100 

101 

102def test_sdc_with_np_array(): 

103 instance = sdc_with_np_array( 

104 "bigger arrays", np.random.rand(128, 128), np.random.rand(256, 256) 

105 ) 

106 

107 z = ZANJ() 

108 path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj" 

109 z.save(instance, path) 

110 recovered = z.read(path) 

111 assert instance == recovered 

112 

113 

114@serializable_dataclass 

115class sdc_with_df(SerializableDataclass): 

116 name: str 

117 iris_data: pd.DataFrame 

118 brain_data: pd.DataFrame 

119 

120 

121def test_sdc_with_df(): 

122 instance = sdc_with_df( 

123 "downloaded_data", 

124 iris_data=pd.read_csv("tests/input_data/iris.csv"), 

125 brain_data=pd.read_csv("tests/input_data/brain_networks.csv"), 

126 ) 

127 

128 z = ZANJ() 

129 path = TEST_DATA_PATH / "test_sdc_with_df.zanj" 

130 z.save(instance, path) 

131 recovered = z.read(path) 

132 assert instance == recovered 

133 

134 

135@serializable_dataclass 

136class sdc_container_explicit(SerializableDataclass): 

137 name: str 

138 container: typing.List[Nested] = serializable_field( 

139 default_factory=list, 

140 # as jsonl string, for whatever reason 

141 serialization_fn=lambda c: "\n".join([json.dumps(n.serialize()) for n in c]), 

142 loading_fn=lambda data: [ 

143 Nested.load(json.loads(n)) for n in data["container"].split("\n") 

144 ], 

145 # TODO: explicitly specifying the following does not work, since it gets automatically converted before we call load in `loading_fn`: 

146 # serialization_fn=lambda c: [n.serialize() for n in c], 

147 # loading_fn=lambda data: [Nested.load(n) for n in data["container"]], 

148 ) 

149 

150 

151def test_sdc_container_explicit(): 

152 instance = sdc_container_explicit( 

153 "container explicit", 

154 container=[ 

155 Nested( 

156 f"n-{n}", 

157 BasicZanj(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]), 

158 n * np.pi, 

159 ) 

160 for n in range(10) 

161 ], 

162 ) 

163 

164 z = ZANJ() 

165 path = TEST_DATA_PATH / "test_sdc_container_explicit.zanj" 

166 z.save(instance, path) 

167 recovered = z.read(path) 

168 assert instance == recovered