Coverage for trnbl\loggers\wandb.py: 0%
46 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-30 04:47 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-30 04:47 -0700
1import datetime
2import json
3import logging
4from typing import Any
5from pathlib import Path
6import sys
8import wandb
9from wandb.sdk.wandb_run import Run, Artifact
11from trnbl.loggers.base import TrainingLoggerBase
14class WandbLogger(TrainingLoggerBase):
15 """wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`"""
17 def __init__(self, run: Run):
18 self._run: Run = run
20 @classmethod
21 def create(
22 cls,
23 config: dict,
24 project: str | None = None,
25 job_type: str | None = None,
26 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s",
27 logging_datefmt: str = "%Y-%m-%d %H:%M:%S",
28 wandb_kwargs: dict | None = None,
29 ) -> "WandbLogger":
30 logging.basicConfig(
31 stream=sys.stdout,
32 level=logging.INFO,
33 format=logging_fmt,
34 datefmt=logging_datefmt,
35 )
37 run: Run # type: ignore[return-value]
38 run = wandb.init(
39 config=config,
40 project=project,
41 job_type=job_type,
42 **(wandb_kwargs if wandb_kwargs else {}),
43 )
45 assert run is not None, f"wandb.init returned None: {wandb_kwargs}"
47 logger: WandbLogger = WandbLogger(run)
48 logger.message(f"{config =}")
49 return logger
51 def debug(self, message: str, **kwargs) -> None:
52 if kwargs:
53 message += f" {kwargs =}"
54 logging.debug(message)
56 def message(self, message: str, **kwargs) -> None:
57 if kwargs:
58 message += f" {kwargs =}"
59 logging.info(message)
61 def metrics(self, data: dict[str, Any]) -> None:
62 self._run.log(data)
64 def artifact(
65 self,
66 path: Path,
67 type: str,
68 aliases: list[str] | None = None,
69 metadata: dict | None = None,
70 ) -> None:
71 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type)
72 artifact.add_file(str(path))
73 if metadata:
74 artifact.description = json.dumps(
75 dict(
76 timestamp=datetime.datetime.now().isoformat(),
77 path=path.as_posix(),
78 type=type,
79 aliases=aliases,
80 metadata=metadata if metadata else {},
81 )
82 )
83 self._run.log_artifact(artifact, aliases=aliases)
85 @property
86 def url(self) -> str:
87 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass
88 return str(self._run.get_url())
90 @property
91 def run_path(self) -> Path:
92 return Path(self._run._get_path())
94 def flush(self) -> None:
95 self._run.save()
97 def finish(self) -> None:
98 """Finish logging"""
99 self._run.finish()