Coverage for src/trapi_predict_kit/trapi_parser.py: 85%

157 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-25 21:14 +0200

1import re 

2 

3import requests 

4 

5from trapi_predict_kit.config import settings 

6from trapi_predict_kit.utils import get_entities_labels, log 

7 

8# TODO: add evidence path to TRAPI 

9 

10 

11def is_accepted_id(id_to_check): 

12 return id_to_check.lower().startswith("omim") or id_to_check.lower().startswith("drugbank") 

13 

14 

15def get_biolink_parents(concept): 

16 concept_snakecase = concept.replace("biolink:", "") 

17 concept_snakecase = re.sub(r"(?<!^)(?=[A-Z])", "_", concept_snakecase).lower() 

18 query_url = f"https://bl-lookup-sri.renci.org/bl/{concept_snakecase}/ancestors" 

19 try: 

20 resolve_curies = requests.get(query_url, timeout=settings.TIMEOUT) 

21 # TODO: can't specify a BioLink version because asking for v3.1.0 does not exist, so we use latest 

22 # resolve_curies = requests.get(query_url, 

23 # params={'version': f'v{settings.BIOLINK_VERSION}'}) 

24 resp = resolve_curies.json() 

25 resp.append(concept) 

26 return resp 

27 except Exception as e: 

28 log.warn(f"Error querying {query_url}, using the original IDs: {e}") 

29 return [concept] 

30 

31 

32def resolve_ids_with_nodenormalization_api(resolve_ids_list, resolved_ids_object): 

33 resolved_ids_list = [] 

34 ids_to_normalize = [] 

35 for id_to_resolve in resolve_ids_list: 

36 if is_accepted_id(id_to_resolve): 

37 resolved_ids_list.append(id_to_resolve) 

38 resolved_ids_object[id_to_resolve] = id_to_resolve 

39 else: 

40 ids_to_normalize.append(id_to_resolve) 

41 

42 # Query Translator NodeNormalization API to convert IDs to OMIM/DrugBank IDs 

43 if len(ids_to_normalize) > 0: 

44 try: 

45 resolve_curies = requests.get( 

46 "https://nodenormalization-sri.renci.org/get_normalized_nodes", 

47 params={"curie": ids_to_normalize}, 

48 timeout=settings.TIMEOUT, 

49 ) 

50 # Get corresponding OMIM IDs for MONDO IDs if match 

51 resp = resolve_curies.json() 

52 for resolved_id, alt_ids in resp.items(): 

53 for alt_id in alt_ids["equivalent_identifiers"]: 

54 if is_accepted_id(str(alt_id["identifier"])): 

55 main_id = str(alt_id["identifier"]) 

56 # NOTE: fix issue when NodeNorm returns OMIM.PS: instead of OMIM: 

57 if main_id.lower().startswith("omim"): 57 ↛ 58line 57 didn't jump to line 58, because the condition on line 57 was never true

58 main_id = "OMIM:" + main_id.split(":", 1)[1] 

59 resolved_ids_list.append(main_id) 

60 resolved_ids_object[main_id] = resolved_id 

61 except Exception: 

62 log.warn("Error querying the NodeNormalization API, using the original IDs") 

63 # log.info(f"Resolved: {resolve_ids_list} to {resolved_ids_object}") 

64 return resolved_ids_list, resolved_ids_object 

65 

66 

67def resolve_id(id_to_resolve, resolved_ids_object): 

68 if id_to_resolve in resolved_ids_object: 

69 return resolved_ids_object[id_to_resolve] 

70 return id_to_resolve 

71 

72 

73def resolve_trapi_query(reasoner_query, endpoints_list): 

74 """Main function for TRAPI 

75 Convert an array of predictions objects to ReasonerAPI format 

76 Run the get_predict to get the QueryGraph edges and nodes 

77 {disease: OMIM:1567, drug: DRUGBANK:DB0001, score: 0.9} 

78 

79 :param: reasoner_query Query from Reasoner API 

80 :return: Results as ReasonerAPI object 

81 """ 

82 # Example TRAPI message: https://github.com/NCATSTranslator/ReasonerAPI/blob/master/examples/Message/simple.json 

83 query_graph = reasoner_query["message"]["query_graph"] 

84 # Default query_options 

85 model_id = None 

86 n_results = None 

87 min_score = None 

88 max_score = None 

89 query_options = {} 

90 if "query_options" in reasoner_query: 

91 query_options = reasoner_query["query_options"] 

92 if "n_results" in query_options: 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was never false

93 n_results = int(query_options["n_results"]) 

94 if "min_score" in query_options: 94 ↛ 96line 94 didn't jump to line 96, because the condition on line 94 was never false

95 min_score = float(query_options["min_score"]) 

