docs for trnbl v0.1.1
View Source on GitHub

trnbl

PyPI docs Checks Coverage

PyPI - Downloads GitHub commit activity GitHub closed pull requests code size, bytes

trnbl -- Training Butler

If you train a lot of models, you might often find yourself being annoyed at swapping between different loggers and fiddling with a bunch of if batch_idx % some_number == 0 statements. This package aims to fix that problem.

Firstly, a universal interface to wandb, tensorboard, and a minimal local logging solution (live demo) is provided.

  • This interface handles logging, error messages, metrics, and artifacts.
  • Swapping from one logger to another requires no modifications except initializing the new logger you want, and passing that instead.
  • You can even log to multiple loggers at once!

Secondly, a TrainingManager class is provided which handles logging, artifacts, checkpointing, evaluations, exceptions, and more, with flexibly customizable intervals.

  • Rather than having to specify all intervals in batches and then change everything manually when you change the batch size, dataset size, or number of epochs, you specify an interval in samples, batches, epochs, or runs. This is computed into the correct number of batches or epochs based on the current dataset and batch size.

    • "1/10 runs" -- 10 times a run
    • "2.5 epochs" -- every 2 & 1/2 epochs
    • (100, "batches") -- every 100 batches
    • "10k samples" -- every 10,000 samples
  • an evaluation function is passed in a tuple with an interval, takes the model as an argument, and returns the metrics as a dictionary

  • checkpointing is handled automatically, specifying an interval in the same way as evaluations

  • models are saved at the end of the run, or if an exception is raised, a model.exception.pt is saved

Installation

pip install trnbl

Usage

also see the notebooks/ folder:

import torch
from torch.utils.data import DataLoader
from trnbl.logging.local import LocalLogger
from trnbl.training_manager import TrainingManager

# set up your dataset, model, optimizer, etc as usual
dataloader: DataLoader = DataLoader(my_dataset, batch_size=32)
model: torch.nn.Module = MyModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
logger: LocalLogger = LocalLogger(
    project="iris-demo",
    metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
    train_config=dict(
        model=str(model), optimizer=str(optimizer), criterion=str(criterion)
    ),
)

with TrainingManager(
    # pass your model and logger
    model=model,
    logger=logger,
    evals={
        # pass evaluation functions which take a model, and return a dict of metrics
        "1k samples": my_evaluation_function,
        "0.5 epochs": lambda model: logger.get_mem_usage(),
        "100 batches": my_other_eval_function,
    }.items(),
    checkpoint_interval="1/10 run", # will save a checkpoint 10 times per run
) as tr:

    # wrap the loops, and length will be automatically calculated
    # and used to figure out when to run evals, checkpoint, etc
    for epoch in tr.epoch_loop(range(120)):
        for inputs, targets in tr.batch_loop(TRAIN_LOADER):
            # your normal training code
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # compute whatever you want every batch
            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)

            # log the metrics
            tr.batch_update(
                samples=len(targets),
                **{"train/loss": loss.item(), "train/acc": accuracy},
            )

    # a `model.final.pt` checkpoint will be saved at the end of the run,
    # or a `model.exception.pt` if something crashes inside the context

LocalLogger

Intended as a minimal logging solution for local runs, when you're too lazy to set up a new wandb project for a quick test, and want to be able to easily read the logs. It logs everything as json or jsonl files, and provides a simple web interface for viewing the data. The web interface allows:

  • enable or disable the visibility of individual runs
  • filter and sort runs by various stats via an interactive table
  • smooth the data and change axes scales
  • move and resize all plots and tables

You can view a live demo of the web interface here.

