Coverage for trnbl\loggers\wandb.py: 0%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-30 04:47 -0700

1import datetime 

2import json 

3import logging 

4from typing import Any 

5from pathlib import Path 

6import sys 

7 

8import wandb 

9from wandb.sdk.wandb_run import Run, Artifact 

10 

11from trnbl.loggers.base import TrainingLoggerBase 

12 

13 

14class WandbLogger(TrainingLoggerBase): 

15 """wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`""" 

16 

17 def __init__(self, run: Run): 

18 self._run: Run = run 

19 

20 @classmethod 

21 def create( 

22 cls, 

23 config: dict, 

24 project: str | None = None, 

25 job_type: str | None = None, 

26 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s", 

27 logging_datefmt: str = "%Y-%m-%d %H:%M:%S", 

28 wandb_kwargs: dict | None = None, 

29 ) -> "WandbLogger": 

30 logging.basicConfig( 

31 stream=sys.stdout, 

32 level=logging.INFO, 

33 format=logging_fmt, 

34 datefmt=logging_datefmt, 

35 ) 

36 

37 run: Run # type: ignore[return-value] 

38 run = wandb.init( 

39 config=config, 

40 project=project, 

41 job_type=job_type, 

42 **(wandb_kwargs if wandb_kwargs else {}), 

43 ) 

44 

45 assert run is not None, f"wandb.init returned None: {wandb_kwargs}" 

46 

47 logger: WandbLogger = WandbLogger(run) 

48 logger.message(f"{config =}") 

49 return logger 

50 

51 def debug(self, message: str, **kwargs) -> None: 

52 if kwargs: 

53 message += f" {kwargs =}" 

54 logging.debug(message) 

55 

56 def message(self, message: str, **kwargs) -> None: 

57 if kwargs: 

58 message += f" {kwargs =}" 

59 logging.info(message) 

60 

61 def metrics(self, data: dict[str, Any]) -> None: 

62 self._run.log(data) 

63 

64 def artifact( 

65 self, 

66 path: Path, 

67 type: str, 

68 aliases: list[str] | None = None, 

69 metadata: dict | None = None, 

70 ) -> None: 

71 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type) 

72 artifact.add_file(str(path)) 

73 if metadata: 

74 artifact.description = json.dumps( 

75 dict( 

76 timestamp=datetime.datetime.now().isoformat(), 

77 path=path.as_posix(), 

78 type=type, 

79 aliases=aliases, 

80 metadata=metadata if metadata else {}, 

81 ) 

82 ) 

83 self._run.log_artifact(artifact, aliases=aliases) 

84 

85 @property 

86 def url(self) -> str: 

87 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass 

88 return str(self._run.get_url()) 

89 

90 @property 

91 def run_path(self) -> Path: 

92 return Path(self._run._get_path()) 

93 

94 def flush(self) -> None: 

95 self._run.save() 

96 

97 def finish(self) -> None: 

98 """Finish logging""" 

99 self._run.finish()