Coverage for tests/unit/no_torch/test_load_item_recursive.py: 99%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-19 14:57 -0600

1from __future__ import annotations 

2 

3import typing 

4from pathlib import Path 

5 

6import numpy as np 

7import pytest 

8from muutils.errormode import ErrorMode 

9from muutils.json_serialize import ( 

10 SerializableDataclass, 

11 serializable_dataclass, 

12 serializable_field, 

13) 

14from muutils.json_serialize.util import _FORMAT_KEY 

15 

16from zanj import ZANJ 

17from zanj.loading import LoadedZANJ, load_item_recursive 

18 

19TEST_DATA_PATH: Path = Path("tests/junk_data") 

20 

21 

22def test_load_item_recursive_basic(): 

23 """Test basic functionality of load_item_recursive""" 

24 # Simple JSON data 

25 json_data = { 

26 "name": "test", 

27 "value": 42, 

28 "list": [1, 2, 3], 

29 "nested": {"a": 1, "b": 2}, 

30 } 

31 

32 # Load with default parameters 

33 result = load_item_recursive(json_data, tuple(), None) 

34 

35 # Check the result 

36 assert result == json_data 

37 assert result["name"] == "test" 

38 assert result["value"] == 42 

39 assert result["list"] == [1, 2, 3] 

40 assert result["nested"] == {"a": 1, "b": 2} 

41 

42 

43def test_load_item_recursive_numpy_array(): 

44 """Test loading a numpy array""" 

45 # Create a JSON representation of a numpy array properly formatted 

46 array_data = np.random.rand(5, 5) 

47 json_data = { 

48 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use the correct format suffix 

49 "dtype": str(array_data.dtype), 

50 "shape": list(array_data.shape), 

51 "data": array_data.tolist(), 

52 } 

53 

54 # Load with default parameters 

55 result = load_item_recursive(json_data, tuple(), None) 

56 

57 # Check the result 

58 assert isinstance(result, np.ndarray) 

59 assert result.shape == tuple(json_data["shape"]) 

60 assert result.dtype == np.dtype(json_data["dtype"]) 

61 assert np.allclose(result, array_data) 

62 

63 

64def test_load_item_recursive_serializable_dataclass(): 

65 """Test loading a SerializableDataclass""" 

66 

67 @serializable_dataclass 

68 class TestClass(SerializableDataclass): 

69 name: str 

70 value: int 

71 data: typing.List[int] = serializable_field(default_factory=list) 

72 

73 # Create an instance and serialize it 

74 instance = TestClass("test", 42, [1, 2, 3]) 

75 serialized = instance.serialize() 

76 

77 # Load with default parameters 

78 result = load_item_recursive(serialized, tuple(), None) 

79 

80 # Check the result 

81 assert isinstance(result, TestClass) 

82 assert result.name == "test" 

83 assert result.value == 42 

84 assert result.data == [1, 2, 3] 

85 

86 

87def test_load_item_recursive_nested_container(): 

88 """Test loading with nested containers""" 

89 # Create a complex nested structure with properly formatted arrays 

90 json_data = { 

91 "name": "test", 

92 "arrays": [ 

93 { 

94 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix 

95 "dtype": "float64", 

96 "shape": [3, 3], 

97 "data": np.random.rand(3, 3).tolist(), 

98 }, 

99 { 

100 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix 

101 "dtype": "float64", 

102 "shape": [2, 2], 

103 "data": np.random.rand(2, 2).tolist(), 

104 }, 

105 ], 

106 "nested": { 

107 "dict_with_array": { 

108 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix 

109 "dtype": "float64", 

110 "shape": [4, 4], 

111 "data": np.random.rand(4, 4).tolist(), 

112 } 

113 }, 

114 } 

115 

116 # Load with default parameters 

117 result = load_item_recursive(json_data, tuple(), None) 

118 

119 # Check the result 

120 assert result["name"] == "test" 

121 assert len(result["arrays"]) == 2 

122 assert isinstance(result["arrays"][0], np.ndarray) 

123 assert isinstance(result["arrays"][1], np.ndarray) 

