Coverage for src / autoencodix / modeling / _imagevae_architecture.py: 29%
55 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2import torch.nn as nn
3from typing import Tuple, Optional, Union, Dict
4from autoencodix.configs.default_config import DefaultConfig
5from autoencodix.utils._model_output import ModelOutput
6from autoencodix.base._base_autoencoder import BaseAutoencoder
9class ImageVAEArchitecture(BaseAutoencoder):
10 """This class defines a VAE, based on a CNN for images
12 It takes as input an image and of shape (C,W,H) and reconstructs it.
13 We ensure to have a latent space of shape <batchsize,1,LatentDim> and img_in.shape = img_out.shape
14 We have a fixed kernel_size=4, padding=1 and stride=2 (given from https://github.com/uhlerlab/cross-modal-auto_encoders/tree/master)
16 So we need to calculate how the image dimension changes after each Convolution (we assume W=H)
17 Applying the formular:
18 W_out = (((W - kernel_size + 2padding)/stride) + 1)
19 We get:
20 W_out = (((W-4+2*1)/2)+1) =
21 = (W-2/2)+1 =
22 = (2(0.5W-1)/2) +1 # factor 2 out
23 = 0.5W - 1 + 1
24 W_out = 0.5W
25 So in this configuration the output shape halfs after every convolutional step (assuming W=H)
28 Attributes:
29 input_dim: (C,W,H) the input image shape
30 config: Configuration object containing model architecture parameters
31 _encoder: Encoder network of the autoencoder
32 _decoder: Decoder network of the autoencoder
33 latent_dim: Dimension of the latent space
34 nc: number of channels in the input image
35 h: height of the input image
36 w: width of the input image
37 img_shape: (C,W,H) the input image shape
38 hidden_dim: number of filters in the first convolutional layer
39 """
41 def __init__(
42 self,
43 input_dim: Tuple[int, int, int], # (C,W,H) the input image shape
44 config: Optional[DefaultConfig],
45 ontologies: Optional[Union[Tuple, Dict]] = None,
46 feature_order: Optional[Union[Tuple, Dict]] = None,
47 # the input_dim is the number of channels in the image, e.g. 3
48 ):
49 """Initialize the ImageVAEArchitecture with the given configuration.
51 Args:
52 input_dim: (C,W,H) the input image shape
53 config: Configuration object containing model parameters.
54 hidden_dim: number of filters in the first convolutional layer
55 """
56 if config is None:
57 config = DefaultConfig()
58 self._config: DefaultConfig = config
59 super().__init__(config=config, input_dim=input_dim)
60 self.input_dim: int = input_dim
61 self.latent_dim: int = self._config.latent_dim
62 self.nc, self.h, self.w = input_dim
63 self.img_shape: Tuple[int, int, int] = input_dim
64 self.hidden_dim: int = self._config.hidden_dim
65 self._build_network()
66 self.apply(self._init_weights)
68 def _build_network(self):
69 """Construct the encoder and decoder networks."""
70 self._encoder = nn.Sequential(
71 nn.Conv2d(
72 in_channels=self.nc,
73 out_channels=self.hidden_dim,
74 kernel_size=4,
75 stride=2,
76 padding=1,
77 bias=False,
78 ),
79 nn.LeakyReLU(0.2, inplace=False),
80 nn.Conv2d(
81 in_channels=self.hidden_dim,
82 out_channels=self.hidden_dim * 2,
83 kernel_size=4,
84 stride=2,
85 padding=1,
86 bias=False,
87 ),
88 nn.BatchNorm2d(self.hidden_dim * 2),
89 nn.LeakyReLU(0.2, inplace=False),
90 nn.Conv2d(
91 in_channels=self.hidden_dim * 2,
92 out_channels=self.hidden_dim * 4,
93 kernel_size=4,
94 stride=2,
95 padding=1,
96 bias=False,
97 ),
98 nn.BatchNorm2d(self.hidden_dim * 4),
99 nn.LeakyReLU(0.2, inplace=False),
100 nn.Conv2d(
101 in_channels=self.hidden_dim * 4,
102 out_channels=self.hidden_dim * 8,
103 kernel_size=4,
104 stride=2,
105 padding=1,
106 bias=False,
107 ),
108 nn.BatchNorm2d(self.hidden_dim * 8),
109 nn.LeakyReLU(0.2, inplace=False),
110 nn.Conv2d(
111 in_channels=self.hidden_dim * 8,
112 out_channels=self.hidden_dim * 8,
113 kernel_size=4,
114 stride=2,
115 padding=1,
116 bias=False,
117 ),
118 nn.BatchNorm2d(self.hidden_dim * 8),
119 nn.LeakyReLU(0.2, inplace=False),
120 )
122 # to Calculate the image shape after the _encoder, we need to know the number of layers
123 # because the shape halfs after every Conv2D layer
124 self.num__encoder_layers = sum(
125 1 for _ in self._encoder.children() if isinstance(_, nn.Conv2d)
126 )
127 # So the output shape after all layers is in_shape / 2**N_layers
128 # We showed above in the DocString why the shape halfs
129 self.spatial_dim = self.h // (2**self.num__encoder_layers)
130 # In the Linear mu and logvar layer we need to flatten the 3D output to a 2D matrix
131 # Therefore we need to multiply the size of every out diemension of the input layer to the Linear layers
132 # This is hidden_dim * 8 (the number of filter/channel layer) * spatial dim (the widht of the image) * spatial diem (the height of the image)
133 # assuimg width = height
134 # The original paper had a fixed spatial dimension of 2, which only worked for images with 64x64 shape
135 self.mu = nn.Linear(
136 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, self.latent_dim
137 )
138 self.logvar = nn.Linear(
139 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, self.latent_dim
140 )
142 # the same logic goes for the first _decoder layer, which takes the latent_dim as inshape
143 # which is the outshape of the previous mu/logvar layer
144 # and the shape of the first ConvTranspose2D layer is the last outpus shape of the _encoder layer
145 # This the same multiplication as above
146 self.d1 = nn.Sequential(
147 nn.Linear(
148 self.latent_dim,
149 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim,
150 ),
151 nn.ReLU(inplace=False),
152 )
153 self._decoder = nn.Sequential(
154 nn.ConvTranspose2d(
155 in_channels=self.hidden_dim * 8,
156 out_channels=self.hidden_dim * 8,
157 kernel_size=4,
158 stride=2,
159 padding=1,
160 bias=False,
161 ),
162 nn.BatchNorm2d(self.hidden_dim * 8),
163 nn.LeakyReLU(0.2, inplace=False),
164 nn.ConvTranspose2d(
165 in_channels=self.hidden_dim * 8,
166 out_channels=self.hidden_dim * 4,
167 kernel_size=4,
168 stride=2,
169 padding=1,
170 bias=False,
171 ),
172 nn.BatchNorm2d(self.hidden_dim * 4),
173 nn.LeakyReLU(0.2, inplace=False),
174 nn.ConvTranspose2d(
175 in_channels=self.hidden_dim * 4,
176 out_channels=self.hidden_dim * 2,
177 kernel_size=4,
178 stride=2,
179 padding=1,
180 bias=False,
181 ),
182 nn.BatchNorm2d(self.hidden_dim * 2),
183 nn.LeakyReLU(0.2, inplace=False),
184 nn.ConvTranspose2d(
185 in_channels=self.hidden_dim * 2,
186 out_channels=self.hidden_dim,
187 kernel_size=4,
188 stride=2,
189 padding=1,
190 bias=False,
191 ),
192 nn.BatchNorm2d(self.hidden_dim),
193 nn.LeakyReLU(0.2, inplace=False),
194 nn.ConvTranspose2d(
195 in_channels=self.hidden_dim,
196 out_channels=self.nc,
197 kernel_size=4,
198 stride=2,
199 padding=1,
200 bias=False,
201 ),
202 )
204 def _get_spatial_dim(self) -> int:
205 return self.spatial_dim
207 def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
208 """Encodes the input tensor x.
210 Args:
211 x: Input tensor
212 Returns:
213 The encoded latent space representation, or mu and logvar for VAEs.
215 """
216 h = self._encoder(x)
217 # this makes sure we get the <batchsize, 1, latent_dim> shape for our latent space in the next step
218 # because we put all dimensionaltiy in the second dimension of the output shape.
219 # By covering all dimensionality here, we are sure that the rest is
220 h = h.view(-1, self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim)
221 logvar = self.logvar(h)
222 mu = self.mu(h)
223 # prevent mu and logvar from being too close to zero, this increased
224 # numerical stability
225 logvar = torch.clamp(logvar, 0.1, 20)
226 # replace mu when mu < 0.00000001 with 0.1
227 mu = torch.where(mu < 0.000001, torch.zeros_like(mu), mu)
228 return mu, logvar
230 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
231 """Reparameterization trick for VAE.
233 Args:
234 mu: mean of the latent distribution
235 logvar: log-variance of the latent distribution
236 Returns:
237 z: sampled latent vector
238 """
239 std = torch.exp(0.5 * logvar)
240 eps = torch.randn_like(std)
241 return mu + eps * std
243 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor:
244 """Returns the latent space representation of the input.
246 Args:
247 x: Input tensor
248 Returns:
249 Latent space representation
251 """
252 mu, logvar = self.encode(x)
253 return self.reparameterize(mu, logvar)
255 def decode(self, x: torch.Tensor) -> torch.Tensor:
256 """Decode the latent tensor x
257 Args:
258 x: Latent tensor
259 Returns:
260 Decoded tensor, reconstructed from the latent space
261 """
262 h = self.d1(x)
263 # here we do a similar thing as in the _encoder,
264 # but instead of ensuring the correct dimension for the latent space,
265 # we ensure the correct dimension for the first Conv2DTranspose layer
266 # so we make sure that the last 3 dimension are (n_filters, reduced_img_dim, reduced_img_dim)
267 h = h.view(-1, self.hidden_dim * 8, self.spatial_dim, self.spatial_dim)
268 return self._decoder(h)
270 def translate(self, z: torch.Tensor) -> torch.Tensor:
271 """Reshapes the output to get actual images
273 Args:
274 z: Latent tensor
275 Returns:
276 Reconstructed image of shape (C,W,H)
277 """
278 out = self.decode(z)
279 return out.view(-1, *self.img_shape)
281 def forward(self, x: torch.Tensor) -> ModelOutput:
282 """Forward pass of the model.
283 Args:
284 x: Input tensor
285 Returns:
286 ModelOutput object containing the reconstructed tensor and latent tensor
287 """
288 mu, logvar = self.encode(x)
289 z = self.reparameterize(mu, logvar)
290 return ModelOutput(
291 reconstruction=self.translate(z),
292 latentspace=z,
293 latent_mean=mu,
294 latent_logvar=logvar,
295 additional_info=None,
296 )