Coverage for src / autoencodix / base / _base_autoencoder.py: 86%

37 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from abc import ABC, abstractmethod 

2from typing import Optional, Tuple, Union, Dict 

3 

4import torch 

5import torch.nn as nn 

6 

7from autoencodix.utils._model_output import ModelOutput 

8from autoencodix.configs.default_config import DefaultConfig 

9 

10 

11class BaseAutoencoder(ABC, nn.Module): 

12 """Interface for building autoencoder models. 

13 

14 Defines standard methods for encoding data to a latent space and decoding 

15 back to the original space. Includes a weight initialization method for 

16 stable training. Intended to be extended by specific autoencoder variants 

17 like VAE. 

18 

19 Attributes: 

20 input_dim: Number of input features. 

21 config: Configuration object containing model architecture parameters. 

22 _encoder: Encoder network. 

23 _decoder: Decoder network. 

24 ontologies: Ontology information, if provided for Ontix 

25 feature_order: For Ontix 

26 """ 

27 

28 def __init__( 

29 self, 

30 config: Optional[DefaultConfig], 

31 input_dim: Union[int, Tuple[int, ...]], 

32 ontologies: Optional[Union[Tuple, Dict]] = None, 

33 feature_order: Optional[Union[Tuple, Dict]] = None, 

34 ): 

35 """Initializes the BaseAutoencoder. 

36 

37 Args: 

38 config: Configuration object containing model parameters. 

39 If None, a default configuration will be used. 

40 input_dim: Number of input features. 

41 ontologies: Ontology information, if provided for Ontix 

42 feature_order: For Ontix 

43 """ 

44 super().__init__() 

45 if config is None: 

46 config = DefaultConfig() 

47 self.input_dim = input_dim 

48 self._encoder: Optional[nn.Module] = None 

49 self._decoder: Optional[nn.Module] = None 

50 self.config = config 

51 self.ontologies = ontologies 

52 self.feature_order = feature_order 

53 self.init_args = dict( 

54 config=config, 

55 input_dim=input_dim, 

56 ontologies=ontologies, 

57 feature_order=feature_order, 

58 ) 

59 

60 @abstractmethod 

61 def _build_network(self) -> None: 

62 """Builds the encoder and decoder networks for the autoencoder model. 

63 

64 Populates the self._encoder and self._decoder attributes. 

65 This method should be implemented by subclasses to define 

66 the architecture of the encoder and decoder networks. 

67 """ 

68 pass 

69 

70 @abstractmethod 

71 def encode( 

72 self, x: torch.Tensor 

73 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 

74 """Encodes the input into the latent space. 

75 

76 Args: 

77 x: The input tensor to be encoded. 

78 

79 Returns: 

80 The encoded latent space representation, or mu and logvar for VAEs. 

81 """ 

82 pass 

83 

84 @abstractmethod 

85 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor: 

86 """Returns the latent space representation of the input. 

87 

88 Method for unification of getting a latent space between Variational 

89 and Vanilla Autoencoders. This method is a wrapper around the encode 

90 method, or the reparameterization method for VAE. 

91 

92 Args: 

93 x: The input tensor to be encoded. 

94 

95 Returns: 

96 The latent space representation of the input tensor. 

97 """ 

98 pass 

99 

100 @abstractmethod 

101 def decode(self, x: torch.Tensor) -> torch.Tensor: 

102 """Decodes the latent representation back to the input space. 

103 

104 Args: 

105 x: The latent tensor to be decoded. 

106 

107 Returns: 

108 The decoded tensor, reconstructed from the latent space. 

109 """ 

110 pass 

111 

112 @abstractmethod 

113 def forward(self, x: torch.Tensor) -> ModelOutput: 

114 """Combines encoding and decoding steps for the autoencoder. 

115 

116 Args: 

117 x: The input tensor to be processed. 

118 

119 Returns: 

120 The reconstructed input tensor and any additional information, 

121 depending on the model type. 

122 """ 

123 pass 

124 

125 def _init_weights(self, m): 

126 """Initializes weights using Xavier uniform initialization. 

127 

128 This weight initialization method helps maintain the variance of 

129 activations across layers, preventing gradients from vanishing or 

130 exploding during training. This approach ensures stable and efficient 

131 training of the autoencoder model. 

132 

133 Args: 

134 m: The module to initialize. 

135 """ 

136 if isinstance(m, nn.Linear): 

137 torch.nn.init.xavier_uniform_(m.weight) 

138 m.bias.data.fill_(0.01)