Coverage for src/trapi_predict_kit/rdf_utils.py: 58%

125 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-25 17:10 +0200

1import glob 

2import os 

3import uuid 

4from datetime import datetime 

5 

6from rdflib import RDF, Graph, Literal, Namespace, URIRef 

7from rdflib.namespace import DC, RDFS, XSD 

8from SPARQLWrapper import JSON, SPARQLWrapper 

9 

10from trapi_predict_kit.config import settings 

11 

12if not settings.OPENPREDICT_DATA_DIR.endswith("/"): 12 ↛ 17line 12 didn't jump to line 17, because the condition on line 12 was never false

13 settings.OPENPREDICT_DATA_DIR += "/" 

14# RDF_DATA_PATH = settings.OPENPREDICT_DATA_DIR + 'openpredict-metadata.ttl' 

15 

16 

17OPENPREDICT_GRAPH = "https://w3id.org/openpredict/graph" 

18OPENPREDICT_NAMESPACE = "https://w3id.org/openpredict/" 

19BIOLINK = Namespace("https://w3id.org/biolink/vocab/") 

20 

21OWL = Namespace("http://www.w3.org/2002/07/owl#") 

22SKOS = Namespace("http://www.w3.org/2004/02/skos/core#") 

23SCHEMA = Namespace("http://schema.org/") 

24DCAT = Namespace("http://www.w3.org/ns/dcat#") 

25PROV = Namespace("http://www.w3.org/ns/prov#") 

26MLS = Namespace("http://www.w3.org/ns/mls#") 

27OPENPREDICT = Namespace("https://w3id.org/openpredict/") 

28 

29# Get SPARQL endpoint credentials from environment variables 

30SPARQL_ENDPOINT_PASSWORD = os.getenv("SPARQL_PASSWORD") 

31SPARQL_ENDPOINT_USERNAME = os.getenv("SPARQL_USERNAME") 

32SPARQL_ENDPOINT_URL = os.getenv("SPARQL_ENDPOINT_URL") 

33SPARQL_ENDPOINT_UPDATE_URL = os.getenv("SPARQL_ENDPOINT_UPDATE_URL") 

34 

35# Default credentials for dev (if no environment variables provided) 

36if not SPARQL_ENDPOINT_USERNAME: 36 ↛ 39line 36 didn't jump to line 39, because the condition on line 36 was never false

37 # SPARQL_ENDPOINT_USERNAME='import_user' 

38 SPARQL_ENDPOINT_USERNAME = "dba" 

39if not SPARQL_ENDPOINT_PASSWORD: 39 ↛ 49line 39 didn't jump to line 49, because the condition on line 39 was never false

40 SPARQL_ENDPOINT_PASSWORD = "dba" 

41# if not SPARQL_ENDPOINT_URL: 

42# SPARQL_ENDPOINT_URL='http://localhost:8890/sparql' 

43# SPARQL_ENDPOINT_URL='https://graphdb.dumontierlab.com/repositories/translator-openpredict-dev' 

44# if not SPARQL_ENDPOINT_UPDATE_URL: 

45# SPARQL_ENDPOINT_UPDATE_URL = 'http://localhost:8890/sparql' 

46# SPARQL_ENDPOINT_UPDATE_URL='https://graphdb.dumontierlab.com/repositories/translator-openpredict-dev/statements' 

47 

48# Uncomment this line to test OpenPredict in dev mode using a RDF file instead of a SPARQL endpoint 

49SPARQL_ENDPOINT_URL = None 

50 

51 

52def get_models_graph(models_dir: str = "models"): 

53 """Helper function to get a graph with the RDF from all models given in a list""" 

54 g = Graph() 

55 

56 for file in glob.glob(f"{models_dir}/*.ttl"): 

57 g.parse(file) 

58 

59 # for loaded_model in models_list: 

60 # g.parse(f"{loaded_model['model']}.ttl") 

61 # # g.parse(f"{os.getcwd()}/{loaded_model['model']}.ttl") 

62 return g 

63 

64 

65def query_sparql_endpoint(query, g, parameters=[]): 

