Coverage for tests / unit / with_torch / test_torchutil_edge_cases.py: 97%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 02:02 -0700

1"""Edge case tests for zanj/torchutil.py to improve coverage.""" 

2 

3from __future__ import annotations 

4 

5from pathlib import Path 

6 

7import pytest 

8import torch 

9 

10from muutils.json_serialize import SerializableDataclass, serializable_dataclass 

11 

12from zanj import ZANJ 

13from zanj.torchutil import ( 

14 ConfiguredModel, 

15 assert_model_exact_equality, 

16 num_params, 

17 set_config_class, 

18) 

19 

20 

21TEST_DATA_PATH: Path = Path("tests/junk_data") 

22 

23 

24@serializable_dataclass 

25class EdgeCaseTestConfig(SerializableDataclass): 

26 """Simple config for testing.""" 

27 

28 hidden_size: int 

29 num_layers: int 

30 

31 

32@set_config_class(EdgeCaseTestConfig) 

33class EdgeCaseTestModel(ConfiguredModel[EdgeCaseTestConfig]): 

34 """Simple model for testing.""" 

35 

36 def __init__(self, cfg: EdgeCaseTestConfig): 

37 super().__init__(cfg) 

38 self.linear = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size) 

39 

40 def forward(self, x): 

41 return self.linear(x) 

42 

43 

44class TestConfiguredModelValidation: 

45 """Tests for ConfiguredModel validation edge cases.""" 

46 

47 def test_missing_config_class_decorator(self): 

48 """Line 102: Model without @set_config_class decorator should raise NotImplementedError.""" 

49 

50 class UnconfiguredModel(ConfiguredModel): 

51 def __init__(self, cfg): 

52 super().__init__(cfg) 

53 

54 with pytest.raises(NotImplementedError, match="need to set"): 

55 UnconfiguredModel(EdgeCaseTestConfig(32, 2)) 

56 

57 def test_wrong_config_type(self): 

58 """Line 104: Passing wrong config type should raise TypeError.""" 

59 

60 class WrongConfig(SerializableDataclass): 

61 other_field: str = "test" 

62 

63 with pytest.raises(TypeError, match="must be an instance of"): 

64 EdgeCaseTestModel(WrongConfig()) 

65 

66 def test_config_not_serializable_dataclass(self): 

67 """Using config that isn't a dict should raise TypeError.""" 

68 with pytest.raises(TypeError): 

69 EdgeCaseTestModel({"hidden_size": 32}) # type: ignore 

70 

71 

72class TestSetConfigClass: 

73 """Tests for set_config_class decorator.""" 

74 

75 def test_invalid_config_class_type(self): 

76 """Line 227: Passing non-SerializableDataclass should raise TypeError.""" 

77 

78 class NotSerializable: 

79 pass 

80 

81 with pytest.raises( 

82 TypeError, match="must be a subclass of SerializableDataclass" 

83 ): 

84 

85 @set_config_class(NotSerializable) # type: ignore 

86 class BadModel(ConfiguredModel): 

87 pass 

88 

89 

90class TestConfiguredModelSaveLoad: 

91 """Tests for save/load with default ZANJ.""" 

92 

93 def test_save_with_default_zanj(self, tmp_path): 

94 """Lines 134-136: Save without explicit ZANJ should create default.""" 

95 cfg = EdgeCaseTestConfig(16, 1) 

96 model = EdgeCaseTestModel(cfg) 

97 

98 file_path = str(tmp_path / "model_default_zanj.zanj") 

99 model.save(file_path) # No zanj argument - uses default 

100 

101 assert Path(file_path).exists() 

102 

103 # Verify we can load it back 

104 loaded = EdgeCaseTestModel.read(file_path) 

105 assert loaded.zanj_model_config.hidden_size == 16 

106 

107 def test_load_with_default_zanj(self, tmp_path): 

108 """Line 154: Load without explicit ZANJ should create default.""" 

109 cfg = EdgeCaseTestConfig(16, 1) 

110 model = EdgeCaseTestModel(cfg) 

111 

112 # Save first 

113 file_path = str(tmp_path / "model_for_load_test.zanj") 

114 z = ZANJ() 

115 z.save(model.serialize(), file_path) 

116 

117 # Load with default ZANJ (via read method) 

