docs for zanj v0.5.2
View Source on GitHub

zanj.torchutil

torch utilities for zanj -- in particular the ConfiguredModel base class

note that this requires torch


  1"""torch utilities for zanj -- in particular the `ConfiguredModel` base class
  2
  3note that this requires torch
  4"""
  5
  6from __future__ import annotations
  7
  8import abc
  9import typing
 10import warnings
 11from typing import Any, Type, TypeVar
 12
 13try:
 14    import torch  # type: ignore[import-not-found]
 15except ImportError as e:
 16    raise ImportError(
 17        "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`"
 18    ) from e
 19
 20from muutils.json_serialize import SerializableDataclass
 21from muutils.json_serialize.json_serialize import ObjectPath
 22from muutils.json_serialize.util import safe_getsource, string_as_lines, _FORMAT_KEY
 23
 24from zanj import ZANJ, register_loader_handler
 25from zanj.loading import LoaderHandler, load_item_recursive
 26
 27# pylint: disable=protected-access
 28
 29KWArgs = Any
 30
 31
 32def num_params(m: torch.nn.Module, only_trainable: bool = True):
 33    """return total number of parameters in a model
 34
 35    - only counting shared parameters once
 36    - if `only_trainable` is False, will include parameters with `requires_grad = False`
 37
 38    https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
 39    """
 40    parameters: list[torch.nn.Parameter] = list(m.parameters())
 41    if only_trainable:
 42        parameters = [p for p in parameters if p.requires_grad]
 43
 44    unique: list[torch.nn.Parameter] = list(
 45        {p.data_ptr(): p for p in parameters}.values()
 46    )
 47
 48    return sum(p.numel() for p in unique)
 49
 50
 51def get_module_device(
 52    m: torch.nn.Module,
 53) -> tuple[bool, torch.device | dict[str, torch.device]]:
 54    """get the current devices"""
 55
 56    devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}
 57
 58    if len(devs) == 0:
 59        return False, devs
 60
 61    # check if all devices are the same by getting one device
 62    dev_uni: torch.device = next(iter(devs.values()))
 63
 64    if all(dev == dev_uni for dev in devs.values()):
 65        return True, dev_uni
 66    else:
 67        return False, devs
 68
 69
 70T_config = TypeVar("T_config", bound=SerializableDataclass)
 71
 72
 73class ConfiguredModel(
 74    torch.nn.Module,
 75    typing.Generic[T_config],
 76    metaclass=abc.ABCMeta,
 77):
 78    """a model that has a configuration, for saving with ZANJ
 79
 80    ```python
 81    @set_config_class(YourConfig)
 82    class YourModule(ConfiguredModel[YourConfig]):
 83        def __init__(self, cfg: YourConfig):
 84            super().__init__(cfg)
 85    ```
 86
 87    `__init__()` must initialize the model from a config object only, and call
 88    `super().__init__(zanj_model_config)`
 89
 90    If you are inheriting from another class + ConfiguredModel,
 91    ConfiguredModel must be the first class in the inheritance list
 92    """
 93
 94    # dont set this directly, use `set_config_class()` decorator
 95    _config_class: type | None = None
 96    zanj_config_class = property(lambda self: type(self)._config_class)
 97
 98    def __init__(self, zanj_model_config: T_config, **kwargs):
 99        super().__init__(**kwargs)
