Coverage for src / autoencodix / losses / disentanglix_loss.py: 26%

46 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2from typing import Tuple, Optional 

3 

4from autoencodix.base._base_loss import BaseLoss 

5from autoencodix.utils._model_output import ModelOutput 

6from autoencodix.configs.default_config import DefaultConfig 

7 

8 

9class DisentanglixLoss(BaseLoss): 

10 """Implements loss for VAE with disentanglement, Disentanglix. 

11 

12 Attributes: 

13 config: Configuraion object inherited from Base. 

14 forward_impl: Stores forward method: There exists 

15 one method with and one without annealing. 

16 

17 """ 

18 

19 def __init__(self, config: DefaultConfig, annealing_scheduler=None): 

20 """Inits Distentanglix loss. 

21 

22 Args: 

23 config: Configuraion object inherited from Base. 

24 annealing_scheduler: Enables passing a custom annealer class, defaults to our implementation of an annealer 

25 """ 

26 super().__init__(config) 

27 

28 # Determine the forward function strategy at initialization 

29 if self.config.anneal_function == "no-annealing": 

30 self._forward_impl = self._forward_without_annealing 

31 else: 

32 self._forward_impl = self._forward_with_annealing 

33 

34 def forward( 

35 self, 

36 model_output: ModelOutput, 

37 targets: torch.Tensor, 

38 epoch: int, 

39 n_samples: int, 

40 **kwargs, 

41 ) -> Tuple[torch.Tensor, dict]: 

42 """Calls forward_impl method. 

43 

44 Args: 

45 model_output: instance that stores model outputs like reconstructions, mus, sigma, etc... 

46 targets: groundtruth tensor. 

47 n_samples: number of samples. 

48 Returns: 

49 Tuple containing the total loss and a Dict with loss type as key and corresponding subloss as value. 

50 """ 

51 

52 return self._forward_impl( 

53 model_output=model_output, targets=targets, n_samples=n_samples, epoch=epoch 

54 ) 

55 

56 def _compute_losses( 

57 self, model_output: ModelOutput, targets: torch.Tensor, n_samples: int 

58 ) -> Tuple[torch.Tensor, ...]: 

59 """Compute reconstruction, mutual information, total correlation and dimension-wise KL loss terms. 

60 

61 Args: 

62 model_output: instance that stores model outputs like reconstructions, mus, sigma, etc... 

63 targets: groundtruth tensor. 

64 n_samples: number of samples. 

65 

66 Returns: 

67 Tuple of tensors of all sub losses. 

68 """ 

69 

70 recon_loss: torch.Tensor = self.recon_loss(model_output.reconstruction, targets) 

71 mut_info_loss, tot_corr_loss, dimwise_kl_loss = ( 

72 self._compute_decomposed_vae_loss( 

73 z=model_output.latentspace, # Latent space of batch samples (shape: [batch_size, latent_dim]) 

74 mu=model_output.latent_mean, # Mean of latent space (shape: [batch_size, latent_dim]) 

75 logvar=model_output.latent_logvar, # Log variance of latent space (shape: [batch_size, latent_dim]) 

76 n_samples=n_samples, # Number of samples of whole dataset 

77 use_mss=self.config.use_mss, 

78 ) 

79 ) 

80 # Clip losses to avoid negative values 

81 mut_info_loss: torch.Tensor = torch.clamp(mut_info_loss, min=0.0) 

82 tot_corr_loss: torch.Tensor = torch.clamp(tot_corr_loss, min=0.0) 

83 dimwise_kl_loss: torch.Tensor = torch.clamp(dimwise_kl_loss, min=0.0) 

84 

85 return recon_loss, mut_info_loss, tot_corr_loss, dimwise_kl_loss 

86 

