trnbl.loggers.multi
1from typing import Any, TypeVar 2from pathlib import Path 3 4from trnbl.loggers.base import TrainingLoggerBase 5 6# TODO: move this to muutils 7T = TypeVar("T") 8 9 10def maybe_flatten(lst: list[T | list[T]]) -> list[T]: 11 """flatten a list if it is nested""" 12 flat_lst: list[T] = [] 13 for item in lst: 14 if isinstance(item, list): 15 flat_lst.extend(item) 16 else: 17 flat_lst.append(item) 18 return flat_lst 19 20 21class MultiLogger(TrainingLoggerBase): 22 """use multiple loggers at once""" 23 24 def __init__(self, loggers: list[TrainingLoggerBase]) -> None: 25 self.loggers: list[TrainingLoggerBase] = loggers 26 27 def debug(self, message: str, **kwargs) -> None: 28 """log a debug message which will be saved, but not printed""" 29 for logger in self.loggers: 30 logger.debug(message, **kwargs) 31 32 def message(self, message: str, **kwargs) -> None: 33 """log a progress message""" 34 for logger in self.loggers: 35 logger.message(message, **kwargs) 36 37 def metrics(self, data: dict[str, Any]) -> None: 38 """Log a dictionary of metrics""" 39 for logger in self.loggers: 40 logger.metrics(data) 41 42 def artifact( 43 self, 44 path: Path, 45 type: str, 46 aliases: list[str] | None = None, 47 metadata: dict | None = None, 48 ) -> None: 49 """log an artifact from a file""" 50 for logger in self.loggers: 51 logger.artifact(path=path, type=type, aliases=aliases, metadata=metadata) 52 53 @property 54 def url(self) -> list[str]: 55 """Get the URL for the current logging run""" 56 return maybe_flatten([logger.url for logger in self.loggers]) 57 58 @property 59 def run_path(self) -> list[Path]: 60 """Get the paths to the current logging run""" 61 return maybe_flatten([logger.run_path for logger in self.loggers]) 62 63 def flush(self) -> None: 64 """Flush the logger""" 65 for logger in self.loggers: 66 logger.flush() 67 68 def finish(self) -> None: 69 """Finish logging""" 70 for logger in self.loggers: 71 logger.finish()
def
maybe_flatten(lst: list[typing.Union[~T, list[~T]]]) -> list[~T]:
11def maybe_flatten(lst: list[T | list[T]]) -> list[T]: 12 """flatten a list if it is nested""" 13 flat_lst: list[T] = [] 14 for item in lst: 15 if isinstance(item, list): 16 flat_lst.extend(item) 17 else: 18 flat_lst.append(item) 19 return flat_lst
flatten a list if it is nested
22class MultiLogger(TrainingLoggerBase): 23 """use multiple loggers at once""" 24 25 def __init__(self, loggers: list[TrainingLoggerBase]) -> None: 26 self.loggers: list[TrainingLoggerBase] = loggers 27 28 def debug(self, message: str, **kwargs) -> None: 29 """log a debug message which will be saved, but not printed""" 30 for logger in self.loggers: 31 logger.debug(message, **kwargs) 32 33 def message(self, message: str, **kwargs) -> None: 34 """log a progress message""" 35 for logger in self.loggers: 36 logger.message(message, **kwargs) 37 38 def metrics(self, data: dict[str, Any]) -> None: 39 """Log a dictionary of metrics""" 40 for logger in self.loggers: 41 logger.metrics(data) 42 43 def artifact( 44 self, 45 path: Path, 46 type: str, 47 aliases: list[str] | None = None, 48 metadata: dict | None = None, 49 ) -> None: 50 """log an artifact from a file""" 51 for logger in self.loggers: 52 logger.artifact(path=path, type=type, aliases=aliases, metadata=metadata) 53 54 @property 55 def url(self) -> list[str]: 56 """Get the URL for the current logging run""" 57 return maybe_flatten([logger.url for logger in self.loggers]) 58 59 @property 60 def run_path(self) -> list[Path]: 61 """Get the paths to the current logging run""" 62 return maybe_flatten([logger.run_path for logger in self.loggers]) 63 64 def flush(self) -> None: 65 """Flush the logger""" 66 for logger in self.loggers: 67 logger.flush() 68 69 def finish(self) -> None: 70 """Finish logging""" 71 for logger in self.loggers: 72 logger.finish()
use multiple loggers at once
MultiLogger(loggers: list[trnbl.TrainingLoggerBase])
loggers: list[trnbl.TrainingLoggerBase]
def
debug(self, message: str, **kwargs) -> None:
28 def debug(self, message: str, **kwargs) -> None: 29 """log a debug message which will be saved, but not printed""" 30 for logger in self.loggers: 31 logger.debug(message, **kwargs)
log a debug message which will be saved, but not printed
def
message(self, message: str, **kwargs) -> None:
33 def message(self, message: str, **kwargs) -> None: 34 """log a progress message""" 35 for logger in self.loggers: 36 logger.message(message, **kwargs)
log a progress message
def
metrics(self, data: dict[str, typing.Any]) -> None:
38 def metrics(self, data: dict[str, Any]) -> None: 39 """Log a dictionary of metrics""" 40 for logger in self.loggers: 41 logger.metrics(data)
Log a dictionary of metrics
def
artifact( self, path: pathlib.Path, type: str, aliases: list[str] | None = None, metadata: dict | None = None) -> None:
43 def artifact( 44 self, 45 path: Path, 46 type: str, 47 aliases: list[str] | None = None, 48 metadata: dict | None = None, 49 ) -> None: 50 """log an artifact from a file""" 51 for logger in self.loggers: 52 logger.artifact(path=path, type=type, aliases=aliases, metadata=metadata)
log an artifact from a file
url: list[str]
54 @property 55 def url(self) -> list[str]: 56 """Get the URL for the current logging run""" 57 return maybe_flatten([logger.url for logger in self.loggers])
Get the URL for the current logging run
run_path: list[pathlib.Path]
59 @property 60 def run_path(self) -> list[Path]: 61 """Get the paths to the current logging run""" 62 return maybe_flatten([logger.run_path for logger in self.loggers])
Get the paths to the current logging run
def
flush(self) -> None:
64 def flush(self) -> None: 65 """Flush the logger""" 66 for logger in self.loggers: 67 logger.flush()
Flush the logger
def
finish(self) -> None:
69 def finish(self) -> None: 70 """Finish logging""" 71 for logger in self.loggers: 72 logger.finish()
Finish logging