Coverage for src / autoencodix / modeling / _maskix_architecture.py: 27%

60 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2import torch.nn as nn 

3from autoencodix.configs import DefaultConfig 

4from autoencodix.modeling._layer_factory import LayerFactory 

5from typing import Optional, Union, Tuple, Dict, List 

6from autoencodix.base._base_autoencoder import BaseAutoencoder 

7from autoencodix.utils._model_output import ModelOutput 

8 

9 

10class MaskixArchitectureVanilla(BaseAutoencoder): 

11 """Masked Autoencoder Architecture that follows https://doi.org/10.1093/bioinformatics/btae020 

12 

13 To closely mimic the publication, the network is not build with our LayerFactory as in 

14 other architectures. 

15 

16 Attributes: 

17 input_dim: number of input features 

18 config: Configuration object containing model architecture parameters 

19 encoder: Encoder network of the autoencoder 

20 decoder: Decoder network of the autoencoder 

21 

22 

23 """ 

24 

25 def __init__( 

26 self, 

27 config: Optional[DefaultConfig], 

28 input_dim: Union[int, Tuple[int, ...]], 

29 ontologies: Optional[Union[Tuple, Dict]] = None, 

30 feature_order: Optional[Union[Tuple, Dict]] = None, 

31 ): 

32 if config is None: 

33 config = DefaultConfig() 

34 self._config: DefaultConfig = config 

35 super().__init__(config, input_dim) 

36 self.input_dim: Union[int, Tuple[int, ...]] = input_dim 

37 if not isinstance(self.input_dim, int): 

38 raise TypeError( 

39 f"input dim needs to be int for MaskixArchitecture, got {type(self.input_dim)}" 

40 ) 

41 self.latent_dim: int = self._config.latent_dim 

42 

43 # populate self.encoder and self.decoder 

44 self._encoder: nn.Module 

45 self._decoder: nn.Module 

46 self._build_network() 

47 self.apply(self._init_weights) 

48 

49 def _build_network(self): 

50 if self._config.maskix_architecture == "scMAE": 

51 self._build_scMAE() 

52 elif self._config.maskix_architecture == "custom": 

53 self._build_custom() 

54 else: 

55 raise ValueError( 

56 f"Got {self.config.maskix_architecture}, but expected 'scMAE' or 'custom'" 

57 "This happens if you allow a new value in DefaultConfig but did not implement it here." 

58 ) 

59 

60 def _build_custom(self): 

61 self._mask_predictor = nn.Linear(self.latent_dim, self.input_dim) # ty: ignore 

62 

63 enc_dim = LayerFactory.get_layer_dimensions( 

64 feature_dim=self.input_dim, # ty: ignore 

65 latent_dim=self._config.latent_dim, 

66 n_layers=self._config.n_layers, 

67 enc_factor=self._config.enc_factor, 

68 ) 

69 first_layer = nn.Dropout(p=self.config.drop_p) 

70 

71 encoder_layers: List[nn.Module] = [] 

72 if self._config.n_layers == 0: 

73 self._encoder = nn.Sequential( 

74 nn.Dropout(p=self.config.drop_p), 

75 nn.Linear(self.input_dim, self.latent_dim), # ty: ignore 

76 ) 

77 # print(enc_dim) 

78 for i, (in_features, out_features) in enumerate(zip(enc_dim[:-1], enc_dim[1:])): 

79 last_layer = i == len(enc_dim) - 2 

80 encoder_layers.extend( 

81 LayerFactory.create_maskix_layer( 

82 in_features=in_features, 

83 out_features=out_features, 

84 last_layer=last_layer, 

85 ) 

86 ) 

87 

88 self._encoder = nn.Sequential(first_layer, *encoder_layers) 

89 

90 # dec_dimensions = enc_dim[::-1] # Reverse the dimensions and copy 

91 # decoder_layers: List[nn.Module] = [] 

92 # for i, (in_features, out_features) in enumerate( 

