Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/visualisation/tree_compressor.py: 13%

223 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-21 09:19 +0100

1import logging 

2from collections import defaultdict 

3from pastml.tree import IS_POLYTOMY, copy_forest 

4 

5import numpy as np 

6 

7VERTICAL = 'VERTICAL' 

8HORIZONTAL = 'HORIZONTAL' 

9TRIM = 'TRIM' 

10 

11IS_TIP = 'is_tip' 

12 

13REASONABLE_NUMBER_OF_TIPS = 15 

14 

15CATEGORIES = 'categories' 

16 

17NUM_TIPS_INSIDE = 'max_size' 

18 

19TIPS_INSIDE = 'in_tips' 

20INTERNAL_NODES_INSIDE = 'in_ns' 

21TIPS_BELOW = 'all_tips' 

22ROOTS = 'roots' 

23ROOT_DATES = 'root_dates' 

24 

25COMPRESSED_NODE = 'compressed_node' 

26 

27METACHILD = 'metachild' 

28 

29IN_FOCUS = 'in_focus' 

30AROUND_FOCUS = 'around_focus' 

31UP_FOCUS = 'up_focus' 

32 

33 

34def _tree2pajek_vertices_arcs(compressed_tree, nodes, edges, columns): 

35 n2id = {} 

36 

37 def get_states(n, columns): 

38 res = [] 

39 for column in columns: 

40 values = n.props.get(column, set()) 

41 value = values if isinstance(values, str) else ' or '.join(sorted(values)) 

42 res.append('{}:{}'.format(column, value)) 

43 return res 

44 

45 for id, n in enumerate(compressed_tree.traverse('preorder'), start=len(nodes) + 1): 

46 n2id[n] = id 

47 if not n.is_root: 

48 edges.append('{} {} {}'.format(n2id[n.up], id, len(n.props.get(ROOTS)))) 

49 nodes.append('{} "{}" "{}" {}'.format(id, n.name, 

50 (';'.join(','.join(_.name for _ in ti) for ti in n.props.get(TIPS_INSIDE))), 

51 ' '.join('"{}"'.format(_) for _ in get_states(n, columns)))) 

52 

53 

54def save_to_pajek(nodes, edges, pajek): 

55 """ 

56 Saves a compressed tree into Pajek format: 

57 

58 *vertices <number_of_vertices> 

59 <id_1> "<vertex_name>" "<tips_inside>" "<column1>:<state(s)>" ["<column2>:<state(s)>" ...] 

60 ... 

61 *arcs 

62 <source_id> <target_id> <weight> 

63 ... 

64 

65 <tips_inside> list tips that were vertically compressed inside this node: they are comma-separated. 

66 If the node was also horizontally merged (i.e. represents several similar configurations), 

67 the tip sets corresponding to different configurations are semicolon-separated. 

68 

69 <state(s)> lists the states predicted for the corresponding column: 

70 if there are several states, they are separated with " or ". 

71 

72 

73 :return: void (creates a file specified in the pajek argument) 

74 """ 

75 with open(pajek, 'w+') as f: 

76 f.write('*vertices {}\n'.format(len(nodes))) 

77 f.write('\n'.join(nodes)) 

78 f.write('\n') 

79 f.write('*arcs\n') 

80 f.write('\n'.join(edges)) 

81 

82 

83def compress_tree(tree, columns, can_merge_diff_sizes=True, tip_size_threshold=REASONABLE_NUMBER_OF_TIPS, mixed=False, 

84 pajek=None, pajek_timing=VERTICAL): 

85 compressed_tree = copy_forest([tree], features=columns | set(tree.props))[0] 

86 

87 for n_compressed, n in zip(compressed_tree.traverse('postorder'), tree.traverse('postorder')): 

88 n_compressed.add_prop(TIPS_BELOW, [list(n_compressed.leaves())]) 

89 n_compressed.add_prop(TIPS_INSIDE, []) 

90 n_compressed.add_prop(INTERNAL_NODES_INSIDE, []) 

91 n_compressed.add_prop(ROOTS, [n]) 

92 if n_compressed.is_leaf: 

