Coverage for src/trapi_predict_kit/trapi_parser.py: 85%
157 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-25 21:14 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-25 21:14 +0200
1import re
3import requests
5from trapi_predict_kit.config import settings
6from trapi_predict_kit.utils import get_entities_labels, log
8# TODO: add evidence path to TRAPI
11def is_accepted_id(id_to_check):
12 return id_to_check.lower().startswith("omim") or id_to_check.lower().startswith("drugbank")
15def get_biolink_parents(concept):
16 concept_snakecase = concept.replace("biolink:", "")
17 concept_snakecase = re.sub(r"(?<!^)(?=[A-Z])", "_", concept_snakecase).lower()
18 query_url = f"https://bl-lookup-sri.renci.org/bl/{concept_snakecase}/ancestors"
19 try:
20 resolve_curies = requests.get(query_url, timeout=settings.TIMEOUT)
21 # TODO: can't specify a BioLink version because asking for v3.1.0 does not exist, so we use latest
22 # resolve_curies = requests.get(query_url,
23 # params={'version': f'v{settings.BIOLINK_VERSION}'})
24 resp = resolve_curies.json()
25 resp.append(concept)
26 return resp
27 except Exception as e:
28 log.warn(f"Error querying {query_url}, using the original IDs: {e}")
29 return [concept]
32def resolve_ids_with_nodenormalization_api(resolve_ids_list, resolved_ids_object):
33 resolved_ids_list = []
34 ids_to_normalize = []
35 for id_to_resolve in resolve_ids_list:
36 if is_accepted_id(id_to_resolve):
37 resolved_ids_list.append(id_to_resolve)
38 resolved_ids_object[id_to_resolve] = id_to_resolve
39 else:
40 ids_to_normalize.append(id_to_resolve)
42 # Query Translator NodeNormalization API to convert IDs to OMIM/DrugBank IDs
43 if len(ids_to_normalize) > 0:
44 try:
45 resolve_curies = requests.get(
46 "https://nodenormalization-sri.renci.org/get_normalized_nodes",
47 params={"curie": ids_to_normalize},
48 timeout=settings.TIMEOUT,
49 )
50 # Get corresponding OMIM IDs for MONDO IDs if match
51 resp = resolve_curies.json()
52 for resolved_id, alt_ids in resp.items():
53 for alt_id in alt_ids["equivalent_identifiers"]:
54 if is_accepted_id(str(alt_id["identifier"])):
55 main_id = str(alt_id["identifier"])
56 # NOTE: fix issue when NodeNorm returns OMIM.PS: instead of OMIM:
57 if main_id.lower().startswith("omim"): 57 ↛ 58line 57 didn't jump to line 58, because the condition on line 57 was never true
58 main_id = "OMIM:" + main_id.split(":", 1)[1]
59 resolved_ids_list.append(main_id)
60 resolved_ids_object[main_id] = resolved_id
61 except Exception:
62 log.warn("Error querying the NodeNormalization API, using the original IDs")
63 # log.info(f"Resolved: {resolve_ids_list} to {resolved_ids_object}")
64 return resolved_ids_list, resolved_ids_object
67def resolve_id(id_to_resolve, resolved_ids_object):
68 if id_to_resolve in resolved_ids_object:
69 return resolved_ids_object[id_to_resolve]
70 return id_to_resolve
73def resolve_trapi_query(reasoner_query, endpoints_list):
74 """Main function for TRAPI
75 Convert an array of predictions objects to ReasonerAPI format
76 Run the get_predict to get the QueryGraph edges and nodes
77 {disease: OMIM:1567, drug: DRUGBANK:DB0001, score: 0.9}
79 :param: reasoner_query Query from Reasoner API
80 :return: Results as ReasonerAPI object
81 """
82 # Example TRAPI message: https://github.com/NCATSTranslator/ReasonerAPI/blob/master/examples/Message/simple.json
83 query_graph = reasoner_query["message"]["query_graph"]
84 # Default query_options
85 model_id = None
86 n_results = None
87 min_score = None
88 max_score = None
89 query_options = {}
90 if "query_options" in reasoner_query:
91 query_options = reasoner_query["query_options"]
92 if "n_results" in query_options: 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was never false
93 n_results = int(query_options["n_results"])
94 if "min_score" in query_options: 94 ↛ 96line 94 didn't jump to line 96, because the condition on line 94 was never false
95 min_score = float(query_options["min_score"])
96 if "max_score" in query_options: 96 ↛ 98line 96 didn't jump to line 98, because the condition on line 96 was never false
97 max_score = float(query_options["max_score"])
98 if "model_id" in query_options: 98 ↛ 99line 98 didn't jump to line 99, because the condition on line 98 was never true
99 model_id = str(query_options["model_id"])
101 query_plan = {}
102 resolved_ids_object = {}
104 # if not similarity_embeddings or similarity_embeddings == {}:
105 # similarity_embeddings = None
106 # treatment_embeddings = None
108 # Parse the query_graph to build the query plan
109 for edge_id, qg_edge in query_graph["edges"].items():
110 # Build dict with all infos of associations to predict
111 query_plan[edge_id] = {
112 # 'predicates': qg_edge['predicates'],
113 # 'qedge_subjects': qg_edge['subject'],
114 "qg_source_id": qg_edge["subject"],
115 "qg_target_id": qg_edge["object"],
116 }
117 query_plan[edge_id]["predicates"] = qg_edge["predicates"]
119 # If single value provided for predicate: make it an array
120 # if not isinstance(query_plan[edge_id]['predicate'], list):
121 # query_plan[edge_id]['predicate'] = [ query_plan[edge_id]['predicate'] ]
123 # Get the nodes infos in the query plan object
124 for node_id, node in query_graph["nodes"].items():
125 if node_id == qg_edge["subject"]:
126 query_plan[edge_id]["subject_qg_id"] = node_id
127 query_plan[edge_id]["subject_types"] = node.get("categories", ["biolink:NamedThing"])
128 if "ids" in node:
129 query_plan[edge_id]["subject_kg_id"], resolved_ids_object = resolve_ids_with_nodenormalization_api(
130 node["ids"], resolved_ids_object
131 )
132 query_plan[edge_id]["ids_to_predict"] = query_plan[edge_id]["subject_kg_id"]
133 query_plan[edge_id]["types_to_predict"] = query_plan[edge_id]["subject_types"]
134 query_plan[edge_id]["relation_to_predict"] = "subject"
135 query_plan[edge_id]["relation_predicted"] = "object"
137 if node_id == qg_edge["object"]:
138 query_plan[edge_id]["object_qg_id"] = node_id
139 query_plan[edge_id]["object_types"] = node.get("categories", ["biolink:NamedThing"])
140 if "ids" in node:
141 query_plan[edge_id]["object_kg_id"], resolved_ids_object = resolve_ids_with_nodenormalization_api(
142 node["ids"], resolved_ids_object
143 )
144 if "ids_to_predict" not in query_plan[edge_id]: 144 ↛ 124line 144 didn't jump to line 124, because the condition on line 144 was never false
145 query_plan[edge_id]["ids_to_predict"] = query_plan[edge_id]["object_kg_id"]
146 query_plan[edge_id]["types_to_predict"] = query_plan[edge_id]["object_types"]
147 query_plan[edge_id]["relation_to_predict"] = "object"
148 query_plan[edge_id]["relation_predicted"] = "subject"
150 knowledge_graph = {"nodes": {}, "edges": {}}
151 node_dict = {}
152 query_results = []
153 kg_edge_count = 0
155 # Now iterates the query plan to execute each query
156 for edge_qg_id in query_plan:
157 # TODO: exit if no ID provided? Or check already done before?
159 for predict_func in endpoints_list:
160 # TODO: run the functions in parallel with future.concurrent
162 for prediction_relation in predict_func._trapi_predict["edges"]:
163 predicate_parents = get_biolink_parents(prediction_relation["predicate"])
164 subject_parents = get_biolink_parents(prediction_relation["subject"])
165 object_parents = get_biolink_parents(prediction_relation["object"])
167 # TODO: add support for "qualifier_constraints" on query edges. cf. https://github.com/NCATSTranslator/testing/blob/main/ars-requests/not-none/1.2/mvp2cMetformin.json
169 # Check if requested subject/predicate/object are served by the function
170 if (
171 any(i in predicate_parents for i in query_plan[edge_qg_id]["predicates"])
172 and any(i in subject_parents for i in query_plan[edge_qg_id]["subject_types"])
173 and any(i in object_parents for i in query_plan[edge_qg_id]["object_types"])
174 ):
175 for id_to_predict in query_plan[edge_id]["ids_to_predict"]:
176 labels_dict = get_entities_labels([id_to_predict])
177 label_to_predict = None
178 if id_to_predict in labels_dict: 178 ↛ 180line 178 didn't jump to line 180, because the condition on line 178 was never false
179 label_to_predict = labels_dict[id_to_predict]["id"]["label"]
180 try:
181 log.info(f"🔮⏳️ Getting predictions for: {id_to_predict}")
182 # Run function to get predictions
183 prediction_results = predict_func(
184 id_to_predict,
185 {
186 "model_id": model_id,
187 "min_score": min_score,
188 "max_score": max_score,
189 "n_results": n_results,
190 "types": query_plan[edge_id]["types_to_predict"],
191 # "types": query_plan[edge_qg_id]['from_type'],
192 },
193 )
194 prediction_json = prediction_results["hits"]
195 except Exception as e:
196 log.error(f"Error getting the predictions: {e}")
197 prediction_json = []
199 for association in prediction_json:
200 # id/type of nodes are registered in a dict to avoid duplicate in knowledge_graph.nodes
201 # Build dict of node ID : label
202 source_node_id = resolve_id(id_to_predict, resolved_ids_object)
203 target_node_id = resolve_id(association["id"], resolved_ids_object)
205 # TODO: XAI get path between source and target nodes (first create the function for this)
207 # If the target ID is given, we filter here from the predictions
208 # if 'to_kg_id' in query_plan[edge_qg_id] and target_node_id not in query_plan[edge_qg_id]['to_kg_id']:
209 if ( 209 ↛ 214line 209 didn't jump to line 214
210 "subject_kg_id" in query_plan[edge_id]
211 and "object_kg_id" in query_plan[edge_id]
212 and target_node_id not in query_plan[edge_qg_id]["object_kg_id"]
213 ):
214 pass
216 else:
217 edge_kg_id = "e" + str(kg_edge_count)
218 # Get the ID of the predicted entity in result association
219 # based on the type expected for the association "to" node
220 # node_dict[id_to_predict] = query_plan[edge_qg_id]['from_type']
221 # node_dict[association[query_plan[edge_qg_id]['to_type']]] = query_plan[edge_qg_id]['to_type']
222 rel_to_predict = query_plan[edge_id]["relation_to_predict"]
223 rel_predicted = query_plan[edge_id]["relation_predicted"]
224 node_dict[source_node_id] = {"type": query_plan[edge_qg_id][f"{rel_to_predict}_types"]}
225 if label_to_predict: 225 ↛ 228line 225 didn't jump to line 228, because the condition on line 225 was never false
226 node_dict[source_node_id]["label"] = label_to_predict
228 node_dict[target_node_id] = {"type": association["type"]}
229 if "label" in association: 229 ↛ 233line 229 didn't jump to line 233, because the condition on line 229 was never false
230 node_dict[target_node_id]["label"] = association["label"]
231 else:
232 # TODO: improve to avoid to call the resolver everytime
233 labels_dict = get_entities_labels([target_node_id])
234 if target_node_id in labels_dict and labels_dict[target_node_id]:
235 node_dict[target_node_id]["label"] = labels_dict[target_node_id]["id"]["label"]
237 # edge_association_type = 'biolink:ChemicalToDiseaseOrPhenotypicFeatureAssociation'
238 # relation = 'RO:0002434' # interacts with
239 # relation = 'OBOREL:0002606'
240 association_score = str(association["score"])
242 model_id_label = model_id
243 if not model_id_label: 243 ↛ 247line 243 didn't jump to line 247
244 model_id_label = "openpredict_baseline"
246 # See attributes examples: https://github.com/NCATSTranslator/Evidence-Provenance-Confidence-Working-Group/blob/master/attribute_epc_examples/COHD_TRAPI1.1_Attribute_Example_2-3-21.yml
247 edge_dict = {
248 # TODO: not required anymore? 'association_type': edge_association_type,
249 # 'relation': relation,
250 # More details on attributes: https://github.com/NCATSTranslator/ReasonerAPI/blob/master/docs/reference.md#attribute-
251 "sources": [
252 {
253 "resource_id": "infores:openpredict",
254 "resource_role": "primary_knowledge_source",
255 },
256 {"resource_id": "infores:cohd", "resource_role": "supporting_data_source"},
257 ],
258 "attributes": [
259 {
260 "description": "model_id",
261 "attribute_type_id": "EDAM:data_1048",
262 "value": model_id_label,
263 },
264 # {
265 # # TODO: use has_confidence_level?
266 # "description": "score",
267 # "attribute_type_id": "EDAM:data_1772",
268 # "value": association_score
269 # # https://www.ebi.ac.uk/ols/ontologies/edam/terms?iri=http%3A%2F%2Fedamontology.org%2Fdata_1772&viewMode=All&siblings=false
270 # },
271 # https://github.com/NCATSTranslator/ReasonerAPI/blob/1.4/ImplementationGuidance/Specifications/knowledge_level_agent_type_specification.md
272 {
273 "attribute_type_id": "biolink:agent_type",
274 "value": "computational_model",
275 "attribute_source": "infores:openpredict",
276 },
277 {
278 "attribute_type_id": "biolink:knowledge_level",
279 "value": "prediction",
280 "attribute_source": "infores:openpredict",
281 },
282 ],
283 }
285 # Map the source/target of query_graph to source/target of association
286 # if association['source']['type'] == query_plan[edge_qg_id]['from_type']:
287 edge_dict["subject"] = source_node_id
288 edge_dict["object"] = target_node_id
290 # TODO: Define the predicate depending on the association source type returned by OpenPredict classifier
291 if len(query_plan[edge_qg_id]["predicates"]) > 0: 291 ↛ 294line 291 didn't jump to line 294, because the condition on line 291 was never false
292 edge_dict["predicate"] = query_plan[edge_qg_id]["predicates"][0]
293 else:
294 edge_dict["predicate"] = prediction_relation["predicate"]
296 # Add the association in the knowledge_graph as edge
297 # Use the type as key in the result association dict (for IDs)
298 knowledge_graph["edges"][edge_kg_id] = edge_dict
300 # Add the bindings to the results object
301 result = {
302 "node_bindings": {},
303 "analyses": [
304 {
305 "resource_id": "infores:openpredict",
306 "score": association_score,
307 "scoring_method": "Model confidence between 0 and 1",
308 "edge_bindings": {edge_qg_id: [{"id": edge_kg_id}]},
309 }
310 ],
311 # 'edge_bindings': {},
312 }
313 # result['edge_bindings'][edge_qg_id] = [
314 # {
315 # "id": edge_kg_id
316 # }
317 # ]
318 result["node_bindings"][query_plan[edge_qg_id][f"{rel_to_predict}_qg_id"]] = [
319 {"id": source_node_id}
320 ]
321 result["node_bindings"][query_plan[edge_qg_id][f"{rel_predicted}_qg_id"]] = [
322 {"id": target_node_id}
323 ]
324 query_results.append(result)
326 kg_edge_count += 1
327 if kg_edge_count == n_results: 327 ↛ 328line 327 didn't jump to line 328, because the condition on line 327 was never true
328 break
330 # Generate kg nodes from the dict of nodes + result from query to resolve labels
331 for node_id, properties in node_dict.items():
332 node_category = properties["type"]
333 if isinstance(node_category, str) and not node_category.startswith("biolink:"): 333 ↛ 334line 333 didn't jump to line 334, because the condition on line 333 was never true
334 node_category = "biolink:" + node_category.capitalize()
335 if isinstance(node_category, str):
336 node_category = [node_category]
337 node_to_add = {
338 "categories": node_category,
339 }
340 if "label" in properties and properties["label"]: 340 ↛ 342line 340 didn't jump to line 342, because the condition on line 340 was never false
341 node_to_add["name"] = properties["label"]
342 knowledge_graph["nodes"][node_id] = node_to_add
344 return {
345 "message": {"knowledge_graph": knowledge_graph, "query_graph": query_graph, "results": query_results},
346 "query_options": query_options,
347 "reasoner_id": "infores:openpredict",
348 "schema_version": settings.TRAPI_VERSION,
349 "biolink_version": settings.BIOLINK_VERSION,
350 "status": "Success",
351 # "logs": [
352 # {
353 # "code": None,
354 # "level": "INFO",
355 # "message": "No descendants found from Ontology KP for QNode 'n00'.",
356 # "timestamp": "2023-04-05T07:24:26.646711"
357 # },
358 # ]
359 }
362example_trapi = {
363 "message": {
364 "query_graph": {
365 "edges": {"e01": {"object": "n1", "predicates": ["biolink:treated_by", "biolink:treats"], "subject": "n0"}},
366 "nodes": {
367 "n0": {"categories": ["biolink:Disease", "biolink:Drug"], "ids": ["OMIM:246300", "DRUGBANK:DB00394"]},
368 "n1": {"categories": ["biolink:Drug", "biolink:Disease"]},
369 },
370 }
371 },
372 "query_options": {"max_score": 1, "min_score": 0.5},
373}