Coverage for /home/deng/Projects/ete4/hackathon/ete4/ete4/gtdb_taxonomy/gtdbquery.py: 51%

505 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-07 10:27 +0200

1#!/usr/bin/env python3 

2 

3import sys 

4import os 

5 

6import pickle 

7from collections import defaultdict, Counter 

8 

9from hashlib import md5 

10 

11import sqlite3 

12import math 

13import tarfile 

14import warnings 

15import requests 

16 

17from ete4 import ETE_DATA_HOME, update_ete_data 

18 

19 

20__all__ = ["GTDBTaxa", "is_taxadb_up_to_date"] 

21 

22DB_VERSION = 2 

23DEFAULT_GTDBTAXADB = ETE_DATA_HOME + '/gtdbtaxa.sqlite' 

24DEFAULT_GTDBTAXADUMP = ETE_DATA_HOME + '/gtdbdump.tar.gz' 

25 

26def is_taxadb_up_to_date(dbfile=DEFAULT_GTDBTAXADB): 

27 """Check if a valid and up-to-date gtdbtaxa.sqlite database exists 

28 If dbfile= is not specified, DEFAULT_TAXADB is assumed 

29 """ 

30 db = sqlite3.connect(dbfile) 

31 try: 

32 r = db.execute('SELECT version FROM stats;') 

33 version = r.fetchone()[0] 

34 except (sqlite3.OperationalError, ValueError, IndexError, TypeError): 

35 version = None 

36 

37 db.close() 

38 

39 if version != DB_VERSION: 

40 return False 

41 return True 

42 

43 

44class GTDBTaxa: 

45 """ 

46 Local transparent connector to the GTDB taxonomy database. 

47 """ 

48 

49 def __init__(self, dbfile=None, taxdump_file=None, memory=False): 

50 

51 if not dbfile: 

52 self.dbfile = DEFAULT_GTDBTAXADB 

53 else: 

54 self.dbfile = dbfile 

55 

56 if taxdump_file: 

57 self.update_taxonomy_database(taxdump_file) 

58 

59 if dbfile != DEFAULT_GTDBTAXADB and not os.path.exists(self.dbfile): 

60 print('GTDB database not present yet (first time used?)', file=sys.stderr) 

61 urlbase = ('https://github.com/etetoolkit/ete-data/raw/main' 

62 '/gtdb_taxonomy/gtdblatest') 

63 

64 update_ete_data(f'{DEFAULT_GTDBTAXADUMP}', f'{urlbase}/gtdb_latest_dump.tar.gz') 

65 

66 self.update_taxonomy_database(taxdump_file=DEFAULT_GTDBTAXADUMP) 

67 

68 if not os.path.exists(self.dbfile): 

69 raise ValueError("Cannot open taxonomy database: %s" % self.dbfile) 

70 

71 self.db = None 

72 self._connect() 

73 

74 if not is_taxadb_up_to_date(self.dbfile): 

75 print('GTDB database format is outdated. Upgrading', file=sys.stderr) 

76 self.update_taxonomy_database(taxdump_file) 

77 

78 if memory: 

79 filedb = self.db 

80 self.db = sqlite3.connect(':memory:') 

81 filedb.backup(self.db) 

82 

83 def update_taxonomy_database(self, taxdump_file=None): 

84 """Update the GTDB taxonomy database. 

85 

86 It updates it by downloading and parsing the latest 

87 gtdbtaxdump.tar.gz file. 

88 

89 :param taxdump_file: Alternative location of gtdbtaxdump.tar.gz. 

90 """ 

91 update_db(self.dbfile, targz_file=taxdump_file) 

92 

93 def _connect(self): 

94 self.db = sqlite3.connect(self.dbfile) 

95 

96 def _translate_merged(self, all_taxids): 

97 conv_all_taxids = set((list(map(int, all_taxids)))) 

98 cmd = 'select taxid_old, taxid_new FROM merged WHERE taxid_old IN (%s)' %','.join(map(str, all_taxids)) 

99 

100 result = self.db.execute(cmd) 

101 conversion = {} 

102 for old, new in result.fetchall(): 

103 conv_all_taxids.discard(int(old)) 

104 conv_all_taxids.add(int(new)) 

105 conversion[int(old)] = int(new) 

106 return conv_all_taxids, conversion 

107 

108 

109 # def get_fuzzy_name_translation(self, name, sim=0.9): 

110 # ''' 

111 # Given an inexact species name, returns the best match in the NCBI database of taxa names. 

112 # :argument 0.9 sim: Min word similarity to report a match (from 0 to 1). 

113 # :return: taxid, species-name-match, match-score 

114 # ''' 

115 

116 

117 # import sqlite3.dbapi2 as dbapi2 

118 # _db = dbapi2.connect(self.dbfile) 

119 # _db.enable_load_extension(True) 

120 # module_path = os.path.split(os.path.realpath(__file__))[0] 

121 # _db.execute("select load_extension('%s')" % os.path.join(module_path, 

122 # "SQLite-Levenshtein/levenshtein.sqlext")) 

123 

124 # print("Trying fuzzy search for %s" % name) 

125 # maxdiffs = math.ceil(len(name) * (1-sim)) 

126 # cmd = 'SELECT taxid, spname, LEVENSHTEIN(spname, "%s") AS sim FROM species WHERE sim<=%s ORDER BY sim LIMIT 1;' % (name, maxdiffs) 

127 # taxid, spname, score = None, None, len(name) 

128 # result = _db.execute(cmd) 

129 # try: 

130 # taxid, spname, score = result.fetchone() 

131 # except TypeError: 

132 # cmd = 'SELECT taxid, spname, LEVENSHTEIN(spname, "%s") AS sim FROM synonym WHERE sim<=%s ORDER BY sim LIMIT 1;' % (name, maxdiffs) 

133 # result = _db.execute(cmd) 

134 # try: 

135 # taxid, spname, score = result.fetchone() 

136 # except: 

137 # pass 

138 # else: 

139 # taxid = int(taxid) 

140 # else: 

141 # taxid = int(taxid) 

142 

143 # norm_score = 1 - (float(score)/len(name)) 

144 # if taxid: 

145 # print("FOUND! %s taxid:%s score:%s (%s)" %(spname, taxid, score, norm_score)) 

146 

147 # return taxid, spname, norm_score 

148 

149 def _get_id2rank(self, internal_taxids): 

150 """Given a list of numeric ids (each one representing a taxa in GTDB), return a dictionary with their corresponding ranks. 

151 Examples:  

152 > gtdb.get_rank([2174, 205487, 610]) 

153 {2174: 'family', 205487: 'order', 610: 'phylum'} 

154 

155 Note: Numeric taxids are not recognized by the official GTDB taxonomy database, only for internal usage. 

156 """ 

157 ids = ','.join('"%s"' % v for v in set(internal_taxids) - {None, ''}) 

158 result = self.db.execute('SELECT taxid, rank FROM species WHERE taxid IN (%s)' % ids) 

159 return {tax: spname for tax, spname in result.fetchall()} 

160 

161 def get_rank(self, taxids): 

162 """Give a list of GTDB string taxids, return a dictionary with their corresponding ranks. 

163 Examples:  

164  

165 > gtdb.get_rank(['c__Thorarchaeia', 'RS_GCF_001477695.1']) 

166 {'c__Thorarchaeia': 'class', 'RS_GCF_001477695.1': 'subspecies'} 

167 """ 

168 

169 taxid2rank = {} 

170 name2ids = self._get_name_translator(taxids) 

171 overlap_ids = name2ids.values() 

172 taxids = [item for sublist in overlap_ids for item in sublist] 

173 ids = ','.join('"%s"' % v for v in set(taxids) - {None, ''}) 

174 result = self.db.execute('SELECT taxid, rank FROM species WHERE taxid IN (%s)' % ids) 

175 for tax, rank in result.fetchall(): 

176 taxid2rank[list(self._get_taxid_translator([tax]).values())[0]] = rank 

177 

178 return taxid2rank 

179 

180 def _get_lineage_translator(self, taxids): 

181 """Given a valid taxid number, return its corresponding lineage track as a 

182 hierarchically sorted list of parent taxids. 

183 """ 

184 all_ids = set(taxids) 

185 all_ids.discard(None) 

186 all_ids.discard("") 

187 query = ','.join(['"%s"' %v for v in all_ids]) 

188 result = self.db.execute('SELECT taxid, track FROM species WHERE taxid IN (%s);' %query) 

189 id2lineages = {} 

190 for tax, track in result.fetchall(): 

191 id2lineages[tax] = list(map(int, reversed(track.split(",")))) 

192 return id2lineages 

193 

194 def get_name_lineage(self, taxnames): 

195 """Given a valid taxname, return its corresponding lineage track as a 

196 hierarchically sorted list of parent taxnames. 

197 """ 

198 name_lineages = [] 

199 name2taxid = self._get_name_translator(taxnames) 

200 for key, value in name2taxid.items(): 

201 lineage = self._get_lineage(value[0]) 

202 names = self._get_taxid_translator(lineage) 

203 name_lineages.append({key:[names[taxid] for taxid in lineage]}) 

204 

205 return name_lineages 

206 

207 def _get_lineage(self, taxid): 

208 """Given a valid taxid number, return its corresponding lineage track as a 

209 hierarchically sorted list of parent taxids. 

210 """ 

211 if not taxid: 

212 return None 

213 taxid = int(taxid) 

214 result = self.db.execute('SELECT track FROM species WHERE taxid=%s' %taxid) 

215 raw_track = result.fetchone() 

216 if not raw_track: 

217 #perhaps is an obsolete taxid 

218 _, merged_conversion = self._translate_merged([taxid]) 

219 if taxid in merged_conversion: 

220 result = self.db.execute('SELECT track FROM species WHERE taxid=%s' %merged_conversion[taxid]) 

221 raw_track = result.fetchone() 

222 # if not raise error 

223 if not raw_track: 

224 #raw_track = ["1"] 

225 raise ValueError("%s taxid not found" %taxid) 

226 else: 

227 warnings.warn("taxid %s was translated into %s" %(taxid, merged_conversion[taxid])) 

228 

229 track = list(map(int, raw_track[0].split(","))) 

230 return list(reversed(track)) 

231 

232 def get_common_names(self, taxids): 

233 query = ','.join(['"%s"' %v for v in taxids]) 

234 cmd = "select taxid, common FROM species WHERE taxid IN (%s);" %query 

235 result = self.db.execute(cmd) 

236 id2name = {} 

237 for tax, common_name in result.fetchall(): 

238 if common_name: 

239 id2name[tax] = common_name 

240 return id2name 

241 

242 def _get_taxid_translator(self, taxids, try_synonyms=True): 

243 """Given a list of taxids, returns a dictionary with their corresponding 

244 scientific names. 

245 """ 

246 

247 all_ids = set(map(int, taxids)) 

248 all_ids.discard(None) 

249 all_ids.discard("") 

250 query = ','.join(['"%s"' %v for v in all_ids]) 

251 cmd = "select taxid, spname FROM species WHERE taxid IN (%s);" %query 

252 result = self.db.execute(cmd) 

253 id2name = {} 

254 for tax, spname in result.fetchall(): 

255 id2name[tax] = spname 

256 

257 # any taxid without translation? lets tray in the merged table 

258 # if len(all_ids) != len(id2name) and try_synonyms: 

259 # not_found_taxids = all_ids - set(id2name.keys()) 

260 # taxids, old2new = self._translate_merged(not_found_taxids) 

261 # new2old = {v: k for k,v in old2new.items()} 

262 

263 # if old2new: 

264 # query = ','.join(['"%s"' %v for v in new2old]) 

265 # cmd = "select taxid, spname FROM species WHERE taxid IN (%s);" %query 

266 # result = self.db.execute(cmd) 

267 # for tax, spname in result.fetchall(): 

268 # id2name[new2old[tax]] = spname 

269 

270 return id2name 

271 

272 def _get_name_translator(self, names): 

273 """ 

274 Given a list of taxid scientific names, returns a dictionary translating them into their corresponding taxids. 

275 Exact name match is required for translation. 

276 """ 

277 

278 name2id = {} 

279 #name2realname = {} 

280 name2origname = {} 

281 for n in names: 

282 name2origname[n.lower()] = n 

283 

284 names = set(name2origname.keys()) 

285 

286 query = ','.join(['"%s"' %n for n in name2origname.keys()]) 

287 cmd = 'select spname, taxid from species where spname IN (%s)' %query 

288 result = self.db.execute('select spname, taxid from species where spname IN (%s)' %query) 

289 for sp, taxid in result.fetchall(): 

290 oname = name2origname[sp.lower()] 

291 name2id.setdefault(oname, []).append(taxid) 

292 #name2realname[oname] = sp 

293 missing = names - set([n.lower() for n in name2id.keys()]) 

294 if missing: 

295 query = ','.join(['"%s"' %n for n in missing]) 

296 result = self.db.execute('select spname, taxid from synonym where spname IN (%s)' %query) 

297 for sp, taxid in result.fetchall(): 

298 oname = name2origname[sp.lower()] 

299 name2id.setdefault(oname, []).append(taxid) 

300 #name2realname[oname] = sp 

301 return name2id 

302 

303 def _translate_to_names(self, taxids): 

304 """ 

305 Given a list of taxid numbers, returns another list with their corresponding scientific names. 

306 """ 

307 id2name = self._get_taxid_translator(taxids) 

308 names = [] 

309 for sp in taxids: 

310 names.append(id2name.get(sp, sp)) 

311 return names 

312 

313 

314 def get_descendant_taxa(self, parent, intermediate_nodes=False, rank_limit=None, collapse_subspecies=False, return_tree=False): 

315 """ 

316 given a parent taxid or scientific species name, returns a list of all its descendants taxids. 

317 If intermediate_nodes is set to True, internal nodes will also be dumped. 

318 """ 

319 try: 

320 taxid = int(parent) 

321 except ValueError: 

322 try: 

323 taxid = self._get_name_translator([parent])[parent][0] 

324 except KeyError: 

325 raise ValueError('%s not found!' %parent) 

326 

327 # checks if taxid is a deprecated one, and converts into the right one. 

328 _, conversion = self._translate_merged([taxid]) #try to find taxid in synonyms table 

329 if conversion: 

330 taxid = conversion[taxid] 

331 

332 with open(self.dbfile+".traverse.pkl", "rb") as CACHED_TRAVERSE: 

333 prepostorder = pickle.load(CACHED_TRAVERSE) 

334 descendants = {} 

335 found = 0 

336 for tid in prepostorder: 

337 if tid == taxid: 

338 found += 1 

339 elif found == 1: 

340 descendants[tid] = descendants.get(tid, 0) + 1 

341 elif found == 2: 

342 break 

343 

344 if not found: 

345 raise ValueError("taxid not found:%s" %taxid) 

346 elif found == 1: 

347 return [taxid] 

348 if rank_limit or collapse_subspecies or return_tree: 

349 descendants_spnames = self._get_taxid_translator(list(descendants.keys())) 

350 #tree = self.get_topology(list(descendants.keys()), intermediate_nodes=intermediate_nodes, collapse_subspecies=collapse_subspecies, rank_limit=rank_limit) 

351 tree = self.get_topology(list(descendants_spnames.values()), intermediate_nodes=intermediate_nodes, collapse_subspecies=collapse_subspecies, rank_limit=rank_limit) 

352 if return_tree: 

353 return tree 

354 elif intermediate_nodes: 

355 return [n.name for n in tree.get_descendants()] 

356 else: 

357 return [n.name for n in tree] 

358 

359 elif intermediate_nodes: 

360 return self._translate_to_names([tid for tid, count in descendants.items()]) 

361 else: 

362 self._translate_to_names([tid for tid, count in descendants.items() if count == 1]) 

363 return self._translate_to_names([tid for tid, count in descendants.items() if count == 1]) 

364 

365 def get_topology(self, taxnames, intermediate_nodes=False, rank_limit=None, 

366 collapse_subspecies=False, annotate=True): 

367 """Return minimal pruned GTDB taxonomy tree containing all given taxids. 

368 

369 :param intermediate_nodes: If True, single child nodes 

370 representing the complete lineage of leaf nodes are kept. 

371 Otherwise, the tree is pruned to contain the first common 

372 ancestor of each group. 

373 :param rank_limit: If valid NCBI rank name is provided, the 

374 tree is pruned at that given level. For instance, use 

375 rank="species" to get rid of sub-species or strain leaf 

376 nodes. 

377 :param collapse_subspecies: If True, any item under the 

378 species rank will be collapsed into the species upper 

379 node. 

380 """ 

381 from .. import PhyloTree 

382 #taxids, merged_conversion = self._translate_merged(taxids) 

383 tax2id = self._get_name_translator(taxnames) #{'f__Korarchaeaceae': [2174], 'o__Peptococcales': [205487], 'p__Huberarchaeota': [610]} 

384 taxids = [i[0] for i in tax2id.values()] 

385 

386 if len(taxids) == 1: 

387 root_taxid = int(list(taxids)[0]) 

388 with open(self.dbfile+".traverse.pkl", "rb") as CACHED_TRAVERSE: 

389 prepostorder = pickle.load(CACHED_TRAVERSE) 

390 descendants = {} 

391 found = 0 

392 nodes = {} 

393 hit = 0 

394 visited = set() 

395 start = prepostorder.index(root_taxid) 

396 try: 

397 end = prepostorder.index(root_taxid, start+1) 

398 subtree = prepostorder[start:end+1] 

399 except ValueError: 

400 # If root taxid is not found in postorder, must be a tip node 

401 subtree = [root_taxid] 

402 leaves = set([v for v, count in Counter(subtree).items() if count == 1]) 

403 tax2name = self._get_taxid_translator(list(subtree)) 

404 name2tax ={spname:taxid for taxid,spname in tax2name.items()} 

405 nodes[root_taxid] = PhyloTree({'name': str(root_taxid)}) 

406 current_parent = nodes[root_taxid] 

407 for tid in subtree: 

408 if tid in visited: 

409 current_parent = nodes[tid].up 

410 else: 

411 visited.add(tid) 

412 nodes[tid] = PhyloTree({'name': tax2name.get(tid, '')}) 

413 current_parent.add_child(nodes[tid]) 

414 if tid not in leaves: 

415 current_parent = nodes[tid] 

416 root = nodes[root_taxid] 

417 else: 

418 taxids = set(map(int, taxids)) 

419 sp2track = {} 

420 elem2node = {} 

421 id2lineage = self._get_lineage_translator(taxids) 

422 all_taxids = set() 

423 for lineage in id2lineage.values(): 

424 all_taxids.update(lineage) 

425 id2rank = self._get_id2rank(all_taxids) 

426 

427 tax2name = self._get_taxid_translator(taxids) 

428 all_taxid_codes = set([_tax for _lin in list(id2lineage.values()) for _tax in _lin]) 

429 extra_tax2name = self._get_taxid_translator(list(all_taxid_codes - set(tax2name.keys()))) 

430 tax2name.update(extra_tax2name) 

431 name2tax ={spname:taxid for taxid,spname in tax2name.items()} 

432 

433 for sp in taxids: 

434 track = [] 

435 lineage = id2lineage[sp] 

436 

437 for elem in lineage: 

438 spanme = tax2name[elem] 

439 if elem not in elem2node: 

440 node = elem2node.setdefault(elem, PhyloTree()) 

441 node.name = str(tax2name[elem]) 

442 node.taxid = str(tax2name[elem]) 

443 node.add_prop("rank", str(id2rank.get(int(elem), "no rank"))) 

444 else: 

445 node = elem2node[elem] 

446 track.append(node) 

447 sp2track[sp] = track 

448 # generate parent child relationships 

449 for sp, track in sp2track.items(): 

450 parent = None 

451 for elem in track: 

452 if parent and elem not in parent.children: 

453 parent.add_child(elem) 

454 if rank_limit and elem.props.get('rank') == rank_limit: 

455 break 

456 parent = elem 

457 root = elem2node[1] 

458 #remove onechild-nodes 

459 

460 if not intermediate_nodes: 

461 for n in root.descendants(): 

462 if len(n.children) == 1 and int(name2tax.get(n.name, n.name)) not in taxids: 

463 n.delete(prevent_nondicotomic=False) 

464 

465 if len(root.children) == 1: 

466 tree = root.children[0].detach() 

467 else: 

468 tree = root 

469 

470 if collapse_subspecies: 

471 to_detach = [] 

472 for node in tree.traverse(): 

473 if node.props.get('rank') == 'species': 

474 to_detach.extend(node.children) 

475 for n in to_detach: 

476 n.detach() 

477 

478 if annotate: 

479 self.annotate_tree(tree, ignore_unclassified=False) 

480 

481 return tree 

482 

483 def annotate_tree(self, t, taxid_attr='name', tax2name=None, 

484 tax2track=None, tax2rank=None, ignore_unclassified=False): 

485 """Annotate a tree containing taxids as leaf names. 

486 

487 It annotates by adding the properties 'taxid', 'sci_name', 

488 'lineage', 'named_lineage' and 'rank'. 

489 

490 :param t: Tree to annotate. 

491 :param taxid_attr: Node attribute (property) containing the 

492 taxid number associated to each node (i.e. species in 

493 PhyloTree instances). 

494 :param tax2name, tax2track, tax2rank: Pre-calculated 

495 dictionaries with translations from taxid number to names, 

496 track lineages and ranks. 

497 """ 

498 taxids = set() 

499 if taxid_attr == "taxid": 

500 for n in t.leaves(): 

501 if taxid_attr in n.props: 

502 taxids.add(n.props[taxid_attr]) 

503 else: 

504 for n in t.leaves(): 

505 try: 

506 # translate gtdb name -> id 

507 taxaname = getattr(n, taxid_attr, n.props.get(taxid_attr)) 

508 tid = self._get_name_translator([taxaname])[taxaname][0] 

509 taxids.add(tid) 

510 except (KeyError, ValueError, AttributeError): 

511 pass 

512 merged_conversion = {} 

513 

514 taxids, merged_conversion = self._translate_merged(taxids) 

515 

516 if not tax2name or taxids - set(map(int, list(tax2name.keys()))): 

517 tax2name = self._get_taxid_translator(taxids) 

518 if not tax2track or taxids - set(map(int, list(tax2track.keys()))): 

519 tax2track = self._get_lineage_translator(taxids) 

520 

521 all_taxid_codes = set([_tax for _lin in list(tax2track.values()) for _tax in _lin]) 

522 extra_tax2name = self._get_taxid_translator(list(all_taxid_codes - set(tax2name.keys()))) 

523 tax2name.update(extra_tax2name) 

524 

525 tax2common_name = self.get_common_names(tax2name.keys()) 

526 

527 if not tax2rank: 

528 tax2rank = self._get_id2rank(list(tax2name.keys())) 

529 

530 name2tax ={spname:taxid for taxid,spname in tax2name.items()} 

531 n2leaves = t.get_cached_content() 

532 

533 for node in t.traverse('postorder'): 

534 if node.is_leaf: 

535 node_taxid = getattr(node, taxid_attr, node.props.get(taxid_attr)) 

536 else: 

537 node_taxid = None 

538 node.add_prop('taxid', node_taxid) 

539 if node_taxid: 

540 tmp_taxid = self._get_name_translator([node_taxid]).get(node_taxid, [None])[0] 

541 if node_taxid in merged_conversion: 

542 node_taxid = merged_conversion[node_taxid] 

543 

544 rank = tax2rank.get(tmp_taxid, 'Unknown') 

545 if rank != 'subspecies': 

546 sci_name = tax2name.get(tmp_taxid, '') 

547 else: 

548 # For subspecies, gtdb taxid (like 'RS_GCF_0062.1') is not informative. Better use the species one. 

549 track = tax2track[tmp_taxid] # like ['root', 'd__Bacteria', ..., 's__Moorella', 'RS_GCF_0062.1'] 

550 sci_name = tax2name.get(track[-2], '') 

551 

552 node.add_props(sci_name = sci_name, 

553 common_name = tax2common_name.get(node_taxid, ''), 

554 lineage = tax2track.get(tmp_taxid, []), 

555 rank = tax2rank.get(tmp_taxid, 'Unknown'), 

556 named_lineage = [tax2name.get(tax, str(tax)) for tax in tax2track.get(tmp_taxid, [])]) 

557 elif node.is_leaf: 

558 node.add_props(sci_name = getattr(node, taxid_attr, node.props.get(taxid_attr, 'NA')), 

559 common_name = '', 

560 lineage = [], 

561 rank = 'Unknown', 

562 named_lineage = []) 

563 else: 

564 

565 if ignore_unclassified: 

566 vectors = [lf.props.get('lineage') for lf in n2leaves[node] if lf.props.get('lineage')] 

567 else: 

568 vectors = [lf.props.get('lineage') for lf in n2leaves[node]] 

569 lineage = self._common_lineage(vectors) 

570 

571 rank = tax2rank.get(lineage[-1], 'Unknown') 

572 

573 if lineage[-1]: 

574 if rank != 'subspecies': 

575 ancestor = self._get_taxid_translator([lineage[-1]])[lineage[-1]] 

576 else: 

577 ancestor = self._get_taxid_translator([lineage[-2]])[lineage[-2]] 

578 lineage = lineage[:-1] # remove subspecies from lineage 

579 rank = tax2rank.get(lineage[-1], 'Unknown') # update rank 

580 else: 

581 ancestor = None 

582 

583 node.add_props(sci_name = tax2name.get(ancestor, str(ancestor)), 

584 common_name = tax2common_name.get(lineage[-1], ''), 

585 taxid = ancestor, 

586 lineage = lineage, 

587 rank = rank, 

588 named_lineage = [tax2name.get(tax, str(tax)) for tax in lineage]) 

589 

590 return tax2name, tax2track, tax2rank 

591 

592 def _common_lineage(self, vectors): 

593 occurrence = defaultdict(int) 

594 pos = defaultdict(set) 

595 for v in vectors: 

596 for i, taxid in enumerate(v): 

597 occurrence[taxid] += 1 

598 pos[taxid].add(i) 

599 

600 common = [taxid for taxid, ocu in occurrence.items() if ocu == len(vectors)] 

601 if not common: 

602 return [""] 

603 else: 

604 sorted_lineage = sorted(common, key=lambda x: min(pos[x])) 

605 return sorted_lineage 

606 

607 # OLD APPROACH: 

608 

609 # visited = defaultdict(int) 

610 # for index, name in [(ei, e) for v in vectors for ei, e in enumerate(v)]: 

611 # visited[(name, index)] += 1 

612 

613 # def _sort(a, b): 

614 # if a[1] > b[1]: 

615 # return 1 

616 # elif a[1] < b[1]: 

617 # return -1 

618 # else: 

619 # if a[0][1] > b[0][1]: 

620 # return 1 

621 # elif a[0][1] < b[0][1]: 

622 # return -1 

623 # return 0 

624 

625 # matches = sorted(visited.items(), _sort) 

626 

627 # if matches: 

628 # best_match = matches[-1] 

629 # else: 

630 # return "", set() 

631 

632 # if best_match[1] != len(vectors): 

633 # return "", set() 

634 # else: 

635 # return best_match[0][0], [m[0][0] for m in matches if m[1] == len(vectors)] 

636 

637 

638 def get_broken_branches(self, t, taxa_lineages, n2content=None): 

639 """Returns a list of GTDB lineage names that are not monophyletic in the 

640 provided tree, as well as the list of affected branches and their size. 

641 CURRENTLY EXPERIMENTAL 

642 """ 

643 if not n2content: 

644 n2content = t.get_cached_content() 

645 

646 tax2node = defaultdict(set) 

647 

648 unknown = set() 

649 for leaf in t.iter_leaves(): 

650 if leaf.sci_name.lower() != "unknown": 

651 lineage = taxa_lineages[leaf.taxid] 

652 for index, tax in enumerate(lineage): 

653 tax2node[tax].add(leaf) 

654 else: 

655 unknown.add(leaf) 

656 

657 broken_branches = defaultdict(set) 

658 broken_clades = set() 

659 for tax, leaves in tax2node.items(): 

660 if len(leaves) > 1: 

661 common = t.get_common_ancestor(leaves) 

662 else: 

663 common = list(leaves)[0] 

664 if (leaves ^ set(n2content[common])) - unknown: 

665 broken_branches[common].add(tax) 

666 broken_clades.add(tax) 

667 

668 broken_clade_sizes = [len(tax2node[tax]) for tax in broken_clades] 

669 return broken_branches, broken_clades, broken_clade_sizes 

670 

671 

672 # TODO: See why this code is commented out and comment it properly or remove it. 

673 # 

674 # def annotate_tree_with_taxa(self, t, name2taxa_file, tax2name=None, tax2track=None, attr_name="name"): 

675 # if name2taxa_file: 

676 # names2taxid = dict([map(strip, line.split("\t")) 

677 # for line in open(name2taxa_file)]) 

678 # else: 

679 # names2taxid = dict([(n.name, getattr(n, attr_name)) for n in t.iter_leaves()]) 

680 

681 # not_found = 0 

682 # for n in t.iter_leaves(): 

683 # n.add_features(taxid=names2taxid.get(n.name, 0)) 

684 # n.add_features(species=n.taxid) 

685 # if n.taxid == 0: 

686 # not_found += 1 

687 # if not_found: 

688 # print >>sys.stderr, "WARNING: %s nodes where not found within NCBI taxonomy!!" %not_found 

689 

690 # return self.annotate_tree(t, tax2name, tax2track, attr_name="taxid") 

691 

692 

693def load_gtdb_tree_from_dump(tar): 

694 from .. import Tree 

695 # Download: gtdbdump/gtdbr202dump.tar.z 

696 parent2child = {} 

697 name2node = {} 

698 node2taxname = {} 

699 synonyms = set() 

700 name2rank = {} 

701 node2common = {} 

702 print("Loading node names...") 

703 unique_nocase_synonyms = set() 

704 for line in tar.extractfile("names.dmp"): 

705 line = str(line.decode()) 

706 fields = [_f.strip() for _f in line.split("|")] 

707 nodename = fields[0] 

708 name_type = fields[3].lower() 

709 taxname = fields[1] 

710 

711 # Clean up tax names so we make sure the don't include quotes. See https://github.com/etetoolkit/ete/issues/469 

712 taxname = taxname.rstrip('"').lstrip('"') 

713 

714 if name_type == "scientific name": 

715 node2taxname[nodename] = taxname 

716 if name_type == "genbank common name": 

717 node2common[nodename] = taxname 

718 elif name_type in set(["synonym", "equivalent name", "genbank equivalent name", 

719 "anamorph", "genbank synonym", "genbank anamorph", "teleomorph"]): 

720 

721 # Keep track synonyms, but ignore duplicate case-insensitive names. See https://github.com/etetoolkit/ete/issues/469 

722 synonym_key = (nodename, taxname.lower()) 

723 if synonym_key not in unique_nocase_synonyms: 

724 unique_nocase_synonyms.add(synonym_key) 

725 synonyms.add((nodename, taxname)) 

726 

727 print(len(node2taxname), "names loaded.") 

728 print(len(synonyms), "synonyms loaded.") 

729 

730 print("Loading nodes...") 

731 for line in tar.extractfile("nodes.dmp"): 

732 line = str(line.decode()) 

733 fields = line.split("|") 

734 nodename = fields[0].strip() 

735 parentname = fields[1].strip() 

736 try: 

737 n = Tree() 

738 except: 

739 from .. import Tree 

740 n = Tree() 

741 n.name = nodename 

742 #n.taxname = node2taxname[nodename] 

743 n.add_prop('taxname', node2taxname[nodename]) 

744 if nodename in node2common: 

745 n.add_prop('common_name', node2taxname[nodename]) 

746 n.add_prop('rank', fields[2].strip()) 

747 parent2child[nodename] = parentname 

748 name2node[nodename] = n 

749 print(len(name2node), "nodes loaded.") 

750 

751 print("Linking nodes...") 

752 for node in name2node: 

753 if node == "1": 

754 t = name2node[node] 

755 else: 

756 parent = parent2child[node] 

757 parent_node = name2node[parent] 

758 parent_node.add_child(name2node[node]) 

759 print("Tree is loaded.") 

760 return t, synonyms 

761 

762def generate_table(t): 

763 OUT = open("taxa.tab", "w") 

764 for j, n in enumerate(t.traverse()): 

765 if j%1000 == 0: 

766 print("\r",j,"generating entries...", end=' ') 

767 temp_node = n 

768 track = [] 

769 while temp_node: 

770 track.append(temp_node.name) 

771 temp_node = temp_node.up 

772 if n.up: 

773 print('\t'.join([n.name, n.up.name, n.props.get('taxname'), n.props.get("common_name", ''), n.props.get("rank"), ','.join(track)]), file=OUT) 

774 else: 

775 print('\t'.join([n.name, "", n.props.get('taxname'), n.props.get("common_name", ''), n.props.get("rank"), ','.join(track)]), file=OUT) 

776 OUT.close() 

777 

778 

779def update_db(dbfile, targz_file=None): 

780 basepath = os.path.split(dbfile)[0] 

781 if basepath and not os.path.exists(basepath): 

782 os.mkdir(basepath) 

783 

784 # if users don't provie targz_file, update the latest version from ete-data  

785 if not targz_file: 

786 update_local_taxdump(DEFAULT_GTDBTAXADUMP) 

787 targz_file = DEFAULT_GTDBTAXADUMP 

788 

789 tar = tarfile.open(targz_file, 'r') 

790 t, synonyms = load_gtdb_tree_from_dump(tar) 

791 

792 prepostorder = [int(node.name) for post, node in t.iter_prepostorder()] 

793 

794 with open(dbfile+'.traverse.pkl', 'wb') as fout: 

795 pickle.dump(prepostorder, fout, 2) 

796 

797 print("Updating database: %s ..." %dbfile) 

798 generate_table(t) 

799 

800 upload_data(dbfile) 

801 

802 os.system("rm taxa.tab") 

803 

804def update_local_taxdump(fname=DEFAULT_GTDBTAXADUMP): 

805 # latest version of gtdb taxonomy dump 

806 url = "https://github.com/etetoolkit/ete-data/raw/main/gtdb_taxonomy/gtdblatest/gtdb_latest_dump.tar.gz" 

807 

808 if not os.path.exists(fname): 

809 print(f'Downloading {fname} from {url} ...') 

810 with open(fname, 'wb') as f: 

811 f.write(requests.get(url).content) 

812 else: 

813 md5_local = md5(open(fname, 'rb').read()).hexdigest() 

814 md5_remote = requests.get(url + '.md5').text.split()[0] 

815 

816 if md5_local != md5_remote: 

817 print(f'Updating {fname} from {url} ...') 

818 with open(fname, 'wb') as f: 

819 f.write(requests.get(url).content) 

820 else: 

821 print(f'File {fname} is already up-to-date with {url} .') 

822 

823def upload_data(dbfile): 

824 print() 

825 print('Uploading to', dbfile) 

826 basepath = os.path.split(dbfile)[0] 

827 if basepath and not os.path.exists(basepath): 

828 os.mkdir(basepath) 

829 

830 db = sqlite3.connect(dbfile) 

831 

832 create_cmd = """ 

833 DROP TABLE IF EXISTS stats; 

834 DROP TABLE IF EXISTS species; 

835 DROP TABLE IF EXISTS synonym; 

836 DROP TABLE IF EXISTS merged; 

837 CREATE TABLE stats (version INT PRIMARY KEY); 

838 CREATE TABLE species (taxid INT PRIMARY KEY, parent INT, spname VARCHAR(50) COLLATE NOCASE, common VARCHAR(50) COLLATE NOCASE, rank VARCHAR(50), track TEXT); 

839 CREATE TABLE synonym (taxid INT,spname VARCHAR(50) COLLATE NOCASE, PRIMARY KEY (spname, taxid)); 

840 CREATE TABLE merged (taxid_old INT, taxid_new INT); 

841 CREATE INDEX spname1 ON species (spname COLLATE NOCASE); 

842 CREATE INDEX spname2 ON synonym (spname COLLATE NOCASE); 

843 """ 

844 for cmd in create_cmd.split(';'): 

845 db.execute(cmd) 

846 print() 

847 

848 db.execute("INSERT INTO stats (version) VALUES (%d);" %DB_VERSION) 

849 db.commit() 

850 

851 # for i, line in enumerate(open("syn.tab")): 

852 # if i%5000 == 0 : 

853 # print('\rInserting synonyms: % 6d' %i, end=' ', file=sys.stderr) 

854 # sys.stderr.flush() 

855 # taxid, spname = line.strip('\n').split('\t') 

856 # db.execute("INSERT INTO synonym (taxid, spname) VALUES (?, ?);", (taxid, spname)) 

857 # print() 

858 # db.commit() 

859 # for i, line in enumerate(open("merged.tab")): 

860 # if i%5000 == 0 : 

861 # print('\rInserting taxid merges: % 6d' %i, end=' ', file=sys.stderr) 

862 # sys.stderr.flush() 

863 # taxid_old, taxid_new = line.strip('\n').split('\t') 

864 # db.execute("INSERT INTO merged (taxid_old, taxid_new) VALUES (?, ?);", (taxid_old, taxid_new)) 

865 # print() 

866 # db.commit() 

867 

868 with open('taxa.tab') as f_taxa: 

869 for i, line in enumerate(f_taxa): 

870 if i % 5000 == 0: 

871 print('\rInserting taxids: %8d' % i, end=' ', file=sys.stderr) 

872 sys.stderr.flush() 

873 taxid, parentid, spname, common, rank, lineage = line.strip('\n').split('\t') 

874 db.execute(('INSERT INTO species (taxid, parent, spname, common, rank, track) ' 

875 'VALUES (?, ?, ?, ?, ?, ?)'), (taxid, parentid, spname, common, rank, lineage)) 

876 print() 

877 db.commit() 

878 

879if __name__ == "__main__": 

880 #from .. import PhyloTree 

881 gtdb = GTDBTaxa() 

882 gtdb.update_taxonomy_database(DEFAULT_GTDBTAXADUMP) 

883 

884 descendants = gtdb.get_descendant_taxa('c__Thorarchaeia', collapse_subspecies=True, return_tree=True) 

885 print(descendants.write(properties=None)) 

886 print(descendants.get_ascii(properties=['sci_name', 'taxid','rank'])) 

887 tree = gtdb.get_topology(["p__Huberarchaeota", "o__Peptococcales", "f__Korarchaeaceae", "s__Korarchaeum"], intermediate_nodes=True, collapse_subspecies=True, annotate=True) 

888 print(tree.get_ascii(properties=["taxid", "sci_name", "rank"])) 

889 

890 tree = PhyloTree('((c__Thorarchaeia, c__Lokiarchaeia_A), s__Caballeronia udeis);', sp_naming_function=lambda name: name) 

891 tax2name, tax2track, tax2rank = gtdb.annotate_tree(tree, taxid_attr="name") 

892 print(tree.get_ascii(properties=["taxid","name", "sci_name", "rank"])) 

893 

894 print(gtdb.get_name_lineage(['RS_GCF_006228565.1','GB_GCA_001515945.1']))