trnbl.training_manager
1import time 2from types import TracebackType 3from typing import Any, Callable, Generic, Iterable, Type, TypeVar, Generator, Sequence 4from pathlib import Path 5import warnings 6 7import tqdm # type: ignore[import-untyped] 8 9# torch 10try: 11 import torch 12except ImportError: 13 warnings.warn("PyTorch not found, this might break things!") 14 15# trnbl 16from trnbl.loggers.base import TrainingLoggerBase 17from trnbl.training_interval import TrainingInterval, CastableToTrainingInterval 18 19# evaluation function should take a model and return some metrics 20EvalFunction = Callable[["torch.nn.Module"], dict] 21 22 23class TrainingManagerInitError(Exception): 24 pass 25 26 27T = TypeVar("T") 28 29 30def wrapped_iterable( 31 sequence: Sequence[T], 32 manager: "TrainingManager", 33 is_epoch: bool = False, 34 use_tqdm: bool | None = None, 35 tqdm_kwargs: dict[str, Any] | None = None, 36) -> Generator[T, None, None]: 37 length: int = len(sequence) 38 39 # update the manager if it's not fully initialized 40 # ------------------------------------------------------------ 41 if not manager.init_complete: 42 if is_epoch: 43 # if epoch loop, set the total epochs 44 manager.epochs_total = length 45 else: 46 # if batch loop, set other things 47 manager.batches_per_epoch = length 48 try: 49 manager.batch_size = sequence.batch_size # type: ignore[attr-defined] 50 manager.samples_per_epoch = len(sequence.dataset) # type: ignore[attr-defined] 51 except AttributeError as e: 52 raise TrainingManagerInitError( 53 "could not get the batch size or dataset size from the dataloader passed to `TrainingManager().batch_loop()`. ", 54 "pass either a `torch.utils.data.DataLoader` ", 55 "or an iterable with a `batch_size: int` attribute and a `dataset: Iterable` attribute.", 56 ) from e 57 58 # try to compute counters and finish init of TrainingManager 59 manager.try_compute_counters() 60 61 # set up progress bar with tqdm 62 # ------------------------------------------------------------ 63 use_tqdm = ( 64 use_tqdm 65 if use_tqdm is not None # do what the user says 66 else is_epoch # otherwise, use tqdm if we are the epoch loop 67 ) 68 69 if use_tqdm: 70 # tqdm kwargs with defaults 71 _tqdm_kwargs: dict[str, Any] = dict( 72 desc="training run" 73 if is_epoch 74 else f"epoch {manager.epochs + 1}/{manager.epochs_total}", 75 unit=" epochs" if is_epoch else " batches", 76 total=length, 77 ) 78 if tqdm_kwargs is not None: 79 _tqdm_kwargs.update(tqdm_kwargs) 80 81 # wrap with tqdm 82 sequence = tqdm.tqdm(sequence, **_tqdm_kwargs) 83 84 # yield the items, and update the manager 85 # ------------------------------------------------------------ 86 for item in sequence: 87 yield item 88 if is_epoch: 89 manager.epoch_update() 90 # no need to call batch_update, since the user has to call batch_update to log metrics 91 92 93TLogger = TypeVar("TLogger", bound=TrainingLoggerBase) 94 95 96class TrainingManager(Generic[TLogger]): 97 """context manager for training a model, with logging, evals, and checkpoints 98 99 # Parameters: 100 - `model : torch.nn.Module` 101 ref to model being trained - used for saving checkpoints 102 - `dataloader : torch.utils.data.DataLoader` 103 ref to dataloader being used - used for calculating training progress 104 - `logger : TrainingLoggerBase` 105 logger, which can be local or interface with wandb. 106 - `epochs : int` 107 number of epochs to train for 108 (defaults to `1`) 109 - `evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None` 110 list of pairs of (interval, eval_fn) to run evals on the model. See `TrainingInterval` for interval options. 111 (defaults to `None`) 112 - `checkpoint_interval : TrainingInterval | str` 113 interval at which to save model checkpoints 114 (defaults to `TrainingInterval(1, "epochs")`) 115 - `print_metrics_interval : TrainingInterval | str` 116 interval at which to print metrics 117 (defaults to `TrainingInterval(0.1, "runs")`) 118 - `save_model : Callable[[torch.nn.Module, Path], None]` 119 function to save the model (defaults to `torch.save`) 120 (defaults to `torch.save`) 121 - `model_save_path : str` 122 format string for saving model checkpoints. uses `_get_format_kwargs` for formatting, along with an `alias` kwarg 123 (defaults to `"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"`) 124 - `model_save_path_special : str` 125 format string for saving special model checkpoints (final, exception, etc.). uses `_get_format_kwargs` for formatting, along with an `alias` kwarg 126 (defaults to `"{run_path}/model.{alias}.pt"`) 127 128 # Usage: 129 ```python 130 with TrainingManager( 131 model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500, 132 evals={ 133 "1 epochs": eval_func, 134 "0.1 runs": lambda model: logger.get_mem_usage(), 135 }.items(), 136 checkpoint_interval="50 epochs", 137 ) as tp: 138 139 # Training loop 140 model.train() 141 for epoch in range(epochs): 142 for inputs, targets in TRAIN_LOADER: 143 # the usual 144 optimizer.zero_grad() 145 outputs = model(inputs) 146 loss = criterion(outputs, targets) 147 loss.backward() 148 optimizer.step() 149 150 # compute accuracy 151 accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets) 152 153 # log metrics 154 tp.batch_update( 155 # pass in number of samples in your batch (or it will be inferred from the batch size) 156 samples=len(targets), 157 # any other metrics you compute every loop 158 **{"train/loss": loss.item(), "train/acc": accuracy}, 159 ) 160 # batch_update will automatically run evals and save checkpoints as needed 161 162 tp.epoch_update() 163 ``` 164 165 """ 166 167 def __init__( 168 self, 169 model: "torch.nn.Module", 170 logger: TLogger, 171 # required if you don't wrap the loops 172 dataloader: "torch.utils.data.DataLoader|None" = None, 173 epochs_total: int | None = None, 174 save_model: Callable[["torch.nn.Module", Path], None] = torch.save, 175 # everything with intervals 176 evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None, 177 checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"), 178 print_metrics_interval: CastableToTrainingInterval = TrainingInterval( 179 0.1, "runs" 180 ), 181 # everything with paths 182 model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt", 183 model_save_path_special: str = "{run_path}/model.{alias}.pt", 184 ): 185 # save start time 186 self.start_time: float = time.time() 187 # non path and non-interval args get copied over directly 188 self.model: "torch.nn.Module" = model 189 self.logger: TLogger = logger 190 self.save_model: Callable[["torch.nn.Module", Path], None] = save_model 191 192 self.logger.message("starting training manager initialization") 193 194 # model save paths 195 self.model_save_path: str = model_save_path 196 self.model_save_path_special: str = model_save_path_special 197 198 # temp intervals for processing later in `try_compute_counters` 199 self._evals: Iterable[tuple[TrainingInterval, EvalFunction]] 200 if evals is None: 201 self._evals = [] 202 else: 203 self._evals = [ 204 (TrainingInterval.from_any(interval), eval_fn) 205 for interval, eval_fn in evals 206 ] 207 self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any( 208 checkpoint_interval 209 ) 210 self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any( 211 print_metrics_interval 212 ) 213 214 self.evals: list[tuple[int, EvalFunction]] = list() 215 self.checkpoint_interval: int | None = None 216 self.print_metrics_interval: int | None = None 217 218 # counters for epochs, batches, samples, and checkpoints 219 self.epochs: int = 0 220 self.batches: int = 0 221 self.samples: int = 0 222 self.checkpoints: int = 0 223 224 # total numbers of epochs, batches, and samples 225 # pass via init kwarg or wrapped epochs loop 226 self.epochs_total: int | None = epochs_total 227 # from dataloader or dataloader in wrapped loop 228 self.batches_per_epoch: int | None = None 229 self.batch_size: int | None = None 230 self.samples_per_epoch: int | None = None 231 # computed dynamically from the above 232 self.batches_total: int | None = None 233 self.samples_total: int | None = None 234 235 # whether the init is finished 236 self.init_complete: bool = False 237 238 # if we have a dataloader, we can compute some of the above 239 if dataloader is not None: 240 self.batches_per_epoch = len(dataloader) 241 self.batch_size = dataloader.batch_size 242 self.samples_per_epoch = len(dataloader.dataset) # type: ignore[arg-type] 243 244 self.try_compute_counters() 245 246 def try_compute_counters(self) -> None: 247 # we depend on either the TrainingManager init or the wrapped loops 248 # getting the epochs_total and dataloader 249 # everything else is computed dynamically 250 251 if any( 252 x is None 253 for x in [ 254 self.epochs_total, 255 self.batches_per_epoch, 256 self.batch_size, 257 self.samples_per_epoch, 258 ] 259 ): 260 # if we don't have all the info we need, return early 261 return 262 263 # we can safely ignore type check here since we just checked for `None` 264 self.batches_total = self.batches_per_epoch * self.epochs_total # type: ignore[operator] 265 self.samples_total = self.samples_per_epoch * self.epochs_total # type: ignore[operator] 266 267 # check if the dataloader has a finite nonzero length 268 if self.samples_per_epoch == 0: 269 raise TrainingManagerInitError( 270 f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }" 271 ) 272 273 if self.batches_per_epoch == 0: 274 raise TrainingManagerInitError( 275 f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }" 276 ) 277 278 if self.batch_size == 0: 279 raise TrainingManagerInitError( 280 f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }" 281 ) 282 283 if self.batch_size is None: 284 warnings.warn( 285 "batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute." 286 + "\nthis should probably be an exception" 287 ) 288 289 # normalize intervals for checkpoints, metrics printing, and evals 290 _batch_info_kwargs: dict[str, int | None] = dict( 291 batches_per_epoch=self.batches_per_epoch, 292 batchsize=self.batch_size, 293 epochs=self.epochs_total, 294 ) 295 296 # TODO: no idea why we need `type: ignore[arg-type]` here 297 self.checkpoint_interval = TrainingInterval.process_to_batches( 298 interval=self._checkpoint_interval, 299 **_batch_info_kwargs, # type: ignore[arg-type] 300 ) 301 self.print_metrics_interval = TrainingInterval.process_to_batches( 302 interval=self._print_metrics_interval, 303 **_batch_info_kwargs, # type: ignore[arg-type] 304 ) 305 306 # list[tuple[int, EvalFunction]] 307 self.evals = [ 308 ( 309 TrainingInterval.process_to_batches(interval, **_batch_info_kwargs), # type: ignore[arg-type] 310 eval_fn, 311 ) 312 for interval, eval_fn in self._evals 313 ] 314 315 # log this info 316 self.init_complete = True 317 self.logger.message( 318 "initialized training manager", 319 __training_manager_init__=True, 320 epochs_total=self.epochs_total, 321 batches_per_epoch=self.batches_per_epoch, 322 batch_size=self.batch_size, 323 samples_per_epoch=self.samples_per_epoch, 324 samples_total=self.samples_total, 325 checkpoint_interval_batches=self.checkpoint_interval, 326 print_metrics_interval_batches=self.print_metrics_interval, 327 model_save_path=self.model_save_path, 328 model_save_path_special=self.model_save_path_special, 329 **self.training_status(), 330 ) 331 332 def __enter__(self): 333 return self 334 335 def __exit__(self, exc_type: Type, exc_val: Exception, exc_tb: TracebackType): 336 # if error 337 if exc_type is not None: 338 # add exception info 339 self.logger.error( 340 str(exc_val), 341 exc_type=str(exc_type), 342 exc_val=str(exc_val), 343 exc_tb=str(exc_tb), 344 ) 345 # save the model 346 self._save_checkpoint(alias="exception") 347 348 # close the logger 349 self.logger.finish() 350 else: 351 # if no error, log and save the final model 352 self.logger.message( 353 "training complete", 354 __complete__=True, 355 ) 356 self._save_checkpoint(alias="final") 357 self.logger.finish() 358 359 def epoch_loop( 360 self, 361 epochs: Sequence[int], 362 use_tqdm: bool = True, 363 **tqdm_kwargs, 364 ) -> Generator[int, None, None]: 365 return wrapped_iterable( 366 sequence=epochs, 367 manager=self, 368 is_epoch=True, 369 use_tqdm=use_tqdm, 370 tqdm_kwargs=tqdm_kwargs, 371 ) 372 373 def batch_loop( 374 self, 375 batches: Sequence[int], 376 use_tqdm: bool = False, 377 **tqdm_kwargs, 378 ) -> Generator[int, None, None]: 379 return wrapped_iterable( 380 sequence=batches, 381 manager=self, 382 is_epoch=False, 383 use_tqdm=use_tqdm, 384 tqdm_kwargs=tqdm_kwargs, 385 ) 386 387 def check_is_initialized(self): 388 if not self.init_complete: 389 raise TrainingManagerInitError( 390 "TrainingManager not correctly initialized. ", 391 "This is likely due to failing to specify the epoch count, or failing to specify batch size/count. " 392 "you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`", 393 "AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.", 394 "please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.", 395 ) 396 397 def get_elapsed_time(self) -> float: 398 """return the elapsed time in seconds since the start of training""" 399 return time.time() - self.start_time 400 401 def training_status(self) -> dict[str, int | float]: 402 """status of elapsed time, samples, batches, epochs, and checkpoints""" 403 return dict( 404 # timestamp handled in logger 405 elapsed_time=self.get_elapsed_time(), 406 samples=self.samples, 407 batches=self.batches, 408 epochs=self.epochs, 409 latest_checkpoint=self.checkpoints, 410 ) 411 412 def _get_format_kwargs(self) -> dict[str, str | int | float]: 413 """keyword args for formatting model save paths. calls `TrainingManager.training_status` 414 415 # Provides: 416 - `run_path : str` - path where the run is being logged and artifacts are being saved 417 - `elapsed_time : float` - the elapsed time in seconds since the start of training 418 - `samples : int` - samples seen so far 419 - `batches : int` - batches seen so far 420 - `epochs : int` - the latest epoch number 421 - `latest_checkpoint : int` - the latest checkpoint number 422 423 """ 424 return { 425 "run_path": ( 426 self.logger.run_path.as_posix() 427 if isinstance(self.logger.run_path, Path) 428 # HACK: just return the first one 429 else self.logger.run_path[0].as_posix() 430 ), 431 **self.training_status(), 432 } 433 434 def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs): 435 """call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs 436 437 This function will: 438 - update internal counters 439 - run evals as needed (based on the intervals passed) 440 - log all metrics and training status 441 - save a checkpoint as needed (based on the checkpoint interval) 442 """ 443 # check init is finished 444 if not self.init_complete: 445 self.try_compute_counters() 446 self.check_is_initialized() 447 448 # process metrics and kwargs 449 if metrics is None: 450 metrics = dict() 451 452 metrics.update(kwargs) 453 454 # update counters 455 self.batches += 1 456 if samples is not None: 457 self.samples += samples 458 else: 459 # TODO: we warn if batch size is None, but don't except 460 self.samples += self.batch_size # type: ignore[operator] 461 462 # run evals if needed 463 for interval, eval_fn in self.evals: 464 if (self.batches % interval == 0) or (self.batches == self.batches_total): 465 metrics.update(eval_fn(self.model)) 466 467 # log metrics & training status 468 self.logger.metrics({**metrics, **self.training_status()}) 469 470 # print metrics if needed 471 472 # save checkpoint if needed 473 if self.batches % self.checkpoint_interval == 0: # type: ignore[operator] 474 self._save_checkpoint() 475 476 def epoch_update(self): 477 """call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter""" 478 self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}") 479 self.epochs += 1 480 481 def _save_checkpoint(self, alias: str | None = None): 482 """wrapper for saving checkpoint as an artifact to the logger and incrementing the checkpoint counter""" 483 # if no alias, then it's a regular checkpoint 484 no_alias: bool = alias is None 485 if no_alias: 486 alias = f"checkpoint-{self.checkpoints}" 487 488 # TODO: store training hist with model? 489 490 # put together a path 491 checkpoint_path: Path 492 if no_alias: 493 # format the model save path for a normal checkpoint 494 checkpoint_path = Path( 495 self.model_save_path.format( 496 **self._get_format_kwargs(), 497 alias=alias, 498 ) 499 ) 500 else: 501 # for a special checkpoint, use the special path 502 checkpoint_path = Path( 503 self.model_save_path_special.format( 504 **self._get_format_kwargs(), 505 alias=alias, 506 ) 507 ) 508 509 # make sure directory exists 510 checkpoint_path.parent.mkdir(parents=True, exist_ok=True) 511 512 # save the model 513 self.save_model(self.model, checkpoint_path) 514 515 # log the checkpoint as an artifact 516 self.logger.artifact( 517 checkpoint_path, 518 "model", 519 metadata=self.training_status(), 520 aliases=[alias] if alias else None, 521 ) 522 523 # increment checkpoint counter 524 self.checkpoints += 1
EvalFunction =
typing.Callable[[ForwardRef('torch.nn.Module')], dict]
class
TrainingManagerInitError(builtins.Exception):
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- add_note
- args
def
wrapped_iterable( sequence: Sequence[~T], manager: TrainingManager, is_epoch: bool = False, use_tqdm: bool | None = None, tqdm_kwargs: dict[str, typing.Any] | None = None) -> Generator[~T, NoneType, NoneType]:
31def wrapped_iterable( 32 sequence: Sequence[T], 33 manager: "TrainingManager", 34 is_epoch: bool = False, 35 use_tqdm: bool | None = None, 36 tqdm_kwargs: dict[str, Any] | None = None, 37) -> Generator[T, None, None]: 38 length: int = len(sequence) 39 40 # update the manager if it's not fully initialized 41 # ------------------------------------------------------------ 42 if not manager.init_complete: 43 if is_epoch: 44 # if epoch loop, set the total epochs 45 manager.epochs_total = length 46 else: 47 # if batch loop, set other things 48 manager.batches_per_epoch = length 49 try: 50 manager.batch_size = sequence.batch_size # type: ignore[attr-defined] 51 manager.samples_per_epoch = len(sequence.dataset) # type: ignore[attr-defined] 52 except AttributeError as e: 53 raise TrainingManagerInitError( 54 "could not get the batch size or dataset size from the dataloader passed to `TrainingManager().batch_loop()`. ", 55 "pass either a `torch.utils.data.DataLoader` ", 56 "or an iterable with a `batch_size: int` attribute and a `dataset: Iterable` attribute.", 57 ) from e 58 59 # try to compute counters and finish init of TrainingManager 60 manager.try_compute_counters() 61 62 # set up progress bar with tqdm 63 # ------------------------------------------------------------ 64 use_tqdm = ( 65 use_tqdm 66 if use_tqdm is not None # do what the user says 67 else is_epoch # otherwise, use tqdm if we are the epoch loop 68 ) 69 70 if use_tqdm: 71 # tqdm kwargs with defaults 72 _tqdm_kwargs: dict[str, Any] = dict( 73 desc="training run" 74 if is_epoch 75 else f"epoch {manager.epochs + 1}/{manager.epochs_total}", 76 unit=" epochs" if is_epoch else " batches", 77 total=length, 78 ) 79 if tqdm_kwargs is not None: 80 _tqdm_kwargs.update(tqdm_kwargs) 81 82 # wrap with tqdm 83 sequence = tqdm.tqdm(sequence, **_tqdm_kwargs) 84 85 # yield the items, and update the manager 86 # ------------------------------------------------------------ 87 for item in sequence: 88 yield item 89 if is_epoch: 90 manager.epoch_update() 91 # no need to call batch_update, since the user has to call batch_update to log metrics
class
TrainingManager(typing.Generic[~TLogger]):
97class TrainingManager(Generic[TLogger]): 98 """context manager for training a model, with logging, evals, and checkpoints 99 100 # Parameters: 101 - `model : torch.nn.Module` 102 ref to model being trained - used for saving checkpoints 103 - `dataloader : torch.utils.data.DataLoader` 104 ref to dataloader being used - used for calculating training progress 105 - `logger : TrainingLoggerBase` 106 logger, which can be local or interface with wandb. 107 - `epochs : int` 108 number of epochs to train for 109 (defaults to `1`) 110 - `evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None` 111 list of pairs of (interval, eval_fn) to run evals on the model. See `TrainingInterval` for interval options. 112 (defaults to `None`) 113 - `checkpoint_interval : TrainingInterval | str` 114 interval at which to save model checkpoints 115 (defaults to `TrainingInterval(1, "epochs")`) 116 - `print_metrics_interval : TrainingInterval | str` 117 interval at which to print metrics 118 (defaults to `TrainingInterval(0.1, "runs")`) 119 - `save_model : Callable[[torch.nn.Module, Path], None]` 120 function to save the model (defaults to `torch.save`) 121 (defaults to `torch.save`) 122 - `model_save_path : str` 123 format string for saving model checkpoints. uses `_get_format_kwargs` for formatting, along with an `alias` kwarg 124 (defaults to `"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"`) 125 - `model_save_path_special : str` 126 format string for saving special model checkpoints (final, exception, etc.). uses `_get_format_kwargs` for formatting, along with an `alias` kwarg 127 (defaults to `"{run_path}/model.{alias}.pt"`) 128 129 # Usage: 130 ```python 131 with TrainingManager( 132 model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500, 133 evals={ 134 "1 epochs": eval_func, 135 "0.1 runs": lambda model: logger.get_mem_usage(), 136 }.items(), 137 checkpoint_interval="50 epochs", 138 ) as tp: 139 140 # Training loop 141 model.train() 142 for epoch in range(epochs): 143 for inputs, targets in TRAIN_LOADER: 144 # the usual 145 optimizer.zero_grad() 146 outputs = model(inputs) 147 loss = criterion(outputs, targets) 148 loss.backward() 149 optimizer.step() 150 151 # compute accuracy 152 accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets) 153 154 # log metrics 155 tp.batch_update( 156 # pass in number of samples in your batch (or it will be inferred from the batch size) 157 samples=len(targets), 158 # any other metrics you compute every loop 159 **{"train/loss": loss.item(), "train/acc": accuracy}, 160 ) 161 # batch_update will automatically run evals and save checkpoints as needed 162 163 tp.epoch_update() 164 ``` 165 166 """ 167 168 def __init__( 169 self, 170 model: "torch.nn.Module", 171 logger: TLogger, 172 # required if you don't wrap the loops 173 dataloader: "torch.utils.data.DataLoader|None" = None, 174 epochs_total: int | None = None, 175 save_model: Callable[["torch.nn.Module", Path], None] = torch.save, 176 # everything with intervals 177 evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None, 178 checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"), 179 print_metrics_interval: CastableToTrainingInterval = TrainingInterval( 180 0.1, "runs" 181 ), 182 # everything with paths 183 model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt", 184 model_save_path_special: str = "{run_path}/model.{alias}.pt", 185 ): 186 # save start time 187 self.start_time: float = time.time() 188 # non path and non-interval args get copied over directly 189 self.model: "torch.nn.Module" = model 190 self.logger: TLogger = logger 191 self.save_model: Callable[["torch.nn.Module", Path], None] = save_model 192 193 self.logger.message("starting training manager initialization") 194 195 # model save paths 196 self.model_save_path: str = model_save_path 197 self.model_save_path_special: str = model_save_path_special 198 199 # temp intervals for processing later in `try_compute_counters` 200 self._evals: Iterable[tuple[TrainingInterval, EvalFunction]] 201 if evals is None: 202 self._evals = [] 203 else: 204 self._evals = [ 205 (TrainingInterval.from_any(interval), eval_fn) 206 for interval, eval_fn in evals 207 ] 208 self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any( 209 checkpoint_interval 210 ) 211 self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any( 212 print_metrics_interval 213 ) 214 215 self.evals: list[tuple[int, EvalFunction]] = list() 216 self.checkpoint_interval: int | None = None 217 self.print_metrics_interval: int | None = None 218 219 # counters for epochs, batches, samples, and checkpoints 220 self.epochs: int = 0 221 self.batches: int = 0 222 self.samples: int = 0 223 self.checkpoints: int = 0 224 225 # total numbers of epochs, batches, and samples 226 # pass via init kwarg or wrapped epochs loop 227 self.epochs_total: int | None = epochs_total 228 # from dataloader or dataloader in wrapped loop 229 self.batches_per_epoch: int | None = None 230 self.batch_size: int | None = None 231 self.samples_per_epoch: int | None = None 232 # computed dynamically from the above 233 self.batches_total: int | None = None 234 self.samples_total: int | None = None 235 236 # whether the init is finished 237 self.init_complete: bool = False 238 239 # if we have a dataloader, we can compute some of the above 240 if dataloader is not None: 241 self.batches_per_epoch = len(dataloader) 242 self.batch_size = dataloader.batch_size 243 self.samples_per_epoch = len(dataloader.dataset) # type: ignore[arg-type] 244 245 self.try_compute_counters() 246 247 def try_compute_counters(self) -> None: 248 # we depend on either the TrainingManager init or the wrapped loops 249 # getting the epochs_total and dataloader 250 # everything else is computed dynamically 251 252 if any( 253 x is None 254 for x in [ 255 self.epochs_total, 256 self.batches_per_epoch, 257 self.batch_size, 258 self.samples_per_epoch, 259 ] 260 ): 261 # if we don't have all the info we need, return early 262 return 263 264 # we can safely ignore type check here since we just checked for `None` 265 self.batches_total = self.batches_per_epoch * self.epochs_total # type: ignore[operator] 266 self.samples_total = self.samples_per_epoch * self.epochs_total # type: ignore[operator] 267 268 # check if the dataloader has a finite nonzero length 269 if self.samples_per_epoch == 0: 270 raise TrainingManagerInitError( 271 f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }" 272 ) 273 274 if self.batches_per_epoch == 0: 275 raise TrainingManagerInitError( 276 f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }" 277 ) 278 279 if self.batch_size == 0: 280 raise TrainingManagerInitError( 281 f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }" 282 ) 283 284 if self.batch_size is None: 285 warnings.warn( 286 "batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute." 287 + "\nthis should probably be an exception" 288 ) 289 290 # normalize intervals for checkpoints, metrics printing, and evals 291 _batch_info_kwargs: dict[str, int | None] = dict( 292 batches_per_epoch=self.batches_per_epoch, 293 batchsize=self.batch_size, 294 epochs=self.epochs_total, 295 ) 296 297 # TODO: no idea why we need `type: ignore[arg-type]` here 298 self.checkpoint_interval = TrainingInterval.process_to_batches( 299 interval=self._checkpoint_interval, 300 **_batch_info_kwargs, # type: ignore[arg-type] 301 ) 302 self.print_metrics_interval = TrainingInterval.process_to_batches( 303 interval=self._print_metrics_interval, 304 **_batch_info_kwargs, # type: ignore[arg-type] 305 ) 306 307 # list[tuple[int, EvalFunction]] 308 self.evals = [ 309 ( 310 TrainingInterval.process_to_batches(interval, **_batch_info_kwargs), # type: ignore[arg-type] 311 eval_fn, 312 ) 313 for interval, eval_fn in self._evals 314 ] 315 316 # log this info 317 self.init_complete = True 318 self.logger.message( 319 "initialized training manager", 320 __training_manager_init__=True, 321 epochs_total=self.epochs_total, 322 batches_per_epoch=self.batches_per_epoch, 323 batch_size=self.batch_size, 324 samples_per_epoch=self.samples_per_epoch, 325 samples_total=self.samples_total, 326 checkpoint_interval_batches=self.checkpoint_interval, 327 print_metrics_interval_batches=self.print_metrics_interval, 328 model_save_path=self.model_save_path, 329 model_save_path_special=self.model_save_path_special, 330 **self.training_status(), 331 ) 332 333 def __enter__(self): 334 return self 335 336 def __exit__(self, exc_type: Type, exc_val: Exception, exc_tb: TracebackType): 337 # if error 338 if exc_type is not None: 339 # add exception info 340 self.logger.error( 341 str(exc_val), 342 exc_type=str(exc_type), 343 exc_val=str(exc_val), 344 exc_tb=str(exc_tb), 345 ) 346 # save the model 347 self._save_checkpoint(alias="exception") 348 349 # close the logger 350 self.logger.finish() 351 else: 352 # if no error, log and save the final model 353 self.logger.message( 354 "training complete", 355 __complete__=True, 356 ) 357 self._save_checkpoint(alias="final") 358 self.logger.finish() 359 360 def epoch_loop( 361 self, 362 epochs: Sequence[int], 363 use_tqdm: bool = True, 364 **tqdm_kwargs, 365 ) -> Generator[int, None, None]: 366 return wrapped_iterable( 367 sequence=epochs, 368 manager=self, 369 is_epoch=True, 370 use_tqdm=use_tqdm, 371 tqdm_kwargs=tqdm_kwargs, 372 ) 373 374 def batch_loop( 375 self, 376 batches: Sequence[int], 377 use_tqdm: bool = False, 378 **tqdm_kwargs, 379 ) -> Generator[int, None, None]: 380 return wrapped_iterable( 381 sequence=batches, 382 manager=self, 383 is_epoch=False, 384 use_tqdm=use_tqdm, 385 tqdm_kwargs=tqdm_kwargs, 386 ) 387 388 def check_is_initialized(self): 389 if not self.init_complete: 390 raise TrainingManagerInitError( 391 "TrainingManager not correctly initialized. ", 392 "This is likely due to failing to specify the epoch count, or failing to specify batch size/count. " 393 "you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`", 394 "AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.", 395 "please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.", 396 ) 397 398 def get_elapsed_time(self) -> float: 399 """return the elapsed time in seconds since the start of training""" 400 return time.time() - self.start_time 401 402 def training_status(self) -> dict[str, int | float]: 403 """status of elapsed time, samples, batches, epochs, and checkpoints""" 404 return dict( 405 # timestamp handled in logger 406 elapsed_time=self.get_elapsed_time(), 407 samples=self.samples, 408 batches=self.batches, 409 epochs=self.epochs, 410 latest_checkpoint=self.checkpoints, 411 ) 412 413 def _get_format_kwargs(self) -> dict[str, str | int | float]: 414 """keyword args for formatting model save paths. calls `TrainingManager.training_status` 415 416 # Provides: 417 - `run_path : str` - path where the run is being logged and artifacts are being saved 418 - `elapsed_time : float` - the elapsed time in seconds since the start of training 419 - `samples : int` - samples seen so far 420 - `batches : int` - batches seen so far 421 - `epochs : int` - the latest epoch number 422 - `latest_checkpoint : int` - the latest checkpoint number 423 424 """ 425 return { 426 "run_path": ( 427 self.logger.run_path.as_posix() 428 if isinstance(self.logger.run_path, Path) 429 # HACK: just return the first one 430 else self.logger.run_path[0].as_posix() 431 ), 432 **self.training_status(), 433 } 434 435 def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs): 436 """call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs 437 438 This function will: 439 - update internal counters 440 - run evals as needed (based on the intervals passed) 441 - log all metrics and training status 442 - save a checkpoint as needed (based on the checkpoint interval) 443 """ 444 # check init is finished 445 if not self.init_complete: 446 self.try_compute_counters() 447 self.check_is_initialized() 448 449 # process metrics and kwargs 450 if metrics is None: 451 metrics = dict() 452 453 metrics.update(kwargs) 454 455 # update counters 456 self.batches += 1 457 if samples is not None: 458 self.samples += samples 459 else: 460 # TODO: we warn if batch size is None, but don't except 461 self.samples += self.batch_size # type: ignore[operator] 462 463 # run evals if needed 464 for interval, eval_fn in self.evals: 465 if (self.batches % interval == 0) or (self.batches == self.batches_total): 466 metrics.update(eval_fn(self.model)) 467 468 # log metrics & training status 469 self.logger.metrics({**metrics, **self.training_status()}) 470 471 # print metrics if needed 472 473 # save checkpoint if needed 474 if self.batches % self.checkpoint_interval == 0: # type: ignore[operator] 475 self._save_checkpoint() 476 477 def epoch_update(self): 478 """call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter""" 479 self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}") 480 self.epochs += 1 481 482 def _save_checkpoint(self, alias: str | None = None): 483 """wrapper for saving checkpoint as an artifact to the logger and incrementing the checkpoint counter""" 484 # if no alias, then it's a regular checkpoint 485 no_alias: bool = alias is None 486 if no_alias: 487 alias = f"checkpoint-{self.checkpoints}" 488 489 # TODO: store training hist with model? 490 491 # put together a path 492 checkpoint_path: Path 493 if no_alias: 494 # format the model save path for a normal checkpoint 495 checkpoint_path = Path( 496 self.model_save_path.format( 497 **self._get_format_kwargs(), 498 alias=alias, 499 ) 500 ) 501 else: 502 # for a special checkpoint, use the special path 503 checkpoint_path = Path( 504 self.model_save_path_special.format( 505 **self._get_format_kwargs(), 506 alias=alias, 507 ) 508 ) 509 510 # make sure directory exists 511 checkpoint_path.parent.mkdir(parents=True, exist_ok=True) 512 513 # save the model 514 self.save_model(self.model, checkpoint_path) 515 516 # log the checkpoint as an artifact 517 self.logger.artifact( 518 checkpoint_path, 519 "model", 520 metadata=self.training_status(), 521 aliases=[alias] if alias else None, 522 ) 523 524 # increment checkpoint counter 525 self.checkpoints += 1
context manager for training a model, with logging, evals, and checkpoints
Parameters:
model : torch.nn.Module
ref to model being trained - used for saving checkpointsdataloader : torch.utils.data.DataLoader
ref to dataloader being used - used for calculating training progresslogger : TrainingLoggerBase
logger, which can be local or interface with wandb.epochs : int
number of epochs to train for (defaults to1
)evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None
list of pairs of (interval, eval_fn) to run evals on the model. SeeTrainingInterval
for interval options. (defaults toNone
)checkpoint_interval : TrainingInterval | str
interval at which to save model checkpoints (defaults toTrainingInterval(1, "epochs")
)print_metrics_interval : TrainingInterval | str
interval at which to print metrics (defaults toTrainingInterval(0.1, "runs")
)save_model : Callable[[torch.nn.Module, Path], None]
function to save the model (defaults totorch.save
) (defaults totorch.save
)model_save_path : str
format string for saving model checkpoints. uses_get_format_kwargs
for formatting, along with analias
kwarg (defaults to"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"
)model_save_path_special : str
format string for saving special model checkpoints (final, exception, etc.). uses_get_format_kwargs
for formatting, along with analias
kwarg (defaults to"{run_path}/model.{alias}.pt"
)
Usage:
with TrainingManager(
model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
evals={
"1 epochs": eval_func,
"0.1 runs": lambda model: logger.get_mem_usage(),
}.items(),
checkpoint_interval="50 epochs",
) as tp:
# Training loop
model.train()
for epoch in range(epochs):
for inputs, targets in TRAIN_LOADER:
# the usual
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# compute accuracy
accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
# log metrics
tp.batch_update(
# pass in number of samples in your batch (or it will be inferred from the batch size)
samples=len(targets),
# any other metrics you compute every loop
**{"train/loss": loss.item(), "train/acc": accuracy},
)
# batch_update will automatically run evals and save checkpoints as needed
tp.epoch_update()
TrainingManager( model: torch.nn.modules.module.Module, logger: ~TLogger, dataloader: torch.utils.data.dataloader.DataLoader | None = None, epochs_total: int | None = None, save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] = <function save>, evals: Optional[Iterable[tuple[Union[str, tuple[Union[int, float, str], str], trnbl.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None, checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'), print_metrics_interval: Union[str, tuple[Union[int, float, str], str], trnbl.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'), model_save_path: str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt', model_save_path_special: str = '{run_path}/model.{alias}.pt')
168 def __init__( 169 self, 170 model: "torch.nn.Module", 171 logger: TLogger, 172 # required if you don't wrap the loops 173 dataloader: "torch.utils.data.DataLoader|None" = None, 174 epochs_total: int | None = None, 175 save_model: Callable[["torch.nn.Module", Path], None] = torch.save, 176 # everything with intervals 177 evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None, 178 checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"), 179 print_metrics_interval: CastableToTrainingInterval = TrainingInterval( 180 0.1, "runs" 181 ), 182 # everything with paths 183 model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt", 184 model_save_path_special: str = "{run_path}/model.{alias}.pt", 185 ): 186 # save start time 187 self.start_time: float = time.time() 188 # non path and non-interval args get copied over directly 189 self.model: "torch.nn.Module" = model 190 self.logger: TLogger = logger 191 self.save_model: Callable[["torch.nn.Module", Path], None] = save_model 192 193 self.logger.message("starting training manager initialization") 194 195 # model save paths 196 self.model_save_path: str = model_save_path 197 self.model_save_path_special: str = model_save_path_special 198 199 # temp intervals for processing later in `try_compute_counters` 200 self._evals: Iterable[tuple[TrainingInterval, EvalFunction]] 201 if evals is None: 202 self._evals = [] 203 else: 204 self._evals = [ 205 (TrainingInterval.from_any(interval), eval_fn) 206 for interval, eval_fn in evals 207 ] 208 self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any( 209 checkpoint_interval 210 ) 211 self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any( 212 print_metrics_interval 213 ) 214 215 self.evals: list[tuple[int, EvalFunction]] = list() 216 self.checkpoint_interval: int | None = None 217 self.print_metrics_interval: int | None = None 218 219 # counters for epochs, batches, samples, and checkpoints 220 self.epochs: int = 0 221 self.batches: int = 0 222 self.samples: int = 0 223 self.checkpoints: int = 0 224 225 # total numbers of epochs, batches, and samples 226 # pass via init kwarg or wrapped epochs loop 227 self.epochs_total: int | None = epochs_total 228 # from dataloader or dataloader in wrapped loop 229 self.batches_per_epoch: int | None = None 230 self.batch_size: int | None = None 231 self.samples_per_epoch: int | None = None 232 # computed dynamically from the above 233 self.batches_total: int | None = None 234 self.samples_total: int | None = None 235 236 # whether the init is finished 237 self.init_complete: bool = False 238 239 # if we have a dataloader, we can compute some of the above 240 if dataloader is not None: 241 self.batches_per_epoch = len(dataloader) 242 self.batch_size = dataloader.batch_size 243 self.samples_per_epoch = len(dataloader.dataset) # type: ignore[arg-type] 244 245 self.try_compute_counters()
def
try_compute_counters(self) -> None:
247 def try_compute_counters(self) -> None: 248 # we depend on either the TrainingManager init or the wrapped loops 249 # getting the epochs_total and dataloader 250 # everything else is computed dynamically 251 252 if any( 253 x is None 254 for x in [ 255 self.epochs_total, 256 self.batches_per_epoch, 257 self.batch_size, 258 self.samples_per_epoch, 259 ] 260 ): 261 # if we don't have all the info we need, return early 262 return 263 264 # we can safely ignore type check here since we just checked for `None` 265 self.batches_total = self.batches_per_epoch * self.epochs_total # type: ignore[operator] 266 self.samples_total = self.samples_per_epoch * self.epochs_total # type: ignore[operator] 267 268 # check if the dataloader has a finite nonzero length 269 if self.samples_per_epoch == 0: 270 raise TrainingManagerInitError( 271 f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }" 272 ) 273 274 if self.batches_per_epoch == 0: 275 raise TrainingManagerInitError( 276 f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }" 277 ) 278 279 if self.batch_size == 0: 280 raise TrainingManagerInitError( 281 f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }" 282 ) 283 284 if self.batch_size is None: 285 warnings.warn( 286 "batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute." 287 + "\nthis should probably be an exception" 288 ) 289 290 # normalize intervals for checkpoints, metrics printing, and evals 291 _batch_info_kwargs: dict[str, int | None] = dict( 292 batches_per_epoch=self.batches_per_epoch, 293 batchsize=self.batch_size, 294 epochs=self.epochs_total, 295 ) 296 297 # TODO: no idea why we need `type: ignore[arg-type]` here 298 self.checkpoint_interval = TrainingInterval.process_to_batches( 299 interval=self._checkpoint_interval, 300 **_batch_info_kwargs, # type: ignore[arg-type] 301 ) 302 self.print_metrics_interval = TrainingInterval.process_to_batches( 303 interval=self._print_metrics_interval, 304 **_batch_info_kwargs, # type: ignore[arg-type] 305 ) 306 307 # list[tuple[int, EvalFunction]] 308 self.evals = [ 309 ( 310 TrainingInterval.process_to_batches(interval, **_batch_info_kwargs), # type: ignore[arg-type] 311 eval_fn, 312 ) 313 for interval, eval_fn in self._evals 314 ] 315 316 # log this info 317 self.init_complete = True 318 self.logger.message( 319 "initialized training manager", 320 __training_manager_init__=True, 321 epochs_total=self.epochs_total, 322 batches_per_epoch=self.batches_per_epoch, 323 batch_size=self.batch_size, 324 samples_per_epoch=self.samples_per_epoch, 325 samples_total=self.samples_total, 326 checkpoint_interval_batches=self.checkpoint_interval, 327 print_metrics_interval_batches=self.print_metrics_interval, 328 model_save_path=self.model_save_path, 329 model_save_path_special=self.model_save_path_special, 330 **self.training_status(), 331 )
def
epoch_loop( self, epochs: Sequence[int], use_tqdm: bool = True, **tqdm_kwargs) -> Generator[int, NoneType, NoneType]:
def
batch_loop( self, batches: Sequence[int], use_tqdm: bool = False, **tqdm_kwargs) -> Generator[int, NoneType, NoneType]:
def
check_is_initialized(self):
388 def check_is_initialized(self): 389 if not self.init_complete: 390 raise TrainingManagerInitError( 391 "TrainingManager not correctly initialized. ", 392 "This is likely due to failing to specify the epoch count, or failing to specify batch size/count. " 393 "you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`", 394 "AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.", 395 "please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.", 396 )
def
get_elapsed_time(self) -> float:
398 def get_elapsed_time(self) -> float: 399 """return the elapsed time in seconds since the start of training""" 400 return time.time() - self.start_time
return the elapsed time in seconds since the start of training
def
training_status(self) -> dict[str, int | float]:
402 def training_status(self) -> dict[str, int | float]: 403 """status of elapsed time, samples, batches, epochs, and checkpoints""" 404 return dict( 405 # timestamp handled in logger 406 elapsed_time=self.get_elapsed_time(), 407 samples=self.samples, 408 batches=self.batches, 409 epochs=self.epochs, 410 latest_checkpoint=self.checkpoints, 411 )
status of elapsed time, samples, batches, epochs, and checkpoints
def
batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs):
435 def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs): 436 """call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs 437 438 This function will: 439 - update internal counters 440 - run evals as needed (based on the intervals passed) 441 - log all metrics and training status 442 - save a checkpoint as needed (based on the checkpoint interval) 443 """ 444 # check init is finished 445 if not self.init_complete: 446 self.try_compute_counters() 447 self.check_is_initialized() 448 449 # process metrics and kwargs 450 if metrics is None: 451 metrics = dict() 452 453 metrics.update(kwargs) 454 455 # update counters 456 self.batches += 1 457 if samples is not None: 458 self.samples += samples 459 else: 460 # TODO: we warn if batch size is None, but don't except 461 self.samples += self.batch_size # type: ignore[operator] 462 463 # run evals if needed 464 for interval, eval_fn in self.evals: 465 if (self.batches % interval == 0) or (self.batches == self.batches_total): 466 metrics.update(eval_fn(self.model)) 467 468 # log metrics & training status 469 self.logger.metrics({**metrics, **self.training_status()}) 470 471 # print metrics if needed 472 473 # save checkpoint if needed 474 if self.batches % self.checkpoint_interval == 0: # type: ignore[operator] 475 self._save_checkpoint()
call this at the end of every batch. Pass samples
or it will be inferred from the batch size, and any other metrics as kwargs
This function will:
- update internal counters
- run evals as needed (based on the intervals passed)
- log all metrics and training status
- save a checkpoint as needed (based on the checkpoint interval)
def
epoch_update(self):
477 def epoch_update(self): 478 """call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter""" 479 self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}") 480 self.epochs += 1
call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter