docs for muutils v0.9.1
View Source on GitHub

muutils.tensor_utils

utilities for working with tensors and arrays.

notably:

  • TYPE_TO_JAX_DTYPE : a mapping from python, numpy, and torch types to jaxtyping types
  • DTYPE_MAP mapping string representations of types to their type
  • TORCH_DTYPE_MAP mapping string representations of types to torch types
  • compare_state_dicts for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match

  1"""utilities for working with tensors and arrays.
  2
  3notably:
  4
  5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
  6- `DTYPE_MAP` mapping string representations of types to their type
  7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
  8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
  9
 10"""
 11
 12from __future__ import annotations
 13
 14import json
 15import typing
 16from typing import Any
 17
 18import jaxtyping
 19import numpy as np
 20import torch
 21
 22from muutils.dictmagic import dotlist_to_nested_dict
 23
 24# pylint: disable=missing-class-docstring
 25
 26
 27TYPE_TO_JAX_DTYPE: dict[Any, Any] = {
 28    float: jaxtyping.Float,
 29    int: jaxtyping.Int,
 30    jaxtyping.Float: jaxtyping.Float,
 31    jaxtyping.Int: jaxtyping.Int,
 32    # bool
 33    bool: jaxtyping.Bool,
 34    jaxtyping.Bool: jaxtyping.Bool,
 35    np.bool_: jaxtyping.Bool,
 36    torch.bool: jaxtyping.Bool,
 37    # numpy float
 38    np.float16: jaxtyping.Float,
 39    np.float32: jaxtyping.Float,
 40    np.float64: jaxtyping.Float,
 41    np.half: jaxtyping.Float,
 42    np.single: jaxtyping.Float,
 43    np.double: jaxtyping.Float,
 44    # numpy int
 45    np.int8: jaxtyping.Int,
 46    np.int16: jaxtyping.Int,
 47    np.int32: jaxtyping.Int,
 48    np.int64: jaxtyping.Int,
 49    np.longlong: jaxtyping.Int,
 50    np.short: jaxtyping.Int,
 51    np.uint8: jaxtyping.Int,
 52    # torch float
 53    torch.float: jaxtyping.Float,
 54    torch.float16: jaxtyping.Float,
 55    torch.float32: jaxtyping.Float,
 56    torch.float64: jaxtyping.Float,
 57    torch.half: jaxtyping.Float,
 58    torch.double: jaxtyping.Float,
 59    torch.bfloat16: jaxtyping.Float,
 60    # torch int
 61    torch.int: jaxtyping.Int,
 62    torch.int8: jaxtyping.Int,
 63    torch.int16: jaxtyping.Int,
 64    torch.int32: jaxtyping.Int,
 65    torch.int64: jaxtyping.Int,
 66    torch.long: jaxtyping.Int,
 67    torch.short: jaxtyping.Int,
 68}
 69"dict mapping python, numpy, and torch types to `jaxtyping` types"
 70
 71# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0
 72# use try/except for backwards compatibility and type checker friendliness
 73try:
 74    TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
 75    TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
 76except AttributeError:
 77    pass  # numpy 2.0+ removed these deprecated aliases
 78
 79
 80# TODO: add proper type annotations to this signature
 81# TODO: maybe get rid of this altogether?
 82# def jaxtype_factory(
 83#     name: str,
 84#     array_type: type,
 85#     default_jax_dtype: type[jaxtyping.Float | jaxtyping.Int | jaxtyping.Bool] = jaxtyping.Float,
 86#     legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN,
 87# ) -> type:
 88#     """usage:
 89#     ```
 90#     ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
 91#     x: ATensor["dim1 dim2", np.float32]
 92#     ```
 93#     """
 94#     legacy_mode_ = ErrorMode.from_any(legacy_mode)
 95
 96#     class _BaseArray:
 97#         """jaxtyping shorthand
 98#         (backwards compatible with older versions of muutils.tensor_utils)
 99
