Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/visualisation/cytoscape_manager.py: 11%

561 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-21 09:19 +0100

1import logging 

2import os 

3from collections import defaultdict 

4from glob import glob 

5from queue import Queue 

6from shutil import copyfile 

7 

8import numpy as np 

9from jinja2 import Environment, PackageLoader 

10 

11from pastml import numeric2datetime 

12from pastml.file import get_pastml_colour_file 

13from pastml.tree import DATE, IS_POLYTOMY, copy_forest 

14from pastml.visualisation import get_formatted_date 

15from pastml.visualisation.colour_generator import get_enough_colours, WHITE, parse_colours 

16from pastml.visualisation.tree_compressor import NUM_TIPS_INSIDE, TIPS_INSIDE, TIPS_BELOW, \ 

17 REASONABLE_NUMBER_OF_TIPS, compress_tree, INTERNAL_NODES_INSIDE, ROOTS, IS_TIP, ROOT_DATES, IN_FOCUS, AROUND_FOCUS, \ 

18 UP_FOCUS, save_to_pajek, VERTICAL 

19 

20JS_LIST = ["https://pastml.pasteur.fr/static/js/jquery.min.js", 

21 "https://pastml.pasteur.fr/static/js/jquery.qtip.min.js", 

22 "https://pastml.pasteur.fr/static/js/cytoscape.min.js", 

23 "https://pastml.pasteur.fr/static/js/cytoscape-qtip.js", 

24 "https://pastml.pasteur.fr/static/js/cytoscape-svg.js", 

25 "https://pastml.pasteur.fr/static/js/layout-base.min.js", 

26 "https://pastml.pasteur.fr/static/js/cose-base.min.js", 

27 "https://pastml.pasteur.fr/static/js/cytoscape-cose-bilkent.min.js"] 

28CSS_LIST = ["https://pastml.pasteur.fr/static/css/jquery.qtip.min.css", 

29 "https://pastml.pasteur.fr/static/css/bootstrap.min.css"] 

30 

31MAX_TIPS_FOR_FULL_TREE_VISUALISATION = 5000 

32 

33TIMELINE_SAMPLED = 'SAMPLED' 

34TIMELINE_NODES = 'NODES' 

35TIMELINE_LTT = 'LTT' 

36 

37TIP_LIMIT = 1000 

38 

39MIN_EDGE_SIZE = 50 

40MIN_FONT_SIZE = 80 

41MIN_NODE_SIZE = 200 

42 

43UNRESOLVED = 'unresolved' 

44TIP = 'tip' 

45 

46TOOLTIP = 'tooltip' 

47COLOUR = 'colour' 

48 

49DATA = 'data' 

50ID = 'id' 

51EDGES = 'edges' 

52NODES = 'nodes' 

53ELEMENTS = 'elements' 

54 

55NODE_SIZE = 'node_size' 

56NODE_NAME = 'node_name' 

57BRANCH_NAME = 'branch_name' 

58EDGE_SIZE = 'edge_size' 

59EDGE_NAME = 'edge_name' 

60FONT_SIZE = 'node_fontsize' 

61 

62MILESTONE = 'mile' 

63 

64DATE_LABEL = 'date' 

65 

66DIST_TO_ROOT_LABEL = 'dist. to root' 

67 

68 

69def get_fake_node(n_id, x, y): 

70 attributes = {ID: n_id, 'fake': 1} 

71 return _get_node(attributes, position=(x, y)) 

72 

73 

74def get_node(n, n_id, tooltip='', clazz=None, x=0, y=0): 

75 features = {feature: n.props.get(feature) for feature in n.props if feature in [MILESTONE, UNRESOLVED, 'x', 'y'] 

76 or feature.startswith('node_')} 

77 features[ID] = n_id 

78 if n.is_leaf: 

79 features[TIP] = 1 

80 features[TOOLTIP] = tooltip 

81 return _get_node(features, clazz=_clazz_list2css_class(clazz), position=(x, y) if x is not None else None) 

82 

83 

84def get_edge(source_name, target_name, **kwargs): 

85 return _get_edge(source=source_name, target=target_name, **kwargs) 

86 

87 

88def get_scaling_function(y_m, y_M, x_m, x_M): 

89 """ 

90 Returns a linear function y = k x + b, where y \in [m, M] 

91 :param y_m: 

92 :param y_M: 

93 :param x_m: 

94 :param x_M: 

95 :return: 

96 """ 

97 if x_M <= x_m: 

98 return lambda _: y_m 

99 k = (y_M - y_m) / (x_M - x_m) 

100 b = y_m - k * x_m 

101 return lambda _: int(k * _ + b) 

102 

103 

104def set_cyto_features_compressed(n, size_scaling, e_size_scaling, font_scaling, transform_size, transform_e_size, state, 

105 root_names, root_dates, suffix='', is_mixed=False): 

106 tips_inside, tips_below, internal_nodes_inside, roots = \ 

107 n.props.get(TIPS_INSIDE, []), n.props.get(TIPS_BELOW, []), \ 

108 n.props.get(INTERNAL_NODES_INSIDE, []), n.props.get(ROOTS, []) 

109 

110 def get_min_max_str(values, default_value=0): 

111 min_v, max_v = (min(len(_) for _ in values), max(len(_) for _ in values)) \ 

112 if values else (default_value, default_value) 

113 return ' {}'.format('{}-{}'.format(min_v, max_v) if min_v != max_v else min_v), min_v, max_v 

114 

