Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import logging 

2 

3try: 

4 from typing import Dict, List, Tuple, Optional, Union, Literal 

5 

6 LiteralFalse = Literal[False] 

7except ImportError: 

8 # Python pre-3.8 compatibility 

9 from typing import Dict, List, Tuple, Optional, Union, NewType 

10 

11 LiteralFalse = NewType("LiteralFalse", bool) 

12 

13import warnings 

14import math 

15import time 

16from collections import namedtuple 

17 

18import torch 

19from torch import Tensor, nn 

20import pytorch_lightning as pl 

21 

22from argparse import _ArgumentGroup, ArgumentParser 

23 

24import elfragmentador 

25from elfragmentador import constants 

26from elfragmentador import encoding_decoding 

27from elfragmentador.spectra import Spectrum 

28from elfragmentador.datamodules import TrainBatch 

29from elfragmentador.metrics import CosineLoss 

30from elfragmentador.nn_encoding import ( 

31 ConcatenationEncoder, 

32 AASequenceEmbedding, 

33) 

34from elfragmentador.math_utils import nanmean 

35from torch.optim.adamw import AdamW 

36from torch.optim.lr_scheduler import ( 

37 CosineAnnealingWarmRestarts, 

38 OneCycleLR, 

39 ReduceLROnPlateau, 

40) 

41 

42PredictionResults = namedtuple("PredictionResults", "irt spectra") 

43ForwardBatch = namedtuple("ForwardBatch", "src nce mods charge") 

44 

45 

46class MLP(nn.Module): 

47 """MLP implements a very simple multi-layer perceptron (also called FFN). 

48 

49 Concatenates hidden linear layers with activations for n layers. 

50 This implementation uses gelu instead of relu 

51 (linear > gelu) * (n-1) > linear 

52 

53 Based on: https://github.com/facebookresearch/detr/blob/models/detr.py#L289 

54 """ 

55 

56 def __init__( 

57 self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int 

58 ) -> None: 

59 """__init__ create a new instance of the MLP. 

60 

61 Parameters 

62 ---------- 

63 input_dim : int 

64 Expected dimensions for the input 

65 hidden_dim : int 

66 Number of dimensions of the hidden layers 

67 output_dim : int 

68 Output dimensions 

69 num_layers : int 

70 Number of layers (total) 

71 """ 

72 super().__init__() 

73 self.num_layers = num_layers 

74 h = [hidden_dim] * (num_layers - 1) 

75 self.layers = nn.ModuleList( 

76 nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 

77 ) 

78 

79 def forward(self, x: Tensor) -> Tensor: 

80 """Forward pass over the network. 

81 

82 Parameters 

83 ---------- 

84 x : Tensor 

85 Dimensions should match the ones specified instantiating the class 

86 

87 Returns 

88 ------- 

89 Tensor 

90 The dims of this tensor are defined when instantiating the class 

91 

92 Examples 

93 -------- 

94 >>> pl.seed_everything(42) 

95 42 

96 >>> net = MLP(1000, 512, 2, 10) 

97 >>> out = net.forward(torch.rand([5, 1000])) 

98 >>> out 

99 tensor([[-0.0061, -0.0219], 

100 [-0.0061, -0.0219], 

101 [-0.0061, -0.0220], 

102 [-0.0061, -0.0220], 

103 [-0.0061, -0.0219]], grad_fn=<AddmmBackward>) 

104 >>> out.shape 

105 torch.Size([5, 2]) 

106 """ 

107 for i, layer in enumerate(self.layers): 

108 x = ( 

109 torch.nn.functional.gelu(layer(x)) 

110 if i < self.num_layers - 1 

111 else layer(x) 

112 ) 

113 return x 

114 

115 

116class _PeptideTransformerEncoder(torch.nn.Module): 

117 def __init__( 

118 self, ninp: int, dropout: float, nhead: int, nhid: int, layers: int 

119 ) -> None: 

120 super().__init__() 

121 

122 # Aminoacid embedding 

123 self.aa_encoder = AASequenceEmbedding(ninp=ninp, position_ratio=0.1) 

124 

125 # Transformer encoder sections 

126 encoder_layers = nn.TransformerEncoderLayer( 

127 d_model=ninp, 

128 nhead=nhead, 

129 dim_feedforward=nhid, 

130 dropout=dropout, 

131 activation="gelu", 

132 ) 

