Coverage for tests\unit\loggers\test_tensorboard_logger.py: 100%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1import json 

2import math 

3from pathlib import Path 

4from datetime import datetime 

5 

6from tensorboard.backend.event_processing.event_file_loader import EventFileLoader # type: ignore[import-untyped] 

7 

8from trnbl.loggers.tensorboard import TensorBoardLogger 

9 

10# Temporary directory for testing 

11TEMP_PATH = Path("tests/_temp/tensorboard") 

12TEMP_PATH.mkdir(parents=True, exist_ok=True) 

13 

14 

15def read_event_file(event_file): 

16 loader = EventFileLoader(event_file.as_posix()) 

17 for event in loader.Load(): 

18 yield event 

19 

20 

21def test_tensorboard_logger_initialization(): 

22 log_dir = TEMP_PATH / "init" 

23 logger = TensorBoardLogger(log_dir=log_dir) 

24 assert logger._run_path.as_posix().startswith(log_dir.as_posix()) 

25 logger.finish() 

26 

27 

28def test_tensorboard_logger_message(): 

29 log_dir = TEMP_PATH / "message" 

30 logger = TensorBoardLogger(log_dir=log_dir) 

31 

32 logger.message("Test message", key="value") 

33 logger.flush() 

34 

35 event_files = list(Path(log_dir).glob("**/events.out.tfevents.*")) 

36 print(f"{event_files = }") 

37 

38 assert len(event_files) > 0 # Ensure at least one event file is created 

39 

40 logger.flush() 

41 

42 # Inspect the content of the event files 

43 found_content: bool = False 

44 for event_file in event_files: 

45 print(f"{event_file = }") 

46 for event in read_event_file(event_file): 

47 print(f"\t{event = }") 

48 if event.HasField("summary"): 

49 for value in event.summary.value: 

50 if value.tag == "message/text_summary": 

51 message = value.tensor.string_val[0].decode("utf-8") 

52 assert "Test message" in message 

53 assert '"key": "value"' in message 

54 found_content = True 

55 

56 assert found_content 

57 

58 logger.finish() 

59 

60 

61def test_tensorboard_logger_metrics(): 

62 log_dir = TEMP_PATH / "metrics" 

63 logger = TensorBoardLogger(log_dir=log_dir) 

64 

65 metrics_data = {"accuracy": 0.95, "loss": 0.05} 

66 logger.metrics(metrics_data) 

67 

68 event_files = list(Path(log_dir).glob("**/events.out.tfevents.*")) 

69 assert len(event_files) > 0 # Ensure at least one event file is created 

70 

71 logger.flush() 

72 

73 # Inspect the content of the event files 

74 found_accuracy: bool = False 

75 found_loss: bool = False 

76 for event_file in event_files: 

77 print(f"{event_file = }") 

78 print("====================================") 

79 for event in read_event_file(event_file): 

80 print(f"{event = }") 

81 if event.HasField("summary"): 

82 for value in event.summary.value: 

83 if value.tag == "accuracy": 

84 found_accuracy = True 

85 assert math.isclose( 

86 value.tensor.float_val[0], 0.95, rel_tol=1e-5 

87 ) 

88 if value.tag == "loss": 

89 found_loss = True 

90 assert math.isclose( 

91 value.tensor.float_val[0], 0.05, rel_tol=1e-5 

92 ) 

93 

94 assert found_accuracy 

95 assert found_loss 

96 

97 logger.finish() 

98 

99 

100def test_tensorboard_logger_artifact(): 

101 log_dir = TEMP_PATH / "artifact" 

102 logger = TensorBoardLogger(log_dir=log_dir) 

103 

104 artifact_path = Path(TEMP_PATH / "test_artifact.txt") 

105 artifact_path.write_text("This is a test artifact.") 

106 

107 logger.artifact( 

108 artifact_path, type="text", aliases=["alias1"], metadata={"key": "value"} 

109 ) 

110 logger.flush() 

111 

112 event_files = list(Path(log_dir).glob("**/events.out.tfevents.*")) 

113 assert len(event_files) > 0 # Ensure at least one event file is created 

114 

115 logger.flush() 

116 

117 # Inspect the content of the event files 

118 found_artifact = False 

119 for event_file in event_files: 

120 print(f"{event_file = }") 

121 for event in read_event_file(event_file): 

122 print(f"{event = }") 

123 if hasattr(event, "summary"): 

124 for value in event.summary.value: 

125 if value.tag == "artifact/text_summary": 

126 found_artifact = True 

127 artifact_data = json.loads(value.tensor.string_val[0]) 

128 

129 # Check artifact data 

130 assert ( 

131 Path(artifact_data["path"]).as_posix() 

132 == artifact_path.as_posix() 

133 ), f"Expected path {artifact_path}, got {artifact_data['path']}" 

134 assert artifact_data["type"] == "text", ( 

135 f"Expected type 'text', got {artifact_data['type']}" 

136 ) 

137 assert artifact_data["aliases"] == ["alias1"], ( 

138 f"Expected aliases ['alias1'], got {artifact_data['aliases']}" 

139 ) 

140 assert artifact_data["metadata"] == {"key": "value"}, ( 

141 f"Expected metadata { 'key': 'value'} , got {artifact_data['metadata']}" 

142 ) 

143 

144 # Check timestamp format (assuming it's ISO format) 

145 timestamp_datetime = datetime.fromisoformat( 

146 artifact_data["timestamp"] 

147 ) 

148 # Check if the timestamp is within the last 5 minutes 

149 assert ( 

150 datetime.now() - timestamp_datetime 

151 ).total_seconds() < 300, ( 

152 f"Timestamp is not within the last 5 minutes: {timestamp_datetime}" 

153 ) 

154 

155 assert found_artifact, "Artifact data not found in event files" 

156 logger.finish() 

157 

158 

159def test_tensorboard_logger_url(): 

160 log_dir = TEMP_PATH / "url" 

161 logger = TensorBoardLogger(log_dir=log_dir) 

162 assert logger.url == f"tensorboard --logdir={logger._run_path}" 

163 logger.finish() 

164 

165 

166def test_tensorboard_logger_run_path(): 

167 log_dir = TEMP_PATH / "run_path" 

168 logger = TensorBoardLogger(log_dir=log_dir) 

169 assert logger.run_path.as_posix().startswith(log_dir.as_posix()) 

170 logger.finish()