96 if "max_score" in query_options: 96 ↛ 98line 96 didn't jump to line 98, because the condition on line 96 was never false

97 max_score = float(query_options["max_score"]) 

98 if "model_id" in query_options: 98 ↛ 99line 98 didn't jump to line 99, because the condition on line 98 was never true

99 model_id = str(query_options["model_id"]) 

100 

101 query_plan = {} 

102 resolved_ids_object = {} 

103 

104 # if not similarity_embeddings or similarity_embeddings == {}: 

105 # similarity_embeddings = None 

106 # treatment_embeddings = None 

107 

108 # Parse the query_graph to build the query plan 

109 for edge_id, qg_edge in query_graph["edges"].items(): 

110 # Build dict with all infos of associations to predict 

111 query_plan[edge_id] = { 

112 # 'predicates': qg_edge['predicates'], 

113 # 'qedge_subjects': qg_edge['subject'], 

114 "qg_source_id": qg_edge["subject"], 

115 "qg_target_id": qg_edge["object"], 

116 } 

117 query_plan[edge_id]["predicates"] = qg_edge["predicates"] 

118 

119 # If single value provided for predicate: make it an array 

120 # if not isinstance(query_plan[edge_id]['predicate'], list): 

121 # query_plan[edge_id]['predicate'] = [ query_plan[edge_id]['predicate'] ] 

122 

123 # Get the nodes infos in the query plan object 

124 for node_id, node in query_graph["nodes"].items(): 

125 if node_id == qg_edge["subject"]: 

126 query_plan[edge_id]["subject_qg_id"] = node_id 

127 query_plan[edge_id]["subject_types"] = node.get("categories", ["biolink:NamedThing"]) 

128 if "ids" in node: 

129 query_plan[edge_id]["subject_kg_id"], resolved_ids_object = resolve_ids_with_nodenormalization_api( 

130 node["ids"], resolved_ids_object 

131 ) 

132 query_plan[edge_id]["ids_to_predict"] = query_plan[edge_id]["subject_kg_id"] 

133 query_plan[edge_id]["types_to_predict"] = query_plan[edge_id]["subject_types"] 

134 query_plan[edge_id]["relation_to_predict"] = "subject" 

135 query_plan[edge_id]["relation_predicted"] = "object" 

136 

137 if node_id == qg_edge["object"]: 

138 query_plan[edge_id]["object_qg_id"] = node_id 

139 query_plan[edge_id]["object_types"] = node.get("categories", ["biolink:NamedThing"]) 

140 if "ids" in node: 

141 query_plan[edge_id]["object_kg_id"], resolved_ids_object = resolve_ids_with_nodenormalization_api( 

142 node["ids"], resolved_ids_object 

143 ) 

144 if "ids_to_predict" not in query_plan[edge_id]: 144 ↛ 124line 144 didn't jump to line 124, because the condition on line 144 was never false

145 query_plan[edge_id]["ids_to_predict"] = query_plan[edge_id]["object_kg_id"] 

146 query_plan[edge_id]["types_to_predict"] = query_plan[edge_id]["object_types"] 

147 query_plan[edge_id]["relation_to_predict"] = "object" 

148 query_plan[edge_id]["relation_predicted"] = "subject" 

149 

150 knowledge_graph = {"nodes": {}, "edges": {}} 

151 node_dict = {} 

152 query_results = [] 

153 kg_edge_count = 0 

154 

155 # Now iterates the query plan to execute each query 

156 for edge_qg_id in query_plan: 

157 # TODO: exit if no ID provided? Or check already done before? 

158 

159 for predict_func in endpoints_list: 

160 # TODO: run the functions in parallel with future.concurrent 

161 

162 for prediction_relation in predict_func._trapi_predict["edges"]: 

163 predicate_parents = get_biolink_parents(prediction_relation["predicate"]) 

164 subject_parents = get_biolink_parents(prediction_relation["subject"]) 

165 object_parents = get_biolink_parents(prediction_relation["object"]) 

166 

167 # TODO: add support for "qualifier_constraints" on query edges. cf. https://github.com/NCATSTranslator/testing/blob/main/ars-requests/not-none/1.2/mvp2cMetformin.json 

168 

169 # Check if requested subject/predicate/object are served by the function 

170 if ( 

171 any(i in predicate_parents for i in query_plan[edge_qg_id]["predicates"]) 

172 and any(i in subject_parents for i in query_plan[edge_qg_id]["subject_types"]) 

173 and any(i in object_parents for i in query_plan[edge_qg_id]["object_types"]) 

174 ): 

175 for id_to_predict in query_plan[edge_id]["ids_to_predict"]: 

176 labels_dict = get_entities_labels([id_to_predict]) 

177 label_to_predict = None 

178 if id_to_predict in labels_dict: 178 ↛ 180line 178 didn't jump to line 180, because the condition on line 178 was never false