133 self.transformer_encoder = nn.TransformerEncoder(encoder_layers, layers) 

134 

135 def forward(self, src: Tensor, mods: Tensor, debug: bool = False) -> Tensor: 

136 trans_encoder_mask = ~src.bool() 

137 if debug: 

138 logging.debug(f"TE: Shape of mask {trans_encoder_mask.shape}") 

139 

140 src = self.aa_encoder.forward(src=src, mods=mods, debug=debug) 

141 if debug: 

142 logging.debug(f"TE: Shape after AASequence encoder {src.shape}") 

143 

144 trans_encoder_output = self.transformer_encoder.forward( 

145 src, src_key_padding_mask=trans_encoder_mask 

146 ) 

147 if debug: 

148 logging.debug(f"TE: Shape after trans encoder {trans_encoder_output.shape}") 

149 

150 return trans_encoder_output 

151 

152 

153class _PeptideTransformerDecoder(torch.nn.Module): 

154 def __init__( 

155 self, 

156 ninp: int, 

157 nhead: int, 

158 nhid: int, 

159 layers: int, 

160 dropout: float, 

161 charge_dims_pct: float = 0.05, 

162 nce_dims_pct: float = 0.05, 

163 ) -> None: 

164 super().__init__() 

165 logging.info( 

166 f"Creating TransformerDecoder nhid=nhid, ninp={ninp} nhead={nhead} layers={layers}" 

167 ) 

168 charge_dims = math.ceil(ninp * charge_dims_pct) 

169 nce_dims = math.ceil(ninp * nce_dims_pct) 

170 n_embeds = ninp - (charge_dims + nce_dims) 

171 

172 decoder_layer = nn.TransformerDecoderLayer( 

173 d_model=ninp, 

174 nhead=nhead, 

175 dim_feedforward=nhid, 

176 dropout=dropout, 

177 activation="gelu", 

178 ) 

179 self.trans_decoder = nn.TransformerDecoder(decoder_layer, num_layers=layers) 

180 self.peak_decoder = MLP(ninp, ninp, output_dim=1, num_layers=3) 

181 

182 logging.info( 

183 f"Creating embedding for spectra of length {constants.NUM_FRAG_EMBEDINGS}" 

184 ) 

185 self.trans_decoder_embedding = nn.Embedding( 

186 constants.NUM_FRAG_EMBEDINGS, n_embeds 

187 ) 

188 self.charge_encoder = ConcatenationEncoder( 

189 dims_add=charge_dims, dropout=dropout, max_val=10.0 

190 ) 

191 self.nce_encoder = ConcatenationEncoder( 

192 dims_add=nce_dims, dropout=dropout, max_val=100.0 

193 ) 

194 

195 def init_weights(self): 

196 initrange = 0.1 

197 nn.init.uniform_(self.trans_decoder_embedding.weight, -initrange, initrange) 

198 

199 def forward( 

200 self, 

201 src: Tensor, 

202 charge: Tensor, 

203 nce: Tensor, 

204 debug: bool = False, 

205 ) -> Tensor: 

206 trans_decoder_tgt = self.trans_decoder_embedding.weight.unsqueeze(1) 

207 trans_decoder_tgt = trans_decoder_tgt.repeat(1, charge.size(0), 1) 

208 trans_decoder_tgt = self.charge_encoder(trans_decoder_tgt, charge, debug=debug) 

209 trans_decoder_tgt = self.nce_encoder(trans_decoder_tgt, nce) 

210 if debug: 

211 logging.debug(f"TD: Shape of query embedding {trans_decoder_tgt.shape}") 

212 

213 spectra_output = self.trans_decoder(memory=src, tgt=trans_decoder_tgt) 

214 if debug: 

215 logging.debug(f"TD: Shape of the output spectra {spectra_output.shape}") 

216 

217 spectra_output = self.peak_decoder(spectra_output) 

218 if debug: 

219 logging.debug(f"TD: Shape of the MLP spectra {spectra_output.shape}") 

220 

221 spectra_output = spectra_output.squeeze(-1).permute(1, 0) 

222 if debug: 

223 logging.debug(f"TD: Shape of the permuted spectra {spectra_output.shape}") 

224 

225 if self.training: 

226 spectra_output = nn.functional.leaky_relu(spectra_output) 

227 else: 

