Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/visualisation/cytoscape_manager.py: 11%
561 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
1import logging
2import os
3from collections import defaultdict
4from glob import glob
5from queue import Queue
6from shutil import copyfile
8import numpy as np
9from jinja2 import Environment, PackageLoader
11from pastml import numeric2datetime
12from pastml.file import get_pastml_colour_file
13from pastml.tree import DATE, IS_POLYTOMY, copy_forest
14from pastml.visualisation import get_formatted_date
15from pastml.visualisation.colour_generator import get_enough_colours, WHITE, parse_colours
16from pastml.visualisation.tree_compressor import NUM_TIPS_INSIDE, TIPS_INSIDE, TIPS_BELOW, \
17 REASONABLE_NUMBER_OF_TIPS, compress_tree, INTERNAL_NODES_INSIDE, ROOTS, IS_TIP, ROOT_DATES, IN_FOCUS, AROUND_FOCUS, \
18 UP_FOCUS, save_to_pajek, VERTICAL
20JS_LIST = ["https://pastml.pasteur.fr/static/js/jquery.min.js",
21 "https://pastml.pasteur.fr/static/js/jquery.qtip.min.js",
22 "https://pastml.pasteur.fr/static/js/cytoscape.min.js",
23 "https://pastml.pasteur.fr/static/js/cytoscape-qtip.js",
24 "https://pastml.pasteur.fr/static/js/cytoscape-svg.js",
25 "https://pastml.pasteur.fr/static/js/layout-base.min.js",
26 "https://pastml.pasteur.fr/static/js/cose-base.min.js",
27 "https://pastml.pasteur.fr/static/js/cytoscape-cose-bilkent.min.js"]
28CSS_LIST = ["https://pastml.pasteur.fr/static/css/jquery.qtip.min.css",
29 "https://pastml.pasteur.fr/static/css/bootstrap.min.css"]
31MAX_TIPS_FOR_FULL_TREE_VISUALISATION = 5000
33TIMELINE_SAMPLED = 'SAMPLED'
34TIMELINE_NODES = 'NODES'
35TIMELINE_LTT = 'LTT'
37TIP_LIMIT = 1000
39MIN_EDGE_SIZE = 50
40MIN_FONT_SIZE = 80
41MIN_NODE_SIZE = 200
43UNRESOLVED = 'unresolved'
44TIP = 'tip'
46TOOLTIP = 'tooltip'
47COLOUR = 'colour'
49DATA = 'data'
50ID = 'id'
51EDGES = 'edges'
52NODES = 'nodes'
53ELEMENTS = 'elements'
55NODE_SIZE = 'node_size'
56NODE_NAME = 'node_name'
57BRANCH_NAME = 'branch_name'
58EDGE_SIZE = 'edge_size'
59EDGE_NAME = 'edge_name'
60FONT_SIZE = 'node_fontsize'
62MILESTONE = 'mile'
64DATE_LABEL = 'date'
66DIST_TO_ROOT_LABEL = 'dist. to root'
69def get_fake_node(n_id, x, y):
70 attributes = {ID: n_id, 'fake': 1}
71 return _get_node(attributes, position=(x, y))
74def get_node(n, n_id, tooltip='', clazz=None, x=0, y=0):
75 features = {feature: n.props.get(feature) for feature in n.props if feature in [MILESTONE, UNRESOLVED, 'x', 'y']
76 or feature.startswith('node_')}
77 features[ID] = n_id
78 if n.is_leaf:
79 features[TIP] = 1
80 features[TOOLTIP] = tooltip
81 return _get_node(features, clazz=_clazz_list2css_class(clazz), position=(x, y) if x is not None else None)
84def get_edge(source_name, target_name, **kwargs):
85 return _get_edge(source=source_name, target=target_name, **kwargs)
88def get_scaling_function(y_m, y_M, x_m, x_M):
89 """
90 Returns a linear function y = k x + b, where y \in [m, M]
91 :param y_m:
92 :param y_M:
93 :param x_m:
94 :param x_M:
95 :return:
96 """
97 if x_M <= x_m:
98 return lambda _: y_m
99 k = (y_M - y_m) / (x_M - x_m)
100 b = y_m - k * x_m
101 return lambda _: int(k * _ + b)
104def set_cyto_features_compressed(n, size_scaling, e_size_scaling, font_scaling, transform_size, transform_e_size, state,
105 root_names, root_dates, suffix='', is_mixed=False):
106 tips_inside, tips_below, internal_nodes_inside, roots = \
107 n.props.get(TIPS_INSIDE, []), n.props.get(TIPS_BELOW, []), \
108 n.props.get(INTERNAL_NODES_INSIDE, []), n.props.get(ROOTS, [])
110 def get_min_max_str(values, default_value=0):
111 min_v, max_v = (min(len(_) for _ in values), max(len(_) for _ in values)) \
112 if values else (default_value, default_value)
113 return ' {}'.format('{}-{}'.format(min_v, max_v) if min_v != max_v else min_v), min_v, max_v
115 tips_below_str, _, max_n_tips_below = get_min_max_str(tips_below)
116 tips_inside_str, _, max_n_tips = get_min_max_str(tips_inside)
117 internal_ns_inside_str, _, _ = get_min_max_str(internal_nodes_inside)
118 n.add_prop('{}{}'.format(NODE_NAME, suffix),
119 '{}{}'.format(state, tips_inside_str)
120 if not is_mixed or (not n.props.get(n, IN_FOCUS, False) and not n.props.get(UP_FOCUS, False)) else
121 '{}{}{}'.format(state, ':' if state else '', root_names[0]))
122 size_factor = 2 if n.props.get(UNRESOLVED, False) else 1
123 n.add_prop('{}{}'.format(NODE_SIZE, suffix),
124 (size_scaling(transform_size(max_n_tips)) if max_n_tips else int(MIN_NODE_SIZE / 1.5))
125 * size_factor)
126 n.add_prop('{}{}'.format(FONT_SIZE, suffix),
127 font_scaling(transform_size(max_n_tips)) if max_n_tips else MIN_FONT_SIZE)
129 n.add_prop('node_{}{}'.format(TIPS_INSIDE, suffix), tips_inside_str)
130 n.add_prop('node_{}{}'.format(INTERNAL_NODES_INSIDE, suffix), internal_ns_inside_str)
131 n.add_prop('node_{}{}'.format(TIPS_BELOW, suffix), tips_below_str)
132 root_name2date = dict(zip(root_names, root_dates))
133 root_names = sorted(root_names)
134 n.add_prop('node_{}{}'.format(ROOTS, suffix), ', '.join(root_names))
135 n.add_prop('node_{}{}'.format(ROOT_DATES, suffix), ', '.join(str(root_name2date[_]) for _ in root_names))
137 edge_size = len(roots)
138 if edge_size > 1:
139 n.add_prop('node_meta{}'.format(suffix), 1)
140 n.add_prop('{}{}'.format(EDGE_NAME, suffix), str(edge_size) if edge_size != 1 else '')
141 e_size = e_size_scaling(transform_e_size(edge_size))
142 n.add_prop('{}{}'.format(EDGE_SIZE, suffix), e_size)
145def set_cyto_features_tree(n, state):
146 n.add_prop(NODE_NAME, state)
147 n.add_prop(EDGE_NAME, n.dist)
150def _forest2json_compressed(forest, compressed_forest, columns, name_feature, get_date, milestones=None,
151 dates_are_dates=True, is_mixed=False):
152 e_size_scaling, font_scaling, size_scaling, transform_e_size, transform_size = \
153 get_size_transformations(compressed_forest)
155 sort_key = lambda n: (n.props.get(UNRESOLVED, 0),
156 get_column_value_str(n, name_feature, format_list=True) if name_feature else '',
157 *(get_column_value_str(n, column, format_list=True) for column in columns),
158 -n.props.get(NUM_TIPS_INSIDE),
159 -len(n.props.get(ROOTS)),
160 n.name)
161 i = 0
162 node2id = {}
163 todo = Queue()
164 for compressed_tree in compressed_forest:
165 todo.put_nowait(compressed_tree)
166 node2id[compressed_tree] = i
167 i += 1
168 while not todo.empty():
169 n = todo.get_nowait()
170 for c in sorted(n.children, key=sort_key):
171 node2id[c] = i
172 i += 1
173 todo.put_nowait(c)
175 n2state = {}
177 # Set the cytoscape features
178 for compressed_tree in compressed_forest:
179 for n in compressed_tree.traverse():
180 state = get_column_value_str(n, name_feature, format_list=False, list_value='') if name_feature else ''
181 n2state[n] = state
182 root_names = [_.name for _ in n.props.get(ROOTS)]
183 root_dates = [get_formatted_date(_, dates_are_dates) for _ in n.props.get(ROOTS)]
184 set_cyto_features_compressed(n, size_scaling, e_size_scaling, font_scaling,
185 transform_size, transform_e_size, state, root_names, root_dates,
186 is_mixed=is_mixed)
188 # Calculate node coordinates
189 min_size = 2 * min(min(_.props.get(NODE_SIZE) for _ in compressed_tree.traverse())
190 for compressed_tree in compressed_forest)
191 n2width = {}
192 for compressed_tree in compressed_forest:
193 for n in compressed_tree.traverse('postorder'):
194 n2width[n] = max(n.props.get(NODE_SIZE),
195 sum(n2width[c] for c in n.children) + min_size * (len(n.children) - 1))
197 n2x, n2y = {}, {compressed_tree: 0 for compressed_tree in compressed_forest}
198 n2offset = {}
199 tree_offset = 0
200 for compressed_tree in compressed_forest:
201 n2offset[compressed_tree] = tree_offset
202 tree_offset += n2width[compressed_tree] + 2 * min_size
203 for n in compressed_tree.traverse('preorder'):
204 n2x[n] = n2offset[n] + n2width[n] / 2
205 offset = n2offset[n]
206 if not n.is_leaf:
207 for c in sorted(n.children, key=lambda c: node2id[c]):
208 n2offset[c] = offset
209 offset += n2width[c] + min_size
210 n2y[c] = n2y[n] + n.props.get(NODE_SIZE) / 2 + c.props.get(NODE_SIZE) / 2 + min_size
211 for n in compressed_tree.traverse('postorder'):
212 if not n.is_leaf:
213 n2x[n] = np.mean([n2x[c] for c in n.children])
215 def filter_by_date(items, date):
216 return [_ for _ in items if get_date(_) <= date]
218 # Set the cytoscape feature for different timeline points
219 for tree, compressed_tree in zip(forest, compressed_forest):
220 if len(milestones) > 1:
221 nodes = list(compressed_tree.traverse())
222 for i in range(len(milestones) - 1, -1, -1):
223 milestone = milestones[i]
224 nodes_i = []
226 # remove too recent nodes from the original tree
227 for n in tree.traverse('postorder'):
228 if n.is_root:
229 continue
230 if get_date(n) > milestone:
231 n.up.remove_child(n)
233 suffix = '_{}'.format(i)
234 for n in nodes:
235 state = n2state[n]
236 tips_inside, tips_below, internal_nodes_inside, roots = n.props.get(TIPS_INSIDE, []), \
237 n.props.get(TIPS_BELOW, []), \
238 n.props.get(INTERNAL_NODES_INSIDE, []), \
239 n.props.get(ROOTS, [])
240 tips_inside_i, tips_below_i, internal_nodes_inside_i, roots_i = [], [], [], []
241 for ti, tb, ini, root in zip(tips_inside, tips_below, internal_nodes_inside, roots):
242 if get_date(root) <= milestone:
243 roots_i.append(root)
245 ti = filter_by_date(ti, milestone)
246 tb = [_ for _ in tb if _.props.get(DATE) <= milestone]
247 ini = filter_by_date(ini, milestone)
249 tips_inside_i.append(ti + [_ for _ in ini if _.is_leaf])
250 tips_below_i.append(tb)
251 internal_nodes_inside_i.append([_ for _ in ini if not _.is_leaf])
252 n.add_props(**{TIPS_INSIDE: tips_inside_i, TIPS_BELOW: tips_below_i,
253 INTERNAL_NODES_INSIDE: internal_nodes_inside_i, ROOTS: roots_i})
254 if roots_i:
255 n.add_prop(MILESTONE, i)
256 root_names = [_.props.get(BRANCH_NAME) if _.props.get(DATE) > milestone else _.name for _ in
257 roots_i]
258 root_dates = [_.props.get(DATE) for _ in roots_i]
259 if dates_are_dates:
260 try:
261 root_dates = [numeric2datetime(_).strftime("%d %b %Y") for _ in root_dates]
262 except:
263 pass
264 set_cyto_features_compressed(n, size_scaling, e_size_scaling, font_scaling, transform_size,
265 transform_e_size, state, root_names=root_names, suffix=suffix,
266 root_dates=root_dates, is_mixed=is_mixed)
267 nodes_i.append(n)
268 nodes = nodes_i
270 # Save the structure
271 clazzes = set()
272 nodes, edges = [], []
274 one_column = columns[0] if len(columns) == 1 else None
276 for n, n_id in node2id.items():
277 if one_column:
278 values = n.props.get(one_column, set())
279 clazz = tuple(sorted(values))
280 else:
281 clazz = tuple('{}_{}'.format(column, get_column_value_str(n, column, format_list=False, list_value=''))
282 for column in columns)
283 if clazz:
284 clazzes.add(clazz)
285 nodes.append(get_node(n, n_id, tooltip=get_tooltip(n, columns),
286 clazz=clazz, x=n2x[n], y=n2y[n]))
288 for child in sorted(n.children, key=lambda _: node2id[_]):
289 edge_attributes = {feature: child.props.get(feature) for feature in child.props
290 if feature.startswith('edge_') or feature == MILESTONE or feature == IS_POLYTOMY}
291 source_name = n_id
292 edges.append(get_edge(source_name, node2id[child], **edge_attributes))
294 json_dict = {NODES: nodes, EDGES: edges}
295 return json_dict, sorted(clazzes)
298def binary_search(start, end, value, array):
299 if start >= end - 1:
300 return start
301 i = int((start + end) / 2)
302 if array[i] == value or array[i] > value and (i == start or value > array[i - 1]):
303 return i
304 if array[i] > value:
305 return binary_search(start, i, value, array)
306 return binary_search(i + 1, end, value, array)
309def _forest2json(forest, columns, name_feature, get_date, milestones=None, timeline_type=TIMELINE_SAMPLED,
310 dates_are_dates=True):
311 min_root_date = min(tree.props.get(DATE) for tree in forest)
312 width = sum(len(tree) for tree in forest)
313 height_factor = 300 * width / max(
314 (max(_.props.get(DATE) for _ in tree) - min_root_date + tree.dist) for tree in forest)
315 zero_dist = min(min(min(_.dist for _ in tree.traverse() if _.dist > 0) for tree in forest), 300) * height_factor / 2
317 # Calculate node coordinates
318 n2x, n2y = {}, {}
319 x = -600
320 for tree in forest:
321 x += 600
322 for t in tree:
323 n2x[t] = x
324 x += 600
326 for tree in forest:
327 for n in tree.traverse('postorder'):
328 state = get_column_value_str(n, name_feature, format_list=False, list_value='') if name_feature else ''
329 n.add_prop('node_root_id', n.name)
330 n.add_prop('node_root_date', get_formatted_date(n, dates_are_dates))
331 if not n.is_leaf:
332 n2x[n] = np.mean([n2x[_] for _ in n.children])
333 n2y[n] = (n.props.get(DATE) - min_root_date) * height_factor
334 for c in n.children:
335 if n2y[c] == n2y[n]:
336 n2y[c] += zero_dist
337 set_cyto_features_tree(n, state)
339 # Save cytoscape features at different timeline points
340 if len(milestones) > 1:
341 for tree in forest:
342 for n in tree.traverse('preorder'):
343 ms_i = binary_search(0 if n.is_root else n.up.props.get(MILESTONE),
344 len(milestones), get_date(n), milestones)
345 n.add_prop(MILESTONE, ms_i)
346 for i in range(len(milestones) - 1, ms_i - 1, -1):
347 milestone = milestones[i]
348 suffix = '_{}'.format(i)
349 if TIMELINE_LTT == timeline_type:
350 # if it is LTT also cut the branches if needed
351 if n.props.get(DATE) > milestone:
352 n.add_prop('{}{}'.format(EDGE_NAME, suffix),
353 np.round(milestone - n.up.props.get(DATE), 3))
354 else:
355 n.add_prop('{}{}'.format(EDGE_NAME, suffix), np.round(n.dist, 3))
357 # Save the structure
358 clazzes = set()
359 nodes, edges = [], []
361 one_column = columns[0] if len(columns) == 1 else None
363 i = 0
364 node2id = {}
365 for tree in forest:
366 for n in tree.traverse():
367 node2id[n] = i
368 i += 1
370 for n, n_id in node2id.items():
371 if n.is_root:
372 fake_id = 'fake_node_{}'.format(n_id)
373 nodes.append(get_fake_node(fake_id, n2x[n], n2y[n] - n.dist * height_factor))
374 edges.append(get_edge(fake_id, n_id, **{feature: n.props.get(feature) for feature in n.props
375 if feature.startswith('edge_') or feature == MILESTONE}))
376 if one_column:
377 values = n.props.get(one_column, set())
378 clazz = tuple(sorted(values))
379 else:
380 clazz = tuple('{}_{}'.format(column, get_column_value_str(n, column, format_list=False, list_value=''))
381 for column in columns)
382 if clazz:
383 clazzes.add(clazz)
384 nodes.append(get_node(n, n_id, tooltip=get_tooltip(n, columns), clazz=clazz, x=n2x[n], y=n2y[n]))
386 for child in n.children:
387 edge_attributes = {feature: child.props.get(feature) for feature in child.props
388 if feature.startswith('edge_') or feature == MILESTONE or feature == IS_POLYTOMY}
389 source_name = n_id
390 target_name = 'fake_node_{}'.format(node2id[child])
391 nodes.append(get_fake_node(target_name, x=n2x[child], y=n2y[n]))
392 edges.append(get_edge(source_name, target_name, fake=1,
393 **{k: v for (k, v) in edge_attributes.items()
394 if EDGE_NAME not in k and IS_POLYTOMY not in k}))
395 source_name = target_name
396 edges.append(get_edge(source_name, node2id[child], **edge_attributes))
398 json_dict = {NODES: nodes, EDGES: edges}
399 return json_dict, sorted(clazzes)
402def _forest2json_transitions(states, counts, transitions, state2colour, threshold=0):
403 nodes, edges = [], []
404 n = len(states)
406 n_scaler = get_scaling_function(y_m=200, y_M=800, x_m=min(counts), x_M=max(counts))
407 font_scaler = get_scaling_function(y_m=MIN_FONT_SIZE, y_M=MIN_FONT_SIZE * 3, x_m=min(counts), x_M=max(counts))
408 positive_transitions = transitions[transitions > 0]
409 e_scaler = get_scaling_function(y_m=30, y_M=200, x_m=min(positive_transitions), x_M=max(positive_transitions))
411 transtions_to_from = np.transpose(transitions.copy())
412 np.fill_diagonal(transtions_to_from, 0)
413 nums = np.triu(transitions + transtions_to_from).flatten()
414 positive_nums = nums[nums > 0]
416 max_transition_num = max(positive_nums)
417 if max_transition_num <= 2:
418 miles = sorted((set({0, threshold, 1 if max_transition_num > 1 else 0}
419 | set(np.round(positive_nums, 3))) - {np.round(max_transition_num, 3)}))
420 else:
421 miles = sorted(set({0, threshold, 1} | set(np.trunc(positive_nums))))
422 miles = np.array([_ for _ in miles if threshold <= _ < max_transition_num])
423 if not len(miles):
424 miles = [0]
426 # we hide connections when they are < mile
427 def get_mile(start, end, value):
428 if start == end:
429 return None
430 i = int((start + end) / 2)
431 if miles[i] <= value:
432 if i == end - 1 or value < miles[i + 1]:
433 return i
434 return get_mile(i + 1, end, value)
435 return get_mile(start, i, value)
437 i2mile = defaultdict(lambda: -1)
438 for i in range(n):
439 from_state = states[i]
440 n_tips = counts[i]
441 i_node_size = n_scaler(n_tips)
442 if n_tips > 0:
443 for j in range(i, n):
444 if counts[j] > 0:
445 to_state = states[j]
446 n_ij = transitions[i, j]
447 n_ji = transitions[j, i]
448 node_size = (i_node_size + n_scaler(counts[j])) / 2
449 if n_ij > 0:
450 mile = get_mile(0, len(miles), n_ij)
451 if mile is not None:
452 i2mile[i] = max(mile, i2mile[i])
453 i2mile[j] = max(mile, i2mile[j])
454 edges.append(get_edge(i, j,
455 **{ID: '{}_{}'.format(i, j),
456 EDGE_SIZE: e_scaler(n_ij),
457 NODE_SIZE: node_size / (2 if i != j else 1),
458 EDGE_NAME: n_ij,
459 TOOLTIP: '{} transitions from {} to {}'
460 .format(n_ij, from_state, to_state),
461 MILESTONE: mile}))
462 if n_ji > 0 and i != j:
463 mile = get_mile(0, len(miles), n_ji)
464 if mile is not None:
465 i2mile[i] = max(mile, i2mile[i])
466 i2mile[j] = max(mile, i2mile[j])
467 edges.append(get_edge(j, i,
468 **{ID: '{}_{}'.format(j, i),
469 EDGE_SIZE: e_scaler(n_ji),
470 NODE_SIZE: node_size / 2,
471 EDGE_NAME: n_ji,
472 TOOLTIP: '{} transitions from {} to {}'
473 .format(n_ji, to_state, from_state),
474 MILESTONE: mile}))
475 if i2mile[i] >= 0:
476 nodes.append(_get_node(data={ID: i, NODE_NAME: '{} ({:.0f})'.format(from_state, n_tips),
477 NODE_SIZE: i_node_size,
478 FONT_SIZE: font_scaler(n_tips),
479 TOOLTIP: '{} is represented by {:.0f} samples.'.format(from_state, n_tips),
480 COLOUR: state2colour[from_state],
481 MILESTONE: i2mile[i]}))
482 json_dict = {NODES: nodes, EDGES: edges}
483 return json_dict, ['{:g}'.format(_) for _ in miles]
486def get_size_transformations(forest):
487 max_size, min_size, max_e_size, min_e_size = 1, np.inf, 1, np.inf
488 for tree in forest:
489 for n in tree.traverse():
490 sz = max(n.props.get(NUM_TIPS_INSIDE), 1)
491 max_size = max(max_size, sz)
492 min_size = min(min_size, sz)
493 e_sz = len(n.props.get(ROOTS))
494 max_e_size = max(max_e_size, e_sz)
495 min_e_size = min(min_e_size, e_sz)
497 need_log = max_size / min_size > 100
498 transform_size = lambda _: np.power(np.log10(_ + 9) if need_log else _, 1 / 2)
500 need_e_log = max_e_size / min_e_size > 100
501 transform_e_size = lambda _: np.log10(_) if need_e_log else _
503 size_scaling = get_scaling_function(y_m=MIN_NODE_SIZE, y_M=MIN_NODE_SIZE * min(8, int(max_size / min_size)),
504 x_m=transform_size(min_size), x_M=transform_size(max_size))
505 font_scaling = get_scaling_function(y_m=MIN_FONT_SIZE, y_M=MIN_FONT_SIZE * min(3, int(max_size / min_size)),
506 x_m=transform_size(min_size), x_M=transform_size(max_size))
507 e_size_scaling = get_scaling_function(y_m=MIN_EDGE_SIZE, y_M=MIN_EDGE_SIZE * min(3, int(max_e_size / min_e_size)),
508 x_m=transform_e_size(min_e_size), x_M=transform_e_size(max_e_size))
510 return e_size_scaling, font_scaling, size_scaling, transform_e_size, transform_size
513def get_tooltip(n, columns):
514 return '<br>'.join('{}: {}'.format(column, get_column_value_str(n, column, format_list=True))
515 for column in columns)
518def save_as_cytoscape_html(forest, out_html, column2states, name_feature, name2colour, compressed_forest,
519 milestone_label, timeline_type, milestones, get_date, work_dir, local_css_js=False,
520 milestone_labels=None, is_mixed=False):
521 """
522 Converts a forest to an html representation using Cytoscape.js.
524 If categories are specified they are visualised as pie-charts inside the nodes,
525 given that each node contains features corresponding to these categories with values being the percentage.
526 For instance, given categories ['A', 'B', 'C'], a node with features {'A': 50, 'B': 50}
527 will have a half-half pie-chart (half-colored in a colour of A, and half B).
529 :param name_feature: str, a node feature whose value will be used as a label
530 :param name2colour: dict, str to str, category name to HEX colour mapping
531 :param forest: list(ete3.Tree)
532 :param out_html: path where to save the resulting html file.
533 """
534 graph_name = os.path.splitext(os.path.basename(out_html))[0]
535 columns = sorted(column2states.keys())
536 if milestone_labels is None:
537 milestone_labels = ['{:g}'.format(_) for _ in milestones]
539 if compressed_forest is not None:
540 json_dict, clazzes \
541 = _forest2json_compressed(forest, compressed_forest, columns, name_feature=name_feature, get_date=get_date,
542 milestones=milestones, dates_are_dates=milestone_label == DATE_LABEL,
543 is_mixed=is_mixed)
544 else:
545 json_dict, clazzes \
546 = _forest2json(forest, columns, name_feature=name_feature, get_date=get_date, milestones=milestones,
547 timeline_type=timeline_type, dates_are_dates=milestone_label == DATE_LABEL)
548 loader = PackageLoader('pastml')
549 env = Environment(loader=loader)
550 template = env.get_template('pie_tree.js') if compressed_forest is not None \
551 else env.get_template('pie_tree_simple.js')
553 clazz2css = {}
554 for clazz_list in clazzes:
555 n = len(clazz_list)
556 css = ''
557 for i, cat in enumerate(clazz_list, start=1):
558 css += """
559 'pie-{i}-background-color': "{colour}",
560 'pie-{i}-background-size': '{percent}\%',
561 """.format(i=i, percent=round(100 / n, 2), colour=name2colour[cat])
562 clazz2css[_clazz_list2css_class(clazz_list)] = css
563 graph = template.render(clazz2css=clazz2css.items(), elements=json_dict, title=graph_name,
564 years=milestone_labels,
565 tips='samples' if TIMELINE_SAMPLED == timeline_type
566 else ('lineages ending' if TIMELINE_LTT == timeline_type else 'external nodes'),
567 internal_nodes='internal nodes' if TIMELINE_NODES == timeline_type
568 else 'diversification events',
569 age_label=milestone_label)
570 slider = env.get_template('time_slider.html').render(min_date=0, max_date=len(milestones) - 1,
571 cur_date=len(milestones) - 1,
572 name=milestone_label) if len(milestones) > 1 else ''
574 template = env.get_template('index.html')
575 os.makedirs(os.path.abspath(os.path.dirname(out_html)), exist_ok=True)
577 if local_css_js:
578 js_list = []
579 os.makedirs(os.path.join(work_dir, 'js'), exist_ok=True)
580 os.makedirs(os.path.join(work_dir, 'css'), exist_ok=True)
581 os.makedirs(os.path.join(work_dir, 'fonts'), exist_ok=True)
583 template_dir = os.path.join(os.path.abspath(os.path.split(__file__)[0]), '..', 'templates')
584 for _ in sorted(glob(os.path.join(template_dir, 'js', '*.js*'))):
585 cp = os.path.join(work_dir, 'js', os.path.split(_)[1])
586 copyfile(_, cp)
587 if cp.endswith('.js'):
588 js_list.append(cp)
589 css_list = []
590 for _ in glob(os.path.join(template_dir, 'css', '*.css*')):
591 cp = os.path.join(work_dir, 'css', os.path.split(_)[1])
592 copyfile(_, cp)
593 if cp.endswith('.css'):
594 css_list.append(cp)
595 for _ in glob(os.path.join(template_dir, 'fonts', '*.*')):
596 cp = os.path.join(work_dir, 'fonts', os.path.split(_)[1])
597 copyfile(_, cp)
598 else:
599 js_list = JS_LIST
600 css_list = CSS_LIST
601 page = template.render(graph=graph, title=graph_name, slider=slider, js_list=js_list, css_list=css_list)
603 with open(out_html, 'w+') as fp:
604 fp.write(page)
607def save_as_transition_html(character, states, counts, transitions, out_html, state2colour, work_dir,
608 local_css_js=False, threshold=0):
609 """
610 Converts transition count data to an html representation using Cytoscape.js.
612 :param out_html: path where to save the resulting html file.
613 """
614 graph_name = os.path.splitext(os.path.basename(out_html))[0]
615 transitions[transitions < threshold] = 0
616 json_dict, thresholds = _forest2json_transitions(states, counts, transitions, state2colour, threshold=threshold)
618 loader = PackageLoader('pastml')
619 env = Environment(loader=loader)
620 template = env.get_template('transitions.js')
622 graph = template.render(elements=json_dict, character=character,
623 years=thresholds)
624 slider = env.get_template('time_slider.html').render(min_date=0, max_date=len(thresholds) - 1,
625 name='transition number threshold', cur_date=0) if len(
626 thresholds) > 1 else ''
628 template = env.get_template('index.html')
629 os.makedirs(os.path.abspath(os.path.dirname(out_html)), exist_ok=True)
631 if local_css_js:
632 js_list = []
633 os.makedirs(os.path.join(work_dir, 'js'), exist_ok=True)
634 os.makedirs(os.path.join(work_dir, 'css'), exist_ok=True)
635 os.makedirs(os.path.join(work_dir, 'fonts'), exist_ok=True)
637 template_dir = os.path.join(os.path.abspath(os.path.split(__file__)[0]), '..', 'templates')
638 for _ in sorted(glob(os.path.join(template_dir, 'js', '*.js*'))):
639 cp = os.path.join(work_dir, 'js', os.path.split(_)[1])
640 copyfile(_, cp)
641 if cp.endswith('.js'):
642 js_list.append(cp)
643 css_list = []
644 for _ in glob(os.path.join(template_dir, 'css', '*.css*')):
645 cp = os.path.join(work_dir, 'css', os.path.split(_)[1])
646 copyfile(_, cp)
647 if cp.endswith('.css'):
648 css_list.append(cp)
649 for _ in glob(os.path.join(template_dir, 'fonts', '*.*')):
650 cp = os.path.join(work_dir, 'fonts', os.path.split(_)[1])
651 copyfile(_, cp)
652 else:
653 js_list = JS_LIST
654 css_list = CSS_LIST
655 page = template.render(graph=graph, title=graph_name, slider=slider, js_list=js_list, css_list=css_list)
657 with open(out_html, 'w+') as fp:
658 fp.write(page)
661def _clazz_list2css_class(clazz_list):
662 if not clazz_list:
663 return None
664 return ''.join(c for c in '-'.join(clazz_list) if c.isalnum() or '-' == c)
667def _get_node(data, clazz=None, position=None):
668 if position:
669 data['node_x'] = position[0]
670 data['node_y'] = position[1]
671 res = {DATA: data}
672 if clazz:
673 res['classes'] = clazz
674 return res
677def _get_edge(**data):
678 return {DATA: data}
681def get_column_value_str(n, column, format_list=True, list_value=''):
682 values = n.props.get(column, set())
683 if isinstance(values, str):
684 return values
685 return ' or '.join(sorted(values)) if format_list or len(values) == 1 else list_value
688def visualize(forest, column2states, work_dir, name_column=None, html=None, html_compressed=None, html_mixed=None,
689 tip_size_threshold=REASONABLE_NUMBER_OF_TIPS, date_label='Dist. to root', timeline_type=TIMELINE_SAMPLED,
690 local_css_js=False, column2colours=None, focus=None, pajek=None, pajek_timing=VERTICAL):
691 for tree in forest:
692 nodes_in_focus = set()
693 for node in tree.traverse():
694 for column in column2states.keys():
695 col_state = node.props.get(column, set())
696 if focus and col_state & focus[column]:
697 nodes_in_focus.add(node)
698 for node in nodes_in_focus:
699 node.add_prop(IN_FOCUS, True)
700 if not node.is_root and node.up not in nodes_in_focus:
701 node.up.add_prop(UP_FOCUS, True)
702 for c in node.children:
703 if c not in nodes_in_focus:
704 c.add_prop(AROUND_FOCUS, True)
706 one_column = next(iter(column2states.keys())) if len(column2states) == 1 else None
708 name2colour = {}
709 for column, states in column2states.items():
710 num_unique_values = len(states)
711 colours = None
712 if column2colours and column in column2colours:
713 try:
714 colours = parse_colours(column2colours[column], states)
715 except ValueError as e:
716 logging.getLogger('pastml').error('Failed to parse the input colours: {}'.format(e))
717 if colours is None:
718 colours = get_enough_colours(num_unique_values)
719 for value, col in zip(states, colours):
720 name2colour[value if one_column else '{}_{}'.format(column, value)] = col
721 state2color = dict(zip(states, colours))
722 # let ambiguous values be white
723 if one_column is None:
724 name2colour['{}_'.format(column)] = WHITE
725 if column == name_column:
726 for tree in forest:
727 for n in tree.traverse():
728 sts = n.props.get(column, set())
729 if len(sts) == 1 and not n.is_root and n.up.props.get(column, set()) == sts:
730 n.add_prop('edge_color', state2color[next(iter(sts))])
731 out_colour_file = os.path.join(work_dir, get_pastml_colour_file(column))
732 # Not using DataFrames to speed up document writing
733 with open(out_colour_file, 'w+') as f:
734 f.write('state\tcolour\n')
735 for s in sorted(states):
736 f.write('{}\t{}\n'.format(s, state2color[s]))
737 logging.getLogger('pastml').debug('Mapped states to colours for {} as following: {} -> {}, '
738 'and serialized this mapping to {}.'
739 .format(column, states, colours, out_colour_file))
740 for tree in forest:
741 for node in tree.traverse():
742 if node.is_leaf:
743 node.add_prop(IS_TIP, True)
744 node.add_prop(BRANCH_NAME, '{}-{}'.format(node.up.name if not node.is_root else '', node.name))
745 for column in column2states.keys():
746 col_state = node.props.get(column, set())
747 if len(col_state) != 1:
748 node.add_prop(UNRESOLVED, 1)
749 break
751 if TIMELINE_NODES == timeline_type:
752 def get_date(node):
753 return node.props.get(DATE)
754 elif TIMELINE_SAMPLED == timeline_type:
755 max_date = max(max(_.props.get(DATE) for _ in tree) for tree in forest)
757 def get_date(node):
758 tips = [_ for _ in node if _.props.get(IS_TIP, False)]
759 return min(_.props.get(DATE) for _ in tips) if tips else max_date
760 elif TIMELINE_LTT == timeline_type:
761 def get_date(node):
762 return node.props.get(DATE) if node.is_root else (node.up.props.get(DATE) + 1e-6)
763 else:
764 raise ValueError('Unknown timeline type: {}. Allowed ones are {}, {} and {}.'
765 .format(timeline_type, TIMELINE_NODES, TIMELINE_SAMPLED, TIMELINE_LTT))
766 dates = []
767 for tree in forest:
768 dates.extend([_.props.get(DATE) for _ in (tree.traverse()
769 if timeline_type in [TIMELINE_LTT, TIMELINE_NODES] else tree)])
770 dates = sorted(dates)
771 milestones = sorted({dates[0], dates[len(dates) // 8], dates[len(dates) // 4], dates[3 * len(dates) // 8],
772 dates[len(dates) // 2], dates[5 * len(dates) // 8], dates[3 * len(dates) // 4],
773 dates[7 * len(dates) // 8], dates[-1]})
774 milestone_labels = None
775 if DATE_LABEL == date_label:
776 try:
777 milestone_labels = [numeric2datetime(_).strftime("%d %b %Y") for _ in milestones]
778 except:
779 pass
780 if milestone_labels is None:
781 milestone_labels = ['{:g}'.format(_) for _ in milestones]
783 if html:
784 total_num_tips = sum(len(tree) for tree in forest)
785 if total_num_tips > MAX_TIPS_FOR_FULL_TREE_VISUALISATION:
786 logging.getLogger('pastml').error('The full tree{} will not be visualised as {} too large ({} tips): '
787 'the limit is {} tips. Check out upload to iTOL option instead.'
788 .format('s' if len(forest) > 1 else '',
789 'they are' if len(forest) > 1 else 'it is',
790 total_num_tips, MAX_TIPS_FOR_FULL_TREE_VISUALISATION))
791 else:
792 save_as_cytoscape_html(forest, html, column2states=column2states, name2colour=name2colour,
793 name_feature='name', compressed_forest=None, milestone_label=date_label,
794 timeline_type=timeline_type, milestones=milestones, get_date=get_date,
795 work_dir=work_dir, local_css_js=local_css_js, milestone_labels=milestone_labels)
796 if html_compressed and html_mixed:
797 forest_mixed = copy_forest(forest)
798 else:
799 forest_mixed = forest
801 if html_compressed:
802 pajek_vert_arcs = [[], []] if pajek else None
803 compressed_forest = [compress_tree(tree, columns=column2states.keys(), tip_size_threshold=tip_size_threshold,
804 mixed=False, pajek=pajek_vert_arcs, pajek_timing=pajek_timing)
805 for tree in forest]
806 if pajek:
807 save_to_pajek(*pajek_vert_arcs, pajek)
809 milestone_labels, milestones = update_milestones(forest, date_label, milestone_labels, milestones,
810 timeline_type)
812 save_as_cytoscape_html(forest, html_compressed,
813 column2states=column2states, name2colour=name2colour,
814 name_feature=name_column, compressed_forest=compressed_forest,
815 milestone_label=date_label, timeline_type=timeline_type,
816 milestones=milestones, get_date=get_date, work_dir=work_dir, local_css_js=local_css_js,
817 milestone_labels=milestone_labels, is_mixed=False)
819 if html_mixed:
820 mixed_forest = [compress_tree(tree, columns=column2states.keys(), tip_size_threshold=tip_size_threshold,
821 mixed=True) for tree in forest_mixed]
822 milestone_labels, milestones = update_milestones(forest_mixed, date_label, milestone_labels, milestones,
823 timeline_type)
824 save_as_cytoscape_html(forest_mixed, html_mixed,
825 column2states=column2states, name2colour=name2colour,
826 name_feature=name_column, compressed_forest=mixed_forest,
827 milestone_label=date_label, timeline_type=timeline_type,
828 milestones=milestones, get_date=get_date, work_dir=work_dir, local_css_js=local_css_js,
829 milestone_labels=milestone_labels, is_mixed=True)
832def update_milestones(forest, date_label, milestone_labels, milestones, timeline_type):
833 # If we trimmed a few tips while compressing and they happened to be the oldest/newest ones,
834 # we should update the milestones accordingly.
835 first_date, last_date = np.inf, -np.inf
836 for tree in forest:
837 for _ in (tree.traverse() if timeline_type in [TIMELINE_LTT, TIMELINE_NODES] else tree):
838 first_date = min(first_date, _.props.get(DATE))
839 last_date = max(last_date, _.props.get(DATE))
840 milestones = [ms for ms in milestones if first_date <= ms <= last_date]
841 if milestones[0] > first_date:
842 milestones.insert(0, first_date)
843 if milestones[-1] < last_date:
844 milestones.append(last_date)
845 if DATE_LABEL == date_label:
846 try:
847 milestone_labels = [numeric2datetime(_).strftime("%d %b %Y") for _ in milestones]
848 except:
849 pass
850 if milestone_labels is None:
851 milestone_labels = ['{:g}'.format(_) for _ in milestones]
852 return milestone_labels, milestones