100#         default_jax_dtype = {default_jax_dtype}
101#         array_type = {array_type}
102#         """
103
104#         def __new__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn:
105#             raise TypeError("Type FArray cannot be instantiated.")
106
107#         def __init_subclass__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn:
108#             raise TypeError(f"Cannot subclass {cls.__name__}")
109
110#         @classmethod
111#         def param_info(cls, params: typing.Union[str, tuple[Any, ...]]) -> str:
112#             """useful for error printing"""
113#             return "\n".join(
114#                 f"{k} = {v}"
115#                 for k, v in {
116#                     "cls.__name__": cls.__name__,
117#                     "cls.__doc__": cls.__doc__,
118#                     "params": params,
119#                     "type(params)": type(params),
120#                 }.items()
121#             )
122
123#         @typing._tp_cache  # type: ignore[attr-defined]  # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
124#         def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type:  # type: ignore[misc]
125#             # MyTensor["dim1 dim2"]
126#             if isinstance(params, str):
127#                 return default_jax_dtype[array_type, params]
128
129#             elif isinstance(params, tuple):
130#                 if len(params) != 2:
131#                     raise Exception(
132#                         f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
133#                     )
134
135#                 if isinstance(params[0], str):
136#                     # MyTensor["dim1 dim2", int]
137#                     return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
138
139#                 elif isinstance(params[0], tuple):
140#                     legacy_mode_.process(
141#                         f"legacy type annotation was used:\n{cls.param_info(params) = }",
142#                         except_cls=Exception,
143#                     )
144#                     # MyTensor[("dim1", "dim2"), int]
145#                     shape_anot: list[str] = list()
146#                     for x in params[0]:
147#                         if isinstance(x, str):
148#                             shape_anot.append(x)
149#                         elif isinstance(x, int):
150#                             shape_anot.append(str(x))
151#                         elif isinstance(x, tuple):
152#                             shape_anot.append("".join(str(y) for y in x))
153#                         else:
154#                             raise Exception(
155#                                 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
156#                             )
157
158#                     return TYPE_TO_JAX_DTYPE[params[1]][
159#                         array_type, " ".join(shape_anot)
160#                     ]
161#             else:
162#                 raise Exception(
163#                     f"unexpected type for params:\n{cls.param_info(params)}"
164#                 )
165
166#     _BaseArray.__name__ = name
167
168#     if _BaseArray.__doc__ is None:
169#         _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
170
171#     _BaseArray.__doc__ = _BaseArray.__doc__.format(
172#         default_jax_dtype=repr(default_jax_dtype),
173#         array_type=repr(array_type),
174#     )
175
176#     return _BaseArray
177
178
179if typing.TYPE_CHECKING:
180    # these class definitions are only used here to make pylint happy,
181    # but they make mypy unhappy and there is no way to only run if not mypy
182    # so, later on we have more ignores
183    class ATensor(torch.Tensor):
184        @typing._tp_cache  # type: ignore[attr-defined]  # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
185        def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type:
186            raise NotImplementedError()
187
188    class NDArray(torch.Tensor):
189        @typing._tp_cache  # type: ignore[attr-defined]  # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
190        def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type:
191            raise NotImplementedError()
192
193
194# ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)  # type: ignore[misc, assignment]
195
196# NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float)  # type: ignore[misc, assignment]
197
198
199def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
200    """convert numpy dtype to torch dtype"""
201    if isinstance(dtype, torch.dtype):
202        return dtype
203    else:
204        return torch.from_numpy(np.array(0, dtype=dtype)).dtype
205
206
207DTYPE_LIST: list[Any] = [
208    *[
209        bool,
210        int,
211        float,
212    ],
213    *[
214        # ----------
215        # pytorch
216        # ----------
217        # floats
218        torch.float,
219        torch.float32,
220        torch.float64,
221        torch.half,
222        torch.double,
223        torch.bfloat16,
224        # complex
225        torch.complex64,
226        torch.complex128,
227        # ints
228        torch.int,
229        torch.int8,
230        torch.int16,
231        torch.int32,
232        torch.int64,
233        torch.long,
234        torch.short,
235        # simplest
236        torch.uint8,
237        torch.bool,
238    ],
239    *[
240        # ----------
241        # numpy
242        # ----------
243        # floats
244        np.float16,
245        np.float32,
246        np.float64,
247        np.half,
248        np.single,
249        np.double,
250        # complex
251        np.complex64,
252        np.complex128,
253        # ints
254        np.int8,
255        np.int16,
256        np.int32,
257        np.int64,
258        np.longlong,
259        np.short,
260        # simplest
261        np.uint8,
262        np.bool_,
263    ],
264]
265"list of all the python, numpy, and torch numerical types I could think of"
266
267# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0
268try:
269    DTYPE_LIST.extend([np.float_, np.int_])  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
270except AttributeError:
271    pass  # numpy 2.0+ removed these deprecated aliases
272
273DTYPE_MAP: dict[str, Any] = {
274    **{str(x): x for x in DTYPE_LIST},
275    **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
276}
277"mapping from string representations of types to their type"
278
279TORCH_DTYPE_MAP: dict[str, torch.dtype] = {
280    key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
281}
282"mapping from string representations of types to specifically torch types"
283
284# no idea why we have to do this, smh
285DTYPE_MAP["bool"] = np.bool_
286TORCH_DTYPE_MAP["bool"] = torch.bool
287
288
289TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
290    "Adagrad": torch.optim.Adagrad,
291    "Adam": torch.optim.Adam,
292    "AdamW": torch.optim.AdamW,
293    "SparseAdam": torch.optim.SparseAdam,
294    "Adamax": torch.optim.Adamax,
295    "ASGD": torch.optim.ASGD,
296    "LBFGS": torch.optim.LBFGS,
297    "NAdam": torch.optim.NAdam,
298    "RAdam": torch.optim.RAdam,
299    "RMSprop": torch.optim.RMSprop,
300    "Rprop": torch.optim.Rprop,
301    "SGD": torch.optim.SGD,
302}
303
304
305def pad_tensor(
306    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
307    padded_length: int,
308    pad_value: float = 0.0,
309    rpad: bool = False,
310) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
311    """pad a 1-d tensor on the left with pad_value to length `padded_length`
312
313    set `rpad = True` to pad on the right instead"""
314
315    temp: list[torch.Tensor] = [
316        torch.full(
317            (padded_length - tensor.shape[0],),
318            pad_value,
319            dtype=tensor.dtype,
320            device=tensor.device,
321        ),
322        tensor,
323    ]
324
325    if rpad:
326        temp.reverse()
327
328    return torch.cat(temp)
329
330
331def lpad_tensor(
332    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
333) -> torch.Tensor:
334    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
335    return pad_tensor(tensor, padded_length, pad_value, rpad=False)
336
337
338def rpad_tensor(
339    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
340) -> torch.Tensor:
341    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
342    return pad_tensor(tensor, pad_length, pad_value, rpad=True)
343
344
345def pad_array(
346    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
347    padded_length: int,
348    pad_value: float = 0.0,
349    rpad: bool = False,
350) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
351    """pad a 1-d array on the left with pad_value to length `padded_length`
352
353    set `rpad = True` to pad on the right instead"""
354
355    temp: list[np.ndarray] = [
356        np.full(
357            (padded_length - array.shape[0],),
358            pad_value,
359            dtype=array.dtype,
360        ),
361        array,
362    ]
363
364    if rpad:
365        temp.reverse()
366
367    return np.concatenate(temp)
368
369
370def lpad_array(
371    array: np.ndarray, padded_length: int, pad_value: float = 0.0
372) -> np.ndarray:
373    """pad a 1-d array on the left with pad_value to length `padded_length`"""
374    return pad_array(array, padded_length, pad_value, rpad=False)
375
376
377def rpad_array(
378    array: np.ndarray, pad_length: int, pad_value: float = 0.0
379) -> np.ndarray:
380    """pad a 1-d array on the right with pad_value to length `pad_length`"""
381    return pad_array(array, pad_length, pad_value, rpad=True)
382
383
384def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
385    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
386    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
387
388
389def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
390    """printable version of get_dict_shapes"""
391    return json.dumps(
392        dotlist_to_nested_dict(
393            {
394                k: str(
395                    tuple(v.shape)
396                )  # to string, since indent wont play nice with tuples
397                for k, v in d.items()
398            }
399        ),
400        indent=2,
401    )
402
403
404class StateDictCompareError(AssertionError):
405    """raised when state dicts don't match"""
406
407    pass
408
409
410class StateDictKeysError(StateDictCompareError):
411    """raised when state dict keys don't match"""
412
413    pass
414
415
416class StateDictShapeError(StateDictCompareError):
417    """raised when state dict shapes don't match"""
418
419    pass
420
421
422class StateDictValueError(StateDictCompareError):
423    """raised when state dict values don't match"""
424
425    pass
426
427
428def compare_state_dicts(
429    d1: dict[str, Any],
430    d2: dict[str, Any],
431    rtol: float = 1e-5,
432    atol: float = 1e-8,
433    verbose: bool = True,
434) -> None:
435    """compare two dicts of tensors
436
437    # Parameters:
438
439     - `d1 : dict`
440     - `d2 : dict`
441     - `rtol : float`
442       (defaults to `1e-5`)
443     - `atol : float`
444       (defaults to `1e-8`)
445     - `verbose : bool`
446       (defaults to `True`)
447
448    # Raises:
449
450     - `StateDictKeysError` : keys don't match
451     - `StateDictShapeError` : shapes don't match (but keys do)
452     - `StateDictValueError` : values don't match (but keys and shapes do)
453    """
454    # check keys match
455    d1_keys: set[str] = set(d1.keys())
456    d2_keys: set[str] = set(d2.keys())
457    symmetric_diff: set[str] = set.symmetric_difference(d1_keys, d2_keys)
458    keys_diff_1: set[str] = d1_keys - d2_keys
459    keys_diff_2: set[str] = d2_keys - d1_keys
460    # sort sets for easier debugging
461    symmetric_diff = set(sorted(symmetric_diff))
462    keys_diff_1 = set(sorted(keys_diff_1))
463    keys_diff_2 = set(sorted(keys_diff_2))
464    diff_shapes_1: str = (
465        string_dict_shapes({k: d1[k] for k in keys_diff_1})
466        if verbose
467        else "(verbose = False)"
468    )
469    diff_shapes_2: str = (
470        string_dict_shapes({k: d2[k] for k in keys_diff_2})
471        if verbose
472        else "(verbose = False)"
473    )
474    if not len(symmetric_diff) == 0:
475        raise StateDictKeysError(
476            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
477        )
478
479    # check tensors match
480    shape_failed: list[str] = list()
481    vals_failed: list[str] = list()
482    for k, v1 in d1.items():
483        v2 = d2[k]
484        # check shapes first
485        if not v1.shape == v2.shape:
486            shape_failed.append(k)
487        else:
488            # if shapes match, check values
489            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
490                vals_failed.append(k)
491
492    str_shape_failed: str = (
493        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
494    )
495    str_vals_failed: str = (
496        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
497    )
498
499    if not len(shape_failed) == 0:
500        raise StateDictShapeError(
501            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
502        )
503    if not len(vals_failed) == 0:
504        raise StateDictValueError(
505            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
506        )

