Coverage for MPT/core.py: 90%
285 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-08 16:59 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-08 16:59 +0200
1"""
2core.py
3=======
5Core functions for MPT class
6"""
8__all__ = [
9 "cluster",
10]
12import sys
13import numpy as np
14from typing import Callable
15from numpy.typing import NDArray
16import matplotlib.pyplot as plt
17from matplotlib.colors import Normalize
19from anytree import NodeMixin
20from anytree.iterators import PreOrderIter
22import MPT.utils as utils
23import MPT.kernel as kern
25sys.setrecursionlimit(2020)
28class BinaryTreeNode(NodeMixin):
29 def __init__(
30 self,
31 name,
32 tmat,
33 population=0,
34 q=0,
35 feature=0,
36 macrostate_thresholds=(0.005, 0.5),
37 parent=None,
38 left=None,
39 right=None,
40 ):
41 """
42 This class is used to plot dendrograms.
44 prameters:
45 ----------
47 name (str): name of the node
48 population (float): population of the node
49 q (float): value at which the node is merged
50 feature (float): some feature used for coloring
51 parent: parent node
52 left: left node
53 right: right node
54 """
55 self._left = None
56 self._right = None
57 self._is_macrostate = None
58 self._macrostates = None
59 self._all_macrostates = None
60 self._parent_macrostate = None
61 self._assigned_macrostate = None
63 self.name = name
64 self.tmat = tmat
65 self.n_states = int((self.tmat.shape[0] + 1) / 2)
66 self.population = population # Base population, used if the node is a leaf
67 self.q = q
68 self.feature = feature
69 self.pop_thr, self.q_min = macrostate_thresholds
70 self.parent = parent
71 self.left = left
72 self.right = right
74 self._x_origin = None
75 self._x_target = None
76 self._y_origin = None
78 self._bins = None
79 self._feature_norm = None
80 self._colors = None
82 @property
83 def population(self):
84 """Population of state."""
85 if self.is_leaf:
86 return self._population
87 else:
88 return (self.left.population if self.left else 0) + (
89 self.right.population if self.right else 0
90 )
92 @population.setter
93 def population(self, value):
94 if self.is_leaf: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true
95 self._population = value
96 else:
97 return ValueError("population can only be set for microstates (leaves)")
99 @property
100 def q(self):
101 """Q, e. g. self transition probability at which states were merged."""
102 return self._q
104 @q.setter
105 def q(self, value):
106 if 0 <= value <= 1: 106 ↛ 109line 106 didn't jump to line 109 because the condition on line 106 was always true
107 self._q = value
108 else:
109 raise ValueError("q must be 0 <= q <= 1")
111 @property
112 def feature(self):
113 """
114 Feature for states (e. g. fraction of native contacts). Is forwarded
115 weighted by population
116 """
117 if self.is_leaf:
118 return self._feature
119 else:
120 return (
121 (self.left.feature * self.left.population if self.left else 0)
122 + (self.right.feature * self.right.population if self.right else 0)
123 ) / self.population
125 @feature.setter
126 def feature(self, value):
127 if 0 <= value <= 1: 127 ↛ 130line 127 didn't jump to line 130 because the condition on line 127 was always true
128 self._feature = value
129 else:
130 raise ValueError("feature must be 0 <= feature <= 1")
132 @property
133 def left(self):
134 return self._left
136 @left.setter
137 def left(self, node):
138 if node is not None and node.parent is not None: 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true
139 raise ValueError("Node already has a parent")
140 if self._left is not None: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true
141 self._left.parent = None
142 self._left = node
143 if node is not None:
144 node.parent = self
146 @property
147 def right(self):
148 return self._right
150 @right.setter
151 def right(self, node):
152 if node is not None and node.parent is not None: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true
153 raise ValueError("Node already has a parent")
154 if self._right is not None: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 self._right.parent = None
156 self._right = node
157 if node is not None:
158 node.parent = self
160 @property
161 def children(self):
162 """Return the two child nodes."""
163 children = []
164 if self.left is not None:
165 children.append(self.left)
166 if self.right is not None:
167 children.append(self.right)
168 return children
170 @property
171 def is_leaf(self):
172 """Check if this node is leaf node."""
173 return not (self.left or self.right)
175 @property
176 def is_macrostate(self):
177 """Mark macrostates using this flag."""
178 if self._is_macrostate is None:
179 if (
180 self.parent is not None
181 and self.parent.q >= self.q_min
182 and self.population >= self.root.population * self.pop_thr
183 and self.siblings[0].population >= self.root.population * self.pop_thr
184 ):
185 self._is_macrostate = True
186 self.siblings[0].is_macrostate = True
187 elif self.parent is None:
188 self._is_macrostate = True
189 else:
190 self.is_macrostate = False
191 return self._is_macrostate
193 @is_macrostate.setter
194 def is_macrostate(self, value):
195 if isinstance(value, bool): 195 ↛ 198line 195 didn't jump to line 198 because the condition on line 195 was always true
196 self._is_macrostate = value
197 else:
198 raise ValueError("is_macrostate must be boolean")
200 @property
201 def macrostates(self):
202 """Returns all macrostate nodes."""
203 if self._macrostates is None:
204 true_macrostates = []
205 for macrostate in self.all_macrostates:
206 if len(macrostate.all_macrostates) == 1:
207 true_macrostates.append(macrostate)
208 self._macrostates = tuple(true_macrostates)
209 return self._macrostates
211 @property
212 def all_macrostates(self):
213 """Returns all macrostate nodes."""
214 if self._all_macrostates is None:
215 self._all_macrostates = tuple(
216 PreOrderIter(self, filter_=lambda node: node.is_macrostate)
217 )
218 return self._all_macrostates
220 @property
221 def parent_macrostate(self):
222 """The parent_macrostate property."""
223 if self._parent_macrostate is None:
224 parent = self.parent
225 while parent is not None and not parent.is_macrostate:
226 parent = parent.parent
227 self._parent_macrostate = parent
228 return self._parent_macrostate
230 @property
231 def assigned_macrostate(self):
232 """The assigned_macrostate property."""
233 if self._assigned_macrostate is None:
234 if self.is_leaf: 234 ↛ 264line 234 didn't jump to line 264 because the condition on line 234 was always true
235 if self.is_macrostate:
236 self._assigned_macrostate = self
237 else:
238 if len(self.parent_macrostate.macrostates) == 1:
239 self._assigned_macrostate = self.parent_macrostate
240 else:
241 trans_probs = []
242 for m in self.parent_macrostate.macrostates:
243 macrostate = np.array(
244 [(s.name, s.population) for s in m.leaves]
245 )
246 indices = list(macrostate[:, 0])
247 indices.append(self.name)
248 indices.append(0)
249 tmp_tmat = self.tmat[np.ix_(indices, indices)].copy()
250 pops = list(macrostate[:, 1])
251 pops.append(self.population)
252 pops.append(0)
253 tmp_tmat, pops = utils.merge_states(
254 tmp_tmat,
255 list(range(macrostate.shape[0])),
256 -1,
257 np.array(pops),
258 )
259 trans_probs.append(tmp_tmat[-2, -1])
260 self._assigned_macrostate = self.parent_macrostate.macrostates[
261 np.argmax(trans_probs)
262 ]
263 else:
264 self._assigned_macrostate = None
265 return self._assigned_macrostate
267 @property
268 def bins(self):
269 """The bins property."""
270 if self.is_root and self._bins is None:
271 leaf_features = [leaf.feature for leaf in self.leaves]
272 min_feature = min(leaf_features)
273 max_feature = max(leaf_features)
274 self._bins = np.linspace(min_feature, max_feature, 11)
275 if self.is_root: 275 ↛ 278line 275 didn't jump to line 278 because the condition on line 275 was always true
276 return self._bins
277 else:
278 return self.root.bins
280 @property
281 def feature_norm(self):
282 """The feature_norm property."""
283 if self.is_root and self._feature_norm is None:
284 self._feature_norm = Normalize(self.bins[0], self.bins[-1])
285 if self.is_root:
286 return self._feature_norm
287 else:
288 return self.root.feature_norm
290 @property
291 def colors(self):
292 """The colors property."""
293 if self.is_root and self._colors is None:
294 cmap = plt.get_cmap("plasma_r", 10)
295 self._colors = [cmap(idx) for idx in range(cmap.N)]
296 if self.is_root:
297 return self._colors
298 else:
299 return self.root.colors
301 @property
302 def color(self):
303 """Color according to feature."""
304 for color, rlower, rhigher in zip( 304 ↛ 309line 304 didn't jump to line 309 because the loop on line 304 didn't complete
305 self.colors, np.arange(0, 1, 0.1), np.arange(0.1, 1.1, 0.1)
306 ):
307 if rlower <= self.feature_norm(self.feature) <= rhigher:
308 return color
309 return "k"
311 @property
312 def edge_width(self):
313 """Edge width from population."""
314 return 6 * self.population / self.root.population
316 @property
317 def macrostate(self):
318 """
319 Macrostate this state belongs to. None if no macrostates are found
320 above in tree.
321 """
322 node = self
323 while not node.is_macrostate and node.parent:
324 node = node.parent
325 if node.is_macrostate:
326 return node
327 else:
328 return None
330 @property
331 def x(self):
332 """X coordinates for dandrogram for this node"""
333 return np.array([self.x_origin, self.x_origin, self.x_target])
335 @property
336 def x_origin(self):
337 """The x_origin property."""
338 if not self.is_leaf:
339 if not self._x_origin:
340 self.x_origin = self.children[0].x_target
341 return self._x_origin
343 @x_origin.setter
344 def x_origin(self, value):
345 self._x_origin = value
347 @property
348 def x_target(self):
349 """The x_target property."""
350 if not self._x_target:
351 if self.is_root: 351 ↛ 352line 351 didn't jump to line 352 because the condition on line 351 was never true
352 self.x_target = self.x_origin
353 else:
354 self.x_target = (self.x_origin + self.siblings[0].x_origin) / 2
355 return self._x_target
357 @x_target.setter
358 def x_target(self, value):
359 self._x_target = value
361 @property
362 def y(self):
363 """Y coordinates for dandrogram for this node"""
364 return np.array([self.y_origin, self.y_target, self.y_target])
366 @property
367 def y_origin(self):
368 """The y_origin property."""
369 if self.is_leaf:
370 return 0
371 else:
372 if not self._y_origin: 372 ↛ 374line 372 didn't jump to line 374 because the condition on line 372 was always true
373 self.y_origin = self.children[0].y_target
374 return self._y_origin
376 @y_origin.setter
377 def y_origin(self, value):
378 self._y_origin = value
380 @property
381 def y_target(self):
382 """The y_target property."""
383 if self.parent: 383 ↛ 386line 383 didn't jump to line 386 because the condition on line 383 was always true
384 return self.parent.q
385 else:
386 return 1
388 def plot(self, ax):
389 for c in self.children:
390 ax = c.plot(ax)
391 # Remove this condition if root should be plotted as well.
392 if not self.is_root:
393 ax.plot(
394 self.x,
395 self.y,
396 color=self.color,
397 linewidth=self.edge_width if self.edge_width > 0.15 else 0.15,
398 )
399 return ax
401 def plot_tree(self, ax):
402 for i, leaf in enumerate(self.leaves):
403 leaf.x_origin = i
404 return self.plot(ax)
407def cluster(
408 tmat: NDArray[float],
409 pop: NDArray[np.int_],
410 kernel: Callable[
411 [NDArray[float], NDArray[np.int_], NDArray[np.bool_]],
412 [np.int_, np.int_, NDArray[np.bool_]],
413 ] = kern.MPTKernel(),
414 feature_kernel=None,
415) -> (NDArray[float], NDArray[np.int_]):
416 """
417 cluster
418 -------
419 Perform full clustering for a transition matrix, given populations and a
420 kernel.
422 tmat (NDArray[float]): transition matrix, e. g. from
423 mh.msm.estimate_markov_model
424 pop (NDArray[float]): populations of microstates
425 kernel: kernel object that determines the next merge
427 returns Z (np.ndarray), full_pop (np.ndarray):
428 The Z matrix holds the full merging of microstates:
429 0: origin state
430 1: target state
431 2: distance between origin and target
432 3: joint population
433 i: Z[i, 0] and Z[i, 1] are combined to cluster n + i
434 reference: scipy.cluster.hierarchy.linkage
435 full_pop holds all state populations from state 0 to n + i
436 """
437 n = tmat.shape[0]
439 full_tmat = np.zeros((2 * n - 1, 2 * n - 1), dtype=tmat.dtype.type)
440 full_tmat[:n, :n] = tmat
442 full_pop = np.zeros(2 * n - 1, dtype=pop.dtype.type)
443 full_pop[:n] = pop
445 if tmat.shape[0] < 2**7: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true
446 states_type = np.uint8
447 elif tmat.shape[0] < 2**15: 447 ↛ 450line 447 didn't jump to line 450 because the condition on line 447 was always true
448 states_type = np.uint16
449 else:
450 states_type = np.uint32
452 # complete linkage
453 full_states = np.zeros((2 * n - 1, 2), dtype=states_type)
454 full_states[:n, 0] = np.arange(0, n)
456 mask = np.full(2 * n - 1, False)
457 mask[:n] = True
459 # 0: state a
460 # 1: state b
461 # 2: distance between a and b
462 # 3: population
463 # i: Z[i, 0] and Z[i, 1] are combined to cluster n + i
464 Z = np.zeros((n - 1, 4), dtype=np.float32)
466 if feature_kernel:
467 feature_kernel.reset()
468 for i in range(n - 1):
469 # Index of new state
470 new_state = n + i
472 # Use feature only for determination of target state
473 if feature_kernel:
474 # state, target_state, mask = kernel(feature_kernel * full_tmat, full_states, mask)
475 state, target_state, mask = kernel(
476 full_tmat, full_states, mask, feature_kernel
477 )
478 feature_kernel.update(state, target_state, new_state)
479 else:
480 state, target_state, mask = kernel(full_tmat, full_states, mask)
482 metastability = full_tmat[state, state]
483 # Merge states in transition matrix
484 full_tmat, full_pop = utils.merge_states(
485 full_tmat, [state, target_state], new_state, full_pop
486 )
488 # Update state linkage
489 full_states[state, 1] = new_state
490 full_states[target_state, 1] = new_state
491 full_states[new_state:, 0] = new_state
493 Z[i] = [state, target_state, metastability, full_pop[new_state]]
495 # Update mask
496 mask[new_state] = True
497 mask[target_state] = False
499 return Z, full_pop