Coverage for tests/unit/no_torch/test_load_item_recursive.py: 99%
86 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-19 14:57 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-19 14:57 -0600
1from __future__ import annotations
3import typing
4from pathlib import Path
6import numpy as np
7import pytest
8from muutils.errormode import ErrorMode
9from muutils.json_serialize import (
10 SerializableDataclass,
11 serializable_dataclass,
12 serializable_field,
13)
14from muutils.json_serialize.util import _FORMAT_KEY
16from zanj import ZANJ
17from zanj.loading import LoadedZANJ, load_item_recursive
19TEST_DATA_PATH: Path = Path("tests/junk_data")
22def test_load_item_recursive_basic():
23 """Test basic functionality of load_item_recursive"""
24 # Simple JSON data
25 json_data = {
26 "name": "test",
27 "value": 42,
28 "list": [1, 2, 3],
29 "nested": {"a": 1, "b": 2},
30 }
32 # Load with default parameters
33 result = load_item_recursive(json_data, tuple(), None)
35 # Check the result
36 assert result == json_data
37 assert result["name"] == "test"
38 assert result["value"] == 42
39 assert result["list"] == [1, 2, 3]
40 assert result["nested"] == {"a": 1, "b": 2}
43def test_load_item_recursive_numpy_array():
44 """Test loading a numpy array"""
45 # Create a JSON representation of a numpy array properly formatted
46 array_data = np.random.rand(5, 5)
47 json_data = {
48 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use the correct format suffix
49 "dtype": str(array_data.dtype),
50 "shape": list(array_data.shape),
51 "data": array_data.tolist(),
52 }
54 # Load with default parameters
55 result = load_item_recursive(json_data, tuple(), None)
57 # Check the result
58 assert isinstance(result, np.ndarray)
59 assert result.shape == tuple(json_data["shape"])
60 assert result.dtype == np.dtype(json_data["dtype"])
61 assert np.allclose(result, array_data)
64def test_load_item_recursive_serializable_dataclass():
65 """Test loading a SerializableDataclass"""
67 @serializable_dataclass
68 class TestClass(SerializableDataclass):
69 name: str
70 value: int
71 data: typing.List[int] = serializable_field(default_factory=list)
73 # Create an instance and serialize it
74 instance = TestClass("test", 42, [1, 2, 3])
75 serialized = instance.serialize()
77 # Load with default parameters
78 result = load_item_recursive(serialized, tuple(), None)
80 # Check the result
81 assert isinstance(result, TestClass)
82 assert result.name == "test"
83 assert result.value == 42
84 assert result.data == [1, 2, 3]
87def test_load_item_recursive_nested_container():
88 """Test loading with nested containers"""
89 # Create a complex nested structure with properly formatted arrays
90 json_data = {
91 "name": "test",
92 "arrays": [
93 {
94 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix
95 "dtype": "float64",
96 "shape": [3, 3],
97 "data": np.random.rand(3, 3).tolist(),
98 },
99 {
100 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix
101 "dtype": "float64",
102 "shape": [2, 2],
103 "data": np.random.rand(2, 2).tolist(),
104 },
105 ],
106 "nested": {
107 "dict_with_array": {
108 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix
109 "dtype": "float64",
110 "shape": [4, 4],
111 "data": np.random.rand(4, 4).tolist(),
112 }
113 },
114 }
116 # Load with default parameters
117 result = load_item_recursive(json_data, tuple(), None)
119 # Check the result
120 assert result["name"] == "test"
121 assert len(result["arrays"]) == 2
122 assert isinstance(result["arrays"][0], np.ndarray)
123 assert isinstance(result["arrays"][1], np.ndarray)
124 assert result["arrays"][0].shape == (3, 3)
125 assert result["arrays"][1].shape == (2, 2)
126 assert isinstance(result["nested"]["dict_with_array"], np.ndarray)
127 assert result["nested"]["dict_with_array"].shape == (4, 4)
130def test_load_item_recursive_unknown_format():
131 """Test loading with an unknown format key"""
132 # Create JSON data with an unknown format that is not registered in the handlers
133 json_data = {
134 _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist",
135 "data": [1, 2, 3],
136 }
138 # Load with default parameters (should return the JSON as is)
139 result = load_item_recursive(json_data, tuple(), None, allow_not_loading=True)
141 # Check the result
142 assert result == json_data
144 # TODO: this doesn't raise any errors
145 # Test with allow_not_loading=False (should raise an error)
146 # Create a ZANJ with EXCEPT error mode to ensure value errors are raised
147 z = ZANJ(error_mode=ErrorMode.EXCEPT)
148 load_item_recursive(
149 json_data, tuple(), z, error_mode=ErrorMode.EXCEPT, allow_not_loading=False
150 )
153def test_load_item_recursive_with_external_reference():
154 """Test loading an item with an external reference"""
155 # Create a ZANJ object and save some data to create externals
156 z = ZANJ(external_array_threshold=10)
157 data = {"large_array": np.random.rand(20, 20)}
158 path = TEST_DATA_PATH / "test_load_item_recursive_external.zanj"
159 z.save(data, path)
161 # Load the ZANJ file
162 loaded_zanj = LoadedZANJ(path, z)
164 # Try loading the data
165 loaded_zanj.populate_externals()
167 # Check that the externals were populated
168 assert len(loaded_zanj._externals) > 0
170 # Verify JSON data structure
171 assert "_REF_KEY" in loaded_zanj._json_data or isinstance(
172 loaded_zanj._json_data, dict
173 )
176def test_load_item_recursive_error_modes():
177 """Test different error modes"""
178 # Create JSON data with an unknown format
179 json_data = {
180 _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist",
181 "data": [1, 2, 3],
182 }
184 # Test WARN mode (should not raise, just return the data)
185 result = load_item_recursive(
186 json_data, tuple(), None, error_mode=ErrorMode.WARN, allow_not_loading=True
187 )
188 assert result == json_data
190 # Test IGNORE mode (should not raise, just return the data)
191 result = load_item_recursive(
192 json_data, tuple(), None, error_mode=ErrorMode.IGNORE, allow_not_loading=True
193 )
194 assert result == json_data
196 # Create a custom class that's known to fail during loading
197 class CustomHandler:
198 def check(self, json_item, path=None, z=None):
199 return (
200 json_item.get(_FORMAT_KEY)
201 == "unknown.format.that.definitely.does.not.exist"
202 )
204 def load(self, json_item, path=None, z=None):
205 # This will raise a ValueError
206 raise ValueError("Forced error for testing purposes")
208 # Register this handler temporarily
209 import zanj.loading
211 original_get_item_loader = zanj.loading.get_item_loader
213 def mock_get_item_loader(*args, **kwargs):
214 # Always return our custom handler
215 return CustomHandler()
217 try:
218 # Override the get_item_loader function
219 zanj.loading.get_item_loader = mock_get_item_loader
221 # Test EXCEPT mode (should raise)
222 with pytest.raises(ValueError):
223 load_item_recursive(
224 json_data,
225 tuple(),
226 None,
227 error_mode=ErrorMode.EXCEPT,
228 allow_not_loading=True,
229 )
230 finally:
231 # Restore the original function
232 zanj.loading.get_item_loader = original_get_item_loader