Coverage for src / autoencodix / trainers / _maskix_trainer.py: 33%
39 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2from typing import Optional, Type, Union, Tuple, Dict, Any, List
3from autoencodix.trainers._general_trainer import GeneralTrainer
4from autoencodix.base._base_dataset import BaseDataset
5from autoencodix.base._base_loss import BaseLoss
6from autoencodix.base._base_autoencoder import BaseAutoencoder
7from autoencodix.utils._result import Result
8from autoencodix.configs.default_config import DefaultConfig
11class MaskixTrainer(GeneralTrainer):
12 """Specialized trainer for Maskix (masked-based) autoencoders.
14 TODO
15 Attributes:
16 Inherits all attributes from GeneralTrainer.
17 """
19 def __init__(
20 self,
21 trainset: Optional[BaseDataset],
22 validset: Optional[BaseDataset],
23 result: Result,
24 config: DefaultConfig,
25 model_type: Type[BaseAutoencoder],
26 loss_type: Type[BaseLoss],
27 ontologies: Optional[Union[Tuple, List]] = None,
28 masking_fn: Optional[Any] = None,
29 masking_fn_kwargs: Optional[Dict[str, Any]] = {},
30 **kwargs,
31 ):
32 """Initializes the OntixTrainer with the given datasets, model, and configuration.
35 Args:
36 trainset: The dataset used for training.
37 validset: The dataset used for validation, if provided.
38 result: An object to store and manage training results.
39 config: Configuration object containing training hyperparameters and settings.
40 model_type: The autoencoder model class to be trained.
41 loss_type: The loss function class specific to the model.
42 """
43 super().__init__(
44 trainset=trainset,
45 validset=validset,
46 result=result,
47 config=config,
48 model_type=model_type,
49 loss_type=loss_type,
50 ontologies=ontologies,
51 )
52 mask_probas_list: List[float] = [
53 self._config.maskix_swap_prob
54 ] * self._model.input_dim
55 self._mask_probas = torch.tensor(mask_probas_list).to(self._model.device)
56 self.masking_fn = masking_fn
57 self.masking_fn_kwargs = masking_fn_kwargs
58 if self.masking_fn is not None:
59 self._validate_masking_fn()
61 def _validate_masking_fn(self):
62 """Test that user provided masking function accepts a torch.Tensor as input and returns a torch.Tensor."""
63 import inspect
65 sig = inspect.signature(self.masking_fn)
66 if len(sig.parameters) < 1:
67 raise ValueError(
68 "The provided masking function must accept at least one argument (the input tensor)."
69 )
71 test_input = torch.randn(100, 30)
72 test_output = self.masking_fn(test_input, **self.masking_fn_kwargs)
73 if not isinstance(test_output, torch.Tensor):
74 raise ValueError(
75 "The provided masking function must return a torch.Tensor as output."
76 )
77 if test_output.shape != test_input.shape:
78 raise ValueError(
79 "The output shape of the masking function must match the input shape."
80 )
82 def maskix_hook(self, X: torch.Tensor) -> torch.Tensor:
83 """Applies the Maskix corruption mechanism to the input data X.
85 For each feature in X, with a probability defined in self._mask_probas,
86 the feature value is replaced by a value from another randomly selected sample.
87 This creates a corrupted version of the input data.
89 Args:
90 X: Input data tensor of shape (batch_size, num_features).
91 Returns:
92 A corrupted version of the input data tensor.
93 """
94 if self.masking_fn is None:
95 return self._maskix_hook(X)
96 return self.masking_fn(X, **self.masking_fn_kwargs)
98 def _maskix_hook(
99 self, X: torch.Tensor
100 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
101 # expand probablities for bernoulli sampling to match input shape
102 probs = self._mask_probas.expand(X.shape)
104 # Create the Boolean Mask (1 = Swap, 0 = Keep)
105 should_swap = torch.bernoulli(probs).bool()
107 # COLUMN-WISE SHUFFLING
108 # We generate a random float matrix and argsort it along dim=0.
109 # This gives us independent random indices for every column.
110 rand_indices = torch.rand(X.shape, device=X.device).argsort(dim=0)
112 # Use gather to reorder X based on these random indices
113 shuffled_X = torch.gather(X, 0, rand_indices)
114 corrupted_X = torch.where(should_swap, shuffled_X, X)
116 return corrupted_X