docs for muutils v0.6.15
View Source on GitHub

muutils.tensor_utils

utilities for working with tensors and arrays.

notably:

  • TYPE_TO_JAX_DTYPE : a mapping from python, numpy, and torch types to jaxtyping types
  • DTYPE_MAP mapping string representations of types to their type
  • TORCH_DTYPE_MAP mapping string representations of types to torch types
  • compare_state_dicts for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match

  1"""utilities for working with tensors and arrays.
  2
  3notably:
  4
  5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
  6- `DTYPE_MAP` mapping string representations of types to their type
  7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
  8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
  9
 10"""
 11
 12from __future__ import annotations
 13
 14import json
 15import typing
 16
 17import jaxtyping
 18import numpy as np
 19import torch
 20
 21from muutils.errormode import ErrorMode
 22from muutils.dictmagic import dotlist_to_nested_dict
 23
 24# pylint: disable=missing-class-docstring
 25
 26
 27TYPE_TO_JAX_DTYPE: dict = {
 28    float: jaxtyping.Float,
 29    int: jaxtyping.Int,
 30    jaxtyping.Float: jaxtyping.Float,
 31    jaxtyping.Int: jaxtyping.Int,
 32    # bool
 33    bool: jaxtyping.Bool,
 34    jaxtyping.Bool: jaxtyping.Bool,
 35    np.bool_: jaxtyping.Bool,
 36    torch.bool: jaxtyping.Bool,
 37    # numpy float
 38    np.float_: jaxtyping.Float,
 39    np.float16: jaxtyping.Float,
 40    np.float32: jaxtyping.Float,
 41    np.float64: jaxtyping.Float,
 42    np.half: jaxtyping.Float,
 43    np.single: jaxtyping.Float,
 44    np.double: jaxtyping.Float,
 45    # numpy int
 46    np.int_: jaxtyping.Int,
 47    np.int8: jaxtyping.Int,
 48    np.int16: jaxtyping.Int,
 49    np.int32: jaxtyping.Int,
 50    np.int64: jaxtyping.Int,
 51    np.longlong: jaxtyping.Int,
 52    np.short: jaxtyping.Int,
 53    np.uint8: jaxtyping.Int,
 54    # torch float
 55    torch.float: jaxtyping.Float,
 56    torch.float16: jaxtyping.Float,
 57    torch.float32: jaxtyping.Float,
 58    torch.float64: jaxtyping.Float,
 59    torch.half: jaxtyping.Float,
 60    torch.double: jaxtyping.Float,
 61    torch.bfloat16: jaxtyping.Float,
 62    # torch int
 63    torch.int: jaxtyping.Int,
 64    torch.int8: jaxtyping.Int,
 65    torch.int16: jaxtyping.Int,
 66    torch.int32: jaxtyping.Int,
 67    torch.int64: jaxtyping.Int,
 68    torch.long: jaxtyping.Int,
 69    torch.short: jaxtyping.Int,
 70}
 71"dict mapping python, numpy, and torch types to `jaxtyping` types"
 72
 73
 74# TODO: add proper type annotations to this signature
 75def jaxtype_factory(
 76    name: str,
 77    array_type: type,
 78    default_jax_dtype=jaxtyping.Float,
 79    legacy_mode: ErrorMode = ErrorMode.WARN,
 80) -> type:
 81    """usage:
 82    ```
 83    ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
 84    x: ATensor["dim1 dim2", np.float32]
 85    ```
 86    """
 87    legacy_mode = ErrorMode.from_any(legacy_mode)
 88
 89    class _BaseArray:
 90        """jaxtyping shorthand
 91        (backwards compatible with older versions of muutils.tensor_utils)
 92
 93        default_jax_dtype = {default_jax_dtype}
 94        array_type = {array_type}
 95        """
 96
 97        def __new__(cls, *args, **kwargs):
 98            raise TypeError("Type FArray cannot be instantiated.")
 99