118 loaded = EdgeCaseTestModel.read(file_path) # No zanj argument - uses default 

119 assert loaded.zanj_model_config.hidden_size == 16 

120 

121 def test_serialize_with_default_zanj(self): 

122 """Line 114-115: Serialize without explicit ZANJ should create default.""" 

123 cfg = EdgeCaseTestConfig(16, 1) 

124 model = EdgeCaseTestModel(cfg) 

125 

126 # Call serialize without zanj argument 

127 serialized = model.serialize() # No zanj argument - uses default 

128 

129 assert "zanj_model_config" in serialized 

130 assert "state_dict" in serialized 

131 assert "__muutils_format__" in serialized 

132 

133 

134class TestNumParamsWrapper: 

135 """Tests for num_params instance method.""" 

136 

137 def test_num_params_instance_method(self): 

138 """Line 220: Test the instance method wrapper for num_params.""" 

139 cfg = EdgeCaseTestConfig(16, 1) 

140 model = EdgeCaseTestModel(cfg) 

141 

142 # Call instance method 

143 instance_result = model.num_params() 

144 

145 # Call module-level function 

146 function_result = num_params(model) 

147 

148 assert instance_result == function_result 

149 assert instance_result > 0 

150 

151 

152class TestAssertModelExactEquality: 

153 """Tests for assert_model_exact_equality function.""" 

154 

155 def test_state_dict_mismatch(self): 

156 """Lines 289-290: State dict value mismatch should fail assertion.""" 

157 cfg = EdgeCaseTestConfig(16, 1) 

158 model_a = EdgeCaseTestModel(cfg) 

159 model_b = EdgeCaseTestModel(cfg) 

160 

161 # Modify model_b's weights to not match model_a 

162 with torch.no_grad(): 

163 for param in model_b.parameters(): 

164 param.add_(1.0) # Add 1 to all parameters 

165 

166 with pytest.raises(AssertionError, match="state dict elements don't match"): 

167 assert_model_exact_equality(model_a, model_b) 

168 

169 def test_equal_models_pass(self): 

170 """Equal models should pass the assertion.""" 

171 cfg = EdgeCaseTestConfig(16, 1) 

172 model_a = EdgeCaseTestModel(cfg) 

173 model_b = EdgeCaseTestModel(cfg) 

174 

175 # Copy state dict from a to b to make them equal 

176 model_b.load_state_dict(model_a.state_dict()) 

177 

178 # Should not raise 

179 assert_model_exact_equality(model_a, model_b) 

180 

181 def test_state_dict_keys_mismatch(self): 

182 """Different state dict keys should fail assertion.""" 

183 

184 class DifferentModel(ConfiguredModel[EdgeCaseTestConfig]): 

185 _config_class = EdgeCaseTestConfig 

186 

187 def __init__(self, cfg: EdgeCaseTestConfig): 

188 super().__init__(cfg) 

189 self.different_linear = torch.nn.Linear( 

190 cfg.hidden_size, cfg.hidden_size 

191 ) 

192 

193 def forward(self, x): 

194 return self.different_linear(x) 

195 

196 cfg = EdgeCaseTestConfig(16, 1) 

197 model_a = EdgeCaseTestModel(cfg) 

198 model_b = DifferentModel(cfg) 

199 

200 with pytest.raises(AssertionError, match="state dict keys don't match"): 

201 assert_model_exact_equality(model_a, model_b) 

202 

203 

204class TestConfiguredModelReadRoundTrip: 

205 """Tests for ConfiguredModel read/write round trip.""" 

206 

207 def test_full_round_trip(self, tmp_path): 

208 """Test complete save and read cycle.""" 

209 cfg = EdgeCaseTestConfig(32, 2) 

210 model = EdgeCaseTestModel(cfg) 

211 

212 # Set some training records 

213 model.training_records = {"epochs": 10, "loss": 0.5} 

214 

215 file_path = str(tmp_path / "round_trip_test.zanj") 

216 model.save(file_path) 

217 

218 loaded = EdgeCaseTestModel.read(file_path) 

219 

220 assert loaded.zanj_model_config.hidden_size == 32 

221 assert loaded.zanj_model_config.num_layers == 2 

222 assert loaded.training_records == {"epochs": 10, "loss": 0.5} 

223 

224 # Check weights are equal 

225 assert_model_exact_equality(model, loaded)