Coverage for MPT/core.py: 90%

285 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-08 16:59 +0200

1""" 

2core.py 

3======= 

4 

5Core functions for MPT class 

6""" 

7 

8__all__ = [ 

9 "cluster", 

10] 

11 

12import sys 

13import numpy as np 

14from typing import Callable 

15from numpy.typing import NDArray 

16import matplotlib.pyplot as plt 

17from matplotlib.colors import Normalize 

18 

19from anytree import NodeMixin 

20from anytree.iterators import PreOrderIter 

21 

22import MPT.utils as utils 

23import MPT.kernel as kern 

24 

25sys.setrecursionlimit(2020) 

26 

27 

28class BinaryTreeNode(NodeMixin): 

29 def __init__( 

30 self, 

31 name, 

32 tmat, 

33 population=0, 

34 q=0, 

35 feature=0, 

36 macrostate_thresholds=(0.005, 0.5), 

37 parent=None, 

38 left=None, 

39 right=None, 

40 ): 

41 """ 

42 This class is used to plot dendrograms. 

43 

44 prameters: 

45 ---------- 

46 

47 name (str): name of the node 

48 population (float): population of the node 

49 q (float): value at which the node is merged 

50 feature (float): some feature used for coloring 

51 parent: parent node 

52 left: left node 

53 right: right node 

54 """ 

55 self._left = None 

56 self._right = None 

57 self._is_macrostate = None 

58 self._macrostates = None 

59 self._all_macrostates = None 

60 self._parent_macrostate = None 

61 self._assigned_macrostate = None 

62 

63 self.name = name 

64 self.tmat = tmat 

65 self.n_states = int((self.tmat.shape[0] + 1) / 2) 

66 self.population = population # Base population, used if the node is a leaf 

67 self.q = q 

68 self.feature = feature 

69 self.pop_thr, self.q_min = macrostate_thresholds 

70 self.parent = parent 

71 self.left = left 

72 self.right = right 

73 

74 self._x_origin = None 

75 self._x_target = None 

76 self._y_origin = None 

77 

78 self._bins = None 

79 self._feature_norm = None 

80 self._colors = None 

81 

82 @property 

83 def population(self): 

84 """Population of state.""" 

85 if self.is_leaf: 

86 return self._population 

87 else: 

88 return (self.left.population if self.left else 0) + ( 

89 self.right.population if self.right else 0 

90 ) 

91 

92 @population.setter 

93 def population(self, value): 

94 if self.is_leaf: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true

95 self._population = value 

96 else: 

97 return ValueError("population can only be set for microstates (leaves)") 

98 

99 @property 

100 def q(self): 

101 """Q, e. g. self transition probability at which states were merged.""" 

102 return self._q 

103 

104 @q.setter 

105 def q(self, value): 

106 if 0 <= value <= 1: 106 ↛ 109line 106 didn't jump to line 109 because the condition on line 106 was always true

107 self._q = value 

108 else: 

109 raise ValueError("q must be 0 <= q <= 1") 

110 

111 @property 

112 def feature(self): 

113 """ 

114 Feature for states (e. g. fraction of native contacts). Is forwarded 

115 weighted by population 

116 """ 

117 if self.is_leaf: 

118 return self._feature 

119 else: 

120 return ( 

121 (self.left.feature * self.left.population if self.left else 0) 

122 + (self.right.feature * self.right.population if self.right else 0) 

123 ) / self.population 

124 

125 @feature.setter 

126 def feature(self, value): 

127 if 0 <= value <= 1: 127 ↛ 130line 127 didn't jump to line 130 because the condition on line 127 was always true

128 self._feature = value 

129 else: 

130 raise ValueError("feature must be 0 <= feature <= 1") 

131 

132 @property 

133 def left(self): 

134 return self._left 

135 

136 @left.setter 

137 def left(self, node): 

138 if node is not None and node.parent is not None: 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true

139 raise ValueError("Node already has a parent") 

140 if self._left is not None: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true

141 self._left.parent = None 

142 self._left = node 

143 if node is not None: 

144 node.parent = self 

145 

146 @property 

147 def right(self): 

148 return self._right 

149 