228 spectra_output = nn.functional.relu(spectra_output) 

229 

230 return spectra_output 

231 

232 

233_model_sections = [ 

234 "TransEncoder", 

235 "TransDecoder", 

236 "AAEmbedding", 

237 "MODEmbedding", 

238 "FragmentEmbedding", 

239 "FragmentFFN", 

240 "RTFFN", 

241] 

242 

243 

244class PepTransformerModel(pl.LightningModule): 

245 """PepTransformerModel Predicts retention times and HCD spectra from peptides.""" 

246 

247 accepted_schedulers = ["plateau", "cosine", "onecycle"] 

248 model_sections = _model_sections 

249 __version__ = elfragmentador.__version__ 

250 

251 def __init__( 

252 self, 

253 num_decoder_layers: int = 6, 

254 num_encoder_layers: int = 6, 

255 nhid: int = 2024, 

256 ninp: int = 516, 

257 nhead: int = 4, 

258 dropout: float = 0.1, 

259 lr: float = 1e-4, 

260 scheduler: str = "plateau", 

261 lr_ratio: Union[float, int] = 200, 

262 steps_per_epoch: None = None, 

263 loss_ratio: float = 5, 

264 trainable_sections: List[str] = _model_sections, 

265 *args, 

266 **kwargs, 

267 ) -> None: 

268 """__init__ Instantiates the class. 

269 

270 Generates a new instance of the PepTransformerModel 

271 

272 Parameters 

273 ---------- 

274 num_decoder_layers : int, optional 

275 Number of layers in the transformer decoder, by default 6 

276 num_encoder_layers : int, optional 

277 Number of laters in the transformer encoder, by default 6 

278 nhid : int, optional 

279 Number of dimensions used in the feedforward networks inside 

280 the transformer encoder and decoders, by default 2024 

281 ninp : int, optional 

282 Number of features to pass to the transformer encoder. 

283 The embedding transforms the input to this input, by default 516 

284 nhead : int, optional 

285 Number of multi-attention heads in the transformer, by default 4 

286 dropout : float, optional 

287 dropout, by default 0.1 

288 lr : float, optional 

289 Learning rate, by default 1e-4 

290 scheduler : str, optional 

291 What scheduler to use, check the available ones with 

292 `PepTransformerModel.accepted_schedulers`, by default "plateau" 

293 lr_ratio : Union[float, int], optional 

294 For cosine annealing: 

295 Ratio of the initial learning rate to use with cosine annealing for 

296 instance a lr or 1 and a ratio of 10 would have a minimum learning 

297 rate of 0.1. 

298 

299 For onecycle: 

300 Ratio of the initial lr and and maximum one, 

301 for instance if lr is 0.1 and ratio is 10, the max learn rate 

302 would be 1.0. 

303 

304 by default 200 

305 steps_per_epoch : None, optional 

306 expected number of steps per epoch, used internally to calculate 

307 learning rates when using the oncecycle scheduler, by default None 

308 loss_ratio: float, optional 

309 The ratio of the spectrum to retention time loss to use when adding 

310 before passing to the optimizer. Higher values mean more weight to 

311 spectra with respect to the retention time. By default 5 

312 """ 

313 super().__init__() 

314 self.save_hyperparameters() 

315 

316 # Peptide encoder 

317 self.encoder = _PeptideTransformerEncoder( 

318 ninp=ninp, 

319 dropout=dropout, 

320 nhead=nhead, 

321 nhid=nhid, 

322 layers=num_encoder_layers, 

323 ) 

324 

325 # Peptide decoder 

326 self.decoder = _PeptideTransformerDecoder( 

327 ninp=ninp, 

328 nhead=nhead, 

329 nhid=nhid, 

330 layers=num_decoder_layers, 

331 dropout=dropout, 

332 ) 

333 

334 # On this implementation, the rt predictor is a simple MLP 

335 # that combines the features from the transformer encoder 

336 self.rt_decoder = MLP(ninp, ninp, output_dim=1, num_layers=4) 

337 

338 # Training related things 

339 self.mse_loss = nn.MSELoss() 

340 self.angle_loss = CosineLoss(dim=1, eps=1e-4) 

341 self.lr = lr 

342 

343 assert ( 

344 scheduler in self.accepted_schedulers 

345 ), f"Passed scheduler '{scheduler} is not one of {self.accepted_schedulers}" 

