Coverage for trnbl\loggers\base.py: 72%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1from abc import ABC, abstractmethod 

2from typing import Any 

3from pathlib import Path 

4import warnings 

5import time 

6import random 

7 

8from muutils.spinner import Spinner 

9 

10# gpu utils 

11GPU_UTILS_AVAILABLE: bool 

12try: 

13 import GPUtil # type: ignore[import-untyped] 

14 

15 GPU_UTILS_AVAILABLE = True 

16except ImportError as e: 

17 warnings.warn(f"GPUtil not available: {e}") 

18 GPU_UTILS_AVAILABLE = False 

19 

20# psutil 

21PSUTIL_AVAILABLE: bool 

22try: 

23 import psutil # type: ignore[import-untyped] 

24 

25 PSUTIL_AVAILABLE = True 

26except ImportError as e: 

27 warnings.warn(f"psutil not available: {e}") 

28 PSUTIL_AVAILABLE = False 

29 

30 

31VOWELS: str = "aeiou" 

32CONSONANTS: str = "bcdfghjklmnpqrstvwxyz" 

33 

34 

35def rand_syllabic_string(length: int = 6) -> str: 

36 """Generate a random string of alternating consonants and vowels to use as a unique identifier 

37 

38 for a length of 2n, there are about 10^{2n} possible strings 

39 

40 default is 6 characters, which gives 10^6 possible strings 

41 """ 

42 string: str = "" 

43 for i in range(length): 

44 if i % 2 == 0: 

45 string += random.choice(CONSONANTS) 

46 else: 

47 string += random.choice(VOWELS) 

48 return string 

49 

50 

51class LoggerSpinner(Spinner): 

52 "see `Spinner` for parameters. catches `update_value` and passes it to the `LocalLogger`" 

53 

54 def __init__( 

55 self, 

56 *args, 

57 logger: "TrainingLoggerBase", 

58 **kwargs, 

59 ): 

60 super().__init__(*args, **kwargs) 

61 self.logger: "TrainingLoggerBase" = logger 

62 

63 def update_value(self, value: Any) -> None: 

64 """update the value of the spinner and log it""" 

65 self.logger.message( 

66 message=self.message, 

67 spinner_value=value, 

68 spinner_elapsed_time=time.time() - self.start_time, 

69 ) 

70 super().update_value(value) 

71 # the above should just be calling `self.current_value = value` 

72 

73 def __enter__(self) -> "LoggerSpinner": 

74 self.start() 

75 return self 

76 

77 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

78 self.stop() 

79 

80 

81class TrainingLoggerBase(ABC): 

82 """Base class for training loggers""" 

83 

84 @abstractmethod 

85 def debug(self, message: str, **kwargs) -> None: 

86 """log a debug message which will be saved, but not printed""" 

87 pass 

88 

89 @abstractmethod 

90 def message(self, message: str, **kwargs) -> None: 

91 """log a progress message, which will be printed to stdout""" 

92 pass 

93 

94 def warning(self, message: str, **kwargs) -> None: 

95 """log a warning message, which will be printed to stderr""" 

96 self.message(f"WARNING: {message}", __warning__=True, **kwargs) 

97 

98 def error(self, message: str, **kwargs) -> None: 

99 """log an error message""" 

100 self.message(f"ERROR: {message}", __error__=True, **kwargs) 

101 

102 @abstractmethod 

103 def metrics(self, data: dict[str, Any]) -> None: 

104 """Log a dictionary of metrics""" 

105 pass 

106 

107 @abstractmethod 

108 def artifact( 

109 self, 

110 path: Path, 

111 type: str, 

112 aliases: list[str] | None = None, 

113 metadata: dict | None = None, 

114 ) -> None: 

115 """log an artifact from a file""" 

116 pass 

117 

118 @property 

119 @abstractmethod 

120 def url(self) -> str | list[str]: 

121 """Get the URL for the current logging run""" 

122 pass 

123 

124 @property 

125 @abstractmethod 

126 def run_path(self) -> Path | list[Path]: 

127 """Get the path to the current logging run""" 

128 pass 

129 

130 @abstractmethod 

131 def flush(self) -> None: 

132 """Flush the logger""" 

133 pass 

134 

135 @abstractmethod 

136 def finish(self) -> None: 

137 """Finish logging""" 

138 pass 

139 

140 def get_mem_usage(self) -> dict: 

141 mem_usage: dict = {} 

142 

143 try: 

144 # CPU/Memory usage (if available) 

145 if PSUTIL_AVAILABLE: 

146 cpu_percent = psutil.cpu_percent() 

147 mem_usage["cpu/percent"] = cpu_percent 

148 

149 # Memory usage 

150 virtual_mem = psutil.virtual_memory() 

151 mem_usage["ram/used"] = virtual_mem.used 

152 mem_usage["ram/percent"] = virtual_mem.percent 

153 

154 # GPU information (if available) 

155 if GPU_UTILS_AVAILABLE: 

156 gpus = GPUtil.getGPUs() 

157 for gpu in gpus: 

158 gpu_id = gpu.id 

159 mem_usage[f"gpu:{gpu_id}/load"] = gpu.load 

160 mem_usage[f"gpu:{gpu_id}/memory_used"] = gpu.memoryUsed 

161 mem_usage[f"gpu:{gpu_id}/temperature"] = gpu.temperature 

162 except Exception as e: 

163 self.warning(f"Error getting memory usage: {e}") 

164 

165 return mem_usage 

166 

167 def spinner_task(self, **kwargs) -> LoggerSpinner: 

168 "Create a spinner task. kwargs are passed to `Spinner`." 

169 return LoggerSpinner(logger=self, **kwargs) 

170 

171 # def seq_task(self, **kwargs) -> LoggerSpinner: 

172 # "Create a sequential task with progress bar. kwargs are passed to `tqdm`." 

173 # return LoggerSpinner(message=message, logger=self, **kwargs)