Coverage for trnbl\training_manager.py: 92%
146 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1import time
2from types import TracebackType
3from typing import Any, Callable, Generic, Iterable, Type, TypeVar, Generator, Sequence
4from pathlib import Path
5import warnings
7import tqdm # type: ignore[import-untyped]
9# torch
10try:
11 import torch
12except ImportError:
13 warnings.warn("PyTorch not found, this might break things!")
15# trnbl
16from trnbl.loggers.base import TrainingLoggerBase
17from trnbl.training_interval import TrainingInterval, CastableToTrainingInterval
19# evaluation function should take a model and return some metrics
20EvalFunction = Callable[["torch.nn.Module"], dict]
23class TrainingManagerInitError(Exception):
24 pass
27T = TypeVar("T")
30def wrapped_iterable(
31 sequence: Sequence[T],
32 manager: "TrainingManager",
33 is_epoch: bool = False,
34 use_tqdm: bool | None = None,
35 tqdm_kwargs: dict[str, Any] | None = None,
36) -> Generator[T, None, None]:
37 length: int = len(sequence)
39 # update the manager if it's not fully initialized
40 # ------------------------------------------------------------
41 if not manager.init_complete:
42 if is_epoch:
43 # if epoch loop, set the total epochs
44 manager.epochs_total = length
45 else:
46 # if batch loop, set other things
47 manager.batches_per_epoch = length
48 try:
49 manager.batch_size = sequence.batch_size # type: ignore[attr-defined]
50 manager.samples_per_epoch = len(sequence.dataset) # type: ignore[attr-defined]
51 except AttributeError as e:
52 raise TrainingManagerInitError(
53 "could not get the batch size or dataset size from the dataloader passed to `TrainingManager().batch_loop()`. ",
54 "pass either a `torch.utils.data.DataLoader` ",
55 "or an iterable with a `batch_size: int` attribute and a `dataset: Iterable` attribute.",
56 ) from e
58 # try to compute counters and finish init of TrainingManager
59 manager.try_compute_counters()
61 # set up progress bar with tqdm
62 # ------------------------------------------------------------
63 use_tqdm = (
64 use_tqdm
65 if use_tqdm is not None # do what the user says
66 else is_epoch # otherwise, use tqdm if we are the epoch loop
67 )
69 if use_tqdm:
70 # tqdm kwargs with defaults
71 _tqdm_kwargs: dict[str, Any] = dict(
72 desc="training run"
73 if is_epoch
74 else f"epoch {manager.epochs + 1}/{manager.epochs_total}",
75 unit=" epochs" if is_epoch else " batches",
76 total=length,
77 )
78 if tqdm_kwargs is not None:
79 _tqdm_kwargs.update(tqdm_kwargs)
81 # wrap with tqdm
82 sequence = tqdm.tqdm(sequence, **_tqdm_kwargs)
84 # yield the items, and update the manager
85 # ------------------------------------------------------------
86 for item in sequence:
87 yield item
88 if is_epoch:
89 manager.epoch_update()
90 # no need to call batch_update, since the user has to call batch_update to log metrics
93TLogger = TypeVar("TLogger", bound=TrainingLoggerBase)
96class TrainingManager(Generic[TLogger]):
97 """context manager for training a model, with logging, evals, and checkpoints
99 # Parameters:
100 - `model : torch.nn.Module`
101 ref to model being trained - used for saving checkpoints
102 - `dataloader : torch.utils.data.DataLoader`
103 ref to dataloader being used - used for calculating training progress
104 - `logger : TrainingLoggerBase`
105 logger, which can be local or interface with wandb.
106 - `epochs : int`
107 number of epochs to train for
108 (defaults to `1`)
109 - `evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None`
110 list of pairs of (interval, eval_fn) to run evals on the model. See `TrainingInterval` for interval options.
111 (defaults to `None`)
112 - `checkpoint_interval : TrainingInterval | str`
113 interval at which to save model checkpoints
114 (defaults to `TrainingInterval(1, "epochs")`)
115 - `print_metrics_interval : TrainingInterval | str`
116 interval at which to print metrics
117 (defaults to `TrainingInterval(0.1, "runs")`)
118 - `save_model : Callable[[torch.nn.Module, Path], None]`
119 function to save the model (defaults to `torch.save`)
120 (defaults to `torch.save`)
121 - `model_save_path : str`
122 format string for saving model checkpoints. uses `_get_format_kwargs` for formatting, along with an `alias` kwarg
123 (defaults to `"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"`)
124 - `model_save_path_special : str`
125 format string for saving special model checkpoints (final, exception, etc.). uses `_get_format_kwargs` for formatting, along with an `alias` kwarg
126 (defaults to `"{run_path}/model.{alias}.pt"`)
128 # Usage:
129 ```python
130 with TrainingManager(
131 model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
132 evals={
133 "1 epochs": eval_func,
134 "0.1 runs": lambda model: logger.get_mem_usage(),
135 }.items(),
136 checkpoint_interval="50 epochs",
137 ) as tp:
139 # Training loop
140 model.train()
141 for epoch in range(epochs):
142 for inputs, targets in TRAIN_LOADER:
143 # the usual
144 optimizer.zero_grad()
145 outputs = model(inputs)
146 loss = criterion(outputs, targets)
147 loss.backward()
148 optimizer.step()
150 # compute accuracy
151 accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
153 # log metrics
154 tp.batch_update(
155 # pass in number of samples in your batch (or it will be inferred from the batch size)
156 samples=len(targets),
157 # any other metrics you compute every loop
158 **{"train/loss": loss.item(), "train/acc": accuracy},
159 )
160 # batch_update will automatically run evals and save checkpoints as needed
162 tp.epoch_update()
163 ```
165 """
167 def __init__(
168 self,
169 model: "torch.nn.Module",
170 logger: TLogger,
171 # required if you don't wrap the loops
172 dataloader: "torch.utils.data.DataLoader|None" = None,
173 epochs_total: int | None = None,
174 save_model: Callable[["torch.nn.Module", Path], None] = torch.save,
175 # everything with intervals
176 evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None,
177 checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"),
178 print_metrics_interval: CastableToTrainingInterval = TrainingInterval(
179 0.1, "runs"
180 ),
181 # everything with paths
182 model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt",
183 model_save_path_special: str = "{run_path}/model.{alias}.pt",
184 ):
185 # save start time
186 self.start_time: float = time.time()
187 # non path and non-interval args get copied over directly
188 self.model: "torch.nn.Module" = model
189 self.logger: TLogger = logger
190 self.save_model: Callable[["torch.nn.Module", Path], None] = save_model
192 self.logger.message("starting training manager initialization")
194 # model save paths
195 self.model_save_path: str = model_save_path
196 self.model_save_path_special: str = model_save_path_special
198 # temp intervals for processing later in `try_compute_counters`
199 self._evals: Iterable[tuple[TrainingInterval, EvalFunction]]
200 if evals is None:
201 self._evals = []
202 else:
203 self._evals = [
204 (TrainingInterval.from_any(interval), eval_fn)
205 for interval, eval_fn in evals
206 ]
207 self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any(
208 checkpoint_interval
209 )
210 self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any(
211 print_metrics_interval
212 )
214 self.evals: list[tuple[int, EvalFunction]] = list()
215 self.checkpoint_interval: int | None = None
216 self.print_metrics_interval: int | None = None
218 # counters for epochs, batches, samples, and checkpoints
219 self.epochs: int = 0
220 self.batches: int = 0
221 self.samples: int = 0
222 self.checkpoints: int = 0
224 # total numbers of epochs, batches, and samples
225 # pass via init kwarg or wrapped epochs loop
226 self.epochs_total: int | None = epochs_total
227 # from dataloader or dataloader in wrapped loop
228 self.batches_per_epoch: int | None = None
229 self.batch_size: int | None = None
230 self.samples_per_epoch: int | None = None
231 # computed dynamically from the above
232 self.batches_total: int | None = None
233 self.samples_total: int | None = None
235 # whether the init is finished
236 self.init_complete: bool = False
238 # if we have a dataloader, we can compute some of the above
239 if dataloader is not None:
240 self.batches_per_epoch = len(dataloader)
241 self.batch_size = dataloader.batch_size
242 self.samples_per_epoch = len(dataloader.dataset) # type: ignore[arg-type]
244 self.try_compute_counters()
246 def try_compute_counters(self) -> None:
247 # we depend on either the TrainingManager init or the wrapped loops
248 # getting the epochs_total and dataloader
249 # everything else is computed dynamically
251 if any(
252 x is None
253 for x in [
254 self.epochs_total,
255 self.batches_per_epoch,
256 self.batch_size,
257 self.samples_per_epoch,
258 ]
259 ):
260 # if we don't have all the info we need, return early
261 return
263 # we can safely ignore type check here since we just checked for `None`
264 self.batches_total = self.batches_per_epoch * self.epochs_total # type: ignore[operator]
265 self.samples_total = self.samples_per_epoch * self.epochs_total # type: ignore[operator]
267 # check if the dataloader has a finite nonzero length
268 if self.samples_per_epoch == 0:
269 raise TrainingManagerInitError(
270 f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }"
271 )
273 if self.batches_per_epoch == 0:
274 raise TrainingManagerInitError(
275 f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }"
276 )
278 if self.batch_size == 0:
279 raise TrainingManagerInitError(
280 f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }"
281 )
283 if self.batch_size is None:
284 warnings.warn(
285 "batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute."
286 + "\nthis should probably be an exception"
287 )
289 # normalize intervals for checkpoints, metrics printing, and evals
290 _batch_info_kwargs: dict[str, int | None] = dict(
291 batches_per_epoch=self.batches_per_epoch,
292 batchsize=self.batch_size,
293 epochs=self.epochs_total,
294 )
296 # TODO: no idea why we need `type: ignore[arg-type]` here
297 self.checkpoint_interval = TrainingInterval.process_to_batches(
298 interval=self._checkpoint_interval,
299 **_batch_info_kwargs, # type: ignore[arg-type]
300 )
301 self.print_metrics_interval = TrainingInterval.process_to_batches(
302 interval=self._print_metrics_interval,
303 **_batch_info_kwargs, # type: ignore[arg-type]
304 )
306 # list[tuple[int, EvalFunction]]
307 self.evals = [
308 (
309 TrainingInterval.process_to_batches(interval, **_batch_info_kwargs), # type: ignore[arg-type]
310 eval_fn,
311 )
312 for interval, eval_fn in self._evals
313 ]
315 # log this info
316 self.init_complete = True
317 self.logger.message(
318 "initialized training manager",
319 __training_manager_init__=True,
320 epochs_total=self.epochs_total,
321 batches_per_epoch=self.batches_per_epoch,
322 batch_size=self.batch_size,
323 samples_per_epoch=self.samples_per_epoch,
324 samples_total=self.samples_total,
325 checkpoint_interval_batches=self.checkpoint_interval,
326 print_metrics_interval_batches=self.print_metrics_interval,
327 model_save_path=self.model_save_path,
328 model_save_path_special=self.model_save_path_special,
329 **self.training_status(),
330 )
332 def __enter__(self):
333 return self
335 def __exit__(self, exc_type: Type, exc_val: Exception, exc_tb: TracebackType):
336 # if error
337 if exc_type is not None:
338 # add exception info
339 self.logger.error(
340 str(exc_val),
341 exc_type=str(exc_type),
342 exc_val=str(exc_val),
343 exc_tb=str(exc_tb),
344 )
345 # save the model
346 self._save_checkpoint(alias="exception")
348 # close the logger
349 self.logger.finish()
350 else:
351 # if no error, log and save the final model
352 self.logger.message(
353 "training complete",
354 __complete__=True,
355 )
356 self._save_checkpoint(alias="final")
357 self.logger.finish()
359 def epoch_loop(
360 self,
361 epochs: Sequence[int],
362 use_tqdm: bool = True,
363 **tqdm_kwargs,
364 ) -> Generator[int, None, None]:
365 return wrapped_iterable(
366 sequence=epochs,
367 manager=self,
368 is_epoch=True,
369 use_tqdm=use_tqdm,
370 tqdm_kwargs=tqdm_kwargs,
371 )
373 def batch_loop(
374 self,
375 batches: Sequence[int],
376 use_tqdm: bool = False,
377 **tqdm_kwargs,
378 ) -> Generator[int, None, None]:
379 return wrapped_iterable(
380 sequence=batches,
381 manager=self,
382 is_epoch=False,
383 use_tqdm=use_tqdm,
384 tqdm_kwargs=tqdm_kwargs,
385 )
387 def check_is_initialized(self):
388 if not self.init_complete:
389 raise TrainingManagerInitError(
390 "TrainingManager not correctly initialized. ",
391 "This is likely due to failing to specify the epoch count, or failing to specify batch size/count. "
392 "you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`",
393 "AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.",
394 "please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.",
395 )
397 def get_elapsed_time(self) -> float:
398 """return the elapsed time in seconds since the start of training"""
399 return time.time() - self.start_time
401 def training_status(self) -> dict[str, int | float]:
402 """status of elapsed time, samples, batches, epochs, and checkpoints"""
403 return dict(
404 # timestamp handled in logger
405 elapsed_time=self.get_elapsed_time(),
406 samples=self.samples,
407 batches=self.batches,
408 epochs=self.epochs,
409 latest_checkpoint=self.checkpoints,
410 )
412 def _get_format_kwargs(self) -> dict[str, str | int | float]:
413 """keyword args for formatting model save paths. calls `TrainingManager.training_status`
415 # Provides:
416 - `run_path : str` - path where the run is being logged and artifacts are being saved
417 - `elapsed_time : float` - the elapsed time in seconds since the start of training
418 - `samples : int` - samples seen so far
419 - `batches : int` - batches seen so far
420 - `epochs : int` - the latest epoch number
421 - `latest_checkpoint : int` - the latest checkpoint number
423 """
424 return {
425 "run_path": (
426 self.logger.run_path.as_posix()
427 if isinstance(self.logger.run_path, Path)
428 # HACK: just return the first one
429 else self.logger.run_path[0].as_posix()
430 ),
431 **self.training_status(),
432 }
434 def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs):
435 """call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs
437 This function will:
438 - update internal counters
439 - run evals as needed (based on the intervals passed)
440 - log all metrics and training status
441 - save a checkpoint as needed (based on the checkpoint interval)
442 """
443 # check init is finished
444 if not self.init_complete:
445 self.try_compute_counters()
446 self.check_is_initialized()
448 # process metrics and kwargs
449 if metrics is None:
450 metrics = dict()
452 metrics.update(kwargs)
454 # update counters
455 self.batches += 1
456 if samples is not None:
457 self.samples += samples
458 else:
459 # TODO: we warn if batch size is None, but don't except
460 self.samples += self.batch_size # type: ignore[operator]
462 # run evals if needed
463 for interval, eval_fn in self.evals:
464 if (self.batches % interval == 0) or (self.batches == self.batches_total):
465 metrics.update(eval_fn(self.model))
467 # log metrics & training status
468 self.logger.metrics({**metrics, **self.training_status()})
470 # print metrics if needed
472 # save checkpoint if needed
473 if self.batches % self.checkpoint_interval == 0: # type: ignore[operator]
474 self._save_checkpoint()
476 def epoch_update(self):
477 """call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter"""
478 self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}")
479 self.epochs += 1
481 def _save_checkpoint(self, alias: str | None = None):
482 """wrapper for saving checkpoint as an artifact to the logger and incrementing the checkpoint counter"""
483 # if no alias, then it's a regular checkpoint
484 no_alias: bool = alias is None
485 if no_alias:
486 alias = f"checkpoint-{self.checkpoints}"
488 # TODO: store training hist with model?
490 # put together a path
491 checkpoint_path: Path
492 if no_alias:
493 # format the model save path for a normal checkpoint
494 checkpoint_path = Path(
495 self.model_save_path.format(
496 **self._get_format_kwargs(),
497 alias=alias,
498 )
499 )
500 else:
501 # for a special checkpoint, use the special path
502 checkpoint_path = Path(
503 self.model_save_path_special.format(
504 **self._get_format_kwargs(),
505 alias=alias,
506 )
507 )
509 # make sure directory exists
510 checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
512 # save the model
513 self.save_model(self.model, checkpoint_path)
515 # log the checkpoint as an artifact
516 self.logger.artifact(
517 checkpoint_path,
518 "model",
519 metadata=self.training_status(),
520 aliases=[alias] if alias else None,
521 )
523 # increment checkpoint counter
524 self.checkpoints += 1