100        def __init_subclass__(cls, *args, **kwargs):
101            raise TypeError(f"Cannot subclass {cls.__name__}")
102
103        @classmethod
104        def param_info(cls, params) -> str:
105            """useful for error printing"""
106            return "\n".join(
107                f"{k} = {v}"
108                for k, v in {
109                    "cls.__name__": cls.__name__,
110                    "cls.__doc__": cls.__doc__,
111                    "params": params,
112                    "type(params)": type(params),
113                }.items()
114            )
115
116        @typing._tp_cache  # type: ignore
117        def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:
118            # MyTensor["dim1 dim2"]
119            if isinstance(params, str):
120                return default_jax_dtype[array_type, params]
121
122            elif isinstance(params, tuple):
123                if len(params) != 2:
124                    raise Exception(
125                        f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
126                    )
127
128                if isinstance(params[0], str):
129                    # MyTensor["dim1 dim2", int]
130                    return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
131
132                elif isinstance(params[0], tuple):
133                    legacy_mode.process(
134                        f"legacy type annotation was used:\n{cls.param_info(params) = }",
135                        except_cls=Exception,
136                    )
137                    # MyTensor[("dim1", "dim2"), int]
138                    shape_anot: list[str] = list()
139                    for x in params[0]:
140                        if isinstance(x, str):
141                            shape_anot.append(x)
142                        elif isinstance(x, int):
143                            shape_anot.append(str(x))
144                        elif isinstance(x, tuple):
145                            shape_anot.append("".join(str(y) for y in x))
146                        else:
147                            raise Exception(
148                                f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
149                            )
150
151                    return TYPE_TO_JAX_DTYPE[params[1]][
152                        array_type, " ".join(shape_anot)
153                    ]
154            else:
155                raise Exception(
156                    f"unexpected type for params:\n{cls.param_info(params)}"
157                )
158
159    _BaseArray.__name__ = name
160
161    if _BaseArray.__doc__ is None:
162        _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
163
164    _BaseArray.__doc__ = _BaseArray.__doc__.format(
165        default_jax_dtype=repr(default_jax_dtype),
166        array_type=repr(array_type),
167    )
168
169    return _BaseArray
170
171
172if typing.TYPE_CHECKING:
173    # these class definitions are only used here to make pylint happy,
174    # but they make mypy unhappy and there is no way to only run if not mypy
175    # so, later on we have more ignores
176    class ATensor(torch.Tensor):
177        @typing._tp_cache  # type: ignore
178        def __class_getitem__(cls, params):
179            raise NotImplementedError()
180
181    class NDArray(torch.Tensor):
182        @typing._tp_cache  # type: ignore
183        def __class_getitem__(cls, params):
184            raise NotImplementedError()
185
186
187ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)  # type: ignore[misc, assignment]
188
189NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float)  # type: ignore[misc, assignment]
190
191
192def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
193    """convert numpy dtype to torch dtype"""
194    if isinstance(dtype, torch.dtype):
195        return dtype
196    else:
197        return torch.from_numpy(np.array(0, dtype=dtype)).dtype
198
199
200DTYPE_LIST: list = [
201    *[
202        bool,
203        int,
204        float,
205    ],
206    *[
207        # ----------
208        # pytorch
209        # ----------
210        # floats
211        torch.float,
212        torch.float32,
213        torch.float64,
214        torch.half,
215        torch.double,
216        torch.bfloat16,
217        # complex
218        torch.complex64,
219        torch.complex128,
220        # ints
221        torch.int,
222        torch.int8,
223        torch.int16,
224        torch.int32,
225        torch.int64,
226        torch.long,
227        torch.short,
228        # simplest
229        torch.uint8,
230        torch.bool,
231    ],
232    *[
233        # ----------
234        # numpy
235        # ----------
236        # floats
237        np.float_,
238        np.float16,
239        np.float32,
240        np.float64,
241        np.half,
242        np.single,
243        np.double,
244        # complex
245        np.complex64,
246        np.complex128,
247        # ints
248        np.int8,
249        np.int16,
250        np.int32,
251        np.int64,
252        np.int_,
253        np.longlong,
254        np.short,
255        # simplest
256        np.uint8,
257        np.bool_,
258    ],
259]
260"list of all the python, numpy, and torch numerical types I could think of"
261
262DTYPE_MAP: dict = {
263    **{str(x): x for x in DTYPE_LIST},
264    **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
265}
266"mapping from string representations of types to their type"
267
268TORCH_DTYPE_MAP: dict = {
269    key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
270}
271"mapping from string representations of types to specifically torch types"
272
273# no idea why we have to do this, smh
274DTYPE_MAP["bool"] = np.bool_
275TORCH_DTYPE_MAP["bool"] = torch.bool
276
277
278TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
279    "Adagrad": torch.optim.Adagrad,
280    "Adam": torch.optim.Adam,
281    "AdamW": torch.optim.AdamW,
282    "SparseAdam": torch.optim.SparseAdam,
283    "Adamax": torch.optim.Adamax,
284    "ASGD": torch.optim.ASGD,
285    "LBFGS": torch.optim.LBFGS,
286    "NAdam": torch.optim.NAdam,
287    "RAdam": torch.optim.RAdam,
288    "RMSprop": torch.optim.RMSprop,
289    "Rprop": torch.optim.Rprop,
290    "SGD": torch.optim.SGD,
291}
292
293
294def pad_tensor(
295    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
296    padded_length: int,
297    pad_value: float = 0.0,
298    rpad: bool = False,
299) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
300    """pad a 1-d tensor on the left with pad_value to length `padded_length`
301
302    set `rpad = True` to pad on the right instead"""
303
304    temp: list[torch.Tensor] = [
305        torch.full(
306            (padded_length - tensor.shape[0],),
307            pad_value,
308            dtype=tensor.dtype,
309            device=tensor.device,
310        ),
311        tensor,
312    ]
313
314    if rpad:
315        temp.reverse()
316
317    return torch.cat(temp)
318
319
320def lpad_tensor(
321    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
322) -> torch.Tensor:
323    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
324    return pad_tensor(tensor, padded_length, pad_value, rpad=False)
325
326
327def rpad_tensor(
328    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
329) -> torch.Tensor:
330    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
331    return pad_tensor(tensor, pad_length, pad_value, rpad=True)
332
333
334def pad_array(
335    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
336    padded_length: int,
337    pad_value: float = 0.0,
338    rpad: bool = False,
339) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
340    """pad a 1-d array on the left with pad_value to length `padded_length`
341
342    set `rpad = True` to pad on the right instead"""
343
344    temp: list[np.ndarray] = [
345        np.full(
346            (padded_length - array.shape[0],),
347            pad_value,
348            dtype=array.dtype,
349        ),
350        array,
351    ]
352
353    if rpad:
354        temp.reverse()
355
356    return np.concatenate(temp)
357
358
359def lpad_array(
360    array: np.ndarray, padded_length: int, pad_value: float = 0.0
361) -> np.ndarray:
362    """pad a 1-d array on the left with pad_value to length `padded_length`"""
363    return pad_array(array, padded_length, pad_value, rpad=False)
364
365
366def rpad_array(
367    array: np.ndarray, pad_length: int, pad_value: float = 0.0
368) -> np.ndarray:
369    """pad a 1-d array on the right with pad_value to length `pad_length`"""
370    return pad_array(array, pad_length, pad_value, rpad=True)
371
372
373def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
374    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
375    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
376
377
378def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
379    """printable version of get_dict_shapes"""
380    return json.dumps(
381        dotlist_to_nested_dict(
382            {
383                k: str(
384                    tuple(v.shape)
385                )  # to string, since indent wont play nice with tuples
386                for k, v in d.items()
387            }
388        ),
389        indent=2,
390    )
391
392
393class StateDictCompareError(AssertionError):
394    """raised when state dicts don't match"""
395
396    pass
397
398
399class StateDictKeysError(StateDictCompareError):
400    """raised when state dict keys don't match"""
401
402    pass
403
404
405class StateDictShapeError(StateDictCompareError):
406    """raised when state dict shapes don't match"""
407
408    pass
409
410
411class StateDictValueError(StateDictCompareError):
412    """raised when state dict values don't match"""
413
414    pass
415
416
417def compare_state_dicts(
418    d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
419) -> None:
420    """compare two dicts of tensors
421
422    # Parameters:
423
424     - `d1 : dict`
425     - `d2 : dict`
426     - `rtol : float`
427       (defaults to `1e-5`)
428     - `atol : float`
429       (defaults to `1e-8`)
430     - `verbose : bool`
431       (defaults to `True`)
432
433    # Raises:
434
435     - `StateDictKeysError` : keys don't match
436     - `StateDictShapeError` : shapes don't match (but keys do)
437     - `StateDictValueError` : values don't match (but keys and shapes do)
438    """
439    # check keys match
440    d1_keys: set = set(d1.keys())
441    d2_keys: set = set(d2.keys())
442    symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
443    keys_diff_1: set = d1_keys - d2_keys
444    keys_diff_2: set = d2_keys - d1_keys
445    # sort sets for easier debugging
446    symmetric_diff = set(sorted(symmetric_diff))
447    keys_diff_1 = set(sorted(keys_diff_1))
448    keys_diff_2 = set(sorted(keys_diff_2))
449    diff_shapes_1: str = (
450        string_dict_shapes({k: d1[k] for k in keys_diff_1})
451        if verbose
452        else "(verbose = False)"
453    )
454    diff_shapes_2: str = (
455        string_dict_shapes({k: d2[k] for k in keys_diff_2})
456        if verbose
457        else "(verbose = False)"
458    )
459    if not len(symmetric_diff) == 0:
460        raise StateDictKeysError(
461            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
462        )
463
464    # check tensors match
465    shape_failed: list[str] = list()
466    vals_failed: list[str] = list()
467    for k, v1 in d1.items():
468        v2 = d2[k]
469        # check shapes first
470        if not v1.shape == v2.shape:
471            shape_failed.append(k)
472        else:
473            # if shapes match, check values
474            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
475                vals_failed.append(k)
476
477    str_shape_failed: str = (
478        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
479    )
480    str_vals_failed: str = (
481        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
482    )
483
484    if not len(shape_failed) == 0:
485        raise StateDictShapeError(
486            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
487        )
488    if not len(vals_failed) == 0:
489        raise StateDictValueError(
490            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
491        )

