docs for muutils v0.9.1
View Source on GitHub

muutils.dictmagic

making working with dictionaries easier

  • DefaulterDict: like a defaultdict, but default_factory is passed the key as an argument
  • various methods for working wit dotlist-nested dicts, converting to and from them
  • condense_nested_dicts: condense a nested dict, by condensing numeric or matching keys with matching values to ranges
  • condense_tensor_dict: convert a dictionary of tensors to a dictionary of shapes
  • kwargs_to_nested_dict: given kwargs from fire, convert them to a nested dict

  1"""making working with dictionaries easier
  2
  3- `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument
  4- various methods for working wit dotlist-nested dicts, converting to and from them
  5- `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges
  6- `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes
  7- `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict
  8"""
  9
 10from __future__ import annotations
 11
 12import typing
 13import warnings
 14from collections import defaultdict
 15from typing import (
 16    Any,
 17    Callable,
 18    Generic,
 19    Hashable,
 20    Iterable,
 21    Literal,
 22    Optional,
 23    TypeVar,
 24    Union,
 25)
 26
 27from muutils.errormode import ErrorMode
 28
 29_KT = TypeVar("_KT")
 30_VT = TypeVar("_VT")
 31
 32
 33class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]):
 34    """like a defaultdict, but default_factory is passed the key as an argument"""
 35
 36    def __init__(
 37        self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any
 38    ) -> None:
 39        if args:
 40            raise TypeError(
 41                f"DefaulterDict does not support positional arguments: *args = {args}"
 42            )
 43        super().__init__(**kwargs)
 44        self.default_factory: Callable[[_KT], _VT] = default_factory
 45
 46    def __getitem__(self, k: _KT) -> _VT:
 47        if k in self:
 48            return dict.__getitem__(self, k)
 49        else:
 50            v: _VT = self.default_factory(k)
 51            dict.__setitem__(self, k, v)
 52            return v
 53
 54
 55def _recursive_defaultdict_ctor() -> defaultdict:
 56    return defaultdict(_recursive_defaultdict_ctor)
 57
 58
 59def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict:
 60    """Convert a defaultdict or DefaulterDict to a normal dict, recursively"""
 61    return {
 62        key: (
 63            defaultdict_to_dict_recursive(value)
 64            if isinstance(value, (defaultdict, DefaulterDict))
 65            else value
 66        )
 67        for key, value in dd.items()
 68    }
 69
 70
 71def dotlist_to_nested_dict(
 72    dot_dict: typing.Dict[str, Any], sep: str = "."
 73) -> typing.Dict[str, Any]:
 74    """Convert a dict with dot-separated keys to a nested dict
 75
 76    Example:
 77
 78        >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
 79        {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
 80    """
 81    nested_dict: defaultdict = _recursive_defaultdict_ctor()
 82    for key, value in dot_dict.items():
 83        if not isinstance(key, str):
 84            raise TypeError(f"key must be a string, got {type(key)}")
 85        keys: list[str] = key.split(sep)
 86        current: defaultdict = nested_dict
 87        # iterate over the keys except the last one
 88        for sub_key in keys[:-1]:
 89            current = current[sub_key]
 90        current[keys[-1]] = value
 91    return defaultdict_to_dict_recursive(nested_dict)
 92
 93
 94def nested_dict_to_dotlist(
 95    nested_dict: typing.Dict[str, Any],
 96    sep: str = ".",
 97    allow_lists: bool = False,
 98) -> dict[str, Any]:
 99    def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]:
100        items: dict = dict()
101
102        new_key: str
103        if isinstance(current, dict):
104            # dict case
105            if not current and parent_key:
106                items[parent_key] = current
107            else:
108                for k, v in current.items():
109                    new_key = f"{parent_key}{sep}{k}" if parent_key else k
110                    items.update(_recurse(v, new_key))
111
112        elif allow_lists and isinstance(current, list):
113            # list case
114            for i, item in enumerate(current):
115                new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
116                items.update(_recurse(item, new_key))
117
118        else:
119            # anything else (write value)
120            items[parent_key] = current
121
122        return items
123
124    return _recurse(nested_dict)
125
126
127def update_with_nested_dict(
128    original: dict[str, Any],
129    update: dict[str, Any],
130) -> dict[str, Any]:
131    """Update a dict with a nested dict
132
133    Example:
134    >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
135    {'a': {'b': 2}, 'c': -1}
136
137    # Arguments
138    - `original: dict[str, Any]`
139        the dict to update (will be modified in-place)
140    - `update: dict[str, Any]`
141        the dict to update with
142
143    # Returns
144    - `dict`
145        the updated dict
146    """
147    for key, value in update.items():
148        if key in original:
149            if isinstance(original[key], dict) and isinstance(value, dict):
150                update_with_nested_dict(original[key], value)
151            else:
152                original[key] = value
153        else:
154            original[key] = value
155
156    return original
157
158
159def kwargs_to_nested_dict(
160    kwargs_dict: dict[str, Any],
161    sep: str = ".",
162    strip_prefix: Optional[str] = None,
163    when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN,
164    transform_key: Optional[Callable[[str], str]] = None,
165) -> dict[str, Any]:
166    """given kwargs from fire, convert them to a nested dict
167
168    if strip_prefix is not None, then all keys must start with the prefix. by default,
169    will warn if an unknown prefix is found, but can be set to raise an error or ignore it:
170    `when_unknown_prefix: ErrorMode`
171
172    Example:
173    ```python
174    def main(**kwargs):
175        print(kwargs_to_nested_dict(kwargs))
176    fire.Fire(main)
177    ```
178    running the above script will give:
179    ```bash
180    $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
181    {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
182    ```
183
184    # Arguments
185    - `kwargs_dict: dict[str, Any]`
186        the kwargs dict to convert
187    - `sep: str = "."`
188        the separator to use for nested keys
189    - `strip_prefix: Optional[str] = None`
190        if not None, then all keys must start with this prefix
191    - `when_unknown_prefix: ErrorMode = ErrorMode.WARN`
192        what to do when an unknown prefix is found
193    - `transform_key: Callable[[str], str] | None = None`
194        a function to apply to each key before adding it to the dict (applied after stripping the prefix)
195    """
196    when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix)
197    filtered_kwargs: dict[str, Any] = dict()
198    for key, value in kwargs_dict.items():
199        if strip_prefix is not None:
200            if not key.startswith(strip_prefix):
201                when_unknown_prefix_.process(
202                    f"key '{key}' does not start with '{strip_prefix}'",
203                    except_cls=ValueError,
204                )
205            else:
206                key = key[len(strip_prefix) :]
207
208        if transform_key is not None:
209            key = transform_key(key)
210
211        filtered_kwargs[key] = value
212
213    return dotlist_to_nested_dict(filtered_kwargs, sep=sep)
214
215
216def is_numeric_consecutive(lst: list[str]) -> bool:
217    """Check if the list of keys is numeric and consecutive."""
218    try:
219        numbers: list[int] = [int(x) for x in lst]
220        return sorted(numbers) == list(range(min(numbers), max(numbers) + 1))
221    except ValueError:
222        return False
223
224
225def condense_nested_dicts_numeric_keys(
226    data: dict[str, Any],
227) -> dict[str, Any]:
228    """condense a nested dict, by condensing numeric keys with matching values to ranges
229
230    # Examples:
231    ```python
232    >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
233    {'[1-3]': 1, '[4-6]': 2}
234    >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
235    {"1": {"[1-2]": "a"}, "2": "b"}
236    ```
237    """
238
239    if not isinstance(data, dict):
240        return data
241
242    # Process each sub-dictionary
243    for key, value in list(data.items()):
244        data[key] = condense_nested_dicts_numeric_keys(value)
245
246    # Find all numeric, consecutive keys
247    if is_numeric_consecutive(list(data.keys())):
248        keys: list[str] = sorted(data.keys(), key=lambda x: int(x))
249    else:
250        return data
251
252    # output dict
253    condensed_data: dict[str, Any] = {}
254
255    # Identify ranges of identical values and condense
256    i: int = 0
257    while i < len(keys):
258        j: int = i
259        while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]:
260            j += 1
261        if j > i:  # Found consecutive keys with identical values
262            condensed_key: str = f"[{keys[i]}-{keys[j]}]"
263            condensed_data[condensed_key] = data[keys[i]]
264            i = j + 1
265        else:
266            condensed_data[keys[i]] = data[keys[i]]
267            i += 1
268
269    return condensed_data
270
271
272def condense_nested_dicts_matching_values(
273    data: dict[str, Any],
274    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
275) -> dict[str, Any]:
276    """condense a nested dict, by condensing keys with matching values
277
278    # Examples: TODO
279
280    # Parameters:
281     - `data : dict[str, Any]`
282        data to process
283     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
284        a function to apply to each value before adding it to the dict (if it's not hashable)
285        (defaults to `None`)
286
287    """
288
289    if isinstance(data, dict):
290        data = {
291            key: condense_nested_dicts_matching_values(
292                value, val_condense_fallback_mapping
293            )
294            for key, value in data.items()
295        }
296    else:
297        return data
298
299    # Find all identical values and condense by stitching together keys
300    values_grouped: defaultdict[Any, list[str]] = defaultdict(list)
301    data_persist: dict[str, Any] = dict()
302    for key, value in data.items():
303        if not isinstance(value, dict):
304            try:
305                values_grouped[value].append(key)
306            except TypeError:
307                # If the value is unhashable, use a fallback mapping to find a hashable representation
308                if val_condense_fallback_mapping is not None:
309                    values_grouped[val_condense_fallback_mapping(value)].append(key)
310                else:
311                    data_persist[key] = value
312        else:
313            data_persist[key] = value
314
315    condensed_data = data_persist
316    for value, keys in values_grouped.items():
317        if len(keys) > 1:
318            merged_key = f"[{', '.join(keys)}]"  # Choose an appropriate method to represent merged keys
319            condensed_data[merged_key] = value
320        else:
321            condensed_data[keys[0]] = value
322
323    return condensed_data
324
325
326def condense_nested_dicts(
327    data: dict[str, Any],
328    condense_numeric_keys: bool = True,
329    condense_matching_values: bool = True,
330    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
331) -> dict[str, Any]:
332    """condense a nested dict, by condensing numeric or matching keys with matching values to ranges
333
334    combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()`
335
336    # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes
337    it's not reversible because types are lost to make the printing pretty
338
339    # Parameters:
340     - `data : dict[str, Any]`
341        data to process
342     - `condense_numeric_keys : bool`
343        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]")
344       (defaults to `True`)
345     - `condense_matching_values : bool`
346        whether to condense keys with matching values
347       (defaults to `True`)
348     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
349        a function to apply to each value before adding it to the dict (if it's not hashable)
350       (defaults to `None`)
351
352    """
353
354    condensed_data: dict = data
355    if condense_numeric_keys:
356        condensed_data = condense_nested_dicts_numeric_keys(condensed_data)
357    if condense_matching_values:
358        condensed_data = condense_nested_dicts_matching_values(
359            condensed_data, val_condense_fallback_mapping
360        )
361    return condensed_data
362
363
364def tuple_dims_replace(
365    t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None
366) -> tuple[Union[int, str], ...]:
367    if dims_names_map is None:
368        return t
369    else:
370        return tuple(dims_names_map.get(x, x) for x in t)
371
372
373TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"]  # type: ignore[name-defined] # noqa: F821
374TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]]  # type: ignore[name-defined] # noqa: F821
375TensorDictFormats = Literal["dict", "json", "yaml", "yml"]
376
377
378def _default_shapes_convert(x: tuple) -> str:
379    return str(x).replace('"', "").replace("'", "")
380
381
382def condense_tensor_dict(
383    data: TensorDict | TensorIterable,
384    fmt: TensorDictFormats = "dict",
385    *args: Any,
386    shapes_convert: Callable[
387        [tuple[Union[int, str], ...]], Any
388    ] = _default_shapes_convert,
389    drop_batch_dims: int = 0,
390    sep: str = ".",
391    dims_names_map: Optional[dict[int, str]] = None,
392    condense_numeric_keys: bool = True,
393    condense_matching_values: bool = True,
394    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
395    return_format: Optional[TensorDictFormats] = None,
396) -> Union[str, dict[str, str | tuple[int, ...]]]:
397    """Convert a dictionary of tensors to a dictionary of shapes.
398
399    by default, values are converted to strings of their shapes (for nice printing).
400    If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`.
401
402    # Parameters:
403     - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]`
404        a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` )
405     - `fmt : TensorDictFormats`
406        format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed.
407        (defaults to `'dict'`)
408     - `shapes_convert : Callable[[tuple], Any]`
409        conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes)
410        (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`)
411     - `drop_batch_dims : int`
412        number of leading dimensions to drop from the shape
413        (defaults to `0`)
414     - `sep : str`
415        separator to use for nested keys
416        (defaults to `'.'`)
417     - `dims_names_map : dict[int, str] | None`
418        convert certain dimension values in shape. not perfect, can be buggy
419        (defaults to `None`)
420     - `condense_numeric_keys : bool`
421        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts`
422        (defaults to `True`)
423     - `condense_matching_values : bool`
424        whether to condense keys with matching values, passed on to `condense_nested_dicts`
425        (defaults to `True`)
426     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
427        a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts`
428        (defaults to `None`)
429     - `return_format : TensorDictFormats | None`
430        legacy alias for `fmt` kwarg
431
432    # Returns:
433     - `str|dict[str, str|tuple[int, ...]]`
434        dict if `return_format='dict'`, a string for `json` or `yaml` output
435
436    # Examples:
437    ```python
438    >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
439    >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
440    ```
441    ```yaml
442    embed:
443      W_E: (50257, 768)
444    pos_embed:
445      W_pos: (1024, 768)
446    blocks:
447      '[0-11]':
448        attn:
449          '[W_Q, W_K, W_V]': (12, 768, 64)
450          W_O: (12, 64, 768)
451          '[b_Q, b_K, b_V]': (12, 64)
452          b_O: (768,)
453        mlp:
454          W_in: (768, 3072)
455          b_in: (3072,)
456          W_out: (3072, 768)
457          b_out: (768,)
458    unembed:
459      W_U: (768, 50257)
460      b_U: (50257,)
461    ```
462
463    # Raises:
464     - `ValueError` :  if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed
465    """
466
467    # handle arg processing:
468    # ----------------------------------------------------------------------
469    # make all args except data and format keyword-only
470    assert len(args) == 0, f"unexpected positional args: {args}"
471    # handle legacy return_format
472    if return_format is not None:
473        warnings.warn(
474            "return_format is deprecated, use fmt instead",
475            DeprecationWarning,
476        )
477        fmt = return_format
478
479    # identity function for shapes_convert if not provided
480    if shapes_convert is None:
481        shapes_convert = lambda x: x  # noqa: E731
482
483    # convert to iterable
484    data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = (  # type: ignore # noqa: F821
485        data.items() if hasattr(data, "items") and callable(data.items) else data  # type: ignore
486    )
487
488    # get shapes
489    data_shapes: dict[str, Union[str, tuple[int, ...]]] = {  # pyright: ignore[reportAssignmentType]
490        k: shapes_convert(
491            tuple_dims_replace(
492                tuple(v.shape)[drop_batch_dims:],
493                dims_names_map,
494            )
495        )
496        for k, v in data_items
497    }
498
499    # nest the dict
500    data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep)
501
502    # condense the nested dict
503    data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts(
504        data=data_nested,
505        condense_numeric_keys=condense_numeric_keys,
506        condense_matching_values=condense_matching_values,
507        val_condense_fallback_mapping=val_condense_fallback_mapping,
508    )
509
510    # return in the specified format
511    fmt_lower: str = fmt.lower()
512    if fmt_lower == "dict":
513        return data_condensed
514    elif fmt_lower == "json":
515        import json
516
517        return json.dumps(data_condensed, indent=2)
518    elif fmt_lower in ["yaml", "yml"]:
519        try:
520            import yaml  # type: ignore[import-untyped]
521
522            return yaml.dump(data_condensed, sort_keys=False)
523        except ImportError as e:
524            raise ValueError("PyYAML is required for YAML output") from e
525    else:
526        raise ValueError(f"Invalid return format: {fmt}")

