zanj.torchutil
torch utilities for zanj -- in particular the ConfiguredModel base class
note that this requires torch
1"""torch utilities for zanj -- in particular the `ConfiguredModel` base class 2 3note that this requires torch 4""" 5 6from __future__ import annotations 7 8import abc 9import typing 10import warnings 11from typing import Any, Type, TypeVar 12 13try: 14 import torch # type: ignore[import-not-found] 15except ImportError as e: 16 raise ImportError( 17 "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`" 18 ) from e 19 20from muutils.json_serialize import SerializableDataclass 21from muutils.json_serialize.json_serialize import ObjectPath 22from muutils.json_serialize.util import safe_getsource, string_as_lines, _FORMAT_KEY 23 24from zanj import ZANJ, register_loader_handler 25from zanj.loading import LoaderHandler, load_item_recursive 26 27# pylint: disable=protected-access 28 29KWArgs = Any 30 31 32def num_params(m: torch.nn.Module, only_trainable: bool = True): 33 """return total number of parameters in a model 34 35 - only counting shared parameters once 36 - if `only_trainable` is False, will include parameters with `requires_grad = False` 37 38 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model 39 """ 40 parameters: list[torch.nn.Parameter] = list(m.parameters()) 41 if only_trainable: 42 parameters = [p for p in parameters if p.requires_grad] 43 44 unique: list[torch.nn.Parameter] = list( 45 {p.data_ptr(): p for p in parameters}.values() 46 ) 47 48 return sum(p.numel() for p in unique) 49 50 51def get_module_device( 52 m: torch.nn.Module, 53) -> tuple[bool, torch.device | dict[str, torch.device]]: 54 """get the current devices""" 55 56 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()} 57 58 if len(devs) == 0: 59 return False, devs 60 61 # check if all devices are the same by getting one device 62 dev_uni: torch.device = next(iter(devs.values())) 63 64 if all(dev == dev_uni for dev in devs.values()): 65 return True, dev_uni 66 else: 67 return False, devs 68 69 70T_config = TypeVar("T_config", bound=SerializableDataclass) 71 72 73class ConfiguredModel( 74 torch.nn.Module, 75 typing.Generic[T_config], 76 metaclass=abc.ABCMeta, 77): 78 """a model that has a configuration, for saving with ZANJ 79 80 ```python 81 @set_config_class(YourConfig) 82 class YourModule(ConfiguredModel[YourConfig]): 83 def __init__(self, cfg: YourConfig): 84 super().__init__(cfg) 85 ``` 86 87 `__init__()` must initialize the model from a config object only, and call 88 `super().__init__(zanj_model_config)` 89 90 If you are inheriting from another class + ConfiguredModel, 91 ConfiguredModel must be the first class in the inheritance list 92 """ 93 94 # dont set this directly, use `set_config_class()` decorator 95 _config_class: type | None = None 96 zanj_config_class = property(lambda self: type(self)._config_class) 97 98 def __init__(self, zanj_model_config: T_config, **kwargs): 99 super().__init__(**kwargs) 100 if self.zanj_config_class is None: 101 raise NotImplementedError("you need to set `config_class` for your model") 102 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore 103 raise TypeError( 104 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }" 105 ) 106 107 self.zanj_model_config: T_config = zanj_model_config 108 self.training_records: dict | None = None 109 110 def serialize( 111 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 112 ) -> dict[str, Any]: 113 if zanj is None: 114 zanj = ZANJ() 115 obj = dict( 116 zanj_model_config=self.zanj_model_config.serialize(), 117 meta=dict( 118 class_name=self.__class__.__name__, 119 class_doc=string_as_lines(self.__class__.__doc__), 120 class_source=safe_getsource(self.__class__), 121 module_name=self.__class__.__module__, 122 module_mro=[str(x) for x in self.__class__.__mro__], 123 num_params=num_params(self), 124 as_str=string_as_lines(str(self)), 125 ), 126 training_records=self.training_records, 127 state_dict=self.state_dict(), 128 __muutils_format__=self.__class__.__name__, 129 ) 130 return obj 131 132 def save(self, file_path: str, zanj: ZANJ | None = None): 133 if zanj is None: 134 zanj = ZANJ() 135 zanj.save(self.serialize(), file_path) 136 137 def _load_state_dict_wrapper( 138 self, 139 state_dict: dict[str, torch.Tensor], 140 **kwargs, 141 ): 142 """wrapper for `load_state_dict()` in case you need to override it""" 143 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}" 144 return self.load_state_dict(state_dict) 145 146 @classmethod 147 def load( 148 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 149 ) -> "ConfiguredModel": 150 """load a model from a serialized object""" 151 152 if zanj is None: 153 zanj = ZANJ() 154 155 # get the config 156 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 157 158 # get the training records 159 training_records: typing.Any = load_item_recursive( 160 obj.get("training_records", None), 161 tuple(path) + ("training_records",), 162 zanj, 163 ) 164 165 # initialize the model 166 model: "ConfiguredModel" = cls(zanj_model_config) 167 168 # load the state dict 169 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 170 obj["state_dict"], 171 tuple(path) + ("state_dict",), 172 zanj, 173 ) 174 175 model._load_state_dict_wrapper( 176 tensored_state_dict, 177 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 178 ) 179 180 # set the training records 181 model.training_records = training_records 182 183 return model 184 185 @classmethod 186 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 187 """read a model from a file""" 188 if zanj is None: 189 zanj = ZANJ() 190 191 mdl: ConfiguredModel = zanj.read(file_path) 192 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 193 return mdl 194 195 @classmethod 196 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 197 """read a model from a file""" 198 warnings.warn( 199 "load_file() is deprecated, use read() instead", DeprecationWarning 200 ) 201 return cls.read(file_path, zanj) 202 203 @classmethod 204 def get_handler(cls) -> LoaderHandler: 205 cls_name: str = str(cls.__name__) 206 return LoaderHandler( 207 check=lambda json_item, path=None, z=None: ( # type: ignore 208 isinstance(json_item, dict) 209 and _FORMAT_KEY in json_item 210 and json_item[_FORMAT_KEY].startswith(cls_name) 211 ), 212 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 213 uid=cls_name, 214 source_pckg=cls.__module__, 215 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 216 ) 217 218 def num_params(self) -> int: 219 return num_params(self) 220 221 222def set_config_class( 223 config_class: Type[SerializableDataclass], 224) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]: 225 if not issubclass(config_class, SerializableDataclass): 226 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass") 227 228 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]: 229 # set the config class 230 cls._config_class = config_class 231 232 # register the handlers 233 register_loader_handler(cls.get_handler()) 234 235 # return the new class 236 return cls 237 238 return wrapper 239 240 241class ConfigMismatchException(ValueError): 242 def __init__(self, msg: str, diff): 243 super().__init__(msg) 244 self.diff = diff 245 246 def __str__(self): 247 return f"{super().__str__()}: {self.diff}" 248 249 250def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 251 """check both models are correct instances and have the same config 252 253 Raises: 254 ConfigMismatchException: if the configs don't match, e.diff will contain the diff 255 """ 256 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel" 257 assert isinstance(model_a.zanj_model_config, SerializableDataclass), ( 258 "model_a must have a zanj_model_config" 259 ) 260 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel" 261 assert isinstance(model_b.zanj_model_config, SerializableDataclass), ( 262 "model_b must have a zanj_model_config" 263 ) 264 265 cls_type: type = type(model_a.zanj_model_config) 266 267 if not (model_a.zanj_model_config == model_b.zanj_model_config): 268 raise ConfigMismatchException( 269 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match", 270 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined] 271 ) 272 273 274def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 275 """check the models are exactly equal, including state dict contents""" 276 assert_model_cfg_equality(model_a, model_b) 277 278 model_a_sd_keys: set[str] = set(model_a.state_dict().keys()) 279 model_b_sd_keys: set[str] = set(model_b.state_dict().keys()) 280 assert model_a_sd_keys == model_b_sd_keys, ( 281 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}" 282 ) 283 keys_failed: list[str] = list() 284 for k, v_a in model_a.state_dict().items(): 285 v_b = model_b.state_dict()[k] 286 if not (v_a == v_b).all(): 287 # if not torch.allclose(v, v_load): 288 keys_failed.append(k) 289 print(f"failed {k}") 290 else: 291 print(f"passed {k}") 292 assert len(keys_failed) == 0, ( 293 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}" 294 )
33def num_params(m: torch.nn.Module, only_trainable: bool = True): 34 """return total number of parameters in a model 35 36 - only counting shared parameters once 37 - if `only_trainable` is False, will include parameters with `requires_grad = False` 38 39 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model 40 """ 41 parameters: list[torch.nn.Parameter] = list(m.parameters()) 42 if only_trainable: 43 parameters = [p for p in parameters if p.requires_grad] 44 45 unique: list[torch.nn.Parameter] = list( 46 {p.data_ptr(): p for p in parameters}.values() 47 ) 48 49 return sum(p.numel() for p in unique)
return total number of parameters in a model
- only counting shared parameters once
- if
only_trainableis False, will include parameters withrequires_grad = False
https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
52def get_module_device( 53 m: torch.nn.Module, 54) -> tuple[bool, torch.device | dict[str, torch.device]]: 55 """get the current devices""" 56 57 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()} 58 59 if len(devs) == 0: 60 return False, devs 61 62 # check if all devices are the same by getting one device 63 dev_uni: torch.device = next(iter(devs.values())) 64 65 if all(dev == dev_uni for dev in devs.values()): 66 return True, dev_uni 67 else: 68 return False, devs
get the current devices
74class ConfiguredModel( 75 torch.nn.Module, 76 typing.Generic[T_config], 77 metaclass=abc.ABCMeta, 78): 79 """a model that has a configuration, for saving with ZANJ 80 81 ```python 82 @set_config_class(YourConfig) 83 class YourModule(ConfiguredModel[YourConfig]): 84 def __init__(self, cfg: YourConfig): 85 super().__init__(cfg) 86 ``` 87 88 `__init__()` must initialize the model from a config object only, and call 89 `super().__init__(zanj_model_config)` 90 91 If you are inheriting from another class + ConfiguredModel, 92 ConfiguredModel must be the first class in the inheritance list 93 """ 94 95 # dont set this directly, use `set_config_class()` decorator 96 _config_class: type | None = None 97 zanj_config_class = property(lambda self: type(self)._config_class) 98 99 def __init__(self, zanj_model_config: T_config, **kwargs): 100 super().__init__(**kwargs) 101 if self.zanj_config_class is None: 102 raise NotImplementedError("you need to set `config_class` for your model") 103 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore 104 raise TypeError( 105 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }" 106 ) 107 108 self.zanj_model_config: T_config = zanj_model_config 109 self.training_records: dict | None = None 110 111 def serialize( 112 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 113 ) -> dict[str, Any]: 114 if zanj is None: 115 zanj = ZANJ() 116 obj = dict( 117 zanj_model_config=self.zanj_model_config.serialize(), 118 meta=dict( 119 class_name=self.__class__.__name__, 120 class_doc=string_as_lines(self.__class__.__doc__), 121 class_source=safe_getsource(self.__class__), 122 module_name=self.__class__.__module__, 123 module_mro=[str(x) for x in self.__class__.__mro__], 124 num_params=num_params(self), 125 as_str=string_as_lines(str(self)), 126 ), 127 training_records=self.training_records, 128 state_dict=self.state_dict(), 129 __muutils_format__=self.__class__.__name__, 130 ) 131 return obj 132 133 def save(self, file_path: str, zanj: ZANJ | None = None): 134 if zanj is None: 135 zanj = ZANJ() 136 zanj.save(self.serialize(), file_path) 137 138 def _load_state_dict_wrapper( 139 self, 140 state_dict: dict[str, torch.Tensor], 141 **kwargs, 142 ): 143 """wrapper for `load_state_dict()` in case you need to override it""" 144 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}" 145 return self.load_state_dict(state_dict) 146 147 @classmethod 148 def load( 149 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 150 ) -> "ConfiguredModel": 151 """load a model from a serialized object""" 152 153 if zanj is None: 154 zanj = ZANJ() 155 156 # get the config 157 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 158 159 # get the training records 160 training_records: typing.Any = load_item_recursive( 161 obj.get("training_records", None), 162 tuple(path) + ("training_records",), 163 zanj, 164 ) 165 166 # initialize the model 167 model: "ConfiguredModel" = cls(zanj_model_config) 168 169 # load the state dict 170 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 171 obj["state_dict"], 172 tuple(path) + ("state_dict",), 173 zanj, 174 ) 175 176 model._load_state_dict_wrapper( 177 tensored_state_dict, 178 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 179 ) 180 181 # set the training records 182 model.training_records = training_records 183 184 return model 185 186 @classmethod 187 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 188 """read a model from a file""" 189 if zanj is None: 190 zanj = ZANJ() 191 192 mdl: ConfiguredModel = zanj.read(file_path) 193 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 194 return mdl 195 196 @classmethod 197 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 198 """read a model from a file""" 199 warnings.warn( 200 "load_file() is deprecated, use read() instead", DeprecationWarning 201 ) 202 return cls.read(file_path, zanj) 203 204 @classmethod 205 def get_handler(cls) -> LoaderHandler: 206 cls_name: str = str(cls.__name__) 207 return LoaderHandler( 208 check=lambda json_item, path=None, z=None: ( # type: ignore 209 isinstance(json_item, dict) 210 and _FORMAT_KEY in json_item 211 and json_item[_FORMAT_KEY].startswith(cls_name) 212 ), 213 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 214 uid=cls_name, 215 source_pckg=cls.__module__, 216 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 217 ) 218 219 def num_params(self) -> int: 220 return num_params(self)
a model that has a configuration, for saving with ZANJ
@set_config_class(YourConfig)
class YourModule(ConfiguredModel[YourConfig]):
def __init__(self, cfg: YourConfig):
super().__init__(cfg)
__init__() must initialize the model from a config object only, and call
super().__init__(zanj_model_config)
If you are inheriting from another class + ConfiguredModel, ConfiguredModel must be the first class in the inheritance list
111 def serialize( 112 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 113 ) -> dict[str, Any]: 114 if zanj is None: 115 zanj = ZANJ() 116 obj = dict( 117 zanj_model_config=self.zanj_model_config.serialize(), 118 meta=dict( 119 class_name=self.__class__.__name__, 120 class_doc=string_as_lines(self.__class__.__doc__), 121 class_source=safe_getsource(self.__class__), 122 module_name=self.__class__.__module__, 123 module_mro=[str(x) for x in self.__class__.__mro__], 124 num_params=num_params(self), 125 as_str=string_as_lines(str(self)), 126 ), 127 training_records=self.training_records, 128 state_dict=self.state_dict(), 129 __muutils_format__=self.__class__.__name__, 130 ) 131 return obj
147 @classmethod 148 def load( 149 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 150 ) -> "ConfiguredModel": 151 """load a model from a serialized object""" 152 153 if zanj is None: 154 zanj = ZANJ() 155 156 # get the config 157 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 158 159 # get the training records 160 training_records: typing.Any = load_item_recursive( 161 obj.get("training_records", None), 162 tuple(path) + ("training_records",), 163 zanj, 164 ) 165 166 # initialize the model 167 model: "ConfiguredModel" = cls(zanj_model_config) 168 169 # load the state dict 170 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 171 obj["state_dict"], 172 tuple(path) + ("state_dict",), 173 zanj, 174 ) 175 176 model._load_state_dict_wrapper( 177 tensored_state_dict, 178 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 179 ) 180 181 # set the training records 182 model.training_records = training_records 183 184 return model
load a model from a serialized object
186 @classmethod 187 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 188 """read a model from a file""" 189 if zanj is None: 190 zanj = ZANJ() 191 192 mdl: ConfiguredModel = zanj.read(file_path) 193 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 194 return mdl
read a model from a file
196 @classmethod 197 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 198 """read a model from a file""" 199 warnings.warn( 200 "load_file() is deprecated, use read() instead", DeprecationWarning 201 ) 202 return cls.read(file_path, zanj)
read a model from a file
204 @classmethod 205 def get_handler(cls) -> LoaderHandler: 206 cls_name: str = str(cls.__name__) 207 return LoaderHandler( 208 check=lambda json_item, path=None, z=None: ( # type: ignore 209 isinstance(json_item, dict) 210 and _FORMAT_KEY in json_item 211 and json_item[_FORMAT_KEY].startswith(cls_name) 212 ), 213 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 214 uid=cls_name, 215 source_pckg=cls.__module__, 216 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 217 )
Inherited Members
- torch.nn.modules.module.Module
- Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- set_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- mtia
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_post_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_pre_hook
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
223def set_config_class( 224 config_class: Type[SerializableDataclass], 225) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]: 226 if not issubclass(config_class, SerializableDataclass): 227 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass") 228 229 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]: 230 # set the config class 231 cls._config_class = config_class 232 233 # register the handlers 234 register_loader_handler(cls.get_handler()) 235 236 # return the new class 237 return cls 238 239 return wrapper
242class ConfigMismatchException(ValueError): 243 def __init__(self, msg: str, diff): 244 super().__init__(msg) 245 self.diff = diff 246 247 def __str__(self): 248 return f"{super().__str__()}: {self.diff}"
Inappropriate argument value (of correct type).
Inherited Members
- builtins.BaseException
- with_traceback
- add_note
- args
251def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 252 """check both models are correct instances and have the same config 253 254 Raises: 255 ConfigMismatchException: if the configs don't match, e.diff will contain the diff 256 """ 257 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel" 258 assert isinstance(model_a.zanj_model_config, SerializableDataclass), ( 259 "model_a must have a zanj_model_config" 260 ) 261 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel" 262 assert isinstance(model_b.zanj_model_config, SerializableDataclass), ( 263 "model_b must have a zanj_model_config" 264 ) 265 266 cls_type: type = type(model_a.zanj_model_config) 267 268 if not (model_a.zanj_model_config == model_b.zanj_model_config): 269 raise ConfigMismatchException( 270 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match", 271 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined] 272 )
check both models are correct instances and have the same config
Raises: ConfigMismatchException: if the configs don't match, e.diff will contain the diff
275def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 276 """check the models are exactly equal, including state dict contents""" 277 assert_model_cfg_equality(model_a, model_b) 278 279 model_a_sd_keys: set[str] = set(model_a.state_dict().keys()) 280 model_b_sd_keys: set[str] = set(model_b.state_dict().keys()) 281 assert model_a_sd_keys == model_b_sd_keys, ( 282 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}" 283 ) 284 keys_failed: list[str] = list() 285 for k, v_a in model_a.state_dict().items(): 286 v_b = model_b.state_dict()[k] 287 if not (v_a == v_b).all(): 288 # if not torch.allclose(v, v_load): 289 keys_failed.append(k) 290 print(f"failed {k}") 291 else: 292 print(f"passed {k}") 293 assert len(keys_failed) == 0, ( 294 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}" 295 )
check the models are exactly equal, including state dict contents