150 @right.setter 

151 def right(self, node): 

152 if node is not None and node.parent is not None: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true

153 raise ValueError("Node already has a parent") 

154 if self._right is not None: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 self._right.parent = None 

156 self._right = node 

157 if node is not None: 

158 node.parent = self 

159 

160 @property 

161 def children(self): 

162 """Return the two child nodes.""" 

163 children = [] 

164 if self.left is not None: 

165 children.append(self.left) 

166 if self.right is not None: 

167 children.append(self.right) 

168 return children 

169 

170 @property 

171 def is_leaf(self): 

172 """Check if this node is leaf node.""" 

173 return not (self.left or self.right) 

174 

175 @property 

176 def is_macrostate(self): 

177 """Mark macrostates using this flag.""" 

178 if self._is_macrostate is None: 

179 if ( 

180 self.parent is not None 

181 and self.parent.q >= self.q_min 

182 and self.population >= self.root.population * self.pop_thr 

183 and self.siblings[0].population >= self.root.population * self.pop_thr 

184 ): 

185 self._is_macrostate = True 

186 self.siblings[0].is_macrostate = True 

187 elif self.parent is None: 

188 self._is_macrostate = True 

189 else: 

190 self.is_macrostate = False 

191 return self._is_macrostate 

192 

193 @is_macrostate.setter 

194 def is_macrostate(self, value): 

195 if isinstance(value, bool): 195 ↛ 198line 195 didn't jump to line 198 because the condition on line 195 was always true

196 self._is_macrostate = value 

197 else: 

198 raise ValueError("is_macrostate must be boolean") 

199 

200 @property 

201 def macrostates(self): 

202 """Returns all macrostate nodes.""" 

203 if self._macrostates is None: 

204 true_macrostates = [] 

205 for macrostate in self.all_macrostates: 

206 if len(macrostate.all_macrostates) == 1: 

207 true_macrostates.append(macrostate) 

208 self._macrostates = tuple(true_macrostates) 

209 return self._macrostates 

210 

211 @property 

212 def all_macrostates(self): 

213 """Returns all macrostate nodes.""" 

214 if self._all_macrostates is None: 

215 self._all_macrostates = tuple( 

216 PreOrderIter(self, filter_=lambda node: node.is_macrostate) 

217 ) 

218 return self._all_macrostates 

219 

220 @property 

221 def parent_macrostate(self): 

222 """The parent_macrostate property.""" 

223 if self._parent_macrostate is None: 

224 parent = self.parent 

225 while parent is not None and not parent.is_macrostate: 

226 parent = parent.parent 

227 self._parent_macrostate = parent 

228 return self._parent_macrostate 

229 

230 @property 

231 def assigned_macrostate(self): 

232 """The assigned_macrostate property.""" 

233 if self._assigned_macrostate is None: 

234 if self.is_leaf: 234 ↛ 264line 234 didn't jump to line 264 because the condition on line 234 was always true

235 if self.is_macrostate: 

236 self._assigned_macrostate = self 

237 else: 

238 if len(self.parent_macrostate.macrostates) == 1: 

239 self._assigned_macrostate = self.parent_macrostate 

240 else: 

241 trans_probs = [] 

242 for m in self.parent_macrostate.macrostates: 

243 macrostate = np.array( 

244 [(s.name, s.population) for s in m.leaves] 

245 ) 

246 indices = list(macrostate[:, 0]) 

247 indices.append(self.name) 

248 indices.append(0) 

249 tmp_tmat = self.tmat[np.ix_(indices, indices)].copy() 

250 pops = list(macrostate[:, 1]) 

251 pops.append(self.population) 

252 pops.append(0) 

253 tmp_tmat, pops = utils.merge_states( 

254 tmp_tmat, 

255 list(range(macrostate.shape[0])), 

256 -1, 

257 np.array(pops), 

258 ) 

259 trans_probs.append(tmp_tmat[-2, -1]) 

260 self._assigned_macrostate = self.parent_macrostate.macrostates[ 

261 np.argmax(trans_probs) 

262 ] 

263 else: 

264 self._assigned_macrostate = None 

265 return self._assigned_macrostate 

