muutils.dictmagic
making working with dictionaries easier
DefaulterDict
: like a defaultdict, but default_factory is passed the key as an argument- various methods for working wit dotlist-nested dicts, converting to and from them
condense_nested_dicts
: condense a nested dict, by condensing numeric or matching keys with matching values to rangescondense_tensor_dict
: convert a dictionary of tensors to a dictionary of shapeskwargs_to_nested_dict
: given kwargs from fire, convert them to a nested dict
1"""making working with dictionaries easier 2 3- `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument 4- various methods for working wit dotlist-nested dicts, converting to and from them 5- `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges 6- `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes 7- `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict 8""" 9 10from __future__ import annotations 11 12import typing 13import warnings 14from collections import defaultdict 15from typing import ( 16 Any, 17 Callable, 18 Generic, 19 Hashable, 20 Iterable, 21 Literal, 22 Optional, 23 TypeVar, 24 Union, 25) 26 27from muutils.errormode import ErrorMode 28 29_KT = TypeVar("_KT") 30_VT = TypeVar("_VT") 31 32 33class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): 34 """like a defaultdict, but default_factory is passed the key as an argument""" 35 36 def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs): 37 if args: 38 raise TypeError( 39 f"DefaulterDict does not support positional arguments: *args = {args}" 40 ) 41 super().__init__(**kwargs) 42 self.default_factory: Callable[[_KT], _VT] = default_factory 43 44 def __getitem__(self, k: _KT) -> _VT: 45 if k in self: 46 return dict.__getitem__(self, k) 47 else: 48 v: _VT = self.default_factory(k) 49 dict.__setitem__(self, k, v) 50 return v 51 52 53def _recursive_defaultdict_ctor() -> defaultdict: 54 return defaultdict(_recursive_defaultdict_ctor) 55 56 57def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict: 58 """Convert a defaultdict or DefaulterDict to a normal dict, recursively""" 59 return { 60 key: ( 61 defaultdict_to_dict_recursive(value) 62 if isinstance(value, (defaultdict, DefaulterDict)) 63 else value 64 ) 65 for key, value in dd.items() 66 } 67 68 69def dotlist_to_nested_dict( 70 dot_dict: typing.Dict[str, Any], sep: str = "." 71) -> typing.Dict[str, Any]: 72 """Convert a dict with dot-separated keys to a nested dict 73 74 Example: 75 76 >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) 77 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 78 """ 79 nested_dict: defaultdict = _recursive_defaultdict_ctor() 80 for key, value in dot_dict.items(): 81 if not isinstance(key, str): 82 raise TypeError(f"key must be a string, got {type(key)}") 83 keys: list[str] = key.split(sep) 84 current: defaultdict = nested_dict 85 # iterate over the keys except the last one 86 for sub_key in keys[:-1]: 87 current = current[sub_key] 88 current[keys[-1]] = value 89 return defaultdict_to_dict_recursive(nested_dict) 90 91 92def nested_dict_to_dotlist( 93 nested_dict: typing.Dict[str, Any], 94 sep: str = ".", 95 allow_lists: bool = False, 96) -> dict[str, Any]: 97 def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]: 98 items: dict = dict() 99 100 new_key: str 101 if isinstance(current, dict): 102 # dict case 103 if not current and parent_key: 104 items[parent_key] = current 105 else: 106 for k, v in current.items(): 107 new_key = f"{parent_key}{sep}{k}" if parent_key else k 108 items.update(_recurse(v, new_key)) 109 110 elif allow_lists and isinstance(current, list): 111 # list case 112 for i, item in enumerate(current): 113 new_key = f"{parent_key}{sep}{i}" if parent_key else str(i) 114 items.update(_recurse(item, new_key)) 115 116 else: 117 # anything else (write value) 118 items[parent_key] = current 119 120 return items 121 122 return _recurse(nested_dict) 123 124 125def update_with_nested_dict( 126 original: dict[str, Any], 127 update: dict[str, Any], 128) -> dict[str, Any]: 129 """Update a dict with a nested dict 130 131 Example: 132 >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}}) 133 {'a': {'b': 2}, 'c': -1} 134 135 # Arguments 136 - `original: dict[str, Any]` 137 the dict to update (will be modified in-place) 138 - `update: dict[str, Any]` 139 the dict to update with 140 141 # Returns 142 - `dict` 143 the updated dict 144 """ 145 for key, value in update.items(): 146 if key in original: 147 if isinstance(original[key], dict) and isinstance(value, dict): 148 update_with_nested_dict(original[key], value) 149 else: 150 original[key] = value 151 else: 152 original[key] = value 153 154 return original 155 156 157def kwargs_to_nested_dict( 158 kwargs_dict: dict[str, Any], 159 sep: str = ".", 160 strip_prefix: Optional[str] = None, 161 when_unknown_prefix: ErrorMode = ErrorMode.WARN, 162 transform_key: Optional[Callable[[str], str]] = None, 163) -> dict[str, Any]: 164 """given kwargs from fire, convert them to a nested dict 165 166 if strip_prefix is not None, then all keys must start with the prefix. by default, 167 will warn if an unknown prefix is found, but can be set to raise an error or ignore it: 168 `when_unknown_prefix: ErrorMode` 169 170 Example: 171 ```python 172 def main(**kwargs): 173 print(kwargs_to_nested_dict(kwargs)) 174 fire.Fire(main) 175 ``` 176 running the above script will give: 177 ```bash 178 $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3 179 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 180 ``` 181 182 # Arguments 183 - `kwargs_dict: dict[str, Any]` 184 the kwargs dict to convert 185 - `sep: str = "."` 186 the separator to use for nested keys 187 - `strip_prefix: Optional[str] = None` 188 if not None, then all keys must start with this prefix 189 - `when_unknown_prefix: ErrorMode = ErrorMode.WARN` 190 what to do when an unknown prefix is found 191 - `transform_key: Callable[[str], str] | None = None` 192 a function to apply to each key before adding it to the dict (applied after stripping the prefix) 193 """ 194 when_unknown_prefix = ErrorMode.from_any(when_unknown_prefix) 195 filtered_kwargs: dict[str, Any] = dict() 196 for key, value in kwargs_dict.items(): 197 if strip_prefix is not None: 198 if not key.startswith(strip_prefix): 199 when_unknown_prefix.process( 200 f"key '{key}' does not start with '{strip_prefix}'", 201 except_cls=ValueError, 202 ) 203 else: 204 key = key[len(strip_prefix) :] 205 206 if transform_key is not None: 207 key = transform_key(key) 208 209 filtered_kwargs[key] = value 210 211 return dotlist_to_nested_dict(filtered_kwargs, sep=sep) 212 213 214def is_numeric_consecutive(lst: list[str]) -> bool: 215 """Check if the list of keys is numeric and consecutive.""" 216 try: 217 numbers: list[int] = [int(x) for x in lst] 218 return sorted(numbers) == list(range(min(numbers), max(numbers) + 1)) 219 except ValueError: 220 return False 221 222 223def condense_nested_dicts_numeric_keys( 224 data: dict[str, Any], 225) -> dict[str, Any]: 226 """condense a nested dict, by condensing numeric keys with matching values to ranges 227 228 # Examples: 229 ```python 230 >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2}) 231 {'[1-3]': 1, '[4-6]': 2} 232 >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'}) 233 {"1": {"[1-2]": "a"}, "2": "b"} 234 ``` 235 """ 236 237 if not isinstance(data, dict): 238 return data 239 240 # Process each sub-dictionary 241 for key, value in list(data.items()): 242 data[key] = condense_nested_dicts_numeric_keys(value) 243 244 # Find all numeric, consecutive keys 245 if is_numeric_consecutive(list(data.keys())): 246 keys: list[str] = sorted(data.keys(), key=lambda x: int(x)) 247 else: 248 return data 249 250 # output dict 251 condensed_data: dict[str, Any] = {} 252 253 # Identify ranges of identical values and condense 254 i: int = 0 255 while i < len(keys): 256 j: int = i 257 while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]: 258 j += 1 259 if j > i: # Found consecutive keys with identical values 260 condensed_key: str = f"[{keys[i]}-{keys[j]}]" 261 condensed_data[condensed_key] = data[keys[i]] 262 i = j + 1 263 else: 264 condensed_data[keys[i]] = data[keys[i]] 265 i += 1 266 267 return condensed_data 268 269 270def condense_nested_dicts_matching_values( 271 data: dict[str, Any], 272 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 273) -> dict[str, Any]: 274 """condense a nested dict, by condensing keys with matching values 275 276 # Examples: TODO 277 278 # Parameters: 279 - `data : dict[str, Any]` 280 data to process 281 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 282 a function to apply to each value before adding it to the dict (if it's not hashable) 283 (defaults to `None`) 284 285 """ 286 287 if isinstance(data, dict): 288 data = { 289 key: condense_nested_dicts_matching_values( 290 value, val_condense_fallback_mapping 291 ) 292 for key, value in data.items() 293 } 294 else: 295 return data 296 297 # Find all identical values and condense by stitching together keys 298 values_grouped: defaultdict[Any, list[str]] = defaultdict(list) 299 data_persist: dict[str, Any] = dict() 300 for key, value in data.items(): 301 if not isinstance(value, dict): 302 try: 303 values_grouped[value].append(key) 304 except TypeError: 305 # If the value is unhashable, use a fallback mapping to find a hashable representation 306 if val_condense_fallback_mapping is not None: 307 values_grouped[val_condense_fallback_mapping(value)].append(key) 308 else: 309 data_persist[key] = value 310 else: 311 data_persist[key] = value 312 313 condensed_data = data_persist 314 for value, keys in values_grouped.items(): 315 if len(keys) > 1: 316 merged_key = f"[{', '.join(keys)}]" # Choose an appropriate method to represent merged keys 317 condensed_data[merged_key] = value 318 else: 319 condensed_data[keys[0]] = value 320 321 return condensed_data 322 323 324def condense_nested_dicts( 325 data: dict[str, Any], 326 condense_numeric_keys: bool = True, 327 condense_matching_values: bool = True, 328 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 329) -> dict[str, Any]: 330 """condense a nested dict, by condensing numeric or matching keys with matching values to ranges 331 332 combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()` 333 334 # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes 335 it's not reversible because types are lost to make the printing pretty 336 337 # Parameters: 338 - `data : dict[str, Any]` 339 data to process 340 - `condense_numeric_keys : bool` 341 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") 342 (defaults to `True`) 343 - `condense_matching_values : bool` 344 whether to condense keys with matching values 345 (defaults to `True`) 346 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 347 a function to apply to each value before adding it to the dict (if it's not hashable) 348 (defaults to `None`) 349 350 """ 351 352 condensed_data: dict = data 353 if condense_numeric_keys: 354 condensed_data = condense_nested_dicts_numeric_keys(condensed_data) 355 if condense_matching_values: 356 condensed_data = condense_nested_dicts_matching_values( 357 condensed_data, val_condense_fallback_mapping 358 ) 359 return condensed_data 360 361 362def tuple_dims_replace( 363 t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None 364) -> tuple[Union[int, str], ...]: 365 if dims_names_map is None: 366 return t 367 else: 368 return tuple(dims_names_map.get(x, x) for x in t) 369 370 371TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] # noqa: F821 372TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] # noqa: F821 373TensorDictFormats = Literal["dict", "json", "yaml", "yml"] 374 375 376def _default_shapes_convert(x: tuple) -> str: 377 return str(x).replace('"', "").replace("'", "") 378 379 380def condense_tensor_dict( 381 data: TensorDict | TensorIterable, 382 fmt: TensorDictFormats = "dict", 383 *args, 384 shapes_convert: Callable[[tuple], Any] = _default_shapes_convert, 385 drop_batch_dims: int = 0, 386 sep: str = ".", 387 dims_names_map: Optional[dict[int, str]] = None, 388 condense_numeric_keys: bool = True, 389 condense_matching_values: bool = True, 390 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 391 return_format: Optional[TensorDictFormats] = None, 392) -> Union[str, dict[str, str | tuple[int, ...]]]: 393 """Convert a dictionary of tensors to a dictionary of shapes. 394 395 by default, values are converted to strings of their shapes (for nice printing). 396 If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. 397 398 # Parameters: 399 - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` 400 a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) 401 - `fmt : TensorDictFormats` 402 format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. 403 (defaults to `'dict'`) 404 - `shapes_convert : Callable[[tuple], Any]` 405 conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) 406 (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`) 407 - `drop_batch_dims : int` 408 number of leading dimensions to drop from the shape 409 (defaults to `0`) 410 - `sep : str` 411 separator to use for nested keys 412 (defaults to `'.'`) 413 - `dims_names_map : dict[int, str] | None` 414 convert certain dimension values in shape. not perfect, can be buggy 415 (defaults to `None`) 416 - `condense_numeric_keys : bool` 417 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts` 418 (defaults to `True`) 419 - `condense_matching_values : bool` 420 whether to condense keys with matching values, passed on to `condense_nested_dicts` 421 (defaults to `True`) 422 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 423 a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts` 424 (defaults to `None`) 425 - `return_format : TensorDictFormats | None` 426 legacy alias for `fmt` kwarg 427 428 # Returns: 429 - `str|dict[str, str|tuple[int, ...]]` 430 dict if `return_format='dict'`, a string for `json` or `yaml` output 431 432 # Examples: 433 ```python 434 >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2") 435 >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml')) 436 ``` 437 ```yaml 438 embed: 439 W_E: (50257, 768) 440 pos_embed: 441 W_pos: (1024, 768) 442 blocks: 443 '[0-11]': 444 attn: 445 '[W_Q, W_K, W_V]': (12, 768, 64) 446 W_O: (12, 64, 768) 447 '[b_Q, b_K, b_V]': (12, 64) 448 b_O: (768,) 449 mlp: 450 W_in: (768, 3072) 451 b_in: (3072,) 452 W_out: (3072, 768) 453 b_out: (768,) 454 unembed: 455 W_U: (768, 50257) 456 b_U: (50257,) 457 ``` 458 459 # Raises: 460 - `ValueError` : if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed 461 """ 462 463 # handle arg processing: 464 # ---------------------------------------------------------------------- 465 # make all args except data and format keyword-only 466 assert len(args) == 0, f"unexpected positional args: {args}" 467 # handle legacy return_format 468 if return_format is not None: 469 warnings.warn( 470 "return_format is deprecated, use fmt instead", 471 DeprecationWarning, 472 ) 473 fmt = return_format 474 475 # identity function for shapes_convert if not provided 476 if shapes_convert is None: 477 shapes_convert = lambda x: x # noqa: E731 478 479 # convert to iterable 480 data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821 481 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore 482 ) 483 484 # get shapes 485 data_shapes: dict[str, Union[str, tuple[int, ...]]] = { 486 k: shapes_convert( 487 tuple_dims_replace( 488 tuple(v.shape)[drop_batch_dims:], 489 dims_names_map, 490 ) 491 ) 492 for k, v in data_items 493 } 494 495 # nest the dict 496 data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) 497 498 # condense the nested dict 499 data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts( 500 data=data_nested, 501 condense_numeric_keys=condense_numeric_keys, 502 condense_matching_values=condense_matching_values, 503 val_condense_fallback_mapping=val_condense_fallback_mapping, 504 ) 505 506 # return in the specified format 507 fmt_lower: str = fmt.lower() 508 if fmt_lower == "dict": 509 return data_condensed 510 elif fmt_lower == "json": 511 import json 512 513 return json.dumps(data_condensed, indent=2) 514 elif fmt_lower in ["yaml", "yml"]: 515 try: 516 import yaml # type: ignore[import-untyped] 517 518 return yaml.dump(data_condensed, sort_keys=False) 519 except ImportError as e: 520 raise ValueError("PyYAML is required for YAML output") from e 521 else: 522 raise ValueError(f"Invalid return format: {fmt}")
34class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): 35 """like a defaultdict, but default_factory is passed the key as an argument""" 36 37 def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs): 38 if args: 39 raise TypeError( 40 f"DefaulterDict does not support positional arguments: *args = {args}" 41 ) 42 super().__init__(**kwargs) 43 self.default_factory: Callable[[_KT], _VT] = default_factory 44 45 def __getitem__(self, k: _KT) -> _VT: 46 if k in self: 47 return dict.__getitem__(self, k) 48 else: 49 v: _VT = self.default_factory(k) 50 dict.__setitem__(self, k, v) 51 return v
like a defaultdict, but default_factory is passed the key as an argument
Inherited Members
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
58def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict: 59 """Convert a defaultdict or DefaulterDict to a normal dict, recursively""" 60 return { 61 key: ( 62 defaultdict_to_dict_recursive(value) 63 if isinstance(value, (defaultdict, DefaulterDict)) 64 else value 65 ) 66 for key, value in dd.items() 67 }
Convert a defaultdict or DefaulterDict to a normal dict, recursively
70def dotlist_to_nested_dict( 71 dot_dict: typing.Dict[str, Any], sep: str = "." 72) -> typing.Dict[str, Any]: 73 """Convert a dict with dot-separated keys to a nested dict 74 75 Example: 76 77 >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) 78 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 79 """ 80 nested_dict: defaultdict = _recursive_defaultdict_ctor() 81 for key, value in dot_dict.items(): 82 if not isinstance(key, str): 83 raise TypeError(f"key must be a string, got {type(key)}") 84 keys: list[str] = key.split(sep) 85 current: defaultdict = nested_dict 86 # iterate over the keys except the last one 87 for sub_key in keys[:-1]: 88 current = current[sub_key] 89 current[keys[-1]] = value 90 return defaultdict_to_dict_recursive(nested_dict)
Convert a dict with dot-separated keys to a nested dict
Example:
>>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
93def nested_dict_to_dotlist( 94 nested_dict: typing.Dict[str, Any], 95 sep: str = ".", 96 allow_lists: bool = False, 97) -> dict[str, Any]: 98 def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]: 99 items: dict = dict() 100 101 new_key: str 102 if isinstance(current, dict): 103 # dict case 104 if not current and parent_key: 105 items[parent_key] = current 106 else: 107 for k, v in current.items(): 108 new_key = f"{parent_key}{sep}{k}" if parent_key else k 109 items.update(_recurse(v, new_key)) 110 111 elif allow_lists and isinstance(current, list): 112 # list case 113 for i, item in enumerate(current): 114 new_key = f"{parent_key}{sep}{i}" if parent_key else str(i) 115 items.update(_recurse(item, new_key)) 116 117 else: 118 # anything else (write value) 119 items[parent_key] = current 120 121 return items 122 123 return _recurse(nested_dict)
126def update_with_nested_dict( 127 original: dict[str, Any], 128 update: dict[str, Any], 129) -> dict[str, Any]: 130 """Update a dict with a nested dict 131 132 Example: 133 >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}}) 134 {'a': {'b': 2}, 'c': -1} 135 136 # Arguments 137 - `original: dict[str, Any]` 138 the dict to update (will be modified in-place) 139 - `update: dict[str, Any]` 140 the dict to update with 141 142 # Returns 143 - `dict` 144 the updated dict 145 """ 146 for key, value in update.items(): 147 if key in original: 148 if isinstance(original[key], dict) and isinstance(value, dict): 149 update_with_nested_dict(original[key], value) 150 else: 151 original[key] = value 152 else: 153 original[key] = value 154 155 return original
Update a dict with a nested dict
Example:
>>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
{'a': {'b': 2}, 'c': -1}
Arguments
original: dict[str, Any]
the dict to update (will be modified in-place)update: dict[str, Any]
the dict to update with
Returns
dict
the updated dict
158def kwargs_to_nested_dict( 159 kwargs_dict: dict[str, Any], 160 sep: str = ".", 161 strip_prefix: Optional[str] = None, 162 when_unknown_prefix: ErrorMode = ErrorMode.WARN, 163 transform_key: Optional[Callable[[str], str]] = None, 164) -> dict[str, Any]: 165 """given kwargs from fire, convert them to a nested dict 166 167 if strip_prefix is not None, then all keys must start with the prefix. by default, 168 will warn if an unknown prefix is found, but can be set to raise an error or ignore it: 169 `when_unknown_prefix: ErrorMode` 170 171 Example: 172 ```python 173 def main(**kwargs): 174 print(kwargs_to_nested_dict(kwargs)) 175 fire.Fire(main) 176 ``` 177 running the above script will give: 178 ```bash 179 $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3 180 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 181 ``` 182 183 # Arguments 184 - `kwargs_dict: dict[str, Any]` 185 the kwargs dict to convert 186 - `sep: str = "."` 187 the separator to use for nested keys 188 - `strip_prefix: Optional[str] = None` 189 if not None, then all keys must start with this prefix 190 - `when_unknown_prefix: ErrorMode = ErrorMode.WARN` 191 what to do when an unknown prefix is found 192 - `transform_key: Callable[[str], str] | None = None` 193 a function to apply to each key before adding it to the dict (applied after stripping the prefix) 194 """ 195 when_unknown_prefix = ErrorMode.from_any(when_unknown_prefix) 196 filtered_kwargs: dict[str, Any] = dict() 197 for key, value in kwargs_dict.items(): 198 if strip_prefix is not None: 199 if not key.startswith(strip_prefix): 200 when_unknown_prefix.process( 201 f"key '{key}' does not start with '{strip_prefix}'", 202 except_cls=ValueError, 203 ) 204 else: 205 key = key[len(strip_prefix) :] 206 207 if transform_key is not None: 208 key = transform_key(key) 209 210 filtered_kwargs[key] = value 211 212 return dotlist_to_nested_dict(filtered_kwargs, sep=sep)
given kwargs from fire, convert them to a nested dict
if strip_prefix is not None, then all keys must start with the prefix. by default,
will warn if an unknown prefix is found, but can be set to raise an error or ignore it:
when_unknown_prefix: ErrorMode
Example:
def main(**kwargs):
print(kwargs_to_nested_dict(kwargs))
fire.Fire(main)
running the above script will give:
$ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
Arguments
kwargs_dict: dict[str, Any]
the kwargs dict to convertsep: str = "."
the separator to use for nested keysstrip_prefix: Optional[str] = None
if not None, then all keys must start with this prefixwhen_unknown_prefix: ErrorMode = ErrorMode.WARN
what to do when an unknown prefix is foundtransform_key: Callable[[str], str] | None = None
a function to apply to each key before adding it to the dict (applied after stripping the prefix)
215def is_numeric_consecutive(lst: list[str]) -> bool: 216 """Check if the list of keys is numeric and consecutive.""" 217 try: 218 numbers: list[int] = [int(x) for x in lst] 219 return sorted(numbers) == list(range(min(numbers), max(numbers) + 1)) 220 except ValueError: 221 return False
Check if the list of keys is numeric and consecutive.
224def condense_nested_dicts_numeric_keys( 225 data: dict[str, Any], 226) -> dict[str, Any]: 227 """condense a nested dict, by condensing numeric keys with matching values to ranges 228 229 # Examples: 230 ```python 231 >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2}) 232 {'[1-3]': 1, '[4-6]': 2} 233 >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'}) 234 {"1": {"[1-2]": "a"}, "2": "b"} 235 ``` 236 """ 237 238 if not isinstance(data, dict): 239 return data 240 241 # Process each sub-dictionary 242 for key, value in list(data.items()): 243 data[key] = condense_nested_dicts_numeric_keys(value) 244 245 # Find all numeric, consecutive keys 246 if is_numeric_consecutive(list(data.keys())): 247 keys: list[str] = sorted(data.keys(), key=lambda x: int(x)) 248 else: 249 return data 250 251 # output dict 252 condensed_data: dict[str, Any] = {} 253 254 # Identify ranges of identical values and condense 255 i: int = 0 256 while i < len(keys): 257 j: int = i 258 while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]: 259 j += 1 260 if j > i: # Found consecutive keys with identical values 261 condensed_key: str = f"[{keys[i]}-{keys[j]}]" 262 condensed_data[condensed_key] = data[keys[i]] 263 i = j + 1 264 else: 265 condensed_data[keys[i]] = data[keys[i]] 266 i += 1 267 268 return condensed_data
condense a nested dict, by condensing numeric keys with matching values to ranges
Examples:
>>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
{'[1-3]': 1, '[4-6]': 2}
>>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
{"1": {"[1-2]": "a"}, "2": "b"}
271def condense_nested_dicts_matching_values( 272 data: dict[str, Any], 273 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 274) -> dict[str, Any]: 275 """condense a nested dict, by condensing keys with matching values 276 277 # Examples: TODO 278 279 # Parameters: 280 - `data : dict[str, Any]` 281 data to process 282 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 283 a function to apply to each value before adding it to the dict (if it's not hashable) 284 (defaults to `None`) 285 286 """ 287 288 if isinstance(data, dict): 289 data = { 290 key: condense_nested_dicts_matching_values( 291 value, val_condense_fallback_mapping 292 ) 293 for key, value in data.items() 294 } 295 else: 296 return data 297 298 # Find all identical values and condense by stitching together keys 299 values_grouped: defaultdict[Any, list[str]] = defaultdict(list) 300 data_persist: dict[str, Any] = dict() 301 for key, value in data.items(): 302 if not isinstance(value, dict): 303 try: 304 values_grouped[value].append(key) 305 except TypeError: 306 # If the value is unhashable, use a fallback mapping to find a hashable representation 307 if val_condense_fallback_mapping is not None: 308 values_grouped[val_condense_fallback_mapping(value)].append(key) 309 else: 310 data_persist[key] = value 311 else: 312 data_persist[key] = value 313 314 condensed_data = data_persist 315 for value, keys in values_grouped.items(): 316 if len(keys) > 1: 317 merged_key = f"[{', '.join(keys)}]" # Choose an appropriate method to represent merged keys 318 condensed_data[merged_key] = value 319 else: 320 condensed_data[keys[0]] = value 321 322 return condensed_data
condense a nested dict, by condensing keys with matching values
Examples: TODO
Parameters:
data : dict[str, Any]
data to processval_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it's not hashable) (defaults toNone
)
325def condense_nested_dicts( 326 data: dict[str, Any], 327 condense_numeric_keys: bool = True, 328 condense_matching_values: bool = True, 329 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 330) -> dict[str, Any]: 331 """condense a nested dict, by condensing numeric or matching keys with matching values to ranges 332 333 combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()` 334 335 # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes 336 it's not reversible because types are lost to make the printing pretty 337 338 # Parameters: 339 - `data : dict[str, Any]` 340 data to process 341 - `condense_numeric_keys : bool` 342 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") 343 (defaults to `True`) 344 - `condense_matching_values : bool` 345 whether to condense keys with matching values 346 (defaults to `True`) 347 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 348 a function to apply to each value before adding it to the dict (if it's not hashable) 349 (defaults to `None`) 350 351 """ 352 353 condensed_data: dict = data 354 if condense_numeric_keys: 355 condensed_data = condense_nested_dicts_numeric_keys(condensed_data) 356 if condense_matching_values: 357 condensed_data = condense_nested_dicts_matching_values( 358 condensed_data, val_condense_fallback_mapping 359 ) 360 return condensed_data
condense a nested dict, by condensing numeric or matching keys with matching values to ranges
combines the functionality of condense_nested_dicts_numeric_keys()
and condense_nested_dicts_matching_values()
NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes
it's not reversible because types are lost to make the printing pretty
Parameters:
data : dict[str, Any]
data to processcondense_numeric_keys : bool
whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") (defaults toTrue
)condense_matching_values : bool
whether to condense keys with matching values (defaults toTrue
)val_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it's not hashable) (defaults toNone
)
381def condense_tensor_dict( 382 data: TensorDict | TensorIterable, 383 fmt: TensorDictFormats = "dict", 384 *args, 385 shapes_convert: Callable[[tuple], Any] = _default_shapes_convert, 386 drop_batch_dims: int = 0, 387 sep: str = ".", 388 dims_names_map: Optional[dict[int, str]] = None, 389 condense_numeric_keys: bool = True, 390 condense_matching_values: bool = True, 391 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 392 return_format: Optional[TensorDictFormats] = None, 393) -> Union[str, dict[str, str | tuple[int, ...]]]: 394 """Convert a dictionary of tensors to a dictionary of shapes. 395 396 by default, values are converted to strings of their shapes (for nice printing). 397 If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. 398 399 # Parameters: 400 - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` 401 a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) 402 - `fmt : TensorDictFormats` 403 format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. 404 (defaults to `'dict'`) 405 - `shapes_convert : Callable[[tuple], Any]` 406 conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) 407 (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`) 408 - `drop_batch_dims : int` 409 number of leading dimensions to drop from the shape 410 (defaults to `0`) 411 - `sep : str` 412 separator to use for nested keys 413 (defaults to `'.'`) 414 - `dims_names_map : dict[int, str] | None` 415 convert certain dimension values in shape. not perfect, can be buggy 416 (defaults to `None`) 417 - `condense_numeric_keys : bool` 418 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts` 419 (defaults to `True`) 420 - `condense_matching_values : bool` 421 whether to condense keys with matching values, passed on to `condense_nested_dicts` 422 (defaults to `True`) 423 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 424 a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts` 425 (defaults to `None`) 426 - `return_format : TensorDictFormats | None` 427 legacy alias for `fmt` kwarg 428 429 # Returns: 430 - `str|dict[str, str|tuple[int, ...]]` 431 dict if `return_format='dict'`, a string for `json` or `yaml` output 432 433 # Examples: 434 ```python 435 >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2") 436 >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml')) 437 ``` 438 ```yaml 439 embed: 440 W_E: (50257, 768) 441 pos_embed: 442 W_pos: (1024, 768) 443 blocks: 444 '[0-11]': 445 attn: 446 '[W_Q, W_K, W_V]': (12, 768, 64) 447 W_O: (12, 64, 768) 448 '[b_Q, b_K, b_V]': (12, 64) 449 b_O: (768,) 450 mlp: 451 W_in: (768, 3072) 452 b_in: (3072,) 453 W_out: (3072, 768) 454 b_out: (768,) 455 unembed: 456 W_U: (768, 50257) 457 b_U: (50257,) 458 ``` 459 460 # Raises: 461 - `ValueError` : if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed 462 """ 463 464 # handle arg processing: 465 # ---------------------------------------------------------------------- 466 # make all args except data and format keyword-only 467 assert len(args) == 0, f"unexpected positional args: {args}" 468 # handle legacy return_format 469 if return_format is not None: 470 warnings.warn( 471 "return_format is deprecated, use fmt instead", 472 DeprecationWarning, 473 ) 474 fmt = return_format 475 476 # identity function for shapes_convert if not provided 477 if shapes_convert is None: 478 shapes_convert = lambda x: x # noqa: E731 479 480 # convert to iterable 481 data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821 482 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore 483 ) 484 485 # get shapes 486 data_shapes: dict[str, Union[str, tuple[int, ...]]] = { 487 k: shapes_convert( 488 tuple_dims_replace( 489 tuple(v.shape)[drop_batch_dims:], 490 dims_names_map, 491 ) 492 ) 493 for k, v in data_items 494 } 495 496 # nest the dict 497 data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) 498 499 # condense the nested dict 500 data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts( 501 data=data_nested, 502 condense_numeric_keys=condense_numeric_keys, 503 condense_matching_values=condense_matching_values, 504 val_condense_fallback_mapping=val_condense_fallback_mapping, 505 ) 506 507 # return in the specified format 508 fmt_lower: str = fmt.lower() 509 if fmt_lower == "dict": 510 return data_condensed 511 elif fmt_lower == "json": 512 import json 513 514 return json.dumps(data_condensed, indent=2) 515 elif fmt_lower in ["yaml", "yml"]: 516 try: 517 import yaml # type: ignore[import-untyped] 518 519 return yaml.dump(data_condensed, sort_keys=False) 520 except ImportError as e: 521 raise ValueError("PyYAML is required for YAML output") from e 522 else: 523 raise ValueError(f"Invalid return format: {fmt}")
Convert a dictionary of tensors to a dictionary of shapes.
by default, values are converted to strings of their shapes (for nice printing).
If you want the actual shapes, set shapes_convert = lambda x: x
or shapes_convert = None
.
Parameters:
data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]
a either aTensorDict
dict from strings to tensors, or anTensorIterable
iterable of (key, tensor) pairs (like you might get from adict().items())
)fmt : TensorDictFormats
format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. (defaults to'dict'
)shapes_convert : Callable[[tuple], Any]
conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) (defaults tolambdax:str(x).replace('"', '').replace("'", '')
)drop_batch_dims : int
number of leading dimensions to drop from the shape (defaults to0
)sep : str
separator to use for nested keys (defaults to'.'
)dims_names_map : dict[int, str] | None
convert certain dimension values in shape. not perfect, can be buggy (defaults toNone
)condense_numeric_keys : bool
whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on tocondense_nested_dicts
(defaults toTrue
)condense_matching_values : bool
whether to condense keys with matching values, passed on tocondense_nested_dicts
(defaults toTrue
)val_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it's not hashable), passed on tocondense_nested_dicts
(defaults toNone
)return_format : TensorDictFormats | None
legacy alias forfmt
kwarg
Returns:
str|dict[str, str|tuple[int, ...]]
dict ifreturn_format='dict'
, a string forjson
oryaml
output
Examples:
>>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
>>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
embed:
W_E: (50257, 768)
pos_embed:
W_pos: (1024, 768)
blocks:
'[0-11]':
attn:
'[W_Q, W_K, W_V]': (12, 768, 64)
W_O: (12, 64, 768)
'[b_Q, b_K, b_V]': (12, 64)
b_O: (768,)
mlp:
W_in: (768, 3072)
b_in: (3072,)
W_out: (3072, 768)
b_out: (768,)
unembed:
W_U: (768, 50257)
b_U: (50257,)
Raises:
ValueError
: ifreturn_format
is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed