Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import math 

2import logging 

3from argparse import Namespace, ArgumentParser, ArgumentDefaultsHelpFormatter 

4 

5import torch 

6import pytorch_lightning as pl 

7from pytorch_lightning.loggers import WandbLogger 

8 

9from elfragmentador import datamodules, model 

10import elfragmentador as tp 

11from elfragmentador.model import PepTransformerModel 

12from pytorch_lightning.callbacks.early_stopping import EarlyStopping 

13from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor 

14from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 

15from pytorch_lightning.loggers.wandb import WandbLogger 

16from typing import Dict, List, Union 

17 

18 

19def build_train_parser() -> ArgumentParser: 

20 parser = ArgumentParser(add_help=False) 

21 

22 program_parser = parser.add_argument_group( 

23 "Program Parameters", 

24 "Program level parameters, these should not change the outcome of the run", 

25 ) 

26 model_parser = parser.add_argument_group( 

27 "Model Parameters", 

28 "Parameters that modify the model or its training (learn rate, scheduler, layers, dimension ...)", 

29 ) 

30 data_parser = parser.add_argument_group( 

31 "Data Parameters", "Parameters for the loading of data" 

32 ) 

33 trainer_parser = parser.add_argument_group( 

34 "Trainer Parameters", "Parameters that modify the model or its training" 

35 ) 

36 

37 # add PROGRAM level args 

38 program_parser.add_argument( 

39 "--run_name", 

40 type=str, 

41 default=f"ElFragmentador", 

42 help="Name to be given to the run (logging)", 

43 ) 

44 program_parser.add_argument( 

45 "--wandb_project", 

46 type=str, 

47 default="rttransformer", 

48 help="Wandb project to log to, check out wandb... please", 

49 ) 

50 trainer_parser.add_argument( 

51 "--terminator_patience", 

52 type=int, 

53 default="5", 

54 help="Patience for early termination", 

55 ) 

56 trainer_parser.add_argument( 

57 "--from_checkpoint", 

58 type=str, 

59 default=None, 

60 help="The path of a checkpoint to copy weights from before training", 

61 ) 

62 

63 # add model specific args 

64 model_parser = model.PepTransformerModel.add_model_specific_args(model_parser) 

65 

66 # add data specific args 

67 data_parser = datamodules.PeptideDataModule.add_model_specific_args(data_parser) 

68 

69 # add all the available trainer options to argparse 

70 # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli 

71 t_parser = ArgumentParser(add_help=False) 

72 t_parser = pl.Trainer.add_argparse_args(t_parser) 

73 

74 if torch.cuda.is_available(): 

75 t_parser.set_defaults(gpus=-1) 

76 t_parser.set_defaults(precision=16) 

77 

78 parser = ArgumentParser( 

79 parents=[t_parser, parser], formatter_class=ArgumentDefaultsHelpFormatter 

80 ) 

81 

82 return parser 

83 

84 

85def get_callbacks( 

86 run_name: str, termination_patience: int = 20, wandb_project: str = "rttransformer" 

87) -> Dict[ 

88 str, 

89 Union[ 

90 WandbLogger, List[Union[LearningRateMonitor, ModelCheckpoint, EarlyStopping]] 

91 ], 

92]: 

93 complete_run_name = f"{tp.__version__}_{run_name}" 

94 wandb_logger = WandbLogger(complete_run_name, project=wandb_project) 

95 lr_monitor = pl.callbacks.lr_monitor.LearningRateMonitor() 

96 checkpointer = pl.callbacks.ModelCheckpoint( 

97 monitor="val_l", 

98 verbose=True, 

99 save_top_k=2, 

100 save_weights_only=True, 

101 dirpath=".", 

102 save_last=True, 

103 mode="min", 

104 filename=complete_run_name + "_{v_l:.6f}_{epoch:03d}", 

105 ) 

106 

107 terminator = pl.callbacks.early_stopping.EarlyStopping( 

108 monitor="val_l", 

109 min_delta=0.00, 

110 patience=termination_patience, 

111 verbose=False, 

112 mode="min", 

113 ) 

114 

115 return {"logger": wandb_logger, "callbacks": [lr_monitor, checkpointer, terminator]} 

116 

117 

118def main_train(model: PepTransformerModel, args: Namespace) -> None: 

119 # TODO add loggging levela and a more structured logger ... 

120 logging.info(model) 

121 datamodule = datamodules.PeptideDataModule( 

122 batch_size=args.batch_size, 

123 base_dir=args.data_dir, 

124 drop_missing_vals=args.drop_missing_vals, 

125 ) 

126 datamodule.setup() 

127 spe = math.ceil(len(datamodule.train_dataset) / datamodule.batch_size) 

128 logging.info(f">>> TRAIN: Setting steps per epoch to {spe}") 

129 model.steps_per_epoch = spe 

130 

131 callbacks = get_callbacks( 

132 run_name=args.run_name, 

133 termination_patience=args.terminator_patience, 

134 wandb_project=args.wandb_project, 

135 ) 

136 

137 callbacks["logger"].watch(model.encoder) 

138 callbacks["logger"].watch(model.decoder) 

139 

140 trainer = pl.Trainer.from_argparse_args( 

141 args, 

142 profiler="simple", 

143 logger=callbacks["logger"], 

144 callbacks=callbacks["callbacks"], 

145 ) 

146 

147 trainer.fit(model, datamodule)