Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1try: 

2 from typing import Dict, List, Tuple, Optional, Union, Literal 

3 

4 LiteralFalse = Literal[False] 

5except ImportError: 

6 # Python pre-3.8 compatibility 

7 from typing import Dict, List, Tuple, Optional, Union, NewType 

8 

9 LiteralFalse = NewType("LiteralFalse", bool) 

10 

11import logging 

12import math 

13import torch 

14from torch import Tensor, nn 

15from elfragmentador import constants 

16import pytorch_lightning as pl 

17 

18 

19class SeqPositionalEmbed(torch.nn.Module): 

20 def __init__(self, dims_add: int = 10, max_len: int = 30, inverted=True): 

21 super().__init__() 

22 pe = torch.zeros(max_len, dims_add) 

23 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 

24 

25 div_term_enum = torch.arange(0, dims_add, 2).float() 

26 div_term_denom = -math.log(10000.0) / dims_add + 1 

27 div_term = torch.exp(div_term_enum * div_term_denom) 

28 pe[:, 0::2] = torch.sin(position * div_term) 

29 pe[:, 1::2] = torch.cos(position * div_term) 

30 pe[0, :] = 0 

31 self.register_buffer("pe", pe) 

32 self.inverted = inverted 

33 

34 def forward(self, x: torch.LongTensor): 

35 """forward Concatenates the values to the 

36 

37 [extended_summary] 

38 

39 Parameters 

40 ---------- 

41 x : Tensor 

42 Integer Tensor of shape [BatchSize, SequenceLength], this should encode 

43 a sequence and be padded with zeros 

44 

45 Returns 

46 ------- 

47 Tensor 

48 Tensor of shape [SequenceLength, BatchSize, DimensionsAdded], 

49 

50 Example 

51 ------- 

52 >>> encoder = SeqPositionalEmbed(6, 50, inverted=True) 

53 >>> x = torch.cat([torch.ones(1,2), torch.ones(1,2)*2, torch.zeros((1,2))], dim = -1).long() 

54 >>> x[0] 

55 tensor([1, 1, 2, 2, 0, 0]) 

56 >>> x.shape 

57 torch.Size([1, 6]) 

58 >>> out = encoder(x) 

59 >>> out.shape 

60 torch.Size([6, 1, 6]) 

61 >>> out 

62 tensor([[[-0.7568, -0.6536, 0.9803, 0.1976, 0.4533, 0.8913]], 

63 [[ 0.1411, -0.9900, 0.8567, 0.5158, 0.3456, 0.9384]], 

64 [[ 0.9093, -0.4161, 0.6334, 0.7738, 0.2331, 0.9725]], 

65 [[ 0.8415, 0.5403, 0.3363, 0.9418, 0.1174, 0.9931]], 

66 [[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], 

67 [[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]) 

68 """ 

69 vals = x.bool().long() 

70 if self.inverted: 

71 vals = vals.flip(1) 

72 

73 out = self.pe[vals.cumsum(1)] 

74 if self.inverted: 

75 out = out.flip(1) 

76 

77 return out.transpose(1, 0) 

78 

79 

80class ConcatenationEncoder(torch.nn.Module): 

81 """ConcatenationEncoder concatenates information into the embedding. 

82 

83 Adds information on continuous variables into an embedding by concatenating an n number 

84 of dimensions to it 

85 

86 Parameters 

87 ---------- 

88 dims_add : int 

89 Number of dimensions to add as an encoding 

90 dropout : float, optional 

91 dropout, by default 0.1 

92 max_val : float, optional 

93 maximum expected value of the variable that will be encoded, by default 200.0 

94 static_size : Union[Literal[False], float], optional 

95 Optional ingeter to pass in order to make the size deterministic. 

96 This is only required if you want to export your model to torchscript, by default False 

97 

98 Examples 

99 -------- 

100 >>> x1 = torch.zeros((5, 1, 20)) 

101 >>> x2 = torch.zeros((5, 2, 20)) 

102 >>> encoder = ConcatenationEncoder(10, 0.1, 10) 

103 >>> output = encoder(x1, torch.tensor([[7]])) 

104 >>> output = encoder(x2, torch.tensor([[7], [4]])) 

105 """ 

106 

107 # TODO evaluate if fropout is actually useful here ... 

108 

109 def __init__( 

110 self, 

111 dims_add: int, 

112 dropout: float = 0.1, 

113 max_val: Union[float, int] = 200.0, 

114 static_size: bool = False, 

115 ) -> None: 

