Coverage for src / autoencodix / base / _base_autoencoder.py: 86%
37 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from abc import ABC, abstractmethod
2from typing import Optional, Tuple, Union, Dict
4import torch
5import torch.nn as nn
7from autoencodix.utils._model_output import ModelOutput
8from autoencodix.configs.default_config import DefaultConfig
11class BaseAutoencoder(ABC, nn.Module):
12 """Interface for building autoencoder models.
14 Defines standard methods for encoding data to a latent space and decoding
15 back to the original space. Includes a weight initialization method for
16 stable training. Intended to be extended by specific autoencoder variants
17 like VAE.
19 Attributes:
20 input_dim: Number of input features.
21 config: Configuration object containing model architecture parameters.
22 _encoder: Encoder network.
23 _decoder: Decoder network.
24 ontologies: Ontology information, if provided for Ontix
25 feature_order: For Ontix
26 """
28 def __init__(
29 self,
30 config: Optional[DefaultConfig],
31 input_dim: Union[int, Tuple[int, ...]],
32 ontologies: Optional[Union[Tuple, Dict]] = None,
33 feature_order: Optional[Union[Tuple, Dict]] = None,
34 ):
35 """Initializes the BaseAutoencoder.
37 Args:
38 config: Configuration object containing model parameters.
39 If None, a default configuration will be used.
40 input_dim: Number of input features.
41 ontologies: Ontology information, if provided for Ontix
42 feature_order: For Ontix
43 """
44 super().__init__()
45 if config is None:
46 config = DefaultConfig()
47 self.input_dim = input_dim
48 self._encoder: Optional[nn.Module] = None
49 self._decoder: Optional[nn.Module] = None
50 self.config = config
51 self.ontologies = ontologies
52 self.feature_order = feature_order
53 self.init_args = dict(
54 config=config,
55 input_dim=input_dim,
56 ontologies=ontologies,
57 feature_order=feature_order,
58 )
60 @abstractmethod
61 def _build_network(self) -> None:
62 """Builds the encoder and decoder networks for the autoencoder model.
64 Populates the self._encoder and self._decoder attributes.
65 This method should be implemented by subclasses to define
66 the architecture of the encoder and decoder networks.
67 """
68 pass
70 @abstractmethod
71 def encode(
72 self, x: torch.Tensor
73 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
74 """Encodes the input into the latent space.
76 Args:
77 x: The input tensor to be encoded.
79 Returns:
80 The encoded latent space representation, or mu and logvar for VAEs.
81 """
82 pass
84 @abstractmethod
85 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor:
86 """Returns the latent space representation of the input.
88 Method for unification of getting a latent space between Variational
89 and Vanilla Autoencoders. This method is a wrapper around the encode
90 method, or the reparameterization method for VAE.
92 Args:
93 x: The input tensor to be encoded.
95 Returns:
96 The latent space representation of the input tensor.
97 """
98 pass
100 @abstractmethod
101 def decode(self, x: torch.Tensor) -> torch.Tensor:
102 """Decodes the latent representation back to the input space.
104 Args:
105 x: The latent tensor to be decoded.
107 Returns:
108 The decoded tensor, reconstructed from the latent space.
109 """
110 pass
112 @abstractmethod
113 def forward(self, x: torch.Tensor) -> ModelOutput:
114 """Combines encoding and decoding steps for the autoencoder.
116 Args:
117 x: The input tensor to be processed.
119 Returns:
120 The reconstructed input tensor and any additional information,
121 depending on the model type.
122 """
123 pass
125 def _init_weights(self, m):
126 """Initializes weights using Xavier uniform initialization.
128 This weight initialization method helps maintain the variance of
129 activations across layers, preventing gradients from vanishing or
130 exploding during training. This approach ensures stable and efficient
131 training of the autoencoder model.
133 Args:
134 m: The module to initialize.
135 """
136 if isinstance(m, nn.Linear):
137 torch.nn.init.xavier_uniform_(m.weight)
138 m.bias.data.fill_(0.01)