Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from pathlib import Path 

2import logging 

3import logging.config 

4 

5import argparse 

6from argparse import ( 

7 ArgumentDefaultsHelpFormatter, 

8 ArgumentParser, 

9) 

10 

11import elfragmentador 

12 

13try: 

14 from argparse import BooleanOptionalAction 

15except ImportError: 

16 # Exception for py <3.8 compatibility ... 

17 BooleanOptionalAction = "store_true" 

18 

19 

20import warnings 

21 

22import pytorch_lightning as pl 

23import pandas as pd 

24 

25from elfragmentador.train import build_train_parser, main_train 

26from elfragmentador.model import PepTransformerModel 

27from elfragmentador.spectra import sptxt_to_csv 

28from elfragmentador.utils import append_preds, predict_df 

29from elfragmentador import datamodules, evaluate, rt 

30 

31import uniplot 

32 

33DEFAULT_LOGGER_BASIC_CONF = { 

34 "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", 

35 "level": logging.DEBUG, 

36} 

37 

38 

39def calculate_irt(): 

40 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF) 

41 parser = ArgumentParser() 

42 parser.add_argument( 

43 "file", 

44 type=argparse.FileType("r"), 

45 nargs="+", 

46 help="Input file(s) to convert (skyline csv output)", 

47 ) 

48 parser.add_argument( 

49 "--out", 

50 default="out.csv", 

51 type=str, 

52 help="Name of the file where the output will be written (csv)", 

53 ) 

54 

55 args = parser.parse_args() 

56 files = [x.name for x in args.file] 

57 df = rt.calculate_multifile_iRT(files) 

58 df.to_csv(str(args.out)) 

59 

60 

61def append_predictions(): 

62 """ 

63 Appends the cosine similarity between the predicted and actual spectra 

64 to a percolator input. 

65 """ 

66 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF) 

67 parser = ArgumentParser() 

68 parser.add_argument( 

69 "--pin", 

70 type=str, 

71 help="Input percolator file", 

72 ) 

73 parser.add_argument( 

74 "--out", 

75 type=str, 

76 help="Input percolator file", 

77 ) 

78 parser.add_argument( 

79 "--model_checkpoint", 

80 type=str, 

81 default=elfragmentador.DEFAULT_CHECKPOINT, 

82 help="Model checkpoint to use for the prediction, if nothing is passed will download a pretrained model", 

83 ) 

84 

85 args = parser.parse_args() 

86 

87 model = PepTransformerModel.load_from_checkpoint(args.model_checkpoint) 

88 model.eval() 

89 

90 return append_preds(in_pin=args.pin, out_pin=args.out, model=model) 

91 

92 

93def predict_csv(): 

94 """ 

95 Predicts the peptides in a csv file 

96 """ 

97 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF) 

98 parser = ArgumentParser() 

99 parser.add_argument( 

100 "--csv", 

101 type=str, 

102 help="Input csv file", 

103 ) 

104 parser.add_argument( 

105 "--impute_collision_energy", 

106 type=float, 

107 default=0, 

108 help="Collision energy to use if none is specified in the file", 

109 ) 

110 parser.add_argument( 

111 "--out", 

112 type=str, 

113 help="Output .sptxt file", 

114 ) 

115 parser.add_argument( 

116 "--model_checkpoint", 

117 type=str, 

118 default=elfragmentador.DEFAULT_CHECKPOINT, 

119 help="Model checkpoint to use for the prediction, if nothing is passed will download a pretrained model", 

120 ) 

121 

122 args = parser.parse_args() 

123 if args.impute_collision_energy == 0: 

124 nce = False 

125 else: 

126 nce = args.impute_collision_energy 

127 

128 model = PepTransformerModel.load_from_checkpoint(args.model_checkpoint) 

129 model.eval() 

130 

131 with open(args.out, "w") as f: 

132 f.write( 

133 predict_df(pd.read_csv(args.csv), impute_collision_energy=nce, model=model) 

134 ) 

135 

136 

137def convert_sptxt(): 

138 """ 

139 convert_sptxt Provides a CLI to convert an sptxt to a csv for training 

140 

141 provides a CLI for the sptxt_to_csv function, chek that guy out for the actual 

142 implementation 

143 """ 

144 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF) 

145 parser = ArgumentParser() 

146 parser.add_argument( 

147 "file", 

148 type=argparse.FileType("r"), 

149 nargs="+", 

150 help="Input file(s) to convert (sptxt)", 

151 ) 

152 parser.add_argument( 

153 "--warn", 

154 default=False, 

155 action=BooleanOptionalAction, 

156 help="Wether to show warnings or not", 

157 ) 

158 parser.add_argument( 

159 "--keep_irts", 

160 default=False, 

161 action=BooleanOptionalAction, 

162 help="Wether to remove sequences that match procal and biognosys iRT peptides", 

163 ) 

