Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/ml.py: 12%
510 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
1import logging
2from collections import Counter
4import numpy as np
5import pandas as pd
6from scipy.optimize import minimize
8from pastml import get_personalized_feature_name, CHARACTER, METHOD, NUM_SCENARIOS, NUM_UNRESOLVED_NODES, \
9 NUM_STATES_PER_NODE, PERC_UNRESOLVED, STATES
10from pastml.models import ModelWithFrequencies
11from pastml.parsimony import parsimonious_acr, MP
13LOG_LIKELIHOOD = 'log_likelihood'
14RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR = '{}_restricted_{{}}'.format(LOG_LIKELIHOOD)
16JOINT = 'JOINT'
17MPPA = 'MPPA'
18MAP = 'MAP'
19ALL = 'ALL'
20ML = 'ML'
22MARGINAL_PROBABILITIES = 'marginal_probabilities'
24MODEL = 'model'
26MIN_VALUE = np.log10(np.finfo(np.float64).eps)
27MAX_VALUE = np.log10(np.finfo(np.float64).max)
29MARGINAL_ML_METHODS = {MPPA, MAP}
30ML_METHODS = MARGINAL_ML_METHODS | {JOINT}
31META_ML_METHODS = {ML, ALL}
33BU_LH = 'BOTTOM_UP_LIKELIHOOD'
34TD_LH = 'TOP_DOWN_LIKELIHOOD'
35LH = 'LIKELIHOOD'
36LH_SF = 'LIKELIHOOD_SF'
37BU_LH_SF = 'BOTTOM_UP_LIKELIHOOD_SF'
38BU_LH_JOINT_STATES = 'BOTTOM_UP_LIKELIHOOD_JOINT_STATES'
39TD_LH_SF = 'TOP_DOWM_LIKELIHOOD_SF'
40ALLOWED_STATES = 'ALLOWED_STATES'
41STATE_COUNTS = 'STATE_COUNTS'
42JOINT_STATE = 'JOINT_STATE'
45def is_marginal(method):
46 """
47 Checks if the method is marginal, i.e. MAP, MPPA, or one of the meta-methods (ALL, ML).
49 :param method: prediction method
50 :type method: str
51 :return: bool
52 """
53 return method in MARGINAL_ML_METHODS or method in META_ML_METHODS
56def is_ml(method):
57 """
58 Checks if the method is max likelihood, i.e. JOINT or one of the marginal ones.
60 :param method: prediction method
61 :type method: str
62 :return: bool
63 """
64 return method in ML_METHODS or method in META_ML_METHODS
67def is_meta_ml(method):
68 """
69 Checks if the method is a meta max likelihood method, combining several methods, i.e. ML or ALL.
71 :param method: prediction method
72 :type method: str
73 :return: bool
74 """
75 return method in META_ML_METHODS
78def get_default_ml_method():
79 return MPPA
82def get_bottom_up_loglikelihood(tree, character, model, is_marginal=True, alter=True):
83 """
84 Calculates the bottom-up loglikelihood for the given tree.
85 The likelihood for each node is stored in the corresponding feature,
86 given by get_personalised_feature_name(feature, BU_LH).
88 :param model: model of character evolution
89 :type model: pastml.models.Model
90 :param is_marginal: whether the likelihood reconstruction is marginal (true) or joint (false)
91 :type is_marginal: bool
92 :param tree: tree of interest
93 :type tree: ete3.Tree
94 :param character: character for which the likelihood is calculated
95 :type character: str
96 :return: log likelihood
97 :rtype: float
98 """
100 altered_nodes = []
101 if 0 == model.tau and alter:
102 altered_nodes = alter_zero_node_allowed_states(tree, character)
104 lh_sf_feature = get_personalized_feature_name(character, BU_LH_SF)
105 lh_feature = get_personalized_feature_name(character, BU_LH)
106 lh_joint_state_feature = get_personalized_feature_name(character, BU_LH_JOINT_STATES)
107 allowed_state_feature = get_personalized_feature_name(character, ALLOWED_STATES)
109 get_pij = model.get_Pij_t
111 for node in tree.traverse('postorder'):
112 calc_node_bu_likelihood(node, allowed_state_feature, lh_feature, lh_sf_feature, lh_joint_state_feature,
113 is_marginal, get_pij)
114 root_likelihoods = tree.props.get(lh_feature) * model.frequencies
115 root_likelihoods = root_likelihoods.sum() if is_marginal else root_likelihoods.max()
117 if altered_nodes:
118 if is_marginal:
119 unalter_zero_node_allowed_states(altered_nodes, character)
120 else:
121 unalter_zero_node_joint_states(altered_nodes, character)
123 return np.log(root_likelihoods) - tree.props.get(lh_sf_feature) / np.log10(np.e)
126def calc_node_bu_likelihood(node, allowed_state_feature, lh_feature, lh_sf_feature, lh_joint_state_feature, is_marginal,
127 get_pij):
128 allowed_states = node.props.get(allowed_state_feature)
129 log_likelihood_array = np.log10(np.ones(len(allowed_states), dtype=np.float64) * allowed_states)
130 factors = 0
131 for child in node.children:
132 child_likelihoods = get_pij(child.dist) * child.props.get(lh_feature)
133 if is_marginal:
134 child_likelihoods = child_likelihoods.sum(axis=1)
135 else:
136 child_states = child_likelihoods.argmax(axis=1)
137 child.add_prop(lh_joint_state_feature, child_states)
138 child_likelihoods = child_likelihoods.max(axis=1)
139 child_likelihoods = np.maximum(child_likelihoods, 0)
140 log_likelihood_array += np.log10(child_likelihoods)
141 if np.all(log_likelihood_array == -np.inf):
142 raise PastMLLikelihoodError("The parent node {} and its child node {} have non-intersecting states, "
143 "and are connected by a zero-length ({:g}) branch. "
144 "This creates a zero likelihood value. "
145 "To avoid this issue check the restrictions on these node states "
146 "and/or use a smoothing factor (tau)."
147 .format(node.name, child.name, child.dist))
148 factors += rescale_log(log_likelihood_array)
149 node.add_prop(lh_feature, np.power(10, log_likelihood_array))
150 node.add_prop(lh_sf_feature, factors + sum(_.props.get(lh_sf_feature) for _ in node.children))
153def rescale_log(loglikelihood_array):
154 """
155 Rescales the likelihood array if it gets too small/large, by multiplying it by a factor of 10.
156 :param loglikelihood_array: numpy array containing the loglikelihood to be rescaled
157 :return: float, factor of 10 by which the likelihood array has been multiplies.
158 """
160 max_limit = MAX_VALUE
161 min_limit = MIN_VALUE
163 non_zero_loglh_array = loglikelihood_array[loglikelihood_array > -np.inf]
164 min_lh_value = np.min(non_zero_loglh_array)
165 max_lh_value = np.max(non_zero_loglh_array)
167 factors = 0
168 if max_lh_value > max_limit:
169 factors = max_limit - max_lh_value - 1
170 elif min_lh_value < min_limit:
171 factors = min(min_limit - min_lh_value + 1, max_limit - max_lh_value - 1)
172 loglikelihood_array += factors
173 return factors
176def optimize_likelihood_params(forest, character, observed_frequencies, model):
177 """
178 Optimizes the likelihood parameters (state frequencies and scaling factor) for the given trees.
180 :param model: model of character evolution
181 :type model: pastml.model.Model
182 :param avg_br_len: avg branch length
183 :type avg_br_len: float
184 :param forest: trees of interest
185 :type forest: list(ete3.Tree)
186 :param character: character for which the likelihood is optimised
187 :type character: str
188 :param frequencies: array of initial state frequencies
189 :type frequencies: numpy.array
190 :param sf: initial scaling factor
191 :type sf: float
192 :param optimise_sf: whether the scaling factor needs to be optimised
193 :type optimise_sf: bool
194 :param optimise_frequencies: whether the state frequencies need to be optimised
195 :type optimise_frequencies: bool
196 :param tau: a smoothing factor to apply to branch lengths during likelihood calculation.
197 If set to zero (default), zero internal branches will be collapsed instead.
198 :type tau: float
199 :return: optimized parameters and log likelihood: ((frequencies, scaling_factor), optimum)
200 :rtype: tuple
201 """
202 bounds = model.get_bounds()
204 def get_v(ps):
205 if np.any(pd.isnull(ps)):
206 return np.nan
207 model.set_params_from_optimised(ps)
208 res = sum(get_bottom_up_loglikelihood(tree=tree, character=character, is_marginal=True, model=model, alter=True)
209 for tree in forest)
210 return np.inf if pd.isnull(res) else -res
212 if np.any(observed_frequencies <= 0):
213 observed_frequencies = np.maximum(observed_frequencies, 1e-10)
215 x0_JC = model.get_optimised_parameters()
216 optimise_frequencies = isinstance(model, ModelWithFrequencies) and model._optimise_frequencies
217 x0_EFT = x0_JC
218 if optimise_frequencies:
219 model.frequencies = observed_frequencies
220 x0_EFT = model.get_optimised_parameters()
221 log_lh_JC = -get_v(x0_JC)
222 log_lh_EFT = log_lh_JC if not optimise_frequencies else -get_v(x0_EFT)
224 best_log_lh = max(log_lh_JC, log_lh_EFT)
226 for i in range(100):
227 if i == 0:
228 vs = x0_JC
229 elif optimise_frequencies and i == 1:
230 vs = x0_EFT
231 else:
232 vs = np.random.uniform(bounds[:, 0], bounds[:, 1])
233 fres = minimize(get_v, x0=vs, method='L-BFGS-B', bounds=bounds)
234 if fres.success and not np.any(np.isnan(fres.x)):
235 if -fres.fun >= best_log_lh:
236 model.set_params_from_optimised(fres.x)
237 return -fres.fun
238 model.set_params_from_optimised(x0_JC if log_lh_JC >= log_lh_EFT else x0_EFT)
239 return best_log_lh
242def calculate_top_down_likelihood(tree, character, model):
243 """
244 Calculates the top-down likelihood for the given tree.
245 The likelihood for each node is stored in the corresponding feature,
246 given by get_personalised_feature_name(feature, TD_LH).
248 To calculate the top-down likelihood of a node, we assume that the tree is rooted in this node
249 and combine the likelihoods of the “up-subtrees”,
250 e.g. to calculate the top-down likelihood of a node N1 being in a state i,
251 given that its parent node is P and its brother node is N2, we imagine that the tree is re-rooted in N1,
252 therefore P becoming the child of N1, and N2 its grandchild.
253 We then calculate the bottom-up likelihood from the P subtree:
254 L_top_down(N1, i) = sum_j P(i -> j, dist(N1, P)) * L_top_down(P) * sum_k P(j -> k, dist(N2, P)) * L_bottom_up (N2).
256 For the root node we assume its top-down likelihood to be 1 for all the states.
258 :param model: model of character evolution
259 :type model: pastml.models.Model
260 :param character: character whose ancestral state likelihood is being calculated
261 :type character: str
262 :param tree: tree of interest (with bottom-up likelihood pre-calculated)
263 :type tree: ete3.Tree
264 :return: void, stores the node top-down likelihoods in the get_personalised_feature_name(feature, TD_LH) feature.
265 """
266 td_lh_feature = get_personalized_feature_name(character, TD_LH)
267 td_lh_sf_feature = get_personalized_feature_name(character, TD_LH_SF)
268 bu_lh_feature = get_personalized_feature_name(character, BU_LH)
269 bu_lh_sf_feature = get_personalized_feature_name(character, BU_LH_SF)
271 get_pij = model.get_Pij_t
272 for node in tree.traverse('preorder'):
273 calc_node_td_likelihood(node, td_lh_feature, td_lh_sf_feature, bu_lh_feature, bu_lh_sf_feature, get_pij)
276def calc_node_td_likelihood(node, td_lh_feature, td_lh_sf_feature, bu_lh_feature, bu_lh_sf_feature, get_pij):
277 if node.is_root:
278 node.add_prop(td_lh_feature, np.ones(len(node.props.get(bu_lh_feature)), np.float64))
279 node.add_prop(td_lh_sf_feature, 0)
280 return
281 parent = node.up
282 node_pjis = np.transpose(get_pij(node.dist))
283 node_contribution = node.props.get(bu_lh_feature).dot(node_pjis)
284 node_contribution[node_contribution <= 0] = 1
285 parent_loglikelihood = np.log10(parent.props.get(td_lh_feature)) \
286 + np.log10(parent.props.get(bu_lh_feature)) - np.log10(node_contribution)
287 factors = parent.props.get(td_lh_sf_feature) \
288 + parent.props.get(bu_lh_sf_feature) - node.props.get(bu_lh_sf_feature)
289 factors += rescale_log(parent_loglikelihood)
290 parent_likelihood = np.power(10, parent_loglikelihood)
291 td_likelihood = parent_likelihood.dot(node_pjis)
292 node.add_prop(td_lh_feature, np.maximum(td_likelihood, 0))
293 node.add_prop(td_lh_sf_feature, factors)
296def initialize_allowed_states(tree, feature, states):
297 """
298 Initializes the allowed state arrays for tips based on their states given by the feature.
300 :param tree: tree for which the tip likelihoods are to be initialized
301 :type tree: ete3.Tree
302 :param feature: feature in which the tip states are stored
303 (the value could be None for a missing state or list if multiple stated are possible)
304 :type feature: str
305 :param states: ordered array of states.
306 :type states: numpy.array
307 :return: void, adds the get_personalised_feature_name(feature, ALLOWED_STATES) feature to tree tips.
308 """
309 allowed_states_feature = get_personalized_feature_name(feature, ALLOWED_STATES)
310 n = len(states)
311 state2index = dict(zip(states, range(n)))
313 for node in tree.traverse():
314 node_states = node.props.get(feature, set())
315 # Check if state is a string and convert it to a set
316 if isinstance(node_states, str):
317 node_states = {node_states}
319 if not node_states:
320 allowed_states = np.ones(n, dtype=int)
321 else:
322 allowed_states = np.zeros(n, dtype=int)
323 for state in node_states:
324 allowed_states[state2index[state]] = 1
325 node.add_prop(allowed_states_feature, allowed_states)
327def get_zero_clusters_with_states(tree, feature):
328 """
329 Returns the zero-distance clusters in the given tree.
331 :param tree: ete3.Tree, the tree of interest
332 :return: iterator of lists of nodes that are at zero distance from each other and have states specified for them.
333 """
335 def has_state(_):
336 state = _.props.get(feature, None)
337 return state is not None and state != ''
339 todo = [tree]
341 while todo:
342 zero_cluster_with_states = []
343 extension = [todo.pop()]
345 while extension:
346 n = extension.pop()
347 if has_state(n):
348 zero_cluster_with_states.append(n)
349 for c in n.children:
350 if c.dist == 0:
351 extension.append(c)
352 else:
353 todo.append(c)
354 if len(zero_cluster_with_states) > 1:
355 yield zero_cluster_with_states
358def alter_zero_node_allowed_states(tree, feature):
359 """
360 Alters the bottom-up likelihood arrays for zero-distance nodes
361 to make sure they do not contradict with other zero-distance node siblings/ancestors/descendants.
363 :param tree: ete3.Tree, the tree of interest
364 :param feature: str, character for which the likelihood is altered
365 :return: void, modifies the get_personalised_feature_name(feature, BU_LH) feature to zero-distance nodes.
366 """
367 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES)
368 allowed_state_feature_unaltered = get_personalized_feature_name(feature, ALLOWED_STATES + '.initial')
370 altered_nodes = []
372 for zero_cluster_with_states in get_zero_clusters_with_states(tree, feature):
373 # If there is a common state do nothing
374 counts = None
375 for c in zero_cluster_with_states:
376 if counts is None:
377 counts = c.props.get(allowed_state_feature).copy()
378 else:
379 counts += c.props.get(allowed_state_feature)
380 if counts.max() == len(zero_cluster_with_states):
381 continue
382 # Otherwise set all zero-cluster node states to state union
383 allowed_states = None
384 for c in zero_cluster_with_states:
385 initial_allowed_states = c.props.get(allowed_state_feature).copy()
386 if allowed_states is None:
387 allowed_states = initial_allowed_states.copy()
388 else:
389 allowed_states[np.nonzero(initial_allowed_states)] = 1
390 c.add_prop(allowed_state_feature, allowed_states)
391 c.add_prop(allowed_state_feature_unaltered, initial_allowed_states)
392 altered_nodes.append(c)
393 return altered_nodes
396def unalter_zero_node_allowed_states(altered_nodes, feature):
397 """
398 Unalters the bottom-up likelihood arrays for zero-distance nodes
399 to contain ones only in their states.
401 :param altered_nodes: list of modified nodes
402 :param feature: str, character for which the likelihood was altered
403 :return: void, modifies the get_personalised_feature_name(feature, BU_LH) feature to zero-distance nodes.
404 """
405 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES)
406 allowed_state_feature_unaltered = get_personalized_feature_name(feature, ALLOWED_STATES + '.initial')
407 for n in altered_nodes:
408 initial_allowed_states = n.props.get(allowed_state_feature_unaltered)
409 allowed_states = n.props.get(allowed_state_feature) & initial_allowed_states
410 n.add_prop(allowed_state_feature,
411 (allowed_states if np.any(allowed_states > 0) else initial_allowed_states))
414def unalter_zero_node_joint_states(altered_nodes, feature):
415 """
416 Unalters the joint states for zero-distance nodes
417 to contain only their states.
419 :param altered_nodes: list of modified nodes
420 :param feature: str, character for which the likelihood was altered
421 :return: void, modifies the get_personalised_feature_name(feature, BU_LH_JOINT_STATES) feature to zero-distance nodes.
422 """
423 lh_joint_state_feature = get_personalized_feature_name(feature, BU_LH_JOINT_STATES)
424 allowed_state_feature_unaltered = get_personalized_feature_name(feature, ALLOWED_STATES + '.initial')
425 for n in altered_nodes:
426 initial_allowed_states = n.props.get(allowed_state_feature_unaltered)
427 allowed_index = np.argmax(initial_allowed_states)
428 if len(initial_allowed_states[initial_allowed_states > 0]) == 1:
429 n.add_prop(lh_joint_state_feature, np.ones(len(initial_allowed_states), int) * allowed_index)
430 else:
431 joint_states = n.props.get(lh_joint_state_feature)
432 for i in range(len(initial_allowed_states)):
433 if not initial_allowed_states[joint_states[i]]:
434 joint_states[i] = allowed_index
437def calculate_marginal_likelihoods(tree, feature, frequencies, clean_up=True):
438 """
439 Calculates marginal likelihoods for each tree node
440 by multiplying state frequencies with their bottom-up and top-down likelihoods.
442 :param tree: ete3.Tree, the tree of interest
443 :param feature: str, character for which the likelihood is calculated
444 :param frequencies: numpy array of state frequencies
445 :return: void, stores the node marginal likelihoods in the get_personalised_feature_name(feature, LH) feature.
446 """
447 bu_lh_feature = get_personalized_feature_name(feature, BU_LH)
448 bu_lh_sf_feature = get_personalized_feature_name(feature, BU_LH_SF)
449 td_lh_feature = get_personalized_feature_name(feature, TD_LH)
450 td_lh_sf_feature = get_personalized_feature_name(feature, TD_LH_SF)
451 lh_feature = get_personalized_feature_name(feature, LH)
452 lh_sf_feature = get_personalized_feature_name(feature, LH_SF)
453 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES)
455 for node in tree.traverse('preorder'):
456 calc_node_marginal_likelihood(node, lh_feature, lh_sf_feature, bu_lh_feature, bu_lh_sf_feature, td_lh_feature,
457 td_lh_sf_feature, allowed_state_feature, frequencies, clean_up)
460def calc_node_marginal_likelihood(node, lh_feature, lh_sf_feature, bu_lh_feature, bu_lh_sf_feature, td_lh_feature,
461 td_lh_sf_feature, allowed_state_feature, frequencies, clean_up):
462 loglikelihood = np.log10(node.props.get(bu_lh_feature)) + np.log10(node.props.get(td_lh_feature)) \
463 + np.log10(frequencies * node.props.get(allowed_state_feature))
464 factors = rescale_log(loglikelihood)
465 node.add_prop(lh_feature, np.power(10, loglikelihood))
466 node.add_prop(lh_sf_feature, factors + node.props.get(td_lh_sf_feature) + node.props.get(bu_lh_sf_feature))
467 if clean_up:
468 node.del_prop(bu_lh_feature)
469 node.del_prop(bu_lh_sf_feature)
470 node.del_prop(td_lh_feature)
471 node.del_prop(td_lh_sf_feature)
474def check_marginal_likelihoods(tree, feature):
475 """
476 Sanity check: combined bottom-up and top-down likelihood of each node of the tree must be the same.
478 :param tree: ete3.Tree, the tree of interest
479 :param feature: str, character for which the likelihood is calculated
480 :return: void, stores the node marginal likelihoods in the get_personalised_feature_name(feature, LH) feature.
481 """
482 lh_feature = get_personalized_feature_name(feature, LH)
483 lh_sf_feature = get_personalized_feature_name(feature, LH_SF)
485 for node in tree.traverse():
486 if not node.is_root:
487 node_loglh = np.log10(node.props.get(lh_feature).sum()) - node.props.get(lh_sf_feature)
488 parent_loglh = np.log10(node.up.props.get(lh_feature).sum()) - node.up.props.get(lh_sf_feature)
489 assert (round(node_loglh, 2) == round(parent_loglh, 2))
492def convert_likelihoods_to_probabilities(tree, feature, states):
493 """
494 Normalizes each node marginal likelihoods to convert them to marginal probabilities.
496 :param states: numpy array of states in the order corresponding to the marginal likelihood arrays
497 :param tree: ete3.Tree, the tree of interest
498 :param feature: str, character for which the probabilities are calculated
499 :return: pandas DataFrame, that maps node names to their marginal likelihoods.
500 """
501 lh_feature = get_personalized_feature_name(feature, LH)
502 name2probs = {}
504 for node in tree.traverse():
505 lh = node.props.get(lh_feature)
506 name2probs[node.name] = lh / lh.sum()
508 return pd.DataFrame.from_dict(name2probs, orient='index', columns=states)
511def choose_ancestral_states_mppa(tree, feature, states, force_joint=True):
512 """
513 Chooses node ancestral states based on their marginal probabilities using MPPA method.
515 :param force_joint: make sure that Joint state is chosen even if it has a low probability.
516 :type force_joint: bool
517 :param tree: tree of interest
518 :type tree: ete3.Tree
519 :param feature: character for which the ancestral states are to be chosen
520 :type feature: str
521 :param states: possible character states in order corresponding to the probabilities array
522 :type states: numpy.array
523 :return: number of ancestral scenarios selected,
524 calculated by multiplying the number of selected states for all nodes.
525 Also modified the get_personalized_feature_name(feature, ALLOWED_STATES) feature of each node
526 to only contain the selected states.
527 :rtype: int
528 """
529 lh_feature = get_personalized_feature_name(feature, LH)
530 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES)
531 joint_state_feature = get_personalized_feature_name(feature, JOINT_STATE)
533 n = len(states)
534 _, state2array = get_state2allowed_states(states, False)
536 num_scenarios = 1
537 unresolved_nodes = 0
538 num_states = 0
540 # If force_joint == True,
541 # we make sure that the joint state is always chosen,
542 # for this we sort the marginal probabilities array as [lowest_non_joint_mp, ..., highest_non_joint_mp, joint_mp]
543 # select k in 1:n such as the correction between choosing 0, 0, ..., 1/k, ..., 1/k and our sorted array is min
544 # and return the corresponding states
545 for node in tree.traverse():
546 marginal_likelihoods = node.props.get(lh_feature)
547 #if hasattr(node, allowed_state_feature + '.initial'):
548 if allowed_state_feature + '.initial' in node.props.keys():
549 marginal_likelihoods *= node.props.get(allowed_state_feature + '.initial')
550 marginal_probs = marginal_likelihoods / marginal_likelihoods.sum()
551 if force_joint:
552 joint_index = node.props.get(joint_state_feature)
553 joint_prob = marginal_probs[joint_index]
554 marginal_probs = np.hstack((np.sort(np.delete(marginal_probs, joint_index)), [joint_prob]))
555 else:
556 marginal_probs = np.sort(marginal_probs)
557 best_k = n
558 best_correstion = np.inf
559 for k in range(1, n + 1):
560 correction = np.hstack((np.zeros(n - k), np.ones(k) / k)) - marginal_probs
561 correction = correction.dot(correction)
562 if correction < best_correstion:
563 best_correstion = correction
564 best_k = k
566 num_scenarios *= best_k
567 num_states += best_k
568 if force_joint:
569 indices_selected = sorted(range(n),
570 key=lambda _: (0 if n == joint_index else 1, -marginal_likelihoods[_]))[:best_k]
571 else:
572 indices_selected = sorted(range(n), key=lambda _: -marginal_likelihoods[_])[:best_k]
573 if best_k == 1:
574 allowed_states = state2array[indices_selected[0]]
575 else:
576 allowed_states = np.zeros(len(states), dtype=int)
577 allowed_states[indices_selected] = 1
578 unresolved_nodes += 1
579 node.add_prop(allowed_state_feature, allowed_states)
581 return num_scenarios, unresolved_nodes, num_states
584def choose_ancestral_states_map(tree, feature, states):
585 """
586 Chooses node ancestral states based on their marginal probabilities using MAP method.
588 :param tree: ete3.Tree, the tree of interest
589 :param feature: str, character for which the ancestral states are to be chosen
590 :param states: numpy.array of possible character states in order corresponding to the probabilities array
591 :return: void, modified the get_personalized_feature_name(feature, ALLOWED_STATES) feature of each node
592 to only contain the selected states.
593 """
594 lh_feature = get_personalized_feature_name(feature, LH)
595 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES)
596 _, state2array = get_state2allowed_states(states, False)
598 for node in tree.traverse():
599 marginal_likelihoods = node.props.get(lh_feature)
600 #if hasattr(node, allowed_state_feature + '.initial'):
601 if allowed_state_feature + '.initial' in node.props.keys():
602 marginal_likelihoods *= node.props.get(allowed_state_feature + '.initial')
603 node.add_prop(allowed_state_feature, state2array[marginal_likelihoods.argmax()])
606def choose_ancestral_states_joint(tree, feature, states, frequencies):
607 """
608 Chooses node ancestral states based on their marginal probabilities using joint method.
610 :param frequencies: numpy array of state frequencies
611 :param tree: ete3.Tree, the tree of interest
612 :param feature: str, character for which the ancestral states are to be chosen
613 :param states: numpy.array of possible character states in order corresponding to the probabilities array
614 :return: void, modified the get_personalized_feature_name(feature, ALLOWED_STATES) feature of each node
615 to only contain the selected states.
616 """
617 lh_feature = get_personalized_feature_name(feature, BU_LH)
618 lh_state_feature = get_personalized_feature_name(feature, BU_LH_JOINT_STATES)
619 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES)
620 joint_state_feature = get_personalized_feature_name(feature, JOINT_STATE)
621 _, state2array = get_state2allowed_states(states, False)
623 def chose_consistent_state(node, state_index):
624 node.add_prop(joint_state_feature, state_index)
625 node.add_prop(allowed_state_feature, state2array[state_index])
627 for child in node.children:
628 chose_consistent_state(child, child.props.get(lh_state_feature)[state_index])
630 chose_consistent_state(tree, (tree.props.get(lh_feature) * frequencies).argmax())
633def get_state2allowed_states(states, by_name=True):
634 # tips allowed state arrays won't be modified so we might as well just share them
635 n = len(states)
636 all_ones = np.ones(n, int)
637 state2array = {}
638 for index, state in enumerate(states):
639 allowed_state_array = np.zeros(n, int)
640 allowed_state_array[index] = 1
641 state2array[state if by_name else index] = allowed_state_array
642 if by_name:
643 state2array[None] = all_ones
644 state2array[''] = all_ones
645 return all_ones, state2array
648def ml_acr(forest, character, prediction_method, model, observed_frequencies, force_joint=True):
649 """
650 Calculates ML states on the trees and stores them in the corresponding feature.
652 :param prediction_method: MPPA (marginal approximation), MAP (max a posteriori), JOINT or ML
653 :type prediction_method: str
654 :param forest: trees of interest
655 :type forest: list(ete3.Tree)
656 :param character: character for which the ML states are reconstructed
657 :type character: str
658 :param model: character state change model
659 :type model: pastml.models.Model
660 :return: mapping between reconstruction parameters and values
661 :rtype: dict
662 """
663 logger = logging.getLogger('pastml')
664 likelihood = \
665 optimise_likelihood(forest=forest, character=character, model=model,
666 observed_frequencies=observed_frequencies)
667 result = {LOG_LIKELIHOOD: likelihood, CHARACTER: character, METHOD: prediction_method, MODEL: model,
668 STATES: model.states}
670 results = []
672 def process_reconstructed_states(method):
673 if method == prediction_method or is_meta_ml(prediction_method):
674 method_character = get_personalized_feature_name(character, method) \
675 if prediction_method != method else character
676 for tree in forest:
677 convert_allowed_states2feature(tree, character, model.states, method_character)
678 res = result.copy()
679 res[CHARACTER] = method_character
680 res[METHOD] = method
681 results.append(res)
683 def process_restricted_likelihood_and_states(method):
684 restricted_likelihood = \
685 sum(get_bottom_up_loglikelihood(tree=tree, character=character,
686 is_marginal=True, model=model, alter=True) for tree in forest)
687 note_restricted_likelihood(method, restricted_likelihood)
688 process_reconstructed_states(method)
690 def note_restricted_likelihood(method, restricted_likelihood):
691 logger.debug('Log likelihood for {} after {} state selection:\t{:.6f}'
692 .format(character, method, restricted_likelihood))
693 result[RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR.format(method)] = restricted_likelihood
695 if prediction_method != MAP:
696 # Calculate joint restricted likelihood
697 restricted_likelihood = \
698 sum(get_bottom_up_loglikelihood(tree=tree, character=character,
699 is_marginal=False, model=model, alter=True) for tree in forest)
700 note_restricted_likelihood(JOINT, restricted_likelihood)
701 for tree in forest:
702 choose_ancestral_states_joint(tree, character, model.states, model.frequencies)
703 process_reconstructed_states(JOINT)
705 if is_marginal(prediction_method):
706 mps = []
707 for tree in forest:
708 initialize_allowed_states(tree, character, model.states)
709 altered_nodes = []
710 if 0 == model.tau:
711 altered_nodes = alter_zero_node_allowed_states(tree, character)
712 get_bottom_up_loglikelihood(tree=tree, character=character, is_marginal=True, model=model, alter=False)
713 calculate_top_down_likelihood(tree, character, model=model)
714 calculate_marginal_likelihoods(tree, character, model.frequencies)
715 # check_marginal_likelihoods(tree, character)
716 mps.append(convert_likelihoods_to_probabilities(tree, character, model.states))
718 if altered_nodes:
719 unalter_zero_node_allowed_states(altered_nodes, character)
720 choose_ancestral_states_map(tree, character, model.states)
721 result[MARGINAL_PROBABILITIES] = pd.concat(mps, copy=False) if len(mps) != 1 else mps[0]
722 process_restricted_likelihood_and_states(MAP)
724 if MPPA == prediction_method or is_meta_ml(prediction_method):
726 if ALL == prediction_method:
727 pars_acr_results = parsimonious_acr(forest, character, MP, model.states,
728 model.forest_stats.num_nodes, model.forest_stats.num_tips)
729 results.extend(pars_acr_results)
730 for pars_acr_res in pars_acr_results:
731 for tree in forest:
732 _parsimonious_states2allowed_states(tree, pars_acr_res[CHARACTER], character, model.states)
733 try:
734 restricted_likelihood = \
735 sum(get_bottom_up_loglikelihood(tree=tree, character=character, is_marginal=True,
736 model=model, alter=True)
737 for tree in forest)
738 note_restricted_likelihood(pars_acr_res[METHOD], restricted_likelihood)
739 except PastMLLikelihoodError as e:
740 logger.error('{}\n{} parsimonious state selection is inconsistent in terms of ML.'
741 .format(e.message, pars_acr_res[METHOD]))
743 result[NUM_SCENARIOS], result[NUM_UNRESOLVED_NODES], result[NUM_STATES_PER_NODE] = 1, 0, 0
744 for tree in forest:
745 ns, nun, nspn = choose_ancestral_states_mppa(tree, character, model.states, force_joint=force_joint)
746 result[NUM_SCENARIOS] *= ns
747 result[NUM_UNRESOLVED_NODES] += nun
748 result[NUM_STATES_PER_NODE] += nspn
749 result[NUM_STATES_PER_NODE] /= model.forest_stats.num_nodes
750 result[PERC_UNRESOLVED] = result[NUM_UNRESOLVED_NODES] * 100 / model.forest_stats.num_nodes
751 logger.debug('{} node{} unresolved ({:.2f}%) for {} by {}, '
752 'i.e. {:.4f} state{} per node in average.'
753 .format(result[NUM_UNRESOLVED_NODES], 's are' if result[NUM_UNRESOLVED_NODES] != 1 else ' is',
754 result[PERC_UNRESOLVED], character, MPPA,
755 result[NUM_STATES_PER_NODE], 's' if result[NUM_STATES_PER_NODE] > 1 else ''))
756 process_restricted_likelihood_and_states(MPPA)
758 return results
761def marginal_counts(forest, character, model, n_repetitions=1_000):
762 """
763 Calculates ML states on the trees and stores them in the corresponding feature.
765 :param n_repetitions:
766 :param forest: trees of interest
767 :type forest: list(ete3.Tree)
768 :param character: character for which the ML states are reconstructed
769 :type character: str
770 :param model: evolutionary model, F81 (Felsenstein 81-like), JC (Jukes-Cantor-like) or EFT (estimate from tips)
771 :type model: Model
772 :return: mapping between reconstruction parameters and values
773 :rtype: dict
774 """
776 lh_feature = get_personalized_feature_name(character, LH)
777 lh_sf_feature = get_personalized_feature_name(character, LH_SF)
778 td_lh_feature = get_personalized_feature_name(character, TD_LH)
779 td_lh_sf_feature = get_personalized_feature_name(character, TD_LH_SF)
780 bu_lh_feature = get_personalized_feature_name(character, BU_LH)
781 bu_lh_sf_feature = get_personalized_feature_name(character, BU_LH_SF)
783 allowed_state_feature = get_personalized_feature_name(character, ALLOWED_STATES)
784 initial_allowed_state_feature = allowed_state_feature + '.initial'
785 state_count_feature = get_personalized_feature_name(character, STATE_COUNTS)
787 get_pij = model.get_Pij_t
788 n_states = len(model.states)
789 state_ids = np.array(list(range(n_states)))
791 result = np.zeros((n_states, n_states), dtype=float)
793 for tree in forest:
794 initialize_allowed_states(tree, character, model.states)
795 altered_nodes = []
796 if 0 == model.tau:
797 altered_nodes = alter_zero_node_allowed_states(tree, character)
798 get_bottom_up_loglikelihood(tree=tree, character=character, is_marginal=True, model=model, alter=False)
799 calculate_top_down_likelihood(tree, character, model=model)
801 for parent in tree.traverse('levelorder'):
802 if parent.is_root:
803 calc_node_marginal_likelihood(parent, lh_feature, lh_sf_feature, bu_lh_feature, bu_lh_sf_feature,
804 td_lh_feature, td_lh_sf_feature, allowed_state_feature, model.frequencies,
805 False)
806 marginal_likelihoods = parent.props.get(lh_feature)
807 marginal_probs = marginal_likelihoods / marginal_likelihoods.sum()
808 # draw random states according to marginal probabilities (n_repetitions times)
809 drawn_state_nums = Counter(np.random.choice(state_ids, size=n_repetitions, p=marginal_probs))
810 parent_state_counts = np.array([drawn_state_nums[_] for _ in state_ids])
811 parent.add_prop(state_count_feature, parent_state_counts)
812 else:
813 parent_state_counts = arent.props.get(state_count_feature)
815 if parent in altered_nodes:
816 initial_allowed_states = parent.props.get(initial_allowed_state_feature)
817 ps_counts_initial = parent_state_counts * initial_allowed_states
818 if np.count_nonzero(ps_counts_initial):
819 ps_counts_initial = n_repetitions * ps_counts_initial / ps_counts_initial.sum()
820 else:
821 ps_counts_initial = n_repetitions * initial_allowed_states / initial_allowed_states.sum()
822 else:
823 ps_counts_initial = parent_state_counts
825 same_state_counts = np.zeros(n_states)
827 for node in parent.children:
828 node_pjis = np.transpose(get_pij(node.dist))
829 marginal_loglikelihood = np.log10(node.props.get(bu_lh_feature)) + np.log10(node_pjis) \
830 + np.log10(model.frequencies * node.props.get(allowed_state_feature))
831 rescale_log(marginal_loglikelihood)
832 marginal_likelihood = np.power(10, marginal_loglikelihood)
833 marginal_probs = marginal_likelihood / marginal_likelihood.sum(axis=1)[:, np.newaxis]
834 state_counts = np.zeros(n_states, dtype=int)
836 update_results = parent not in altered_nodes and node not in altered_nodes
838 for j in range(n_states):
839 parent_count_j = parent_state_counts[j]
840 if parent_count_j:
841 drawn_state_nums_j = \
842 Counter(np.random.choice(state_ids, size=parent_count_j, p=marginal_probs[j, :]))
843 state_counts_j = np.array([drawn_state_nums_j[_] for _ in state_ids])
844 state_counts += state_counts_j
845 if update_results:
846 result[j, :] += state_counts_j
847 same_state_counts[j] += state_counts_j[j]
849 if node in altered_nodes:
850 initial_allowed_states = node.props.get(initial_allowed_state_feature)
851 counts_initial = state_counts * initial_allowed_states
852 if np.count_nonzero(counts_initial):
853 counts_initial = n_repetitions * counts_initial / counts_initial.sum()
854 else:
855 counts_initial = n_repetitions * initial_allowed_states / initial_allowed_states.sum()
856 else:
857 counts_initial = state_counts
859 if not update_results:
860 norm_counts = counts_initial / counts_initial.sum()
861 for i in np.argwhere(ps_counts_initial > 0):
862 child_counts_adjusted = norm_counts * ps_counts_initial[i]
863 result[i, :] += child_counts_adjusted
864 same_state_counts[i] += child_counts_adjusted[i]
866 node.add_prop(state_count_feature, state_counts)
868 for i in range(n_states):
869 result[i, i] -= np.minimum(ps_counts_initial[i], same_state_counts[i])
871 return result / n_repetitions
874def optimise_likelihood(forest, character, model, observed_frequencies):
875 for tree in forest:
876 initialize_allowed_states(tree, character, model.states)
877 logger = logging.getLogger('pastml')
878 likelihood = sum(get_bottom_up_loglikelihood(tree=tree, character=character,
879 is_marginal=True, model=model, alter=True)
880 for tree in forest)
881 if np.isnan(likelihood):
882 raise PastMLLikelihoodError('Failed to calculate the likelihood for your tree, '
883 'please check that you do not have contradicting {} states specified '
884 'for internal tree nodes, '
885 'and if not - submit a bug at https://github.com/evolbioinfo/pastml/issues'
886 .format(character))
887 if not model.get_num_params():
888 logger.debug('All the parameters are fixed for {}:\n{}{}.'
889 .format(character,
890 model._print_parameters,
891 '\tlog likelihood:\t{:.6f}'.format(likelihood))
892 )
893 else:
894 logger.debug('Initial values for {} parameter optimisation:\n{}{}.'
895 .format(character,
896 model._print_parameters,
897 '\tlog likelihood:\t{:.6f}'.format(likelihood))
898 )
899 if not model.basic_params_fixed():
900 model.fix_extra_params()
901 likelihood = optimize_likelihood_params(forest=forest, character=character, model=model,
902 observed_frequencies=observed_frequencies)
903 if np.any(np.isnan(likelihood) or likelihood == -np.inf):
904 raise PastMLLikelihoodError('Failed to optimise the likelihood for your tree, '
905 'please check that you do not have contradicting {} states specified '
906 'for internal tree nodes, '
907 'and if not - submit a bug at https://github.com/evolbioinfo/pastml/issues'
908 .format(character))
909 model.unfix_extra_params()
910 if not model.extra_params_fixed():
911 logger.debug('Pre-optimised basic parameters for {}:\n{}{}.'
912 .format(character,
913 model._print_basic_parameters(),
914 '\tlog likelihood:\t{:.6f}'.format(likelihood)))
915 if not model.extra_params_fixed():
916 likelihood = \
917 optimize_likelihood_params(forest=forest, character=character, model=model,
918 observed_frequencies=observed_frequencies)
919 if np.any(np.isnan(likelihood) or likelihood == -np.inf):
920 raise PastMLLikelihoodError('Failed to calculate the likelihood for your tree, '
921 'please check that you do not have contradicting {} states specified '
922 'for internal tree nodes, '
923 'and if not - submit a bug at https://github.com/evolbioinfo/pastml/issues'
924 .format(character))
925 logger.debug('Optimised parameters for {}:\n{}{}'
926 .format(character,
927 model._print_parameters,
928 '\tlog likelihood:\t{:.6f}'.format(likelihood)))
929 return likelihood
932def convert_allowed_states2feature(tree, feature, states, out_feature=None):
933 if out_feature is None:
934 out_feature = feature
935 allowed_states_feature = get_personalized_feature_name(feature, ALLOWED_STATES)
936 for node in tree.traverse():
937 node.add_prop(out_feature, set(states[node.props.get(allowed_states_feature).astype(bool)]))
940def _parsimonious_states2allowed_states(tree, ps_feature, feature, states):
941 n = len(states)
942 state2index = dict(zip(states, range(n)))
943 allowed_state_feature = get_personalized_feature_name(feature, ALLOWED_STATES)
944 for node in tree.traverse():
945 pars_states = node.props.get(ps_feature)
946 allowed_states = np.zeros(n, dtype=int)
947 for state in pars_states:
948 allowed_states[state2index[state]] = 1
949 node.add_prop(allowed_state_feature, allowed_states)
952class PastMLLikelihoodError(Exception):
954 def __init__(self, *args):
955 self.message = args[0] if args else None
957 def __str__(self):
958 if self.message:
959 return 'PastMLLikelihoodError, {}'.format(self.message)
960 else:
961 return 'PastMLLikelihoodError has been raised.'