Coverage for src / autoencodix / modeling / _varix_architecture.py: 94%
64 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Optional, Union, Tuple, Dict
3import torch
4import torch.nn as nn
6from autoencodix.base._base_autoencoder import BaseAutoencoder
7from autoencodix.utils._model_output import ModelOutput
8from autoencodix.configs.default_config import DefaultConfig
10from ._layer_factory import LayerFactory
13class VarixArchitecture(BaseAutoencoder):
14 """Variational Autoencoder implementation with separate encoder and decoder construction.
16 Attributes:
17 input_dim: number of input features
18 config: Configuration object containing model architecture parameters
19 encoder: Encoder network of the autoencoder
20 decoder: Decoder network of the autoencoder
21 mu: Linear layer to compute the mean of the latent distribution
22 logvar: Linear layer to compute the log-variance of the latent distribution
24 """
26 def __init__(
27 self,
28 config: Optional[Union[None, DefaultConfig]],
29 input_dim: int,
30 ontologies: Optional[Union[Tuple, Dict]] = None,
31 feature_order: Optional[Union[Tuple, Dict]] = None,
32 ) -> None:
33 """Initialize the Vanilla Autoencoder with the given configuration.
35 Args:
36 config: Configuration object containing model parameters.
37 input_dim: Number of input features.
38 """
40 if config is None:
41 config = DefaultConfig()
42 self._config: DefaultConfig = config
43 super().__init__(config=config, input_dim=input_dim)
44 self.input_dim: int = input_dim
45 self._mu: nn.Module
46 self._logvar: nn.Module
47 self._encoder: nn.Module
48 self._decoder: nn.Module
50 # populate self.encoder and self.decoder
51 self._build_network()
52 self.apply(self._init_weights)
54 def _build_network(self) -> None:
55 """Construct the encoder and decoder networks.
57 Handles cases where `n_layers=0` by skipping the encoder and using only mu/logvar.
58 """
59 enc_dim = LayerFactory.get_layer_dimensions(
60 feature_dim=self.input_dim,
61 latent_dim=self._config.latent_dim,
62 n_layers=self._config.n_layers,
63 enc_factor=self._config.enc_factor,
64 )
65 #
67 # Case 1: No Hidden Layers (Direct Mapping)
68 self._encoder = nn.Sequential()
69 self._mu = nn.Linear(self.input_dim, self._config.latent_dim)
70 self._logvar = nn.Linear(self.input_dim, self._config.latent_dim)
72 # Case 2: At Least One Hidden Layer
73 if self._config.n_layers > 0:
74 encoder_layers = []
75 # print(enc_dim)
76 for i, (in_features, out_features) in enumerate(
77 zip(enc_dim[:-1], enc_dim[1:])
78 ):
79 # since we add mu and logvar, we will remove the last layer
80 if i == len(enc_dim) - 2:
81 break
82 encoder_layers.extend(
83 LayerFactory.create_layer(
84 in_features=in_features,
85 out_features=out_features,
86 dropout_p=self._config.drop_p,
87 last_layer=False, # only for decoder relevant
88 )
89 )
91 self._encoder = nn.Sequential(*encoder_layers)
92 self._mu = nn.Linear(enc_dim[-2], self._config.latent_dim)
93 self._logvar = nn.Linear(enc_dim[-2], self._config.latent_dim)
95 # Construct Decoder (Same for Both Cases)
96 dec_dim = enc_dim[::-1] # Reverse the dimensions and copy
97 decoder_layers = []
98 for i, (in_features, out_features) in enumerate(zip(dec_dim[:-1], dec_dim[1:])):
99 last_layer = i == len(dec_dim) - 2
100 decoder_layers.extend(
101 LayerFactory.create_layer(
102 in_features=in_features,
103 out_features=out_features,
104 dropout_p=self._config.drop_p,
105 last_layer=last_layer,
106 )
107 )
109 self._decoder = nn.Sequential(*decoder_layers)
111 def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
112 """Encode the input tensor x.
114 Args:
115 x: Input tensor
117 Returns:
118 Encoded tensor
120 """
122 latent = x # for case where n_layers=0
123 if self._config.n_layers > 0:
124 latent = self._encoder(x)
125 mu = self._mu(latent)
126 logvar = self._logvar(latent)
127 # numeric stability
128 logvar = torch.clamp(logvar, 0.01, 20)
129 mu = torch.where(mu < 0.0000001, torch.zeros_like(mu), mu)
130 return mu, logvar
132 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
133 """Reparameterization trick for VAE
135 Args:
136 mu: torch.Tensor
137 logvar: torch.Tensor
139 Returns:
140 torch.Tensor
142 """
143 std = torch.exp(0.5 * logvar)
144 eps = torch.randn_like(std)
145 return mu + eps * std
147 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor:
148 """Returns the latent space representation of the input.
150 Args:
151 x: Input tensor
153 Returns:
154 Latent space representation
156 """
157 mu, logvar = self.encode(x)
158 z = self.reparameterize(mu, logvar)
159 return z
161 def decode(self, x: torch.Tensor) -> torch.Tensor:
162 """Decode the latent tensor x
164 Args:
165 x: Latent tensor
167 Returns:
168 Decoded tensor
169 """
171 return self._decoder(x)
173 def forward(self, x: torch.Tensor) -> ModelOutput:
174 """Forward pass of the model, fill
176 Args:
177 x: Input tensor
179 Returns:
180 ModelOutput object containing the reconstructed tensor and latent tensor
182 """
183 mu, logvar = self.encode(x)
184 z = self.reparameterize(mu, logvar)
185 x_hat = self.decode(z)
186 return ModelOutput(
187 reconstruction=x_hat,
188 latentspace=z,
189 latent_mean=mu,
190 latent_logvar=logvar,
191 additional_info=None,
192 )