Coverage for tests / unit / no_torch / test_load_item_recursive.py: 99%

86 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 19:31 -0700

1from __future__ import annotations 

2 

3import typing 

4from pathlib import Path 

5 

6import numpy as np 

7import pytest 

8from muutils.errormode import ErrorMode 

9from muutils.json_serialize import ( 

10 SerializableDataclass, 

11 serializable_dataclass, 

12 serializable_field, 

13) 

14from zanj import ZANJ 

15from zanj.consts import _FORMAT_KEY 

16from zanj.loading import LoadedZANJ, load_item_recursive 

17 

18TEST_DATA_PATH: Path = Path("tests/junk_data") 

19 

20 

21def test_load_item_recursive_basic(): 

22 """Test basic functionality of load_item_recursive""" 

23 # Simple JSON data 

24 json_data = { 

25 "name": "test", 

26 "value": 42, 

27 "list": [1, 2, 3], 

28 "nested": {"a": 1, "b": 2}, 

29 } 

30 

31 # Load with default parameters 

32 result = load_item_recursive(json_data, tuple(), None) 

33 

34 # Check the result 

35 assert result == json_data 

36 assert result["name"] == "test" 

37 assert result["value"] == 42 

38 assert result["list"] == [1, 2, 3] 

39 assert result["nested"] == {"a": 1, "b": 2} 

40 

41 

42def test_load_item_recursive_numpy_array(): 

43 """Test loading a numpy array""" 

44 # Create a JSON representation of a numpy array properly formatted 

45 array_data = np.random.rand(5, 5) 

46 json_data = { 

47 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use the correct format suffix 

48 "dtype": str(array_data.dtype), 

49 "shape": list(array_data.shape), 

50 "data": array_data.tolist(), 

51 } 

52 

53 # Load with default parameters 

54 result = load_item_recursive(json_data, tuple(), None) 

55 

56 # Check the result 

57 assert isinstance(result, np.ndarray) 

58 assert result.shape == tuple(json_data["shape"]) 

59 assert result.dtype == np.dtype(json_data["dtype"]) 

60 assert np.allclose(result, array_data) 

61 

62 

63def test_load_item_recursive_serializable_dataclass(): 

64 """Test loading a SerializableDataclass""" 

65 

66 @serializable_dataclass 

67 class TestClass(SerializableDataclass): 

68 name: str 

69 value: int 

70 data: typing.List[int] = serializable_field(default_factory=list) 

71 

72 # Create an instance and serialize it 

73 instance = TestClass("test", 42, [1, 2, 3]) 

74 serialized = instance.serialize() 

75 

76 # Load with default parameters 

77 result = load_item_recursive(serialized, tuple(), None) 

78 

79 # Check the result 

80 assert isinstance(result, TestClass) 

81 assert result.name == "test" 

82 assert result.value == 42 

83 assert result.data == [1, 2, 3] 

84 

85 

86def test_load_item_recursive_nested_container(): 

87 """Test loading with nested containers""" 

88 # Create a complex nested structure with properly formatted arrays 

89 json_data = { 

90 "name": "test", 

91 "arrays": [ 

92 { 

93 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix 

94 "dtype": "float64", 

95 "shape": [3, 3], 

96 "data": np.random.rand(3, 3).tolist(), 

97 }, 

98 { 

99 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix 

100 "dtype": "float64", 

101 "shape": [2, 2], 

102 "data": np.random.rand(2, 2).tolist(), 

103 }, 

104 ], 

105 "nested": { 

106 "dict_with_array": { 

107 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix 

108 "dtype": "float64", 

109 "shape": [4, 4], 

110 "data": np.random.rand(4, 4).tolist(), 

111 } 

112 }, 

113 } 

114 

115 # Load with default parameters 

116 result = load_item_recursive(json_data, tuple(), None) 

117 

118 # Check the result 

119 assert result["name"] == "test" 

120 assert len(result["arrays"]) == 2 

121 assert isinstance(result["arrays"][0], np.ndarray) 

122 assert isinstance(result["arrays"][1], np.ndarray) 

