docs for trnbl v0.1.1
View Source on GitHub

trnbl.training_manager


  1import time
  2from types import TracebackType
  3from typing import Any, Callable, Generic, Iterable, Type, TypeVar, Generator, Sequence
  4from pathlib import Path
  5import warnings
  6
  7import tqdm  # type: ignore[import-untyped]
  8
  9# torch
 10try:
 11	import torch
 12except ImportError:
 13	warnings.warn("PyTorch not found, this might break things!")
 14
 15# trnbl
 16from trnbl.loggers.base import TrainingLoggerBase
 17from trnbl.training_interval import TrainingInterval, CastableToTrainingInterval
 18
 19# evaluation function should take a model and return some metrics
 20EvalFunction = Callable[["torch.nn.Module"], dict]
 21
 22
 23class TrainingManagerInitError(Exception):
 24	pass
 25
 26
 27T = TypeVar("T")
 28
 29
 30def wrapped_iterable(
 31	sequence: Sequence[T],
 32	manager: "TrainingManager",
 33	is_epoch: bool = False,
 34	use_tqdm: bool | None = None,
 35	tqdm_kwargs: dict[str, Any] | None = None,
 36) -> Generator[T, None, None]:
 37	length: int = len(sequence)
 38
 39	# update the manager if it's not fully initialized
 40	# ------------------------------------------------------------
 41	if not manager.init_complete:
 42		if is_epoch:
 43			# if epoch loop, set the total epochs
 44			manager.epochs_total = length
 45		else:
 46			# if batch loop, set other things
 47			manager.batches_per_epoch = length
 48			try:
 49				manager.batch_size = sequence.batch_size  # type: ignore[attr-defined]
 50				manager.samples_per_epoch = len(sequence.dataset)  # type: ignore[attr-defined]
 51			except AttributeError as e:
 52				raise TrainingManagerInitError(
 53					"could not get the batch size or dataset size from the dataloader passed to `TrainingManager().batch_loop()`. ",
 54					"pass either a `torch.utils.data.DataLoader` ",
 55					"or an iterable with a `batch_size: int` attribute and a `dataset: Iterable` attribute.",
 56				) from e
 57
 58		# try to compute counters and finish init of TrainingManager
 59		manager.try_compute_counters()
 60
 61	# set up progress bar with tqdm
 62	# ------------------------------------------------------------
 63	use_tqdm = (
 64		use_tqdm
 65		if use_tqdm is not None  # do what the user says
 66		else is_epoch  # otherwise, use tqdm if we are the epoch loop
 67	)
 68
 69	if use_tqdm:
 70		# tqdm kwargs with defaults
 71		_tqdm_kwargs: dict[str, Any] = dict(
 72			desc="training run"
 73			if is_epoch
 74			else f"epoch {manager.epochs + 1}/{manager.epochs_total}",
 75			unit=" epochs" if is_epoch else " batches",
 76			total=length,
 77		)
 78		if tqdm_kwargs is not None:
 79			_tqdm_kwargs.update(tqdm_kwargs)
 80
 81		# wrap with tqdm
 82		sequence = tqdm.tqdm(sequence, **_tqdm_kwargs)
 83
 84	# yield the items, and update the manager
 85	# ------------------------------------------------------------
 86	for item in sequence:
 87		yield item
 88		if is_epoch:
 89			manager.epoch_update()
 90		# no need to call batch_update, since the user has to call batch_update to log metrics
 91
 92
 93TLogger = TypeVar("TLogger", bound=TrainingLoggerBase)
 94
 95
 96class TrainingManager(Generic[TLogger]):
 97	"""context manager for training a model, with logging, evals, and checkpoints
 98
 99	# Parameters:
100	- `model : torch.nn.Module`
101	    ref to model being trained - used for saving checkpoints
102	- `dataloader : torch.utils.data.DataLoader`
103	    ref to dataloader being used - used for calculating training progress
104	- `logger : TrainingLoggerBase`
105	    logger, which can be local or interface with wandb.
106	- `epochs : int`
107	    number of epochs to train for
108	    (defaults to `1`)
109	- `evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None`
110	    list of pairs of (interval, eval_fn) to run evals on the model. See `TrainingInterval` for interval options.
111	    (defaults to `None`)
112	- `checkpoint_interval : TrainingInterval | str`
113	    interval at which to save model checkpoints
114	    (defaults to `TrainingInterval(1, "epochs")`)
115	- `print_metrics_interval : TrainingInterval | str`
116	    interval at which to print metrics
117	    (defaults to `TrainingInterval(0.1, "runs")`)
118	- `save_model : Callable[[torch.nn.Module, Path], None]`
119	    function to save the model (defaults to `torch.save`)
120	    (defaults to `torch.save`)
121	- `model_save_path : str`
122	    format string for saving model checkpoints. uses `_get_format_kwargs` for formatting, along with an `alias` kwarg
123	    (defaults to `"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"`)
124	- `model_save_path_special : str`
125	    format string for saving special model checkpoints (final, exception, etc.). uses `_get_format_kwargs` for formatting, along with an `alias` kwarg
126	    (defaults to `"{run_path}/model.{alias}.pt"`)
127
128	# Usage:
129	```python
130	with TrainingManager(
131	    model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
132	    evals={
133	        "1 epochs": eval_func,
134	        "0.1 runs": lambda model: logger.get_mem_usage(),
135	    }.items(),
136	    checkpoint_interval="50 epochs",
137	) as tp:
138
139	    # Training loop
140	    model.train()
141	    for epoch in range(epochs):
142	        for inputs, targets in TRAIN_LOADER:
143	            # the usual
144	            optimizer.zero_grad()
145	            outputs = model(inputs)
146	            loss = criterion(outputs, targets)
147	            loss.backward()
148	            optimizer.step()
149
150	            # compute accuracy
151	            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
152
153	            # log metrics
154	            tp.batch_update(
155	                # pass in number of samples in your batch (or it will be inferred from the batch size)
156	                samples=len(targets),
157	                # any other metrics you compute every loop
158	                **{"train/loss": loss.item(), "train/acc": accuracy},
159	            )
160	            # batch_update will automatically run evals and save checkpoints as needed
161
162	        tp.epoch_update()
163	```
164
165	"""
166
167	def __init__(
168		self,
169		model: "torch.nn.Module",
170		logger: TLogger,
171		# required if you don't wrap the loops
172		dataloader: "torch.utils.data.DataLoader|None" = None,
173		epochs_total: int | None = None,
174		save_model: Callable[["torch.nn.Module", Path], None] = torch.save,
175		# everything with intervals
176		evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None,
177		checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"),
178		print_metrics_interval: CastableToTrainingInterval = TrainingInterval(
179			0.1, "runs"
180		),
181		# everything with paths
182		model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt",
183		model_save_path_special: str = "{run_path}/model.{alias}.pt",
184	):
185		# save start time
186		self.start_time: float = time.time()
187		# non path and non-interval args get copied over directly
188		self.model: "torch.nn.Module" = model
189		self.logger: TLogger = logger
190		self.save_model: Callable[["torch.nn.Module", Path], None] = save_model
191
192		self.logger.message("starting training manager initialization")
193
194		# model save paths
195		self.model_save_path: str = model_save_path
196		self.model_save_path_special: str = model_save_path_special
197
198		# temp intervals for processing later in `try_compute_counters`
199		self._evals: Iterable[tuple[TrainingInterval, EvalFunction]]
200		if evals is None:
201			self._evals = []
202		else:
203			self._evals = [
204				(TrainingInterval.from_any(interval), eval_fn)
205				for interval, eval_fn in evals
206			]
207		self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any(
208			checkpoint_interval
209		)
210		self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any(
211			print_metrics_interval
212		)
213
214		self.evals: list[tuple[int, EvalFunction]] = list()
215		self.checkpoint_interval: int | None = None
216		self.print_metrics_interval: int | None = None
217
218		# counters for epochs, batches, samples, and checkpoints
219		self.epochs: int = 0
220		self.batches: int = 0
221		self.samples: int = 0
222		self.checkpoints: int = 0
223
224		# total numbers of epochs, batches, and samples
225		# pass via init kwarg or wrapped epochs loop
226		self.epochs_total: int | None = epochs_total
227		# from dataloader or dataloader in wrapped loop
228		self.batches_per_epoch: int | None = None
229		self.batch_size: int | None = None
230		self.samples_per_epoch: int | None = None
231		# computed dynamically from the above
232		self.batches_total: int | None = None
233		self.samples_total: int | None = None
234
235		# whether the init is finished
236		self.init_complete: bool = False
237
238		# if we have a dataloader, we can compute some of the above
239		if dataloader is not None:
240			self.batches_per_epoch = len(dataloader)
241			self.batch_size = dataloader.batch_size
242			self.samples_per_epoch = len(dataloader.dataset)  # type: ignore[arg-type]
243
244		self.try_compute_counters()
245
246	def try_compute_counters(self) -> None:
247		# we depend on either the TrainingManager init or the wrapped loops
248		# getting the epochs_total and dataloader
249		# everything else is computed dynamically
250
251		if any(
252			x is None
253			for x in [
254				self.epochs_total,
255				self.batches_per_epoch,
256				self.batch_size,
257				self.samples_per_epoch,
258			]
259		):
260			# if we don't have all the info we need, return early
261			return
262
263		# we can safely ignore type check here since we just checked for `None`
264		self.batches_total = self.batches_per_epoch * self.epochs_total  # type: ignore[operator]
265		self.samples_total = self.samples_per_epoch * self.epochs_total  # type: ignore[operator]
266
267		# check if the dataloader has a finite nonzero length
268		if self.samples_per_epoch == 0:
269			raise TrainingManagerInitError(
270				f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }"
271			)
272
273		if self.batches_per_epoch == 0:
274			raise TrainingManagerInitError(
275				f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }"
276			)
277
278		if self.batch_size == 0:
279			raise TrainingManagerInitError(
280				f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }"
281			)
282
283		if self.batch_size is None:
284			warnings.warn(
285				"batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute."
286				+ "\nthis should probably be an exception"
287			)
288
289		# normalize intervals for checkpoints, metrics printing, and evals
290		_batch_info_kwargs: dict[str, int | None] = dict(
291			batches_per_epoch=self.batches_per_epoch,
292			batchsize=self.batch_size,
293			epochs=self.epochs_total,
294		)
295
296		# TODO: no idea why we need `type: ignore[arg-type]` here
297		self.checkpoint_interval = TrainingInterval.process_to_batches(
298			interval=self._checkpoint_interval,
299			**_batch_info_kwargs,  # type: ignore[arg-type]
300		)
301		self.print_metrics_interval = TrainingInterval.process_to_batches(
302			interval=self._print_metrics_interval,
303			**_batch_info_kwargs,  # type: ignore[arg-type]
304		)
305
306		# list[tuple[int, EvalFunction]]
307		self.evals = [
308			(
309				TrainingInterval.process_to_batches(interval, **_batch_info_kwargs),  # type: ignore[arg-type]
310				eval_fn,
311			)
312			for interval, eval_fn in self._evals
313		]
314
315		# log this info
316		self.init_complete = True
317		self.logger.message(
318			"initialized training manager",
319			__training_manager_init__=True,
320			epochs_total=self.epochs_total,
321			batches_per_epoch=self.batches_per_epoch,
322			batch_size=self.batch_size,
323			samples_per_epoch=self.samples_per_epoch,
324			samples_total=self.samples_total,
325			checkpoint_interval_batches=self.checkpoint_interval,
326			print_metrics_interval_batches=self.print_metrics_interval,
327			model_save_path=self.model_save_path,
328			model_save_path_special=self.model_save_path_special,
329			**self.training_status(),
330		)
331
332	def __enter__(self):
333		return self
334
335	def __exit__(self, exc_type: Type, exc_val: Exception, exc_tb: TracebackType):
336		# if error
337		if exc_type is not None:
338			# add exception info
339			self.logger.error(
340				str(exc_val),
341				exc_type=str(exc_type),
342				exc_val=str(exc_val),
343				exc_tb=str(exc_tb),
344			)
345			# save the model
346			self._save_checkpoint(alias="exception")
347
348			# close the logger
349			self.logger.finish()
350		else:
351			# if no error, log and save the final model
352			self.logger.message(
353				"training complete",
354				__complete__=True,
355			)
356			self._save_checkpoint(alias="final")
357			self.logger.finish()
358
359	def epoch_loop(
360		self,
361		epochs: Sequence[int],
362		use_tqdm: bool = True,
363		**tqdm_kwargs,
364	) -> Generator[int, None, None]:
365		return wrapped_iterable(
366			sequence=epochs,
367			manager=self,
368			is_epoch=True,
369			use_tqdm=use_tqdm,
370			tqdm_kwargs=tqdm_kwargs,
371		)
372
373	def batch_loop(
374		self,
375		batches: Sequence[int],
376		use_tqdm: bool = False,
377		**tqdm_kwargs,
378	) -> Generator[int, None, None]:
379		return wrapped_iterable(
380			sequence=batches,
381			manager=self,
382			is_epoch=False,
383			use_tqdm=use_tqdm,
384			tqdm_kwargs=tqdm_kwargs,
385		)
386
387	def check_is_initialized(self):
388		if not self.init_complete:
389			raise TrainingManagerInitError(
390				"TrainingManager not correctly initialized. ",
391				"This is likely due to failing to specify the epoch count, or failing to specify batch size/count. "
392				"you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`",
393				"AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.",
394				"please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.",
395			)
396
397	def get_elapsed_time(self) -> float:
398		"""return the elapsed time in seconds since the start of training"""
399		return time.time() - self.start_time
400
401	def training_status(self) -> dict[str, int | float]:
402		"""status of elapsed time, samples, batches, epochs, and checkpoints"""
403		return dict(
404			# timestamp handled in logger
405			elapsed_time=self.get_elapsed_time(),
406			samples=self.samples,
407			batches=self.batches,
408			epochs=self.epochs,
409			latest_checkpoint=self.checkpoints,
410		)
411
412	def _get_format_kwargs(self) -> dict[str, str | int | float]:
413		"""keyword args for formatting model save paths. calls `TrainingManager.training_status`
414
415		# Provides:
416		- `run_path : str` - path where the run is being logged and artifacts are being saved
417		- `elapsed_time : float` - the elapsed time in seconds since the start of training
418		- `samples : int` - samples seen so far
419		- `batches : int` - batches seen so far
420		- `epochs : int` - the latest epoch number
421		- `latest_checkpoint : int` - the latest checkpoint number
422
423		"""
424		return {
425			"run_path": (
426				self.logger.run_path.as_posix()
427				if isinstance(self.logger.run_path, Path)
428				# HACK: just return the first one
429				else self.logger.run_path[0].as_posix()
430			),
431			**self.training_status(),
432		}
433
434	def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs):
435		"""call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs
436
437		This function will:
438		- update internal counters
439		- run evals as needed (based on the intervals passed)
440		- log all metrics and training status
441		- save a checkpoint as needed (based on the checkpoint interval)
442		"""
443		# check init is finished
444		if not self.init_complete:
445			self.try_compute_counters()
446			self.check_is_initialized()
447
448		# process metrics and kwargs
449		if metrics is None:
450			metrics = dict()
451
452		metrics.update(kwargs)
453
454		# update counters
455		self.batches += 1
456		if samples is not None:
457			self.samples += samples
458		else:
459			# TODO: we warn if batch size is None, but don't except
460			self.samples += self.batch_size  # type: ignore[operator]
461
462		# run evals if needed
463		for interval, eval_fn in self.evals:
464			if (self.batches % interval == 0) or (self.batches == self.batches_total):
465				metrics.update(eval_fn(self.model))
466
467		# log metrics & training status
468		self.logger.metrics({**metrics, **self.training_status()})
469
470		# print metrics if needed
471
472		# save checkpoint if needed
473		if self.batches % self.checkpoint_interval == 0:  # type: ignore[operator]
474			self._save_checkpoint()
475
476	def epoch_update(self):
477		"""call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter"""
478		self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}")
479		self.epochs += 1
480
481	def _save_checkpoint(self, alias: str | None = None):
482		"""wrapper for saving checkpoint as an artifact to the logger and incrementing the checkpoint counter"""
483		# if no alias, then it's a regular checkpoint
484		no_alias: bool = alias is None
485		if no_alias:
486			alias = f"checkpoint-{self.checkpoints}"
487
488		# TODO: store training hist with model?
489
490		# put together a path
491		checkpoint_path: Path
492		if no_alias:
493			# format the model save path for a normal checkpoint
494			checkpoint_path = Path(
495				self.model_save_path.format(
496					**self._get_format_kwargs(),
497					alias=alias,
498				)
499			)
500		else:
501			# for a special checkpoint, use the special path
502			checkpoint_path = Path(
503				self.model_save_path_special.format(
504					**self._get_format_kwargs(),
505					alias=alias,
506				)
507			)
508
509		# make sure directory exists
510		checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
511
512		# save the model
513		self.save_model(self.model, checkpoint_path)
514
515		# log the checkpoint as an artifact
516		self.logger.artifact(
517			checkpoint_path,
518			"model",
519			metadata=self.training_status(),
520			aliases=[alias] if alias else None,
521		)
522
523		# increment checkpoint counter
524		self.checkpoints += 1

EvalFunction = typing.Callable[[ForwardRef('torch.nn.Module')], dict]
class TrainingManagerInitError(builtins.Exception):
24class TrainingManagerInitError(Exception):
25	pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
add_note
args
def wrapped_iterable( sequence: Sequence[~T], manager: TrainingManager, is_epoch: bool = False, use_tqdm: bool | None = None, tqdm_kwargs: dict[str, typing.Any] | None = None) -> Generator[~T, NoneType, NoneType]:
31def wrapped_iterable(
32	sequence: Sequence[T],
33	manager: "TrainingManager",
34	is_epoch: bool = False,
35	use_tqdm: bool | None = None,
36	tqdm_kwargs: dict[str, Any] | None = None,
37) -> Generator[T, None, None]:
38	length: int = len(sequence)
39
40	# update the manager if it's not fully initialized
41	# ------------------------------------------------------------
42	if not manager.init_complete:
43		if is_epoch:
44			# if epoch loop, set the total epochs
45			manager.epochs_total = length
46		else:
47			# if batch loop, set other things
48			manager.batches_per_epoch = length
49			try:
50				manager.batch_size = sequence.batch_size  # type: ignore[attr-defined]
51				manager.samples_per_epoch = len(sequence.dataset)  # type: ignore[attr-defined]
52			except AttributeError as e:
53				raise TrainingManagerInitError(
54					"could not get the batch size or dataset size from the dataloader passed to `TrainingManager().batch_loop()`. ",
55					"pass either a `torch.utils.data.DataLoader` ",
56					"or an iterable with a `batch_size: int` attribute and a `dataset: Iterable` attribute.",
57				) from e
58
59		# try to compute counters and finish init of TrainingManager
60		manager.try_compute_counters()
61
62	# set up progress bar with tqdm
63	# ------------------------------------------------------------
64	use_tqdm = (
65		use_tqdm
66		if use_tqdm is not None  # do what the user says
67		else is_epoch  # otherwise, use tqdm if we are the epoch loop
68	)
69
70	if use_tqdm:
71		# tqdm kwargs with defaults
72		_tqdm_kwargs: dict[str, Any] = dict(
73			desc="training run"
74			if is_epoch
75			else f"epoch {manager.epochs + 1}/{manager.epochs_total}",
76			unit=" epochs" if is_epoch else " batches",
77			total=length,
78		)
79		if tqdm_kwargs is not None:
80			_tqdm_kwargs.update(tqdm_kwargs)
81
82		# wrap with tqdm
83		sequence = tqdm.tqdm(sequence, **_tqdm_kwargs)
84
85	# yield the items, and update the manager
86	# ------------------------------------------------------------
87	for item in sequence:
88		yield item
89		if is_epoch:
90			manager.epoch_update()
91		# no need to call batch_update, since the user has to call batch_update to log metrics
class TrainingManager(typing.Generic[~TLogger]):
 97class TrainingManager(Generic[TLogger]):
 98	"""context manager for training a model, with logging, evals, and checkpoints
 99
100	# Parameters:
101	- `model : torch.nn.Module`
102	    ref to model being trained - used for saving checkpoints
103	- `dataloader : torch.utils.data.DataLoader`
104	    ref to dataloader being used - used for calculating training progress
105	- `logger : TrainingLoggerBase`
106	    logger, which can be local or interface with wandb.
107	- `epochs : int`
108	    number of epochs to train for
109	    (defaults to `1`)
110	- `evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None`
111	    list of pairs of (interval, eval_fn) to run evals on the model. See `TrainingInterval` for interval options.
112	    (defaults to `None`)
113	- `checkpoint_interval : TrainingInterval | str`
114	    interval at which to save model checkpoints
115	    (defaults to `TrainingInterval(1, "epochs")`)
116	- `print_metrics_interval : TrainingInterval | str`
117	    interval at which to print metrics
118	    (defaults to `TrainingInterval(0.1, "runs")`)
119	- `save_model : Callable[[torch.nn.Module, Path], None]`
120	    function to save the model (defaults to `torch.save`)
121	    (defaults to `torch.save`)
122	- `model_save_path : str`
123	    format string for saving model checkpoints. uses `_get_format_kwargs` for formatting, along with an `alias` kwarg
124	    (defaults to `"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"`)
125	- `model_save_path_special : str`
126	    format string for saving special model checkpoints (final, exception, etc.). uses `_get_format_kwargs` for formatting, along with an `alias` kwarg
127	    (defaults to `"{run_path}/model.{alias}.pt"`)
128
129	# Usage:
130	```python
131	with TrainingManager(
132	    model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
133	    evals={
134	        "1 epochs": eval_func,
135	        "0.1 runs": lambda model: logger.get_mem_usage(),
136	    }.items(),
137	    checkpoint_interval="50 epochs",
138	) as tp:
139
140	    # Training loop
141	    model.train()
142	    for epoch in range(epochs):
143	        for inputs, targets in TRAIN_LOADER:
144	            # the usual
145	            optimizer.zero_grad()
146	            outputs = model(inputs)
147	            loss = criterion(outputs, targets)
148	            loss.backward()
149	            optimizer.step()
150
151	            # compute accuracy
152	            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
153
154	            # log metrics
155	            tp.batch_update(
156	                # pass in number of samples in your batch (or it will be inferred from the batch size)
157	                samples=len(targets),
158	                # any other metrics you compute every loop
159	                **{"train/loss": loss.item(), "train/acc": accuracy},
160	            )
161	            # batch_update will automatically run evals and save checkpoints as needed
162
163	        tp.epoch_update()
164	```
165
166	"""
167
168	def __init__(
169		self,
170		model: "torch.nn.Module",
171		logger: TLogger,
172		# required if you don't wrap the loops
173		dataloader: "torch.utils.data.DataLoader|None" = None,
174		epochs_total: int | None = None,
175		save_model: Callable[["torch.nn.Module", Path], None] = torch.save,
176		# everything with intervals
177		evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None,
178		checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"),
179		print_metrics_interval: CastableToTrainingInterval = TrainingInterval(
180			0.1, "runs"
181		),
182		# everything with paths
183		model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt",
184		model_save_path_special: str = "{run_path}/model.{alias}.pt",
185	):
186		# save start time
187		self.start_time: float = time.time()
188		# non path and non-interval args get copied over directly
189		self.model: "torch.nn.Module" = model
190		self.logger: TLogger = logger
191		self.save_model: Callable[["torch.nn.Module", Path], None] = save_model
192
193		self.logger.message("starting training manager initialization")
194
195		# model save paths
196		self.model_save_path: str = model_save_path
197		self.model_save_path_special: str = model_save_path_special
198
199		# temp intervals for processing later in `try_compute_counters`
200		self._evals: Iterable[tuple[TrainingInterval, EvalFunction]]
201		if evals is None:
202			self._evals = []
203		else:
204			self._evals = [
205				(TrainingInterval.from_any(interval), eval_fn)
206				for interval, eval_fn in evals
207			]
208		self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any(
209			checkpoint_interval
210		)
211		self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any(
212			print_metrics_interval
213		)
214
215		self.evals: list[tuple[int, EvalFunction]] = list()
216		self.checkpoint_interval: int | None = None
217		self.print_metrics_interval: int | None = None
218
219		# counters for epochs, batches, samples, and checkpoints
220		self.epochs: int = 0
221		self.batches: int = 0
222		self.samples: int = 0
223		self.checkpoints: int = 0
224
225		# total numbers of epochs, batches, and samples
226		# pass via init kwarg or wrapped epochs loop
227		self.epochs_total: int | None = epochs_total
228		# from dataloader or dataloader in wrapped loop
229		self.batches_per_epoch: int | None = None
230		self.batch_size: int | None = None
231		self.samples_per_epoch: int | None = None
232		# computed dynamically from the above
233		self.batches_total: int | None = None
234		self.samples_total: int | None = None
235
236		# whether the init is finished
237		self.init_complete: bool = False
238
239		# if we have a dataloader, we can compute some of the above
240		if dataloader is not None:
241			self.batches_per_epoch = len(dataloader)
242			self.batch_size = dataloader.batch_size
243			self.samples_per_epoch = len(dataloader.dataset)  # type: ignore[arg-type]
244
245		self.try_compute_counters()
246
247	def try_compute_counters(self) -> None:
248		# we depend on either the TrainingManager init or the wrapped loops
249		# getting the epochs_total and dataloader
250		# everything else is computed dynamically
251
252		if any(
253			x is None
254			for x in [
255				self.epochs_total,
256				self.batches_per_epoch,
257				self.batch_size,
258				self.samples_per_epoch,
259			]
260		):
261			# if we don't have all the info we need, return early
262			return
263
264		# we can safely ignore type check here since we just checked for `None`
265		self.batches_total = self.batches_per_epoch * self.epochs_total  # type: ignore[operator]
266		self.samples_total = self.samples_per_epoch * self.epochs_total  # type: ignore[operator]
267
268		# check if the dataloader has a finite nonzero length
269		if self.samples_per_epoch == 0:
270			raise TrainingManagerInitError(
271				f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }"
272			)
273
274		if self.batches_per_epoch == 0:
275			raise TrainingManagerInitError(
276				f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }"
277			)
278
279		if self.batch_size == 0:
280			raise TrainingManagerInitError(
281				f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }"
282			)
283
284		if self.batch_size is None:
285			warnings.warn(
286				"batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute."
287				+ "\nthis should probably be an exception"
288			)
289
290		# normalize intervals for checkpoints, metrics printing, and evals
291		_batch_info_kwargs: dict[str, int | None] = dict(
292			batches_per_epoch=self.batches_per_epoch,
293			batchsize=self.batch_size,
294			epochs=self.epochs_total,
295		)
296
297		# TODO: no idea why we need `type: ignore[arg-type]` here
298		self.checkpoint_interval = TrainingInterval.process_to_batches(
299			interval=self._checkpoint_interval,
300			**_batch_info_kwargs,  # type: ignore[arg-type]
301		)
302		self.print_metrics_interval = TrainingInterval.process_to_batches(
303			interval=self._print_metrics_interval,
304			**_batch_info_kwargs,  # type: ignore[arg-type]
305		)
306
307		# list[tuple[int, EvalFunction]]
308		self.evals = [
309			(
310				TrainingInterval.process_to_batches(interval, **_batch_info_kwargs),  # type: ignore[arg-type]
311				eval_fn,
312			)
313			for interval, eval_fn in self._evals
314		]
315
316		# log this info
317		self.init_complete = True
318		self.logger.message(
319			"initialized training manager",
320			__training_manager_init__=True,
321			epochs_total=self.epochs_total,
322			batches_per_epoch=self.batches_per_epoch,
323			batch_size=self.batch_size,
324			samples_per_epoch=self.samples_per_epoch,
325			samples_total=self.samples_total,
326			checkpoint_interval_batches=self.checkpoint_interval,
327			print_metrics_interval_batches=self.print_metrics_interval,
328			model_save_path=self.model_save_path,
329			model_save_path_special=self.model_save_path_special,
330			**self.training_status(),
331		)
332
333	def __enter__(self):
334		return self
335
336	def __exit__(self, exc_type: Type, exc_val: Exception, exc_tb: TracebackType):
337		# if error
338		if exc_type is not None:
339			# add exception info
340			self.logger.error(
341				str(exc_val),
342				exc_type=str(exc_type),
343				exc_val=str(exc_val),
344				exc_tb=str(exc_tb),
345			)
346			# save the model
347			self._save_checkpoint(alias="exception")
348
349			# close the logger
350			self.logger.finish()
351		else:
352			# if no error, log and save the final model
353			self.logger.message(
354				"training complete",
355				__complete__=True,
356			)
357			self._save_checkpoint(alias="final")
358			self.logger.finish()
359
360	def epoch_loop(
361		self,
362		epochs: Sequence[int],
363		use_tqdm: bool = True,
364		**tqdm_kwargs,
365	) -> Generator[int, None, None]:
366		return wrapped_iterable(
367			sequence=epochs,
368			manager=self,
369			is_epoch=True,
370			use_tqdm=use_tqdm,
371			tqdm_kwargs=tqdm_kwargs,
372		)
373
374	def batch_loop(
375		self,
376		batches: Sequence[int],
377		use_tqdm: bool = False,
378		**tqdm_kwargs,
379	) -> Generator[int, None, None]:
380		return wrapped_iterable(
381			sequence=batches,
382			manager=self,
383			is_epoch=False,
384			use_tqdm=use_tqdm,
385			tqdm_kwargs=tqdm_kwargs,
386		)
387
388	def check_is_initialized(self):
389		if not self.init_complete:
390			raise TrainingManagerInitError(
391				"TrainingManager not correctly initialized. ",
392				"This is likely due to failing to specify the epoch count, or failing to specify batch size/count. "
393				"you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`",
394				"AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.",
395				"please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.",
396			)
397
398	def get_elapsed_time(self) -> float:
399		"""return the elapsed time in seconds since the start of training"""
400		return time.time() - self.start_time
401
402	def training_status(self) -> dict[str, int | float]:
403		"""status of elapsed time, samples, batches, epochs, and checkpoints"""
404		return dict(
405			# timestamp handled in logger
406			elapsed_time=self.get_elapsed_time(),
407			samples=self.samples,
408			batches=self.batches,
409			epochs=self.epochs,
410			latest_checkpoint=self.checkpoints,
411		)
412
413	def _get_format_kwargs(self) -> dict[str, str | int | float]:
414		"""keyword args for formatting model save paths. calls `TrainingManager.training_status`
415
416		# Provides:
417		- `run_path : str` - path where the run is being logged and artifacts are being saved
418		- `elapsed_time : float` - the elapsed time in seconds since the start of training
419		- `samples : int` - samples seen so far
420		- `batches : int` - batches seen so far
421		- `epochs : int` - the latest epoch number
422		- `latest_checkpoint : int` - the latest checkpoint number
423
424		"""
425		return {
426			"run_path": (
427				self.logger.run_path.as_posix()
428				if isinstance(self.logger.run_path, Path)
429				# HACK: just return the first one
430				else self.logger.run_path[0].as_posix()
431			),
432			**self.training_status(),
433		}
434
435	def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs):
436		"""call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs
437
438		This function will:
439		- update internal counters
440		- run evals as needed (based on the intervals passed)
441		- log all metrics and training status
442		- save a checkpoint as needed (based on the checkpoint interval)
443		"""
444		# check init is finished
445		if not self.init_complete:
446			self.try_compute_counters()
447			self.check_is_initialized()
448
449		# process metrics and kwargs
450		if metrics is None:
451			metrics = dict()
452
453		metrics.update(kwargs)
454
455		# update counters
456		self.batches += 1
457		if samples is not None:
458			self.samples += samples
459		else:
460			# TODO: we warn if batch size is None, but don't except
461			self.samples += self.batch_size  # type: ignore[operator]
462
463		# run evals if needed
464		for interval, eval_fn in self.evals:
465			if (self.batches % interval == 0) or (self.batches == self.batches_total):
466				metrics.update(eval_fn(self.model))
467
468		# log metrics & training status
469		self.logger.metrics({**metrics, **self.training_status()})
470
471		# print metrics if needed
472
473		# save checkpoint if needed
474		if self.batches % self.checkpoint_interval == 0:  # type: ignore[operator]
475			self._save_checkpoint()
476
477	def epoch_update(self):
478		"""call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter"""
479		self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}")
480		self.epochs += 1
481
482	def _save_checkpoint(self, alias: str | None = None):
483		"""wrapper for saving checkpoint as an artifact to the logger and incrementing the checkpoint counter"""
484		# if no alias, then it's a regular checkpoint
485		no_alias: bool = alias is None
486		if no_alias:
487			alias = f"checkpoint-{self.checkpoints}"
488
489		# TODO: store training hist with model?
490
491		# put together a path
492		checkpoint_path: Path
493		if no_alias:
494			# format the model save path for a normal checkpoint
495			checkpoint_path = Path(
496				self.model_save_path.format(
497					**self._get_format_kwargs(),
498					alias=alias,
499				)
500			)
501		else:
502			# for a special checkpoint, use the special path
503			checkpoint_path = Path(
504				self.model_save_path_special.format(
505					**self._get_format_kwargs(),
506					alias=alias,
507				)
508			)
509
510		# make sure directory exists
511		checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
512
513		# save the model
514		self.save_model(self.model, checkpoint_path)
515
516		# log the checkpoint as an artifact
517		self.logger.artifact(
518			checkpoint_path,
519			"model",
520			metadata=self.training_status(),
521			aliases=[alias] if alias else None,
522		)
523
524		# increment checkpoint counter
525		self.checkpoints += 1