346 self.scheduler = scheduler 

347 self.lr_ratio = lr_ratio 

348 self.steps_per_epoch = steps_per_epoch 

349 self.loss_ratio = loss_ratio 

350 

351 self.model_sections = { 

352 "TransEncoder": self.encoder.transformer_encoder, 

353 "TransDecoder": self.decoder.trans_decoder, 

354 "AAEmbedding": self.encoder.aa_encoder.aa_encoder, 

355 "MODEmbedding": self.encoder.aa_encoder.mod_encoder, 

356 "FragmentEmbedding": self.decoder.trans_decoder_embedding, 

357 "FragmentFFN": self.decoder.peak_decoder, 

358 "RTFFN": self.rt_decoder, 

359 } 

360 

361 self.make_trainable_sections(trainable_sections) 

362 

363 def forward( 

364 self, 

365 src: Tensor, 

366 nce: Tensor, 

367 mods: Optional[Tensor] = None, 

368 charge: Optional[Tensor] = None, 

369 debug: bool = False, 

370 ) -> PredictionResults: 

371 """Forward Generate predictions. 

372 

373 Privides the function for the forward pass to the model. 

374 

375 Parameters 

376 ---------- 

377 src : Tensor 

378 Encoded pepide sequence [B, L] (view details) 

379 nce : Tensor 

380 float Tensor with the charges [B, 1] 

381 mods : Optional[Tensor], optional 

382 Encoded modification sequence [B, L], by default None 

383 charge : Optional[Tensor], optional 

384 long Tensor with the charges [B, 1], by default None 

385 debug : bool, optional 

386 When set, it will log (a lot) of the shapes of the intermediate 

387 tensors inside the model. By default False 

388 

389 Details 

390 ------- 

391 src: 

392 The peptide is encoded as integers for the aminoacid. 

393 "AAA" encoded for a max length of 5 would be 

394 torch.Tensor([ 1, 1, 1, 0, 0]).long() 

395 nce: 

396 Normalized collision energy to use during the prediction. 

397 charge: 

398 A tensor corresponding to the charges of each of the 

399 peptide precursors (long) 

400 mods: 

401 Modifications encoded as integers 

402 

403 Returns 

404 ------- 

405 PredictionResults 

406 A named tuple with two named results; irt and spectra 

407 iRT prediction [B, 1] 

408 Spectra prediction [B, self.num_queries] 

409 """ 

410 if debug: 

411 logging.debug( 

412 f"PT: Shape of inputs src={src.shape}," 

413 f" mods={mods.shape if mods is not None else None}," 

414 f" nce={nce.shape}" 

415 f" charge={charge.shape}" 

416 ) 

417 

418 trans_encoder_output = self.encoder.forward(src=src, mods=mods, debug=debug) 

419 rt_output = self.rt_decoder.forward(trans_encoder_output) 

420 if debug: 

421 logging.debug(f"PT: Shape after RT decoder {rt_output.shape}") 

422 

423 rt_output = rt_output.mean(dim=0) 

424 if debug: 

425 logging.debug(f"PT: Shape of RT output {rt_output.shape}") 

426 

427 spectra_output = self.decoder.forward( 

428 src=trans_encoder_output, charge=charge, nce=nce, debug=debug 

429 ) 

430 

431 if debug: 

432 logging.debug( 

433 f"PT: Final Outputs of shapes {rt_output.shape}, {spectra_output.shape}" 

434 ) 

435 

436 return PredictionResults(rt_output, spectra_output) 

437 

438 def batch_forward( 

439 self, inputs: TrainBatch, debug: bool = False 

440 ) -> PredictionResults: 

441 """batch_forward Forward function that takes a `TrainBatch` as an input. 

442 

443 This function is a wrapper around forward but takes a named tuple as an 

444 input instead of the positional/keword arguments. 

445 

446 Parameters 

447 ---------- 

448 inputs : TrainBatch 

449 Named tuple (check the documentation of that object for which names) 

450 debug : bool, optional 

451 When set, it will log (a lot) of the shapes of the intermediate 

452 tensors inside the model. By default False 

453 

454 Returns 

455 ------- 

456 PredictionResults 

457 A named tuple with two named results; irt and spectra 

458 

459 """ 

460 

461 def unsqueeze_if_needed(x, dims): 

462 if len(x.shape) != dims: 

