Coverage for tests / unit / no_torch / test_load_item_recursive.py: 99%
86 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 19:31 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 19:31 -0700
1from __future__ import annotations
3import typing
4from pathlib import Path
6import numpy as np
7import pytest
8from muutils.errormode import ErrorMode
9from muutils.json_serialize import (
10 SerializableDataclass,
11 serializable_dataclass,
12 serializable_field,
13)
14from zanj import ZANJ
15from zanj.consts import _FORMAT_KEY
16from zanj.loading import LoadedZANJ, load_item_recursive
18TEST_DATA_PATH: Path = Path("tests/junk_data")
21def test_load_item_recursive_basic():
22 """Test basic functionality of load_item_recursive"""
23 # Simple JSON data
24 json_data = {
25 "name": "test",
26 "value": 42,
27 "list": [1, 2, 3],
28 "nested": {"a": 1, "b": 2},
29 }
31 # Load with default parameters
32 result = load_item_recursive(json_data, tuple(), None)
34 # Check the result
35 assert result == json_data
36 assert result["name"] == "test"
37 assert result["value"] == 42
38 assert result["list"] == [1, 2, 3]
39 assert result["nested"] == {"a": 1, "b": 2}
42def test_load_item_recursive_numpy_array():
43 """Test loading a numpy array"""
44 # Create a JSON representation of a numpy array properly formatted
45 array_data = np.random.rand(5, 5)
46 json_data = {
47 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use the correct format suffix
48 "dtype": str(array_data.dtype),
49 "shape": list(array_data.shape),
50 "data": array_data.tolist(),
51 }
53 # Load with default parameters
54 result = load_item_recursive(json_data, tuple(), None)
56 # Check the result
57 assert isinstance(result, np.ndarray)
58 assert result.shape == tuple(json_data["shape"])
59 assert result.dtype == np.dtype(json_data["dtype"])
60 assert np.allclose(result, array_data)
63def test_load_item_recursive_serializable_dataclass():
64 """Test loading a SerializableDataclass"""
66 @serializable_dataclass
67 class TestClass(SerializableDataclass):
68 name: str
69 value: int
70 data: typing.List[int] = serializable_field(default_factory=list)
72 # Create an instance and serialize it
73 instance = TestClass("test", 42, [1, 2, 3])
74 serialized = instance.serialize()
76 # Load with default parameters
77 result = load_item_recursive(serialized, tuple(), None)
79 # Check the result
80 assert isinstance(result, TestClass)
81 assert result.name == "test"
82 assert result.value == 42
83 assert result.data == [1, 2, 3]
86def test_load_item_recursive_nested_container():
87 """Test loading with nested containers"""
88 # Create a complex nested structure with properly formatted arrays
89 json_data = {
90 "name": "test",
91 "arrays": [
92 {
93 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix
94 "dtype": "float64",
95 "shape": [3, 3],
96 "data": np.random.rand(3, 3).tolist(),
97 },
98 {
99 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix
100 "dtype": "float64",
101 "shape": [2, 2],
102 "data": np.random.rand(2, 2).tolist(),
103 },
104 ],
105 "nested": {
106 "dict_with_array": {
107 _FORMAT_KEY: "numpy.ndarray:array_list_meta", # Use correct format suffix
108 "dtype": "float64",
109 "shape": [4, 4],
110 "data": np.random.rand(4, 4).tolist(),
111 }
112 },
113 }
115 # Load with default parameters
116 result = load_item_recursive(json_data, tuple(), None)
118 # Check the result
119 assert result["name"] == "test"
120 assert len(result["arrays"]) == 2
121 assert isinstance(result["arrays"][0], np.ndarray)
122 assert isinstance(result["arrays"][1], np.ndarray)
123 assert result["arrays"][0].shape == (3, 3)
124 assert result["arrays"][1].shape == (2, 2)
125 assert isinstance(result["nested"]["dict_with_array"], np.ndarray)
126 assert result["nested"]["dict_with_array"].shape == (4, 4)
129def test_load_item_recursive_unknown_format():
130 """Test loading with an unknown format key"""
131 # Create JSON data with an unknown format that is not registered in the handlers
132 json_data = {
133 _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist",
134 "data": [1, 2, 3],
135 }
137 # Load with default parameters (should return the JSON as is)
138 result = load_item_recursive(json_data, tuple(), None, allow_not_loading=True)
140 # Check the result
141 assert result == json_data
143 # TODO: this doesn't raise any errors
144 # Test with allow_not_loading=False (should raise an error)
145 # Create a ZANJ with EXCEPT error mode to ensure value errors are raised
146 z = ZANJ(error_mode=ErrorMode.EXCEPT)
147 load_item_recursive(
148 json_data, tuple(), z, error_mode=ErrorMode.EXCEPT, allow_not_loading=False
149 )
152def test_load_item_recursive_with_external_reference():
153 """Test loading an item with an external reference"""
154 # Create a ZANJ object and save some data to create externals
155 z = ZANJ(external_array_threshold=10)
156 data = {"large_array": np.random.rand(20, 20)}
157 path = TEST_DATA_PATH / "test_load_item_recursive_external.zanj"
158 z.save(data, path)
160 # Load the ZANJ file
161 loaded_zanj = LoadedZANJ(path, z)
163 # Try loading the data
164 loaded_zanj.populate_externals()
166 # Check that the externals were populated
167 assert len(loaded_zanj._externals) > 0
169 # Verify JSON data structure
170 assert "_REF_KEY" in loaded_zanj._json_data or isinstance(
171 loaded_zanj._json_data, dict
172 )
175def test_load_item_recursive_error_modes():
176 """Test different error modes"""
177 # Create JSON data with an unknown format
178 json_data = {
179 _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist",
180 "data": [1, 2, 3],
181 }
183 # Test WARN mode (should not raise, just return the data)
184 result = load_item_recursive(
185 json_data, tuple(), None, error_mode=ErrorMode.WARN, allow_not_loading=True
186 )
187 assert result == json_data
189 # Test IGNORE mode (should not raise, just return the data)
190 result = load_item_recursive(
191 json_data, tuple(), None, error_mode=ErrorMode.IGNORE, allow_not_loading=True
192 )
193 assert result == json_data
195 # Create a custom class that's known to fail during loading
196 class CustomHandler:
197 def check(self, json_item, path=None, z=None):
198 return (
199 json_item.get(_FORMAT_KEY)
200 == "unknown.format.that.definitely.does.not.exist"
201 )
203 def load(self, json_item, path=None, z=None):
204 # This will raise a ValueError
205 raise ValueError("Forced error for testing purposes")
207 # Register this handler temporarily
208 import zanj.loading
210 original_get_item_loader = zanj.loading.get_item_loader
212 def mock_get_item_loader(*args, **kwargs):
213 # Always return our custom handler
214 return CustomHandler()
216 try:
217 # Override the get_item_loader function
218 zanj.loading.get_item_loader = mock_get_item_loader
220 # Test EXCEPT mode (should raise)
221 with pytest.raises(ValueError):
222 load_item_recursive(
223 json_data,
224 tuple(),
225 None,
226 error_mode=ErrorMode.EXCEPT,
227 allow_not_loading=True,
228 )
229 finally:
230 # Restore the original function
231 zanj.loading.get_item_loader = original_get_item_loader