docs for zanj v0.6.0
View Source on GitHub

zanj.torchutil

torch utilities for zanj -- in particular the ConfiguredModel base class

note that this requires torch


  1"""torch utilities for zanj -- in particular the `ConfiguredModel` base class
  2
  3note that this requires torch
  4"""
  5
  6from __future__ import annotations
  7
  8import abc
  9import typing
 10import warnings
 11from typing import Any, Type, TypeVar
 12
 13try:
 14    import torch  # type: ignore[import-not-found]
 15except ImportError as e:
 16    raise ImportError(
 17        "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`"
 18    ) from e
 19
 20from muutils.json_serialize import SerializableDataclass
 21from muutils.json_serialize.json_serialize import ObjectPath
 22
 23from zanj.consts import _FORMAT_KEY, safe_getsource, string_as_lines
 24
 25from zanj import ZANJ, register_loader_handler
 26from zanj.loading import LoaderHandler, load_item_recursive
 27
 28# pylint: disable=protected-access
 29
 30KWArgs = Any
 31
 32
 33def num_params(m: torch.nn.Module, only_trainable: bool = True):
 34    """return total number of parameters in a model
 35
 36    - only counting shared parameters once
 37    - if `only_trainable` is False, will include parameters with `requires_grad = False`
 38
 39    https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
 40    """
 41    parameters: list[torch.nn.Parameter] = list(m.parameters())
 42    if only_trainable:
 43        parameters = [p for p in parameters if p.requires_grad]
 44
 45    unique: list[torch.nn.Parameter] = list(
 46        {p.data_ptr(): p for p in parameters}.values()
 47    )
 48
 49    return sum(p.numel() for p in unique)
 50
 51
 52def get_module_device(
 53    m: torch.nn.Module,
 54) -> tuple[bool, torch.device | dict[str, torch.device]]:
 55    """get the current devices"""
 56
 57    devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}
 58
 59    if len(devs) == 0:
 60        return False, devs
 61
 62    # check if all devices are the same by getting one device
 63    dev_uni: torch.device = next(iter(devs.values()))
 64
 65    if all(dev == dev_uni for dev in devs.values()):
 66        return True, dev_uni
 67    else:
 68        return False, devs
 69
 70
 71T_config = TypeVar("T_config", bound=SerializableDataclass)
 72
 73
 74class ConfiguredModel(
 75    torch.nn.Module,
 76    typing.Generic[T_config],
 77    metaclass=abc.ABCMeta,
 78):
 79    """a model that has a configuration, for saving with ZANJ
 80
 81    ```python
 82    @set_config_class(YourConfig)
 83    class YourModule(ConfiguredModel[YourConfig]):
 84        def __init__(self, cfg: YourConfig):
 85            super().__init__(cfg)
 86    ```
 87
 88    `__init__()` must initialize the model from a config object only, and call
 89    `super().__init__(zanj_model_config)`
 90
 91    If you are inheriting from another class + ConfiguredModel,
 92    ConfiguredModel must be the first class in the inheritance list
 93    """
 94
 95    # dont set this directly, use `set_config_class()` decorator
 96    _config_class: type | None = None
 97    zanj_config_class = property(lambda self: type(self)._config_class)
 98
 99    def __init__(self, zanj_model_config: T_config, **kwargs):
