Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/ml.py: 12%

510 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-21 09:19 +0100

1import logging 

2from collections import Counter 

3 

4import numpy as np 

5import pandas as pd 

6from scipy.optimize import minimize 

7 

8from pastml import get_personalized_feature_name, CHARACTER, METHOD, NUM_SCENARIOS, NUM_UNRESOLVED_NODES, \ 

9 NUM_STATES_PER_NODE, PERC_UNRESOLVED, STATES 

10from pastml.models import ModelWithFrequencies 

11from pastml.parsimony import parsimonious_acr, MP 

12 

13LOG_LIKELIHOOD = 'log_likelihood' 

14RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR = '{}_restricted_{{}}'.format(LOG_LIKELIHOOD) 

15 

16JOINT = 'JOINT' 

17MPPA = 'MPPA' 

18MAP = 'MAP' 

19ALL = 'ALL' 

20ML = 'ML' 

21 

22MARGINAL_PROBABILITIES = 'marginal_probabilities' 

23 

24MODEL = 'model' 

25 

26MIN_VALUE = np.log10(np.finfo(np.float64).eps) 

27MAX_VALUE = np.log10(np.finfo(np.float64).max) 

28 

29MARGINAL_ML_METHODS = {MPPA, MAP} 

30ML_METHODS = MARGINAL_ML_METHODS | {JOINT} 

31META_ML_METHODS = {ML, ALL} 

32 

33BU_LH = 'BOTTOM_UP_LIKELIHOOD' 

34TD_LH = 'TOP_DOWN_LIKELIHOOD' 

35LH = 'LIKELIHOOD' 

36LH_SF = 'LIKELIHOOD_SF' 

37BU_LH_SF = 'BOTTOM_UP_LIKELIHOOD_SF' 

38BU_LH_JOINT_STATES = 'BOTTOM_UP_LIKELIHOOD_JOINT_STATES' 

39TD_LH_SF = 'TOP_DOWM_LIKELIHOOD_SF' 

40ALLOWED_STATES = 'ALLOWED_STATES' 

41STATE_COUNTS = 'STATE_COUNTS' 

42JOINT_STATE = 'JOINT_STATE' 

43 

44 

45def is_marginal(method): 

46 """ 

47 Checks if the method is marginal, i.e. MAP, MPPA, or one of the meta-methods (ALL, ML). 

48 

49 :param method: prediction method 

50 :type method: str 

51 :return: bool 

52 """ 

53 return method in MARGINAL_ML_METHODS or method in META_ML_METHODS 

54 

55 

56def is_ml(method): 

57 """ 

58 Checks if the method is max likelihood, i.e. JOINT or one of the marginal ones. 

59 

60 :param method: prediction method 

61 :type method: str 

62 :return: bool 

63 """ 

64 return method in ML_METHODS or method in META_ML_METHODS 

65 

66 

67def is_meta_ml(method): 

68 """ 

69 Checks if the method is a meta max likelihood method, combining several methods, i.e. ML or ALL. 

70 

71 :param method: prediction method 

72 :type method: str 

73 :return: bool 

74 """ 

75 return method in META_ML_METHODS 

76 

77 

78def get_default_ml_method(): 

79 return MPPA 

80 

81 

82def get_bottom_up_loglikelihood(tree, character, model, is_marginal=True, alter=True): 

83 """ 

84 Calculates the bottom-up loglikelihood for the given tree. 

85 The likelihood for each node is stored in the corresponding feature, 

86 given by get_personalised_feature_name(feature, BU_LH). 

87 

88 :param model: model of character evolution 

89 :type model: pastml.models.Model 

90 :param is_marginal: whether the likelihood reconstruction is marginal (true) or joint (false) 

91 :type is_marginal: bool 

92 :param tree: tree of interest 

93 :type tree: ete3.Tree 

94 :param character: character for which the likelihood is calculated 

95 :type character: str 

96 :return: log likelihood 

97 :rtype: float 

98 """ 

99 

100 altered_nodes = [] 

101 if 0 == model.tau and alter: 

102 altered_nodes = alter_zero_node_allowed_states(tree, character) 

103 

104 lh_sf_feature = get_personalized_feature_name(character, BU_LH_SF) 

105 lh_feature = get_personalized_feature_name(character, BU_LH) 

106 lh_joint_state_feature = get_personalized_feature_name(character, BU_LH_JOINT_STATES) 

107 allowed_state_feature = get_personalized_feature_name(character, ALLOWED_STATES) 

108 

109 get_pij = model.get_Pij_t 

110 

111 for node in tree.traverse('postorder'): 

112 calc_node_bu_likelihood(node, allowed_state_feature, lh_feature, lh_sf_feature, lh_joint_state_feature, 

113 is_marginal, get_pij) 

114 root_likelihoods = tree.props.get(lh_feature) * model.frequencies 

115 root_likelihoods = root_likelihoods.sum() if is_marginal else root_likelihoods.max() 

116 

117 if altered_nodes: 

118 if is_marginal: 

119 unalter_zero_node_allowed_states(altered_nodes, character) 

120 else: 

121 unalter_zero_node_joint_states(altered_nodes, character) 

122 

123 return np.log(root_likelihoods) - tree.props.get(lh_sf_feature) / np.log10(np.e) 

124 

125 

126def calc_node_bu_likelihood(node, allowed_state_feature, lh_feature, lh_sf_feature, lh_joint_state_feature, is_marginal, 

127 get_pij): 

128 allowed_states = node.props.get(allowed_state_feature) 

129 log_likelihood_array = np.log10(np.ones(len(allowed_states), dtype=np.float64) * allowed_states) 

130 factors = 0 

131 for child in node.children: 

132 child_likelihoods = get_pij(child.dist) * child.props.get(lh_feature) 

133 if is_marginal: 

134 child_likelihoods = child_likelihoods.sum(axis=1) 

135 else: 

136 child_states = child_likelihoods.argmax(axis=1) 

