Coverage for src / autoencodix / losses / varix_loss.py: 88%

25 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2from typing import Tuple, Optional 

3 

4from autoencodix.base._base_loss import BaseLoss 

5from autoencodix.utils._model_output import ModelOutput 

6from autoencodix.configs.default_config import DefaultConfig 

7 

8 

9class VarixLoss(BaseLoss): 

10 """Implements loss for variational autoencoder with unified interface. 

11 Attributes: 

12 config: Configuration object 

13 """ 

14 

15 def __init__(self, config: DefaultConfig, annealing_scheduler=None): 

16 """Inits VarixLoss 

17 Args: 

18 config: Configuraion object.Any 

19 annealing_scheduler: Enables passing a custom annealer class, defaults to our implementation of an annealer 

20 """ 

21 super().__init__(config, annealing_scheduler=annealing_scheduler) 

22 

23 def _compute_losses( 

24 self, model_output: ModelOutput, targets: torch.Tensor 

25 ) -> Tuple[torch.Tensor, torch.Tensor]: 

26 """Compute reconstruction and variational losses. 

27 

28 Args: 

29 model_output: custom class that stores model output like latentspaces and reconstructions. 

30 targets: original data to compare with reconstruction 

31 

32 Returns: 

33 Tuple of torch.Tensors: reconstruction loss and variational loss 

34 

35 """ 

36 true_samples = torch.randn( 

37 self.config.batch_size, self.config.latent_dim, requires_grad=False 

38 ) 

39 

40 recon_loss = self.recon_loss(model_output.reconstruction, targets) 

41 var_loss = self.compute_variational_loss( 

42 mu=model_output.latent_mean, 

43 logvar=model_output.latent_logvar, 

44 z=model_output.latentspace, 

45 true_samples=true_samples, 

46 ) 

47 

48 return recon_loss, var_loss 

49 

50 def forward( 

51 self, 

52 model_output: ModelOutput, 

53 targets: torch.Tensor, 

54 epoch: Optional[int] = None, 

55 total_epochs: Optional[int] = None, 

56 **kwargs, 

57 ) -> Tuple[torch.Tensor, dict]: 

58 """Forward pass with conditional annealing. 

59 Args: 

60 model_output: custom class that stores model output like latentspaces and reconstructions. 

61 targets: original data to compare with reconstruction 

62 epoch: current training epoch 

63 total_epochs: number of total epochs 

64 **kwargs 

65 Returns: 

66 Tuple consisting of: 

67 - tensor of the total loss 

68 Returns: 

69 Tuple consisting of: 

70 - tensor of the total loss 

71 - Dict with loss_type as key and sub_loss value. 

72 - Dict with loss_type as key and sub_loss value. 

73 

74 """ 

75 

76 recon_loss, var_loss = self._compute_losses(model_output, targets) 

77 

78 # if are pretraining, we pass total_epochs, otherwise, we use 'epochs' from config 

79 calc_epochs: int = self.config.epochs 

80 if total_epochs: 

81 calc_epochs = total_epochs 

82 

83 if self.config.anneal_function == "no-annealing": 

84 # Use constant beta 

85 effective_beta = self.config.beta 

86 anneal_factor = 1.0 

87 else: 

88 anneal_factor = self.annealing_scheduler.get_weight( 

89 epoch_current=epoch, 

90 total_epoch=calc_epochs, 

91 func=self.config.anneal_function, 

92 ) 

93 effective_beta = self.config.beta * anneal_factor 

94 

95 total_loss = recon_loss + effective_beta * var_loss 

96 

97 return total_loss, { 

98 "recon_loss": recon_loss, 

99 "var_loss": var_loss * effective_beta, 

100 "anneal_factor": torch.tensor(anneal_factor), 

101 "effective_beta_factor": torch.tensor(effective_beta), 

102 }