Coverage for src / autoencodix / modeling / _ontix_architecture.py: 14%

120 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Optional, Union, Tuple 

2 

3import torch 

4import torch.nn as nn 

5 

6from autoencodix.base._base_autoencoder import BaseAutoencoder 

7from autoencodix.utils._model_output import ModelOutput 

8from autoencodix.configs.default_config import DefaultConfig 

9 

10from ._layer_factory import LayerFactory 

11 

12 

13class OntixArchitecture(BaseAutoencoder): 

14 """Ontology Autoencoder implementation with separate encoder and decoder construction. 

15 

16 Attributes: 

17 input_dim: number of input features 

18 config: Configuration object containing model architecture parameters 

19 _encoder: Encoder network of the autoencoder 

20 _decoder: Decoder network of the autoencoder 

21 mu: Linear layer to compute the mean of the latent distribution 

22 logvar: Linear layer to compute the log-variance of the latent distribution 

23 masks: Tuple of weight masks for the decoder layers based on ontology 

24 latent_dim: Dimension of the latent space, inferred from the first mask 

25 ontologies: Ontology information. 

26 feature_order: Order of features for input data. 

27 

28 """ 

29 

30 def __init__( 

31 self, 

32 config: Optional[Union[None, DefaultConfig]], 

33 input_dim: int, 

34 ontologies: tuple, 

35 feature_order: list, 

36 ) -> None: 

37 """Initialize the Vanilla Autoencoder with the given configuration. 

38 

39 Args: 

40 config: Configuration object containing model parameters. 

41 input_dim: Number of input features. 

42 ontologies: Ontology information. 

43 feature_order: Order of features for input data. 

44 """ 

45 if config is None: 

46 config = DefaultConfig() 

47 self._config = config 

48 super().__init__(config, input_dim) 

49 self.input_dim = input_dim 

50 self._mu: nn.Module 

51 self._logvar: nn.Module 

52 # create masks for sparse decoder 

53 self.ontologies = ontologies 

54 self.feature_order = feature_order 

55 self.masks = self._make_masks(config=self._config, feature_order=feature_order) 

56 self.latent_dim = self.masks[0].shape[1] 

57 print("Latent Dim: " + str(self.latent_dim)) 

58 # populate self.encoder and self.decoder 

59 self._build_network() 

60 self.apply(self._init_weights) 

61 self._decoder.apply( 

62 self._positive_dec 

63 ) # Sparse decoder only has positive weights 

64 

65 # Apply weight mask to create ontology-based decoder 

66 with torch.no_grad(): 

67 # Check that the decoder has the same number of layers as masks 

68 if len(self.masks) != len(self._decoder): 

69 print(len(self.masks), len(self._decoder)) 

70 print(self._encoder) 

71 print(self._decoder) 

72 raise ValueError( 

73 "Number of masks does not match number of decoder layers" 

74 ) 

75 else: 

76 for i, mask in enumerate(self.masks): 

77 self._decoder[i].weight.mul_(mask) 

78 

79 def _make_masks( 

80 self, config: DefaultConfig, feature_order: list 

81 ) -> Tuple[torch.Tensor, ...]: 

82 """Create masks for sparse decoder based on ontology via config 

83 

84 Args: 

85 config: Configuration object containing model parameters 

86 feature_order: Order of features for input data 

87 

88 Returns: 

89 Tuple containing the masks for the decoder network 

90 

91 """ 

92 # Read ontology from config 

93 

94 masks = tuple() 

95 # feature_names are all values in the last ontology layer 

96 all_feature_names = set() 

97 for key, values in self.ontologies[-1].items(): 

98 all_feature_names.update(values) 

99 all_feature_names = list(all_feature_names) 

100 print("Ontix checks:") 

101 print(f"All possible feature names length: {len(all_feature_names)}") 

102 print(f"Feature order length: {len(feature_order)}") 

103 # Check if all features in feature_order are present in all_feature_names 

104 feature_names = [f for f in feature_order] 

105 missing_features = [f for f in feature_order if f not in all_feature_names] 

106 if missing_features: 

107 print( 

108 f"Features in feature_order not found in all_feature_names: {missing_features}" 

109 ) 

110 print(f"Feature names without filtering: {len(feature_names)}") 

111 

112 # Enumerate through the ontologies 

113 for x, ont_dic in enumerate(self.ontologies): 

114 prev_lay_dim = len(ont_dic.keys()) 

115 

116 if x == len(self.ontologies) - 1: 

117 # fixed sort of feature list 

118 node_list = feature_names 

119 else: 

120 node_list = list(self.ontologies[x + 1].keys()) 

121 next_lay_dim = len(node_list) 

122 # create masks for sparse decoder 

123 mask = torch.zeros(next_lay_dim, prev_lay_dim) 

124 p_int = 0 

125 if len(node_list) == next_lay_dim: 

126 if len(ont_dic.keys()) == prev_lay_dim: 

127 for p_id in ont_dic: 

128 feature_list = ont_dic[p_id] 

129 for f_id in feature_list: 

130 if f_id in node_list: 

131 f_int = node_list.index(f_id) 

132 mask[f_int, p_int] = 1 

133 

134 p_int += 1 

135 else: 

136 print( 

137 "Mask layer cannot be calculated. Ontology key list does not match previous layer dimension" 

138 ) 

139 print("Returning zero mask") 

140 else: 

141 print(f"node list: {len(node_list)} vs next_lay_dim:{next_lay_dim}") 

