Coverage for src / autoencodix / losses / maskix_loss.py: 95%
38 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2from torch.nn.functional import binary_cross_entropy_with_logits as bce_logits
3from torch.nn.functional import mse_loss
4from autoencodix.configs import DefaultConfig
5from autoencodix.base._base_loss import BaseLoss
6from autoencodix.utils._model_output import ModelOutput
7from typing import Tuple, Dict
10class MaskixLoss(BaseLoss):
11 def __init__(self, config: DefaultConfig, annealing_scheduler=None):
12 """Inits MaskixLoss
13 Args:
14 config: Configuration object.Any
15 annealing_scheduler: Enables passing a custom annealer class, defaults to our implementation of an annealer
16 """
17 super().__init__(config, annealing_scheduler=annealing_scheduler)
18 self._validate_loss_config()
20 def _validate_model_output(self, model_output: ModelOutput):
21 if model_output.additional_info is None:
22 raise ValueError(
23 "For Maskix, we need to provide an 'additional_info' attribute in the ModelOutput, this likely went wrong in the architecture's forward method"
24 )
25 if not isinstance(model_output.additional_info, dict):
26 raise TypeError(
27 f"The `additional_info` attribute of ModelOutput needs to be of type dict, got {type(model_output.additional_info)}"
28 )
29 if "predicted_mask" not in model_output.additional_info.keys():
30 raise ValueError(
31 f"For Maskix, we require 'predicted_mask' to be in the additional_info attribute of ModelOutput, got: {model_output.additional_info.keys()}"
32 )
34 def _validate_loss_config(self):
35 if not self.config.loss_reduction == "mean":
36 import warnings
38 warnings.warn(
39 f"You chose loss reduction: {self.config.loss_reduction}, this deviates from the implementation in the literature for this architecture, the authors used 'mean'"
40 )
41 if not self.config.reconstruction_loss == "mse":
42 import warnings
44 warnings.warn(
45 f"You chose {self.config.reconstruction_loss}, however we support only 'mse' for Maskix, we will use 'mse'"
46 )
48 def forward(
49 self,
50 model_output: ModelOutput,
51 targets: torch.Tensor,
52 corrupted_input: torch.Tensor,
53 **kwargs,
54 ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
55 self._validate_model_output(model_output)
56 predicted_mask = model_output.additional_info["predicted_mask"] # ty: ignore
58 is_masked = (corrupted_input != targets).float()
59 if not predicted_mask.shape == is_masked.shape:
60 raise ValueError(
61 f"Shape mismatch: predicted_mask {predicted_mask.shape} vs mask {is_masked.shape}"
62 )
63 mask_loss: torch.Tensor = bce_logits(
64 predicted_mask, is_masked, reduction=self.config.loss_reduction
65 )
67 # this is from the publication
68 # the authors want to give higher incentives to reconstruct corrupted features correctly
69 corrupted_weight_matrix: torch.Tensor = (
70 self.config.delta_mask_corrupted * is_masked
71 + (1 - is_masked) * (1 - self.config.delta_mask_corrupted)
72 )
73 recon_loss: torch.Tensor = mse_loss(
74 model_output.reconstruction, targets, reduction="none"
75 )
76 recon_loss_weighted = self.reduction_fn(
77 torch.mul(corrupted_weight_matrix, recon_loss)
78 )
79 total_loss: torch.Tensor = (
80 mask_loss * self.config.delta_mask_predictor
81 + (1 - self.config.delta_mask_predictor) * recon_loss_weighted
82 )
84 recon_loss: torch.Tensor = mse_loss(
85 model_output.reconstruction, targets, reduction=self.config.loss_reduction
86 )
88 # predicted_mask: torch.tensor
89 return total_loss, {
90 "recon_loss": recon_loss,
91 "recon_loss_weighted": (1 - self.config.delta_mask_predictor)
92 * recon_loss_weighted,
93 "mask_loss": mask_loss * self.config.delta_mask_predictor,
94 }