115 tips_below_str, _, max_n_tips_below = get_min_max_str(tips_below) 

116 tips_inside_str, _, max_n_tips = get_min_max_str(tips_inside) 

117 internal_ns_inside_str, _, _ = get_min_max_str(internal_nodes_inside) 

118 n.add_prop('{}{}'.format(NODE_NAME, suffix), 

119 '{}{}'.format(state, tips_inside_str) 

120 if not is_mixed or (not n.props.get(n, IN_FOCUS, False) and not n.props.get(UP_FOCUS, False)) else 

121 '{}{}{}'.format(state, ':' if state else '', root_names[0])) 

122 size_factor = 2 if n.props.get(UNRESOLVED, False) else 1 

123 n.add_prop('{}{}'.format(NODE_SIZE, suffix), 

124 (size_scaling(transform_size(max_n_tips)) if max_n_tips else int(MIN_NODE_SIZE / 1.5)) 

125 * size_factor) 

126 n.add_prop('{}{}'.format(FONT_SIZE, suffix), 

127 font_scaling(transform_size(max_n_tips)) if max_n_tips else MIN_FONT_SIZE) 

128 

129 n.add_prop('node_{}{}'.format(TIPS_INSIDE, suffix), tips_inside_str) 

130 n.add_prop('node_{}{}'.format(INTERNAL_NODES_INSIDE, suffix), internal_ns_inside_str) 

131 n.add_prop('node_{}{}'.format(TIPS_BELOW, suffix), tips_below_str) 

132 root_name2date = dict(zip(root_names, root_dates)) 

133 root_names = sorted(root_names) 

134 n.add_prop('node_{}{}'.format(ROOTS, suffix), ', '.join(root_names)) 

135 n.add_prop('node_{}{}'.format(ROOT_DATES, suffix), ', '.join(str(root_name2date[_]) for _ in root_names)) 

136 

137 edge_size = len(roots) 

138 if edge_size > 1: 

139 n.add_prop('node_meta{}'.format(suffix), 1) 

140 n.add_prop('{}{}'.format(EDGE_NAME, suffix), str(edge_size) if edge_size != 1 else '') 

141 e_size = e_size_scaling(transform_e_size(edge_size)) 

142 n.add_prop('{}{}'.format(EDGE_SIZE, suffix), e_size) 

143 

144 

145def set_cyto_features_tree(n, state): 

146 n.add_prop(NODE_NAME, state) 

147 n.add_prop(EDGE_NAME, n.dist) 

148 

149 

150def _forest2json_compressed(forest, compressed_forest, columns, name_feature, get_date, milestones=None, 

151 dates_are_dates=True, is_mixed=False): 

152 e_size_scaling, font_scaling, size_scaling, transform_e_size, transform_size = \ 

153 get_size_transformations(compressed_forest) 

154 

155 sort_key = lambda n: (n.props.get(UNRESOLVED, 0), 

156 get_column_value_str(n, name_feature, format_list=True) if name_feature else '', 

157 *(get_column_value_str(n, column, format_list=True) for column in columns), 

158 -n.props.get(NUM_TIPS_INSIDE), 

159 -len(n.props.get(ROOTS)), 

160 n.name) 

161 i = 0 

162 node2id = {} 

163 todo = Queue() 

164 for compressed_tree in compressed_forest: 

165 todo.put_nowait(compressed_tree) 

166 node2id[compressed_tree] = i 

167 i += 1 

168 while not todo.empty(): 

169 n = todo.get_nowait() 

170 for c in sorted(n.children, key=sort_key): 

171 node2id[c] = i 

172 i += 1 

173 todo.put_nowait(c) 

174 

175 n2state = {} 

176 

177 # Set the cytoscape features 

178 for compressed_tree in compressed_forest: 

179 for n in compressed_tree.traverse(): 

180 state = get_column_value_str(n, name_feature, format_list=False, list_value='') if name_feature else '' 

181 n2state[n] = state 

182 root_names = [_.name for _ in n.props.get(ROOTS)] 

183 root_dates = [get_formatted_date(_, dates_are_dates) for _ in n.props.get(ROOTS)] 

184 set_cyto_features_compressed(n, size_scaling, e_size_scaling, font_scaling, 

185 transform_size, transform_e_size, state, root_names, root_dates, 

186 is_mixed=is_mixed) 

187 

188 # Calculate node coordinates 

189 min_size = 2 * min(min(_.props.get(NODE_SIZE) for _ in compressed_tree.traverse()) 

190 for compressed_tree in compressed_forest) 

191 n2width = {} 

192 for compressed_tree in compressed_forest: 

193 for n in compressed_tree.traverse('postorder'): 

194 n2width[n] = max(n.props.get(NODE_SIZE), 

195 sum(n2width[c] for c in n.children) + min_size * (len(n.children) - 1)) 

196 

197 n2x, n2y = {}, {compressed_tree: 0 for compressed_tree in compressed_forest} 

198 n2offset = {} 

199 tree_offset = 0 

200 for compressed_tree in compressed_forest: 

201 n2offset[compressed_tree] = tree_offset 

202 tree_offset += n2width[compressed_tree] + 2 * min_size 

203 for n in compressed_tree.traverse('preorder'): 

204 n2x[n] = n2offset[n] + n2width[n] / 2 

205 offset = n2offset[n] 

206 if not n.is_leaf: 

207 for c in sorted(n.children, key=lambda c: node2id[c]): 

208 n2offset[c] = offset 

209 offset += n2width[c] + min_size 

