Coverage for src / autoencodix / modeling / _imagevae_architecture.py: 29%

55 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2import torch.nn as nn 

3from typing import Tuple, Optional, Union, Dict 

4from autoencodix.configs.default_config import DefaultConfig 

5from autoencodix.utils._model_output import ModelOutput 

6from autoencodix.base._base_autoencoder import BaseAutoencoder 

7 

8 

9class ImageVAEArchitecture(BaseAutoencoder): 

10 """This class defines a VAE, based on a CNN for images 

11 

12 It takes as input an image and of shape (C,W,H) and reconstructs it. 

13 We ensure to have a latent space of shape <batchsize,1,LatentDim> and img_in.shape = img_out.shape 

14 We have a fixed kernel_size=4, padding=1 and stride=2 (given from https://github.com/uhlerlab/cross-modal-auto_encoders/tree/master) 

15 

16 So we need to calculate how the image dimension changes after each Convolution (we assume W=H) 

17 Applying the formular: 

18 W_out = (((W - kernel_size + 2padding)/stride) + 1) 

19 We get: 

20 W_out = (((W-4+2*1)/2)+1) = 

21 = (W-2/2)+1 = 

22 = (2(0.5W-1)/2) +1 # factor 2 out 

23 = 0.5W - 1 + 1 

24 W_out = 0.5W 

25 So in this configuration the output shape halfs after every convolutional step (assuming W=H) 

26 

27 

28 Attributes: 

29 input_dim: (C,W,H) the input image shape 

30 config: Configuration object containing model architecture parameters 

31 _encoder: Encoder network of the autoencoder 

32 _decoder: Decoder network of the autoencoder 

33 latent_dim: Dimension of the latent space 

34 nc: number of channels in the input image 

35 h: height of the input image 

36 w: width of the input image 

37 img_shape: (C,W,H) the input image shape 

38 hidden_dim: number of filters in the first convolutional layer 

39 """ 

40 

41 def __init__( 

42 self, 

43 input_dim: Tuple[int, int, int], # (C,W,H) the input image shape 

44 config: Optional[DefaultConfig], 

45 ontologies: Optional[Union[Tuple, Dict]] = None, 

46 feature_order: Optional[Union[Tuple, Dict]] = None, 

47 # the input_dim is the number of channels in the image, e.g. 3 

48 ): 

49 """Initialize the ImageVAEArchitecture with the given configuration. 

50 

51 Args: 

52 input_dim: (C,W,H) the input image shape 

53 config: Configuration object containing model parameters. 

54 hidden_dim: number of filters in the first convolutional layer 

55 """ 

56 if config is None: 

57 config = DefaultConfig() 

58 self._config: DefaultConfig = config 

59 super().__init__(config=config, input_dim=input_dim) 

60 self.input_dim: int = input_dim 

61 self.latent_dim: int = self._config.latent_dim 

62 self.nc, self.h, self.w = input_dim 

63 self.img_shape: Tuple[int, int, int] = input_dim 

64 self.hidden_dim: int = self._config.hidden_dim 

65 self._build_network() 

66 self.apply(self._init_weights) 

67 

68 def _build_network(self): 

69 """Construct the encoder and decoder networks.""" 