TYPE_TO_JAX_DTYPE: dict = {<class 'float'>: <class 'jaxtyping.Float'>, <class 'int'>: <class 'jaxtyping.Int'>, <class 'jaxtyping.Float'>: <class 'jaxtyping.Float'>, <class 'jaxtyping.Int'>: <class 'jaxtyping.Int'>, <class 'bool'>: <class 'jaxtyping.Bool'>, <class 'jaxtyping.Bool'>: <class 'jaxtyping.Bool'>, <class 'numpy.bool_'>: <class 'jaxtyping.Bool'>, torch.bool: <class 'jaxtyping.Bool'>, <class 'numpy.float64'>: <class 'jaxtyping.Float'>, <class 'numpy.float16'>: <class 'jaxtyping.Float'>, <class 'numpy.float32'>: <class 'jaxtyping.Float'>, <class 'numpy.int32'>: <class 'jaxtyping.Int'>, <class 'numpy.int8'>: <class 'jaxtyping.Int'>, <class 'numpy.int16'>: <class 'jaxtyping.Int'>, <class 'numpy.int64'>: <class 'jaxtyping.Int'>, <class 'numpy.uint8'>: <class 'jaxtyping.Int'>, torch.float32: <class 'jaxtyping.Float'>, torch.float16: <class 'jaxtyping.Float'>, torch.float64: <class 'jaxtyping.Float'>, torch.bfloat16: <class 'jaxtyping.Float'>, torch.int32: <class 'jaxtyping.Int'>, torch.int8: <class 'jaxtyping.Int'>, torch.int16: <class 'jaxtyping.Int'>, torch.int64: <class 'jaxtyping.Int'>}

