Coverage for zanj / torchutil.py: 97%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 01:52 -0700

1"""torch utilities for zanj -- in particular the `ConfiguredModel` base class 

2 

3note that this requires torch 

4""" 

5 

6from __future__ import annotations 

7 

8import abc 

9import typing 

10import warnings 

11from typing import Any, Type, TypeVar 

12 

13try: 

14 import torch # type: ignore[import-not-found] 

15except ImportError as e: 

16 raise ImportError( 

17 "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`" 

18 ) from e 

19 

20from muutils.json_serialize import SerializableDataclass 

21from muutils.json_serialize.json_serialize import ObjectPath 

22 

23from zanj.consts import _FORMAT_KEY, safe_getsource, string_as_lines 

24 

25from zanj import ZANJ, register_loader_handler 

26from zanj.loading import LoaderHandler, load_item_recursive 

27 

28# pylint: disable=protected-access 

29 

30KWArgs = Any 

31 

32 

33def num_params(m: torch.nn.Module, only_trainable: bool = True): 

34 """return total number of parameters in a model 

35 

36 - only counting shared parameters once 

37 - if `only_trainable` is False, will include parameters with `requires_grad = False` 

38 

39 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model 

40 """ 

41 parameters: list[torch.nn.Parameter] = list(m.parameters()) 

42 if only_trainable: 

43 parameters = [p for p in parameters if p.requires_grad] 

44 

45 unique: list[torch.nn.Parameter] = list( 

46 {p.data_ptr(): p for p in parameters}.values() 

47 ) 

48 

49 return sum(p.numel() for p in unique) 

50 

51 

52def get_module_device( 

53 m: torch.nn.Module, 

54) -> tuple[bool, torch.device | dict[str, torch.device]]: 

55 """get the current devices""" 

56 

57 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()} 

58 

59 if len(devs) == 0: 

60 return False, devs 

61 

62 # check if all devices are the same by getting one device 

63 dev_uni: torch.device = next(iter(devs.values())) 

64 

65 if all(dev == dev_uni for dev in devs.values()): 

66 return True, dev_uni 

67 else: 

68 return False, devs 

69 

70 

71T_config = TypeVar("T_config", bound=SerializableDataclass) 

72 

73 

74class ConfiguredModel( 

75 torch.nn.Module, 

76 typing.Generic[T_config], 

77 metaclass=abc.ABCMeta, 

78): 

79 """a model that has a configuration, for saving with ZANJ 

80 

81 ```python 

82 @set_config_class(YourConfig) 

83 class YourModule(ConfiguredModel[YourConfig]): 

84 def __init__(self, cfg: YourConfig): 

85 super().__init__(cfg) 

86 ``` 

87 

88 `__init__()` must initialize the model from a config object only, and call 

89 `super().__init__(zanj_model_config)` 

90 

91 If you are inheriting from another class + ConfiguredModel, 

92 ConfiguredModel must be the first class in the inheritance list 

93 """ 

94 

95 # dont set this directly, use `set_config_class()` decorator 

96 _config_class: type | None = None 

97 zanj_config_class = property(lambda self: type(self)._config_class) 

98 

99 def __init__(self, zanj_model_config: T_config, **kwargs): 

100 super().__init__(**kwargs) 

101 if self.zanj_config_class is None: 

102 raise NotImplementedError("you need to set `config_class` for your model") 

103 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore 

104 raise TypeError( 

105 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }" 

106 ) 

107 

108 self.zanj_model_config: T_config = zanj_model_config 

109 self.training_records: dict | None = None 

110 

111 def serialize( 

112 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 

113 ) -> dict[str, Any]: 

114 if zanj is None: 

115 zanj = ZANJ() 

116 obj = dict( 

117 zanj_model_config=self.zanj_model_config.serialize(), 

118 meta=dict( 

119 class_name=self.__class__.__name__, 

120 class_doc=string_as_lines(self.__class__.__doc__), 

121 class_source=safe_getsource(self.__class__), 

122 module_name=self.__class__.__module__, 

123 module_mro=[str(x) for x in self.__class__.__mro__], 

124 num_params=num_params(self), 

125 as_str=string_as_lines(str(self)), 

126 ), 

127 training_records=self.training_records, 

128 state_dict=self.state_dict(), 

129 __muutils_format__=self.__class__.__name__, 

130 ) 

131 return obj 

132 

133 def save(self, file_path: str, zanj: ZANJ | None = None): 

134 if zanj is None: 

135 zanj = ZANJ() 

136 zanj.save(self.serialize(), file_path) 

137 

138 def _load_state_dict_wrapper( 

139 self, 

140 state_dict: dict[str, torch.Tensor], 

141 **kwargs, 

142 ): 

143 """wrapper for `load_state_dict()` in case you need to override it""" 

144 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}" 

145 return self.load_state_dict(state_dict) 

146 

147 @classmethod 

148 def load( 

149 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 

150 ) -> "ConfiguredModel": 

151 """load a model from a serialized object""" 

152 

153 if zanj is None: 

154 zanj = ZANJ() 

155 

156 # get the config 

157 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 

158 

159 # get the training records 

160 training_records: typing.Any = load_item_recursive( 

161 obj.get("training_records", None), 

162 tuple(path) + ("training_records",), 

163 zanj, 

164 ) 

165 

166 # initialize the model 

167 model: "ConfiguredModel" = cls(zanj_model_config) 

168 

169 # load the state dict 

170 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 

171 obj["state_dict"], 

172 tuple(path) + ("state_dict",), 

173 zanj, 

174 ) 

175 

176 model._load_state_dict_wrapper( 

177 tensored_state_dict, 

178 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 

179 ) 

180 

181 # set the training records 

182 model.training_records = training_records 

183 

184 return model 

185 

186 @classmethod 

187 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 

188 """read a model from a file""" 

189 if zanj is None: 

190 zanj = ZANJ() 

191 

192 mdl: ConfiguredModel = zanj.read(file_path) 

193 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 

194 return mdl 

195 

196 @classmethod 

197 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 

198 """read a model from a file""" 

199 warnings.warn( 

200 "load_file() is deprecated, use read() instead", DeprecationWarning 

201 ) 

202 return cls.read(file_path, zanj) 

203 

204 @classmethod 

205 def get_handler(cls) -> LoaderHandler: 

206 cls_name: str = str(cls.__name__) 

207 return LoaderHandler( 

208 check=lambda json_item, path=None, z=None: ( # type: ignore 

209 isinstance(json_item, dict) 

210 and _FORMAT_KEY in json_item 

211 and json_item[_FORMAT_KEY].startswith(cls_name) 

212 ), 

213 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 

214 uid=cls_name, 

215 source_pckg=cls.__module__, 

216 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 

217 ) 

218 

219 def num_params(self) -> int: 

220 return num_params(self) 

221 

222 

223def set_config_class( 

224 config_class: Type[SerializableDataclass], 

225) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]: 

226 if not issubclass(config_class, SerializableDataclass): 

227 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass") 

228 

229 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]: 

230 # set the config class 

231 cls._config_class = config_class 

232 

233 # register the handlers 

234 register_loader_handler(cls.get_handler()) 

235 

236 # return the new class 

237 return cls 

238 

239 return wrapper 

240 

241 

242class ConfigMismatchException(ValueError): 

243 def __init__(self, msg: str, diff): 

244 super().__init__(msg) 

245 self.diff = diff 

246 

247 def __str__(self): 

248 return f"{super().__str__()}: {self.diff}" 

249 

250 

251def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 

252 """check both models are correct instances and have the same config 

253 

254 Raises: 

255 ConfigMismatchException: if the configs don't match, e.diff will contain the diff 

256 """ 

257 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel" 

258 assert isinstance(model_a.zanj_model_config, SerializableDataclass), ( 

259 "model_a must have a zanj_model_config" 

260 ) 

261 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel" 

262 assert isinstance(model_b.zanj_model_config, SerializableDataclass), ( 

263 "model_b must have a zanj_model_config" 

264 ) 

265 

266 cls_type: type = type(model_a.zanj_model_config) 

267 

268 if not (model_a.zanj_model_config == model_b.zanj_model_config): 

269 raise ConfigMismatchException( 

270 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match", 

271 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined] 

272 ) 

273 

274 

275def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 

276 """check the models are exactly equal, including state dict contents""" 

277 assert_model_cfg_equality(model_a, model_b) 

278 

279 model_a_sd_keys: set[str] = set(model_a.state_dict().keys()) 

280 model_b_sd_keys: set[str] = set(model_b.state_dict().keys()) 

281 assert model_a_sd_keys == model_b_sd_keys, ( 

282 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}" 

283 ) 

284 keys_failed: list[str] = list() 

285 for k, v_a in model_a.state_dict().items(): 

286 v_b = model_b.state_dict()[k] 

287 if not (v_a == v_b).all(): 

288 # if not torch.allclose(v, v_load): 

289 keys_failed.append(k) 

290 print(f"failed {k}") 

291 else: 

292 print(f"passed {k}") 

293 assert len(keys_failed) == 0, ( 

294 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}" 

295 )