116 super().__init__() 

117 self.dropout = torch.nn.Dropout(p=dropout) 

118 

119 # pos would be a variable ... 

120 div_term = torch.exp( 

121 torch.arange(0, dims_add, 2).float() 

122 * (-math.log(float(2 * max_val)) / (dims_add)) 

123 ) 

124 self.register_buffer("div_term", div_term) 

125 self.static_size = static_size 

126 self.dims_add = dims_add 

127 

128 def forward(self, x: Tensor, val: Tensor, debug: bool = False) -> Tensor: 

129 r"""Forward pass thought the encoder. 

130 

131 Args 

132 ---- 

133 x: 

134 the sequence fed to the encoder model (required). 

135 val: 

136 value to be encoded into the sequence (required). 

137 Shape: 

138 x: [sequence length, batch size, embed dim] 

139 val: [batch size, 1] 

140 output: [sequence length, batch size, embed_dim + added_dims] 

141 

142 Examples 

143 -------- 

144 >>> x1 = torch.zeros((5, 1, 20)) 

145 >>> x2 = torch.cat([x1, x1+1], axis = 1) 

146 >>> encoder = ConcatenationEncoder(10, dropout = 0, max_val = 10) 

147 >>> output = encoder(x1, torch.tensor([[7]])) 

148 >>> output.shape 

149 torch.Size([5, 1, 30]) 

150 >>> output = encoder(x2, torch.tensor([[7], [4]])) 

151 """ 

152 if debug: 

153 logging.debug(f"CE: Shape of inputs val={val.shape} x={x.shape}") 

154 

155 if self.static_size: 

156 assert self.static_size == x.size(0), ( 

157 f"Size of the first dimension ({x.size(0)}) " 

158 f"does not match the expected value ({self.static_size})" 

159 ) 

160 end_position = self.static_size 

161 else: 

162 end_position = x.size(0) 

163 

164 e_sin = torch.sin(val * self.div_term) 

165 e_cos = torch.cos(torch.cos(val * self.div_term)) 

166 e = torch.cat([e_sin, e_cos], axis=-1) 

167 

168 if debug: 

169 logging.debug(f"CE: Making encodings e={e.shape}") 

170 

171 assert ( 

172 e.shape[-1] < self.dims_add + 2 

173 ), "Internal error in concatenation encoder" 

174 e = e[..., : self.dims_add] 

175 

176 if debug: 

177 logging.debug(f"CE: clipping encodings e={e.shape}") 

178 

179 e = torch.cat([e.unsqueeze(0)] * end_position) 

180 

181 if debug: 

182 logging.debug(f"CE: Shape before concat e={e.shape} x={x.shape}") 

183 

184 x = torch.cat((x, e), axis=-1) 

185 if debug: 

186 logging.debug(f"CE: Shape after concat x={x.shape}") 

187 return self.dropout(x) 

188 

189 

190class PositionalEncoding(torch.nn.Module): 

191 r"""PositionalEncoding adds positional information to tensors. 

192 

193 Inject some information about the relative or absolute position of the tokens 

194 in the sequence. The positional encodings have the same dimension as 

195 the embeddings, so that the two can be summed. Here, we use sine and cosine 

196 functions of different frequencies. 

197 

198 .. math:: 

199 \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 

200 \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 

201 \text{where pos is the word position and i is the embed idx) 

202 

203 Args 

204 ---- 

205 d_model: int 

206 the embed dim (required), must be even. 

207 dropout: float 

208 the dropout value (default=0.1). 

209 max_len: int 

210 the max. length of the incoming sequence (default=5000). 

211 static_size : Union[LiteralFalse, int], optional 

212 If it is an integer it is the size of the inputs that will 

213 be given, it is used only when tracing the model for torchscript 

214 (since torchscript needs fixed length inputs), by default False 

215 

216 Examples 

217 -------- 

218 >>> posencoder = PositionalEncoding(20, 0.1, max_len=20) 

219 >>> x = torch.ones((2,1,20)).float() 

220 >>> x.shape 

221 torch.Size([2, 1, 20]) 

222 >>> posencoder(x).shape 

223 torch.Size([2, 1, 20]) 

224 

225 Therefore encoding are (seq_length, batch, encodings) 

226 """ 

227 

228 def __init__( 

229 self, 

230 d_model: int, 

231 dropout: float = 0.1, 

232 max_len: int = 5000, 

233 static_size: Union[LiteralFalse, int] = False, 

234 ) -> None: 

235 """__init__ Creates a new instance.""" 

