Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/models/__init__.py: 28%
236 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
1import logging
2import os
4import numpy as np
5import pandas as pd
7from pastml import NUM_NODES, NUM_TIPS
10MODEL = 'model'
11CHANGES_PER_AVG_BRANCH = 'state_changes_per_avg_branch'
12SCALING_FACTOR = 'scaling_factor'
13SMOOTHING_FACTOR = 'smoothing_factor'
14FREQUENCIES = 'frequencies'
17class Model(object):
19 def __init__(self, states, forest_stats, sf=None, tau=0, optimise_tau=False,
20 parameter_file=None, reoptimise=False, **kwargs):
21 self._name = None
22 self._states = np.sort(states)
23 self._forest_stats = forest_stats
24 self._optimise_tau = optimise_tau
25 self._optimise_sf = True
26 self._sf = None
27 self._tau = None
28 self.parse_parameters(parameter_file, reoptimise)
29 if self._sf is None:
30 self._sf = sf if sf is not None else 1. / forest_stats.avg_nonzero_brlen
31 if self._tau is None:
32 self._tau = tau if tau else 0
34 self.calc_tau_factor()
36 self._extra_params_fixed = False
38 def calc_tau_factor(self):
39 self._tau_factor = \
40 self._forest_stats.forest_length / (self._forest_stats.forest_length
41 + self._tau * (self._forest_stats.num_nodes - 1)) if self._tau else 1
43 def __str__(self):
44 return \
45 'Model {} with parameter values:\n' \
46 '{}'.format(self.name,
47 self._print_parameters())
49 def _print_parameters(self):
50 """
51 Constructs a string representing parameter values (to be used to logging).
53 :return: str representing parameter values
54 """
55 return self._print_basic_parameters()
57 def _print_basic_parameters(self):
58 return '\tscaling factor:\t{:.6f}, i.e. {:.6f} changes per avg branch\t{}\n' \
59 '\tsmoothing factor:\t{:.6f}\t{}\n'\
60 .format(self.sf, self.forest_stats.avg_nonzero_brlen * self.sf,
61 '(optimised)' if self._optimise_sf else '(fixed)',
62 self.tau, '(optimised)' if self._optimise_tau else '(fixed)')
64 def save_parameters(self, filehandle):
65 """
66 Writes this model parameter values to the parameter file (in the same format as the input parameter file).
68 :param filehandle: filehandle for the file where the parameter values should be written.
69 :return: void
70 """
71 filehandle.write('{}\t{}\n'.format(MODEL, self.name))
72 filehandle.write('{}\t{}\n'.format(NUM_NODES, self.forest_stats.num_nodes))
73 filehandle.write('{}\t{}\n'.format(NUM_TIPS, self.forest_stats.num_tips))
74 filehandle.write('{}\t{}\n'.format(SCALING_FACTOR, self.sf))
75 filehandle.write('{}\t{}\n'.format(CHANGES_PER_AVG_BRANCH, self.sf * self.forest_stats.avg_nonzero_brlen))
76 filehandle.write('{}\t{}\n'.format(SMOOTHING_FACTOR, self.tau))
78 def extra_params_fixed(self):
79 return self._extra_params_fixed
81 def fix_extra_params(self):
82 self._extra_params_fixed = True
84 def unfix_extra_params(self):
85 self._extra_params_fixed = False
87 @property
88 def forest_stats(self):
89 return self._forest_stats
91 @forest_stats.setter
92 def forest_stats(self, forest_stats):
93 self._forest_stats = forest_stats
94 self.calc_tau_factor()
96 @property
97 def name(self):
98 return self._name
100 @name.setter
101 def name(self, name):
102 self._name = name
104 @property
105 def states(self):
106 return self._states
108 @states.setter
109 def states(self, states):
110 self._states = states
112 @property
113 def sf(self):
114 return self._sf
116 @sf.setter
117 def sf(self, sf):
118 if self._optimise_sf:
119 self._sf = sf
120 else:
121 raise NotImplementedError('The scaling factor is preset and cannot be changed.')
123 @property
124 def tau(self):
125 return self._tau
127 @tau.setter
128 def tau(self, tau):
129 if self._optimise_tau:
130 self._tau = tau
131 self.calc_tau_factor()
132 else:
133 raise NotImplementedError('Tau is preset and cannot be changed.')
135 def get_Pij_t(self, t, *args, **kwargs):
136 """
137 Returns a function for calculation of probability matrix of substitutions i->j over time t.
139 :return: probability matrix
140 :rtype: function
141 """
142 raise NotImplementedError("Please implement this method in the Model subclass")
144 def set_params_from_optimised(self, ps, **kwargs):
145 """
146 Update this model parameter values from a vector representing parameters
147 for the likelihood optimization algorithm.
149 :param ps: np.array containing parameters of the likelihood optimization algorithm
150 :param kwargs: dict of eventual other arguments
151 :return: void, update this model
152 """
153 if self._optimise_sf:
154 self.sf = ps[0]
155 if self._optimise_tau:
156 self.tau = ps[1 if self._optimise_sf else 0]
158 def get_optimised_parameters(self):
159 """
160 Converts this model parameters to a vector representing parameters
161 for the likelihood optimization algorithm.
163 :return: np.array containing parameters of the likelihood optimization algorithm
164 """
165 return np.hstack(([self.sf] if self._optimise_sf else [],
166 [self.tau] if self._optimise_tau else []))
168 def get_bounds(self):
169 """
170 Get bounds for parameters for likelihood optimization algorithm.
172 :return: np.array containing lower and upper (potentially infinite) bounds for each parameter
173 """
174 bounds = []
175 if self._optimise_sf:
176 bounds += [np.array([0.001 / self.forest_stats.avg_nonzero_brlen,
177 10. / self.forest_stats.avg_nonzero_brlen])]
178 if self._optimise_tau:
179 bounds += [np.array([0, self.forest_stats.avg_nonzero_brlen])]
180 return np.array(bounds, np.float64)
182 def get_num_params(self):
183 """
184 Returns the number of optimized parameters for this model.
186 :return: the number of optimized parameters
187 """
188 return (1 if self._optimise_sf else 0) + (1 if self._optimise_tau else 0)
190 def parse_parameters(self, params, reoptimise=False):
191 """
192 Update this model's values from the input parameters.
193 The input might contain:
194 (1) the scaling factor, by which each branch length will be multiplied (optional).
195 The key for this parameter is pastml.models.SCALING_FACTOR;
196 (2) the smoothing factor, which will be added to each branch length
197 before the branches are renormalized to keep the initial tree length (optional).
198 The key for this parameter is pastml.models.SMOOTHING_FACTOR;
200 :param params: dict {key->value}
201 or a path to the file containing a tab-delimited table with the first column containing keys
202 and the second (named 'value') containing values.
203 :param reoptimise: whether these model parameters should be treated as starting values (True)
204 or as fixed values (False)
205 :return: dict with parameter values
206 """
208 logger = logging.getLogger('pastml')
209 frequencies, sf, kappa, tau = None, None, None, None
210 if params is None:
211 return {}
212 if not isinstance(params, str) and not isinstance(params, dict):
213 raise ValueError('Parameters must be specified either as a dict or as a path to a csv file, not as {}!'
214 .format(type(params)))
215 if isinstance(params, str):
216 if not os.path.exists(params):
217 raise ValueError('The specified parameter file ({}) does not exist.'.format(params))
218 try:
219 param_dict = pd.read_csv(params, header=0, index_col=0, sep='\t')
220 if 'value' not in param_dict.columns:
221 raise ValueError('Could not find the "value" column in the parameter file {}. '
222 'It should be a tab-delimited file with two columns, '
223 'the first one containing parameter names, '
224 'and the second, named "value", containing parameter values.')
225 param_dict = param_dict.to_dict()['value']
226 params = param_dict
227 except:
228 raise ValueError('The specified parameter file {} is malformed, '
229 'should be a tab-delimited file with two columns, '
230 'the first one containing parameter names, '
231 'and the second, named "value", containing parameter values.'.format(params))
232 params = {str(k.encode('ASCII', 'replace').decode()): v for (k, v) in params.items()}
233 if SCALING_FACTOR in params:
234 self._sf = params[SCALING_FACTOR]
235 try:
236 self._sf = np.float64(self._sf)
237 if self._sf <= 0:
238 logger.error('Scaling factor cannot be negative, ignoring the value given in parameters ({}).'
239 .format(sf))
240 self._sf = None
241 else:
242 self._optimise_sf = reoptimise
243 except:
244 logger.error('Scaling factor ({}) given in parameters is not float, ignoring it.'.format(sf))
245 self._sf = None
246 if SMOOTHING_FACTOR in params:
247 self._tau = params[SMOOTHING_FACTOR]
248 try:
249 self._tau = np.float64(self._tau)
250 if self._tau < 0:
251 logger.error(
252 'Smoothing factor cannot be negative, ignoring the value given in parameters ({}).'.format(tau))
253 self._tau = None
254 except:
255 logger.error('Smoothing factor ({}) given in parameters is not float, ignoring it.'.format(tau))
256 self._tau = None
257 return params
259 def freeze(self):
260 """
261 Prohibit parameter optimization by setting all optimization flags to False.
263 :return: void
264 """
265 self._optimise_sf = False
266 self._optimise_tau = False
268 def transform_t(self, t):
269 return (t + self.tau) * self._tau_factor * self.sf
271 def basic_params_fixed(self):
272 return not self._optimise_tau and not self._optimise_sf
275class ModelWithFrequencies(Model):
277 def __init__(self, states, forest_stats, sf=None, frequencies=None, tau=0,
278 optimise_tau=False, frequency_smoothing=False, parameter_file=None, reoptimise=False, **kwargs):
279 self._frequencies = None
280 self._optimise_frequencies = not frequency_smoothing
281 self._frequency_smoothing = frequency_smoothing
282 Model.__init__(self, states, forest_stats=forest_stats,
283 sf=sf, tau=tau, optimise_tau=optimise_tau, reoptimise=reoptimise,
284 parameter_file=parameter_file, **kwargs)
285 if self._frequencies is None:
286 self._frequencies = frequencies if frequencies is not None \
287 else np.ones(len(states), dtype=np.float64) / len(states)
289 @property
290 def frequencies(self):
291 return self._frequencies
293 @frequencies.setter
294 def frequencies(self, frequencies):
295 if self._optimise_frequencies or self._frequency_smoothing:
296 self._frequencies = frequencies
297 else:
298 raise NotImplementedError('The frequencies are preset and cannot be changed.')
300 def get_num_params(self):
301 """
302 Returns the number of optimized parameters for this model.
304 :return: the number of optimized parameters
305 """
306 return Model.get_num_params(self) \
307 + ((len(self.frequencies) - 1)
308 if self._optimise_frequencies else (1 if self._frequency_smoothing else 0))
310 def set_params_from_optimised(self, ps, **kwargs):
311 """
312 Update this model parameter values from a vector representing parameters
313 for the likelihood optimization algorithm.
315 :param ps: np.array containing parameters of the likelihood optimization algorithm
316 :param kwargs: dict of eventual other arguments
317 :return: void, update this model
318 """
319 Model.set_params_from_optimised(self, ps, **kwargs)
320 if not self.extra_params_fixed():
321 n_freq = len(self.frequencies)
322 n_params = Model.get_num_params(self)
324 freqs = self.frequencies
326 if self._optimise_frequencies:
327 freqs = np.hstack((ps[n_params: n_params + (n_freq - 1)], [1.]))
328 freqs /= freqs.sum()
329 self.frequencies = freqs
330 elif self._frequency_smoothing:
331 freqs = freqs * self.forest_stats.num_tips
332 freqs += ps[n_params]
333 freqs /= freqs.sum()
334 self.frequencies = freqs
336 def get_optimised_parameters(self):
337 """
338 Converts this model parameters to a vector representing parameters
339 for the likelihood optimization algorithm.
341 :return: np.array containing parameters of the likelihood optimization algorithm
342 """
343 if not self.extra_params_fixed():
344 return np.hstack((Model.get_optimised_parameters(self),
345 self.frequencies[:-1] / self.frequencies[-1] if self._optimise_frequencies
346 else ([0] if self._frequency_smoothing else [])))
347 return Model.get_optimised_parameters(self)
349 def get_bounds(self):
350 """
351 Get bounds for parameters for likelihood optimization algorithm.
353 :return: np.array containing lower and upper (potentially infinite) bounds for each parameter
354 """
355 if not self.extra_params_fixed():
356 extras = []
357 if self._optimise_frequencies:
358 extras += [np.array([1e-6, 10e6], np.float64)] * (len(self.frequencies) - 1)
359 if self._frequency_smoothing:
360 extras.append(np.array([0, self.forest_stats.num_nodes]))
361 return np.array((*Model.get_bounds(self), *extras))
362 return Model.get_bounds(self)
364 def parse_parameters(self, params, reoptimise=False):
365 """
366 Update this model's values from the input parameters.
367 For a model with frequencies, apart from the basic parameters
368 (scaling factor and smoothing factor, see pastml.models.Model),
369 the input might contain the frequency values:
370 the key for each frequency value is the name of the corresponding character state.
372 :param params: dict {key->value}
373 :param reoptimise: whether these model parameters should be treated as starting values (True)
374 or as fixed values (False)
375 :return: dict with parameter values (same as input)
376 """
377 params = Model.parse_parameters(self, params, reoptimise)
378 logger = logging.getLogger('pastml')
379 known_freq_states = set(self.states) & set(params.keys())
380 if known_freq_states:
381 unknown_freq_states = [state for state in self.states if state not in params.keys()]
382 if unknown_freq_states and not reoptimise:
383 logger.error('Frequencies for some of the states ({}) are missing, '
384 'ignoring the specified frequencies.'.format(', '.join(unknown_freq_states)))
385 else:
386 self._frequencies = np.array([params[state] if state in params.keys() else 0 for state in self.states])
387 try:
388 self._frequencies = self._frequencies.astype(np.float64)
389 if np.round(self._frequencies.sum() - 1, 2) != 0 and not reoptimise:
390 logger.error('Frequencies given in parameters ({}) do not sum up to one ({}),'
391 'ignoring them.'.format(self._frequencies, self._frequencies.sum()))
392 self._frequencies = None
393 else:
394 if np.any(self._frequencies < 0) and not reoptimise:
395 logger.error('Some of the frequencies given in parameters ({}) are negative,'
396 'ignoring them.'.format(self._frequencies))
397 self._frequencies = None
398 else:
399 min_freq = \
400 min(1 / self.forest_stats.num_tips,
401 min(float(params[state]) for state in known_freq_states
402 if float(params[state]) > 0)) / 2
403 if unknown_freq_states:
404 logger.error('Frequencies for some of the states ({}) are missing from parameters, '
405 'setting them to {}.'.format(', '.join(unknown_freq_states), min_freq))
406 frequencies = np.maximum(self._frequencies, min_freq)
407 frequencies /= frequencies.sum()
408 self._optimise_frequencies = reoptimise and not self._frequency_smoothing
409 except:
410 logger.error('Could not convert the frequencies given in parameters ({}) to float, '
411 'ignoring them.'.format(self._frequencies))
412 self._frequencies = None
413 return params
415 def _print_parameters(self):
416 """
417 Constructs a string representing parameter values (to be used to logging).
419 :return: str representing parameter values
420 """
421 return '{}' \
422 '\tfrequencies\t{}\n' \
423 '{}\n'.format(Model._print_parameters(self),
424 '(optimised)' if self._optimise_frequencies
425 else '(smoothed)' if self._frequency_smoothing else '(fixed)',
426 '\n'.join('\t\t{}:\t{:g}'.format(state, freq)
427 for (state, freq) in zip(self.states, self.frequencies)))
429 def freeze(self):
430 """
431 Prohibit parameter optimization by setting all optimization flags to False.
433 :return: void
434 """
435 Model.freeze(self)
436 self._optimise_frequencies = False
437 self._frequency_smoothing = False
439 def extra_params_fixed(self):
440 return self._extra_params_fixed or Model.get_num_params(self) == self.get_num_params()
442 def basic_params_fixed(self):
443 return not Model.get_num_params(self)
445 def save_parameters(self, filehandle):
446 """
447 Writes this model parameter values to the parameter file (in the same format as the input parameter file).
449 :param filehandle: filehandle for the file where the parameter values should be written.
450 :return: void
451 """
452 Model.save_parameters(self, filehandle)
453 for state, frequency in zip(self.states, self.frequencies):
454 filehandle.write('{}\t{}\n'.format(state, frequency))