[](https://miv.name/trnbl/iris-demo/index.html)

TODOs:

  • BUG: minifying the html/js code causes things to break?

  • frontend:

    • batch/epoch size to table in config column group
    • box to add aliases to runs
    • customizable grid snap size?
    • display the grid on the background?
  • deployment:

    • demo website for local logger
    • CI/CD for website, minification, tests, etc
    • migrate to typescript

 1"""
 2.. include:: ../README.md
 3"""
 4
 5from trnbl.training_interval import TrainingInterval, TrainingIntervalUnit
 6from trnbl.loggers.base import TrainingLoggerBase
 7from trnbl.training_manager import TrainingManager
 8
 9__all__ = [
10	"TrainingInterval",
11	"TrainingIntervalUnit",
12	"TrainingLoggerBase",
13	"TrainingManager",
14	# submodules
15	"loggers",
16	"training_interval",
17	"training_manager",
18]

@dataclass(frozen=True)
class TrainingInterval:
 47@dataclass(frozen=True)
 48class TrainingInterval:
 49	"""A training interval, which can be specified in a few different units.
 50
 51	# Attributes:
 52	- `quantity: int|float` - the quantity of the interval
 53	- `unit: TrainingIntervalUnit` - the unit of the interval, one of "runs", "epochs", "batches", or "samples"
 54
 55	# Methods:
 56	- `TrainingInterval.from_str(raw: str) -> TrainingInterval` - parse a string into a TrainingInterval object
 57	- `TrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - convert the interval to a raw number of batches
 58	- `TrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - any representation to a number of batches
 59	- `TrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None` - current interval, with units switched to batches
 60
 61	Provides methods for reading from a string or tuple, and normalizing to batches.
 62	"""
 63
 64	quantity: int | float
 65	unit: TrainingIntervalUnit
 66
 67	def __iter__(self) -> Generator[int | float | TrainingIntervalUnit, None, None]:
 68		yield self.quantity
 69		yield self.unit
 70
 71	def __getitem__(self, index: int) -> int | float | TrainingIntervalUnit:
 72		if index == 0:
 73			return self.quantity
 74		elif index == 1:
 75			return self.unit
 76		else:
 77			raise IndexError(f"invalid index {index} for TrainingInterval")
 78
 79	def __post_init__(self) -> None:
 80		try:
 81			assert isinstance(self.quantity, (int, float)), (
 82				"quantity should be an integer or float"
 83			)
 84			# TODO: Literal[...].__args__ is not defined??
 85			if self.unit not in TrainingIntervalUnit.__args__:  # type: ignore[attr-defined]
 86				unit_dealised: str | None = _TRAINING_INTERVAL_UNIT_ALIASES.get(
 87					self.unit.lower(), None
 88				)
 89				if isinstance(unit_dealised, str):
 90					self.__dict__["unit"] = unit_dealised
 91				else:
 92					raise ValueError(f"invalid unit {self.unit = }")
 93
 94			assert self.unit in TrainingIntervalUnit.__args__, (  # type: ignore[attr-defined]
 95				f"invalid unit {self.unit}"
 96			)
 97		except AssertionError as e:
 98			raise AssertionError(
 99				f"Error initializing TrainingInterval\n{self}\n{e}"
100			) from e
101
102		# check values in proper ranges
103		expected_interval: Interval = _TRAINING_INTERVAL_UNITS_RANGES[self.unit]
104		if self.quantity not in expected_interval:
105			WhenIntervalLessThanBatch.process(
106				f"interval {self} has invalid quantity, expected in interval {expected_interval}, will set to closest bound if not erroring out",
107				except_cls=IntervalValueError,
108				warn_cls=IntervalValueError,
109			)
110			self.__dict__["quantity"] = expected_interval.clamp(self.quantity)
111
112		# cast if necessary
113		self.__dict__["quantity"] = _TRAINING_INTERVAL_UNITS_CAST[self.unit](
114			self.quantity
115		)
116
117	def __eq__(self, other: Any) -> bool:
118		if not isinstance(other, self.__class__):
119			raise TypeError(
120				f"invalid type {type(other)} for comparison with TrainingInterval"
121			)
122		return (
123			abs(self.quantity - other.quantity) < _EPSILON and self.unit == other.unit
124		)
125
126	def as_batch_count(
127		self,
128		batchsize: int,
129		batches_per_epoch: int,
130		epochs: int | None = None,
131	) -> int:
132		"""given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
133
134		# Parameters:
135		 - `batchsize: int`
136		   the size of a batch
137		 - `batches_per_epoch: int`
138		   the number of batches in an epoch
139		 - `epochs: int|None`
140		   the number of epochs to run (only required if the interval is in "runs")
141
142		# Returns:
143		 - `int`
144		   the interval as a number of batches
145
146		# Raises:
147		 - `ValueError`
148		   if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR`
149		   otherwise, will warn or ignore and set the interval to 1 batch
150		 - `ValueError`
151		   if the unit is not one of "runs", "epochs", "batches", or "samples"
152
153
154		"""
155
156		output: int | float
157
158		match self.unit:
159			case "runs":
160				assert epochs is not None, (
161					"epochs must be provided to convert runs to batches"
162				)
163				output = self.quantity * epochs * batches_per_epoch
164			case "epochs":
165				output = self.quantity * batches_per_epoch
166			case "batches":
167				output = self.quantity
168			case "samples":
169				output = self.quantity / batchsize
170			case _:
171				raise ValueError(f"invalid unit {self.unit}")
172
173		# check if interval is less than 1 batch
174		if output < 1:
175			WhenIntervalLessThanBatch.process(
176				f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out",
177				except_cls=IntervalValueError,
178				warn_cls=IntervalValueError,
179			)
180			output = 1
181
182		return int(round(output))
183
184	def normalized(
185		self,
186		batchsize: int,
187		batches_per_epoch: int,
188		epochs: int | None = None,
189	) -> "TrainingInterval":
190		"""convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches"""
191		quantity: int | float = self.as_batch_count(
192			batches_per_epoch=batches_per_epoch,
193			batchsize=batchsize,
194			epochs=epochs,
195		)
196		unit: TrainingIntervalUnit = "batches"
197		return self.__class__(quantity, unit)
198
199	@classmethod
200	def from_str(cls, raw: str) -> "TrainingInterval":
201		"""parse a string into a TrainingInterval object
202
203		# Examples:
204
205		>>> TrainingInterval.from_str("5 epochs")
206		TrainingInterval(5, 'epochs')
207		>>> TrainingInterval.from_str("100 batches")
208		TrainingInterval(100, 'batches')
209		>>> TrainingInterval.from_str("0.1 runs")
210		TrainingInterval(0.1, 'runs')
211		>>> TrainingInterval.from_str("1/5 runs")
212		TrainingInterval(0.2, 'runs')
213
214		"""
215		try:
216			# remove prefix and suffix (optionally)
217			raw = raw.removeprefix("TrainingInterval(").removesuffix(")")
218
219			# process quantity
220			raw_split: list[str]
221			quantity_str: str
222			if "," in raw:
223				raw_split = raw.split(",")
224				quantity_str = ",".join(raw_split[:-1])
225			else:
226				raw_split = raw.split()
227				quantity_str = " ".join(raw_split[:-1])
228
229			quantity: int | float = str_to_numeric(quantity_str)
230
231			# process unit
232			unit: str = raw_split[-1]
233			unit.strip().strip("'\"").strip()
234
235			# unit should be one of the allowed units
236			unit_dealised: str | None
237			if unit.lower() in TrainingIntervalUnit.__args__:  # type: ignore[attr-defined]
238				unit_dealised = unit.lower()
239			else:
240				unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None)
241			if isinstance(unit_dealised, str):
242				unit = unit_dealised
243			else:
244				raise ValueError(f"invalid unit {unit}")
245
246			assert unit in TrainingIntervalUnit.__args__  # type: ignore[attr-defined]
247		except Exception as e:
248			raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e
249
250		return cls(quantity, unit)  # type: ignore[arg-type]
251
252	@classmethod
253	def from_any(cls, *args, **kwargs) -> "TrainingInterval":
254		"""parse a string or tuple into a TrainingInterval object"""
255
256		try:
257			# no kwargs allowed
258			assert len(kwargs) == 0, "no kwargs allowed for from_any"
259
260			# split up args
261			data: Any
262			match len(args):
263				case 1:
264					data = args[0]
265				case 2:
266					data = args
267				case _:
268					raise ValueError(
269						f"invalid number of args {len(args)} for from_any: {args = }"
270					)
271
272			if isinstance(data, cls):
273				return data
274			elif isinstance(data, str):
275				return cls.from_str(data)
276			elif isinstance(data, Sequence):
277				assert len(data) == 2, (
278					f"invalid length {len(data)} for TrainingInterval: {data}"
279				)
280				quantity, unit = data
281				if isinstance(quantity, str):
282					quantity = str_to_numeric(quantity)
283				return cls(quantity, unit)
284			else:
285				raise ValueError(f"invalid type {type(data)} for TrainingInterval")
286
287		except AssertionError as e:
288			raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e
289
290	@classmethod
291	def process_to_batches(
292		cls,
293		interval: "CastableToTrainingInterval",
294		batchsize: int,
295		batches_per_epoch: int,
296		epochs: int | None = None,
297	) -> int:
298		"""directly from any representation to a number of batches"""
299
300		interval_ti: TrainingInterval = cls.from_any(interval)
301
302		return interval_ti.as_batch_count(
303			batches_per_epoch=batches_per_epoch,
304			batchsize=batchsize,
305			epochs=epochs,
306		)

A training interval, which can be specified in a few different units.

Attributes:

  • quantity: int|float - the quantity of the interval
  • unit: TrainingIntervalUnit - the unit of the interval, one of "runs", "epochs", "batches", or "samples"

Methods:

Provides methods for reading from a string or tuple, and normalizing to batches.

TrainingInterval( quantity: int | float, unit: Literal['runs', 'epochs', 'batches', 'samples'])
quantity: int | float
unit: Literal['runs', 'epochs', 'batches', 'samples']
def as_batch_count( self, batchsize: int, batches_per_epoch: int, epochs: int | None = None) -> int:
126	def as_batch_count(
127		self,
128		batchsize: int,
129		batches_per_epoch: int,
130		epochs: int | None = None,
131	) -> int:
132		"""given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
133
134		# Parameters:
135		 - `batchsize: int`
136		   the size of a batch
137		 - `batches_per_epoch: int`
138		   the number of batches in an epoch
139		 - `epochs: int|None`
140		   the number of epochs to run (only required if the interval is in "runs")
141
142		# Returns:
143		 - `int`
144		   the interval as a number of batches
145
146		# Raises:
147		 - `ValueError`
148		   if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR`
149		   otherwise, will warn or ignore and set the interval to 1 batch
150		 - `ValueError`
151		   if the unit is not one of "runs", "epochs", "batches", or "samples"
152
153
154		"""
155
156		output: int | float
157
158		match self.unit:
159			case "runs":
160				assert epochs is not None, (
161					"epochs must be provided to convert runs to batches"
162				)
163				output = self.quantity * epochs * batches_per_epoch
164			case "epochs":
165				output = self.quantity * batches_per_epoch
166			case "batches":
167				output = self.quantity
168			case "samples":
169				output = self.quantity / batchsize
170			case _:
171				raise ValueError(f"invalid unit {self.unit}")
172
173		# check if interval is less than 1 batch
174		if output < 1:
175			WhenIntervalLessThanBatch.process(
176				f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out",
177				except_cls=IntervalValueError,
178				warn_cls=IntervalValueError,
179			)
180			output = 1
181
182		return int(round(output))

given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches

Parameters:

  • batchsize: int the size of a batch
  • batches_per_epoch: int the number of batches in an epoch
  • epochs: int|None the number of epochs to run (only required if the interval is in "runs")

Returns:

  • int the interval as a number of batches

Raises:

  • ValueError if the interval is less than 1 batch, and the trnbl.training_interval.WhenIntervalLessThanBatch is set to muutils.errormode.ErrorMode.ERROR otherwise, will warn or ignore and set the interval to 1 batch
  • ValueError if the unit is not one of "runs", "epochs", "batches", or "samples"
def normalized( self, batchsize: int, batches_per_epoch: int, epochs: int | None = None) -> TrainingInterval:
184	def normalized(
185		self,
186		batchsize: int,
187		batches_per_epoch: int,
188		epochs: int | None = None,
189	) -> "TrainingInterval":
190		"""convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches"""
191		quantity: int | float = self.as_batch_count(
192			batches_per_epoch=batches_per_epoch,
193			batchsize=batchsize,
194			epochs=epochs,
195		)
196		unit: TrainingIntervalUnit = "batches"
197		return self.__class__(quantity, unit)

convert the units of the interval to batches, by calling as_batch_count and setting the unit to "batches

@classmethod
def from_str(cls, raw: str) -> TrainingInterval:
199	@classmethod
200	def from_str(cls, raw: str) -> "TrainingInterval":
201		"""parse a string into a TrainingInterval object
202
203		# Examples:
204
205		>>> TrainingInterval.from_str("5 epochs")
206		TrainingInterval(5, 'epochs')
207		>>> TrainingInterval.from_str("100 batches")
208		TrainingInterval(100, 'batches')
209		>>> TrainingInterval.from_str("0.1 runs")
210		TrainingInterval(0.1, 'runs')
211		>>> TrainingInterval.from_str("1/5 runs")
212		TrainingInterval(0.2, 'runs')
213
214		"""
215		try:
216			# remove prefix and suffix (optionally)
217			raw = raw.removeprefix("TrainingInterval(").removesuffix(")")
218
219			# process quantity
220			raw_split: list[str]
221			quantity_str: str
222			if "," in raw:
223				raw_split = raw.split(",")
224				quantity_str = ",".join(raw_split[:-1])
225			else:
226				raw_split = raw.split()
227				quantity_str = " ".join(raw_split[:-1])
228
229			quantity: int | float = str_to_numeric(quantity_str)
230
231			# process unit
232			unit: str = raw_split[-1]
233			unit.strip().strip("'\"").strip()
234
235			# unit should be one of the allowed units
236			unit_dealised: str | None
237			if unit.lower() in TrainingIntervalUnit.__args__:  # type: ignore[attr-defined]
238				unit_dealised = unit.lower()
239			else:
240				unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None)
241			if isinstance(unit_dealised, str):
242				unit = unit_dealised
243			else:
244				raise ValueError(f"invalid unit {unit}")
245
246			assert unit in TrainingIntervalUnit.__args__  # type: ignore[attr-defined]
247		except Exception as e:
248			raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e
249
250		return cls(quantity, unit)  # type: ignore[arg-type]

parse a string into a TrainingInterval object

Examples:

>>> TrainingInterval.from_str("5 epochs")
TrainingInterval(5, 'epochs')
>>> TrainingInterval.from_str("100 batches")
TrainingInterval(100, 'batches')
>>> TrainingInterval.from_str("0.1 runs")
TrainingInterval(0.1, 'runs')
>>> TrainingInterval.from_str("1/5 runs")
TrainingInterval(0.2, 'runs')
@classmethod
def from_any(cls, *args, **kwargs) -> TrainingInterval:
252	@classmethod
253	def from_any(cls, *args, **kwargs) -> "TrainingInterval":
254		"""parse a string or tuple into a TrainingInterval object"""
255
256		try:
257			# no kwargs allowed
258			assert len(kwargs) == 0, "no kwargs allowed for from_any"
259
260			# split up args
261			data: Any
262			match len(args):
263				case 1:
264					data = args[0]
265				case 2:
266					data = args
267				case _:
268					raise ValueError(
269						f"invalid number of args {len(args)} for from_any: {args = }"
270					)
271
272			if isinstance(data, cls):
273				return data
274			elif isinstance(data, str):
275				return cls.from_str(data)
276			elif isinstance(data, Sequence):
277				assert len(data) == 2, (
278					f"invalid length {len(data)} for TrainingInterval: {data}"
279				)
280				quantity, unit = data
281				if isinstance(quantity, str):
282					quantity = str_to_numeric(quantity)
283				return cls(quantity, unit)
284			else:
285				raise ValueError(f"invalid type {type(data)} for TrainingInterval")
286
287		except AssertionError as e:
288			raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e

parse a string or tuple into a TrainingInterval object

@classmethod
def process_to_batches( cls, interval: Union[str, tuple[Union[int, float, str], str], TrainingInterval], batchsize: int, batches_per_epoch: int, epochs: int | None = None) -> int:
290	@classmethod
291	def process_to_batches(
292		cls,
293		interval: "CastableToTrainingInterval",
294		batchsize: int,
295		batches_per_epoch: int,
296		epochs: int | None = None,
297	) -> int:
298		"""directly from any representation to a number of batches"""
299
300		interval_ti: TrainingInterval = cls.from_any(interval)
301
302		return interval_ti.as_batch_count(
303			batches_per_epoch=batches_per_epoch,
304			batchsize=batchsize,
305			epochs=epochs,
306		)

