Coverage for src/driada/dim_reduction/neural.py: 92.74%

124 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import torch 

2import torch.nn as nn 

3import torch.nn.functional as F 

4import torch.optim as optim 

5# import torchvision 

6from torch.utils.data import Dataset, DataLoader 

7 

8 

9class Encoder(nn.Module): 

10 

11 def __init__(self, orig_dim, inter_dim, code_dim, kwargs, device=None): 

12 super().__init__() 

13 dropout = kwargs.get('dropout', None) 

14 

15 self.encoder_hidden_layer = nn.Linear( 

16 in_features=orig_dim, out_features=inter_dim 

17 ) 

18 self.encoder_output_layer = nn.Linear( 

19 in_features=inter_dim, out_features=code_dim 

20 ) 

21 

22 if dropout is not None: 

23 if 0 <= dropout < 1: 

24 self.dropout = nn.Dropout(p=dropout) 

25 else: 

26 raise ValueError('Dropout rate should be in the range 0<=dropout<1') 

27 else: 

28 self.dropout = nn.Dropout(0.0) 

29 

30 if device is None: 

31 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

32 else: 

33 self._device = device 

34 

35 def forward(self, features): 

36 activation = self.encoder_hidden_layer(features) 

37 activation = self.dropout(torch.ones(activation.shape).to(self._device)) * activation 

38 activation = F.leaky_relu(activation) 

39 # activation = torch.relu(activation) 

40 code = self.encoder_output_layer(activation) 

41 code = torch.sigmoid(code) 

42 # code = F.leaky_relu(code) 

43 

44 return code 

45 

46 

47class VAEEncoder(nn.Module): 

48 """Special encoder for VAE that doesn't use sigmoid activation""" 

49 

50 def __init__(self, orig_dim, inter_dim, code_dim, kwargs, device=None): 

51 super().__init__() 

52 dropout = kwargs.get('dropout', None) 

53 

54 self.encoder_hidden_layer = nn.Linear( 

55 in_features=orig_dim, out_features=inter_dim 

56 ) 

57 self.encoder_output_layer = nn.Linear( 

58 in_features=inter_dim, out_features=code_dim 

59 ) 

60 

61 if dropout is not None: 

62 if 0 <= dropout < 1: 

63 self.dropout = nn.Dropout(p=dropout) 

64 else: 

65 raise ValueError('Dropout rate should be in the range 0<=dropout<1') 

66 else: 

67 self.dropout = nn.Dropout(0.0) 

68 

69 if device is None: 

70 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

71 else: 

72 self._device = device 

73 

74 def forward(self, features): 

75 activation = self.encoder_hidden_layer(features) 

76 activation = self.dropout(torch.ones(activation.shape).to(self._device)) * activation 

77 activation = F.leaky_relu(activation) 

78 # No sigmoid activation for VAE! The output represents mean and log variance 

79 code = self.encoder_output_layer(activation) 

80 return code 

81 

82 

83class Decoder(nn.Module): 

84 

85 def __init__(self, code_dim, inter_dim, orig_dim, kwargs, device=None): 

86 super().__init__() 

87 dropout = kwargs.get('dropout', None) 

88 

89 self.decoder_hidden_layer = nn.Linear( 

90 in_features=code_dim, out_features=inter_dim 

91 ) 

92 self.decoder_output_layer = nn.Linear( 

93 in_features=inter_dim, out_features=orig_dim 

94 ) 

95 

96 if dropout is not None: 

97 if 0 <= dropout < 1: 

98 self.dropout = nn.Dropout(p=dropout) 

99 else: 

100 raise ValueError('Dropout rate should be in the range 0<=dropout<1') 

101 else: 

102 self.dropout = nn.Dropout(0.0) 

103 

104 if device is None: 

105 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

106 else: 

107 self._device = device 

108 

109 def forward(self, features): 

110 activation = self.decoder_hidden_layer(features) 

111 activation = self.dropout(torch.ones(activation.shape).to(self._device)) * activation 

112 # activation = torch.relu(activation) 

113 activation = F.leaky_relu(activation) 

