Coverage for zanj / torchutil.py: 97%
116 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 01:52 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 01:52 -0700
1"""torch utilities for zanj -- in particular the `ConfiguredModel` base class
3note that this requires torch
4"""
6from __future__ import annotations
8import abc
9import typing
10import warnings
11from typing import Any, Type, TypeVar
13try:
14 import torch # type: ignore[import-not-found]
15except ImportError as e:
16 raise ImportError(
17 "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`"
18 ) from e
20from muutils.json_serialize import SerializableDataclass
21from muutils.json_serialize.json_serialize import ObjectPath
23from zanj.consts import _FORMAT_KEY, safe_getsource, string_as_lines
25from zanj import ZANJ, register_loader_handler
26from zanj.loading import LoaderHandler, load_item_recursive
28# pylint: disable=protected-access
30KWArgs = Any
33def num_params(m: torch.nn.Module, only_trainable: bool = True):
34 """return total number of parameters in a model
36 - only counting shared parameters once
37 - if `only_trainable` is False, will include parameters with `requires_grad = False`
39 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
40 """
41 parameters: list[torch.nn.Parameter] = list(m.parameters())
42 if only_trainable:
43 parameters = [p for p in parameters if p.requires_grad]
45 unique: list[torch.nn.Parameter] = list(
46 {p.data_ptr(): p for p in parameters}.values()
47 )
49 return sum(p.numel() for p in unique)
52def get_module_device(
53 m: torch.nn.Module,
54) -> tuple[bool, torch.device | dict[str, torch.device]]:
55 """get the current devices"""
57 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}
59 if len(devs) == 0:
60 return False, devs
62 # check if all devices are the same by getting one device
63 dev_uni: torch.device = next(iter(devs.values()))
65 if all(dev == dev_uni for dev in devs.values()):
66 return True, dev_uni
67 else:
68 return False, devs
71T_config = TypeVar("T_config", bound=SerializableDataclass)
74class ConfiguredModel(
75 torch.nn.Module,
76 typing.Generic[T_config],
77 metaclass=abc.ABCMeta,
78):
79 """a model that has a configuration, for saving with ZANJ
81 ```python
82 @set_config_class(YourConfig)
83 class YourModule(ConfiguredModel[YourConfig]):
84 def __init__(self, cfg: YourConfig):
85 super().__init__(cfg)
86 ```
88 `__init__()` must initialize the model from a config object only, and call
89 `super().__init__(zanj_model_config)`
91 If you are inheriting from another class + ConfiguredModel,
92 ConfiguredModel must be the first class in the inheritance list
93 """
95 # dont set this directly, use `set_config_class()` decorator
96 _config_class: type | None = None
97 zanj_config_class = property(lambda self: type(self)._config_class)
99 def __init__(self, zanj_model_config: T_config, **kwargs):
100 super().__init__(**kwargs)
101 if self.zanj_config_class is None:
102 raise NotImplementedError("you need to set `config_class` for your model")
103 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore
104 raise TypeError(
105 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
106 )
108 self.zanj_model_config: T_config = zanj_model_config
109 self.training_records: dict | None = None
111 def serialize(
112 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
113 ) -> dict[str, Any]:
114 if zanj is None:
115 zanj = ZANJ()
116 obj = dict(
117 zanj_model_config=self.zanj_model_config.serialize(),
118 meta=dict(
119 class_name=self.__class__.__name__,
120 class_doc=string_as_lines(self.__class__.__doc__),
121 class_source=safe_getsource(self.__class__),
122 module_name=self.__class__.__module__,
123 module_mro=[str(x) for x in self.__class__.__mro__],
124 num_params=num_params(self),
125 as_str=string_as_lines(str(self)),
126 ),
127 training_records=self.training_records,
128 state_dict=self.state_dict(),
129 __muutils_format__=self.__class__.__name__,
130 )
131 return obj
133 def save(self, file_path: str, zanj: ZANJ | None = None):
134 if zanj is None:
135 zanj = ZANJ()
136 zanj.save(self.serialize(), file_path)
138 def _load_state_dict_wrapper(
139 self,
140 state_dict: dict[str, torch.Tensor],
141 **kwargs,
142 ):
143 """wrapper for `load_state_dict()` in case you need to override it"""
144 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
145 return self.load_state_dict(state_dict)
147 @classmethod
148 def load(
149 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
150 ) -> "ConfiguredModel":
151 """load a model from a serialized object"""
153 if zanj is None:
154 zanj = ZANJ()
156 # get the config
157 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore
159 # get the training records
160 training_records: typing.Any = load_item_recursive(
161 obj.get("training_records", None),
162 tuple(path) + ("training_records",),
163 zanj,
164 )
166 # initialize the model
167 model: "ConfiguredModel" = cls(zanj_model_config)
169 # load the state dict
170 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
171 obj["state_dict"],
172 tuple(path) + ("state_dict",),
173 zanj,
174 )
176 model._load_state_dict_wrapper(
177 tensored_state_dict,
178 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
179 )
181 # set the training records
182 model.training_records = training_records
184 return model
186 @classmethod
187 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
188 """read a model from a file"""
189 if zanj is None:
190 zanj = ZANJ()
192 mdl: ConfiguredModel = zanj.read(file_path)
193 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
194 return mdl
196 @classmethod
197 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
198 """read a model from a file"""
199 warnings.warn(
200 "load_file() is deprecated, use read() instead", DeprecationWarning
201 )
202 return cls.read(file_path, zanj)
204 @classmethod
205 def get_handler(cls) -> LoaderHandler:
206 cls_name: str = str(cls.__name__)
207 return LoaderHandler(
208 check=lambda json_item, path=None, z=None: ( # type: ignore
209 isinstance(json_item, dict)
210 and _FORMAT_KEY in json_item
211 and json_item[_FORMAT_KEY].startswith(cls_name)
212 ),
213 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore
214 uid=cls_name,
215 source_pckg=cls.__module__,
216 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
217 )
219 def num_params(self) -> int:
220 return num_params(self)
223def set_config_class(
224 config_class: Type[SerializableDataclass],
225) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
226 if not issubclass(config_class, SerializableDataclass):
227 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")
229 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
230 # set the config class
231 cls._config_class = config_class
233 # register the handlers
234 register_loader_handler(cls.get_handler())
236 # return the new class
237 return cls
239 return wrapper
242class ConfigMismatchException(ValueError):
243 def __init__(self, msg: str, diff):
244 super().__init__(msg)
245 self.diff = diff
247 def __str__(self):
248 return f"{super().__str__()}: {self.diff}"
251def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
252 """check both models are correct instances and have the same config
254 Raises:
255 ConfigMismatchException: if the configs don't match, e.diff will contain the diff
256 """
257 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
258 assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
259 "model_a must have a zanj_model_config"
260 )
261 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
262 assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
263 "model_b must have a zanj_model_config"
264 )
266 cls_type: type = type(model_a.zanj_model_config)
268 if not (model_a.zanj_model_config == model_b.zanj_model_config):
269 raise ConfigMismatchException(
270 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
271 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined]
272 )
275def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
276 """check the models are exactly equal, including state dict contents"""
277 assert_model_cfg_equality(model_a, model_b)
279 model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
280 model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
281 assert model_a_sd_keys == model_b_sd_keys, (
282 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
283 )
284 keys_failed: list[str] = list()
285 for k, v_a in model_a.state_dict().items():
286 v_b = model_b.state_dict()[k]
287 if not (v_a == v_b).all():
288 # if not torch.allclose(v, v_load):
289 keys_failed.append(k)
290 print(f"failed {k}")
291 else:
292 print(f"passed {k}")
293 assert len(keys_failed) == 0, (
294 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
295 )