236 super(PositionalEncoding, self).__init__() 

237 self.dropout = torch.nn.Dropout(p=dropout) 

238 

239 pe = torch.zeros(max_len, d_model) 

240 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 

241 div_term = torch.exp( 

242 torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 

243 ) 

244 pe[:, 0::2] = torch.sin(position * div_term) 

245 pe[:, 1::2] = torch.cos(position * div_term) 

246 pe = pe.unsqueeze(0).transpose(0, 1) 

247 self.register_buffer("pe", pe) 

248 self.static_size = static_size 

249 

250 def forward(self, x: Tensor) -> Tensor: 

251 r"""Forward pass though the encoder. 

252 

253 Args 

254 ---- 

255 x: the sequence fed to the positional encoder model (required). 

256 Shape 

257 ----- 

258 x: [sequence length, batch size, embed dim] 

259 output: [sequence length, batch size, embed dim] 

260 

261 Examples 

262 -------- 

263 >>> pl.seed_everything(42) 

264 42 

265 >>> x = torch.ones((1,4,6)).float() 

266 >>> pos_encoder = PositionalEncoding(6, 0.1, max_len=10) 

267 >>> output = pos_encoder(x) 

268 >>> output.shape 

269 torch.Size([1, 4, 6]) 

270 >>> output 

271 tensor([[[1.1111, 2.2222, 1.1111, 2.2222, 1.1111, 2.2222], 

272 [1.1111, 2.2222, 1.1111, 2.2222, 1.1111, 2.2222], 

273 [1.1111, 2.2222, 1.1111, 2.2222, 1.1111, 0.0000], 

274 [1.1111, 2.2222, 1.1111, 2.2222, 1.1111, 2.2222]]]) 

275 """ 

276 if self.static_size: 

277 end_position = self.static_size 

278 else: 

279 end_position = x.size(0) 

280 

281 x = x + self.pe[:end_position, :] 

282 return self.dropout(x) 

283 

284 

285class AASequenceEmbedding(torch.nn.Module): 

286 def __init__(self, ninp, position_ratio=0.1): 

287 super().__init__() 

288 positional_ninp = int((ninp / 2) * position_ratio) 

289 if positional_ninp % 2: 

290 positional_ninp += 1 

291 ninp_embed = int(ninp - (2 * positional_ninp)) 

292 

293 # Positional information additions 

294 self.fw_position_embed = SeqPositionalEmbed( 

295 max_len=constants.MAX_TENSOR_SEQUENCE * 4, 

296 dims_add=positional_ninp, 

297 inverted=False, 

298 ) 

299 self.rev_position_embed = SeqPositionalEmbed( 

300 max_len=constants.MAX_TENSOR_SEQUENCE * 4, 

301 dims_add=positional_ninp, 

302 inverted=True, 

303 ) 

304 

305 # Aminoacid embedding 

306 self.aa_encoder = nn.Embedding(constants.AAS_NUM + 1, ninp_embed, padding_idx=0) 

307 # PTM embedding 

308 self.mod_encoder = nn.Embedding( 

309 len(constants.MODIFICATION) + 1, ninp_embed, padding_idx=0 

310 ) 

311 

312 # Weight Initialization 

313 self.init_weights() 

314 self.ninp = ninp_embed 

315 

316 def init_weights(self) -> None: 

317 initrange = 0.1 

318 ptm_initrange = initrange * 0.01 

319 torch.nn.init.uniform_(self.aa_encoder.weight, -initrange, initrange) 

320 torch.nn.init.uniform_(self.mod_encoder.weight, -ptm_initrange, ptm_initrange) 

321 

322 def forward(self, src, mods, debug: bool = False): 

323 if debug: 

324 logging.debug(f"AAE: Input shapes src={src.shape}, mods={mods.shape}") 

325 fw_pos_emb = self.fw_position_embed(src) 

326 rev_pos_emb = self.rev_position_embed(src) 

327 

328 src = self.aa_encoder(src.permute(1, 0)) 

329 mods = self.mod_encoder(mods.permute(1, 0)) 

330 src = src + mods 

331 

332 # TODO consider if this line is needed 

333 src = src * math.sqrt(self.ninp) 

334 if debug: 

335 logging.debug(f"AAE: Shape after embedding {src.shape}") 

336 

337 src = torch.cat([src, fw_pos_emb, rev_pos_emb], dim=-1) 

338 if debug: 

339 logging.debug(f"AAE: Shape after embedding positions {src.shape}") 

340 

341 return src