Coverage for MPT/kernel.py: 91%
97 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-09 15:48 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-09 15:48 +0200
1import numpy as np
2import scipy as scy
4import MPT.utils as utils
6__all__ = [
7 "MPTKernel",
8 "FeatureKernel",
9]
12### MERGING KERNEL ###########################################################
15class MPTKernel(object):
16 def __init__(self, method="n", param=1, similarity="T"):
17 """
18 Kernel for the most probable path (MPP) algorithm.
20 This object holds the parameters of the lumping and analyzes the
21 full transition matrix of the lumping based on a mask of not yet
22 merged states, upon calling.
24 Parameters
25 ----------
26 method : str
27 'n' : Consider <param> most similar options.
28 'p' : Consider as many most similar options as needed to
29 represent <param> similarity. For similarity 'T',
30 <param>=0.5 means that at least 50% of the transitions
31 to other states must be considered.
32 param : int|float
33 for 'n' : Number of most similar options to consider
34 (1 deterministic lumping).
35 for 'p' : Accumulated similarity threshold for most similar
36 states to consider.
37 similarity:
38 - T: Utilize the transition probabilities as dynamic
39 metric.
40 - KL: Utilize the Kullback-Leibler divergence between the
41 transition probabilities of the options.
42 - none: Utilize only the feature as similarity measure.
44 Notes
45 -----
46 The similarity between two states may be composed of a dynamic
47 similarity (defined in this object, c.f. parameter
48 <similarity>) and / or a geometric similarity, which is
49 determined by the feature kernel (passed at call).
50 """
51 self.method = method
52 self.param = param
53 self.similarity = similarity
55 def __call__(self, full_tmat, states_not_merged, mask, feature_kernel=None):
56 """
57 Finds the states to be lumped together next.
59 The least metastable state is selected and the most similar
60 other state is determined. The similarity of two states is
61 determined by the parameters of the object and the feature
62 kernel.
63 """
64 # Select state with least self transition probability
65 mask_state = np.argmin(np.diag(full_tmat)[mask])
66 # Get correct state index
67 state = states_not_merged[mask][mask_state][0]
69 mask[state] = False
70 trans_probs = full_tmat[state][mask]
71 trans_probs /= trans_probs.sum()
73 # If Kullback-Leibler divergence is used
74 if self.similarity == "KL":
75 # Mask self transition probabilities
76 t = full_tmat[mask][:, mask].copy()
77 np.fill_diagonal(t, trans_probs)
79 # Regularization parameter
80 epsilon = 1e-6
81 kl = scy.stats.entropy(
82 trans_probs + epsilon,
83 t + epsilon,
84 axis=1,
85 )
86 dkl = utils.weighting_function(kl)
87 if dkl.shape[0] > 1:
88 dkl -= dkl.min()
89 dkl /= dkl.sum()
90 trans_probs = dkl
91 elif self.similarity == "none":
92 trans_probs = 1
94 # Apply feature kernel, if there is one
95 if feature_kernel:
96 feature = feature_kernel.apply(full_tmat[state], state, mask)
97 if not isinstance(feature, np.ndarray): # and feature == 0:
98 feature = np.array([1.0])
99 trans_probs *= feature
101 trans_probs = np.nan_to_num(trans_probs, copy=False, nan=1e-6)
103 # transitions contains indices for masked tmat
104 transitions = np.argsort(trans_probs)[::-1]
105 # consider n most similar options
106 if self.method == "n":
107 options = list(range(self.param))[: trans_probs.shape[0]]
108 # consider as many most probable options until they sum up to param
109 elif self.method == "p": 109 ↛ 117line 109 didn't jump to line 117 because the condition on line 109 was always true
110 t_prob_norm = trans_probs / trans_probs.sum()
111 options = [0]
112 while t_prob_norm[options].sum() <= self.param and len(
113 options
114 ) < np.count_nonzero(trans_probs):
115 options.append(options[-1] + 1)
116 else:
117 raise ValueError("Method must be either 'p' or 'n'")
119 # Get similarities of the options
120 p_options = trans_probs[transitions[options]]
122 # Select the target state
123 mask_target_state = np.random.choice(
124 transitions[options], p=p_options / sum(p_options)
125 )
127 target_state = states_not_merged[mask][mask_target_state][0]
128 return state, target_state, mask
130 def __repr__(self):
131 return "<class MPTKernel>"
134### FEATURE KERNEL ###########################################################
137# TODO:
138# Remove similarity entirely from feature kernel
139class MultiFeatureKernel(object):
140 def __init__(
141 self,
142 feature_traj,
143 microstate_traj,
144 feature_type=np.float64,
145 traj_type=np.uint16,
146 similarity="JS",
147 ):
148 """
149 feature_traj: either N or NxM, N being the number of frames and M the
150 number of features
151 """
152 if feature_traj.ndim == 2: 152 ↛ 155line 152 didn't jump to line 155 because the condition on line 152 was always true
153 self.feature_traj = feature_traj.astype(feature_type)
154 else:
155 raise ValueError("featuretraj must be a 2 D array.")
157 self._init_feature(microstate_traj.astype(traj_type))
159 def __repr__(self):
160 return "<class MultiFeatureKernel>"
162 def _init_feature(self, microstate_traj):
163 states, pop = np.unique(microstate_traj, return_counts=True)
164 self.n_states = states.shape[0]
165 # Populations for all states incl intermediate states
166 self.full_pop = np.zeros(2 * self.n_states - 1, dtype=np.uint32)
167 self.full_pop[: self.n_states] = pop
168 # corresponding feature values
169 self.full_feature = np.zeros(
170 (2 * self.n_states - 1, self.feature_traj.shape[1]),
171 dtype=self.feature_traj.dtype.type,
172 )
173 for i in range(self.n_states):
174 self.full_feature[i] = self.feature_traj[microstate_traj == i + 1].mean(
175 axis=0
176 )
178 def reset(self):
179 self.full_pop[self.n_states :] = 0
180 self.full_feature[self.n_states :] = 0
182 def apply(self, trans_prob, state, mask):
183 f = self.js(state, mask)
184 f -= f.min()
185 if f.sum() != 0:
186 return f / f.sum()
187 else:
188 return 0
190 def update(self, origin, target, new_state):
191 self.full_pop[new_state] = self.full_pop[[origin, target]].sum()
192 self.full_feature[new_state] = (
193 self.full_feature[origin] * self.full_pop[origin]
194 + self.full_feature[target] * self.full_pop[target]
195 ) / self.full_pop[new_state]
197 def js(self, state, mask):
198 p = self.full_feature[state]
199 q = self.full_feature[mask]
200 if p.ndim == 1: 200 ↛ 202line 200 didn't jump to line 202 because the condition on line 200 was always true
201 p = np.expand_dims(p, axis=0)
202 if q.ndim == 1: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true
203 q = np.expand_dims(q, axis=0)
204 djs = scy.spatial.distance.jensenshannon(p, q, axis=1) ** 2
205 return utils.weighting_function(djs)
207 def full_feature_from_Z(self, Z):
208 # Ensure that Z is 3D
209 if Z.ndim == 2: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true
210 Z = Z.reshape((1, *Z.shape))
212 full_dim = 2 * self.n_states - 1
214 self.n_full_feature = np.empty(
215 (Z.shape[0], full_dim, self.feature_traj.shape[1])
216 )
217 self.n_full_feature[:, : self.n_states] = self.full_feature[: self.n_states]
218 for run, z in enumerate(Z):
219 self.reset()
220 for i, (origin, target) in enumerate(z[:, :2].astype(int)):
221 self.update(origin, target, self.n_states + i)
222 self.n_full_feature[run, self.n_states :] = self.full_feature[
223 self.n_states :
224 ]
225 return self.n_full_feature