Coverage for MPT/kernel.py: 91%

97 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-09 15:48 +0200

1import numpy as np 

2import scipy as scy 

3 

4import MPT.utils as utils 

5 

6__all__ = [ 

7 "MPTKernel", 

8 "FeatureKernel", 

9] 

10 

11 

12### MERGING KERNEL ########################################################### 

13 

14 

15class MPTKernel(object): 

16 def __init__(self, method="n", param=1, similarity="T"): 

17 """ 

18 Kernel for the most probable path (MPP) algorithm. 

19 

20 This object holds the parameters of the lumping and analyzes the 

21 full transition matrix of the lumping based on a mask of not yet 

22 merged states, upon calling. 

23 

24 Parameters 

25 ---------- 

26 method : str 

27 'n' : Consider <param> most similar options. 

28 'p' : Consider as many most similar options as needed to 

29 represent <param> similarity. For similarity 'T', 

30 <param>=0.5 means that at least 50% of the transitions 

31 to other states must be considered. 

32 param : int|float 

33 for 'n' : Number of most similar options to consider 

34 (1 deterministic lumping). 

35 for 'p' : Accumulated similarity threshold for most similar 

36 states to consider. 

37 similarity: 

38 - T: Utilize the transition probabilities as dynamic 

39 metric. 

40 - KL: Utilize the Kullback-Leibler divergence between the 

41 transition probabilities of the options. 

42 - none: Utilize only the feature as similarity measure. 

43 

44 Notes 

45 ----- 

46 The similarity between two states may be composed of a dynamic 

47 similarity (defined in this object, c.f. parameter 

48 <similarity>) and / or a geometric similarity, which is 

49 determined by the feature kernel (passed at call). 

50 """ 

51 self.method = method 

52 self.param = param 

53 self.similarity = similarity 

54 

55 def __call__(self, full_tmat, states_not_merged, mask, feature_kernel=None): 

56 """ 

57 Finds the states to be lumped together next. 

58 

59 The least metastable state is selected and the most similar 

60 other state is determined. The similarity of two states is 

61 determined by the parameters of the object and the feature 

62 kernel. 

63 """ 

64 # Select state with least self transition probability 

65 mask_state = np.argmin(np.diag(full_tmat)[mask]) 

66 # Get correct state index 

67 state = states_not_merged[mask][mask_state][0] 

68 

69 mask[state] = False 

70 trans_probs = full_tmat[state][mask] 

71 trans_probs /= trans_probs.sum() 

72 

73 # If Kullback-Leibler divergence is used 

74 if self.similarity == "KL": 

75 # Mask self transition probabilities 

76 t = full_tmat[mask][:, mask].copy() 

77 np.fill_diagonal(t, trans_probs) 

78 

79 # Regularization parameter 

80 epsilon = 1e-6 

81 kl = scy.stats.entropy( 

82 trans_probs + epsilon, 

83 t + epsilon, 

84 axis=1, 

85 ) 

86 dkl = utils.weighting_function(kl) 

87 if dkl.shape[0] > 1: 

88 dkl -= dkl.min() 

89 dkl /= dkl.sum() 

90 trans_probs = dkl 

91 elif self.similarity == "none": 

92 trans_probs = 1 

93 

94 # Apply feature kernel, if there is one 

95 if feature_kernel: 

96 feature = feature_kernel.apply(full_tmat[state], state, mask) 

97 if not isinstance(feature, np.ndarray): # and feature == 0: 

98 feature = np.array([1.0]) 

99 trans_probs *= feature 

100 

101 trans_probs = np.nan_to_num(trans_probs, copy=False, nan=1e-6) 

102 

103 # transitions contains indices for masked tmat 

104 transitions = np.argsort(trans_probs)[::-1] 

105 # consider n most similar options 

106 if self.method == "n": 

107 options = list(range(self.param))[: trans_probs.shape[0]] 

108 # consider as many most probable options until they sum up to param 

109 elif self.method == "p": 109 ↛ 117line 109 didn't jump to line 117 because the condition on line 109 was always true

110 t_prob_norm = trans_probs / trans_probs.sum() 

111 options = [0] 

112 while t_prob_norm[options].sum() <= self.param and len( 

113 options 

114 ) < np.count_nonzero(trans_probs): 

115 options.append(options[-1] + 1) 

116 else: 

117 raise ValueError("Method must be either 'p' or 'n'") 

118 

119 # Get similarities of the options 

120 p_options = trans_probs[transitions[options]] 

121 

122 # Select the target state 

123 mask_target_state = np.random.choice( 

124 transitions[options], p=p_options / sum(p_options) 

125 ) 

126 

127 target_state = states_not_merged[mask][mask_target_state][0] 

128 return state, target_state, mask 

129 

130 def __repr__(self): 

131 return "<class MPTKernel>" 

132 

133 

134### FEATURE KERNEL ########################################################### 

135 

136 

137# TODO: 

138# Remove similarity entirely from feature kernel 

139class MultiFeatureKernel(object): 

140 def __init__( 

141 self, 

142 feature_traj, 

143 microstate_traj, 

144 feature_type=np.float64, 

145 traj_type=np.uint16, 

146 similarity="JS", 

147 ): 

148 """ 

149 feature_traj: either N or NxM, N being the number of frames and M the 

150 number of features 

151 """ 

152 if feature_traj.ndim == 2: 152 ↛ 155line 152 didn't jump to line 155 because the condition on line 152 was always true

153 self.feature_traj = feature_traj.astype(feature_type) 

154 else: 

155 raise ValueError("featuretraj must be a 2 D array.") 

156 

157 self._init_feature(microstate_traj.astype(traj_type)) 

158 

159 def __repr__(self): 

160 return "<class MultiFeatureKernel>" 

161 

162 def _init_feature(self, microstate_traj): 

163 states, pop = np.unique(microstate_traj, return_counts=True) 

164 self.n_states = states.shape[0] 

165 # Populations for all states incl intermediate states 

166 self.full_pop = np.zeros(2 * self.n_states - 1, dtype=np.uint32) 

167 self.full_pop[: self.n_states] = pop 

168 # corresponding feature values 

169 self.full_feature = np.zeros( 

170 (2 * self.n_states - 1, self.feature_traj.shape[1]), 

171 dtype=self.feature_traj.dtype.type, 

172 ) 

173 for i in range(self.n_states): 

174 self.full_feature[i] = self.feature_traj[microstate_traj == i + 1].mean( 

175 axis=0 

176 ) 

177 

178 def reset(self): 

179 self.full_pop[self.n_states :] = 0 

180 self.full_feature[self.n_states :] = 0 

181 

182 def apply(self, trans_prob, state, mask): 

183 f = self.js(state, mask) 

184 f -= f.min() 

185 if f.sum() != 0: 

186 return f / f.sum() 

187 else: 

188 return 0 

189 

190 def update(self, origin, target, new_state): 

191 self.full_pop[new_state] = self.full_pop[[origin, target]].sum() 

192 self.full_feature[new_state] = ( 

193 self.full_feature[origin] * self.full_pop[origin] 

194 + self.full_feature[target] * self.full_pop[target] 

195 ) / self.full_pop[new_state] 

196 

197 def js(self, state, mask): 

198 p = self.full_feature[state] 

199 q = self.full_feature[mask] 

200 if p.ndim == 1: 200 ↛ 202line 200 didn't jump to line 202 because the condition on line 200 was always true

201 p = np.expand_dims(p, axis=0) 

202 if q.ndim == 1: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true

203 q = np.expand_dims(q, axis=0) 

204 djs = scy.spatial.distance.jensenshannon(p, q, axis=1) ** 2 

205 return utils.weighting_function(djs) 

206 

207 def full_feature_from_Z(self, Z): 

208 # Ensure that Z is 3D 

209 if Z.ndim == 2: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true

210 Z = Z.reshape((1, *Z.shape)) 

211 

212 full_dim = 2 * self.n_states - 1 

213 

214 self.n_full_feature = np.empty( 

215 (Z.shape[0], full_dim, self.feature_traj.shape[1]) 

216 ) 

217 self.n_full_feature[:, : self.n_states] = self.full_feature[: self.n_states] 

218 for run, z in enumerate(Z): 

219 self.reset() 

220 for i, (origin, target) in enumerate(z[:, :2].astype(int)): 

221 self.update(origin, target, self.n_states + i) 

222 self.n_full_feature[run, self.n_states :] = self.full_feature[ 

223 self.n_states : 

224 ] 

225 return self.n_full_feature