463 if debug: 

464 logging.debug(f"PT: Unsqueezing tensor of shape {x.shape}") 

465 x = x.unsqueeze(1) 

466 else: 

467 if debug: 

468 logging.debug(f"PT: Skipping Unsqueezing tensor of shape {x.shape}") 

469 return x 

470 

471 if isinstance(inputs, list): 

472 inputs = TrainBatch(*inputs) 

473 

474 out = self.forward( 

475 src=unsqueeze_if_needed(inputs.encoded_sequence, 2), 

476 mods=unsqueeze_if_needed(inputs.encoded_mods, 2), 

477 nce=unsqueeze_if_needed(inputs.nce, 2), 

478 charge=unsqueeze_if_needed(inputs.charge, 2), 

479 debug=debug, 

480 ) 

481 return out 

482 

483 @staticmethod 

484 def torch_batch_from_seq(seq: str, nce: float, charge: int): 

485 """ 

486 Generate an input batch for the model from a sequence string. 

487 

488 Args: 

489 seq (str): String describing the sequence to be predicted, e. "PEPT[PHOSOHO]IDEPINK" 

490 nce (float): Collision energy to use for the prediction, e. 27.0 

491 charge (int): Charge of the precursor to use for the prediction, e. 3 

492 

493 Examples: 

494 >>> PepTransformerModel.torch_batch_from_seq("PEPTIDEPINK", 27.0, 3) 

495 ForwardBatch(src=tensor([[23, 13, 4, 13, 17, ...]]), nce=tensor([[27.]]), mods=tensor([[0, ... 0]]), charge=tensor([[3]])) 

496 """ 

497 encoded_seq, encoded_mods = encoding_decoding.encode_mod_seq(seq) 

498 

499 src = torch.Tensor(encoded_seq).unsqueeze(0).long() 

500 mods = torch.Tensor(encoded_mods).unsqueeze(0).long() 

501 in_charge = torch.Tensor([charge]).unsqueeze(0).long() 

502 in_nce = torch.Tensor([nce]).unsqueeze(0).float() 

503 

504 # This is a named tuple 

505 out = ForwardBatch(src=src, nce=in_nce, mods=mods, charge=in_charge) 

506 return out 

507 

508 def to_torchscript(self): 

509 _fake_input_data_torchscript = self.torch_batch_from_seq( 

510 seq="MYM[OXIDATION]DIFIEDPEPTYDE", charge=3, nce=27.0 

511 ) 

512 

513 bkp_1 = self.decoder.nce_encoder.static_size 

514 self.decoder.nce_encoder.static_size = constants.NUM_FRAG_EMBEDINGS 

515 

516 bkp_2 = self.decoder.charge_encoder.static_size 

517 self.decoder.charge_encoder.static_size = constants.NUM_FRAG_EMBEDINGS 

518 

519 script = super().to_torchscript( 

520 example_inputs=_fake_input_data_torchscript, method="trace" 

521 ) 

522 

523 self.decoder.nce_encoder.static_size = bkp_1 

524 self.decoder.charge_encoder.static_size = bkp_2 

525 

526 return script 

527 

528 def predict_from_seq( 

529 self, seq: str, charge: int, nce: float, as_spectrum=False, debug: bool = False 

530 ) -> Union[PredictionResults, Spectrum]: 