70 self._encoder = nn.Sequential( 

71 nn.Conv2d( 

72 in_channels=self.nc, 

73 out_channels=self.hidden_dim, 

74 kernel_size=4, 

75 stride=2, 

76 padding=1, 

77 bias=False, 

78 ), 

79 nn.LeakyReLU(0.2, inplace=False), 

80 nn.Conv2d( 

81 in_channels=self.hidden_dim, 

82 out_channels=self.hidden_dim * 2, 

83 kernel_size=4, 

84 stride=2, 

85 padding=1, 

86 bias=False, 

87 ), 

88 nn.BatchNorm2d(self.hidden_dim * 2), 

89 nn.LeakyReLU(0.2, inplace=False), 

90 nn.Conv2d( 

91 in_channels=self.hidden_dim * 2, 

92 out_channels=self.hidden_dim * 4, 

93 kernel_size=4, 

94 stride=2, 

95 padding=1, 

96 bias=False, 

97 ), 

98 nn.BatchNorm2d(self.hidden_dim * 4), 

99 nn.LeakyReLU(0.2, inplace=False), 

100 nn.Conv2d( 

101 in_channels=self.hidden_dim * 4, 

102 out_channels=self.hidden_dim * 8, 

103 kernel_size=4, 

104 stride=2, 

105 padding=1, 

106 bias=False, 

107 ), 

108 nn.BatchNorm2d(self.hidden_dim * 8), 

109 nn.LeakyReLU(0.2, inplace=False), 

110 nn.Conv2d( 

111 in_channels=self.hidden_dim * 8, 

112 out_channels=self.hidden_dim * 8, 

113 kernel_size=4, 

114 stride=2, 

115 padding=1, 

116 bias=False, 

117 ), 

118 nn.BatchNorm2d(self.hidden_dim * 8), 

119 nn.LeakyReLU(0.2, inplace=False), 

120 ) 

121 

122 # to Calculate the image shape after the _encoder, we need to know the number of layers 

123 # because the shape halfs after every Conv2D layer 

124 self.num__encoder_layers = sum( 

125 1 for _ in self._encoder.children() if isinstance(_, nn.Conv2d) 

126 ) 

127 # So the output shape after all layers is in_shape / 2**N_layers 

128 # We showed above in the DocString why the shape halfs 

129 self.spatial_dim = self.h // (2**self.num__encoder_layers) 

130 # In the Linear mu and logvar layer we need to flatten the 3D output to a 2D matrix 

131 # Therefore we need to multiply the size of every out diemension of the input layer to the Linear layers 

132 # This is hidden_dim * 8 (the number of filter/channel layer) * spatial dim (the widht of the image) * spatial diem (the height of the image) 

133 # assuimg width = height 

134 # The original paper had a fixed spatial dimension of 2, which only worked for images with 64x64 shape 

135 self.mu = nn.Linear( 

136 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, self.latent_dim 

137 ) 

138 self.logvar = nn.Linear( 

139 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, self.latent_dim 

140 ) 

141 

142 # the same logic goes for the first _decoder layer, which takes the latent_dim as inshape 

143 # which is the outshape of the previous mu/logvar layer 

144 # and the shape of the first ConvTranspose2D layer is the last outpus shape of the _encoder layer 

145 # This the same multiplication as above 

146 self.d1 = nn.Sequential( 

147 nn.Linear( 

148 self.latent_dim, 

149 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, 

150 ), 

151 nn.ReLU(inplace=False), 

152 ) 

153 self._decoder = nn.Sequential( 

154 nn.ConvTranspose2d( 

155 in_channels=self.hidden_dim * 8, 

156 out_channels=self.hidden_dim * 8, 

157 kernel_size=4, 

158 stride=2, 

159 padding=1, 

160 bias=False, 

161 ), 

162 nn.BatchNorm2d(self.hidden_dim * 8), 

163 nn.LeakyReLU(0.2, inplace=False), 

164 nn.ConvTranspose2d( 

165 in_channels=self.hidden_dim * 8, 

166 out_channels=self.hidden_dim * 4, 

167 kernel_size=4, 

168 stride=2, 

169 padding=1, 

170 bias=False, 

171 ), 

172 nn.BatchNorm2d(self.hidden_dim * 4), 

173 nn.LeakyReLU(0.2, inplace=False), 

174 nn.ConvTranspose2d( 

175 in_channels=self.hidden_dim * 4, 

176 out_channels=self.hidden_dim * 2, 

177 kernel_size=4, 

178 stride=2, 

179 padding=1, 

180 bias=False, 

181 ), 

182 nn.BatchNorm2d(self.hidden_dim * 2), 

183 nn.LeakyReLU(0.2, inplace=False), 

184 nn.ConvTranspose2d( 

185 in_channels=self.hidden_dim * 2, 

186 out_channels=self.hidden_dim, 

187 kernel_size=4, 

188 stride=2, 

189 padding=1, 

190 bias=False, 

191 ), 

192 nn.BatchNorm2d(self.hidden_dim), 

193 nn.LeakyReLU(0.2, inplace=False), 

194 nn.ConvTranspose2d( 

195 in_channels=self.hidden_dim, 

196 out_channels=self.nc, 

197 kernel_size=4, 

198 stride=2, 

199 padding=1, 

200 bias=False, 

201 ), 

202 ) 