123 assert result["arrays"][0].shape == (3, 3) 

124 assert result["arrays"][1].shape == (2, 2) 

125 assert isinstance(result["nested"]["dict_with_array"], np.ndarray) 

126 assert result["nested"]["dict_with_array"].shape == (4, 4) 

127 

128 

129def test_load_item_recursive_unknown_format(): 

130 """Test loading with an unknown format key""" 

131 # Create JSON data with an unknown format that is not registered in the handlers 

132 json_data = { 

133 _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist", 

134 "data": [1, 2, 3], 

135 } 

136 

137 # Load with default parameters (should return the JSON as is) 

138 result = load_item_recursive(json_data, tuple(), None, allow_not_loading=True) 

139 

140 # Check the result 

141 assert result == json_data 

142 

143 # TODO: this doesn't raise any errors 

144 # Test with allow_not_loading=False (should raise an error) 

145 # Create a ZANJ with EXCEPT error mode to ensure value errors are raised 

146 z = ZANJ(error_mode=ErrorMode.EXCEPT) 

147 load_item_recursive( 

148 json_data, tuple(), z, error_mode=ErrorMode.EXCEPT, allow_not_loading=False 

149 ) 

150 

151 

152def test_load_item_recursive_with_external_reference(): 

153 """Test loading an item with an external reference""" 

154 # Create a ZANJ object and save some data to create externals 

155 z = ZANJ(external_array_threshold=10) 

156 data = {"large_array": np.random.rand(20, 20)} 

157 path = TEST_DATA_PATH / "test_load_item_recursive_external.zanj" 

158 z.save(data, path) 

159 

160 # Load the ZANJ file 

161 loaded_zanj = LoadedZANJ(path, z) 

162 

163 # Try loading the data 

164 loaded_zanj.populate_externals() 

165 

166 # Check that the externals were populated 

167 assert len(loaded_zanj._externals) > 0 

168 

169 # Verify JSON data structure 

170 assert "_REF_KEY" in loaded_zanj._json_data or isinstance( 

171 loaded_zanj._json_data, dict 

172 ) 

173 

174 

175def test_load_item_recursive_error_modes(): 

176 """Test different error modes""" 

177 # Create JSON data with an unknown format 

178 json_data = { 

179 _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist", 

180 "data": [1, 2, 3], 

181 } 

182 

183 # Test WARN mode (should not raise, just return the data) 

184 result = load_item_recursive( 

185 json_data, tuple(), None, error_mode=ErrorMode.WARN, allow_not_loading=True 

186 ) 

187 assert result == json_data 

188 

189 # Test IGNORE mode (should not raise, just return the data) 

190 result = load_item_recursive( 

191 json_data, tuple(), None, error_mode=ErrorMode.IGNORE, allow_not_loading=True 

192 ) 

193 assert result == json_data 

194 

195 # Create a custom class that's known to fail during loading 

196 class CustomHandler: 

197 def check(self, json_item, path=None, z=None): 

198 return ( 

199 json_item.get(_FORMAT_KEY) 

200 == "unknown.format.that.definitely.does.not.exist" 

201 ) 

202 

203 def load(self, json_item, path=None, z=None): 

204 # This will raise a ValueError 

205 raise ValueError("Forced error for testing purposes") 

206 

207 # Register this handler temporarily 

208 import zanj.loading 

209 

210 original_get_item_loader = zanj.loading.get_item_loader 

211 

212 def mock_get_item_loader(*args, **kwargs): 

213 # Always return our custom handler 

214 return CustomHandler() 

215 

216 try: 

217 # Override the get_item_loader function 

218 zanj.loading.get_item_loader = mock_get_item_loader 

219 

220 # Test EXCEPT mode (should raise) 

221 with pytest.raises(ValueError): 

222 load_item_recursive( 

223 json_data, 

224 tuple(), 

225 None, 

226 error_mode=ErrorMode.EXCEPT, 

227 allow_not_loading=True, 

228 ) 

229 finally: 

230 # Restore the original function 

231 zanj.loading.get_item_loader = original_get_item_loader