Coverage for tests\unit\loggers\test_local_logger.py: 100%
70 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1import json
2from pathlib import Path
5from trnbl.loggers.local import LocalLogger, FilePaths
7# Temporary directory for testing
8TEMP_PATH = Path("tests/_temp")
9TEMP_PATH.mkdir(parents=True, exist_ok=True)
12# Sample training configuration
13train_config = {"batch_size": 32, "learning_rate": 0.001, "epochs": 10}
16# Define the test functions
17def test_logger_initialization():
18 logger = LocalLogger(
19 project="test_project",
20 metric_names=["accuracy", "loss"],
21 group="test_run",
22 train_config=train_config,
23 base_path=TEMP_PATH,
24 )
25 assert logger.project == "test_project"
26 assert "accuracy" in logger.metric_names
27 assert logger.train_config == train_config
28 logger.finish()
31def test_logger_message():
32 logger = LocalLogger(
33 project="test_project",
34 metric_names=["accuracy", "loss"],
35 group="test_run",
36 train_config=train_config,
37 base_path=TEMP_PATH,
38 )
39 logger.message("Test message")
40 assert len(logger.log_list) > 0
41 assert logger.log_list[-1]["message"] == "Test message"
42 logger.finish()
45def test_logger_metrics():
46 logger = LocalLogger(
47 project="test_project",
48 metric_names=["accuracy", "loss"],
49 group="test_run",
50 train_config=train_config,
51 base_path=TEMP_PATH,
52 )
53 logger.metrics({"accuracy": 0.95, "loss": 0.05})
54 assert len(logger.metrics_list) > 0
55 assert logger.metrics_list[-1]["accuracy"] == 0.95
56 assert logger.metrics_list[-1]["loss"] == 0.05
57 logger.finish()
60def test_logger_artifact():
61 logger = LocalLogger(
62 project="test_project",
63 metric_names=["accuracy", "loss"],
64 group="test_run",
65 train_config=train_config,
66 base_path=TEMP_PATH,
67 )
68 artifact_path = Path("tests/_temp/test_artifact.txt")
69 artifact_path.parent.mkdir(parents=True, exist_ok=True)
70 artifact_path.write_text("This is a test artifact.")
71 logger.artifact(artifact_path, type="text")
72 assert len(logger.artifacts_list) > 0
73 assert logger.artifacts_list[-1]["path"] == artifact_path.as_posix()
74 artifact_path.unlink()
75 logger.finish()
78def test_logger_files():
79 # Define a specific path for this test run
80 test_run_path = TEMP_PATH / "test_run"
81 test_run_path.mkdir(parents=True, exist_ok=True)
83 # Initialize the logger with the test run path
84 logger = LocalLogger(
85 project="test_project",
86 metric_names=["accuracy", "loss"],
87 group="test_run",
88 train_config=train_config,
89 base_path=test_run_path,
90 )
92 # Log message
93 logger.message("Test message")
94 logger.flush()
95 log_file_path = logger.run_path / FilePaths.LOG
96 print(f"{log_file_path = }")
97 with open(log_file_path, "r") as log_file:
98 log_content = log_file.readlines()
99 print("Log file content:", log_content) # Debug print
100 assert len(log_content) > 0
101 assert "Test message" in log_content[-1]
103 # Log metrics
104 logger.metrics({"accuracy": 0.95, "loss": 0.05})
105 logger.flush()
106 metrics_file_path = logger.run_path / FilePaths.METRICS
107 with open(metrics_file_path, "r") as metrics_file:
108 metrics_content = metrics_file.readlines()
109 print("Metrics file content:", metrics_content) # Debug print
110 assert len(metrics_content) > 0
111 metrics_data = json.loads(metrics_content[-1])
112 assert metrics_data["accuracy"] == 0.95
113 assert metrics_data["loss"] == 0.05
115 # Log artifact
116 artifact_path = logger.run_path / "test_artifact.txt"
117 artifact_path.write_text("This is a test artifact.")
118 logger.artifact(artifact_path, type="text")
119 logger.flush()
120 artifacts_file_path = logger.run_path / FilePaths.ARTIFACTS
121 with open(artifacts_file_path, "r") as artifacts_file:
122 artifacts_content = artifacts_file.readlines()
123 print("Artifacts file content:", artifacts_content) # Debug print
124 assert len(artifacts_content) > 0
125 artifact_data = json.loads(artifacts_content[-1])
126 assert artifact_data["path"] == artifact_path.as_posix()
128 # Clean up after tests
129 logger.finish()