Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/tree.py: 10%
347 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
1import logging
2import os
3import re
4from collections import Counter, defaultdict
5from datetime import datetime
7from Bio import Phylo
8#from ete3 import Tree, TreeNode
9from ete4 import Tree
11POSTORDER = 'postorder'
13INORDER = 'inorder'
15PREORDER = 'preorder'
17DATE = 'date'
18DATE_CI = 'date_CI'
20DATE_REGEX = r'[+-]*[\d]+[.\d]*(?:[e][+-][\d]+){0,1}'
21DATE_COMMENT_REGEX = '[&,:]date[=]["]{{0,1}}({})["]{{0,1}}'.format(DATE_REGEX)
22CI_DATE_REGEX_LSD = '[&,:]CI_date[=]["]{{0,1}}[{{]{{0,1}}({})\s*[,;]{{0,1}}\s*({})[}}]{{0,1}}["]{{0,1}}'.format(DATE_REGEX, DATE_REGEX)
23CI_DATE_REGEX_PASTML = '[&,:]date_CI[=]["]{{0,1}}({})[|]({})["]{{0,1}}'.format(DATE_REGEX, DATE_REGEX)
24COLUMN_REGEX_PASTML = '[&,]{column}[=]([^]^,]+)'
26IS_POLYTOMY = 'polytomy'
29def get_dist_to_root(tip):
30 dist_to_root = 0
31 n = tip
32 while not n.is_root:
33 dist_to_root += n.dist
34 n = n.up
35 return dist_to_root
38def annotate_dates(forest, root_dates=None):
39 if root_dates is None:
40 root_dates = [0] * len(forest)
41 for tree, root_date in zip(forest, root_dates):
42 for node in tree.traverse('preorder'):
43 if node.props.get(DATE, None) is None:
44 if node.is_root:
45 node.add_prop(DATE, root_date if root_date else 0)
46 else:
47 node.add_prop(DATE, node.up.props.get(DATE) + node.dist)
48 else:
49 node.add_prop(DATE, float(node.props.get(DATE)))
50 ci = node.props.get(DATE_CI, None)
51 if ci and not isinstance(ci, list) and not isinstance(ci, tuple):
52 node.del_prop(DATE_CI)
53 if isinstance(ci, str) and '|' in ci:
54 try:
55 node.add_prop(DATE_CI, [float(_) for _ in ci.split('|')])
56 except:
57 pass
60def name_tree(tree, suffix=""):
61 """
62 Names all the tree nodes that are not named or have non-unique names, with unique names.
64 :param tree: tree to be named
65 :type tree: ete3.Tree
67 :return: void, modifies the original tree
68 """
69 existing_names = Counter()
70 n_nodes = 0
71 for _ in tree.traverse():
72 n_nodes += 1
73 if _.name:
74 existing_names[_.name] += 1
75 if '.polytomy_' in _.name:
76 _.add_prop(IS_POLYTOMY, 1)
77 if n_nodes == len(existing_names):
78 return
79 i = 0
80 new_existing_names = Counter()
81 for node in tree.traverse('preorder'):
82 name_prefix = node.name if node.name and existing_names[node.name] < 10 \
83 else '{}{}{}'.format('t' if node.is_leaf else 'n', i, suffix)
84 name = 'root{}'.format(suffix) if node.is_root else name_prefix
85 while name is None or name in new_existing_names:
86 name = '{}{}{}'.format(name_prefix, i, suffix)
87 i += 1
88 node.name = name
89 new_existing_names[name] += 1
92def collapse_zero_branches(forest, features_to_be_merged=None):
93 """
94 Collapses zero branches in tre tree/forest.
96 :param forest: tree or list of trees
97 :type forest: ete3.Tree or list(ete3.Tree)
98 :param features_to_be_merged: list of features whose values are to be merged
99 in case the nodes are merged during collapsing
100 :type features_to_be_merged: list(str)
101 :return: void
102 """
103 num_collapsed = 0
105 if features_to_be_merged is None:
106 features_to_be_merged = []
108 for tree in forest:
109 for n in list(tree.traverse('postorder')):
110 zero_children = [child for child in n.children if not child.is_leaf and child.dist <= 0]
111 if not zero_children:
112 continue
113 for feature in features_to_be_merged:
114 feature_intersection = set.intersection(*(child.props.get(feature, set()) for child in zero_children)) \
115 & n.props.get(feature, set())
116 if feature_intersection:
117 value = feature_intersection
118 else:
119 value = set.union(*(child.props.get(feature, set()) for child in zero_children)) \
120 | n.props.get(feature, set())
121 if value:
122 n.add_prop(feature, value)
123 for child in zero_children:
124 n.remove_child(child)
125 for grandchild in child.children:
126 n.add_child(grandchild)
127 num_collapsed += len(zero_children)
128 if num_collapsed:
129 logging.getLogger('pastml').debug('Collapsed {} internal zero branches.'.format(num_collapsed))
132def remove_certain_leaves(tr, to_remove=lambda node: False):
133 """
134 Removes all the branches leading to leaves identified positively by to_remove function.
135 :param tr: the tree of interest (ete3 Tree)
136 :param to_remove: a method to check is a leaf should be removed.
137 :return: void, modifies the initial tree.
138 """
140 tips = [tip for tip in tr if to_remove(tip)]
141 for node in tips:
142 if node.is_root:
143 return None
144 parent = node.up
145 parent.remove_child(node)
146 # If the parent node has only one child now, merge them.
147 if len(parent.children) == 1:
148 brother = parent.children[0]
149 brother.dist += parent.dist
150 if parent.is_root:
151 brother.up = None
152 tr = brother
153 else:
154 grandparent = parent.up
155 grandparent.remove_child(parent)
156 grandparent.add_child(brother)
157 return tr
160def read_forest(tree_path, columns=None):
161 try:
162 roots = parse_nexus(tree_path, columns=columns)
163 if roots:
164 return roots
165 except:
166 pass
167 with open(tree_path, 'r') as f:
168 nwks = f.read().replace('\n', '').split(';')
169 if not nwks:
170 raise ValueError('Could not find any trees (in newick or nexus format) in the file {}.'.format(tree_path))
171 return [read_tree(nwk + ';', columns) for nwk in nwks[:-1]]
174# def read_tree(tree_path, columns=None):
175# tree = None
176# for f in (3, 2, 5, 0, 1, 4, 6, 7, 8, 9):
177# try:
178# tree = Tree(tree_path, format=f)
179# break
180# except:
181# continue
182# if not tree:
183# raise ValueError('Could not read the tree {}. Is it a valid newick?'.format(tree_path))
184# if columns:
185# for n in tree.traverse():
186# for c in columns:
187# vs = set(getattr(n, c).split('|')) if hasattr(n, c) else set()
188# if vs:
189# n.add_prop(c, vs)
190# return tree
192def read_tree(tree_path, columns=None):
193 tree = None
194 tree = Tree(tree_path, parser=1)
195 if columns:
196 for n in tree.traverse():
197 for c in columns:
198 vs = set(n.props.get(c).split('|')) if n.props.get(c) else set()
199 if vs:
200 n.add_prop(c, vs)
201 return tree
203def parse_nexus(tree_path, columns=None):
204 trees = []
205 for nex_tree in read_nexus(tree_path):
206 todo = [(nex_tree.root, None)]
207 tree = None
208 while todo:
209 clade, parent = todo.pop()
210 dist = 0
211 try:
212 dist = float(clade.branch_length)
213 except:
214 pass
215 name = clade.props.get('name', None)
216 if not name:
217 name = clade.props.get('confidence', None)
218 if not isinstance(name, str):
219 name = None
220 node = Tree(dist=dist, name=name)
221 if parent is None:
222 tree = node
223 else:
224 parent.add_child(node)
226 # Parse LSD2 dates and CIs, and PastML columns
227 date, ci = None, None
228 columns2values = defaultdict(set)
229 comment = clade.props.get('comment', None)
230 if isinstance(comment, str):
231 date = next(iter(re.findall(DATE_COMMENT_REGEX, comment)), None)
232 ci = next(iter(re.findall(CI_DATE_REGEX_LSD, comment)), None)
233 if ci is None:
234 ci = next(iter(re.findall(CI_DATE_REGEX_PASTML, comment)), None)
235 if columns:
236 for column in columns:
237 values = \
238 set.union(*(set(_.split('|')) for _ in re.findall(COLUMN_REGEX_PASTML.format(column=column),
239 comment)), set())
240 if values:
241 columns2values[column] |= values
242 comment = clade.props.get('branch_length', None)
243 if not ci and not parent and isinstance(comment, str):
244 ci = next(iter(re.findall(CI_DATE_REGEX_LSD, comment)), None)
245 if ci is None:
246 ci = next(iter(re.findall(CI_DATE_REGEX_PASTML, comment)), None)
247 comment = clade.props.get('confidence', None)
248 if ci is None and comment is not None and isinstance(comment, str):
249 ci = next(iter(re.findall(CI_DATE_REGEX_LSD, comment)), None)
250 if ci is None:
251 ci = next(iter(re.findall(CI_DATE_REGEX_PASTML, comment)), None)
252 if date is not None:
253 try:
254 date = float(date)
255 node.add_prop(DATE, date)
256 except:
257 pass
258 if ci is not None:
259 try:
260 ci = [float(_) for _ in ci]
261 node.add_prop(DATE_CI, ci)
262 except:
263 pass
264 if columns2values:
265 for c, vs in columns2values.items():
266 node.add_prop(c, vs)
267 todo.extend((c, node) for c in clade.clades)
268 for n in tree.traverse('preorder'):
269 date, ci = n.props.get(DATE, None), n.props.get(DATE_CI, None)
270 if date is not None or ci is not None:
271 for c in n.children:
272 if c.dist == 0:
273 if c.props.get(DATE, None) is None:
274 c.add_prop(DATE, date)
275 if c.props.get(DATE_CI, None) is None:
276 c.add_prop(DATE_CI, ci)
277 for n in tree.traverse('postorder'):
278 date, ci = n.props.get(DATE, None), n.props.get(DATE_CI, None)
279 if not n.is_root and n.dist == 0 and (date is not None or ci is not None):
280 if n.up.props.get(DATE, None) is None:
281 n.up.add_prop(DATE, date)
282 if n.up.props.get(DATE_CI, None) is None:
283 n.up.add_prop(DATE_CI, ci)
285 # propagate dates up to the root if needed
286 if tree.props.get(DATE, None) is None:
287 dated_node = next((n for n in tree.traverse() if n.props.get(DATE, None) is not None), None)
288 if dated_node:
289 while dated_node != tree:
290 if dated_node.up.props.get(DATE, None) is None:
291 dated_node.up.add_prop(DATE, dated_node.props.get(DATE) - dated_node.dist)
292 dated_node = dated_node.up
294 trees.append(tree)
295 return trees
298def read_nexus(tree_path):
299 with open(tree_path, 'r') as f:
300 nexus = f.read()
301 # replace CI_date="2019(2018,2020)" with CI_date="2018 2020"
302 nexus = re.sub(r'CI_date="({})\(({}),({})\)"'.format(DATE_REGEX, DATE_REGEX, DATE_REGEX), r'CI_date="\2 \3"',
303 nexus)
304 temp = tree_path + '.{}.temp'.format(datetime.timestamp(datetime.now()))
305 with open(temp, 'w') as f:
306 f.write(nexus)
307 trees = list(Phylo.parse(temp, 'nexus'))
308 os.remove(temp)
309 return trees
312def depth_first_traversal(node):
313 yield node, PREORDER
314 for i, child in enumerate(node.children):
315 if i != 0:
316 yield node, INORDER
317 for _ in depth_first_traversal(child):
318 yield _
319 yield node, POSTORDER
322def resolve_trees(column2states, forest):
323 """
324 Resolved polytomies based on state predictions:
325 if a parent P in a state A has n children 2 <= m < n of which are in state B,
326 we add a parent to these children (who becomes a child of P) in a state B
327 at the distance of the oldest of the children.
329 :param column2states: character to possible state mapping
330 :type column2states: dict
331 :param forest: a forest of trees of interest
332 :type list(ete.Tree)
333 :return: number of newly created nodes.
334 :rtype: int
335 """
336 columns = sorted(column2states.keys())
338 col2state2i = {c: dict(zip(states, range(len(states)))) for (c, states) in column2states.items()}
340 def get_prediction(n):
341 return '.'.join('-'.join(str(i) for i in sorted([col2state2i[c][_] for _ in n.props.get(c, set())]))
342 for c in columns)
344 num_new_nodes = 0
346 for tree in forest:
347 todo = [tree]
348 while todo:
349 n = todo.pop()
350 todo.extend(n.children)
351 if len(n.children) > 2:
352 state2children = defaultdict(list)
353 for c in n.children:
354 state2children[get_prediction(c)].append(c)
355 if len(state2children) > 1:
356 for state, children in state2children.items():
357 if group_children_if_needed(n, children, columns, state):
358 num_new_nodes += 1
359 if num_new_nodes:
360 logging.getLogger('pastml').debug(
361 'Created {} new internal nodes while resolving polytomies'.format(num_new_nodes))
362 else:
363 logging.getLogger('pastml').debug('Could not resolve any polytomy')
364 return num_new_nodes
367def states_are_different(n1, n2, columns):
368 for c in columns:
369 if not n1.props.get(c, set()) & n2.props.get(c, set()):
370 return True
371 return False
374def group_children_if_needed(n, children, columns, state):
375 if len(children) <= 1:
376 return False
377 child = min(children, key=lambda _: _.dist)
378 if not states_are_different(n, child, columns):
379 return False
380 dist = child.dist
381 pol = n.add_child(dist=dist, name='{}.polytomy_{}'.format(n.name, state))
382 pol.add_prop(IS_POLYTOMY, 1)
383 c_date = child.props.get(DATE)
384 pol.add_prop(DATE, c_date)
385 n_ci = n.props.get(DATE_CI, None)
386 c_ci = child.props.get(DATE_CI, None)
387 pol.add_prop(DATE_CI, (None if not n_ci or not isinstance(n_ci, list)
388 else [n_ci[0],
389 (c_ci[1] if c_ci and isinstance(c_ci, list) and len(c_ci) > 1
390 else c_date)]))
391 for c in columns:
392 pol.add_prop(c, child.props.get(c))
393 for c in children:
394 n.remove_child(c)
395 pol.add_child(c, dist=c.dist - dist)
396 return True
399def unresolve_trees(column2states, forest):
400 """
401 Unresolves polytomies whose states do not correspond to child states after likelihood recalculation.
403 :param column2states: character to possible state mapping
404 :type column2states: dict
405 :param forest: a forest of trees of interest
406 :type list(ete.Tree)
407 :return: number of newly deleted nodes.
408 :rtype: int
409 """
410 columns = sorted(column2states.keys())
412 col2state2i = {c: dict(zip(states, range(len(states)))) for (c, states) in column2states.items()}
414 def get_prediction(n):
415 return '.'.join('-'.join(str(i) for i in sorted([col2state2i[c][_] for _ in n.props.get(c, set())]))
416 for c in columns)
418 num_removed_nodes = 0
419 num_new_nodes = 0
421 def remove_node(n):
422 parent = n.up
423 for c in n.children:
424 parent.add_child(c, dist=c.dist + n.dist)
425 parent.remove_child(n)
427 num_polytomies = 0
429 for tree in forest:
430 for n in tree.traverse('postorder'):
431 if n.props.get(IS_POLYTOMY, False):
432 num_polytomies += 1
434 state2children = defaultdict(list)
435 n_children = list(n.children)
436 for c in n_children:
437 state2children[get_prediction(c)].append(c)
438 parent = n.up
440 # if the state is the same as all the child states, it's still a good polytomy resolution
441 if len(state2children) == 1 and not states_are_different(n, n_children[0], columns):
442 # Just need to check that it is not the same state as the parent (then we don't need this polytomy)
443 if not states_are_different(n, parent, columns):
444 num_removed_nodes += 1
445 remove_node(n)
446 continue
448 num_removed_nodes += 1
449 remove_node(n)
451 # now let's try to create new polytomies above
452 above_state2children = defaultdict(list)
453 for c in parent.children:
454 state2children[get_prediction(c)].append(c)
455 for state, children in above_state2children.items():
456 if len(children) <= 1:
457 continue
458 if set(children) != set(n_children):
459 if group_children_if_needed(parent, children, columns, state):
460 num_new_nodes += 1
461 if num_removed_nodes:
462 logging.getLogger('pastml').debug(
463 'Removed {} polytomy resolution{} as inconsistent with model parameters.'
464 .format(num_removed_nodes, 's' if num_removed_nodes > 1 else ''))
465 if num_new_nodes:
466 logging.getLogger('pastml').debug(
467 'Created {} new polytomy resolution{}.'.format(num_new_nodes, 's' if num_new_nodes > 1 else ''))
468 elif num_polytomies - num_removed_nodes + num_new_nodes:
469 logging.getLogger('pastml').debug('All the polytomy resolutions are consistent with model parameters.')
470 return num_removed_nodes
473def clear_extra_features(forest, features):
474 features = set(features) | {'name', 'dist', 'support'}
475 for tree in forest:
476 for n in tree.traverse():
477 for f in set(n.props) - features:
478 if f not in features:
479 n.del_prop(f)
482def copy_forest(forest, features=None):
483 features = set(features if features else forest[0].props)
485 copied_forest = []
486 for tree in forest:
487 copied_tree = Tree()
488 todo = [(tree, copied_tree)]
489 copied_forest.append(copied_tree)
490 while todo:
491 n, copied_n = todo.pop()
492 copied_n.dist = n.dist
493 copied_n.support = n.support
494 copied_n.name = n.name
495 for f in features:
496 if f in n.props.keys():
497 copied_n.add_prop(f, n.props.get(f))
498 for c in n.children:
499 todo.append((c, copied_n.add_child()))
500 return copied_forest