93 # zip(dec_dimensions[:-1], dec_dimensions[1:]) 

94 # ): 

95 # last_layer = i == len(dec_dimensions) - 2 

96 # decoder_layers.extend( 

97 # LayerFactory.create_maskix_layer( 

98 # in_features=in_features, 

99 # out_features=out_features, 

100 # last_layer=last_layer, 

101 # ) 

102 # ) 

103 

104 # latent_layer = nn.Linear( 

105 # in_features=self.latent_dim + self.input_dim, # ty: ignore 

106 # out_features=dec_dimensions[0], 

107 # ) # ty: ignore 

108 # self._decoder = nn.Sequential(latent_layer, *decoder_layers) 

109 

110 dec_start: int = self.latent_dim + self.input_dim # ty: ignore 

111 dec_end: int = self.input_dim # ty: ignore 

112 dec_dim = LayerFactory.get_layer_dimensions( 

113 feature_dim=dec_start, 

114 latent_dim=dec_end, # Repurpose 'latent_dim' param as target dim 

115 n_layers=self._config.n_layers, 

116 enc_factor=self._config.enc_factor, 

117 ) 

118 decoder_layers: List[nn.Module] = [] 

119 for i, (in_features, out_features) in enumerate(zip(dec_dim[:-1], dec_dim[1:])): 

120 last_layer = i == len(dec_dim) - 2 

121 decoder_layers.extend( 

122 LayerFactory.create_maskix_layer( 

123 in_features=in_features, 

124 out_features=out_features, 

125 last_layer=last_layer, 

126 ) 

127 ) 

128 self._decoder = nn.Sequential(*decoder_layers) 

129 

130 def _build_scMAE(self): 

131 self._encoder = nn.Sequential( 

132 nn.Dropout(p=self.config.drop_p), 

133 nn.Linear(self.input_dim, self._config.maskix_hidden_dim), # ty: ignore 

134 nn.LayerNorm(self._config.maskix_hidden_dim), 

135 nn.Mish(inplace=True), 

136 nn.Linear(self._config.maskix_hidden_dim, self.latent_dim), 

137 nn.LayerNorm(self.latent_dim), 

138 nn.Mish(inplace=True), 

139 nn.Linear(self.latent_dim, self.latent_dim), 

140 ) 

141 

142 self._mask_predictor = nn.Linear(self.latent_dim, self.input_dim) # ty: ignore 

143 self._decoder = nn.Linear( 

144 in_features=self.latent_dim + self.input_dim, # ty: ignore 

145 out_features=self.input_dim, # ty: ignore 

146 ) 

147 

148 def encode(self, x: torch.Tensor) -> torch.Tensor: 

149 """Encodes the input data. 

150 

151 Args: 

152 x: input Tensor 

153 Returns: 

154 torch.Tensor 

155 

156 """ 

157 return self._encoder(x) 

158 

159 def get_latent_space(self, x: torch.Tensor) -> torch.Tensor: 

160 """Returns the latent space representation of the input data. 

161 

162 Args: 

163 x: input Tensor 

164 Returns: 

165 torch.Tensor 

166 

167 """ 

168 return self.encode(x) 

169 

170 def decode(self, x: torch.Tensor) -> torch.Tensor: 

171 """Decodes the latent representation. 

172 

173 Args: 

174 x: input Tensor 

175 Returns: 

176 torch.Tensor 

177 

178 """ 

179 return self._decoder(x) 

180 

181 def forward(self, x: torch.Tensor) -> ModelOutput: 

182 latent: torch.Tensor = self.encode(x=x) 

183 predicted_mask: torch.Tensor = self._mask_predictor(latent) 

184 return ModelOutput( 

185 reconstruction=self.decode(torch.cat([latent, predicted_mask], dim=1)), 

186 latentspace=latent, 

187 latent_mean=None, 

188 latent_logvar=None, 

189 additional_info={"predicted_mask": predicted_mask}, 

190 )