Coverage for src / autoencodix / losses / vanillix_loss.py: 100%
8 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2from typing import Dict, Optional, Tuple
4from autoencodix.base._base_loss import BaseLoss
5from autoencodix.utils._model_output import ModelOutput
8class VanillixLoss(BaseLoss):
9 """Implements loss for vanilla autoencoder."""
11 def forward(
12 self,
13 model_output: ModelOutput,
14 targets: torch.Tensor,
15 epoch: Optional[int] = None,
16 **kwargs,
17 ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
18 """Calculates reconstruction loss as specified in config (BCE, MSE, etc.).
20 Args:
21 model_output: custom class that stores model output like latentspaces and reconstructions.
22 targets: original data to compare with reconstruction
23 epoch: not used for Vanillix loss.
24 **kwargs: addtional keyword args.
26 """
27 total_loss = self.recon_loss(model_output.reconstruction, targets)
28 return total_loss, {"recon_loss": total_loss}