Coverage for src / autoencodix / utils / _annealer.py: 34%

53 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Dict, Callable, Optional 

2from math import exp 

3 

4 

5class AnnealingScheduler: 

6 """VAE loss annealing scheduler with multiple annealing strategies.""" 

7 

8 def __init__(self) -> None: 

9 """Initialize the annealing scheduler with strategy mappings.""" 

10 self._strategies: Dict[str, Callable[[int, int], float]] = { 

11 "5phase-constant": self._five_phase_constant, 

12 "3phase-linear": self._three_phase_linear, 

13 "3phase-log": self._three_phase_log, 

14 "logistic-mid": self._logistic_mid, 

15 "logistic-early": self._logistic_early, 

16 "logistic-late": self._logistic_late, 

17 "no-annealing": self._no_annealing, 

18 } 

19 

20 @staticmethod 

21 def get_annealing_epoch( 

22 *, anneal_pretraining: bool, n_epochs_pretrain: int, current_epoch: int 

23 ) -> Optional[int]: 

24 """Check if annealing should be used for pretraining. 

25 Args: 

26 anneal_pretraining: Whether to apply annealing during pretraining phase. 

27 n_epochs_pretrain: Number of pretraining epochs. 

28 current_epoch: Current epoch number. 

29 Returns: 

30 int or None: Annealing epoch number, or None if no annealing. 

31 Raises: 

32 NotImplementedError: This is a deprecated method. 

33 """ 

34 raise NotImplementedError( 

35 "Deprecated, for annealing the current epoch is passed, we split between training and \ 

36 pretraining, so no extra calculation is needed" 

37 ) 

38 

39 def get_weight( 

40 self, 

41 *, 

42 epoch_current: Optional[int], 

43 total_epoch: int, 

44 func: str = "logistic-mid", 

45 ) -> float: 

46 """Calculate VAE loss annealing weight. 

47 

48 Args: 

49 epoch_current: Current epoch in training, or None for full weight. 

50 total_epoch: Total epochs for training. 

51 func: Specification of annealing function. Default is 'logistic-mid'. 

52 

53 Returns: 

54 Annealing weight between 0 (no VAE loss) and 1 (full VAE loss). 

55 

56 Raises: 

57 NotImplementedError: If the specified annealing function is not implemented. 

58 """ 

59 if epoch_current is None: 

60 return 1.0 

61 

62 if func not in self._strategies: 

63 raise NotImplementedError("The annealer is not implemented yet") 

64 

65 return self._strategies[func](epoch_current, total_epoch) 

66 

67 def _five_phase_constant(self, epoch_current: int, total_epoch: int) -> float: 

68 """Five phase constant annealing strategy. 

69 

70 Args: 

71 epoch_current: Current epoch number. 

72 total_epoch: Total number of epochs. 

73 

74 Returns: 

75 Annealing weight. 

76 """ 

77 intervals = 5 

78 current_phase = int((epoch_current / total_epoch) * intervals) 

79 

80 if current_phase == 0: 

81 return 0.0 

82 elif current_phase == 1: 

83 return 0.001 

84 elif current_phase == 2: 

85 return 0.01 

86 elif current_phase == 3: 

87 return 0.1 

88 else: 

89 return 1.0 

90 

91 def _three_phase_linear(self, epoch_current: int, total_epoch: int) -> float: 

92 """Three phase linear annealing strategy. 

93 

94 Args: 

95 epoch_current: Current epoch number. 

96 total_epoch: Total number of epochs. 

97 

98 Returns: 

99 Annealing weight. 

100 """ 

101 first_phase_end = total_epoch / 3 

102 second_phase_end = 2 * (total_epoch / 3) 

103 

104 if epoch_current < first_phase_end: 

105 return 0.0 

106 elif epoch_current < second_phase_end: 

107 return (epoch_current - first_phase_end) / (total_epoch / 3) 

108 else: 

109 return 1.0 

110 

111 def _three_phase_log(self, epoch_current: int, total_epoch: int) -> float: 

112 """Three phase logarithmic annealing strategy. 

113 

114 Args: 

115 epoch_current: Current epoch number. 

116 total_epoch: Total number of epochs. 

117 

118 Returns: 

119 Annealing weight. 

120 """ 

121 first_phase_end = total_epoch / 3 

122 second_phase_end = 2 * (total_epoch / 3) 

123 

124 if epoch_current < first_phase_end: 

125 return 0.0 

126 elif epoch_current < second_phase_end: 

127 return self._logistic_mid(epoch_current - first_phase_end, total_epoch / 3) 

128 else: 

129 return 1.0 

130 

131 def _compute_logistic_weight( 

132 self, epoch_current: int, total_epoch: int, midpoint: float 

133 ) -> float: 

134 """Compute logistic annealing weight. 

135 

136 Args: 

137 epoch_current: Current epoch number. 

138 total_epoch: Total number of epochs. 

139 midpoint: Midpoint ratio for the logistic function (0.0 to 1.0). 

140 

141 Returns: 

142 Annealing weight. 

143 """ 

144 b_param = (1 / total_epoch) * 20 

145 return 1 / (1 + exp(-b_param * (epoch_current - total_epoch * midpoint))) 

146 

147 def _logistic_mid(self, epoch_current: int, total_epoch: int) -> float: 

148 """Logistic annealing with midpoint at half of total epochs. 

149 

150 Args: 

151 epoch_current: Current epoch number. 

152 total_epoch: Total number of epochs. 

153 

154 Returns: 

155 Annealing weight. 

156 """ 

157 return self._compute_logistic_weight(epoch_current, total_epoch, 0.5) 

158 

159 def _logistic_early(self, epoch_current: int, total_epoch: int) -> float: 

160 """Logistic annealing with early midpoint at quarter of total epochs. 

161 

162 Args: 

163 epoch_current: Current epoch number. 

164 total_epoch: Total number of epochs. 

165 

166 Returns: 

167 Annealing weight. 

168 """ 

169 return self._compute_logistic_weight(epoch_current, total_epoch, 0.25) 

170 

171 def _logistic_late(self, epoch_current: int, total_epoch: int) -> float: 

172 """Logistic annealing with late midpoint at three quarters of total epochs. 

173 

174 Args: 

175 epoch_current: Current epoch number. 

176 total_epoch: Total number of epochs. 

177 

178 Returns: 

179 Annealing weight. 

180 """ 

181 return self._compute_logistic_weight(epoch_current, total_epoch, 0.75) 

182 

183 def _no_annealing(self, epoch_current: int, total_epoch: int) -> float: 

184 """No annealing strategy - constant full weight. 

185 

186 Args: 

187 epoch_current: Current epoch number. 

188 total_epoch: Total number of epochs. 

189 

190 Returns: 

191 Annealing weight (always 1.0). 

192 """ 

193 return 1.0