zanj.torchutil
torch utilities for zanj -- in particular the ConfiguredModel base class
note that this requires torch
1"""torch utilities for zanj -- in particular the `ConfiguredModel` base class 2 3note that this requires torch 4""" 5 6from __future__ import annotations 7 8import abc 9import typing 10import warnings 11from typing import Any, Type, TypeVar 12 13try: 14 import torch # type: ignore[import-not-found] 15except ImportError as e: 16 raise ImportError( 17 "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`" 18 ) from e 19 20from muutils.json_serialize import SerializableDataclass 21from muutils.json_serialize.json_serialize import ObjectPath 22 23from zanj.consts import _FORMAT_KEY, safe_getsource, string_as_lines 24 25from zanj import ZANJ, register_loader_handler 26from zanj.loading import LoaderHandler, load_item_recursive 27 28# pylint: disable=protected-access 29 30KWArgs = Any 31 32 33def num_params(m: torch.nn.Module, only_trainable: bool = True): 34 """return total number of parameters in a model 35 36 - only counting shared parameters once 37 - if `only_trainable` is False, will include parameters with `requires_grad = False` 38 39 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model 40 """ 41 parameters: list[torch.nn.Parameter] = list(m.parameters()) 42 if only_trainable: 43 parameters = [p for p in parameters if p.requires_grad] 44 45 unique: list[torch.nn.Parameter] = list( 46 {p.data_ptr(): p for p in parameters}.values() 47 ) 48 49 return sum(p.numel() for p in unique) 50 51 52def get_module_device( 53 m: torch.nn.Module, 54) -> tuple[bool, torch.device | dict[str, torch.device]]: 55 """get the current devices""" 56 57 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()} 58 59 if len(devs) == 0: 60 return False, devs 61 62 # check if all devices are the same by getting one device 63 dev_uni: torch.device = next(iter(devs.values())) 64 65 if all(dev == dev_uni for dev in devs.values()): 66 return True, dev_uni 67 else: 68 return False, devs 69 70 71T_config = TypeVar("T_config", bound=SerializableDataclass) 72 73 74class ConfiguredModel( 75 torch.nn.Module, 76 typing.Generic[T_config], 77 metaclass=abc.ABCMeta, 78): 79 """a model that has a configuration, for saving with ZANJ 80 81 ```python 82 @set_config_class(YourConfig) 83 class YourModule(ConfiguredModel[YourConfig]): 84 def __init__(self, cfg: YourConfig): 85 super().__init__(cfg) 86 ``` 87 88 `__init__()` must initialize the model from a config object only, and call 89 `super().__init__(zanj_model_config)` 90 91 If you are inheriting from another class + ConfiguredModel, 92 ConfiguredModel must be the first class in the inheritance list 93 """ 94 95 # dont set this directly, use `set_config_class()` decorator 96 _config_class: type | None = None 97 zanj_config_class = property(lambda self: type(self)._config_class) 98 99 def __init__(self, zanj_model_config: T_config, **kwargs): 100 super().__init__(**kwargs) 101 if self.zanj_config_class is None: 102 raise NotImplementedError("you need to set `config_class` for your model") 103 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore 104 raise TypeError( 105 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }" 106 ) 107 108 self.zanj_model_config: T_config = zanj_model_config 109 self.training_records: dict | None = None 110 111 def serialize( 112 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 113 ) -> dict[str, Any]: 114 if zanj is None: 115 zanj = ZANJ() 116 obj = dict( 117 zanj_model_config=self.zanj_model_config.serialize(), 118 meta=dict( 119 class_name=self.__class__.__name__, 120 class_doc=string_as_lines(self.__class__.__doc__), 121 class_source=safe_getsource(self.__class__), 122 module_name=self.__class__.__module__, 123 module_mro=[str(x) for x in self.__class__.__mro__], 124 num_params=num_params(self), 125 as_str=string_as_lines(str(self)), 126 ), 127 training_records=self.training_records, 128 state_dict=self.state_dict(), 129 __muutils_format__=self.__class__.__name__, 130 ) 131 return obj 132 133 def save(self, file_path: str, zanj: ZANJ | None = None): 134 if zanj is None: 135 zanj = ZANJ() 136 zanj.save(self.serialize(), file_path) 137 138 def _load_state_dict_wrapper( 139 self, 140 state_dict: dict[str, torch.Tensor], 141 **kwargs, 142 ): 143 """wrapper for `load_state_dict()` in case you need to override it""" 144 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}" 145 return self.load_state_dict(state_dict) 146 147 @classmethod 148 def load( 149 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 150 ) -> "ConfiguredModel": 151 """load a model from a serialized object""" 152 153 if zanj is None: 154 zanj = ZANJ() 155 156 # get the config 157 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 158 159 # get the training records 160 training_records: typing.Any = load_item_recursive( 161 obj.get("training_records", None), 162 tuple(path) + ("training_records",), 163 zanj, 164 ) 165 166 # initialize the model 167 model: "ConfiguredModel" = cls(zanj_model_config) 168 169 # load the state dict 170 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 171 obj["state_dict"], 172 tuple(path) + ("state_dict",), 173 zanj, 174 ) 175 176 model._load_state_dict_wrapper( 177 tensored_state_dict, 178 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 179 ) 180 181 # set the training records 182 model.training_records = training_records 183 184 return model 185 186 @classmethod 187 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 188 """read a model from a file""" 189 if zanj is None: 190 zanj = ZANJ() 191 192 mdl: ConfiguredModel = zanj.read(file_path) 193 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 194 return mdl 195 196 @classmethod 197 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 198 """read a model from a file""" 199 warnings.warn( 200 "load_file() is deprecated, use read() instead", DeprecationWarning 201 ) 202 return cls.read(file_path, zanj) 203 204 @classmethod 205 def get_handler(cls) -> LoaderHandler: 206 cls_name: str = str(cls.__name__) 207 return LoaderHandler( 208 check=lambda json_item, path=None, z=None: ( # type: ignore 209 isinstance(json_item, dict) 210 and _FORMAT_KEY in json_item 211 and json_item[_FORMAT_KEY].startswith(cls_name) 212 ), 213 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 214 uid=cls_name, 215 source_pckg=cls.__module__, 216 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 217 ) 218 219 def num_params(self) -> int: 220 return num_params(self) 221 222 223def set_config_class( 224 config_class: Type[SerializableDataclass], 225) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]: 226 if not issubclass(config_class, SerializableDataclass): 227 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass") 228 229 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]: 230 # set the config class 231 cls._config_class = config_class 232 233 # register the handlers 234 register_loader_handler(cls.get_handler()) 235 236 # return the new class 237 return cls 238 239 return wrapper 240 241 242class ConfigMismatchException(ValueError): 243 def __init__(self, msg: str, diff): 244 super().__init__(msg) 245 self.diff = diff 246 247 def __str__(self): 248 return f"{super().__str__()}: {self.diff}" 249 250 251def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 252 """check both models are correct instances and have the same config 253 254 Raises: 255 ConfigMismatchException: if the configs don't match, e.diff will contain the diff 256 """ 257 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel" 258 assert isinstance(model_a.zanj_model_config, SerializableDataclass), ( 259 "model_a must have a zanj_model_config" 260 ) 261 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel" 262 assert isinstance(model_b.zanj_model_config, SerializableDataclass), ( 263 "model_b must have a zanj_model_config" 264 ) 265 266 cls_type: type = type(model_a.zanj_model_config) 267 268 if not (model_a.zanj_model_config == model_b.zanj_model_config): 269 raise ConfigMismatchException( 270 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match", 271 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined] 272 ) 273 274 275def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 276 """check the models are exactly equal, including state dict contents""" 277 assert_model_cfg_equality(model_a, model_b) 278 279 model_a_sd_keys: set[str] = set(model_a.state_dict().keys()) 280 model_b_sd_keys: set[str] = set(model_b.state_dict().keys()) 281 assert model_a_sd_keys == model_b_sd_keys, ( 282 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}" 283 ) 284 keys_failed: list[str] = list() 285 for k, v_a in model_a.state_dict().items(): 286 v_b = model_b.state_dict()[k] 287 if not (v_a == v_b).all(): 288 # if not torch.allclose(v, v_load): 289 keys_failed.append(k) 290 print(f"failed {k}") 291 else: 292 print(f"passed {k}") 293 assert len(keys_failed) == 0, ( 294 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}" 295 )
34def num_params(m: torch.nn.Module, only_trainable: bool = True): 35 """return total number of parameters in a model 36 37 - only counting shared parameters once 38 - if `only_trainable` is False, will include parameters with `requires_grad = False` 39 40 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model 41 """ 42 parameters: list[torch.nn.Parameter] = list(m.parameters()) 43 if only_trainable: 44 parameters = [p for p in parameters if p.requires_grad] 45 46 unique: list[torch.nn.Parameter] = list( 47 {p.data_ptr(): p for p in parameters}.values() 48 ) 49 50 return sum(p.numel() for p in unique)
return total number of parameters in a model
- only counting shared parameters once
- if
only_trainableis False, will include parameters withrequires_grad = False
https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
53def get_module_device( 54 m: torch.nn.Module, 55) -> tuple[bool, torch.device | dict[str, torch.device]]: 56 """get the current devices""" 57 58 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()} 59 60 if len(devs) == 0: 61 return False, devs 62 63 # check if all devices are the same by getting one device 64 dev_uni: torch.device = next(iter(devs.values())) 65 66 if all(dev == dev_uni for dev in devs.values()): 67 return True, dev_uni 68 else: 69 return False, devs
get the current devices
75class ConfiguredModel( 76 torch.nn.Module, 77 typing.Generic[T_config], 78 metaclass=abc.ABCMeta, 79): 80 """a model that has a configuration, for saving with ZANJ 81 82 ```python 83 @set_config_class(YourConfig) 84 class YourModule(ConfiguredModel[YourConfig]): 85 def __init__(self, cfg: YourConfig): 86 super().__init__(cfg) 87 ``` 88 89 `__init__()` must initialize the model from a config object only, and call 90 `super().__init__(zanj_model_config)` 91 92 If you are inheriting from another class + ConfiguredModel, 93 ConfiguredModel must be the first class in the inheritance list 94 """ 95 96 # dont set this directly, use `set_config_class()` decorator 97 _config_class: type | None = None 98 zanj_config_class = property(lambda self: type(self)._config_class) 99 100 def __init__(self, zanj_model_config: T_config, **kwargs): 101 super().__init__(**kwargs) 102 if self.zanj_config_class is None: 103 raise NotImplementedError("you need to set `config_class` for your model") 104 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore 105 raise TypeError( 106 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }" 107 ) 108 109 self.zanj_model_config: T_config = zanj_model_config 110 self.training_records: dict | None = None 111 112 def serialize( 113 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 114 ) -> dict[str, Any]: 115 if zanj is None: 116 zanj = ZANJ() 117 obj = dict( 118 zanj_model_config=self.zanj_model_config.serialize(), 119 meta=dict( 120 class_name=self.__class__.__name__, 121 class_doc=string_as_lines(self.__class__.__doc__), 122 class_source=safe_getsource(self.__class__), 123 module_name=self.__class__.__module__, 124 module_mro=[str(x) for x in self.__class__.__mro__], 125 num_params=num_params(self), 126 as_str=string_as_lines(str(self)), 127 ), 128 training_records=self.training_records, 129 state_dict=self.state_dict(), 130 __muutils_format__=self.__class__.__name__, 131 ) 132 return obj 133 134 def save(self, file_path: str, zanj: ZANJ | None = None): 135 if zanj is None: 136 zanj = ZANJ() 137 zanj.save(self.serialize(), file_path) 138 139 def _load_state_dict_wrapper( 140 self, 141 state_dict: dict[str, torch.Tensor], 142 **kwargs, 143 ): 144 """wrapper for `load_state_dict()` in case you need to override it""" 145 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}" 146 return self.load_state_dict(state_dict) 147 148 @classmethod 149 def load( 150 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 151 ) -> "ConfiguredModel": 152 """load a model from a serialized object""" 153 154 if zanj is None: 155 zanj = ZANJ() 156 157 # get the config 158 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 159 160 # get the training records 161 training_records: typing.Any = load_item_recursive( 162 obj.get("training_records", None), 163 tuple(path) + ("training_records",), 164 zanj, 165 ) 166 167 # initialize the model 168 model: "ConfiguredModel" = cls(zanj_model_config) 169 170 # load the state dict 171 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 172 obj["state_dict"], 173 tuple(path) + ("state_dict",), 174 zanj, 175 ) 176 177 model._load_state_dict_wrapper( 178 tensored_state_dict, 179 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 180 ) 181 182 # set the training records 183 model.training_records = training_records 184 185 return model 186 187 @classmethod 188 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 189 """read a model from a file""" 190 if zanj is None: 191 zanj = ZANJ() 192 193 mdl: ConfiguredModel = zanj.read(file_path) 194 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 195 return mdl 196 197 @classmethod 198 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 199 """read a model from a file""" 200 warnings.warn( 201 "load_file() is deprecated, use read() instead", DeprecationWarning 202 ) 203 return cls.read(file_path, zanj) 204 205 @classmethod 206 def get_handler(cls) -> LoaderHandler: 207 cls_name: str = str(cls.__name__) 208 return LoaderHandler( 209 check=lambda json_item, path=None, z=None: ( # type: ignore 210 isinstance(json_item, dict) 211 and _FORMAT_KEY in json_item 212 and json_item[_FORMAT_KEY].startswith(cls_name) 213 ), 214 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 215 uid=cls_name, 216 source_pckg=cls.__module__, 217 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 218 ) 219 220 def num_params(self) -> int: 221 return num_params(self)
a model that has a configuration, for saving with ZANJ
@set_config_class(YourConfig)
class YourModule(ConfiguredModel[YourConfig]):
def __init__(self, cfg: YourConfig):
super().__init__(cfg)
__init__() must initialize the model from a config object only, and call
super().__init__(zanj_model_config)
If you are inheriting from another class + ConfiguredModel, ConfiguredModel must be the first class in the inheritance list
112 def serialize( 113 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 114 ) -> dict[str, Any]: 115 if zanj is None: 116 zanj = ZANJ() 117 obj = dict( 118 zanj_model_config=self.zanj_model_config.serialize(), 119 meta=dict( 120 class_name=self.__class__.__name__, 121 class_doc=string_as_lines(self.__class__.__doc__), 122 class_source=safe_getsource(self.__class__), 123 module_name=self.__class__.__module__, 124 module_mro=[str(x) for x in self.__class__.__mro__], 125 num_params=num_params(self), 126 as_str=string_as_lines(str(self)), 127 ), 128 training_records=self.training_records, 129 state_dict=self.state_dict(), 130 __muutils_format__=self.__class__.__name__, 131 ) 132 return obj
148 @classmethod 149 def load( 150 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 151 ) -> "ConfiguredModel": 152 """load a model from a serialized object""" 153 154 if zanj is None: 155 zanj = ZANJ() 156 157 # get the config 158 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 159 160 # get the training records 161 training_records: typing.Any = load_item_recursive( 162 obj.get("training_records", None), 163 tuple(path) + ("training_records",), 164 zanj, 165 ) 166 167 # initialize the model 168 model: "ConfiguredModel" = cls(zanj_model_config) 169 170 # load the state dict 171 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 172 obj["state_dict"], 173 tuple(path) + ("state_dict",), 174 zanj, 175 ) 176 177 model._load_state_dict_wrapper( 178 tensored_state_dict, 179 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 180 ) 181 182 # set the training records 183 model.training_records = training_records 184 185 return model
load a model from a serialized object
187 @classmethod 188 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 189 """read a model from a file""" 190 if zanj is None: 191 zanj = ZANJ() 192 193 mdl: ConfiguredModel = zanj.read(file_path) 194 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 195 return mdl
read a model from a file
197 @classmethod 198 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 199 """read a model from a file""" 200 warnings.warn( 201 "load_file() is deprecated, use read() instead", DeprecationWarning 202 ) 203 return cls.read(file_path, zanj)
read a model from a file
205 @classmethod 206 def get_handler(cls) -> LoaderHandler: 207 cls_name: str = str(cls.__name__) 208 return LoaderHandler( 209 check=lambda json_item, path=None, z=None: ( # type: ignore 210 isinstance(json_item, dict) 211 and _FORMAT_KEY in json_item 212 and json_item[_FORMAT_KEY].startswith(cls_name) 213 ), 214 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 215 uid=cls_name, 216 source_pckg=cls.__module__, 217 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 218 )
Inherited Members
- torch.nn.modules.module.Module
- Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- set_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- mtia
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_post_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_pre_hook
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
224def set_config_class( 225 config_class: Type[SerializableDataclass], 226) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]: 227 if not issubclass(config_class, SerializableDataclass): 228 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass") 229 230 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]: 231 # set the config class 232 cls._config_class = config_class 233 234 # register the handlers 235 register_loader_handler(cls.get_handler()) 236 237 # return the new class 238 return cls 239 240 return wrapper
243class ConfigMismatchException(ValueError): 244 def __init__(self, msg: str, diff): 245 super().__init__(msg) 246 self.diff = diff 247 248 def __str__(self): 249 return f"{super().__str__()}: {self.diff}"
Inappropriate argument value (of correct type).
Inherited Members
- builtins.BaseException
- with_traceback
- add_note
- args
252def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 253 """check both models are correct instances and have the same config 254 255 Raises: 256 ConfigMismatchException: if the configs don't match, e.diff will contain the diff 257 """ 258 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel" 259 assert isinstance(model_a.zanj_model_config, SerializableDataclass), ( 260 "model_a must have a zanj_model_config" 261 ) 262 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel" 263 assert isinstance(model_b.zanj_model_config, SerializableDataclass), ( 264 "model_b must have a zanj_model_config" 265 ) 266 267 cls_type: type = type(model_a.zanj_model_config) 268 269 if not (model_a.zanj_model_config == model_b.zanj_model_config): 270 raise ConfigMismatchException( 271 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match", 272 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined] 273 )
check both models are correct instances and have the same config
Raises: ConfigMismatchException: if the configs don't match, e.diff will contain the diff
276def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 277 """check the models are exactly equal, including state dict contents""" 278 assert_model_cfg_equality(model_a, model_b) 279 280 model_a_sd_keys: set[str] = set(model_a.state_dict().keys()) 281 model_b_sd_keys: set[str] = set(model_b.state_dict().keys()) 282 assert model_a_sd_keys == model_b_sd_keys, ( 283 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}" 284 ) 285 keys_failed: list[str] = list() 286 for k, v_a in model_a.state_dict().items(): 287 v_b = model_b.state_dict()[k] 288 if not (v_a == v_b).all(): 289 # if not torch.allclose(v, v_load): 290 keys_failed.append(k) 291 print(f"failed {k}") 292 else: 293 print(f"passed {k}") 294 assert len(keys_failed) == 0, ( 295 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}" 296 )
check the models are exactly equal, including state dict contents