muutils.tensor_utils
utilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE: a mapping from python, numpy, and torch types tojaxtypingtypesDTYPE_MAPmapping string representations of types to their typeTORCH_DTYPE_MAPmapping string representations of types to torch typescompare_state_dictsfor comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
1"""utilities for working with tensors and arrays. 2 3notably: 4 5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 6- `DTYPE_MAP` mapping string representations of types to their type 7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 9 10""" 11 12from __future__ import annotations 13 14import json 15import typing 16from typing import Any 17 18import jaxtyping 19import numpy as np 20import torch 21 22from muutils.dictmagic import dotlist_to_nested_dict 23 24# pylint: disable=missing-class-docstring 25 26 27TYPE_TO_JAX_DTYPE: dict[Any, Any] = { 28 float: jaxtyping.Float, 29 int: jaxtyping.Int, 30 jaxtyping.Float: jaxtyping.Float, 31 jaxtyping.Int: jaxtyping.Int, 32 # bool 33 bool: jaxtyping.Bool, 34 jaxtyping.Bool: jaxtyping.Bool, 35 np.bool_: jaxtyping.Bool, 36 torch.bool: jaxtyping.Bool, 37 # numpy float 38 np.float16: jaxtyping.Float, 39 np.float32: jaxtyping.Float, 40 np.float64: jaxtyping.Float, 41 np.half: jaxtyping.Float, 42 np.single: jaxtyping.Float, 43 np.double: jaxtyping.Float, 44 # numpy int 45 np.int8: jaxtyping.Int, 46 np.int16: jaxtyping.Int, 47 np.int32: jaxtyping.Int, 48 np.int64: jaxtyping.Int, 49 np.longlong: jaxtyping.Int, 50 np.short: jaxtyping.Int, 51 np.uint8: jaxtyping.Int, 52 # torch float 53 torch.float: jaxtyping.Float, 54 torch.float16: jaxtyping.Float, 55 torch.float32: jaxtyping.Float, 56 torch.float64: jaxtyping.Float, 57 torch.half: jaxtyping.Float, 58 torch.double: jaxtyping.Float, 59 torch.bfloat16: jaxtyping.Float, 60 # torch int 61 torch.int: jaxtyping.Int, 62 torch.int8: jaxtyping.Int, 63 torch.int16: jaxtyping.Int, 64 torch.int32: jaxtyping.Int, 65 torch.int64: jaxtyping.Int, 66 torch.long: jaxtyping.Int, 67 torch.short: jaxtyping.Int, 68} 69"dict mapping python, numpy, and torch types to `jaxtyping` types" 70 71# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0 72# use try/except for backwards compatibility and type checker friendliness 73try: 74 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] 75 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] 76except AttributeError: 77 pass # numpy 2.0+ removed these deprecated aliases 78 79 80# TODO: add proper type annotations to this signature 81# TODO: maybe get rid of this altogether? 82# def jaxtype_factory( 83# name: str, 84# array_type: type, 85# default_jax_dtype: type[jaxtyping.Float | jaxtyping.Int | jaxtyping.Bool] = jaxtyping.Float, 86# legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, 87# ) -> type: 88# """usage: 89# ``` 90# ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 91# x: ATensor["dim1 dim2", np.float32] 92# ``` 93# """ 94# legacy_mode_ = ErrorMode.from_any(legacy_mode) 95 96# class _BaseArray: 97# """jaxtyping shorthand 98# (backwards compatible with older versions of muutils.tensor_utils) 99 100# default_jax_dtype = {default_jax_dtype} 101# array_type = {array_type} 102# """ 103 104# def __new__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn: 105# raise TypeError("Type FArray cannot be instantiated.") 106 107# def __init_subclass__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn: 108# raise TypeError(f"Cannot subclass {cls.__name__}") 109 110# @classmethod 111# def param_info(cls, params: typing.Union[str, tuple[Any, ...]]) -> str: 112# """useful for error printing""" 113# return "\n".join( 114# f"{k} = {v}" 115# for k, v in { 116# "cls.__name__": cls.__name__, 117# "cls.__doc__": cls.__doc__, 118# "params": params, 119# "type(params)": type(params), 120# }.items() 121# ) 122 123# @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] 124# def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: # type: ignore[misc] 125# # MyTensor["dim1 dim2"] 126# if isinstance(params, str): 127# return default_jax_dtype[array_type, params] 128 129# elif isinstance(params, tuple): 130# if len(params) != 2: 131# raise Exception( 132# f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 133# ) 134 135# if isinstance(params[0], str): 136# # MyTensor["dim1 dim2", int] 137# return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 138 139# elif isinstance(params[0], tuple): 140# legacy_mode_.process( 141# f"legacy type annotation was used:\n{cls.param_info(params) = }", 142# except_cls=Exception, 143# ) 144# # MyTensor[("dim1", "dim2"), int] 145# shape_anot: list[str] = list() 146# for x in params[0]: 147# if isinstance(x, str): 148# shape_anot.append(x) 149# elif isinstance(x, int): 150# shape_anot.append(str(x)) 151# elif isinstance(x, tuple): 152# shape_anot.append("".join(str(y) for y in x)) 153# else: 154# raise Exception( 155# f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 156# ) 157 158# return TYPE_TO_JAX_DTYPE[params[1]][ 159# array_type, " ".join(shape_anot) 160# ] 161# else: 162# raise Exception( 163# f"unexpected type for params:\n{cls.param_info(params)}" 164# ) 165 166# _BaseArray.__name__ = name 167 168# if _BaseArray.__doc__ is None: 169# _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 170 171# _BaseArray.__doc__ = _BaseArray.__doc__.format( 172# default_jax_dtype=repr(default_jax_dtype), 173# array_type=repr(array_type), 174# ) 175 176# return _BaseArray 177 178 179if typing.TYPE_CHECKING: 180 # these class definitions are only used here to make pylint happy, 181 # but they make mypy unhappy and there is no way to only run if not mypy 182 # so, later on we have more ignores 183 class ATensor(torch.Tensor): 184 @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] 185 def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: 186 raise NotImplementedError() 187 188 class NDArray(torch.Tensor): 189 @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] 190 def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: 191 raise NotImplementedError() 192 193 194# ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 195 196# NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 197 198 199def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 200 """convert numpy dtype to torch dtype""" 201 if isinstance(dtype, torch.dtype): 202 return dtype 203 else: 204 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 205 206 207DTYPE_LIST: list[Any] = [ 208 *[ 209 bool, 210 int, 211 float, 212 ], 213 *[ 214 # ---------- 215 # pytorch 216 # ---------- 217 # floats 218 torch.float, 219 torch.float32, 220 torch.float64, 221 torch.half, 222 torch.double, 223 torch.bfloat16, 224 # complex 225 torch.complex64, 226 torch.complex128, 227 # ints 228 torch.int, 229 torch.int8, 230 torch.int16, 231 torch.int32, 232 torch.int64, 233 torch.long, 234 torch.short, 235 # simplest 236 torch.uint8, 237 torch.bool, 238 ], 239 *[ 240 # ---------- 241 # numpy 242 # ---------- 243 # floats 244 np.float16, 245 np.float32, 246 np.float64, 247 np.half, 248 np.single, 249 np.double, 250 # complex 251 np.complex64, 252 np.complex128, 253 # ints 254 np.int8, 255 np.int16, 256 np.int32, 257 np.int64, 258 np.longlong, 259 np.short, 260 # simplest 261 np.uint8, 262 np.bool_, 263 ], 264] 265"list of all the python, numpy, and torch numerical types I could think of" 266 267# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0 268try: 269 DTYPE_LIST.extend([np.float_, np.int_]) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] 270except AttributeError: 271 pass # numpy 2.0+ removed these deprecated aliases 272 273DTYPE_MAP: dict[str, Any] = { 274 **{str(x): x for x in DTYPE_LIST}, 275 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 276} 277"mapping from string representations of types to their type" 278 279TORCH_DTYPE_MAP: dict[str, torch.dtype] = { 280 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 281} 282"mapping from string representations of types to specifically torch types" 283 284# no idea why we have to do this, smh 285DTYPE_MAP["bool"] = np.bool_ 286TORCH_DTYPE_MAP["bool"] = torch.bool 287 288 289TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 290 "Adagrad": torch.optim.Adagrad, 291 "Adam": torch.optim.Adam, 292 "AdamW": torch.optim.AdamW, 293 "SparseAdam": torch.optim.SparseAdam, 294 "Adamax": torch.optim.Adamax, 295 "ASGD": torch.optim.ASGD, 296 "LBFGS": torch.optim.LBFGS, 297 "NAdam": torch.optim.NAdam, 298 "RAdam": torch.optim.RAdam, 299 "RMSprop": torch.optim.RMSprop, 300 "Rprop": torch.optim.Rprop, 301 "SGD": torch.optim.SGD, 302} 303 304 305def pad_tensor( 306 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 307 padded_length: int, 308 pad_value: float = 0.0, 309 rpad: bool = False, 310) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 311 """pad a 1-d tensor on the left with pad_value to length `padded_length` 312 313 set `rpad = True` to pad on the right instead""" 314 315 temp: list[torch.Tensor] = [ 316 torch.full( 317 (padded_length - tensor.shape[0],), 318 pad_value, 319 dtype=tensor.dtype, 320 device=tensor.device, 321 ), 322 tensor, 323 ] 324 325 if rpad: 326 temp.reverse() 327 328 return torch.cat(temp) 329 330 331def lpad_tensor( 332 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 333) -> torch.Tensor: 334 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 335 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 336 337 338def rpad_tensor( 339 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 340) -> torch.Tensor: 341 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 342 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 343 344 345def pad_array( 346 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 347 padded_length: int, 348 pad_value: float = 0.0, 349 rpad: bool = False, 350) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 351 """pad a 1-d array on the left with pad_value to length `padded_length` 352 353 set `rpad = True` to pad on the right instead""" 354 355 temp: list[np.ndarray] = [ 356 np.full( 357 (padded_length - array.shape[0],), 358 pad_value, 359 dtype=array.dtype, 360 ), 361 array, 362 ] 363 364 if rpad: 365 temp.reverse() 366 367 return np.concatenate(temp) 368 369 370def lpad_array( 371 array: np.ndarray, padded_length: int, pad_value: float = 0.0 372) -> np.ndarray: 373 """pad a 1-d array on the left with pad_value to length `padded_length`""" 374 return pad_array(array, padded_length, pad_value, rpad=False) 375 376 377def rpad_array( 378 array: np.ndarray, pad_length: int, pad_value: float = 0.0 379) -> np.ndarray: 380 """pad a 1-d array on the right with pad_value to length `pad_length`""" 381 return pad_array(array, pad_length, pad_value, rpad=True) 382 383 384def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 385 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 386 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 387 388 389def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 390 """printable version of get_dict_shapes""" 391 return json.dumps( 392 dotlist_to_nested_dict( 393 { 394 k: str( 395 tuple(v.shape) 396 ) # to string, since indent wont play nice with tuples 397 for k, v in d.items() 398 } 399 ), 400 indent=2, 401 ) 402 403 404class StateDictCompareError(AssertionError): 405 """raised when state dicts don't match""" 406 407 pass 408 409 410class StateDictKeysError(StateDictCompareError): 411 """raised when state dict keys don't match""" 412 413 pass 414 415 416class StateDictShapeError(StateDictCompareError): 417 """raised when state dict shapes don't match""" 418 419 pass 420 421 422class StateDictValueError(StateDictCompareError): 423 """raised when state dict values don't match""" 424 425 pass 426 427 428def compare_state_dicts( 429 d1: dict[str, Any], 430 d2: dict[str, Any], 431 rtol: float = 1e-5, 432 atol: float = 1e-8, 433 verbose: bool = True, 434) -> None: 435 """compare two dicts of tensors 436 437 # Parameters: 438 439 - `d1 : dict` 440 - `d2 : dict` 441 - `rtol : float` 442 (defaults to `1e-5`) 443 - `atol : float` 444 (defaults to `1e-8`) 445 - `verbose : bool` 446 (defaults to `True`) 447 448 # Raises: 449 450 - `StateDictKeysError` : keys don't match 451 - `StateDictShapeError` : shapes don't match (but keys do) 452 - `StateDictValueError` : values don't match (but keys and shapes do) 453 """ 454 # check keys match 455 d1_keys: set[str] = set(d1.keys()) 456 d2_keys: set[str] = set(d2.keys()) 457 symmetric_diff: set[str] = set.symmetric_difference(d1_keys, d2_keys) 458 keys_diff_1: set[str] = d1_keys - d2_keys 459 keys_diff_2: set[str] = d2_keys - d1_keys 460 # sort sets for easier debugging 461 symmetric_diff = set(sorted(symmetric_diff)) 462 keys_diff_1 = set(sorted(keys_diff_1)) 463 keys_diff_2 = set(sorted(keys_diff_2)) 464 diff_shapes_1: str = ( 465 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 466 if verbose 467 else "(verbose = False)" 468 ) 469 diff_shapes_2: str = ( 470 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 471 if verbose 472 else "(verbose = False)" 473 ) 474 if not len(symmetric_diff) == 0: 475 raise StateDictKeysError( 476 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 477 ) 478 479 # check tensors match 480 shape_failed: list[str] = list() 481 vals_failed: list[str] = list() 482 for k, v1 in d1.items(): 483 v2 = d2[k] 484 # check shapes first 485 if not v1.shape == v2.shape: 486 shape_failed.append(k) 487 else: 488 # if shapes match, check values 489 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 490 vals_failed.append(k) 491 492 str_shape_failed: str = ( 493 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 494 ) 495 str_vals_failed: str = ( 496 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 497 ) 498 499 if not len(shape_failed) == 0: 500 raise StateDictShapeError( 501 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 502 ) 503 if not len(vals_failed) == 0: 504 raise StateDictValueError( 505 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 506 )
dict mapping python, numpy, and torch types to jaxtyping types
200def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 201 """convert numpy dtype to torch dtype""" 202 if isinstance(dtype, torch.dtype): 203 return dtype 204 else: 205 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
convert numpy dtype to torch dtype
list of all the python, numpy, and torch numerical types I could think of
mapping from string representations of types to their type
mapping from string representations of types to specifically torch types
306def pad_tensor( 307 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 308 padded_length: int, 309 pad_value: float = 0.0, 310 rpad: bool = False, 311) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 312 """pad a 1-d tensor on the left with pad_value to length `padded_length` 313 314 set `rpad = True` to pad on the right instead""" 315 316 temp: list[torch.Tensor] = [ 317 torch.full( 318 (padded_length - tensor.shape[0],), 319 pad_value, 320 dtype=tensor.dtype, 321 device=tensor.device, 322 ), 323 tensor, 324 ] 325 326 if rpad: 327 temp.reverse() 328 329 return torch.cat(temp)
pad a 1-d tensor on the left with pad_value to length padded_length
set rpad = True to pad on the right instead
332def lpad_tensor( 333 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 334) -> torch.Tensor: 335 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 336 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
pad a 1-d tensor on the left with pad_value to length padded_length
339def rpad_tensor( 340 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 341) -> torch.Tensor: 342 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 343 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
pad a 1-d tensor on the right with pad_value to length pad_length
346def pad_array( 347 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 348 padded_length: int, 349 pad_value: float = 0.0, 350 rpad: bool = False, 351) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 352 """pad a 1-d array on the left with pad_value to length `padded_length` 353 354 set `rpad = True` to pad on the right instead""" 355 356 temp: list[np.ndarray] = [ 357 np.full( 358 (padded_length - array.shape[0],), 359 pad_value, 360 dtype=array.dtype, 361 ), 362 array, 363 ] 364 365 if rpad: 366 temp.reverse() 367 368 return np.concatenate(temp)
pad a 1-d array on the left with pad_value to length padded_length
set rpad = True to pad on the right instead
371def lpad_array( 372 array: np.ndarray, padded_length: int, pad_value: float = 0.0 373) -> np.ndarray: 374 """pad a 1-d array on the left with pad_value to length `padded_length`""" 375 return pad_array(array, padded_length, pad_value, rpad=False)
pad a 1-d array on the left with pad_value to length padded_length
378def rpad_array( 379 array: np.ndarray, pad_length: int, pad_value: float = 0.0 380) -> np.ndarray: 381 """pad a 1-d array on the right with pad_value to length `pad_length`""" 382 return pad_array(array, pad_length, pad_value, rpad=True)
pad a 1-d array on the right with pad_value to length pad_length
385def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 386 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 387 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
given a state dict or cache dict, compute the shapes and put them in a nested dict
390def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 391 """printable version of get_dict_shapes""" 392 return json.dumps( 393 dotlist_to_nested_dict( 394 { 395 k: str( 396 tuple(v.shape) 397 ) # to string, since indent wont play nice with tuples 398 for k, v in d.items() 399 } 400 ), 401 indent=2, 402 )
printable version of get_dict_shapes
405class StateDictCompareError(AssertionError): 406 """raised when state dicts don't match""" 407 408 pass
raised when state dicts don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
411class StateDictKeysError(StateDictCompareError): 412 """raised when state dict keys don't match""" 413 414 pass
raised when state dict keys don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
417class StateDictShapeError(StateDictCompareError): 418 """raised when state dict shapes don't match""" 419 420 pass
raised when state dict shapes don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
423class StateDictValueError(StateDictCompareError): 424 """raised when state dict values don't match""" 425 426 pass
raised when state dict values don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
429def compare_state_dicts( 430 d1: dict[str, Any], 431 d2: dict[str, Any], 432 rtol: float = 1e-5, 433 atol: float = 1e-8, 434 verbose: bool = True, 435) -> None: 436 """compare two dicts of tensors 437 438 # Parameters: 439 440 - `d1 : dict` 441 - `d2 : dict` 442 - `rtol : float` 443 (defaults to `1e-5`) 444 - `atol : float` 445 (defaults to `1e-8`) 446 - `verbose : bool` 447 (defaults to `True`) 448 449 # Raises: 450 451 - `StateDictKeysError` : keys don't match 452 - `StateDictShapeError` : shapes don't match (but keys do) 453 - `StateDictValueError` : values don't match (but keys and shapes do) 454 """ 455 # check keys match 456 d1_keys: set[str] = set(d1.keys()) 457 d2_keys: set[str] = set(d2.keys()) 458 symmetric_diff: set[str] = set.symmetric_difference(d1_keys, d2_keys) 459 keys_diff_1: set[str] = d1_keys - d2_keys 460 keys_diff_2: set[str] = d2_keys - d1_keys 461 # sort sets for easier debugging 462 symmetric_diff = set(sorted(symmetric_diff)) 463 keys_diff_1 = set(sorted(keys_diff_1)) 464 keys_diff_2 = set(sorted(keys_diff_2)) 465 diff_shapes_1: str = ( 466 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 467 if verbose 468 else "(verbose = False)" 469 ) 470 diff_shapes_2: str = ( 471 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 472 if verbose 473 else "(verbose = False)" 474 ) 475 if not len(symmetric_diff) == 0: 476 raise StateDictKeysError( 477 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 478 ) 479 480 # check tensors match 481 shape_failed: list[str] = list() 482 vals_failed: list[str] = list() 483 for k, v1 in d1.items(): 484 v2 = d2[k] 485 # check shapes first 486 if not v1.shape == v2.shape: 487 shape_failed.append(k) 488 else: 489 # if shapes match, check values 490 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 491 vals_failed.append(k) 492 493 str_shape_failed: str = ( 494 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 495 ) 496 str_vals_failed: str = ( 497 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 498 ) 499 500 if not len(shape_failed) == 0: 501 raise StateDictShapeError( 502 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 503 ) 504 if not len(vals_failed) == 0: 505 raise StateDictValueError( 506 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 507 )
compare two dicts of tensors
Parameters:
d1 : dictd2 : dictrtol : float(defaults to1e-5)atol : float(defaults to1e-8)verbose : bool(defaults toTrue)
Raises:
StateDictKeysError: keys don't matchStateDictShapeError: shapes don't match (but keys do)StateDictValueError: values don't match (but keys and shapes do)