class DefaulterDict(typing.Dict[~_KT, ~_VT], typing.Generic[~_KT, ~_VT]):
34class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]):
35    """like a defaultdict, but default_factory is passed the key as an argument"""
36
37    def __init__(
38        self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any
39    ) -> None:
40        if args:
41            raise TypeError(
42                f"DefaulterDict does not support positional arguments: *args = {args}"
43            )
44        super().__init__(**kwargs)
45        self.default_factory: Callable[[_KT], _VT] = default_factory
46
47    def __getitem__(self, k: _KT) -> _VT:
48        if k in self:
49            return dict.__getitem__(self, k)
50        else:
51            v: _VT = self.default_factory(k)
52            dict.__setitem__(self, k, v)
53            return v

like a defaultdict, but default_factory is passed the key as an argument

default_factory: Callable[[~_KT], ~_VT]
Inherited Members
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
def defaultdict_to_dict_recursive( dd: Union[collections.defaultdict, DefaulterDict]) -> dict:
60def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict:
61    """Convert a defaultdict or DefaulterDict to a normal dict, recursively"""
62    return {
63        key: (
64            defaultdict_to_dict_recursive(value)
65            if isinstance(value, (defaultdict, DefaulterDict))
66            else value
67        )
68        for key, value in dd.items()
69    }