100        if self.zanj_config_class is None:
101            raise NotImplementedError("you need to set `config_class` for your model")
102        if not isinstance(zanj_model_config, self.zanj_config_class):  # type: ignore
103            raise TypeError(
104                f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
105            )
106
107        self.zanj_model_config: T_config = zanj_model_config
108        self.training_records: dict | None = None
109
110    def serialize(
111        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
112    ) -> dict[str, Any]:
113        if zanj is None:
114            zanj = ZANJ()
115        obj = dict(
116            zanj_model_config=self.zanj_model_config.serialize(),
117            meta=dict(
118                class_name=self.__class__.__name__,
119                class_doc=string_as_lines(self.__class__.__doc__),
120                class_source=safe_getsource(self.__class__),
121                module_name=self.__class__.__module__,
122                module_mro=[str(x) for x in self.__class__.__mro__],
123                num_params=num_params(self),
124                as_str=string_as_lines(str(self)),
125            ),
126            training_records=self.training_records,
127            state_dict=self.state_dict(),
128            __muutils_format__=self.__class__.__name__,
129        )
130        return obj
131
132    def save(self, file_path: str, zanj: ZANJ | None = None):
133        if zanj is None:
134            zanj = ZANJ()
135        zanj.save(self.serialize(), file_path)
136
137    def _load_state_dict_wrapper(
138        self,
139        state_dict: dict[str, torch.Tensor],
140        **kwargs,
141    ):
142        """wrapper for `load_state_dict()` in case you need to override it"""
143        assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
144        return self.load_state_dict(state_dict)
145
146    @classmethod
147    def load(
148        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
149    ) -> "ConfiguredModel":
150        """load a model from a serialized object"""
151
152        if zanj is None:
153            zanj = ZANJ()
154
155        # get the config
156        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore
157
158        # get the training records
159        training_records: typing.Any = load_item_recursive(
160            obj.get("training_records", None),
161            tuple(path) + ("training_records",),
162            zanj,
163        )
164
165        # initialize the model
166        model: "ConfiguredModel" = cls(zanj_model_config)
167
168        # load the state dict
169        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
170            obj["state_dict"],
171            tuple(path) + ("state_dict",),
172            zanj,
173        )
174
175        model._load_state_dict_wrapper(
176            tensored_state_dict,
177            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
178        )
179
180        # set the training records
181        model.training_records = training_records
182
183        return model
184
185    @classmethod
186    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
187        """read a model from a file"""
188        if zanj is None:
189            zanj = ZANJ()
190
191        mdl: ConfiguredModel = zanj.read(file_path)
192        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
193        return mdl
194
195    @classmethod
196    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
197        """read a model from a file"""
198        warnings.warn(
199            "load_file() is deprecated, use read() instead", DeprecationWarning
200        )
201        return cls.read(file_path, zanj)
202
203    @classmethod
204    def get_handler(cls) -> LoaderHandler:
205        cls_name: str = str(cls.__name__)
206        return LoaderHandler(
207            check=lambda json_item, path=None, z=None: (  # type: ignore
208                isinstance(json_item, dict)
209                and _FORMAT_KEY in json_item
210                and json_item[_FORMAT_KEY].startswith(cls_name)
211            ),
212            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
213            uid=cls_name,
214            source_pckg=cls.__module__,
215            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
216        )
217
218    def num_params(self) -> int:
219        return num_params(self)
220
221
222def set_config_class(
223    config_class: Type[SerializableDataclass],
224) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
225    if not issubclass(config_class, SerializableDataclass):
226        raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")
227
228    def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
229        # set the config class
230        cls._config_class = config_class
231
232        # register the handlers
233        register_loader_handler(cls.get_handler())
234
235        # return the new class
236        return cls
237
238    return wrapper
239
240
241class ConfigMismatchException(ValueError):
242    def __init__(self, msg: str, diff):
243        super().__init__(msg)
244        self.diff = diff
245
246    def __str__(self):
247        return f"{super().__str__()}: {self.diff}"
248
249
250def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
251    """check both models are correct instances and have the same config
252
253    Raises:
254        ConfigMismatchException: if the configs don't match, e.diff will contain the diff
255    """
256    assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
257    assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
258        "model_a must have a zanj_model_config"
259    )
260    assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
261    assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
262        "model_b must have a zanj_model_config"
263    )
264
265    cls_type: type = type(model_a.zanj_model_config)
266
267    if not (model_a.zanj_model_config == model_b.zanj_model_config):
268        raise ConfigMismatchException(
269            f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
270            diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config),  # type: ignore[attr-defined]
271        )
272
273
274def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
275    """check the models are exactly equal, including state dict contents"""
276    assert_model_cfg_equality(model_a, model_b)
277
278    model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
279    model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
280    assert model_a_sd_keys == model_b_sd_keys, (
281        f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
282    )
283    keys_failed: list[str] = list()
284    for k, v_a in model_a.state_dict().items():
285        v_b = model_b.state_dict()[k]
286        if not (v_a == v_b).all():
287            # if not torch.allclose(v, v_load):
288            keys_failed.append(k)
289            print(f"failed {k}")
290        else:
291            print(f"passed {k}")
292    assert len(keys_failed) == 0, (
293        f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
294    )