TYPE_TO_JAX_DTYPE: dict[typing.Any, typing.Any] = {<class 'float'>: <class 'jaxtyping.Float'>, <class 'int'>: <class 'jaxtyping.Int'>, <class 'jaxtyping.Float'>: <class 'jaxtyping.Float'>, <class 'jaxtyping.Int'>: <class 'jaxtyping.Int'>, <class 'bool'>: <class 'jaxtyping.Bool'>, <class 'jaxtyping.Bool'>: <class 'jaxtyping.Bool'>, <class 'numpy.bool'>: <class 'jaxtyping.Bool'>, torch.bool: <class 'jaxtyping.Bool'>, <class 'numpy.float16'>: <class 'jaxtyping.Float'>, <class 'numpy.float32'>: <class 'jaxtyping.Float'>, <class 'numpy.float64'>: <class 'jaxtyping.Float'>, <class 'numpy.int8'>: <class 'jaxtyping.Int'>, <class 'numpy.int16'>: <class 'jaxtyping.Int'>, <class 'numpy.int32'>: <class 'jaxtyping.Int'>, <class 'numpy.int64'>: <class 'jaxtyping.Int'>, <class 'numpy.longlong'>: <class 'jaxtyping.Int'>, <class 'numpy.uint8'>: <class 'jaxtyping.Int'>, torch.float32: <class 'jaxtyping.Float'>, torch.float16: <class 'jaxtyping.Float'>, torch.float64: <class 'jaxtyping.Float'>, torch.bfloat16: <class 'jaxtyping.Float'>, torch.int32: <class 'jaxtyping.Int'>, torch.int8: <class 'jaxtyping.Int'>, torch.int16: <class 'jaxtyping.Int'>, torch.int64: <class 'jaxtyping.Int'>}

