Coverage for trnbl\loggers\tensorboard.py: 98%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1import datetime 

2from typing import Any 

3from pathlib import Path 

4import json 

5import hashlib 

6 

7from torch.utils.tensorboard import SummaryWriter 

8 

9from trnbl.loggers.base import TrainingLoggerBase, rand_syllabic_string 

10 

11 

12class TensorBoardLogger(TrainingLoggerBase): 

13 def __init__( 

14 self, 

15 log_dir: str | Path, 

16 train_config: dict | None = None, 

17 name: str | None = None, 

18 **kwargs, 

19 ) -> None: 

20 train_config_json: str = json.dumps(train_config, indent="\t") 

21 

22 if name is None: 

23 _run_hash: str = hashlib.md5(train_config_json.encode()).hexdigest() 

24 name = f"h{_run_hash[:5]}-{datetime.datetime.now().strftime('%y%m%d_%H%M')}-{rand_syllabic_string()}" 

25 

26 log_dir = (Path(log_dir) / name).as_posix() 

27 

28 # Initialize the TensorBoard SummaryWriter with the specified log directory 

29 self._writer: SummaryWriter = SummaryWriter(log_dir=log_dir, **kwargs) 

30 

31 # Store the run path 

32 self._run_path: Path = Path(log_dir) 

33 

34 # Initialize the global step counter 

35 self._global_step: int = 0 

36 

37 # Log the training configuration 

38 self._self_writer_add_text("config", train_config_json) 

39 self._self_writer_add_text("name", name) 

40 with open(self._run_path / "config.json", "w") as f: 

41 f.write(train_config_json) 

42 

43 def _self_writer_add_text(self, tag: str, message: str, **kwargs) -> None: 

44 self._writer.add_text( 

45 tag, 

46 message + ("" if not kwargs else "\n" + json.dumps(kwargs, indent=4)), 

47 global_step=self._global_step, 

48 ) 

49 

50 def debug(self, message: str, **kwargs) -> None: 

51 self._self_writer_add_text("debug", message, **kwargs) 

52 

53 def message(self, message: str, **kwargs) -> None: 

54 self._self_writer_add_text("message", message, **kwargs) 

55 

56 # Also print the message 

57 print(message) 

58 

59 def metrics(self, data: dict[str, Any]) -> None: 

60 # Log a dictionary of metrics using add_scalar in TensorBoard 

61 for key, value in data.items(): 

62 self._writer.add_scalar(key, value, global_step=self._global_step) 

63 

64 # Increment the global step counter 

65 self._global_step += 1 

66 

67 def artifact( 

68 self, 

69 path: Path, 

70 type: str, 

71 aliases: list[str] | None = None, 

72 metadata: dict | None = None, 

73 ) -> None: 

74 # Log an artifact file using add_artifact in TensorBoard 

75 self._writer.add_text( 

76 tag="artifact", 

77 text_string=json.dumps( 

78 dict( 

79 timestamp=datetime.datetime.now().isoformat(), 

80 path=path.as_posix(), 

81 type=type, 

82 aliases=aliases, 

83 metadata=metadata if metadata else {}, 

84 ) 

85 ), 

86 global_step=self._global_step, 

87 ) 

88 

89 @property 

90 def url(self) -> str: 

91 # Return the command to launch TensorBoard with the specified log directory 

92 return f"tensorboard --logdir={self._run_path}" 

93 

94 @property 

95 def run_path(self) -> Path: 

96 # Return the run path 

97 return self._run_path 

98 

99 def flush(self) -> None: 

100 self._writer.flush() 

101 

102 def finish(self) -> None: 

103 # Flush and close the TensorBoard SummaryWriter 

104 self._writer.flush() 

105 self._writer.close()