Source code for ise.utils.training

import torch
import warnings

[docs] class CheckpointSaver: def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, checkpoint_path: str, verbose: bool = False): self.checkpoint_path = checkpoint_path self.model = model self.optimizer = optimizer self.best_loss = float('inf') self.verbose = verbose self.log = None def __call__(self, loss, epoch, save_best_only=True,): is_better = self._determine_if_better(loss) if save_best_only else True if is_better or not save_best_only: # Save if loss improves or save_best_only is False self.save_checkpoint(epoch, loss, self.checkpoint_path) if self.verbose: self.log = f"Loss decreased ({self.best_loss:.6f} --> {loss:.6f}). Saving checkpoint to {self.checkpoint_path}." self._update_best_loss(loss) return True else: self.log = "" return False def _determine_if_better(self, loss: float): # Determine if current loss is better than best_loss return loss < self.best_loss def _update_best_loss(self, loss): self.best_loss = loss
[docs] def save_checkpoint(self, epoch, loss, path: str = None): checkpoint_path = path or self.checkpoint_path checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'best_loss': self.best_loss, } torch.save(checkpoint, checkpoint_path)
# if self.verbose: # print(f"Checkpoint saved to {checkpoint_path}")
[docs] def load_checkpoint(self, path: str = None): checkpoint_path = path or self.checkpoint_path checkpoint = torch.load(checkpoint_path) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.best_loss = checkpoint.get('best_loss', float('inf')) start_epoch = checkpoint.get('epoch', 0) + 1 if self.verbose: print(f"Loaded checkpoint from {checkpoint_path}, resuming from epoch {start_epoch}") return start_epoch
[docs] class EarlyStoppingCheckpointer(CheckpointSaver): def __init__(self, model, optimizer, checkpoint_path='checkpoint.pt', patience=10, verbose=False): super().__init__(model, optimizer, checkpoint_path, verbose) self.patience = patience self.counter = 0 self.early_stop = False def __call__(self, loss, epoch, save_best_only=True,): saved = super().__call__(loss, epoch, save_best_only,) if saved: self.counter = 0 # Reset counter if the model improved else: self.counter += 1 if self.verbose: print(f"EarlyStopping counter: {self.counter} out of {self.patience}") if self.counter >= self.patience: self.early_stop = True