Coverage for /home/deng/Projects/ete4/hackathon/ete4/ete4/gtdb_taxonomy/gtdbquery.py: 51%
505 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-07 10:27 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-07 10:27 +0200
1#!/usr/bin/env python3
3import sys
4import os
6import pickle
7from collections import defaultdict, Counter
9from hashlib import md5
11import sqlite3
12import math
13import tarfile
14import warnings
15import requests
17from ete4 import ETE_DATA_HOME, update_ete_data
20__all__ = ["GTDBTaxa", "is_taxadb_up_to_date"]
22DB_VERSION = 2
23DEFAULT_GTDBTAXADB = ETE_DATA_HOME + '/gtdbtaxa.sqlite'
24DEFAULT_GTDBTAXADUMP = ETE_DATA_HOME + '/gtdbdump.tar.gz'
26def is_taxadb_up_to_date(dbfile=DEFAULT_GTDBTAXADB):
27 """Check if a valid and up-to-date gtdbtaxa.sqlite database exists
28 If dbfile= is not specified, DEFAULT_TAXADB is assumed
29 """
30 db = sqlite3.connect(dbfile)
31 try:
32 r = db.execute('SELECT version FROM stats;')
33 version = r.fetchone()[0]
34 except (sqlite3.OperationalError, ValueError, IndexError, TypeError):
35 version = None
37 db.close()
39 if version != DB_VERSION:
40 return False
41 return True
44class GTDBTaxa:
45 """
46 Local transparent connector to the GTDB taxonomy database.
47 """
49 def __init__(self, dbfile=None, taxdump_file=None, memory=False):
51 if not dbfile:
52 self.dbfile = DEFAULT_GTDBTAXADB
53 else:
54 self.dbfile = dbfile
56 if taxdump_file:
57 self.update_taxonomy_database(taxdump_file)
59 if dbfile != DEFAULT_GTDBTAXADB and not os.path.exists(self.dbfile):
60 print('GTDB database not present yet (first time used?)', file=sys.stderr)
61 urlbase = ('https://github.com/etetoolkit/ete-data/raw/main'
62 '/gtdb_taxonomy/gtdblatest')
64 update_ete_data(f'{DEFAULT_GTDBTAXADUMP}', f'{urlbase}/gtdb_latest_dump.tar.gz')
66 self.update_taxonomy_database(taxdump_file=DEFAULT_GTDBTAXADUMP)
68 if not os.path.exists(self.dbfile):
69 raise ValueError("Cannot open taxonomy database: %s" % self.dbfile)
71 self.db = None
72 self._connect()
74 if not is_taxadb_up_to_date(self.dbfile):
75 print('GTDB database format is outdated. Upgrading', file=sys.stderr)
76 self.update_taxonomy_database(taxdump_file)
78 if memory:
79 filedb = self.db
80 self.db = sqlite3.connect(':memory:')
81 filedb.backup(self.db)
83 def update_taxonomy_database(self, taxdump_file=None):
84 """Update the GTDB taxonomy database.
86 It updates it by downloading and parsing the latest
87 gtdbtaxdump.tar.gz file.
89 :param taxdump_file: Alternative location of gtdbtaxdump.tar.gz.
90 """
91 update_db(self.dbfile, targz_file=taxdump_file)
93 def _connect(self):
94 self.db = sqlite3.connect(self.dbfile)
96 def _translate_merged(self, all_taxids):
97 conv_all_taxids = set((list(map(int, all_taxids))))
98 cmd = 'select taxid_old, taxid_new FROM merged WHERE taxid_old IN (%s)' %','.join(map(str, all_taxids))
100 result = self.db.execute(cmd)
101 conversion = {}
102 for old, new in result.fetchall():
103 conv_all_taxids.discard(int(old))
104 conv_all_taxids.add(int(new))
105 conversion[int(old)] = int(new)
106 return conv_all_taxids, conversion
109 # def get_fuzzy_name_translation(self, name, sim=0.9):
110 # '''
111 # Given an inexact species name, returns the best match in the NCBI database of taxa names.
112 # :argument 0.9 sim: Min word similarity to report a match (from 0 to 1).
113 # :return: taxid, species-name-match, match-score
114 # '''
117 # import sqlite3.dbapi2 as dbapi2
118 # _db = dbapi2.connect(self.dbfile)
119 # _db.enable_load_extension(True)
120 # module_path = os.path.split(os.path.realpath(__file__))[0]
121 # _db.execute("select load_extension('%s')" % os.path.join(module_path,
122 # "SQLite-Levenshtein/levenshtein.sqlext"))
124 # print("Trying fuzzy search for %s" % name)
125 # maxdiffs = math.ceil(len(name) * (1-sim))
126 # cmd = 'SELECT taxid, spname, LEVENSHTEIN(spname, "%s") AS sim FROM species WHERE sim<=%s ORDER BY sim LIMIT 1;' % (name, maxdiffs)
127 # taxid, spname, score = None, None, len(name)
128 # result = _db.execute(cmd)
129 # try:
130 # taxid, spname, score = result.fetchone()
131 # except TypeError:
132 # cmd = 'SELECT taxid, spname, LEVENSHTEIN(spname, "%s") AS sim FROM synonym WHERE sim<=%s ORDER BY sim LIMIT 1;' % (name, maxdiffs)
133 # result = _db.execute(cmd)
134 # try:
135 # taxid, spname, score = result.fetchone()
136 # except:
137 # pass
138 # else:
139 # taxid = int(taxid)
140 # else:
141 # taxid = int(taxid)
143 # norm_score = 1 - (float(score)/len(name))
144 # if taxid:
145 # print("FOUND! %s taxid:%s score:%s (%s)" %(spname, taxid, score, norm_score))
147 # return taxid, spname, norm_score
149 def _get_id2rank(self, internal_taxids):
150 """Given a list of numeric ids (each one representing a taxa in GTDB), return a dictionary with their corresponding ranks.
151 Examples:
152 > gtdb.get_rank([2174, 205487, 610])
153 {2174: 'family', 205487: 'order', 610: 'phylum'}
155 Note: Numeric taxids are not recognized by the official GTDB taxonomy database, only for internal usage.
156 """
157 ids = ','.join('"%s"' % v for v in set(internal_taxids) - {None, ''})
158 result = self.db.execute('SELECT taxid, rank FROM species WHERE taxid IN (%s)' % ids)
159 return {tax: spname for tax, spname in result.fetchall()}
161 def get_rank(self, taxids):
162 """Give a list of GTDB string taxids, return a dictionary with their corresponding ranks.
163 Examples:
165 > gtdb.get_rank(['c__Thorarchaeia', 'RS_GCF_001477695.1'])
166 {'c__Thorarchaeia': 'class', 'RS_GCF_001477695.1': 'subspecies'}
167 """
169 taxid2rank = {}
170 name2ids = self._get_name_translator(taxids)
171 overlap_ids = name2ids.values()
172 taxids = [item for sublist in overlap_ids for item in sublist]
173 ids = ','.join('"%s"' % v for v in set(taxids) - {None, ''})
174 result = self.db.execute('SELECT taxid, rank FROM species WHERE taxid IN (%s)' % ids)
175 for tax, rank in result.fetchall():
176 taxid2rank[list(self._get_taxid_translator([tax]).values())[0]] = rank
178 return taxid2rank
180 def _get_lineage_translator(self, taxids):
181 """Given a valid taxid number, return its corresponding lineage track as a
182 hierarchically sorted list of parent taxids.
183 """
184 all_ids = set(taxids)
185 all_ids.discard(None)
186 all_ids.discard("")
187 query = ','.join(['"%s"' %v for v in all_ids])
188 result = self.db.execute('SELECT taxid, track FROM species WHERE taxid IN (%s);' %query)
189 id2lineages = {}
190 for tax, track in result.fetchall():
191 id2lineages[tax] = list(map(int, reversed(track.split(","))))
192 return id2lineages
194 def get_name_lineage(self, taxnames):
195 """Given a valid taxname, return its corresponding lineage track as a
196 hierarchically sorted list of parent taxnames.
197 """
198 name_lineages = []
199 name2taxid = self._get_name_translator(taxnames)
200 for key, value in name2taxid.items():
201 lineage = self._get_lineage(value[0])
202 names = self._get_taxid_translator(lineage)
203 name_lineages.append({key:[names[taxid] for taxid in lineage]})
205 return name_lineages
207 def _get_lineage(self, taxid):
208 """Given a valid taxid number, return its corresponding lineage track as a
209 hierarchically sorted list of parent taxids.
210 """
211 if not taxid:
212 return None
213 taxid = int(taxid)
214 result = self.db.execute('SELECT track FROM species WHERE taxid=%s' %taxid)
215 raw_track = result.fetchone()
216 if not raw_track:
217 #perhaps is an obsolete taxid
218 _, merged_conversion = self._translate_merged([taxid])
219 if taxid in merged_conversion:
220 result = self.db.execute('SELECT track FROM species WHERE taxid=%s' %merged_conversion[taxid])
221 raw_track = result.fetchone()
222 # if not raise error
223 if not raw_track:
224 #raw_track = ["1"]
225 raise ValueError("%s taxid not found" %taxid)
226 else:
227 warnings.warn("taxid %s was translated into %s" %(taxid, merged_conversion[taxid]))
229 track = list(map(int, raw_track[0].split(",")))
230 return list(reversed(track))
232 def get_common_names(self, taxids):
233 query = ','.join(['"%s"' %v for v in taxids])
234 cmd = "select taxid, common FROM species WHERE taxid IN (%s);" %query
235 result = self.db.execute(cmd)
236 id2name = {}
237 for tax, common_name in result.fetchall():
238 if common_name:
239 id2name[tax] = common_name
240 return id2name
242 def _get_taxid_translator(self, taxids, try_synonyms=True):
243 """Given a list of taxids, returns a dictionary with their corresponding
244 scientific names.
245 """
247 all_ids = set(map(int, taxids))
248 all_ids.discard(None)
249 all_ids.discard("")
250 query = ','.join(['"%s"' %v for v in all_ids])
251 cmd = "select taxid, spname FROM species WHERE taxid IN (%s);" %query
252 result = self.db.execute(cmd)
253 id2name = {}
254 for tax, spname in result.fetchall():
255 id2name[tax] = spname
257 # any taxid without translation? lets tray in the merged table
258 # if len(all_ids) != len(id2name) and try_synonyms:
259 # not_found_taxids = all_ids - set(id2name.keys())
260 # taxids, old2new = self._translate_merged(not_found_taxids)
261 # new2old = {v: k for k,v in old2new.items()}
263 # if old2new:
264 # query = ','.join(['"%s"' %v for v in new2old])
265 # cmd = "select taxid, spname FROM species WHERE taxid IN (%s);" %query
266 # result = self.db.execute(cmd)
267 # for tax, spname in result.fetchall():
268 # id2name[new2old[tax]] = spname
270 return id2name
272 def _get_name_translator(self, names):
273 """
274 Given a list of taxid scientific names, returns a dictionary translating them into their corresponding taxids.
275 Exact name match is required for translation.
276 """
278 name2id = {}
279 #name2realname = {}
280 name2origname = {}
281 for n in names:
282 name2origname[n.lower()] = n
284 names = set(name2origname.keys())
286 query = ','.join(['"%s"' %n for n in name2origname.keys()])
287 cmd = 'select spname, taxid from species where spname IN (%s)' %query
288 result = self.db.execute('select spname, taxid from species where spname IN (%s)' %query)
289 for sp, taxid in result.fetchall():
290 oname = name2origname[sp.lower()]
291 name2id.setdefault(oname, []).append(taxid)
292 #name2realname[oname] = sp
293 missing = names - set([n.lower() for n in name2id.keys()])
294 if missing:
295 query = ','.join(['"%s"' %n for n in missing])
296 result = self.db.execute('select spname, taxid from synonym where spname IN (%s)' %query)
297 for sp, taxid in result.fetchall():
298 oname = name2origname[sp.lower()]
299 name2id.setdefault(oname, []).append(taxid)
300 #name2realname[oname] = sp
301 return name2id
303 def _translate_to_names(self, taxids):
304 """
305 Given a list of taxid numbers, returns another list with their corresponding scientific names.
306 """
307 id2name = self._get_taxid_translator(taxids)
308 names = []
309 for sp in taxids:
310 names.append(id2name.get(sp, sp))
311 return names
314 def get_descendant_taxa(self, parent, intermediate_nodes=False, rank_limit=None, collapse_subspecies=False, return_tree=False):
315 """
316 given a parent taxid or scientific species name, returns a list of all its descendants taxids.
317 If intermediate_nodes is set to True, internal nodes will also be dumped.
318 """
319 try:
320 taxid = int(parent)
321 except ValueError:
322 try:
323 taxid = self._get_name_translator([parent])[parent][0]
324 except KeyError:
325 raise ValueError('%s not found!' %parent)
327 # checks if taxid is a deprecated one, and converts into the right one.
328 _, conversion = self._translate_merged([taxid]) #try to find taxid in synonyms table
329 if conversion:
330 taxid = conversion[taxid]
332 with open(self.dbfile+".traverse.pkl", "rb") as CACHED_TRAVERSE:
333 prepostorder = pickle.load(CACHED_TRAVERSE)
334 descendants = {}
335 found = 0
336 for tid in prepostorder:
337 if tid == taxid:
338 found += 1
339 elif found == 1:
340 descendants[tid] = descendants.get(tid, 0) + 1
341 elif found == 2:
342 break
344 if not found:
345 raise ValueError("taxid not found:%s" %taxid)
346 elif found == 1:
347 return [taxid]
348 if rank_limit or collapse_subspecies or return_tree:
349 descendants_spnames = self._get_taxid_translator(list(descendants.keys()))
350 #tree = self.get_topology(list(descendants.keys()), intermediate_nodes=intermediate_nodes, collapse_subspecies=collapse_subspecies, rank_limit=rank_limit)
351 tree = self.get_topology(list(descendants_spnames.values()), intermediate_nodes=intermediate_nodes, collapse_subspecies=collapse_subspecies, rank_limit=rank_limit)
352 if return_tree:
353 return tree
354 elif intermediate_nodes:
355 return [n.name for n in tree.get_descendants()]
356 else:
357 return [n.name for n in tree]
359 elif intermediate_nodes:
360 return self._translate_to_names([tid for tid, count in descendants.items()])
361 else:
362 self._translate_to_names([tid for tid, count in descendants.items() if count == 1])
363 return self._translate_to_names([tid for tid, count in descendants.items() if count == 1])
365 def get_topology(self, taxnames, intermediate_nodes=False, rank_limit=None,
366 collapse_subspecies=False, annotate=True):
367 """Return minimal pruned GTDB taxonomy tree containing all given taxids.
369 :param intermediate_nodes: If True, single child nodes
370 representing the complete lineage of leaf nodes are kept.
371 Otherwise, the tree is pruned to contain the first common
372 ancestor of each group.
373 :param rank_limit: If valid NCBI rank name is provided, the
374 tree is pruned at that given level. For instance, use
375 rank="species" to get rid of sub-species or strain leaf
376 nodes.
377 :param collapse_subspecies: If True, any item under the
378 species rank will be collapsed into the species upper
379 node.
380 """
381 from .. import PhyloTree
382 #taxids, merged_conversion = self._translate_merged(taxids)
383 tax2id = self._get_name_translator(taxnames) #{'f__Korarchaeaceae': [2174], 'o__Peptococcales': [205487], 'p__Huberarchaeota': [610]}
384 taxids = [i[0] for i in tax2id.values()]
386 if len(taxids) == 1:
387 root_taxid = int(list(taxids)[0])
388 with open(self.dbfile+".traverse.pkl", "rb") as CACHED_TRAVERSE:
389 prepostorder = pickle.load(CACHED_TRAVERSE)
390 descendants = {}
391 found = 0
392 nodes = {}
393 hit = 0
394 visited = set()
395 start = prepostorder.index(root_taxid)
396 try:
397 end = prepostorder.index(root_taxid, start+1)
398 subtree = prepostorder[start:end+1]
399 except ValueError:
400 # If root taxid is not found in postorder, must be a tip node
401 subtree = [root_taxid]
402 leaves = set([v for v, count in Counter(subtree).items() if count == 1])
403 tax2name = self._get_taxid_translator(list(subtree))
404 name2tax ={spname:taxid for taxid,spname in tax2name.items()}
405 nodes[root_taxid] = PhyloTree({'name': str(root_taxid)})
406 current_parent = nodes[root_taxid]
407 for tid in subtree:
408 if tid in visited:
409 current_parent = nodes[tid].up
410 else:
411 visited.add(tid)
412 nodes[tid] = PhyloTree({'name': tax2name.get(tid, '')})
413 current_parent.add_child(nodes[tid])
414 if tid not in leaves:
415 current_parent = nodes[tid]
416 root = nodes[root_taxid]
417 else:
418 taxids = set(map(int, taxids))
419 sp2track = {}
420 elem2node = {}
421 id2lineage = self._get_lineage_translator(taxids)
422 all_taxids = set()
423 for lineage in id2lineage.values():
424 all_taxids.update(lineage)
425 id2rank = self._get_id2rank(all_taxids)
427 tax2name = self._get_taxid_translator(taxids)
428 all_taxid_codes = set([_tax for _lin in list(id2lineage.values()) for _tax in _lin])
429 extra_tax2name = self._get_taxid_translator(list(all_taxid_codes - set(tax2name.keys())))
430 tax2name.update(extra_tax2name)
431 name2tax ={spname:taxid for taxid,spname in tax2name.items()}
433 for sp in taxids:
434 track = []
435 lineage = id2lineage[sp]
437 for elem in lineage:
438 spanme = tax2name[elem]
439 if elem not in elem2node:
440 node = elem2node.setdefault(elem, PhyloTree())
441 node.name = str(tax2name[elem])
442 node.taxid = str(tax2name[elem])
443 node.add_prop("rank", str(id2rank.get(int(elem), "no rank")))
444 else:
445 node = elem2node[elem]
446 track.append(node)
447 sp2track[sp] = track
448 # generate parent child relationships
449 for sp, track in sp2track.items():
450 parent = None
451 for elem in track:
452 if parent and elem not in parent.children:
453 parent.add_child(elem)
454 if rank_limit and elem.props.get('rank') == rank_limit:
455 break
456 parent = elem
457 root = elem2node[1]
458 #remove onechild-nodes
460 if not intermediate_nodes:
461 for n in root.descendants():
462 if len(n.children) == 1 and int(name2tax.get(n.name, n.name)) not in taxids:
463 n.delete(prevent_nondicotomic=False)
465 if len(root.children) == 1:
466 tree = root.children[0].detach()
467 else:
468 tree = root
470 if collapse_subspecies:
471 to_detach = []
472 for node in tree.traverse():
473 if node.props.get('rank') == 'species':
474 to_detach.extend(node.children)
475 for n in to_detach:
476 n.detach()
478 if annotate:
479 self.annotate_tree(tree, ignore_unclassified=False)
481 return tree
483 def annotate_tree(self, t, taxid_attr='name', tax2name=None,
484 tax2track=None, tax2rank=None, ignore_unclassified=False):
485 """Annotate a tree containing taxids as leaf names.
487 It annotates by adding the properties 'taxid', 'sci_name',
488 'lineage', 'named_lineage' and 'rank'.
490 :param t: Tree to annotate.
491 :param taxid_attr: Node attribute (property) containing the
492 taxid number associated to each node (i.e. species in
493 PhyloTree instances).
494 :param tax2name, tax2track, tax2rank: Pre-calculated
495 dictionaries with translations from taxid number to names,
496 track lineages and ranks.
497 """
498 taxids = set()
499 if taxid_attr == "taxid":
500 for n in t.leaves():
501 if taxid_attr in n.props:
502 taxids.add(n.props[taxid_attr])
503 else:
504 for n in t.leaves():
505 try:
506 # translate gtdb name -> id
507 taxaname = getattr(n, taxid_attr, n.props.get(taxid_attr))
508 tid = self._get_name_translator([taxaname])[taxaname][0]
509 taxids.add(tid)
510 except (KeyError, ValueError, AttributeError):
511 pass
512 merged_conversion = {}
514 taxids, merged_conversion = self._translate_merged(taxids)
516 if not tax2name or taxids - set(map(int, list(tax2name.keys()))):
517 tax2name = self._get_taxid_translator(taxids)
518 if not tax2track or taxids - set(map(int, list(tax2track.keys()))):
519 tax2track = self._get_lineage_translator(taxids)
521 all_taxid_codes = set([_tax for _lin in list(tax2track.values()) for _tax in _lin])
522 extra_tax2name = self._get_taxid_translator(list(all_taxid_codes - set(tax2name.keys())))
523 tax2name.update(extra_tax2name)
525 tax2common_name = self.get_common_names(tax2name.keys())
527 if not tax2rank:
528 tax2rank = self._get_id2rank(list(tax2name.keys()))
530 name2tax ={spname:taxid for taxid,spname in tax2name.items()}
531 n2leaves = t.get_cached_content()
533 for node in t.traverse('postorder'):
534 if node.is_leaf:
535 node_taxid = getattr(node, taxid_attr, node.props.get(taxid_attr))
536 else:
537 node_taxid = None
538 node.add_prop('taxid', node_taxid)
539 if node_taxid:
540 tmp_taxid = self._get_name_translator([node_taxid]).get(node_taxid, [None])[0]
541 if node_taxid in merged_conversion:
542 node_taxid = merged_conversion[node_taxid]
544 rank = tax2rank.get(tmp_taxid, 'Unknown')
545 if rank != 'subspecies':
546 sci_name = tax2name.get(tmp_taxid, '')
547 else:
548 # For subspecies, gtdb taxid (like 'RS_GCF_0062.1') is not informative. Better use the species one.
549 track = tax2track[tmp_taxid] # like ['root', 'd__Bacteria', ..., 's__Moorella', 'RS_GCF_0062.1']
550 sci_name = tax2name.get(track[-2], '')
552 node.add_props(sci_name = sci_name,
553 common_name = tax2common_name.get(node_taxid, ''),
554 lineage = tax2track.get(tmp_taxid, []),
555 rank = tax2rank.get(tmp_taxid, 'Unknown'),
556 named_lineage = [tax2name.get(tax, str(tax)) for tax in tax2track.get(tmp_taxid, [])])
557 elif node.is_leaf:
558 node.add_props(sci_name = getattr(node, taxid_attr, node.props.get(taxid_attr, 'NA')),
559 common_name = '',
560 lineage = [],
561 rank = 'Unknown',
562 named_lineage = [])
563 else:
565 if ignore_unclassified:
566 vectors = [lf.props.get('lineage') for lf in n2leaves[node] if lf.props.get('lineage')]
567 else:
568 vectors = [lf.props.get('lineage') for lf in n2leaves[node]]
569 lineage = self._common_lineage(vectors)
571 rank = tax2rank.get(lineage[-1], 'Unknown')
573 if lineage[-1]:
574 if rank != 'subspecies':
575 ancestor = self._get_taxid_translator([lineage[-1]])[lineage[-1]]
576 else:
577 ancestor = self._get_taxid_translator([lineage[-2]])[lineage[-2]]
578 lineage = lineage[:-1] # remove subspecies from lineage
579 rank = tax2rank.get(lineage[-1], 'Unknown') # update rank
580 else:
581 ancestor = None
583 node.add_props(sci_name = tax2name.get(ancestor, str(ancestor)),
584 common_name = tax2common_name.get(lineage[-1], ''),
585 taxid = ancestor,
586 lineage = lineage,
587 rank = rank,
588 named_lineage = [tax2name.get(tax, str(tax)) for tax in lineage])
590 return tax2name, tax2track, tax2rank
592 def _common_lineage(self, vectors):
593 occurrence = defaultdict(int)
594 pos = defaultdict(set)
595 for v in vectors:
596 for i, taxid in enumerate(v):
597 occurrence[taxid] += 1
598 pos[taxid].add(i)
600 common = [taxid for taxid, ocu in occurrence.items() if ocu == len(vectors)]
601 if not common:
602 return [""]
603 else:
604 sorted_lineage = sorted(common, key=lambda x: min(pos[x]))
605 return sorted_lineage
607 # OLD APPROACH:
609 # visited = defaultdict(int)
610 # for index, name in [(ei, e) for v in vectors for ei, e in enumerate(v)]:
611 # visited[(name, index)] += 1
613 # def _sort(a, b):
614 # if a[1] > b[1]:
615 # return 1
616 # elif a[1] < b[1]:
617 # return -1
618 # else:
619 # if a[0][1] > b[0][1]:
620 # return 1
621 # elif a[0][1] < b[0][1]:
622 # return -1
623 # return 0
625 # matches = sorted(visited.items(), _sort)
627 # if matches:
628 # best_match = matches[-1]
629 # else:
630 # return "", set()
632 # if best_match[1] != len(vectors):
633 # return "", set()
634 # else:
635 # return best_match[0][0], [m[0][0] for m in matches if m[1] == len(vectors)]
638 def get_broken_branches(self, t, taxa_lineages, n2content=None):
639 """Returns a list of GTDB lineage names that are not monophyletic in the
640 provided tree, as well as the list of affected branches and their size.
641 CURRENTLY EXPERIMENTAL
642 """
643 if not n2content:
644 n2content = t.get_cached_content()
646 tax2node = defaultdict(set)
648 unknown = set()
649 for leaf in t.iter_leaves():
650 if leaf.sci_name.lower() != "unknown":
651 lineage = taxa_lineages[leaf.taxid]
652 for index, tax in enumerate(lineage):
653 tax2node[tax].add(leaf)
654 else:
655 unknown.add(leaf)
657 broken_branches = defaultdict(set)
658 broken_clades = set()
659 for tax, leaves in tax2node.items():
660 if len(leaves) > 1:
661 common = t.get_common_ancestor(leaves)
662 else:
663 common = list(leaves)[0]
664 if (leaves ^ set(n2content[common])) - unknown:
665 broken_branches[common].add(tax)
666 broken_clades.add(tax)
668 broken_clade_sizes = [len(tax2node[tax]) for tax in broken_clades]
669 return broken_branches, broken_clades, broken_clade_sizes
672 # TODO: See why this code is commented out and comment it properly or remove it.
673 #
674 # def annotate_tree_with_taxa(self, t, name2taxa_file, tax2name=None, tax2track=None, attr_name="name"):
675 # if name2taxa_file:
676 # names2taxid = dict([map(strip, line.split("\t"))
677 # for line in open(name2taxa_file)])
678 # else:
679 # names2taxid = dict([(n.name, getattr(n, attr_name)) for n in t.iter_leaves()])
681 # not_found = 0
682 # for n in t.iter_leaves():
683 # n.add_features(taxid=names2taxid.get(n.name, 0))
684 # n.add_features(species=n.taxid)
685 # if n.taxid == 0:
686 # not_found += 1
687 # if not_found:
688 # print >>sys.stderr, "WARNING: %s nodes where not found within NCBI taxonomy!!" %not_found
690 # return self.annotate_tree(t, tax2name, tax2track, attr_name="taxid")
693def load_gtdb_tree_from_dump(tar):
694 from .. import Tree
695 # Download: gtdbdump/gtdbr202dump.tar.z
696 parent2child = {}
697 name2node = {}
698 node2taxname = {}
699 synonyms = set()
700 name2rank = {}
701 node2common = {}
702 print("Loading node names...")
703 unique_nocase_synonyms = set()
704 for line in tar.extractfile("names.dmp"):
705 line = str(line.decode())
706 fields = [_f.strip() for _f in line.split("|")]
707 nodename = fields[0]
708 name_type = fields[3].lower()
709 taxname = fields[1]
711 # Clean up tax names so we make sure the don't include quotes. See https://github.com/etetoolkit/ete/issues/469
712 taxname = taxname.rstrip('"').lstrip('"')
714 if name_type == "scientific name":
715 node2taxname[nodename] = taxname
716 if name_type == "genbank common name":
717 node2common[nodename] = taxname
718 elif name_type in set(["synonym", "equivalent name", "genbank equivalent name",
719 "anamorph", "genbank synonym", "genbank anamorph", "teleomorph"]):
721 # Keep track synonyms, but ignore duplicate case-insensitive names. See https://github.com/etetoolkit/ete/issues/469
722 synonym_key = (nodename, taxname.lower())
723 if synonym_key not in unique_nocase_synonyms:
724 unique_nocase_synonyms.add(synonym_key)
725 synonyms.add((nodename, taxname))
727 print(len(node2taxname), "names loaded.")
728 print(len(synonyms), "synonyms loaded.")
730 print("Loading nodes...")
731 for line in tar.extractfile("nodes.dmp"):
732 line = str(line.decode())
733 fields = line.split("|")
734 nodename = fields[0].strip()
735 parentname = fields[1].strip()
736 try:
737 n = Tree()
738 except:
739 from .. import Tree
740 n = Tree()
741 n.name = nodename
742 #n.taxname = node2taxname[nodename]
743 n.add_prop('taxname', node2taxname[nodename])
744 if nodename in node2common:
745 n.add_prop('common_name', node2taxname[nodename])
746 n.add_prop('rank', fields[2].strip())
747 parent2child[nodename] = parentname
748 name2node[nodename] = n
749 print(len(name2node), "nodes loaded.")
751 print("Linking nodes...")
752 for node in name2node:
753 if node == "1":
754 t = name2node[node]
755 else:
756 parent = parent2child[node]
757 parent_node = name2node[parent]
758 parent_node.add_child(name2node[node])
759 print("Tree is loaded.")
760 return t, synonyms
762def generate_table(t):
763 OUT = open("taxa.tab", "w")
764 for j, n in enumerate(t.traverse()):
765 if j%1000 == 0:
766 print("\r",j,"generating entries...", end=' ')
767 temp_node = n
768 track = []
769 while temp_node:
770 track.append(temp_node.name)
771 temp_node = temp_node.up
772 if n.up:
773 print('\t'.join([n.name, n.up.name, n.props.get('taxname'), n.props.get("common_name", ''), n.props.get("rank"), ','.join(track)]), file=OUT)
774 else:
775 print('\t'.join([n.name, "", n.props.get('taxname'), n.props.get("common_name", ''), n.props.get("rank"), ','.join(track)]), file=OUT)
776 OUT.close()
779def update_db(dbfile, targz_file=None):
780 basepath = os.path.split(dbfile)[0]
781 if basepath and not os.path.exists(basepath):
782 os.mkdir(basepath)
784 # if users don't provie targz_file, update the latest version from ete-data
785 if not targz_file:
786 update_local_taxdump(DEFAULT_GTDBTAXADUMP)
787 targz_file = DEFAULT_GTDBTAXADUMP
789 tar = tarfile.open(targz_file, 'r')
790 t, synonyms = load_gtdb_tree_from_dump(tar)
792 prepostorder = [int(node.name) for post, node in t.iter_prepostorder()]
794 with open(dbfile+'.traverse.pkl', 'wb') as fout:
795 pickle.dump(prepostorder, fout, 2)
797 print("Updating database: %s ..." %dbfile)
798 generate_table(t)
800 upload_data(dbfile)
802 os.system("rm taxa.tab")
804def update_local_taxdump(fname=DEFAULT_GTDBTAXADUMP):
805 # latest version of gtdb taxonomy dump
806 url = "https://github.com/etetoolkit/ete-data/raw/main/gtdb_taxonomy/gtdblatest/gtdb_latest_dump.tar.gz"
808 if not os.path.exists(fname):
809 print(f'Downloading {fname} from {url} ...')
810 with open(fname, 'wb') as f:
811 f.write(requests.get(url).content)
812 else:
813 md5_local = md5(open(fname, 'rb').read()).hexdigest()
814 md5_remote = requests.get(url + '.md5').text.split()[0]
816 if md5_local != md5_remote:
817 print(f'Updating {fname} from {url} ...')
818 with open(fname, 'wb') as f:
819 f.write(requests.get(url).content)
820 else:
821 print(f'File {fname} is already up-to-date with {url} .')
823def upload_data(dbfile):
824 print()
825 print('Uploading to', dbfile)
826 basepath = os.path.split(dbfile)[0]
827 if basepath and not os.path.exists(basepath):
828 os.mkdir(basepath)
830 db = sqlite3.connect(dbfile)
832 create_cmd = """
833 DROP TABLE IF EXISTS stats;
834 DROP TABLE IF EXISTS species;
835 DROP TABLE IF EXISTS synonym;
836 DROP TABLE IF EXISTS merged;
837 CREATE TABLE stats (version INT PRIMARY KEY);
838 CREATE TABLE species (taxid INT PRIMARY KEY, parent INT, spname VARCHAR(50) COLLATE NOCASE, common VARCHAR(50) COLLATE NOCASE, rank VARCHAR(50), track TEXT);
839 CREATE TABLE synonym (taxid INT,spname VARCHAR(50) COLLATE NOCASE, PRIMARY KEY (spname, taxid));
840 CREATE TABLE merged (taxid_old INT, taxid_new INT);
841 CREATE INDEX spname1 ON species (spname COLLATE NOCASE);
842 CREATE INDEX spname2 ON synonym (spname COLLATE NOCASE);
843 """
844 for cmd in create_cmd.split(';'):
845 db.execute(cmd)
846 print()
848 db.execute("INSERT INTO stats (version) VALUES (%d);" %DB_VERSION)
849 db.commit()
851 # for i, line in enumerate(open("syn.tab")):
852 # if i%5000 == 0 :
853 # print('\rInserting synonyms: % 6d' %i, end=' ', file=sys.stderr)
854 # sys.stderr.flush()
855 # taxid, spname = line.strip('\n').split('\t')
856 # db.execute("INSERT INTO synonym (taxid, spname) VALUES (?, ?);", (taxid, spname))
857 # print()
858 # db.commit()
859 # for i, line in enumerate(open("merged.tab")):
860 # if i%5000 == 0 :
861 # print('\rInserting taxid merges: % 6d' %i, end=' ', file=sys.stderr)
862 # sys.stderr.flush()
863 # taxid_old, taxid_new = line.strip('\n').split('\t')
864 # db.execute("INSERT INTO merged (taxid_old, taxid_new) VALUES (?, ?);", (taxid_old, taxid_new))
865 # print()
866 # db.commit()
868 with open('taxa.tab') as f_taxa:
869 for i, line in enumerate(f_taxa):
870 if i % 5000 == 0:
871 print('\rInserting taxids: %8d' % i, end=' ', file=sys.stderr)
872 sys.stderr.flush()
873 taxid, parentid, spname, common, rank, lineage = line.strip('\n').split('\t')
874 db.execute(('INSERT INTO species (taxid, parent, spname, common, rank, track) '
875 'VALUES (?, ?, ?, ?, ?, ?)'), (taxid, parentid, spname, common, rank, lineage))
876 print()
877 db.commit()
879if __name__ == "__main__":
880 #from .. import PhyloTree
881 gtdb = GTDBTaxa()
882 gtdb.update_taxonomy_database(DEFAULT_GTDBTAXADUMP)
884 descendants = gtdb.get_descendant_taxa('c__Thorarchaeia', collapse_subspecies=True, return_tree=True)
885 print(descendants.write(properties=None))
886 print(descendants.get_ascii(properties=['sci_name', 'taxid','rank']))
887 tree = gtdb.get_topology(["p__Huberarchaeota", "o__Peptococcales", "f__Korarchaeaceae", "s__Korarchaeum"], intermediate_nodes=True, collapse_subspecies=True, annotate=True)
888 print(tree.get_ascii(properties=["taxid", "sci_name", "rank"]))
890 tree = PhyloTree('((c__Thorarchaeia, c__Lokiarchaeia_A), s__Caballeronia udeis);', sp_naming_function=lambda name: name)
891 tax2name, tax2track, tax2rank = gtdb.annotate_tree(tree, taxid_attr="name")
892 print(tree.get_ascii(properties=["taxid","name", "sci_name", "rank"]))
894 print(gtdb.get_name_lineage(['RS_GCF_006228565.1','GB_GCA_001515945.1']))