Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/visualisation/tree_compressor.py: 13%
223 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
1import logging
2from collections import defaultdict
3from pastml.tree import IS_POLYTOMY, copy_forest
5import numpy as np
7VERTICAL = 'VERTICAL'
8HORIZONTAL = 'HORIZONTAL'
9TRIM = 'TRIM'
11IS_TIP = 'is_tip'
13REASONABLE_NUMBER_OF_TIPS = 15
15CATEGORIES = 'categories'
17NUM_TIPS_INSIDE = 'max_size'
19TIPS_INSIDE = 'in_tips'
20INTERNAL_NODES_INSIDE = 'in_ns'
21TIPS_BELOW = 'all_tips'
22ROOTS = 'roots'
23ROOT_DATES = 'root_dates'
25COMPRESSED_NODE = 'compressed_node'
27METACHILD = 'metachild'
29IN_FOCUS = 'in_focus'
30AROUND_FOCUS = 'around_focus'
31UP_FOCUS = 'up_focus'
34def _tree2pajek_vertices_arcs(compressed_tree, nodes, edges, columns):
35 n2id = {}
37 def get_states(n, columns):
38 res = []
39 for column in columns:
40 values = n.props.get(column, set())
41 value = values if isinstance(values, str) else ' or '.join(sorted(values))
42 res.append('{}:{}'.format(column, value))
43 return res
45 for id, n in enumerate(compressed_tree.traverse('preorder'), start=len(nodes) + 1):
46 n2id[n] = id
47 if not n.is_root:
48 edges.append('{} {} {}'.format(n2id[n.up], id, len(n.props.get(ROOTS))))
49 nodes.append('{} "{}" "{}" {}'.format(id, n.name,
50 (';'.join(','.join(_.name for _ in ti) for ti in n.props.get(TIPS_INSIDE))),
51 ' '.join('"{}"'.format(_) for _ in get_states(n, columns))))
54def save_to_pajek(nodes, edges, pajek):
55 """
56 Saves a compressed tree into Pajek format:
58 *vertices <number_of_vertices>
59 <id_1> "<vertex_name>" "<tips_inside>" "<column1>:<state(s)>" ["<column2>:<state(s)>" ...]
60 ...
61 *arcs
62 <source_id> <target_id> <weight>
63 ...
65 <tips_inside> list tips that were vertically compressed inside this node: they are comma-separated.
66 If the node was also horizontally merged (i.e. represents several similar configurations),
67 the tip sets corresponding to different configurations are semicolon-separated.
69 <state(s)> lists the states predicted for the corresponding column:
70 if there are several states, they are separated with " or ".
73 :return: void (creates a file specified in the pajek argument)
74 """
75 with open(pajek, 'w+') as f:
76 f.write('*vertices {}\n'.format(len(nodes)))
77 f.write('\n'.join(nodes))
78 f.write('\n')
79 f.write('*arcs\n')
80 f.write('\n'.join(edges))
83def compress_tree(tree, columns, can_merge_diff_sizes=True, tip_size_threshold=REASONABLE_NUMBER_OF_TIPS, mixed=False,
84 pajek=None, pajek_timing=VERTICAL):
85 compressed_tree = copy_forest([tree], features=columns | set(tree.props))[0]
87 for n_compressed, n in zip(compressed_tree.traverse('postorder'), tree.traverse('postorder')):
88 n_compressed.add_prop(TIPS_BELOW, [list(n_compressed.leaves())])
89 n_compressed.add_prop(TIPS_INSIDE, [])
90 n_compressed.add_prop(INTERNAL_NODES_INSIDE, [])
91 n_compressed.add_prop(ROOTS, [n])
92 if n_compressed.is_leaf:
93 n_compressed.props.get(TIPS_INSIDE).append(n)
94 elif not n_compressed.props.get(IS_POLYTOMY, False):
95 n_compressed.props.get(INTERNAL_NODES_INSIDE).append(n)
96 n.add_prop(COMPRESSED_NODE, n_compressed)
98 collapse_vertically(compressed_tree, columns, mixed=mixed)
99 if pajek is not None and VERTICAL == pajek_timing:
100 _tree2pajek_vertices_arcs(compressed_tree, *pajek, columns=sorted(columns))
102 for n in compressed_tree.traverse():
103 n.add_prop(NUM_TIPS_INSIDE, len(n.props.get(TIPS_INSIDE)))
104 n.add_prop(TIPS_INSIDE, [n.props.get(TIPS_INSIDE)])
105 n.add_prop(INTERNAL_NODES_INSIDE, [n.props.get(INTERNAL_NODES_INSIDE)])
107 get_bin = lambda _: _
108 collapse_horizontally(compressed_tree, columns, get_bin, mixed=mixed)
110 if can_merge_diff_sizes and len(compressed_tree) > tip_size_threshold:
111 get_bin = lambda _: int(np.log10(max(1, _)))
112 logging.getLogger('pastml').debug('Allowed merging nodes of different sizes.')
113 collapse_horizontally(compressed_tree, columns, get_bin, mixed=mixed)
115 if pajek is not None and HORIZONTAL == pajek_timing:
116 _tree2pajek_vertices_arcs(compressed_tree, *pajek, columns=sorted(columns))
118 if len(compressed_tree) > tip_size_threshold:
119 for n in compressed_tree.traverse('preorder'):
120 multiplier = (n.up.props.get('multiplier') if n.up else 1) * len(n.props.get(ROOTS))
121 n.add_prop('multiplier', multiplier)
123 def get_tsize(n):
124 if n.props.get(IN_FOCUS, False) or n.props.get(AROUND_FOCUS, False) or n.props.get(UP_FOCUS, False):
125 return np.inf
126 return n.props.get(NUM_TIPS_INSIDE) * n.props.get('multiplier')
128 node_thresholds = []
129 for n in compressed_tree.traverse('postorder'):
130 children_bs = 0 if not n.children else max(get_tsize(_) for _ in n.children)
131 bs = get_tsize(n)
132 # if bs > children_bs it means that the trimming threshold for the node is higher
133 # than the ones for its children
134 if not n.is_root and bs > children_bs:
135 node_thresholds.append(bs)
136 threshold = sorted(node_thresholds)[-tip_size_threshold]
138 if min(node_thresholds) >= threshold:
139 if threshold == np.inf:
140 logging.getLogger('pastml') .debug('All tips are in focus.')
141 else:
142 logging.getLogger('pastml')\
143 .debug('No tip is smaller than the threshold ({}, the size of the {}-th largest tip).'
144 .format(threshold, tip_size_threshold))
145 else:
146 if threshold == np.inf:
147 logging.getLogger('pastml')\
148 .debug('Removing all the out of focus tips (as there are at least {} tips in focus).'
149 .format(tip_size_threshold))
150 else:
151 logging.getLogger('pastml').debug('Set tip size threshold to {} (the size of the {}-th largest tip).'
152 .format(threshold, tip_size_threshold))
153 remove_small_tips(compressed_tree=compressed_tree, full_tree=tree,
154 to_be_removed=lambda _: get_tsize(_) < threshold)
155 remove_mediators(compressed_tree, columns)
156 collapse_horizontally(compressed_tree, columns, get_bin, mixed=mixed)
158 if pajek is not None and TRIM == pajek_timing:
159 _tree2pajek_vertices_arcs(compressed_tree, *pajek, columns=sorted(columns))
161 return compressed_tree
164def collapse_horizontally(tree, columns, tips2bin, mixed=False):
165 config_cache = {}
167 def get_configuration(n):
168 if n.name not in config_cache:
169 # Configuration is (branch_width, (size, states, child_configurations)),
170 # where branch_width is only used for recursive calls and is ignored when considering a merge
171 config_cache[n.name] = (len(n.props.get(TIPS_INSIDE)),
172 (tips2bin(n.props.get(NUM_TIPS_INSIDE)),
173 tuple(tuple(sorted(n.props.get(column, set()))) for column in columns),
174 tuple(sorted([get_configuration(_) for _ in n.children]))))
175 return config_cache[n.name]
177 collapsed_configurations = 0
179 uncompressable_ids = set()
180 for n in tree.traverse('postorder'):
181 config2children = defaultdict(list)
182 for _ in n.children:
183 if mixed and (_.props.get(IN_FOCUS, False) or _.name in uncompressable_ids):
184 uncompressable_ids.add(_.name)
185 uncompressable_ids.add(n.name)
186 else:
187 # use (size, states, child_configurations) as configuration (ignore branch width)
188 config2children[get_configuration(_)[1]].append(_)
189 for children in (_ for _ in config2children.values() if len(_) > 1):
190 collapsed_configurations += 1
191 child = children[0]
192 for sibling in children[1:]:
193 child.props.get(TIPS_INSIDE).extend(sibling.props.get(TIPS_INSIDE))
194 for ti in sibling.props.get(TIPS_INSIDE):
195 for _ in ti:
196 _.add_prop(COMPRESSED_NODE, child)
197 child.props.get(INTERNAL_NODES_INSIDE).extend(sibling.props.get(INTERNAL_NODES_INSIDE))
198 for ii in sibling.props.get(INTERNAL_NODES_INSIDE):
199 for _ in ii:
200 _.add_prop(COMPRESSED_NODE, child)
201 child.props.get(ROOTS).extend(sibling.props.get(ROOTS))
202 child.props.get(TIPS_BELOW).extend(sibling.props.get(TIPS_BELOW))
203 n.remove_child(sibling)
204 child.add_prop(METACHILD, True)
205 child.add_prop(NUM_TIPS_INSIDE,
206 sum(len(_) for _ in child.props.get(TIPS_INSIDE)) / len(child.props.get(TIPS_INSIDE)))
207 if child.name in config_cache:
208 config_cache[child.name] = (len(child.props.get(TIPS_INSIDE)), config_cache[child.name][1])
209 if collapsed_configurations:
210 logging.getLogger('pastml').debug(
211 'Collapsed {} sets of equivalent configurations horizontally.'.format(collapsed_configurations))
214def remove_small_tips(compressed_tree, full_tree, to_be_removed):
215 num_removed = 0
216 changed = True
217 while changed:
218 changed = False
219 for l in compressed_tree.get_leaves():
220 parent = l.up
221 if parent and to_be_removed(l):
222 num_removed += 1
223 parent.remove_child(l)
224 # remove the corresponding nodes from the non-collapsed tree
225 for ti in l.props.get(TIPS_INSIDE):
226 for _ in ti:
227 _.up.remove_child(_)
228 for ii in l.props.get(INTERNAL_NODES_INSIDE):
229 for _ in ii:
230 _.up.remove_child(_)
231 changed = True
233 # if the full tree now contains non-sampled tips,
234 # remove them from the tree and from the corresponding collapsed nodes
235 todo = list(full_tree)
236 while todo:
237 t = todo.pop()
238 if not t.props.get(IS_TIP, False):
239 parent = t.up
240 t.up.remove_child(t)
241 if parent.is_leaf:
242 todo.append(parent)
243 for ini_list in t.props.get(COMPRESSED_NODE).props.get(INTERNAL_NODES_INSIDE):
244 if t in ini_list:
245 ini_list.remove(t)
247 logging.getLogger('pastml').debug(
248 'Recursively removed {} tips of size smaller than the threshold.'.format(num_removed))
251def collapse_vertically(tree, columns, mixed=False):
252 """
253 Collapses a child node into its parent if they are in the same state.
254 :param columns: a list of characters
255 :param tree: ete3.Tree
256 :param mixed: if True then the nodes in focus will not get collapsed
257 :return: void, modifies the input tree
258 """
260 def _same_states(node1, node2, columns):
261 for column in columns:
262 if node1.props.get(column, set()) != node2.props.get(column, set()):
263 return False
264 if mixed:
265 if node1.props.get(IN_FOCUS, False) or node2.props.get(IN_FOCUS, False):
266 return False
267 if node1.props.get(UP_FOCUS, False) and not node2.props.get(IN_FOCUS, False) and not node2.props.get(UP_FOCUS, False):
268 node2.add_prop(AROUND_FOCUS, True)
269 return False
270 if node2.props.get(UP_FOCUS, False) and not node1.props.get(IN_FOCUS, False) and not node1.props.get(UP_FOCUS, False):
271 node1.add_prop(AROUND_FOCUS, True)
272 return False
273 return True
275 num_collapsed = 0
276 for n in tree.traverse('postorder'):
277 if n.is_leaf:
278 continue
280 children = list(n.children)
281 for child in children:
282 # merge the child into this node if their states are the same
283 if _same_states(n, child, columns):
284 n.props.get(TIPS_INSIDE).extend(child.props.get(TIPS_INSIDE))
285 for _ in child.props.get(TIPS_INSIDE):
286 _.add_prop(COMPRESSED_NODE, n)
287 n.props.get(INTERNAL_NODES_INSIDE).extend(child.props.get(INTERNAL_NODES_INSIDE))
288 for _ in child.props.get(INTERNAL_NODES_INSIDE):
289 _.add_prop(COMPRESSED_NODE, n)
291 n.remove_child(child)
292 grandchildren = list(child.children)
293 for grandchild in grandchildren:
294 n.add_child(grandchild)
295 num_collapsed += 1
296 if num_collapsed:
297 logging.getLogger('pastml').debug('Collapsed vertically {} internal nodes without state change.'
298 .format(num_collapsed))
301def remove_mediators(tree, columns):
302 """
303 Removes intermediate nodes that are just mediators between their parent and child states.
304 :param columns: list of characters
305 :param tree: ete3.Tree
306 :return: void, modifies the input tree
307 """
308 num_removed = 0
309 for n in tree.traverse('postorder'):
310 if n.props.get(METACHILD, False) or n.is_leaf or len(n.children) > 1 or n.is_root \
311 or n.props.get(NUM_TIPS_INSIDE) > 0:
312 continue
314 parent = n.up
315 child = n.children[0]
317 compatible = True
318 for column in columns:
319 states = n.props.get(column, set())
320 parent_states = parent.props.get(column, set())
321 child_states = child.props.get(column, set())
322 # if mediator has unresolved states, it should hesitate between the parent and the child:
323 if len(states) < 2 or states != child_states | parent_states:
324 compatible = False
325 break
327 if compatible:
328 parent.remove_child(n)
329 parent.add_child(child)
330 # update the uncompressed tree
331 for ii in n.props.get(INTERNAL_NODES_INSIDE):
332 for _ in ii:
333 for c in list(_.children):
334 _.up.add_child(c)
335 _.up.remove_child(_)
336 num_removed += 1
337 if num_removed:
338 logging.getLogger('pastml').debug("Removed {} internal node{}"
339 " with the state unresolved between the parent's and the only child's."
340 .format(num_removed, '' if num_removed == 1 else 's'))