Coverage for src / autoencodix / losses / varix_loss.py: 88%
25 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2from typing import Tuple, Optional
4from autoencodix.base._base_loss import BaseLoss
5from autoencodix.utils._model_output import ModelOutput
6from autoencodix.configs.default_config import DefaultConfig
9class VarixLoss(BaseLoss):
10 """Implements loss for variational autoencoder with unified interface.
11 Attributes:
12 config: Configuration object
13 """
15 def __init__(self, config: DefaultConfig, annealing_scheduler=None):
16 """Inits VarixLoss
17 Args:
18 config: Configuraion object.Any
19 annealing_scheduler: Enables passing a custom annealer class, defaults to our implementation of an annealer
20 """
21 super().__init__(config, annealing_scheduler=annealing_scheduler)
23 def _compute_losses(
24 self, model_output: ModelOutput, targets: torch.Tensor
25 ) -> Tuple[torch.Tensor, torch.Tensor]:
26 """Compute reconstruction and variational losses.
28 Args:
29 model_output: custom class that stores model output like latentspaces and reconstructions.
30 targets: original data to compare with reconstruction
32 Returns:
33 Tuple of torch.Tensors: reconstruction loss and variational loss
35 """
36 true_samples = torch.randn(
37 self.config.batch_size, self.config.latent_dim, requires_grad=False
38 )
40 recon_loss = self.recon_loss(model_output.reconstruction, targets)
41 var_loss = self.compute_variational_loss(
42 mu=model_output.latent_mean,
43 logvar=model_output.latent_logvar,
44 z=model_output.latentspace,
45 true_samples=true_samples,
46 )
48 return recon_loss, var_loss
50 def forward(
51 self,
52 model_output: ModelOutput,
53 targets: torch.Tensor,
54 epoch: Optional[int] = None,
55 total_epochs: Optional[int] = None,
56 **kwargs,
57 ) -> Tuple[torch.Tensor, dict]:
58 """Forward pass with conditional annealing.
59 Args:
60 model_output: custom class that stores model output like latentspaces and reconstructions.
61 targets: original data to compare with reconstruction
62 epoch: current training epoch
63 total_epochs: number of total epochs
64 **kwargs
65 Returns:
66 Tuple consisting of:
67 - tensor of the total loss
68 Returns:
69 Tuple consisting of:
70 - tensor of the total loss
71 - Dict with loss_type as key and sub_loss value.
72 - Dict with loss_type as key and sub_loss value.
74 """
76 recon_loss, var_loss = self._compute_losses(model_output, targets)
78 # if are pretraining, we pass total_epochs, otherwise, we use 'epochs' from config
79 calc_epochs: int = self.config.epochs
80 if total_epochs:
81 calc_epochs = total_epochs
83 if self.config.anneal_function == "no-annealing":
84 # Use constant beta
85 effective_beta = self.config.beta
86 anneal_factor = 1.0
87 else:
88 anneal_factor = self.annealing_scheduler.get_weight(
89 epoch_current=epoch,
90 total_epoch=calc_epochs,
91 func=self.config.anneal_function,
92 )
93 effective_beta = self.config.beta * anneal_factor
95 total_loss = recon_loss + effective_beta * var_loss
97 return total_loss, {
98 "recon_loss": recon_loss,
99 "var_loss": var_loss * effective_beta,
100 "anneal_factor": torch.tensor(anneal_factor),
101 "effective_beta_factor": torch.tensor(effective_beta),
102 }