531 """predict_from_seq Predicts spectra from a sequence as a string. 

532 

533 Utility method that gets a sequence as a string, encodes it internally 

534 to the correct input form and outputs the predicted spectra. 

535 

536 Note that the spectra is not decoded as an output, please check 

537 `elfragmentador.encoding_decoding.decode_fragment_tensor` for the 

538 decoding. 

539 

540 The irt is scaled by 100 and is in the Biognosys scale. 

541 

542 TODO: consider if the output should be decoded ... 

543 

544 Parameters 

545 ---------- 

546 seq : str 

547 Sequence to use for prediction, supports modifications in the form 

548 of S[PHOSPHO], S[+80] and T[181] 

549 charge : int 

550 Precursor charge to be assumed during the fragmentation 

551 nce : float 

552 Normalized collision energy to use during the prediction 

553 as_spectrum : bool 

554 Wether to return a Spectrum object instead of the raw tensor predictions 

555 debug : bool, optional 

556 When set, it will write to logging at a debug level (a lot) of the shapes of the intermediate 

557 tensors inside the model. By default False 

558 

559 Returns 

560 ------- 

561 PredictionResults 

562 A named tuple with two named results; irt and spectra 

563 Spectrum 

564 A spectrum object with the predicted spectrum 

565 

566 Examples 

567 -------- 

568 >>> pl.seed_everything(42) 

569 42 

570 >>> my_model = PepTransformerModel() # Or load the model from a checkpoint 

571 >>> _ = my_model.eval() 

572 >>> my_model.predict_from_seq("MYPEPT[PHOSPHO]IDEK", 3, 27) 

573 PredictionResults(irt=tensor([...], grad_fn=<SqueezeBackward1>), \ 

574spectra=tensor([...], grad_fn=<SqueezeBackward1>)) 

575 >>> out = my_model.predict_from_seq("MYPEPT[PHOSPHO]IDEK", 3, 27, as_spectrum=True) 

576 >>> type(out) 

577 <class 'elfragmentador.spectra.Spectrum'> 

578 >>> # my_model.predict_from_seq("MYPEPT[PHOSPHO]IDEK", 3, 27, debug=True) 

579 """ 

580 

581 in_batch = self.torch_batch_from_seq(seq, nce, charge) 

582 

583 if debug: 

584 logging.debug( 

585 f">>PT: PEPTIDE INPUT Shape of peptide" 

586 f" inputs {in_batch.src.shape}, {in_batch.charge.shape}" 

587 ) 

588 

589 # TODO consider if adding GPU inference 

590 out = self.forward(debug=debug, **in_batch._asdict()) 

591 out = PredictionResults(*[x.squeeze(0) for x in out]) 

592 logging.debug(out) 

593 

594 # rt should be in seconds for spectrast ... 

595 # irt should be non-dimensional 

596 if as_spectrum: 

597 

598 out = Spectrum.from_tensors( 

599 sequence_tensor=in_batch.src.squeeze().numpy(), 

600 fragment_tensor=out.spectra / out.spectra.max(), 

601 mod_tensor=in_batch.mods.squeeze().numpy(), 

602 charge=charge, 

603 nce=nce, 

604 rt=float(out.irt) * 100 * 60, 

605 irt=float(out.irt) * 100, 

606 ) 

607 

608 return out 

609 

610 @staticmethod 

611 def add_model_specific_args(parser: _ArgumentGroup) -> _ArgumentGroup: 

612 """add_model_specific_args Adds arguments to a parser. 

613 

614 It is used to add the command line arguments for the training/generation 

615 of the model. 

616 

617 Parameters 

618 ---------- 

619 parser : _ArgumentGroup 

620 An argparser parser (anything that has the `.add_argument` method) 

621 to which the arguments will be added 

622 

623 Returns 

624 ------- 

625 _ArgumentGroup 

626 Same parser with the added arguments 

627 """ 

628 parser.add_argument( 

629 "--num_queries", 

630 default=150, 

631 type=int, 

632 help="Expected encoding length of the spectra", 

633 ) 

634 parser.add_argument( 

635 "--num_decoder_layers", 

636 default=6, 

637 type=int, 

638 help="Number of sub-encoder-layers in the encoder", 

639 ) 

640 parser.add_argument( 

641 "--num_encoder_layers", 

642 default=6, 

643 type=int, 

644 help="Number of sub-encoder-layers in the decoder", 

645 ) 

646 parser.add_argument( 

647 "--nhid", 

648 default=1024, 

649 type=int, 

650 help="Dimension of the feedforward networks", 

651 ) 

652 parser.add_argument( 

653 "--ninp", 

654 default=516, 

655 type=int, 

656 help="Number of input features to the transformer encoder", 

657 ) 

658 parser.add_argument( 

659 "--nhead", default=12, type=int, help="Number of attention heads" 

660 ) 

661 parser.add_argument("--dropout", default=0.1, type=float) 

662 parser.add_argument("--lr", default=1e-4, type=float) 

663 parser.add_argument( 

664 "--scheduler", 

665 default="plateau", 

666 type=str, 

667 help=( 

668 "Scheduler to use during training, " 

669 f"either of {PepTransformerModel.accepted_schedulers}" 

670 ), 

671 ) 

