Coverage for MPP/core.py: 90%

286 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-14 12:00 +0200

1""" 

2core.py 

3======= 

4 

5Core functions for MPT class 

6""" 

7 

8__all__ = [ 

9 "cluster", 

10] 

11 

12import sys 

13import numpy as np 

14from typing import Callable 

15from numpy.typing import NDArray 

16import matplotlib.pyplot as plt 

17from matplotlib.colors import Normalize 

18 

19from anytree import NodeMixin 

20from anytree.iterators import PreOrderIter 

21 

22from . import utils 

23from . import kernel as kernel_module 

24 

25sys.setrecursionlimit(2020) 

26 

27 

28class BinaryTreeNode(NodeMixin): 

29 def __init__( 

30 self, 

31 name, 

32 tmat, 

33 population=0, 

34 q=0, 

35 feature=0, 

36 pop_thr=0.005, 

37 q_min=0.5, 

38 parent=None, 

39 left=None, 

40 right=None, 

41 ): 

42 """ 

43 This class is used to plot dendrograms. 

44 

45 prameters: 

46 ---------- 

47 

48 name (str): name of the node 

49 population (float): population of the node 

50 q (float): value at which the node is merged 

51 feature (float): some feature used for coloring 

52 parent: parent node 

53 left: left node 

54 right: right node 

55 """ 

56 self._left = None 

57 self._right = None 

58 self._is_macrostate = None 

59 self._macrostates = None 

60 self._all_macrostates = None 

61 self._parent_macrostate = None 

62 self._assigned_macrostate = None 

63 

64 self.name = name 

65 self.tmat = tmat 

66 self.n_states = int((self.tmat.shape[0] + 1) / 2) 

67 self.population = population # Base population, used if the node is a leaf 

68 self.q = q 

69 self.feature = feature 

70 self.pop_thr = pop_thr 

71 self.q_min = q_min 

72 self.parent = parent 

73 self.left = left 

74 self.right = right 

75 

76 self._x_origin = None 

77 self._x_target = None 

78 self._y_origin = None 

79 

80 self._bins = None 

81 self._feature_norm = None 

82 self._colors = None 

83 

84 @property 

85 def population(self): 

86 """Population of state.""" 

87 if self.is_leaf: 

88 return self._population 

89 else: 

90 return (self.left.population if self.left else 0) + ( 

91 self.right.population if self.right else 0 

92 ) 

93 

94 @population.setter 

95 def population(self, value): 

96 if self.is_leaf: 96 ↛ 99line 96 didn't jump to line 99 because the condition on line 96 was always true

97 self._population = value 

98 else: 

99 return ValueError("population can only be set for microstates (leaves)") 

100 

101 @property 

102 def q(self): 

103 """Q, e. g. self transition probability at which states were merged.""" 

104 return self._q 

105 

106 @q.setter 

107 def q(self, value): 

108 if 0 <= value <= 1: 108 ↛ 111line 108 didn't jump to line 111 because the condition on line 108 was always true

109 self._q = value 

110 else: 

111 raise ValueError("q must be 0 <= q <= 1") 

112 

113 @property 

114 def feature(self): 

115 """ 

116 Feature for states (e. g. fraction of native contacts). Is forwarded 

117 weighted by population 

118 """ 

119 if self.is_leaf: 

120 return self._feature 

121 else: 

122 return ( 

123 (self.left.feature * self.left.population if self.left else 0) 

124 + (self.right.feature * self.right.population if self.right else 0) 

125 ) / self.population 

126 

127 @feature.setter 

128 def feature(self, value): 

129 if 0 <= value <= 1: 129 ↛ 132line 129 didn't jump to line 132 because the condition on line 129 was always true

130 self._feature = value 

131 else: 

132 raise ValueError("feature must be 0 <= feature <= 1") 

133 

134 @property 

135 def left(self): 

136 return self._left 

137 

138 @left.setter 

139 def left(self, node): 

140 if node is not None and node.parent is not None: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true

141 raise ValueError("Node already has a parent") 

142 if self._left is not None: 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true

143 self._left.parent = None 

144 self._left = node 

145 if node is not None: 

146 node.parent = self 