KWArgs = typing.Any
def num_params(m: torch.nn.modules.module.Module, only_trainable: bool = True):
33def num_params(m: torch.nn.Module, only_trainable: bool = True):
34    """return total number of parameters in a model
35
36    - only counting shared parameters once
37    - if `only_trainable` is False, will include parameters with `requires_grad = False`
38
39    https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
40    """
41    parameters: list[torch.nn.Parameter] = list(m.parameters())
42    if only_trainable:
43        parameters = [p for p in parameters if p.requires_grad]
44
45    unique: list[torch.nn.Parameter] = list(
46        {p.data_ptr(): p for p in parameters}.values()
47    )
48
49    return sum(p.numel() for p in unique)

return total number of parameters in a model

  • only counting shared parameters once
  • if only_trainable is False, will include parameters with requires_grad = False

https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model

def get_module_device( m: torch.nn.modules.module.Module) -> tuple[bool, torch.device | dict[str, torch.device]]:
52def get_module_device(
53    m: torch.nn.Module,
54) -> tuple[bool, torch.device | dict[str, torch.device]]:
55    """get the current devices"""
56
57    devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}
58
59    if len(devs) == 0:
60        return False, devs
61
62    # check if all devices are the same by getting one device
63    dev_uni: torch.device = next(iter(devs.values()))
64
65    if all(dev == dev_uni for dev in devs.values()):
66        return True, dev_uni
67    else:
68        return False, devs

get the current devices

class ConfiguredModel(torch.nn.modules.module.Module, typing.Generic[~T_config]):
 74class ConfiguredModel(
 75    torch.nn.Module,
 76    typing.Generic[T_config],
 77    metaclass=abc.ABCMeta,
 78):
 79    """a model that has a configuration, for saving with ZANJ
 80
 81    ```python
 82    @set_config_class(YourConfig)
 83    class YourModule(ConfiguredModel[YourConfig]):
 84        def __init__(self, cfg: YourConfig):
 85            super().__init__(cfg)
 86    ```
 87
 88    `__init__()` must initialize the model from a config object only, and call
 89    `super().__init__(zanj_model_config)`
 90
 91    If you are inheriting from another class + ConfiguredModel,
 92    ConfiguredModel must be the first class in the inheritance list
 93    """
 94
 95    # dont set this directly, use `set_config_class()` decorator
 96    _config_class: type | None = None
 97    zanj_config_class = property(lambda self: type(self)._config_class)
 98
 99    def __init__(self, zanj_model_config: T_config, **kwargs):
100        super().__init__(**kwargs)
101        if self.zanj_config_class is None:
102            raise NotImplementedError("you need to set `config_class` for your model")
103        if not isinstance(zanj_model_config, self.zanj_config_class):  # type: ignore
104            raise TypeError(
105                f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
106            )
107
108        self.zanj_model_config: T_config = zanj_model_config
109        self.training_records: dict | None = None
110
111    def serialize(
112        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
113    ) -> dict[str, Any]:
114        if zanj is None:
115            zanj = ZANJ()
116        obj = dict(
117            zanj_model_config=self.zanj_model_config.serialize(),
118            meta=dict(
119                class_name=self.__class__.__name__,
120                class_doc=string_as_lines(self.__class__.__doc__),
121                class_source=safe_getsource(self.__class__),
122                module_name=self.__class__.__module__,
123                module_mro=[str(x) for x in self.__class__.__mro__],
124                num_params=num_params(self),
125                as_str=string_as_lines(str(self)),
126            ),
127            training_records=self.training_records,
128            state_dict=self.state_dict(),
129            __muutils_format__=self.__class__.__name__,
130        )
131        return obj
132
133    def save(self, file_path: str, zanj: ZANJ | None = None):
134        if zanj is None:
135            zanj = ZANJ()
136        zanj.save(self.serialize(), file_path)
137
138    def _load_state_dict_wrapper(
139        self,
140        state_dict: dict[str, torch.Tensor],
141        **kwargs,
142    ):
143        """wrapper for `load_state_dict()` in case you need to override it"""
144        assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
145        return self.load_state_dict(state_dict)
146
147    @classmethod
148    def load(
149        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
150    ) -> "ConfiguredModel":
151        """load a model from a serialized object"""
152
153        if zanj is None:
154            zanj = ZANJ()
155
156        # get the config
157        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore
158
159        # get the training records
160        training_records: typing.Any = load_item_recursive(
161            obj.get("training_records", None),
162            tuple(path) + ("training_records",),
163            zanj,
164        )
165
166        # initialize the model
167        model: "ConfiguredModel" = cls(zanj_model_config)
168
169        # load the state dict
170        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
171            obj["state_dict"],
172            tuple(path) + ("state_dict",),
173            zanj,
174        )
175
176        model._load_state_dict_wrapper(
177            tensored_state_dict,
178            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
179        )
180
181        # set the training records
182        model.training_records = training_records
183
184        return model
185
186    @classmethod
187    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
188        """read a model from a file"""
189        if zanj is None:
190            zanj = ZANJ()
191
192        mdl: ConfiguredModel = zanj.read(file_path)
193        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
194        return mdl
195
196    @classmethod
197    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
198        """read a model from a file"""
199        warnings.warn(
200            "load_file() is deprecated, use read() instead", DeprecationWarning
201        )
202        return cls.read(file_path, zanj)
203
204    @classmethod
205    def get_handler(cls) -> LoaderHandler:
206        cls_name: str = str(cls.__name__)
207        return LoaderHandler(
208            check=lambda json_item, path=None, z=None: (  # type: ignore
209                isinstance(json_item, dict)
210                and _FORMAT_KEY in json_item
211                and json_item[_FORMAT_KEY].startswith(cls_name)
212            ),
213            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
214            uid=cls_name,
215            source_pckg=cls.__module__,
216            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
217        )
218
219    def num_params(self) -> int:
220        return num_params(self)

a model that has a configuration, for saving with ZANJ

@set_config_class(YourConfig)
class YourModule(ConfiguredModel[YourConfig]):
    def __init__(self, cfg: YourConfig):
        super().__init__(cfg)

__init__() must initialize the model from a config object only, and call super().__init__(zanj_model_config)

If you are inheriting from another class + ConfiguredModel, ConfiguredModel must be the first class in the inheritance list

zanj_config_class
97    zanj_config_class = property(lambda self: type(self)._config_class)
zanj_model_config: ~T_config
training_records: dict | None
def serialize( self, path: tuple[typing.Union[str, int], ...] = (), zanj: zanj.ZANJ | None = None) -> dict[str, typing.Any]:
111    def serialize(
112        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
113    ) -> dict[str, Any]:
114        if zanj is None:
115            zanj = ZANJ()
116        obj = dict(
117            zanj_model_config=self.zanj_model_config.serialize(),
118            meta=dict(
119                class_name=self.__class__.__name__,
120                class_doc=string_as_lines(self.__class__.__doc__),
121                class_source=safe_getsource(self.__class__),
122                module_name=self.__class__.__module__,
123                module_mro=[str(x) for x in self.__class__.__mro__],
124                num_params=num_params(self),
125                as_str=string_as_lines(str(self)),
126            ),
127            training_records=self.training_records,
128            state_dict=self.state_dict(),
129            __muutils_format__=self.__class__.__name__,
130        )
131        return obj
def save(self, file_path: str, zanj: zanj.ZANJ | None = None):
133    def save(self, file_path: str, zanj: ZANJ | None = None):
134        if zanj is None:
135            zanj = ZANJ()
136        zanj.save(self.serialize(), file_path)
@classmethod
def load( cls, obj: dict[str, typing.Any], path: tuple[typing.Union[str, int], ...], zanj: zanj.ZANJ | None = None) -> ConfiguredModel:
147    @classmethod
148    def load(
149        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
150    ) -> "ConfiguredModel":
151        """load a model from a serialized object"""
152
153        if zanj is None:
154            zanj = ZANJ()
155
156        # get the config
157        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore
158
159        # get the training records
160        training_records: typing.Any = load_item_recursive(
161            obj.get("training_records", None),
162            tuple(path) + ("training_records",),
163            zanj,
164        )
165
166        # initialize the model
167        model: "ConfiguredModel" = cls(zanj_model_config)
168
169        # load the state dict
170        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
171            obj["state_dict"],
172            tuple(path) + ("state_dict",),
173            zanj,
174        )
175
176        model._load_state_dict_wrapper(
177            tensored_state_dict,
178            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
179        )
180
181        # set the training records
182        model.training_records = training_records
183
184        return model

