Coverage for trnbl\loggers\tensorboard.py: 98%
45 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1import datetime
2from typing import Any
3from pathlib import Path
4import json
5import hashlib
7from torch.utils.tensorboard import SummaryWriter
9from trnbl.loggers.base import TrainingLoggerBase, rand_syllabic_string
12class TensorBoardLogger(TrainingLoggerBase):
13 def __init__(
14 self,
15 log_dir: str | Path,
16 train_config: dict | None = None,
17 name: str | None = None,
18 **kwargs,
19 ) -> None:
20 train_config_json: str = json.dumps(train_config, indent="\t")
22 if name is None:
23 _run_hash: str = hashlib.md5(train_config_json.encode()).hexdigest()
24 name = f"h{_run_hash[:5]}-{datetime.datetime.now().strftime('%y%m%d_%H%M')}-{rand_syllabic_string()}"
26 log_dir = (Path(log_dir) / name).as_posix()
28 # Initialize the TensorBoard SummaryWriter with the specified log directory
29 self._writer: SummaryWriter = SummaryWriter(log_dir=log_dir, **kwargs)
31 # Store the run path
32 self._run_path: Path = Path(log_dir)
34 # Initialize the global step counter
35 self._global_step: int = 0
37 # Log the training configuration
38 self._self_writer_add_text("config", train_config_json)
39 self._self_writer_add_text("name", name)
40 with open(self._run_path / "config.json", "w") as f:
41 f.write(train_config_json)
43 def _self_writer_add_text(self, tag: str, message: str, **kwargs) -> None:
44 self._writer.add_text(
45 tag,
46 message + ("" if not kwargs else "\n" + json.dumps(kwargs, indent=4)),
47 global_step=self._global_step,
48 )
50 def debug(self, message: str, **kwargs) -> None:
51 self._self_writer_add_text("debug", message, **kwargs)
53 def message(self, message: str, **kwargs) -> None:
54 self._self_writer_add_text("message", message, **kwargs)
56 # Also print the message
57 print(message)
59 def metrics(self, data: dict[str, Any]) -> None:
60 # Log a dictionary of metrics using add_scalar in TensorBoard
61 for key, value in data.items():
62 self._writer.add_scalar(key, value, global_step=self._global_step)
64 # Increment the global step counter
65 self._global_step += 1
67 def artifact(
68 self,
69 path: Path,
70 type: str,
71 aliases: list[str] | None = None,
72 metadata: dict | None = None,
73 ) -> None:
74 # Log an artifact file using add_artifact in TensorBoard
75 self._writer.add_text(
76 tag="artifact",
77 text_string=json.dumps(
78 dict(
79 timestamp=datetime.datetime.now().isoformat(),
80 path=path.as_posix(),
81 type=type,
82 aliases=aliases,
83 metadata=metadata if metadata else {},
84 )
85 ),
86 global_step=self._global_step,
87 )
89 @property
90 def url(self) -> str:
91 # Return the command to launch TensorBoard with the specified log directory
92 return f"tensorboard --logdir={self._run_path}"
94 @property
95 def run_path(self) -> Path:
96 # Return the run path
97 return self._run_path
99 def flush(self) -> None:
100 self._writer.flush()
102 def finish(self) -> None:
103 # Flush and close the TensorBoard SummaryWriter
104 self._writer.flush()
105 self._writer.close()