114 activation = self.decoder_output_layer(activation) 

115 reconstructed = activation 

116 # reconstructed = torch.sigmoid(activation) 

117 return reconstructed 

118 

119 

120class AE(nn.Module): 

121 

122 def __init__(self, orig_dim, inter_dim, code_dim, enc_kwargs, dec_kwargs, device): 

123 super(AE, self).__init__() 

124 

125 self.encoder = Encoder(orig_dim=orig_dim, inter_dim=inter_dim, code_dim=code_dim, kwargs=enc_kwargs, device=device) 

126 self.decoder = Decoder(orig_dim=orig_dim, inter_dim=inter_dim, code_dim=code_dim, kwargs=dec_kwargs, device=device) 

127 self.orig_dim = orig_dim 

128 self.inter_dim = inter_dim 

129 self.code_dim = code_dim 

130 self._device = device 

131 

132 def forward(self, features): 

133 code = self.encoder.forward(features) 

134 reconstructed = self.decoder.forward(code) 

135 return reconstructed 

136 

137 def get_code_embedding(self, input_): 

138 encoder = self.encoder 

139 embedding = encoder.forward(input_) 

140 return embedding.detach().cpu().numpy().T 

141 

142 

143class VAE(nn.Module): 

144 

145 def __init__(self, orig_dim, inter_dim, code_dim, enc_kwargs=None, dec_kwargs=None, device=None): 

146 super(VAE, self).__init__() 

147 

148 # Use VAEEncoder instead of regular Encoder 

149 self.encoder = VAEEncoder(orig_dim=orig_dim, inter_dim=inter_dim, code_dim=2 * code_dim, kwargs=enc_kwargs or {}, device=device) 

150 self.decoder = Decoder(orig_dim=orig_dim, inter_dim=inter_dim, code_dim=code_dim, kwargs=dec_kwargs or {}, device=device) 

151 self.orig_dim = orig_dim 

152 self.inter_dim = inter_dim 

153 self.code_dim = code_dim 

154 

155 def reparameterization(self, mu, log_var): 

156 """ 

157 :param mu: mean from the encoder's latent space 

158 :param log_var: log variance from the encoder's latent space 

159 """ 

160 std = torch.exp(0.5 * log_var) # standard deviation 

161 eps = torch.randn_like(std) # `randn_like` as we need the same size 

162 sample = mu + (eps * std) # sampling as if coming from the input space 

163 return sample 

164 

165 def get_code(self, features): 

166 x = self.encoder.forward(features) 

167 

168 # print('x shape:', x.shape) 

169 x = x.view(-1, 2, self.code_dim) 

170 

171 # get `mu` and `log_var` 

172 mu = x[:, 0, :] # the first feature values as mean 

173 log_var = x[:, 1, :] # the other feature values as variance 

174 

175 # print('mu shape:', mu.shape) 

176 # get the latent vector through reparameterization 

177 code = self.reparameterization(mu, log_var) 

178 # print('code shape:', mu.shape) 

179 

180 return code, mu, log_var 

181 

182 def forward(self, features): 

183 # encoding 

184 code, mu, log_var = self.get_code(features) 

185 

186 # decoding 

187 reconstructed = self.decoder.forward(code) 

188 return reconstructed, mu, log_var 

189 

190 def get_code_embedding(self, input_): 

191 #encoder = self.encoder 

192 embedding, mu, log_var = self.get_code(input_) 

193 return embedding.detach().cpu().numpy().T 

194 

195 

196class NeuroDataset(Dataset): 

197 """Neural activity dataset.""" 

198 

199 def __init__(self, data, transform=None): 

200 

201 self.data = data.T 

202 self.transform = transform 

203 

204 def __len__(self): 

205 return len(self.data) 

206 

207 def __getitem__(self, idx): 

208 if torch.is_tensor(idx): 

209 idx = idx.tolist() 

210 

211 sample = {'vector': self.data[idx].reshape(-1, 1), 'target': 0} 

212 

213 if self.transform: 

214 sample = self.transform(sample) 

215 

216 return self.data[idx], -42, idx 

217 # return sample