100        super().__init__(**kwargs)
101        if self.zanj_config_class is None:
102            raise NotImplementedError("you need to set `config_class` for your model")
103        if not isinstance(zanj_model_config, self.zanj_config_class):  # type: ignore
104            raise TypeError(
105                f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
106            )
107
108        self.zanj_model_config: T_config = zanj_model_config
109        self.training_records: dict | None = None
110
111    def serialize(
112        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
113    ) -> dict[str, Any]:
114        if zanj is None:
115            zanj = ZANJ()
116        obj = dict(
117            zanj_model_config=self.zanj_model_config.serialize(),
118            meta=dict(
119                class_name=self.__class__.__name__,
120                class_doc=string_as_lines(self.__class__.__doc__),
121                class_source=safe_getsource(self.__class__),
122                module_name=self.__class__.__module__,
123                module_mro=[str(x) for x in self.__class__.__mro__],
124                num_params=num_params(self),
125                as_str=string_as_lines(str(self)),
126            ),
127            training_records=self.training_records,
128            state_dict=self.state_dict(),
129            __muutils_format__=self.__class__.__name__,
130        )
131        return obj
132
133    def save(self, file_path: str, zanj: ZANJ | None = None):
134        if zanj is None:
135            zanj = ZANJ()
136        zanj.save(self.serialize(), file_path)
137
138    def _load_state_dict_wrapper(
139        self,
140        state_dict: dict[str, torch.Tensor],
141        **kwargs,
142    ):
143        """wrapper for `load_state_dict()` in case you need to override it"""
144        assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
145        return self.load_state_dict(state_dict)
146
147    @classmethod
148    def load(
149        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
150    ) -> "ConfiguredModel":
151        """load a model from a serialized object"""
152
153        if zanj is None:
154            zanj = ZANJ()
155
156        # get the config
157        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore
158
159        # get the training records
160        training_records: typing.Any = load_item_recursive(
161            obj.get("training_records", None),
162            tuple(path) + ("training_records",),
163            zanj,
164        )
165
166        # initialize the model
167        model: "ConfiguredModel" = cls(zanj_model_config)
168
169        # load the state dict
170        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
171            obj["state_dict"],
172            tuple(path) + ("state_dict",),
173            zanj,
174        )
175
176        model._load_state_dict_wrapper(
177            tensored_state_dict,
178            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
179        )
180
181        # set the training records
182        model.training_records = training_records
183
184        return model
185
186    @classmethod
187    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
188        """read a model from a file"""
189        if zanj is None:
190            zanj = ZANJ()
191
192        mdl: ConfiguredModel = zanj.read(file_path)
193        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
194        return mdl
195
196    @classmethod
197    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
198        """read a model from a file"""
199        warnings.warn(
200            "load_file() is deprecated, use read() instead", DeprecationWarning
201        )
202        return cls.read(file_path, zanj)
203
204    @classmethod
205    def get_handler(cls) -> LoaderHandler:
206        cls_name: str = str(cls.__name__)
207        return LoaderHandler(
208            check=lambda json_item, path=None, z=None: (  # type: ignore
209                isinstance(json_item, dict)
210                and _FORMAT_KEY in json_item
211                and json_item[_FORMAT_KEY].startswith(cls_name)
212            ),
213            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
214            uid=cls_name,
215            source_pckg=cls.__module__,
216            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
217        )
218
219    def num_params(self) -> int:
220        return num_params(self)
221
222
223def set_config_class(
224    config_class: Type[SerializableDataclass],
225) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
226    if not issubclass(config_class, SerializableDataclass):
227        raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")
228
229    def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
230        # set the config class
231        cls._config_class = config_class
232
233        # register the handlers
234        register_loader_handler(cls.get_handler())
235
236        # return the new class
237        return cls
238
239    return wrapper
240
241
242class ConfigMismatchException(ValueError):
243    def __init__(self, msg: str, diff):
244        super().__init__(msg)
245        self.diff = diff
246
247    def __str__(self):
248        return f"{super().__str__()}: {self.diff}"
249
250
251def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
252    """check both models are correct instances and have the same config
253
254    Raises:
255        ConfigMismatchException: if the configs don't match, e.diff will contain the diff
256    """
257    assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
258    assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
259        "model_a must have a zanj_model_config"
260    )
261    assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
262    assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
263        "model_b must have a zanj_model_config"
264    )
265
266    cls_type: type = type(model_a.zanj_model_config)
267
268    if not (model_a.zanj_model_config == model_b.zanj_model_config):
269        raise ConfigMismatchException(
270            f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
271            diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config),  # type: ignore[attr-defined]
272        )
273
274
275def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
276    """check the models are exactly equal, including state dict contents"""
277    assert_model_cfg_equality(model_a, model_b)
278
279    model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
280    model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
281    assert model_a_sd_keys == model_b_sd_keys, (
282        f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
283    )
284    keys_failed: list[str] = list()
285    for k, v_a in model_a.state_dict().items():
286        v_b = model_b.state_dict()[k]
287        if not (v_a == v_b).all():
288            # if not torch.allclose(v, v_load):
289            keys_failed.append(k)
290            print(f"failed {k}")
291        else:
292            print(f"passed {k}")
293    assert len(keys_failed) == 0, (
294        f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
295    )

