Coverage for src / autoencodix / modeling / _imgfast_architecture.py: 29%

55 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2import torch.nn as nn 

3from typing import Tuple, Optional, Union, Dict 

4from autoencodix.configs.default_config import DefaultConfig 

5from autoencodix.utils._model_output import ModelOutput 

6from autoencodix.base._base_autoencoder import BaseAutoencoder 

7 

8 

9class ImageVAEFastArchitecture(BaseAutoencoder): 

10 """This class defines a VAE, based on a CNN for images 

11 

12 It takes as input an image and of shape (C,W,H) and reconstructs it. 

13 We ensure to have a latent space of shape <batchsize,1,LatentDim> and img_in.shape = img_out.shape 

14 We have a fixed kernel_size=4, padding=1 and stride=2 (given from https://github.com/uhlerlab/cross-modal-auto_encoders/tree/master) 

15 

16 So we need to calculate how the image dimension changes after each Convolution (we assume W=H) 

17 Applying the formular: 

18 W_out = (((W - kernel_size + 2padding)/stride) + 1) 

19 We get: 

20 W_out = (((W-4+2*1)/2)+1) = 

21 = (W-2/2)+1 = 

22 = (2(0.5W-1)/2) +1 # factor 2 out 

23 = 0.5W - 1 + 1 

24 W_out = 0.5W 

25 So in this configuration the output shape halfs after every convolutional step (assuming W=H) 

26 

27 

28 Attributes: 

29 input_dim: (C,W,H) the input image shape 

30 config: Configuration object containing model architecture parameters 

31 _encoder: Encoder network of the autoencoder 

32 _decoder: Decoder network of the autoencoder 

33 latent_dim: Dimension of the latent space 

34 nc: number of channels in the input image 

35 h: height of the input image 

36 w: width of the input image 

37 img_shape: (C,W,H) the input image shape 

38 hidden_dim: number of filters in the first convolutional layer 

39 """ 

40 

41 def __init__( 

42 self, 

43 input_dim: Tuple[int, int, int], # (C,W,H) the input image shape 

44 config: Optional[DefaultConfig], 

45 ontologies: Optional[Union[Tuple, Dict]] = None, 

46 feature_order: Optional[Union[Tuple, Dict]] = None, 

47 # the input_dim is the number of channels in the image, e.g. 3 

48 ): 

49 """Initialize the ImageVAEArchitecture with the given configuration. 

50 

51 Args: 

52 input_dim: (C,W,H) the input image shape 

53 config: Configuration object containing model parameters. 

54 hidden_dim: number of filters in the first convolutional layer 

55 """ 

56 if config is None: 

57 config = DefaultConfig() 

58 self._config: DefaultConfig = config 

59 super().__init__(config=config, input_dim=input_dim) 

60 self.input_dim: int = input_dim 

61 self.latent_dim: int = self._config.latent_dim 

62 self.nc, self.h, self.w = input_dim 

63 self.img_shape: Tuple[int, int, int] = input_dim 

64 self.hidden_dim: int = self._config.hidden_dim 

65 self._build_network() 

66 self.apply(self._init_weights) 

67 

68 def _build_network(self): 

69 """Construct the encoder and decoder networks.""" 

70 self._encoder = nn.Sequential( 

71 nn.Conv2d( 

72 in_channels=self.nc, 

73 out_channels=self.hidden_dim, 

74 kernel_size=4, 

75 stride=2, 

76 padding=1, 

77 bias=False, 

78 ), 

79 nn.LeakyReLU(0.2, inplace=True), 

80 nn.Conv2d( 

81 in_channels=self.hidden_dim, 

82 out_channels=self.hidden_dim * 2, 

83 kernel_size=4, 

84 stride=2, 

85 padding=1, 

86 bias=False, 

87 ), 

88 nn.LeakyReLU(0.2, inplace=True), 

89 nn.Conv2d( 

90 in_channels=self.hidden_dim * 2, 

91 out_channels=self.hidden_dim * 4, 

92 kernel_size=4, 

93 stride=2, 

94 padding=1, 

95 bias=False, 

96 ), 

97 nn.LeakyReLU(0.2, inplace=True), 

98 nn.Conv2d( 

99 in_channels=self.hidden_dim * 4, 

100 out_channels=self.hidden_dim * 8, 

101 kernel_size=4, 

102 stride=2, 

103 padding=1, 

104 bias=False, 

105 ), 

106 nn.LeakyReLU(0.2, inplace=True), 

107 nn.Conv2d( 

108 in_channels=self.hidden_dim * 8, 

109 out_channels=self.hidden_dim * 8, 

110 kernel_size=4, 

111 stride=2, 

112 padding=1, 

113 bias=False, 

114 ), 

115 nn.LeakyReLU(0.2, inplace=True), 

116 ) 

117 

118 # to Calculate the image shape after the _encoder, we need to know the number of layers 

119 # because the shape halfs after every Conv2D layer 

120 self.num__encoder_layers = sum( 

121 1 for _ in self._encoder.children() if isinstance(_, nn.Conv2d) 

122 ) 

123 # So the output shape after all layers is in_shape / 2**N_layers 

124 # We showed above in the DocString why the shape halfs 

125 self.spatial_dim = self.h // (2**self.num__encoder_layers) 

126 # In the Linear mu and logvar layer we need to flatten the 3D output to a 2D matrix 

127 # Therefore we need to multiply the size of every out diemension of the input layer to the Linear layers 

128 # This is hidden_dim * 8 (the number of filter/channel layer) * spatial dim (the widht of the image) * spatial diem (the height of the image) 

129 # assuimg width = height 

130 # The original paper had a fixed spatial dimension of 2, which only worked for images with 64x64 shape 