Convert a defaultdict or DefaulterDict to a normal dict, recursively

def dotlist_to_nested_dict(dot_dict: Dict[str, Any], sep: str = '.') -> Dict[str, Any]:
72def dotlist_to_nested_dict(
73    dot_dict: typing.Dict[str, Any], sep: str = "."
74) -> typing.Dict[str, Any]:
75    """Convert a dict with dot-separated keys to a nested dict
76
77    Example:
78
79        >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
80        {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
81    """
82    nested_dict: defaultdict = _recursive_defaultdict_ctor()
83    for key, value in dot_dict.items():
84        if not isinstance(key, str):
85            raise TypeError(f"key must be a string, got {type(key)}")
86        keys: list[str] = key.split(sep)
87        current: defaultdict = nested_dict
88        # iterate over the keys except the last one
89        for sub_key in keys[:-1]:
90            current = current[sub_key]
91        current[keys[-1]] = value
92    return defaultdict_to_dict_recursive(nested_dict)

Convert a dict with dot-separated keys to a nested dict

Example:

>>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
def nested_dict_to_dotlist( nested_dict: Dict[str, Any], sep: str = '.', allow_lists: bool = False) -> dict[str, typing.Any]:
 95def nested_dict_to_dotlist(
 96    nested_dict: typing.Dict[str, Any],
 97    sep: str = ".",
 98    allow_lists: bool = False,
 99) -> dict[str, Any]:
100    def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]:
101        items: dict = dict()
102
103        new_key: str
104        if isinstance(current, dict):
105            # dict case
106            if not current and parent_key:
107                items[parent_key] = current
108            else:
109                for k, v in current.items():
110                    new_key = f"{parent_key}{sep}{k}" if parent_key else k
111                    items.update(_recurse(v, new_key))
112
113        elif allow_lists and isinstance(current, list):
114            # list case
115            for i, item in enumerate(current):
116                new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
117                items.update(_recurse(item, new_key))
118
119        else:
120            # anything else (write value)
121            items[parent_key] = current
122
123        return items
124
125    return _recurse(nested_dict)
def update_with_nested_dict( original: dict[str, typing.Any], update: dict[str, typing.Any]) -> dict[str, typing.Any]:
128def update_with_nested_dict(
129    original: dict[str, Any],
130    update: dict[str, Any],
131) -> dict[str, Any]:
132    """Update a dict with a nested dict
133
134    Example:
135    >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
136    {'a': {'b': 2}, 'c': -1}
137
138    # Arguments
139    - `original: dict[str, Any]`
140        the dict to update (will be modified in-place)
141    - `update: dict[str, Any]`
142        the dict to update with
143
144    # Returns
145    - `dict`
146        the updated dict
147    """
148    for key, value in update.items():
149        if key in original:
150            if isinstance(original[key], dict) and isinstance(value, dict):
151                update_with_nested_dict(original[key], value)
152            else:
153                original[key] = value
154        else:
155            original[key] = value
156
157    return original

Update a dict with a nested dict

Example:

>>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
{'a': {'b': 2}, 'c': -1}

Arguments

  • original: dict[str, Any] the dict to update (will be modified in-place)
  • update: dict[str, Any] the dict to update with

Returns

  • dict the updated dict
def kwargs_to_nested_dict( kwargs_dict: dict[str, typing.Any], sep: str = '.', strip_prefix: Optional[str] = None, when_unknown_prefix: Union[muutils.errormode.ErrorMode, str] = ErrorMode.Warn, transform_key: Optional[Callable[[str], str]] = None) -> dict[str, typing.Any]:
160def kwargs_to_nested_dict(
161    kwargs_dict: dict[str, Any],
162    sep: str = ".",
163    strip_prefix: Optional[str] = None,
164    when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN,
165    transform_key: Optional[Callable[[str], str]] = None,
166) -> dict[str, Any]:
167    """given kwargs from fire, convert them to a nested dict
168
169    if strip_prefix is not None, then all keys must start with the prefix. by default,
170    will warn if an unknown prefix is found, but can be set to raise an error or ignore it:
171    `when_unknown_prefix: ErrorMode`
172
173    Example:
174    ```python
175    def main(**kwargs):
176        print(kwargs_to_nested_dict(kwargs))
177    fire.Fire(main)
178    ```
179    running the above script will give:
180    ```bash
181    $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
182    {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
183    ```
184
185    # Arguments
186    - `kwargs_dict: dict[str, Any]`
187        the kwargs dict to convert
188    - `sep: str = "."`
189        the separator to use for nested keys
190    - `strip_prefix: Optional[str] = None`
191        if not None, then all keys must start with this prefix
192    - `when_unknown_prefix: ErrorMode = ErrorMode.WARN`
193        what to do when an unknown prefix is found
194    - `transform_key: Callable[[str], str] | None = None`
195        a function to apply to each key before adding it to the dict (applied after stripping the prefix)
196    """
197    when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix)
198    filtered_kwargs: dict[str, Any] = dict()
199    for key, value in kwargs_dict.items():
200        if strip_prefix is not None:
201            if not key.startswith(strip_prefix):
202                when_unknown_prefix_.process(
203                    f"key '{key}' does not start with '{strip_prefix}'",
204                    except_cls=ValueError,
205                )
206            else:
207                key = key[len(strip_prefix) :]
208
209        if transform_key is not None:
210            key = transform_key(key)
211
212        filtered_kwargs[key] = value
213
214    return dotlist_to_nested_dict(filtered_kwargs, sep=sep)

given kwargs from fire, convert them to a nested dict

if strip_prefix is not None, then all keys must start with the prefix. by default, will warn if an unknown prefix is found, but can be set to raise an error or ignore it: when_unknown_prefix: ErrorMode

Example:

def main(**kwargs):
    print(kwargs_to_nested_dict(kwargs))
fire.Fire(main)

running the above script will give:

$ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}