266 

267 @property 

268 def bins(self): 

269 """The bins property.""" 

270 if self.is_root and self._bins is None: 

271 leaf_features = [leaf.feature for leaf in self.leaves] 

272 min_feature = min(leaf_features) 

273 max_feature = max(leaf_features) 

274 self._bins = np.linspace(min_feature, max_feature, 11) 

275 if self.is_root: 275 ↛ 278line 275 didn't jump to line 278 because the condition on line 275 was always true

276 return self._bins 

277 else: 

278 return self.root.bins 

279 

280 @property 

281 def feature_norm(self): 

282 """The feature_norm property.""" 

283 if self.is_root and self._feature_norm is None: 

284 self._feature_norm = Normalize(self.bins[0], self.bins[-1]) 

285 if self.is_root: 

286 return self._feature_norm 

287 else: 

288 return self.root.feature_norm 

289 

290 @property 

291 def colors(self): 

292 """The colors property.""" 

293 if self.is_root and self._colors is None: 

294 cmap = plt.get_cmap("plasma_r", 10) 

295 self._colors = [cmap(idx) for idx in range(cmap.N)] 

296 if self.is_root: 

297 return self._colors 

298 else: 

299 return self.root.colors 

300 

301 @property 

302 def color(self): 

303 """Color according to feature.""" 

304 for color, rlower, rhigher in zip( 304 ↛ 309line 304 didn't jump to line 309 because the loop on line 304 didn't complete

305 self.colors, np.arange(0, 1, 0.1), np.arange(0.1, 1.1, 0.1) 

306 ): 

307 if rlower <= self.feature_norm(self.feature) <= rhigher: 

308 return color 

309 return "k" 

310 

311 @property 

312 def edge_width(self): 

313 """Edge width from population.""" 

314 return 6 * self.population / self.root.population 

315 

316 @property 

317 def macrostate(self): 

318 """ 

319 Macrostate this state belongs to. None if no macrostates are found 

320 above in tree. 

321 """ 

322 node = self 

323 while not node.is_macrostate and node.parent: 

324 node = node.parent 

325 if node.is_macrostate: 

326 return node 

327 else: 

328 return None 

329 

330 @property 

331 def x(self): 

332 """X coordinates for dandrogram for this node""" 

333 return np.array([self.x_origin, self.x_origin, self.x_target]) 

334 

335 @property 

336 def x_origin(self): 

337 """The x_origin property.""" 

338 if not self.is_leaf: 

339 if not self._x_origin: 

340 self.x_origin = self.children[0].x_target 

341 return self._x_origin 

342 

343 @x_origin.setter 

344 def x_origin(self, value): 

345 self._x_origin = value 

346 

347 @property 

348 def x_target(self): 

349 """The x_target property.""" 

350 if not self._x_target: 

351 if self.is_root: 351 ↛ 352line 351 didn't jump to line 352 because the condition on line 351 was never true

352 self.x_target = self.x_origin 

353 else: 

354 self.x_target = (self.x_origin + self.siblings[0].x_origin) / 2 

355 return self._x_target 

356 

357 @x_target.setter 

358 def x_target(self, value): 

359 self._x_target = value 

360 

361 @property 

362 def y(self): 

363 """Y coordinates for dandrogram for this node""" 

364 return np.array([self.y_origin, self.y_target, self.y_target]) 

365 

366 @property 

367 def y_origin(self): 

368 """The y_origin property.""" 

369 if self.is_leaf: 

370 return 0 

371 else: 

372 if not self._y_origin: 372 ↛ 374line 372 didn't jump to line 374 because the condition on line 372 was always true

373 self.y_origin = self.children[0].y_target 

374 return self._y_origin 

375 

376 @y_origin.setter 

377 def y_origin(self, value): 

378 self._y_origin = value 

379 

380 @property 

381 def y_target(self): 

382 """The y_target property.""" 

383 if self.parent: 383 ↛ 386line 383 didn't jump to line 386 because the condition on line 383 was always true

384 return self.parent.q 

385 else: 

386 return 1 

387 

388 def plot(self, ax): 

389 for c in self.children: 

390 ax = c.plot(ax) 