147 

148 @property 

149 def right(self): 

150 return self._right 

151 

152 @right.setter 

153 def right(self, node): 

154 if node is not None and node.parent is not None: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 raise ValueError("Node already has a parent") 

156 if self._right is not None: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 self._right.parent = None 

158 self._right = node 

159 if node is not None: 

160 node.parent = self 

161 

162 @property 

163 def children(self): 

164 """Return the two child nodes.""" 

165 children = [] 

166 if self.left is not None: 

167 children.append(self.left) 

168 if self.right is not None: 

169 children.append(self.right) 

170 return children 

171 

172 @property 

173 def is_leaf(self): 

174 """Check if this node is leaf node.""" 

175 return not (self.left or self.right) 

176 

177 @property 

178 def is_macrostate(self): 

179 """Mark macrostates using this flag.""" 

180 if self._is_macrostate is None: 

181 if ( 

182 self.parent is not None 

183 and self.parent.q >= self.q_min 

184 and self.population >= self.root.population * self.pop_thr 

185 and self.siblings[0].population >= self.root.population * self.pop_thr 

186 ): 

187 self._is_macrostate = True 

188 self.siblings[0].is_macrostate = True 

189 elif self.parent is None: 

190 self._is_macrostate = True 

191 else: 

192 self.is_macrostate = False 

193 return self._is_macrostate 

194 

195 @is_macrostate.setter 

196 def is_macrostate(self, value): 

197 if isinstance(value, bool): 197 ↛ 200line 197 didn't jump to line 200 because the condition on line 197 was always true

198 self._is_macrostate = value 

199 else: 

200 raise ValueError("is_macrostate must be boolean") 

201 

202 @property 

203 def macrostates(self): 

204 """Returns all macrostate nodes.""" 

205 if self._macrostates is None: 

206 true_macrostates = [] 

207 for macrostate in self.all_macrostates: 

208 if len(macrostate.all_macrostates) == 1: 

209 true_macrostates.append(macrostate) 

210 self._macrostates = tuple(true_macrostates) 

211 return self._macrostates 

212 

213 @property 

214 def all_macrostates(self): 

215 """Returns all macrostate nodes.""" 

216 if self._all_macrostates is None: 

217 self._all_macrostates = tuple( 

218 PreOrderIter(self, filter_=lambda node: node.is_macrostate) 

219 ) 

220 return self._all_macrostates 

221 

222 @property 

223 def parent_macrostate(self): 

224 """The parent_macrostate property.""" 

225 if self._parent_macrostate is None: 

226 parent = self.parent 

227 while parent is not None and not parent.is_macrostate: 

228 parent = parent.parent 

229 self._parent_macrostate = parent 

230 return self._parent_macrostate 

231 

232 @property 

233 def assigned_macrostate(self): 

234 """The assigned_macrostate property.""" 

235 if self._assigned_macrostate is None: 

236 if self.is_leaf: 236 ↛ 266line 236 didn't jump to line 266 because the condition on line 236 was always true

237 if self.is_macrostate: 

238 self._assigned_macrostate = self 

239 else: 

240 if len(self.parent_macrostate.macrostates) == 1: 

241 self._assigned_macrostate = self.parent_macrostate 

242 else: 

243 trans_probs = [] 

244 for m in self.parent_macrostate.macrostates: 

245 macrostate = np.array( 

246 [(s.name, s.population) for s in m.leaves] 

247 ) 

248 indices = list(macrostate[:, 0]) 

249 indices.append(self.name) 

250 indices.append(0) 

251 tmp_tmat = self.tmat[np.ix_(indices, indices)].copy() 

252 pops = list(macrostate[:, 1]) 

253 pops.append(self.population) 

254 pops.append(0) 

255 tmp_tmat, pops = utils.merge_states( 

256 tmp_tmat, 

257 list(range(macrostate.shape[0])), 

258 -1, 

259 np.array(pops), 

260 ) 

261 trans_probs.append(tmp_tmat[-2, -1]) 

262 self._assigned_macrostate = self.parent_macrostate.macrostates[ 

263 np.argmax(trans_probs) 

264 ] 

265 else: 