87 def _compute_decomposed_vae_loss( 

88 self, 

89 z: torch.Tensor, 

90 mu: torch.Tensor, 

91 logvar: torch.Tensor, 

92 n_samples: int, 

93 use_mss: bool = False, 

94 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

95 """Compute decomposed VAE loss terms. 

96 

97 Returns: 

98 A tuple of three scalar tensors: 

99 - Mutual information loss 

100 - Total correlation loss 

101 - Dimension-wise KL divergence loss 

102 """ 

103 log_q_z_given_x = self._compute_log_gauss_dense(z, mu, logvar).sum( 

104 dim=1 

105 ) # Dim [batch_size] 

106 log_prior = self._compute_log_gauss_dense( 

107 z, torch.zeros_like(z), torch.zeros_like(z) 

108 ).sum( 

109 dim=1 

110 ) # Dim [batch_size] 

111 

112 log_q_batch_perm = self._compute_log_gauss_dense( 

113 z.reshape(z.shape[0], 1, -1), 

114 mu.reshape(1, z.shape[0], -1), 

115 logvar.reshape(1, z.shape[0], -1), 

116 ) # Dim [batch_size, batch_size, latent_dim] 

117 

118 if use_mss: 

119 logiw_mat = self._compute_log_import_weight_mat(z.shape[0], n_samples).to( 

120 z.device 

121 ) 

122 log_q_z = torch.logsumexp( 

123 logiw_mat + log_q_batch_perm.sum(dim=-1), dim=-1 

124 ) # Dim [batch_size] 

125 

126 log_product_q_z = torch.logsumexp( 

127 logiw_mat.reshape(z.shape[0], z.shape[0], -1) + log_q_batch_perm, 

128 dim=1, 

129 ).sum( 

130 dim=-1 

131 ) # Dim [batch_size] 

132 else: 

133 log_q_z = torch.logsumexp(log_q_batch_perm.sum(dim=-1), dim=-1) - torch.log( 

134 torch.tensor([z.shape[0] * n_samples]).to(z.device) 

135 ) # Dim [batch_size] 

136 log_product_q_z = torch.logsumexp(log_q_batch_perm, dim=1) - torch.log( 

137 torch.tensor([z.shape[0] * n_samples]).to(z.device) 

138 ).sum( 

139 dim=-1 

140 ) # Dim [batch_size] 

141 

142 mut_info_loss = self.reduction_fn( 

143 log_q_z_given_x - log_q_z 

144 ) ## Reduction: mean or sum over batch 

145 tot_corr_loss = self.reduction_fn(log_q_z - log_product_q_z) 

146 dimwise_kl_loss = self.reduction_fn(log_product_q_z - log_prior) 

147 

148 return mut_info_loss, tot_corr_loss, dimwise_kl_loss 

149 

150 def _forward_without_annealing( 

151 self, 

152 model_output: ModelOutput, 

153 targets: torch.Tensor, 

154 n_samples: int, 

155 epoch: Optional[int] = None, 

156 ) -> Tuple[torch.Tensor, dict]: 

157 """Forward pass without annealing - uses constant beta (ignores epoch).""" 

158 recon_loss, mut_info_loss, tot_corr_loss, dimwise_kl_loss = ( 

159 self._compute_losses(model_output, targets, n_samples) 

160 ) 

161 total_loss = ( 

162 recon_loss 

163 + self.config.beta_mi * mut_info_loss 

164 + self.config.beta_tc * tot_corr_loss 

165 + self.config.beta_dimKL * dimwise_kl_loss 

166 ) 

167 return total_loss, { 

168 "recon_loss": recon_loss, 

169 "mut_info_loss": mut_info_loss * self.config.beta_mi, 

170 "tot_corr_loss": tot_corr_loss * self.config.beta_tc, 

171 "dimwise_kl_loss": dimwise_kl_loss * self.config.beta_dimKL, 

172 } 

173 

174 def _forward_with_annealing( 

175 self, 

176 model_output: ModelOutput, 

177 targets: torch.Tensor, 

178 n_samples: int, 

179 epoch: int, 

180 ) -> Tuple[torch.Tensor, dict]: 

181 """Forward pass with annealing - calculates annealing factor from epoch.""" 

182 recon_loss, mut_info_loss, tot_corr_loss, dimwise_kl_loss = ( 

183 self._compute_losses(model_output, targets, n_samples) 

184 ) 

185 

186 # Get annealing weight 

187 anneal_factor = self.annealing_scheduler.get_weight( 

188 # epoch_current=annealing_epoch, 

189 epoch_current=epoch, 

190 total_epoch=self.config.epochs, 

191 func=self.config.anneal_function, 

192 ) 

193 

194 # Apply annealed beta 

195 effective_beta_mi = self.config.beta_mi * anneal_factor 

196 effective_beta_tc = self.config.beta_tc * anneal_factor 

197 effective_beta_dimKL = self.config.beta_dimKL * anneal_factor 

198 # Calculate total loss 

199 total_loss = ( 

200 recon_loss 

201 + effective_beta_mi * mut_info_loss 

202 + effective_beta_tc * tot_corr_loss 

203 + effective_beta_dimKL * dimwise_kl_loss 

204 ) 

205 

206 return total_loss, { 

207 "recon_loss": recon_loss, 

208 "mut_info_loss": mut_info_loss * effective_beta_mi, 

209 "tot_corr_loss": tot_corr_loss * effective_beta_tc, 

210 "dimwise_kl_loss": dimwise_kl_loss * effective_beta_dimKL, 

211 "anneal_factor": torch.tensor(anneal_factor), 

212 "effective_beta_mi_factor": torch.tensor(effective_beta_mi), 

213 "effective_beta_tc_factor": torch.tensor(effective_beta_tc), 

214 "effective_beta_dimKL_factor": torch.tensor(effective_beta_dimKL), 

215 }