dict mapping python, numpy, and torch types to jaxtyping types

def jaxtype_factory( name: str, array_type: type, default_jax_dtype=<class 'jaxtyping.Float'>, legacy_mode: muutils.errormode.ErrorMode = ErrorMode.Warn) -> type:
 76def jaxtype_factory(
 77    name: str,
 78    array_type: type,
 79    default_jax_dtype=jaxtyping.Float,
 80    legacy_mode: ErrorMode = ErrorMode.WARN,
 81) -> type:
 82    """usage:
 83    ```
 84    ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
 85    x: ATensor["dim1 dim2", np.float32]
 86    ```
 87    """
 88    legacy_mode = ErrorMode.from_any(legacy_mode)
 89
 90    class _BaseArray:
 91        """jaxtyping shorthand
 92        (backwards compatible with older versions of muutils.tensor_utils)
 93
 94        default_jax_dtype = {default_jax_dtype}
 95        array_type = {array_type}
 96        """
 97
 98        def __new__(cls, *args, **kwargs):
 99            raise TypeError("Type FArray cannot be instantiated.")
100
101        def __init_subclass__(cls, *args, **kwargs):
102            raise TypeError(f"Cannot subclass {cls.__name__}")
103
104        @classmethod
105        def param_info(cls, params) -> str:
106            """useful for error printing"""
107            return "\n".join(
108                f"{k} = {v}"
109                for k, v in {
110                    "cls.__name__": cls.__name__,
111                    "cls.__doc__": cls.__doc__,
112                    "params": params,
113                    "type(params)": type(params),
114                }.items()
115            )
116
117        @typing._tp_cache  # type: ignore
118        def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:
119            # MyTensor["dim1 dim2"]
120            if isinstance(params, str):
121                return default_jax_dtype[array_type, params]
122
123            elif isinstance(params, tuple):
124                if len(params) != 2:
125                    raise Exception(
126                        f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
127                    )
128
129                if isinstance(params[0], str):
130                    # MyTensor["dim1 dim2", int]
131                    return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
132
133                elif isinstance(params[0], tuple):
134                    legacy_mode.process(
135                        f"legacy type annotation was used:\n{cls.param_info(params) = }",
136                        except_cls=Exception,
137                    )
138                    # MyTensor[("dim1", "dim2"), int]
139                    shape_anot: list[str] = list()
140                    for x in params[0]:
141                        if isinstance(x, str):
142                            shape_anot.append(x)
143                        elif isinstance(x, int):
144                            shape_anot.append(str(x))
145                        elif isinstance(x, tuple):
146                            shape_anot.append("".join(str(y) for y in x))
147                        else:
148                            raise Exception(
149                                f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
150                            )
151
152                    return TYPE_TO_JAX_DTYPE[params[1]][
153                        array_type, " ".join(shape_anot)
154                    ]
155            else:
156                raise Exception(
157                    f"unexpected type for params:\n{cls.param_info(params)}"
158                )
159
160    _BaseArray.__name__ = name
161
162    if _BaseArray.__doc__ is None:
163        _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
164
165    _BaseArray.__doc__ = _BaseArray.__doc__.format(
166        default_jax_dtype=repr(default_jax_dtype),
167        array_type=repr(array_type),
168    )
169
170    return _BaseArray

