Coverage for tests\unit\loggers\test_local_logger.py: 100%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1import json 

2from pathlib import Path 

3 

4 

5from trnbl.loggers.local import LocalLogger, FilePaths 

6 

7# Temporary directory for testing 

8TEMP_PATH = Path("tests/_temp") 

9TEMP_PATH.mkdir(parents=True, exist_ok=True) 

10 

11 

12# Sample training configuration 

13train_config = {"batch_size": 32, "learning_rate": 0.001, "epochs": 10} 

14 

15 

16# Define the test functions 

17def test_logger_initialization(): 

18 logger = LocalLogger( 

19 project="test_project", 

20 metric_names=["accuracy", "loss"], 

21 group="test_run", 

22 train_config=train_config, 

23 base_path=TEMP_PATH, 

24 ) 

25 assert logger.project == "test_project" 

26 assert "accuracy" in logger.metric_names 

27 assert logger.train_config == train_config 

28 logger.finish() 

29 

30 

31def test_logger_message(): 

32 logger = LocalLogger( 

33 project="test_project", 

34 metric_names=["accuracy", "loss"], 

35 group="test_run", 

36 train_config=train_config, 

37 base_path=TEMP_PATH, 

38 ) 

39 logger.message("Test message") 

40 assert len(logger.log_list) > 0 

41 assert logger.log_list[-1]["message"] == "Test message" 

42 logger.finish() 

43 

44 

45def test_logger_metrics(): 

46 logger = LocalLogger( 

47 project="test_project", 

48 metric_names=["accuracy", "loss"], 

49 group="test_run", 

50 train_config=train_config, 

51 base_path=TEMP_PATH, 

52 ) 

53 logger.metrics({"accuracy": 0.95, "loss": 0.05}) 

54 assert len(logger.metrics_list) > 0 

55 assert logger.metrics_list[-1]["accuracy"] == 0.95 

56 assert logger.metrics_list[-1]["loss"] == 0.05 

57 logger.finish() 

58 

59 

60def test_logger_artifact(): 

61 logger = LocalLogger( 

62 project="test_project", 

63 metric_names=["accuracy", "loss"], 

64 group="test_run", 

65 train_config=train_config, 

66 base_path=TEMP_PATH, 

67 ) 

68 artifact_path = Path("tests/_temp/test_artifact.txt") 

69 artifact_path.parent.mkdir(parents=True, exist_ok=True) 

70 artifact_path.write_text("This is a test artifact.") 

71 logger.artifact(artifact_path, type="text") 

72 assert len(logger.artifacts_list) > 0 

73 assert logger.artifacts_list[-1]["path"] == artifact_path.as_posix() 

74 artifact_path.unlink() 

75 logger.finish() 

76 

77 

78def test_logger_files(): 

79 # Define a specific path for this test run 

80 test_run_path = TEMP_PATH / "test_run" 

81 test_run_path.mkdir(parents=True, exist_ok=True) 

82 

83 # Initialize the logger with the test run path 

84 logger = LocalLogger( 

85 project="test_project", 

86 metric_names=["accuracy", "loss"], 

87 group="test_run", 

88 train_config=train_config, 

89 base_path=test_run_path, 

90 ) 

91 

92 # Log message 

93 logger.message("Test message") 

94 logger.flush() 

95 log_file_path = logger.run_path / FilePaths.LOG 

96 print(f"{log_file_path = }") 

97 with open(log_file_path, "r") as log_file: 

98 log_content = log_file.readlines() 

99 print("Log file content:", log_content) # Debug print 

100 assert len(log_content) > 0 

101 assert "Test message" in log_content[-1] 

102 

103 # Log metrics 

104 logger.metrics({"accuracy": 0.95, "loss": 0.05}) 

105 logger.flush() 

106 metrics_file_path = logger.run_path / FilePaths.METRICS 

107 with open(metrics_file_path, "r") as metrics_file: 

108 metrics_content = metrics_file.readlines() 

109 print("Metrics file content:", metrics_content) # Debug print 

110 assert len(metrics_content) > 0 

111 metrics_data = json.loads(metrics_content[-1]) 

112 assert metrics_data["accuracy"] == 0.95 

113 assert metrics_data["loss"] == 0.05 

114 

115 # Log artifact 

116 artifact_path = logger.run_path / "test_artifact.txt" 

117 artifact_path.write_text("This is a test artifact.") 

118 logger.artifact(artifact_path, type="text") 

119 logger.flush() 

120 artifacts_file_path = logger.run_path / FilePaths.ARTIFACTS 

121 with open(artifacts_file_path, "r") as artifacts_file: 

122 artifacts_content = artifacts_file.readlines() 

123 print("Artifacts file content:", artifacts_content) # Debug print 

124 assert len(artifacts_content) > 0 

125 artifact_data = json.loads(artifacts_content[-1]) 

126 assert artifact_data["path"] == artifact_path.as_posix() 

127 

128 # Clean up after tests 

129 logger.finish()