Coverage for tests / unit / with_torch / test_torchutil_edge_cases.py: 97%
115 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 02:02 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 02:02 -0700
1"""Edge case tests for zanj/torchutil.py to improve coverage."""
3from __future__ import annotations
5from pathlib import Path
7import pytest
8import torch
10from muutils.json_serialize import SerializableDataclass, serializable_dataclass
12from zanj import ZANJ
13from zanj.torchutil import (
14 ConfiguredModel,
15 assert_model_exact_equality,
16 num_params,
17 set_config_class,
18)
21TEST_DATA_PATH: Path = Path("tests/junk_data")
24@serializable_dataclass
25class EdgeCaseTestConfig(SerializableDataclass):
26 """Simple config for testing."""
28 hidden_size: int
29 num_layers: int
32@set_config_class(EdgeCaseTestConfig)
33class EdgeCaseTestModel(ConfiguredModel[EdgeCaseTestConfig]):
34 """Simple model for testing."""
36 def __init__(self, cfg: EdgeCaseTestConfig):
37 super().__init__(cfg)
38 self.linear = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size)
40 def forward(self, x):
41 return self.linear(x)
44class TestConfiguredModelValidation:
45 """Tests for ConfiguredModel validation edge cases."""
47 def test_missing_config_class_decorator(self):
48 """Line 102: Model without @set_config_class decorator should raise NotImplementedError."""
50 class UnconfiguredModel(ConfiguredModel):
51 def __init__(self, cfg):
52 super().__init__(cfg)
54 with pytest.raises(NotImplementedError, match="need to set"):
55 UnconfiguredModel(EdgeCaseTestConfig(32, 2))
57 def test_wrong_config_type(self):
58 """Line 104: Passing wrong config type should raise TypeError."""
60 class WrongConfig(SerializableDataclass):
61 other_field: str = "test"
63 with pytest.raises(TypeError, match="must be an instance of"):
64 EdgeCaseTestModel(WrongConfig())
66 def test_config_not_serializable_dataclass(self):
67 """Using config that isn't a dict should raise TypeError."""
68 with pytest.raises(TypeError):
69 EdgeCaseTestModel({"hidden_size": 32}) # type: ignore
72class TestSetConfigClass:
73 """Tests for set_config_class decorator."""
75 def test_invalid_config_class_type(self):
76 """Line 227: Passing non-SerializableDataclass should raise TypeError."""
78 class NotSerializable:
79 pass
81 with pytest.raises(
82 TypeError, match="must be a subclass of SerializableDataclass"
83 ):
85 @set_config_class(NotSerializable) # type: ignore
86 class BadModel(ConfiguredModel):
87 pass
90class TestConfiguredModelSaveLoad:
91 """Tests for save/load with default ZANJ."""
93 def test_save_with_default_zanj(self, tmp_path):
94 """Lines 134-136: Save without explicit ZANJ should create default."""
95 cfg = EdgeCaseTestConfig(16, 1)
96 model = EdgeCaseTestModel(cfg)
98 file_path = str(tmp_path / "model_default_zanj.zanj")
99 model.save(file_path) # No zanj argument - uses default
101 assert Path(file_path).exists()
103 # Verify we can load it back
104 loaded = EdgeCaseTestModel.read(file_path)
105 assert loaded.zanj_model_config.hidden_size == 16
107 def test_load_with_default_zanj(self, tmp_path):
108 """Line 154: Load without explicit ZANJ should create default."""
109 cfg = EdgeCaseTestConfig(16, 1)
110 model = EdgeCaseTestModel(cfg)
112 # Save first
113 file_path = str(tmp_path / "model_for_load_test.zanj")
114 z = ZANJ()
115 z.save(model.serialize(), file_path)
117 # Load with default ZANJ (via read method)
118 loaded = EdgeCaseTestModel.read(file_path) # No zanj argument - uses default
119 assert loaded.zanj_model_config.hidden_size == 16
121 def test_serialize_with_default_zanj(self):
122 """Line 114-115: Serialize without explicit ZANJ should create default."""
123 cfg = EdgeCaseTestConfig(16, 1)
124 model = EdgeCaseTestModel(cfg)
126 # Call serialize without zanj argument
127 serialized = model.serialize() # No zanj argument - uses default
129 assert "zanj_model_config" in serialized
130 assert "state_dict" in serialized
131 assert "__muutils_format__" in serialized
134class TestNumParamsWrapper:
135 """Tests for num_params instance method."""
137 def test_num_params_instance_method(self):
138 """Line 220: Test the instance method wrapper for num_params."""
139 cfg = EdgeCaseTestConfig(16, 1)
140 model = EdgeCaseTestModel(cfg)
142 # Call instance method
143 instance_result = model.num_params()
145 # Call module-level function
146 function_result = num_params(model)
148 assert instance_result == function_result
149 assert instance_result > 0
152class TestAssertModelExactEquality:
153 """Tests for assert_model_exact_equality function."""
155 def test_state_dict_mismatch(self):
156 """Lines 289-290: State dict value mismatch should fail assertion."""
157 cfg = EdgeCaseTestConfig(16, 1)
158 model_a = EdgeCaseTestModel(cfg)
159 model_b = EdgeCaseTestModel(cfg)
161 # Modify model_b's weights to not match model_a
162 with torch.no_grad():
163 for param in model_b.parameters():
164 param.add_(1.0) # Add 1 to all parameters
166 with pytest.raises(AssertionError, match="state dict elements don't match"):
167 assert_model_exact_equality(model_a, model_b)
169 def test_equal_models_pass(self):
170 """Equal models should pass the assertion."""
171 cfg = EdgeCaseTestConfig(16, 1)
172 model_a = EdgeCaseTestModel(cfg)
173 model_b = EdgeCaseTestModel(cfg)
175 # Copy state dict from a to b to make them equal
176 model_b.load_state_dict(model_a.state_dict())
178 # Should not raise
179 assert_model_exact_equality(model_a, model_b)
181 def test_state_dict_keys_mismatch(self):
182 """Different state dict keys should fail assertion."""
184 class DifferentModel(ConfiguredModel[EdgeCaseTestConfig]):
185 _config_class = EdgeCaseTestConfig
187 def __init__(self, cfg: EdgeCaseTestConfig):
188 super().__init__(cfg)
189 self.different_linear = torch.nn.Linear(
190 cfg.hidden_size, cfg.hidden_size
191 )
193 def forward(self, x):
194 return self.different_linear(x)
196 cfg = EdgeCaseTestConfig(16, 1)
197 model_a = EdgeCaseTestModel(cfg)
198 model_b = DifferentModel(cfg)
200 with pytest.raises(AssertionError, match="state dict keys don't match"):
201 assert_model_exact_equality(model_a, model_b)
204class TestConfiguredModelReadRoundTrip:
205 """Tests for ConfiguredModel read/write round trip."""
207 def test_full_round_trip(self, tmp_path):
208 """Test complete save and read cycle."""
209 cfg = EdgeCaseTestConfig(32, 2)
210 model = EdgeCaseTestModel(cfg)
212 # Set some training records
213 model.training_records = {"epochs": 10, "loss": 0.5}
215 file_path = str(tmp_path / "round_trip_test.zanj")
216 model.save(file_path)
218 loaded = EdgeCaseTestModel.read(file_path)
220 assert loaded.zanj_model_config.hidden_size == 32
221 assert loaded.zanj_model_config.num_layers == 2
222 assert loaded.training_records == {"epochs": 10, "loss": 0.5}
224 # Check weights are equal
225 assert_model_exact_equality(model, loaded)