context manager for training a model, with logging, evals, and checkpoints

Parameters:

  • model : torch.nn.Module ref to model being trained - used for saving checkpoints
  • dataloader : torch.utils.data.DataLoader ref to dataloader being used - used for calculating training progress
  • logger : TrainingLoggerBase logger, which can be local or interface with wandb.
  • epochs : int number of epochs to train for (defaults to 1)
  • evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None list of pairs of (interval, eval_fn) to run evals on the model. See TrainingInterval for interval options. (defaults to None)
  • checkpoint_interval : TrainingInterval | str interval at which to save model checkpoints (defaults to TrainingInterval(1, "epochs"))
  • print_metrics_interval : TrainingInterval | str interval at which to print metrics (defaults to TrainingInterval(0.1, "runs"))
  • save_model : Callable[[torch.nn.Module, Path], None] function to save the model (defaults to torch.save) (defaults to torch.save)
  • model_save_path : str format string for saving model checkpoints. uses _get_format_kwargs for formatting, along with an alias kwarg (defaults to "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt")
  • model_save_path_special : str format string for saving special model checkpoints (final, exception, etc.). uses _get_format_kwargs for formatting, along with an alias kwarg (defaults to "{run_path}/model.{alias}.pt")

Usage:

with TrainingManager(
    model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
    evals={
        "1 epochs": eval_func,
        "0.1 runs": lambda model: logger.get_mem_usage(),
    }.items(),
    checkpoint_interval="50 epochs",
) as tp:

    # Training loop
    model.train()
    for epoch in range(epochs):
        for inputs, targets in TRAIN_LOADER:
            # the usual
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # compute accuracy
            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)

            # log metrics
            tp.batch_update(
                # pass in number of samples in your batch (or it will be inferred from the batch size)
                samples=len(targets),
                # any other metrics you compute every loop
                **{"train/loss": loss.item(), "train/acc": accuracy},
            )
            # batch_update will automatically run evals and save checkpoints as needed

        tp.epoch_update()
