Coverage for tests\unit\loggers\test_tensorboard_logger.py: 100%
104 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1import json
2import math
3from pathlib import Path
4from datetime import datetime
6from tensorboard.backend.event_processing.event_file_loader import EventFileLoader # type: ignore[import-untyped]
8from trnbl.loggers.tensorboard import TensorBoardLogger
10# Temporary directory for testing
11TEMP_PATH = Path("tests/_temp/tensorboard")
12TEMP_PATH.mkdir(parents=True, exist_ok=True)
15def read_event_file(event_file):
16 loader = EventFileLoader(event_file.as_posix())
17 for event in loader.Load():
18 yield event
21def test_tensorboard_logger_initialization():
22 log_dir = TEMP_PATH / "init"
23 logger = TensorBoardLogger(log_dir=log_dir)
24 assert logger._run_path.as_posix().startswith(log_dir.as_posix())
25 logger.finish()
28def test_tensorboard_logger_message():
29 log_dir = TEMP_PATH / "message"
30 logger = TensorBoardLogger(log_dir=log_dir)
32 logger.message("Test message", key="value")
33 logger.flush()
35 event_files = list(Path(log_dir).glob("**/events.out.tfevents.*"))
36 print(f"{event_files = }")
38 assert len(event_files) > 0 # Ensure at least one event file is created
40 logger.flush()
42 # Inspect the content of the event files
43 found_content: bool = False
44 for event_file in event_files:
45 print(f"{event_file = }")
46 for event in read_event_file(event_file):
47 print(f"\t{event = }")
48 if event.HasField("summary"):
49 for value in event.summary.value:
50 if value.tag == "message/text_summary":
51 message = value.tensor.string_val[0].decode("utf-8")
52 assert "Test message" in message
53 assert '"key": "value"' in message
54 found_content = True
56 assert found_content
58 logger.finish()
61def test_tensorboard_logger_metrics():
62 log_dir = TEMP_PATH / "metrics"
63 logger = TensorBoardLogger(log_dir=log_dir)
65 metrics_data = {"accuracy": 0.95, "loss": 0.05}
66 logger.metrics(metrics_data)
68 event_files = list(Path(log_dir).glob("**/events.out.tfevents.*"))
69 assert len(event_files) > 0 # Ensure at least one event file is created
71 logger.flush()
73 # Inspect the content of the event files
74 found_accuracy: bool = False
75 found_loss: bool = False
76 for event_file in event_files:
77 print(f"{event_file = }")
78 print("====================================")
79 for event in read_event_file(event_file):
80 print(f"{event = }")
81 if event.HasField("summary"):
82 for value in event.summary.value:
83 if value.tag == "accuracy":
84 found_accuracy = True
85 assert math.isclose(
86 value.tensor.float_val[0], 0.95, rel_tol=1e-5
87 )
88 if value.tag == "loss":
89 found_loss = True
90 assert math.isclose(
91 value.tensor.float_val[0], 0.05, rel_tol=1e-5
92 )
94 assert found_accuracy
95 assert found_loss
97 logger.finish()
100def test_tensorboard_logger_artifact():
101 log_dir = TEMP_PATH / "artifact"
102 logger = TensorBoardLogger(log_dir=log_dir)
104 artifact_path = Path(TEMP_PATH / "test_artifact.txt")
105 artifact_path.write_text("This is a test artifact.")
107 logger.artifact(
108 artifact_path, type="text", aliases=["alias1"], metadata={"key": "value"}
109 )
110 logger.flush()
112 event_files = list(Path(log_dir).glob("**/events.out.tfevents.*"))
113 assert len(event_files) > 0 # Ensure at least one event file is created
115 logger.flush()
117 # Inspect the content of the event files
118 found_artifact = False
119 for event_file in event_files:
120 print(f"{event_file = }")
121 for event in read_event_file(event_file):
122 print(f"{event = }")
123 if hasattr(event, "summary"):
124 for value in event.summary.value:
125 if value.tag == "artifact/text_summary":
126 found_artifact = True
127 artifact_data = json.loads(value.tensor.string_val[0])
129 # Check artifact data
130 assert (
131 Path(artifact_data["path"]).as_posix()
132 == artifact_path.as_posix()
133 ), f"Expected path {artifact_path}, got {artifact_data['path']}"
134 assert artifact_data["type"] == "text", (
135 f"Expected type 'text', got {artifact_data['type']}"
136 )
137 assert artifact_data["aliases"] == ["alias1"], (
138 f"Expected aliases ['alias1'], got {artifact_data['aliases']}"
139 )
140 assert artifact_data["metadata"] == {"key": "value"}, (
141 f"Expected metadata { 'key': 'value'} , got {artifact_data['metadata']}"
142 )
144 # Check timestamp format (assuming it's ISO format)
145 timestamp_datetime = datetime.fromisoformat(
146 artifact_data["timestamp"]
147 )
148 # Check if the timestamp is within the last 5 minutes
149 assert (
150 datetime.now() - timestamp_datetime
151 ).total_seconds() < 300, (
152 f"Timestamp is not within the last 5 minutes: {timestamp_datetime}"
153 )
155 assert found_artifact, "Artifact data not found in event files"
156 logger.finish()
159def test_tensorboard_logger_url():
160 log_dir = TEMP_PATH / "url"
161 logger = TensorBoardLogger(log_dir=log_dir)
162 assert logger.url == f"tensorboard --logdir={logger._run_path}"
163 logger.finish()
166def test_tensorboard_logger_run_path():
167 log_dir = TEMP_PATH / "run_path"
168 logger = TensorBoardLogger(log_dir=log_dir)
169 assert logger.run_path.as_posix().startswith(log_dir.as_posix())
170 logger.finish()