Arguments

  • kwargs_dict: dict[str, Any] the kwargs dict to convert
  • sep: str = "." the separator to use for nested keys
  • strip_prefix: Optional[str] = None if not None, then all keys must start with this prefix
  • when_unknown_prefix: ErrorMode = ErrorMode.WARN what to do when an unknown prefix is found
  • transform_key: Callable[[str], str] | None = None a function to apply to each key before adding it to the dict (applied after stripping the prefix)
def is_numeric_consecutive(lst: list[str]) -> bool:
217def is_numeric_consecutive(lst: list[str]) -> bool:
218    """Check if the list of keys is numeric and consecutive."""
219    try:
220        numbers: list[int] = [int(x) for x in lst]
221        return sorted(numbers) == list(range(min(numbers), max(numbers) + 1))
222    except ValueError:
223        return False

Check if the list of keys is numeric and consecutive.

def condense_nested_dicts_numeric_keys(data: dict[str, typing.Any]) -> dict[str, typing.Any]:
226def condense_nested_dicts_numeric_keys(
227    data: dict[str, Any],
228) -> dict[str, Any]:
229    """condense a nested dict, by condensing numeric keys with matching values to ranges
230
231    # Examples:
232    ```python
233    >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
234    {'[1-3]': 1, '[4-6]': 2}
235    >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
236    {"1": {"[1-2]": "a"}, "2": "b"}
237    ```
238    """
239
240    if not isinstance(data, dict):
241        return data
242
243    # Process each sub-dictionary
244    for key, value in list(data.items()):
245        data[key] = condense_nested_dicts_numeric_keys(value)
246
247    # Find all numeric, consecutive keys
248    if is_numeric_consecutive(list(data.keys())):
249        keys: list[str] = sorted(data.keys(), key=lambda x: int(x))
250    else:
251        return data
252
253    # output dict
254    condensed_data: dict[str, Any] = {}
255
256    # Identify ranges of identical values and condense
257    i: int = 0
258    while i < len(keys):
259        j: int = i
260        while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]:
261            j += 1
262        if j > i:  # Found consecutive keys with identical values
263            condensed_key: str = f"[{keys[i]}-{keys[j]}]"
264            condensed_data[condensed_key] = data[keys[i]]
265            i = j + 1
266        else:
267            condensed_data[keys[i]] = data[keys[i]]
268            i += 1
269
270    return condensed_data

condense a nested dict, by condensing numeric keys with matching values to ranges

Examples:

>>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
{'[1-3]': 1, '[4-6]': 2}
>>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
{"1": {"[1-2]": "a"}, "2": "b"}
def condense_nested_dicts_matching_values( data: dict[str, typing.Any], val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None) -> dict[str, typing.Any]:
273def condense_nested_dicts_matching_values(
274    data: dict[str, Any],
275    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
276) -> dict[str, Any]:
277    """condense a nested dict, by condensing keys with matching values
278
279    # Examples: TODO
280
281    # Parameters:
282     - `data : dict[str, Any]`
283        data to process
284     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
285        a function to apply to each value before adding it to the dict (if it's not hashable)
286        (defaults to `None`)
287
288    """
289
290    if isinstance(data, dict):
291        data = {
292            key: condense_nested_dicts_matching_values(
293                value, val_condense_fallback_mapping
294            )
295            for key, value in data.items()
296        }
297    else:
298        return data
299
300    # Find all identical values and condense by stitching together keys
301    values_grouped: defaultdict[Any, list[str]] = defaultdict(list)
302    data_persist: dict[str, Any] = dict()
303    for key, value in data.items():
304        if not isinstance(value, dict):
305            try:
306                values_grouped[value].append(key)
307            except TypeError:
308                # If the value is unhashable, use a fallback mapping to find a hashable representation
309                if val_condense_fallback_mapping is not None:
310                    values_grouped[val_condense_fallback_mapping(value)].append(key)
311                else:
312                    data_persist[key] = value
313        else:
314            data_persist[key] = value
315
316    condensed_data = data_persist
317    for value, keys in values_grouped.items():
318        if len(keys) > 1:
319            merged_key = f"[{', '.join(keys)}]"  # Choose an appropriate method to represent merged keys
320            condensed_data[merged_key] = value
321        else:
322            condensed_data[keys[0]] = value
323
324    return condensed_data

condense a nested dict, by condensing keys with matching values

Examples: TODO

