Coverage for tests / unit / with_torch / test_torch_edge_cases.py: 95%
108 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
1from __future__ import annotations
3from pathlib import Path
5import pytest
6import torch # type: ignore[import-not-found]
7from muutils.json_serialize import (
8 SerializableDataclass,
9 serializable_dataclass,
10)
12from zanj import ZANJ
13from zanj.torchutil import (
14 ConfiguredModel,
15 assert_model_exact_equality,
16 get_module_device,
17 num_params,
18 set_config_class,
19)
21TEST_DATA_PATH: Path = Path("tests/junk_data")
24def test_num_params():
25 """Test the num_params function with various models"""
26 # Simple model with trainable parameters
27 model1 = torch.nn.Sequential(
28 torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 1)
29 )
31 # Expected number of trainable parameters:
32 # Linear1: 10*20 + 20 = 220
33 # Linear2: 20*1 + 1 = 21
34 # Total: 241
35 assert num_params(model1) == 241
37 # Model with some non-trainable parameters
38 model2 = torch.nn.Sequential(
39 torch.nn.Linear(10, 20),
40 torch.nn.BatchNorm1d(20, track_running_stats=True),
41 torch.nn.Linear(20, 1),
42 )
44 # Freeze the batch norm layer
45 for param in model2[1].parameters():
46 param.requires_grad = False
48 # Count only trainable parameters
49 trainable_params = num_params(model2, only_trainable=True)
51 # Count all parameters
52 all_params = num_params(model2, only_trainable=False)
54 # Batch norm has 2*20 = 40 parameters (weight and bias)
55 # So difference should be 40
56 assert all_params - trainable_params == 40
59def test_get_module_device_empty():
60 """Test get_module_device with a module that has no parameters"""
62 # Create a module with no parameters
63 class EmptyModule(torch.nn.Module):
64 def __init__(self):
65 super().__init__()
67 def forward(self, x):
68 return x
70 empty_module = EmptyModule()
72 # Test the function
73 is_single, device_dict = get_module_device(empty_module)
75 # Should return False and an empty dict
76 assert is_single is False
77 assert device_dict == {}
80def test_load_state_dict_wrapper():
81 """Test the _load_state_dict_wrapper method of ConfiguredModel"""
83 @serializable_dataclass
84 class SimpleConfig(SerializableDataclass):
85 size: int
87 @set_config_class(SimpleConfig)
88 class SimpleModel(ConfiguredModel[SimpleConfig]):
89 def __init__(self, config: SimpleConfig):
90 super().__init__(config)
91 self.linear = torch.nn.Linear(config.size, config.size)
93 def forward(self, x):
94 return self.linear(x)
96 # Override _load_state_dict_wrapper to test custom behavior
97 def _load_state_dict_wrapper(self, state_dict, **kwargs):
98 # This should be called instead of the standard load_state_dict
99 self.custom_wrapper_called = True
100 # Still need to actually load the state dict
101 return super()._load_state_dict_wrapper(state_dict)
103 # Create a model, save it, and load it
104 model = SimpleModel(SimpleConfig(10))
105 path = TEST_DATA_PATH / "test_load_state_dict_wrapper.zanj"
106 ZANJ().save(model, path)
108 # Create a new model instance
109 model2 = SimpleModel(SimpleConfig(10))
110 model2.custom_wrapper_called = False
112 # Use read to load the model, which should call _load_state_dict_wrapper
113 loaded_model = model2.read(path)
115 # Check that our custom wrapper was called
116 assert hasattr(loaded_model, "custom_wrapper_called")
117 assert loaded_model.custom_wrapper_called is True
120def test_deprecated_load_file():
121 """Test that the deprecated load_file method works and issues a warning"""
123 @serializable_dataclass
124 class SimpleConfig(SerializableDataclass):
125 size: int
127 @set_config_class(SimpleConfig)
128 class SimpleModel(ConfiguredModel[SimpleConfig]):
129 def __init__(self, config: SimpleConfig):
130 super().__init__(config)
131 self.linear = torch.nn.Linear(config.size, config.size)
133 def forward(self, x):
134 return self.linear(x)
136 # Create a model and save it
137 model = SimpleModel(SimpleConfig(10))
138 path = TEST_DATA_PATH / "test_deprecated_load_file.zanj"
139 ZANJ().save(model, path)
141 # Use the deprecated method with a warning check
142 with pytest.warns(DeprecationWarning):
143 loaded_model = SimpleModel.load_file(path)
145 # Check that the model was loaded correctly
146 assert_model_exact_equality(model, loaded_model)
149def test_configmodel_training_records():
150 """Test that training_records are properly saved and loaded"""
152 @serializable_dataclass
153 class SimpleConfig(SerializableDataclass):
154 size: int
156 @set_config_class(SimpleConfig)
157 class SimpleModel(ConfiguredModel[SimpleConfig]):
158 def __init__(self, config: SimpleConfig):
159 super().__init__(config)
160 self.linear = torch.nn.Linear(config.size, config.size)
161 self.training_records = None # Initialize to None
163 def forward(self, x):
164 return self.linear(x)
166 # Create a model and set some training records
167 model = SimpleModel(SimpleConfig(10))
168 model.training_records = {
169 "loss": [1.0, 0.9, 0.8, 0.7],
170 "accuracy": [0.6, 0.7, 0.8, 0.9],
171 "learning_rate": [0.01, 0.005, 0.001, 0.0005],
172 }
174 # Save the model
175 path = TEST_DATA_PATH / "test_training_records.zanj"
176 ZANJ().save(model, path)
178 # Load the model
179 loaded_model = SimpleModel.read(path)
181 # Check that training records were loaded correctly
182 assert loaded_model.training_records == model.training_records
183 assert loaded_model.training_records["loss"] == [1.0, 0.9, 0.8, 0.7]
184 assert loaded_model.training_records["accuracy"] == [0.6, 0.7, 0.8, 0.9]
185 assert loaded_model.training_records["learning_rate"] == [
186 0.01,
187 0.005,
188 0.001,
189 0.0005,
190 ]
193def test_configmodel_with_custom_settings():
194 """Test ConfiguredModel with custom settings for _load_state_dict_wrapper"""
196 @serializable_dataclass
197 class SimpleConfig(SerializableDataclass):
198 size: int
200 @set_config_class(SimpleConfig)
201 class CustomStateModel(ConfiguredModel[SimpleConfig]):
202 def __init__(self, config: SimpleConfig):
203 super().__init__(config)
204 self.linear = torch.nn.Linear(config.size, config.size)
205 self.custom_param_received = None
207 def forward(self, x):
208 return self.linear(x)
210 def _load_state_dict_wrapper(self, state_dict, **kwargs):
211 # Store the custom param if provided
212 self.custom_param_received = kwargs.get("custom_param", None)
213 return super(CustomStateModel, self)._load_state_dict_wrapper(state_dict)
215 # Create a model
216 model = CustomStateModel(SimpleConfig(10))
217 path = TEST_DATA_PATH / "test_custom_settings.zanj"
218 ZANJ().save(model, path)
220 # Create a ZANJ with custom settings
221 z = ZANJ(
222 custom_settings={"_load_state_dict_wrapper": {"custom_param": "test_value"}}
223 )
225 # Load the model with the custom ZANJ
226 loaded_model = z.read(path)
228 # Check that the custom param was received
229 assert loaded_model.custom_param_received == "test_value"