Coverage for src / autoencodix / losses / vanillix_loss.py: 100%

8 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2from typing import Dict, Optional, Tuple 

3 

4from autoencodix.base._base_loss import BaseLoss 

5from autoencodix.utils._model_output import ModelOutput 

6 

7 

8class VanillixLoss(BaseLoss): 

9 """Implements loss for vanilla autoencoder.""" 

10 

11 def forward( 

12 self, 

13 model_output: ModelOutput, 

14 targets: torch.Tensor, 

15 epoch: Optional[int] = None, 

16 **kwargs, 

17 ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 

18 """Calculates reconstruction loss as specified in config (BCE, MSE, etc.). 

19 

20 Args: 

21 model_output: custom class that stores model output like latentspaces and reconstructions. 

22 targets: original data to compare with reconstruction 

23 epoch: not used for Vanillix loss. 

24 **kwargs: addtional keyword args. 

25 

26 """ 

27 total_loss = self.recon_loss(model_output.reconstruction, targets) 

28 return total_loss, {"recon_loss": total_loss}