Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/file.py: 46%
39 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
1import os
3from pastml.parsimony import is_meta_mp, get_default_mp_method
4from pastml import col_name2cat, get_personalized_feature_name
5from pastml.ml import is_ml, is_marginal, is_meta_ml, get_default_ml_method
7PASTML_WORK_DIR = '{tree}_pastml'
9COMBINED_ANCESTRAL_STATE_TAB = 'combined_ancestral_states.tab'
10NAMED_TREE_NWK = 'named.tree_{tree}.nwk'
12PASTML_ML_PARAMS_TAB = 'params.character_{state}.method_{method}.model_{model}.tab'
13PASTML_COLOUR_TAB = 'colours.character_{state}.tab'
14PASTML_MP_PARAMS_TAB = 'params.character_{state}.method_{method}.tab'
15PASTML_MARGINAL_PROBS_TAB = 'marginal_probabilities.character_{state}.model_{model}.tab'
18def get_column_method(column, method):
19 column = col_name2cat(column)
20 if is_meta_ml(method):
21 method = get_default_ml_method()
22 elif is_meta_mp(method):
23 method = get_default_mp_method()
24 else:
25 return column, method
26 return get_personalized_feature_name(column, method), method
29def get_pastml_parameter_file(method, model, column):
30 """
31 Get the filename where the PastML parameters are saved
32 (for non-ML methods and input parameters will be None, as they have no parameters).
33 This file is inside the work_dir that can be specified for the pastml_pipeline method.
35 :param method: str, the ancestral state prediction method used by PASTML.
36 :param model: str, the state evolution model used by PASTML.
37 :param column: str, the column for which ancestral states are reconstructed with PASTML.
38 :return: str, filename or None for non-ML methods
39 """
40 ml = is_ml(method)
41 template = PASTML_ML_PARAMS_TAB if ml else PASTML_MP_PARAMS_TAB
42 column, method = get_column_method(column, method)
43 return template.format(state=column, method=method, model=model)
46def get_pastml_colour_file(column):
47 """
48 Get the filename where the PastML colours used for visualisation are saved.
49 This file is inside the work_dir that can be specified for the pastml_pipeline method.
51 :param column: str, the column for which ancestral states are reconstructed with PASTML.
52 :return: str, filename
53 """
54 template = PASTML_COLOUR_TAB
55 return template.format(state=column)
58def get_combined_ancestral_state_file():
59 """
60 Get the filename where the combined ancestral states are saved (for one or several columns).
61 This file is inside the work_dir that can be specified for the pastml_pipeline method.
63 :return: str, filename
64 """
65 return COMBINED_ANCESTRAL_STATE_TAB
68def get_pastml_work_dir(tree):
69 """
70 Get the pastml work dir path.
72 :param tree: str, path to the input tree.
73 :return: str, filename
74 """
75 return PASTML_WORK_DIR.format(tree=os.path.splitext(tree)[0])
78def get_named_tree_file(tree):
79 """
80 Get the filename where the PastML tree (input tree but named and with collapsed zero branches) is saved.
81 This file is inside the work_dir that can be specified for the pastml_pipeline method.
83 :param tree: str, the input tree in newick format.
84 :return: str, filename
85 """
86 tree_name = os.path.splitext(os.path.basename(tree))[0]
87 return NAMED_TREE_NWK.format(tree=tree_name if tree_name else 'tree')
90def get_pastml_marginal_prob_file(method, model, column):
91 """
92 Get the filename where the PastML marginal probabilities of node states are saved (will be None for non-marginal methods).
93 This file is inside the work_dir that can be specified for the pastml_pipeline method.
95 :param method: str, the ancestral state prediction method used by PASTML.
96 :param model: str, the state evolution model used by PASTML.
97 :param column: str, the column for which ancestral states are reconstructed with PASTML.
98 :return: str, filename or None if the method is not marginal.
99 """
100 if not is_marginal(method):
101 return None
102 column, method = get_column_method(column, method)
103 return PASTML_MARGINAL_PROBS_TAB.format(state=column, model=model)