trnbl.loggers.wandb
1import datetime 2import json 3import logging 4from typing import Any 5from pathlib import Path 6import sys 7 8import wandb 9from wandb.sdk.wandb_run import Run, Artifact 10 11from trnbl.loggers.base import TrainingLoggerBase 12 13 14class WandbLogger(TrainingLoggerBase): 15 """wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`""" 16 17 def __init__(self, run: Run): 18 self._run: Run = run 19 20 @classmethod 21 def create( 22 cls, 23 config: dict, 24 project: str | None = None, 25 job_type: str | None = None, 26 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s", 27 logging_datefmt: str = "%Y-%m-%d %H:%M:%S", 28 wandb_kwargs: dict | None = None, 29 ) -> "WandbLogger": 30 logging.basicConfig( 31 stream=sys.stdout, 32 level=logging.INFO, 33 format=logging_fmt, 34 datefmt=logging_datefmt, 35 ) 36 37 run: Run # type: ignore[return-value] 38 run = wandb.init( 39 config=config, 40 project=project, 41 job_type=job_type, 42 **(wandb_kwargs if wandb_kwargs else {}), 43 ) 44 45 assert run is not None, f"wandb.init returned None: {wandb_kwargs}" 46 47 logger: WandbLogger = WandbLogger(run) 48 logger.message(f"{config =}") 49 return logger 50 51 def debug(self, message: str, **kwargs) -> None: 52 if kwargs: 53 message += f" {kwargs =}" 54 logging.debug(message) 55 56 def message(self, message: str, **kwargs) -> None: 57 if kwargs: 58 message += f" {kwargs =}" 59 logging.info(message) 60 61 def metrics(self, data: dict[str, Any]) -> None: 62 self._run.log(data) 63 64 def artifact( 65 self, 66 path: Path, 67 type: str, 68 aliases: list[str] | None = None, 69 metadata: dict | None = None, 70 ) -> None: 71 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type) 72 artifact.add_file(str(path)) 73 if metadata: 74 artifact.description = json.dumps( 75 dict( 76 timestamp=datetime.datetime.now().isoformat(), 77 path=path.as_posix(), 78 type=type, 79 aliases=aliases, 80 metadata=metadata if metadata else {}, 81 ) 82 ) 83 self._run.log_artifact(artifact, aliases=aliases) 84 85 @property 86 def url(self) -> str: 87 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass 88 return str(self._run.get_url()) 89 90 @property 91 def run_path(self) -> Path: 92 return Path(self._run._get_path()) 93 94 def flush(self) -> None: 95 self._run.save() 96 97 def finish(self) -> None: 98 """Finish logging""" 99 self._run.finish()
15class WandbLogger(TrainingLoggerBase): 16 """wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`""" 17 18 def __init__(self, run: Run): 19 self._run: Run = run 20 21 @classmethod 22 def create( 23 cls, 24 config: dict, 25 project: str | None = None, 26 job_type: str | None = None, 27 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s", 28 logging_datefmt: str = "%Y-%m-%d %H:%M:%S", 29 wandb_kwargs: dict | None = None, 30 ) -> "WandbLogger": 31 logging.basicConfig( 32 stream=sys.stdout, 33 level=logging.INFO, 34 format=logging_fmt, 35 datefmt=logging_datefmt, 36 ) 37 38 run: Run # type: ignore[return-value] 39 run = wandb.init( 40 config=config, 41 project=project, 42 job_type=job_type, 43 **(wandb_kwargs if wandb_kwargs else {}), 44 ) 45 46 assert run is not None, f"wandb.init returned None: {wandb_kwargs}" 47 48 logger: WandbLogger = WandbLogger(run) 49 logger.message(f"{config =}") 50 return logger 51 52 def debug(self, message: str, **kwargs) -> None: 53 if kwargs: 54 message += f" {kwargs =}" 55 logging.debug(message) 56 57 def message(self, message: str, **kwargs) -> None: 58 if kwargs: 59 message += f" {kwargs =}" 60 logging.info(message) 61 62 def metrics(self, data: dict[str, Any]) -> None: 63 self._run.log(data) 64 65 def artifact( 66 self, 67 path: Path, 68 type: str, 69 aliases: list[str] | None = None, 70 metadata: dict | None = None, 71 ) -> None: 72 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type) 73 artifact.add_file(str(path)) 74 if metadata: 75 artifact.description = json.dumps( 76 dict( 77 timestamp=datetime.datetime.now().isoformat(), 78 path=path.as_posix(), 79 type=type, 80 aliases=aliases, 81 metadata=metadata if metadata else {}, 82 ) 83 ) 84 self._run.log_artifact(artifact, aliases=aliases) 85 86 @property 87 def url(self) -> str: 88 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass 89 return str(self._run.get_url()) 90 91 @property 92 def run_path(self) -> Path: 93 return Path(self._run._get_path()) 94 95 def flush(self) -> None: 96 self._run.save() 97 98 def finish(self) -> None: 99 """Finish logging""" 100 self._run.finish()
wrapper around wandb logging for TrainingLoggerBase
. create using WandbLogger.create(config, project, job_type)
@classmethod
def
create( cls, config: dict, project: str | None = None, job_type: str | None = None, logging_fmt: str = '%(asctime)s %(levelname)s %(message)s', logging_datefmt: str = '%Y-%m-%d %H:%M:%S', wandb_kwargs: dict | None = None) -> WandbLogger:
21 @classmethod 22 def create( 23 cls, 24 config: dict, 25 project: str | None = None, 26 job_type: str | None = None, 27 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s", 28 logging_datefmt: str = "%Y-%m-%d %H:%M:%S", 29 wandb_kwargs: dict | None = None, 30 ) -> "WandbLogger": 31 logging.basicConfig( 32 stream=sys.stdout, 33 level=logging.INFO, 34 format=logging_fmt, 35 datefmt=logging_datefmt, 36 ) 37 38 run: Run # type: ignore[return-value] 39 run = wandb.init( 40 config=config, 41 project=project, 42 job_type=job_type, 43 **(wandb_kwargs if wandb_kwargs else {}), 44 ) 45 46 assert run is not None, f"wandb.init returned None: {wandb_kwargs}" 47 48 logger: WandbLogger = WandbLogger(run) 49 logger.message(f"{config =}") 50 return logger
def
debug(self, message: str, **kwargs) -> None:
52 def debug(self, message: str, **kwargs) -> None: 53 if kwargs: 54 message += f" {kwargs =}" 55 logging.debug(message)
log a debug message which will be saved, but not printed
def
message(self, message: str, **kwargs) -> None:
57 def message(self, message: str, **kwargs) -> None: 58 if kwargs: 59 message += f" {kwargs =}" 60 logging.info(message)
log a progress message, which will be printed to stdout
def
artifact( self, path: pathlib.Path, type: str, aliases: list[str] | None = None, metadata: dict | None = None) -> None:
65 def artifact( 66 self, 67 path: Path, 68 type: str, 69 aliases: list[str] | None = None, 70 metadata: dict | None = None, 71 ) -> None: 72 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type) 73 artifact.add_file(str(path)) 74 if metadata: 75 artifact.description = json.dumps( 76 dict( 77 timestamp=datetime.datetime.now().isoformat(), 78 path=path.as_posix(), 79 type=type, 80 aliases=aliases, 81 metadata=metadata if metadata else {}, 82 ) 83 ) 84 self._run.log_artifact(artifact, aliases=aliases)
log an artifact from a file
url: str
86 @property 87 def url(self) -> str: 88 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass 89 return str(self._run.get_url())
Get the URL for the current logging run