docs for trnbl v0.1.1
View Source on GitHub

trnbl.loggers.base


  1from abc import ABC, abstractmethod
  2from typing import Any
  3from pathlib import Path
  4import warnings
  5import time
  6import random
  7
  8from muutils.spinner import Spinner
  9
 10# gpu utils
 11GPU_UTILS_AVAILABLE: bool
 12try:
 13	import GPUtil  # type: ignore[import-untyped]
 14
 15	GPU_UTILS_AVAILABLE = True
 16except ImportError as e:
 17	warnings.warn(f"GPUtil not available: {e}")
 18	GPU_UTILS_AVAILABLE = False
 19
 20# psutil
 21PSUTIL_AVAILABLE: bool
 22try:
 23	import psutil  # type: ignore[import-untyped]
 24
 25	PSUTIL_AVAILABLE = True
 26except ImportError as e:
 27	warnings.warn(f"psutil not available: {e}")
 28	PSUTIL_AVAILABLE = False
 29
 30
 31VOWELS: str = "aeiou"
 32CONSONANTS: str = "bcdfghjklmnpqrstvwxyz"
 33
 34
 35def rand_syllabic_string(length: int = 6) -> str:
 36	"""Generate a random string of alternating consonants and vowels to use as a unique identifier
 37
 38	for a length of 2n, there are about 10^{2n} possible strings
 39
 40	default is 6 characters, which gives 10^6 possible strings
 41	"""
 42	string: str = ""
 43	for i in range(length):
 44		if i % 2 == 0:
 45			string += random.choice(CONSONANTS)
 46		else:
 47			string += random.choice(VOWELS)
 48	return string
 49
 50
 51class LoggerSpinner(Spinner):
 52	"see `Spinner` for parameters. catches `update_value` and passes it to the `LocalLogger`"
 53
 54	def __init__(
 55		self,
 56		*args,
 57		logger: "TrainingLoggerBase",
 58		**kwargs,
 59	):
 60		super().__init__(*args, **kwargs)
 61		self.logger: "TrainingLoggerBase" = logger
 62
 63	def update_value(self, value: Any) -> None:
 64		"""update the value of the spinner and log it"""
 65		self.logger.message(
 66			message=self.message,
 67			spinner_value=value,
 68			spinner_elapsed_time=time.time() - self.start_time,
 69		)
 70		super().update_value(value)
 71		# the above should just be calling `self.current_value = value`
 72
 73	def __enter__(self) -> "LoggerSpinner":
 74		self.start()
 75		return self
 76
 77	def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
 78		self.stop()
 79
 80
 81class TrainingLoggerBase(ABC):
 82	"""Base class for training loggers"""
 83
 84	@abstractmethod
 85	def debug(self, message: str, **kwargs) -> None:
 86		"""log a debug message which will be saved, but not printed"""
 87		pass
 88
 89	@abstractmethod
 90	def message(self, message: str, **kwargs) -> None:
 91		"""log a progress message, which will be printed to stdout"""
 92		pass
 93
 94	def warning(self, message: str, **kwargs) -> None:
 95		"""log a warning message, which will be printed to stderr"""
 96		self.message(f"WARNING: {message}", __warning__=True, **kwargs)
 97
 98	def error(self, message: str, **kwargs) -> None:
 99		"""log an error message"""