142 print( 

143 "Mask layer cannot be calculated. Output layer list does not match next layer dimension" 

144 ) 

145 print("Returning zero mask") 

146 print( 

147 f"Mask layer {x} with shape {mask.shape} and {torch.sum(mask)} connections" 

148 ) 

149 masks += (mask,) 

150 

151 if torch.max(mask) < 1: 

152 print( 

153 "You provided an ontology with no connections between layers in the decoder. Please check your ontology definition." 

154 ) 

155 

156 return masks 

157 

158 def _build_network(self) -> None: 

159 """Construct the encoder and decoder networks. 

160 

161 Handles cases where `n_layers=0` by skipping the encoder and using only mu/logvar. 

162 """ 

163 #### Encoder copied from varix architecture #### 

164 enc_dim = LayerFactory.get_layer_dimensions( 

165 feature_dim=self.input_dim, 

166 latent_dim=self.latent_dim, 

167 n_layers=self._config.n_layers, 

168 enc_factor=self._config.enc_factor, 

169 ) 

170 # 

171 

172 # Case 1: No Hidden Layers (Direct Mapping) 

173 self._encoder = nn.Sequential() 

174 self._mu = nn.Linear(self.input_dim, self.latent_dim) 

175 self._logvar = nn.Linear(self.input_dim, self.latent_dim) 

176 

177 # Case 2: At Least One Hidden Layer 

178 if self._config.n_layers > 0: 

179 encoder_layers = [] 

180 # print(enc_dim) 

181 for i, (in_features, out_features) in enumerate( 

182 zip(enc_dim[:-1], enc_dim[1:]) 

183 ): 

184 # since we add mu and logvar, we will remove the last layer 

185 if i == len(enc_dim) - 2: 

186 break 

187 encoder_layers.extend( 

188 LayerFactory.create_layer( 

189 in_features=in_features, 

190 out_features=out_features, 

191 dropout_p=self._config.drop_p, 

192 last_layer=False, # only for decoder relevant 

193 ) 

194 ) 

195 

196 self._encoder = nn.Sequential(*encoder_layers) 

197 self._mu = nn.Linear(enc_dim[-2], self.latent_dim) 

198 self._logvar = nn.Linear(enc_dim[-2], self.latent_dim) 

199 #### Encoder copied from varix architecture #### 

200 

201 # Construct Decoder with Sparse Connections via masks 

202 # Decoder dimension is determined by the masks 

203 dec_dim = [self.latent_dim] + [ 

204 mask.shape[0] for mask in self.masks 

205 ] # + [self.input_dim] 

206 decoder_layers = [] 

207 for i, (in_features, out_features) in enumerate(zip(dec_dim[:-1], dec_dim[1:])): 

208 # last_layer = i == len(dec_dim) - 2 

209 last_layer = True ## Only linear layers in sparse decoder 

210 decoder_layers.extend( 

211 LayerFactory.create_layer( 

212 in_features=in_features, 

213 out_features=out_features, 

214 dropout_p=0, ## No dropout in sparse decoder 

215 last_layer=last_layer, 

216 # only_linear=True, 

217 ) 

218 ) 

219 

220 self._decoder = nn.Sequential(*decoder_layers) 

221 

222 def _positive_dec(self, m): 

223 if isinstance(m, nn.Linear): 

224 m.weight.data = m.weight.data.clamp(min=0) 

225 

226 def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 

227 """Encode the input tensor x 

228 

229 Args: 

230 x: Input tensor 

231 

232 Returns: 

233 Encoded tensor 

234 """ 

235 latent = x # for case where n_layers=0 

236 if len(self._encoder) > 0: 

237 latent = self._encoder(x) 

238 mu = self._mu(latent) 

239 logvar = self._logvar(latent) 

240 # numeric stability 

241 logvar = torch.clamp(logvar, 0.01, 20) 

242 mu = torch.where(mu < 0.0000001, torch.zeros_like(mu), mu) 

243 return mu, logvar 

244 

245 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: 

246 """Reparameterization trick for VAE 

247 

248 Args: 

249 mu: Mean tensor 

250 logvar: Log variance tensor 

251 

252 Returns: 

253 Reparameterized latent tensor 

254 """ 

255 std = torch.exp(0.5 * logvar) 

256 eps = torch.randn_like(std) 

257 return mu + eps * std 

258 

259 def decode(self, x: torch.Tensor) -> torch.Tensor: 

260 """Decode the latent tensor x 

261 

262 Args: 

263 x: Latent tensor 

264 

265 Returns: 

266 torch.Tensor 

267 Decoded tensor 

268 

269 """ 

270 return self._decoder(x) 

271 

272 def forward(self, x: torch.Tensor) -> ModelOutput: 

273 """Forward pass of the model, fill 

274 

275 Args: 

276 x: Input tensor 

277 

278 Returns: 

279 ModelOutput object containing the reconstructed tensor and latent tensor 

280 

281 """ 

282 mu, logvar = self.encode(x) 

283 z = self.reparameterize(mu, logvar) 

284 x_hat = self.decode(z) 

285 return ModelOutput( 

286 reconstruction=x_hat, 

287 latentspace=z, 

288 latent_mean=mu, 

289 latent_logvar=logvar, 

290 additional_info=None, 

291 ) 

292 

293 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor: 

294 """Returns the latent space representation of the input. 

295 

296 Args: 

297 x: Input tensor 

298 

299 Returns: 

300 Latent space representation 

301 

302 """ 

303 mu, logvar = self.encode(x) 

304 z = self.reparameterize(mu, logvar) 

305 return z