Coverage for zanj/torchutil.py: 90%
116 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 11:17 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 11:17 -0600
1"""torch utilities for zanj -- in particular the `ConfiguredModel` base class
3note that this requires torch
4"""
6from __future__ import annotations
8import abc
9import typing
10import warnings
11from typing import Any, Type, TypeVar
13try:
14 import torch # type: ignore[import-not-found]
15except ImportError as e:
16 raise ImportError(
17 "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`"
18 ) from e
20from muutils.json_serialize import SerializableDataclass
21from muutils.json_serialize.json_serialize import ObjectPath
22from muutils.json_serialize.util import safe_getsource, string_as_lines, _FORMAT_KEY
24from zanj import ZANJ, register_loader_handler
25from zanj.loading import LoaderHandler, load_item_recursive
27# pylint: disable=protected-access
29KWArgs = Any
32def num_params(m: torch.nn.Module, only_trainable: bool = True):
33 """return total number of parameters in a model
35 - only counting shared parameters once
36 - if `only_trainable` is False, will include parameters with `requires_grad = False`
38 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
39 """
40 parameters: list[torch.nn.Parameter] = list(m.parameters())
41 if only_trainable:
42 parameters = [p for p in parameters if p.requires_grad]
44 unique: list[torch.nn.Parameter] = list(
45 {p.data_ptr(): p for p in parameters}.values()
46 )
48 return sum(p.numel() for p in unique)
51def get_module_device(
52 m: torch.nn.Module,
53) -> tuple[bool, torch.device | dict[str, torch.device]]:
54 """get the current devices"""
56 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}
58 if len(devs) == 0:
59 return False, devs
61 # check if all devices are the same by getting one device
62 dev_uni: torch.device = next(iter(devs.values()))
64 if all(dev == dev_uni for dev in devs.values()):
65 return True, dev_uni
66 else:
67 return False, devs
70T_config = TypeVar("T_config", bound=SerializableDataclass)
73class ConfiguredModel(
74 torch.nn.Module,
75 typing.Generic[T_config],
76 metaclass=abc.ABCMeta,
77):
78 """a model that has a configuration, for saving with ZANJ
80 ```python
81 @set_config_class(YourConfig)
82 class YourModule(ConfiguredModel[YourConfig]):
83 def __init__(self, cfg: YourConfig):
84 super().__init__(cfg)
85 ```
87 `__init__()` must initialize the model from a config object only, and call
88 `super().__init__(zanj_model_config)`
90 If you are inheriting from another class + ConfiguredModel,
91 ConfiguredModel must be the first class in the inheritance list
92 """
94 # dont set this directly, use `set_config_class()` decorator
95 _config_class: type | None = None
96 zanj_config_class = property(lambda self: type(self)._config_class)
98 def __init__(self, zanj_model_config: T_config, **kwargs):
99 super().__init__(**kwargs)
100 if self.zanj_config_class is None:
101 raise NotImplementedError("you need to set `config_class` for your model")
102 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore
103 raise TypeError(
104 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
105 )
107 self.zanj_model_config: T_config = zanj_model_config
108 self.training_records: dict | None = None
110 def serialize(
111 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
112 ) -> dict[str, Any]:
113 if zanj is None:
114 zanj = ZANJ()
115 obj = dict(
116 zanj_model_config=self.zanj_model_config.serialize(),
117 meta=dict(
118 class_name=self.__class__.__name__,
119 class_doc=string_as_lines(self.__class__.__doc__),
120 class_source=safe_getsource(self.__class__),
121 module_name=self.__class__.__module__,
122 module_mro=[str(x) for x in self.__class__.__mro__],
123 num_params=num_params(self),
124 as_str=string_as_lines(str(self)),
125 ),
126 training_records=self.training_records,
127 state_dict=self.state_dict(),
128 __muutils_format__=self.__class__.__name__,
129 )
130 return obj
132 def save(self, file_path: str, zanj: ZANJ | None = None):
133 if zanj is None:
134 zanj = ZANJ()
135 zanj.save(self.serialize(), file_path)
137 def _load_state_dict_wrapper(
138 self,
139 state_dict: dict[str, torch.Tensor],
140 **kwargs,
141 ):
142 """wrapper for `load_state_dict()` in case you need to override it"""
143 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
144 return self.load_state_dict(state_dict)
146 @classmethod
147 def load(
148 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
149 ) -> "ConfiguredModel":
150 """load a model from a serialized object"""
152 if zanj is None:
153 zanj = ZANJ()
155 # get the config
156 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore
158 # get the training records
159 training_records: typing.Any = load_item_recursive(
160 obj.get("training_records", None),
161 tuple(path) + ("training_records",),
162 zanj,
163 )
165 # initialize the model
166 model: "ConfiguredModel" = cls(zanj_model_config)
168 # load the state dict
169 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
170 obj["state_dict"],
171 tuple(path) + ("state_dict",),
172 zanj,
173 )
175 model._load_state_dict_wrapper(
176 tensored_state_dict,
177 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
178 )
180 # set the training records
181 model.training_records = training_records
183 return model
185 @classmethod
186 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
187 """read a model from a file"""
188 if zanj is None:
189 zanj = ZANJ()
191 mdl: ConfiguredModel = zanj.read(file_path)
192 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
193 return mdl
195 @classmethod
196 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
197 """read a model from a file"""
198 warnings.warn(
199 "load_file() is deprecated, use read() instead", DeprecationWarning
200 )
201 return cls.read(file_path, zanj)
203 @classmethod
204 def get_handler(cls) -> LoaderHandler:
205 cls_name: str = str(cls.__name__)
206 return LoaderHandler(
207 check=lambda json_item, path=None, z=None: ( # type: ignore
208 isinstance(json_item, dict)
209 and _FORMAT_KEY in json_item
210 and json_item[_FORMAT_KEY].startswith(cls_name)
211 ),
212 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore
213 uid=cls_name,
214 source_pckg=cls.__module__,
215 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
216 )
218 def num_params(self) -> int:
219 return num_params(self)
222def set_config_class(
223 config_class: Type[SerializableDataclass],
224) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
225 if not issubclass(config_class, SerializableDataclass):
226 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")
228 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
229 # set the config class
230 cls._config_class = config_class
232 # register the handlers
233 register_loader_handler(cls.get_handler())
235 # return the new class
236 return cls
238 return wrapper
241class ConfigMismatchException(ValueError):
242 def __init__(self, msg: str, diff):
243 super().__init__(msg)
244 self.diff = diff
246 def __str__(self):
247 return f"{super().__str__()}: {self.diff}"
250def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
251 """check both models are correct instances and have the same config
253 Raises:
254 ConfigMismatchException: if the configs don't match, e.diff will contain the diff
255 """
256 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
257 assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
258 "model_a must have a zanj_model_config"
259 )
260 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
261 assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
262 "model_b must have a zanj_model_config"
263 )
265 cls_type: type = type(model_a.zanj_model_config)
267 if not (model_a.zanj_model_config == model_b.zanj_model_config):
268 raise ConfigMismatchException(
269 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
270 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined]
271 )
274def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
275 """check the models are exactly equal, including state dict contents"""
276 assert_model_cfg_equality(model_a, model_b)
278 model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
279 model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
280 assert model_a_sd_keys == model_b_sd_keys, (
281 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
282 )
283 keys_failed: list[str] = list()
284 for k, v_a in model_a.state_dict().items():
285 v_b = model_b.state_dict()[k]
286 if not (v_a == v_b).all():
287 # if not torch.allclose(v, v_load):
288 keys_failed.append(k)
289 print(f"failed {k}")
290 else:
291 print(f"passed {k}")
292 assert len(keys_failed) == 0, (
293 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
294 )