usage:

ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
ATensor = <class 'jaxtype_factory.<locals>._BaseArray'>
NDArray = <class 'jaxtype_factory.<locals>._BaseArray'>
def numpy_to_torch_dtype(dtype: Union[numpy.dtype, torch.dtype]) -> torch.dtype:
193def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
194    """convert numpy dtype to torch dtype"""
195    if isinstance(dtype, torch.dtype):
196        return dtype
197    else:
198        return torch.from_numpy(np.array(0, dtype=dtype)).dtype

convert numpy dtype to torch dtype

DTYPE_LIST: list = [<class 'bool'>, <class 'int'>, <class 'float'>, torch.float32, torch.float32, torch.float64, torch.float16, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.int32, torch.int8, torch.int16, torch.int32, torch.int64, torch.int64, torch.int16, torch.uint8, torch.bool, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.int8'>, <class 'numpy.int16'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.int16'>, <class 'numpy.uint8'>, <class 'numpy.bool_'>]

list of all the python, numpy, and torch numerical types I could think of

DTYPE_MAP: dict = {"<class 'bool'>": <class 'bool'>, "<class 'int'>": <class 'int'>, "<class 'float'>": <class 'float'>, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float64'>": <class 'numpy.float64'>, "<class 'numpy.float16'>": <class 'numpy.float16'>, "<class 'numpy.float32'>": <class 'numpy.float32'>, "<class 'numpy.complex64'>": <class 'numpy.complex64'>, "<class 'numpy.complex128'>": <class 'numpy.complex128'>, "<class 'numpy.int8'>": <class 'numpy.int8'>, "<class 'numpy.int16'>": <class 'numpy.int16'>, "<class 'numpy.int32'>": <class 'numpy.int32'>, "<class 'numpy.int64'>": <class 'numpy.int64'>, "<class 'numpy.uint8'>": <class 'numpy.uint8'>, "<class 'numpy.bool_'>": <class 'numpy.bool_'>, 'float64': <class 'numpy.float64'>, 'float16': <class 'numpy.float16'>, 'float32': <class 'numpy.float32'>, 'complex64': <class 'numpy.complex64'>, 'complex128': <class 'numpy.complex128'>, 'int8': <class 'numpy.int8'>, 'int16': <class 'numpy.int16'>, 'int32': <class 'numpy.int32'>, 'int64': <class 'numpy.int64'>, 'uint8': <class 'numpy.uint8'>, 'bool_': <class 'numpy.bool_'>, 'bool': <class 'numpy.bool_'>}

mapping from string representations of types to their type

TORCH_DTYPE_MAP: dict = {"<class 'bool'>": torch.bool, "<class 'int'>": torch.int32, "<class 'float'>": torch.float64, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float64'>": torch.float64, "<class 'numpy.float16'>": torch.float16, "<class 'numpy.float32'>": torch.float32, "<class 'numpy.complex64'>": torch.complex64, "<class 'numpy.complex128'>": torch.complex128, "<class 'numpy.int8'>": torch.int8, "<class 'numpy.int16'>": torch.int16, "<class 'numpy.int32'>": torch.int32, "<class 'numpy.int64'>": torch.int64, "<class 'numpy.uint8'>": torch.uint8, "<class 'numpy.bool_'>": torch.bool, 'float64': torch.float64, 'float16': torch.float16, 'float32': torch.float32, 'complex64': torch.complex64, 'complex128': torch.complex128, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'uint8': torch.uint8, 'bool_': torch.bool, 'bool': torch.bool}

mapping from string representations of types to specifically torch types

TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.optimizer.Optimizer]] = {'Adagrad': <class 'torch.optim.adagrad.Adagrad'>, 'Adam': <class 'torch.optim.adam.Adam'>, 'AdamW': <class 'torch.optim.adamw.AdamW'>, 'SparseAdam': <class 'torch.optim.sparse_adam.SparseAdam'>, 'Adamax': <class 'torch.optim.adamax.Adamax'>, 'ASGD': <class 'torch.optim.asgd.ASGD'>, 'LBFGS': <class 'torch.optim.lbfgs.LBFGS'>, 'NAdam': <class 'torch.optim.nadam.NAdam'>, 'RAdam': <class 'torch.optim.radam.RAdam'>, 'RMSprop': <class 'torch.optim.rmsprop.RMSprop'>, 'Rprop': <class 'torch.optim.rprop.Rprop'>, 'SGD': <class 'torch.optim.sgd.SGD'>}
def pad_tensor( tensor: jaxtyping.Shaped[Tensor, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[Tensor, 'padded_length']:
295def pad_tensor(
296    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
297    padded_length: int,
298    pad_value: float = 0.0,
299    rpad: bool = False,
300) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
301    """pad a 1-d tensor on the left with pad_value to length `padded_length`
302
303    set `rpad = True` to pad on the right instead"""
304
305    temp: list[torch.Tensor] = [
306        torch.full(
307            (padded_length - tensor.shape[0],),
308            pad_value,
309            dtype=tensor.dtype,
310            device=tensor.device,
311        ),
312        tensor,
313    ]
314
315    if rpad:
316        temp.reverse()
317
318    return torch.cat(temp)

pad a 1-d tensor on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_tensor( tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0) -> torch.Tensor:
321def lpad_tensor(
322    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
323) -> torch.Tensor:
324    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
325    return pad_tensor(tensor, padded_length, pad_value, rpad=False)

pad a 1-d tensor on the left with pad_value to length padded_length

def rpad_tensor( tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0) -> torch.Tensor:
328def rpad_tensor(
329    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
330) -> torch.Tensor:
331    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
332    return pad_tensor(tensor, pad_length, pad_value, rpad=True)

pad a 1-d tensor on the right with pad_value to length pad_length

def pad_array( array: jaxtyping.Shaped[ndarray, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[ndarray, 'padded_length']:
335def pad_array(
336    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
337    padded_length: int,
338    pad_value: float = 0.0,
339    rpad: bool = False,
340) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
341    """pad a 1-d array on the left with pad_value to length `padded_length`
342
343    set `rpad = True` to pad on the right instead"""
344
345    temp: list[np.ndarray] = [
346        np.full(
347            (padded_length - array.shape[0],),
348            pad_value,
349            dtype=array.dtype,
350        ),
351        array,
352    ]
353
354    if rpad:
355        temp.reverse()
356
357    return np.concatenate(temp)

pad a 1-d array on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_array( array: numpy.ndarray, padded_length: int, pad_value: float = 0.0) -> numpy.ndarray:
360def lpad_array(
361    array: np.ndarray, padded_length: int, pad_value: float = 0.0
362) -> np.ndarray:
363    """pad a 1-d array on the left with pad_value to length `padded_length`"""
364    return pad_array(array, padded_length, pad_value, rpad=False)

pad a 1-d array on the left with pad_value to length padded_length

def rpad_array( array: numpy.ndarray, pad_length: int, pad_value: float = 0.0) -> numpy.ndarray:
367def rpad_array(
368    array: np.ndarray, pad_length: int, pad_value: float = 0.0
369) -> np.ndarray:
370    """pad a 1-d array on the right with pad_value to length `pad_length`"""
371    return pad_array(array, pad_length, pad_value, rpad=True)

pad a 1-d array on the right with pad_value to length pad_length

def get_dict_shapes(d: dict[str, torch.Tensor]) -> dict[str, tuple[int, ...]]:
374def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
375    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
376    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})

given a state dict or cache dict, compute the shapes and put them in a nested dict

def string_dict_shapes(d: dict[str, torch.Tensor]) -> str:
379def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
380    """printable version of get_dict_shapes"""
381    return json.dumps(
382        dotlist_to_nested_dict(
383            {
384                k: str(
385                    tuple(v.shape)
386                )  # to string, since indent wont play nice with tuples
387                for k, v in d.items()
388            }
389        ),
390        indent=2,
391    )

printable version of get_dict_shapes

class StateDictCompareError(builtins.AssertionError):
394class StateDictCompareError(AssertionError):
395    """raised when state dicts don't match"""
396
397    pass

raised when state dicts don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictKeysError(StateDictCompareError):
400class StateDictKeysError(StateDictCompareError):
401    """raised when state dict keys don't match"""
402
403    pass

raised when state dict keys don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictShapeError(StateDictCompareError):
406class StateDictShapeError(StateDictCompareError):
407    """raised when state dict shapes don't match"""
408
409    pass

raised when state dict shapes don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictValueError(StateDictCompareError):
412class StateDictValueError(StateDictCompareError):
413    """raised when state dict values don't match"""
414
415    pass

raised when state dict values don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
def compare_state_dicts( d1: dict, d2: dict, rtol: float = 1e-05, atol: float = 1e-08, verbose: bool = True) -> None:
418def compare_state_dicts(
419    d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
420) -> None:
421    """compare two dicts of tensors
422
423    # Parameters:
424
425     - `d1 : dict`
426     - `d2 : dict`
427     - `rtol : float`
428       (defaults to `1e-5`)
429     - `atol : float`
430       (defaults to `1e-8`)
431     - `verbose : bool`
432       (defaults to `True`)
433
434    # Raises:
435
436     - `StateDictKeysError` : keys don't match
437     - `StateDictShapeError` : shapes don't match (but keys do)
438     - `StateDictValueError` : values don't match (but keys and shapes do)
439    """
440    # check keys match
441    d1_keys: set = set(d1.keys())
442    d2_keys: set = set(d2.keys())
443    symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
444    keys_diff_1: set = d1_keys - d2_keys
445    keys_diff_2: set = d2_keys - d1_keys
446    # sort sets for easier debugging
447    symmetric_diff = set(sorted(symmetric_diff))
448    keys_diff_1 = set(sorted(keys_diff_1))
449    keys_diff_2 = set(sorted(keys_diff_2))
450    diff_shapes_1: str = (
451        string_dict_shapes({k: d1[k] for k in keys_diff_1})
452        if verbose
453        else "(verbose = False)"
454    )
455    diff_shapes_2: str = (
456        string_dict_shapes({k: d2[k] for k in keys_diff_2})
457        if verbose
458        else "(verbose = False)"
459    )
460    if not len(symmetric_diff) == 0:
461        raise StateDictKeysError(
462            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
463        )
464
465    # check tensors match
466    shape_failed: list[str] = list()
467    vals_failed: list[str] = list()
468    for k, v1 in d1.items():
469        v2 = d2[k]
470        # check shapes first
471        if not v1.shape == v2.shape:
472            shape_failed.append(k)
473        else:
474            # if shapes match, check values
475            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
476                vals_failed.append(k)
477
478    str_shape_failed: str = (
479        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
480    )
481    str_vals_failed: str = (
482        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
483    )
484
485    if not len(shape_failed) == 0:
486        raise StateDictShapeError(
487            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
488        )
489    if not len(vals_failed) == 0:
490        raise StateDictValueError(
491            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
492        )

compare two dicts of tensors

Parameters:

  • d1 : dict
  • d2 : dict
  • rtol : float (defaults to 1e-5)
  • atol : float (defaults to 1e-8)
  • verbose : bool (defaults to True)

Raises: