Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/models/CustomRatesModel.py: 31%

48 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-21 09:19 +0100

1import logging 

2 

3import numpy as np 

4 

5from pastml.models import ModelWithFrequencies 

6from pastml.models.generator import get_diagonalisation, get_pij_matrix 

7 

8CUSTOM_RATES = 'CUSTOM_RATES' 

9 

10 

11def load_custom_rates(infile): 

12 rate_matrix = np.loadtxt(infile, dtype=np.float64, comments='#', delimiter=' ') 

13 if not len(rate_matrix.shape) == 2 or not rate_matrix.shape[0] == rate_matrix.shape[1]: 

14 raise ValueError('The input rate matrix must be squared, but yours is {}.'.format('x'.join(rate_matrix.shape))) 

15 if not np.all(rate_matrix == rate_matrix.transpose()): 

16 raise ValueError('The input rate matrix must be symmetric, but yours is not.') 

17 np.fill_diagonal(rate_matrix, 0) 

18 n = len(rate_matrix) 

19 if np.count_nonzero(rate_matrix) != n * (n - 1): 

20 logging.getLogger('pastml').warning('The rate matrix contains zero rates (apart from the diagonal).') 

21 with open(infile, 'r') as f: 

22 states = f.readlines()[0] 

23 if not states.startswith('#'): 

24 raise ValueError('The rate matrix file should start with state names, ' 

25 'separated by whitespaces and preceded by # .') 

26 states = np.array(states.strip('#').strip('\n').strip().split(' '), dtype=str) 

27 if len(states) != n: 

28 raise ValueError( 

29 'The number of specified state names ({}) does not correspond to the rate matrix dimensions ({}x{}).' 

30 .format(len(states), *rate_matrix.shape)) 

31 new_order = np.argsort(states) 

32 return states[new_order], np.array(rate_matrix)[:, new_order][new_order, :] 

33 

34 

35class CustomRatesModel(ModelWithFrequencies): 

36 

37 def __init__(self, forest_stats, sf=None, frequencies=None, rate_matrix_file=None, states=None, rate_matrix=None, tau=0, 

38 optimise_tau=False, frequency_smoothing=False, parameter_file=None, reoptimise=False, **kwargs): 

39 ModelWithFrequencies.__init__(self, states=states, forest_stats=forest_stats, 

40 sf=sf, tau=tau, frequencies=frequencies, optimise_tau=optimise_tau, 

41 frequency_smoothing=frequency_smoothing, reoptimise=reoptimise, 

42 parameter_file=parameter_file, **kwargs) 

43 self.name = CUSTOM_RATES 

44 if rate_matrix_file is None and (rate_matrix is None or states is None): 

45 raise ValueError('Either the rate matrix file ' 

46 'or the rate matrix plus the states must be specified for {} model'.format(CUSTOM_RATES)) 

47 if rate_matrix_file is None: 

48 # States are already set in the super constructor 

49 self._rate_matrix = rate_matrix 

50 else: 

51 self._states, self._rate_matrix = load_custom_rates(rate_matrix_file) 

52 self.D_DIAGONAL, self.A, self.A_INV = get_diagonalisation(self.frequencies, self._rate_matrix) 

53 

54 @property 

55 def rate_matrix(self): 

56 return self._rate_matrix 

57 

58 @rate_matrix.setter 

59 def rate_matrix(self, rate_matrix): 

60 raise NotImplementedError('The rate matrix is preset and cannot be changed.') 

61 

62 @ModelWithFrequencies.frequencies.setter 

63 def frequencies(self, frequencies): 

64 if self._optimise_frequencies or self._frequency_smoothing: 

65 self._frequencies = frequencies 

66 else: 

67 raise NotImplementedError('The frequencies are preset and cannot be changed.') 

68 self.D_DIAGONAL, self.A, self.A_INV = get_diagonalisation(frequencies, self.rate_matrix) 

69 

70 def get_Pij_t(self, t, *args, **kwargs): 

71 """ 

72 Returns a function of t that calculates the probability matrix of substitutions i->j over time t, 

73 with the given rate matrix. 

74 

75 :return: a function of t that calculates the probability matrix of substitutions i->j over time t. 

76 :rtype: lambda t: np.array 

77 """ 

78 return get_pij_matrix(self.transform_t(t), self.D_DIAGONAL, self.A, self.A_INV) 

79