391 # Remove this condition if root should be plotted as well. 

392 if not self.is_root: 

393 ax.plot( 

394 self.x, 

395 self.y, 

396 color=self.color, 

397 linewidth=self.edge_width if self.edge_width > 0.15 else 0.15, 

398 ) 

399 return ax 

400 

401 def plot_tree(self, ax): 

402 for i, leaf in enumerate(self.leaves): 

403 leaf.x_origin = i 

404 return self.plot(ax) 

405 

406 

407def cluster( 

408 tmat: NDArray[float], 

409 pop: NDArray[np.int_], 

410 kernel: Callable[ 

411 [NDArray[float], NDArray[np.int_], NDArray[np.bool_]], 

412 [np.int_, np.int_, NDArray[np.bool_]], 

413 ] = kern.MPTKernel(), 

414 feature_kernel=None, 

415) -> (NDArray[float], NDArray[np.int_]): 

416 """ 

417 cluster 

418 ------- 

419 Perform full clustering for a transition matrix, given populations and a 

420 kernel. 

421 

422 tmat (NDArray[float]): transition matrix, e. g. from 

423 mh.msm.estimate_markov_model 

424 pop (NDArray[float]): populations of microstates 

425 kernel: kernel object that determines the next merge 

426 

427 returns Z (np.ndarray), full_pop (np.ndarray): 

428 The Z matrix holds the full merging of microstates: 

429 0: origin state 

430 1: target state 

431 2: distance between origin and target 

432 3: joint population 

433 i: Z[i, 0] and Z[i, 1] are combined to cluster n + i 

434 reference: scipy.cluster.hierarchy.linkage 

435 full_pop holds all state populations from state 0 to n + i 

436 """ 

437 n = tmat.shape[0] 

438 

439 full_tmat = np.zeros((2 * n - 1, 2 * n - 1), dtype=tmat.dtype.type) 

440 full_tmat[:n, :n] = tmat 

441 

442 full_pop = np.zeros(2 * n - 1, dtype=pop.dtype.type) 

443 full_pop[:n] = pop 

444 

445 if tmat.shape[0] < 2**7: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true

446 states_type = np.uint8 

447 elif tmat.shape[0] < 2**15: 447 ↛ 450line 447 didn't jump to line 450 because the condition on line 447 was always true

448 states_type = np.uint16 

449 else: 

450 states_type = np.uint32 

451 

452 # complete linkage 

453 full_states = np.zeros((2 * n - 1, 2), dtype=states_type) 

454 full_states[:n, 0] = np.arange(0, n) 

455 

456 mask = np.full(2 * n - 1, False) 

457 mask[:n] = True 

458 

459 # 0: state a 

460 # 1: state b 

461 # 2: distance between a and b 

462 # 3: population 

463 # i: Z[i, 0] and Z[i, 1] are combined to cluster n + i 

464 Z = np.zeros((n - 1, 4), dtype=np.float32) 

465 

466 if feature_kernel: 

467 feature_kernel.reset() 

468 for i in range(n - 1): 

469 # Index of new state 

470 new_state = n + i 

471 

472 # Use feature only for determination of target state 

473 if feature_kernel: 

474 # state, target_state, mask = kernel(feature_kernel * full_tmat, full_states, mask) 

475 state, target_state, mask = kernel( 

476 full_tmat, full_states, mask, feature_kernel 

477 ) 

478 feature_kernel.update(state, target_state, new_state) 

479 else: 

480 state, target_state, mask = kernel(full_tmat, full_states, mask) 

481 

482 metastability = full_tmat[state, state] 

483 # Merge states in transition matrix 

484 full_tmat, full_pop = utils.merge_states( 

485 full_tmat, [state, target_state], new_state, full_pop 

486 ) 

487 

488 # Update state linkage 

489 full_states[state, 1] = new_state 

490 full_states[target_state, 1] = new_state 

491 full_states[new_state:, 0] = new_state 

492 

493 Z[i] = [state, target_state, metastability, full_pop[new_state]] 

494 

495 # Update mask 

496 mask[new_state] = True 

497 mask[target_state] = False 

498 

499 return Z, full_pop