266 self._assigned_macrostate = None 

267 return self._assigned_macrostate 

268 

269 @property 

270 def bins(self): 

271 """The bins property.""" 

272 if self.is_root and self._bins is None: 

273 leaf_features = [leaf.feature for leaf in self.leaves] 

274 min_feature = min(leaf_features) 

275 max_feature = max(leaf_features) 

276 self._bins = np.linspace(min_feature, max_feature, 11) 

277 if self.is_root: 277 ↛ 280line 277 didn't jump to line 280 because the condition on line 277 was always true

278 return self._bins 

279 else: 

280 return self.root.bins 

281 

282 @property 

283 def feature_norm(self): 

284 """The feature_norm property.""" 

285 if self.is_root and self._feature_norm is None: 

286 self._feature_norm = Normalize(self.bins[0], self.bins[-1]) 

287 if self.is_root: 

288 return self._feature_norm 

289 else: 

290 return self.root.feature_norm 

291 

292 @property 

293 def colors(self): 

294 """The colors property.""" 

295 if self.is_root and self._colors is None: 

296 cmap = plt.get_cmap("plasma_r", 10) 

297 self._colors = [cmap(idx) for idx in range(cmap.N)] 

298 if self.is_root: 

299 return self._colors 

300 else: 

301 return self.root.colors 

302 

303 @property 

304 def color(self): 

305 """Color according to feature.""" 

306 for color, rlower, rhigher in zip( 306 ↛ 311line 306 didn't jump to line 311 because the loop on line 306 didn't complete

307 self.colors, np.arange(0, 1, 0.1), np.arange(0.1, 1.1, 0.1) 

308 ): 

309 if rlower <= self.feature_norm(self.feature) <= rhigher: 

310 return color 

311 return "k" 

312 

313 @property 

314 def edge_width(self): 

315 """Edge width from population.""" 

316 return 6 * self.population / self.root.population 

317 

318 @property 

319 def macrostate(self): 

320 """ 

321 Macrostate this state belongs to. None if no macrostates are found 

322 above in tree. 

323 """ 

324 node = self 

325 while not node.is_macrostate and node.parent: 

326 node = node.parent 

327 if node.is_macrostate: 

328 return node 

329 else: 

330 return None 

331 

332 @property 

333 def x(self): 

334 """X coordinates for dandrogram for this node""" 

335 return np.array([self.x_origin, self.x_origin, self.x_target]) 

336 

337 @property 

338 def x_origin(self): 

339 """The x_origin property.""" 

340 if not self.is_leaf: 

341 if not self._x_origin: 

342 self.x_origin = self.children[0].x_target 

343 return self._x_origin 

344 

345 @x_origin.setter 

346 def x_origin(self, value): 

347 self._x_origin = value 

348 

349 @property 

350 def x_target(self): 

351 """The x_target property.""" 

352 if not self._x_target: 

353 if self.is_root: 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true

354 self.x_target = self.x_origin 

355 else: 

356 self.x_target = (self.x_origin + self.siblings[0].x_origin) / 2 

357 return self._x_target 

358 

359 @x_target.setter 

360 def x_target(self, value): 

361 self._x_target = value 

362 

363 @property 

364 def y(self): 

365 """Y coordinates for dandrogram for this node""" 

366 return np.array([self.y_origin, self.y_target, self.y_target]) 

367 

368 @property 

369 def y_origin(self): 

370 """The y_origin property.""" 

371 if self.is_leaf: 

372 return 0 

373 else: 

374 if not self._y_origin: 374 ↛ 376line 374 didn't jump to line 376 because the condition on line 374 was always true

375 self.y_origin = self.children[0].y_target 

376 return self._y_origin 

377 

378 @y_origin.setter 

379 def y_origin(self, value): 

380 self._y_origin = value 

381 

382 @property 

383 def y_target(self): 

384 """The y_target property.""" 

385 if self.parent: 385 ↛ 388line 385 didn't jump to line 388 because the condition on line 385 was always true

386 return self.parent.q 

387 else: 

388 return 1 

389 

390 def plot(self, ax): 

391 for c in self.children: 

392 ax = c.plot(ax) 