100		self.message(f"ERROR: {message}", __error__=True, **kwargs)
101
102	@abstractmethod
103	def metrics(self, data: dict[str, Any]) -> None:
104		"""Log a dictionary of metrics"""
105		pass
106
107	@abstractmethod
108	def artifact(
109		self,
110		path: Path,
111		type: str,
112		aliases: list[str] | None = None,
113		metadata: dict | None = None,
114	) -> None:
115		"""log an artifact from a file"""
116		pass
117
118	@property
119	@abstractmethod
120	def url(self) -> str | list[str]:
121		"""Get the URL for the current logging run"""
122		pass
123
124	@property
125	@abstractmethod
126	def run_path(self) -> Path | list[Path]:
127		"""Get the path to the current logging run"""
128		pass
129
130	@abstractmethod
131	def flush(self) -> None:
132		"""Flush the logger"""
133		pass
134
135	@abstractmethod
136	def finish(self) -> None:
137		"""Finish logging"""
138		pass
139
140	def get_mem_usage(self) -> dict:
141		mem_usage: dict = {}
142
143		try:
144			# CPU/Memory usage (if available)
145			if PSUTIL_AVAILABLE:
146				cpu_percent = psutil.cpu_percent()
147				mem_usage["cpu/percent"] = cpu_percent
148
149				# Memory usage
150				virtual_mem = psutil.virtual_memory()
151				mem_usage["ram/used"] = virtual_mem.used
152				mem_usage["ram/percent"] = virtual_mem.percent
153
154			# GPU information (if available)
155			if GPU_UTILS_AVAILABLE:
156				gpus = GPUtil.getGPUs()
157				for gpu in gpus:
158					gpu_id = gpu.id
159					mem_usage[f"gpu:{gpu_id}/load"] = gpu.load
160					mem_usage[f"gpu:{gpu_id}/memory_used"] = gpu.memoryUsed
161					mem_usage[f"gpu:{gpu_id}/temperature"] = gpu.temperature
162		except Exception as e:
163			self.warning(f"Error getting memory usage: {e}")
164
165		return mem_usage
166
167	def spinner_task(self, **kwargs) -> LoggerSpinner:
168		"Create a spinner task. kwargs are passed to `Spinner`."
169		return LoggerSpinner(logger=self, **kwargs)
170
171	# def seq_task(self, **kwargs) -> LoggerSpinner:
172	# 	"Create a sequential task with progress bar. kwargs are passed to `tqdm`."
173	# 	return LoggerSpinner(message=message, logger=self, **kwargs)

GPU_UTILS_AVAILABLE: bool = True
PSUTIL_AVAILABLE: bool = True
VOWELS: str = 'aeiou'
CONSONANTS: str = 'bcdfghjklmnpqrstvwxyz'
def rand_syllabic_string(length: int = 6) -> str:
36def rand_syllabic_string(length: int = 6) -> str:
37	"""Generate a random string of alternating consonants and vowels to use as a unique identifier
38
39	for a length of 2n, there are about 10^{2n} possible strings
40
41	default is 6 characters, which gives 10^6 possible strings
42	"""
43	string: str = ""
44	for i in range(length):
45		if i % 2 == 0:
46			string += random.choice(CONSONANTS)
47		else:
48			string += random.choice(VOWELS)
49	return string

Generate a random string of alternating consonants and vowels to use as a unique identifier

for a length of 2n, there are about 10^{2n} possible strings

default is 6 characters, which gives 10^6 possible strings

class LoggerSpinner(muutils.spinner.Spinner):
52class LoggerSpinner(Spinner):
53	"see `Spinner` for parameters. catches `update_value` and passes it to the `LocalLogger`"
54
55	def __init__(
56		self,
57		*args,
58		logger: "TrainingLoggerBase",
59		**kwargs,
60	):
61		super().__init__(*args, **kwargs)
62		self.logger: "TrainingLoggerBase" = logger
63
64	def update_value(self, value: Any) -> None:
65		"""update the value of the spinner and log it"""
66		self.logger.message(
67			message=self.message,
68			spinner_value=value,
69			spinner_elapsed_time=time.time() - self.start_time,
70		)
71		super().update_value(value)
72		# the above should just be calling `self.current_value = value`
73
74	def __enter__(self) -> "LoggerSpinner":
75		self.start()
76		return self
77
78	def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
79		self.stop()

see Spinner for parameters. catches update_value and passes it to the LocalLogger

LoggerSpinner(*args, logger: TrainingLoggerBase, **kwargs)
55	def __init__(
56		self,
57		*args,
58		logger: "TrainingLoggerBase",
59		**kwargs,
60	):
61		super().__init__(*args, **kwargs)
62		self.logger: "TrainingLoggerBase" = logger
def update_value(self, value: Any) -> None:
64	def update_value(self, value: Any) -> None:
65		"""update the value of the spinner and log it"""
66		self.logger.message(
67			message=self.message,
68			spinner_value=value,
69			spinner_elapsed_time=time.time() - self.start_time,
70		)
71		super().update_value(value)
72		# the above should just be calling `self.current_value = value`

update the value of the spinner and log it

Inherited Members
muutils.spinner.Spinner
config
format_string_when_updated
update_interval
message
current_value
format_string
output_stream
start_time
stop_spinner
spinner_thread
value_changed
term_width
state
spin
start
stop
class TrainingLoggerBase(abc.ABC):
 82class TrainingLoggerBase(ABC):
 83	"""Base class for training loggers"""
 84
 85	@abstractmethod
 86	def debug(self, message: str, **kwargs) -> None:
 87		"""log a debug message which will be saved, but not printed"""
 88		pass
 89
 90	@abstractmethod
 91	def message(self, message: str, **kwargs) -> None:
 92		"""log a progress message, which will be printed to stdout"""
 93		pass
 94
 95	def warning(self, message: str, **kwargs) -> None:
 96		"""log a warning message, which will be printed to stderr"""
 97		self.message(f"WARNING: {message}", __warning__=True, **kwargs)
 98
 99	def error(self, message: str, **kwargs) -> None:
100		"""log an error message"""
101		self.message(f"ERROR: {message}", __error__=True, **kwargs)
102
103	@abstractmethod
104	def metrics(self, data: dict[str, Any]) -> None:
105		"""Log a dictionary of metrics"""
106		pass
107
108	@abstractmethod
109	def artifact(
110		self,
111		path: Path,
112		type: str,
113		aliases: list[str] | None = None,
114		metadata: dict | None = None,
115	) -> None:
116		"""log an artifact from a file"""
117		pass
118
119	@property
120	@abstractmethod
121	def url(self) -> str | list[str]:
122		"""Get the URL for the current logging run"""
123		pass
124
125	@property
126	@abstractmethod
127	def run_path(self) -> Path | list[Path]:
128		"""Get the path to the current logging run"""
129		pass
130
131	@abstractmethod
132	def flush(self) -> None:
133		"""Flush the logger"""
134		pass
135
136	@abstractmethod
137	def finish(self) -> None:
138		"""Finish logging"""
139		pass
140
141	def get_mem_usage(self) -> dict:
142		mem_usage: dict = {}
143
144		try:
145			# CPU/Memory usage (if available)
146			if PSUTIL_AVAILABLE:
147				cpu_percent = psutil.cpu_percent()
148				mem_usage["cpu/percent"] = cpu_percent
149
150				# Memory usage
151				virtual_mem = psutil.virtual_memory()
152				mem_usage["ram/used"] = virtual_mem.used
153				mem_usage["ram/percent"] = virtual_mem.percent
154
155			# GPU information (if available)
156			if GPU_UTILS_AVAILABLE:
157				gpus = GPUtil.getGPUs()
158				for gpu in gpus:
159					gpu_id = gpu.id
160					mem_usage[f"gpu:{gpu_id}/load"] = gpu.load
161					mem_usage[f"gpu:{gpu_id}/memory_used"] = gpu.memoryUsed
162					mem_usage[f"gpu:{gpu_id}/temperature"] = gpu.temperature
163		except Exception as e:
164			self.warning(f"Error getting memory usage: {e}")
165
166		return mem_usage
167
168	def spinner_task(self, **kwargs) -> LoggerSpinner:
169		"Create a spinner task. kwargs are passed to `Spinner`."
170		return LoggerSpinner(logger=self, **kwargs)
171
172	# def seq_task(self, **kwargs) -> LoggerSpinner:
173	# 	"Create a sequential task with progress bar. kwargs are passed to `tqdm`."
174	# 	return LoggerSpinner(message=message, logger=self, **kwargs)

Base class for training loggers

@abstractmethod
def debug(self, message: str, **kwargs) -> None:
85	@abstractmethod
86	def debug(self, message: str, **kwargs) -> None:
87		"""log a debug message which will be saved, but not printed"""
88		pass

log a debug message which will be saved, but not printed

@abstractmethod
def message(self, message: str, **kwargs) -> None:
90	@abstractmethod
91	def message(self, message: str, **kwargs) -> None:
92		"""log a progress message, which will be printed to stdout"""
93		pass

log a progress message, which will be printed to stdout

def warning(self, message: str, **kwargs) -> None:
95	def warning(self, message: str, **kwargs) -> None:
96		"""log a warning message, which will be printed to stderr"""
97		self.message(f"WARNING: {message}", __warning__=True, **kwargs)

log a warning message, which will be printed to stderr

def error(self, message: str, **kwargs) -> None:
 99	def error(self, message: str, **kwargs) -> None:
100		"""log an error message"""
101		self.message(f"ERROR: {message}", __error__=True, **kwargs)

log an error message

@abstractmethod
def metrics(self, data: dict[str, typing.Any]) -> None:
103	@abstractmethod
104	def metrics(self, data: dict[str, Any]) -> None:
105		"""Log a dictionary of metrics"""
106		pass

Log a dictionary of metrics

@abstractmethod
def artifact( self, path: pathlib.Path, type: str, aliases: list[str] | None = None, metadata: dict | None = None) -> None:
108	@abstractmethod
109	def artifact(
110		self,
111		path: Path,
112		type: str,
113		aliases: list[str] | None = None,
114		metadata: dict | None = None,
115	) -> None:
116		"""log an artifact from a file"""
117		pass

log an artifact from a file

url: str | list[str]
119	@property
120	@abstractmethod
121	def url(self) -> str | list[str]:
122		"""Get the URL for the current logging run"""
123		pass

Get the URL for the current logging run

run_path: pathlib.Path | list[pathlib.Path]
125	@property
126	@abstractmethod
127	def run_path(self) -> Path | list[Path]:
128		"""Get the path to the current logging run"""
129		pass

Get the path to the current logging run

@abstractmethod
def flush(self) -> None:
131	@abstractmethod
132	def flush(self) -> None:
133		"""Flush the logger"""
134		pass

Flush the logger

@abstractmethod
def finish(self) -> None:
136	@abstractmethod
137	def finish(self) -> None:
138		"""Finish logging"""
139		pass

Finish logging

def get_mem_usage(self) -> dict:
141	def get_mem_usage(self) -> dict:
142		mem_usage: dict = {}
143
144		try:
145			# CPU/Memory usage (if available)
146			if PSUTIL_AVAILABLE:
147				cpu_percent = psutil.cpu_percent()
148				mem_usage["cpu/percent"] = cpu_percent
149
150				# Memory usage
151				virtual_mem = psutil.virtual_memory()
152				mem_usage["ram/used"] = virtual_mem.used
153				mem_usage["ram/percent"] = virtual_mem.percent
154
155			# GPU information (if available)
156			if GPU_UTILS_AVAILABLE:
157				gpus = GPUtil.getGPUs()
158				for gpu in gpus:
159					gpu_id = gpu.id
160					mem_usage[f"gpu:{gpu_id}/load"] = gpu.load
161					mem_usage[f"gpu:{gpu_id}/memory_used"] = gpu.memoryUsed
162					mem_usage[f"gpu:{gpu_id}/temperature"] = gpu.temperature
163		except Exception as e:
164			self.warning(f"Error getting memory usage: {e}")
165
166		return mem_usage
def spinner_task(self, **kwargs) -> LoggerSpinner:
168	def spinner_task(self, **kwargs) -> LoggerSpinner:
169		"Create a spinner task. kwargs are passed to `Spinner`."
170		return LoggerSpinner(logger=self, **kwargs)

Create a spinner task. kwargs are passed to Spinner.