directly from any representation to a number of batches

TrainingIntervalUnit = typing.Literal['runs', 'epochs', 'batches', 'samples']
class TrainingLoggerBase(abc.ABC):
 82class TrainingLoggerBase(ABC):
 83	"""Base class for training loggers"""
 84
 85	@abstractmethod
 86	def debug(self, message: str, **kwargs) -> None:
 87		"""log a debug message which will be saved, but not printed"""
 88		pass
 89
 90	@abstractmethod
 91	def message(self, message: str, **kwargs) -> None:
 92		"""log a progress message, which will be printed to stdout"""
 93		pass
 94
 95	def warning(self, message: str, **kwargs) -> None:
 96		"""log a warning message, which will be printed to stderr"""
 97		self.message(f"WARNING: {message}", __warning__=True, **kwargs)
 98
 99	def error(self, message: str, **kwargs) -> None:
100		"""log an error message"""
101		self.message(f"ERROR: {message}", __error__=True, **kwargs)
102
103	@abstractmethod
104	def metrics(self, data: dict[str, Any]) -> None:
105		"""Log a dictionary of metrics"""
106		pass
107
108	@abstractmethod
109	def artifact(
110		self,
111		path: Path,
112		type: str,
113		aliases: list[str] | None = None,
114		metadata: dict | None = None,
115	) -> None:
116		"""log an artifact from a file"""
117		pass
118
119	@property
120	@abstractmethod
121	def url(self) -> str | list[str]:
122		"""Get the URL for the current logging run"""
123		pass
124
125	@property
126	@abstractmethod
127	def run_path(self) -> Path | list[Path]:
128		"""Get the path to the current logging run"""
129		pass
130
131	@abstractmethod
132	def flush(self) -> None:
133		"""Flush the logger"""
134		pass
135
136	@abstractmethod
137	def finish(self) -> None:
138		"""Finish logging"""
139		pass
140
141	def get_mem_usage(self) -> dict:
142		mem_usage: dict = {}
143
144		try:
145			# CPU/Memory usage (if available)
146			if PSUTIL_AVAILABLE:
147				cpu_percent = psutil.cpu_percent()
148				mem_usage["cpu/percent"] = cpu_percent
149
150				# Memory usage
151				virtual_mem = psutil.virtual_memory()
152				mem_usage["ram/used"] = virtual_mem.used
153				mem_usage["ram/percent"] = virtual_mem.percent
154
155			# GPU information (if available)
156			if GPU_UTILS_AVAILABLE:
157				gpus = GPUtil.getGPUs()
158				for gpu in gpus:
159					gpu_id = gpu.id
160					mem_usage[f"gpu:{gpu_id}/load"] = gpu.load
161					mem_usage[f"gpu:{gpu_id}/memory_used"] = gpu.memoryUsed
162					mem_usage[f"gpu:{gpu_id}/temperature"] = gpu.temperature
163		except Exception as e:
164			self.warning(f"Error getting memory usage: {e}")
165
166		return mem_usage
167
168	def spinner_task(self, **kwargs) -> LoggerSpinner:
169		"Create a spinner task. kwargs are passed to `Spinner`."
170		return LoggerSpinner(logger=self, **kwargs)
171
172	# def seq_task(self, **kwargs) -> LoggerSpinner:
173	# 	"Create a sequential task with progress bar. kwargs are passed to `tqdm`."
174	# 	return LoggerSpinner(message=message, logger=self, **kwargs)

Base class for training loggers

@abstractmethod
def debug(self, message: str, **kwargs) -> None:
85	@abstractmethod
86	def debug(self, message: str, **kwargs) -> None:
87		"""log a debug message which will be saved, but not printed"""
88		pass

log a debug message which will be saved, but not printed

@abstractmethod
def message(self, message: str, **kwargs) -> None:
90	@abstractmethod
91	def message(self, message: str, **kwargs) -> None:
92		"""log a progress message, which will be printed to stdout"""
93		pass

log a progress message, which will be printed to stdout

def warning(self, message: str, **kwargs) -> None:
95	def warning(self, message: str, **kwargs) -> None:
96		"""log a warning message, which will be printed to stderr"""
97		self.message(f"WARNING: {message}", __warning__=True, **kwargs)

log a warning message, which will be printed to stderr

def error(self, message: str, **kwargs) -> None:
 99	def error(self, message: str, **kwargs) -> None:
100		"""log an error message"""
101		self.message(f"ERROR: {message}", __error__=True, **kwargs)

log an error message

@abstractmethod
def metrics(self, data: dict[str, typing.Any]) -> None:
103	@abstractmethod
104	def metrics(self, data: dict[str, Any]) -> None:
105		"""Log a dictionary of metrics"""
106		pass

Log a dictionary of metrics

@abstractmethod
def artifact( self, path: pathlib.Path, type: str, aliases: list[str] | None = None, metadata: dict | None = None) -> None:
108	@abstractmethod
109	def artifact(
110		self,
111		path: Path,
112		type: str,
113		aliases: list[str] | None = None,
114		metadata: dict | None = None,
115	) -> None:
116		"""log an artifact from a file"""
117		pass

log an artifact from a file

url: str | list[str]
119	@property
120	@abstractmethod
121	def url(self) -> str | list[str]:
122		"""Get the URL for the current logging run"""
123		pass

Get the URL for the current logging run

run_path: pathlib.Path | list[pathlib.Path]
125	@property
126	@abstractmethod
127	def run_path(self) -> Path | list[Path]:
128		"""Get the path to the current logging run"""
129		pass

Get the path to the current logging run

@abstractmethod
def flush(self) -> None:
131	@abstractmethod
132	def flush(self) -> None:
133		"""Flush the logger"""
134		pass

Flush the logger

@abstractmethod
def finish(self) -> None:
136	@abstractmethod
137	def finish(self) -> None:
138		"""Finish logging"""
139		pass

Finish logging