KWArgs = typing.Any
def num_params(m: torch.nn.modules.module.Module, only_trainable: bool = True):
34def num_params(m: torch.nn.Module, only_trainable: bool = True):
35    """return total number of parameters in a model
36
37    - only counting shared parameters once
38    - if `only_trainable` is False, will include parameters with `requires_grad = False`
39
40    https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
41    """
42    parameters: list[torch.nn.Parameter] = list(m.parameters())
43    if only_trainable:
44        parameters = [p for p in parameters if p.requires_grad]
45
46    unique: list[torch.nn.Parameter] = list(
47        {p.data_ptr(): p for p in parameters}.values()
48    )
49
50    return sum(p.numel() for p in unique)

return total number of parameters in a model

  • only counting shared parameters once
  • if only_trainable is False, will include parameters with requires_grad = False

https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model

def get_module_device( m: torch.nn.modules.module.Module) -> tuple[bool, torch.device | dict[str, torch.device]]:
53def get_module_device(
54    m: torch.nn.Module,
55) -> tuple[bool, torch.device | dict[str, torch.device]]:
56    """get the current devices"""
57
58    devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}
59
60    if len(devs) == 0:
61        return False, devs
62
63    # check if all devices are the same by getting one device
64    dev_uni: torch.device = next(iter(devs.values()))
65
66    if all(dev == dev_uni for dev in devs.values()):
67        return True, dev_uni
68    else:
69        return False, devs

get the current devices

class ConfiguredModel(torch.nn.modules.module.Module, typing.Generic[~T_config]):
 75class ConfiguredModel(
 76    torch.nn.Module,
 77    typing.Generic[T_config],
 78    metaclass=abc.ABCMeta,
 79):
 80    """a model that has a configuration, for saving with ZANJ
 81
 82    ```python
 83    @set_config_class(YourConfig)
 84    class YourModule(ConfiguredModel[YourConfig]):
 85        def __init__(self, cfg: YourConfig):
 86            super().__init__(cfg)
 87    ```
 88
 89    `__init__()` must initialize the model from a config object only, and call
 90    `super().__init__(zanj_model_config)`
 91
 92    If you are inheriting from another class + ConfiguredModel,
 93    ConfiguredModel must be the first class in the inheritance list
 94    """
 95
 96    # dont set this directly, use `set_config_class()` decorator
 97    _config_class: type | None = None
 98    zanj_config_class = property(lambda self: type(self)._config_class)
 99
100    def __init__(self, zanj_model_config: T_config, **kwargs):
101        super().__init__(**kwargs)
102        if self.zanj_config_class is None:
103            raise NotImplementedError("you need to set `config_class` for your model")
104        if not isinstance(zanj_model_config, self.zanj_config_class):  # type: ignore
105            raise TypeError(
106                f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
107            )
108
109        self.zanj_model_config: T_config = zanj_model_config
110        self.training_records: dict | None = None
111
112    def serialize(
113        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
114    ) -> dict[str, Any]:
115        if zanj is None:
116            zanj = ZANJ()
117        obj = dict(
118            zanj_model_config=self.zanj_model_config.serialize(),
119            meta=dict(
120                class_name=self.__class__.__name__,
121                class_doc=string_as_lines(self.__class__.__doc__),
122                class_source=safe_getsource(self.__class__),
123                module_name=self.__class__.__module__,
124                module_mro=[str(x) for x in self.__class__.__mro__],
125                num_params=num_params(self),
126                as_str=string_as_lines(str(self)),
127            ),
128            training_records=self.training_records,
129            state_dict=self.state_dict(),
130            __muutils_format__=self.__class__.__name__,
131        )
132        return obj
133
134    def save(self, file_path: str, zanj: ZANJ | None = None):
135        if zanj is None:
136            zanj = ZANJ()
137        zanj.save(self.serialize(), file_path)
138
139    def _load_state_dict_wrapper(
140        self,
141        state_dict: dict[str, torch.Tensor],
142        **kwargs,
143    ):
144        """wrapper for `load_state_dict()` in case you need to override it"""
145        assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
146        return self.load_state_dict(state_dict)
147
148    @classmethod
149    def load(
150        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
151    ) -> "ConfiguredModel":
152        """load a model from a serialized object"""
153
154        if zanj is None:
155            zanj = ZANJ()
156
157        # get the config
158        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore
159
160        # get the training records
161        training_records: typing.Any = load_item_recursive(
162            obj.get("training_records", None),
163            tuple(path) + ("training_records",),
164            zanj,
165        )
166
167        # initialize the model
168        model: "ConfiguredModel" = cls(zanj_model_config)
169
170        # load the state dict
171        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
172            obj["state_dict"],
173            tuple(path) + ("state_dict",),
174            zanj,
175        )
176
177        model._load_state_dict_wrapper(
178            tensored_state_dict,
179            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
180        )
181
182        # set the training records
183        model.training_records = training_records
184
185        return model
186
187    @classmethod
188    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
189        """read a model from a file"""
190        if zanj is None:
191            zanj = ZANJ()
192
193        mdl: ConfiguredModel = zanj.read(file_path)
194        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
195        return mdl
196
197    @classmethod
198    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
199        """read a model from a file"""
200        warnings.warn(
201            "load_file() is deprecated, use read() instead", DeprecationWarning
202        )
203        return cls.read(file_path, zanj)
204
205    @classmethod
206    def get_handler(cls) -> LoaderHandler:
207        cls_name: str = str(cls.__name__)
208        return LoaderHandler(
209            check=lambda json_item, path=None, z=None: (  # type: ignore
210                isinstance(json_item, dict)
211                and _FORMAT_KEY in json_item
212                and json_item[_FORMAT_KEY].startswith(cls_name)
213            ),
214            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
215            uid=cls_name,
216            source_pckg=cls.__module__,
217            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
218        )
219
220    def num_params(self) -> int:
221        return num_params(self)

