Coverage for src / autoencodix / utils / _annealer.py: 34%
53 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Dict, Callable, Optional
2from math import exp
5class AnnealingScheduler:
6 """VAE loss annealing scheduler with multiple annealing strategies."""
8 def __init__(self) -> None:
9 """Initialize the annealing scheduler with strategy mappings."""
10 self._strategies: Dict[str, Callable[[int, int], float]] = {
11 "5phase-constant": self._five_phase_constant,
12 "3phase-linear": self._three_phase_linear,
13 "3phase-log": self._three_phase_log,
14 "logistic-mid": self._logistic_mid,
15 "logistic-early": self._logistic_early,
16 "logistic-late": self._logistic_late,
17 "no-annealing": self._no_annealing,
18 }
20 @staticmethod
21 def get_annealing_epoch(
22 *, anneal_pretraining: bool, n_epochs_pretrain: int, current_epoch: int
23 ) -> Optional[int]:
24 """Check if annealing should be used for pretraining.
25 Args:
26 anneal_pretraining: Whether to apply annealing during pretraining phase.
27 n_epochs_pretrain: Number of pretraining epochs.
28 current_epoch: Current epoch number.
29 Returns:
30 int or None: Annealing epoch number, or None if no annealing.
31 Raises:
32 NotImplementedError: This is a deprecated method.
33 """
34 raise NotImplementedError(
35 "Deprecated, for annealing the current epoch is passed, we split between training and \
36 pretraining, so no extra calculation is needed"
37 )
39 def get_weight(
40 self,
41 *,
42 epoch_current: Optional[int],
43 total_epoch: int,
44 func: str = "logistic-mid",
45 ) -> float:
46 """Calculate VAE loss annealing weight.
48 Args:
49 epoch_current: Current epoch in training, or None for full weight.
50 total_epoch: Total epochs for training.
51 func: Specification of annealing function. Default is 'logistic-mid'.
53 Returns:
54 Annealing weight between 0 (no VAE loss) and 1 (full VAE loss).
56 Raises:
57 NotImplementedError: If the specified annealing function is not implemented.
58 """
59 if epoch_current is None:
60 return 1.0
62 if func not in self._strategies:
63 raise NotImplementedError("The annealer is not implemented yet")
65 return self._strategies[func](epoch_current, total_epoch)
67 def _five_phase_constant(self, epoch_current: int, total_epoch: int) -> float:
68 """Five phase constant annealing strategy.
70 Args:
71 epoch_current: Current epoch number.
72 total_epoch: Total number of epochs.
74 Returns:
75 Annealing weight.
76 """
77 intervals = 5
78 current_phase = int((epoch_current / total_epoch) * intervals)
80 if current_phase == 0:
81 return 0.0
82 elif current_phase == 1:
83 return 0.001
84 elif current_phase == 2:
85 return 0.01
86 elif current_phase == 3:
87 return 0.1
88 else:
89 return 1.0
91 def _three_phase_linear(self, epoch_current: int, total_epoch: int) -> float:
92 """Three phase linear annealing strategy.
94 Args:
95 epoch_current: Current epoch number.
96 total_epoch: Total number of epochs.
98 Returns:
99 Annealing weight.
100 """
101 first_phase_end = total_epoch / 3
102 second_phase_end = 2 * (total_epoch / 3)
104 if epoch_current < first_phase_end:
105 return 0.0
106 elif epoch_current < second_phase_end:
107 return (epoch_current - first_phase_end) / (total_epoch / 3)
108 else:
109 return 1.0
111 def _three_phase_log(self, epoch_current: int, total_epoch: int) -> float:
112 """Three phase logarithmic annealing strategy.
114 Args:
115 epoch_current: Current epoch number.
116 total_epoch: Total number of epochs.
118 Returns:
119 Annealing weight.
120 """
121 first_phase_end = total_epoch / 3
122 second_phase_end = 2 * (total_epoch / 3)
124 if epoch_current < first_phase_end:
125 return 0.0
126 elif epoch_current < second_phase_end:
127 return self._logistic_mid(epoch_current - first_phase_end, total_epoch / 3)
128 else:
129 return 1.0
131 def _compute_logistic_weight(
132 self, epoch_current: int, total_epoch: int, midpoint: float
133 ) -> float:
134 """Compute logistic annealing weight.
136 Args:
137 epoch_current: Current epoch number.
138 total_epoch: Total number of epochs.
139 midpoint: Midpoint ratio for the logistic function (0.0 to 1.0).
141 Returns:
142 Annealing weight.
143 """
144 b_param = (1 / total_epoch) * 20
145 return 1 / (1 + exp(-b_param * (epoch_current - total_epoch * midpoint)))
147 def _logistic_mid(self, epoch_current: int, total_epoch: int) -> float:
148 """Logistic annealing with midpoint at half of total epochs.
150 Args:
151 epoch_current: Current epoch number.
152 total_epoch: Total number of epochs.
154 Returns:
155 Annealing weight.
156 """
157 return self._compute_logistic_weight(epoch_current, total_epoch, 0.5)
159 def _logistic_early(self, epoch_current: int, total_epoch: int) -> float:
160 """Logistic annealing with early midpoint at quarter of total epochs.
162 Args:
163 epoch_current: Current epoch number.
164 total_epoch: Total number of epochs.
166 Returns:
167 Annealing weight.
168 """
169 return self._compute_logistic_weight(epoch_current, total_epoch, 0.25)
171 def _logistic_late(self, epoch_current: int, total_epoch: int) -> float:
172 """Logistic annealing with late midpoint at three quarters of total epochs.
174 Args:
175 epoch_current: Current epoch number.
176 total_epoch: Total number of epochs.
178 Returns:
179 Annealing weight.
180 """
181 return self._compute_logistic_weight(epoch_current, total_epoch, 0.75)
183 def _no_annealing(self, epoch_current: int, total_epoch: int) -> float:
184 """No annealing strategy - constant full weight.
186 Args:
187 epoch_current: Current epoch number.
188 total_epoch: Total number of epochs.
190 Returns:
191 Annealing weight (always 1.0).
192 """
193 return 1.0