203 

204 def _get_spatial_dim(self) -> int: 

205 return self.spatial_dim 

206 

207 def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 

208 """Encodes the input tensor x. 

209 

210 Args: 

211 x: Input tensor 

212 Returns: 

213 The encoded latent space representation, or mu and logvar for VAEs. 

214 

215 """ 

216 h = self._encoder(x) 

217 # this makes sure we get the <batchsize, 1, latent_dim> shape for our latent space in the next step 

218 # because we put all dimensionaltiy in the second dimension of the output shape. 

219 # By covering all dimensionality here, we are sure that the rest is 

220 h = h.view(-1, self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim) 

221 logvar = self.logvar(h) 

222 mu = self.mu(h) 

223 # prevent mu and logvar from being too close to zero, this increased 

224 # numerical stability 

225 logvar = torch.clamp(logvar, 0.1, 20) 

226 # replace mu when mu < 0.00000001 with 0.1 

227 mu = torch.where(mu < 0.000001, torch.zeros_like(mu), mu) 

228 return mu, logvar 

229 

230 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: 

231 """Reparameterization trick for VAE. 

232 

233 Args: 

234 mu: mean of the latent distribution 

235 logvar: log-variance of the latent distribution 

236 Returns: 

237 z: sampled latent vector 

238 """ 

239 std = torch.exp(0.5 * logvar) 

240 eps = torch.randn_like(std) 

241 return mu + eps * std 

242 

243 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor: 

244 """Returns the latent space representation of the input. 

245 

246 Args: 

247 x: Input tensor 

248 Returns: 

249 Latent space representation 

250 

251 """ 

252 mu, logvar = self.encode(x) 

253 return self.reparameterize(mu, logvar) 

254 

255 def decode(self, x: torch.Tensor) -> torch.Tensor: 

256 """Decode the latent tensor x 

257 Args: 

258 x: Latent tensor 

259 Returns: 

260 Decoded tensor, reconstructed from the latent space 

261 """ 

262 h = self.d1(x) 

263 # here we do a similar thing as in the _encoder, 

264 # but instead of ensuring the correct dimension for the latent space, 

265 # we ensure the correct dimension for the first Conv2DTranspose layer 

266 # so we make sure that the last 3 dimension are (n_filters, reduced_img_dim, reduced_img_dim) 

267 h = h.view(-1, self.hidden_dim * 8, self.spatial_dim, self.spatial_dim) 

268 return self._decoder(h) 

269 

270 def translate(self, z: torch.Tensor) -> torch.Tensor: 

271 """Reshapes the output to get actual images 

272 

273 Args: 

274 z: Latent tensor 

275 Returns: 

276 Reconstructed image of shape (C,W,H) 

277 """ 

278 out = self.decode(z) 

279 return out.view(-1, *self.img_shape) 

280 

281 def forward(self, x: torch.Tensor) -> ModelOutput: 

282 """Forward pass of the model. 

283 Args: 

284 x: Input tensor 

285 Returns: 

286 ModelOutput object containing the reconstructed tensor and latent tensor 

287 """ 

288 mu, logvar = self.encode(x) 

289 z = self.reparameterize(mu, logvar) 

290 return ModelOutput( 

291 reconstruction=self.translate(z), 

292 latentspace=z, 

293 latent_mean=mu, 

294 latent_logvar=logvar, 

295 additional_info=None, 

296 )