131 self.mu = nn.Linear( 

132 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, self.latent_dim 

133 ) 

134 self.logvar = nn.Linear( 

135 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, self.latent_dim 

136 ) 

137 

138 # the same logic goes for the first _decoder layer, which takes the latent_dim as inshape 

139 # which is the outshape of the previous mu/logvar layer 

140 # and the shape of the first ConvTranspose2D layer is the last outpus shape of the _encoder layer 

141 # This the same multiplication as above 

142 self.d1 = nn.Sequential( 

143 nn.Linear( 

144 self.latent_dim, 

145 self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim, 

146 ), 

147 nn.ReLU(inplace=False), 

148 ) 

149 self._decoder = nn.Sequential( 

150 nn.ConvTranspose2d( 

151 in_channels=self.hidden_dim * 8, 

152 out_channels=self.hidden_dim * 8, 

153 kernel_size=4, 

154 stride=2, 

155 padding=1, 

156 bias=False, 

157 ), 

158 nn.LeakyReLU(0.2, inplace=True), 

159 nn.ConvTranspose2d( 

160 in_channels=self.hidden_dim * 8, 

161 out_channels=self.hidden_dim * 4, 

162 kernel_size=4, 

163 stride=2, 

164 padding=1, 

165 bias=False, 

166 ), 

167 nn.LeakyReLU(0.2, inplace=True), 

168 nn.ConvTranspose2d( 

169 in_channels=self.hidden_dim * 4, 

170 out_channels=self.hidden_dim * 2, 

171 kernel_size=4, 

172 stride=2, 

173 padding=1, 

174 bias=False, 

175 ), 

176 nn.LeakyReLU(0.2, inplace=True), 

177 nn.ConvTranspose2d( 

178 in_channels=self.hidden_dim * 2, 

179 out_channels=self.hidden_dim, 

180 kernel_size=4, 

181 stride=2, 

182 padding=1, 

183 bias=False, 

184 ), 

185 nn.LeakyReLU(0.2, inplace=True), 

186 nn.ConvTranspose2d( 

187 in_channels=self.hidden_dim, 

188 out_channels=self.nc, 

189 kernel_size=4, 

190 stride=2, 

191 padding=1, 

192 bias=False, 

193 ), 

194 ) 

195 

196 def _get_spatial_dim(self) -> int: 

197 return self.spatial_dim 

198 

199 def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 

200 """Encodes the input tensor x. 

201 

202 Args: 

203 x: Input tensor 

204 Returns: 

205 The encoded latent space representation, or mu and logvar for VAEs. 

206 

207 """ 

208 h = self._encoder(x) 

209 # this makes sure we get the <batchsize, 1, latent_dim> shape for our latent space in the next step 

210 # because we put all dimensionaltiy in the second dimension of the output shape. 

211 # By covering all dimensionality here, we are sure that the rest is 

212 h = h.view(-1, self.hidden_dim * 8 * self.spatial_dim * self.spatial_dim) 

213 logvar = self.logvar(h) 

214 mu = self.mu(h) 

215 # prevent mu and logvar from being too close to zero, this increased 

216 # numerical stability 

217 logvar = torch.clamp(logvar, 0.1, 20) 

218 # replace mu when mu < 0.00000001 with 0.1 

219 mu = torch.where(mu < 0.000001, torch.zeros_like(mu), mu) 

220 return mu, logvar 

221 

222 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: 

223 """Reparameterization trick for VAE. 

224 

225 Args: 

226 mu: mean of the latent distribution 

227 logvar: log-variance of the latent distribution 

228 Returns: 

229 z: sampled latent vector 

230 """ 

231 std = torch.exp(0.5 * logvar) 

232 eps = torch.randn_like(std) 

233 return mu + eps * std 

234 

235 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor: 

236 """Returns the latent space representation of the input. 

237 

238 Args: 

239 x: Input tensor 

240 Returns: 

241 Latent space representation 

242 

243 """ 

244 mu, logvar = self.encode(x) 

245 return self.reparameterize(mu, logvar) 

246 

247 def decode(self, x: torch.Tensor) -> torch.Tensor: 

248 """Decode the latent tensor x 

249 Args: 

250 x: Latent tensor 

251 Returns: 

252 Decoded tensor, reconstructed from the latent space 

253 """ 

254 h = self.d1(x) 

255 # here we do a similar thing as in the _encoder, 

256 # but instead of ensuring the correct dimension for the latent space, 

257 # we ensure the correct dimension for the first Conv2DTranspose layer 

258 # so we make sure that the last 3 dimension are (n_filters, reduced_img_dim, reduced_img_dim) 

259 h = h.view(-1, self.hidden_dim * 8, self.spatial_dim, self.spatial_dim) 

260 return self._decoder(h) 

261 

262 def translate(self, z: torch.Tensor) -> torch.Tensor: 

263 """Reshapes the output to get actual images 

264 

265 Args: 

266 z: Latent tensor 

267 Returns: 

268 Reconstructed image of shape (C,W,H) 

269 """ 

270 out = self.decode(z) 

271 return out.view(-1, *self.img_shape) 

272 

273 def forward(self, x: torch.Tensor) -> ModelOutput: 

274 """Forward pass of the model. 

275 Args: 

276 x: Input tensor 

277 Returns: 

278 ModelOutput object containing the reconstructed tensor and latent tensor 

279 """ 

280 mu, logvar = self.encode(x) 

281 z = self.reparameterize(mu, logvar) 

282 return ModelOutput( 

283 reconstruction=self.translate(z), 

284 latentspace=z, 

285 latent_mean=mu, 

286 latent_logvar=logvar, 

287 additional_info=None, 

288 )