Module flashy.loggers.utils
Utilities for loggers.
Expand source code
# Most of the utilities here are derived from PyTorch Lightning
"""Utilities for loggers."""
from argparse import Namespace
import typing as tp
import torch
def _fmt_prefix(prefix: tp.Union[str, tp.List[str]], delimiter: str = "/") -> str:
"""Format prefix(es) to a single prefix string."""
if isinstance(prefix, str):
return prefix
else:
return delimiter.join(prefix)
def _add_prefix(metrics: dict, prefix: tp.Union[str, tp.List[str]], delimiter: str = "/") -> dict:
"""Insert prefix before each key in a dict, separated by the delimiter.
Args:
metrics: Dictionary with metric names as keys and values
prefix: Prefix to insert before each key
delimiter: Separates prefix and original key name. Defaults to ``'/'``
Returns:
Dictionary with prefix and delimiter inserted before each key
"""
if prefix is None:
return metrics
prefix = _fmt_prefix(prefix, delimiter)
metrics = {f"{prefix}{delimiter}{k}": v for k, v in metrics.items()}
return metrics
def _convert_params(params: tp.Union[dict, Namespace]) -> dict:
"""Ensure parameters are a dict or convert to dict if necessary.
Args:
params: Target to be converted to a dictionary
Returns:
params as a dictionary
"""
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
if params is None:
params = {}
return params
def _flatten_dict(params: tp.Dict[str, tp.Any], delimiter: str = ".") -> tp.Dict[str, tp.Any]:
"""Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
Args:
params: Dictionary containing the hyperparameters
delimiter: Delimiter to express the hierarchy. Defaults to ``'.'``
Returns:
Flattened dict
Examples:
>>> _flatten_dict({'a': {'b': 'c'}})
{'a/b': 'c'}
>>> _flatten_dict({'a': {'b': 123}})
{'a/b': 123}
>>> _flatten_dict({5: {'a': 123}})
{'5/a': 123}
"""
def _dict_generator(
input_dict: tp.Any, prefixes: tp.List[tp.Optional[str]] = None
):
prefixes = prefixes[:] if prefixes else []
if isinstance(input_dict, tp.MutableMapping):
for key, value in input_dict.items():
key = str(key)
if isinstance(value, (tp.MutableMapping, Namespace)):
value = vars(value) if isinstance(value, Namespace) else value
yield from _dict_generator(value, prefixes + [key])
else:
yield prefixes + [key, value if value is not None else str(None)]
else:
yield prefixes + [input_dict if input_dict is None else str(input_dict)]
return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}
def _sanitize_params(params: tp.Dict[str, tp.Any]) -> tp.Dict[str, tp.Any]:
"""Returns params with non-primitvies converted to strings for logging.
>>> params = {"float": 0.3,
... "int": 1,
... "string": "abc",
... "bool": True,
... "list": [1, 2, 3],
... "namespace": Namespace(foo=3),
... "layer": torch.nn.BatchNorm1d}
>>> import pprint
>>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
{'bool': True,
'float': 0.3,
'int': 1,
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
'list': '[1, 2, 3]',
'namespace': 'Namespace(foo=3)',
'string': 'abc'}
"""
for k in params.keys():
# convert relevant np scalars to python types first (instead of str)
if type(params[k]) not in [bool, int, float, str, torch.Tensor]:
params[k] = str(params[k])
return params