Coverage for tests / unit / with_torch / test_torch_edge_cases.py: 95%

108 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-21 22:18 -0700

1from __future__ import annotations 

2 

3from pathlib import Path 

4 

5import pytest 

6import torch # type: ignore[import-not-found] 

7from muutils.json_serialize import ( 

8 SerializableDataclass, 

9 serializable_dataclass, 

10) 

11 

12from zanj import ZANJ 

13from zanj.torchutil import ( 

14 ConfiguredModel, 

15 assert_model_exact_equality, 

16 get_module_device, 

17 num_params, 

18 set_config_class, 

19) 

20 

21TEST_DATA_PATH: Path = Path("tests/junk_data") 

22 

23 

24def test_num_params(): 

25 """Test the num_params function with various models""" 

26 # Simple model with trainable parameters 

27 model1 = torch.nn.Sequential( 

28 torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 1) 

29 ) 

30 

31 # Expected number of trainable parameters: 

32 # Linear1: 10*20 + 20 = 220 

33 # Linear2: 20*1 + 1 = 21 

34 # Total: 241 

35 assert num_params(model1) == 241 

36 

37 # Model with some non-trainable parameters 

38 model2 = torch.nn.Sequential( 

39 torch.nn.Linear(10, 20), 

40 torch.nn.BatchNorm1d(20, track_running_stats=True), 

41 torch.nn.Linear(20, 1), 

42 ) 

43 

44 # Freeze the batch norm layer 

45 for param in model2[1].parameters(): 

46 param.requires_grad = False 

47 

48 # Count only trainable parameters 

49 trainable_params = num_params(model2, only_trainable=True) 

50 

51 # Count all parameters 

52 all_params = num_params(model2, only_trainable=False) 

53 

54 # Batch norm has 2*20 = 40 parameters (weight and bias) 

55 # So difference should be 40 

56 assert all_params - trainable_params == 40 

57 

58 

59def test_get_module_device_empty(): 

60 """Test get_module_device with a module that has no parameters""" 

61 

62 # Create a module with no parameters 

63 class EmptyModule(torch.nn.Module): 

64 def __init__(self): 

65 super().__init__() 

66 

67 def forward(self, x): 

68 return x 

69 

70 empty_module = EmptyModule() 

71 

72 # Test the function 

73 is_single, device_dict = get_module_device(empty_module) 

74 

75 # Should return False and an empty dict 

76 assert is_single is False 

77 assert device_dict == {} 

78 

79 

80def test_load_state_dict_wrapper(): 

81 """Test the _load_state_dict_wrapper method of ConfiguredModel""" 

82 

83 @serializable_dataclass 

84 class SimpleConfig(SerializableDataclass): 

85 size: int 

86 

87 @set_config_class(SimpleConfig) 

88 class SimpleModel(ConfiguredModel[SimpleConfig]): 

89 def __init__(self, config: SimpleConfig): 

90 super().__init__(config) 

91 self.linear = torch.nn.Linear(config.size, config.size) 

92 

93 def forward(self, x): 

94 return self.linear(x) 

95 

96 # Override _load_state_dict_wrapper to test custom behavior 

97 def _load_state_dict_wrapper(self, state_dict, **kwargs): 

98 # This should be called instead of the standard load_state_dict 

99 self.custom_wrapper_called = True 

100 # Still need to actually load the state dict 

101 return super()._load_state_dict_wrapper(state_dict) 

102 

103 # Create a model, save it, and load it 

104 model = SimpleModel(SimpleConfig(10)) 

105 path = TEST_DATA_PATH / "test_load_state_dict_wrapper.zanj" 

106 ZANJ().save(model, path) 

107 

108 # Create a new model instance 

109 model2 = SimpleModel(SimpleConfig(10)) 

110 model2.custom_wrapper_called = False 

111 

112 # Use read to load the model, which should call _load_state_dict_wrapper 

113 loaded_model = model2.read(path) 

114 

115 # Check that our custom wrapper was called 

116 assert hasattr(loaded_model, "custom_wrapper_called") 

117 assert loaded_model.custom_wrapper_called is True 

118 

119 

120def test_deprecated_load_file(): 

121 """Test that the deprecated load_file method works and issues a warning""" 

