Module flashy.loggers.wandb

Expand source code
from argparse import Namespace
import typing as tp
import warnings

import dora
import torch

try:
    import wandb
except ModuleNotFoundError:
    wandb = None  # type: ignore

from . import ExperimentLogger
from .utils import (
    _add_prefix, _convert_params, _flatten_dict,
    _sanitize_params
)
from ..distrib import rank_zero_only, is_rank_zero


class WandbLogger(ExperimentLogger):
    """ExperimentLogger for Wandb (Weight and Biases)

    Args:
        save_dir (str): The directory where the experiment logs are written to
        with_media_logging (bool): Whether to save media samples with the logger or ignore them
            This comes at bandwidth costs for Wandb but we are usually fine with it
        project (str): Wandb project name
        id (Optional[str]): Wandb run id
        group (Optional[str]): Wandb group
        reinit (bool): Whether to reinit run or not
        resume (bool): Whether to resume run or not
        name (str): Name for the experiment logs
        kwargs (Any): Additional wandb parameters, including tags for example
    """
    def __init__(self, save_dir: str, with_media_logging: bool = True, project: tp.Optional[str] = None,
                 id: tp.Optional[str] = None, group: tp.Optional[str] = None, reinit: bool = False,
                 resume: tp.Union[str, bool] = False, name: tp.Optional[str] = None, **kwargs: tp.Any):
        self._save_dir = save_dir
        self._with_media_logging = with_media_logging
        self._name = name or 'wandb'
        self._project = project
        self._id = id
        self._group = group
        self._reinit = reinit
        self._resume = resume
        self._wandb_run: tp.Optional[tp.Any] = None
        if wandb:
            self._wandb_run = wandb.init(
                project=self._project,
                reinit=self._reinit,
                group=self._group,
                name=self._name,
                dir=self.save_dir,
                id=self._id,
                resume=self._resume,
                **kwargs
            )
        else:
            warnings.warn("wandb package was not found: use pip install wandb")

    @property  # type: ignore
    @rank_zero_only
    def writer(self) -> tp.Optional[tp.Any]:
        """Actual wandb run object."""
        return self._wandb_run

    @rank_zero_only
    def is_disabled(self) -> bool:
        return self.writer is None

    @rank_zero_only
    def log_hyperparams(self, params: tp.Union[tp.Dict[str, tp.Any], Namespace],
                        metrics: tp.Optional[dict] = None) -> None:
        """Record experiment hyperparameters.
        This method logs the hyperparameters associated to the experiment.

        Args:
            params: Dictionary of hyperparameters
            metrics: Dictionary of final metrics for the set of hyperparameters
        """
        assert is_rank_zero(), "experiment tried to log from global_rank != 0"

        params = _convert_params(params)
        params = _flatten_dict(params)
        params = _sanitize_params(params)
        self.writer.config.update(params, allow_val_change=True)

    @rank_zero_only
    def log_metrics(self, prefix: tp.Union[str, tp.List[str]], metrics: dict, step: tp.Optional[int] = None) -> None:
        """Records metrics.
        This method logs metrics as as soon as it received them.

        Args:
            prefix: Prefix(es) to use for metric names when writing to smart logger
            metrics: Dictionary with metric names as keys and measured quantities as values
            step: Step number at which the metrics should be recorded
        """
        assert is_rank_zero(), "writer tried to log from global_rank != 0"
        if self.is_disabled() or not self.with_media_logging:
            return

        metrics = _add_prefix(metrics, prefix, self.group_separator)

        for key, val in metrics.items():
            wandb.log({key: val}, step=step)

    @rank_zero_only
    def log_audio(self, prefix: tp.Union[str, tp.List[str]], key: str, audio: tp.Any, sample_rate: int,
                  step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
        """Records audio.
        This method logs audio wave as soon as it received them.

        Args:
            prefix: Prefix(es) to use for metric names when writing to smart logger
            key: Key for the audio
            audio: Torch Tensor representing the audio as [C, T]
            step: Step number at which the metrics should be recorded
        """
        assert is_rank_zero(), "writer tried to log from global_rank != 0"
        if self.is_disabled() or not self.with_media_logging:
            return

        assert isinstance(audio, torch.Tensor), "Only support logging torch.Tensor as audio"

        audio = audio.t().clamp(-0.99, 0.99).numpy()
        metrics = {
            key: wandb.Audio(audio, sample_rate=sample_rate, **kwargs)
        }
        metrics = _add_prefix(metrics, prefix, self.group_separator)
        self.log_metrics(metrics, prefix, step)

    @rank_zero_only
    def log_image(self, prefix: tp.Union[str, tp.List[str]], key: str, image: tp.Any,
                  step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
        """Records image.
        This method logs image as soon as it received them.

        Args:
            prefix: Prefix(es) to use for metric names when writing to smart logger
            key: Key for the image
            image: Torch Tensor representing the image
            step: Step number at which the metrics should be recorded
        """
        assert is_rank_zero(), "writer tried to log from global_rank != 0"
        if self.is_disabled() or not self.with_media_logging:
            return

        assert isinstance(image, torch.Tensor), "Only support logging torch.Tensor as image"

        metrics = {
            key: wandb.Image(image, **kwargs)
        }
        metrics = _add_prefix(metrics, prefix, self.group_separator)
        self.log_metrics(metrics, prefix, step)

    @rank_zero_only
    def log_text(self, prefix: tp.Union[str, tp.List[str]], key: str, text: str,
                 step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
        """Records text.
        This method logs text as soon as it received them.

        Args:
            prefix: Prefix(es) to use for metric names when writing to smart logger
            key: Key for the text
            text: String containing message
            step: Step number at which the metrics should be recorded
        """
        assert is_rank_zero(), "writer tried to log from global_rank != 0"
        if self.is_disabled() or not self.with_media_logging:
            return

        metrics = {
            key: wandb.Table(columns=[key], data=[text], **kwargs)
        }
        metrics = _add_prefix(metrics, prefix, self.group_separator)
        self.log_metrics(metrics, prefix, step)

    def with_media_logging(self) -> bool:
        """Whether the logger can save media or ignore them."""
        return self._with_media_logging

    @property
    def save_dir(self):
        """Directory where the data is saved."""
        return self._save_dir

    @property
    def name(self):
        """Name of the experiment logger."""
        return self._name

    @classmethod
    def from_xp(cls, with_media_logging: bool = True, project: tp.Optional[str] = None, name: tp.Optional[str] = None,
                group: tp.Optional[str] = None, **kwargs: tp.Any):
        xp = dora.get_xp()
        save_dir = xp.folder
        save_dir.mkdir(exist_ok=True)
        flag_file = xp.folder / 'wandb_flag'
        resume = flag_file.exists()
        flag_file.touch()
        config = None
        if wandb:
            api = wandb.Api()
            try:
                if project:
                    run = api.run(project + '/' + xp.sig)
                else:
                    run = api.run(xp.sig)
            except wandb.CommError:
                pass
            else:
                group = run.group
                name = run.name
                config = run.config
        return WandbLogger(str(save_dir), with_media_logging, id=xp.sig, name=name or xp.sig, group=group,
                           config=config, project=project, resume='allow' if resume else False, **kwargs)

Classes

class WandbLogger (save_dir: str, with_media_logging: bool = True, project: Optional[str] = None, id: Optional[str] = None, group: Optional[str] = None, reinit: bool = False, resume: Union[str, bool] = False, name: Optional[str] = None, **kwargs: Any)

ExperimentLogger for Wandb (Weight and Biases)

Args

save_dir : str
The directory where the experiment logs are written to
with_media_logging : bool
Whether to save media samples with the logger or ignore them This comes at bandwidth costs for Wandb but we are usually fine with it
project : str
Wandb project name
id : Optional[str]
Wandb run id
group : Optional[str]
Wandb group
reinit : bool
Whether to reinit run or not
resume : bool
Whether to resume run or not
name : str
Name for the experiment logs
kwargs : Any
Additional wandb parameters, including tags for example
Expand source code
class WandbLogger(ExperimentLogger):
    """ExperimentLogger for Wandb (Weight and Biases)

    Args:
        save_dir (str): The directory where the experiment logs are written to
        with_media_logging (bool): Whether to save media samples with the logger or ignore them
            This comes at bandwidth costs for Wandb but we are usually fine with it
        project (str): Wandb project name
        id (Optional[str]): Wandb run id
        group (Optional[str]): Wandb group
        reinit (bool): Whether to reinit run or not
        resume (bool): Whether to resume run or not
        name (str): Name for the experiment logs
        kwargs (Any): Additional wandb parameters, including tags for example
    """
    def __init__(self, save_dir: str, with_media_logging: bool = True, project: tp.Optional[str] = None,
                 id: tp.Optional[str] = None, group: tp.Optional[str] = None, reinit: bool = False,
                 resume: tp.Union[str, bool] = False, name: tp.Optional[str] = None, **kwargs: tp.Any):
        self._save_dir = save_dir
        self._with_media_logging = with_media_logging
        self._name = name or 'wandb'
        self._project = project
        self._id = id
        self._group = group
        self._reinit = reinit
        self._resume = resume
        self._wandb_run: tp.Optional[tp.Any] = None
        if wandb:
            self._wandb_run = wandb.init(
                project=self._project,
                reinit=self._reinit,
                group=self._group,
                name=self._name,
                dir=self.save_dir,
                id=self._id,
                resume=self._resume,
                **kwargs
            )
        else:
            warnings.warn("wandb package was not found: use pip install wandb")

    @property  # type: ignore
    @rank_zero_only
    def writer(self) -> tp.Optional[tp.Any]:
        """Actual wandb run object."""
        return self._wandb_run

    @rank_zero_only
    def is_disabled(self) -> bool:
        return self.writer is None

    @rank_zero_only
    def log_hyperparams(self, params: tp.Union[tp.Dict[str, tp.Any], Namespace],
                        metrics: tp.Optional[dict] = None) -> None:
        """Record experiment hyperparameters.
        This method logs the hyperparameters associated to the experiment.

        Args:
            params: Dictionary of hyperparameters
            metrics: Dictionary of final metrics for the set of hyperparameters
        """
        assert is_rank_zero(), "experiment tried to log from global_rank != 0"

        params = _convert_params(params)
        params = _flatten_dict(params)
        params = _sanitize_params(params)
        self.writer.config.update(params, allow_val_change=True)

    @rank_zero_only
    def log_metrics(self, prefix: tp.Union[str, tp.List[str]], metrics: dict, step: tp.Optional[int] = None) -> None:
        """Records metrics.
        This method logs metrics as as soon as it received them.

        Args:
            prefix: Prefix(es) to use for metric names when writing to smart logger
            metrics: Dictionary with metric names as keys and measured quantities as values
            step: Step number at which the metrics should be recorded
        """
        assert is_rank_zero(), "writer tried to log from global_rank != 0"
        if self.is_disabled() or not self.with_media_logging:
            return

        metrics = _add_prefix(metrics, prefix, self.group_separator)

        for key, val in metrics.items():
            wandb.log({key: val}, step=step)

    @rank_zero_only
    def log_audio(self, prefix: tp.Union[str, tp.List[str]], key: str, audio: tp.Any, sample_rate: int,
                  step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
        """Records audio.
        This method logs audio wave as soon as it received them.

        Args:
            prefix: Prefix(es) to use for metric names when writing to smart logger
            key: Key for the audio
            audio: Torch Tensor representing the audio as [C, T]
            step: Step number at which the metrics should be recorded
        """
        assert is_rank_zero(), "writer tried to log from global_rank != 0"
        if self.is_disabled() or not self.with_media_logging:
            return

        assert isinstance(audio, torch.Tensor), "Only support logging torch.Tensor as audio"

        audio = audio.t().clamp(-0.99, 0.99).numpy()
        metrics = {
            key: wandb.Audio(audio, sample_rate=sample_rate, **kwargs)
        }
        metrics = _add_prefix(metrics, prefix, self.group_separator)
        self.log_metrics(metrics, prefix, step)

    @rank_zero_only
    def log_image(self, prefix: tp.Union[str, tp.List[str]], key: str, image: tp.Any,
                  step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
        """Records image.
        This method logs image as soon as it received them.

        Args:
            prefix: Prefix(es) to use for metric names when writing to smart logger
            key: Key for the image
            image: Torch Tensor representing the image
            step: Step number at which the metrics should be recorded
        """
        assert is_rank_zero(), "writer tried to log from global_rank != 0"
        if self.is_disabled() or not self.with_media_logging:
            return

        assert isinstance(image, torch.Tensor), "Only support logging torch.Tensor as image"

        metrics = {
            key: wandb.Image(image, **kwargs)
        }
        metrics = _add_prefix(metrics, prefix, self.group_separator)
        self.log_metrics(metrics, prefix, step)

    @rank_zero_only
    def log_text(self, prefix: tp.Union[str, tp.List[str]], key: str, text: str,
                 step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
        """Records text.
        This method logs text as soon as it received them.

        Args:
            prefix: Prefix(es) to use for metric names when writing to smart logger
            key: Key for the text
            text: String containing message
            step: Step number at which the metrics should be recorded
        """
        assert is_rank_zero(), "writer tried to log from global_rank != 0"
        if self.is_disabled() or not self.with_media_logging:
            return

        metrics = {
            key: wandb.Table(columns=[key], data=[text], **kwargs)
        }
        metrics = _add_prefix(metrics, prefix, self.group_separator)
        self.log_metrics(metrics, prefix, step)

    def with_media_logging(self) -> bool:
        """Whether the logger can save media or ignore them."""
        return self._with_media_logging

    @property
    def save_dir(self):
        """Directory where the data is saved."""
        return self._save_dir

    @property
    def name(self):
        """Name of the experiment logger."""
        return self._name

    @classmethod
    def from_xp(cls, with_media_logging: bool = True, project: tp.Optional[str] = None, name: tp.Optional[str] = None,
                group: tp.Optional[str] = None, **kwargs: tp.Any):
        xp = dora.get_xp()
        save_dir = xp.folder
        save_dir.mkdir(exist_ok=True)
        flag_file = xp.folder / 'wandb_flag'
        resume = flag_file.exists()
        flag_file.touch()
        config = None
        if wandb:
            api = wandb.Api()
            try:
                if project:
                    run = api.run(project + '/' + xp.sig)
                else:
                    run = api.run(xp.sig)
            except wandb.CommError:
                pass
            else:
                group = run.group
                name = run.name
                config = run.config
        return WandbLogger(str(save_dir), with_media_logging, id=xp.sig, name=name or xp.sig, group=group,
                           config=config, project=project, resume='allow' if resume else False, **kwargs)

Ancestors

Static methods

def from_xp(with_media_logging: bool = True, project: Optional[str] = None, name: Optional[str] = None, group: Optional[str] = None, **kwargs: Any)
Expand source code
@classmethod
def from_xp(cls, with_media_logging: bool = True, project: tp.Optional[str] = None, name: tp.Optional[str] = None,
            group: tp.Optional[str] = None, **kwargs: tp.Any):
    xp = dora.get_xp()
    save_dir = xp.folder
    save_dir.mkdir(exist_ok=True)
    flag_file = xp.folder / 'wandb_flag'
    resume = flag_file.exists()
    flag_file.touch()
    config = None
    if wandb:
        api = wandb.Api()
        try:
            if project:
                run = api.run(project + '/' + xp.sig)
            else:
                run = api.run(xp.sig)
        except wandb.CommError:
            pass
        else:
            group = run.group
            name = run.name
            config = run.config
    return WandbLogger(str(save_dir), with_media_logging, id=xp.sig, name=name or xp.sig, group=group,
                       config=config, project=project, resume='allow' if resume else False, **kwargs)

Instance variables

var writer : Optional[Any]

Actual wandb run object.

Expand source code
@property  # type: ignore
@rank_zero_only
def writer(self) -> tp.Optional[tp.Any]:
    """Actual wandb run object."""
    return self._wandb_run

Methods

def is_disabled(self) ‑> bool
Expand source code
@rank_zero_only
def is_disabled(self) -> bool:
    return self.writer is None
def log_audio(self, prefix: Union[str, List[str]], key: str, audio: Any, sample_rate: int, step: Optional[int] = None, **kwargs: Any) ‑> None

Records audio. This method logs audio wave as soon as it received them.

Args

prefix
Prefix(es) to use for metric names when writing to smart logger
key
Key for the audio
audio
Torch Tensor representing the audio as [C, T]
step
Step number at which the metrics should be recorded
Expand source code
@rank_zero_only
def log_audio(self, prefix: tp.Union[str, tp.List[str]], key: str, audio: tp.Any, sample_rate: int,
              step: tp.Optional[int] = None, **kwargs: tp.Any) -> None:
    """Records audio.
    This method logs audio wave as soon as it received them.

    Args:
        prefix: Prefix(es) to use for metric names when writing to smart logger
        key: Key for the audio
        audio: Torch Tensor representing the audio as [C, T]
        step: Step number at which the metrics should be recorded
    """
    assert is_rank_zero(), "writer tried to log from global_rank != 0"
    if self.is_disabled() or not self.with_media_logging:
        return

    assert isinstance(audio, torch.Tensor), "Only support logging torch.Tensor as audio"

    audio = audio.t().clamp(-0.99, 0.99).numpy()
    metrics = {
        key: wandb.Audio(audio, sample_rate=sample_rate, **kwargs)
    }
    metrics = _add_prefix(metrics, prefix, self.group_separator)
    self.log_metrics(metrics, prefix, step)

Inherited members