Source code for reflectorch.ml.schedulers

# -*- coding: utf-8 -*-
#
#
# This source code is licensed under the GPL license found in the
# LICENSE file in the root directory of this source tree.

from torch.optim import lr_scheduler

import numpy as np

from reflectorch.ml.basic_trainer import Trainer, TrainerCallback, PeriodicTrainerCallback

__all__ = [
    'ScheduleBatchSize',
    'ScheduleLR',
    'StepLR',
    'CyclicLR',
    'LogCyclicLR',
    'ReduceLROnPlateau',
    'OneCycleLR',
]


[docs] class ScheduleBatchSize(PeriodicTrainerCallback): """Batch size scheduler Args: step (int): number of iterations after which the batch size is modified. gamma (int, optional): quantity which is added to or multiplied with the current batch size. Defaults to 2. last_epoch (int, optional): the last training iteration for which the batch size is modified. Defaults to -1. mode (str, optional): ``'add'`` for addition or ``'multiply'`` for multiplication. Defaults to 'add'. """ def __init__(self, step: int, gamma: int = 2, last_epoch: int = -1, mode: str = 'add'): super().__init__(step, last_epoch) assert mode in ('add', 'multiply') self.gamma = gamma self.mode = mode def _end_batch(self, trainer: Trainer, batch_num: int) -> None: if self.mode == 'add': trainer.batch_size += self.gamma elif self.mode == 'multiply': trainer.batch_size *= self.gamma
[docs] class ScheduleLR(TrainerCallback): """Base class for learning rate schedulers Args: lr_scheduler_cls: class of the learning rate scheduler """ def __init__(self, lr_scheduler_cls, **kwargs): self.lr_scheduler_cls = lr_scheduler_cls self.kwargs = kwargs self.lr_scheduler = None
[docs] def start_training(self, trainer: Trainer) -> None: """initializes a learning rate scheduler based on its class and keyword arguments at the start of training Args: trainer (Trainer): the trainer object """ self.lr_scheduler = self.lr_scheduler_cls(trainer.optim, **self.kwargs) trainer.callback_params['lrs'] = []
[docs] def end_batch(self, trainer: Trainer, batch_num: int) -> None: """modifies the learning rate at the end of each iteration Args: trainer (Trainer): the trainer object batch_num (int): index of the current iteration """ self.lr_scheduler.step()
[docs] class StepLR(ScheduleLR): """Learning rate scheduler which decays the learning rate of each parameter group by gamma every ``step_size`` epochs. Args: step_size (int): Period of learning rate decay gamma (float): Multiplicative factor of learning rate decay last_epoch (int, optional): The index of last iteration. Defaults to -1. """ def __init__(self, step_size: int, gamma: float, last_epoch: int = -1, **kwargs): super().__init__(lr_scheduler.StepLR, step_size=step_size, gamma=gamma, last_epoch=last_epoch, **kwargs)
[docs] def start_training(self, trainer: Trainer) -> None: trainer.optim.param_groups[0]['initial_lr'] = trainer.lr() super().start_training(trainer)
[docs] class CyclicLR(ScheduleLR): """Cyclic learning rate scheduler Args: base_lr (float): Initial learning rate which is the lower boundary in the cycle max_lr (float): Upper learning rate boundary in the cycle step_size_up (int, optional): Number of training iterations in the increasing half of a cycle. Defaults to 2000. cycle_momentum (bool, optional): If True, momentum is cycled inversely to learning rate between ``base_momentum`` and ``max_momentum``. Defaults to False. gamma (float, optional): Constant in ``‘exp_range’`` mode scaling function: gamma^(cycle iterations). Defaults to 1. mode (str, optional): One of: ``'triangular'`` (a basic triangular cycle without amplitude scaling), ``'triangular2'`` (a basic triangular cycle that scales initial amplitude by half each cycle), ``'exp_range'`` (a cycle that scales initial amplitude by gamma^iterations at each cycle iteration). Defaults to 'triangular'. """ def __init__(self, base_lr, max_lr, step_size_up: int = 2000, cycle_momentum: bool = False, gamma: float = 1., mode: str = 'triangular', **kwargs): super().__init__( lr_scheduler.CyclicLR, base_lr=base_lr, max_lr=max_lr, step_size_up=step_size_up, cycle_momentum=cycle_momentum, gamma=gamma, mode=mode, **kwargs )
[docs] class LogCyclicLR(TrainerCallback): """Cyclic learning rate scheduler on a logarithmic scale Args: base_lr (float): Lower learning rate boundary in the cycle max_lr (float): Upper learning rate boundary in the cycle period (int, optional): Number of training iterations in the cycle. Defaults to 2000. gamma (float, optional): Constant for scaling the amplitude as ``gamma`` ^ ``iterations``. Defaults to 1. start_period (int, optional): Number of starting iterations with the default learning rate. log (bool, optional): If ``True``, the cycle is in the logarithmic domain. param_groups (tupe, optional): Parameter groups of the optimizer. """ def __init__(self, base_lr, max_lr, period: int = 2000, gamma: float = None, log: bool = True, param_groups: tuple = (0,), start_period: int = 25, ): self.base_lr = base_lr self.max_lr = max_lr self.period = period self.gamma = gamma self.param_groups = param_groups self.log = log self.start_period = start_period self._axis = None self._period = None def get_lr(self, batch_num: int): return self._get_lr(batch_num) def _get_lr(self, batch_num): num_period, t = batch_num // self.period, batch_num % self.period if self._period != num_period: self._period = num_period if self.gamma and (num_period >= self.start_period): amp = (self.max_lr - self.base_lr) * (self.gamma ** (num_period - self.start_period)) max_lr = self.base_lr + amp else: max_lr = self.max_lr if self.log: self._axis = np.logspace(np.log10(self.base_lr), np.log10(max_lr), self.period // 2) else: self._axis = np.linspace(self.base_lr, max_lr, self.period // 2) if t < self.period // 2: lr = self._axis[t] else: lr = self._axis[self.period - t - 1] return lr
[docs] def end_batch(self, trainer: Trainer, batch_num: int): lr = self.get_lr(batch_num) for param_group in self.param_groups: trainer.set_lr(lr, param_group)
[docs] class ReduceLROnPlateau(TrainerCallback): """Learning rate scheduler which reduces the learning rate when the loss stops decreasing Args: gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.5. patience (int, optional): The number of allowed iterations with no improvement after which the learning rate will be reduced. Defaults to 500. average (int, optional): Size of the window over which the average loss is computed. Defaults to 50. loss_key (str, optional): Defaults to 'total_loss'. param_groups (tuple, optional): Defaults to (0,). """ def __init__( self, gamma: float = 0.5, patience: int = 500, average: int = 50, loss_key: str = 'total_loss', param_groups: tuple = (0,), ): """ """ self.patience = patience self.average = average self.gamma = gamma self.loss_key = loss_key self.param_groups = param_groups
[docs] def end_batch(self, trainer: Trainer, batch_num: int) -> None: loss = trainer.losses[self.loss_key] if len(loss) < self.patience: return if np.mean(loss[-self.patience:-(self.patience - self.average)]) <= np.mean(loss[-self.average:]): for param_group in self.param_groups: trainer.set_lr(trainer.lr(param_group) * self.gamma, param_group)
[docs] class OneCycleLR(ScheduleLR): """One-cycle learning rate scheduler (https://arxiv.org/abs/1708.07120) Args: max_lr (float): Upper learning rate boundary in the cycle total_steps (int): The total number of steps in the cycle pct_start (float, optional): The percentage of the cycle (in number of steps) spent increasing the learning rate. Defaults to 0.3. div_factor (float, optional): Determines the initial learning rate via initial_lr = ``max_lr`` / ``div_factor``. Defaults to 25.. final_div_factor (float, optional): Determines the minimum learning rate via min_lr = ``initial_lr`` / ``final_div_factor``. Defaults to 1e4. three_phase (bool, optional): If ``True``, use a third phase of the schedule to annihilate the learning rate according to ``final_div_factor`` instead of modifying the second phase. Defaults to True. """ def __init__(self, max_lr: float, total_steps: int, pct_start: float = 0.3, div_factor: float = 25., final_div_factor: float = 1e4, three_phase: bool = True, **kwargs): super().__init__( lr_scheduler.OneCycleLR, max_lr=max_lr, total_steps=total_steps, pct_start=pct_start, div_factor=div_factor , final_div_factor=final_div_factor, three_phase=three_phase, **kwargs )