Coverage for src / autoencodix / modeling / _varix_architecture.py: 94%

64 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Optional, Union, Tuple, Dict 

2 

3import torch 

4import torch.nn as nn 

5 

6from autoencodix.base._base_autoencoder import BaseAutoencoder 

7from autoencodix.utils._model_output import ModelOutput 

8from autoencodix.configs.default_config import DefaultConfig 

9 

10from ._layer_factory import LayerFactory 

11 

12 

13class VarixArchitecture(BaseAutoencoder): 

14 """Variational Autoencoder implementation with separate encoder and decoder construction. 

15 

16 Attributes: 

17 input_dim: number of input features 

18 config: Configuration object containing model architecture parameters 

19 encoder: Encoder network of the autoencoder 

20 decoder: Decoder network of the autoencoder 

21 mu: Linear layer to compute the mean of the latent distribution 

22 logvar: Linear layer to compute the log-variance of the latent distribution 

23 

24 """ 

25 

26 def __init__( 

27 self, 

28 config: Optional[Union[None, DefaultConfig]], 

29 input_dim: int, 

30 ontologies: Optional[Union[Tuple, Dict]] = None, 

31 feature_order: Optional[Union[Tuple, Dict]] = None, 

32 ) -> None: 

33 """Initialize the Vanilla Autoencoder with the given configuration. 

34 

35 Args: 

36 config: Configuration object containing model parameters. 

37 input_dim: Number of input features. 

38 """ 

39 

40 if config is None: 

41 config = DefaultConfig() 

42 self._config: DefaultConfig = config 

43 super().__init__(config=config, input_dim=input_dim) 

44 self.input_dim: int = input_dim 

45 self._mu: nn.Module 

46 self._logvar: nn.Module 

47 self._encoder: nn.Module 

48 self._decoder: nn.Module 

49 

50 # populate self.encoder and self.decoder 

51 self._build_network() 

52 self.apply(self._init_weights) 

53 

54 def _build_network(self) -> None: 

55 """Construct the encoder and decoder networks. 

56 

57 Handles cases where `n_layers=0` by skipping the encoder and using only mu/logvar. 

58 """ 

59 enc_dim = LayerFactory.get_layer_dimensions( 

60 feature_dim=self.input_dim, 

61 latent_dim=self._config.latent_dim, 

62 n_layers=self._config.n_layers, 

63 enc_factor=self._config.enc_factor, 

64 ) 

65 # 

66 

67 # Case 1: No Hidden Layers (Direct Mapping) 

68 self._encoder = nn.Sequential() 

69 self._mu = nn.Linear(self.input_dim, self._config.latent_dim) 

70 self._logvar = nn.Linear(self.input_dim, self._config.latent_dim) 

71 

72 # Case 2: At Least One Hidden Layer 

73 if self._config.n_layers > 0: 

74 encoder_layers = [] 

75 # print(enc_dim) 

76 for i, (in_features, out_features) in enumerate( 

77 zip(enc_dim[:-1], enc_dim[1:]) 

78 ): 

79 # since we add mu and logvar, we will remove the last layer 

80 if i == len(enc_dim) - 2: 

81 break 

82 encoder_layers.extend( 

83 LayerFactory.create_layer( 

84 in_features=in_features, 

85 out_features=out_features, 

86 dropout_p=self._config.drop_p, 

87 last_layer=False, # only for decoder relevant 

88 ) 

89 ) 

90 

91 self._encoder = nn.Sequential(*encoder_layers) 

92 self._mu = nn.Linear(enc_dim[-2], self._config.latent_dim) 

93 self._logvar = nn.Linear(enc_dim[-2], self._config.latent_dim) 

94 

95 # Construct Decoder (Same for Both Cases) 

96 dec_dim = enc_dim[::-1] # Reverse the dimensions and copy 

97 decoder_layers = [] 

98 for i, (in_features, out_features) in enumerate(zip(dec_dim[:-1], dec_dim[1:])): 

99 last_layer = i == len(dec_dim) - 2 

100 decoder_layers.extend( 

101 LayerFactory.create_layer( 

102 in_features=in_features, 

103 out_features=out_features, 

104 dropout_p=self._config.drop_p, 

105 last_layer=last_layer, 

106 ) 

107 ) 

108 

109 self._decoder = nn.Sequential(*decoder_layers) 

110 

111 def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 

112 """Encode the input tensor x. 

113 

114 Args: 

115 x: Input tensor 

116 

117 Returns: 

118 Encoded tensor 

119 

120 """ 

121 

122 latent = x # for case where n_layers=0 

123 if self._config.n_layers > 0: 

124 latent = self._encoder(x) 

125 mu = self._mu(latent) 

126 logvar = self._logvar(latent) 

127 # numeric stability 

128 logvar = torch.clamp(logvar, 0.01, 20) 

129 mu = torch.where(mu < 0.0000001, torch.zeros_like(mu), mu) 

130 return mu, logvar 

131 

132 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: 

133 """Reparameterization trick for VAE 

134 

135 Args: 

136 mu: torch.Tensor 

137 logvar: torch.Tensor 

138 

139 Returns: 

140 torch.Tensor 

141 

142 """ 

143 std = torch.exp(0.5 * logvar) 

144 eps = torch.randn_like(std) 

145 return mu + eps * std 

146 

147 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor: 

148 """Returns the latent space representation of the input. 

149 

150 Args: 

151 x: Input tensor 

152 

153 Returns: 

154 Latent space representation 

155 

156 """ 

157 mu, logvar = self.encode(x) 

158 z = self.reparameterize(mu, logvar) 

159 return z 

160 

161 def decode(self, x: torch.Tensor) -> torch.Tensor: 

162 """Decode the latent tensor x 

163 

164 Args: 

165 x: Latent tensor 

166 

167 Returns: 

168 Decoded tensor 

169 """ 

170 

171 return self._decoder(x) 

172 

173 def forward(self, x: torch.Tensor) -> ModelOutput: 

174 """Forward pass of the model, fill 

175 

176 Args: 

177 x: Input tensor 

178 

179 Returns: 

180 ModelOutput object containing the reconstructed tensor and latent tensor 

181 

182 """ 

183 mu, logvar = self.encode(x) 

184 z = self.reparameterize(mu, logvar) 

185 x_hat = self.decode(z) 

186 return ModelOutput( 

187 reconstruction=x_hat, 

188 latentspace=z, 

189 latent_mean=mu, 

190 latent_logvar=logvar, 

191 additional_info=None, 

192 )