137 child.add_prop(lh_joint_state_feature, child_states) 

138 child_likelihoods = child_likelihoods.max(axis=1) 

139 child_likelihoods = np.maximum(child_likelihoods, 0) 

140 log_likelihood_array += np.log10(child_likelihoods) 

141 if np.all(log_likelihood_array == -np.inf): 

142 raise PastMLLikelihoodError("The parent node {} and its child node {} have non-intersecting states, " 

143 "and are connected by a zero-length ({:g}) branch. " 

144 "This creates a zero likelihood value. " 

145 "To avoid this issue check the restrictions on these node states " 

146 "and/or use a smoothing factor (tau)." 

147 .format(node.name, child.name, child.dist)) 

148 factors += rescale_log(log_likelihood_array) 

149 node.add_prop(lh_feature, np.power(10, log_likelihood_array)) 

150 node.add_prop(lh_sf_feature, factors + sum(_.props.get(lh_sf_feature) for _ in node.children)) 

151 

152 

153def rescale_log(loglikelihood_array): 

154 """ 

155 Rescales the likelihood array if it gets too small/large, by multiplying it by a factor of 10. 

156 :param loglikelihood_array: numpy array containing the loglikelihood to be rescaled 

157 :return: float, factor of 10 by which the likelihood array has been multiplies. 

158 """ 

159 

160 max_limit = MAX_VALUE 

161 min_limit = MIN_VALUE 

162 

163 non_zero_loglh_array = loglikelihood_array[loglikelihood_array > -np.inf] 

164 min_lh_value = np.min(non_zero_loglh_array) 

165 max_lh_value = np.max(non_zero_loglh_array) 

166 

167 factors = 0 

168 if max_lh_value > max_limit: 

169 factors = max_limit - max_lh_value - 1 

170 elif min_lh_value < min_limit: 

171 factors = min(min_limit - min_lh_value + 1, max_limit - max_lh_value - 1) 

172 loglikelihood_array += factors 

173 return factors 

174 

175 

176def optimize_likelihood_params(forest, character, observed_frequencies, model): 

177 """ 

178 Optimizes the likelihood parameters (state frequencies and scaling factor) for the given trees. 

179 

180 :param model: model of character evolution 

181 :type model: pastml.model.Model 

182 :param avg_br_len: avg branch length 

183 :type avg_br_len: float 

184 :param forest: trees of interest 

185 :type forest: list(ete3.Tree) 

186 :param character: character for which the likelihood is optimised 

187 :type character: str 

188 :param frequencies: array of initial state frequencies 

189 :type frequencies: numpy.array 

190 :param sf: initial scaling factor 

191 :type sf: float 

192 :param optimise_sf: whether the scaling factor needs to be optimised 

193 :type optimise_sf: bool 

194 :param optimise_frequencies: whether the state frequencies need to be optimised 

195 :type optimise_frequencies: bool 

196 :param tau: a smoothing factor to apply to branch lengths during likelihood calculation. 

197 If set to zero (default), zero internal branches will be collapsed instead. 

198 :type tau: float 

199 :return: optimized parameters and log likelihood: ((frequencies, scaling_factor), optimum) 

200 :rtype: tuple 

201 """ 

202 bounds = model.get_bounds() 

203 

204 def get_v(ps): 

205 if np.any(pd.isnull(ps)): 

206 return np.nan 

207 model.set_params_from_optimised(ps) 

208 res = sum(get_bottom_up_loglikelihood(tree=tree, character=character, is_marginal=True, model=model, alter=True) 

209 for tree in forest) 

210 return np.inf if pd.isnull(res) else -res 

211 

212 if np.any(observed_frequencies <= 0): 

213 observed_frequencies = np.maximum(observed_frequencies, 1e-10) 

214 

215 x0_JC = model.get_optimised_parameters() 

216 optimise_frequencies = isinstance(model, ModelWithFrequencies) and model._optimise_frequencies 

217 x0_EFT = x0_JC 

218 if optimise_frequencies: 

219 model.frequencies = observed_frequencies 

220 x0_EFT = model.get_optimised_parameters() 

221 log_lh_JC = -get_v(x0_JC) 

222 log_lh_EFT = log_lh_JC if not optimise_frequencies else -get_v(x0_EFT) 

223 

224 best_log_lh = max(log_lh_JC, log_lh_EFT) 

225 

226 for i in range(100): 

227 if i == 0: 

228 vs = x0_JC 

229 elif optimise_frequencies and i == 1: 

230 vs = x0_EFT 

231 else: 

232 vs = np.random.uniform(bounds[:, 0], bounds[:, 1]) 

233 fres = minimize(get_v, x0=vs, method='L-BFGS-B', bounds=bounds) 

234 if fres.success and not np.any(np.isnan(fres.x)): 

235 if -fres.fun >= best_log_lh: 

236 model.set_params_from_optimised(fres.x) 

237 return -fres.fun 

238 model.set_params_from_optimised(x0_JC if log_lh_JC >= log_lh_EFT else x0_EFT) 

239 return best_log_lh 

240 

241 

242def calculate_top_down_likelihood(tree, character, model): 

243 """ 

244 Calculates the top-down likelihood for the given tree. 

245 The likelihood for each node is stored in the corresponding feature, 

246 given by get_personalised_feature_name(feature, TD_LH). 

247 

248 To calculate the top-down likelihood of a node, we assume that the tree is rooted in this node 

249 and combine the likelihoods of the “up-subtrees”, 

250 e.g. to calculate the top-down likelihood of a node N1 being in a state i, 

251 given that its parent node is P and its brother node is N2, we imagine that the tree is re-rooted in N1, 

252 therefore P becoming the child of N1, and N2 its grandchild. 

253 We then calculate the bottom-up likelihood from the P subtree: 

254 L_top_down(N1, i) = sum_j P(i -> j, dist(N1, P)) * L_top_down(P) * sum_k P(j -> k, dist(N2, P)) * L_bottom_up (N2). 

255 

256 For the root node we assume its top-down likelihood to be 1 for all the states. 

257 

258 :param model: model of character evolution 

259 :type model: pastml.models.Model 

260 :param character: character whose ancestral state likelihood is being calculated 

261 :type character: str 

262 :param tree: tree of interest (with bottom-up likelihood pre-calculated) 

263 :type tree: ete3.Tree 

264 :return: void, stores the node top-down likelihoods in the get_personalised_feature_name(feature, TD_LH) feature. 

265 """ 

266 td_lh_feature = get_personalized_feature_name(character, TD_LH) 

267 td_lh_sf_feature = get_personalized_feature_name(character, TD_LH_SF) 

268 bu_lh_feature = get_personalized_feature_name(character, BU_LH) 

269 bu_lh_sf_feature = get_personalized_feature_name(character, BU_LH_SF) 

270 

271 get_pij = model.get_Pij_t 

272 for node in tree.traverse('preorder'): 

273 calc_node_td_likelihood(node, td_lh_feature, td_lh_sf_feature, bu_lh_feature, bu_lh_sf_feature, get_pij) 

274 

275 

276def calc_node_td_likelihood(node, td_lh_feature, td_lh_sf_feature, bu_lh_feature, bu_lh_sf_feature, get_pij): 

277 if node.is_root: 

278 node.add_prop(td_lh_feature, np.ones(len(node.props.get(bu_lh_feature)), np.float64)) 

279 node.add_prop(td_lh_sf_feature, 0) 

280 return 

281 parent = node.up 

282 node_pjis = np.transpose(get_pij(node.dist)) 

283 node_contribution = node.props.get(bu_lh_feature).dot(node_pjis) 

284 node_contribution[node_contribution <= 0] = 1 

285 parent_loglikelihood = np.log10(parent.props.get(td_lh_feature)) \ 

286 + np.log10(parent.props.get(bu_lh_feature)) - np.log10(node_contribution) 

287 factors = parent.props.get(td_lh_sf_feature) \ 

288 + parent.props.get(bu_lh_sf_feature) - node.props.get(bu_lh_sf_feature) 

289 factors += rescale_log(parent_loglikelihood) 

290 parent_likelihood = np.power(10, parent_loglikelihood) 

291 td_likelihood = parent_likelihood.dot(node_pjis) 

292 node.add_prop(td_lh_feature, np.maximum(td_likelihood, 0)) 

293 node.add_prop(td_lh_sf_feature, factors) 

294 

295 

296def initialize_allowed_states(tree, feature, states): 

297 """ 

298 Initializes the allowed state arrays for tips based on their states given by the feature. 

299 

300 :param tree: tree for which the tip likelihoods are to be initialized 

301 :type tree: ete3.Tree 

302 :param feature: feature in which the tip states are stored 

303 (the value could be None for a missing state or list if multiple stated are possible) 

304 :type feature: str 

305 :param states: ordered array of states. 

306 :type states: numpy.array 

307 :return: void, adds the get_personalised_feature_name(feature, ALLOWED_STATES) feature to tree tips. 

308 """ 

309 allowed_states_feature = get_personalized_feature_name(feature, ALLOWED_STATES) 

310 n = len(states) 

311 state2index = dict(zip(states, range(n))) 

312 

313 for node in tree.traverse(): 

314 node_states = node.props.get(feature, set()) 

315 # Check if state is a string and convert it to a set 

316 if isinstance(node_states, str): 

317 node_states = {node_states} 

318 

319 if not node_states: 

320 allowed_states = np.ones(n, dtype=int) 

321 else: 

322 allowed_states = np.zeros(n, dtype=int) 

323 for state in node_states: 

324 allowed_states[state2index[state]] = 1 

325 node.add_prop(allowed_states_feature, allowed_states) 

326 

327def get_zero_clusters_with_states(tree, feature): 

328 """ 

329 Returns the zero-distance clusters in the given tree. 

330 

331 :param tree: ete3.Tree, the tree of interest 

332 :return: iterator of lists of nodes that are at zero distance from each other and have states specified for them. 

333 """ 

334 

335 def has_state(_): 

336 state = _.props.get(feature, None) 

337 return state is not None and state != '' 

338 

339 todo = [tree] 

340 

341 while todo: 

342 zero_cluster_with_states = [] 

343 extension = [todo.pop()] 

344 

345 while extension: 

346 n = extension.pop() 

347 if has_state(n): 

348 zero_cluster_with_states.append(n) 

349 for c in n.children: 

350 if c.dist == 0: 

351 extension.append(c) 

352 else: 

353 todo.append(c) 

354 if len(zero_cluster_with_states) > 1: 

355 yield zero_cluster_with_states 

356 

357 

358def alter_zero_node_allowed_states(tree, feature): 

359 """ 

360 Alters the bottom-up likelihood arrays for zero-distance nodes 

361 to make sure they do not contradict with other zero-distance node siblings/ancestors/descendants. 

362 

363 :param tree: ete3.Tree, the tree of interest 

364 :param feature: str, character for which the likelihood is altered 

365 :return: void, modifies the get_personalised_feature_name(feature, BU_LH) feature to zero-distance nodes. 

366 """ 

367 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES) 

368 allowed_state_feature_unaltered = get_personalized_feature_name(feature, ALLOWED_STATES + '.initial') 

369 

370 altered_nodes = [] 

371 

372 for zero_cluster_with_states in get_zero_clusters_with_states(tree, feature): 

373 # If there is a common state do nothing 

374 counts = None 

375 for c in zero_cluster_with_states: 

376 if counts is None: 

377 counts = c.props.get(allowed_state_feature).copy() 

378 else: 

379 counts += c.props.get(allowed_state_feature) 

380 if counts.max() == len(zero_cluster_with_states): 

381 continue 

382 # Otherwise set all zero-cluster node states to state union 

383 allowed_states = None 

384 for c in zero_cluster_with_states: 

385 initial_allowed_states = c.props.get(allowed_state_feature).copy() 

386 if allowed_states is None: 

387 allowed_states = initial_allowed_states.copy() 

388 else: 

389 allowed_states[np.nonzero(initial_allowed_states)] = 1 

390 c.add_prop(allowed_state_feature, allowed_states) 

391 c.add_prop(allowed_state_feature_unaltered, initial_allowed_states) 

392 altered_nodes.append(c) 

393 return altered_nodes 

394 

395 

396def unalter_zero_node_allowed_states(altered_nodes, feature): 

397 """ 

398 Unalters the bottom-up likelihood arrays for zero-distance nodes 

399 to contain ones only in their states. 

400 

401 :param altered_nodes: list of modified nodes 

402 :param feature: str, character for which the likelihood was altered 

403 :return: void, modifies the get_personalised_feature_name(feature, BU_LH) feature to zero-distance nodes. 

404 """ 

405 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES) 

406 allowed_state_feature_unaltered = get_personalized_feature_name(feature, ALLOWED_STATES + '.initial') 

407 for n in altered_nodes: 

408 initial_allowed_states = n.props.get(allowed_state_feature_unaltered) 

409 allowed_states = n.props.get(allowed_state_feature) & initial_allowed_states 

410 n.add_prop(allowed_state_feature, 

411 (allowed_states if np.any(allowed_states > 0) else initial_allowed_states)) 

412 

413 

414def unalter_zero_node_joint_states(altered_nodes, feature): 

415 """ 

416 Unalters the joint states for zero-distance nodes 

417 to contain only their states. 

418 

419 :param altered_nodes: list of modified nodes 

420 :param feature: str, character for which the likelihood was altered 

421 :return: void, modifies the get_personalised_feature_name(feature, BU_LH_JOINT_STATES) feature to zero-distance nodes. 

422 """ 

423 lh_joint_state_feature = get_personalized_feature_name(feature, BU_LH_JOINT_STATES) 

424 allowed_state_feature_unaltered = get_personalized_feature_name(feature, ALLOWED_STATES + '.initial') 

425 for n in altered_nodes: 

426 initial_allowed_states = n.props.get(allowed_state_feature_unaltered) 

427 allowed_index = np.argmax(initial_allowed_states) 

428 if len(initial_allowed_states[initial_allowed_states > 0]) == 1: 

429 n.add_prop(lh_joint_state_feature, np.ones(len(initial_allowed_states), int) * allowed_index) 

430 else: 

431 joint_states = n.props.get(lh_joint_state_feature) 

432 for i in range(len(initial_allowed_states)): 

433 if not initial_allowed_states[joint_states[i]]: 

434 joint_states[i] = allowed_index 

435 

436 

437def calculate_marginal_likelihoods(tree, feature, frequencies, clean_up=True): 

438 """ 

439 Calculates marginal likelihoods for each tree node 

440 by multiplying state frequencies with their bottom-up and top-down likelihoods. 

441 

442 :param tree: ete3.Tree, the tree of interest 

443 :param feature: str, character for which the likelihood is calculated 

444 :param frequencies: numpy array of state frequencies 

445 :return: void, stores the node marginal likelihoods in the get_personalised_feature_name(feature, LH) feature. 

446 """ 

447 bu_lh_feature = get_personalized_feature_name(feature, BU_LH) 

448 bu_lh_sf_feature = get_personalized_feature_name(feature, BU_LH_SF) 

449 td_lh_feature = get_personalized_feature_name(feature, TD_LH) 

450 td_lh_sf_feature = get_personalized_feature_name(feature, TD_LH_SF) 

451 lh_feature = get_personalized_feature_name(feature, LH) 

452 lh_sf_feature = get_personalized_feature_name(feature, LH_SF) 

453 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES) 

454 

455 for node in tree.traverse('preorder'): 

456 calc_node_marginal_likelihood(node, lh_feature, lh_sf_feature, bu_lh_feature, bu_lh_sf_feature, td_lh_feature, 

457 td_lh_sf_feature, allowed_state_feature, frequencies, clean_up) 

458 

459 

460def calc_node_marginal_likelihood(node, lh_feature, lh_sf_feature, bu_lh_feature, bu_lh_sf_feature, td_lh_feature, 

461 td_lh_sf_feature, allowed_state_feature, frequencies, clean_up): 

462 loglikelihood = np.log10(node.props.get(bu_lh_feature)) + np.log10(node.props.get(td_lh_feature)) \ 

463 + np.log10(frequencies * node.props.get(allowed_state_feature)) 

464 factors = rescale_log(loglikelihood) 

465 node.add_prop(lh_feature, np.power(10, loglikelihood)) 

466 node.add_prop(lh_sf_feature, factors + node.props.get(td_lh_sf_feature) + node.props.get(bu_lh_sf_feature)) 

467 if clean_up: 

468 node.del_prop(bu_lh_feature) 

469 node.del_prop(bu_lh_sf_feature) 

470 node.del_prop(td_lh_feature) 

471 node.del_prop(td_lh_sf_feature) 

472 

473 

474def check_marginal_likelihoods(tree, feature): 

475 """ 

476 Sanity check: combined bottom-up and top-down likelihood of each node of the tree must be the same. 

477 

478 :param tree: ete3.Tree, the tree of interest 

479 :param feature: str, character for which the likelihood is calculated 

480 :return: void, stores the node marginal likelihoods in the get_personalised_feature_name(feature, LH) feature. 

481 """ 

482 lh_feature = get_personalized_feature_name(feature, LH) 

483 lh_sf_feature = get_personalized_feature_name(feature, LH_SF) 

484 

485 for node in tree.traverse(): 

486 if not node.is_root: 

487 node_loglh = np.log10(node.props.get(lh_feature).sum()) - node.props.get(lh_sf_feature) 

488 parent_loglh = np.log10(node.up.props.get(lh_feature).sum()) - node.up.props.get(lh_sf_feature) 

489 assert (round(node_loglh, 2) == round(parent_loglh, 2)) 

490 

491 

492def convert_likelihoods_to_probabilities(tree, feature, states): 

493 """ 

494 Normalizes each node marginal likelihoods to convert them to marginal probabilities. 

495 

496 :param states: numpy array of states in the order corresponding to the marginal likelihood arrays 

497 :param tree: ete3.Tree, the tree of interest 

498 :param feature: str, character for which the probabilities are calculated 

499 :return: pandas DataFrame, that maps node names to their marginal likelihoods. 

500 """ 

501 lh_feature = get_personalized_feature_name(feature, LH) 

502 name2probs = {} 

503 

504 for node in tree.traverse(): 

505 lh = node.props.get(lh_feature) 

506 name2probs[node.name] = lh / lh.sum() 

507 

508 return pd.DataFrame.from_dict(name2probs, orient='index', columns=states) 

509 

510 

511def choose_ancestral_states_mppa(tree, feature, states, force_joint=True): 

512 """ 

513 Chooses node ancestral states based on their marginal probabilities using MPPA method. 

514 

515 :param force_joint: make sure that Joint state is chosen even if it has a low probability. 

516 :type force_joint: bool 

517 :param tree: tree of interest 

518 :type tree: ete3.Tree 

519 :param feature: character for which the ancestral states are to be chosen 

520 :type feature: str 

521 :param states: possible character states in order corresponding to the probabilities array 

522 :type states: numpy.array 

523 :return: number of ancestral scenarios selected, 

524 calculated by multiplying the number of selected states for all nodes. 

525 Also modified the get_personalized_feature_name(feature, ALLOWED_STATES) feature of each node 

526 to only contain the selected states. 

527 :rtype: int 

528 """ 

529 lh_feature = get_personalized_feature_name(feature, LH) 

530 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES) 

531 joint_state_feature = get_personalized_feature_name(feature, JOINT_STATE) 

532 

533 n = len(states) 

534 _, state2array = get_state2allowed_states(states, False) 

535 

536 num_scenarios = 1 

537 unresolved_nodes = 0 

538 num_states = 0 

539 

540 # If force_joint == True, 

541 # we make sure that the joint state is always chosen, 

542 # for this we sort the marginal probabilities array as [lowest_non_joint_mp, ..., highest_non_joint_mp, joint_mp] 

543 # select k in 1:n such as the correction between choosing 0, 0, ..., 1/k, ..., 1/k and our sorted array is min 

544 # and return the corresponding states 

545 for node in tree.traverse(): 

546 marginal_likelihoods = node.props.get(lh_feature) 

547 #if hasattr(node, allowed_state_feature + '.initial'): 

548 if allowed_state_feature + '.initial' in node.props.keys(): 

549 marginal_likelihoods *= node.props.get(allowed_state_feature + '.initial') 

550 marginal_probs = marginal_likelihoods / marginal_likelihoods.sum() 

551 if force_joint: 

552 joint_index = node.props.get(joint_state_feature) 

553 joint_prob = marginal_probs[joint_index] 

554 marginal_probs = np.hstack((np.sort(np.delete(marginal_probs, joint_index)), [joint_prob])) 

555 else: 

556 marginal_probs = np.sort(marginal_probs) 

557 best_k = n 

558 best_correstion = np.inf 

559 for k in range(1, n + 1): 

560 correction = np.hstack((np.zeros(n - k), np.ones(k) / k)) - marginal_probs 

561 correction = correction.dot(correction) 

562 if correction < best_correstion: 

563 best_correstion = correction 

564 best_k = k 

565 

566 num_scenarios *= best_k 

567 num_states += best_k 

568 if force_joint: 

569 indices_selected = sorted(range(n), 

570 key=lambda _: (0 if n == joint_index else 1, -marginal_likelihoods[_]))[:best_k] 

571 else: 

572 indices_selected = sorted(range(n), key=lambda _: -marginal_likelihoods[_])[:best_k] 

573 if best_k == 1: 

574 allowed_states = state2array[indices_selected[0]] 

575 else: 

576 allowed_states = np.zeros(len(states), dtype=int) 

577 allowed_states[indices_selected] = 1 

578 unresolved_nodes += 1 

579 node.add_prop(allowed_state_feature, allowed_states) 

580 

581 return num_scenarios, unresolved_nodes, num_states 

582 

583 

584def choose_ancestral_states_map(tree, feature, states): 

585 """ 

586 Chooses node ancestral states based on their marginal probabilities using MAP method. 

587 

588 :param tree: ete3.Tree, the tree of interest 

589 :param feature: str, character for which the ancestral states are to be chosen 

590 :param states: numpy.array of possible character states in order corresponding to the probabilities array 

591 :return: void, modified the get_personalized_feature_name(feature, ALLOWED_STATES) feature of each node 

592 to only contain the selected states. 

593 """ 

594 lh_feature = get_personalized_feature_name(feature, LH) 

595 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES) 

596 _, state2array = get_state2allowed_states(states, False) 

597 

598 for node in tree.traverse(): 

599 marginal_likelihoods = node.props.get(lh_feature) 

600 #if hasattr(node, allowed_state_feature + '.initial'): 

601 if allowed_state_feature + '.initial' in node.props.keys(): 

602 marginal_likelihoods *= node.props.get(allowed_state_feature + '.initial') 

603 node.add_prop(allowed_state_feature, state2array[marginal_likelihoods.argmax()]) 

604 

605 

606def choose_ancestral_states_joint(tree, feature, states, frequencies): 

607 """ 

608 Chooses node ancestral states based on their marginal probabilities using joint method. 

609 

610 :param frequencies: numpy array of state frequencies 

611 :param tree: ete3.Tree, the tree of interest 

612 :param feature: str, character for which the ancestral states are to be chosen 

613 :param states: numpy.array of possible character states in order corresponding to the probabilities array 

614 :return: void, modified the get_personalized_feature_name(feature, ALLOWED_STATES) feature of each node 

615 to only contain the selected states. 

616 """ 

617 lh_feature = get_personalized_feature_name(feature, BU_LH) 

618 lh_state_feature = get_personalized_feature_name(feature, BU_LH_JOINT_STATES) 

619 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES) 

620 joint_state_feature = get_personalized_feature_name(feature, JOINT_STATE) 

621 _, state2array = get_state2allowed_states(states, False) 

622 

623 def chose_consistent_state(node, state_index): 

624 node.add_prop(joint_state_feature, state_index) 

625 node.add_prop(allowed_state_feature, state2array[state_index]) 

626 

627 for child in node.children: 

628 chose_consistent_state(child, child.props.get(lh_state_feature)[state_index]) 

629 

630 chose_consistent_state(tree, (tree.props.get(lh_feature) * frequencies).argmax()) 

631 

632 

633def get_state2allowed_states(states, by_name=True): 

634 # tips allowed state arrays won't be modified so we might as well just share them 

635 n = len(states) 

636 all_ones = np.ones(n, int) 

637 state2array = {} 

638 for index, state in enumerate(states): 

639 allowed_state_array = np.zeros(n, int) 

640 allowed_state_array[index] = 1 

641 state2array[state if by_name else index] = allowed_state_array 

642 if by_name: 

643 state2array[None] = all_ones 

644 state2array[''] = all_ones 

645 return all_ones, state2array 

646 

647 

648def ml_acr(forest, character, prediction_method, model, observed_frequencies, force_joint=True): 

649 """ 

650 Calculates ML states on the trees and stores them in the corresponding feature. 

651 

652 :param prediction_method: MPPA (marginal approximation), MAP (max a posteriori), JOINT or ML 

653 :type prediction_method: str 

654 :param forest: trees of interest 

655 :type forest: list(ete3.Tree) 

656 :param character: character for which the ML states are reconstructed 

657 :type character: str 

658 :param model: character state change model 

659 :type model: pastml.models.Model 

660 :return: mapping between reconstruction parameters and values 

661 :rtype: dict 

662 """ 

663 logger = logging.getLogger('pastml') 

664 likelihood = \ 

665 optimise_likelihood(forest=forest, character=character, model=model, 

666 observed_frequencies=observed_frequencies) 

667 result = {LOG_LIKELIHOOD: likelihood, CHARACTER: character, METHOD: prediction_method, MODEL: model, 

668 STATES: model.states} 

669 

670 results = [] 

671 

672 def process_reconstructed_states(method): 

673 if method == prediction_method or is_meta_ml(prediction_method): 

674 method_character = get_personalized_feature_name(character, method) \ 

675 if prediction_method != method else character 

676 for tree in forest: 

677 convert_allowed_states2feature(tree, character, model.states, method_character) 

678 res = result.copy() 

679 res[CHARACTER] = method_character 

680 res[METHOD] = method 

681 results.append(res) 

682 

683 def process_restricted_likelihood_and_states(method): 

684 restricted_likelihood = \ 

685 sum(get_bottom_up_loglikelihood(tree=tree, character=character, 

686 is_marginal=True, model=model, alter=True) for tree in forest) 

687 note_restricted_likelihood(method, restricted_likelihood) 

688 process_reconstructed_states(method) 

689 

690 def note_restricted_likelihood(method, restricted_likelihood): 

691 logger.debug('Log likelihood for {} after {} state selection:\t{:.6f}' 

692 .format(character, method, restricted_likelihood)) 

693 result[RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR.format(method)] = restricted_likelihood 

694 

695 if prediction_method != MAP: 

696 # Calculate joint restricted likelihood 

697 restricted_likelihood = \ 

698 sum(get_bottom_up_loglikelihood(tree=tree, character=character, 

699 is_marginal=False, model=model, alter=True) for tree in forest) 

700 note_restricted_likelihood(JOINT, restricted_likelihood) 

701 for tree in forest: 

702 choose_ancestral_states_joint(tree, character, model.states, model.frequencies) 

703 process_reconstructed_states(JOINT) 

704 

705 if is_marginal(prediction_method): 

706 mps = [] 

707 for tree in forest: 

708 initialize_allowed_states(tree, character, model.states) 

709 altered_nodes = [] 

710 if 0 == model.tau: 

711 altered_nodes = alter_zero_node_allowed_states(tree, character) 

712 get_bottom_up_loglikelihood(tree=tree, character=character, is_marginal=True, model=model, alter=False) 

713 calculate_top_down_likelihood(tree, character, model=model) 

714 calculate_marginal_likelihoods(tree, character, model.frequencies) 

715 # check_marginal_likelihoods(tree, character) 

716 mps.append(convert_likelihoods_to_probabilities(tree, character, model.states)) 

717 

718 if altered_nodes: 

719 unalter_zero_node_allowed_states(altered_nodes, character) 

720 choose_ancestral_states_map(tree, character, model.states) 

721 result[MARGINAL_PROBABILITIES] = pd.concat(mps, copy=False) if len(mps) != 1 else mps[0] 

722 process_restricted_likelihood_and_states(MAP) 

723 

724 if MPPA == prediction_method or is_meta_ml(prediction_method): 

725 

726 if ALL == prediction_method: 

727 pars_acr_results = parsimonious_acr(forest, character, MP, model.states, 

728 model.forest_stats.num_nodes, model.forest_stats.num_tips) 

729 results.extend(pars_acr_results) 

730 for pars_acr_res in pars_acr_results: 

731 for tree in forest: 

732 _parsimonious_states2allowed_states(tree, pars_acr_res[CHARACTER], character, model.states) 

733 try: 

734 restricted_likelihood = \ 

735 sum(get_bottom_up_loglikelihood(tree=tree, character=character, is_marginal=True, 

736 model=model, alter=True) 

737 for tree in forest) 

738 note_restricted_likelihood(pars_acr_res[METHOD], restricted_likelihood) 

739 except PastMLLikelihoodError as e: 

740 logger.error('{}\n{} parsimonious state selection is inconsistent in terms of ML.' 

741 .format(e.message, pars_acr_res[METHOD])) 

742 

743 result[NUM_SCENARIOS], result[NUM_UNRESOLVED_NODES], result[NUM_STATES_PER_NODE] = 1, 0, 0 

744 for tree in forest: 

745 ns, nun, nspn = choose_ancestral_states_mppa(tree, character, model.states, force_joint=force_joint) 

746 result[NUM_SCENARIOS] *= ns 

747 result[NUM_UNRESOLVED_NODES] += nun 

748 result[NUM_STATES_PER_NODE] += nspn 

749 result[NUM_STATES_PER_NODE] /= model.forest_stats.num_nodes 

750 result[PERC_UNRESOLVED] = result[NUM_UNRESOLVED_NODES] * 100 / model.forest_stats.num_nodes 

751 logger.debug('{} node{} unresolved ({:.2f}%) for {} by {}, ' 

752 'i.e. {:.4f} state{} per node in average.' 

753 .format(result[NUM_UNRESOLVED_NODES], 's are' if result[NUM_UNRESOLVED_NODES] != 1 else ' is', 

754 result[PERC_UNRESOLVED], character, MPPA, 

755 result[NUM_STATES_PER_NODE], 's' if result[NUM_STATES_PER_NODE] > 1 else '')) 

756 process_restricted_likelihood_and_states(MPPA) 

757 

758 return results 

759 

760 

761def marginal_counts(forest, character, model, n_repetitions=1_000): 

762 """ 

763 Calculates ML states on the trees and stores them in the corresponding feature. 

764 

765 :param n_repetitions: 

766 :param forest: trees of interest 

767 :type forest: list(ete3.Tree) 

768 :param character: character for which the ML states are reconstructed 

769 :type character: str 

770 :param model: evolutionary model, F81 (Felsenstein 81-like), JC (Jukes-Cantor-like) or EFT (estimate from tips) 

771 :type model: Model 

772 :return: mapping between reconstruction parameters and values 

773 :rtype: dict 

774 """ 

775 

776 lh_feature = get_personalized_feature_name(character, LH) 

777 lh_sf_feature = get_personalized_feature_name(character, LH_SF) 

778 td_lh_feature = get_personalized_feature_name(character, TD_LH) 

779 td_lh_sf_feature = get_personalized_feature_name(character, TD_LH_SF) 

780 bu_lh_feature = get_personalized_feature_name(character, BU_LH) 

781 bu_lh_sf_feature = get_personalized_feature_name(character, BU_LH_SF) 

782 

783 allowed_state_feature = get_personalized_feature_name(character, ALLOWED_STATES) 

784 initial_allowed_state_feature = allowed_state_feature + '.initial' 

785 state_count_feature = get_personalized_feature_name(character, STATE_COUNTS) 

786 

787 get_pij = model.get_Pij_t 

788 n_states = len(model.states) 

789 state_ids = np.array(list(range(n_states))) 

790 

791 result = np.zeros((n_states, n_states), dtype=float) 

792 

793 for tree in forest: 

794 initialize_allowed_states(tree, character, model.states) 

795 altered_nodes = [] 

796 if 0 == model.tau: 

797 altered_nodes = alter_zero_node_allowed_states(tree, character) 

798 get_bottom_up_loglikelihood(tree=tree, character=character, is_marginal=True, model=model, alter=False) 

799 calculate_top_down_likelihood(tree, character, model=model) 

800 

801 for parent in tree.traverse('levelorder'): 

802 if parent.is_root: 

803 calc_node_marginal_likelihood(parent, lh_feature, lh_sf_feature, bu_lh_feature, bu_lh_sf_feature, 

804 td_lh_feature, td_lh_sf_feature, allowed_state_feature, model.frequencies, 

805 False) 

806 marginal_likelihoods = parent.props.get(lh_feature) 

807 marginal_probs = marginal_likelihoods / marginal_likelihoods.sum() 

808 # draw random states according to marginal probabilities (n_repetitions times) 

809 drawn_state_nums = Counter(np.random.choice(state_ids, size=n_repetitions, p=marginal_probs)) 

810 parent_state_counts = np.array([drawn_state_nums[_] for _ in state_ids]) 

811 parent.add_prop(state_count_feature, parent_state_counts) 

812 else: 

813 parent_state_counts = arent.props.get(state_count_feature) 

814 

815 if parent in altered_nodes: 

816 initial_allowed_states = parent.props.get(initial_allowed_state_feature) 

817 ps_counts_initial = parent_state_counts * initial_allowed_states 

818 if np.count_nonzero(ps_counts_initial): 

819 ps_counts_initial = n_repetitions * ps_counts_initial / ps_counts_initial.sum() 

820 else: 

821 ps_counts_initial = n_repetitions * initial_allowed_states / initial_allowed_states.sum() 

822 else: 

823 ps_counts_initial = parent_state_counts 

824 

825 same_state_counts = np.zeros(n_states) 

826 

827 for node in parent.children: 

828 node_pjis = np.transpose(get_pij(node.dist)) 

829 marginal_loglikelihood = np.log10(node.props.get(bu_lh_feature)) + np.log10(node_pjis) \ 

830 + np.log10(model.frequencies * node.props.get(allowed_state_feature)) 

831 rescale_log(marginal_loglikelihood) 

832 marginal_likelihood = np.power(10, marginal_loglikelihood) 

833 marginal_probs = marginal_likelihood / marginal_likelihood.sum(axis=1)[:, np.newaxis] 

834 state_counts = np.zeros(n_states, dtype=int) 

835 

836 update_results = parent not in altered_nodes and node not in altered_nodes 

837 

838 for j in range(n_states): 

839 parent_count_j = parent_state_counts[j] 

840 if parent_count_j: 

841 drawn_state_nums_j = \ 

842 Counter(np.random.choice(state_ids, size=parent_count_j, p=marginal_probs[j, :])) 

843 state_counts_j = np.array([drawn_state_nums_j[_] for _ in state_ids]) 

844 state_counts += state_counts_j 

845 if update_results: 

846 result[j, :] += state_counts_j 

847 same_state_counts[j] += state_counts_j[j] 

848 

849 if node in altered_nodes: 

850 initial_allowed_states = node.props.get(initial_allowed_state_feature) 

851 counts_initial = state_counts * initial_allowed_states 

852 if np.count_nonzero(counts_initial): 

853 counts_initial = n_repetitions * counts_initial / counts_initial.sum() 

854 else: 

855 counts_initial = n_repetitions * initial_allowed_states / initial_allowed_states.sum() 

856 else: 

857 counts_initial = state_counts 

858 

859 if not update_results: 

860 norm_counts = counts_initial / counts_initial.sum() 

861 for i in np.argwhere(ps_counts_initial > 0): 

862 child_counts_adjusted = norm_counts * ps_counts_initial[i] 

863 result[i, :] += child_counts_adjusted 

864 same_state_counts[i] += child_counts_adjusted[i] 

865 

866 node.add_prop(state_count_feature, state_counts) 

867 

868 for i in range(n_states): 

869 result[i, i] -= np.minimum(ps_counts_initial[i], same_state_counts[i]) 

870 

871 return result / n_repetitions 

872 

873 

874def optimise_likelihood(forest, character, model, observed_frequencies): 

875 for tree in forest: 

876 initialize_allowed_states(tree, character, model.states) 

877 logger = logging.getLogger('pastml') 

878 likelihood = sum(get_bottom_up_loglikelihood(tree=tree, character=character, 

879 is_marginal=True, model=model, alter=True) 

880 for tree in forest) 

881 if np.isnan(likelihood): 

882 raise PastMLLikelihoodError('Failed to calculate the likelihood for your tree, ' 

883 'please check that you do not have contradicting {} states specified ' 

884 'for internal tree nodes, ' 

885 'and if not - submit a bug at https://github.com/evolbioinfo/pastml/issues' 

886 .format(character)) 

887 if not model.get_num_params(): 

888 logger.debug('All the parameters are fixed for {}:\n{}{}.' 

889 .format(character, 

890 model._print_parameters, 

891 '\tlog likelihood:\t{:.6f}'.format(likelihood)) 

892 ) 

893 else: 

894 logger.debug('Initial values for {} parameter optimisation:\n{}{}.' 

895 .format(character, 

896 model._print_parameters, 

897 '\tlog likelihood:\t{:.6f}'.format(likelihood)) 

898 ) 

899 if not model.basic_params_fixed(): 

900 model.fix_extra_params() 

901 likelihood = optimize_likelihood_params(forest=forest, character=character, model=model, 

902 observed_frequencies=observed_frequencies) 

903 if np.any(np.isnan(likelihood) or likelihood == -np.inf): 

904 raise PastMLLikelihoodError('Failed to optimise the likelihood for your tree, ' 

905 'please check that you do not have contradicting {} states specified ' 

906 'for internal tree nodes, ' 

907 'and if not - submit a bug at https://github.com/evolbioinfo/pastml/issues' 

908 .format(character)) 

909 model.unfix_extra_params() 

910 if not model.extra_params_fixed(): 

911 logger.debug('Pre-optimised basic parameters for {}:\n{}{}.' 

912 .format(character, 

913 model._print_basic_parameters(), 

914 '\tlog likelihood:\t{:.6f}'.format(likelihood))) 

915 if not model.extra_params_fixed(): 

916 likelihood = \ 

917 optimize_likelihood_params(forest=forest, character=character, model=model, 

918 observed_frequencies=observed_frequencies) 

919 if np.any(np.isnan(likelihood) or likelihood == -np.inf): 

920 raise PastMLLikelihoodError('Failed to calculate the likelihood for your tree, ' 

921 'please check that you do not have contradicting {} states specified ' 

922 'for internal tree nodes, ' 

923 'and if not - submit a bug at https://github.com/evolbioinfo/pastml/issues' 

924 .format(character)) 

925 logger.debug('Optimised parameters for {}:\n{}{}' 

926 .format(character, 

927 model._print_parameters, 

928 '\tlog likelihood:\t{:.6f}'.format(likelihood))) 

929 return likelihood 

930 

931 

932def convert_allowed_states2feature(tree, feature, states, out_feature=None): 

933 if out_feature is None: 

934 out_feature = feature 

935 allowed_states_feature = get_personalized_feature_name(feature, ALLOWED_STATES) 

936 for node in tree.traverse(): 

937 node.add_prop(out_feature, set(states[node.props.get(allowed_states_feature).astype(bool)])) 

938 

939 

940def _parsimonious_states2allowed_states(tree, ps_feature, feature, states): 

941 n = len(states) 

942 state2index = dict(zip(states, range(n))) 

943 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES) 

944 for node in tree.traverse(): 

945 pars_states = node.props.get(ps_feature) 

946 allowed_states = np.zeros(n, dtype=int) 

947 for state in pars_states: 

948 allowed_states[state2index[state]] = 1 

949 node.add_prop(allowed_state_feature, allowed_states) 

950 

951 

952class PastMLLikelihoodError(Exception): 

953 

954 def __init__(self, *args): 

955 self.message = args[0] if args else None 

956 

957 def __str__(self): 

958 if self.message: 

959 return 'PastMLLikelihoodError, {}'.format(self.message) 

960 else: 

961 return 'PastMLLikelihoodError has been raised.'