def get_mem_usage(self) -> dict:
141	def get_mem_usage(self) -> dict:
142		mem_usage: dict = {}
143
144		try:
145			# CPU/Memory usage (if available)
146			if PSUTIL_AVAILABLE:
147				cpu_percent = psutil.cpu_percent()
148				mem_usage["cpu/percent"] = cpu_percent
149
150				# Memory usage
151				virtual_mem = psutil.virtual_memory()
152				mem_usage["ram/used"] = virtual_mem.used
153				mem_usage["ram/percent"] = virtual_mem.percent
154
155			# GPU information (if available)
156			if GPU_UTILS_AVAILABLE:
157				gpus = GPUtil.getGPUs()
158				for gpu in gpus:
159					gpu_id = gpu.id
160					mem_usage[f"gpu:{gpu_id}/load"] = gpu.load
161					mem_usage[f"gpu:{gpu_id}/memory_used"] = gpu.memoryUsed
162					mem_usage[f"gpu:{gpu_id}/temperature"] = gpu.temperature
163		except Exception as e:
164			self.warning(f"Error getting memory usage: {e}")
165
166		return mem_usage
def spinner_task(self, **kwargs) -> trnbl.loggers.base.LoggerSpinner:
168	def spinner_task(self, **kwargs) -> LoggerSpinner:
169		"Create a spinner task. kwargs are passed to `Spinner`."
170		return LoggerSpinner(logger=self, **kwargs)

Create a spinner task. kwargs are passed to Spinner.