179 label_to_predict = labels_dict[id_to_predict]["id"]["label"] 

180 try: 

181 log.info(f"🔮⏳️ Getting predictions for: {id_to_predict}") 

182 # Run function to get predictions 

183 prediction_results = predict_func( 

184 id_to_predict, 

185 { 

186 "model_id": model_id, 

187 "min_score": min_score, 

188 "max_score": max_score, 

189 "n_results": n_results, 

190 "types": query_plan[edge_id]["types_to_predict"], 

191 # "types": query_plan[edge_qg_id]['from_type'], 

192 }, 

193 ) 

194 prediction_json = prediction_results["hits"] 

195 except Exception as e: 

196 log.error(f"Error getting the predictions: {e}") 

197 prediction_json = [] 

198 

199 for association in prediction_json: 

200 # id/type of nodes are registered in a dict to avoid duplicate in knowledge_graph.nodes 

201 # Build dict of node ID : label 

202 source_node_id = resolve_id(id_to_predict, resolved_ids_object) 

203 target_node_id = resolve_id(association["id"], resolved_ids_object) 

204 

205 # TODO: XAI get path between source and target nodes (first create the function for this) 

206 

207 # If the target ID is given, we filter here from the predictions 

208 # if 'to_kg_id' in query_plan[edge_qg_id] and target_node_id not in query_plan[edge_qg_id]['to_kg_id']: 

209 if ( 209 ↛ 214line 209 didn't jump to line 214

210 "subject_kg_id" in query_plan[edge_id] 

211 and "object_kg_id" in query_plan[edge_id] 

212 and target_node_id not in query_plan[edge_qg_id]["object_kg_id"] 

213 ): 

214 pass 

215 

216 else: 

217 edge_kg_id = "e" + str(kg_edge_count) 

218 # Get the ID of the predicted entity in result association 

219 # based on the type expected for the association "to" node 

220 # node_dict[id_to_predict] = query_plan[edge_qg_id]['from_type'] 

221 # node_dict[association[query_plan[edge_qg_id]['to_type']]] = query_plan[edge_qg_id]['to_type'] 

222 rel_to_predict = query_plan[edge_id]["relation_to_predict"] 

223 rel_predicted = query_plan[edge_id]["relation_predicted"] 

224 node_dict[source_node_id] = {"type": query_plan[edge_qg_id][f"{rel_to_predict}_types"]} 

225 if label_to_predict: 225 ↛ 228line 225 didn't jump to line 228, because the condition on line 225 was never false

226 node_dict[source_node_id]["label"] = label_to_predict 

227 

228 node_dict[target_node_id] = {"type": association["type"]} 

229 if "label" in association: 229 ↛ 233line 229 didn't jump to line 233, because the condition on line 229 was never false

230 node_dict[target_node_id]["label"] = association["label"] 

231 else: 

232 # TODO: improve to avoid to call the resolver everytime 

233 labels_dict = get_entities_labels([target_node_id]) 

234 if target_node_id in labels_dict and labels_dict[target_node_id]: 

235 node_dict[target_node_id]["label"] = labels_dict[target_node_id]["id"]["label"] 

236 

237 # edge_association_type = 'biolink:ChemicalToDiseaseOrPhenotypicFeatureAssociation' 

238 # relation = 'RO:0002434' # interacts with 

239 # relation = 'OBOREL:0002606' 

240 association_score = str(association["score"]) 

241 

242 model_id_label = model_id 

243 if not model_id_label: 243 ↛ 247line 243 didn't jump to line 247

244 model_id_label = "openpredict_baseline" 

245 

246 # See attributes examples: https://github.com/NCATSTranslator/Evidence-Provenance-Confidence-Working-Group/blob/master/attribute_epc_examples/COHD_TRAPI1.1_Attribute_Example_2-3-21.yml 

247 edge_dict = { 

248 # TODO: not required anymore? 'association_type': edge_association_type, 

249 # 'relation': relation, 

250 # More details on attributes: https://github.com/NCATSTranslator/ReasonerAPI/blob/master/docs/reference.md#attribute- 

251 "sources": [ 

252 { 

253 "resource_id": "infores:openpredict", 

254 "resource_role": "primary_knowledge_source", 

255 }, 

256 {"resource_id": "infores:cohd", "resource_role": "supporting_data_source"}, 

257 ], 

258 "attributes": [ 

259 { 

260 "description": "model_id", 

261 "attribute_type_id": "EDAM:data_1048", 

262 "value": model_id_label, 

263 }, 

264 # { 

265 # # TODO: use has_confidence_level? 

266 # "description": "score", 

267 # "attribute_type_id": "EDAM:data_1772", 

268 # "value": association_score 

269 # # https://www.ebi.ac.uk/ols/ontologies/edam/terms?iri=http%3A%2F%2Fedamontology.org%2Fdata_1772&viewMode=All&siblings=false 

270 # }, 

271 # https://github.com/NCATSTranslator/ReasonerAPI/blob/1.4/ImplementationGuidance/Specifications/knowledge_level_agent_type_specification.md 

272 { 

273 "attribute_type_id": "biolink:agent_type", 

274 "value": "computational_model", 

275 "attribute_source": "infores:openpredict", 

276 }, 

277 { 

278 "attribute_type_id": "biolink:knowledge_level", 

279 "value": "prediction", 

280 "attribute_source": "infores:openpredict", 

281 }, 

282 ], 

283 } 

