Coverage for MPP/core.py: 90%
286 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-14 12:00 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-14 12:00 +0200
1"""
2core.py
3=======
5Core functions for MPT class
6"""
8__all__ = [
9 "cluster",
10]
12import sys
13import numpy as np
14from typing import Callable
15from numpy.typing import NDArray
16import matplotlib.pyplot as plt
17from matplotlib.colors import Normalize
19from anytree import NodeMixin
20from anytree.iterators import PreOrderIter
22from . import utils
23from . import kernel as kernel_module
25sys.setrecursionlimit(2020)
28class BinaryTreeNode(NodeMixin):
29 def __init__(
30 self,
31 name,
32 tmat,
33 population=0,
34 q=0,
35 feature=0,
36 pop_thr=0.005,
37 q_min=0.5,
38 parent=None,
39 left=None,
40 right=None,
41 ):
42 """
43 This class is used to plot dendrograms.
45 prameters:
46 ----------
48 name (str): name of the node
49 population (float): population of the node
50 q (float): value at which the node is merged
51 feature (float): some feature used for coloring
52 parent: parent node
53 left: left node
54 right: right node
55 """
56 self._left = None
57 self._right = None
58 self._is_macrostate = None
59 self._macrostates = None
60 self._all_macrostates = None
61 self._parent_macrostate = None
62 self._assigned_macrostate = None
64 self.name = name
65 self.tmat = tmat
66 self.n_states = int((self.tmat.shape[0] + 1) / 2)
67 self.population = population # Base population, used if the node is a leaf
68 self.q = q
69 self.feature = feature
70 self.pop_thr = pop_thr
71 self.q_min = q_min
72 self.parent = parent
73 self.left = left
74 self.right = right
76 self._x_origin = None
77 self._x_target = None
78 self._y_origin = None
80 self._bins = None
81 self._feature_norm = None
82 self._colors = None
84 @property
85 def population(self):
86 """Population of state."""
87 if self.is_leaf:
88 return self._population
89 else:
90 return (self.left.population if self.left else 0) + (
91 self.right.population if self.right else 0
92 )
94 @population.setter
95 def population(self, value):
96 if self.is_leaf: 96 ↛ 99line 96 didn't jump to line 99 because the condition on line 96 was always true
97 self._population = value
98 else:
99 return ValueError("population can only be set for microstates (leaves)")
101 @property
102 def q(self):
103 """Q, e. g. self transition probability at which states were merged."""
104 return self._q
106 @q.setter
107 def q(self, value):
108 if 0 <= value <= 1: 108 ↛ 111line 108 didn't jump to line 111 because the condition on line 108 was always true
109 self._q = value
110 else:
111 raise ValueError("q must be 0 <= q <= 1")
113 @property
114 def feature(self):
115 """
116 Feature for states (e. g. fraction of native contacts). Is forwarded
117 weighted by population
118 """
119 if self.is_leaf:
120 return self._feature
121 else:
122 return (
123 (self.left.feature * self.left.population if self.left else 0)
124 + (self.right.feature * self.right.population if self.right else 0)
125 ) / self.population
127 @feature.setter
128 def feature(self, value):
129 if 0 <= value <= 1: 129 ↛ 132line 129 didn't jump to line 132 because the condition on line 129 was always true
130 self._feature = value
131 else:
132 raise ValueError("feature must be 0 <= feature <= 1")
134 @property
135 def left(self):
136 return self._left
138 @left.setter
139 def left(self, node):
140 if node is not None and node.parent is not None: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true
141 raise ValueError("Node already has a parent")
142 if self._left is not None: 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true
143 self._left.parent = None
144 self._left = node
145 if node is not None:
146 node.parent = self
148 @property
149 def right(self):
150 return self._right
152 @right.setter
153 def right(self, node):
154 if node is not None and node.parent is not None: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 raise ValueError("Node already has a parent")
156 if self._right is not None: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 self._right.parent = None
158 self._right = node
159 if node is not None:
160 node.parent = self
162 @property
163 def children(self):
164 """Return the two child nodes."""
165 children = []
166 if self.left is not None:
167 children.append(self.left)
168 if self.right is not None:
169 children.append(self.right)
170 return children
172 @property
173 def is_leaf(self):
174 """Check if this node is leaf node."""
175 return not (self.left or self.right)
177 @property
178 def is_macrostate(self):
179 """Mark macrostates using this flag."""
180 if self._is_macrostate is None:
181 if (
182 self.parent is not None
183 and self.parent.q >= self.q_min
184 and self.population >= self.root.population * self.pop_thr
185 and self.siblings[0].population >= self.root.population * self.pop_thr
186 ):
187 self._is_macrostate = True
188 self.siblings[0].is_macrostate = True
189 elif self.parent is None:
190 self._is_macrostate = True
191 else:
192 self.is_macrostate = False
193 return self._is_macrostate
195 @is_macrostate.setter
196 def is_macrostate(self, value):
197 if isinstance(value, bool): 197 ↛ 200line 197 didn't jump to line 200 because the condition on line 197 was always true
198 self._is_macrostate = value
199 else:
200 raise ValueError("is_macrostate must be boolean")
202 @property
203 def macrostates(self):
204 """Returns all macrostate nodes."""
205 if self._macrostates is None:
206 true_macrostates = []
207 for macrostate in self.all_macrostates:
208 if len(macrostate.all_macrostates) == 1:
209 true_macrostates.append(macrostate)
210 self._macrostates = tuple(true_macrostates)
211 return self._macrostates
213 @property
214 def all_macrostates(self):
215 """Returns all macrostate nodes."""
216 if self._all_macrostates is None:
217 self._all_macrostates = tuple(
218 PreOrderIter(self, filter_=lambda node: node.is_macrostate)
219 )
220 return self._all_macrostates
222 @property
223 def parent_macrostate(self):
224 """The parent_macrostate property."""
225 if self._parent_macrostate is None:
226 parent = self.parent
227 while parent is not None and not parent.is_macrostate:
228 parent = parent.parent
229 self._parent_macrostate = parent
230 return self._parent_macrostate
232 @property
233 def assigned_macrostate(self):
234 """The assigned_macrostate property."""
235 if self._assigned_macrostate is None:
236 if self.is_leaf: 236 ↛ 266line 236 didn't jump to line 266 because the condition on line 236 was always true
237 if self.is_macrostate:
238 self._assigned_macrostate = self
239 else:
240 if len(self.parent_macrostate.macrostates) == 1:
241 self._assigned_macrostate = self.parent_macrostate
242 else:
243 trans_probs = []
244 for m in self.parent_macrostate.macrostates:
245 macrostate = np.array(
246 [(s.name, s.population) for s in m.leaves]
247 )
248 indices = list(macrostate[:, 0])
249 indices.append(self.name)
250 indices.append(0)
251 tmp_tmat = self.tmat[np.ix_(indices, indices)].copy()
252 pops = list(macrostate[:, 1])
253 pops.append(self.population)
254 pops.append(0)
255 tmp_tmat, pops = utils.merge_states(
256 tmp_tmat,
257 list(range(macrostate.shape[0])),
258 -1,
259 np.array(pops),
260 )
261 trans_probs.append(tmp_tmat[-2, -1])
262 self._assigned_macrostate = self.parent_macrostate.macrostates[
263 np.argmax(trans_probs)
264 ]
265 else:
266 self._assigned_macrostate = None
267 return self._assigned_macrostate
269 @property
270 def bins(self):
271 """The bins property."""
272 if self.is_root and self._bins is None:
273 leaf_features = [leaf.feature for leaf in self.leaves]
274 min_feature = min(leaf_features)
275 max_feature = max(leaf_features)
276 self._bins = np.linspace(min_feature, max_feature, 11)
277 if self.is_root: 277 ↛ 280line 277 didn't jump to line 280 because the condition on line 277 was always true
278 return self._bins
279 else:
280 return self.root.bins
282 @property
283 def feature_norm(self):
284 """The feature_norm property."""
285 if self.is_root and self._feature_norm is None:
286 self._feature_norm = Normalize(self.bins[0], self.bins[-1])
287 if self.is_root:
288 return self._feature_norm
289 else:
290 return self.root.feature_norm
292 @property
293 def colors(self):
294 """The colors property."""
295 if self.is_root and self._colors is None:
296 cmap = plt.get_cmap("plasma_r", 10)
297 self._colors = [cmap(idx) for idx in range(cmap.N)]
298 if self.is_root:
299 return self._colors
300 else:
301 return self.root.colors
303 @property
304 def color(self):
305 """Color according to feature."""
306 for color, rlower, rhigher in zip( 306 ↛ 311line 306 didn't jump to line 311 because the loop on line 306 didn't complete
307 self.colors, np.arange(0, 1, 0.1), np.arange(0.1, 1.1, 0.1)
308 ):
309 if rlower <= self.feature_norm(self.feature) <= rhigher:
310 return color
311 return "k"
313 @property
314 def edge_width(self):
315 """Edge width from population."""
316 return 6 * self.population / self.root.population
318 @property
319 def macrostate(self):
320 """
321 Macrostate this state belongs to. None if no macrostates are found
322 above in tree.
323 """
324 node = self
325 while not node.is_macrostate and node.parent:
326 node = node.parent
327 if node.is_macrostate:
328 return node
329 else:
330 return None
332 @property
333 def x(self):
334 """X coordinates for dandrogram for this node"""
335 return np.array([self.x_origin, self.x_origin, self.x_target])
337 @property
338 def x_origin(self):
339 """The x_origin property."""
340 if not self.is_leaf:
341 if not self._x_origin:
342 self.x_origin = self.children[0].x_target
343 return self._x_origin
345 @x_origin.setter
346 def x_origin(self, value):
347 self._x_origin = value
349 @property
350 def x_target(self):
351 """The x_target property."""
352 if not self._x_target:
353 if self.is_root: 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true
354 self.x_target = self.x_origin
355 else:
356 self.x_target = (self.x_origin + self.siblings[0].x_origin) / 2
357 return self._x_target
359 @x_target.setter
360 def x_target(self, value):
361 self._x_target = value
363 @property
364 def y(self):
365 """Y coordinates for dandrogram for this node"""
366 return np.array([self.y_origin, self.y_target, self.y_target])
368 @property
369 def y_origin(self):
370 """The y_origin property."""
371 if self.is_leaf:
372 return 0
373 else:
374 if not self._y_origin: 374 ↛ 376line 374 didn't jump to line 376 because the condition on line 374 was always true
375 self.y_origin = self.children[0].y_target
376 return self._y_origin
378 @y_origin.setter
379 def y_origin(self, value):
380 self._y_origin = value
382 @property
383 def y_target(self):
384 """The y_target property."""
385 if self.parent: 385 ↛ 388line 385 didn't jump to line 388 because the condition on line 385 was always true
386 return self.parent.q
387 else:
388 return 1
390 def plot(self, ax):
391 for c in self.children:
392 ax = c.plot(ax)
393 # Remove this condition if root should be plotted as well.
394 if not self.is_root:
395 ax.plot(
396 self.x,
397 self.y,
398 color=self.color,
399 linewidth=self.edge_width if self.edge_width > 0.15 else 0.15,
400 )
401 return ax
403 def plot_tree(self, ax):
404 for i, leaf in enumerate(self.leaves):
405 leaf.x_origin = i
406 return self.plot(ax)
409def cluster(
410 tmat: NDArray[float],
411 pop: NDArray[np.int_],
412 kernel: Callable[
413 [NDArray[float], NDArray[np.int_], NDArray[np.bool_]],
414 [np.int_, np.int_, NDArray[np.bool_]],
415 ] = kernel_module.LumpingKernel(),
416 feature_kernel=None,
417) -> (NDArray[float], NDArray[np.int_]):
418 """
419 cluster
420 -------
421 Perform full clustering for a transition matrix, given populations and a
422 kernel.
424 tmat (NDArray[float]): transition matrix, e. g. from
425 mh.msm.estimate_markov_model
426 pop (NDArray[float]): populations of microstates
427 kernel: kernel object that determines the next merge
429 returns Z (np.ndarray), full_pop (np.ndarray):
430 The Z matrix holds the full merging of microstates:
431 0: origin state
432 1: target state
433 2: distance between origin and target
434 3: joint population
435 i: Z[i, 0] and Z[i, 1] are combined to cluster n + i
436 reference: scipy.cluster.hierarchy.linkage
437 full_pop holds all state populations from state 0 to n + i
438 """
439 n = tmat.shape[0]
441 full_tmat = np.zeros((2 * n - 1, 2 * n - 1), dtype=tmat.dtype.type)
442 full_tmat[:n, :n] = tmat
444 full_pop = np.zeros(2 * n - 1, dtype=pop.dtype.type)
445 full_pop[:n] = pop
447 if tmat.shape[0] < 2**7: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true
448 states_type = np.uint8
449 elif tmat.shape[0] < 2**15: 449 ↛ 452line 449 didn't jump to line 452 because the condition on line 449 was always true
450 states_type = np.uint16
451 else:
452 states_type = np.uint32
454 # complete linkage
455 full_states = np.zeros((2 * n - 1, 2), dtype=states_type)
456 full_states[:n, 0] = np.arange(0, n)
458 mask = np.full(2 * n - 1, False)
459 mask[:n] = True
461 # 0: state a
462 # 1: state b
463 # 2: distance between a and b
464 # 3: population
465 # i: Z[i, 0] and Z[i, 1] are combined to cluster n + i
466 Z = np.zeros((n - 1, 4), dtype=np.float32)
468 if feature_kernel:
469 feature_kernel.reset()
470 for i in range(n - 1):
471 # Index of new state
472 new_state = n + i
474 # Use feature only for determination of target state
475 if feature_kernel:
476 # state, target_state, mask = kernel(feature_kernel * full_tmat, full_states, mask)
477 state, target_state, mask = kernel(
478 full_tmat, full_states, mask, feature_kernel
479 )
480 feature_kernel.update(state, target_state, new_state)
481 else:
482 state, target_state, mask = kernel(full_tmat, full_states, mask)
484 metastability = full_tmat[state, state]
485 # Merge states in transition matrix
486 full_tmat, full_pop = utils.merge_states(
487 full_tmat, [state, target_state], new_state, full_pop
488 )
490 # Update state linkage
491 full_states[state, 1] = new_state
492 full_states[target_state, 1] = new_state
493 full_states[new_state:, 0] = new_state
495 Z[i] = [state, target_state, metastability, full_pop[new_state]]
497 # Update mask
498 mask[new_state] = True
499 mask[target_state] = False
501 return Z, full_pop