66 """Run select SPARQL query against SPARQL endpoint 

67 

68 :param query: SPARQL query as a string 

69 :return: Object containing the result bindings 

70 """ 

71 if SPARQL_ENDPOINT_URL: 

72 sparql = SPARQLWrapper(SPARQL_ENDPOINT_URL) 

73 sparql.setReturnFormat(JSON) 

74 sparql.setQuery(query) 

75 results = sparql.query().convert() 

76 # print('SPARQLWrapper Results:') 

77 # print(results["results"]["bindings"]) 

78 return results["results"]["bindings"] 

79 else: 

80 # Trying to SPARQL query a RDF file directly, to avoid using triplestores in dev (not working) 

81 # Docs: https://rdflib.readthedocs.io/en/stable/intro_to_sparql.html 

82 # Examples: https://github.com/RDFLib/rdflib/tree/master/examples 

83 # Use SPARQLStore? https://github.com/RDFLib/rdflib/blob/master/examples/sparqlstore_example.py 

84 # But this would require to rewrite all SPARQL query resolution to use rdflib response object 

85 # Which miss the informations about which SPARQL variables (just returns rows of results without variable bind) 

86 qres = g.query(query) 

87 results = [] 

88 for row in qres: 

89 result = {} 

90 for _i, p in enumerate(parameters): 

91 result[p] = {} 

92 result[p]["value"] = str(row[p]) 

93 results.append(result) 

94 # How can we iterate over the variable defined in the SPARQL query? 

95 # It only returns the results, without the variables list 

96 # Does not seems possible: https://dokk.org/documentation/rdflib/3.2.0/gettingstarted/#run-a-query 

97 # print(row.run) 

98 # or row["s"] 

99 # or row[rdflib.Variable("s")] 

100 # TODO: create an object similar to SPARQLWrapper 

101 # result[variable]['value'] 

102 # print(results) 

103 return results 

104 

105 

106# def init_triplestore(): 

107# """Only initialized the triplestore if no run for openpredict_baseline can be found. 

108# Init using the data/openpredict-metadata.ttl RDF file 

109# """ 

110# # check_baseline_run_query = """SELECT DISTINCT ?runType 

111# # WHERE { 

112# # <https://w3id.org/openpredict/run/openpredict_baseline> a ?runType 

113# # } LIMIT 10 

114# # """ 

115# # results = query_sparql_endpoint(check_baseline_run_query, parameters=['runType']) 

116# # if (len(results) < 1): 

117# g = Graph() 

118# g.parse('openpredict/data/openpredict-metadata.ttl', format="ttl") 

119# insert_graph_in_sparql_endpoint(g) 

120# print('Triplestore initialized at ' + SPARQL_ENDPOINT_UPDATE_URL) 

121 

122 

123def get_run_id(run_id=None): 

124 if not run_id: 

125 # Generate random UUID for the run ID 

126 run_id = str(uuid.uuid1()) 

127 return run_id 

128 

129 

130def get_run_metadata(scores, model_features, hyper_params, run_id=None): 

131 """Generate RDF metadata for a classifier and save it in data/openpredict-metadata.ttl, based on OpenPredict model: 

132 https://github.com/fair-workflows/openpredict/blob/master/data/rdf/results_disjoint_lr.nq 

133 

134 :param scores: scores 

135 :param model_features: List of features in the model 

136 :param label: label of the classifier 

137 :return: Run id 

138 """ 

139 g = Graph() 

140 g.bind("mls", Namespace("http://www.w3.org/ns/mls#")) 

141 g.bind("prov", Namespace("http://www.w3.org/ns/prov#")) 

142 g.bind("dc", Namespace("http://purl.org/dc/elements/1.1/")) 

143 g.bind("openpredict", Namespace("https://w3id.org/openpredict/")) 

144 

145 if not run_id: 145 ↛ 147line 145 didn't jump to line 147, because the condition on line 145 was never true

146 # Generate random UUID for the run ID 

147 run_id = str(uuid.uuid1()) 

148 

149 run_uri = URIRef(OPENPREDICT_NAMESPACE + "run/" + run_id) 

150 run_prop_prefix = OPENPREDICT_NAMESPACE + run_id + "/" 

151 evaluation_uri = URIRef(OPENPREDICT_NAMESPACE + "run/" + run_id + "/ModelEvaluation") 

152 # The same for all run: 

153 implementation_uri = URIRef(OPENPREDICT_NAMESPACE + "implementation/OpenPredict") 

154 

155 # Add Run metadata 

156 g.add((run_uri, RDF.type, MLS["Run"])) 

157 g.add((run_uri, DC.identifier, Literal(run_id))) 

158 g.add((run_uri, PROV["generatedAtTime"], Literal(datetime.now(), datatype=XSD.dateTime))) 

159 g.add((run_uri, MLS["realizes"], OPENPREDICT["LogisticRegression"])) 

160 g.add((run_uri, MLS["executes"], implementation_uri)) 

161 g.add((run_uri, MLS["hasOutput"], evaluation_uri)) 

162 g.add((run_uri, MLS["hasOutput"], URIRef(run_prop_prefix + "Model"))) 

163 

164 # Add Model, should we point it to the generated model? 

165 g.add((URIRef(run_prop_prefix + "Model"), RDF.type, MLS["Model"])) 

166 

167 # Add implementation metadata 

168 g.add((OPENPREDICT["LogisticRegression"], RDF.type, MLS["Algorithm"])) 

169 g.add((implementation_uri, RDF.type, MLS["Implementation"])) 

170 g.add((implementation_uri, MLS["implements"], OPENPREDICT["LogisticRegression"])) 

171 

172 # Add HyperParameters and their settings to implementation 

173 for hyperparam, hyperparam_setting in hyper_params.items(): 

174 hyperparam_uri = URIRef(OPENPREDICT_NAMESPACE + "HyperParameter/" + hyperparam) 

175 g.add((implementation_uri, MLS["hasHyperParameter"], hyperparam_uri)) 

176 g.add((hyperparam_uri, RDF.type, MLS["HyperParameter"])) 

177 g.add((hyperparam_uri, RDFS.label, Literal(hyperparam))) 

178 

179 hyperparam_setting_uri = URIRef(OPENPREDICT_NAMESPACE + "HyperParameterSetting/" + hyperparam) 

180 g.add((implementation_uri, MLS["hasHyperParameter"], hyperparam_setting_uri)) 

181 g.add((hyperparam_setting_uri, RDF.type, MLS["HyperParameterSetting"])) 

182 g.add((hyperparam_setting_uri, MLS["specifiedBy"], hyperparam_uri)) 

183 g.add((hyperparam_setting_uri, MLS["hasValue"], Literal(hyperparam_setting))) 

184 g.add((run_uri, MLS["hasInput"], hyperparam_setting_uri)) 

185 

186 # TODO: improve how we retrieve features 

187 for feature in model_features: 

188 feature_uri = URIRef( 

189 OPENPREDICT_NAMESPACE + "feature/" + feature.replace(" ", "_").replace("(", "").replace(")", "") 

190 ) 

191 g.add((feature_uri, RDF.type, MLS["Feature"])) 

192 g.add((feature_uri, DC.identifier, Literal(feature))) 

193 g.add((run_uri, MLS["hasInput"], feature_uri)) 

194 

195 # TODO: those 2 triples are for the PLEX ontology 

196 g.add((evaluation_uri, RDF.type, PROV["Entity"])) 

197 g.add((evaluation_uri, PROV["wasGeneratedBy"], run_uri)) 

198 

199 # Add scores as EvaluationMeasures 

200 g.add((evaluation_uri, RDF.type, MLS["ModelEvaluation"])) 

201 for key in scores: 

202 key_uri = URIRef(run_prop_prefix + key) 

203 g.add((evaluation_uri, MLS["specifiedBy"], key_uri)) 

204 g.add((key_uri, RDF.type, MLS["EvaluationMeasure"])) 

205 g.add((key_uri, RDFS.label, Literal(key))) 

206 g.add((key_uri, MLS["hasValue"], Literal(scores[key], datatype=XSD.double))) 

207 # TODO: The Example 1 puts hasValue directly in the ModelEvaluation 

208 # but that prevents to provide multiple values for 1 evaluation 

209 # http://ml-schema.github.io/documentation/ML%20Schema.html#overview 

210 

211 return g 

212 

213 

214def retrieve_features(g, type="Both", run_id=None): 

215 """Get features in the ML model 

216 

217 :param type: type of the feature (Both, Drug, Disease) 

218 :return: JSON with features 

219 """ 

220 if run_id: 

221 sparql_feature_for_run = ( 

222 """PREFIX dct: <http://purl.org/dc/terms/> 

223 PREFIX mls: <http://www.w3.org/ns/mls#> 

224 PREFIX prov: <http://www.w3.org/ns/prov#> 

225 PREFIX openpredict: <https://w3id.org/openpredict/> 

226 PREFIX dc: <http://purl.org/dc/elements/1.1/> 

227 PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> 

228 PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> 

229 SELECT DISTINCT ?feature ?featureId 

230 WHERE { 

231 ?run a mls:Run ; 

232 dc:identifier \"""" 

233 + run_id 

234 + """\" ; 

235 mls:hasInput ?feature . 

236 ?feature dc:identifier ?featureId . 

237 }""" 

238 ) 

239 # <https://w3id.org/openpredict/embedding_type> ?embeddingType ; 

240 # dc:description ?featureDescription . 

241 results = query_sparql_endpoint(sparql_feature_for_run, g, parameters=["feature", "featureId"]) 

242 # print(results) 

243 

244 features_json = {} 

245 for result in results: 

246 features_json[result["feature"]["value"]] = { 

247 "id": result["featureId"]["value"], 

248 } 

249 

250 else: 

251 # type_filter = '' 

252 # if (type != "Both"): 

253 # type_filter = 'FILTER(?embeddingType = "' + type + '")' 

254 

255 query = """SELECT DISTINCT ?id ?feature 

256 WHERE {{ 

257 ?feature a <http://www.w3.org/ns/mls#Feature> ; 

258 <http://purl.org/dc/elements/1.1/identifier> ?id . 

259 }} 

260 """ 

261 # {type_filter} .format(type_filter=type_filter) 

262 results = query_sparql_endpoint(query, g, parameters=["id", "feature"]) 

263 # print(results) 

264 

265 features_json = {} 

266 for result in results: 

267 features_json[result["feature"]["value"]] = { 

268 "id": result["id"]["value"], 

269 } 

270 return features_json 

271 

272 

273def retrieve_models(g): 

274 """Get models with their scores and features 

275 

276 :return: JSON with models and features 

277 """ 

278 sparql_get_scores = """PREFIX dct: <http://purl.org/dc/terms/> 

279 PREFIX mls: <http://www.w3.org/ns/mls#> 

280 PREFIX prov: <http://www.w3.org/ns/prov#> 

281 PREFIX openpredict: <https://w3id.org/openpredict/> 

282 PREFIX dc: <http://purl.org/dc/elements/1.1/> 

283 PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> 

284 PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> 

285 SELECT DISTINCT ?run ?runId ?generatedAtTime ?featureId ?accuracy ?average_precision ?f1 ?precision ?recall ?roc_auc 

286 WHERE { 

287 ?run a mls:Run ; 

288 dc:identifier ?runId ; 

289 prov:generatedAtTime ?generatedAtTime ; 

290 mls:hasInput ?features ; 

291 mls:hasOutput ?evaluation . 

292 ?evaluation a mls:ModelEvaluation . 

293 ?features dc:identifier ?featureId . 

294 

295 ?evaluation mls:specifiedBy [a mls:EvaluationMeasure ; 

296 rdfs:label "accuracy" ; 

297 mls:hasValue ?accuracy ] . 

298 ?evaluation mls:specifiedBy [ a mls:EvaluationMeasure ; 

299 rdfs:label "precision" ; 

300 mls:hasValue ?precision ] . 

301 ?evaluation mls:specifiedBy [ a mls:EvaluationMeasure ; 

302 rdfs:label "f1" ; 

303 mls:hasValue ?f1 ] . 

304 ?evaluation mls:specifiedBy [ a mls:EvaluationMeasure ; 

305 rdfs:label "recall" ; 

306 mls:hasValue ?recall ] . 

307 ?evaluation mls:specifiedBy [ a mls:EvaluationMeasure ; 

308 rdfs:label "roc_auc" ; 

309 mls:hasValue ?roc_auc ] . 

310 ?evaluation mls:specifiedBy [ a mls:EvaluationMeasure ; 

311 rdfs:label "average_precision" ; 

312 mls:hasValue ?average_precision ] . 

313 } 

314 """ 

315 

316 results = query_sparql_endpoint( 

317 sparql_get_scores, 

318 g, 

319 parameters=[ 

320 "run", 

321 "runId", 

322 "generatedAtTime", 

323 "featureId", 

324 "accuracy", 

325 "average_precision", 

326 "f1", 

327 "precision", 

328 "recall", 

329 "roc_auc", 

330 ], 

331 ) 

332 models_json = {} 

333 for result in results: 

334 if result["run"]["value"] in models_json: 

335 models_json[result["run"]["value"]]["features"].append(result["featureId"]["value"]) 

336 else: 

337 models_json[result["run"]["value"]] = { 

338 "id": result["runId"]["value"], 

339 "generatedAtTime": result["generatedAtTime"]["value"], 

340 "features": [result["featureId"]["value"]], 

341 "accuracy": result["accuracy"]["value"], 

342 "average_precision": result["average_precision"]["value"], 

343 "f1": result["f1"]["value"], 

344 "precision": result["precision"]["value"], 

345 "recall": result["recall"]["value"], 

346 "roc_auc": result["roc_auc"]["value"], 

347 } 

348 

349 # We could create an object with feature description instead of passing just the ID 

350 # features_json[result['id']['value']] = { 

351 # "description": result['description']['value'], 

352 # "type": result['embeddingType']['value'] 

353 # } 

354 return models_json 

355 

356 

357# TODO: Not really used, remove? 

358# def get_feature_metadata(id, description, type): 

359# """Generate RDF metadata for a feature 

360 

361# :param id: if used to identify the feature 

362# :param description: feature description 

363# :param type: feature type 

364# :return: rdflib graph after loading the feature 

365# """ 

366# g = Graph() 

367# feature_uri = URIRef(OPENPREDICT_NAMESPACE + 'feature/' + id) 

368# g.add((feature_uri, RDF.type, MLS['Feature'])) 

369# g.add((feature_uri, DC.identifier, Literal(id))) 

370# g.add((feature_uri, DC.description, Literal(description))) 

371# g.add((feature_uri, OPENPREDICT['embedding_type'], Literal(type))) 

372# return g 

373 

374 

375# def insert_graph_in_sparql_endpoint(g): 

376# """Insert rdflib graph in a Update SPARQL endpoint using SPARQLWrapper 

377 

378# :param g: rdflib graph to insert 

379# :return: SPARQL update query result 

380# """ 

381# if SPARQL_ENDPOINT_URL: 

382# sparql = SPARQLWrapper(SPARQL_ENDPOINT_UPDATE_URL) 

383# sparql.setMethod(POST) 

384# # sparql.setHTTPAuth(BASIC) 

385# sparql.setCredentials(SPARQL_ENDPOINT_USERNAME, 

386# SPARQL_ENDPOINT_PASSWORD) 

387# query = """INSERT DATA {{ GRAPH <{graph}> 

388# {{ 

389# {ntriples} 

390# }} 

391# }} 

392# """.format(ntriples=g.serialize(format='nt').decode('utf-8'), graph=OPENPREDICT_GRAPH) 

393 

394# sparql.setQuery(query) 

395# return sparql.query() 

396# else: 

397# # If no SPARQL endpoint provided we store to the RDF file in data/openpredict-metadata.ttl (working) 

398# graph_from_file = Graph() 

399# graph_from_file.parse(RDF_DATA_PATH, format="ttl") 

400# # graph_from_file.parse(g.serialize(format='turtle').decode('utf-8'), format="ttl") 

401# graph_from_file = graph_from_file + g 

402# graph_from_file.serialize(RDF_DATA_PATH, format='turtle')