Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/models/CustomRatesModel.py: 31%
48 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-21 09:19 +0100
1import logging
3import numpy as np
5from pastml.models import ModelWithFrequencies
6from pastml.models.generator import get_diagonalisation, get_pij_matrix
8CUSTOM_RATES = 'CUSTOM_RATES'
11def load_custom_rates(infile):
12 rate_matrix = np.loadtxt(infile, dtype=np.float64, comments='#', delimiter=' ')
13 if not len(rate_matrix.shape) == 2 or not rate_matrix.shape[0] == rate_matrix.shape[1]:
14 raise ValueError('The input rate matrix must be squared, but yours is {}.'.format('x'.join(rate_matrix.shape)))
15 if not np.all(rate_matrix == rate_matrix.transpose()):
16 raise ValueError('The input rate matrix must be symmetric, but yours is not.')
17 np.fill_diagonal(rate_matrix, 0)
18 n = len(rate_matrix)
19 if np.count_nonzero(rate_matrix) != n * (n - 1):
20 logging.getLogger('pastml').warning('The rate matrix contains zero rates (apart from the diagonal).')
21 with open(infile, 'r') as f:
22 states = f.readlines()[0]
23 if not states.startswith('#'):
24 raise ValueError('The rate matrix file should start with state names, '
25 'separated by whitespaces and preceded by # .')
26 states = np.array(states.strip('#').strip('\n').strip().split(' '), dtype=str)
27 if len(states) != n:
28 raise ValueError(
29 'The number of specified state names ({}) does not correspond to the rate matrix dimensions ({}x{}).'
30 .format(len(states), *rate_matrix.shape))
31 new_order = np.argsort(states)
32 return states[new_order], np.array(rate_matrix)[:, new_order][new_order, :]
35class CustomRatesModel(ModelWithFrequencies):
37 def __init__(self, forest_stats, sf=None, frequencies=None, rate_matrix_file=None, states=None, rate_matrix=None, tau=0,
38 optimise_tau=False, frequency_smoothing=False, parameter_file=None, reoptimise=False, **kwargs):
39 ModelWithFrequencies.__init__(self, states=states, forest_stats=forest_stats,
40 sf=sf, tau=tau, frequencies=frequencies, optimise_tau=optimise_tau,
41 frequency_smoothing=frequency_smoothing, reoptimise=reoptimise,
42 parameter_file=parameter_file, **kwargs)
43 self.name = CUSTOM_RATES
44 if rate_matrix_file is None and (rate_matrix is None or states is None):
45 raise ValueError('Either the rate matrix file '
46 'or the rate matrix plus the states must be specified for {} model'.format(CUSTOM_RATES))
47 if rate_matrix_file is None:
48 # States are already set in the super constructor
49 self._rate_matrix = rate_matrix
50 else:
51 self._states, self._rate_matrix = load_custom_rates(rate_matrix_file)
52 self.D_DIAGONAL, self.A, self.A_INV = get_diagonalisation(self.frequencies, self._rate_matrix)
54 @property
55 def rate_matrix(self):
56 return self._rate_matrix
58 @rate_matrix.setter
59 def rate_matrix(self, rate_matrix):
60 raise NotImplementedError('The rate matrix is preset and cannot be changed.')
62 @ModelWithFrequencies.frequencies.setter
63 def frequencies(self, frequencies):
64 if self._optimise_frequencies or self._frequency_smoothing:
65 self._frequencies = frequencies
66 else:
67 raise NotImplementedError('The frequencies are preset and cannot be changed.')
68 self.D_DIAGONAL, self.A, self.A_INV = get_diagonalisation(frequencies, self.rate_matrix)
70 def get_Pij_t(self, t, *args, **kwargs):
71 """
72 Returns a function of t that calculates the probability matrix of substitutions i->j over time t,
73 with the given rate matrix.
75 :return: a function of t that calculates the probability matrix of substitutions i->j over time t.
76 :rtype: lambda t: np.array
77 """
78 return get_pij_matrix(self.transform_t(t), self.D_DIAGONAL, self.A, self.A_INV)