docs for trnbl v0.1.1
View Source on GitHub

trnbl.loggers.wandb


 1import datetime
 2import json
 3import logging
 4from typing import Any
 5from pathlib import Path
 6import sys
 7
 8import wandb
 9from wandb.sdk.wandb_run import Run, Artifact
10
11from trnbl.loggers.base import TrainingLoggerBase
12
13
14class WandbLogger(TrainingLoggerBase):
15	"""wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`"""
16
17	def __init__(self, run: Run):
18		self._run: Run = run
19
20	@classmethod
21	def create(
22		cls,
23		config: dict,
24		project: str | None = None,
25		job_type: str | None = None,
26		logging_fmt: str = "%(asctime)s %(levelname)s %(message)s",
27		logging_datefmt: str = "%Y-%m-%d %H:%M:%S",
28		wandb_kwargs: dict | None = None,
29	) -> "WandbLogger":
30		logging.basicConfig(
31			stream=sys.stdout,
32			level=logging.INFO,
33			format=logging_fmt,
34			datefmt=logging_datefmt,
35		)
36
37		run: Run  # type: ignore[return-value]
38		run = wandb.init(
39			config=config,
40			project=project,
41			job_type=job_type,
42			**(wandb_kwargs if wandb_kwargs else {}),
43		)
44
45		assert run is not None, f"wandb.init returned None: {wandb_kwargs}"
46
47		logger: WandbLogger = WandbLogger(run)
48		logger.message(f"{config =}")
49		return logger
50
51	def debug(self, message: str, **kwargs) -> None:
52		if kwargs:
53			message += f" {kwargs =}"
54		logging.debug(message)
55
56	def message(self, message: str, **kwargs) -> None:
57		if kwargs:
58			message += f" {kwargs =}"
59		logging.info(message)
60
61	def metrics(self, data: dict[str, Any]) -> None:
62		self._run.log(data)
63
64	def artifact(
65		self,
66		path: Path,
67		type: str,
68		aliases: list[str] | None = None,
69		metadata: dict | None = None,
70	) -> None:
71		artifact: Artifact = wandb.Artifact(name=self._run.id, type=type)
72		artifact.add_file(str(path))
73		if metadata:
74			artifact.description = json.dumps(
75				dict(
76					timestamp=datetime.datetime.now().isoformat(),
77					path=path.as_posix(),
78					type=type,
79					aliases=aliases,
80					metadata=metadata if metadata else {},
81				)
82			)
83		self._run.log_artifact(artifact, aliases=aliases)
84
85	@property
86	def url(self) -> str:
87		# TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass
88		return str(self._run.get_url())
89
90	@property
91	def run_path(self) -> Path:
92		return Path(self._run._get_path())
93
94	def flush(self) -> None:
95		self._run.save()
96
97	def finish(self) -> None:
98		"""Finish logging"""
99		self._run.finish()

 15class WandbLogger(TrainingLoggerBase):
 16	"""wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`"""
 17
 18	def __init__(self, run: Run):
 19		self._run: Run = run
 20
 21	@classmethod
 22	def create(
 23		cls,
 24		config: dict,
 25		project: str | None = None,
 26		job_type: str | None = None,
 27		logging_fmt: str = "%(asctime)s %(levelname)s %(message)s",
 28		logging_datefmt: str = "%Y-%m-%d %H:%M:%S",
 29		wandb_kwargs: dict | None = None,
 30	) -> "WandbLogger":
 31		logging.basicConfig(
 32			stream=sys.stdout,
 33			level=logging.INFO,
 34			format=logging_fmt,
 35			datefmt=logging_datefmt,
 36		)
 37
 38		run: Run  # type: ignore[return-value]
 39		run = wandb.init(
 40			config=config,
 41			project=project,
 42			job_type=job_type,
 43			**(wandb_kwargs if wandb_kwargs else {}),
 44		)
 45
 46		assert run is not None, f"wandb.init returned None: {wandb_kwargs}"
 47
 48		logger: WandbLogger = WandbLogger(run)
 49		logger.message(f"{config =}")
 50		return logger
 51
 52	def debug(self, message: str, **kwargs) -> None:
 53		if kwargs:
 54			message += f" {kwargs =}"
 55		logging.debug(message)
 56
 57	def message(self, message: str, **kwargs) -> None:
 58		if kwargs:
 59			message += f" {kwargs =}"
 60		logging.info(message)
 61
 62	def metrics(self, data: dict[str, Any]) -> None:
 63		self._run.log(data)
 64
 65	def artifact(
 66		self,
 67		path: Path,
 68		type: str,
 69		aliases: list[str] | None = None,
 70		metadata: dict | None = None,
 71	) -> None:
 72		artifact: Artifact = wandb.Artifact(name=self._run.id, type=type)
 73		artifact.add_file(str(path))
 74		if metadata:
 75			artifact.description = json.dumps(
 76				dict(
 77					timestamp=datetime.datetime.now().isoformat(),
 78					path=path.as_posix(),
 79					type=type,
 80					aliases=aliases,
 81					metadata=metadata if metadata else {},
 82				)
 83			)
 84		self._run.log_artifact(artifact, aliases=aliases)
 85
 86	@property
 87	def url(self) -> str:
 88		# TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass
 89		return str(self._run.get_url())
 90
 91	@property
 92	def run_path(self) -> Path:
 93		return Path(self._run._get_path())
 94
 95	def flush(self) -> None:
 96		self._run.save()
 97
 98	def finish(self) -> None:
 99		"""Finish logging"""
100		self._run.finish()

wrapper around wandb logging for TrainingLoggerBase. create using WandbLogger.create(config, project, job_type)

WandbLogger(run: wandb.sdk.wandb_run.Run)
18	def __init__(self, run: Run):
19		self._run: Run = run
@classmethod
def create( cls, config: dict, project: str | None = None, job_type: str | None = None, logging_fmt: str = '%(asctime)s %(levelname)s %(message)s', logging_datefmt: str = '%Y-%m-%d %H:%M:%S', wandb_kwargs: dict | None = None) -> WandbLogger:
21	@classmethod
22	def create(
23		cls,
24		config: dict,
25		project: str | None = None,
26		job_type: str | None = None,
27		logging_fmt: str = "%(asctime)s %(levelname)s %(message)s",
28		logging_datefmt: str = "%Y-%m-%d %H:%M:%S",
29		wandb_kwargs: dict | None = None,
30	) -> "WandbLogger":
31		logging.basicConfig(
32			stream=sys.stdout,
33			level=logging.INFO,
34			format=logging_fmt,
35			datefmt=logging_datefmt,
36		)
37
38		run: Run  # type: ignore[return-value]
39		run = wandb.init(
40			config=config,
41			project=project,
42			job_type=job_type,
43			**(wandb_kwargs if wandb_kwargs else {}),
44		)
45
46		assert run is not None, f"wandb.init returned None: {wandb_kwargs}"
47
48		logger: WandbLogger = WandbLogger(run)
49		logger.message(f"{config =}")
50		return logger
def debug(self, message: str, **kwargs) -> None:
52	def debug(self, message: str, **kwargs) -> None:
53		if kwargs:
54			message += f" {kwargs =}"
55		logging.debug(message)

log a debug message which will be saved, but not printed

def message(self, message: str, **kwargs) -> None:
57	def message(self, message: str, **kwargs) -> None:
58		if kwargs:
59			message += f" {kwargs =}"
60		logging.info(message)

log a progress message, which will be printed to stdout

def metrics(self, data: dict[str, typing.Any]) -> None:
62	def metrics(self, data: dict[str, Any]) -> None:
63		self._run.log(data)

Log a dictionary of metrics

def artifact( self, path: pathlib.Path, type: str, aliases: list[str] | None = None, metadata: dict | None = None) -> None:
65	def artifact(
66		self,
67		path: Path,
68		type: str,
69		aliases: list[str] | None = None,
70		metadata: dict | None = None,
71	) -> None:
72		artifact: Artifact = wandb.Artifact(name=self._run.id, type=type)
73		artifact.add_file(str(path))
74		if metadata:
75			artifact.description = json.dumps(
76				dict(
77					timestamp=datetime.datetime.now().isoformat(),
78					path=path.as_posix(),
79					type=type,
80					aliases=aliases,
81					metadata=metadata if metadata else {},
82				)
83			)
84		self._run.log_artifact(artifact, aliases=aliases)

log an artifact from a file

url: str
86	@property
87	def url(self) -> str:
88		# TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass
89		return str(self._run.get_url())

Get the URL for the current logging run

run_path: pathlib.Path
91	@property
92	def run_path(self) -> Path:
93		return Path(self._run._get_path())

Get the path to the current logging run

def flush(self) -> None:
95	def flush(self) -> None:
96		self._run.save()

Flush the logger

def finish(self) -> None:
 98	def finish(self) -> None:
 99		"""Finish logging"""
100		self._run.finish()

Finish logging