93 n_compressed.props.get(TIPS_INSIDE).append(n) 

94 elif not n_compressed.props.get(IS_POLYTOMY, False): 

95 n_compressed.props.get(INTERNAL_NODES_INSIDE).append(n) 

96 n.add_prop(COMPRESSED_NODE, n_compressed) 

97 

98 collapse_vertically(compressed_tree, columns, mixed=mixed) 

99 if pajek is not None and VERTICAL == pajek_timing: 

100 _tree2pajek_vertices_arcs(compressed_tree, *pajek, columns=sorted(columns)) 

101 

102 for n in compressed_tree.traverse(): 

103 n.add_prop(NUM_TIPS_INSIDE, len(n.props.get(TIPS_INSIDE))) 

104 n.add_prop(TIPS_INSIDE, [n.props.get(TIPS_INSIDE)]) 

105 n.add_prop(INTERNAL_NODES_INSIDE, [n.props.get(INTERNAL_NODES_INSIDE)]) 

106 

107 get_bin = lambda _: _ 

108 collapse_horizontally(compressed_tree, columns, get_bin, mixed=mixed) 

109 

110 if can_merge_diff_sizes and len(compressed_tree) > tip_size_threshold: 

111 get_bin = lambda _: int(np.log10(max(1, _))) 

112 logging.getLogger('pastml').debug('Allowed merging nodes of different sizes.') 

113 collapse_horizontally(compressed_tree, columns, get_bin, mixed=mixed) 

114 

115 if pajek is not None and HORIZONTAL == pajek_timing: 

116 _tree2pajek_vertices_arcs(compressed_tree, *pajek, columns=sorted(columns)) 

117 

118 if len(compressed_tree) > tip_size_threshold: 

119 for n in compressed_tree.traverse('preorder'): 

120 multiplier = (n.up.props.get('multiplier') if n.up else 1) * len(n.props.get(ROOTS)) 

121 n.add_prop('multiplier', multiplier) 

122 

123 def get_tsize(n): 

124 if n.props.get(IN_FOCUS, False) or n.props.get(AROUND_FOCUS, False) or n.props.get(UP_FOCUS, False): 

125 return np.inf 

126 return n.props.get(NUM_TIPS_INSIDE) * n.props.get('multiplier') 

127 

128 node_thresholds = [] 

129 for n in compressed_tree.traverse('postorder'): 

130 children_bs = 0 if not n.children else max(get_tsize(_) for _ in n.children) 

131 bs = get_tsize(n) 

132 # if bs > children_bs it means that the trimming threshold for the node is higher 

133 # than the ones for its children 

134 if not n.is_root and bs > children_bs: 

135 node_thresholds.append(bs) 

136 threshold = sorted(node_thresholds)[-tip_size_threshold] 

137 

138 if min(node_thresholds) >= threshold: 

139 if threshold == np.inf: 

140 logging.getLogger('pastml') .debug('All tips are in focus.') 

141 else: 

142 logging.getLogger('pastml')\ 

143 .debug('No tip is smaller than the threshold ({}, the size of the {}-th largest tip).' 

144 .format(threshold, tip_size_threshold)) 

145 else: 

146 if threshold == np.inf: 

147 logging.getLogger('pastml')\ 

148 .debug('Removing all the out of focus tips (as there are at least {} tips in focus).' 

149 .format(tip_size_threshold)) 

150 else: 

151 logging.getLogger('pastml').debug('Set tip size threshold to {} (the size of the {}-th largest tip).' 

152 .format(threshold, tip_size_threshold)) 

153 remove_small_tips(compressed_tree=compressed_tree, full_tree=tree, 

154 to_be_removed=lambda _: get_tsize(_) < threshold) 

155 remove_mediators(compressed_tree, columns) 

156 collapse_horizontally(compressed_tree, columns, get_bin, mixed=mixed) 

157 

158 if pajek is not None and TRIM == pajek_timing: 

159 _tree2pajek_vertices_arcs(compressed_tree, *pajek, columns=sorted(columns)) 

160 

161 return compressed_tree 

162 

163 

