Coverage for model.py : 61%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
3try:
4 from typing import Dict, List, Tuple, Optional, Union, Literal
6 LiteralFalse = Literal[False]
7except ImportError:
8 # Python pre-3.8 compatibility
9 from typing import Dict, List, Tuple, Optional, Union, NewType
11 LiteralFalse = NewType("LiteralFalse", bool)
13import warnings
14import math
15import time
16from collections import namedtuple
18import torch
19from torch import Tensor, nn
20import pytorch_lightning as pl
22from argparse import _ArgumentGroup, ArgumentParser
24import elfragmentador
25from elfragmentador import constants
26from elfragmentador import encoding_decoding
27from elfragmentador.spectra import Spectrum
28from elfragmentador.datamodules import TrainBatch
29from elfragmentador.metrics import CosineLoss
30from elfragmentador.nn_encoding import (
31 ConcatenationEncoder,
32 AASequenceEmbedding,
33)
34from elfragmentador.math_utils import nanmean
35from torch.optim.adamw import AdamW
36from torch.optim.lr_scheduler import (
37 CosineAnnealingWarmRestarts,
38 OneCycleLR,
39 ReduceLROnPlateau,
40)
42PredictionResults = namedtuple("PredictionResults", "irt spectra")
43ForwardBatch = namedtuple("ForwardBatch", "src nce mods charge")
46class MLP(nn.Module):
47 """MLP implements a very simple multi-layer perceptron (also called FFN).
49 Concatenates hidden linear layers with activations for n layers.
50 This implementation uses gelu instead of relu
51 (linear > gelu) * (n-1) > linear
53 Based on: https://github.com/facebookresearch/detr/blob/models/detr.py#L289
54 """
56 def __init__(
57 self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
58 ) -> None:
59 """__init__ create a new instance of the MLP.
61 Parameters
62 ----------
63 input_dim : int
64 Expected dimensions for the input
65 hidden_dim : int
66 Number of dimensions of the hidden layers
67 output_dim : int
68 Output dimensions
69 num_layers : int
70 Number of layers (total)
71 """
72 super().__init__()
73 self.num_layers = num_layers
74 h = [hidden_dim] * (num_layers - 1)
75 self.layers = nn.ModuleList(
76 nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
77 )
79 def forward(self, x: Tensor) -> Tensor:
80 """Forward pass over the network.
82 Parameters
83 ----------
84 x : Tensor
85 Dimensions should match the ones specified instantiating the class
87 Returns
88 -------
89 Tensor
90 The dims of this tensor are defined when instantiating the class
92 Examples
93 --------
94 >>> pl.seed_everything(42)
95 42
96 >>> net = MLP(1000, 512, 2, 10)
97 >>> out = net.forward(torch.rand([5, 1000]))
98 >>> out
99 tensor([[-0.0061, -0.0219],
100 [-0.0061, -0.0219],
101 [-0.0061, -0.0220],
102 [-0.0061, -0.0220],
103 [-0.0061, -0.0219]], grad_fn=<AddmmBackward>)
104 >>> out.shape
105 torch.Size([5, 2])
106 """
107 for i, layer in enumerate(self.layers):
108 x = (
109 torch.nn.functional.gelu(layer(x))
110 if i < self.num_layers - 1
111 else layer(x)
112 )
113 return x
116class _PeptideTransformerEncoder(torch.nn.Module):
117 def __init__(
118 self, ninp: int, dropout: float, nhead: int, nhid: int, layers: int
119 ) -> None:
120 super().__init__()
122 # Aminoacid embedding
123 self.aa_encoder = AASequenceEmbedding(ninp=ninp, position_ratio=0.1)
125 # Transformer encoder sections
126 encoder_layers = nn.TransformerEncoderLayer(
127 d_model=ninp,
128 nhead=nhead,
129 dim_feedforward=nhid,
130 dropout=dropout,
131 activation="gelu",
132 )
133 self.transformer_encoder = nn.TransformerEncoder(encoder_layers, layers)
135 def forward(self, src: Tensor, mods: Tensor, debug: bool = False) -> Tensor:
136 trans_encoder_mask = ~src.bool()
137 if debug:
138 logging.debug(f"TE: Shape of mask {trans_encoder_mask.shape}")
140 src = self.aa_encoder.forward(src=src, mods=mods, debug=debug)
141 if debug:
142 logging.debug(f"TE: Shape after AASequence encoder {src.shape}")
144 trans_encoder_output = self.transformer_encoder.forward(
145 src, src_key_padding_mask=trans_encoder_mask
146 )
147 if debug:
148 logging.debug(f"TE: Shape after trans encoder {trans_encoder_output.shape}")
150 return trans_encoder_output
153class _PeptideTransformerDecoder(torch.nn.Module):
154 def __init__(
155 self,
156 ninp: int,
157 nhead: int,
158 nhid: int,
159 layers: int,
160 dropout: float,
161 charge_dims_pct: float = 0.05,
162 nce_dims_pct: float = 0.05,
163 ) -> None:
164 super().__init__()
165 logging.info(
166 f"Creating TransformerDecoder nhid=nhid, ninp={ninp} nhead={nhead} layers={layers}"
167 )
168 charge_dims = math.ceil(ninp * charge_dims_pct)
169 nce_dims = math.ceil(ninp * nce_dims_pct)
170 n_embeds = ninp - (charge_dims + nce_dims)
172 decoder_layer = nn.TransformerDecoderLayer(
173 d_model=ninp,
174 nhead=nhead,
175 dim_feedforward=nhid,
176 dropout=dropout,
177 activation="gelu",
178 )
179 self.trans_decoder = nn.TransformerDecoder(decoder_layer, num_layers=layers)
180 self.peak_decoder = MLP(ninp, ninp, output_dim=1, num_layers=3)
182 logging.info(
183 f"Creating embedding for spectra of length {constants.NUM_FRAG_EMBEDINGS}"
184 )
185 self.trans_decoder_embedding = nn.Embedding(
186 constants.NUM_FRAG_EMBEDINGS, n_embeds
187 )
188 self.charge_encoder = ConcatenationEncoder(
189 dims_add=charge_dims, dropout=dropout, max_val=10.0
190 )
191 self.nce_encoder = ConcatenationEncoder(
192 dims_add=nce_dims, dropout=dropout, max_val=100.0
193 )
195 def init_weights(self):
196 initrange = 0.1
197 nn.init.uniform_(self.trans_decoder_embedding.weight, -initrange, initrange)
199 def forward(
200 self,
201 src: Tensor,
202 charge: Tensor,
203 nce: Tensor,
204 debug: bool = False,
205 ) -> Tensor:
206 trans_decoder_tgt = self.trans_decoder_embedding.weight.unsqueeze(1)
207 trans_decoder_tgt = trans_decoder_tgt.repeat(1, charge.size(0), 1)
208 trans_decoder_tgt = self.charge_encoder(trans_decoder_tgt, charge, debug=debug)
209 trans_decoder_tgt = self.nce_encoder(trans_decoder_tgt, nce)
210 if debug:
211 logging.debug(f"TD: Shape of query embedding {trans_decoder_tgt.shape}")
213 spectra_output = self.trans_decoder(memory=src, tgt=trans_decoder_tgt)
214 if debug:
215 logging.debug(f"TD: Shape of the output spectra {spectra_output.shape}")
217 spectra_output = self.peak_decoder(spectra_output)
218 if debug:
219 logging.debug(f"TD: Shape of the MLP spectra {spectra_output.shape}")
221 spectra_output = spectra_output.squeeze(-1).permute(1, 0)
222 if debug:
223 logging.debug(f"TD: Shape of the permuted spectra {spectra_output.shape}")
225 if self.training:
226 spectra_output = nn.functional.leaky_relu(spectra_output)
227 else:
228 spectra_output = nn.functional.relu(spectra_output)
230 return spectra_output
233_model_sections = [
234 "TransEncoder",
235 "TransDecoder",
236 "AAEmbedding",
237 "MODEmbedding",
238 "FragmentEmbedding",
239 "FragmentFFN",
240 "RTFFN",
241]
244class PepTransformerModel(pl.LightningModule):
245 """PepTransformerModel Predicts retention times and HCD spectra from peptides."""
247 accepted_schedulers = ["plateau", "cosine", "onecycle"]
248 model_sections = _model_sections
249 __version__ = elfragmentador.__version__
251 def __init__(
252 self,
253 num_decoder_layers: int = 6,
254 num_encoder_layers: int = 6,
255 nhid: int = 2024,
256 ninp: int = 516,
257 nhead: int = 4,
258 dropout: float = 0.1,
259 lr: float = 1e-4,
260 scheduler: str = "plateau",
261 lr_ratio: Union[float, int] = 200,
262 steps_per_epoch: None = None,
263 loss_ratio: float = 5,
264 trainable_sections: List[str] = _model_sections,
265 *args,
266 **kwargs,
267 ) -> None:
268 """__init__ Instantiates the class.
270 Generates a new instance of the PepTransformerModel
272 Parameters
273 ----------
274 num_decoder_layers : int, optional
275 Number of layers in the transformer decoder, by default 6
276 num_encoder_layers : int, optional
277 Number of laters in the transformer encoder, by default 6
278 nhid : int, optional
279 Number of dimensions used in the feedforward networks inside
280 the transformer encoder and decoders, by default 2024
281 ninp : int, optional
282 Number of features to pass to the transformer encoder.
283 The embedding transforms the input to this input, by default 516
284 nhead : int, optional
285 Number of multi-attention heads in the transformer, by default 4
286 dropout : float, optional
287 dropout, by default 0.1
288 lr : float, optional
289 Learning rate, by default 1e-4
290 scheduler : str, optional
291 What scheduler to use, check the available ones with
292 `PepTransformerModel.accepted_schedulers`, by default "plateau"
293 lr_ratio : Union[float, int], optional
294 For cosine annealing:
295 Ratio of the initial learning rate to use with cosine annealing for
296 instance a lr or 1 and a ratio of 10 would have a minimum learning
297 rate of 0.1.
299 For onecycle:
300 Ratio of the initial lr and and maximum one,
301 for instance if lr is 0.1 and ratio is 10, the max learn rate
302 would be 1.0.
304 by default 200
305 steps_per_epoch : None, optional
306 expected number of steps per epoch, used internally to calculate
307 learning rates when using the oncecycle scheduler, by default None
308 loss_ratio: float, optional
309 The ratio of the spectrum to retention time loss to use when adding
310 before passing to the optimizer. Higher values mean more weight to
311 spectra with respect to the retention time. By default 5
312 """
313 super().__init__()
314 self.save_hyperparameters()
316 # Peptide encoder
317 self.encoder = _PeptideTransformerEncoder(
318 ninp=ninp,
319 dropout=dropout,
320 nhead=nhead,
321 nhid=nhid,
322 layers=num_encoder_layers,
323 )
325 # Peptide decoder
326 self.decoder = _PeptideTransformerDecoder(
327 ninp=ninp,
328 nhead=nhead,
329 nhid=nhid,
330 layers=num_decoder_layers,
331 dropout=dropout,
332 )
334 # On this implementation, the rt predictor is a simple MLP
335 # that combines the features from the transformer encoder
336 self.rt_decoder = MLP(ninp, ninp, output_dim=1, num_layers=4)
338 # Training related things
339 self.mse_loss = nn.MSELoss()
340 self.angle_loss = CosineLoss(dim=1, eps=1e-4)
341 self.lr = lr
343 assert (
344 scheduler in self.accepted_schedulers
345 ), f"Passed scheduler '{scheduler} is not one of {self.accepted_schedulers}"
346 self.scheduler = scheduler
347 self.lr_ratio = lr_ratio
348 self.steps_per_epoch = steps_per_epoch
349 self.loss_ratio = loss_ratio
351 self.model_sections = {
352 "TransEncoder": self.encoder.transformer_encoder,
353 "TransDecoder": self.decoder.trans_decoder,
354 "AAEmbedding": self.encoder.aa_encoder.aa_encoder,
355 "MODEmbedding": self.encoder.aa_encoder.mod_encoder,
356 "FragmentEmbedding": self.decoder.trans_decoder_embedding,
357 "FragmentFFN": self.decoder.peak_decoder,
358 "RTFFN": self.rt_decoder,
359 }
361 self.make_trainable_sections(trainable_sections)
363 def forward(
364 self,
365 src: Tensor,
366 nce: Tensor,
367 mods: Optional[Tensor] = None,
368 charge: Optional[Tensor] = None,
369 debug: bool = False,
370 ) -> PredictionResults:
371 """Forward Generate predictions.
373 Privides the function for the forward pass to the model.
375 Parameters
376 ----------
377 src : Tensor
378 Encoded pepide sequence [B, L] (view details)
379 nce : Tensor
380 float Tensor with the charges [B, 1]
381 mods : Optional[Tensor], optional
382 Encoded modification sequence [B, L], by default None
383 charge : Optional[Tensor], optional
384 long Tensor with the charges [B, 1], by default None
385 debug : bool, optional
386 When set, it will log (a lot) of the shapes of the intermediate
387 tensors inside the model. By default False
389 Details
390 -------
391 src:
392 The peptide is encoded as integers for the aminoacid.
393 "AAA" encoded for a max length of 5 would be
394 torch.Tensor([ 1, 1, 1, 0, 0]).long()
395 nce:
396 Normalized collision energy to use during the prediction.
397 charge:
398 A tensor corresponding to the charges of each of the
399 peptide precursors (long)
400 mods:
401 Modifications encoded as integers
403 Returns
404 -------
405 PredictionResults
406 A named tuple with two named results; irt and spectra
407 iRT prediction [B, 1]
408 Spectra prediction [B, self.num_queries]
409 """
410 if debug:
411 logging.debug(
412 f"PT: Shape of inputs src={src.shape},"
413 f" mods={mods.shape if mods is not None else None},"
414 f" nce={nce.shape}"
415 f" charge={charge.shape}"
416 )
418 trans_encoder_output = self.encoder.forward(src=src, mods=mods, debug=debug)
419 rt_output = self.rt_decoder.forward(trans_encoder_output)
420 if debug:
421 logging.debug(f"PT: Shape after RT decoder {rt_output.shape}")
423 rt_output = rt_output.mean(dim=0)
424 if debug:
425 logging.debug(f"PT: Shape of RT output {rt_output.shape}")
427 spectra_output = self.decoder.forward(
428 src=trans_encoder_output, charge=charge, nce=nce, debug=debug
429 )
431 if debug:
432 logging.debug(
433 f"PT: Final Outputs of shapes {rt_output.shape}, {spectra_output.shape}"
434 )
436 return PredictionResults(rt_output, spectra_output)
438 def batch_forward(
439 self, inputs: TrainBatch, debug: bool = False
440 ) -> PredictionResults:
441 """batch_forward Forward function that takes a `TrainBatch` as an input.
443 This function is a wrapper around forward but takes a named tuple as an
444 input instead of the positional/keword arguments.
446 Parameters
447 ----------
448 inputs : TrainBatch
449 Named tuple (check the documentation of that object for which names)
450 debug : bool, optional
451 When set, it will log (a lot) of the shapes of the intermediate
452 tensors inside the model. By default False
454 Returns
455 -------
456 PredictionResults
457 A named tuple with two named results; irt and spectra
459 """
461 def unsqueeze_if_needed(x, dims):
462 if len(x.shape) != dims:
463 if debug:
464 logging.debug(f"PT: Unsqueezing tensor of shape {x.shape}")
465 x = x.unsqueeze(1)
466 else:
467 if debug:
468 logging.debug(f"PT: Skipping Unsqueezing tensor of shape {x.shape}")
469 return x
471 if isinstance(inputs, list):
472 inputs = TrainBatch(*inputs)
474 out = self.forward(
475 src=unsqueeze_if_needed(inputs.encoded_sequence, 2),
476 mods=unsqueeze_if_needed(inputs.encoded_mods, 2),
477 nce=unsqueeze_if_needed(inputs.nce, 2),
478 charge=unsqueeze_if_needed(inputs.charge, 2),
479 debug=debug,
480 )
481 return out
483 @staticmethod
484 def torch_batch_from_seq(seq: str, nce: float, charge: int):
485 """
486 Generate an input batch for the model from a sequence string.
488 Args:
489 seq (str): String describing the sequence to be predicted, e. "PEPT[PHOSOHO]IDEPINK"
490 nce (float): Collision energy to use for the prediction, e. 27.0
491 charge (int): Charge of the precursor to use for the prediction, e. 3
493 Examples:
494 >>> PepTransformerModel.torch_batch_from_seq("PEPTIDEPINK", 27.0, 3)
495 ForwardBatch(src=tensor([[23, 13, 4, 13, 17, ...]]), nce=tensor([[27.]]), mods=tensor([[0, ... 0]]), charge=tensor([[3]]))
496 """
497 encoded_seq, encoded_mods = encoding_decoding.encode_mod_seq(seq)
499 src = torch.Tensor(encoded_seq).unsqueeze(0).long()
500 mods = torch.Tensor(encoded_mods).unsqueeze(0).long()
501 in_charge = torch.Tensor([charge]).unsqueeze(0).long()
502 in_nce = torch.Tensor([nce]).unsqueeze(0).float()
504 # This is a named tuple
505 out = ForwardBatch(src=src, nce=in_nce, mods=mods, charge=in_charge)
506 return out
508 def to_torchscript(self):
509 _fake_input_data_torchscript = self.torch_batch_from_seq(
510 seq="MYM[OXIDATION]DIFIEDPEPTYDE", charge=3, nce=27.0
511 )
513 bkp_1 = self.decoder.nce_encoder.static_size
514 self.decoder.nce_encoder.static_size = constants.NUM_FRAG_EMBEDINGS
516 bkp_2 = self.decoder.charge_encoder.static_size
517 self.decoder.charge_encoder.static_size = constants.NUM_FRAG_EMBEDINGS
519 script = super().to_torchscript(
520 example_inputs=_fake_input_data_torchscript, method="trace"
521 )
523 self.decoder.nce_encoder.static_size = bkp_1
524 self.decoder.charge_encoder.static_size = bkp_2
526 return script
528 def predict_from_seq(
529 self, seq: str, charge: int, nce: float, as_spectrum=False, debug: bool = False
530 ) -> Union[PredictionResults, Spectrum]:
531 """predict_from_seq Predicts spectra from a sequence as a string.
533 Utility method that gets a sequence as a string, encodes it internally
534 to the correct input form and outputs the predicted spectra.
536 Note that the spectra is not decoded as an output, please check
537 `elfragmentador.encoding_decoding.decode_fragment_tensor` for the
538 decoding.
540 The irt is scaled by 100 and is in the Biognosys scale.
542 TODO: consider if the output should be decoded ...
544 Parameters
545 ----------
546 seq : str
547 Sequence to use for prediction, supports modifications in the form
548 of S[PHOSPHO], S[+80] and T[181]
549 charge : int
550 Precursor charge to be assumed during the fragmentation
551 nce : float
552 Normalized collision energy to use during the prediction
553 as_spectrum : bool
554 Wether to return a Spectrum object instead of the raw tensor predictions
555 debug : bool, optional
556 When set, it will write to logging at a debug level (a lot) of the shapes of the intermediate
557 tensors inside the model. By default False
559 Returns
560 -------
561 PredictionResults
562 A named tuple with two named results; irt and spectra
563 Spectrum
564 A spectrum object with the predicted spectrum
566 Examples
567 --------
568 >>> pl.seed_everything(42)
569 42
570 >>> my_model = PepTransformerModel() # Or load the model from a checkpoint
571 >>> _ = my_model.eval()
572 >>> my_model.predict_from_seq("MYPEPT[PHOSPHO]IDEK", 3, 27)
573 PredictionResults(irt=tensor([...], grad_fn=<SqueezeBackward1>), \
574spectra=tensor([...], grad_fn=<SqueezeBackward1>))
575 >>> out = my_model.predict_from_seq("MYPEPT[PHOSPHO]IDEK", 3, 27, as_spectrum=True)
576 >>> type(out)
577 <class 'elfragmentador.spectra.Spectrum'>
578 >>> # my_model.predict_from_seq("MYPEPT[PHOSPHO]IDEK", 3, 27, debug=True)
579 """
581 in_batch = self.torch_batch_from_seq(seq, nce, charge)
583 if debug:
584 logging.debug(
585 f">>PT: PEPTIDE INPUT Shape of peptide"
586 f" inputs {in_batch.src.shape}, {in_batch.charge.shape}"
587 )
589 # TODO consider if adding GPU inference
590 out = self.forward(debug=debug, **in_batch._asdict())
591 out = PredictionResults(*[x.squeeze(0) for x in out])
592 logging.debug(out)
594 # rt should be in seconds for spectrast ...
595 # irt should be non-dimensional
596 if as_spectrum:
598 out = Spectrum.from_tensors(
599 sequence_tensor=in_batch.src.squeeze().numpy(),
600 fragment_tensor=out.spectra / out.spectra.max(),
601 mod_tensor=in_batch.mods.squeeze().numpy(),
602 charge=charge,
603 nce=nce,
604 rt=float(out.irt) * 100 * 60,
605 irt=float(out.irt) * 100,
606 )
608 return out
610 @staticmethod
611 def add_model_specific_args(parser: _ArgumentGroup) -> _ArgumentGroup:
612 """add_model_specific_args Adds arguments to a parser.
614 It is used to add the command line arguments for the training/generation
615 of the model.
617 Parameters
618 ----------
619 parser : _ArgumentGroup
620 An argparser parser (anything that has the `.add_argument` method)
621 to which the arguments will be added
623 Returns
624 -------
625 _ArgumentGroup
626 Same parser with the added arguments
627 """
628 parser.add_argument(
629 "--num_queries",
630 default=150,
631 type=int,
632 help="Expected encoding length of the spectra",
633 )
634 parser.add_argument(
635 "--num_decoder_layers",
636 default=6,
637 type=int,
638 help="Number of sub-encoder-layers in the encoder",
639 )
640 parser.add_argument(
641 "--num_encoder_layers",
642 default=6,
643 type=int,
644 help="Number of sub-encoder-layers in the decoder",
645 )
646 parser.add_argument(
647 "--nhid",
648 default=1024,
649 type=int,
650 help="Dimension of the feedforward networks",
651 )
652 parser.add_argument(
653 "--ninp",
654 default=516,
655 type=int,
656 help="Number of input features to the transformer encoder",
657 )
658 parser.add_argument(
659 "--nhead", default=12, type=int, help="Number of attention heads"
660 )
661 parser.add_argument("--dropout", default=0.1, type=float)
662 parser.add_argument("--lr", default=1e-4, type=float)
663 parser.add_argument(
664 "--scheduler",
665 default="plateau",
666 type=str,
667 help=(
668 "Scheduler to use during training, "
669 f"either of {PepTransformerModel.accepted_schedulers}"
670 ),
671 )
672 parser.add_argument(
673 "--lr_ratio",
674 default=200.0,
675 type=float,
676 help=(
677 "For cosine annealing: "
678 "Ratio of the initial learning rate to use with cosine annealing"
679 " for instance a lr or 1 and a ratio of 10 would have a minimum"
680 " learning rate of 0.1\n"
681 "For onecycle: "
682 "Ratio of the initial lr and and maximum one, "
683 "for instance if lr is 0.1 and ratio is 10, the max learn rate"
684 "would be 1.0"
685 ),
686 )
687 parser.add_argument(
688 "--loss_ratio",
689 default=5.0,
690 type=float,
691 help=(
692 "Ratio between the retention time and the spectrum loss"
693 " (higher values mean more weight to the spectra loss"
694 " with respect to the retention time loss)"
695 ),
696 )
697 parser.add_argument(
698 "--trainable_secions",
699 nargs="+",
700 type=str,
701 default=PepTransformerModel.model_sections,
702 help=(
703 f"Sections of the model to train, "
704 f"can be any subset of {PepTransformerModel.model_sections}"
705 ),
706 )
708 return parser
710 def make_trainable_sections(self, sections: List) -> None:
711 def set_grad_section(model_section, trainable=True):
712 """Freezes or unfreezes a model section"""
713 for param in model_section.parameters():
714 param.requires_grad = trainable
716 logging.warning("Freezing the model")
717 set_grad_section(self, trainable=False)
719 for section in sections:
720 logging.warning(f"Unfreezing {section}")
721 set_grad_section(self.model_sections[section], trainable=True)
723 def configure_optimizers(
724 self,
725 ) -> Union[
726 Tuple[List[AdamW], List[Dict[str, Union[ReduceLROnPlateau, str]]]],
727 Tuple[List[AdamW], List[Dict[str, Union[CosineAnnealingWarmRestarts, str]]]],
728 Tuple[List[AdamW], List[Dict[str, Union[OneCycleLR, str]]]],
729 ]:
730 """configure_optimizers COnfigures the optimizers for training.
732 It is internally used by pytorch_lightning during training, so far I
733 implemented 3 options (set when making the module).
735 OneCycleLR seems to give the best results overall in the least amount
736 of time. The only tradeoff that I see is that resuming training does
737 not seem to be really easy.
739 Check the pytorch_lightning documentation to see how this is used in the
740 training loop
742 Returns
743 -------
744 Union[
745 Tuple[List[AdamW], List[Dict[str, Union[ReduceLROnPlateau, str]]]],
746 Tuple[List[AdamW], List[Dict[str, Union[CosineAnnealingWarmRestarts, str]]]],
747 Tuple[List[AdamW], List[Dict[str, Union[OneCycleLR, str]]]],
748 ]
749 Two lists, one containing the optimizer and another contining the
750 scheduler.
752 Raises
753 ------
754 ValueError
755 Raised when a scheduler that is not supported is requested.
756 If you want to use another one, please over-write this method
757 or make a subclass with the modification. (PRs are also welcome)
759 """
760 opt = torch.optim.AdamW(
761 filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr
762 )
764 if self.scheduler == "plateau":
765 scheduler_dict = {
766 "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
767 opt, mode="min", factor=0.5, patience=2, verbose=True
768 ),
769 "interval": "epoch",
770 "monitor": "val_l",
771 }
772 elif self.scheduler == "cosine":
773 assert self.lr_ratio > 1
774 scheduler_dict = {
775 "scheduler": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
776 opt,
777 T_0=1,
778 T_mult=2,
779 eta_min=self.lr / self.lr_ratio,
780 last_epoch=-1,
781 verbose=False,
782 ),
783 "interval": "step",
784 }
785 elif self.scheduler == "onecycle":
786 assert self.steps_per_epoch is not None, "Please set steps_per_epoch"
787 if self.trainer.max_epochs == 1000:
788 warnings.warn("Max epochs was 1000, make sure you want this")
789 if self.lr_ratio > 20:
790 warnings.warn(
791 f"Provided LR ratio '{self.lr_ratio}' seems a lil high,"
792 " make sure you want that for the OneCycleLR scheduler"
793 )
794 time.sleep(3) # just so the user has time to see the message...
795 max_lr = self.lr * self.lr_ratio
796 logging.info(
797 f">> Scheduler setup: max_lr {max_lr}, "
798 f"Max Epochs: {self.trainer.max_epochs}, "
799 f"Steps per epoch: {self.steps_per_epoch}, "
800 f"Accumulate Batches {self.trainer.accumulate_grad_batches}"
801 )
802 spe = self.steps_per_epoch // self.trainer.accumulate_grad_batches
803 scheduler_dict = {
804 "scheduler": torch.optim.lr_scheduler.OneCycleLR(
805 opt,
806 max_lr,
807 epochs=self.trainer.max_epochs,
808 steps_per_epoch=spe,
809 ),
810 "interval": "step",
811 }
813 else:
814 raise ValueError(
815 "Scheduler should be one of 'plateau' or 'cosine', passed: ",
816 self.scheduler,
817 )
818 # TODO check if using different optimizers for different parts of the
819 # model would work better
820 logging.info(f"\n\n>>> Setting up schedulers:\n\n{scheduler_dict}")
822 return [opt], [scheduler_dict]
824 def _step(self, batch: TrainBatch, batch_idx: int) -> Dict[str, Tensor]:
825 """Run main functionality during training an testing steps.
827 Internally used in training and evaluation steps during the training
828 loop in pytorch_lightning.
830 Does inference, loss calculation, handling of missing values ...
831 """
833 if isinstance(batch, list):
834 batch = TrainBatch(*batch)
836 yhat_irt, yhat_spectra = self.batch_forward(batch)
837 yhat_irt = yhat_irt[~batch.norm_irt.isnan()]
838 norm_irt = batch.norm_irt[~batch.norm_irt.isnan()]
840 loss_irt = self.mse_loss(yhat_irt, norm_irt.float())
841 loss_spectra = self.angle_loss(yhat_spectra, batch.encoded_spectra).mean()
843 if len(norm_irt.data) == 0:
844 total_loss = loss_spectra
845 else:
846 total_loss = loss_irt + loss_spectra * self.loss_ratio
847 total_loss = total_loss / (self.loss_ratio + 1)
849 out = {
850 "l": total_loss,
851 "irt_l": loss_irt,
852 "spec_l": loss_spectra,
853 }
855 assert not torch.isnan(total_loss), logging.error(
856 f"Fail at... \n Loss: {total_loss},\n"
857 f"\n loss_irt: {loss_irt}\n"
858 f"\n loss_spectra: {loss_spectra}\n"
859 f"\n yhat_spec: {yhat_spectra},\n"
860 f"\n y_spec: {batch.encoded_spectra}\n"
861 f"\n y_irt: {norm_irt}, {len(norm_irt.data)}"
862 )
864 return out
866 def training_step(
867 self, batch: TrainBatch, batch_idx: Optional[int] = None
868 ) -> Dict[str, Tensor]:
869 """See pytorch_lightning documentation."""
870 step_out = self._step(batch, batch_idx=batch_idx)
871 log_dict = {"train_" + k: v for k, v in step_out.items()}
872 log_dict.update({"LR": self.trainer.optimizers[0].param_groups[0]["lr"]})
874 self.log_dict(
875 log_dict,
876 prog_bar=True,
877 # reduce_fx=nanmean,
878 )
880 return {"loss": step_out["l"]}
882 def validation_step(
883 self, batch: TrainBatch, batch_idx: Optional[int] = None
884 ) -> None:
885 """See pytorch_lightning documentation."""
886 step_out = self._step(batch, batch_idx=batch_idx)
887 log_dict = {"val_" + k: v for k, v in step_out.items()}
889 self.log_dict(log_dict, prog_bar=True, reduce_fx=nanmean)
891 __doc__ += "\n\n" + __init__.__doc__ + "\n\n" + forward.__doc__