210 n2y[c] = n2y[n] + n.props.get(NODE_SIZE) / 2 + c.props.get(NODE_SIZE) / 2 + min_size 

211 for n in compressed_tree.traverse('postorder'): 

212 if not n.is_leaf: 

213 n2x[n] = np.mean([n2x[c] for c in n.children]) 

214 

215 def filter_by_date(items, date): 

216 return [_ for _ in items if get_date(_) <= date] 

217 

218 # Set the cytoscape feature for different timeline points 

219 for tree, compressed_tree in zip(forest, compressed_forest): 

220 if len(milestones) > 1: 

221 nodes = list(compressed_tree.traverse()) 

222 for i in range(len(milestones) - 1, -1, -1): 

223 milestone = milestones[i] 

224 nodes_i = [] 

225 

226 # remove too recent nodes from the original tree 

227 for n in tree.traverse('postorder'): 

228 if n.is_root: 

229 continue 

230 if get_date(n) > milestone: 

231 n.up.remove_child(n) 

232 

233 suffix = '_{}'.format(i) 

234 for n in nodes: 

235 state = n2state[n] 

236 tips_inside, tips_below, internal_nodes_inside, roots = n.props.get(TIPS_INSIDE, []), \ 

237 n.props.get(TIPS_BELOW, []), \ 

238 n.props.get(INTERNAL_NODES_INSIDE, []), \ 

239 n.props.get(ROOTS, []) 

240 tips_inside_i, tips_below_i, internal_nodes_inside_i, roots_i = [], [], [], [] 

241 for ti, tb, ini, root in zip(tips_inside, tips_below, internal_nodes_inside, roots): 

242 if get_date(root) <= milestone: 

243 roots_i.append(root) 

244 

245 ti = filter_by_date(ti, milestone) 

246 tb = [_ for _ in tb if _.props.get(DATE) <= milestone] 

247 ini = filter_by_date(ini, milestone) 

248 

249 tips_inside_i.append(ti + [_ for _ in ini if _.is_leaf]) 

250 tips_below_i.append(tb) 

251 internal_nodes_inside_i.append([_ for _ in ini if not _.is_leaf]) 

252 n.add_props(**{TIPS_INSIDE: tips_inside_i, TIPS_BELOW: tips_below_i, 

253 INTERNAL_NODES_INSIDE: internal_nodes_inside_i, ROOTS: roots_i}) 

254 if roots_i: 

255 n.add_prop(MILESTONE, i) 

256 root_names = [_.props.get(BRANCH_NAME) if _.props.get(DATE) > milestone else _.name for _ in 

257 roots_i] 

258 root_dates = [_.props.get(DATE) for _ in roots_i] 

259 if dates_are_dates: 

260 try: 

261 root_dates = [numeric2datetime(_).strftime("%d %b %Y") for _ in root_dates] 

262 except: 

263 pass 

264 set_cyto_features_compressed(n, size_scaling, e_size_scaling, font_scaling, transform_size, 

265 transform_e_size, state, root_names=root_names, suffix=suffix, 

266 root_dates=root_dates, is_mixed=is_mixed) 

267 nodes_i.append(n) 

268 nodes = nodes_i 

269 

270 # Save the structure 

271 clazzes = set() 

272 nodes, edges = [], [] 

273 

274 one_column = columns[0] if len(columns) == 1 else None 

275 

276 for n, n_id in node2id.items(): 

277 if one_column: 

278 values = n.props.get(one_column, set()) 

279 clazz = tuple(sorted(values)) 

280 else: 

281 clazz = tuple('{}_{}'.format(column, get_column_value_str(n, column, format_list=False, list_value='')) 

282 for column in columns) 

283 if clazz: 

284 clazzes.add(clazz) 

285 nodes.append(get_node(n, n_id, tooltip=get_tooltip(n, columns), 

286 clazz=clazz, x=n2x[n], y=n2y[n])) 

287 

288 for child in sorted(n.children, key=lambda _: node2id[_]): 

289 edge_attributes = {feature: child.props.get(feature) for feature in child.props 

290 if feature.startswith('edge_') or feature == MILESTONE or feature == IS_POLYTOMY} 

291 source_name = n_id 

292 edges.append(get_edge(source_name, node2id[child], **edge_attributes)) 

293 

294 json_dict = {NODES: nodes, EDGES: edges} 

295 return json_dict, sorted(clazzes) 

296 

297 

298def binary_search(start, end, value, array): 

299 if start >= end - 1: 

300 return start 

301 i = int((start + end) / 2) 

302 if array[i] == value or array[i] > value and (i == start or value > array[i - 1]): 

303 return i 

304 if array[i] > value: 

305 return binary_search(start, i, value, array) 

306 return binary_search(i + 1, end, value, array) 

307 

308 

309def _forest2json(forest, columns, name_feature, get_date, milestones=None, timeline_type=TIMELINE_SAMPLED, 

310 dates_are_dates=True): 

311 min_root_date = min(tree.props.get(DATE) for tree in forest) 

312 width = sum(len(tree) for tree in forest) 

313 height_factor = 300 * width / max( 

314 (max(_.props.get(DATE) for _ in tree) - min_root_date + tree.dist) for tree in forest) 

315 zero_dist = min(min(min(_.dist for _ in tree.traverse() if _.dist > 0) for tree in forest), 300) * height_factor / 2 

316 

317 # Calculate node coordinates 

318 n2x, n2y = {}, {} 

319 x = -600 

320 for tree in forest: 

321 x += 600 

322 for t in tree: 

323 n2x[t] = x 

324 x += 600 

325 

326 for tree in forest: 

327 for n in tree.traverse('postorder'): 

328 state = get_column_value_str(n, name_feature, format_list=False, list_value='') if name_feature else '' 

329 n.add_prop('node_root_id', n.name) 

330 n.add_prop('node_root_date', get_formatted_date(n, dates_are_dates)) 

331 if not n.is_leaf: 

332 n2x[n] = np.mean([n2x[_] for _ in n.children]) 

333 n2y[n] = (n.props.get(DATE) - min_root_date) * height_factor 

334 for c in n.children: 

335 if n2y[c] == n2y[n]: 

336 n2y[c] += zero_dist 

337 set_cyto_features_tree(n, state) 

338 

339 # Save cytoscape features at different timeline points 

340 if len(milestones) > 1: 

341 for tree in forest: 

342 for n in tree.traverse('preorder'): 

343 ms_i = binary_search(0 if n.is_root else n.up.props.get(MILESTONE), 

344 len(milestones), get_date(n), milestones) 

345 n.add_prop(MILESTONE, ms_i) 

346 for i in range(len(milestones) - 1, ms_i - 1, -1): 

347 milestone = milestones[i] 

348 suffix = '_{}'.format(i) 

349 if TIMELINE_LTT == timeline_type: 

350 # if it is LTT also cut the branches if needed 

351 if n.props.get(DATE) > milestone: 

352 n.add_prop('{}{}'.format(EDGE_NAME, suffix), 

353 np.round(milestone - n.up.props.get(DATE), 3)) 

354 else: 

355 n.add_prop('{}{}'.format(EDGE_NAME, suffix), np.round(n.dist, 3)) 

356 

357 # Save the structure 

358 clazzes = set() 

359 nodes, edges = [], [] 

360 

361 one_column = columns[0] if len(columns) == 1 else None 

362 

363 i = 0 

364 node2id = {} 

365 for tree in forest: 

366 for n in tree.traverse(): 

367 node2id[n] = i 

368 i += 1 

369 

370 for n, n_id in node2id.items(): 

371 if n.is_root: 

372 fake_id = 'fake_node_{}'.format(n_id) 

373 nodes.append(get_fake_node(fake_id, n2x[n], n2y[n] - n.dist * height_factor)) 

374 edges.append(get_edge(fake_id, n_id, **{feature: n.props.get(feature) for feature in n.props 

375 if feature.startswith('edge_') or feature == MILESTONE})) 

376 if one_column: 

377 values = n.props.get(one_column, set()) 

378 clazz = tuple(sorted(values)) 

379 else: 

380 clazz = tuple('{}_{}'.format(column, get_column_value_str(n, column, format_list=False, list_value='')) 

381 for column in columns) 

382 if clazz: 

383 clazzes.add(clazz) 

384 nodes.append(get_node(n, n_id, tooltip=get_tooltip(n, columns), clazz=clazz, x=n2x[n], y=n2y[n])) 

385 

386 for child in n.children: 

387 edge_attributes = {feature: child.props.get(feature) for feature in child.props 

388 if feature.startswith('edge_') or feature == MILESTONE or feature == IS_POLYTOMY} 

389 source_name = n_id 

390 target_name = 'fake_node_{}'.format(node2id[child]) 

391 nodes.append(get_fake_node(target_name, x=n2x[child], y=n2y[n])) 

392 edges.append(get_edge(source_name, target_name, fake=1, 

393 **{k: v for (k, v) in edge_attributes.items() 

394 if EDGE_NAME not in k and IS_POLYTOMY not in k})) 

395 source_name = target_name 

396 edges.append(get_edge(source_name, node2id[child], **edge_attributes)) 

397 

398 json_dict = {NODES: nodes, EDGES: edges} 

399 return json_dict, sorted(clazzes) 

400 

401 

402def _forest2json_transitions(states, counts, transitions, state2colour, threshold=0): 

403 nodes, edges = [], [] 

404 n = len(states) 

405 

406 n_scaler = get_scaling_function(y_m=200, y_M=800, x_m=min(counts), x_M=max(counts)) 

407 font_scaler = get_scaling_function(y_m=MIN_FONT_SIZE, y_M=MIN_FONT_SIZE * 3, x_m=min(counts), x_M=max(counts)) 

408 positive_transitions = transitions[transitions > 0] 

409 e_scaler = get_scaling_function(y_m=30, y_M=200, x_m=min(positive_transitions), x_M=max(positive_transitions)) 

410 

411 transtions_to_from = np.transpose(transitions.copy()) 

412 np.fill_diagonal(transtions_to_from, 0) 

413 nums = np.triu(transitions + transtions_to_from).flatten() 

414 positive_nums = nums[nums > 0] 

415 

416 max_transition_num = max(positive_nums) 

417 if max_transition_num <= 2: 

418 miles = sorted((set({0, threshold, 1 if max_transition_num > 1 else 0} 

419 | set(np.round(positive_nums, 3))) - {np.round(max_transition_num, 3)})) 

420 else: 

421 miles = sorted(set({0, threshold, 1} | set(np.trunc(positive_nums)))) 

422 miles = np.array([_ for _ in miles if threshold <= _ < max_transition_num]) 

423 if not len(miles): 

424 miles = [0] 

425 

426 # we hide connections when they are < mile 

427 def get_mile(start, end, value): 

428 if start == end: 

429 return None 

430 i = int((start + end) / 2) 

431 if miles[i] <= value: 

