Coverage for src / autoencodix / modeling / _maskix_architecture.py: 27%
60 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2import torch.nn as nn
3from autoencodix.configs import DefaultConfig
4from autoencodix.modeling._layer_factory import LayerFactory
5from typing import Optional, Union, Tuple, Dict, List
6from autoencodix.base._base_autoencoder import BaseAutoencoder
7from autoencodix.utils._model_output import ModelOutput
10class MaskixArchitectureVanilla(BaseAutoencoder):
11 """Masked Autoencoder Architecture that follows https://doi.org/10.1093/bioinformatics/btae020
13 To closely mimic the publication, the network is not build with our LayerFactory as in
14 other architectures.
16 Attributes:
17 input_dim: number of input features
18 config: Configuration object containing model architecture parameters
19 encoder: Encoder network of the autoencoder
20 decoder: Decoder network of the autoencoder
23 """
25 def __init__(
26 self,
27 config: Optional[DefaultConfig],
28 input_dim: Union[int, Tuple[int, ...]],
29 ontologies: Optional[Union[Tuple, Dict]] = None,
30 feature_order: Optional[Union[Tuple, Dict]] = None,
31 ):
32 if config is None:
33 config = DefaultConfig()
34 self._config: DefaultConfig = config
35 super().__init__(config, input_dim)
36 self.input_dim: Union[int, Tuple[int, ...]] = input_dim
37 if not isinstance(self.input_dim, int):
38 raise TypeError(
39 f"input dim needs to be int for MaskixArchitecture, got {type(self.input_dim)}"
40 )
41 self.latent_dim: int = self._config.latent_dim
43 # populate self.encoder and self.decoder
44 self._encoder: nn.Module
45 self._decoder: nn.Module
46 self._build_network()
47 self.apply(self._init_weights)
49 def _build_network(self):
50 if self._config.maskix_architecture == "scMAE":
51 self._build_scMAE()
52 elif self._config.maskix_architecture == "custom":
53 self._build_custom()
54 else:
55 raise ValueError(
56 f"Got {self.config.maskix_architecture}, but expected 'scMAE' or 'custom'"
57 "This happens if you allow a new value in DefaultConfig but did not implement it here."
58 )
60 def _build_custom(self):
61 self._mask_predictor = nn.Linear(self.latent_dim, self.input_dim) # ty: ignore
63 enc_dim = LayerFactory.get_layer_dimensions(
64 feature_dim=self.input_dim, # ty: ignore
65 latent_dim=self._config.latent_dim,
66 n_layers=self._config.n_layers,
67 enc_factor=self._config.enc_factor,
68 )
69 first_layer = nn.Dropout(p=self.config.drop_p)
71 encoder_layers: List[nn.Module] = []
72 if self._config.n_layers == 0:
73 self._encoder = nn.Sequential(
74 nn.Dropout(p=self.config.drop_p),
75 nn.Linear(self.input_dim, self.latent_dim), # ty: ignore
76 )
77 # print(enc_dim)
78 for i, (in_features, out_features) in enumerate(zip(enc_dim[:-1], enc_dim[1:])):
79 last_layer = i == len(enc_dim) - 2
80 encoder_layers.extend(
81 LayerFactory.create_maskix_layer(
82 in_features=in_features,
83 out_features=out_features,
84 last_layer=last_layer,
85 )
86 )
88 self._encoder = nn.Sequential(first_layer, *encoder_layers)
90 # dec_dimensions = enc_dim[::-1] # Reverse the dimensions and copy
91 # decoder_layers: List[nn.Module] = []
92 # for i, (in_features, out_features) in enumerate(
93 # zip(dec_dimensions[:-1], dec_dimensions[1:])
94 # ):
95 # last_layer = i == len(dec_dimensions) - 2
96 # decoder_layers.extend(
97 # LayerFactory.create_maskix_layer(
98 # in_features=in_features,
99 # out_features=out_features,
100 # last_layer=last_layer,
101 # )
102 # )
104 # latent_layer = nn.Linear(
105 # in_features=self.latent_dim + self.input_dim, # ty: ignore
106 # out_features=dec_dimensions[0],
107 # ) # ty: ignore
108 # self._decoder = nn.Sequential(latent_layer, *decoder_layers)
110 dec_start: int = self.latent_dim + self.input_dim # ty: ignore
111 dec_end: int = self.input_dim # ty: ignore
112 dec_dim = LayerFactory.get_layer_dimensions(
113 feature_dim=dec_start,
114 latent_dim=dec_end, # Repurpose 'latent_dim' param as target dim
115 n_layers=self._config.n_layers,
116 enc_factor=self._config.enc_factor,
117 )
118 decoder_layers: List[nn.Module] = []
119 for i, (in_features, out_features) in enumerate(zip(dec_dim[:-1], dec_dim[1:])):
120 last_layer = i == len(dec_dim) - 2
121 decoder_layers.extend(
122 LayerFactory.create_maskix_layer(
123 in_features=in_features,
124 out_features=out_features,
125 last_layer=last_layer,
126 )
127 )
128 self._decoder = nn.Sequential(*decoder_layers)
130 def _build_scMAE(self):
131 self._encoder = nn.Sequential(
132 nn.Dropout(p=self.config.drop_p),
133 nn.Linear(self.input_dim, self._config.maskix_hidden_dim), # ty: ignore
134 nn.LayerNorm(self._config.maskix_hidden_dim),
135 nn.Mish(inplace=True),
136 nn.Linear(self._config.maskix_hidden_dim, self.latent_dim),
137 nn.LayerNorm(self.latent_dim),
138 nn.Mish(inplace=True),
139 nn.Linear(self.latent_dim, self.latent_dim),
140 )
142 self._mask_predictor = nn.Linear(self.latent_dim, self.input_dim) # ty: ignore
143 self._decoder = nn.Linear(
144 in_features=self.latent_dim + self.input_dim, # ty: ignore
145 out_features=self.input_dim, # ty: ignore
146 )
148 def encode(self, x: torch.Tensor) -> torch.Tensor:
149 """Encodes the input data.
151 Args:
152 x: input Tensor
153 Returns:
154 torch.Tensor
156 """
157 return self._encoder(x)
159 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor:
160 """Returns the latent space representation of the input data.
162 Args:
163 x: input Tensor
164 Returns:
165 torch.Tensor
167 """
168 return self.encode(x)
170 def decode(self, x: torch.Tensor) -> torch.Tensor:
171 """Decodes the latent representation.
173 Args:
174 x: input Tensor
175 Returns:
176 torch.Tensor
178 """
179 return self._decoder(x)
181 def forward(self, x: torch.Tensor) -> ModelOutput:
182 latent: torch.Tensor = self.encode(x=x)
183 predicted_mask: torch.Tensor = self._mask_predictor(latent)
184 return ModelOutput(
185 reconstruction=self.decode(torch.cat([latent, predicted_mask], dim=1)),
186 latentspace=latent,
187 latent_mean=None,
188 latent_logvar=None,
189 additional_info={"predicted_mask": predicted_mask},
190 )