393 # Remove this condition if root should be plotted as well. 

394 if not self.is_root: 

395 ax.plot( 

396 self.x, 

397 self.y, 

398 color=self.color, 

399 linewidth=self.edge_width if self.edge_width > 0.15 else 0.15, 

400 ) 

401 return ax 

402 

403 def plot_tree(self, ax): 

404 for i, leaf in enumerate(self.leaves): 

405 leaf.x_origin = i 

406 return self.plot(ax) 

407 

408 

409def cluster( 

410 tmat: NDArray[float], 

411 pop: NDArray[np.int_], 

412 kernel: Callable[ 

413 [NDArray[float], NDArray[np.int_], NDArray[np.bool_]], 

414 [np.int_, np.int_, NDArray[np.bool_]], 

415 ] = kernel_module.LumpingKernel(), 

416 feature_kernel=None, 

417) -> (NDArray[float], NDArray[np.int_]): 

418 """ 

419 cluster 

420 ------- 

421 Perform full clustering for a transition matrix, given populations and a 

422 kernel. 

423 

424 tmat (NDArray[float]): transition matrix, e. g. from 

425 mh.msm.estimate_markov_model 

426 pop (NDArray[float]): populations of microstates 

427 kernel: kernel object that determines the next merge 

428 

429 returns Z (np.ndarray), full_pop (np.ndarray): 

430 The Z matrix holds the full merging of microstates: 

431 0: origin state 

432 1: target state 

433 2: distance between origin and target 

434 3: joint population 

435 i: Z[i, 0] and Z[i, 1] are combined to cluster n + i 

436 reference: scipy.cluster.hierarchy.linkage 

437 full_pop holds all state populations from state 0 to n + i 

438 """ 

439 n = tmat.shape[0] 

440 

441 full_tmat = np.zeros((2 * n - 1, 2 * n - 1), dtype=tmat.dtype.type) 

442 full_tmat[:n, :n] = tmat 

443 

444 full_pop = np.zeros(2 * n - 1, dtype=pop.dtype.type) 

445 full_pop[:n] = pop 

446 

447 if tmat.shape[0] < 2**7: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true

448 states_type = np.uint8 

449 elif tmat.shape[0] < 2**15: 449 ↛ 452line 449 didn't jump to line 452 because the condition on line 449 was always true

450 states_type = np.uint16 

451 else: 

452 states_type = np.uint32 

453 

454 # complete linkage 

455 full_states = np.zeros((2 * n - 1, 2), dtype=states_type) 

456 full_states[:n, 0] = np.arange(0, n) 

457 

458 mask = np.full(2 * n - 1, False) 

459 mask[:n] = True 

460 

461 # 0: state a 

462 # 1: state b 

463 # 2: distance between a and b 

464 # 3: population 

465 # i: Z[i, 0] and Z[i, 1] are combined to cluster n + i 

466 Z = np.zeros((n - 1, 4), dtype=np.float32) 

467 

468 if feature_kernel: 

469 feature_kernel.reset() 

470 for i in range(n - 1): 

471 # Index of new state 

472 new_state = n + i 

473 

474 # Use feature only for determination of target state 

475 if feature_kernel: 

476 # state, target_state, mask = kernel(feature_kernel * full_tmat, full_states, mask) 

477 state, target_state, mask = kernel( 

478 full_tmat, full_states, mask, feature_kernel 

479 ) 

480 feature_kernel.update(state, target_state, new_state) 

481 else: 

482 state, target_state, mask = kernel(full_tmat, full_states, mask) 

483 

484 metastability = full_tmat[state, state] 

485 # Merge states in transition matrix 

486 full_tmat, full_pop = utils.merge_states( 

487 full_tmat, [state, target_state], new_state, full_pop 

488 ) 

489 

490 # Update state linkage 

491 full_states[state, 1] = new_state 

492 full_states[target_state, 1] = new_state 

493 full_states[new_state:, 0] = new_state 

494 

495 Z[i] = [state, target_state, metastability, full_pop[new_state]] 

496 

497 # Update mask 

498 mask[new_state] = True 

499 mask[target_state] = False 

500 

501 return Z, full_pop