Source code for reflectorch.ml.callbacks
# -*- coding: utf-8 -*-
#
#
# This source code is licensed under the GPL license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
from reflectorch.ml.basic_trainer import (
TrainerCallback,
Trainer,
)
from reflectorch.ml.utils import is_divisor
__all__ = [
'SaveBestModel',
'LogLosses',
]
[docs]
class SaveBestModel(TrainerCallback):
"""Callback for periodically saving the best model weights
Args:
path (str): path for saving the model weights
freq (int, optional): frequency in iterations at which the current average loss is evaluated. Defaults to 50.
average (int, optional): number of recent iterations over which the average loss is computed. Defaults to 10.
"""
def __init__(self, path: str, freq: int = 50, average: int = 10):
self.path = path
self.average = average
self._best_loss = np.inf
self.freq = freq
[docs]
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
"""checks if the current average loss has improved from the previous save, if true the model is saved
Args:
trainer (Trainer): the trainer object
batch_num (int): the current iteration / batch
"""
if is_divisor(batch_num, self.freq):
loss = np.mean(trainer.losses['total_loss'][-self.average:])
if loss < self._best_loss:
self._best_loss = loss
self.save(trainer, batch_num)
[docs]
def save(self, trainer: Trainer, batch_num: int):
"""saves a dictionary containing the network weights, the learning rates, the losses and the current \
best loss with its corresponding iteration to the disk
Args:
trainer (Trainer): the trainer object
batch_num (int): the current iteration / batch
"""
prev_save = trainer.callback_params.pop('saved_iteration', 0)
trainer.callback_params['saved_iteration'] = batch_num
save_dict = {
'model': trainer.model.state_dict(),
'lrs': trainer.lrs,
'losses': trainer.losses,
'prev_save': prev_save,
'batch_num': batch_num,
'best_loss': self._best_loss
}
torch.save(save_dict, self.path)
[docs]
class LogLosses(TrainerCallback):
"""Callback for logging the training losses"""
[docs]
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
"""log loss at the current iteration
Args:
trainer (Trainer): the trainer object
batch_num (int): the index of the current iteration / batch
"""
try:
trainer.log('train/total_loss', trainer.losses[trainer.TOTAL_LOSS_KEY][-1])
except IndexError:
pass