Coverage for zanj/torchutil.py: 90%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 11:17 -0600

1"""torch utilities for zanj -- in particular the `ConfiguredModel` base class 

2 

3note that this requires torch 

4""" 

5 

6from __future__ import annotations 

7 

8import abc 

9import typing 

10import warnings 

11from typing import Any, Type, TypeVar 

12 

13try: 

14 import torch # type: ignore[import-not-found] 

15except ImportError as e: 

16 raise ImportError( 

17 "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`" 

18 ) from e 

19 

20from muutils.json_serialize import SerializableDataclass 

21from muutils.json_serialize.json_serialize import ObjectPath 

22from muutils.json_serialize.util import safe_getsource, string_as_lines, _FORMAT_KEY 

23 

24from zanj import ZANJ, register_loader_handler 

25from zanj.loading import LoaderHandler, load_item_recursive 

26 

27# pylint: disable=protected-access 

28 

29KWArgs = Any 

30 

31 

32def num_params(m: torch.nn.Module, only_trainable: bool = True): 

33 """return total number of parameters in a model 

34 

35 - only counting shared parameters once 

36 - if `only_trainable` is False, will include parameters with `requires_grad = False` 

37 

38 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model 

39 """ 

40 parameters: list[torch.nn.Parameter] = list(m.parameters()) 

41 if only_trainable: 

42 parameters = [p for p in parameters if p.requires_grad] 

43 

44 unique: list[torch.nn.Parameter] = list( 

45 {p.data_ptr(): p for p in parameters}.values() 

46 ) 

47 

48 return sum(p.numel() for p in unique) 

49 

50 

51def get_module_device( 

52 m: torch.nn.Module, 

53) -> tuple[bool, torch.device | dict[str, torch.device]]: 

54 """get the current devices""" 

55 

56 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()} 

57 

58 if len(devs) == 0: 

59 return False, devs 

60 

61 # check if all devices are the same by getting one device 

62 dev_uni: torch.device = next(iter(devs.values())) 

63 

64 if all(dev == dev_uni for dev in devs.values()): 

65 return True, dev_uni 

66 else: 

67 return False, devs 

68 

69 

70T_config = TypeVar("T_config", bound=SerializableDataclass) 

71 

72 

73class ConfiguredModel( 

74 torch.nn.Module, 

75 typing.Generic[T_config], 

76 metaclass=abc.ABCMeta, 

77): 

78 """a model that has a configuration, for saving with ZANJ 

79 

80 ```python 

81 @set_config_class(YourConfig) 

82 class YourModule(ConfiguredModel[YourConfig]): 

83 def __init__(self, cfg: YourConfig): 

84 super().__init__(cfg) 

85 ``` 

86 

87 `__init__()` must initialize the model from a config object only, and call 

88 `super().__init__(zanj_model_config)` 

89 

90 If you are inheriting from another class + ConfiguredModel, 

91 ConfiguredModel must be the first class in the inheritance list 

92 """ 

93 

94 # dont set this directly, use `set_config_class()` decorator 

95 _config_class: type | None = None 

96 zanj_config_class = property(lambda self: type(self)._config_class) 

97 

98 def __init__(self, zanj_model_config: T_config, **kwargs): 

99 super().__init__(**kwargs) 

100 if self.zanj_config_class is None: 

101 raise NotImplementedError("you need to set `config_class` for your model") 

102 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore 

103 raise TypeError( 

104 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }" 

105 ) 

106 

107 self.zanj_model_config: T_config = zanj_model_config 

108 self.training_records: dict | None = None 

109 

110 def serialize( 

111 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 

112 ) -> dict[str, Any]: 

113 if zanj is None: 

114 zanj = ZANJ() 

115 obj = dict( 

116 zanj_model_config=self.zanj_model_config.serialize(), 

117 meta=dict( 

118 class_name=self.__class__.__name__, 

119 class_doc=string_as_lines(self.__class__.__doc__), 

120 class_source=safe_getsource(self.__class__), 

121 module_name=self.__class__.__module__, 

122 module_mro=[str(x) for x in self.__class__.__mro__], 

123 num_params=num_params(self), 

124 as_str=string_as_lines(str(self)), 

125 ), 

126 training_records=self.training_records, 

127 state_dict=self.state_dict(), 

128 __muutils_format__=self.__class__.__name__, 

129 ) 

130 return obj 

131 

132 def save(self, file_path: str, zanj: ZANJ | None = None): 

133 if zanj is None: 

134 zanj = ZANJ() 

135 zanj.save(self.serialize(), file_path) 

136 

137 def _load_state_dict_wrapper( 

138 self, 

139 state_dict: dict[str, torch.Tensor], 

140 **kwargs, 

141 ): 

142 """wrapper for `load_state_dict()` in case you need to override it""" 

143 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}" 

144 return self.load_state_dict(state_dict) 

145 

146 @classmethod 

147 def load( 

148 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 

149 ) -> "ConfiguredModel": 

150 """load a model from a serialized object""" 

151 

152 if zanj is None: 

153 zanj = ZANJ() 

154 

155 # get the config 

156 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 

157 

158 # get the training records 

159 training_records: typing.Any = load_item_recursive( 

160 obj.get("training_records", None), 

161 tuple(path) + ("training_records",), 

162 zanj, 

163 ) 

164 

165 # initialize the model 

166 model: "ConfiguredModel" = cls(zanj_model_config) 

167 

168 # load the state dict 

169 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 

170 obj["state_dict"], 

171 tuple(path) + ("state_dict",), 

172 zanj, 

173 ) 

174 

175 model._load_state_dict_wrapper( 

176 tensored_state_dict, 

177 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 

178 ) 

179 

180 # set the training records 

181 model.training_records = training_records 

182 

183 return model 

184 

185 @classmethod 

186 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 

187 """read a model from a file""" 

188 if zanj is None: 

189 zanj = ZANJ() 

190 

191 mdl: ConfiguredModel = zanj.read(file_path) 

192 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 

193 return mdl 

194 

195 @classmethod 

196 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 

197 """read a model from a file""" 

198 warnings.warn( 

199 "load_file() is deprecated, use read() instead", DeprecationWarning 

200 ) 

201 return cls.read(file_path, zanj) 

202 

203 @classmethod 

204 def get_handler(cls) -> LoaderHandler: 

205 cls_name: str = str(cls.__name__) 

206 return LoaderHandler( 

207 check=lambda json_item, path=None, z=None: ( # type: ignore 

208 isinstance(json_item, dict) 

209 and _FORMAT_KEY in json_item 

210 and json_item[_FORMAT_KEY].startswith(cls_name) 

211 ), 

212 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 

213 uid=cls_name, 

214 source_pckg=cls.__module__, 

215 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 

216 ) 

217 

218 def num_params(self) -> int: 

219 return num_params(self) 

220 

221 

222def set_config_class( 

223 config_class: Type[SerializableDataclass], 

224) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]: 

225 if not issubclass(config_class, SerializableDataclass): 

226 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass") 

227 

228 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]: 

229 # set the config class 

230 cls._config_class = config_class 

231 

232 # register the handlers 

233 register_loader_handler(cls.get_handler()) 

234 

235 # return the new class 

236 return cls 

237 

238 return wrapper 

239 

240 

241class ConfigMismatchException(ValueError): 

242 def __init__(self, msg: str, diff): 

243 super().__init__(msg) 

244 self.diff = diff 

245 

246 def __str__(self): 

247 return f"{super().__str__()}: {self.diff}" 

248 

249 

250def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 

251 """check both models are correct instances and have the same config 

252 

253 Raises: 

254 ConfigMismatchException: if the configs don't match, e.diff will contain the diff 

255 """ 

256 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel" 

257 assert isinstance(model_a.zanj_model_config, SerializableDataclass), ( 

258 "model_a must have a zanj_model_config" 

259 ) 

260 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel" 

261 assert isinstance(model_b.zanj_model_config, SerializableDataclass), ( 

262 "model_b must have a zanj_model_config" 

263 ) 

264 

265 cls_type: type = type(model_a.zanj_model_config) 

266 

267 if not (model_a.zanj_model_config == model_b.zanj_model_config): 

268 raise ConfigMismatchException( 

269 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match", 

270 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined] 

271 ) 

272 

273 

274def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 

275 """check the models are exactly equal, including state dict contents""" 

276 assert_model_cfg_equality(model_a, model_b) 

277 

278 model_a_sd_keys: set[str] = set(model_a.state_dict().keys()) 

279 model_b_sd_keys: set[str] = set(model_b.state_dict().keys()) 

280 assert model_a_sd_keys == model_b_sd_keys, ( 

281 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}" 

282 ) 

283 keys_failed: list[str] = list() 

284 for k, v_a in model_a.state_dict().items(): 

285 v_b = model_b.state_dict()[k] 

286 if not (v_a == v_b).all(): 

287 # if not torch.allclose(v, v_load): 

288 keys_failed.append(k) 

289 print(f"failed {k}") 

290 else: 

291 print(f"passed {k}") 

292 assert len(keys_failed) == 0, ( 

293 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}" 

294 )