muutils.dictmagic
making working with dictionaries easier
DefaulterDict: like a defaultdict, but default_factory is passed the key as an argument- various methods for working wit dotlist-nested dicts, converting to and from them
condense_nested_dicts: condense a nested dict, by condensing numeric or matching keys with matching values to rangescondense_tensor_dict: convert a dictionary of tensors to a dictionary of shapeskwargs_to_nested_dict: given kwargs from fire, convert them to a nested dict
1"""making working with dictionaries easier 2 3- `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument 4- various methods for working wit dotlist-nested dicts, converting to and from them 5- `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges 6- `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes 7- `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict 8""" 9 10from __future__ import annotations 11 12import typing 13import warnings 14from collections import defaultdict 15from typing import ( 16 Any, 17 Callable, 18 Generic, 19 Hashable, 20 Iterable, 21 Literal, 22 Optional, 23 TypeVar, 24 Union, 25) 26 27from muutils.errormode import ErrorMode 28 29_KT = TypeVar("_KT") 30_VT = TypeVar("_VT") 31 32 33class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): 34 """like a defaultdict, but default_factory is passed the key as an argument""" 35 36 def __init__( 37 self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any 38 ) -> None: 39 if args: 40 raise TypeError( 41 f"DefaulterDict does not support positional arguments: *args = {args}" 42 ) 43 super().__init__(**kwargs) 44 self.default_factory: Callable[[_KT], _VT] = default_factory 45 46 def __getitem__(self, k: _KT) -> _VT: 47 if k in self: 48 return dict.__getitem__(self, k) 49 else: 50 v: _VT = self.default_factory(k) 51 dict.__setitem__(self, k, v) 52 return v 53 54 55def _recursive_defaultdict_ctor() -> defaultdict: 56 return defaultdict(_recursive_defaultdict_ctor) 57 58 59def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict: 60 """Convert a defaultdict or DefaulterDict to a normal dict, recursively""" 61 return { 62 key: ( 63 defaultdict_to_dict_recursive(value) 64 if isinstance(value, (defaultdict, DefaulterDict)) 65 else value 66 ) 67 for key, value in dd.items() 68 } 69 70 71def dotlist_to_nested_dict( 72 dot_dict: typing.Dict[str, Any], sep: str = "." 73) -> typing.Dict[str, Any]: 74 """Convert a dict with dot-separated keys to a nested dict 75 76 Example: 77 78 >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) 79 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 80 """ 81 nested_dict: defaultdict = _recursive_defaultdict_ctor() 82 for key, value in dot_dict.items(): 83 if not isinstance(key, str): 84 raise TypeError(f"key must be a string, got {type(key)}") 85 keys: list[str] = key.split(sep) 86 current: defaultdict = nested_dict 87 # iterate over the keys except the last one 88 for sub_key in keys[:-1]: 89 current = current[sub_key] 90 current[keys[-1]] = value 91 return defaultdict_to_dict_recursive(nested_dict) 92 93 94def nested_dict_to_dotlist( 95 nested_dict: typing.Dict[str, Any], 96 sep: str = ".", 97 allow_lists: bool = False, 98) -> dict[str, Any]: 99 def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]: 100 items: dict = dict() 101 102 new_key: str 103 if isinstance(current, dict): 104 # dict case 105 if not current and parent_key: 106 items[parent_key] = current 107 else: 108 for k, v in current.items(): 109 new_key = f"{parent_key}{sep}{k}" if parent_key else k 110 items.update(_recurse(v, new_key)) 111 112 elif allow_lists and isinstance(current, list): 113 # list case 114 for i, item in enumerate(current): 115 new_key = f"{parent_key}{sep}{i}" if parent_key else str(i) 116 items.update(_recurse(item, new_key)) 117 118 else: 119 # anything else (write value) 120 items[parent_key] = current 121 122 return items 123 124 return _recurse(nested_dict) 125 126 127def update_with_nested_dict( 128 original: dict[str, Any], 129 update: dict[str, Any], 130) -> dict[str, Any]: 131 """Update a dict with a nested dict 132 133 Example: 134 >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}}) 135 {'a': {'b': 2}, 'c': -1} 136 137 # Arguments 138 - `original: dict[str, Any]` 139 the dict to update (will be modified in-place) 140 - `update: dict[str, Any]` 141 the dict to update with 142 143 # Returns 144 - `dict` 145 the updated dict 146 """ 147 for key, value in update.items(): 148 if key in original: 149 if isinstance(original[key], dict) and isinstance(value, dict): 150 update_with_nested_dict(original[key], value) 151 else: 152 original[key] = value 153 else: 154 original[key] = value 155 156 return original 157 158 159def kwargs_to_nested_dict( 160 kwargs_dict: dict[str, Any], 161 sep: str = ".", 162 strip_prefix: Optional[str] = None, 163 when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN, 164 transform_key: Optional[Callable[[str], str]] = None, 165) -> dict[str, Any]: 166 """given kwargs from fire, convert them to a nested dict 167 168 if strip_prefix is not None, then all keys must start with the prefix. by default, 169 will warn if an unknown prefix is found, but can be set to raise an error or ignore it: 170 `when_unknown_prefix: ErrorMode` 171 172 Example: 173 ```python 174 def main(**kwargs): 175 print(kwargs_to_nested_dict(kwargs)) 176 fire.Fire(main) 177 ``` 178 running the above script will give: 179 ```bash 180 $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3 181 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 182 ``` 183 184 # Arguments 185 - `kwargs_dict: dict[str, Any]` 186 the kwargs dict to convert 187 - `sep: str = "."` 188 the separator to use for nested keys 189 - `strip_prefix: Optional[str] = None` 190 if not None, then all keys must start with this prefix 191 - `when_unknown_prefix: ErrorMode = ErrorMode.WARN` 192 what to do when an unknown prefix is found 193 - `transform_key: Callable[[str], str] | None = None` 194 a function to apply to each key before adding it to the dict (applied after stripping the prefix) 195 """ 196 when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix) 197 filtered_kwargs: dict[str, Any] = dict() 198 for key, value in kwargs_dict.items(): 199 if strip_prefix is not None: 200 if not key.startswith(strip_prefix): 201 when_unknown_prefix_.process( 202 f"key '{key}' does not start with '{strip_prefix}'", 203 except_cls=ValueError, 204 ) 205 else: 206 key = key[len(strip_prefix) :] 207 208 if transform_key is not None: 209 key = transform_key(key) 210 211 filtered_kwargs[key] = value 212 213 return dotlist_to_nested_dict(filtered_kwargs, sep=sep) 214 215 216def is_numeric_consecutive(lst: list[str]) -> bool: 217 """Check if the list of keys is numeric and consecutive.""" 218 try: 219 numbers: list[int] = [int(x) for x in lst] 220 return sorted(numbers) == list(range(min(numbers), max(numbers) + 1)) 221 except ValueError: 222 return False 223 224 225def condense_nested_dicts_numeric_keys( 226 data: dict[str, Any], 227) -> dict[str, Any]: 228 """condense a nested dict, by condensing numeric keys with matching values to ranges 229 230 # Examples: 231 ```python 232 >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2}) 233 {'[1-3]': 1, '[4-6]': 2} 234 >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'}) 235 {"1": {"[1-2]": "a"}, "2": "b"} 236 ``` 237 """ 238 239 if not isinstance(data, dict): 240 return data 241 242 # Process each sub-dictionary 243 for key, value in list(data.items()): 244 data[key] = condense_nested_dicts_numeric_keys(value) 245 246 # Find all numeric, consecutive keys 247 if is_numeric_consecutive(list(data.keys())): 248 keys: list[str] = sorted(data.keys(), key=lambda x: int(x)) 249 else: 250 return data 251 252 # output dict 253 condensed_data: dict[str, Any] = {} 254 255 # Identify ranges of identical values and condense 256 i: int = 0 257 while i < len(keys): 258 j: int = i 259 while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]: 260 j += 1 261 if j > i: # Found consecutive keys with identical values 262 condensed_key: str = f"[{keys[i]}-{keys[j]}]" 263 condensed_data[condensed_key] = data[keys[i]] 264 i = j + 1 265 else: 266 condensed_data[keys[i]] = data[keys[i]] 267 i += 1 268 269 return condensed_data 270 271 272def condense_nested_dicts_matching_values( 273 data: dict[str, Any], 274 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 275) -> dict[str, Any]: 276 """condense a nested dict, by condensing keys with matching values 277 278 # Examples: TODO 279 280 # Parameters: 281 - `data : dict[str, Any]` 282 data to process 283 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 284 a function to apply to each value before adding it to the dict (if it's not hashable) 285 (defaults to `None`) 286 287 """ 288 289 if isinstance(data, dict): 290 data = { 291 key: condense_nested_dicts_matching_values( 292 value, val_condense_fallback_mapping 293 ) 294 for key, value in data.items() 295 } 296 else: 297 return data 298 299 # Find all identical values and condense by stitching together keys 300 values_grouped: defaultdict[Any, list[str]] = defaultdict(list) 301 data_persist: dict[str, Any] = dict() 302 for key, value in data.items(): 303 if not isinstance(value, dict): 304 try: 305 values_grouped[value].append(key) 306 except TypeError: 307 # If the value is unhashable, use a fallback mapping to find a hashable representation 308 if val_condense_fallback_mapping is not None: 309 values_grouped[val_condense_fallback_mapping(value)].append(key) 310 else: 311 data_persist[key] = value 312 else: 313 data_persist[key] = value 314 315 condensed_data = data_persist 316 for value, keys in values_grouped.items(): 317 if len(keys) > 1: 318 merged_key = f"[{', '.join(keys)}]" # Choose an appropriate method to represent merged keys 319 condensed_data[merged_key] = value 320 else: 321 condensed_data[keys[0]] = value 322 323 return condensed_data 324 325 326def condense_nested_dicts( 327 data: dict[str, Any], 328 condense_numeric_keys: bool = True, 329 condense_matching_values: bool = True, 330 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 331) -> dict[str, Any]: 332 """condense a nested dict, by condensing numeric or matching keys with matching values to ranges 333 334 combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()` 335 336 # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes 337 it's not reversible because types are lost to make the printing pretty 338 339 # Parameters: 340 - `data : dict[str, Any]` 341 data to process 342 - `condense_numeric_keys : bool` 343 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") 344 (defaults to `True`) 345 - `condense_matching_values : bool` 346 whether to condense keys with matching values 347 (defaults to `True`) 348 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 349 a function to apply to each value before adding it to the dict (if it's not hashable) 350 (defaults to `None`) 351 352 """ 353 354 condensed_data: dict = data 355 if condense_numeric_keys: 356 condensed_data = condense_nested_dicts_numeric_keys(condensed_data) 357 if condense_matching_values: 358 condensed_data = condense_nested_dicts_matching_values( 359 condensed_data, val_condense_fallback_mapping 360 ) 361 return condensed_data 362 363 364def tuple_dims_replace( 365 t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None 366) -> tuple[Union[int, str], ...]: 367 if dims_names_map is None: 368 return t 369 else: 370 return tuple(dims_names_map.get(x, x) for x in t) 371 372 373TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] # noqa: F821 374TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] # noqa: F821 375TensorDictFormats = Literal["dict", "json", "yaml", "yml"] 376 377 378def _default_shapes_convert(x: tuple) -> str: 379 return str(x).replace('"', "").replace("'", "") 380 381 382def condense_tensor_dict( 383 data: TensorDict | TensorIterable, 384 fmt: TensorDictFormats = "dict", 385 *args: Any, 386 shapes_convert: Callable[ 387 [tuple[Union[int, str], ...]], Any 388 ] = _default_shapes_convert, 389 drop_batch_dims: int = 0, 390 sep: str = ".", 391 dims_names_map: Optional[dict[int, str]] = None, 392 condense_numeric_keys: bool = True, 393 condense_matching_values: bool = True, 394 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 395 return_format: Optional[TensorDictFormats] = None, 396) -> Union[str, dict[str, str | tuple[int, ...]]]: 397 """Convert a dictionary of tensors to a dictionary of shapes. 398 399 by default, values are converted to strings of their shapes (for nice printing). 400 If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. 401 402 # Parameters: 403 - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` 404 a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) 405 - `fmt : TensorDictFormats` 406 format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. 407 (defaults to `'dict'`) 408 - `shapes_convert : Callable[[tuple], Any]` 409 conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) 410 (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`) 411 - `drop_batch_dims : int` 412 number of leading dimensions to drop from the shape 413 (defaults to `0`) 414 - `sep : str` 415 separator to use for nested keys 416 (defaults to `'.'`) 417 - `dims_names_map : dict[int, str] | None` 418 convert certain dimension values in shape. not perfect, can be buggy 419 (defaults to `None`) 420 - `condense_numeric_keys : bool` 421 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts` 422 (defaults to `True`) 423 - `condense_matching_values : bool` 424 whether to condense keys with matching values, passed on to `condense_nested_dicts` 425 (defaults to `True`) 426 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 427 a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts` 428 (defaults to `None`) 429 - `return_format : TensorDictFormats | None` 430 legacy alias for `fmt` kwarg 431 432 # Returns: 433 - `str|dict[str, str|tuple[int, ...]]` 434 dict if `return_format='dict'`, a string for `json` or `yaml` output 435 436 # Examples: 437 ```python 438 >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2") 439 >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml')) 440 ``` 441 ```yaml 442 embed: 443 W_E: (50257, 768) 444 pos_embed: 445 W_pos: (1024, 768) 446 blocks: 447 '[0-11]': 448 attn: 449 '[W_Q, W_K, W_V]': (12, 768, 64) 450 W_O: (12, 64, 768) 451 '[b_Q, b_K, b_V]': (12, 64) 452 b_O: (768,) 453 mlp: 454 W_in: (768, 3072) 455 b_in: (3072,) 456 W_out: (3072, 768) 457 b_out: (768,) 458 unembed: 459 W_U: (768, 50257) 460 b_U: (50257,) 461 ``` 462 463 # Raises: 464 - `ValueError` : if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed 465 """ 466 467 # handle arg processing: 468 # ---------------------------------------------------------------------- 469 # make all args except data and format keyword-only 470 assert len(args) == 0, f"unexpected positional args: {args}" 471 # handle legacy return_format 472 if return_format is not None: 473 warnings.warn( 474 "return_format is deprecated, use fmt instead", 475 DeprecationWarning, 476 ) 477 fmt = return_format 478 479 # identity function for shapes_convert if not provided 480 if shapes_convert is None: 481 shapes_convert = lambda x: x # noqa: E731 482 483 # convert to iterable 484 data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821 485 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore 486 ) 487 488 # get shapes 489 data_shapes: dict[str, Union[str, tuple[int, ...]]] = { # pyright: ignore[reportAssignmentType] 490 k: shapes_convert( 491 tuple_dims_replace( 492 tuple(v.shape)[drop_batch_dims:], 493 dims_names_map, 494 ) 495 ) 496 for k, v in data_items 497 } 498 499 # nest the dict 500 data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) 501 502 # condense the nested dict 503 data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts( 504 data=data_nested, 505 condense_numeric_keys=condense_numeric_keys, 506 condense_matching_values=condense_matching_values, 507 val_condense_fallback_mapping=val_condense_fallback_mapping, 508 ) 509 510 # return in the specified format 511 fmt_lower: str = fmt.lower() 512 if fmt_lower == "dict": 513 return data_condensed 514 elif fmt_lower == "json": 515 import json 516 517 return json.dumps(data_condensed, indent=2) 518 elif fmt_lower in ["yaml", "yml"]: 519 try: 520 import yaml # type: ignore[import-untyped] 521 522 return yaml.dump(data_condensed, sort_keys=False) 523 except ImportError as e: 524 raise ValueError("PyYAML is required for YAML output") from e 525 else: 526 raise ValueError(f"Invalid return format: {fmt}")
34class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): 35 """like a defaultdict, but default_factory is passed the key as an argument""" 36 37 def __init__( 38 self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any 39 ) -> None: 40 if args: 41 raise TypeError( 42 f"DefaulterDict does not support positional arguments: *args = {args}" 43 ) 44 super().__init__(**kwargs) 45 self.default_factory: Callable[[_KT], _VT] = default_factory 46 47 def __getitem__(self, k: _KT) -> _VT: 48 if k in self: 49 return dict.__getitem__(self, k) 50 else: 51 v: _VT = self.default_factory(k) 52 dict.__setitem__(self, k, v) 53 return v
like a defaultdict, but default_factory is passed the key as an argument
Inherited Members
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
60def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict: 61 """Convert a defaultdict or DefaulterDict to a normal dict, recursively""" 62 return { 63 key: ( 64 defaultdict_to_dict_recursive(value) 65 if isinstance(value, (defaultdict, DefaulterDict)) 66 else value 67 ) 68 for key, value in dd.items() 69 }
Convert a defaultdict or DefaulterDict to a normal dict, recursively
72def dotlist_to_nested_dict( 73 dot_dict: typing.Dict[str, Any], sep: str = "." 74) -> typing.Dict[str, Any]: 75 """Convert a dict with dot-separated keys to a nested dict 76 77 Example: 78 79 >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) 80 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 81 """ 82 nested_dict: defaultdict = _recursive_defaultdict_ctor() 83 for key, value in dot_dict.items(): 84 if not isinstance(key, str): 85 raise TypeError(f"key must be a string, got {type(key)}") 86 keys: list[str] = key.split(sep) 87 current: defaultdict = nested_dict 88 # iterate over the keys except the last one 89 for sub_key in keys[:-1]: 90 current = current[sub_key] 91 current[keys[-1]] = value 92 return defaultdict_to_dict_recursive(nested_dict)
Convert a dict with dot-separated keys to a nested dict
Example:
>>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
95def nested_dict_to_dotlist( 96 nested_dict: typing.Dict[str, Any], 97 sep: str = ".", 98 allow_lists: bool = False, 99) -> dict[str, Any]: 100 def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]: 101 items: dict = dict() 102 103 new_key: str 104 if isinstance(current, dict): 105 # dict case 106 if not current and parent_key: 107 items[parent_key] = current 108 else: 109 for k, v in current.items(): 110 new_key = f"{parent_key}{sep}{k}" if parent_key else k 111 items.update(_recurse(v, new_key)) 112 113 elif allow_lists and isinstance(current, list): 114 # list case 115 for i, item in enumerate(current): 116 new_key = f"{parent_key}{sep}{i}" if parent_key else str(i) 117 items.update(_recurse(item, new_key)) 118 119 else: 120 # anything else (write value) 121 items[parent_key] = current 122 123 return items 124 125 return _recurse(nested_dict)
128def update_with_nested_dict( 129 original: dict[str, Any], 130 update: dict[str, Any], 131) -> dict[str, Any]: 132 """Update a dict with a nested dict 133 134 Example: 135 >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}}) 136 {'a': {'b': 2}, 'c': -1} 137 138 # Arguments 139 - `original: dict[str, Any]` 140 the dict to update (will be modified in-place) 141 - `update: dict[str, Any]` 142 the dict to update with 143 144 # Returns 145 - `dict` 146 the updated dict 147 """ 148 for key, value in update.items(): 149 if key in original: 150 if isinstance(original[key], dict) and isinstance(value, dict): 151 update_with_nested_dict(original[key], value) 152 else: 153 original[key] = value 154 else: 155 original[key] = value 156 157 return original
Update a dict with a nested dict
Example:
>>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
{'a': {'b': 2}, 'c': -1}
Arguments
original: dict[str, Any]the dict to update (will be modified in-place)update: dict[str, Any]the dict to update with
Returns
dictthe updated dict
160def kwargs_to_nested_dict( 161 kwargs_dict: dict[str, Any], 162 sep: str = ".", 163 strip_prefix: Optional[str] = None, 164 when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN, 165 transform_key: Optional[Callable[[str], str]] = None, 166) -> dict[str, Any]: 167 """given kwargs from fire, convert them to a nested dict 168 169 if strip_prefix is not None, then all keys must start with the prefix. by default, 170 will warn if an unknown prefix is found, but can be set to raise an error or ignore it: 171 `when_unknown_prefix: ErrorMode` 172 173 Example: 174 ```python 175 def main(**kwargs): 176 print(kwargs_to_nested_dict(kwargs)) 177 fire.Fire(main) 178 ``` 179 running the above script will give: 180 ```bash 181 $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3 182 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 183 ``` 184 185 # Arguments 186 - `kwargs_dict: dict[str, Any]` 187 the kwargs dict to convert 188 - `sep: str = "."` 189 the separator to use for nested keys 190 - `strip_prefix: Optional[str] = None` 191 if not None, then all keys must start with this prefix 192 - `when_unknown_prefix: ErrorMode = ErrorMode.WARN` 193 what to do when an unknown prefix is found 194 - `transform_key: Callable[[str], str] | None = None` 195 a function to apply to each key before adding it to the dict (applied after stripping the prefix) 196 """ 197 when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix) 198 filtered_kwargs: dict[str, Any] = dict() 199 for key, value in kwargs_dict.items(): 200 if strip_prefix is not None: 201 if not key.startswith(strip_prefix): 202 when_unknown_prefix_.process( 203 f"key '{key}' does not start with '{strip_prefix}'", 204 except_cls=ValueError, 205 ) 206 else: 207 key = key[len(strip_prefix) :] 208 209 if transform_key is not None: 210 key = transform_key(key) 211 212 filtered_kwargs[key] = value 213 214 return dotlist_to_nested_dict(filtered_kwargs, sep=sep)
given kwargs from fire, convert them to a nested dict
if strip_prefix is not None, then all keys must start with the prefix. by default,
will warn if an unknown prefix is found, but can be set to raise an error or ignore it:
when_unknown_prefix: ErrorMode
Example:
def main(**kwargs):
print(kwargs_to_nested_dict(kwargs))
fire.Fire(main)
running the above script will give:
$ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
Arguments
kwargs_dict: dict[str, Any]the kwargs dict to convertsep: str = "."the separator to use for nested keysstrip_prefix: Optional[str] = Noneif not None, then all keys must start with this prefixwhen_unknown_prefix: ErrorMode = ErrorMode.WARNwhat to do when an unknown prefix is foundtransform_key: Callable[[str], str] | None = Nonea function to apply to each key before adding it to the dict (applied after stripping the prefix)
217def is_numeric_consecutive(lst: list[str]) -> bool: 218 """Check if the list of keys is numeric and consecutive.""" 219 try: 220 numbers: list[int] = [int(x) for x in lst] 221 return sorted(numbers) == list(range(min(numbers), max(numbers) + 1)) 222 except ValueError: 223 return False
Check if the list of keys is numeric and consecutive.
226def condense_nested_dicts_numeric_keys( 227 data: dict[str, Any], 228) -> dict[str, Any]: 229 """condense a nested dict, by condensing numeric keys with matching values to ranges 230 231 # Examples: 232 ```python 233 >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2}) 234 {'[1-3]': 1, '[4-6]': 2} 235 >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'}) 236 {"1": {"[1-2]": "a"}, "2": "b"} 237 ``` 238 """ 239 240 if not isinstance(data, dict): 241 return data 242 243 # Process each sub-dictionary 244 for key, value in list(data.items()): 245 data[key] = condense_nested_dicts_numeric_keys(value) 246 247 # Find all numeric, consecutive keys 248 if is_numeric_consecutive(list(data.keys())): 249 keys: list[str] = sorted(data.keys(), key=lambda x: int(x)) 250 else: 251 return data 252 253 # output dict 254 condensed_data: dict[str, Any] = {} 255 256 # Identify ranges of identical values and condense 257 i: int = 0 258 while i < len(keys): 259 j: int = i 260 while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]: 261 j += 1 262 if j > i: # Found consecutive keys with identical values 263 condensed_key: str = f"[{keys[i]}-{keys[j]}]" 264 condensed_data[condensed_key] = data[keys[i]] 265 i = j + 1 266 else: 267 condensed_data[keys[i]] = data[keys[i]] 268 i += 1 269 270 return condensed_data
condense a nested dict, by condensing numeric keys with matching values to ranges
Examples:
>>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
{'[1-3]': 1, '[4-6]': 2}
>>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
{"1": {"[1-2]": "a"}, "2": "b"}
273def condense_nested_dicts_matching_values( 274 data: dict[str, Any], 275 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 276) -> dict[str, Any]: 277 """condense a nested dict, by condensing keys with matching values 278 279 # Examples: TODO 280 281 # Parameters: 282 - `data : dict[str, Any]` 283 data to process 284 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 285 a function to apply to each value before adding it to the dict (if it's not hashable) 286 (defaults to `None`) 287 288 """ 289 290 if isinstance(data, dict): 291 data = { 292 key: condense_nested_dicts_matching_values( 293 value, val_condense_fallback_mapping 294 ) 295 for key, value in data.items() 296 } 297 else: 298 return data 299 300 # Find all identical values and condense by stitching together keys 301 values_grouped: defaultdict[Any, list[str]] = defaultdict(list) 302 data_persist: dict[str, Any] = dict() 303 for key, value in data.items(): 304 if not isinstance(value, dict): 305 try: 306 values_grouped[value].append(key) 307 except TypeError: 308 # If the value is unhashable, use a fallback mapping to find a hashable representation 309 if val_condense_fallback_mapping is not None: 310 values_grouped[val_condense_fallback_mapping(value)].append(key) 311 else: 312 data_persist[key] = value 313 else: 314 data_persist[key] = value 315 316 condensed_data = data_persist 317 for value, keys in values_grouped.items(): 318 if len(keys) > 1: 319 merged_key = f"[{', '.join(keys)}]" # Choose an appropriate method to represent merged keys 320 condensed_data[merged_key] = value 321 else: 322 condensed_data[keys[0]] = value 323 324 return condensed_data
condense a nested dict, by condensing keys with matching values
Examples: TODO
Parameters:
data : dict[str, Any]data to processval_condense_fallback_mapping : Callable[[Any], Hashable] | Nonea function to apply to each value before adding it to the dict (if it's not hashable) (defaults toNone)
327def condense_nested_dicts( 328 data: dict[str, Any], 329 condense_numeric_keys: bool = True, 330 condense_matching_values: bool = True, 331 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 332) -> dict[str, Any]: 333 """condense a nested dict, by condensing numeric or matching keys with matching values to ranges 334 335 combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()` 336 337 # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes 338 it's not reversible because types are lost to make the printing pretty 339 340 # Parameters: 341 - `data : dict[str, Any]` 342 data to process 343 - `condense_numeric_keys : bool` 344 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") 345 (defaults to `True`) 346 - `condense_matching_values : bool` 347 whether to condense keys with matching values 348 (defaults to `True`) 349 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 350 a function to apply to each value before adding it to the dict (if it's not hashable) 351 (defaults to `None`) 352 353 """ 354 355 condensed_data: dict = data 356 if condense_numeric_keys: 357 condensed_data = condense_nested_dicts_numeric_keys(condensed_data) 358 if condense_matching_values: 359 condensed_data = condense_nested_dicts_matching_values( 360 condensed_data, val_condense_fallback_mapping 361 ) 362 return condensed_data
condense a nested dict, by condensing numeric or matching keys with matching values to ranges
combines the functionality of condense_nested_dicts_numeric_keys() and condense_nested_dicts_matching_values()
NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes
it's not reversible because types are lost to make the printing pretty
Parameters:
data : dict[str, Any]data to processcondense_numeric_keys : boolwhether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") (defaults toTrue)condense_matching_values : boolwhether to condense keys with matching values (defaults toTrue)val_condense_fallback_mapping : Callable[[Any], Hashable] | Nonea function to apply to each value before adding it to the dict (if it's not hashable) (defaults toNone)
383def condense_tensor_dict( 384 data: TensorDict | TensorIterable, 385 fmt: TensorDictFormats = "dict", 386 *args: Any, 387 shapes_convert: Callable[ 388 [tuple[Union[int, str], ...]], Any 389 ] = _default_shapes_convert, 390 drop_batch_dims: int = 0, 391 sep: str = ".", 392 dims_names_map: Optional[dict[int, str]] = None, 393 condense_numeric_keys: bool = True, 394 condense_matching_values: bool = True, 395 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 396 return_format: Optional[TensorDictFormats] = None, 397) -> Union[str, dict[str, str | tuple[int, ...]]]: 398 """Convert a dictionary of tensors to a dictionary of shapes. 399 400 by default, values are converted to strings of their shapes (for nice printing). 401 If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. 402 403 # Parameters: 404 - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` 405 a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) 406 - `fmt : TensorDictFormats` 407 format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. 408 (defaults to `'dict'`) 409 - `shapes_convert : Callable[[tuple], Any]` 410 conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) 411 (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`) 412 - `drop_batch_dims : int` 413 number of leading dimensions to drop from the shape 414 (defaults to `0`) 415 - `sep : str` 416 separator to use for nested keys 417 (defaults to `'.'`) 418 - `dims_names_map : dict[int, str] | None` 419 convert certain dimension values in shape. not perfect, can be buggy 420 (defaults to `None`) 421 - `condense_numeric_keys : bool` 422 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts` 423 (defaults to `True`) 424 - `condense_matching_values : bool` 425 whether to condense keys with matching values, passed on to `condense_nested_dicts` 426 (defaults to `True`) 427 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 428 a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts` 429 (defaults to `None`) 430 - `return_format : TensorDictFormats | None` 431 legacy alias for `fmt` kwarg 432 433 # Returns: 434 - `str|dict[str, str|tuple[int, ...]]` 435 dict if `return_format='dict'`, a string for `json` or `yaml` output 436 437 # Examples: 438 ```python 439 >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2") 440 >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml')) 441 ``` 442 ```yaml 443 embed: 444 W_E: (50257, 768) 445 pos_embed: 446 W_pos: (1024, 768) 447 blocks: 448 '[0-11]': 449 attn: 450 '[W_Q, W_K, W_V]': (12, 768, 64) 451 W_O: (12, 64, 768) 452 '[b_Q, b_K, b_V]': (12, 64) 453 b_O: (768,) 454 mlp: 455 W_in: (768, 3072) 456 b_in: (3072,) 457 W_out: (3072, 768) 458 b_out: (768,) 459 unembed: 460 W_U: (768, 50257) 461 b_U: (50257,) 462 ``` 463 464 # Raises: 465 - `ValueError` : if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed 466 """ 467 468 # handle arg processing: 469 # ---------------------------------------------------------------------- 470 # make all args except data and format keyword-only 471 assert len(args) == 0, f"unexpected positional args: {args}" 472 # handle legacy return_format 473 if return_format is not None: 474 warnings.warn( 475 "return_format is deprecated, use fmt instead", 476 DeprecationWarning, 477 ) 478 fmt = return_format 479 480 # identity function for shapes_convert if not provided 481 if shapes_convert is None: 482 shapes_convert = lambda x: x # noqa: E731 483 484 # convert to iterable 485 data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821 486 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore 487 ) 488 489 # get shapes 490 data_shapes: dict[str, Union[str, tuple[int, ...]]] = { # pyright: ignore[reportAssignmentType] 491 k: shapes_convert( 492 tuple_dims_replace( 493 tuple(v.shape)[drop_batch_dims:], 494 dims_names_map, 495 ) 496 ) 497 for k, v in data_items 498 } 499 500 # nest the dict 501 data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) 502 503 # condense the nested dict 504 data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts( 505 data=data_nested, 506 condense_numeric_keys=condense_numeric_keys, 507 condense_matching_values=condense_matching_values, 508 val_condense_fallback_mapping=val_condense_fallback_mapping, 509 ) 510 511 # return in the specified format 512 fmt_lower: str = fmt.lower() 513 if fmt_lower == "dict": 514 return data_condensed 515 elif fmt_lower == "json": 516 import json 517 518 return json.dumps(data_condensed, indent=2) 519 elif fmt_lower in ["yaml", "yml"]: 520 try: 521 import yaml # type: ignore[import-untyped] 522 523 return yaml.dump(data_condensed, sort_keys=False) 524 except ImportError as e: 525 raise ValueError("PyYAML is required for YAML output") from e 526 else: 527 raise ValueError(f"Invalid return format: {fmt}")
Convert a dictionary of tensors to a dictionary of shapes.
by default, values are converted to strings of their shapes (for nice printing).
If you want the actual shapes, set shapes_convert = lambda x: x or shapes_convert = None.
Parameters:
data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]a either aTensorDictdict from strings to tensors, or anTensorIterableiterable of (key, tensor) pairs (like you might get from adict().items()))fmt : TensorDictFormatsformat to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. (defaults to'dict')shapes_convert : Callable[[tuple], Any]conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) (defaults tolambdax:str(x).replace('"', '').replace("'", ''))drop_batch_dims : intnumber of leading dimensions to drop from the shape (defaults to0)sep : strseparator to use for nested keys (defaults to'.')dims_names_map : dict[int, str] | Noneconvert certain dimension values in shape. not perfect, can be buggy (defaults toNone)condense_numeric_keys : boolwhether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on tocondense_nested_dicts(defaults toTrue)condense_matching_values : boolwhether to condense keys with matching values, passed on tocondense_nested_dicts(defaults toTrue)val_condense_fallback_mapping : Callable[[Any], Hashable] | Nonea function to apply to each value before adding it to the dict (if it's not hashable), passed on tocondense_nested_dicts(defaults toNone)return_format : TensorDictFormats | Nonelegacy alias forfmtkwarg
Returns:
str|dict[str, str|tuple[int, ...]]dict ifreturn_format='dict', a string forjsonoryamloutput
Examples:
>>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
>>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
embed:
W_E: (50257, 768)
pos_embed:
W_pos: (1024, 768)
blocks:
'[0-11]':
attn:
'[W_Q, W_K, W_V]': (12, 768, 64)
W_O: (12, 64, 768)
'[b_Q, b_K, b_V]': (12, 64)
b_O: (768,)
mlp:
W_in: (768, 3072)
b_in: (3072,)
W_out: (3072, 768)
b_out: (768,)
unembed:
W_U: (768, 50257)
b_U: (50257,)
Raises:
ValueError: ifreturn_formatis not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed