Coverage for trnbl\loggers\multi.py: 0%
38 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1from typing import Any, TypeVar
2from pathlib import Path
4from trnbl.loggers.base import TrainingLoggerBase
6# TODO: move this to muutils
7T = TypeVar("T")
10def maybe_flatten(lst: list[T | list[T]]) -> list[T]:
11 """flatten a list if it is nested"""
12 flat_lst: list[T] = []
13 for item in lst:
14 if isinstance(item, list):
15 flat_lst.extend(item)
16 else:
17 flat_lst.append(item)
18 return flat_lst
21class MultiLogger(TrainingLoggerBase):
22 """use multiple loggers at once"""
24 def __init__(self, loggers: list[TrainingLoggerBase]) -> None:
25 self.loggers: list[TrainingLoggerBase] = loggers
27 def debug(self, message: str, **kwargs) -> None:
28 """log a debug message which will be saved, but not printed"""
29 for logger in self.loggers:
30 logger.debug(message, **kwargs)
32 def message(self, message: str, **kwargs) -> None:
33 """log a progress message"""
34 for logger in self.loggers:
35 logger.message(message, **kwargs)
37 def metrics(self, data: dict[str, Any]) -> None:
38 """Log a dictionary of metrics"""
39 for logger in self.loggers:
40 logger.metrics(data)
42 def artifact(
43 self,
44 path: Path,
45 type: str,
46 aliases: list[str] | None = None,
47 metadata: dict | None = None,
48 ) -> None:
49 """log an artifact from a file"""
50 for logger in self.loggers:
51 logger.artifact(path=path, type=type, aliases=aliases, metadata=metadata)
53 @property
54 def url(self) -> list[str]:
55 """Get the URL for the current logging run"""
56 return maybe_flatten([logger.url for logger in self.loggers])
58 @property
59 def run_path(self) -> list[Path]:
60 """Get the paths to the current logging run"""
61 return maybe_flatten([logger.run_path for logger in self.loggers])
63 def flush(self) -> None:
64 """Flush the logger"""
65 for logger in self.loggers:
66 logger.flush()
68 def finish(self) -> None:
69 """Finish logging"""
70 for logger in self.loggers:
71 logger.finish()