Coverage for src/driada/dim_reduction/neural.py: 92.74%
124 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import torch.optim as optim
5# import torchvision
6from torch.utils.data import Dataset, DataLoader
9class Encoder(nn.Module):
11 def __init__(self, orig_dim, inter_dim, code_dim, kwargs, device=None):
12 super().__init__()
13 dropout = kwargs.get('dropout', None)
15 self.encoder_hidden_layer = nn.Linear(
16 in_features=orig_dim, out_features=inter_dim
17 )
18 self.encoder_output_layer = nn.Linear(
19 in_features=inter_dim, out_features=code_dim
20 )
22 if dropout is not None:
23 if 0 <= dropout < 1:
24 self.dropout = nn.Dropout(p=dropout)
25 else:
26 raise ValueError('Dropout rate should be in the range 0<=dropout<1')
27 else:
28 self.dropout = nn.Dropout(0.0)
30 if device is None:
31 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 else:
33 self._device = device
35 def forward(self, features):
36 activation = self.encoder_hidden_layer(features)
37 activation = self.dropout(torch.ones(activation.shape).to(self._device)) * activation
38 activation = F.leaky_relu(activation)
39 # activation = torch.relu(activation)
40 code = self.encoder_output_layer(activation)
41 code = torch.sigmoid(code)
42 # code = F.leaky_relu(code)
44 return code
47class VAEEncoder(nn.Module):
48 """Special encoder for VAE that doesn't use sigmoid activation"""
50 def __init__(self, orig_dim, inter_dim, code_dim, kwargs, device=None):
51 super().__init__()
52 dropout = kwargs.get('dropout', None)
54 self.encoder_hidden_layer = nn.Linear(
55 in_features=orig_dim, out_features=inter_dim
56 )
57 self.encoder_output_layer = nn.Linear(
58 in_features=inter_dim, out_features=code_dim
59 )
61 if dropout is not None:
62 if 0 <= dropout < 1:
63 self.dropout = nn.Dropout(p=dropout)
64 else:
65 raise ValueError('Dropout rate should be in the range 0<=dropout<1')
66 else:
67 self.dropout = nn.Dropout(0.0)
69 if device is None:
70 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71 else:
72 self._device = device
74 def forward(self, features):
75 activation = self.encoder_hidden_layer(features)
76 activation = self.dropout(torch.ones(activation.shape).to(self._device)) * activation
77 activation = F.leaky_relu(activation)
78 # No sigmoid activation for VAE! The output represents mean and log variance
79 code = self.encoder_output_layer(activation)
80 return code
83class Decoder(nn.Module):
85 def __init__(self, code_dim, inter_dim, orig_dim, kwargs, device=None):
86 super().__init__()
87 dropout = kwargs.get('dropout', None)
89 self.decoder_hidden_layer = nn.Linear(
90 in_features=code_dim, out_features=inter_dim
91 )
92 self.decoder_output_layer = nn.Linear(
93 in_features=inter_dim, out_features=orig_dim
94 )
96 if dropout is not None:
97 if 0 <= dropout < 1:
98 self.dropout = nn.Dropout(p=dropout)
99 else:
100 raise ValueError('Dropout rate should be in the range 0<=dropout<1')
101 else:
102 self.dropout = nn.Dropout(0.0)
104 if device is None:
105 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106 else:
107 self._device = device
109 def forward(self, features):
110 activation = self.decoder_hidden_layer(features)
111 activation = self.dropout(torch.ones(activation.shape).to(self._device)) * activation
112 # activation = torch.relu(activation)
113 activation = F.leaky_relu(activation)
114 activation = self.decoder_output_layer(activation)
115 reconstructed = activation
116 # reconstructed = torch.sigmoid(activation)
117 return reconstructed
120class AE(nn.Module):
122 def __init__(self, orig_dim, inter_dim, code_dim, enc_kwargs, dec_kwargs, device):
123 super(AE, self).__init__()
125 self.encoder = Encoder(orig_dim=orig_dim, inter_dim=inter_dim, code_dim=code_dim, kwargs=enc_kwargs, device=device)
126 self.decoder = Decoder(orig_dim=orig_dim, inter_dim=inter_dim, code_dim=code_dim, kwargs=dec_kwargs, device=device)
127 self.orig_dim = orig_dim
128 self.inter_dim = inter_dim
129 self.code_dim = code_dim
130 self._device = device
132 def forward(self, features):
133 code = self.encoder.forward(features)
134 reconstructed = self.decoder.forward(code)
135 return reconstructed
137 def get_code_embedding(self, input_):
138 encoder = self.encoder
139 embedding = encoder.forward(input_)
140 return embedding.detach().cpu().numpy().T
143class VAE(nn.Module):
145 def __init__(self, orig_dim, inter_dim, code_dim, enc_kwargs=None, dec_kwargs=None, device=None):
146 super(VAE, self).__init__()
148 # Use VAEEncoder instead of regular Encoder
149 self.encoder = VAEEncoder(orig_dim=orig_dim, inter_dim=inter_dim, code_dim=2 * code_dim, kwargs=enc_kwargs or {}, device=device)
150 self.decoder = Decoder(orig_dim=orig_dim, inter_dim=inter_dim, code_dim=code_dim, kwargs=dec_kwargs or {}, device=device)
151 self.orig_dim = orig_dim
152 self.inter_dim = inter_dim
153 self.code_dim = code_dim
155 def reparameterization(self, mu, log_var):
156 """
157 :param mu: mean from the encoder's latent space
158 :param log_var: log variance from the encoder's latent space
159 """
160 std = torch.exp(0.5 * log_var) # standard deviation
161 eps = torch.randn_like(std) # `randn_like` as we need the same size
162 sample = mu + (eps * std) # sampling as if coming from the input space
163 return sample
165 def get_code(self, features):
166 x = self.encoder.forward(features)
168 # print('x shape:', x.shape)
169 x = x.view(-1, 2, self.code_dim)
171 # get `mu` and `log_var`
172 mu = x[:, 0, :] # the first feature values as mean
173 log_var = x[:, 1, :] # the other feature values as variance
175 # print('mu shape:', mu.shape)
176 # get the latent vector through reparameterization
177 code = self.reparameterization(mu, log_var)
178 # print('code shape:', mu.shape)
180 return code, mu, log_var
182 def forward(self, features):
183 # encoding
184 code, mu, log_var = self.get_code(features)
186 # decoding
187 reconstructed = self.decoder.forward(code)
188 return reconstructed, mu, log_var
190 def get_code_embedding(self, input_):
191 #encoder = self.encoder
192 embedding, mu, log_var = self.get_code(input_)
193 return embedding.detach().cpu().numpy().T
196class NeuroDataset(Dataset):
197 """Neural activity dataset."""
199 def __init__(self, data, transform=None):
201 self.data = data.T
202 self.transform = transform
204 def __len__(self):
205 return len(self.data)
207 def __getitem__(self, idx):
208 if torch.is_tensor(idx):
209 idx = idx.tolist()
211 sample = {'vector': self.data[idx].reshape(-1, 1), 'target': 0}
213 if self.transform:
214 sample = self.transform(sample)
216 return self.data[idx], -42, idx
217 # return sample