dict mapping python, numpy, and torch types to jaxtyping types

def numpy_to_torch_dtype(dtype: Union[numpy.dtype, torch.dtype]) -> torch.dtype:
200def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
201    """convert numpy dtype to torch dtype"""
202    if isinstance(dtype, torch.dtype):
203        return dtype
204    else:
205        return torch.from_numpy(np.array(0, dtype=dtype)).dtype

convert numpy dtype to torch dtype

DTYPE_LIST: list[typing.Any] = [<class 'bool'>, <class 'int'>, <class 'float'>, torch.float32, torch.float32, torch.float64, torch.float16, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.int32, torch.int8, torch.int16, torch.int32, torch.int64, torch.int64, torch.int16, torch.uint8, torch.bool, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.int8'>, <class 'numpy.int16'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.longlong'>, <class 'numpy.int16'>, <class 'numpy.uint8'>, <class 'numpy.bool'>]

list of all the python, numpy, and torch numerical types I could think of

DTYPE_MAP: dict[str, typing.Any] = {"<class 'bool'>": <class 'bool'>, "<class 'int'>": <class 'int'>, "<class 'float'>": <class 'float'>, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": <class 'numpy.float16'>, "<class 'numpy.float32'>": <class 'numpy.float32'>, "<class 'numpy.float64'>": <class 'numpy.float64'>, "<class 'numpy.complex64'>": <class 'numpy.complex64'>, "<class 'numpy.complex128'>": <class 'numpy.complex128'>, "<class 'numpy.int8'>": <class 'numpy.int8'>, "<class 'numpy.int16'>": <class 'numpy.int16'>, "<class 'numpy.int32'>": <class 'numpy.int32'>, "<class 'numpy.int64'>": <class 'numpy.int64'>, "<class 'numpy.longlong'>": <class 'numpy.longlong'>, "<class 'numpy.uint8'>": <class 'numpy.uint8'>, "<class 'numpy.bool'>": <class 'numpy.bool'>, 'float16': <class 'numpy.float16'>, 'float32': <class 'numpy.float32'>, 'float64': <class 'numpy.float64'>, 'complex64': <class 'numpy.complex64'>, 'complex128': <class 'numpy.complex128'>, 'int8': <class 'numpy.int8'>, 'int16': <class 'numpy.int16'>, 'int32': <class 'numpy.int32'>, 'int64': <class 'numpy.int64'>, 'longlong': <class 'numpy.longlong'>, 'uint8': <class 'numpy.uint8'>, 'bool': <class 'numpy.bool'>}

mapping from string representations of types to their type

TORCH_DTYPE_MAP: dict[str, torch.dtype] = {"<class 'bool'>": torch.bool, "<class 'int'>": torch.int64, "<class 'float'>": torch.float64, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": torch.float16, "<class 'numpy.float32'>": torch.float32, "<class 'numpy.float64'>": torch.float64, "<class 'numpy.complex64'>": torch.complex64, "<class 'numpy.complex128'>": torch.complex128, "<class 'numpy.int8'>": torch.int8, "<class 'numpy.int16'>": torch.int16, "<class 'numpy.int32'>": torch.int32, "<class 'numpy.int64'>": torch.int64, "<class 'numpy.longlong'>": torch.int64, "<class 'numpy.uint8'>": torch.uint8, "<class 'numpy.bool'>": torch.bool, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'complex64': torch.complex64, 'complex128': torch.complex128, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'longlong': torch.int64, 'uint8': torch.uint8, 'bool': torch.bool}

mapping from string representations of types to specifically torch types

TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.optimizer.Optimizer]] = {'Adagrad': <class 'torch.optim.adagrad.Adagrad'>, 'Adam': <class 'torch.optim.adam.Adam'>, 'AdamW': <class 'torch.optim.adamw.AdamW'>, 'SparseAdam': <class 'torch.optim.sparse_adam.SparseAdam'>, 'Adamax': <class 'torch.optim.adamax.Adamax'>, 'ASGD': <class 'torch.optim.asgd.ASGD'>, 'LBFGS': <class 'torch.optim.lbfgs.LBFGS'>, 'NAdam': <class 'torch.optim.nadam.NAdam'>, 'RAdam': <class 'torch.optim.radam.RAdam'>, 'RMSprop': <class 'torch.optim.rmsprop.RMSprop'>, 'Rprop': <class 'torch.optim.rprop.Rprop'>, 'SGD': <class 'torch.optim.sgd.SGD'>}
def pad_tensor( tensor: jaxtyping.Shaped[Tensor, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[Tensor, 'padded_length']:
306def pad_tensor(
307    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
308    padded_length: int,
309    pad_value: float = 0.0,
310    rpad: bool = False,
311) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
312    """pad a 1-d tensor on the left with pad_value to length `padded_length`
313
314    set `rpad = True` to pad on the right instead"""
315
316    temp: list[torch.Tensor] = [
317        torch.full(
318            (padded_length - tensor.shape[0],),
319            pad_value,
320            dtype=tensor.dtype,
321            device=tensor.device,
322        ),
323        tensor,
324    ]
325
326    if rpad:
327        temp.reverse()
328
329    return torch.cat(temp)

pad a 1-d tensor on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_tensor( tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0) -> torch.Tensor:
332def lpad_tensor(
333    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
334) -> torch.Tensor:
335    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
336    return pad_tensor(tensor, padded_length, pad_value, rpad=False)

pad a 1-d tensor on the left with pad_value to length padded_length

def rpad_tensor( tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0) -> torch.Tensor:
339def rpad_tensor(
340    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
341) -> torch.Tensor:
342    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
343    return pad_tensor(tensor, pad_length, pad_value, rpad=True)

pad a 1-d tensor on the right with pad_value to length pad_length

def pad_array( array: jaxtyping.Shaped[ndarray, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[ndarray, 'padded_length']:
346def pad_array(
347    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
348    padded_length: int,
349    pad_value: float = 0.0,
350    rpad: bool = False,
351) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
352    """pad a 1-d array on the left with pad_value to length `padded_length`
353
354    set `rpad = True` to pad on the right instead"""
355
356    temp: list[np.ndarray] = [
357        np.full(
358            (padded_length - array.shape[0],),
359            pad_value,
360            dtype=array.dtype,
361        ),
362        array,
363    ]
364
365    if rpad:
366        temp.reverse()
367
368    return np.concatenate(temp)

pad a 1-d array on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_array( array: numpy.ndarray, padded_length: int, pad_value: float = 0.0) -> numpy.ndarray:
371def lpad_array(
372    array: np.ndarray, padded_length: int, pad_value: float = 0.0
373) -> np.ndarray:
374    """pad a 1-d array on the left with pad_value to length `padded_length`"""
375    return pad_array(array, padded_length, pad_value, rpad=False)

pad a 1-d array on the left with pad_value to length padded_length

def rpad_array( array: numpy.ndarray, pad_length: int, pad_value: float = 0.0) -> numpy.ndarray:
378def rpad_array(
379    array: np.ndarray, pad_length: int, pad_value: float = 0.0
380) -> np.ndarray:
381    """pad a 1-d array on the right with pad_value to length `pad_length`"""
382    return pad_array(array, pad_length, pad_value, rpad=True)

pad a 1-d array on the right with pad_value to length pad_length

def get_dict_shapes(d: dict[str, torch.Tensor]) -> dict[str, tuple[int, ...]]:
385def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
386    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
387    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})

given a state dict or cache dict, compute the shapes and put them in a nested dict

def string_dict_shapes(d: dict[str, torch.Tensor]) -> str:
390def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
391    """printable version of get_dict_shapes"""
392    return json.dumps(
393        dotlist_to_nested_dict(
394            {
395                k: str(
396                    tuple(v.shape)
397                )  # to string, since indent wont play nice with tuples
398                for k, v in d.items()
399            }
400        ),
401        indent=2,
402    )

printable version of get_dict_shapes

class StateDictCompareError(builtins.AssertionError):
405class StateDictCompareError(AssertionError):
406    """raised when state dicts don't match"""
407
408    pass

raised when state dicts don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictKeysError(StateDictCompareError):
411class StateDictKeysError(StateDictCompareError):
412    """raised when state dict keys don't match"""
413
414    pass

raised when state dict keys don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictShapeError(StateDictCompareError):
417class StateDictShapeError(StateDictCompareError):
418    """raised when state dict shapes don't match"""
419
420    pass

raised when state dict shapes don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictValueError(StateDictCompareError):
423class StateDictValueError(StateDictCompareError):
424    """raised when state dict values don't match"""
425
426    pass

raised when state dict values don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
def compare_state_dicts( d1: dict[str, typing.Any], d2: dict[str, typing.Any], rtol: float = 1e-05, atol: float = 1e-08, verbose: bool = True) -> None:
429def compare_state_dicts(
430    d1: dict[str, Any],
431    d2: dict[str, Any],
432    rtol: float = 1e-5,
433    atol: float = 1e-8,
434    verbose: bool = True,
435) -> None:
436    """compare two dicts of tensors
437
438    # Parameters:
439
440     - `d1 : dict`
441     - `d2 : dict`
442     - `rtol : float`
443       (defaults to `1e-5`)
444     - `atol : float`
445       (defaults to `1e-8`)
446     - `verbose : bool`
447       (defaults to `True`)
448
449    # Raises:
450
451     - `StateDictKeysError` : keys don't match
452     - `StateDictShapeError` : shapes don't match (but keys do)
453     - `StateDictValueError` : values don't match (but keys and shapes do)
454    """
455    # check keys match
456    d1_keys: set[str] = set(d1.keys())
457    d2_keys: set[str] = set(d2.keys())
458    symmetric_diff: set[str] = set.symmetric_difference(d1_keys, d2_keys)
459    keys_diff_1: set[str] = d1_keys - d2_keys
460    keys_diff_2: set[str] = d2_keys - d1_keys
461    # sort sets for easier debugging
462    symmetric_diff = set(sorted(symmetric_diff))
463    keys_diff_1 = set(sorted(keys_diff_1))
464    keys_diff_2 = set(sorted(keys_diff_2))
465    diff_shapes_1: str = (
466        string_dict_shapes({k: d1[k] for k in keys_diff_1})
467        if verbose
468        else "(verbose = False)"
469    )
470    diff_shapes_2: str = (
471        string_dict_shapes({k: d2[k] for k in keys_diff_2})
472        if verbose
473        else "(verbose = False)"
474    )
475    if not len(symmetric_diff) == 0:
476        raise StateDictKeysError(
477            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
478        )
479
480    # check tensors match
481    shape_failed: list[str] = list()
482    vals_failed: list[str] = list()
483    for k, v1 in d1.items():
484        v2 = d2[k]
485        # check shapes first
486        if not v1.shape == v2.shape:
487            shape_failed.append(k)
488        else:
489            # if shapes match, check values
490            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
491                vals_failed.append(k)
492
493    str_shape_failed: str = (
494        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
495    )
496    str_vals_failed: str = (
497        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
498    )
499
500    if not len(shape_failed) == 0:
501        raise StateDictShapeError(
502            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
503        )
504    if not len(vals_failed) == 0:
505        raise StateDictValueError(
506            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
507        )

compare two dicts of tensors

Parameters:

  • d1 : dict
  • d2 : dict
  • rtol : float (defaults to 1e-5)
  • atol : float (defaults to 1e-8)
  • verbose : bool (defaults to True)

Raises: