Coverage for MPP/kernel.py: 89%

101 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-22 12:22 +0200

1import numpy as np 

2import scipy as scy 

3 

4from . import utils 

5 

6__all__ = [ 

7 "LumpingKernel", 

8 "FeatureKernel", 

9] 

10 

11 

12### MERGING KERNEL ########################################################### 

13 

14 

15class LumpingKernel(object): 

16 """Kernel for the most probable path (MPP) algorithm. 

17 

18 This object holds the parameters of the lumping and analyzes the 

19 full transition matrix of the lumping based on a mask of not yet 

20 merged states, upon calling. 

21 

22 Notes 

23 ----- 

24 The similarity between two states may be composed of a dynamic 

25 similarity (defined in this object, c.f. parameter 

26 <similarity>) and / or a geometric similarity, which is 

27 determined by the feature kernel (passed at call). 

28 """ 

29 

30 def __init__(self, method="n", param=1, similarity="T"): 

31 """Initialize LumpingKernel 

32 

33 Parameters 

34 ---------- 

35 method : str 

36 'n' : Consider <param> most similar options. (default) 

37 'p' : Consider as many most similar options as needed to 

38 represent <param> similarity. For similarity 'T', 

39 <param>=0.5 means that at least 50% of the transitions 

40 to other states must be considered. 

41 param : int|float 

42 for 'n' : Number of most similar options to consider 

43 (1 deterministic lumping). (default 1) 

44 for 'p' : Accumulated similarity threshold for most similar 

45 states to consider. 

46 similarity: 

47 - T: Utilize the transition probabilities as dynamic 

48 metric. (default) 

49 - KL: Utilize the Kullback-Leibler divergence between the 

50 transition probabilities of the options. 

51 - none: Utilize only the feature as similarity measure. 

52 """ 

53 self.method = method 

54 self.param = param 

55 self.similarity = similarity 

56 

57 def __call__(self, full_tmat, states_not_merged, mask, feature_kernel=None): 

58 """ 

59 Finds the states to be lumped together next. 

60 

61 The least metastable state is selected and the most similar 

62 other state is determined. The similarity of two states is 

63 determined by the parameters of the object and the feature 

64 kernel. 

65 """ 

66 # Select state with least self transition probability 

67 mask_state = np.argmin(np.diag(full_tmat)[mask]) 

68 # Get correct state index 

69 state = states_not_merged[mask][mask_state][0] 

70 

71 mask[state] = False 

72 trans_probs = full_tmat[state][mask] 

73 trans_probs /= trans_probs.sum() 

74 

75 # If Kullback-Leibler divergence is used 

76 if self.similarity == "KL": 

77 # Mask self transition probabilities 

78 t = full_tmat[mask][:, mask].copy() 

79 np.fill_diagonal(t, trans_probs) 

80 

81 # Regularization parameter 

82 epsilon = 1e-6 

83 kl = scy.stats.entropy( 

84 trans_probs + epsilon, 

85 t + epsilon, 

86 axis=1, 

87 ) 

88 dkl = utils.weighting_function(kl) 

89 if dkl.shape[0] > 1: 

90 dkl -= dkl.min() 

91 dkl /= dkl.sum() 

92 trans_probs = dkl 

93 elif self.similarity == "none": 

94 trans_probs = 1 

95 

96 # Apply feature kernel, if there is one 

97 if feature_kernel: 

98 feature = feature_kernel.apply(full_tmat[state], state, mask) 

99 if not isinstance(feature, np.ndarray): # and feature == 0: 

100 feature = np.array([1.0]) 

101 trans_probs *= feature 

102 

103 trans_probs = np.nan_to_num(trans_probs, copy=False, nan=1e-6) 

104 

105 # transitions contains indices for masked tmat 

106 transitions = np.argsort(trans_probs)[::-1] 

107 # consider n most similar options 

108 if self.method == "n": 

109 options = list(range(self.param))[: trans_probs.shape[0]] 

110 # consider as many most probable options until they sum up to param 

111 elif self.method == "p": 111 ↛ 119line 111 didn't jump to line 119 because the condition on line 111 was always true

112 t_prob_norm = trans_probs / trans_probs.sum() 

113 options = [0] 

114 while t_prob_norm[options].sum() <= self.param and len( 

115 options 

116 ) < np.count_nonzero(trans_probs): 

117 options.append(options[-1] + 1) 

118 else: 

119 raise ValueError("Method must be either 'p' or 'n'") 

120 

121 # Get similarities of the options 

122 p_options = ( 

123 trans_probs[transitions[options]] if len(options) > 1 else np.ones(1) 

124 ) 

125 if np.isnan(p_options).any(): 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true

126 raise ValueError(f"p_options contains NaN: {p_options}") 

127 if sum(p_options) == 0: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true

128 raise ValueError(f"sum of p_options is 0: {p_options}") 

129 

130 # Select the target state 

131 mask_target_state = np.random.choice( 

132 transitions[options], p=p_options / sum(p_options) 

133 ) 

134 

135 target_state = states_not_merged[mask][mask_target_state][0] 

136 return state, target_state, mask 

137 

138 def __repr__(self): 

139 return "<class LumpingKernel>" 

140 

141 

142# TODO: 

143# Remove similarity entirely from feature kernel 

144class FeatureKernel(object): 

145 def __init__( 

146 self, 

147 feature_trajectory, 

148 microstate_trajectory, 

149 feature_type=np.float64, 

150 trajectory_type=np.uint16, 

151 similarity="JS", 

152 ): 

153 """ 

154 feature_trajectory: either N or NxM, N being the number of frames and M the 

155 number of features 

156 """ 

157 if feature_trajectory.ndim == 2: 157 ↛ 160line 157 didn't jump to line 160 because the condition on line 157 was always true

158 self.feature_trajectory = feature_trajectory.astype(feature_type) 

159 else: 

160 raise ValueError("featuretrajectory must be a 2 D array.") 

161 

162 self._init_feature(microstate_trajectory.astype(trajectory_type)) 

163 

164 def __repr__(self): 

165 return "<class FeatureKernel>" 

166 

167 def _init_feature(self, microstate_trajectory): 

168 states, pop = np.unique(microstate_trajectory, return_counts=True) 

169 self.n_states = states.shape[0] 

170 # Populations for all states incl intermediate states 

171 self.full_pop = np.zeros(2 * self.n_states - 1, dtype=np.uint32) 

172 self.full_pop[: self.n_states] = pop 

173 # corresponding feature values 

174 self.full_feature = np.zeros( 

175 (2 * self.n_states - 1, self.feature_trajectory.shape[1]), 

176 dtype=self.feature_trajectory.dtype.type, 

177 ) 

178 for i in range(self.n_states): 

179 self.full_feature[i] = self.feature_trajectory[ 

180 microstate_trajectory == i + 1 

181 ].mean(axis=0) 

182 

183 def reset(self): 

184 self.full_pop[self.n_states :] = 0 

185 self.full_feature[self.n_states :] = 0 

186 

187 def apply(self, trans_prob, state, mask): 

188 f = self.js(state, mask) 

189 f -= f.min() 

190 if f.sum() != 0: 

191 return f / f.sum() 

192 else: 

193 return 0 

194 

195 def update(self, origin, target, new_state): 

196 self.full_pop[new_state] = self.full_pop[[origin, target]].sum() 

197 self.full_feature[new_state] = ( 

198 self.full_feature[origin] * self.full_pop[origin] 

199 + self.full_feature[target] * self.full_pop[target] 

200 ) / self.full_pop[new_state] 

201 

202 def js(self, state, mask): 

203 p = self.full_feature[state] 

204 q = self.full_feature[mask] 

205 if p.ndim == 1: 205 ↛ 207line 205 didn't jump to line 207 because the condition on line 205 was always true

206 p = np.expand_dims(p, axis=0) 

207 if q.ndim == 1: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 q = np.expand_dims(q, axis=0) 

209 djs = scy.spatial.distance.jensenshannon(p, q, axis=1) ** 2 

210 return utils.weighting_function(djs) 

211 

212 def full_feature_from_Z(self, Z): 

213 # Ensure that Z is 3D 

214 if Z.ndim == 2: 214 ↛ 215line 214 didn't jump to line 215 because the condition on line 214 was never true

215 Z = Z.reshape((1, *Z.shape)) 

216 

217 full_dim = 2 * self.n_states - 1 

218 

219 self.n_full_feature = np.empty( 

220 (Z.shape[0], full_dim, self.feature_trajectory.shape[1]) 

221 ) 

222 self.n_full_feature[:, : self.n_states] = self.full_feature[: self.n_states] 

223 for run, z in enumerate(Z): 

224 self.reset() 

225 for i, (origin, target) in enumerate(z[:, :2].astype(int)): 

226 self.update(origin, target, self.n_states + i) 

227 self.n_full_feature[run, self.n_states :] = self.full_feature[ 

228 self.n_states : 

229 ] 

230 return self.n_full_feature