122 

123 @serializable_dataclass 

124 class SimpleConfig(SerializableDataclass): 

125 size: int 

126 

127 @set_config_class(SimpleConfig) 

128 class SimpleModel(ConfiguredModel[SimpleConfig]): 

129 def __init__(self, config: SimpleConfig): 

130 super().__init__(config) 

131 self.linear = torch.nn.Linear(config.size, config.size) 

132 

133 def forward(self, x): 

134 return self.linear(x) 

135 

136 # Create a model and save it 

137 model = SimpleModel(SimpleConfig(10)) 

138 path = TEST_DATA_PATH / "test_deprecated_load_file.zanj" 

139 ZANJ().save(model, path) 

140 

141 # Use the deprecated method with a warning check 

142 with pytest.warns(DeprecationWarning): 

143 loaded_model = SimpleModel.load_file(path) 

144 

145 # Check that the model was loaded correctly 

146 assert_model_exact_equality(model, loaded_model) 

147 

148 

149def test_configmodel_training_records(): 

150 """Test that training_records are properly saved and loaded""" 

151 

152 @serializable_dataclass 

153 class SimpleConfig(SerializableDataclass): 

154 size: int 

155 

156 @set_config_class(SimpleConfig) 

157 class SimpleModel(ConfiguredModel[SimpleConfig]): 

158 def __init__(self, config: SimpleConfig): 

159 super().__init__(config) 

160 self.linear = torch.nn.Linear(config.size, config.size) 

161 self.training_records = None # Initialize to None 

162 

163 def forward(self, x): 

164 return self.linear(x) 

165 

166 # Create a model and set some training records 

167 model = SimpleModel(SimpleConfig(10)) 

168 model.training_records = { 

169 "loss": [1.0, 0.9, 0.8, 0.7], 

170 "accuracy": [0.6, 0.7, 0.8, 0.9], 

171 "learning_rate": [0.01, 0.005, 0.001, 0.0005], 

172 } 

173 

174 # Save the model 

175 path = TEST_DATA_PATH / "test_training_records.zanj" 

176 ZANJ().save(model, path) 

177 

178 # Load the model 

179 loaded_model = SimpleModel.read(path) 

180 

181 # Check that training records were loaded correctly 

182 assert loaded_model.training_records == model.training_records 

183 assert loaded_model.training_records["loss"] == [1.0, 0.9, 0.8, 0.7] 

184 assert loaded_model.training_records["accuracy"] == [0.6, 0.7, 0.8, 0.9] 

185 assert loaded_model.training_records["learning_rate"] == [ 

186 0.01, 

187 0.005, 

188 0.001, 

189 0.0005, 

190 ] 

191 

192 

193def test_configmodel_with_custom_settings(): 

194 """Test ConfiguredModel with custom settings for _load_state_dict_wrapper""" 

195 

196 @serializable_dataclass 

197 class SimpleConfig(SerializableDataclass): 

198 size: int 

199 

200 @set_config_class(SimpleConfig) 

201 class CustomStateModel(ConfiguredModel[SimpleConfig]): 

202 def __init__(self, config: SimpleConfig): 

203 super().__init__(config) 

204 self.linear = torch.nn.Linear(config.size, config.size) 

205 self.custom_param_received = None 

206 

207 def forward(self, x): 

208 return self.linear(x) 

209 

210 def _load_state_dict_wrapper(self, state_dict, **kwargs): 

211 # Store the custom param if provided 

212 self.custom_param_received = kwargs.get("custom_param", None) 

213 return super(CustomStateModel, self)._load_state_dict_wrapper(state_dict) 

214 

215 # Create a model 

216 model = CustomStateModel(SimpleConfig(10)) 

217 path = TEST_DATA_PATH / "test_custom_settings.zanj" 

218 ZANJ().save(model, path) 

219 

220 # Create a ZANJ with custom settings 

221 z = ZANJ( 

222 custom_settings={"_load_state_dict_wrapper": {"custom_param": "test_value"}} 

223 ) 

224 

225 # Load the model with the custom ZANJ 

226 loaded_model = z.read(path) 

227 

228 # Check that the custom param was received 

229 assert loaded_model.custom_param_received == "test_value"