Coverage for src / autoencodix / losses / disentanglix_loss.py: 26%
46 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2from typing import Tuple, Optional
4from autoencodix.base._base_loss import BaseLoss
5from autoencodix.utils._model_output import ModelOutput
6from autoencodix.configs.default_config import DefaultConfig
9class DisentanglixLoss(BaseLoss):
10 """Implements loss for VAE with disentanglement, Disentanglix.
12 Attributes:
13 config: Configuraion object inherited from Base.
14 forward_impl: Stores forward method: There exists
15 one method with and one without annealing.
17 """
19 def __init__(self, config: DefaultConfig, annealing_scheduler=None):
20 """Inits Distentanglix loss.
22 Args:
23 config: Configuraion object inherited from Base.
24 annealing_scheduler: Enables passing a custom annealer class, defaults to our implementation of an annealer
25 """
26 super().__init__(config)
28 # Determine the forward function strategy at initialization
29 if self.config.anneal_function == "no-annealing":
30 self._forward_impl = self._forward_without_annealing
31 else:
32 self._forward_impl = self._forward_with_annealing
34 def forward(
35 self,
36 model_output: ModelOutput,
37 targets: torch.Tensor,
38 epoch: int,
39 n_samples: int,
40 **kwargs,
41 ) -> Tuple[torch.Tensor, dict]:
42 """Calls forward_impl method.
44 Args:
45 model_output: instance that stores model outputs like reconstructions, mus, sigma, etc...
46 targets: groundtruth tensor.
47 n_samples: number of samples.
48 Returns:
49 Tuple containing the total loss and a Dict with loss type as key and corresponding subloss as value.
50 """
52 return self._forward_impl(
53 model_output=model_output, targets=targets, n_samples=n_samples, epoch=epoch
54 )
56 def _compute_losses(
57 self, model_output: ModelOutput, targets: torch.Tensor, n_samples: int
58 ) -> Tuple[torch.Tensor, ...]:
59 """Compute reconstruction, mutual information, total correlation and dimension-wise KL loss terms.
61 Args:
62 model_output: instance that stores model outputs like reconstructions, mus, sigma, etc...
63 targets: groundtruth tensor.
64 n_samples: number of samples.
66 Returns:
67 Tuple of tensors of all sub losses.
68 """
70 recon_loss: torch.Tensor = self.recon_loss(model_output.reconstruction, targets)
71 mut_info_loss, tot_corr_loss, dimwise_kl_loss = (
72 self._compute_decomposed_vae_loss(
73 z=model_output.latentspace, # Latent space of batch samples (shape: [batch_size, latent_dim])
74 mu=model_output.latent_mean, # Mean of latent space (shape: [batch_size, latent_dim])
75 logvar=model_output.latent_logvar, # Log variance of latent space (shape: [batch_size, latent_dim])
76 n_samples=n_samples, # Number of samples of whole dataset
77 use_mss=self.config.use_mss,
78 )
79 )
80 # Clip losses to avoid negative values
81 mut_info_loss: torch.Tensor = torch.clamp(mut_info_loss, min=0.0)
82 tot_corr_loss: torch.Tensor = torch.clamp(tot_corr_loss, min=0.0)
83 dimwise_kl_loss: torch.Tensor = torch.clamp(dimwise_kl_loss, min=0.0)
85 return recon_loss, mut_info_loss, tot_corr_loss, dimwise_kl_loss
87 def _compute_decomposed_vae_loss(
88 self,
89 z: torch.Tensor,
90 mu: torch.Tensor,
91 logvar: torch.Tensor,
92 n_samples: int,
93 use_mss: bool = False,
94 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
95 """Compute decomposed VAE loss terms.
97 Returns:
98 A tuple of three scalar tensors:
99 - Mutual information loss
100 - Total correlation loss
101 - Dimension-wise KL divergence loss
102 """
103 log_q_z_given_x = self._compute_log_gauss_dense(z, mu, logvar).sum(
104 dim=1
105 ) # Dim [batch_size]
106 log_prior = self._compute_log_gauss_dense(
107 z, torch.zeros_like(z), torch.zeros_like(z)
108 ).sum(
109 dim=1
110 ) # Dim [batch_size]
112 log_q_batch_perm = self._compute_log_gauss_dense(
113 z.reshape(z.shape[0], 1, -1),
114 mu.reshape(1, z.shape[0], -1),
115 logvar.reshape(1, z.shape[0], -1),
116 ) # Dim [batch_size, batch_size, latent_dim]
118 if use_mss:
119 logiw_mat = self._compute_log_import_weight_mat(z.shape[0], n_samples).to(
120 z.device
121 )
122 log_q_z = torch.logsumexp(
123 logiw_mat + log_q_batch_perm.sum(dim=-1), dim=-1
124 ) # Dim [batch_size]
126 log_product_q_z = torch.logsumexp(
127 logiw_mat.reshape(z.shape[0], z.shape[0], -1) + log_q_batch_perm,
128 dim=1,
129 ).sum(
130 dim=-1
131 ) # Dim [batch_size]
132 else:
133 log_q_z = torch.logsumexp(log_q_batch_perm.sum(dim=-1), dim=-1) - torch.log(
134 torch.tensor([z.shape[0] * n_samples]).to(z.device)
135 ) # Dim [batch_size]
136 log_product_q_z = torch.logsumexp(log_q_batch_perm, dim=1) - torch.log(
137 torch.tensor([z.shape[0] * n_samples]).to(z.device)
138 ).sum(
139 dim=-1
140 ) # Dim [batch_size]
142 mut_info_loss = self.reduction_fn(
143 log_q_z_given_x - log_q_z
144 ) ## Reduction: mean or sum over batch
145 tot_corr_loss = self.reduction_fn(log_q_z - log_product_q_z)
146 dimwise_kl_loss = self.reduction_fn(log_product_q_z - log_prior)
148 return mut_info_loss, tot_corr_loss, dimwise_kl_loss
150 def _forward_without_annealing(
151 self,
152 model_output: ModelOutput,
153 targets: torch.Tensor,
154 n_samples: int,
155 epoch: Optional[int] = None,
156 ) -> Tuple[torch.Tensor, dict]:
157 """Forward pass without annealing - uses constant beta (ignores epoch)."""
158 recon_loss, mut_info_loss, tot_corr_loss, dimwise_kl_loss = (
159 self._compute_losses(model_output, targets, n_samples)
160 )
161 total_loss = (
162 recon_loss
163 + self.config.beta_mi * mut_info_loss
164 + self.config.beta_tc * tot_corr_loss
165 + self.config.beta_dimKL * dimwise_kl_loss
166 )
167 return total_loss, {
168 "recon_loss": recon_loss,
169 "mut_info_loss": mut_info_loss * self.config.beta_mi,
170 "tot_corr_loss": tot_corr_loss * self.config.beta_tc,
171 "dimwise_kl_loss": dimwise_kl_loss * self.config.beta_dimKL,
172 }
174 def _forward_with_annealing(
175 self,
176 model_output: ModelOutput,
177 targets: torch.Tensor,
178 n_samples: int,
179 epoch: int,
180 ) -> Tuple[torch.Tensor, dict]:
181 """Forward pass with annealing - calculates annealing factor from epoch."""
182 recon_loss, mut_info_loss, tot_corr_loss, dimwise_kl_loss = (
183 self._compute_losses(model_output, targets, n_samples)
184 )
186 # Get annealing weight
187 anneal_factor = self.annealing_scheduler.get_weight(
188 # epoch_current=annealing_epoch,
189 epoch_current=epoch,
190 total_epoch=self.config.epochs,
191 func=self.config.anneal_function,
192 )
194 # Apply annealed beta
195 effective_beta_mi = self.config.beta_mi * anneal_factor
196 effective_beta_tc = self.config.beta_tc * anneal_factor
197 effective_beta_dimKL = self.config.beta_dimKL * anneal_factor
198 # Calculate total loss
199 total_loss = (
200 recon_loss
201 + effective_beta_mi * mut_info_loss
202 + effective_beta_tc * tot_corr_loss
203 + effective_beta_dimKL * dimwise_kl_loss
204 )
206 return total_loss, {
207 "recon_loss": recon_loss,
208 "mut_info_loss": mut_info_loss * effective_beta_mi,
209 "tot_corr_loss": tot_corr_loss * effective_beta_tc,
210 "dimwise_kl_loss": dimwise_kl_loss * effective_beta_dimKL,
211 "anneal_factor": torch.tensor(anneal_factor),
212 "effective_beta_mi_factor": torch.tensor(effective_beta_mi),
213 "effective_beta_tc_factor": torch.tensor(effective_beta_tc),
214 "effective_beta_dimKL_factor": torch.tensor(effective_beta_dimKL),
215 }