Coverage for cli.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from pathlib import Path
2import logging
3import logging.config
5import argparse
6from argparse import (
7 ArgumentDefaultsHelpFormatter,
8 ArgumentParser,
9)
11import elfragmentador
13try:
14 from argparse import BooleanOptionalAction
15except ImportError:
16 # Exception for py <3.8 compatibility ...
17 BooleanOptionalAction = "store_true"
20import warnings
22import pytorch_lightning as pl
23import pandas as pd
25from elfragmentador.train import build_train_parser, main_train
26from elfragmentador.model import PepTransformerModel
27from elfragmentador.spectra import sptxt_to_csv
28from elfragmentador.utils import append_preds, predict_df
29from elfragmentador import datamodules, evaluate, rt
31import uniplot
33DEFAULT_LOGGER_BASIC_CONF = {
34 "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
35 "level": logging.DEBUG,
36}
39def calculate_irt():
40 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF)
41 parser = ArgumentParser()
42 parser.add_argument(
43 "file",
44 type=argparse.FileType("r"),
45 nargs="+",
46 help="Input file(s) to convert (skyline csv output)",
47 )
48 parser.add_argument(
49 "--out",
50 default="out.csv",
51 type=str,
52 help="Name of the file where the output will be written (csv)",
53 )
55 args = parser.parse_args()
56 files = [x.name for x in args.file]
57 df = rt.calculate_multifile_iRT(files)
58 df.to_csv(str(args.out))
61def append_predictions():
62 """
63 Appends the cosine similarity between the predicted and actual spectra
64 to a percolator input.
65 """
66 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF)
67 parser = ArgumentParser()
68 parser.add_argument(
69 "--pin",
70 type=str,
71 help="Input percolator file",
72 )
73 parser.add_argument(
74 "--out",
75 type=str,
76 help="Input percolator file",
77 )
78 parser.add_argument(
79 "--model_checkpoint",
80 type=str,
81 default=elfragmentador.DEFAULT_CHECKPOINT,
82 help="Model checkpoint to use for the prediction, if nothing is passed will download a pretrained model",
83 )
85 args = parser.parse_args()
87 model = PepTransformerModel.load_from_checkpoint(args.model_checkpoint)
88 model.eval()
90 return append_preds(in_pin=args.pin, out_pin=args.out, model=model)
93def predict_csv():
94 """
95 Predicts the peptides in a csv file
96 """
97 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF)
98 parser = ArgumentParser()
99 parser.add_argument(
100 "--csv",
101 type=str,
102 help="Input csv file",
103 )
104 parser.add_argument(
105 "--impute_collision_energy",
106 type=float,
107 default=0,
108 help="Collision energy to use if none is specified in the file",
109 )
110 parser.add_argument(
111 "--out",
112 type=str,
113 help="Output .sptxt file",
114 )
115 parser.add_argument(
116 "--model_checkpoint",
117 type=str,
118 default=elfragmentador.DEFAULT_CHECKPOINT,
119 help="Model checkpoint to use for the prediction, if nothing is passed will download a pretrained model",
120 )
122 args = parser.parse_args()
123 if args.impute_collision_energy == 0:
124 nce = False
125 else:
126 nce = args.impute_collision_energy
128 model = PepTransformerModel.load_from_checkpoint(args.model_checkpoint)
129 model.eval()
131 with open(args.out, "w") as f:
132 f.write(
133 predict_df(pd.read_csv(args.csv), impute_collision_energy=nce, model=model)
134 )
137def convert_sptxt():
138 """
139 convert_sptxt Provides a CLI to convert an sptxt to a csv for training
141 provides a CLI for the sptxt_to_csv function, chek that guy out for the actual
142 implementation
143 """
144 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF)
145 parser = ArgumentParser()
146 parser.add_argument(
147 "file",
148 type=argparse.FileType("r"),
149 nargs="+",
150 help="Input file(s) to convert (sptxt)",
151 )
152 parser.add_argument(
153 "--warn",
154 default=False,
155 action=BooleanOptionalAction,
156 help="Wether to show warnings or not",
157 )
158 parser.add_argument(
159 "--keep_irts",
160 default=False,
161 action=BooleanOptionalAction,
162 help="Wether to remove sequences that match procal and biognosys iRT peptides",
163 )
164 parser.add_argument(
165 "--min_peaks",
166 default=3,
167 type=int,
168 help="Minimum number of annotated peaks required to keep the spectrum",
169 )
170 parser.add_argument(
171 "--min_delta_ascore",
172 default=20,
173 type=int,
174 help="Minimum ascore required to keep a spectrum",
175 )
177 args = parser.parse_args()
179 logging.info([x.name for x in args.file])
181 # Here we make the partial function that will be used to actually read the data
182 converter = lambda fname, outname: sptxt_to_csv(
183 fname,
184 outname,
185 filter_irt_peptides=args.keep_irts,
186 min_delta_ascore=args.min_delta_ascore,
187 min_peaks=args.min_peaks,
188 )
190 for f in args.file:
191 out_file = f.name + ".csv"
192 if Path(out_file).exists():
193 logging.info(
194 f"Skipping conversion of '{f.name}' "
195 f"to '{out_file}', "
196 f"because {out_file} exists."
197 )
198 else:
199 logging.info(f"Converting '{f.name}' to '{out_file}'")
200 if args.warn:
201 logging.warning("Warning stuff ...")
202 converter(f.name, out_file)
203 else:
204 with warnings.catch_warnings(record=True) as c:
205 logging.warning("Not Warning stuff ...")
206 converter(f.name, out_file)
208 if len(c) > 0:
209 logging.warning(
210 f"Last Error Message of {len(c)}: '{c[-1].message}'"
211 )
214def evaluate_checkpoint():
215 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF)
216 pl.seed_everything(2020)
217 parser = evaluate.build_evaluate_parser()
218 args = parser.parse_args()
219 dict_args = vars(args)
220 logging.info(dict_args)
222 model = PepTransformerModel.load_from_checkpoint(args.checkpoint_path)
223 if dict_args["csv"] is not None:
224 ds = datamodules.PeptideDataset.from_csv(
225 args.csv,
226 max_spec=args.max_spec,
227 )
228 elif dict_args["sptxt"] is not None:
229 ds = datamodules.PeptideDataset.from_sptxt(
230 args.sptxt,
231 max_spec=args.max_spec,
232 )
233 else:
234 raise ValueError("Must have an argument to either --csv or --sptxt")
236 if dict_args["screen_nce"] is not None:
237 nces = [float(x) for x in dict_args["screen_nce"].split(",")]
238 elif dict_args["overwrite_nce"] is not None:
239 nces = [dict_args["overwrite_nce"]]
240 else:
241 nces = [False]
243 best_res = tuple([{}, {"AverageSpectraCosineSimilarity": 0}])
244 best_nce = None
245 res_history = []
246 for nce in nces:
247 if nce:
248 logging.info(f">>>> Starting evaluation of NCE={nce}")
249 res = evaluate.evaluate_on_dataset(
250 model=model,
251 dataset=ds,
252 batch_size=args.batch_size,
253 device=args.device,
254 overwrite_nce=nce,
255 )
256 res_history.append(res[1]["AverageSpectraCosineSimilarity"])
257 if (
258 res[1]["AverageSpectraCosineSimilarity"]
259 > best_res[1]["AverageSpectraCosineSimilarity"]
260 ):
261 best_res = res
262 best_nce = nce
264 if len(nces) > 1:
265 logging.info(f"Best Nce was {best_nce}")
266 uniplot.plot(ys=res_history, xs=nces)
268 if dict_args["out_csv"] is not None:
269 best_res[0].to_csv(dict_args["out_csv"], index=False)
272def train():
273 logging.basicConfig(**DEFAULT_LOGGER_BASIC_CONF)
274 pl.seed_everything(2020)
275 parser = build_train_parser()
276 args = parser.parse_args()
277 dict_args = vars(args)
278 logging.info("====== Passed command line args/params =====")
279 for k, v in dict_args.items():
280 logging.info(f">> {k}: {v}")
282 mod = PepTransformerModel(**dict_args)
283 if args.from_checkpoint is not None:
284 logging.info(f">> Resuming training from checkpoint {args.from_checkpoint} <<")
285 weights_mod = PepTransformerModel.load_from_checkpoint(args.from_checkpoint)
286 mod.load_state_dict(weights_mod.state_dict())
287 del weights_mod
289 main_train(mod, args)