TrainingManager( model: torch.nn.modules.module.Module, logger: ~TLogger, dataloader: torch.utils.data.dataloader.DataLoader | None = None, epochs_total: int | None = None, save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] = <function save>, evals: Optional[Iterable[tuple[Union[str, tuple[Union[int, float, str], str], trnbl.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None, checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'), print_metrics_interval: Union[str, tuple[Union[int, float, str], str], trnbl.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'), model_save_path: str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt', model_save_path_special: str = '{run_path}/model.{alias}.pt')
168	def __init__(
169		self,
170		model: "torch.nn.Module",
171		logger: TLogger,
172		# required if you don't wrap the loops
173		dataloader: "torch.utils.data.DataLoader|None" = None,
174		epochs_total: int | None = None,
175		save_model: Callable[["torch.nn.Module", Path], None] = torch.save,
176		# everything with intervals
177		evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None,
178		checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"),
179		print_metrics_interval: CastableToTrainingInterval = TrainingInterval(
180			0.1, "runs"
181		),
182		# everything with paths
183		model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt",
184		model_save_path_special: str = "{run_path}/model.{alias}.pt",
185	):
186		# save start time
187		self.start_time: float = time.time()
188		# non path and non-interval args get copied over directly
189		self.model: "torch.nn.Module" = model
190		self.logger: TLogger = logger
191		self.save_model: Callable[["torch.nn.Module", Path], None] = save_model
192
193		self.logger.message("starting training manager initialization")
194
195		# model save paths
196		self.model_save_path: str = model_save_path
197		self.model_save_path_special: str = model_save_path_special
198
199		# temp intervals for processing later in `try_compute_counters`
200		self._evals: Iterable[tuple[TrainingInterval, EvalFunction]]
201		if evals is None:
202			self._evals = []
203		else:
204			self._evals = [
205				(TrainingInterval.from_any(interval), eval_fn)
206				for interval, eval_fn in evals
207			]
208		self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any(
209			checkpoint_interval
210		)
211		self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any(
212			print_metrics_interval
213		)
214
215		self.evals: list[tuple[int, EvalFunction]] = list()
216		self.checkpoint_interval: int | None = None
217		self.print_metrics_interval: int | None = None
218
219		# counters for epochs, batches, samples, and checkpoints
220		self.epochs: int = 0
221		self.batches: int = 0
222		self.samples: int = 0
223		self.checkpoints: int = 0
224
225		# total numbers of epochs, batches, and samples
226		# pass via init kwarg or wrapped epochs loop
227		self.epochs_total: int | None = epochs_total
228		# from dataloader or dataloader in wrapped loop
229		self.batches_per_epoch: int | None = None
230		self.batch_size: int | None = None
231		self.samples_per_epoch: int | None = None
232		# computed dynamically from the above
233		self.batches_total: int | None = None
234		self.samples_total: int | None = None
235
236		# whether the init is finished
237		self.init_complete: bool = False
238
239		# if we have a dataloader, we can compute some of the above
240		if dataloader is not None:
241			self.batches_per_epoch = len(dataloader)
242			self.batch_size = dataloader.batch_size
243			self.samples_per_epoch = len(dataloader.dataset)  # type: ignore[arg-type]
244
245		self.try_compute_counters()
start_time: float
model: torch.nn.modules.module.Module
logger: ~TLogger
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType]
model_save_path: str
model_save_path_special: str
evals: list[tuple[int, typing.Callable[[torch.nn.modules.module.Module], dict]]]
checkpoint_interval: int | None
print_metrics_interval: int | None
epochs: int
batches: int
samples: int
checkpoints: int
epochs_total: int | None
batches_per_epoch: int | None
batch_size: int | None
samples_per_epoch: int | None
batches_total: int | None
samples_total: int | None
init_complete: bool
def try_compute_counters(self) -> None:
247	def try_compute_counters(self) -> None:
248		# we depend on either the TrainingManager init or the wrapped loops
249		# getting the epochs_total and dataloader
250		# everything else is computed dynamically
251
252		if any(
253			x is None
254			for x in [
255				self.epochs_total,
256				self.batches_per_epoch,
257				self.batch_size,
258				self.samples_per_epoch,
259			]
260		):
261			# if we don't have all the info we need, return early
262			return
263
264		# we can safely ignore type check here since we just checked for `None`
265		self.batches_total = self.batches_per_epoch * self.epochs_total  # type: ignore[operator]
266		self.samples_total = self.samples_per_epoch * self.epochs_total  # type: ignore[operator]
267
268		# check if the dataloader has a finite nonzero length
269		if self.samples_per_epoch == 0:
270			raise TrainingManagerInitError(
271				f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }"
272			)
273
274		if self.batches_per_epoch == 0:
275			raise TrainingManagerInitError(
276				f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }"
277			)
278
279		if self.batch_size == 0:
280			raise TrainingManagerInitError(
281				f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }"
282			)
283
284		if self.batch_size is None:
285			warnings.warn(
286				"batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute."
287				+ "\nthis should probably be an exception"
288			)
289
290		# normalize intervals for checkpoints, metrics printing, and evals
291		_batch_info_kwargs: dict[str, int | None] = dict(
292			batches_per_epoch=self.batches_per_epoch,
293			batchsize=self.batch_size,
294			epochs=self.epochs_total,
295		)
296
297		# TODO: no idea why we need `type: ignore[arg-type]` here
298		self.checkpoint_interval = TrainingInterval.process_to_batches(
299			interval=self._checkpoint_interval,
300			**_batch_info_kwargs,  # type: ignore[arg-type]
301		)
302		self.print_metrics_interval = TrainingInterval.process_to_batches(
303			interval=self._print_metrics_interval,
304			**_batch_info_kwargs,  # type: ignore[arg-type]
305		)
306
307		# list[tuple[int, EvalFunction]]
308		self.evals = [
309			(
310				TrainingInterval.process_to_batches(interval, **_batch_info_kwargs),  # type: ignore[arg-type]
311				eval_fn,
312			)
313			for interval, eval_fn in self._evals
314		]
315
316		# log this info
317		self.init_complete = True
318		self.logger.message(
319			"initialized training manager",
320			__training_manager_init__=True,
321			epochs_total=self.epochs_total,
322			batches_per_epoch=self.batches_per_epoch,
323			batch_size=self.batch_size,
324			samples_per_epoch=self.samples_per_epoch,
325			samples_total=self.samples_total,
326			checkpoint_interval_batches=self.checkpoint_interval,
327			print_metrics_interval_batches=self.print_metrics_interval,
328			model_save_path=self.model_save_path,
329			model_save_path_special=self.model_save_path_special,
330			**self.training_status(),
331		)
def epoch_loop( self, epochs: Sequence[int], use_tqdm: bool = True, **tqdm_kwargs) -> Generator[int, NoneType, NoneType]:
360	def epoch_loop(
361		self,
362		epochs: Sequence[int],
363		use_tqdm: bool = True,
364		**tqdm_kwargs,
365	) -> Generator[int, None, None]:
366		return wrapped_iterable(
367			sequence=epochs,
368			manager=self,
369			is_epoch=True,
370			use_tqdm=use_tqdm,
371			tqdm_kwargs=tqdm_kwargs,
372		)
def batch_loop( self, batches: Sequence[int], use_tqdm: bool = False, **tqdm_kwargs) -> Generator[int, NoneType, NoneType]:
374	def batch_loop(
375		self,
376		batches: Sequence[int],
377		use_tqdm: bool = False,
378		**tqdm_kwargs,
379	) -> Generator[int, None, None]:
380		return wrapped_iterable(
381			sequence=batches,
382			manager=self,
383			is_epoch=False,
384			use_tqdm=use_tqdm,
385			tqdm_kwargs=tqdm_kwargs,
386		)
def check_is_initialized(self):
388	def check_is_initialized(self):
389		if not self.init_complete:
390			raise TrainingManagerInitError(
391				"TrainingManager not correctly initialized. ",
392				"This is likely due to failing to specify the epoch count, or failing to specify batch size/count. "
393				"you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`",
394				"AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.",
395				"please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.",
396			)
def get_elapsed_time(self) -> float:
398	def get_elapsed_time(self) -> float:
399		"""return the elapsed time in seconds since the start of training"""
400		return time.time() - self.start_time

