Module flashy.loggers.tensorboard
Expand source code
from argparse import Namespace
from pathlib import Path
import typing as tp
import warnings
import dora
import torch
try:
from torch.utils.tensorboard import SummaryWriter
except ModuleNotFoundError:
SummaryWriter = None # type: ignore
from . import ExperimentLogger
from .utils import (
_add_prefix, _convert_params, _flatten_dict,
_sanitize_params
)
from ..distrib import rank_zero_only, is_rank_zero
class TensorboardLogger(ExperimentLogger):
"""ExperimentLogger for Tensorboard
Args:
save_dir (str): The directory where the experiment logs are written to
with_media_logging (bool): Whether to save media samples with the logger or ignore them
This comes at extra storage costs for Tensorboard so we might not want to enforce it
name (str): Name for the experiment logs
kwargs (Any): Additional tensorboard parameters
"""
def __init__(self, save_dir: tp.Union[Path, str], with_media_logging: bool = True,
name: tp.Optional[str] = None, **kwargs: tp.Any):
self._with_media_logging = with_media_logging
self._save_dir = str(save_dir)
self._name = name or 'tensorboard'
self._writer: tp.Optional[SummaryWriter] = None
if SummaryWriter:
self._writer = SummaryWriter(self.save_dir, **kwargs)
else:
warnings.warn("tensorboard package was not found: use pip install tensorboard")
@property # type: ignore
@rank_zero_only
def writer(self) -> tp.Optional[SummaryWriter]:
"""Actual tensorboard writer object."""
return self._writer
@rank_zero_only
def is_disabled(self) -> bool:
return self.writer is None
@rank_zero_only
def log_hyperparams(self, params: tp.Union[tp.Dict[str, tp.Any], Namespace],
metrics: tp.Optional[dict] = None) -> None:
"""Record experiment hyperparameters.
This method logs the hyperparameters associated to the experiment.
Args:
params: Dictionary of hyperparameters
metrics: Dictionary of final metrics for the set of hyperparameters
"""
assert is_rank_zero(), "experiment tried to log from global_rank != 0"
if self.is_disabled():
return
params = _convert_params(params)
params = _flatten_dict(params)
params = _sanitize_params(params)
if metrics is None or len(metrics) == 0:
metrics = {'hparams_metrics': -1}
self.writer.add_hparams(params, metrics)
@rank_zero_only
def log_metrics(self, prefix: tp.Union[str, tp.List[str]], metrics: dict, step: tp.Optional[int] = None) -> None:
"""Records metrics.
This method logs metrics as as soon as it received them.
Args:
prefix: Prefix(es) to use for metric names when writing to smart logger
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
assert is_rank_zero(), "experiment tried to log from global_rank != 0"
if self.is_disabled():
return
metrics = _add_prefix(metrics, prefix, self.group_separator)
for key, val in metrics.items():
if isinstance(val, torch.Tensor):
val = val.item()
if isinstance(val, dict):
self.writer.add_scalars(key, val, step)
else:
try:
self.writer.add_scalar(key, val, step)
except Exception as ex:
msg = f"\n you tried to log {val} ({type(val)}) which is currently not supported. " \
"Try a dict or a scalar/tensor."
raise ValueError(msg) from ex
@rank_zero_only
def log_audio(self, key: str, prefix: tp.Union[str, tp.List[str]], audio: tp.Any, sample_rate: int,
step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
"""Records audio.
This method logs audio wave as soon as it received them.
Args:
prefix: Prefix(es) to use for metric names when writing to smart logger
key: Key for the audio
audio: Torch Tensor representing the audio as [C, T]
sample_rate: Sample rate corresponding to the audio
step: Step number at which the metrics should be recorded
"""
assert is_rank_zero(), "experiment tried to log from global_rank != 0"
if self.is_disabled() or not self.with_media_logging:
return
assert isinstance(audio, torch.Tensor), "Only support logging torch.Tensor as audio"
metrics = {
key: audio.mean(dim=-2, keepdim=True).clamp(-0.99, 0.99)
}
metrics = _add_prefix(metrics, prefix, self.group_separator)
for name, media in metrics.items():
self.writer.add_audio(
name,
media,
step,
sample_rate,
**kwargs
)
@rank_zero_only
def log_image(self, prefix: tp.Union[str, tp.List[str]], key: str, image: tp.Any,
step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
"""Records image.
This method logs image as soon as it received them.
Args:
prefix: Prefix(es) to use for metric names when writing to smart logger
key: Key for the image
image: Torch Tensor representing the image
step: Step number at which the metrics should be recorded
"""
assert is_rank_zero(), "experiment tried to log from global_rank != 0"
if self.is_disabled() or not self.with_media_logging:
return
assert isinstance(image, torch.Tensor), "Only support logging torch.Tensor as image"
metrics = {
key: image
}
metrics = _add_prefix(metrics, prefix, self.group_separator)
for name, media in metrics.items():
self.writer.add_image(
name,
media,
step,
**kwargs
)
@rank_zero_only
def log_text(self, prefix: tp.Union[str, tp.List[str]], key: str, text: str,
step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
"""Records text.
This method logs text as soon as it received them.
Args:
prefix: Prefix(es) to use for metric names when writing to smart logger
key: Key for the text
text: String containing message
step: Step number at which the metrics should be recorded
"""
assert is_rank_zero(), "writer tried to log from global_rank != 0"
if self.is_disabled() or not self.with_media_logging:
return
metrics = {
key: text
}
metrics = _add_prefix(metrics, prefix, self.group_separator)
for name, media in metrics.items():
self.writer.add_text(
name,
media,
step,
**kwargs
)
@property
def save_dir(self) -> str:
"""Directory where the data is saved."""
return self._save_dir
def with_media_logging(self) -> bool:
"""Whether the logger can save media or ignore them."""
return self._with_media_logging
@property
def name(self) -> str:
"""Name of the experiment logger."""
return self._name
@classmethod
def from_xp(cls, with_media_logging: bool = True, name: tp.Optional[str] = None, sub_dir: tp.Optional[str] = None):
save_dir = dora.get_xp().folder / 'tensorboard'
if sub_dir:
save_dir = save_dir / sub_dir
save_dir.mkdir(exist_ok=True, parents=True)
return TensorboardLogger(save_dir, with_media_logging, name=name)
Classes
class TensorboardLogger (save_dir: Union[pathlib.Path, str], with_media_logging: bool = True, name: Optional[str] = None, **kwargs: Any)
-
ExperimentLogger for Tensorboard
Args
save_dir
:str
- The directory where the experiment logs are written to
with_media_logging
:bool
- Whether to save media samples with the logger or ignore them This comes at extra storage costs for Tensorboard so we might not want to enforce it
name
:str
- Name for the experiment logs
kwargs
:Any
- Additional tensorboard parameters
Expand source code
class TensorboardLogger(ExperimentLogger): """ExperimentLogger for Tensorboard Args: save_dir (str): The directory where the experiment logs are written to with_media_logging (bool): Whether to save media samples with the logger or ignore them This comes at extra storage costs for Tensorboard so we might not want to enforce it name (str): Name for the experiment logs kwargs (Any): Additional tensorboard parameters """ def __init__(self, save_dir: tp.Union[Path, str], with_media_logging: bool = True, name: tp.Optional[str] = None, **kwargs: tp.Any): self._with_media_logging = with_media_logging self._save_dir = str(save_dir) self._name = name or 'tensorboard' self._writer: tp.Optional[SummaryWriter] = None if SummaryWriter: self._writer = SummaryWriter(self.save_dir, **kwargs) else: warnings.warn("tensorboard package was not found: use pip install tensorboard") @property # type: ignore @rank_zero_only def writer(self) -> tp.Optional[SummaryWriter]: """Actual tensorboard writer object.""" return self._writer @rank_zero_only def is_disabled(self) -> bool: return self.writer is None @rank_zero_only def log_hyperparams(self, params: tp.Union[tp.Dict[str, tp.Any], Namespace], metrics: tp.Optional[dict] = None) -> None: """Record experiment hyperparameters. This method logs the hyperparameters associated to the experiment. Args: params: Dictionary of hyperparameters metrics: Dictionary of final metrics for the set of hyperparameters """ assert is_rank_zero(), "experiment tried to log from global_rank != 0" if self.is_disabled(): return params = _convert_params(params) params = _flatten_dict(params) params = _sanitize_params(params) if metrics is None or len(metrics) == 0: metrics = {'hparams_metrics': -1} self.writer.add_hparams(params, metrics) @rank_zero_only def log_metrics(self, prefix: tp.Union[str, tp.List[str]], metrics: dict, step: tp.Optional[int] = None) -> None: """Records metrics. This method logs metrics as as soon as it received them. Args: prefix: Prefix(es) to use for metric names when writing to smart logger metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """ assert is_rank_zero(), "experiment tried to log from global_rank != 0" if self.is_disabled(): return metrics = _add_prefix(metrics, prefix, self.group_separator) for key, val in metrics.items(): if isinstance(val, torch.Tensor): val = val.item() if isinstance(val, dict): self.writer.add_scalars(key, val, step) else: try: self.writer.add_scalar(key, val, step) except Exception as ex: msg = f"\n you tried to log {val} ({type(val)}) which is currently not supported. " \ "Try a dict or a scalar/tensor." raise ValueError(msg) from ex @rank_zero_only def log_audio(self, key: str, prefix: tp.Union[str, tp.List[str]], audio: tp.Any, sample_rate: int, step: tp.Optional[int] = None, **kwargs: tp.Any) -> None: """Records audio. This method logs audio wave as soon as it received them. Args: prefix: Prefix(es) to use for metric names when writing to smart logger key: Key for the audio audio: Torch Tensor representing the audio as [C, T] sample_rate: Sample rate corresponding to the audio step: Step number at which the metrics should be recorded """ assert is_rank_zero(), "experiment tried to log from global_rank != 0" if self.is_disabled() or not self.with_media_logging: return assert isinstance(audio, torch.Tensor), "Only support logging torch.Tensor as audio" metrics = { key: audio.mean(dim=-2, keepdim=True).clamp(-0.99, 0.99) } metrics = _add_prefix(metrics, prefix, self.group_separator) for name, media in metrics.items(): self.writer.add_audio( name, media, step, sample_rate, **kwargs ) @rank_zero_only def log_image(self, prefix: tp.Union[str, tp.List[str]], key: str, image: tp.Any, step: tp.Optional[int] = None, **kwargs: tp.Any) -> None: """Records image. This method logs image as soon as it received them. Args: prefix: Prefix(es) to use for metric names when writing to smart logger key: Key for the image image: Torch Tensor representing the image step: Step number at which the metrics should be recorded """ assert is_rank_zero(), "experiment tried to log from global_rank != 0" if self.is_disabled() or not self.with_media_logging: return assert isinstance(image, torch.Tensor), "Only support logging torch.Tensor as image" metrics = { key: image } metrics = _add_prefix(metrics, prefix, self.group_separator) for name, media in metrics.items(): self.writer.add_image( name, media, step, **kwargs ) @rank_zero_only def log_text(self, prefix: tp.Union[str, tp.List[str]], key: str, text: str, step: tp.Optional[int] = None, **kwargs: tp.Any) -> None: """Records text. This method logs text as soon as it received them. Args: prefix: Prefix(es) to use for metric names when writing to smart logger key: Key for the text text: String containing message step: Step number at which the metrics should be recorded """ assert is_rank_zero(), "writer tried to log from global_rank != 0" if self.is_disabled() or not self.with_media_logging: return metrics = { key: text } metrics = _add_prefix(metrics, prefix, self.group_separator) for name, media in metrics.items(): self.writer.add_text( name, media, step, **kwargs ) @property def save_dir(self) -> str: """Directory where the data is saved.""" return self._save_dir def with_media_logging(self) -> bool: """Whether the logger can save media or ignore them.""" return self._with_media_logging @property def name(self) -> str: """Name of the experiment logger.""" return self._name @classmethod def from_xp(cls, with_media_logging: bool = True, name: tp.Optional[str] = None, sub_dir: tp.Optional[str] = None): save_dir = dora.get_xp().folder / 'tensorboard' if sub_dir: save_dir = save_dir / sub_dir save_dir.mkdir(exist_ok=True, parents=True) return TensorboardLogger(save_dir, with_media_logging, name=name)
Ancestors
- ExperimentLogger
- abc.ABC
Static methods
def from_xp(with_media_logging: bool = True, name: Optional[str] = None, sub_dir: Optional[str] = None)
-
Expand source code
@classmethod def from_xp(cls, with_media_logging: bool = True, name: tp.Optional[str] = None, sub_dir: tp.Optional[str] = None): save_dir = dora.get_xp().folder / 'tensorboard' if sub_dir: save_dir = save_dir / sub_dir save_dir.mkdir(exist_ok=True, parents=True) return TensorboardLogger(save_dir, with_media_logging, name=name)
Instance variables
var writer : None
-
Actual tensorboard writer object.
Expand source code
@property # type: ignore @rank_zero_only def writer(self) -> tp.Optional[SummaryWriter]: """Actual tensorboard writer object.""" return self._writer
Methods
def is_disabled(self) ‑> bool
-
Expand source code
@rank_zero_only def is_disabled(self) -> bool: return self.writer is None
Inherited members