Coverage for src / autoencodix / modeling / _ontix_architecture.py: 14%
120 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Optional, Union, Tuple
3import torch
4import torch.nn as nn
6from autoencodix.base._base_autoencoder import BaseAutoencoder
7from autoencodix.utils._model_output import ModelOutput
8from autoencodix.configs.default_config import DefaultConfig
10from ._layer_factory import LayerFactory
13class OntixArchitecture(BaseAutoencoder):
14 """Ontology Autoencoder implementation with separate encoder and decoder construction.
16 Attributes:
17 input_dim: number of input features
18 config: Configuration object containing model architecture parameters
19 _encoder: Encoder network of the autoencoder
20 _decoder: Decoder network of the autoencoder
21 mu: Linear layer to compute the mean of the latent distribution
22 logvar: Linear layer to compute the log-variance of the latent distribution
23 masks: Tuple of weight masks for the decoder layers based on ontology
24 latent_dim: Dimension of the latent space, inferred from the first mask
25 ontologies: Ontology information.
26 feature_order: Order of features for input data.
28 """
30 def __init__(
31 self,
32 config: Optional[Union[None, DefaultConfig]],
33 input_dim: int,
34 ontologies: tuple,
35 feature_order: list,
36 ) -> None:
37 """Initialize the Vanilla Autoencoder with the given configuration.
39 Args:
40 config: Configuration object containing model parameters.
41 input_dim: Number of input features.
42 ontologies: Ontology information.
43 feature_order: Order of features for input data.
44 """
45 if config is None:
46 config = DefaultConfig()
47 self._config = config
48 super().__init__(config, input_dim)
49 self.input_dim = input_dim
50 self._mu: nn.Module
51 self._logvar: nn.Module
52 # create masks for sparse decoder
53 self.ontologies = ontologies
54 self.feature_order = feature_order
55 self.masks = self._make_masks(config=self._config, feature_order=feature_order)
56 self.latent_dim = self.masks[0].shape[1]
57 print("Latent Dim: " + str(self.latent_dim))
58 # populate self.encoder and self.decoder
59 self._build_network()
60 self.apply(self._init_weights)
61 self._decoder.apply(
62 self._positive_dec
63 ) # Sparse decoder only has positive weights
65 # Apply weight mask to create ontology-based decoder
66 with torch.no_grad():
67 # Check that the decoder has the same number of layers as masks
68 if len(self.masks) != len(self._decoder):
69 print(len(self.masks), len(self._decoder))
70 print(self._encoder)
71 print(self._decoder)
72 raise ValueError(
73 "Number of masks does not match number of decoder layers"
74 )
75 else:
76 for i, mask in enumerate(self.masks):
77 self._decoder[i].weight.mul_(mask)
79 def _make_masks(
80 self, config: DefaultConfig, feature_order: list
81 ) -> Tuple[torch.Tensor, ...]:
82 """Create masks for sparse decoder based on ontology via config
84 Args:
85 config: Configuration object containing model parameters
86 feature_order: Order of features for input data
88 Returns:
89 Tuple containing the masks for the decoder network
91 """
92 # Read ontology from config
94 masks = tuple()
95 # feature_names are all values in the last ontology layer
96 all_feature_names = set()
97 for key, values in self.ontologies[-1].items():
98 all_feature_names.update(values)
99 all_feature_names = list(all_feature_names)
100 print("Ontix checks:")
101 print(f"All possible feature names length: {len(all_feature_names)}")
102 print(f"Feature order length: {len(feature_order)}")
103 # Check if all features in feature_order are present in all_feature_names
104 feature_names = [f for f in feature_order]
105 missing_features = [f for f in feature_order if f not in all_feature_names]
106 if missing_features:
107 print(
108 f"Features in feature_order not found in all_feature_names: {missing_features}"
109 )
110 print(f"Feature names without filtering: {len(feature_names)}")
112 # Enumerate through the ontologies
113 for x, ont_dic in enumerate(self.ontologies):
114 prev_lay_dim = len(ont_dic.keys())
116 if x == len(self.ontologies) - 1:
117 # fixed sort of feature list
118 node_list = feature_names
119 else:
120 node_list = list(self.ontologies[x + 1].keys())
121 next_lay_dim = len(node_list)
122 # create masks for sparse decoder
123 mask = torch.zeros(next_lay_dim, prev_lay_dim)
124 p_int = 0
125 if len(node_list) == next_lay_dim:
126 if len(ont_dic.keys()) == prev_lay_dim:
127 for p_id in ont_dic:
128 feature_list = ont_dic[p_id]
129 for f_id in feature_list:
130 if f_id in node_list:
131 f_int = node_list.index(f_id)
132 mask[f_int, p_int] = 1
134 p_int += 1
135 else:
136 print(
137 "Mask layer cannot be calculated. Ontology key list does not match previous layer dimension"
138 )
139 print("Returning zero mask")
140 else:
141 print(f"node list: {len(node_list)} vs next_lay_dim:{next_lay_dim}")
142 print(
143 "Mask layer cannot be calculated. Output layer list does not match next layer dimension"
144 )
145 print("Returning zero mask")
146 print(
147 f"Mask layer {x} with shape {mask.shape} and {torch.sum(mask)} connections"
148 )
149 masks += (mask,)
151 if torch.max(mask) < 1:
152 print(
153 "You provided an ontology with no connections between layers in the decoder. Please check your ontology definition."
154 )
156 return masks
158 def _build_network(self) -> None:
159 """Construct the encoder and decoder networks.
161 Handles cases where `n_layers=0` by skipping the encoder and using only mu/logvar.
162 """
163 #### Encoder copied from varix architecture ####
164 enc_dim = LayerFactory.get_layer_dimensions(
165 feature_dim=self.input_dim,
166 latent_dim=self.latent_dim,
167 n_layers=self._config.n_layers,
168 enc_factor=self._config.enc_factor,
169 )
170 #
172 # Case 1: No Hidden Layers (Direct Mapping)
173 self._encoder = nn.Sequential()
174 self._mu = nn.Linear(self.input_dim, self.latent_dim)
175 self._logvar = nn.Linear(self.input_dim, self.latent_dim)
177 # Case 2: At Least One Hidden Layer
178 if self._config.n_layers > 0:
179 encoder_layers = []
180 # print(enc_dim)
181 for i, (in_features, out_features) in enumerate(
182 zip(enc_dim[:-1], enc_dim[1:])
183 ):
184 # since we add mu and logvar, we will remove the last layer
185 if i == len(enc_dim) - 2:
186 break
187 encoder_layers.extend(
188 LayerFactory.create_layer(
189 in_features=in_features,
190 out_features=out_features,
191 dropout_p=self._config.drop_p,
192 last_layer=False, # only for decoder relevant
193 )
194 )
196 self._encoder = nn.Sequential(*encoder_layers)
197 self._mu = nn.Linear(enc_dim[-2], self.latent_dim)
198 self._logvar = nn.Linear(enc_dim[-2], self.latent_dim)
199 #### Encoder copied from varix architecture ####
201 # Construct Decoder with Sparse Connections via masks
202 # Decoder dimension is determined by the masks
203 dec_dim = [self.latent_dim] + [
204 mask.shape[0] for mask in self.masks
205 ] # + [self.input_dim]
206 decoder_layers = []
207 for i, (in_features, out_features) in enumerate(zip(dec_dim[:-1], dec_dim[1:])):
208 # last_layer = i == len(dec_dim) - 2
209 last_layer = True ## Only linear layers in sparse decoder
210 decoder_layers.extend(
211 LayerFactory.create_layer(
212 in_features=in_features,
213 out_features=out_features,
214 dropout_p=0, ## No dropout in sparse decoder
215 last_layer=last_layer,
216 # only_linear=True,
217 )
218 )
220 self._decoder = nn.Sequential(*decoder_layers)
222 def _positive_dec(self, m):
223 if isinstance(m, nn.Linear):
224 m.weight.data = m.weight.data.clamp(min=0)
226 def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
227 """Encode the input tensor x
229 Args:
230 x: Input tensor
232 Returns:
233 Encoded tensor
234 """
235 latent = x # for case where n_layers=0
236 if len(self._encoder) > 0:
237 latent = self._encoder(x)
238 mu = self._mu(latent)
239 logvar = self._logvar(latent)
240 # numeric stability
241 logvar = torch.clamp(logvar, 0.01, 20)
242 mu = torch.where(mu < 0.0000001, torch.zeros_like(mu), mu)
243 return mu, logvar
245 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
246 """Reparameterization trick for VAE
248 Args:
249 mu: Mean tensor
250 logvar: Log variance tensor
252 Returns:
253 Reparameterized latent tensor
254 """
255 std = torch.exp(0.5 * logvar)
256 eps = torch.randn_like(std)
257 return mu + eps * std
259 def decode(self, x: torch.Tensor) -> torch.Tensor:
260 """Decode the latent tensor x
262 Args:
263 x: Latent tensor
265 Returns:
266 torch.Tensor
267 Decoded tensor
269 """
270 return self._decoder(x)
272 def forward(self, x: torch.Tensor) -> ModelOutput:
273 """Forward pass of the model, fill
275 Args:
276 x: Input tensor
278 Returns:
279 ModelOutput object containing the reconstructed tensor and latent tensor
281 """
282 mu, logvar = self.encode(x)
283 z = self.reparameterize(mu, logvar)
284 x_hat = self.decode(z)
285 return ModelOutput(
286 reconstruction=x_hat,
287 latentspace=z,
288 latent_mean=mu,
289 latent_logvar=logvar,
290 additional_info=None,
291 )
293 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor:
294 """Returns the latent space representation of the input.
296 Args:
297 x: Input tensor
299 Returns:
300 Latent space representation
302 """
303 mu, logvar = self.encode(x)
304 z = self.reparameterize(mu, logvar)
305 return z