class TrainingManager(typing.Generic[~TLogger]):
 97class TrainingManager(Generic[TLogger]):
 98	"""context manager for training a model, with logging, evals, and checkpoints
 99
100	# Parameters:
101	- `model : torch.nn.Module`
102	    ref to model being trained - used for saving checkpoints
103	- `dataloader : torch.utils.data.DataLoader`
104	    ref to dataloader being used - used for calculating training progress
105	- `logger : TrainingLoggerBase`
106	    logger, which can be local or interface with wandb.
107	- `epochs : int`
108	    number of epochs to train for
109	    (defaults to `1`)
110	- `evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None`
111	    list of pairs of (interval, eval_fn) to run evals on the model. See `TrainingInterval` for interval options.
112	    (defaults to `None`)
113	- `checkpoint_interval : TrainingInterval | str`
114	    interval at which to save model checkpoints
115	    (defaults to `TrainingInterval(1, "epochs")`)
116	- `print_metrics_interval : TrainingInterval | str`
117	    interval at which to print metrics
118	    (defaults to `TrainingInterval(0.1, "runs")`)
119	- `save_model : Callable[[torch.nn.Module, Path], None]`
120	    function to save the model (defaults to `torch.save`)
121	    (defaults to `torch.save`)
122	- `model_save_path : str`
123	    format string for saving model checkpoints. uses `_get_format_kwargs` for formatting, along with an `alias` kwarg
124	    (defaults to `"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"`)
125	- `model_save_path_special : str`
126	    format string for saving special model checkpoints (final, exception, etc.). uses `_get_format_kwargs` for formatting, along with an `alias` kwarg
127	    (defaults to `"{run_path}/model.{alias}.pt"`)
128
129	# Usage:
130	```python
131	with TrainingManager(
132	    model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
133	    evals={
134	        "1 epochs": eval_func,
135	        "0.1 runs": lambda model: logger.get_mem_usage(),
136	    }.items(),
137	    checkpoint_interval="50 epochs",
138	) as tp:
139
140	    # Training loop
141	    model.train()
142	    for epoch in range(epochs):
143	        for inputs, targets in TRAIN_LOADER:
144	            # the usual
145	            optimizer.zero_grad()
146	            outputs = model(inputs)
147	            loss = criterion(outputs, targets)
148	            loss.backward()
149	            optimizer.step()
150
151	            # compute accuracy
152	            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
153
154	            # log metrics
155	            tp.batch_update(
156	                # pass in number of samples in your batch (or it will be inferred from the batch size)
157	                samples=len(targets),
158	                # any other metrics you compute every loop
159	                **{"train/loss": loss.item(), "train/acc": accuracy},
160	            )
161	            # batch_update will automatically run evals and save checkpoints as needed
162
163	        tp.epoch_update()
164	```
165
166	"""
167
168	def __init__(
169		self,
170		model: "torch.nn.Module",
171		logger: TLogger,
172		# required if you don't wrap the loops
173		dataloader: "torch.utils.data.DataLoader|None" = None,
174		epochs_total: int | None = None,
175		save_model: Callable[["torch.nn.Module", Path], None] = torch.save,
176		# everything with intervals
177		evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None,
178		checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"),
179		print_metrics_interval: CastableToTrainingInterval = TrainingInterval(
180			0.1, "runs"
181		),
182		# everything with paths
183		model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt",
184		model_save_path_special: str = "{run_path}/model.{alias}.pt",
185	):
186		# save start time
187		self.start_time: float = time.time()
188		# non path and non-interval args get copied over directly
189		self.model: "torch.nn.Module" = model
190		self.logger: TLogger = logger
191		self.save_model: Callable[["torch.nn.Module", Path], None] = save_model
192
193		self.logger.message("starting training manager initialization")
194
195		# model save paths
196		self.model_save_path: str = model_save_path
197		self.model_save_path_special: str = model_save_path_special
198
199		# temp intervals for processing later in `try_compute_counters`
200		self._evals: Iterable[tuple[TrainingInterval, EvalFunction]]
201		if evals is None:
202			self._evals = []
203		else:
204			self._evals = [
205				(TrainingInterval.from_any(interval), eval_fn)
206				for interval, eval_fn in evals
207			]
208		self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any(
209			checkpoint_interval
210		)
211		self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any(
212			print_metrics_interval
213		)
214
215		self.evals: list[tuple[int, EvalFunction]] = list()
216		self.checkpoint_interval: int | None = None
217		self.print_metrics_interval: int | None = None
218
219		# counters for epochs, batches, samples, and checkpoints
220		self.epochs: int = 0
221		self.batches: int = 0
222		self.samples: int = 0
223		self.checkpoints: int = 0
224
225		# total numbers of epochs, batches, and samples
226		# pass via init kwarg or wrapped epochs loop
227		self.epochs_total: int | None = epochs_total
228		# from dataloader or dataloader in wrapped loop
229		self.batches_per_epoch: int | None = None
230		self.batch_size: int | None = None
231		self.samples_per_epoch: int | None = None
232		# computed dynamically from the above
233		self.batches_total: int | None = None
234		self.samples_total: int | None = None
235
236		# whether the init is finished
237		self.init_complete: bool = False
238
239		# if we have a dataloader, we can compute some of the above
240		if dataloader is not None:
241			self.batches_per_epoch = len(dataloader)
242			self.batch_size = dataloader.batch_size
243			self.samples_per_epoch = len(dataloader.dataset)  # type: ignore[arg-type]
244
245		self.try_compute_counters()
246
247	def try_compute_counters(self) -> None:
248		# we depend on either the TrainingManager init or the wrapped loops
249		# getting the epochs_total and dataloader
250		# everything else is computed dynamically
251
252		if any(
253			x is None
254			for x in [
255				self.epochs_total,
256				self.batches_per_epoch,
257				self.batch_size,
258				self.samples_per_epoch,
259			]
260		):
261			# if we don't have all the info we need, return early
262			return
263
264		# we can safely ignore type check here since we just checked for `None`
265		self.batches_total = self.batches_per_epoch * self.epochs_total  # type: ignore[operator]
266		self.samples_total = self.samples_per_epoch * self.epochs_total  # type: ignore[operator]
267
268		# check if the dataloader has a finite nonzero length
269		if self.samples_per_epoch == 0:
270			raise TrainingManagerInitError(
271				f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }"
272			)
273
274		if self.batches_per_epoch == 0:
275			raise TrainingManagerInitError(
276				f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }"
277			)
278
279		if self.batch_size == 0:
280			raise TrainingManagerInitError(
281				f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }"
282			)
283
284		if self.batch_size is None:
285			warnings.warn(
286				"batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute."
287				+ "\nthis should probably be an exception"
288			)
289
290		# normalize intervals for checkpoints, metrics printing, and evals
291		_batch_info_kwargs: dict[str, int | None] = dict(
292			batches_per_epoch=self.batches_per_epoch,
293			batchsize=self.batch_size,
294			epochs=self.epochs_total,
295		)
296
297		# TODO: no idea why we need `type: ignore[arg-type]` here
298		self.checkpoint_interval = TrainingInterval.process_to_batches(
299			interval=self._checkpoint_interval,
300			**_batch_info_kwargs,  # type: ignore[arg-type]
301		)
302		self.print_metrics_interval = TrainingInterval.process_to_batches(
303			interval=self._print_metrics_interval,
304			**_batch_info_kwargs,  # type: ignore[arg-type]
305		)
306
307		# list[tuple[int, EvalFunction]]
308		self.evals = [
309			(
310				TrainingInterval.process_to_batches(interval, **_batch_info_kwargs),  # type: ignore[arg-type]
311				eval_fn,
312			)
313			for interval, eval_fn in self._evals
314		]
315
316		# log this info
317		self.init_complete = True
318		self.logger.message(
319			"initialized training manager",
320			__training_manager_init__=True,
321			epochs_total=self.epochs_total,
322			batches_per_epoch=self.batches_per_epoch,
323			batch_size=self.batch_size,
324			samples_per_epoch=self.samples_per_epoch,
325			samples_total=self.samples_total,
326			checkpoint_interval_batches=self.checkpoint_interval,
327			print_metrics_interval_batches=self.print_metrics_interval,
328			model_save_path=self.model_save_path,
329			model_save_path_special=self.model_save_path_special,
330			**self.training_status(),
331		)
332
333	def __enter__(self):
334		return self
335
336	def __exit__(self, exc_type: Type, exc_val: Exception, exc_tb: TracebackType):
337		# if error
338		if exc_type is not None:
339			# add exception info
340			self.logger.error(
341				str(exc_val),
342				exc_type=str(exc_type),
343				exc_val=str(exc_val),
344				exc_tb=str(exc_tb),
345			)
346			# save the model
347			self._save_checkpoint(alias="exception")
348
349			# close the logger
350			self.logger.finish()
351		else:
352			# if no error, log and save the final model
353			self.logger.message(
354				"training complete",
355				__complete__=True,
356			)
357			self._save_checkpoint(alias="final")
358			self.logger.finish()
359
360	def epoch_loop(
361		self,
362		epochs: Sequence[int],
363		use_tqdm: bool = True,
364		**tqdm_kwargs,
365	) -> Generator[int, None, None]:
366		return wrapped_iterable(
367			sequence=epochs,
368			manager=self,
369			is_epoch=True,
370			use_tqdm=use_tqdm,
371			tqdm_kwargs=tqdm_kwargs,
372		)
373
374	def batch_loop(
375		self,
376		batches: Sequence[int],
377		use_tqdm: bool = False,
378		**tqdm_kwargs,
379	) -> Generator[int, None, None]:
380		return wrapped_iterable(
381			sequence=batches,
382			manager=self,
383			is_epoch=False,
384			use_tqdm=use_tqdm,
385			tqdm_kwargs=tqdm_kwargs,
386		)
387
388	def check_is_initialized(self):
389		if not self.init_complete:
390			raise TrainingManagerInitError(
391				"TrainingManager not correctly initialized. ",
392				"This is likely due to failing to specify the epoch count, or failing to specify batch size/count. "
393				"you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`",
394				"AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.",
395				"please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.",
396			)
397
398	def get_elapsed_time(self) -> float:
399		"""return the elapsed time in seconds since the start of training"""
400		return time.time() - self.start_time
401
402	def training_status(self) -> dict[str, int | float]:
403		"""status of elapsed time, samples, batches, epochs, and checkpoints"""
404		return dict(
405			# timestamp handled in logger
406			elapsed_time=self.get_elapsed_time(),
407			samples=self.samples,
408			batches=self.batches,
409			epochs=self.epochs,
410			latest_checkpoint=self.checkpoints,
411		)
412
413	def _get_format_kwargs(self) -> dict[str, str | int | float]:
414		"""keyword args for formatting model save paths. calls `TrainingManager.training_status`
415
416		# Provides:
417		- `run_path : str` - path where the run is being logged and artifacts are being saved
418		- `elapsed_time : float` - the elapsed time in seconds since the start of training
419		- `samples : int` - samples seen so far
420		- `batches : int` - batches seen so far
421		- `epochs : int` - the latest epoch number
422		- `latest_checkpoint : int` - the latest checkpoint number
423
424		"""
425		return {
426			"run_path": (
427				self.logger.run_path.as_posix()
428				if isinstance(self.logger.run_path, Path)
429				# HACK: just return the first one
430				else self.logger.run_path[0].as_posix()
431			),
432			**self.training_status(),
433		}
434
435	def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs):
436		"""call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs
437
438		This function will:
439		- update internal counters
440		- run evals as needed (based on the intervals passed)
441		- log all metrics and training status
442		- save a checkpoint as needed (based on the checkpoint interval)
443		"""
444		# check init is finished
445		if not self.init_complete:
446			self.try_compute_counters()
447			self.check_is_initialized()
448
449		# process metrics and kwargs
450		if metrics is None:
451			metrics = dict()
452
453		metrics.update(kwargs)
454
455		# update counters
456		self.batches += 1
457		if samples is not None:
458			self.samples += samples
459		else:
460			# TODO: we warn if batch size is None, but don't except
461			self.samples += self.batch_size  # type: ignore[operator]
462
463		# run evals if needed
464		for interval, eval_fn in self.evals:
465			if (self.batches % interval == 0) or (self.batches == self.batches_total):
466				metrics.update(eval_fn(self.model))
467
468		# log metrics & training status
469		self.logger.metrics({**metrics, **self.training_status()})
470
471		# print metrics if needed
472
473		# save checkpoint if needed
474		if self.batches % self.checkpoint_interval == 0:  # type: ignore[operator]
475			self._save_checkpoint()
476
477	def epoch_update(self):
478		"""call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter"""
479		self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}")
480		self.epochs += 1
481
482	def _save_checkpoint(self, alias: str | None = None):
483		"""wrapper for saving checkpoint as an artifact to the logger and incrementing the checkpoint counter"""
484		# if no alias, then it's a regular checkpoint
485		no_alias: bool = alias is None
486		if no_alias:
487			alias = f"checkpoint-{self.checkpoints}"
488
489		# TODO: store training hist with model?
490
491		# put together a path
492		checkpoint_path: Path
493		if no_alias:
494			# format the model save path for a normal checkpoint
495			checkpoint_path = Path(
496				self.model_save_path.format(
497					**self._get_format_kwargs(),
498					alias=alias,
499				)
500			)
501		else:
502			# for a special checkpoint, use the special path
503			checkpoint_path = Path(
504				self.model_save_path_special.format(
505					**self._get_format_kwargs(),
506					alias=alias,
507				)
508			)
509
510		# make sure directory exists
511		checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
512
513		# save the model
514		self.save_model(self.model, checkpoint_path)
515
516		# log the checkpoint as an artifact
517		self.logger.artifact(
518			checkpoint_path,
519			"model",
520			metadata=self.training_status(),
521			aliases=[alias] if alias else None,
522		)
523
524		# increment checkpoint counter
525		self.checkpoints += 1

