Coverage for MPT/MPT.py: 86%
402 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 11:20 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 11:20 +0200
1import os
2import datetime
3import warnings
4import numpy as np
5import msmhelper as mh
6import matplotlib.pyplot as plt
7import mdtraj as md
9from pathlib import Path
10from tqdm import tqdm
11from typing import Callable, List
12from numpy.typing import NDArray
13from collections.abc import Iterable
14from sklearn.metrics import davies_bouldin_score
15import pygpcca as gp
17from MPT import core
19# import MPT.core as core
20import MPT.utils as utils
21import MPT.kernel as kernel_module
22from MPT.graph import draw_knetwork
24import MPT.plot as plot
26# TODO:
27# - change traj and macrotraj to list - add one dimension. First, mark all places that need adaptation.
28# - Connect with contacts, check for implications. Float contacts file: /data/PDZ3_Ali/short_ligand/reduction/trans/contacts_analysis/cluster1-7/data/dist_all
29# - internally change trajectory to 0-based, still support 1-based, ussue warning; Marcotraj as well
32class MPT(object):
33 def __init__(
34 self,
35 # traj: List[NDArray[np.int_]],
36 traj: NDArray[np.int_],
37 tlag: int,
38 feature_traj: NDArray[float] = None,
39 contact_threshold=0.45,
40 feature_type=np.float64,
41 macrostate_thresholds: tuple = (0.005, 0.5),
42 limits=None,
43 quiet=False,
44 frame_length=0.2,
45 ):
46 self.traj = traj
47 self.tlag = tlag
48 self.pop_thr, self.q_min = macrostate_thresholds
49 self.limits = limits
50 tmat, states = mh.msm.estimate_markov_model(
51 utils.get_multi_state_traj(self.traj, self.limits),
52 self.tlag,
53 )
54 self.tmat = tmat.astype(np.float64)
55 _, self.pop = np.unique(self.traj, return_counts=True)
56 self.n_states = len(states)
57 self.quiet = quiet
58 if feature_traj is not None: 58 ↛ 65line 58 didn't jump to line 65 because the condition on line 58 was always true
59 self._add_feature(
60 feature_traj,
61 contact_threshold=contact_threshold,
62 feature_type=feature_type,
63 )
64 else:
65 self._add_feature(np.ones((traj.shape, 1)))
67 self.Z = None
68 self._timescales = None
69 self._linkage = None
70 self._macro_pop = None
71 self._tree = None
72 self._shannon_entropy = None
73 self._davies_bouldin_index = None
74 self._gmrq = None
75 self._reference = None
76 self._topology_file = None
77 self._xtc_trajectory_file = None
78 self._rmsd = None
79 self._n_i = None
80 self._macro_micro_feature = None
81 self.xtc_stride = None
82 self.frame_length = frame_length
84 def mpt(
85 self,
86 kernel: Callable[
87 [NDArray[float], NDArray[np.int_], NDArray[np.bool_]],
88 [np.int_, np.int_, NDArray[np.bool_]],
89 ] = kernel_module.MPTKernel(),
90 feature_kernel=None,
91 n: int = 1,
92 ) -> (NDArray[float], NDArray[np.int_]):
93 """Perform MPT"""
94 self.n_runs = n
95 self.kernel = kernel
96 self.feature_kernel = feature_kernel
97 # n: number of macrostates
99 self.Z = np.zeros((self.n_runs, self.n_states - 1, 4), dtype=np.float64)
100 self.full_pop = np.zeros((self.n_runs, 2 * self.n_states - 1), dtype=np.uint32)
101 if self.quiet: 101 ↛ 104line 101 didn't jump to line 104 because the condition on line 101 was always true
102 iter = range(self.n_runs)
103 else:
104 print("Clustering ...")
105 iter = tqdm(range(self.n_runs))
106 for i in iter:
107 self.Z[i], self.full_pop[i] = core.cluster(
108 self.tmat,
109 self.pop,
110 kernel=self.kernel,
111 feature_kernel=self.feature_kernel,
112 )
113 self.assign_macrostates()
115 def _add_feature(
116 self,
117 feature_traj: NDArray[float],
118 contact_threshold=0.45,
119 feature_type=np.float64,
120 ):
121 """
122 Add feature data to instance
124 feature_traj (NDArray(float)): frames x features
125 """
126 if feature_traj.shape[0] != self.traj.shape[0]: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true
127 raise ValueError(
128 "feature_traj must have the same length as the microstate trajectory (mpp.traj)"
129 )
130 if feature_traj.ndim == 2: 130 ↛ 133line 130 didn't jump to line 133 because the condition on line 130 was always true
131 self.multi_feature_traj = feature_traj.astype(feature_type)
132 else:
133 raise ValueError("feature_traj must be 2 D")
135 self.contact_threshold = contact_threshold
136 self.multi_feature_traj_bool = self.multi_feature_traj < self.contact_threshold
137 self.feature_traj = self.multi_feature_traj_bool.mean(axis=1)
138 self.feature = np.zeros(self.n_states, dtype=feature_type)
139 for i in range(self.n_states):
140 self.feature[i] = self.feature_traj[self.traj == i].mean()
142 def assign_macrostates(self, macrotraj_type=np.uint8):
143 """Assign microstates to macrostates and collect associate data"""
144 self.macrostate_feature = []
145 self.macrostate_multi_feature = []
146 self.macrostate_assignment = []
147 self.macrostates_map = []
148 self.macro_tmat = []
149 self.macrotraj = np.zeros(
150 (self.n_runs, self.traj.shape[0]), dtype=macrotraj_type
151 )
152 self.n_macrostates = []
154 if self.quiet: 154 ↛ 157line 154 didn't jump to line 157 because the condition on line 154 was always true
155 iter = range(self.n_runs)
156 else:
157 print("Assigning macrostates ...")
158 iter = tqdm(range(self.n_runs))
159 for n_i in iter:
160 self.macrostate_assignment.append(
161 utils.get_macrostate_assignment_from_tree(self.tree[n_i])
162 )
164 # Calculate other macrostate related values
165 self.macrostates_map.append(
166 np.zeros(self.n_states, dtype=self.traj.dtype.type)
167 )
168 mas, mis = np.where(self.macrostate_assignment[-1] == 1)
169 self.macrostates_map[-1][mis] = mas
170 self.macro_tmat.append(
171 utils.macro_tmat(self.tmat, self.macrostate_assignment[-1], self.pop)
172 )
173 self.macrotraj[n_i] = utils.translate_traj(
174 self.traj, self.macrostates_map[-1]
175 )
176 self.n_macrostates.append(self.macrostate_assignment[-1].shape[0])
177 self.macrostate_feature.append(
178 [
179 self.feature_traj[np.where(self.macrotraj[n_i] == i)].mean()
180 for i in np.arange(self.n_macrostates[-1])
181 ]
182 )
183 self.macrostate_multi_feature.append(
184 [
185 self.multi_feature_traj_bool[
186 np.where(self.macrotraj[n_i] == i)
187 ].mean(axis=0)
188 for i in np.arange(self.n_macrostates[-1], dtype=int)
189 ]
190 )
192 def gpcca(self, n_macrostates, macrotraj_type=np.uint8):
193 self.gpcca = gp.GPCCA(self.tmat, method="krylov")
194 self.gpcca.optimize(n_macrostates)
196 self.n_runs = 1
197 self.n_macrostates = [n_macrostates]
199 gma = self.gpcca.macrostate_assignment
200 gmt = np.zeros(self.traj.shape, dtype=self.traj.dtype)
201 gmf = np.empty(self.n_macrostates[0])
202 for i in range(self.n_macrostates[0]):
203 gmt[np.where(np.isin(self.traj, np.where(gma == i)[0]))[0]] = i + 1
204 gmf[i] = self.feature_traj[gmt == i + 1].mean()
206 order = np.argsort(gmf)[::-1]
207 new_states = np.empty(self.n_macrostates[0], dtype=macrotraj_type)
208 new_states[order] = np.arange(self.n_macrostates[0], dtype=macrotraj_type)
209 self.macrostates_map = [np.empty(gma.shape, dtype=macrotraj_type)]
210 for i in range(self.n_macrostates[0]):
211 self.macrostates_map[0][np.where(gma == i)] = new_states[i]
213 self.macrostate_assignment = [
214 np.full((self.n_macrostates[0], self.macrostates_map[0].shape[0]), False)
215 ]
216 self.macrostate_assignment[0][
217 self.macrostates_map[0],
218 np.arange(self.macrostates_map[0].shape[0], dtype=int),
219 ] = True
220 self.macrostate_feature = [gmf[order]]
221 self.macrotraj = np.empty(
222 (self.n_runs, self.traj.shape[0]), dtype=macrotraj_type
223 )
224 self.macrotraj[0] = utils.translate_traj(self.traj, self.macrostates_map[0])
225 self.macro_tmat = [
226 utils.macro_tmat(self.tmat, self.macrostate_assignment[0], self.pop)
227 ]
229 # Create mock Z and mock full_pop for Sankey plot
230 # After implementation remove mock Z.npy file in run.py
231 self.Z = np.zeros((self.n_runs, self.n_states - 1, 4), dtype=np.float64)
232 self.full_pop = np.zeros((self.n_runs, 2 * self.n_states - 1), dtype=np.uint32)
233 self.full_pop[0, : self.n_states] = self.pop
235 last_merged = self.n_states
236 merge = 0
237 for macrostate in range(self.n_macrostates[0]):
238 microstates = np.where(self.macrostates_map[0] == macrostate)[0]
239 origin = microstates[0]
240 if microstates.shape[0] > 1:
241 for target in microstates[1:]:
242 intermediate_state = self.n_states + merge
243 self.full_pop[0, intermediate_state] = self.full_pop[
244 0, [origin, target]
245 ].sum()
246 self.Z[0, merge] = (
247 origin,
248 target,
249 0.2,
250 self.full_pop[0, intermediate_state],
251 )
252 origin = intermediate_state
253 merge += 1
255 if macrostate > 0:
256 intermediate_state = self.n_states + merge
257 target = last_merged
258 self.full_pop[0, intermediate_state] = self.full_pop[
259 0, [origin, target]
260 ].sum()
261 self.Z[0, merge] = (
262 origin,
263 target,
264 0.9,
265 self.full_pop[0, intermediate_state],
266 )
267 last_merged = intermediate_state
268 merge += 1
269 else:
270 last_merged = origin
272 self.tree
273 self.pop_thr = 0
274 self.q_min = 0.5
276 def __add__(self, other):
277 """'+' operator is used to calculate similarity"""
278 if self.n_runs == 1 and other.n_runs >= 1: 278 ↛ 280line 278 didn't jump to line 280 because the condition on line 278 was never true
279 # reference
280 ref = self
281 # stochastic lumping
282 sto = other
283 elif other.n_runs == 1 and self.n_runs >= 1: 283 ↛ 287line 283 didn't jump to line 287 because the condition on line 283 was always true
284 ref = other
285 sto = self
286 else:
287 raise ValueError("The reference lumping must have exactly one run.")
288 return ref, sto, utils.similarity(ref, sto)
290 @property
291 def n_i(self):
292 """Sets self.n_i to the lumping with longest first implied timescale."""
293 if self._n_i is None:
294 if self.n_runs > 1: 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true
295 self._n_i = np.argmax(self.timescales[:, 0])
296 else:
297 self._n_i = 0
298 return self._n_i
300 @property
301 def timescales(self):
302 """The timescales property."""
303 if self._timescales is None: 303 ↛ 305line 303 didn't jump to line 305 because the condition on line 303 was always true
304 self.calc_timescales()
305 return self._timescales
307 def calc_timescales(self, ntimescales=3, dtype=np.float32):
308 """Calculate implied timescales"""
309 self._timescales = np.zeros((self.n_runs, ntimescales), dtype=dtype)
310 for i, traj in enumerate(self.macrotraj):
311 self._timescales[i, :] = mh.msm.implied_timescales(
312 utils.get_multi_state_traj(traj, self.limits),
313 [self.tlag],
314 ntimescales=ntimescales,
315 )[0]
317 def save_macrotraj(self, out):
318 header = (
319 f"# Created by MPT class\n"
320 f"# Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
321 f"# Trajectory contains {self.n_macrostates[self.n_i]} states and {self.macrotraj.shape[1]} frames.\n"
322 f"# Trajectory index: {self.n_i}\n"
323 )
324 np.savetxt(out, self.macrotraj[self.n_i], fmt="%.0f", header=header)
326 def save_Z(self, out, n_i="all"):
327 """Save Z matrix"""
328 if not out.endswith(".npy"): 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true
329 out += ".npy"
331 if n_i == "all": 331 ↛ 333line 331 didn't jump to line 333 because the condition on line 331 was always true
332 np.save(out, self.Z)
333 elif isinstance(n_i, Iterable):
334 np.save(out, self.Z[n_i])
335 elif isinstance(n_i, int):
336 np.save(out, self.Z[n_i : n_i + 1])
337 else:
338 raise ValueError("n_i must be 'all', Iterable or int.")
340 def from_Z(self, Z):
341 """Load Z matrix"""
342 if isinstance(Z, np.ndarray): 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true
343 self.Z = Z
344 elif os.path.exists(Z): 344 ↛ 347line 344 didn't jump to line 347 because the condition on line 344 was always true
345 self.Z = np.load(Z)
346 else:
347 raise ValueError("Z must be a numpy array or a .npy file.")
349 self.n_runs = self.Z.shape[0]
350 # n: number of macrostates
351 tmat, states = mh.msm.estimate_markov_model(
352 utils.get_multi_state_traj(self.traj, self.limits),
353 self.tlag,
354 )
355 self.tmat = tmat.astype(float)
356 _, self.pop = np.unique(self.traj, return_counts=True)
357 self.n_states = len(states)
358 self.full_pop = np.zeros((self.n_runs, 2 * self.n_states - 1), dtype=np.uint32)
359 self.full_pop[:, : self.n_states] = self.pop
360 self.full_pop[:, self.n_states :] = self.Z[:, :, 3]
362 self.assign_macrostates()
364 @property
365 def linkage(self):
366 """The linkage property."""
367 if self._linkage is None:
368 self._linkage = utils.Z_to_linkage(self.Z[self.n_i])
369 return self._linkage
371 @property
372 def macro_pop(self):
373 """The macro_pop property."""
374 if self._macro_pop is None:
375 self._macro_pop = []
376 for j, ma in enumerate(self.macrostate_assignment):
377 self._macro_pop.append(
378 np.zeros(ma.shape[0], dtype=self.full_pop.dtype.type)
379 )
380 for i, m in enumerate(ma):
381 self._macro_pop[-1][i] = self.full_pop[j, : self.n_states][
382 m.astype(bool)
383 ].sum()
384 return self._macro_pop
386 @property
387 def tree(self):
388 """The tree property."""
389 if self._tree is None:
390 self._tree = []
391 for z, pop in zip(self.Z, self.full_pop):
392 self._tree.append(self.build_tree(z, pop))
393 return self._tree
395 def build_tree(self, Z, full_pop):
396 """Build tree using BinaryTreeNode and return root"""
397 macrostate_thresholds = (self.pop_thr, self.q_min)
398 n = Z.shape[0] + 1
399 nodes = {}
400 for i, (state, target_state, q, pop) in enumerate(Z):
401 state = int(state)
402 target_state = int(target_state)
403 if state not in nodes:
404 nodes[state] = core.BinaryTreeNode(
405 state,
406 self.tmat,
407 population=full_pop[state],
408 q=q,
409 macrostate_thresholds=macrostate_thresholds,
410 )
411 if target_state not in nodes:
412 nodes[target_state] = core.BinaryTreeNode(
413 target_state,
414 self.tmat,
415 population=full_pop[target_state],
416 q=q,
417 macrostate_thresholds=macrostate_thresholds,
418 )
419 nodes[n + i] = core.BinaryTreeNode(
420 n + i, self.tmat, q=q, macrostate_thresholds=macrostate_thresholds
421 )
422 nodes[n + i].left = nodes[state]
423 nodes[n + i].right = nodes[target_state]
424 for node in nodes[n + i].leaves:
425 node.feature = self.feature[node.name]
426 return nodes[n + i]
428 @property
429 def shannon_entropy(self):
430 """The shannon_entropy property."""
431 if self._shannon_entropy is None: 431 ↛ 435line 431 didn't jump to line 435 because the condition on line 431 was always true
432 self._shannon_entropy = np.zeros(self.n_runs)
433 for i, pop in enumerate(self.macro_pop):
434 self._shannon_entropy[i] = utils.shannon_entropy(pop)
435 return self._shannon_entropy
437 @property
438 def davies_bouldin_index(self):
439 """The davies_bouldin_index property."""
440 if self._davies_bouldin_index is None: 440 ↛ 446line 440 didn't jump to line 446 because the condition on line 440 was always true
441 self._davies_bouldin_index = np.zeros(self.n_runs)
442 for i in range(self.n_runs):
443 self._davies_bouldin_index[i] = davies_bouldin_score(
444 self.multi_feature_traj, self.macrotraj[i]
445 )
446 return self._davies_bouldin_index
448 @property
449 def gmrq(self):
450 """The gmrq property."""
451 if self._gmrq is None: 451 ↛ 453line 451 didn't jump to line 453 because the condition on line 451 was always true
452 self._gmrq = utils.gmrq(self.macro_tmat)
453 return self._gmrq
455 @property
456 def reference(self):
457 """The reference property."""
458 if self._reference is None: 458 ↛ 470line 458 didn't jump to line 470 because the condition on line 458 was always true
459 k = kernel_module.MPTKernel()
460 self._reference = MPT(
461 self.traj,
462 self.tlag,
463 self.multi_feature_traj,
464 contact_threshold=self.contact_threshold,
465 macrostate_thresholds=(self.pop_thr, self.q_min),
466 limits=self.limits,
467 quiet=True,
468 )
469 self._reference.mpt(k)
470 return self._reference
472 @property
473 def traj(self):
474 """The microstate trajectory - 0-based."""
475 return self._traj
477 @traj.setter
478 def traj(self, value):
479 if value.min() == 1:
480 value -= 1
481 warnings.warn("1-based trajectory was shifted to 0-based.")
482 if np.unique(value).shape[0] > value.max() + 1: 482 ↛ 483line 482 didn't jump to line 483 because the condition on line 482 was never true
483 raise ValueError("The state numbering in the trajectory is not continuous")
484 if value.max() < 2**7: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true
485 traj_type = np.uint8
486 elif value.max() < 2**15: 486 ↛ 489line 486 didn't jump to line 489 because the condition on line 486 was always true
487 traj_type = np.uint16
488 else:
489 traj_type = np.uint32
490 self._traj = value.astype(traj_type)
492 def print_rel(self):
493 for l, i in [
494 (
495 "Implied Timescale: ",
496 self.timescales[0, 0] / self.reference.timescales[0, 0],
497 ),
498 ("GMRQ: ", self.gmrq[0] / self.reference.gmrq[0]),
499 (
500 "DBI: ",
501 self.davies_bouldin_index()[0]
502 / self.reference.davies_bouldin_index()[0],
503 ),
504 ("H: ", self.shannon_entropy[0] / self.reference.shannon_entropy[0]),
505 ]:
506 print(l + f"{i:.2f}")
508 @property
509 def topology_file(self):
510 """The topology_file property."""
511 if self._topology_file is None: 511 ↛ 512line 511 didn't jump to line 512 because the condition on line 511 was never true
512 raise ValueError("No topology file set.")
513 return self._topology_file
515 @topology_file.setter
516 def topology_file(self, value):
517 if os.path.isfile(value): 517 ↛ 520line 517 didn't jump to line 520 because the condition on line 517 was always true
518 self._topology_file = value
519 else:
520 raise FileNotFoundError(f"No such file: {value}")
522 @property
523 def xtc_trajectory_file(self):
524 """The xtc_trajectory_file property."""
525 if self._xtc_trajectory_file is None: 525 ↛ 526line 525 didn't jump to line 526 because the condition on line 525 was never true
526 raise ValueError("No xtc trajectory file set.")
527 return self._xtc_trajectory_file
529 @xtc_trajectory_file.setter
530 def xtc_trajectory_file(self, value):
531 if os.path.isfile(value): 531 ↛ 534line 531 didn't jump to line 534 because the condition on line 531 was always true
532 self._xtc_trajectory_file = value
533 else:
534 raise FileNotFoundError(f"No such file: {value}")
536 @property
537 def rmsd(self):
538 """The rmsd property."""
539 if self._rmsd is None:
540 self._rmsd, self.mean_frames = utils.calc_rmsd(self, quiet=self.quiet)
541 return self._rmsd
543 @property
544 def frame_length(self):
545 """The frame_length property. Frame length in ns."""
546 return self._frame_length
548 @frame_length.setter
549 def frame_length(self, value):
550 self._frame_length = value
552 @property
553 def macro_micro_feature(self):
554 """Assign macrostate feature values to corresponding microstates"""
555 if self._macro_micro_feature is None: 555 ↛ 564line 555 didn't jump to line 564 because the condition on line 555 was always true
556 self._macro_micro_feature = np.zeros(
557 (self.n_states, self.n_runs), dtype=self.feature_traj.dtype.type
558 )
559 for i, (ma, mf) in enumerate(
560 zip(self.macrostate_assignment, self.macrostate_feature)
561 ):
562 for j, mb in enumerate(ma.astype(bool)):
563 self._macro_micro_feature[mb, i] = mf[j]
564 return self._macro_micro_feature
566 def save_rmsd(self, out):
567 np.save(out, self.rmsd)
569 def load_rmsd(self, f_name):
570 self._rmsd = np.load(f_name)
572 def write_pdbs(self, out):
573 utils.write_pdbs(
574 out,
575 np.log(self.rmsd),
576 self.topology_file,
577 self.xtc_trajectory_file,
578 self.mean_frames,
579 )
581 def rmsd_sharpness(self):
582 return (
583 self.rmsd.mean(axis=1) * self.macro_pop[self.n_i]
584 ).sum() / self.macro_pop[self.n_i].sum()
586 def draw_random_frames_indices(self, out=None, n=20):
587 """
588 Draw n random frames for each macrostate
590 out (str): Path to directory where to save the .random[n] files
591 n (int): number of frames to draw randomly
592 """
593 drawn_frames = np.empty((self.n_macrostates[self.n_i], n), dtype=int)
594 for state in np.arange(self.n_macrostates[self.n_i]):
595 frames_in_state = np.where(self.macrotraj[self.n_i] == state)[0]
596 drawn_frames[state] = np.random.choice(
597 frames_in_state, size=n, replace=False
598 )
599 if self.xtc_stride is not None: 599 ↛ 600line 599 didn't jump to line 600 because the condition on line 599 was never true
600 drawn_frames *= self.xtc_stride
601 if out:
602 Path(os.path.join(out)).mkdir(parents=True, exist_ok=True)
603 for s, i in enumerate(drawn_frames):
604 # Path(os.path.join(out, f"{s+1:02d}")).mkdir(parents=True, exist_ok=True)
605 # np.savetxt(os.path.join(out, f"{s+1:02d}", f".frames.ndx"), i, fmt="%.0f", header="[frames]")
606 np.savetxt(
607 os.path.join(out, f"{s + 1:02d}.ndx"),
608 i,
609 fmt="%.0f",
610 header="[frames]",
611 )
612 else:
613 return drawn_frames
615 def draw_random_frames(self, out, n=20):
616 """
617 Draw n random frames for each macrostate
619 out (str): Path to directory where to save the pdb files
620 n (int): number of frames to draw randomly
621 """
622 for state in np.arange(self.n_macrostates[self.n_i]):
623 frames_in_state = np.where(self.macrotraj[self.n_i] == state)[0]
624 drawn_frames = np.random.choice(frames_in_state, size=n, replace=False)
625 for i, frame in enumerate(drawn_frames):
626 f = md.load_xtc(
627 self.xtc_trajectory_file,
628 top=self.topology_file,
629 frame=frame,
630 )
631 f.save_pdb(os.path.join(out, f"S{state}_{i:02d}.pdb"))
633 def get_best_defined_contacts(self, n=3):
634 """Calculate the variance for each contact in each macrostate."""
635 contacts = np.zeros((self.n_macrostates[self.n_i], n), dtype=int)
636 for i in range(self.n_macrostates[self.n_i]):
637 contacts[i] = np.argsort(
638 np.var(
639 self.multi_feature_traj[self.macrotraj[self.n_i] == i],
640 axis=0,
641 )
642 )[:n]
643 return contacts
645 def get_least_moving_residues(self, contact_index_file, n=3):
646 contact_indices = np.loadtxt(contact_index_file, dtype=int)
647 contacts = self.get_best_defined_contacts(n)
648 least_moving_residues = []
649 for c in contacts:
650 least_moving_residues.append(np.unique(contact_indices[c].flatten()))
651 return least_moving_residues
653 def write_least_moving_residues(self, contact_index_file, out, n=3):
654 if contact_index_file != "none": 654 ↛ 662line 654 didn't jump to line 662 because the condition on line 654 was always true
655 least_moving_residues = self.get_least_moving_residues(
656 contact_index_file, n=n
657 )
658 with open(out, "w") as f:
659 for residues in least_moving_residues:
660 f.write(f"{' '.join(residues.astype(str))}\n")
661 else:
662 with open(out, "w") as f:
663 f.write("")
665 ### PLOT METHODS #########################################################
667 def plot(self, out: str, scale=1, offset=0):
668 """Plot dendrogram"""
669 plot.plot_tree(
670 self.tree[self.n_i],
671 self.macrostate_assignment[self.n_i],
672 out,
673 scale=scale,
674 offset=offset,
675 )
677 def plot_implied_timescales(self, out, use_ref=True, scale=1):
678 """
679 out: File to write plot
680 use_ref: If it for reference trajectory should be plotted
681 scale: scaling factor for plot
682 """
683 if use_ref: 683 ↛ 684line 683 didn't jump to line 684 because the condition on line 683 was never true
684 ref_traj = self.reference.macrotraj[0]
685 else:
686 ref_traj = self.traj
688 macrotraj = utils.get_multi_state_traj(self.macrotraj[self.n_i], self.limits)
690 dtlag = max(1, int(1 / self.frame_length))
691 plot.plot_implied_timescales(
692 [ref_traj, macrotraj],
693 np.arange(1, 4.5 * self.tlag + dtlag, dtlag, dtype=int),
694 out,
695 frame_length=self.frame_length,
696 first_ref=True,
697 scale=scale,
698 use_ref=use_ref,
699 ntimescales=self.timescales.shape[1],
700 )
702 def plot_macro_feature(self, out, ref=None):
703 """
704 Plot histogram of feature distribution.
706 micro_feature (np.ndarray, NxR): N microstates, R runs, holds feature
707 values of respective macrostate
708 out (str): file to save the plot
709 ref (list[tuple]): list of
710 - macrostate_assignment
711 - macrostate_feature
712 - color
713 - label
714 of the clusterings that should be shown explicitly.
715 """
716 plot.plot_macro_feature(
717 self.macro_micro_feature, out, self.reference if ref is None else ref
718 )
720 def plot_rmsd(self, out, helices=None):
721 plot.plot_rmsd(self.rmsd, self.macro_pop[self.n_i], helices, out)
723 def plot_delta_rmsd(self, out, helices=None):
724 plot.plot_delta_rmsd(self.rmsd, self.macro_pop[self.n_i], helices, out)
726 def plot_contact_rep(self, cluster_file, out, scale=1):
727 plot.contact_rep(
728 self.multi_feature_traj,
729 cluster_file,
730 self.macrotraj[self.n_i],
731 out,
732 utils.get_grid_format(self.n_macrostates[self.n_i]),
733 scale=scale,
734 )
736 def plot_relative_implied_timescales(self, out):
737 plot.relative_implied_timescales(self, out)
739 def plot_ck_test(self, out):
740 plot.chapman_kolmogorov(self, out, self.frame_length)
742 def plot_state_network(self, out):
743 plot.state_network(self, out)
745 def plot_stochastic_state_similarity(self, out):
746 plot.evaluate_stochastic_clustering(self, self.reference, out)
748 def plot_transition_matrix(self, out):
749 plot.transition_matrix(self.macro_tmat[self.n_i], out)
751 def plot_transition_time(self, out):
752 plot.transition_time(
753 self.macro_tmat[self.n_i],
754 out,
755 tlag=self.tlag,
756 frame_length=self.frame_length,
757 )
759 def plot_graph(self, out, u=0, f=0):
760 draw_knetwork(
761 self.macrotraj[self.n_i], self.tlag, self.feature_traj, out, u=u, f=f
762 )
764 def plot_sankey(self, out, ax=None, scale=1):
765 plot.plot_sankey(self, self.reference, out, ax=ax, scale=scale)
767 def plot_macrotraj(self, out, row_length=0.2):
768 plot.plot_state_trajectory(
769 self.macrotraj[self.n_i],
770 out,
771 row_length=row_length,
772 frame_length=self.frame_length,
773 )