Coverage for src / autoencodix / losses / maskix_loss.py: 95%

38 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2from torch.nn.functional import binary_cross_entropy_with_logits as bce_logits 

3from torch.nn.functional import mse_loss 

4from autoencodix.configs import DefaultConfig 

5from autoencodix.base._base_loss import BaseLoss 

6from autoencodix.utils._model_output import ModelOutput 

7from typing import Tuple, Dict 

8 

9 

10class MaskixLoss(BaseLoss): 

11 def __init__(self, config: DefaultConfig, annealing_scheduler=None): 

12 """Inits MaskixLoss 

13 Args: 

14 config: Configuration object.Any 

15 annealing_scheduler: Enables passing a custom annealer class, defaults to our implementation of an annealer 

16 """ 

17 super().__init__(config, annealing_scheduler=annealing_scheduler) 

18 self._validate_loss_config() 

19 

20 def _validate_model_output(self, model_output: ModelOutput): 

21 if model_output.additional_info is None: 

22 raise ValueError( 

23 "For Maskix, we need to provide an 'additional_info' attribute in the ModelOutput, this likely went wrong in the architecture's forward method" 

24 ) 

25 if not isinstance(model_output.additional_info, dict): 

26 raise TypeError( 

27 f"The `additional_info` attribute of ModelOutput needs to be of type dict, got {type(model_output.additional_info)}" 

28 ) 

29 if "predicted_mask" not in model_output.additional_info.keys(): 

30 raise ValueError( 

31 f"For Maskix, we require 'predicted_mask' to be in the additional_info attribute of ModelOutput, got: {model_output.additional_info.keys()}" 

32 ) 

33 

34 def _validate_loss_config(self): 

35 if not self.config.loss_reduction == "mean": 

36 import warnings 

37 

38 warnings.warn( 

39 f"You chose loss reduction: {self.config.loss_reduction}, this deviates from the implementation in the literature for this architecture, the authors used 'mean'" 

40 ) 

41 if not self.config.reconstruction_loss == "mse": 

42 import warnings 

43 

44 warnings.warn( 

45 f"You chose {self.config.reconstruction_loss}, however we support only 'mse' for Maskix, we will use 'mse'" 

46 ) 

47 

48 def forward( 

49 self, 

50 model_output: ModelOutput, 

51 targets: torch.Tensor, 

52 corrupted_input: torch.Tensor, 

53 **kwargs, 

54 ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 

55 self._validate_model_output(model_output) 

56 predicted_mask = model_output.additional_info["predicted_mask"] # ty: ignore 

57 

58 is_masked = (corrupted_input != targets).float() 

59 if not predicted_mask.shape == is_masked.shape: 

60 raise ValueError( 

61 f"Shape mismatch: predicted_mask {predicted_mask.shape} vs mask {is_masked.shape}" 

62 ) 

63 mask_loss: torch.Tensor = bce_logits( 

64 predicted_mask, is_masked, reduction=self.config.loss_reduction 

65 ) 

66 

67 # this is from the publication 

68 # the authors want to give higher incentives to reconstruct corrupted features correctly 

69 corrupted_weight_matrix: torch.Tensor = ( 

70 self.config.delta_mask_corrupted * is_masked 

71 + (1 - is_masked) * (1 - self.config.delta_mask_corrupted) 

72 ) 

73 recon_loss: torch.Tensor = mse_loss( 

74 model_output.reconstruction, targets, reduction="none" 

75 ) 

76 recon_loss_weighted = self.reduction_fn( 

77 torch.mul(corrupted_weight_matrix, recon_loss) 

78 ) 

79 total_loss: torch.Tensor = ( 

80 mask_loss * self.config.delta_mask_predictor 

81 + (1 - self.config.delta_mask_predictor) * recon_loss_weighted 

82 ) 

83 

84 recon_loss: torch.Tensor = mse_loss( 

85 model_output.reconstruction, targets, reduction=self.config.loss_reduction 

86 ) 

87 

88 # predicted_mask: torch.tensor 

89 return total_loss, { 

90 "recon_loss": recon_loss, 

91 "recon_loss_weighted": (1 - self.config.delta_mask_predictor) 

92 * recon_loss_weighted, 

93 "mask_loss": mask_loss * self.config.delta_mask_predictor, 

94 }