a model that has a configuration, for saving with ZANJ

@set_config_class(YourConfig)
class YourModule(ConfiguredModel[YourConfig]):
    def __init__(self, cfg: YourConfig):
        super().__init__(cfg)

__init__() must initialize the model from a config object only, and call super().__init__(zanj_model_config)

If you are inheriting from another class + ConfiguredModel, ConfiguredModel must be the first class in the inheritance list

zanj_config_class
98    zanj_config_class = property(lambda self: type(self)._config_class)
zanj_model_config: ~T_config
training_records: dict | None
def serialize( self, path: tuple[typing.Union[str, int], ...] = (), zanj: zanj.ZANJ | None = None) -> dict[str, typing.Any]:
112    def serialize(
113        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
114    ) -> dict[str, Any]:
115        if zanj is None:
116            zanj = ZANJ()
117        obj = dict(
118            zanj_model_config=self.zanj_model_config.serialize(),
119            meta=dict(
120                class_name=self.__class__.__name__,
121                class_doc=string_as_lines(self.__class__.__doc__),
122                class_source=safe_getsource(self.__class__),
123                module_name=self.__class__.__module__,
124                module_mro=[str(x) for x in self.__class__.__mro__],
125                num_params=num_params(self),
126                as_str=string_as_lines(str(self)),
127            ),
128            training_records=self.training_records,
129            state_dict=self.state_dict(),
130            __muutils_format__=self.__class__.__name__,
131        )
132        return obj
def save(self, file_path: str, zanj: zanj.ZANJ | None = None):
134    def save(self, file_path: str, zanj: ZANJ | None = None):
135        if zanj is None:
136            zanj = ZANJ()
137        zanj.save(self.serialize(), file_path)
@classmethod
def load( cls, obj: dict[str, typing.Any], path: tuple[typing.Union[str, int], ...], zanj: zanj.ZANJ | None = None) -> ConfiguredModel:
148    @classmethod
149    def load(
150        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
151    ) -> "ConfiguredModel":
152        """load a model from a serialized object"""
153
154        if zanj is None:
155            zanj = ZANJ()
156
157        # get the config
158        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore
159
160        # get the training records
161        training_records: typing.Any = load_item_recursive(
162            obj.get("training_records", None),
163            tuple(path) + ("training_records",),
164            zanj,
165        )
166
167        # initialize the model
168        model: "ConfiguredModel" = cls(zanj_model_config)
169
170        # load the state dict
171        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
172            obj["state_dict"],
173            tuple(path) + ("state_dict",),
174            zanj,
175        )
176
177        model._load_state_dict_wrapper(
178            tensored_state_dict,
179            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
180        )
181
182        # set the training records
183        model.training_records = training_records
184
185        return model