672 parser.add_argument( 

673 "--lr_ratio", 

674 default=200.0, 

675 type=float, 

676 help=( 

677 "For cosine annealing: " 

678 "Ratio of the initial learning rate to use with cosine annealing" 

679 " for instance a lr or 1 and a ratio of 10 would have a minimum" 

680 " learning rate of 0.1\n" 

681 "For onecycle: " 

682 "Ratio of the initial lr and and maximum one, " 

683 "for instance if lr is 0.1 and ratio is 10, the max learn rate" 

684 "would be 1.0" 

685 ), 

686 ) 

687 parser.add_argument( 

688 "--loss_ratio", 

689 default=5.0, 

690 type=float, 

691 help=( 

692 "Ratio between the retention time and the spectrum loss" 

693 " (higher values mean more weight to the spectra loss" 

694 " with respect to the retention time loss)" 

695 ), 

696 ) 

697 parser.add_argument( 

698 "--trainable_secions", 

699 nargs="+", 

700 type=str, 

701 default=PepTransformerModel.model_sections, 

702 help=( 

703 f"Sections of the model to train, " 

704 f"can be any subset of {PepTransformerModel.model_sections}" 

705 ), 

706 ) 

707 

708 return parser 

709 

710 def make_trainable_sections(self, sections: List) -> None: 

711 def set_grad_section(model_section, trainable=True): 

712 """Freezes or unfreezes a model section""" 

713 for param in model_section.parameters(): 

714 param.requires_grad = trainable 

715 

716 logging.warning("Freezing the model") 

717 set_grad_section(self, trainable=False) 

718 

719 for section in sections: 

720 logging.warning(f"Unfreezing {section}") 

721 set_grad_section(self.model_sections[section], trainable=True) 

722 

723 def configure_optimizers( 

724 self, 

725 ) -> Union[ 

726 Tuple[List[AdamW], List[Dict[str, Union[ReduceLROnPlateau, str]]]], 

727 Tuple[List[AdamW], List[Dict[str, Union[CosineAnnealingWarmRestarts, str]]]], 

728 Tuple[List[AdamW], List[Dict[str, Union[OneCycleLR, str]]]], 

729 ]: 

730 """configure_optimizers COnfigures the optimizers for training. 

731 

732 It is internally used by pytorch_lightning during training, so far I 

733 implemented 3 options (set when making the module). 

734 

735 OneCycleLR seems to give the best results overall in the least amount 

736 of time. The only tradeoff that I see is that resuming training does 

737 not seem to be really easy. 

738 

739 Check the pytorch_lightning documentation to see how this is used in the 

740 training loop 

741 

742 Returns 

743 ------- 

744 Union[ 

745 Tuple[List[AdamW], List[Dict[str, Union[ReduceLROnPlateau, str]]]], 

746 Tuple[List[AdamW], List[Dict[str, Union[CosineAnnealingWarmRestarts, str]]]], 

747 Tuple[List[AdamW], List[Dict[str, Union[OneCycleLR, str]]]], 

748 ] 

749 Two lists, one containing the optimizer and another contining the 

750 scheduler. 

751 

752 Raises 

753 ------ 

754 ValueError 

755 Raised when a scheduler that is not supported is requested. 

756 If you want to use another one, please over-write this method 

757 or make a subclass with the modification. (PRs are also welcome) 

758 

759 """ 

760 opt = torch.optim.AdamW( 

761 filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr 

762 ) 

763 

764 if self.scheduler == "plateau": 

765 scheduler_dict = { 

766 "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 

767 opt, mode="min", factor=0.5, patience=2, verbose=True 

768 ), 

769 "interval": "epoch", 

770 "monitor": "val_l", 

771 } 

772 elif self.scheduler == "cosine": 

773 assert self.lr_ratio > 1 

774 scheduler_dict = { 

775 "scheduler": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 

776 opt, 

777 T_0=1, 

778 T_mult=2, 

779 eta_min=self.lr / self.lr_ratio, 

780 last_epoch=-1, 

781 verbose=False, 

782 ), 

783 "interval": "step", 

784 } 

785 elif self.scheduler == "onecycle": 

786 assert self.steps_per_epoch is not None, "Please set steps_per_epoch" 

787 if self.trainer.max_epochs == 1000: 

788 warnings.warn("Max epochs was 1000, make sure you want this") 

789 if self.lr_ratio > 20: 

