Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/tree.py: 10%

347 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-21 09:19 +0100

1import logging 

2import os 

3import re 

4from collections import Counter, defaultdict 

5from datetime import datetime 

6 

7from Bio import Phylo 

8#from ete3 import Tree, TreeNode 

9from ete4 import Tree 

10 

11POSTORDER = 'postorder' 

12 

13INORDER = 'inorder' 

14 

15PREORDER = 'preorder' 

16 

17DATE = 'date' 

18DATE_CI = 'date_CI' 

19 

20DATE_REGEX = r'[+-]*[\d]+[.\d]*(?:[e][+-][\d]+){0,1}' 

21DATE_COMMENT_REGEX = '[&,:]date[=]["]{{0,1}}({})["]{{0,1}}'.format(DATE_REGEX) 

22CI_DATE_REGEX_LSD = '[&,:]CI_date[=]["]{{0,1}}[{{]{{0,1}}({})\s*[,;]{{0,1}}\s*({})[}}]{{0,1}}["]{{0,1}}'.format(DATE_REGEX, DATE_REGEX) 

23CI_DATE_REGEX_PASTML = '[&,:]date_CI[=]["]{{0,1}}({})[|]({})["]{{0,1}}'.format(DATE_REGEX, DATE_REGEX) 

24COLUMN_REGEX_PASTML = '[&,]{column}[=]([^]^,]+)' 

25 

26IS_POLYTOMY = 'polytomy' 

27 

28 

29def get_dist_to_root(tip): 

30 dist_to_root = 0 

31 n = tip 

32 while not n.is_root: 

33 dist_to_root += n.dist 

34 n = n.up 

35 return dist_to_root 

36 

37 

38def annotate_dates(forest, root_dates=None): 

39 if root_dates is None: 

40 root_dates = [0] * len(forest) 

41 for tree, root_date in zip(forest, root_dates): 

42 for node in tree.traverse('preorder'): 

43 if node.props.get(DATE, None) is None: 

44 if node.is_root: 

45 node.add_prop(DATE, root_date if root_date else 0) 

46 else: 

47 node.add_prop(DATE, node.up.props.get(DATE) + node.dist) 

48 else: 

49 node.add_prop(DATE, float(node.props.get(DATE))) 

50 ci = node.props.get(DATE_CI, None) 

51 if ci and not isinstance(ci, list) and not isinstance(ci, tuple): 

52 node.del_prop(DATE_CI) 

53 if isinstance(ci, str) and '|' in ci: 

54 try: 

55 node.add_prop(DATE_CI, [float(_) for _ in ci.split('|')]) 

56 except: 

57 pass 

58 

59 

60def name_tree(tree, suffix=""): 

61 """ 

62 Names all the tree nodes that are not named or have non-unique names, with unique names. 

63 

64 :param tree: tree to be named 

65 :type tree: ete3.Tree 

66 

67 :return: void, modifies the original tree 

68 """ 

69 existing_names = Counter() 

70 n_nodes = 0 

71 for _ in tree.traverse(): 

72 n_nodes += 1 

73 if _.name: 

74 existing_names[_.name] += 1 

75 if '.polytomy_' in _.name: 

76 _.add_prop(IS_POLYTOMY, 1) 

77 if n_nodes == len(existing_names): 

78 return 

79 i = 0 

80 new_existing_names = Counter() 

81 for node in tree.traverse('preorder'): 

82 name_prefix = node.name if node.name and existing_names[node.name] < 10 \ 

83 else '{}{}{}'.format('t' if node.is_leaf else 'n', i, suffix) 

84 name = 'root{}'.format(suffix) if node.is_root else name_prefix 

85 while name is None or name in new_existing_names: 

86 name = '{}{}{}'.format(name_prefix, i, suffix) 

87 i += 1 

88 node.name = name 

89 new_existing_names[name] += 1 

90 

91 

92def collapse_zero_branches(forest, features_to_be_merged=None): 

93 """ 

94 Collapses zero branches in tre tree/forest. 

95 

96 :param forest: tree or list of trees 

97 :type forest: ete3.Tree or list(ete3.Tree) 

98 :param features_to_be_merged: list of features whose values are to be merged 

99 in case the nodes are merged during collapsing 

100 :type features_to_be_merged: list(str) 

101 :return: void 

102 """ 

103 num_collapsed = 0 

104 

105 if features_to_be_merged is None: 

106 features_to_be_merged = [] 

107 

108 for tree in forest: 

109 for n in list(tree.traverse('postorder')): 

110 zero_children = [child for child in n.children if not child.is_leaf and child.dist <= 0] 

111 if not zero_children: 

112 continue 

113 for feature in features_to_be_merged: 

114 feature_intersection = set.intersection(*(child.props.get(feature, set()) for child in zero_children)) \ 

115 & n.props.get(feature, set()) 

116 if feature_intersection: 

117 value = feature_intersection 

118 else: 

119 value = set.union(*(child.props.get(feature, set()) for child in zero_children)) \ 

120 | n.props.get(feature, set()) 

121 if value: 

122 n.add_prop(feature, value) 

123 for child in zero_children: 

124 n.remove_child(child) 

125 for grandchild in child.children: 

126 n.add_child(grandchild) 

127 num_collapsed += len(zero_children) 

128 if num_collapsed: 

129 logging.getLogger('pastml').debug('Collapsed {} internal zero branches.'.format(num_collapsed)) 

130 

131 

132def remove_certain_leaves(tr, to_remove=lambda node: False): 

133 """ 

134 Removes all the branches leading to leaves identified positively by to_remove function. 

135 :param tr: the tree of interest (ete3 Tree) 

136 :param to_remove: a method to check is a leaf should be removed. 

137 :return: void, modifies the initial tree. 

138 """ 

139 

140 tips = [tip for tip in tr if to_remove(tip)] 

141 for node in tips: 

142 if node.is_root: 

143 return None 

144 parent = node.up 

145 parent.remove_child(node) 

146 # If the parent node has only one child now, merge them. 

147 if len(parent.children) == 1: 

148 brother = parent.children[0] 

149 brother.dist += parent.dist 

150 if parent.is_root: 

151 brother.up = None 

152 tr = brother 

153 else: 

154 grandparent = parent.up 

155 grandparent.remove_child(parent) 

156 grandparent.add_child(brother) 

157 return tr 

158 

159 

160def read_forest(tree_path, columns=None): 

161 try: 

162 roots = parse_nexus(tree_path, columns=columns) 

163 if roots: 

164 return roots 

165 except: 

166 pass 

167 with open(tree_path, 'r') as f: 

168 nwks = f.read().replace('\n', '').split(';') 

169 if not nwks: 

170 raise ValueError('Could not find any trees (in newick or nexus format) in the file {}.'.format(tree_path)) 

171 return [read_tree(nwk + ';', columns) for nwk in nwks[:-1]] 

172 

173 

174# def read_tree(tree_path, columns=None): 

175# tree = None 

176# for f in (3, 2, 5, 0, 1, 4, 6, 7, 8, 9): 

177# try: 

178# tree = Tree(tree_path, format=f) 

179# break 

180# except: 

181# continue 

182# if not tree: 

183# raise ValueError('Could not read the tree {}. Is it a valid newick?'.format(tree_path)) 

184# if columns: 

185# for n in tree.traverse(): 

186# for c in columns: 

187# vs = set(getattr(n, c).split('|')) if hasattr(n, c) else set() 

188# if vs: 

189# n.add_prop(c, vs) 

190# return tree 

191 

192def read_tree(tree_path, columns=None): 

193 tree = None 

194 tree = Tree(tree_path, parser=1) 

195 if columns: 

196 for n in tree.traverse(): 

197 for c in columns: 

198 vs = set(n.props.get(c).split('|')) if n.props.get(c) else set() 

199 if vs: 

200 n.add_prop(c, vs) 

201 return tree 

202 

203def parse_nexus(tree_path, columns=None): 

204 trees = [] 

205 for nex_tree in read_nexus(tree_path): 

206 todo = [(nex_tree.root, None)] 

207 tree = None 

208 while todo: 

209 clade, parent = todo.pop() 

210 dist = 0 

211 try: 

212 dist = float(clade.branch_length) 

213 except: 

214 pass 

215 name = clade.props.get('name', None) 

216 if not name: 

217 name = clade.props.get('confidence', None) 

218 if not isinstance(name, str): 

219 name = None 

220 node = Tree(dist=dist, name=name) 

221 if parent is None: 

222 tree = node 

223 else: 

224 parent.add_child(node) 

225 

226 # Parse LSD2 dates and CIs, and PastML columns 

227 date, ci = None, None 

228 columns2values = defaultdict(set) 

229 comment = clade.props.get('comment', None) 

230 if isinstance(comment, str): 

231 date = next(iter(re.findall(DATE_COMMENT_REGEX, comment)), None) 

232 ci = next(iter(re.findall(CI_DATE_REGEX_LSD, comment)), None) 

233 if ci is None: 

234 ci = next(iter(re.findall(CI_DATE_REGEX_PASTML, comment)), None) 

235 if columns: 

236 for column in columns: 

237 values = \ 

238 set.union(*(set(_.split('|')) for _ in re.findall(COLUMN_REGEX_PASTML.format(column=column), 

239 comment)), set()) 

240 if values: 

241 columns2values[column] |= values 

242 comment = clade.props.get('branch_length', None) 

243 if not ci and not parent and isinstance(comment, str): 

244 ci = next(iter(re.findall(CI_DATE_REGEX_LSD, comment)), None) 

245 if ci is None: 

246 ci = next(iter(re.findall(CI_DATE_REGEX_PASTML, comment)), None) 

247 comment = clade.props.get('confidence', None) 

248 if ci is None and comment is not None and isinstance(comment, str): 

249 ci = next(iter(re.findall(CI_DATE_REGEX_LSD, comment)), None) 

250 if ci is None: 

251 ci = next(iter(re.findall(CI_DATE_REGEX_PASTML, comment)), None) 

252 if date is not None: 

253 try: 

254 date = float(date) 

255 node.add_prop(DATE, date) 

256 except: 

257 pass 

258 if ci is not None: 

259 try: 

260 ci = [float(_) for _ in ci] 

261 node.add_prop(DATE_CI, ci) 

262 except: 

263 pass 

264 if columns2values: 

265 for c, vs in columns2values.items(): 

266 node.add_prop(c, vs) 

267 todo.extend((c, node) for c in clade.clades) 

268 for n in tree.traverse('preorder'): 

269 date, ci = n.props.get(DATE, None), n.props.get(DATE_CI, None) 

270 if date is not None or ci is not None: 

271 for c in n.children: 

272 if c.dist == 0: 

273 if c.props.get(DATE, None) is None: 

274 c.add_prop(DATE, date) 

275 if c.props.get(DATE_CI, None) is None: 

276 c.add_prop(DATE_CI, ci) 

277 for n in tree.traverse('postorder'): 

278 date, ci = n.props.get(DATE, None), n.props.get(DATE_CI, None) 

279 if not n.is_root and n.dist == 0 and (date is not None or ci is not None): 

280 if n.up.props.get(DATE, None) is None: 

281 n.up.add_prop(DATE, date) 

282 if n.up.props.get(DATE_CI, None) is None: 

283 n.up.add_prop(DATE_CI, ci) 

284 

285 # propagate dates up to the root if needed 

286 if tree.props.get(DATE, None) is None: 

287 dated_node = next((n for n in tree.traverse() if n.props.get(DATE, None) is not None), None) 

288 if dated_node: 

289 while dated_node != tree: 

290 if dated_node.up.props.get(DATE, None) is None: 

291 dated_node.up.add_prop(DATE, dated_node.props.get(DATE) - dated_node.dist) 

292 dated_node = dated_node.up 

293 

294 trees.append(tree) 

295 return trees 

296 

297 

298def read_nexus(tree_path): 

299 with open(tree_path, 'r') as f: 

300 nexus = f.read() 

301 # replace CI_date="2019(2018,2020)" with CI_date="2018 2020" 

302 nexus = re.sub(r'CI_date="({})\(({}),({})\)"'.format(DATE_REGEX, DATE_REGEX, DATE_REGEX), r'CI_date="\2 \3"', 

303 nexus) 

304 temp = tree_path + '.{}.temp'.format(datetime.timestamp(datetime.now())) 

305 with open(temp, 'w') as f: 

306 f.write(nexus) 

307 trees = list(Phylo.parse(temp, 'nexus')) 

308 os.remove(temp) 

309 return trees 

310 

311 

312def depth_first_traversal(node): 

313 yield node, PREORDER 

314 for i, child in enumerate(node.children): 

315 if i != 0: 

316 yield node, INORDER 

317 for _ in depth_first_traversal(child): 

318 yield _ 

319 yield node, POSTORDER 

320 

321 

322def resolve_trees(column2states, forest): 

323 """ 

324 Resolved polytomies based on state predictions: 

325 if a parent P in a state A has n children 2 <= m < n of which are in state B, 

326 we add a parent to these children (who becomes a child of P) in a state B 

327 at the distance of the oldest of the children. 

328 

329 :param column2states: character to possible state mapping 

330 :type column2states: dict 

331 :param forest: a forest of trees of interest 

332 :type list(ete.Tree) 

333 :return: number of newly created nodes. 

334 :rtype: int 

335 """ 

336 columns = sorted(column2states.keys()) 

337 

338 col2state2i = {c: dict(zip(states, range(len(states)))) for (c, states) in column2states.items()} 

339 

340 def get_prediction(n): 

341 return '.'.join('-'.join(str(i) for i in sorted([col2state2i[c][_] for _ in n.props.get(c, set())])) 

342 for c in columns) 

343 

344 num_new_nodes = 0 

345 

346 for tree in forest: 

347 todo = [tree] 

348 while todo: 

349 n = todo.pop() 

350 todo.extend(n.children) 

351 if len(n.children) > 2: 

352 state2children = defaultdict(list) 

353 for c in n.children: 

354 state2children[get_prediction(c)].append(c) 

355 if len(state2children) > 1: 

356 for state, children in state2children.items(): 

357 if group_children_if_needed(n, children, columns, state): 

358 num_new_nodes += 1 

359 if num_new_nodes: 

360 logging.getLogger('pastml').debug( 

361 'Created {} new internal nodes while resolving polytomies'.format(num_new_nodes)) 

362 else: 

363 logging.getLogger('pastml').debug('Could not resolve any polytomy') 

364 return num_new_nodes 

365 

366 

367def states_are_different(n1, n2, columns): 

368 for c in columns: 

369 if not n1.props.get(c, set()) & n2.props.get(c, set()): 

370 return True 

371 return False 

372 

373 

374def group_children_if_needed(n, children, columns, state): 

375 if len(children) <= 1: 

376 return False 

377 child = min(children, key=lambda _: _.dist) 

378 if not states_are_different(n, child, columns): 

379 return False 

380 dist = child.dist 

381 pol = n.add_child(dist=dist, name='{}.polytomy_{}'.format(n.name, state)) 

382 pol.add_prop(IS_POLYTOMY, 1) 

383 c_date = child.props.get(DATE) 

384 pol.add_prop(DATE, c_date) 

385 n_ci = n.props.get(DATE_CI, None) 

386 c_ci = child.props.get(DATE_CI, None) 

387 pol.add_prop(DATE_CI, (None if not n_ci or not isinstance(n_ci, list) 

388 else [n_ci[0], 

389 (c_ci[1] if c_ci and isinstance(c_ci, list) and len(c_ci) > 1 

390 else c_date)])) 

391 for c in columns: 

392 pol.add_prop(c, child.props.get(c)) 

393 for c in children: 

394 n.remove_child(c) 

395 pol.add_child(c, dist=c.dist - dist) 

396 return True 

397 

398 

399def unresolve_trees(column2states, forest): 

400 """ 

401 Unresolves polytomies whose states do not correspond to child states after likelihood recalculation. 

402 

403 :param column2states: character to possible state mapping 

404 :type column2states: dict 

405 :param forest: a forest of trees of interest 

406 :type list(ete.Tree) 

407 :return: number of newly deleted nodes. 

408 :rtype: int 

409 """ 

410 columns = sorted(column2states.keys()) 

411 

412 col2state2i = {c: dict(zip(states, range(len(states)))) for (c, states) in column2states.items()} 

413 

414 def get_prediction(n): 

415 return '.'.join('-'.join(str(i) for i in sorted([col2state2i[c][_] for _ in n.props.get(c, set())])) 

416 for c in columns) 

417 

418 num_removed_nodes = 0 

419 num_new_nodes = 0 

420 

421 def remove_node(n): 

422 parent = n.up 

423 for c in n.children: 

424 parent.add_child(c, dist=c.dist + n.dist) 

425 parent.remove_child(n) 

426 

427 num_polytomies = 0 

428 

429 for tree in forest: 

430 for n in tree.traverse('postorder'): 

431 if n.props.get(IS_POLYTOMY, False): 

432 num_polytomies += 1 

433 

434 state2children = defaultdict(list) 

435 n_children = list(n.children) 

436 for c in n_children: 

437 state2children[get_prediction(c)].append(c) 

438 parent = n.up 

439 

440 # if the state is the same as all the child states, it's still a good polytomy resolution 

441 if len(state2children) == 1 and not states_are_different(n, n_children[0], columns): 

442 # Just need to check that it is not the same state as the parent (then we don't need this polytomy) 

443 if not states_are_different(n, parent, columns): 

444 num_removed_nodes += 1 

445 remove_node(n) 

446 continue 

447 

448 num_removed_nodes += 1 

449 remove_node(n) 

450 

451 # now let's try to create new polytomies above 

452 above_state2children = defaultdict(list) 

453 for c in parent.children: 

454 state2children[get_prediction(c)].append(c) 

455 for state, children in above_state2children.items(): 

456 if len(children) <= 1: 

457 continue 

458 if set(children) != set(n_children): 

459 if group_children_if_needed(parent, children, columns, state): 

460 num_new_nodes += 1 

461 if num_removed_nodes: 

462 logging.getLogger('pastml').debug( 

463 'Removed {} polytomy resolution{} as inconsistent with model parameters.' 

464 .format(num_removed_nodes, 's' if num_removed_nodes > 1 else '')) 

465 if num_new_nodes: 

466 logging.getLogger('pastml').debug( 

467 'Created {} new polytomy resolution{}.'.format(num_new_nodes, 's' if num_new_nodes > 1 else '')) 

468 elif num_polytomies - num_removed_nodes + num_new_nodes: 

469 logging.getLogger('pastml').debug('All the polytomy resolutions are consistent with model parameters.') 

470 return num_removed_nodes 

471 

472 

473def clear_extra_features(forest, features): 

474 features = set(features) | {'name', 'dist', 'support'} 

475 for tree in forest: 

476 for n in tree.traverse(): 

477 for f in set(n.props) - features: 

478 if f not in features: 

479 n.del_prop(f) 

480 

481 

482def copy_forest(forest, features=None): 

483 features = set(features if features else forest[0].props) 

484 

485 copied_forest = [] 

486 for tree in forest: 

487 copied_tree = Tree() 

488 todo = [(tree, copied_tree)] 

489 copied_forest.append(copied_tree) 

490 while todo: 

491 n, copied_n = todo.pop() 

492 copied_n.dist = n.dist 

493 copied_n.support = n.support 

494 copied_n.name = n.name 

495 for f in features: 

496 if f in n.props.keys(): 

497 copied_n.add_prop(f, n.props.get(f)) 

498 for c in n.children: 

499 todo.append((c, copied_n.add_child())) 

500 return copied_forest