Coverage for train.py : 33%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
2import logging
3from argparse import Namespace, ArgumentParser, ArgumentDefaultsHelpFormatter
5import torch
6import pytorch_lightning as pl
7from pytorch_lightning.loggers import WandbLogger
9from elfragmentador import datamodules, model
10import elfragmentador as tp
11from elfragmentador.model import PepTransformerModel
12from pytorch_lightning.callbacks.early_stopping import EarlyStopping
13from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
14from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
15from pytorch_lightning.loggers.wandb import WandbLogger
16from typing import Dict, List, Union
19def build_train_parser() -> ArgumentParser:
20 parser = ArgumentParser(add_help=False)
22 program_parser = parser.add_argument_group(
23 "Program Parameters",
24 "Program level parameters, these should not change the outcome of the run",
25 )
26 model_parser = parser.add_argument_group(
27 "Model Parameters",
28 "Parameters that modify the model or its training (learn rate, scheduler, layers, dimension ...)",
29 )
30 data_parser = parser.add_argument_group(
31 "Data Parameters", "Parameters for the loading of data"
32 )
33 trainer_parser = parser.add_argument_group(
34 "Trainer Parameters", "Parameters that modify the model or its training"
35 )
37 # add PROGRAM level args
38 program_parser.add_argument(
39 "--run_name",
40 type=str,
41 default=f"ElFragmentador",
42 help="Name to be given to the run (logging)",
43 )
44 program_parser.add_argument(
45 "--wandb_project",
46 type=str,
47 default="rttransformer",
48 help="Wandb project to log to, check out wandb... please",
49 )
50 trainer_parser.add_argument(
51 "--terminator_patience",
52 type=int,
53 default="5",
54 help="Patience for early termination",
55 )
56 trainer_parser.add_argument(
57 "--from_checkpoint",
58 type=str,
59 default=None,
60 help="The path of a checkpoint to copy weights from before training",
61 )
63 # add model specific args
64 model_parser = model.PepTransformerModel.add_model_specific_args(model_parser)
66 # add data specific args
67 data_parser = datamodules.PeptideDataModule.add_model_specific_args(data_parser)
69 # add all the available trainer options to argparse
70 # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
71 t_parser = ArgumentParser(add_help=False)
72 t_parser = pl.Trainer.add_argparse_args(t_parser)
74 if torch.cuda.is_available():
75 t_parser.set_defaults(gpus=-1)
76 t_parser.set_defaults(precision=16)
78 parser = ArgumentParser(
79 parents=[t_parser, parser], formatter_class=ArgumentDefaultsHelpFormatter
80 )
82 return parser
85def get_callbacks(
86 run_name: str, termination_patience: int = 20, wandb_project: str = "rttransformer"
87) -> Dict[
88 str,
89 Union[
90 WandbLogger, List[Union[LearningRateMonitor, ModelCheckpoint, EarlyStopping]]
91 ],
92]:
93 complete_run_name = f"{tp.__version__}_{run_name}"
94 wandb_logger = WandbLogger(complete_run_name, project=wandb_project)
95 lr_monitor = pl.callbacks.lr_monitor.LearningRateMonitor()
96 checkpointer = pl.callbacks.ModelCheckpoint(
97 monitor="val_l",
98 verbose=True,
99 save_top_k=2,
100 save_weights_only=True,
101 dirpath=".",
102 save_last=True,
103 mode="min",
104 filename=complete_run_name + "_{v_l:.6f}_{epoch:03d}",
105 )
107 terminator = pl.callbacks.early_stopping.EarlyStopping(
108 monitor="val_l",
109 min_delta=0.00,
110 patience=termination_patience,
111 verbose=False,
112 mode="min",
113 )
115 return {"logger": wandb_logger, "callbacks": [lr_monitor, checkpointer, terminator]}
118def main_train(model: PepTransformerModel, args: Namespace) -> None:
119 # TODO add loggging levela and a more structured logger ...
120 logging.info(model)
121 datamodule = datamodules.PeptideDataModule(
122 batch_size=args.batch_size,
123 base_dir=args.data_dir,
124 drop_missing_vals=args.drop_missing_vals,
125 )
126 datamodule.setup()
127 spe = math.ceil(len(datamodule.train_dataset) / datamodule.batch_size)
128 logging.info(f">>> TRAIN: Setting steps per epoch to {spe}")
129 model.steps_per_epoch = spe
131 callbacks = get_callbacks(
132 run_name=args.run_name,
133 termination_patience=args.terminator_patience,
134 wandb_project=args.wandb_project,
135 )
137 callbacks["logger"].watch(model.encoder)
138 callbacks["logger"].watch(model.decoder)
140 trainer = pl.Trainer.from_argparse_args(
141 args,
142 profiler="simple",
143 logger=callbacks["logger"],
144 callbacks=callbacks["callbacks"],
145 )
147 trainer.fit(model, datamodule)