790 warnings.warn( 

791 f"Provided LR ratio '{self.lr_ratio}' seems a lil high," 

792 " make sure you want that for the OneCycleLR scheduler" 

793 ) 

794 time.sleep(3) # just so the user has time to see the message... 

795 max_lr = self.lr * self.lr_ratio 

796 logging.info( 

797 f">> Scheduler setup: max_lr {max_lr}, " 

798 f"Max Epochs: {self.trainer.max_epochs}, " 

799 f"Steps per epoch: {self.steps_per_epoch}, " 

800 f"Accumulate Batches {self.trainer.accumulate_grad_batches}" 

801 ) 

802 spe = self.steps_per_epoch // self.trainer.accumulate_grad_batches 

803 scheduler_dict = { 

804 "scheduler": torch.optim.lr_scheduler.OneCycleLR( 

805 opt, 

806 max_lr, 

807 epochs=self.trainer.max_epochs, 

808 steps_per_epoch=spe, 

809 ), 

810 "interval": "step", 

811 } 

812 

813 else: 

814 raise ValueError( 

815 "Scheduler should be one of 'plateau' or 'cosine', passed: ", 

816 self.scheduler, 

817 ) 

818 # TODO check if using different optimizers for different parts of the 

819 # model would work better 

820 logging.info(f"\n\n>>> Setting up schedulers:\n\n{scheduler_dict}") 

821 

822 return [opt], [scheduler_dict] 

823 

824 def _step(self, batch: TrainBatch, batch_idx: int) -> Dict[str, Tensor]: 

825 """Run main functionality during training an testing steps. 

826 

827 Internally used in training and evaluation steps during the training 

828 loop in pytorch_lightning. 

829 

830 Does inference, loss calculation, handling of missing values ... 

831 """ 

832 

833 if isinstance(batch, list): 

834 batch = TrainBatch(*batch) 

835 

836 yhat_irt, yhat_spectra = self.batch_forward(batch) 

837 yhat_irt = yhat_irt[~batch.norm_irt.isnan()] 

838 norm_irt = batch.norm_irt[~batch.norm_irt.isnan()] 

839 

840 loss_irt = self.mse_loss(yhat_irt, norm_irt.float()) 

841 loss_spectra = self.angle_loss(yhat_spectra, batch.encoded_spectra).mean() 

842 

843 if len(norm_irt.data) == 0: 

844 total_loss = loss_spectra 

845 else: 

846 total_loss = loss_irt + loss_spectra * self.loss_ratio 

847 total_loss = total_loss / (self.loss_ratio + 1) 

848 

849 out = { 

850 "l": total_loss, 

851 "irt_l": loss_irt, 

852 "spec_l": loss_spectra, 

853 } 

854 

855 assert not torch.isnan(total_loss), logging.error( 

856 f"Fail at... \n Loss: {total_loss},\n" 

857 f"\n loss_irt: {loss_irt}\n" 

858 f"\n loss_spectra: {loss_spectra}\n" 

859 f"\n yhat_spec: {yhat_spectra},\n" 

860 f"\n y_spec: {batch.encoded_spectra}\n" 

861 f"\n y_irt: {norm_irt}, {len(norm_irt.data)}" 

862 ) 

863 

864 return out 

865 

866 def training_step( 

867 self, batch: TrainBatch, batch_idx: Optional[int] = None 

868 ) -> Dict[str, Tensor]: 

869 """See pytorch_lightning documentation.""" 

870 step_out = self._step(batch, batch_idx=batch_idx) 

871 log_dict = {"train_" + k: v for k, v in step_out.items()} 

872 log_dict.update({"LR": self.trainer.optimizers[0].param_groups[0]["lr"]}) 

873 

874 self.log_dict( 

875 log_dict, 

876 prog_bar=True, 

877 # reduce_fx=nanmean, 

878 ) 

879 

880 return {"loss": step_out["l"]} 

881 

882 def validation_step( 

883 self, batch: TrainBatch, batch_idx: Optional[int] = None 

884 ) -> None: 

885 """See pytorch_lightning documentation.""" 

886 step_out = self._step(batch, batch_idx=batch_idx) 

887 log_dict = {"val_" + k: v for k, v in step_out.items()} 

888 

889 self.log_dict(log_dict, prog_bar=True, reduce_fx=nanmean) 

890 

891 __doc__ += "\n\n" + __init__.__doc__ + "\n\n" + forward.__doc__