return the elapsed time in seconds since the start of training

def training_status(self) -> dict[str, int | float]:
402	def training_status(self) -> dict[str, int | float]:
403		"""status of elapsed time, samples, batches, epochs, and checkpoints"""
404		return dict(
405			# timestamp handled in logger
406			elapsed_time=self.get_elapsed_time(),
407			samples=self.samples,
408			batches=self.batches,
409			epochs=self.epochs,
410			latest_checkpoint=self.checkpoints,
411		)

status of elapsed time, samples, batches, epochs, and checkpoints

def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs):
435	def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs):
436		"""call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs
437
438		This function will:
439		- update internal counters
440		- run evals as needed (based on the intervals passed)
441		- log all metrics and training status
442		- save a checkpoint as needed (based on the checkpoint interval)
443		"""
444		# check init is finished
445		if not self.init_complete:
446			self.try_compute_counters()
447			self.check_is_initialized()
448
449		# process metrics and kwargs
450		if metrics is None:
451			metrics = dict()
452
453		metrics.update(kwargs)
454
455		# update counters
456		self.batches += 1
457		if samples is not None:
458			self.samples += samples
459		else:
460			# TODO: we warn if batch size is None, but don't except
461			self.samples += self.batch_size  # type: ignore[operator]
462
463		# run evals if needed
464		for interval, eval_fn in self.evals:
465			if (self.batches % interval == 0) or (self.batches == self.batches_total):
466				metrics.update(eval_fn(self.model))
467
468		# log metrics & training status
469		self.logger.metrics({**metrics, **self.training_status()})
470
471		# print metrics if needed
472
473		# save checkpoint if needed
474		if self.batches % self.checkpoint_interval == 0:  # type: ignore[operator]
475			self._save_checkpoint()

call this at the end of every batch. Pass samples or it will be inferred from the batch size, and any other metrics as kwargs

This function will:

  • update internal counters
  • run evals as needed (based on the intervals passed)
  • log all metrics and training status
  • save a checkpoint as needed (based on the checkpoint interval)
def epoch_update(self):
477	def epoch_update(self):
478		"""call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter"""
479		self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}")
480		self.epochs += 1

call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter