Coverage for src / autoencodix / modeling / _imgfast_architecture.py: 29%
55 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2import torch.nn as nn
3from typing import Tuple, Optional, Union, Dict
4from autoencodix.configs.default_config import DefaultConfig
5from autoencodix.utils._model_output import ModelOutput
6from autoencodix.base._base_autoencoder import BaseAutoencoder
9class ImageVAEFastArchitecture(BaseAutoencoder):
10 """This class defines a VAE, based on a CNN for images
12 It takes as input an image and of shape (C,W,H) and reconstructs it.
13 We ensure to have a latent space of shape <batchsize,1,LatentDim> and img_in.shape = img_out.shape
14 We have a fixed kernel_size=4, padding=1 and stride=2 (given from https://github.com/uhlerlab/cross-modal-auto_encoders/tree/master)
16 So we need to calculate how the image dimension changes after each Convolution (we assume W=H)
17 Applying the formular:
18 W_out = (((W - kernel_size + 2padding)/stride) + 1)
19 We get:
20 W_out = (((W-4+2*1)/2)+1) =
21 = (W-2/2)+1 =
22 = (2(0.5W-1)/2) +1 # factor 2 out
23 = 0.5W - 1 + 1
24 W_out = 0.5W
25 So in this configuration the output shape halfs after every convolutional step (assuming W=H)
28 Attributes:
29 input_dim: (C,W,H) the input image shape
30 config: Configuration object containing model architecture parameters
31 _encoder: Encoder network of the autoencoder
32 _decoder: Decoder network of the autoencoder
33 latent_dim: Dimension of the latent space
34 nc: number of channels in the input image
35 h: height of the input image
36 w: width of the input image
37 img_shape: (C,W,H) the input image shape
38 hidden_dim: number of filters in the first convolutional layer
39 """
41 def __init__(
42 self,
43 input_dim: Tuple[int, int, int], # (C,W,H) the input image shape
44 config: Optional[DefaultConfig],
45 ontologies: Optional[Union[Tuple, Dict]] = None,
46 feature_order: Optional[Union[Tuple, Dict]] = None,
47 # the input_dim is the number of channels in the image, e.g. 3
48 ):
49 """Initialize the ImageVAEArchitecture with the given configuration.
51 Args:
52 input_dim: (C,W,H) the input image shape
53 config: Configuration object containing model parameters.
54 hidden_dim: number of filters in the first convolutional layer
55 """
56 if config is None:
57 config = DefaultConfig()
58 self._config: DefaultConfig = config
59 super().__init__(config=config, input_dim=input_dim)
60 self.input_dim: int = input_dim
61 self.latent_dim: int = self._config.latent_dim
62 self.nc, self.h, self.w = input_dim
63 self.img_shape: Tuple[int, int, int] = input_dim
64 self.hidden_dim: int = self._config.hidden_dim
65 self._build_network()
66 self.apply(self._init_weights)
68 def _build_network(self):
69 """Construct the encoder and decoder networks."""
70 self._encoder = nn.Sequential(
71 nn.Conv2d(
72 in_channels=self.nc,
73 out_channels=self.hidden_dim,
74 kernel_size=4,
75 stride=2,
76 padding=1,
77 bias=False,
78 ),
79 nn.LeakyReLU(0.2, inplace=True),
80 nn.Conv2d(
81 in_channels=self.hidden_dim,
82 out_channels=self.hidden_dim * 2,
83 kernel_size=4,
84 stride=2,
85 padding=1,
86 bias=False,
87 ),
88 nn.LeakyReLU(0.2, inplace=True),
89 nn.Conv2d(
90 in_channels=self.hidden_dim * 2,
91 out_channels=self.hidden_dim * 4,
92 kernel_size=4,
93 stride=2,
94 padding=1,
95 bias=False,
96 ),
97 nn.LeakyReLU(0.2, inplace=True),
98 nn.Conv2d(
99 in_channels=self.hidden_dim * 4,
100 out_channels=self.hidden_dim * 8,
101 kernel_size=4,
102 stride=2,
103 padding=1,
104 bias=False,
105 ),
106 nn.LeakyReLU(0.2, inplace=True),
107 nn.Conv2d(
108 in_channels=self.hidden_dim * 8,
109 out_channels=self.hidden_dim * 8,
110 kernel_size=4,
111 stride=2,
112 padding=1,
113 bias=False,
114 ),
115 nn.LeakyReLU(0.2, inplace=True),
116 )
118 # to Calculate the image shape after the _encoder, we need to know the number of layers
119 # because the shape halfs after every Conv2D layer
120 self.num__encoder_layers = sum(
121 1 for _ in self._encoder.children() if isinstance(_, nn.Conv2d)
122 )
123 # So the output shape after all layers is in_shape / 2**N_layers
124 # We showed above in the DocString why the shape halfs
125 self.spatial_dim = self.h // (2**self.num__encoder_layers)
126 # In the Linear mu and logvar layer we need to flatten the 3D output to a 2D matrix
127 # Therefore we need to multiply the size of every out diemension of the input layer to the Linear layers
128 # This is hidden_dim * 8 (the number of filter/channel layer) * spatial dim (the widht of the image) * spatial diem (the height of the image)
129 # assuimg width = height
130 # The original paper had a fixed spatial dimension of 2, which only worked for images with 64x64 shape
131 self.mu = nn.Linear(
132 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, self.latent_dim
133 )
134 self.logvar = nn.Linear(
135 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, self.latent_dim
136 )
138 # the same logic goes for the first _decoder layer, which takes the latent_dim as inshape
139 # which is the outshape of the previous mu/logvar layer
140 # and the shape of the first ConvTranspose2D layer is the last outpus shape of the _encoder layer
141 # This the same multiplication as above
142 self.d1 = nn.Sequential(
143 nn.Linear(
144 self.latent_dim,
145 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim,
146 ),
147 nn.ReLU(inplace=False),
148 )
149 self._decoder = nn.Sequential(
150 nn.ConvTranspose2d(
151 in_channels=self.hidden_dim * 8,
152 out_channels=self.hidden_dim * 8,
153 kernel_size=4,
154 stride=2,
155 padding=1,
156 bias=False,
157 ),
158 nn.LeakyReLU(0.2, inplace=True),
159 nn.ConvTranspose2d(
160 in_channels=self.hidden_dim * 8,
161 out_channels=self.hidden_dim * 4,
162 kernel_size=4,
163 stride=2,
164 padding=1,
165 bias=False,
166 ),
167 nn.LeakyReLU(0.2, inplace=True),
168 nn.ConvTranspose2d(
169 in_channels=self.hidden_dim * 4,
170 out_channels=self.hidden_dim * 2,
171 kernel_size=4,
172 stride=2,
173 padding=1,
174 bias=False,
175 ),
176 nn.LeakyReLU(0.2, inplace=True),
177 nn.ConvTranspose2d(
178 in_channels=self.hidden_dim * 2,
179 out_channels=self.hidden_dim,
180 kernel_size=4,
181 stride=2,
182 padding=1,
183 bias=False,
184 ),
185 nn.LeakyReLU(0.2, inplace=True),
186 nn.ConvTranspose2d(
187 in_channels=self.hidden_dim,
188 out_channels=self.nc,
189 kernel_size=4,
190 stride=2,
191 padding=1,
192 bias=False,
193 ),
194 )
196 def _get_spatial_dim(self) -> int:
197 return self.spatial_dim
199 def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
200 """Encodes the input tensor x.
202 Args:
203 x: Input tensor
204 Returns:
205 The encoded latent space representation, or mu and logvar for VAEs.
207 """
208 h = self._encoder(x)
209 # this makes sure we get the <batchsize, 1, latent_dim> shape for our latent space in the next step
210 # because we put all dimensionaltiy in the second dimension of the output shape.
211 # By covering all dimensionality here, we are sure that the rest is
212 h = h.view(-1, self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim)
213 logvar = self.logvar(h)
214 mu = self.mu(h)
215 # prevent mu and logvar from being too close to zero, this increased
216 # numerical stability
217 logvar = torch.clamp(logvar, 0.1, 20)
218 # replace mu when mu < 0.00000001 with 0.1
219 mu = torch.where(mu < 0.000001, torch.zeros_like(mu), mu)
220 return mu, logvar
222 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
223 """Reparameterization trick for VAE.
225 Args:
226 mu: mean of the latent distribution
227 logvar: log-variance of the latent distribution
228 Returns:
229 z: sampled latent vector
230 """
231 std = torch.exp(0.5 * logvar)
232 eps = torch.randn_like(std)
233 return mu + eps * std
235 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor:
236 """Returns the latent space representation of the input.
238 Args:
239 x: Input tensor
240 Returns:
241 Latent space representation
243 """
244 mu, logvar = self.encode(x)
245 return self.reparameterize(mu, logvar)
247 def decode(self, x: torch.Tensor) -> torch.Tensor:
248 """Decode the latent tensor x
249 Args:
250 x: Latent tensor
251 Returns:
252 Decoded tensor, reconstructed from the latent space
253 """
254 h = self.d1(x)
255 # here we do a similar thing as in the _encoder,
256 # but instead of ensuring the correct dimension for the latent space,
257 # we ensure the correct dimension for the first Conv2DTranspose layer
258 # so we make sure that the last 3 dimension are (n_filters, reduced_img_dim, reduced_img_dim)
259 h = h.view(-1, self.hidden_dim * 8, self.spatial_dim, self.spatial_dim)
260 return self._decoder(h)
262 def translate(self, z: torch.Tensor) -> torch.Tensor:
263 """Reshapes the output to get actual images
265 Args:
266 z: Latent tensor
267 Returns:
268 Reconstructed image of shape (C,W,H)
269 """
270 out = self.decode(z)
271 return out.view(-1, *self.img_shape)
273 def forward(self, x: torch.Tensor) -> ModelOutput:
274 """Forward pass of the model.
275 Args:
276 x: Input tensor
277 Returns:
278 ModelOutput object containing the reconstructed tensor and latent tensor
279 """
280 mu, logvar = self.encode(x)
281 z = self.reparameterize(mu, logvar)
282 return ModelOutput(
283 reconstruction=self.translate(z),
284 latentspace=z,
285 latent_mean=mu,
286 latent_logvar=logvar,
287 additional_info=None,
288 )