Coverage for src/trapi_predict_kit/trapi.py: 90%
73 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-25 21:15 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-25 21:15 +0200
1import os
2import time
3from typing import Any, Callable, Dict, List, Optional
5from fastapi import Body, FastAPI, Request
6from fastapi.middleware.cors import CORSMiddleware
7from fastapi.responses import JSONResponse, RedirectResponse
8from reasoner_pydantic import Query
10from trapi_predict_kit.trapi_parser import resolve_trapi_query
11from trapi_predict_kit.types import PredictOptions
13REQUIRED_TAGS = [
14 {"name": "reasoner"},
15 {"name": "trapi"},
16 {"name": "models"},
17 {"name": "openpredict"},
18 {"name": "translator"},
19]
22class TRAPI(FastAPI):
23 """Translator Reasoner API - wrapper for FastAPI."""
25 def __init__(
26 self,
27 *args: Any,
28 predict_endpoints: List[Callable],
29 ordered_servers: Optional[List[Dict[str, str]]] = None,
30 itrb_url_prefix: Optional[str] = None,
31 dev_server_url: Optional[str] = None,
32 info: Optional[Dict[str, Any]] = None,
33 title="Translator Reasoner API",
34 version="1.0.0",
35 openapi_version="3.0.1",
36 description="""Get predicted targets for a given entity
37\n\nService supported by the [NCATS Translator project](https://ncats.nih.gov/translator/about)""",
38 **kwargs: Any,
39 ):
40 super().__init__(
41 *args,
42 title=title,
43 version=version,
44 openapi_version=openapi_version,
45 description=description,
46 root_path_in_servers=False,
47 **kwargs,
48 )
49 self.predict_endpoints = predict_endpoints
50 self.info = info
52 # On ITRB deployment and local dev we directly use the current server
53 self.servers = []
55 # For the API deployed on our server and registered to SmartAPI we provide the complete list
56 if os.getenv("VIRTUAL_HOST"): 56 ↛ 96line 56 didn't jump to line 96, because the condition on line 56 was never false
57 if itrb_url_prefix: 57 ↛ 79line 57 didn't jump to line 79, because the condition on line 57 was never false
58 self.servers.append(
59 {
60 "url": f"https://{itrb_url_prefix}.transltr.io",
61 "description": "TRAPI ITRB Production Server",
62 "x-maturity": "production",
63 }
64 )
65 self.servers.append(
66 {
67 "url": f"https://{itrb_url_prefix}.test.transltr.io",
68 "description": "TRAPI ITRB Test Server",
69 "x-maturity": "testing",
70 }
71 )
72 self.servers.append(
73 {
74 "url": f"https://{itrb_url_prefix}.ci.transltr.io",
75 "description": "TRAPI ITRB CI Server",
76 "x-maturity": "staging",
77 }
78 )
79 if dev_server_url: 79 ↛ 84line 79 didn't jump to line 84, because the condition on line 79 was never false
80 self.servers.append(
81 {"url": dev_server_url, "description": "TRAPI Dev Server", "x-maturity": "development"}
82 )
84 ordered_servers = []
85 # Add the current server as 1st server in the list
86 for server in self.servers: 86 ↛ 91line 86 didn't jump to line 91, because the loop on line 86 didn't complete
87 if os.getenv("VIRTUAL_HOST") in server["url"]:
88 ordered_servers.append(server)
89 break
90 # Add other servers
91 for server in self.servers:
92 if os.getenv("VIRTUAL_HOST") not in server["url"]:
93 ordered_servers.append(server)
94 self.servers = ordered_servers
96 self.add_middleware(
97 CORSMiddleware,
98 allow_origins=["*"],
99 allow_credentials=True,
100 allow_methods=["*"],
101 allow_headers=["*"],
102 )
104 trapi_example = {
105 "message": {
106 "query_graph": {
107 "edges": {"e01": {"object": "n1", "predicates": ["biolink:treated_by"], "subject": "n0"}},
108 "nodes": {
109 "n0": {
110 "categories": ["biolink:Disease"],
111 "ids": [
112 "OMIM:246300",
113 # "MONDO:0007190"
114 ],
115 },
116 "n1": {"categories": ["biolink:Drug"]},
117 },
118 }
119 },
120 "query_options": {"max_score": 1, "min_score": 0.5, "n_results": 10},
121 }
123 @self.post(
124 "/query",
125 name="TRAPI query",
126 description="""The default example TRAPI query will give you a list of predicted potential drug treatments for a given disease
128You can also try this query to retrieve similar entities to a given drug:
130```json
131{
132 "message": {
133 "query_graph": {
134 "edges": {
135 "e01": {
136 "object": "n1",
137 "predicates": [ "biolink:similar_to" ],
138 "subject": "n0"
139 }
140 },
141 "nodes": {
142 "n0": {
143 "categories": [ "biolink:Drug" ],
144 "ids": [ "DRUGBANK:DB00394" ]
145 },
146 "n1": {
147 "categories": [ "biolink:Drug" ]
148 }
149 }
150 }
151 },
152 "query_options": { "n_results": 5 }
153}
154```
155 """,
156 response_model=Query,
157 tags=["reasoner"],
158 )
159 def post_reasoner_predict(request_body: Query = Body(..., example=trapi_example)) -> Query:
160 """Get predicted associations for a given ReasonerAPI query.
162 :param request_body: The ReasonerStdAPI query in JSON
163 :return: Predictions as a ReasonerStdAPI Message
164 """
165 query_graph = request_body.message.query_graph.dict(exclude_none=True)
167 if len(query_graph["edges"]) == 0: 167 ↛ 168line 167 didn't jump to line 168, because the condition on line 167 was never true
168 return {
169 "message": {
170 "knowledge_graph": {"nodes": {}, "edges": {}},
171 "query_graph": query_graph,
172 "results": [],
173 }
174 }
175 # return ({"status": 400, "title": "Bad Request", "detail": "No edges", "type": "about:blank" }, 400)
177 if len(query_graph["edges"]) > 1: 177 ↛ 179line 177 didn't jump to line 179, because the condition on line 177 was never true
178 # Currently just return a empty result if multi-edges query
179 return {
180 "message": {
181 "knowledge_graph": {"nodes": {}, "edges": {}},
182 "query_graph": query_graph,
183 "results": [],
184 }
185 }
186 # return ({"status": 501, "title": "Not Implemented", "detail": "Multi-edges queries not yet implemented", "type": "about:blank" }, 501)
188 reasonerapi_response = resolve_trapi_query(request_body.dict(exclude_none=True), self.predict_endpoints)
190 return JSONResponse(reasonerapi_response) or ("Not found", 404)
192 @self.get(
193 "/meta_knowledge_graph",
194 name="Get the meta knowledge graph",
195 description="Get the meta knowledge graph",
196 response_model=dict,
197 tags=["trapi"],
198 )
199 def get_meta_knowledge_graph() -> dict:
200 """Get predicates and entities provided by the API
202 :return: JSON with biolink entities
203 """
204 metakg = {"edges": [], "nodes": {}}
205 for predict_func in self.predict_endpoints:
206 if predict_func._trapi_predict["edges"] not in metakg["edges"]: 206 ↛ 209line 206 didn't jump to line 209, because the condition on line 206 was never false
207 metakg["edges"] += predict_func._trapi_predict["edges"]
208 # Merge nodes dict
209 metakg["nodes"] = {**metakg["nodes"], **predict_func._trapi_predict["nodes"]}
210 return JSONResponse(metakg)
212 @self.middleware("http")
213 async def add_process_time_header(request: Request, call_next):
214 start_time = time.time()
215 response = await call_next(request)
216 process_time = time.time() - start_time
217 response.headers["X-Process-Time"] = str(process_time)
218 return response
220 @self.get("/health", include_in_schema=False)
221 def health_check():
222 """Health check for Translator elastic load balancer"""
223 return {"status": "ok"}
225 @self.get("/", include_in_schema=False)
226 def redirect_root_to_docs():
227 """Redirect the route / to /docs"""
228 return RedirectResponse(url="/docs")
230 # Generate endpoints for the loaded models
231 def endpoint_factory(predict_func):
232 def prediction_endpoint(
233 input_id: str = predict_func._trapi_predict["default_input"],
234 model_id: str = predict_func._trapi_predict["default_model"],
235 min_score: Optional[float] = None,
236 max_score: Optional[float] = None,
237 n_results: Optional[int] = None,
238 ):
239 try:
240 return predict_func(
241 input_id,
242 PredictOptions.parse_obj(
243 {
244 "model_id": model_id,
245 "min_score": min_score,
246 "max_score": max_score,
247 "n_results": n_results,
248 # "types": ['biolink:Drug'],
249 }
250 ),
251 )
252 except Exception as e:
253 return (f"Error when getting the predictions: {e}", 500)
255 return prediction_endpoint
257 for predict_func in self.predict_endpoints:
258 self.add_api_route(
259 path=predict_func._trapi_predict["path"],
260 methods=["GET"],
261 # endpoint=copy_func(prediction_endpoint, model['path'].replace('/', '')),
262 endpoint=endpoint_factory(predict_func),
263 name=predict_func._trapi_predict["name"],
264 openapi_extra={"description": predict_func._trapi_predict["description"]},
265 response_model=dict,
266 tags=["models"],
267 )