context manager for training a model, with logging, evals, and checkpoints

Parameters:

  • model : torch.nn.Module ref to model being trained - used for saving checkpoints
  • dataloader : torch.utils.data.DataLoader ref to dataloader being used - used for calculating training progress
  • logger : TrainingLoggerBase logger, which can be local or interface with wandb.
  • epochs : int number of epochs to train for (defaults to 1)
  • evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None list of pairs of (interval, eval_fn) to run evals on the model. See TrainingInterval for interval options. (defaults to None)
  • checkpoint_interval : TrainingInterval | str interval at which to save model checkpoints (defaults to TrainingInterval(1, "epochs"))
  • print_metrics_interval : TrainingInterval | str interval at which to print metrics (defaults to TrainingInterval(0.1, "runs"))
  • save_model : Callable[[torch.nn.Module, Path], None] function to save the model (defaults to torch.save) (defaults to torch.save)
  • model_save_path : str format string for saving model checkpoints. uses _get_format_kwargs for formatting, along with an alias kwarg (defaults to "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt")
  • model_save_path_special : str format string for saving special model checkpoints (final, exception, etc.). uses _get_format_kwargs for formatting, along with an alias kwarg (defaults to "{run_path}/model.{alias}.pt")

Usage:

with TrainingManager(
    model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
    evals={
        "1 epochs": eval_func,
        "0.1 runs": lambda model: logger.get_mem_usage(),
    }.items(),
    checkpoint_interval="50 epochs",
) as tp:

    # Training loop
    model.train()
    for epoch in range(epochs):
        for inputs, targets in TRAIN_LOADER:
            # the usual
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # compute accuracy
            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)

            # log metrics
            tp.batch_update(
                # pass in number of samples in your batch (or it will be inferred from the batch size)
                samples=len(targets),
                # any other metrics you compute every loop
                **{"train/loss": loss.item(), "train/acc": accuracy},
            )
            # batch_update will automatically run evals and save checkpoints as needed

        tp.epoch_update()
