Coverage for trnbl\loggers\base.py: 72%
94 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1from abc import ABC, abstractmethod
2from typing import Any
3from pathlib import Path
4import warnings
5import time
6import random
8from muutils.spinner import Spinner
10# gpu utils
11GPU_UTILS_AVAILABLE: bool
12try:
13 import GPUtil # type: ignore[import-untyped]
15 GPU_UTILS_AVAILABLE = True
16except ImportError as e:
17 warnings.warn(f"GPUtil not available: {e}")
18 GPU_UTILS_AVAILABLE = False
20# psutil
21PSUTIL_AVAILABLE: bool
22try:
23 import psutil # type: ignore[import-untyped]
25 PSUTIL_AVAILABLE = True
26except ImportError as e:
27 warnings.warn(f"psutil not available: {e}")
28 PSUTIL_AVAILABLE = False
31VOWELS: str = "aeiou"
32CONSONANTS: str = "bcdfghjklmnpqrstvwxyz"
35def rand_syllabic_string(length: int = 6) -> str:
36 """Generate a random string of alternating consonants and vowels to use as a unique identifier
38 for a length of 2n, there are about 10^{2n} possible strings
40 default is 6 characters, which gives 10^6 possible strings
41 """
42 string: str = ""
43 for i in range(length):
44 if i % 2 == 0:
45 string += random.choice(CONSONANTS)
46 else:
47 string += random.choice(VOWELS)
48 return string
51class LoggerSpinner(Spinner):
52 "see `Spinner` for parameters. catches `update_value` and passes it to the `LocalLogger`"
54 def __init__(
55 self,
56 *args,
57 logger: "TrainingLoggerBase",
58 **kwargs,
59 ):
60 super().__init__(*args, **kwargs)
61 self.logger: "TrainingLoggerBase" = logger
63 def update_value(self, value: Any) -> None:
64 """update the value of the spinner and log it"""
65 self.logger.message(
66 message=self.message,
67 spinner_value=value,
68 spinner_elapsed_time=time.time() - self.start_time,
69 )
70 super().update_value(value)
71 # the above should just be calling `self.current_value = value`
73 def __enter__(self) -> "LoggerSpinner":
74 self.start()
75 return self
77 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
78 self.stop()
81class TrainingLoggerBase(ABC):
82 """Base class for training loggers"""
84 @abstractmethod
85 def debug(self, message: str, **kwargs) -> None:
86 """log a debug message which will be saved, but not printed"""
87 pass
89 @abstractmethod
90 def message(self, message: str, **kwargs) -> None:
91 """log a progress message, which will be printed to stdout"""
92 pass
94 def warning(self, message: str, **kwargs) -> None:
95 """log a warning message, which will be printed to stderr"""
96 self.message(f"WARNING: {message}", __warning__=True, **kwargs)
98 def error(self, message: str, **kwargs) -> None:
99 """log an error message"""
100 self.message(f"ERROR: {message}", __error__=True, **kwargs)
102 @abstractmethod
103 def metrics(self, data: dict[str, Any]) -> None:
104 """Log a dictionary of metrics"""
105 pass
107 @abstractmethod
108 def artifact(
109 self,
110 path: Path,
111 type: str,
112 aliases: list[str] | None = None,
113 metadata: dict | None = None,
114 ) -> None:
115 """log an artifact from a file"""
116 pass
118 @property
119 @abstractmethod
120 def url(self) -> str | list[str]:
121 """Get the URL for the current logging run"""
122 pass
124 @property
125 @abstractmethod
126 def run_path(self) -> Path | list[Path]:
127 """Get the path to the current logging run"""
128 pass
130 @abstractmethod
131 def flush(self) -> None:
132 """Flush the logger"""
133 pass
135 @abstractmethod
136 def finish(self) -> None:
137 """Finish logging"""
138 pass
140 def get_mem_usage(self) -> dict:
141 mem_usage: dict = {}
143 try:
144 # CPU/Memory usage (if available)
145 if PSUTIL_AVAILABLE:
146 cpu_percent = psutil.cpu_percent()
147 mem_usage["cpu/percent"] = cpu_percent
149 # Memory usage
150 virtual_mem = psutil.virtual_memory()
151 mem_usage["ram/used"] = virtual_mem.used
152 mem_usage["ram/percent"] = virtual_mem.percent
154 # GPU information (if available)
155 if GPU_UTILS_AVAILABLE:
156 gpus = GPUtil.getGPUs()
157 for gpu in gpus:
158 gpu_id = gpu.id
159 mem_usage[f"gpu:{gpu_id}/load"] = gpu.load
160 mem_usage[f"gpu:{gpu_id}/memory_used"] = gpu.memoryUsed
161 mem_usage[f"gpu:{gpu_id}/temperature"] = gpu.temperature
162 except Exception as e:
163 self.warning(f"Error getting memory usage: {e}")
165 return mem_usage
167 def spinner_task(self, **kwargs) -> LoggerSpinner:
168 "Create a spinner task. kwargs are passed to `Spinner`."
169 return LoggerSpinner(logger=self, **kwargs)
171 # def seq_task(self, **kwargs) -> LoggerSpinner:
172 # "Create a sequential task with progress bar. kwargs are passed to `tqdm`."
173 # return LoggerSpinner(message=message, logger=self, **kwargs)