Coverage for src / autoencodix / trainers / _maskix_trainer.py: 33%

39 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2from typing import Optional, Type, Union, Tuple, Dict, Any, List 

3from autoencodix.trainers._general_trainer import GeneralTrainer 

4from autoencodix.base._base_dataset import BaseDataset 

5from autoencodix.base._base_loss import BaseLoss 

6from autoencodix.base._base_autoencoder import BaseAutoencoder 

7from autoencodix.utils._result import Result 

8from autoencodix.configs.default_config import DefaultConfig 

9 

10 

11class MaskixTrainer(GeneralTrainer): 

12 """Specialized trainer for Maskix (masked-based) autoencoders. 

13 

14 TODO 

15 Attributes: 

16 Inherits all attributes from GeneralTrainer. 

17 """ 

18 

19 def __init__( 

20 self, 

21 trainset: Optional[BaseDataset], 

22 validset: Optional[BaseDataset], 

23 result: Result, 

24 config: DefaultConfig, 

25 model_type: Type[BaseAutoencoder], 

26 loss_type: Type[BaseLoss], 

27 ontologies: Optional[Union[Tuple, List]] = None, 

28 masking_fn: Optional[Any] = None, 

29 masking_fn_kwargs: Optional[Dict[str, Any]] = {}, 

30 **kwargs, 

31 ): 

32 """Initializes the OntixTrainer with the given datasets, model, and configuration. 

33 

34 

35 Args: 

36 trainset: The dataset used for training. 

37 validset: The dataset used for validation, if provided. 

38 result: An object to store and manage training results. 

39 config: Configuration object containing training hyperparameters and settings. 

40 model_type: The autoencoder model class to be trained. 

41 loss_type: The loss function class specific to the model. 

42 """ 

43 super().__init__( 

44 trainset=trainset, 

45 validset=validset, 

46 result=result, 

47 config=config, 

48 model_type=model_type, 

49 loss_type=loss_type, 

50 ontologies=ontologies, 

51 ) 

52 mask_probas_list: List[float] = [ 

53 self._config.maskix_swap_prob 

54 ] * self._model.input_dim 

55 self._mask_probas = torch.tensor(mask_probas_list).to(self._model.device) 

56 self.masking_fn = masking_fn 

57 self.masking_fn_kwargs = masking_fn_kwargs 

58 if self.masking_fn is not None: 

59 self._validate_masking_fn() 

60 

61 def _validate_masking_fn(self): 

62 """Test that user provided masking function accepts a torch.Tensor as input and returns a torch.Tensor.""" 

63 import inspect 

64 

65 sig = inspect.signature(self.masking_fn) 

66 if len(sig.parameters) < 1: 

67 raise ValueError( 

68 "The provided masking function must accept at least one argument (the input tensor)." 

69 ) 

70 

71 test_input = torch.randn(100, 30) 

72 test_output = self.masking_fn(test_input, **self.masking_fn_kwargs) 

73 if not isinstance(test_output, torch.Tensor): 

74 raise ValueError( 

75 "The provided masking function must return a torch.Tensor as output." 

76 ) 

77 if test_output.shape != test_input.shape: 

78 raise ValueError( 

79 "The output shape of the masking function must match the input shape." 

80 ) 

81 

82 def maskix_hook(self, X: torch.Tensor) -> torch.Tensor: 

83 """Applies the Maskix corruption mechanism to the input data X. 

84 

85 For each feature in X, with a probability defined in self._mask_probas, 

86 the feature value is replaced by a value from another randomly selected sample. 

87 This creates a corrupted version of the input data. 

88 

89 Args: 

90 X: Input data tensor of shape (batch_size, num_features). 

91 Returns: 

92 A corrupted version of the input data tensor. 

93 """ 

94 if self.masking_fn is None: 

95 return self._maskix_hook(X) 

96 return self.masking_fn(X, **self.masking_fn_kwargs) 

97 

98 def _maskix_hook( 

99 self, X: torch.Tensor 

100 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 

101 # expand probablities for bernoulli sampling to match input shape 

102 probs = self._mask_probas.expand(X.shape) 

103 

104 # Create the Boolean Mask (1 = Swap, 0 = Keep) 

105 should_swap = torch.bernoulli(probs).bool() 

106 

107 # COLUMN-WISE SHUFFLING 

108 # We generate a random float matrix and argsort it along dim=0. 

109 # This gives us independent random indices for every column. 

110 rand_indices = torch.rand(X.shape, device=X.device).argsort(dim=0) 

111 

112 # Use gather to reorder X based on these random indices 

113 shuffled_X = torch.gather(X, 0, rand_indices) 

114 corrupted_X = torch.where(should_swap, shuffled_X, X) 

115 

116 return corrupted_X