164 parser.add_argument( 

165 "--min_peaks", 

166 default=3, 

167 type=int, 

168 help="Minimum number of annotated peaks required to keep the spectrum", 

169 ) 

170 parser.add_argument( 

171 "--min_delta_ascore", 

172 default=20, 

173 type=int, 

174 help="Minimum ascore required to keep a spectrum", 

175 ) 

176 

177 args = parser.parse_args() 

178 

179 logging.info([x.name for x in args.file]) 

180 

181 # Here we make the partial function that will be used to actually read the data 

182 converter = lambda fname, outname: sptxt_to_csv( 

183 fname, 

184 outname, 

185 filter_irt_peptides=args.keep_irts, 

186 min_delta_ascore=args.min_delta_ascore, 

187 min_peaks=args.min_peaks, 

188 ) 

189 

190 for f in args.file: 

191 out_file = f.name + ".csv" 

192 if Path(out_file).exists(): 

193 logging.info( 

194 f"Skipping conversion of '{f.name}' " 

195 f"to '{out_file}', " 

196 f"because {out_file} exists." 

197 ) 

198 else: 

199 logging.info(f"Converting '{f.name}' to '{out_file}'") 

200 if args.warn: 

201 logging.warning("Warning stuff ...") 

202 converter(f.name, out_file) 

203 else: 

204 with warnings.catch_warnings(record=True) as c: 

205 logging.warning("Not Warning stuff ...") 

206 converter(f.name, out_file) 

207 

208 if len(c) > 0: 

209 logging.warning( 

210 f"Last Error Message of {len(c)}: '{c[-1].message}'" 

211 ) 

212 

213 

214def evaluate_checkpoint(): 

215 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF) 

216 pl.seed_everything(2020) 

217 parser = evaluate.build_evaluate_parser() 

218 args = parser.parse_args() 

219 dict_args = vars(args) 

220 logging.info(dict_args) 

221 

222 model = PepTransformerModel.load_from_checkpoint(args.checkpoint_path) 

223 if dict_args["csv"] is not None: 

224 ds = datamodules.PeptideDataset.from_csv( 

225 args.csv, 

226 max_spec=args.max_spec, 

227 ) 

228 elif dict_args["sptxt"] is not None: 

229 ds = datamodules.PeptideDataset.from_sptxt( 

230 args.sptxt, 

231 max_spec=args.max_spec, 

232 ) 

233 else: 

234 raise ValueError("Must have an argument to either --csv or --sptxt") 

235 

236 if dict_args["screen_nce"] is not None: 

237 nces = [float(x) for x in dict_args["screen_nce"].split(",")] 

238 elif dict_args["overwrite_nce"] is not None: 

239 nces = [dict_args["overwrite_nce"]] 

240 else: 

241 nces = [False] 

242 

243 best_res = tuple([{}, {"AverageSpectraCosineSimilarity": 0}]) 

244 best_nce = None 

245 res_history = [] 

246 for nce in nces: 

247 if nce: 

248 logging.info(f">>>> Starting evaluation of NCE={nce}") 

249 res = evaluate.evaluate_on_dataset( 

250 model=model, 

251 dataset=ds, 

252 batch_size=args.batch_size, 

253 device=args.device, 

254 overwrite_nce=nce, 

255 ) 

256 res_history.append(res[1]["AverageSpectraCosineSimilarity"]) 

257 if ( 

258 res[1]["AverageSpectraCosineSimilarity"] 

259 > best_res[1]["AverageSpectraCosineSimilarity"] 

260 ): 

261 best_res = res 

262 best_nce = nce 

263 

264 if len(nces) > 1: 

265 logging.info(f"Best Nce was {best_nce}") 

266 uniplot.plot(ys=res_history, xs=nces) 

267 

268 if dict_args["out_csv"] is not None: 

269 best_res[0].to_csv(dict_args["out_csv"], index=False) 

270 

271 

272def train(): 

273 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF) 

274 pl.seed_everything(2020) 

275 parser = build_train_parser() 

276 args = parser.parse_args() 

277 dict_args = vars(args) 

278 logging.info("====== Passed command line args/params =====") 

279 for k, v in dict_args.items(): 

280 logging.info(f">> {k}: {v}") 

281 

282 mod = PepTransformerModel(**dict_args) 

283 if args.from_checkpoint is not None: 

284 logging.info(f">> Resuming training from checkpoint {args.from_checkpoint} <<") 

285 weights_mod = PepTransformerModel.load_from_checkpoint(args.from_checkpoint) 

286 mod.load_state_dict(weights_mod.state_dict()) 

287 del weights_mod 

288 

289 main_train(mod, args)