load a model from a serialized object

@classmethod
def read( cls, file_path: str, zanj: zanj.ZANJ | None = None) -> ConfiguredModel:
186    @classmethod
187    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
188        """read a model from a file"""
189        if zanj is None:
190            zanj = ZANJ()
191
192        mdl: ConfiguredModel = zanj.read(file_path)
193        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
194        return mdl

read a model from a file

@classmethod
def load_file( cls, file_path: str, zanj: zanj.ZANJ | None = None) -> ConfiguredModel:
196    @classmethod
197    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
198        """read a model from a file"""
199        warnings.warn(
200            "load_file() is deprecated, use read() instead", DeprecationWarning
201        )
202        return cls.read(file_path, zanj)

read a model from a file

@classmethod
def get_handler(cls) -> zanj.loading.LoaderHandler:
204    @classmethod
205    def get_handler(cls) -> LoaderHandler:
206        cls_name: str = str(cls.__name__)
207        return LoaderHandler(
208            check=lambda json_item, path=None, z=None: (  # type: ignore
209                isinstance(json_item, dict)
210                and _FORMAT_KEY in json_item
211                and json_item[_FORMAT_KEY].startswith(cls_name)
212            ),
213            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
214            uid=cls_name,
215            source_pckg=cls.__module__,
216            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
217        )
def num_params(self) -> int:
219    def num_params(self) -> int:
220        return num_params(self)
Inherited Members
torch.nn.modules.module.Module
Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
set_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
mtia
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_post_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_pre_hook
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def set_config_class( config_class: Type[muutils.json_serialize.serializable_dataclass.SerializableDataclass]) -> Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
223def set_config_class(
224    config_class: Type[SerializableDataclass],
225) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
226    if not issubclass(config_class, SerializableDataclass):
227        raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")
228
229    def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
230        # set the config class
231        cls._config_class = config_class
232
233        # register the handlers
234        register_loader_handler(cls.get_handler())
235
236        # return the new class
237        return cls
238
239    return wrapper
class ConfigMismatchException(builtins.ValueError):
242class ConfigMismatchException(ValueError):
243    def __init__(self, msg: str, diff):
244        super().__init__(msg)
245        self.diff = diff
246
247    def __str__(self):
248        return f"{super().__str__()}: {self.diff}"

Inappropriate argument value (of correct type).

ConfigMismatchException(msg: str, diff)
243    def __init__(self, msg: str, diff):
244        super().__init__(msg)
245        self.diff = diff
diff
Inherited Members
builtins.BaseException
with_traceback
add_note
args
def assert_model_cfg_equality( model_a: ConfiguredModel, model_b: ConfiguredModel):
251def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
252    """check both models are correct instances and have the same config
253
254    Raises:
255        ConfigMismatchException: if the configs don't match, e.diff will contain the diff
256    """
257    assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
258    assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
259        "model_a must have a zanj_model_config"
260    )
261    assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
262    assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
263        "model_b must have a zanj_model_config"
264    )
265
266    cls_type: type = type(model_a.zanj_model_config)
267
268    if not (model_a.zanj_model_config == model_b.zanj_model_config):
269        raise ConfigMismatchException(
270            f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
271            diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config),  # type: ignore[attr-defined]
272        )

check both models are correct instances and have the same config

Raises: ConfigMismatchException: if the configs don't match, e.diff will contain the diff

def assert_model_exact_equality( model_a: ConfiguredModel, model_b: ConfiguredModel):
275def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
276    """check the models are exactly equal, including state dict contents"""
277    assert_model_cfg_equality(model_a, model_b)
278
279    model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
280    model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
281    assert model_a_sd_keys == model_b_sd_keys, (
282        f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
283    )
284    keys_failed: list[str] = list()
285    for k, v_a in model_a.state_dict().items():
286        v_b = model_b.state_dict()[k]
287        if not (v_a == v_b).all():
288            # if not torch.allclose(v, v_load):
289            keys_failed.append(k)
290            print(f"failed {k}")
291        else:
292            print(f"passed {k}")
293    assert len(keys_failed) == 0, (
294        f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
295    )

check the models are exactly equal, including state dict contents