Coverage for src/trapi_predict_kit/trapi.py: 90%

73 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-25 21:15 +0200

1import os 

2import time 

3from typing import Any, Callable, Dict, List, Optional 

4 

5from fastapi import Body, FastAPI, Request 

6from fastapi.middleware.cors import CORSMiddleware 

7from fastapi.responses import JSONResponse, RedirectResponse 

8from reasoner_pydantic import Query 

9 

10from trapi_predict_kit.trapi_parser import resolve_trapi_query 

11from trapi_predict_kit.types import PredictOptions 

12 

13REQUIRED_TAGS = [ 

14 {"name": "reasoner"}, 

15 {"name": "trapi"}, 

16 {"name": "models"}, 

17 {"name": "openpredict"}, 

18 {"name": "translator"}, 

19] 

20 

21 

22class TRAPI(FastAPI): 

23 """Translator Reasoner API - wrapper for FastAPI.""" 

24 

25 def __init__( 

26 self, 

27 *args: Any, 

28 predict_endpoints: List[Callable], 

29 ordered_servers: Optional[List[Dict[str, str]]] = None, 

30 itrb_url_prefix: Optional[str] = None, 

31 dev_server_url: Optional[str] = None, 

32 info: Optional[Dict[str, Any]] = None, 

33 title="Translator Reasoner API", 

34 version="1.0.0", 

35 openapi_version="3.0.1", 

36 description="""Get predicted targets for a given entity 

37\n\nService supported by the [NCATS Translator project](https://ncats.nih.gov/translator/about)""", 

38 **kwargs: Any, 

39 ): 

40 super().__init__( 

41 *args, 

42 title=title, 

43 version=version, 

44 openapi_version=openapi_version, 

45 description=description, 

46 root_path_in_servers=False, 

47 **kwargs, 

48 ) 

49 self.predict_endpoints = predict_endpoints 

50 self.info = info 

51 

52 # On ITRB deployment and local dev we directly use the current server 

53 self.servers = [] 

54 

55 # For the API deployed on our server and registered to SmartAPI we provide the complete list 

56 if os.getenv("VIRTUAL_HOST"): 56 ↛ 96line 56 didn't jump to line 96, because the condition on line 56 was never false

57 if itrb_url_prefix: 57 ↛ 79line 57 didn't jump to line 79, because the condition on line 57 was never false

58 self.servers.append( 

59 { 

60 "url": f"https://{itrb_url_prefix}.transltr.io", 

61 "description": "TRAPI ITRB Production Server", 

62 "x-maturity": "production", 

63 } 

64 ) 

65 self.servers.append( 

66 { 

67 "url": f"https://{itrb_url_prefix}.test.transltr.io", 

68 "description": "TRAPI ITRB Test Server", 

69 "x-maturity": "testing", 

70 } 

71 ) 

72 self.servers.append( 

73 { 

74 "url": f"https://{itrb_url_prefix}.ci.transltr.io", 

75 "description": "TRAPI ITRB CI Server", 

76 "x-maturity": "staging", 

77 } 

78 ) 

79 if dev_server_url: 79 ↛ 84line 79 didn't jump to line 84, because the condition on line 79 was never false

80 self.servers.append( 

81 {"url": dev_server_url, "description": "TRAPI Dev Server", "x-maturity": "development"} 

82 ) 

83 

84 ordered_servers = [] 

85 # Add the current server as 1st server in the list 

86 for server in self.servers: 86 ↛ 91line 86 didn't jump to line 91, because the loop on line 86 didn't complete

87 if os.getenv("VIRTUAL_HOST") in server["url"]: 

88 ordered_servers.append(server) 

89 break 

90 # Add other servers 

91 for server in self.servers: 

92 if os.getenv("VIRTUAL_HOST") not in server["url"]: 

93 ordered_servers.append(server) 

94 self.servers = ordered_servers 

95 

96 self.add_middleware( 

97 CORSMiddleware, 

98 allow_origins=["*"], 

99 allow_credentials=True, 

100 allow_methods=["*"], 

101 allow_headers=["*"], 

102 ) 

103 

104 trapi_example = { 

105 "message": { 

106 "query_graph": { 

107 "edges": {"e01": {"object": "n1", "predicates": ["biolink:treated_by"], "subject": "n0"}}, 

108 "nodes": { 

109 "n0": { 

110 "categories": ["biolink:Disease"], 

111 "ids": [ 

112 "OMIM:246300", 

113 # "MONDO:0007190" 

114 ], 

115 }, 

116 "n1": {"categories": ["biolink:Drug"]}, 

117 }, 

118 } 

119 }, 

120 "query_options": {"max_score": 1, "min_score": 0.5, "n_results": 10}, 

121 } 

122 

123 @self.post( 

124 "/query", 

125 name="TRAPI query", 

126 description="""The default example TRAPI query will give you a list of predicted potential drug treatments for a given disease 

127 

128You can also try this query to retrieve similar entities to a given drug: 

129 

130```json 

131{ 

132 "message": { 

133 "query_graph": { 

134 "edges": { 

135 "e01": { 

136 "object": "n1", 

137 "predicates": [ "biolink:similar_to" ], 

138 "subject": "n0" 

139 } 

140 }, 

141 "nodes": { 

142 "n0": { 

143 "categories": [ "biolink:Drug" ], 

144 "ids": [ "DRUGBANK:DB00394" ] 

145 }, 

146 "n1": { 

147 "categories": [ "biolink:Drug" ] 

148 } 

149 } 

150 } 

151 }, 

152 "query_options": { "n_results": 5 } 

153} 

154``` 

155 """, 

156 response_model=Query, 

157 tags=["reasoner"], 

158 ) 

159 def post_reasoner_predict(request_body: Query = Body(..., example=trapi_example)) -> Query: 

160 """Get predicted associations for a given ReasonerAPI query. 

161 

162 :param request_body: The ReasonerStdAPI query in JSON 

163 :return: Predictions as a ReasonerStdAPI Message 

164 """ 

165 query_graph = request_body.message.query_graph.dict(exclude_none=True) 

166 

167 if len(query_graph["edges"]) == 0: 167 ↛ 168line 167 didn't jump to line 168, because the condition on line 167 was never true

168 return { 

169 "message": { 

170 "knowledge_graph": {"nodes": {}, "edges": {}}, 

171 "query_graph": query_graph, 

172 "results": [], 

173 } 

174 } 

175 # return ({"status": 400, "title": "Bad Request", "detail": "No edges", "type": "about:blank" }, 400) 

176 

177 if len(query_graph["edges"]) > 1: 177 ↛ 179line 177 didn't jump to line 179, because the condition on line 177 was never true

178 # Currently just return a empty result if multi-edges query 

179 return { 

180 "message": { 

181 "knowledge_graph": {"nodes": {}, "edges": {}}, 

182 "query_graph": query_graph, 

183 "results": [], 

184 } 

185 } 

186 # return ({"status": 501, "title": "Not Implemented", "detail": "Multi-edges queries not yet implemented", "type": "about:blank" }, 501) 

187 

188 reasonerapi_response = resolve_trapi_query(request_body.dict(exclude_none=True), self.predict_endpoints) 

189 

190 return JSONResponse(reasonerapi_response) or ("Not found", 404) 

191 

192 @self.get( 

193 "/meta_knowledge_graph", 

194 name="Get the meta knowledge graph", 

195 description="Get the meta knowledge graph", 

196 response_model=dict, 

197 tags=["trapi"], 

198 ) 

199 def get_meta_knowledge_graph() -> dict: 

200 """Get predicates and entities provided by the API 

201 

202 :return: JSON with biolink entities 

203 """ 

204 metakg = {"edges": [], "nodes": {}} 

205 for predict_func in self.predict_endpoints: 

206 if predict_func._trapi_predict["edges"] not in metakg["edges"]: 206 ↛ 209line 206 didn't jump to line 209, because the condition on line 206 was never false

207 metakg["edges"] += predict_func._trapi_predict["edges"] 

208 # Merge nodes dict 

209 metakg["nodes"] = {**metakg["nodes"], **predict_func._trapi_predict["nodes"]} 

210 return JSONResponse(metakg) 

211 

212 @self.middleware("http") 

213 async def add_process_time_header(request: Request, call_next): 

214 start_time = time.time() 

215 response = await call_next(request) 

216 process_time = time.time() - start_time 

217 response.headers["X-Process-Time"] = str(process_time) 

218 return response 

219 

220 @self.get("/health", include_in_schema=False) 

221 def health_check(): 

222 """Health check for Translator elastic load balancer""" 

223 return {"status": "ok"} 

224 

225 @self.get("/", include_in_schema=False) 

226 def redirect_root_to_docs(): 

227 """Redirect the route / to /docs""" 

228 return RedirectResponse(url="/docs") 

229 

230 # Generate endpoints for the loaded models 

231 def endpoint_factory(predict_func): 

232 def prediction_endpoint( 

233 input_id: str = predict_func._trapi_predict["default_input"], 

234 model_id: str = predict_func._trapi_predict["default_model"], 

235 min_score: Optional[float] = None, 

236 max_score: Optional[float] = None, 

237 n_results: Optional[int] = None, 

238 ): 

239 try: 

240 return predict_func( 

241 input_id, 

242 PredictOptions.parse_obj( 

243 { 

244 "model_id": model_id, 

245 "min_score": min_score, 

246 "max_score": max_score, 

247 "n_results": n_results, 

248 # "types": ['biolink:Drug'], 

249 } 

250 ), 

251 ) 

252 except Exception as e: 

253 return (f"Error when getting the predictions: {e}", 500) 

254 

255 return prediction_endpoint 

256 

257 for predict_func in self.predict_endpoints: 

258 self.add_api_route( 

259 path=predict_func._trapi_predict["path"], 

260 methods=["GET"], 

261 # endpoint=copy_func(prediction_endpoint, model['path'].replace('/', '')), 

262 endpoint=endpoint_factory(predict_func), 

263 name=predict_func._trapi_predict["name"], 

264 openapi_extra={"description": predict_func._trapi_predict["description"]}, 

265 response_model=dict, 

266 tags=["models"], 

267 )