Coverage for trnbl\loggers\multi.py: 0%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1from typing import Any, TypeVar 

2from pathlib import Path 

3 

4from trnbl.loggers.base import TrainingLoggerBase 

5 

6# TODO: move this to muutils 

7T = TypeVar("T") 

8 

9 

10def maybe_flatten(lst: list[T | list[T]]) -> list[T]: 

11 """flatten a list if it is nested""" 

12 flat_lst: list[T] = [] 

13 for item in lst: 

14 if isinstance(item, list): 

15 flat_lst.extend(item) 

16 else: 

17 flat_lst.append(item) 

18 return flat_lst 

19 

20 

21class MultiLogger(TrainingLoggerBase): 

22 """use multiple loggers at once""" 

23 

24 def __init__(self, loggers: list[TrainingLoggerBase]) -> None: 

25 self.loggers: list[TrainingLoggerBase] = loggers 

26 

27 def debug(self, message: str, **kwargs) -> None: 

28 """log a debug message which will be saved, but not printed""" 

29 for logger in self.loggers: 

30 logger.debug(message, **kwargs) 

31 

32 def message(self, message: str, **kwargs) -> None: 

33 """log a progress message""" 

34 for logger in self.loggers: 

35 logger.message(message, **kwargs) 

36 

37 def metrics(self, data: dict[str, Any]) -> None: 

38 """Log a dictionary of metrics""" 

39 for logger in self.loggers: 

40 logger.metrics(data) 

41 

42 def artifact( 

43 self, 

44 path: Path, 

45 type: str, 

46 aliases: list[str] | None = None, 

47 metadata: dict | None = None, 

48 ) -> None: 

49 """log an artifact from a file""" 

50 for logger in self.loggers: 

51 logger.artifact(path=path, type=type, aliases=aliases, metadata=metadata) 

52 

53 @property 

54 def url(self) -> list[str]: 

55 """Get the URL for the current logging run""" 

56 return maybe_flatten([logger.url for logger in self.loggers]) 

57 

58 @property 

59 def run_path(self) -> list[Path]: 

60 """Get the paths to the current logging run""" 

61 return maybe_flatten([logger.run_path for logger in self.loggers]) 

62 

63 def flush(self) -> None: 

64 """Flush the logger""" 

65 for logger in self.loggers: 

66 logger.flush() 

67 

68 def finish(self) -> None: 

69 """Finish logging""" 

70 for logger in self.loggers: 

71 logger.finish()