164def collapse_horizontally(tree, columns, tips2bin, mixed=False): 

165 config_cache = {} 

166 

167 def get_configuration(n): 

168 if n.name not in config_cache: 

169 # Configuration is (branch_width, (size, states, child_configurations)), 

170 # where branch_width is only used for recursive calls and is ignored when considering a merge 

171 config_cache[n.name] = (len(n.props.get(TIPS_INSIDE)), 

172 (tips2bin(n.props.get(NUM_TIPS_INSIDE)), 

173 tuple(tuple(sorted(n.props.get(column, set()))) for column in columns), 

174 tuple(sorted([get_configuration(_) for _ in n.children])))) 

175 return config_cache[n.name] 

176 

177 collapsed_configurations = 0 

178 

179 uncompressable_ids = set() 

180 for n in tree.traverse('postorder'): 

181 config2children = defaultdict(list) 

182 for _ in n.children: 

183 if mixed and (_.props.get(IN_FOCUS, False) or _.name in uncompressable_ids): 

184 uncompressable_ids.add(_.name) 

185 uncompressable_ids.add(n.name) 

186 else: 

187 # use (size, states, child_configurations) as configuration (ignore branch width) 

188 config2children[get_configuration(_)[1]].append(_) 

189 for children in (_ for _ in config2children.values() if len(_) > 1): 

190 collapsed_configurations += 1 

191 child = children[0] 

192 for sibling in children[1:]: 

193 child.props.get(TIPS_INSIDE).extend(sibling.props.get(TIPS_INSIDE)) 

194 for ti in sibling.props.get(TIPS_INSIDE): 

195 for _ in ti: 

196 _.add_prop(COMPRESSED_NODE, child) 

197 child.props.get(INTERNAL_NODES_INSIDE).extend(sibling.props.get(INTERNAL_NODES_INSIDE)) 

198 for ii in sibling.props.get(INTERNAL_NODES_INSIDE): 

199 for _ in ii: 

200 _.add_prop(COMPRESSED_NODE, child) 

201 child.props.get(ROOTS).extend(sibling.props.get(ROOTS)) 

202 child.props.get(TIPS_BELOW).extend(sibling.props.get(TIPS_BELOW)) 

203 n.remove_child(sibling) 

204 child.add_prop(METACHILD, True) 

205 child.add_prop(NUM_TIPS_INSIDE, 

206 sum(len(_) for _ in child.props.get(TIPS_INSIDE)) / len(child.props.get(TIPS_INSIDE))) 

207 if child.name in config_cache: 

208 config_cache[child.name] = (len(child.props.get(TIPS_INSIDE)), config_cache[child.name][1]) 

209 if collapsed_configurations: 

210 logging.getLogger('pastml').debug( 

211 'Collapsed {} sets of equivalent configurations horizontally.'.format(collapsed_configurations)) 

212 

213 

214def remove_small_tips(compressed_tree, full_tree, to_be_removed): 

215 num_removed = 0 

216 changed = True 

217 while changed: 

218 changed = False 

219 for l in compressed_tree.get_leaves(): 

220 parent = l.up 

221 if parent and to_be_removed(l): 

222 num_removed += 1 

223 parent.remove_child(l) 

224 # remove the corresponding nodes from the non-collapsed tree 

225 for ti in l.props.get(TIPS_INSIDE): 

226 for _ in ti: 

227 _.up.remove_child(_) 

228 for ii in l.props.get(INTERNAL_NODES_INSIDE): 

229 for _ in ii: 

230 _.up.remove_child(_) 

231 changed = True 

232 

233 # if the full tree now contains non-sampled tips, 

234 # remove them from the tree and from the corresponding collapsed nodes 

235 todo = list(full_tree) 

236 while todo: 

237 t = todo.pop() 

238 if not t.props.get(IS_TIP, False): 

239 parent = t.up 

240 t.up.remove_child(t) 

241 if parent.is_leaf: 

242 todo.append(parent) 

243 for ini_list in t.props.get(COMPRESSED_NODE).props.get(INTERNAL_NODES_INSIDE): 

244 if t in ini_list: 

245 ini_list.remove(t) 

246 

247 logging.getLogger('pastml').debug( 

248 'Recursively removed {} tips of size smaller than the threshold.'.format(num_removed)) 

249 

250 

251def collapse_vertically(tree, columns, mixed=False): 

252 """ 

253 Collapses a child node into its parent if they are in the same state. 

254 :param columns: a list of characters 

255 :param tree: ete3.Tree 

256 :param mixed: if True then the nodes in focus will not get collapsed 

257 :return: void, modifies the input tree 

258 """ 

259 

260 def _same_states(node1, node2, columns): 

261 for column in columns: 

262 if node1.props.get(column, set()) != node2.props.get(column, set()): 

263 return False 

264 if mixed: 

265 if node1.props.get(IN_FOCUS, False) or node2.props.get(IN_FOCUS, False): 

266 return False 

267 if node1.props.get(UP_FOCUS, False) and not node2.props.get(IN_FOCUS, False) and not node2.props.get(UP_FOCUS, False): 

268 node2.add_prop(AROUND_FOCUS, True) 

269 return False 

270 if node2.props.get(UP_FOCUS, False) and not node1.props.get(IN_FOCUS, False) and not node1.props.get(UP_FOCUS, False): 

271 node1.add_prop(AROUND_FOCUS, True) 

272 return False 

273 return True 

274 

275 num_collapsed = 0 

276 for n in tree.traverse('postorder'): 

277 if n.is_leaf: 

278 continue 

279 

280 children = list(n.children) 

281 for child in children: 

282 # merge the child into this node if their states are the same 

283 if _same_states(n, child, columns): 

284 n.props.get(TIPS_INSIDE).extend(child.props.get(TIPS_INSIDE)) 

285 for _ in child.props.get(TIPS_INSIDE): 

286 _.add_prop(COMPRESSED_NODE, n) 

287 n.props.get(INTERNAL_NODES_INSIDE).extend(child.props.get(INTERNAL_NODES_INSIDE)) 

288 for _ in child.props.get(INTERNAL_NODES_INSIDE): 

289 _.add_prop(COMPRESSED_NODE, n) 

290 

291 n.remove_child(child) 

292 grandchildren = list(child.children) 

293 for grandchild in grandchildren: 

294 n.add_child(grandchild) 

295 num_collapsed += 1 

296 if num_collapsed: 

297 logging.getLogger('pastml').debug('Collapsed vertically {} internal nodes without state change.' 

298 .format(num_collapsed)) 

299 

300 

301def remove_mediators(tree, columns): 

302 """ 

303 Removes intermediate nodes that are just mediators between their parent and child states. 

304 :param columns: list of characters 

305 :param tree: ete3.Tree 

306 :return: void, modifies the input tree 

307 """ 

308 num_removed = 0 

309 for n in tree.traverse('postorder'): 

310 if n.props.get(METACHILD, False) or n.is_leaf or len(n.children) > 1 or n.is_root \ 

311 or n.props.get(NUM_TIPS_INSIDE) > 0: 

312 continue 

313 

314 parent = n.up 

315 child = n.children[0] 

316 

317 compatible = True 

318 for column in columns: 

319 states = n.props.get(column, set()) 

320 parent_states = parent.props.get(column, set()) 

321 child_states = child.props.get(column, set()) 

322 # if mediator has unresolved states, it should hesitate between the parent and the child: 

323 if len(states) < 2 or states != child_states | parent_states: 

324 compatible = False 

325 break 

326 

327 if compatible: 

328 parent.remove_child(n) 

329 parent.add_child(child) 

330 # update the uncompressed tree 

331 for ii in n.props.get(INTERNAL_NODES_INSIDE): 

332 for _ in ii: 

333 for c in list(_.children): 

334 _.up.add_child(c) 

335 _.up.remove_child(_) 

336 num_removed += 1 

337 if num_removed: 

338 logging.getLogger('pastml').debug("Removed {} internal node{}" 

339 " with the state unresolved between the parent's and the only child's." 

340 .format(num_removed, '' if num_removed == 1 else 's'))