Coverage for src / autoencodix / modeling / _vanillix_architecture.py: 95%
38 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Optional, Union, Tuple, Dict
3import torch
4import torch.nn as nn
6from autoencodix.base._base_autoencoder import BaseAutoencoder
7from autoencodix.utils._model_output import ModelOutput
8from autoencodix.configs.default_config import DefaultConfig
10from ._layer_factory import LayerFactory
13# internal check done
14# write tests: done
15class VanillixArchitecture(BaseAutoencoder):
16 """Vanilla Autoencoder implementation with separate encoder and decoder construction.
18 Attributes:
19 input_dim: number of input features
20 config: Configuration object containing model architecture parameters
21 encoder: Encoder network of the autoencoder
22 decoder: Decoder network of the autoencoder
24 """
26 def __init__(
27 self,
28 config: Optional[Union[None, DefaultConfig]],
29 input_dim: int,
30 ontologies: Optional[Union[Tuple, Dict]] = None,
31 feature_order: Optional[Union[Tuple, Dict]] = None,
32 ) -> None:
33 """Initialize the Vanilla Autoencoder with the given configuration.
35 Args:
36 config: Configuration object containing model parameters.
37 input_dim: Number of input features.
38 """
40 if config is None:
41 config = DefaultConfig()
42 self._config = config
43 super().__init__(config, input_dim)
44 self.input_dim = input_dim
46 # populate self.encoder and self.decoder
47 self._build_network()
48 self.apply(self._init_weights)
50 def _build_network(self) -> None:
51 """Construct the encoder with linear layers."""
52 # Calculate layer dimensions
53 enc_dim = LayerFactory.get_layer_dimensions(
54 feature_dim=self.input_dim,
55 latent_dim=self._config.latent_dim,
56 n_layers=self._config.n_layers,
57 enc_factor=self._config.enc_factor,
58 )
60 encoder_layers = []
61 for i, (in_features, out_features) in enumerate(zip(enc_dim[:-1], enc_dim[1:])):
62 last_layer = i == len(enc_dim) - 2
63 encoder_layers.extend(
64 LayerFactory.create_layer(
65 in_features=in_features,
66 out_features=out_features,
67 dropout_p=self._config.drop_p,
68 last_layer=last_layer,
69 )
70 )
72 dec_dim = enc_dim[::-1] # Reverse the dimensions and copy
73 decoder_layers = []
74 for i, (in_features, out_features) in enumerate(zip(dec_dim[:-1], dec_dim[1:])):
75 last_layer = i == len(dec_dim) - 2
76 decoder_layers.extend(
77 LayerFactory.create_layer(
78 in_features=in_features,
79 out_features=out_features,
80 dropout_p=self._config.drop_p,
81 last_layer=last_layer,
82 )
83 )
84 self._encoder = nn.Sequential(*encoder_layers)
85 self._decoder = nn.Sequential(*decoder_layers)
87 def encode(self, x: torch.Tensor) -> torch.Tensor:
88 """Encodes the input data.
90 Args:
91 x: input Tensor
92 Returns:
93 torch.Tensor
95 """
96 return self._encoder(x)
98 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor:
99 """Returns the latent space representation of the input data.
101 Args:
102 x: input Tensor
103 Returns:
104 torch.Tensor
106 """
107 return self.encode(x)
109 def decode(self, x: torch.Tensor) -> torch.Tensor:
110 """Decodes the latent representation.
112 Args:
113 x: input Tensor
114 Returns:
115 torch.Tensor
117 """
118 return self._decoder(x)
120 def forward(self, x: torch.Tensor) -> ModelOutput:
121 """Forward pass of the model.
123 Args:
124 x: input Tensor
125 Returns:
126 ModelOutput
128 """
129 latent = self.encode(x)
130 return ModelOutput(
131 reconstruction=self.decode(latent),
132 latentspace=latent,
133 latent_mean=None,
134 latent_logvar=None,
135 additional_info=None,
136 )