Parameters:

  • data : dict[str, Any] data to process
  • val_condense_fallback_mapping : Callable[[Any], Hashable] | None a function to apply to each value before adding it to the dict (if it's not hashable) (defaults to None)
def condense_nested_dicts( data: dict[str, typing.Any], condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None) -> dict[str, typing.Any]:
327def condense_nested_dicts(
328    data: dict[str, Any],
329    condense_numeric_keys: bool = True,
330    condense_matching_values: bool = True,
331    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
332) -> dict[str, Any]:
333    """condense a nested dict, by condensing numeric or matching keys with matching values to ranges
334
335    combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()`
336
337    # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes
338    it's not reversible because types are lost to make the printing pretty
339
340    # Parameters:
341     - `data : dict[str, Any]`
342        data to process
343     - `condense_numeric_keys : bool`
344        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]")
345       (defaults to `True`)
346     - `condense_matching_values : bool`
347        whether to condense keys with matching values
348       (defaults to `True`)
349     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
350        a function to apply to each value before adding it to the dict (if it's not hashable)
351       (defaults to `None`)
352
353    """
354
355    condensed_data: dict = data
356    if condense_numeric_keys:
357        condensed_data = condense_nested_dicts_numeric_keys(condensed_data)
358    if condense_matching_values:
359        condensed_data = condense_nested_dicts_matching_values(
360            condensed_data, val_condense_fallback_mapping
361        )
362    return condensed_data

condense a nested dict, by condensing numeric or matching keys with matching values to ranges

combines the functionality of condense_nested_dicts_numeric_keys() and condense_nested_dicts_matching_values()

NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes

it's not reversible because types are lost to make the printing pretty

Parameters:

  • data : dict[str, Any] data to process
  • condense_numeric_keys : bool whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") (defaults to True)
  • condense_matching_values : bool whether to condense keys with matching values (defaults to True)
  • val_condense_fallback_mapping : Callable[[Any], Hashable] | None a function to apply to each value before adding it to the dict (if it's not hashable) (defaults to None)
def tuple_dims_replace( t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None) -> tuple[typing.Union[int, str], ...]:
365def tuple_dims_replace(
366    t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None
367) -> tuple[Union[int, str], ...]:
368    if dims_names_map is None:
369        return t
370    else:
371        return tuple(dims_names_map.get(x, x) for x in t)
TensorDict = typing.Dict[str, ForwardRef('torch.Tensor|np.ndarray')]
TensorIterable = typing.Iterable[typing.Tuple[str, ForwardRef('torch.Tensor|np.ndarray')]]
TensorDictFormats = typing.Literal['dict', 'json', 'yaml', 'yml']
def condense_tensor_dict( data: 'TensorDict | TensorIterable', fmt: Literal['dict', 'json', 'yaml', 'yml'] = 'dict', *args: Any, shapes_convert: Callable[[tuple[Union[int, str], ...]], Any] = <function _default_shapes_convert>, drop_batch_dims: int = 0, sep: str = '.', dims_names_map: Optional[dict[int, str]] = None, condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, return_format: Optional[Literal['dict', 'json', 'yaml', 'yml']] = None) -> Union[str, dict[str, str | tuple[int, ...]]]:
383def condense_tensor_dict(
384    data: TensorDict | TensorIterable,
385    fmt: TensorDictFormats = "dict",
386    *args: Any,
387    shapes_convert: Callable[
388        [tuple[Union[int, str], ...]], Any
389    ] = _default_shapes_convert,
390    drop_batch_dims: int = 0,
391    sep: str = ".",
392    dims_names_map: Optional[dict[int, str]] = None,
393    condense_numeric_keys: bool = True,
394    condense_matching_values: bool = True,
395    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
396    return_format: Optional[TensorDictFormats] = None,
397) -> Union[str, dict[str, str | tuple[int, ...]]]:
398    """Convert a dictionary of tensors to a dictionary of shapes.
399
400    by default, values are converted to strings of their shapes (for nice printing).
401    If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`.
402
403    # Parameters:
404     - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]`
405        a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` )
406     - `fmt : TensorDictFormats`
407        format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed.
408        (defaults to `'dict'`)
409     - `shapes_convert : Callable[[tuple], Any]`
410        conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes)
411        (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`)
412     - `drop_batch_dims : int`
413        number of leading dimensions to drop from the shape
414        (defaults to `0`)
415     - `sep : str`
416        separator to use for nested keys
417        (defaults to `'.'`)
418     - `dims_names_map : dict[int, str] | None`
419        convert certain dimension values in shape. not perfect, can be buggy
420        (defaults to `None`)
421     - `condense_numeric_keys : bool`
422        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts`
423        (defaults to `True`)
424     - `condense_matching_values : bool`
425        whether to condense keys with matching values, passed on to `condense_nested_dicts`
426        (defaults to `True`)
427     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
428        a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts`
429        (defaults to `None`)
430     - `return_format : TensorDictFormats | None`
431        legacy alias for `fmt` kwarg
432
433    # Returns:
434     - `str|dict[str, str|tuple[int, ...]]`
435        dict if `return_format='dict'`, a string for `json` or `yaml` output
436
437    # Examples:
438    ```python
439    >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
440    >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
441    ```
442    ```yaml
443    embed:
444      W_E: (50257, 768)
445    pos_embed:
446      W_pos: (1024, 768)
447    blocks:
448      '[0-11]':
449        attn:
450          '[W_Q, W_K, W_V]': (12, 768, 64)
451          W_O: (12, 64, 768)
452          '[b_Q, b_K, b_V]': (12, 64)
453          b_O: (768,)
454        mlp:
455          W_in: (768, 3072)
456          b_in: (3072,)
457          W_out: (3072, 768)
458          b_out: (768,)
459    unembed:
460      W_U: (768, 50257)
461      b_U: (50257,)
462    ```
463
464    # Raises:
465     - `ValueError` :  if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed
466    """
467
468    # handle arg processing:
469    # ----------------------------------------------------------------------
470    # make all args except data and format keyword-only
471    assert len(args) == 0, f"unexpected positional args: {args}"
472    # handle legacy return_format
473    if return_format is not None:
474        warnings.warn(
475            "return_format is deprecated, use fmt instead",
476            DeprecationWarning,
477        )
478        fmt = return_format
479
480    # identity function for shapes_convert if not provided
481    if shapes_convert is None:
482        shapes_convert = lambda x: x  # noqa: E731
483
484    # convert to iterable
485    data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = (  # type: ignore # noqa: F821
486        data.items() if hasattr(data, "items") and callable(data.items) else data  # type: ignore
487    )
488
489    # get shapes
490    data_shapes: dict[str, Union[str, tuple[int, ...]]] = {  # pyright: ignore[reportAssignmentType]
491        k: shapes_convert(
492            tuple_dims_replace(
493                tuple(v.shape)[drop_batch_dims:],
494                dims_names_map,
495            )
496        )
497        for k, v in data_items
498    }
499
500    # nest the dict
501    data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep)
502
503    # condense the nested dict
504    data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts(
505        data=data_nested,
506        condense_numeric_keys=condense_numeric_keys,
507        condense_matching_values=condense_matching_values,
508        val_condense_fallback_mapping=val_condense_fallback_mapping,
509    )
510
511    # return in the specified format
512    fmt_lower: str = fmt.lower()
513    if fmt_lower == "dict":
514        return data_condensed
515    elif fmt_lower == "json":
516        import json
517
518        return json.dumps(data_condensed, indent=2)
519    elif fmt_lower in ["yaml", "yml"]:
520        try:
521            import yaml  # type: ignore[import-untyped]
522
523            return yaml.dump(data_condensed, sort_keys=False)
524        except ImportError as e:
525            raise ValueError("PyYAML is required for YAML output") from e
526    else:
527        raise ValueError(f"Invalid return format: {fmt}")

Convert a dictionary of tensors to a dictionary of shapes.

by default, values are converted to strings of their shapes (for nice printing). If you want the actual shapes, set shapes_convert = lambda x: x or shapes_convert = None.

Parameters:

  • data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]] a either a TensorDict dict from strings to tensors, or an TensorIterable iterable of (key, tensor) pairs (like you might get from a dict().items()) )
  • fmt : TensorDictFormats format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. (defaults to 'dict')
  • shapes_convert : Callable[[tuple], Any] conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) (defaults to lambdax:str(x).replace('"', '').replace("'", ''))
  • drop_batch_dims : int number of leading dimensions to drop from the shape (defaults to 0)
  • sep : str separator to use for nested keys (defaults to '.')
  • dims_names_map : dict[int, str] | None convert certain dimension values in shape. not perfect, can be buggy (defaults to None)
  • condense_numeric_keys : bool whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to condense_nested_dicts (defaults to True)
  • condense_matching_values : bool whether to condense keys with matching values, passed on to condense_nested_dicts (defaults to True)
  • val_condense_fallback_mapping : Callable[[Any], Hashable] | None a function to apply to each value before adding it to the dict (if it's not hashable), passed on to condense_nested_dicts (defaults to None)
  • return_format : TensorDictFormats | None legacy alias for fmt kwarg

Returns:

  • str|dict[str, str|tuple[int, ...]] dict if return_format='dict', a string for json or yaml output

Examples:

>>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
>>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
embed:
  W_E: (50257, 768)
pos_embed:
  W_pos: (1024, 768)
blocks:
  '[0-11]':
    attn:
      '[W_Q, W_K, W_V]': (12, 768, 64)
      W_O: (12, 64, 768)
      '[b_Q, b_K, b_V]': (12, 64)
      b_O: (768,)
    mlp:
      W_in: (768, 3072)
      b_in: (3072,)
      W_out: (3072, 768)
      b_out: (768,)
unembed:
  W_U: (768, 50257)
  b_U: (50257,)

Raises:

  • ValueError : if return_format is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed