docs for trnbl v0.1.1
View Source on GitHub

trnbl.loggers.multi


 1from typing import Any, TypeVar
 2from pathlib import Path
 3
 4from trnbl.loggers.base import TrainingLoggerBase
 5
 6# TODO: move this to muutils
 7T = TypeVar("T")
 8
 9
10def maybe_flatten(lst: list[T | list[T]]) -> list[T]:
11	"""flatten a list if it is nested"""
12	flat_lst: list[T] = []
13	for item in lst:
14		if isinstance(item, list):
15			flat_lst.extend(item)
16		else:
17			flat_lst.append(item)
18	return flat_lst
19
20
21class MultiLogger(TrainingLoggerBase):
22	"""use multiple loggers at once"""
23
24	def __init__(self, loggers: list[TrainingLoggerBase]) -> None:
25		self.loggers: list[TrainingLoggerBase] = loggers
26
27	def debug(self, message: str, **kwargs) -> None:
28		"""log a debug message which will be saved, but not printed"""
29		for logger in self.loggers:
30			logger.debug(message, **kwargs)
31
32	def message(self, message: str, **kwargs) -> None:
33		"""log a progress message"""
34		for logger in self.loggers:
35			logger.message(message, **kwargs)
36
37	def metrics(self, data: dict[str, Any]) -> None:
38		"""Log a dictionary of metrics"""
39		for logger in self.loggers:
40			logger.metrics(data)
41
42	def artifact(
43		self,
44		path: Path,
45		type: str,
46		aliases: list[str] | None = None,
47		metadata: dict | None = None,
48	) -> None:
49		"""log an artifact from a file"""
50		for logger in self.loggers:
51			logger.artifact(path=path, type=type, aliases=aliases, metadata=metadata)
52
53	@property
54	def url(self) -> list[str]:
55		"""Get the URL for the current logging run"""
56		return maybe_flatten([logger.url for logger in self.loggers])
57
58	@property
59	def run_path(self) -> list[Path]:
60		"""Get the paths to the current logging run"""
61		return maybe_flatten([logger.run_path for logger in self.loggers])
62
63	def flush(self) -> None:
64		"""Flush the logger"""
65		for logger in self.loggers:
66			logger.flush()
67
68	def finish(self) -> None:
69		"""Finish logging"""
70		for logger in self.loggers:
71			logger.finish()

def maybe_flatten(lst: list[typing.Union[~T, list[~T]]]) -> list[~T]:
11def maybe_flatten(lst: list[T | list[T]]) -> list[T]:
12	"""flatten a list if it is nested"""
13	flat_lst: list[T] = []
14	for item in lst:
15		if isinstance(item, list):
16			flat_lst.extend(item)
17		else:
18			flat_lst.append(item)
19	return flat_lst

flatten a list if it is nested

22class MultiLogger(TrainingLoggerBase):
23	"""use multiple loggers at once"""
24
25	def __init__(self, loggers: list[TrainingLoggerBase]) -> None:
26		self.loggers: list[TrainingLoggerBase] = loggers
27
28	def debug(self, message: str, **kwargs) -> None:
29		"""log a debug message which will be saved, but not printed"""
30		for logger in self.loggers:
31			logger.debug(message, **kwargs)
32
33	def message(self, message: str, **kwargs) -> None:
34		"""log a progress message"""
35		for logger in self.loggers:
36			logger.message(message, **kwargs)
37
38	def metrics(self, data: dict[str, Any]) -> None:
39		"""Log a dictionary of metrics"""
40		for logger in self.loggers:
41			logger.metrics(data)
42
43	def artifact(
44		self,
45		path: Path,
46		type: str,
47		aliases: list[str] | None = None,
48		metadata: dict | None = None,
49	) -> None:
50		"""log an artifact from a file"""
51		for logger in self.loggers:
52			logger.artifact(path=path, type=type, aliases=aliases, metadata=metadata)
53
54	@property
55	def url(self) -> list[str]:
56		"""Get the URL for the current logging run"""
57		return maybe_flatten([logger.url for logger in self.loggers])
58
59	@property
60	def run_path(self) -> list[Path]:
61		"""Get the paths to the current logging run"""
62		return maybe_flatten([logger.run_path for logger in self.loggers])
63
64	def flush(self) -> None:
65		"""Flush the logger"""
66		for logger in self.loggers:
67			logger.flush()
68
69	def finish(self) -> None:
70		"""Finish logging"""
71		for logger in self.loggers:
72			logger.finish()

use multiple loggers at once

MultiLogger(loggers: list[trnbl.TrainingLoggerBase])
25	def __init__(self, loggers: list[TrainingLoggerBase]) -> None:
26		self.loggers: list[TrainingLoggerBase] = loggers
loggers: list[trnbl.TrainingLoggerBase]
def debug(self, message: str, **kwargs) -> None:
28	def debug(self, message: str, **kwargs) -> None:
29		"""log a debug message which will be saved, but not printed"""
30		for logger in self.loggers:
31			logger.debug(message, **kwargs)

log a debug message which will be saved, but not printed

def message(self, message: str, **kwargs) -> None:
33	def message(self, message: str, **kwargs) -> None:
34		"""log a progress message"""
35		for logger in self.loggers:
36			logger.message(message, **kwargs)

log a progress message

def metrics(self, data: dict[str, typing.Any]) -> None:
38	def metrics(self, data: dict[str, Any]) -> None:
39		"""Log a dictionary of metrics"""
40		for logger in self.loggers:
41			logger.metrics(data)

Log a dictionary of metrics

def artifact( self, path: pathlib.Path, type: str, aliases: list[str] | None = None, metadata: dict | None = None) -> None:
43	def artifact(
44		self,
45		path: Path,
46		type: str,
47		aliases: list[str] | None = None,
48		metadata: dict | None = None,
49	) -> None:
50		"""log an artifact from a file"""
51		for logger in self.loggers:
52			logger.artifact(path=path, type=type, aliases=aliases, metadata=metadata)

log an artifact from a file

url: list[str]
54	@property
55	def url(self) -> list[str]:
56		"""Get the URL for the current logging run"""
57		return maybe_flatten([logger.url for logger in self.loggers])

Get the URL for the current logging run

run_path: list[pathlib.Path]
59	@property
60	def run_path(self) -> list[Path]:
61		"""Get the paths to the current logging run"""
62		return maybe_flatten([logger.run_path for logger in self.loggers])

Get the paths to the current logging run

def flush(self) -> None:
64	def flush(self) -> None:
65		"""Flush the logger"""
66		for logger in self.loggers:
67			logger.flush()

Flush the logger

def finish(self) -> None:
69	def finish(self) -> None:
70		"""Finish logging"""
71		for logger in self.loggers:
72			logger.finish()

Finish logging