TrainingManager( model: torch.nn.modules.module.Module, logger: ~TLogger, dataloader: torch.utils.data.dataloader.DataLoader | None = None, epochs_total: int | None = None, save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] = <function save>, evals: Optional[Iterable[tuple[Union[str, tuple[Union[int, float, str], str], TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None, checkpoint_interval: Union[str, tuple[Union[int, float, str], str], TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'), print_metrics_interval: Union[str, tuple[Union[int, float, str], str], TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'), model_save_path: str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt', model_save_path_special: str = '{run_path}/model.{alias}.pt')
168	def __init__(
169		self,
170		model: "torch.nn.Module",
171		logger: TLogger,
172		# required if you don't wrap the loops
173		dataloader: "torch.utils.data.DataLoader|None" = None,
174		epochs_total: int | None = None,
175		save_model: Callable[["torch.nn.Module", Path], None] = torch.save,
176		# everything with intervals
177		evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None,
178		checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"),
179		print_metrics_interval: CastableToTrainingInterval = TrainingInterval(
180			0.1, "runs"
181		),
182		# everything with paths
183		model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt",
184		model_save_path_special: str = "{run_path}/model.{alias}.pt",
185	):
186		# save start time
187		self.start_time: float = time.time()
188		# non path and non-interval args get copied over directly
189		self.model: "torch.nn.Module" = model
190		self.logger: TLogger = logger
191		self.save_model: Callable[["torch.nn.Module", Path], None] = save_model
192
193		self.logger.message("starting training manager initialization")
194
195		# model save paths
196		self.model_save_path: str = model_save_path
197		self.model_save_path_special: str = model_save_path_special
198
199		# temp intervals for processing later in `try_compute_counters`
200		self._evals: Iterable[tuple[TrainingInterval, EvalFunction]]
201		if evals is None:
202			self._evals = []
203		else:
204			self._evals = [
205				(TrainingInterval.from_any(interval), eval_fn)
206				for interval, eval_fn in evals
207			]
208		self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any(
209			checkpoint_interval
210		)
211		self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any(
212			print_metrics_interval
213		)
214
215		self.evals: list[tuple[int, EvalFunction]] = list()
216		self.checkpoint_interval: int | None = None
217		self.print_metrics_interval: int | None = None
218
219		# counters for epochs, batches, samples, and checkpoints
220		self.epochs: int = 0
221		self.batches: int = 0
222		self.samples: int = 0
223		self.checkpoints: int = 0
224
225		# total numbers of epochs, batches, and samples
226		# pass via init kwarg or wrapped epochs loop
227		self.epochs_total: int | None = epochs_total
228		# from dataloader or dataloader in wrapped loop
229		self.batches_per_epoch: int | None = None
230		self.batch_size: int | None = None
231		self.samples_per_epoch: int | None = None
232		# computed dynamically from the above
233		self.batches_total: int | None = None
234		self.samples_total: int | None = None
235
236		# whether the init is finished
237		self.init_complete: bool = False
238
239		# if we have a dataloader, we can compute some of the above
240		if dataloader is not None:
241			self.batches_per_epoch = len(dataloader)
242			self.batch_size = dataloader.batch_size
243			self.samples_per_epoch = len(dataloader.dataset)  # type: ignore[arg-type]
244
245		self.try_compute_counters()
start_time: float
model: torch.nn.modules.module.Module
logger: ~TLogger
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType]
model_save_path: str
model_save_path_special: str
evals: list[tuple[int, typing.Callable[[torch.nn.modules.module.Module], dict]]]
checkpoint_interval: int | None
print_metrics_interval: int | None
epochs: int
batches: int
samples: int
checkpoints: int
epochs_total: int | None
batches_per_epoch: int | None
batch_size: int | None
samples_per_epoch: int | None
batches_total: int | None
samples_total: int | None
init_complete: bool
def try_compute_counters(self) -> None:
247	def try_compute_counters(self) -> None:
248		# we depend on either the TrainingManager init or the wrapped loops
249		# getting the epochs_total and dataloader
250		# everything else is computed dynamically
251
252		if any(
253			x is None
254			for x in [
255				self.epochs_total,
256				self.batches_per_epoch,
257				self.batch_size,
258				self.samples_per_epoch,
259			]
260		):
261			# if we don't have all the info we need, return early
262			return
263
264		# we can safely ignore type check here since we just checked for `None`
265		self.batches_total = self.batches_per_epoch * self.epochs_total  # type: ignore[operator]
266		self.samples_total = self.samples_per_epoch * self.epochs_total  # type: ignore[operator]
267
268		# check if the dataloader has a finite nonzero length
269		if self.samples_per_epoch == 0:
270			raise TrainingManagerInitError(
271				f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }"
272			)
273
274		if self.batches_per_epoch == 0:
275			raise TrainingManagerInitError(
276				f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }"
277			)
278
279		if self.batch_size == 0:
280			raise TrainingManagerInitError(
281				f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }"
282			)
283
284		if self.batch_size is None:
285			warnings.warn(
286				"batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute."
287				+ "\nthis should probably be an exception"
288			)
289
290		# normalize intervals for checkpoints, metrics printing, and evals
291		_batch_info_kwargs: dict[str, int | None] = dict(
292			batches_per_epoch=self.batches_per_epoch,
293			batchsize=self.batch_size,
294			epochs=self.epochs_total,
295		)
296
297		# TODO: no idea why we need `type: ignore[arg-type]` here
298		self.checkpoint_interval = TrainingInterval.process_to_batches(
299			interval=self._checkpoint_interval,
300			**_batch_info_kwargs,  # type: ignore[arg-type]
301		)
302		self.print_metrics_interval = TrainingInterval.process_to_batches(
303			interval=self._print_metrics_interval,
304			**_batch_info_kwargs,  # type: ignore[arg-type]
305		)
306
307		# list[tuple[int, EvalFunction]]
308		self.evals = [
309			(
310				TrainingInterval.process_to_batches(interval, **_batch_info_kwargs),  # type: ignore[arg-type]
311				eval_fn,
312			)
313			for interval, eval_fn in self._evals
314		]
315
316		# log this info
317		self.init_complete = True
318		self.logger.message(
319			"initialized training manager",
320			__training_manager_init__=True,
321			epochs_total=self.epochs_total,
322			batches_per_epoch=self.batches_per_epoch,
323			batch_size=self.batch_size,
324			samples_per_epoch=self.samples_per_epoch,
325			samples_total=self.samples_total,
326			checkpoint_interval_batches=self.checkpoint_interval,
327			print_metrics_interval_batches=self.print_metrics_interval,
328			model_save_path=self.model_save_path,
329			model_save_path_special=self.model_save_path_special,
330			**self.training_status(),
331		)
def epoch_loop( self, epochs: Sequence[int], use_tqdm: bool = True, **tqdm_kwargs) -> Generator[int, NoneType, NoneType]:
360	def epoch_loop(
361		self,
362		epochs: Sequence[int],
363		use_tqdm: bool = True,
364		**tqdm_kwargs,
365	) -> Generator[int, None, None]:
366		return wrapped_iterable(
367			sequence=epochs,
368			manager=self,
369			is_epoch=True,
370			use_tqdm=use_tqdm,
371			tqdm_kwargs=tqdm_kwargs,
372		)
def batch_loop( self, batches: Sequence[int], use_tqdm: bool = False, **tqdm_kwargs) -> Generator[int, NoneType, NoneType]:
374	def batch_loop(
375		self,
376		batches: Sequence[int],
377		use_tqdm: bool = False,
378		**tqdm_kwargs,
379	) -> Generator[int, None, None]:
380		return wrapped_iterable(
381			sequence=batches,
382			manager=self,
383			is_epoch=False,
384			use_tqdm=use_tqdm,
385			tqdm_kwargs=tqdm_kwargs,
386		)
def check_is_initialized(self):
388	def check_is_initialized(self):
389		if not self.init_complete:
390			raise TrainingManagerInitError(
391				"TrainingManager not correctly initialized. ",
392				"This is likely due to failing to specify the epoch count, or failing to specify batch size/count. "
393				"you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`",
394				"AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.",
395				"please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.",
396			)
def get_elapsed_time(self) -> float:
398	def get_elapsed_time(self) -> float:
399		"""return the elapsed time in seconds since the start of training"""
400		return time.time() - self.start_time

return the elapsed time in seconds since the start of training

def training_status(self) -> dict[str, int | float]:
402	def training_status(self) -> dict[str, int | float]:
403		"""status of elapsed time, samples, batches, epochs, and checkpoints"""
404		return dict(
405			# timestamp handled in logger
406			elapsed_time=self.get_elapsed_time(),
407			samples=self.samples,
408			batches=self.batches,
409			epochs=self.epochs,
410			latest_checkpoint=self.checkpoints,
411		)

status of elapsed time, samples, batches, epochs, and checkpoints

def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs):
435	def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs):
436		"""call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs
437
438		This function will:
439		- update internal counters
440		- run evals as needed (based on the intervals passed)
441		- log all metrics and training status
442		- save a checkpoint as needed (based on the checkpoint interval)
443		"""
444		# check init is finished
445		if not self.init_complete:
446			self.try_compute_counters()
447			self.check_is_initialized()
448
449		# process metrics and kwargs
450		if metrics is None:
451			metrics = dict()
452
453		metrics.update(kwargs)
454
455		# update counters
456		self.batches += 1
457		if samples is not None:
458			self.samples += samples
459		else:
460			# TODO: we warn if batch size is None, but don't except
461			self.samples += self.batch_size  # type: ignore[operator]
462
463		# run evals if needed
464		for interval, eval_fn in self.evals:
465			if (self.batches % interval == 0) or (self.batches == self.batches_total):
466				metrics.update(eval_fn(self.model))
467
468		# log metrics & training status
469		self.logger.metrics({**metrics, **self.training_status()})
470
471		# print metrics if needed
472
473		# save checkpoint if needed
474		if self.batches % self.checkpoint_interval == 0:  # type: ignore[operator]
475			self._save_checkpoint()

call this at the end of every batch. Pass samples or it will be inferred from the batch size, and any other metrics as kwargs

This function will:

  • update internal counters
  • run evals as needed (based on the intervals passed)
  • log all metrics and training status
  • save a checkpoint as needed (based on the checkpoint interval)
def epoch_update(self):
477	def epoch_update(self):
478		"""call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter"""
479		self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}")
480		self.epochs += 1

call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter