Coverage for /home/deng/Projects/metatree_drawer/metatreedrawer/treeprofiler/tree_annotate.py: 63%
1005 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-07 10:33 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-07 10:33 +0200
1#!/usr/bin/env python3
3import os, math, re
4import logging
5import sys
6import time
7import random
8import csv
9import tarfile
11from collections import defaultdict, Counter
12import numpy as np
13from scipy import stats
14import requests
16from ete4.parser.newick import NewickError
17from ete4 import SeqGroup
18from ete4 import Tree, PhyloTree
19from ete4 import GTDBTaxa
20from ete4 import NCBITaxa
22from treeprofiler.src import utils
23from treeprofiler.src.phylosignal import run_acr_discrete, run_delta
24from treeprofiler.src.ls import run_ls
25from treeprofiler.src import b64pickle
27from multiprocessing import Pool
29DESC = "annotate tree"
31TAXONOMICDICT = {# start with leaf name
32 'rank': str,
33 'sci_name': str,
34 'taxid': str,
35 'lineage':str,
36 'named_lineage': str,
37 'evoltype': str,
38 'dup_sp': str,
39 'dup_percent': float,
40 }
42logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
44def populate_annotate_args(parser):
45 gmeta = parser.add_argument_group(
46 title='METADATA TABLE parameters',
47 description="Input parameters of METADATA")
48 add = gmeta.add_argument
49 add('-m', '--metadata', nargs='+',
50 help="<metadata.csv> .csv, .tsv. mandatory input")
51 # add('--data-matrix', nargs='+',
52 # help="<metadata.csv> .csv, .tsv. optional input")
53 add('--data-matrix', nargs='+',
54 help="<datamatrix.csv> .csv, .tsv. matrix data metadata table as array to tree, please do not provide column headers in this file")
55 add('-s', '--metadata-sep', default='\t',
56 help="column separator of metadata table [default: \\t]")
57 add('--no-headers', action='store_true',
58 help="metadata table doesn't contain columns name, namespace col+index will be assigned as the key of property such as col1.")
59 add('--duplicate', action='store_true',
60 help="treeprofiler will aggregate duplicated metadata to a list as a property if metadata contains duplicated row")
61 add('--text-prop', nargs='+',
62 help=("<col1> <col2> names, column index or index range of columns which "
63 "need to be read as categorical data"))
64 add('--multiple-text-prop', nargs='+',
65 help=("<col1> <col2> names, column index or index range of columns which "
66 "need to be read as categorical data which contains more than one"
67 " value and seperate by ',' such "
68 "as GO:0000003,GO:0000902,GO:0000904,GO:0003006"))
69 add('--num-prop', nargs='+',
70 help=("<col1> <col2> names, column index or index range of columns which "
71 "need to be read as numerical data"))
72 add('--bool-prop', nargs='+',
73 help=("<col1> <col2> names, column index or index range of columns which "
74 "need to be read as boolean data"))
75 add('--text-prop-idx', nargs='+',
76 help="1 2 3 or [1-5] index of columns which need to be read as categorical data")
77 add('--num-prop-idx', nargs='+',
78 help="1 2 3 or [1-5] index columns which need to be read as numerical data")
79 add('--bool-prop-idx', nargs='+',
80 help="1 2 3 or [1-5] index columns which need to be read as boolean data")
81 add('--acr-discrete-columns', nargs='+',
82 help=("<col1> <col2> names to perform acr analysis for discrete traits"))
83 # add('--acr-continuous-columns', nargs='+',
84 # help=("<col1> <col2> names to perform acr analysis for continuous traits"))
85 add('--ls-columns', nargs='+',
86 help=("<col1> <col2> names to perform lineage specificity analysis"))
87 # add('--taxatree',
88 # help=("<kingdom|phylum|class|order|family|genus|species|subspecies> "
89 # "reference tree from taxonomic database"))
90 add('--taxadb', type=str.upper,
91 choices=['NCBI', 'GTDB', 'customdb'],
92 help="<NCBI|GTDB> for taxonomic annotation or fetch taxatree")
93 add('--gtdb-version', type=int,
94 choices=[95, 202, 207, 214, 220],
95 help='GTDB version for taxonomic annotation, such as 220. If it is not provided, the latest version will be used.')
96 add('--taxa-dump', type=str,
97 help='Path to taxonomic database dump file for specific version, such as gtdb taxadump https://github.com/etetoolkit/ete-data/raw/main/gtdb_taxonomy/gtdblatest/gtdb_latest_dump.tar.gz or NCBI taxadump https://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdump.tar.gz')
98 add('--taxon-column',
99 help="Activate taxonomic annotation using <col1> name of columns which need to be read as taxon data. \
100 Unless taxon data in leaf name, please use 'name' as input such as --taxon-column name")
101 add('--taxon-delimiter', default=None,
102 help="delimiter of taxa columns. [default: None]")
103 add('--taxa-field', type=int, default=0,
104 help="field of taxa name after delimiter. [default: 0]")
105 add('--ignore-unclassified', action='store_true',
106 help="Ignore unclassified taxa in taxonomic annotation")
107 add('--emapper-annotations',
108 help="attach eggNOG-mapper output out.emapper.annotations")
109 add('--emapper-pfam',
110 help="attach eggNOG-mapper pfam output out.emapper.pfams")
111 add('--emapper-smart',
112 help="attach eggNOG-mapper smart output out.emapper.smart")
113 add('--alignment',
114 help="Sequence alignment, .fasta format")
116 annotation_group = parser.add_argument_group(title='Internal nodes annotation arguments',
117 description="Annotation parameters")
118 annotation_group.add_argument('--column-summary-method',
119 nargs='+',
120 required=False,
121 help="Specify summary method for individual columns in the format COL=METHOD. Method option can be seen in --counter-stat and --num-stat.")
122 annotation_group.add_argument('--num-stat',
123 default='all',
124 choices=['all', 'sum', 'avg', 'max', 'min', 'std', 'none'],
125 type=str,
126 required=False,
127 help="statistic calculation to perform for numerical data in internal nodes, [all, sum, avg, max, min, std, none]. If 'none' was chosen, numerical properties won't be summarized nor annotated in internal nodes. [default: all]")
128 annotation_group.add_argument('--counter-stat',
129 default='raw',
130 choices=['raw', 'relative', 'none'],
131 type=str,
132 required=False,
133 help="statistic calculation to perform for categorical data in internal nodes, raw count or in percentage [raw, relative, none]. If 'none' was chosen, categorical and boolean properties won't be summarized nor annotated in internal nodes [default: raw]")
135 acr_group = parser.add_argument_group(title='Ancestral Character Reconstruction arguments',
136 description="ACR parameters")
137 acr_group.add_argument('--prediction-method',
138 default='MPPA',
139 choices=['MPPA','MAP','JOINT','DOWNPASS','ACCTRAN','DELTRAN','COPY','ALL','ML','MP'],
140 type=str,
141 required=False,
142 help="prediction method for ACR discrete analysis [default: MPPA]"
143 )
144 acr_group.add_argument('--model',
145 default='F81',
146 choices=['JC','F81','EFT','HKY','JTT','CUSTOM_RATES'],
147 type=str,
148 required=False,
149 help="Evolutionary model for ML methods in ACR discrete analysis [default: F81]"
150 )
151 acr_group.add_argument('--threads',
152 default=4,
153 type=int,
154 required=False,
155 help="Number of threads to use for annotation [default: 4]")
156 delta_group = parser.add_argument_group(title='Ancestral Character Reconstruction arguments',
157 description="Delta statistic parameters")
158 delta_group.add_argument('--delta-stats',
159 action='store_true',
160 required=False,
161 help="Calculate delta statistic for discrete traits in ACR analysis, ONLY for MPPA or MAP prediction method.[default: False]"
162 )
163 delta_group.add_argument('--ent-type',
164 default='SE',
165 choices=['LSE', 'SE', 'GINI'],
166 type=str,
167 required=False,
168 help="Entropy method to measure the degree of phylogenetic signal between discrete trati and phylogeny. \
169 [default: SE] for Shannon Entropy, other options are GINI for Gini impurity and LSE for Linear Shannon Entropy."
170 )
171 delta_group.add_argument('--iteration',
172 default=10000,
173 type=int,
174 required=False,
175 help="Number of iterations for delta statistic calculation. [default: 100]"
176 )
177 delta_group.add_argument('--lambda0',
178 type=float,
179 default=0.1,
180 help='Rate parameter of the delta statistic calculation.'
181 )
182 delta_group.add_argument('--se',
183 type=float,
184 default=0.5,
185 help='Standard deviation of the delta statistic calculation.')
186 delta_group.add_argument('--thin',
187 type=int,
188 default=10,
189 help='Keep only each xth iterate.')
190 delta_group.add_argument('--burn',
191 type=int,
192 default=100,
193 help='Burned-in iterates.')
194 ls_group = parser.add_argument_group(title='Lineage Specificity Analysis arguments',
195 description="ls parameters")
196 ls_group.add_argument('--prec-cutoff',
197 default=0.95,
198 type=float,
199 required=False,
200 help="Precision cutoff for lineage specificity analysis [default: 0.95]")
201 ls_group.add_argument('--sens-cutoff',
202 default=0.95,
203 type=float,
204 required=False,
205 help="Sensitivity threshold for lineage specificity analysis [default: 0.95]")
207 group = parser.add_argument_group(title='OUTPUT options',
208 description="")
209 group.add_argument('-o', '--outdir',
210 type=str,
211 required=True,
212 help="Directory for annotated outputs.")
214def run_tree_annotate(tree, input_annotated_tree=False,
215 metadata_dict={}, node_props=[], columns={}, prop2type={},
216 emapper_annotations=None,
217 text_prop=[], text_prop_idx=[], multiple_text_prop=[], num_prop=[], num_prop_idx=[],
218 bool_prop=[], bool_prop_idx=[], prop2type_file=None, alignment=None, emapper_pfam=None,
219 emapper_smart=None, counter_stat='raw', num_stat='all', column2method={},
220 taxadb='GTDB', gtdb_version=None, taxa_dump=None, taxon_column=None,
221 taxon_delimiter='', taxa_field=0, ignore_unclassified=False,
222 rank_limit=None, pruned_by=None,
223 acr_discrete_columns=None, prediction_method="MPPA", model="F81",
224 delta_stats=False, ent_type="SE",
225 iteration=100, lambda0=0.1, se=0.5, thin=10, burn=100,
226 ls_columns=None, prec_cutoff=0.95, sens_cutoff=0.95,
227 threads=1, outdir='./'):
229 total_color_dict = []
230 layouts = []
231 level = 1 # level 1 is the leaf name
233 if emapper_annotations:
234 emapper_metadata_dict, emapper_node_props, emapper_columns = parse_emapper_annotations(emapper_annotations)
235 metadata_dict.update(emapper_metadata_dict)
236 node_props.extend(emapper_node_props)
237 columns.update(emapper_columns)
239 prop2type.update({
240 'name': str,
241 'dist': float,
242 'support': float,
243 'seed_ortholog': str,
244 'evalue': float,
245 'score': float,
246 'eggNOG_OGs': list,
247 'max_annot_lvl': str,
248 'COG_category': str,
249 'Description': str,
250 'Preferred_name': str,
251 'GOs': list,
252 'EC':str,
253 'KEGG_ko': list,
254 'KEGG_Pathway': list,
255 'KEGG_Module': list,
256 'KEGG_Reaction':list,
257 'KEGG_rclass':list,
258 'BRITE':list,
259 'KEGG_TC':list,
260 'CAZy':list,
261 'BiGG_Reaction':list,
262 'PFAMs':list
263 })
265 if text_prop:
266 text_prop = text_prop
267 else:
268 text_prop = []
270 if multiple_text_prop:
271 multiple_text_prop = multiple_text_prop
272 else:
273 multiple_text_prop = []
275 if num_prop:
276 num_prop = num_prop
277 else:
278 num_prop = []
280 if bool_prop:
281 bool_prop = bool_prop
282 else:
283 bool_prop = []
285 if text_prop_idx:
287 index_list = []
288 for i in text_prop_idx:
290 if i[0] == '[' and i[-1] == ']':
291 text_prop_start, text_prop_end = get_range(i)
292 for j in range(text_prop_start, text_prop_end+1):
293 index_list.append(j)
294 else:
295 index_list.append(int(i))
297 text_prop = [node_props[index-1] for index in index_list]
299 if num_prop_idx:
300 index_list = []
301 for i in num_prop_idx:
302 if i[0] == '[' and i[-1] == ']':
303 num_prop_start, num_prop_end = get_range(i)
304 for j in range(num_prop_start, num_prop_end+1):
305 index_list.append(j)
306 else:
307 index_list.append(int(i))
309 num_prop = [node_props[index-1] for index in index_list]
311 if bool_prop_idx:
312 index_list = []
313 for i in bool_prop_idx:
314 if i[0] == '[' and i[-1] == ']':
315 bool_prop_start, bool_prop_end = get_range(i)
316 for j in range(bool_prop_start, bool_prop_end+1):
317 index_list.append(j)
318 else:
319 index_list.append(int(i))
321 bool_prop = [node_props[index-1] for index in index_list]
323 #rest_prop = []
324 if prop2type_file:
325 prop2type = {}
326 with open(prop2type_file, 'r') as f:
327 for line in f:
328 line = line.rstrip()
329 prop, value = line.split('\t')
330 prop2type[prop] = eval(value)
331 else:
332 # output datatype of each property of each tree node including internal nodes
333 if prop2type:
334 for key, dtype in prop2type.items():
335 if key in text_prop+multiple_text_prop+num_prop+bool_prop:
336 pass
338 # taxon prop wouldn be process as numerical/text/bool/list value
339 elif (taxon_column and key in taxon_column):
340 pass
342 else:
343 if dtype == list:
344 multiple_text_prop.append(key)
345 if dtype == str:
346 if key not in multiple_text_prop:
347 text_prop.append(key)
348 else:
349 pass
350 if dtype == float:
351 num_prop.append(key)
352 if dtype == bool:
353 bool_prop.append(key)
355 # paramemters can over write the default
356 if emapper_annotations:
357 text_prop.extend([
358 'seed_ortholog',
359 'max_annot_lvl',
360 'COG_category',
361 'EC'
362 ])
363 num_prop.extend([
364 'evalue',
365 'score'
366 ])
367 multiple_text_prop.extend([
368 'eggNOG_OGs', 'GOs', 'KEGG_ko', 'KEGG_Pathway',
369 'KEGG_Module', 'KEGG_Reaction', 'KEGG_rclass',
370 'BRITE', 'KEGG_TC', 'CAZy', 'BiGG_Reaction', 'PFAMs'])
373 for prop in text_prop:
374 prop2type[prop] = str
377 for prop in bool_prop:
378 prop2type[prop] = bool
381 for prop in multiple_text_prop:
382 prop2type[prop] = list
385 for prop in num_prop:
386 prop2type[prop] = float
389 prop2type.update({# start with leaf name
390 'name':str,
391 'dist':float,
392 'support':float,
393 })
395 # load annotations to leaves
396 start = time.time()
398 # alignment annotation
399 if alignment:
400 alignment_prop = 'alignment'
401 name2seq = parse_fasta(alignment)
402 for leaf in tree.leaves():
403 leaf.add_prop(alignment_prop, name2seq.get(leaf.name,''))
404 prop2type.update({
405 alignment_prop:str
406 })
408 # domain annotation before other annotation
409 if emapper_pfam:
410 domain_prop = 'dom_arq'
411 if not alignment:
412 raise ValueError("Please provide alignment file using '--alignment' for pfam annotation.")
413 annot_tree_pfam_table(tree, emapper_pfam, alignment, domain_prop=domain_prop)
414 prop2type.update({
415 domain_prop:str
416 })
417 if emapper_smart:
418 domain_prop = 'dom_arq'
419 if not alignment:
420 raise ValueError("Please provide alignment file using '--alignment' for smart annotation.")
421 annot_tree_smart_table(tree, emapper_smart, alignment, domain_prop=domain_prop)
422 prop2type.update({
423 domain_prop:str
424 })
426 # load all metadata to leaf nodes
428 # input_annotated_tree determines if input tree is already annotated, if annotated, no longer need metadata
430 if not input_annotated_tree:
431 if taxon_column: # to identify taxon column as taxa property from metadata
432 annotated_tree = load_metadata_to_tree(tree, metadata_dict, prop2type=prop2type, taxon_column=taxon_column, taxon_delimiter=taxon_delimiter, taxa_field=taxa_field, ignore_unclassified=ignore_unclassified)
433 else:
434 annotated_tree = load_metadata_to_tree(tree, metadata_dict, prop2type=prop2type)
435 else:
436 annotated_tree = tree
438 end = time.time()
439 print('Time for load_metadata_to_tree to run: ', end - start)
442 # Ancestor Character Reconstruction analysis
443 # data preparation
444 if acr_discrete_columns:
445 logging.info(f"Performing ACR analysis with Character {acr_discrete_columns} via {prediction_method} method with {model} model.......\n")
446 # need to be discrete traits
447 discrete_traits = text_prop + bool_prop
448 for k in acr_discrete_columns:
449 if k:
450 if k not in discrete_traits:
451 raise ValueError(f"Character {k} is not discrete trait, please check your input.")
452 #############################
453 start = time.time()
454 acr_discrete_columns_dict = {k: v for k, v in columns.items() if k in acr_discrete_columns}
455 acr_results, annotated_tree = run_acr_discrete(annotated_tree, acr_discrete_columns_dict, \
456 prediction_method=prediction_method, model=model, threads=threads, outdir=outdir)
458 # Clear extra features
459 utils.clear_extra_features([annotated_tree], prop2type.keys())
461 # get observed delta
462 # only MPPA,MAP method has marginal probabilities to calculate delta
463 if delta_stats:
464 if prediction_method in ['MPPA', 'MAP']:
465 logging.info(f"Performing Delta Statistic analysis with Character {acr_discrete_columns}...\n")
466 prop2delta = run_delta(acr_results, annotated_tree, ent_type=ent_type,
467 lambda0=lambda0, se=se, sim=iteration, burn=burn, thin=thin,
468 threads=threads)
470 for prop, delta_result in prop2delta.items():
471 logging.info(f"Delta statistic of {prop} is: {delta_result}")
472 tree.add_prop(utils.add_suffix(prop, "delta"), delta_result)
474 # start calculating p_value
475 logging.info(f"Calculating p_value for delta statistic...")
476 # get a copy of the tree
477 dump_tree = annotated_tree.copy()
478 utils.clear_extra_features([dump_tree], ["name", "dist", "support"])
480 prop2array = {}
481 for prop in columns.keys():
482 prop2array.update(convert_to_prop_array(metadata_dict, prop))
484 prop2delta_array = get_pval(prop2array, dump_tree, acr_discrete_columns_dict, \
485 iteration=100, prediction_method=prediction_method, model=model,
486 ent_type=ent_type, lambda0=lambda0, se=se, sim=iteration, burn=burn, thin=thin,
487 threads=threads)
489 for prop, delta_array in prop2delta_array.items():
490 p_value = np.sum(np.array(delta_array) > prop2delta[prop]) / len(delta_array)
491 logging.info(f"p_value of {prop} is {p_value}")
492 tree.add_prop(utils.add_suffix(prop, "pval"), p_value)
493 prop2type.update({
494 utils.add_suffix(prop, "pval"): float
495 })
497 for prop in acr_discrete_columns:
498 prop2type.update({
499 utils.add_suffix(prop, "delta"): float
500 })
501 else:
502 logging.warning(f"Delta statistic analysis only support MPPA and MAP prediction method, {prediction_method} is not supported.")
504 end = time.time()
505 print('Time for acr to run: ', end - start)
507 # lineage specificity analysis
508 if ls_columns:
509 logging.info(f"Performing Lineage Specificity analysis with Character {ls_columns}...\n")
510 if all(column in bool_prop for column in ls_columns):
511 best_node, qualified_nodes = run_ls(annotated_tree, props=ls_columns,
512 precision_cutoff=prec_cutoff, sensitivity_cutoff=sens_cutoff)
513 for prop in ls_columns:
514 prop2type.update({
515 utils.add_suffix(prop, "prec"): float,
516 utils.add_suffix(prop, "sens"): float,
517 utils.add_suffix(prop, "f1"): float
518 })
519 else:
520 logging.warning(f"Lineage specificity analysis only support boolean properties, {ls_columns} is not boolean property.")
522 # statistic method
523 counter_stat = counter_stat #'raw' or 'relative'
524 num_stat = num_stat
526 # merge annotations depends on the column datatype
527 start = time.time()
529 # choose summary method based on datatype
530 for prop in text_prop+multiple_text_prop+bool_prop:
531 if not prop in column2method:
532 column2method[prop] = counter_stat
533 if column2method[prop] != 'none':
534 prop2type[utils.add_suffix(prop, "counter")] = str
536 for prop in num_prop:
537 if not prop in column2method:
538 column2method[prop] = num_stat
539 if column2method[prop] == 'all':
540 prop2type[utils.add_suffix(prop, "avg")] = float
541 prop2type[utils.add_suffix(prop, "sum")] = float
542 prop2type[utils.add_suffix(prop, "max")] = float
543 prop2type[utils.add_suffix(prop, "min")] = float
544 prop2type[utils.add_suffix(prop, "std")] = float
545 elif column2method[prop] == 'none':
546 pass
547 else:
548 prop2type[utils.add_suffix(prop, column2method[prop])] = float
550 if not input_annotated_tree:
551 node2leaves = annotated_tree.get_cached_content()
553 # Prepare data for all nodes
554 nodes_data = []
555 nodes = []
556 for node in annotated_tree.traverse("postorder"):
557 if not node.is_leaf:
558 nodes.append(node)
559 node_data = (node, node2leaves[node], text_prop, multiple_text_prop, bool_prop, num_prop, column2method, alignment if 'alignment' in locals() else None, name2seq if 'name2seq' in locals() else None)
560 nodes_data.append(node_data)
562 # Process nodes in parallel if more than one thread is specified
563 if threads > 1:
564 with Pool(threads) as pool:
565 results = pool.map(process_node, nodes_data)
566 else:
567 # For single-threaded execution, process nodes sequentially
568 results = map(process_node, nodes_data)
570 # Integrate the results back into tree
571 for node, result in zip(nodes, results):
572 internal_props, consensus_seq = result
574 for key, value in internal_props.items():
575 node.add_prop(key, value)
576 if consensus_seq:
577 node.add_prop(alignment_prop, consensus_seq)
579 else:
580 pass
582 end = time.time()
583 print('Time for merge annotations to run: ', end - start)
585 # taxa annotations
586 start = time.time()
587 if taxon_column:
588 if not taxadb:
589 raise Exception('Please specify which taxa db using --taxadb <GTDB|NCBI>')
590 else:
591 if taxadb == 'GTDB':
592 if gtdb_version and taxa_dump:
593 raise Exception('Please specify either GTDB version or taxa dump file, not both.')
594 if gtdb_version:
595 # get taxadump from ete-data
596 gtdbtaxadump = get_gtdbtaxadump(gtdb_version)
597 logging.info(f"Loading GTDB database dump file {gtdbtaxadump}...")
598 GTDBTaxa().update_taxonomy_database(gtdbtaxadump)
599 elif taxa_dump:
600 logging.info(f"Loading GTDB database dump file {taxa_dump}...")
601 GTDBTaxa().update_taxonomy_database(taxa_dump)
602 else:
603 logging.info("No specific version or dump file provided; using latest GTDB data...")
604 GTDBTaxa().update_taxonomy_database()
605 elif taxadb == 'NCBI':
606 if taxa_dump:
607 logging.info(f"Loading NCBI database dump file {taxa_dump}...")
608 NCBITaxa().update_taxonomy_database(taxa_dump)
609 # else:
610 # NCBITaxa().update_taxonomy_database()
612 annotated_tree, rank2values = annotate_taxa(annotated_tree, db=taxadb, \
613 taxid_attr=taxon_column, sp_delimiter=taxon_delimiter, sp_field=taxa_field, \
614 ignore_unclassified=ignore_unclassified)
616 # evolutionary events annotation
617 annotated_tree = annotate_evol_events(annotated_tree, sp_delimiter=taxon_delimiter, sp_field=taxa_field)
618 prop2type.update(TAXONOMICDICT)
619 else:
620 rank2values = {}
622 end = time.time()
623 print('Time for annotate_taxa to run: ', end - start)
625 # prune tree by rank
626 if rank_limit:
627 annotated_tree = utils.taxatree_prune(annotated_tree, rank_limit=rank_limit)
629 # prune tree by condition
630 if pruned_by: # need to be wrap with quotes
631 condition_strings = pruned_by
632 annotated_tree = utils.conditional_prune(annotated_tree, condition_strings, prop2type)
634 # name internal nodes
635 annotated_tree = name_nodes(annotated_tree)
636 return annotated_tree, prop2type
639def run_array_annotate(tree, array_dict, num_stat='none', column2method={}):
640 matrix_props = list(array_dict.keys())
641 # annotate to the leaves
642 for node in tree.traverse():
643 if node.is_leaf:
644 for filename, array in array_dict.items():
645 if array.get(node.name):
646 node.add_prop(filename, array.get(node.name))
649 # merge annotations to internal nodes
650 for node in tree.traverse():
651 if not node.is_leaf:
652 for prop in matrix_props:
653 # get the array from the children leaf nodes
654 arrays = [child.get_prop(prop) for child in node.leaves() if child.get_prop(prop) is not None]
656 if column2method.get(prop) is not None:
657 num_stat = column2method.get(prop)
659 stats = compute_matrix_statistics(arrays, num_stat=num_stat)
660 if stats:
661 for stat, value in stats.items():
662 node.add_prop(utils.add_suffix(prop, stat), value.tolist())
663 #prop2type[utils.add_suffix(prop, stat)] = float
664 return tree
667def run(args):
668 total_color_dict = []
669 layouts = []
670 level = 1 # level 1 is the leaf name
671 prop2type = {}
672 metadata_dict = {}
673 column2method = {}
675 # checking file and output exists
676 if not os.path.exists(args.tree):
677 raise FileNotFoundError(f"Input tree {args.tree} does not exist.")
679 if args.metadata:
680 for metadata_file in args.metadata:
681 if not os.path.exists(metadata_file):
682 raise FileNotFoundError(f"Metadata {metadata_file} does not exist.")
684 if not os.path.exists(args.outdir):
685 raise FileNotFoundError(f"Output directory {args.outdir} does not exist.")
688 # parsing tree
689 try:
690 tree, eteformat_flag = utils.validate_tree(args.tree, args.input_type, args.internal)
691 except utils.TreeFormatError as e:
692 print(e)
693 sys.exit(1)
695 # resolve polytomy
696 if args.resolve_polytomy:
697 tree.resolve_polytomy()
699 # parse csv to metadata table
700 start = time.time()
701 print("start parsing...")
702 # parsing metadata
703 if args.metadata: # make a series of metadatas
704 metadata_dict, node_props, columns, prop2type = parse_csv(args.metadata, delimiter=args.metadata_sep, \
705 no_headers=args.no_headers, duplicate=args.duplicate)
706 else: # annotated_tree
707 node_props=[]
708 columns = {}
710 if args.data_matrix:
711 array_dict = parse_tsv_to_array(args.data_matrix, delimiter=args.metadata_sep)
712 end = time.time()
713 print('Time for parse_csv to run: ', end - start)
715 if args.emapper_annotations:
716 emapper_metadata_dict, emapper_node_props, emapper_columns = parse_emapper_annotations(args.emapper_annotations)
717 metadata_dict.update(emapper_metadata_dict)
718 node_props.extend(emapper_node_props)
719 columns.update(emapper_columns)
720 prop2type.update({
721 'name': str,
722 'dist': float,
723 'support': float,
724 'seed_ortholog': str,
725 'evalue': float,
726 'score': float,
727 'eggNOG_OGs': list,
728 'max_annot_lvl': str,
729 'COG_category': str,
730 'Description': str,
731 'Preferred_name': str,
732 'GOs': list,
733 'EC':str,
734 'KEGG_ko': list,
735 'KEGG_Pathway': list,
736 'KEGG_Module': list,
737 'KEGG_Reaction':list,
738 'KEGG_rclass':list,
739 'BRITE':list,
740 'KEGG_TC':list,
741 'CAZy':list,
742 'BiGG_Reaction':list,
743 'PFAMs':list
744 })
746 # start annotation
747 if args.column_summary_method:
748 column2method = process_column_summary_methods(args.column_summary_method)
750 annotated_tree, prop2type = run_tree_annotate(tree, input_annotated_tree=args.annotated_tree,
751 metadata_dict=metadata_dict, node_props=node_props, columns=columns,
752 prop2type=prop2type,
753 text_prop=args.text_prop, text_prop_idx=args.text_prop_idx,
754 multiple_text_prop=args.multiple_text_prop, num_prop=args.num_prop, num_prop_idx=args.num_prop_idx,
755 bool_prop=args.bool_prop, bool_prop_idx=args.bool_prop_idx,
756 prop2type_file=args.prop2type, alignment=args.alignment,
757 emapper_pfam=args.emapper_pfam, emapper_smart=args.emapper_smart,
758 counter_stat=args.counter_stat, num_stat=args.num_stat, column2method=column2method,
759 taxadb=args.taxadb, gtdb_version=args.gtdb_version,
760 taxa_dump=args.taxa_dump, taxon_column=args.taxon_column,
761 taxon_delimiter=args.taxon_delimiter, taxa_field=args.taxa_field, ignore_unclassified=args.ignore_unclassified,
762 rank_limit=args.rank_limit, pruned_by=args.pruned_by,
763 acr_discrete_columns=args.acr_discrete_columns,
764 prediction_method=args.prediction_method, model=args.model,
765 delta_stats=args.delta_stats, ent_type=args.ent_type,
766 iteration=args.iteration, lambda0=args.lambda0, se=args.se,
767 thin=args.thin, burn=args.burn,
768 ls_columns=args.ls_columns, prec_cutoff=args.prec_cutoff, sens_cutoff=args.sens_cutoff,
769 threads=args.threads, outdir=args.outdir)
771 if args.data_matrix:
772 annotated_tree = run_array_annotate(annotated_tree, array_dict, num_stat=args.num_stat, column2method=column2method)
775 if args.outdir:
776 base=os.path.splitext(os.path.basename(args.tree))[0]
777 out_newick = base + '_annotated.nw'
778 out_prop2tpye = base + '_prop2type.txt'
779 out_ete = base+'_annotated.ete'
780 out_tsv = base+'_annotated.tsv'
782 ### out newick
783 annotated_tree.write(outfile=os.path.join(args.outdir, out_newick), props=None,
784 parser=utils.get_internal_parser(args.internal), format_root_node=True)
786 ### output prop2type
787 with open(os.path.join(args.outdir, base+'_prop2type.txt'), "w") as f:
788 #f.write(first_line + "\n")
789 for key, value in prop2type.items():
790 f.write("{}\t{}\n".format(key, value.__name__))
791 ### out ete
792 with open(os.path.join(args.outdir, base+'_annotated.ete'), 'w') as f:
793 f.write(b64pickle.dumps(annotated_tree, encoder='pickle', pack=False))
795 ### out tsv
796 prop_keys = list(prop2type.keys())
797 if args.taxon_column:
798 prop_keys.extend(list(TAXONOMICDICT.keys()))
799 if args.annotated_tree:
800 tree2table(annotated_tree, internal_node=True, props=None, outfile=os.path.join(args.outdir, out_tsv))
801 else:
802 tree2table(annotated_tree, internal_node=True, props=prop_keys, outfile=os.path.join(args.outdir, out_tsv))
804 # if args.outtsv:
805 # tree2table(annotated_tree, internal_node=True, outfile=args.outtsv)
806 return
808def check_missing(input_string):
809 """
810 define missing:
811 1) One or more non-word characters at the beginning of the string.
812 2) The exact strings "none", "None", "null", or "NaN".
813 3) An empty string (zero characters).
814 """
815 pattern = r'^(?:\W+|none|None|null|Null|NaN|)$'
817 if input_string is None:
818 return True
819 elif re.match(pattern, input_string):
820 #print("Input contains only non-alphanumeric characters, 'none', a missing value, or an empty value.")
821 return True
822 else:
823 return False
826def check_tar_gz(file_path):
827 try:
828 with tarfile.open(file_path, 'r:gz') as tar:
829 return True
830 except tarfile.ReadError:
831 return False
833def parse_csv(input_files, delimiter='\t', no_headers=False, duplicate=False):
834 """
835 Takes tsv table as input
836 Return
837 metadata, as dictionary of dictionaries for each node's metadata
838 node_props, a list of property names(column names of metadata table)
839 columns, dictionary of property name and it's values
840 """
841 metadata = {}
842 columns = defaultdict(list)
843 prop2type = {}
844 def update_metadata(reader, node_header):
845 for row in reader:
846 nodename = row[node_header]
847 del row[node_header]
848 #row = {k: 'NaN' if (not v or v.lower() == 'none') else v for k, v in row.items() } ## replace empty to NaN
849 for k, v in row.items(): # replace missing value
850 if check_missing(v):
851 row[k] = 'NaN'
852 else:
853 row[k] = v
855 if nodename in metadata.keys():
856 for prop, value in row.items():
857 if duplicate:
858 if prop in metadata[nodename]:
859 exisiting_value = metadata[nodename][prop]
860 new_value = ','.join([exisiting_value,value])
861 metadata[nodename][prop] = new_value
862 columns[prop].append(new_value)
863 else:
864 metadata[nodename][prop] = value
865 columns[prop].append(value)
866 else:
867 metadata[nodename][prop] = value
868 columns[prop].append(value)
869 else:
870 metadata[nodename] = dict(row)
871 for (prop, value) in row.items(): # go over each column name and value
872 columns[prop].append(value) # append the value into the appropriate list
873 # based on column name k
875 def update_prop2type(node_props):
876 for prop in node_props:
877 if set(columns[prop])=={'NaN'}:
878 #prop2type[prop] = np.str_
879 prop2type[prop] = str
880 else:
881 dtype = infer_dtype(columns[prop])
882 prop2type[prop] = dtype # get_type_convert(dtype)
884 for input_file in input_files:
885 # check file
886 if check_tar_gz(input_file):
887 with tarfile.open(input_file, 'r:gz') as tar:
888 for member in tar.getmembers():
889 if member.isfile() and member.name.endswith('.tsv'):
890 with tar.extractfile(member) as tsv_file:
891 tsv_text = tsv_file.read().decode('utf-8').splitlines()
892 if no_headers:
893 fields_len = len(tsv_text[0].split(delimiter))
894 headers = ['col'+str(i) for i in range(fields_len)]
895 reader = csv.DictReader(tsv_text, delimiter=delimiter,fieldnames=headers)
896 else:
897 reader = csv.DictReader(tsv_text, delimiter=delimiter)
898 headers = reader.fieldnames
899 node_header, node_props = headers[0], headers[1:]
900 update_metadata(reader, node_header)
902 update_prop2type(node_props)
904 else:
905 with open(input_file, 'r') as f:
906 # Read the first line to determine the number of fields
907 first_line = next(f)
908 fields_len = len(first_line.split(delimiter))
912 # Reset the file pointer to the beginning
913 f.seek(0)
915 if no_headers:
916 # Generate header names
917 headers = ['col'+str(i) for i in range(fields_len)]
918 # Create a CSV reader with the generated headers
919 reader = csv.DictReader(f, delimiter=delimiter, fieldnames=headers)
920 else:
921 # Use the existing headers in the file
922 reader = csv.DictReader(f, delimiter=delimiter)
923 headers = reader.fieldnames
925 node_header, node_props = headers[0], headers[1:]
927 for row in reader:
928 nodename = row[node_header]
929 del row[node_header]
931 #row = {k: 'NaN' if (not v or v.lower() == 'none') else v for k, v in row.items() } ## replace empty to NaN
933 for k, v in row.items(): # replace missing value
934 if check_missing(v):
935 row[k] = 'NaN'
936 else:
937 row[k] = v
939 if nodename in metadata.keys():
940 for prop, value in row.items():
941 if duplicate:
942 if prop in metadata[nodename]:
943 exisiting_value = metadata[nodename][prop]
944 new_value = ','.join([exisiting_value,value])
945 metadata[nodename][prop] = new_value
946 columns[prop].append(new_value)
947 else:
948 metadata[nodename][prop] = value
949 columns[prop].append(value)
950 else:
951 metadata[nodename][prop] = value
952 columns[prop].append(value)
953 else:
954 metadata[nodename] = dict(row)
955 for (prop, value) in row.items(): # go over each column name and value
956 columns[prop].append(value) # append the value into the appropriate list
957 # based on column name k
958 update_prop2type(node_props)
960 return metadata, node_props, columns, prop2type
962def parse_tsv_to_array(input_files, delimiter='\t', no_headers=True):
963 """
964 Parses a TSV file into a dictionary with the first item of each row as the key
965 and the rest of the items in the row as a list in the value.
967 :param filename: Path to the TSV file to be parsed.
968 :return: A dictionary with keys as the first item of each row and values as lists of the remaining items.
969 """
970 is_float = True
971 matrix2array = {}
973 for input_file in input_files:
974 leaf2array = {}
975 prefix = os.path.basename(input_file)
976 with open(input_file, 'r') as file:
977 for line in file:
978 # Split each line by tab, strip removes trailing newline
979 row = line.strip().split(delimiter)
980 node = row[0]
981 value = row[1:] # The rest of the items as value
982 # Replace empty string with np.nan
983 value_list = [np.nan if x == '' else x for x in value]
984 try:
985 np_array = np.array(value_list).astype(np.float64)
986 leaf2array[node] = np_array.tolist()
987 except ValueError:
988 # Handle the case where conversion fails
989 print(f"Warning: Non-numeric data found in {prefix} for node {node}. Skipping.")
990 leaf2array[node] = None
991 is_float = False
993 matrix2array[prefix] = leaf2array
994 return matrix2array
996def process_column_summary_methods(column_summary_methods):
997 column_methods = {}
998 if column_summary_methods:
999 for entry in column_summary_methods:
1000 try:
1001 column, method = entry.split('=')
1002 column_methods[column] = method
1003 except ValueError:
1004 raise ValueError(f"Invalid format for --column-summary-method: '{entry}'. Expected format: ColumnName=Method")
1005 return column_methods
1007def get_comma_separated_values(lst):
1008 for item in lst:
1009 if isinstance(item, str) and any(',' in x for x in item.split()):
1010 return True
1011 return False
1013def can_convert_to_bool(column):
1014 true_values = {'true', 't', 'yes', 'y', '1'}
1015 false_values = {'false', 'f', 'no', 'n', '0'}
1016 ignore_values = {'nan', 'none', ''} # Add other representations of NaN as needed
1018 # Initialize sets to hold the representations of true and false values
1019 true_representations = set()
1020 false_representations = set()
1022 for value in column:
1023 str_val = str(value).strip()
1024 if str_val.lower() in ignore_values:
1025 continue # Skip this value
1026 if str_val.lower() in true_values:
1027 true_representations.add(str_val)
1028 elif str_val.lower() in false_values:
1029 false_representations.add(str_val)
1030 else:
1031 return False
1033 # Check that all true values and all false values have exactly one representation
1034 return len(true_representations) <= 1 and len(false_representations) <= 1
1037def convert_column_data(column, np_dtype):
1038 #np_dtype = np.dtype(dtype).type
1039 try:
1040 data = np.array(column).astype(np_dtype)
1041 return np_dtype
1042 except ValueError as e:
1043 return None
1045def convert_to_prop_array(metadata_dict, prop):
1046 """
1047 Convert a dictionary of metadata to a structured array format.
1049 Parameters:
1050 metadata_dict (dict): The original dictionary containing metadata.
1051 prop (str): The property to extract from the metadata.
1053 Returns:
1054 dict: A dictionary with keys as properties and values as structured arrays.
1056 # {"leaf1":{"prop":"value1"},{"leaf2":{"prop":"value2"}}}
1057 # to
1058 # {"prop":[["leaf1", "leaf2"],["value1", "value2"]]}
1059 """
1060 prop_array = {prop: [[], []]}
1061 for leaf, value in metadata_dict.items():
1062 prop_array[prop][0].append(leaf) # Append key
1063 prop_array[prop][1].append(value.get(prop, None)) # Append property value, handle missing values
1065 return prop_array
1067def convert_back_to_original(prop2array):
1068 """
1069 Convert the structured array format back to the original dictionary format.
1071 Parameters:
1072 prop2array (dict): The structured array format dictionary.
1074 Returns:
1075 dict: The original format of the data.
1076 """
1077 metadata_dict = {}
1078 for key in prop2array:
1079 identifiers, values = prop2array[key]
1080 for identifier, value in zip(identifiers, values):
1081 if identifier not in metadata_dict:
1082 metadata_dict[identifier] = {}
1083 metadata_dict[identifier][key] = value
1084 return metadata_dict
1086def infer_dtype(column):
1087 if get_comma_separated_values(column):
1088 return list
1089 elif can_convert_to_bool(column):
1090 return bool
1091 else:
1092 dtype_dict = {
1093 float:np.float64,
1094 str:np.str_
1095 }
1096 #dtype_order = ['float64', 'str']
1097 for dtype, np_dtype in dtype_dict.items():
1098 result = convert_column_data(column, np_dtype)
1099 if result is not None:
1100 # Successful inference, exit from the loop
1101 return dtype
1102 return None
1104def load_metadata_to_tree(tree, metadata_dict, prop2type={}, taxon_column=None, taxon_delimiter='', taxa_field=0, ignore_unclassified=False):
1105 #name2leaf = {}
1106 multi_text_seperator = ','
1107 common_ancestor_seperator = '||'
1109 name2node = defaultdict(list)
1110 # preload all leaves to save time instead of search in tree
1111 for node in tree.traverse():
1112 if node.name:
1113 name2node[node.name].append(node)
1115 # load all metadata to leaf nodes
1116 for node, props in metadata_dict.items():
1117 if node in name2node.keys():
1118 target_nodes = name2node[node]
1119 for target_node in target_nodes:
1120 for key,value in props.items():
1121 # taxa
1122 if key == taxon_column:
1123 if taxon_delimiter:
1124 taxon_prop = value.split(taxon_delimiter)[taxa_field]
1125 else:
1126 taxon_prop = value
1127 target_node.add_prop(key, taxon_prop)
1129 # numerical
1130 elif key in prop2type and prop2type[key]==float:
1131 try:
1132 flot_value = float(value)
1133 if math.isnan(flot_value):
1134 target_node.add_prop(key, 'NaN')
1135 else:
1136 target_node.add_prop(key, flot_value)
1137 except (ValueError,TypeError):
1138 target_node.add_prop(key, 'NaN')
1140 # categorical
1141 # list
1142 elif key in prop2type and prop2type[key]==list:
1143 value_list = value.split(multi_text_seperator)
1144 target_node.add_prop(key, value_list)
1145 # str
1146 else:
1147 target_node.add_prop(key, value)
1148 else:
1149 if common_ancestor_seperator in node:
1150 # get the common ancestor
1151 children = node.split(common_ancestor_seperator)
1152 target_node = tree.common_ancestor(children)
1153 for key,value in props.items():
1154 # taxa
1155 if key == taxon_column:
1156 if taxon_delimiter:
1157 taxon_prop = value.split(taxon_delimiter)[taxa_field]
1158 else:
1159 taxon_prop = value
1160 target_node.add_prop(key, taxon_prop)
1162 # numerical
1163 elif key in prop2type and prop2type[key]==float:
1164 try:
1165 flot_value = float(value)
1166 if math.isnan(flot_value):
1167 target_node.add_prop(key, 'NaN')
1168 else:
1169 target_node.add_prop(key, flot_value)
1170 except (ValueError,TypeError):
1171 target_node.add_prop(key, 'NaN')
1173 # categorical
1174 # list
1175 elif key in prop2type and prop2type[key]==list:
1176 value_list = value.split(multi_text_seperator)
1177 target_node.add_prop(key, value_list)
1178 # str
1179 else:
1180 target_node.add_prop(key, value)
1182 # hits = tree.get_leaves_by_name(node)
1183 # if hits:
1184 # for target_node in hits:
1185 # for key,value in props.items():
1186 # if key == taxon_column:
1187 # taxon_prop = value.split(taxon_delimiter)[-1]
1188 # target_node.add_prop(key, taxon_prop)
1189 # elif key in prop2type and prop2type[key]=='num':
1190 # if math.isnan(float(value)):
1191 # target_node.add_prop(key, value)
1192 # else:
1193 # target_node.add_prop(key, float(value))
1194 # else:
1195 # target_node.add_prop(key, value)
1196 # else:
1197 # pass
1198 #hits = tree.search_nodes(name=node) # including internal nodes
1200 return tree
1202def process_node(node_data):
1203 node, node_leaves, text_prop, multiple_text_prop, bool_prop, num_prop, column2method, alignment, name2seq = node_data
1204 internal_props = {}
1206 # Process text, multitext, bool, and num properties
1207 if text_prop:
1208 internal_props_text = merge_text_annotations(node_leaves, text_prop, column2method)
1209 internal_props.update(internal_props_text)
1211 if multiple_text_prop:
1212 internal_props_multi = merge_multitext_annotations(node_leaves, multiple_text_prop, column2method)
1213 internal_props.update(internal_props_multi)
1215 if bool_prop:
1216 internal_props_bool = merge_text_annotations(node_leaves, bool_prop, column2method)
1217 internal_props.update(internal_props_bool)
1219 if num_prop:
1220 internal_props_num = merge_num_annotations(node_leaves, num_prop, column2method)
1221 if internal_props_num:
1222 internal_props.update(internal_props_num)
1224 # Generate consensus sequence
1226 consensus_seq = None
1227 if alignment and name2seq is not None: # Check alignment and name2seq together
1228 aln_sum = column2method.get('alignment')
1229 if aln_sum is None or aln_sum != 'none':
1230 matrix_string = build_matrix_string(node, name2seq) # Assuming 'name2seq' is accessible here
1231 consensus_seq = utils.get_consensus_seq(matrix_string, threshold=0.7)
1233 return internal_props, consensus_seq
1235def merge_text_annotations(nodes, target_props, column2method):
1236 pair_seperator = "--"
1237 item_seperator = "||"
1238 internal_props = {}
1239 for target_prop in target_props:
1240 counter_stat = column2method.get(target_prop, "raw")
1241 if counter_stat == 'raw':
1242 prop_list = utils.children_prop_array_missing(nodes, target_prop)
1243 internal_props[utils.add_suffix(target_prop, 'counter')] = item_seperator.join([utils.add_suffix(str(key), value, pair_seperator) for key, value in sorted(dict(Counter(prop_list)).items())])
1245 elif counter_stat == 'relative':
1246 prop_list = utils.children_prop_array_missing(nodes, target_prop)
1247 counter_line = []
1249 total = sum(dict(Counter(prop_list)).values())
1251 for key, value in sorted(dict(Counter(prop_list)).items()):
1253 rel_val = '{0:.2f}'.format(float(value)/total)
1254 counter_line.append(utils.add_suffix(key, rel_val, pair_seperator))
1255 internal_props[utils.add_suffix(target_prop, 'counter')] = item_seperator.join(counter_line)
1256 #internal_props[utils.add_suffix(target_prop, 'counter')] = '||'.join([utils.add_suffix(key, value, '--') for key, value in dict(Counter(prop_list)).items()])
1258 else:
1259 #print('Invalid stat method')
1260 pass
1262 return internal_props
1264def merge_multitext_annotations(nodes, target_props, column2method):
1265 #seperator of multiple text 'GO:0000003,GO:0000902,GO:0000904'
1266 multi_text_seperator = ','
1267 pair_seperator = "--"
1268 item_seperator = "||"
1270 internal_props = {}
1271 for target_prop in target_props:
1272 counter_stat = column2method.get(target_prop, "raw")
1273 if counter_stat == 'raw':
1274 prop_list = utils.children_prop_array(nodes, target_prop)
1275 multi_prop_list = []
1277 for elements in prop_list:
1278 for j in elements:
1279 multi_prop_list.append(j)
1280 internal_props[utils.add_suffix(target_prop, 'counter')] = item_seperator.join([utils.add_suffix(str(key), value, pair_seperator) for key, value in sorted(dict(Counter(multi_prop_list)).items())])
1282 elif counter_stat == 'relative':
1283 prop_list = utils.children_prop_array(nodes, target_prop)
1284 multi_prop_list = []
1286 for elements in prop_list:
1287 for j in elements:
1288 multi_prop_list.append(j)
1290 counter_line = []
1292 total = sum(dict(Counter(multi_prop_list)).values())
1294 for key, value in sorted(dict(Counter(multi_prop_list)).items()):
1295 rel_val = '{0:.2f}'.format(float(value)/total)
1296 counter_line.append(utils.add_suffix(key, rel_val, pair_seperator))
1297 internal_props[utils.add_suffix(target_prop, 'counter')] = item_seperator.join(counter_line)
1298 #internal_props[utils.add_suffix(target_prop, 'counter')] = '||'.join([utils.add_suffix(key, value, '--') for key, value in dict(Counter(prop_list)).items()])
1299 else:
1300 #print('Invalid stat method')
1301 pass
1303 return internal_props
1305def merge_num_annotations(nodes, target_props, column2method):
1306 internal_props = {}
1307 for target_prop in target_props:
1308 num_stat = column2method.get(target_prop, None)
1309 if num_stat != 'none':
1310 if target_prop != 'dist' and target_prop != 'support':
1311 prop_array = np.array(utils.children_prop_array(nodes, target_prop),dtype=np.float64)
1312 prop_array = prop_array[~np.isnan(prop_array)] # remove nan data
1315 if prop_array.any():
1316 n, (smin, smax), sm, sv, ss, sk = stats.describe(prop_array)
1318 if num_stat == 'all':
1319 internal_props[utils.add_suffix(target_prop, 'avg')] = sm
1320 internal_props[utils.add_suffix(target_prop, 'sum')] = np.sum(prop_array)
1321 internal_props[utils.add_suffix(target_prop, 'max')] = smax
1322 internal_props[utils.add_suffix(target_prop, 'min')] = smin
1323 if math.isnan(sv) == False:
1324 internal_props[utils.add_suffix(target_prop, 'std')] = sv
1325 else:
1326 internal_props[utils.add_suffix(target_prop, 'std')] = 0
1328 elif num_stat == 'avg':
1329 internal_props[utils.add_suffix(target_prop, 'avg')] = sm
1330 elif num_stat == 'sum':
1331 internal_props[utils.add_suffix(target_prop, 'sum')] = np.sum(prop_array)
1332 elif num_stat == 'max':
1333 internal_props[utils.add_suffix(target_prop, 'max')] = smax
1334 elif num_stat == 'min':
1335 internal_props[utils.add_suffix(target_prop, 'min')] = smin
1336 elif num_stat == 'std':
1337 if math.isnan(sv) == False:
1338 internal_props[utils.add_suffix(target_prop, 'std')] = sv
1339 else:
1340 internal_props[utils.add_suffix(target_prop, 'std')] = 0
1341 else:
1342 #print('Invalid stat method')
1343 pass
1344 else:
1345 pass
1347 if internal_props:
1348 return internal_props
1349 else:
1350 return None
1352def compute_matrix_statistics(matrix, num_stat=None):
1353 """
1354 Computes specified statistics for the given matrix based on the num_stat parameter.
1356 :param matrix: A list of lists representing the matrix.
1357 :param num_stat: Specifies which statistics to compute. Can be "avg", "max", "min", "sum", "std", "all", or None.
1358 :return: A dictionary with the requested statistics or an empty dict/message.
1359 """
1361 stats = {}
1363 if num_stat == 'none':
1364 return stats # Return an empty dictionary if no statistics are requested
1366 # Replace None with np.nan or another appropriate value before creating the array
1367 if matrix is not None:
1368 cleaned_matrix = [[0 if x is None else x for x in row] for row in matrix]
1369 np_matrix = np.array(cleaned_matrix, dtype=np.float64)
1370 else:
1371 return {} # Return an empty dictionary if the matrix is empty
1373 if np_matrix.size == 0:
1374 return {} # Return an empty dictionary if the matrix is empty
1377 if np_matrix.ndim == 2 and np_matrix.shape[1] > 0:
1378 available_stats = {
1379 'avg': np_matrix.mean(axis=0),
1380 'max': np_matrix.max(axis=0),
1381 'min': np_matrix.min(axis=0),
1382 'sum': np_matrix.sum(axis=0),
1383 'std': np_matrix.std(axis=0)
1384 }
1386 if num_stat == "all":
1387 return available_stats
1388 elif num_stat in available_stats:
1389 stats[num_stat] = available_stats[num_stat]
1390 else:
1391 raise ValueError(f"Unsupported stat '{num_stat}'. Supported stats are 'avg', 'max', 'min', 'sum', 'std', or 'all'.")
1392 return stats
1394def name_nodes(tree):
1395 for i, node in enumerate(tree.traverse("postorder")):
1396 if not node.name or node.name == 'None':
1397 if not node.is_root:
1398 node.name = 'N'+str(i)
1399 else:
1400 node.name = 'Root'
1401 return tree
1403def gtdb_accession_to_taxid(accession):
1404 """Given a GTDB accession number, returns its complete accession"""
1405 if accession.startswith('GCA'):
1406 prefix = 'GB_'
1407 return prefix+accession
1408 elif accession.startswith('GCF'):
1409 prefix = 'RS_'
1410 return prefix+accession
1411 else:
1412 return accession
1414def get_gtdbtaxadump(version):
1415 """
1416 Download GTDB taxonomy dump
1417 """
1418 url = f"https://github.com/etetoolkit/ete-data/raw/main/gtdb_taxonomy/gtdb{version}/gtdb{version}dump.tar.gz"
1419 fname = f"gtdb{version}dump.tar.gz"
1420 logging.info(f'Downloading GTDB taxa dump fname from {url} ...')
1421 with open(fname, 'wb') as f:
1422 f.write(requests.get(url).content)
1423 return fname
1425def annotate_taxa(tree, db="GTDB", taxid_attr="name", sp_delimiter='.', sp_field=0, ignore_unclassified=False):
1426 global rank2values
1427 logging.info(f"\n==============Annotating tree with {db} taxonomic database============")
1429 def return_spcode_ncbi(leaf):
1430 try:
1431 return leaf.props.get(taxid_attr).split(sp_delimiter)[sp_field]
1432 except (IndexError, ValueError):
1433 return leaf.props.get(taxid_attr)
1435 def return_spcode_gtdb(leaf):
1436 try:
1437 if sp_delimiter:
1438 species_attribute = leaf.props.get(taxid_attr).split(sp_delimiter)[sp_field]
1439 return gtdb_accession_to_taxid(species_attribute)
1440 else:
1441 return gtdb_accession_to_taxid(leaf.props.get(taxid_attr))
1442 except (IndexError, ValueError):
1443 return gtdb_accession_to_taxid(leaf.props.get(taxid_attr))
1445 def merge_dictionaries(dict_ranks, dict_names):
1446 """
1447 Merges two dictionaries into one where the key is the rank from dict_ranks
1448 and the value is the corresponding name from dict_names.
1450 :param dict_ranks: Dictionary where the key is a numeric id and the value is a rank.
1451 :param dict_names: Dictionary where the key is the same numeric id and the value is a name.
1452 :return: A new dictionary where the rank is the key and the name is the value.
1453 """
1454 merged_dict = {}
1455 for key, rank in dict_ranks.items():
1456 if key in dict_names: # Ensure the key exists in both dictionaries
1457 if rank not in merged_dict or rank == 'no rank': # Handle 'no rank' by not overwriting existing entries unless it's the first encounter
1458 merged_dict[rank] = dict_names[key]
1460 return merged_dict
1463 if db == "GTDB":
1464 gtdb = GTDBTaxa()
1465 tree.set_species_naming_function(return_spcode_gtdb)
1466 gtdb.annotate_tree(tree, taxid_attr="species", ignore_unclassified=ignore_unclassified)
1467 suffix_to_rank_dict = {
1468 'd__': 'superkingdom', # Domain or Superkingdom
1469 'p__': 'phylum',
1470 'c__': 'class',
1471 'o__': 'order',
1472 'f__': 'family',
1473 'g__': 'genus',
1474 's__': 'species'
1475 }
1476 gtdb_re = r'^(GB_GCA_[0-9]+\.[0-9]+|RS_GCF_[0-9]+\.[0-9]+)'
1477 for n in tree.traverse():
1478 # in case miss something
1479 if n.props.get('named_lineage'):
1480 lca_dict = {}
1481 for taxa in n.props.get("named_lineage"):
1482 if re.match(gtdb_re, taxa):
1483 potential_rank = 'subspecies'
1484 lca_dict[potential_rank] = n.props.get("sci_name")
1485 else:
1486 potential_rank = suffix_to_rank_dict.get(taxa[:3], None)
1487 if potential_rank:
1488 lca_dict[potential_rank] = taxa
1489 n.add_prop("lca", utils.dict_to_string(lca_dict))
1491 elif db == "NCBI":
1492 ncbi = NCBITaxa()
1493 # extract sp codes from leaf names
1494 tree.set_species_naming_function(return_spcode_ncbi)
1495 ncbi.annotate_tree(tree, taxid_attr="species", ignore_unclassified=ignore_unclassified)
1496 for n in tree.traverse():
1497 if n.props.get('lineage'):
1498 lca_dict = {}
1499 #for taxa in n.props.get("lineage"):
1500 lineage2rank = ncbi.get_rank(n.props.get("lineage"))
1501 taxid2name = ncbi.get_taxid_translator(n.props.get("lineage"))
1502 lca_dict = merge_dictionaries(lineage2rank, taxid2name)
1503 n.add_prop("named_lineage", list(taxid2name.values()))
1504 n.add_prop("lca", utils.dict_to_string(lca_dict))
1506 # tree.annotate_gtdb_taxa(taxid_attr='name')
1507 # assign internal node as sci_name
1508 rank2values = defaultdict(list)
1509 for n in tree.traverse():
1510 if db == 'NCBI':
1511 n.del_prop('_speciesFunction')
1512 if n.props.get('rank') and n.props.get('rank') != 'Unknown':
1513 rank2values[n.props.get('rank')].append(n.props.get('sci_name',''))
1515 # if n.name:
1516 # pass
1517 # else:
1518 # n.name = n.props.get("sci_name", "")
1520 return tree, rank2values
1522def annotate_evol_events(tree, sp_delimiter='.', sp_field=0):
1523 def return_spcode(leaf):
1524 try:
1525 return leaf.name.split(sp_delimiter)[sp_field]
1526 except (IndexError, ValueError):
1527 return leaf.name
1529 tree.set_species_naming_function(return_spcode)
1531 node2species = tree.get_cached_content('species')
1532 for n in tree.traverse():
1533 n.props['species'] = node2species[n]
1534 if len(n.children) == 2:
1535 dup_sp = node2species[n.children[0]] & node2species[n.children[1]]
1536 if dup_sp:
1537 n.props['evoltype'] = 'D'
1538 n.props['dup_sp'] = ','.join(dup_sp)
1539 n.props['dup_percent'] = round(len(dup_sp)/len(node2species[n]), 3) * 100
1540 else:
1541 n.props['evoltype'] = 'S'
1542 n.del_prop('_speciesFunction')
1543 return tree
1545def get_range(input_range):
1546 column_range = input_range[input_range.find("[")+1:input_range.find("]")]
1547 column_start, column_end = [int(i) for i in column_range.split('-')]
1548 #column_list_idx = [i for i in range(column_start, column_end+1)]
1549 return column_start, column_end
1551def parse_emapper_annotations(input_file, delimiter='\t', no_headers=False):
1552 metadata = {}
1553 columns = defaultdict(list)
1554 prop2type = {}
1555 headers = ["#query", "seed_ortholog", "evalue", "score", "eggNOG_OGs",
1556 "max_annot_lvl", "COG_category", "Description", "Preferred_name", "GOs",
1557 "EC", "KEGG_ko", "KEGG_Pathway", "KEGG_Module", "KEGG_Reaction", "KEGG_rclass",
1558 "BRITE", "KEGG_TC", "CAZy", "BiGG_Reaction", "PFAMs"]
1560 with open(input_file, 'r') as f:
1561 # Skip lines starting with '##'
1562 filtered_lines = (line for line in f if not line.startswith('##'))
1564 if no_headers:
1565 reader = csv.DictReader(filtered_lines, delimiter=delimiter, fieldnames=headers)
1566 else:
1567 reader = csv.DictReader(filtered_lines, delimiter=delimiter)
1569 node_header, node_props = headers[0], headers[1:]
1570 for row in reader:
1571 nodename = row[node_header]
1572 del row[node_header]
1574 for k, v in row.items(): # Replace missing value
1575 row[k] = 'NaN' if check_missing(v) else v
1576 metadata[nodename] = dict(row)
1577 for k, v in row.items(): # Go over each column name and value
1578 columns[k].append(v) # Append the value into the appropriate list based on column name k
1580 return metadata, node_props, columns
1582def annot_tree_pfam_table(post_tree, pfam_table, alg_fasta, domain_prop='dom_arq'):
1583 pair_delimiter = "@"
1584 item_seperator = "||"
1585 fasta = SeqGroup(alg_fasta) # aligned_fasta
1586 raw2alg = defaultdict(dict)
1588 for num, (name, seq, _) in enumerate(fasta):
1589 p_raw = 1
1590 for p_alg, (a) in enumerate(seq, 1):
1591 if a != '-':
1592 raw2alg[name][p_raw] = p_alg
1593 p_raw +=1
1595 seq2doms = defaultdict(list)
1596 with open(pfam_table) as f_in:
1597 for line in f_in:
1598 if not line.startswith('#'):
1599 info = line.strip().split('\t')
1600 seq_name = info[0]
1601 dom_name = info[1]
1602 dom_start = int(info[7])
1603 dom_end = int(info[8])
1604 if raw2alg.get(seq_name):
1605 try:
1606 trans_dom_start = raw2alg[seq_name][dom_start]
1607 trans_dom_end = raw2alg[seq_name][dom_end]
1608 dom_info_string = pair_delimiter.join([dom_name, str(trans_dom_start), str(trans_dom_end)])
1609 seq2doms[seq_name].append(dom_info_string)
1610 except KeyError:
1611 raise KeyError(f"Cannot find {dom_start} or {dom_end} in {seq_name}")
1612 for l in post_tree:
1613 if l.name in seq2doms.keys():
1614 domains = seq2doms[l.name]
1615 domains_string = item_seperator.join(domains)
1616 l.add_prop(domain_prop, domains_string)
1618 for n in post_tree.traverse():
1619 if not n.is_leaf:
1620 random_node_domains = n.get_closest_leaf()[0].props.get(domain_prop, 'none@none@none')
1621 n.add_prop(domain_prop, random_node_domains)
1623 # for n in post_tree.traverse():
1624 # print(n.name, n.props.get('dom_arq'))
1626def annot_tree_smart_table(post_tree, smart_table, alg_fasta, domain_prop='dom_arq'):
1627 pair_delimiter = "@"
1628 item_seperator = "||"
1629 fasta = SeqGroup(alg_fasta) # aligned_fasta
1630 raw2alg = defaultdict(dict)
1631 for num, (name, seq, _) in enumerate(fasta):
1632 p_raw = 1
1633 for p_alg, (a) in enumerate(seq, 1):
1634 if a != '-':
1635 raw2alg[name][p_raw] = p_alg
1636 p_raw +=1
1638 seq2doms = defaultdict(list)
1639 with open(smart_table) as f_in:
1640 for line in f_in:
1641 if not line.startswith('#'):
1642 info = line.strip().split('\t')
1643 seq_name = info[0]
1644 dom_name = info[1]
1645 dom_start = int(info[2])
1646 dom_end = int(info[3])
1647 if raw2alg.get(seq_name):
1648 trans_dom_start = raw2alg[seq_name][dom_start]
1649 trans_dom_end = raw2alg[seq_name][dom_end]
1651 dom_info_string = pair_delimiter.join([dom_name, str(trans_dom_start), str(trans_dom_end)])
1652 seq2doms[seq_name].append(dom_info_string)
1654 for l in post_tree:
1655 if l.name in seq2doms.keys():
1656 domains = seq2doms[l.name]
1657 domains_string = item_seperator.join(domains)
1658 l.add_prop(domain_prop, domains_string)
1660 for n in post_tree.traverse():
1661 if not n.is_leaf:
1662 random_node_domains = n.get_closest_leaf()[0].props.get(domain_prop, 'none@none@none')
1663 n.add_prop(domain_prop, random_node_domains)
1665 # for n in post_tree.traverse():
1666 # print(n.name, n.props.get('dom_arq'))
1668def parse_fasta(fastafile):
1669 fasta_dict = {}
1670 with open(fastafile,'r') as f:
1671 head = ''
1672 seq = ''
1673 for line in f:
1674 line = line.strip()
1675 if line.startswith('>'):
1676 if seq != '':
1677 fasta_dict[head] = seq
1678 seq = ''
1679 head = line[1:]
1680 else:
1681 head = line[1:]
1682 else:
1683 seq += line
1684 fasta_dict[head] = seq
1685 return fasta_dict
1687# def get_pval(prop2array, dump_tree, acr_discrete_columns_dict, iteration=100,
1688# prediction_method="MPPA", model="F81", ent_type='SE',
1689# lambda0=0.1, se=0.5, sim=10000, burn=100, thin=10, threads=1):
1690# prop2delta_array = {}
1691# for _ in range(iteration):
1692# shuffled_dict = {}
1693# for column, trait in acr_discrete_columns_dict.items():
1694# trait = acr_discrete_columns_dict[column]
1695# #shuffle traits
1696# shuffled_trait = np.random.choice(trait, len(trait), replace=False)
1697# prop2array[column][1] = list(shuffled_trait)
1698# shuffled_dict[column] = list(shuffled_trait)
1700# # Converting back to the original dictionary format
1701# # # annotate new metadata to leaf
1702# new_metadata_dict = convert_back_to_original(prop2array)
1703# dump_tree = load_metadata_to_tree(dump_tree, new_metadata_dict)
1705# # # run acr
1706# random_acr_results, dump_tree = run_acr_discrete(dump_tree, shuffled_dict, \
1707# prediction_method="MPPA", model="F81", threads=threads, outdir=None)
1708# random_delta = run_delta(random_acr_results, dump_tree, ent_type=ent_type,
1709# lambda0=lambda0, se=se, sim=sim, burn=burn, thin=thin,
1710# threads=threads)
1712# for prop, delta_result in random_delta.items():
1714# if prop in prop2delta_array:
1715# prop2delta_array[prop].append(delta_result)
1716# else:
1717# prop2delta_array[prop] = [delta_result]
1718# utils.clear_extra_features([dump_tree], ["name", "dist", "support"])
1720# return prop2delta_array
1722def _worker_function(iteration_data):
1723 # Unpack the necessary data for one iteration
1724 prop2array, dump_tree, acr_discrete_columns_dict, prediction_method, model, ent_type, lambda0, se, sim, burn, thin, threads = iteration_data
1726 shuffled_dict = {}
1727 for column, trait in acr_discrete_columns_dict.items():
1728 # Shuffle traits
1729 shuffled_trait = np.random.choice(trait, len(trait), replace=False)
1730 prop2array[column][1] = list(shuffled_trait)
1731 shuffled_dict[column] = list(shuffled_trait)
1733 # Converting back to the original dictionary format
1734 new_metadata_dict = convert_back_to_original(prop2array)
1735 updated_tree = load_metadata_to_tree(dump_tree, new_metadata_dict)
1737 # Run ACR
1738 random_acr_results, updated_tree = run_acr_discrete(updated_tree, shuffled_dict,
1739 prediction_method=prediction_method,
1740 model=model, threads=threads, outdir=None)
1741 random_delta = run_delta(random_acr_results, updated_tree, ent_type=ent_type,
1742 lambda0=lambda0, se=se, sim=sim, burn=burn, thin=thin,
1743 threads=threads)
1745 # Clear extra features from the tree
1746 utils.clear_extra_features([updated_tree], ["name", "dist", "support"])
1747 return random_delta
1749def get_pval(prop2array, dump_tree, acr_discrete_columns_dict, iteration=100,
1750 prediction_method="MPPA", model="F81", ent_type='SE',
1751 lambda0=0.1, se=0.5, sim=10000, burn=100, thin=10, threads=1):
1752 prop2delta_array = {}
1754 # Prepare data for each iteration
1755 iteration_data = [(prop2array, dump_tree, acr_discrete_columns_dict, prediction_method, model, ent_type, lambda0, se, sim, burn, thin, threads) for _ in range(iteration)]
1757 # Use multiprocessing pool
1758 if threads > 1:
1759 with Pool(threads) as pool:
1760 results = pool.map(_worker_function, iteration_data)
1761 else:
1762 results = map(_worker_function, iteration_data)
1764 # Aggregate results
1765 for delta_result in results:
1766 for prop, result in delta_result.items():
1767 if prop in prop2delta_array:
1768 prop2delta_array[prop].append(result)
1769 else:
1770 prop2delta_array[prop] = [result]
1772 return prop2delta_array
1774# Function to build the matrix string for a node
1775def build_matrix_string(node, name2seq):
1776 matrix = ''
1777 for leaf in node.leaves():
1778 if name2seq.get(leaf.name):
1779 matrix += f">{leaf.name}\n{name2seq.get(leaf.name)}\n"
1780 return matrix
1782def tree2table(tree, internal_node=True, props=None, outfile='tree2table.csv'):
1783 node2leaves = {}
1784 leaf2annotations = {}
1785 if not props:
1786 props = set()
1787 for node in tree.traverse():
1788 props |= node.props.keys()
1789 props = [ p for p in props if not p.startswith("_") ]
1791 with open(outfile, 'w', newline='') as csvfile:
1792 if '_speciesFunction' in props:
1793 props.remove('_speciesFunction')
1794 fieldnames = ['name', 'dist', 'support']
1795 fieldnames.extend(x for x in sorted(props) if x not in fieldnames)
1797 writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter='\t', extrasaction='ignore')
1798 writer.writeheader()
1799 for node in tree.traverse():
1800 if node.name:
1801 # if '_speciesFunction' in node.props:
1802 # node.del_prop('_speciesFunction')
1804 if node.is_leaf:
1805 output_row = dict(node.props)
1806 for k, prop in output_row.items():
1807 if type(prop) == list:
1808 output_row[k] = '|'.join(str(v) for v in prop)
1809 writer.writerow(output_row)
1810 else:
1811 if internal_node:
1812 output_row = dict(node.props)
1813 for k, prop in output_row.items():
1814 if type(prop) == list:
1815 output_row[k] = '|'.join(str(v) for v in prop)
1816 writer.writerow(output_row)
1817 else:
1818 pass