284 

285 # Map the source/target of query_graph to source/target of association 

286 # if association['source']['type'] == query_plan[edge_qg_id]['from_type']: 

287 edge_dict["subject"] = source_node_id 

288 edge_dict["object"] = target_node_id 

289 

290 # TODO: Define the predicate depending on the association source type returned by OpenPredict classifier 

291 if len(query_plan[edge_qg_id]["predicates"]) > 0: 291 ↛ 294line 291 didn't jump to line 294, because the condition on line 291 was never false

292 edge_dict["predicate"] = query_plan[edge_qg_id]["predicates"][0] 

293 else: 

294 edge_dict["predicate"] = prediction_relation["predicate"] 

295 

296 # Add the association in the knowledge_graph as edge 

297 # Use the type as key in the result association dict (for IDs) 

298 knowledge_graph["edges"][edge_kg_id] = edge_dict 

299 

300 # Add the bindings to the results object 

301 result = { 

302 "node_bindings": {}, 

303 "analyses": [ 

304 { 

305 "resource_id": "infores:openpredict", 

306 "score": association_score, 

307 "scoring_method": "Model confidence between 0 and 1", 

308 "edge_bindings": {edge_qg_id: [{"id": edge_kg_id}]}, 

309 } 

310 ], 

311 # 'edge_bindings': {}, 

312 } 

313 # result['edge_bindings'][edge_qg_id] = [ 

314 # { 

315 # "id": edge_kg_id 

316 # } 

317 # ] 

318 result["node_bindings"][query_plan[edge_qg_id][f"{rel_to_predict}_qg_id"]] = [ 

319 {"id": source_node_id} 

320 ] 

321 result["node_bindings"][query_plan[edge_qg_id][f"{rel_predicted}_qg_id"]] = [ 

322 {"id": target_node_id} 

323 ] 

324 query_results.append(result) 

325 

326 kg_edge_count += 1 

327 if kg_edge_count == n_results: 327 ↛ 328line 327 didn't jump to line 328, because the condition on line 327 was never true

328 break 

329 

330 # Generate kg nodes from the dict of nodes + result from query to resolve labels 

331 for node_id, properties in node_dict.items(): 

332 node_category = properties["type"] 

333 if isinstance(node_category, str) and not node_category.startswith("biolink:"): 333 ↛ 334line 333 didn't jump to line 334, because the condition on line 333 was never true

334 node_category = "biolink:" + node_category.capitalize() 

335 if isinstance(node_category, str): 

336 node_category = [node_category] 

337 node_to_add = { 

338 "categories": node_category, 

339 } 

340 if "label" in properties and properties["label"]: 340 ↛ 342line 340 didn't jump to line 342, because the condition on line 340 was never false

341 node_to_add["name"] = properties["label"] 

342 knowledge_graph["nodes"][node_id] = node_to_add 

343 

344 return { 

345 "message": {"knowledge_graph": knowledge_graph, "query_graph": query_graph, "results": query_results}, 

346 "query_options": query_options, 

347 "reasoner_id": "infores:openpredict", 

348 "schema_version": settings.TRAPI_VERSION, 

349 "biolink_version": settings.BIOLINK_VERSION, 

350 "status": "Success", 

351 # "logs": [ 

352 # { 

353 # "code": None, 

354 # "level": "INFO", 

355 # "message": "No descendants found from Ontology KP for QNode 'n00'.", 

356 # "timestamp": "2023-04-05T07:24:26.646711" 

357 # }, 

358 # ] 

359 } 

360 

361 

362example_trapi = { 

363 "message": { 

364 "query_graph": { 

365 "edges": {"e01": {"object": "n1", "predicates": ["biolink:treated_by", "biolink:treats"], "subject": "n0"}}, 

366 "nodes": { 

367 "n0": {"categories": ["biolink:Disease", "biolink:Drug"], "ids": ["OMIM:246300", "DRUGBANK:DB00394"]}, 

368 "n1": {"categories": ["biolink:Drug", "biolink:Disease"]}, 

369 }, 

370 } 

371 }, 

372 "query_options": {"max_score": 1, "min_score": 0.5}, 

373}