Coverage for src / autoencodix / modeling / _vanillix_architecture.py: 95%

38 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Optional, Union, Tuple, Dict 

2 

3import torch 

4import torch.nn as nn 

5 

6from autoencodix.base._base_autoencoder import BaseAutoencoder 

7from autoencodix.utils._model_output import ModelOutput 

8from autoencodix.configs.default_config import DefaultConfig 

9 

10from ._layer_factory import LayerFactory 

11 

12 

13# internal check done 

14# write tests: done 

15class VanillixArchitecture(BaseAutoencoder): 

16 """Vanilla Autoencoder implementation with separate encoder and decoder construction. 

17 

18 Attributes: 

19 input_dim: number of input features 

20 config: Configuration object containing model architecture parameters 

21 encoder: Encoder network of the autoencoder 

22 decoder: Decoder network of the autoencoder 

23 

24 """ 

25 

26 def __init__( 

27 self, 

28 config: Optional[Union[None, DefaultConfig]], 

29 input_dim: int, 

30 ontologies: Optional[Union[Tuple, Dict]] = None, 

31 feature_order: Optional[Union[Tuple, Dict]] = None, 

32 ) -> None: 

33 """Initialize the Vanilla Autoencoder with the given configuration. 

34 

35 Args: 

36 config: Configuration object containing model parameters. 

37 input_dim: Number of input features. 

38 """ 

39 

40 if config is None: 

41 config = DefaultConfig() 

42 self._config = config 

43 super().__init__(config, input_dim) 

44 self.input_dim = input_dim 

45 

46 # populate self.encoder and self.decoder 

47 self._build_network() 

48 self.apply(self._init_weights) 

49 

50 def _build_network(self) -> None: 

51 """Construct the encoder with linear layers.""" 

52 # Calculate layer dimensions 

53 enc_dim = LayerFactory.get_layer_dimensions( 

54 feature_dim=self.input_dim, 

55 latent_dim=self._config.latent_dim, 

56 n_layers=self._config.n_layers, 

57 enc_factor=self._config.enc_factor, 

58 ) 

59 

60 encoder_layers = [] 

61 for i, (in_features, out_features) in enumerate(zip(enc_dim[:-1], enc_dim[1:])): 

62 last_layer = i == len(enc_dim) - 2 

63 encoder_layers.extend( 

64 LayerFactory.create_layer( 

65 in_features=in_features, 

66 out_features=out_features, 

67 dropout_p=self._config.drop_p, 

68 last_layer=last_layer, 

69 ) 

70 ) 

71 

72 dec_dim = enc_dim[::-1] # Reverse the dimensions and copy 

73 decoder_layers = [] 

74 for i, (in_features, out_features) in enumerate(zip(dec_dim[:-1], dec_dim[1:])): 

75 last_layer = i == len(dec_dim) - 2 

76 decoder_layers.extend( 

77 LayerFactory.create_layer( 

78 in_features=in_features, 

79 out_features=out_features, 

80 dropout_p=self._config.drop_p, 

81 last_layer=last_layer, 

82 ) 

83 ) 

84 self._encoder = nn.Sequential(*encoder_layers) 

85 self._decoder = nn.Sequential(*decoder_layers) 

86 

87 def encode(self, x: torch.Tensor) -> torch.Tensor: 

88 """Encodes the input data. 

89 

90 Args: 

91 x: input Tensor 

92 Returns: 

93 torch.Tensor 

94 

95 """ 

96 return self._encoder(x) 

97 

98 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor: 

99 """Returns the latent space representation of the input data. 

100 

101 Args: 

102 x: input Tensor 

103 Returns: 

104 torch.Tensor 

105 

106 """ 

107 return self.encode(x) 

108 

109 def decode(self, x: torch.Tensor) -> torch.Tensor: 

110 """Decodes the latent representation. 

111 

112 Args: 

113 x: input Tensor 

114 Returns: 

115 torch.Tensor 

116 

117 """ 

118 return self._decoder(x) 

119 

120 def forward(self, x: torch.Tensor) -> ModelOutput: 

121 """Forward pass of the model. 

122 

123 Args: 

124 x: input Tensor 

125 Returns: 

126 ModelOutput 

127 

128 """ 

129 latent = self.encode(x) 

130 return ModelOutput( 

131 reconstruction=self.decode(latent), 

132 latentspace=latent, 

133 latent_mean=None, 

134 latent_logvar=None, 

135 additional_info=None, 

136 )