load a model from a serialized object

@classmethod
def read( cls, file_path: str, zanj: zanj.ZANJ | None = None) -> ConfiguredModel:
187    @classmethod
188    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
189        """read a model from a file"""
190        if zanj is None:
191            zanj = ZANJ()
192
193        mdl: ConfiguredModel = zanj.read(file_path)
194        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
195        return mdl

read a model from a file

@classmethod
def load_file( cls, file_path: str, zanj: zanj.ZANJ | None = None) -> ConfiguredModel:
197    @classmethod
198    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
199        """read a model from a file"""
200        warnings.warn(
201            "load_file() is deprecated, use read() instead", DeprecationWarning
202        )
203        return cls.read(file_path, zanj)

read a model from a file

@classmethod
def get_handler(cls) -> zanj.loading.LoaderHandler:
205    @classmethod
206    def get_handler(cls) -> LoaderHandler:
207        cls_name: str = str(cls.__name__)
208        return LoaderHandler(
209            check=lambda json_item, path=None, z=None: (  # type: ignore
210                isinstance(json_item, dict)
211                and _FORMAT_KEY in json_item
212                and json_item[_FORMAT_KEY].startswith(cls_name)
213            ),
214            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
215            uid=cls_name,
216            source_pckg=cls.__module__,
217            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
218        )
def num_params(self) -> int:
220    def num_params(self) -> int:
221        return num_params(self)
Inherited Members
torch.nn.modules.module.Module
Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
set_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
mtia
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_post_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_pre_hook
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def set_config_class( config_class: Type[muutils.json_serialize.serializable_dataclass.SerializableDataclass]) -> Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
224def set_config_class(
225    config_class: Type[SerializableDataclass],
226) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
227    if not issubclass(config_class, SerializableDataclass):
228        raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")
229
230    def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
231        # set the config class
232        cls._config_class = config_class
233
234        # register the handlers
235        register_loader_handler(cls.get_handler())
236
237        # return the new class
238        return cls
239
240    return wrapper
class ConfigMismatchException(builtins.ValueError):
243class ConfigMismatchException(ValueError):
244    def __init__(self, msg: str, diff):
245        super().__init__(msg)
246        self.diff = diff
247
248    def __str__(self):
249        return f"{super().__str__()}: {self.diff}"

Inappropriate argument value (of correct type).

ConfigMismatchException(msg: str, diff)
244    def __init__(self, msg: str, diff):
245        super().__init__(msg)
246        self.diff = diff
diff
Inherited Members
builtins.BaseException
with_traceback
add_note
args
def assert_model_cfg_equality( model_a: ConfiguredModel, model_b: ConfiguredModel):
252def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
253    """check both models are correct instances and have the same config
254
255    Raises:
256        ConfigMismatchException: if the configs don't match, e.diff will contain the diff
257    """
258    assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
259    assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
260        "model_a must have a zanj_model_config"
261    )
262    assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
263    assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
264        "model_b must have a zanj_model_config"
265    )
266
267    cls_type: type = type(model_a.zanj_model_config)
268
269    if not (model_a.zanj_model_config == model_b.zanj_model_config):
270        raise ConfigMismatchException(
271            f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
272            diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config),  # type: ignore[attr-defined]
273        )

check both models are correct instances and have the same config

Raises: ConfigMismatchException: if the configs don't match, e.diff will contain the diff

def assert_model_exact_equality( model_a: ConfiguredModel, model_b: ConfiguredModel):
276def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
277    """check the models are exactly equal, including state dict contents"""
278    assert_model_cfg_equality(model_a, model_b)
279
280    model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
281    model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
282    assert model_a_sd_keys == model_b_sd_keys, (
283        f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
284    )
285    keys_failed: list[str] = list()
286    for k, v_a in model_a.state_dict().items():
287        v_b = model_b.state_dict()[k]
288        if not (v_a == v_b).all():
289            # if not torch.allclose(v, v_load):
290            keys_failed.append(k)
291            print(f"failed {k}")
292        else:
293            print(f"passed {k}")
294    assert len(keys_failed) == 0, (
295        f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
296    )

check the models are exactly equal, including state dict contents