Coverage for nn_encoding.py : 88%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1try:
2 from typing import Dict, List, Tuple, Optional, Union, Literal
4 LiteralFalse = Literal[False]
5except ImportError:
6 # Python pre-3.8 compatibility
7 from typing import Dict, List, Tuple, Optional, Union, NewType
9 LiteralFalse = NewType("LiteralFalse", bool)
11import logging
12import math
13import torch
14from torch import Tensor, nn
15from elfragmentador import constants
16import pytorch_lightning as pl
19class SeqPositionalEmbed(torch.nn.Module):
20 def __init__(self, dims_add: int = 10, max_len: int = 30, inverted=True):
21 super().__init__()
22 pe = torch.zeros(max_len, dims_add)
23 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
25 div_term_enum = torch.arange(0, dims_add, 2).float()
26 div_term_denom = -math.log(10000.0) / dims_add + 1
27 div_term = torch.exp(div_term_enum * div_term_denom)
28 pe[:, 0::2] = torch.sin(position * div_term)
29 pe[:, 1::2] = torch.cos(position * div_term)
30 pe[0, :] = 0
31 self.register_buffer("pe", pe)
32 self.inverted = inverted
34 def forward(self, x: torch.LongTensor):
35 """forward Concatenates the values to the
37 [extended_summary]
39 Parameters
40 ----------
41 x : Tensor
42 Integer Tensor of shape [BatchSize, SequenceLength], this should encode
43 a sequence and be padded with zeros
45 Returns
46 -------
47 Tensor
48 Tensor of shape [SequenceLength, BatchSize, DimensionsAdded],
50 Example
51 -------
52 >>> encoder = SeqPositionalEmbed(6, 50, inverted=True)
53 >>> x = torch.cat([torch.ones(1,2), torch.ones(1,2)*2, torch.zeros((1,2))], dim = -1).long()
54 >>> x[0]
55 tensor([1, 1, 2, 2, 0, 0])
56 >>> x.shape
57 torch.Size([1, 6])
58 >>> out = encoder(x)
59 >>> out.shape
60 torch.Size([6, 1, 6])
61 >>> out
62 tensor([[[-0.7568, -0.6536, 0.9803, 0.1976, 0.4533, 0.8913]],
63 [[ 0.1411, -0.9900, 0.8567, 0.5158, 0.3456, 0.9384]],
64 [[ 0.9093, -0.4161, 0.6334, 0.7738, 0.2331, 0.9725]],
65 [[ 0.8415, 0.5403, 0.3363, 0.9418, 0.1174, 0.9931]],
66 [[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
67 [[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
68 """
69 vals = x.bool().long()
70 if self.inverted:
71 vals = vals.flip(1)
73 out = self.pe[vals.cumsum(1)]
74 if self.inverted:
75 out = out.flip(1)
77 return out.transpose(1, 0)
80class ConcatenationEncoder(torch.nn.Module):
81 """ConcatenationEncoder concatenates information into the embedding.
83 Adds information on continuous variables into an embedding by concatenating an n number
84 of dimensions to it
86 Parameters
87 ----------
88 dims_add : int
89 Number of dimensions to add as an encoding
90 dropout : float, optional
91 dropout, by default 0.1
92 max_val : float, optional
93 maximum expected value of the variable that will be encoded, by default 200.0
94 static_size : Union[Literal[False], float], optional
95 Optional ingeter to pass in order to make the size deterministic.
96 This is only required if you want to export your model to torchscript, by default False
98 Examples
99 --------
100 >>> x1 = torch.zeros((5, 1, 20))
101 >>> x2 = torch.zeros((5, 2, 20))
102 >>> encoder = ConcatenationEncoder(10, 0.1, 10)
103 >>> output = encoder(x1, torch.tensor([[7]]))
104 >>> output = encoder(x2, torch.tensor([[7], [4]]))
105 """
107 # TODO evaluate if fropout is actually useful here ...
109 def __init__(
110 self,
111 dims_add: int,
112 dropout: float = 0.1,
113 max_val: Union[float, int] = 200.0,
114 static_size: bool = False,
115 ) -> None:
116 super().__init__()
117 self.dropout = torch.nn.Dropout(p=dropout)
119 # pos would be a variable ...
120 div_term = torch.exp(
121 torch.arange(0, dims_add, 2).float()
122 * (-math.log(float(2 * max_val)) / (dims_add))
123 )
124 self.register_buffer("div_term", div_term)
125 self.static_size = static_size
126 self.dims_add = dims_add
128 def forward(self, x: Tensor, val: Tensor, debug: bool = False) -> Tensor:
129 r"""Forward pass thought the encoder.
131 Args
132 ----
133 x:
134 the sequence fed to the encoder model (required).
135 val:
136 value to be encoded into the sequence (required).
137 Shape:
138 x: [sequence length, batch size, embed dim]
139 val: [batch size, 1]
140 output: [sequence length, batch size, embed_dim + added_dims]
142 Examples
143 --------
144 >>> x1 = torch.zeros((5, 1, 20))
145 >>> x2 = torch.cat([x1, x1+1], axis = 1)
146 >>> encoder = ConcatenationEncoder(10, dropout = 0, max_val = 10)
147 >>> output = encoder(x1, torch.tensor([[7]]))
148 >>> output.shape
149 torch.Size([5, 1, 30])
150 >>> output = encoder(x2, torch.tensor([[7], [4]]))
151 """
152 if debug:
153 logging.debug(f"CE: Shape of inputs val={val.shape} x={x.shape}")
155 if self.static_size:
156 assert self.static_size == x.size(0), (
157 f"Size of the first dimension ({x.size(0)}) "
158 f"does not match the expected value ({self.static_size})"
159 )
160 end_position = self.static_size
161 else:
162 end_position = x.size(0)
164 e_sin = torch.sin(val * self.div_term)
165 e_cos = torch.cos(torch.cos(val * self.div_term))
166 e = torch.cat([e_sin, e_cos], axis=-1)
168 if debug:
169 logging.debug(f"CE: Making encodings e={e.shape}")
171 assert (
172 e.shape[-1] < self.dims_add + 2
173 ), "Internal error in concatenation encoder"
174 e = e[..., : self.dims_add]
176 if debug:
177 logging.debug(f"CE: clipping encodings e={e.shape}")
179 e = torch.cat([e.unsqueeze(0)] * end_position)
181 if debug:
182 logging.debug(f"CE: Shape before concat e={e.shape} x={x.shape}")
184 x = torch.cat((x, e), axis=-1)
185 if debug:
186 logging.debug(f"CE: Shape after concat x={x.shape}")
187 return self.dropout(x)
190class PositionalEncoding(torch.nn.Module):
191 r"""PositionalEncoding adds positional information to tensors.
193 Inject some information about the relative or absolute position of the tokens
194 in the sequence. The positional encodings have the same dimension as
195 the embeddings, so that the two can be summed. Here, we use sine and cosine
196 functions of different frequencies.
198 .. math::
199 \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
200 \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
201 \text{where pos is the word position and i is the embed idx)
203 Args
204 ----
205 d_model: int
206 the embed dim (required), must be even.
207 dropout: float
208 the dropout value (default=0.1).
209 max_len: int
210 the max. length of the incoming sequence (default=5000).
211 static_size : Union[LiteralFalse, int], optional
212 If it is an integer it is the size of the inputs that will
213 be given, it is used only when tracing the model for torchscript
214 (since torchscript needs fixed length inputs), by default False
216 Examples
217 --------
218 >>> posencoder = PositionalEncoding(20, 0.1, max_len=20)
219 >>> x = torch.ones((2,1,20)).float()
220 >>> x.shape
221 torch.Size([2, 1, 20])
222 >>> posencoder(x).shape
223 torch.Size([2, 1, 20])
225 Therefore encoding are (seq_length, batch, encodings)
226 """
228 def __init__(
229 self,
230 d_model: int,
231 dropout: float = 0.1,
232 max_len: int = 5000,
233 static_size: Union[LiteralFalse, int] = False,
234 ) -> None:
235 """__init__ Creates a new instance."""
236 super(PositionalEncoding, self).__init__()
237 self.dropout = torch.nn.Dropout(p=dropout)
239 pe = torch.zeros(max_len, d_model)
240 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
241 div_term = torch.exp(
242 torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
243 )
244 pe[:, 0::2] = torch.sin(position * div_term)
245 pe[:, 1::2] = torch.cos(position * div_term)
246 pe = pe.unsqueeze(0).transpose(0, 1)
247 self.register_buffer("pe", pe)
248 self.static_size = static_size
250 def forward(self, x: Tensor) -> Tensor:
251 r"""Forward pass though the encoder.
253 Args
254 ----
255 x: the sequence fed to the positional encoder model (required).
256 Shape
257 -----
258 x: [sequence length, batch size, embed dim]
259 output: [sequence length, batch size, embed dim]
261 Examples
262 --------
263 >>> pl.seed_everything(42)
264 42
265 >>> x = torch.ones((1,4,6)).float()
266 >>> pos_encoder = PositionalEncoding(6, 0.1, max_len=10)
267 >>> output = pos_encoder(x)
268 >>> output.shape
269 torch.Size([1, 4, 6])
270 >>> output
271 tensor([[[1.1111, 2.2222, 1.1111, 2.2222, 1.1111, 2.2222],
272 [1.1111, 2.2222, 1.1111, 2.2222, 1.1111, 2.2222],
273 [1.1111, 2.2222, 1.1111, 2.2222, 1.1111, 0.0000],
274 [1.1111, 2.2222, 1.1111, 2.2222, 1.1111, 2.2222]]])
275 """
276 if self.static_size:
277 end_position = self.static_size
278 else:
279 end_position = x.size(0)
281 x = x + self.pe[:end_position, :]
282 return self.dropout(x)
285class AASequenceEmbedding(torch.nn.Module):
286 def __init__(self, ninp, position_ratio=0.1):
287 super().__init__()
288 positional_ninp = int((ninp / 2) * position_ratio)
289 if positional_ninp % 2:
290 positional_ninp += 1
291 ninp_embed = int(ninp - (2 * positional_ninp))
293 # Positional information additions
294 self.fw_position_embed = SeqPositionalEmbed(
295 max_len=constants.MAX_TENSOR_SEQUENCE * 4,
296 dims_add=positional_ninp,
297 inverted=False,
298 )
299 self.rev_position_embed = SeqPositionalEmbed(
300 max_len=constants.MAX_TENSOR_SEQUENCE * 4,
301 dims_add=positional_ninp,
302 inverted=True,
303 )
305 # Aminoacid embedding
306 self.aa_encoder = nn.Embedding(constants.AAS_NUM + 1, ninp_embed, padding_idx=0)
307 # PTM embedding
308 self.mod_encoder = nn.Embedding(
309 len(constants.MODIFICATION) + 1, ninp_embed, padding_idx=0
310 )
312 # Weight Initialization
313 self.init_weights()
314 self.ninp = ninp_embed
316 def init_weights(self) -> None:
317 initrange = 0.1
318 ptm_initrange = initrange * 0.01
319 torch.nn.init.uniform_(self.aa_encoder.weight, -initrange, initrange)
320 torch.nn.init.uniform_(self.mod_encoder.weight, -ptm_initrange, ptm_initrange)
322 def forward(self, src, mods, debug: bool = False):
323 if debug:
324 logging.debug(f"AAE: Input shapes src={src.shape}, mods={mods.shape}")
325 fw_pos_emb = self.fw_position_embed(src)
326 rev_pos_emb = self.rev_position_embed(src)
328 src = self.aa_encoder(src.permute(1, 0))
329 mods = self.mod_encoder(mods.permute(1, 0))
330 src = src + mods
332 # TODO consider if this line is needed
333 src = src * math.sqrt(self.ninp)
334 if debug:
335 logging.debug(f"AAE: Shape after embedding {src.shape}")
337 src = torch.cat([src, fw_pos_emb, rev_pos_emb], dim=-1)
338 if debug:
339 logging.debug(f"AAE: Shape after embedding positions {src.shape}")
341 return src