124 assert result["arrays"][0].shape == (3, 3) 

125 assert result["arrays"][1].shape == (2, 2) 

126 assert isinstance(result["nested"]["dict_with_array"], np.ndarray) 

127 assert result["nested"]["dict_with_array"].shape == (4, 4) 

128 

129 

130def test_load_item_recursive_unknown_format(): 

131 """Test loading with an unknown format key""" 

132 # Create JSON data with an unknown format that is not registered in the handlers 

133 json_data = { 

134 _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist", 

135 "data": [1, 2, 3], 

136 } 

137 

138 # Load with default parameters (should return the JSON as is) 

139 result = load_item_recursive(json_data, tuple(), None, allow_not_loading=True) 

140 

141 # Check the result 

142 assert result == json_data 

143 

144 # TODO: this doesn't raise any errors 

145 # Test with allow_not_loading=False (should raise an error) 

146 # Create a ZANJ with EXCEPT error mode to ensure value errors are raised 

147 z = ZANJ(error_mode=ErrorMode.EXCEPT) 

148 load_item_recursive( 

149 json_data, tuple(), z, error_mode=ErrorMode.EXCEPT, allow_not_loading=False 

150 ) 

151 

152 

153def test_load_item_recursive_with_external_reference(): 

154 """Test loading an item with an external reference""" 

155 # Create a ZANJ object and save some data to create externals 

156 z = ZANJ(external_array_threshold=10) 

157 data = {"large_array": np.random.rand(20, 20)} 

158 path = TEST_DATA_PATH / "test_load_item_recursive_external.zanj" 

159 z.save(data, path) 

160 

161 # Load the ZANJ file 

162 loaded_zanj = LoadedZANJ(path, z) 

163 

164 # Try loading the data 

165 loaded_zanj.populate_externals() 

166 

167 # Check that the externals were populated 

168 assert len(loaded_zanj._externals) > 0 

169 

170 # Verify JSON data structure 

171 assert "_REF_KEY" in loaded_zanj._json_data or isinstance( 

172 loaded_zanj._json_data, dict 

173 ) 

174 

175 

176def test_load_item_recursive_error_modes(): 

177 """Test different error modes""" 

178 # Create JSON data with an unknown format 

179 json_data = { 

180 _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist", 

181 "data": [1, 2, 3], 

182 } 

183 

184 # Test WARN mode (should not raise, just return the data) 

185 result = load_item_recursive( 

186 json_data, tuple(), None, error_mode=ErrorMode.WARN, allow_not_loading=True 

187 ) 

188 assert result == json_data 

189 

190 # Test IGNORE mode (should not raise, just return the data) 

191 result = load_item_recursive( 

192 json_data, tuple(), None, error_mode=ErrorMode.IGNORE, allow_not_loading=True 

193 ) 

194 assert result == json_data 

195 

196 # Create a custom class that's known to fail during loading 

197 class CustomHandler: 

198 def check(self, json_item, path=None, z=None): 

199 return ( 

200 json_item.get(_FORMAT_KEY) 

201 == "unknown.format.that.definitely.does.not.exist" 

202 ) 

203 

204 def load(self, json_item, path=None, z=None): 

205 # This will raise a ValueError 

206 raise ValueError("Forced error for testing purposes") 

207 

208 # Register this handler temporarily 

209 import zanj.loading 

210 

211 original_get_item_loader = zanj.loading.get_item_loader 

212 

213 def mock_get_item_loader(*args, **kwargs): 

214 # Always return our custom handler 

215 return CustomHandler() 

216 

217 try: 

218 # Override the get_item_loader function 

219 zanj.loading.get_item_loader = mock_get_item_loader 

220 

221 # Test EXCEPT mode (should raise) 

222 with pytest.raises(ValueError): 

223 load_item_recursive( 

224 json_data, 

225 tuple(), 

226 None, 

227 error_mode=ErrorMode.EXCEPT, 

228 allow_not_loading=True, 

229 ) 

230 finally: 

231 # Restore the original function 

232 zanj.loading.get_item_loader = original_get_item_loader