432 if i == end - 1 or value < miles[i + 1]: 

433 return i 

434 return get_mile(i + 1, end, value) 

435 return get_mile(start, i, value) 

436 

437 i2mile = defaultdict(lambda: -1) 

438 for i in range(n): 

439 from_state = states[i] 

440 n_tips = counts[i] 

441 i_node_size = n_scaler(n_tips) 

442 if n_tips > 0: 

443 for j in range(i, n): 

444 if counts[j] > 0: 

445 to_state = states[j] 

446 n_ij = transitions[i, j] 

447 n_ji = transitions[j, i] 

448 node_size = (i_node_size + n_scaler(counts[j])) / 2 

449 if n_ij > 0: 

450 mile = get_mile(0, len(miles), n_ij) 

451 if mile is not None: 

452 i2mile[i] = max(mile, i2mile[i]) 

453 i2mile[j] = max(mile, i2mile[j]) 

454 edges.append(get_edge(i, j, 

455 **{ID: '{}_{}'.format(i, j), 

456 EDGE_SIZE: e_scaler(n_ij), 

457 NODE_SIZE: node_size / (2 if i != j else 1), 

458 EDGE_NAME: n_ij, 

459 TOOLTIP: '{} transitions from {} to {}' 

460 .format(n_ij, from_state, to_state), 

461 MILESTONE: mile})) 

462 if n_ji > 0 and i != j: 

463 mile = get_mile(0, len(miles), n_ji) 

464 if mile is not None: 

465 i2mile[i] = max(mile, i2mile[i]) 

466 i2mile[j] = max(mile, i2mile[j]) 

467 edges.append(get_edge(j, i, 

468 **{ID: '{}_{}'.format(j, i), 

469 EDGE_SIZE: e_scaler(n_ji), 

470 NODE_SIZE: node_size / 2, 

471 EDGE_NAME: n_ji, 

472 TOOLTIP: '{} transitions from {} to {}' 

473 .format(n_ji, to_state, from_state), 

474 MILESTONE: mile})) 

475 if i2mile[i] >= 0: 

476 nodes.append(_get_node(data={ID: i, NODE_NAME: '{} ({:.0f})'.format(from_state, n_tips), 

477 NODE_SIZE: i_node_size, 

478 FONT_SIZE: font_scaler(n_tips), 

479 TOOLTIP: '{} is represented by {:.0f} samples.'.format(from_state, n_tips), 

480 COLOUR: state2colour[from_state], 

481 MILESTONE: i2mile[i]})) 

482 json_dict = {NODES: nodes, EDGES: edges} 

483 return json_dict, ['{:g}'.format(_) for _ in miles] 

484 

485 

486def get_size_transformations(forest): 

487 max_size, min_size, max_e_size, min_e_size = 1, np.inf, 1, np.inf 

488 for tree in forest: 

489 for n in tree.traverse(): 

490 sz = max(n.props.get(NUM_TIPS_INSIDE), 1) 

491 max_size = max(max_size, sz) 

492 min_size = min(min_size, sz) 

493 e_sz = len(n.props.get(ROOTS)) 

494 max_e_size = max(max_e_size, e_sz) 

495 min_e_size = min(min_e_size, e_sz) 

496 

497 need_log = max_size / min_size > 100 

498 transform_size = lambda _: np.power(np.log10(_ + 9) if need_log else _, 1 / 2) 

499 

500 need_e_log = max_e_size / min_e_size > 100 

501 transform_e_size = lambda _: np.log10(_) if need_e_log else _ 

502 

503 size_scaling = get_scaling_function(y_m=MIN_NODE_SIZE, y_M=MIN_NODE_SIZE * min(8, int(max_size / min_size)), 

504 x_m=transform_size(min_size), x_M=transform_size(max_size)) 

505 font_scaling = get_scaling_function(y_m=MIN_FONT_SIZE, y_M=MIN_FONT_SIZE * min(3, int(max_size / min_size)), 

506 x_m=transform_size(min_size), x_M=transform_size(max_size)) 

507 e_size_scaling = get_scaling_function(y_m=MIN_EDGE_SIZE, y_M=MIN_EDGE_SIZE * min(3, int(max_e_size / min_e_size)), 

508 x_m=transform_e_size(min_e_size), x_M=transform_e_size(max_e_size)) 

509 

510 return e_size_scaling, font_scaling, size_scaling, transform_e_size, transform_size 

511 

512 

513def get_tooltip(n, columns): 

514 return '<br>'.join('{}: {}'.format(column, get_column_value_str(n, column, format_list=True)) 

515 for column in columns) 

516 

517 

518def save_as_cytoscape_html(forest, out_html, column2states, name_feature, name2colour, compressed_forest, 

519 milestone_label, timeline_type, milestones, get_date, work_dir, local_css_js=False, 

520 milestone_labels=None, is_mixed=False): 

521 """ 

522 Converts a forest to an html representation using Cytoscape.js. 

523 

524 If categories are specified they are visualised as pie-charts inside the nodes, 

525 given that each node contains features corresponding to these categories with values being the percentage. 

526 For instance, given categories ['A', 'B', 'C'], a node with features {'A': 50, 'B': 50} 

527 will have a half-half pie-chart (half-colored in a colour of A, and half B). 

528 

529 :param name_feature: str, a node feature whose value will be used as a label 

530 :param name2colour: dict, str to str, category name to HEX colour mapping  

531 :param forest: list(ete3.Tree) 

532 :param out_html: path where to save the resulting html file. 

533 """ 

534 graph_name = os.path.splitext(os.path.basename(out_html))[0] 

535 columns = sorted(column2states.keys()) 

536 if milestone_labels is None: 

537 milestone_labels = ['{:g}'.format(_) for _ in milestones] 

538 

539 if compressed_forest is not None: 

540 json_dict, clazzes \ 

541 = _forest2json_compressed(forest, compressed_forest, columns, name_feature=name_feature, get_date=get_date, 

542 milestones=milestones, dates_are_dates=milestone_label == DATE_LABEL, 

543 is_mixed=is_mixed) 

544 else: 

545 json_dict, clazzes \ 

546 = _forest2json(forest, columns, name_feature=name_feature, get_date=get_date, milestones=milestones, 

547 timeline_type=timeline_type, dates_are_dates=milestone_label == DATE_LABEL) 

548 loader = PackageLoader('pastml') 

549 env = Environment(loader=loader) 

550 template = env.get_template('pie_tree.js') if compressed_forest is not None \ 

551 else env.get_template('pie_tree_simple.js') 

552 

553 clazz2css = {} 

554 for clazz_list in clazzes: 

555 n = len(clazz_list) 

556 css = '' 

557 for i, cat in enumerate(clazz_list, start=1): 

558 css += """ 

559 'pie-{i}-background-color': "{colour}", 

560 'pie-{i}-background-size': '{percent}\%', 

561 """.format(i=i, percent=round(100 / n, 2), colour=name2colour[cat]) 

562 clazz2css[_clazz_list2css_class(clazz_list)] = css 

563 graph = template.render(clazz2css=clazz2css.items(), elements=json_dict, title=graph_name, 

564 years=milestone_labels, 

565 tips='samples' if TIMELINE_SAMPLED == timeline_type 

566 else ('lineages ending' if TIMELINE_LTT == timeline_type else 'external nodes'), 

567 internal_nodes='internal nodes' if TIMELINE_NODES == timeline_type 

568 else 'diversification events', 

569 age_label=milestone_label) 

570 slider = env.get_template('time_slider.html').render(min_date=0, max_date=len(milestones) - 1, 

571 cur_date=len(milestones) - 1, 

572 name=milestone_label) if len(milestones) > 1 else '' 

573 

574 template = env.get_template('index.html') 

575 os.makedirs(os.path.abspath(os.path.dirname(out_html)), exist_ok=True) 

576 

577 if local_css_js: 

578 js_list = [] 

579 os.makedirs(os.path.join(work_dir, 'js'), exist_ok=True) 

580 os.makedirs(os.path.join(work_dir, 'css'), exist_ok=True) 

581 os.makedirs(os.path.join(work_dir, 'fonts'), exist_ok=True) 

582 

583 template_dir = os.path.join(os.path.abspath(os.path.split(__file__)[0]), '..', 'templates') 

584 for _ in sorted(glob(os.path.join(template_dir, 'js', '*.js*'))): 

585 cp = os.path.join(work_dir, 'js', os.path.split(_)[1]) 

586 copyfile(_, cp) 

587 if cp.endswith('.js'): 

588 js_list.append(cp) 

589 css_list = [] 

590 for _ in glob(os.path.join(template_dir, 'css', '*.css*')): 

591 cp = os.path.join(work_dir, 'css', os.path.split(_)[1]) 

592 copyfile(_, cp) 

593 if cp.endswith('.css'): 

594 css_list.append(cp) 

595 for _ in glob(os.path.join(template_dir, 'fonts', '*.*')): 

596 cp = os.path.join(work_dir, 'fonts', os.path.split(_)[1]) 

597 copyfile(_, cp) 

598 else: 

599 js_list = JS_LIST 

600 css_list = CSS_LIST 

601 page = template.render(graph=graph, title=graph_name, slider=slider, js_list=js_list, css_list=css_list) 

602 

603 with open(out_html, 'w+') as fp: 

604 fp.write(page) 

605 

606 

607def save_as_transition_html(character, states, counts, transitions, out_html, state2colour, work_dir, 

608 local_css_js=False, threshold=0): 

609 """ 

610 Converts transition count data to an html representation using Cytoscape.js. 

611 

612 :param out_html: path where to save the resulting html file. 

613 """ 

614 graph_name = os.path.splitext(os.path.basename(out_html))[0] 

615 transitions[transitions < threshold] = 0 

616 json_dict, thresholds = _forest2json_transitions(states, counts, transitions, state2colour, threshold=threshold) 

617 

618 loader = PackageLoader('pastml') 

619 env = Environment(loader=loader) 

620 template = env.get_template('transitions.js') 

621 

622 graph = template.render(elements=json_dict, character=character, 

623 years=thresholds) 

624 slider = env.get_template('time_slider.html').render(min_date=0, max_date=len(thresholds) - 1, 

625 name='transition number threshold', cur_date=0) if len( 

626 thresholds) > 1 else '' 

627 

628 template = env.get_template('index.html') 

629 os.makedirs(os.path.abspath(os.path.dirname(out_html)), exist_ok=True) 

630 

631 if local_css_js: 

632 js_list = [] 

633 os.makedirs(os.path.join(work_dir, 'js'), exist_ok=True) 

634 os.makedirs(os.path.join(work_dir, 'css'), exist_ok=True) 

635 os.makedirs(os.path.join(work_dir, 'fonts'), exist_ok=True) 

636 

637 template_dir = os.path.join(os.path.abspath(os.path.split(__file__)[0]), '..', 'templates') 

638 for _ in sorted(glob(os.path.join(template_dir, 'js', '*.js*'))): 

639 cp = os.path.join(work_dir, 'js', os.path.split(_)[1]) 

640 copyfile(_, cp) 

641 if cp.endswith('.js'): 

642 js_list.append(cp) 

643 css_list = [] 

644 for _ in glob(os.path.join(template_dir, 'css', '*.css*')): 

645 cp = os.path.join(work_dir, 'css', os.path.split(_)[1]) 

646 copyfile(_, cp) 

647 if cp.endswith('.css'): 

648 css_list.append(cp) 

649 for _ in glob(os.path.join(template_dir, 'fonts', '*.*')): 

650 cp = os.path.join(work_dir, 'fonts', os.path.split(_)[1]) 

651 copyfile(_, cp) 

652 else: 

653 js_list = JS_LIST 

654 css_list = CSS_LIST 

655 page = template.render(graph=graph, title=graph_name, slider=slider, js_list=js_list, css_list=css_list) 

656 

657 with open(out_html, 'w+') as fp: 

658 fp.write(page) 

659 

660 

661def _clazz_list2css_class(clazz_list): 

662 if not clazz_list: 

663 return None 

664 return ''.join(c for c in '-'.join(clazz_list) if c.isalnum() or '-' == c) 

665 

666 

667def _get_node(data, clazz=None, position=None): 

668 if position: 

669 data['node_x'] = position[0] 

670 data['node_y'] = position[1] 

671 res = {DATA: data} 

672 if clazz: 

673 res['classes'] = clazz 

674 return res 

675 

676 

677def _get_edge(**data): 

678 return {DATA: data} 

679 

680 

681def get_column_value_str(n, column, format_list=True, list_value=''): 

682 values = n.props.get(column, set()) 

683 if isinstance(values, str): 

684 return values 

685 return ' or '.join(sorted(values)) if format_list or len(values) == 1 else list_value 

686 

687 

688def visualize(forest, column2states, work_dir, name_column=None, html=None, html_compressed=None, html_mixed=None, 

689 tip_size_threshold=REASONABLE_NUMBER_OF_TIPS, date_label='Dist. to root', timeline_type=TIMELINE_SAMPLED, 

690 local_css_js=False, column2colours=None, focus=None, pajek=None, pajek_timing=VERTICAL): 

691 for tree in forest: 

692 nodes_in_focus = set() 

693 for node in tree.traverse(): 

694 for column in column2states.keys(): 

695 col_state = node.props.get(column, set()) 

696 if focus and col_state & focus[column]: 

697 nodes_in_focus.add(node) 

698 for node in nodes_in_focus: 

699 node.add_prop(IN_FOCUS, True) 

700 if not node.is_root and node.up not in nodes_in_focus: 

701 node.up.add_prop(UP_FOCUS, True) 

702 for c in node.children: 

703 if c not in nodes_in_focus: 

704 c.add_prop(AROUND_FOCUS, True) 

705 

706 one_column = next(iter(column2states.keys())) if len(column2states) == 1 else None 

707 

708 name2colour = {} 

709 for column, states in column2states.items(): 

710 num_unique_values = len(states) 

711 colours = None 

712 if column2colours and column in column2colours: 

713 try: 

714 colours = parse_colours(column2colours[column], states) 

715 except ValueError as e: 

716 logging.getLogger('pastml').error('Failed to parse the input colours: {}'.format(e)) 

717 if colours is None: 

718 colours = get_enough_colours(num_unique_values) 

719 for value, col in zip(states, colours): 

720 name2colour[value if one_column else '{}_{}'.format(column, value)] = col 

721 state2color = dict(zip(states, colours)) 

722 # let ambiguous values be white 

723 if one_column is None: 

724 name2colour['{}_'.format(column)] = WHITE 

725 if column == name_column: 

726 for tree in forest: 

727 for n in tree.traverse(): 

728 sts = n.props.get(column, set()) 

729 if len(sts) == 1 and not n.is_root and n.up.props.get(column, set()) == sts: 

730 n.add_prop('edge_color', state2color[next(iter(sts))]) 

731 out_colour_file = os.path.join(work_dir, get_pastml_colour_file(column)) 

732 # Not using DataFrames to speed up document writing 

733 with open(out_colour_file, 'w+') as f: 

734 f.write('state\tcolour\n') 

735 for s in sorted(states): 

736 f.write('{}\t{}\n'.format(s, state2color[s])) 

737 logging.getLogger('pastml').debug('Mapped states to colours for {} as following: {} -> {}, ' 

738 'and serialized this mapping to {}.' 

739 .format(column, states, colours, out_colour_file)) 

740 for tree in forest: 

741 for node in tree.traverse(): 

742 if node.is_leaf: 

743 node.add_prop(IS_TIP, True) 

744 node.add_prop(BRANCH_NAME, '{}-{}'.format(node.up.name if not node.is_root else '', node.name)) 

745 for column in column2states.keys(): 

746 col_state = node.props.get(column, set()) 

747 if len(col_state) != 1: 

748 node.add_prop(UNRESOLVED, 1) 

749 break 

750 

751 if TIMELINE_NODES == timeline_type: 

752 def get_date(node): 

753 return node.props.get(DATE) 

754 elif TIMELINE_SAMPLED == timeline_type: 

755 max_date = max(max(_.props.get(DATE) for _ in tree) for tree in forest) 

756 

757 def get_date(node): 

758 tips = [_ for _ in node if _.props.get(IS_TIP, False)] 

759 return min(_.props.get(DATE) for _ in tips) if tips else max_date 

760 elif TIMELINE_LTT == timeline_type: 

761 def get_date(node): 

762 return node.props.get(DATE) if node.is_root else (node.up.props.get(DATE) + 1e-6) 

763 else: 

764 raise ValueError('Unknown timeline type: {}. Allowed ones are {}, {} and {}.' 

765 .format(timeline_type, TIMELINE_NODES, TIMELINE_SAMPLED, TIMELINE_LTT)) 

766 dates = [] 

767 for tree in forest: 

768 dates.extend([_.props.get(DATE) for _ in (tree.traverse() 

769 if timeline_type in [TIMELINE_LTT, TIMELINE_NODES] else tree)]) 

770 dates = sorted(dates) 

771 milestones = sorted({dates[0], dates[len(dates) // 8], dates[len(dates) // 4], dates[3 * len(dates) // 8], 

772 dates[len(dates) // 2], dates[5 * len(dates) // 8], dates[3 * len(dates) // 4], 

773 dates[7 * len(dates) // 8], dates[-1]}) 

774 milestone_labels = None 

775 if DATE_LABEL == date_label: 

776 try: 

777 milestone_labels = [numeric2datetime(_).strftime("%d %b %Y") for _ in milestones] 

778 except: 

779 pass 

780 if milestone_labels is None: 

781 milestone_labels = ['{:g}'.format(_) for _ in milestones] 

782 

783 if html: 

784 total_num_tips = sum(len(tree) for tree in forest) 

785 if total_num_tips > MAX_TIPS_FOR_FULL_TREE_VISUALISATION: 

786 logging.getLogger('pastml').error('The full tree{} will not be visualised as {} too large ({} tips): ' 

787 'the limit is {} tips. Check out upload to iTOL option instead.' 

788 .format('s' if len(forest) > 1 else '', 

789 'they are' if len(forest) > 1 else 'it is', 

790 total_num_tips, MAX_TIPS_FOR_FULL_TREE_VISUALISATION)) 

791 else: 

792 save_as_cytoscape_html(forest, html, column2states=column2states, name2colour=name2colour, 

793 name_feature='name', compressed_forest=None, milestone_label=date_label, 

794 timeline_type=timeline_type, milestones=milestones, get_date=get_date, 

795 work_dir=work_dir, local_css_js=local_css_js, milestone_labels=milestone_labels) 

796 if html_compressed and html_mixed: 

797 forest_mixed = copy_forest(forest) 

798 else: 

799 forest_mixed = forest 

800 

801 if html_compressed: 

802 pajek_vert_arcs = [[], []] if pajek else None 

803 compressed_forest = [compress_tree(tree, columns=column2states.keys(), tip_size_threshold=tip_size_threshold, 

804 mixed=False, pajek=pajek_vert_arcs, pajek_timing=pajek_timing) 

805 for tree in forest] 

806 if pajek: 

807 save_to_pajek(*pajek_vert_arcs, pajek) 

808 

809 milestone_labels, milestones = update_milestones(forest, date_label, milestone_labels, milestones, 

810 timeline_type) 

811 

812 save_as_cytoscape_html(forest, html_compressed, 

813 column2states=column2states, name2colour=name2colour, 

814 name_feature=name_column, compressed_forest=compressed_forest, 

815 milestone_label=date_label, timeline_type=timeline_type, 

816 milestones=milestones, get_date=get_date, work_dir=work_dir, local_css_js=local_css_js, 

817 milestone_labels=milestone_labels, is_mixed=False) 

818 

819 if html_mixed: 

820 mixed_forest = [compress_tree(tree, columns=column2states.keys(), tip_size_threshold=tip_size_threshold, 

821 mixed=True) for tree in forest_mixed] 

822 milestone_labels, milestones = update_milestones(forest_mixed, date_label, milestone_labels, milestones, 

823 timeline_type) 

824 save_as_cytoscape_html(forest_mixed, html_mixed, 

825 column2states=column2states, name2colour=name2colour, 

826 name_feature=name_column, compressed_forest=mixed_forest, 

827 milestone_label=date_label, timeline_type=timeline_type, 

828 milestones=milestones, get_date=get_date, work_dir=work_dir, local_css_js=local_css_js, 

829 milestone_labels=milestone_labels, is_mixed=True) 

830 

831 

832def update_milestones(forest, date_label, milestone_labels, milestones, timeline_type): 

833 # If we trimmed a few tips while compressing and they happened to be the oldest/newest ones, 

834 # we should update the milestones accordingly. 

835 first_date, last_date = np.inf, -np.inf 

836 for tree in forest: 

837 for _ in (tree.traverse() if timeline_type in [TIMELINE_LTT, TIMELINE_NODES] else tree): 

838 first_date = min(first_date, _.props.get(DATE)) 

839 last_date = max(last_date, _.props.get(DATE)) 

840 milestones = [ms for ms in milestones if first_date <= ms <= last_date] 

841 if milestones[0] > first_date: 

842 milestones.insert(0, first_date) 

843 if milestones[-1] < last_date: 

844 milestones.append(last_date) 

845 if DATE_LABEL == date_label: 

846 try: 

847 milestone_labels = [numeric2datetime(_).strftime("%d %b %Y") for _ in milestones] 

848 except: 

849 pass 

850 if milestone_labels is None: 

851 milestone_labels = ['{:g}'.format(_) for _ in milestones] 

852 return milestone_labels, milestones