Coverage for MPT/plot.py: 90%
643 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 11:01 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 11:01 +0200
1#!/usr/bin/env python3
2"""
3plot.py
4==================
6Various plot functions used in this package.
7"""
9from os.path import splitext
11import numpy as np
12import prettypyplot as pplt
13import matplotlib as mpl
14from matplotlib import pyplot as plt
15from matplotlib.cm import ScalarMappable
16from matplotlib.colors import (
17 LinearSegmentedColormap,
18 LogNorm,
19 ListedColormap,
20)
21from matplotlib import colors
22from matplotlib.cbook import boxplot_stats
23import matplotlib.animation as animation
24import matplotlib.patches as patches
25from matplotlib.ticker import MultipleLocator
26import msmhelper as mh
27from msmhelper._cli.contact_rep import load_clusters
28from scipy.stats import pearsonr
29import MPT.utils as utils
30from MPT.sankey_gap import sankey
31from MPT.graph import draw_knetwork
33plt.rcParams["font.family"] = "sans-serif"
35### DENDROGRAM ###############################################################
38def plot_tree(root, macrostate_assignment, output_file, scale=1, offset=0):
39 """
40 Plot the dendrogram from a given state tree of BinaryTreeNode.
41 """
42 n_states = len(root.leaves)
44 # setup matplotlib
45 pplt.use_style(figsize=3.2 * scale, figratio="golden", true_black=True)
46 plt.rcParams["font.family"] = "sans-serif"
48 fig, (ax, ax_mat) = plt.subplots(
49 2,
50 1,
51 gridspec_kw={
52 "hspace": 0.05,
53 "height_ratios": [9, 1],
54 },
55 )
56 for key, spine in ax_mat.spines.items():
57 spine.set_visible(False)
59 ax = root.plot_tree(ax)
61 ax.set_ylabel(r"Metastability $T_{ii}$")
62 ax.set_xlabel("microstates")
63 ax.set_xlim(-0.005 * n_states, 1.005 * n_states)
64 ax.set_ylim(offset, 1.05)
66 # plot legend
67 cmap = plt.get_cmap("plasma_r", 10)
68 label = r"Fraction of Contacts $q$"
70 cmappable = ScalarMappable(root.feature_norm, cmap)
71 plt.sca(ax)
72 pplt.colorbar(cmappable, width="5%", label=label, position="top")
74 # bring microstates in the right order
75 macrostate_assignment = macrostate_assignment[:, [l.name for l in root.leaves]]
77 yticks = np.arange(0.5, 1.5 + macrostate_assignment.shape[0])
78 xticks = np.arange(0, n_states + 1)
79 cmap = LinearSegmentedColormap.from_list(
80 "binary",
81 [(0, 0, 0, 0), (0, 0, 0, 1)],
82 )
84 xvals = 0.5 * (xticks[:-1] + xticks[1:])
85 for idx, assignment in enumerate(macrostate_assignment):
86 xmean = np.median(xvals[assignment == 1])
88 pplt.text(
89 xmean,
90 yticks[idx] - (yticks[1] - yticks[0]),
91 f"{idx + 1:.0f}",
92 ax=ax_mat,
93 va="top",
94 contour=True,
95 size="small",
96 )
98 # Plot macrostate assignments
99 ax_mat.pcolormesh(
100 xticks,
101 yticks,
102 macrostate_assignment,
103 snap=True,
104 cmap=cmap,
105 vmin=0,
106 vmax=1,
107 )
108 # set x-labels
109 ax_mat.set_yticks(yticks)
110 ax_mat.set_yticklabels([])
111 ax_mat.grid(visible=True, axis="y", ls="-", lw=0.5)
112 ax_mat.tick_params(axis="y", length=0, width=0)
113 ax_mat.set_xlim(ax.get_xlim())
114 ax.set_xlabel("")
115 ax_mat.set_xlabel("Macrostates")
116 ax_mat.set_ylabel("")
117 fig.align_ylabels([ax, ax_mat])
119 ax_mat.set_xticks(np.arange(0.5, 0.5 + n_states))
121 # Hide microstate labels
122 for axes in (ax, ax_mat):
123 axes.set_xticks([])
124 axes.set_xticks([], minor=True)
125 axes.set_xticklabels([])
126 axes.set_xticklabels([], minor=True)
128 pplt.savefig(output_file)
129 plt.close()
132### SIMILARITY ###############################################################
135def evaluate_stochastic_clustering(mpt1, mpt2, out):
136 """
137 Plot similarity values for a reference and a stochastic clustering.
138 """
139 ref, sto, S = mpt1 + mpt2
140 s1, s2, s3 = S
141 n_states = S.shape[1]
142 x, y = utils.get_grid_format(n_states)
143 fig, axs = plt.subplots(y, x, figsize=(2 * x, 2 * y))
144 for state, ax in enumerate(axs.flatten()[:n_states]):
145 m = 0
146 # Set left limit to minimum instead of 0
147 m = min([min(s1[state]), min(s2[state]), min(s3[state])]) - 0.02
149 ax.hist(s1[state], bins=np.linspace(m, 1, 21), color="g", alpha=0.7)
150 ax.hist(s2[state], bins=np.linspace(m, 1, 21), color="b", alpha=0.7)
151 ax.hist(s3[state], bins=np.linspace(m, 1, 21), color="r", alpha=0.7)
152 ax.set_title(f"state {state + 1}")
153 fig.supxlabel("Macrostate similarity")
154 fig.supylabel(f"Count of clusterings ({sto.n_runs} clusterings)")
155 leg = plt.figlegend(
156 ["union", "reference", "clustering"],
157 ncols=3,
158 loc="lower center",
159 bbox_to_anchor=(0.5, 0.05),
160 )
161 plt.tight_layout(rect=(0, 0.04, 1, 1))
162 plt.savefig(out)
163 plt.close()
166### IMPLIED TIMESCALES #######################################################
169def plot_implied_timescales(
170 trajs,
171 lagtimes,
172 out,
173 titles="",
174 frame_length=0.2,
175 first_ref=False,
176 scale=1,
177 use_ref=True,
178 ntimescales=3,
179):
180 """
181 frame_length in ns / frame
182 """
183 if first_ref: 183 ↛ 185line 183 didn't jump to line 185 because the condition on line 183 was always true
184 ref_traj = trajs.pop(0)
185 x, y = utils.get_grid_format(len(trajs))
186 pplt.use_style(
187 figsize=(3.8 * scale, 3.2 * scale), latex=False, colors="pastel_autumn"
188 )
189 fig, axs = plt.subplots(y, x, sharex=True, sharey=True)
190 plt.grid(False)
191 if not isinstance(axs, np.ndarray): 191 ↛ 194line 191 didn't jump to line 194 because the condition on line 191 was always true
192 axs = np.array([axs])
194 if titles != "": 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true
195 titles = titles
196 else:
197 titles = [""] * len(trajs)
199 min_it = None
200 max_it = None
202 if first_ref: 202 ↛ 209line 202 didn't jump to line 209 because the condition on line 202 was always true
203 it_ref = mh.msm.implied_timescales(ref_traj, lagtimes, ntimescales=ntimescales)
204 # change from frames to ns
205 it_ref *= frame_length
206 min_it = it_ref.min()
207 max_it = it_ref.max()
209 tlag = lagtimes[-1] / 4.5 * frame_length
210 lagtimes_ns = lagtimes * frame_length
211 for ax, traj, title in zip(axs.flatten(), trajs, titles):
212 ax.axvline(tlag, color="pplt:grid")
213 it = mh.msm.implied_timescales(traj, lagtimes, ntimescales=ntimescales)
214 # change from frames to ns
215 it *= frame_length
216 if min_it is None: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true
217 min_it = it.min()
218 else:
219 min_it = min(it.min(), min_it)
220 if max_it is None: 220 ↛ 221line 220 didn't jump to line 221 because the condition on line 220 was never true
221 max_it = it.max()
222 else:
223 max_it = max(it.max(), max_it)
225 if first_ref: 225 ↛ 230line 225 didn't jump to line 230 because the condition on line 225 was always true
226 if not use_ref: 226 ↛ 229line 226 didn't jump to line 229 because the condition on line 226 was always true
227 _plot_impl_times(it_ref, lagtimes_ns, ax, ls="--")
228 else:
229 _plot_impl_times(it_ref, lagtimes_ns, ax, ls=":")
230 _plot_impl_times(it, lagtimes_ns, ax)
231 ax.set_yscale("log")
232 ax.set_title(title)
234 for ax in axs.flatten():
235 ax.set_ylim(min(min_it * 0.9, int(lagtimes_ns.shape[0] / 4)), max_it * 1.5)
237 if len(axs.shape) == 2: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true
238 for ax in axs[-1]:
239 ax.set_xlabel(r"lag time $\tau$ / ns")
240 for axx in axs:
241 for ax in axx[1:]:
242 plt.setp(ax.get_yticklabels(), visible=False)
243 for ax in axs[:, 0]:
244 ax.set_ylabel("time scale / ns")
245 elif len(axs.shape) == 1: 245 ↛ 253line 245 didn't jump to line 253 because the condition on line 245 was always true
246 axs[0].set_ylabel("time scale / ns")
247 for ax in axs:
248 ax.set_xlabel(r"lag time $\tau$ / ns")
249 for ax in axs[1:]: 249 ↛ 250line 249 didn't jump to line 250 because the loop on line 249 never started
250 plt.setp(ax.get_yticklabels(), visible=False)
252 # Get handles and labels
253 handles, labels = plt.gca().get_legend_handles_labels()
255 # Reorder the handles and labels manually to achieve column-major ordering
256 desired_order = np.array(
257 [(i + ntimescales, i) for i in range(ntimescales)]
258 ).flatten()
259 handles = [handles[i] for i in desired_order]
260 labels = [labels[i] for i in desired_order]
262 pplt.legend(
263 handles=handles, labels=labels, outside="top", frameon=False, ncols=ntimescales
264 )
266 plt.tight_layout()
267 plt.savefig(out)
268 plt.close()
271def _plot_impl_times(impl_times, lagtimes, ax, ls="-"):
272 """Plot the implied timescales"""
273 colors = ["#264653", "#2A9D8F", "#E9C46A", "#f4a261", "#e76f51"] * 4
274 for idx, impl_time in enumerate(impl_times.T):
275 if ls == ":": 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true
276 label = f"$t_{{\\mathrm{{ref}},{idx + 1}}}$"
277 elif ls == "--":
278 label = f"$t_{{\\mathrm{{mic}},{idx + 1}}}$"
279 else:
280 label = f"$t_{idx + 1}$"
281 ax.plot(lagtimes, impl_time, label=label, color=colors[idx], ls=ls)
283 xlim = lagtimes[0], lagtimes[-1]
284 ref_low = int(lagtimes.shape[0] / 4)
285 ax.set_xlim(xlim)
286 # highlight diagonal
287 x_i = np.arange(ref_low, xlim[1])
288 ax.fill_between(x_i, x_i, color="pplt:grid")
291def relative_implied_timescales(cl, out):
292 pplt.use_style(figsize=(8, 2.5), latex=False, colors="pastel_autumn")
294 ref = cl.reference
295 its = cl.timescales / ref.timescales
297 fig = plt.figure()
298 ax1 = fig.add_subplot(1, 3, 1)
299 ax2 = fig.add_subplot(1, 3, 2, sharey=ax1)
300 ax3 = fig.add_subplot(1, 3, 3)
302 for ax in (ax1, ax2, ax3):
303 ax.grid(False)
305 ax1.hist(its[:, 0], bins=20)
306 ax1.set_title("its 1")
307 ax1.set_xlabel(
308 r"Relative Implied Timescale $\left(\frac{t_\mathrm{stoch}}{t_\mathrm{ref}}\right)$"
309 )
310 ax1.set_ylabel("Count of Clusterings")
311 ax2.hist(its.mean(axis=1), bins=20)
312 ax2.set_title(f"Mean its {1}-{3}")
313 ax2.set_xlabel(
314 r"Relative Implied Timescale $\left(\frac{t_\mathrm{stoch}}{t_\mathrm{ref}}\right)$"
315 )
317 bins = np.array(range(min(cl.n_macrostates) - 1, max(cl.n_macrostates) + 1)) + 0.5
319 ax3.hist(cl.n_macrostates, bins=bins)
320 ax3.set_title("n macrostates")
321 ax3.set_xlabel("macrostate count")
323 plt.tight_layout()
324 plt.savefig(out)
325 plt.close()
328### SIMILARITY MATRIX ########################################################
331def plot_heatmap(a, out, title=""):
332 """
333 Plot heatmap from a matrix. This is supposed for a similarity matrix as
334 returned from the multiplication of two MPT objects.
335 """
336 fig, ax = plt.subplots()
337 ax.imshow(a, norm="log")
338 ax.set_aspect("equal", "box")
340 ax.set_xticks(np.arange(a.shape[1]))
341 ax.set_yticks(np.arange(a.shape[0]))
342 if title:
343 ax.set_title(title)
344 ax.set_xlabel("Macrostate")
345 ax.set_ylabel("Macrostate")
346 plt.tight_layout()
347 plt.savefig(out)
348 plt.close()
351def transition_matrix(a, out, title="Transition Matrix", color_thr=0.01):
352 """
353 Plot heatmap from a matrix. This is supposed for a similarity matrix as
354 returned from the multiplication of two MPT objects.
355 """
356 # Scale a to percent
357 a = a * 100
359 # Define the colormap for the diagonal elements (logarithmic Reds)
360 diagonal_values = np.diag(a)
361 diag_norm = LogNorm(vmin=diagonal_values.min(), vmax=diagonal_values.max())
362 diag_cmap = plt.cm.Reds
364 # Adjust the Reds colormap to make the lower bound closer to red
365 reds_custom = diag_cmap(np.linspace(0.2, 1, 256))
366 diag_cmap_custom = ListedColormap(reds_custom)
368 # Define the colormap for the off-diagonal elements (logarithmic viridis)
369 off_diag_mask = ~np.eye(a.shape[0], dtype=bool)
370 off_diag_values = a[off_diag_mask]
372 # Threshold for light gray
373 threshold = color_thr * off_diag_values.max()
374 print(f"Threshold for probabilities: {threshold:.3f} %")
376 off_diag_norm = LogNorm(
377 vmin=threshold * (1 - color_thr), vmax=off_diag_values.max()
378 )
380 # Create a custom colormap for off-diagonal values including light gray
381 colors_list = plt.cm.viridis(np.linspace(0, 1, 256))
382 gray = np.array([0.9, 0.9, 0.9, 1.0])
383 colors_list[: int(color_thr * 256)] = gray
384 custom_off_diag_cmap = colors.ListedColormap(colors_list)
386 fig, ax = plt.subplots(figsize=(8, 8))
387 ax.set_aspect("equal", "box")
388 ax.grid(False)
390 for i in range(a.shape[0]):
391 for j in range(a.shape[1]):
392 value = a[i, j]
393 if value == 0:
394 color = (1, 1, 1, 1) # Zero probabilities are white
395 elif i == j:
396 color = diag_cmap_custom(diag_norm(value))
397 else:
398 color = (
399 gray
400 if value < threshold
401 else custom_off_diag_cmap(off_diag_norm(value))
402 )
404 ax.add_patch(patches.Rectangle((j - 0.5, i - 0.5), 1, 1, color=color))
406 # Add text with transition probabilities
407 if value != 0:
408 grayscale = np.sum(
409 np.array(color[:3]) * np.array([0.299, 0.587, 0.114])
410 )
411 text_color = "white" if grayscale < 0.5 else "black"
412 ax.text(
413 j,
414 i,
415 f"{value:.2f}%",
416 ha="center",
417 va="center",
418 color=text_color,
419 fontsize=10,
420 )
422 ax.set_xticks(np.arange(a.shape[1]))
423 ax.set_yticks(np.arange(a.shape[0]))
424 ax.set_xticklabels(np.arange(1, a.shape[1] + 1))
425 ax.set_yticklabels(np.arange(1, a.shape[0] + 1))
426 ax.set_xlim(-0.5, a.shape[1] - 0.5)
427 ax.set_ylim(-0.5, a.shape[0] - 0.5)
429 # Add a colorbar for diagonal values
430 cbar_diag = fig.colorbar(
431 plt.cm.ScalarMappable(norm=diag_norm, cmap=diag_cmap), ax=ax, shrink=0.5
432 )
433 cbar_diag.set_label("Self Transition Probabilities / \\%")
435 # Add a colorbar for off-diagonal values
436 cbar_off_diag = fig.colorbar(
437 plt.cm.ScalarMappable(norm=off_diag_norm, cmap=custom_off_diag_cmap),
438 ax=ax,
439 shrink=0.5,
440 )
441 cbar_off_diag.set_label("Transitiion Probabilities / \\%")
443 if title: 443 ↛ 446line 443 didn't jump to line 446 because the condition on line 443 was always true
444 ax.set_title(title)
446 ax.set_xlabel("From Macrostate")
447 ax.set_ylabel("To Macrostate")
448 plt.tight_layout()
449 plt.savefig(out)
450 plt.close()
453def transition_time(
454 a,
455 out,
456 tlag=50.0,
457 frame_length=0.2,
458 title=r"Transition Times $\frac{t_\mathrm{lag}}{P}$",
459 color_thr=0.01,
460):
461 """
462 Plot heatmap from a matrix. This is supposed for a similarity matrix as
463 returned from the multiplication of two MPT objects.
464 frame_length in ns
465 """
466 with np.errstate(divide="ignore"):
467 a = tlag / a * frame_length
469 # Define the colormap for the diagonal elements (logarithmic Reds)
470 diagonal_values = np.diag(a)
471 diag_norm = LogNorm(vmin=diagonal_values.min(), vmax=diagonal_values.max())
472 diag_cmap = plt.cm.Reds_r
474 # Adjust the Reds colormap to make the lower bound closer to red
475 reds_custom = diag_cmap(np.linspace(0, 0.8, 256))
476 diag_cmap_custom = ListedColormap(reds_custom)
478 # Define the colormap for the off-diagonal elements (logarithmic viridis)
479 off_diag_mask = ~np.eye(a.shape[0], dtype=bool)
480 off_diag_values = a[off_diag_mask]
482 # Threshold for light gray
483 threshold = off_diag_values.min() / color_thr
484 print(f"Threshold for probabilities: {threshold:.2f} ns")
486 off_diag_norm = LogNorm(
487 vmin=off_diag_values.min(), vmax=threshold / (1 - color_thr)
488 )
490 # Create a custom colormap for off-diagonal values including light gray
491 colors_list = plt.cm.viridis_r(np.linspace(0, 1, 256))
492 gray = np.array([0.9, 0.9, 0.9, 1.0])
493 colors_list[int((1 - color_thr) * 256) :] = gray
494 custom_off_diag_cmap = colors.ListedColormap(colors_list)
496 fig, ax = plt.subplots(figsize=(8, 8))
497 ax.set_aspect("equal", "box")
498 ax.grid(False)
500 for i in range(a.shape[0]):
501 for j in range(a.shape[1]):
502 value = a[i, j]
503 if value == np.inf:
504 color = (1, 1, 1, 1) # Zero probabilities are white
505 elif i == j:
506 color = diag_cmap_custom(diag_norm(value))
507 else:
508 color = (
509 gray
510 if value > threshold
511 else custom_off_diag_cmap(off_diag_norm(value))
512 )
514 ax.add_patch(patches.Rectangle((j - 0.5, i - 0.5), 1, 1, color=color))
516 # Add text with transition probabilities
517 if value != np.inf:
518 grayscale = np.sum(
519 np.array(color[:3]) * np.array([0.299, 0.587, 0.114])
520 )
521 text_color = "white" if grayscale < 0.5 else "black"
522 if value >= threshold:
523 pre_text = f"{value:.1g}"
524 text = pre_text[:2] + pre_text[-1]
525 else:
526 if value >= 100:
527 text = f"{value:.0f}"
528 else:
529 text = f"{value:#.3g}"
530 ax.text(
531 j, i, text, ha="center", va="center", color=text_color, fontsize=10
532 )
534 ax.set_xticks(np.arange(a.shape[1]))
535 ax.set_yticks(np.arange(a.shape[0]))
536 ax.set_xticklabels(np.arange(1, a.shape[1] + 1))
537 ax.set_yticklabels(np.arange(1, a.shape[0] + 1))
538 ax.set_xlim(-0.5, a.shape[1] - 0.5)
539 ax.set_ylim(-0.5, a.shape[0] - 0.5)
541 # Add a colorbar for diagonal values
542 cbar_diag = fig.colorbar(
543 plt.cm.ScalarMappable(norm=diag_norm, cmap=diag_cmap), ax=ax, shrink=0.5
544 )
545 cbar_diag.set_label("Self Transition Times / ns")
547 # Add a colorbar for off-diagonal values
548 cbar_off_diag = fig.colorbar(
549 plt.cm.ScalarMappable(norm=off_diag_norm, cmap=custom_off_diag_cmap),
550 ax=ax,
551 shrink=0.5,
552 )
553 cbar_off_diag.set_label("Transitiion Times / ns")
555 if title: 555 ↛ 558line 555 didn't jump to line 558 because the condition on line 555 was always true
556 ax.set_title(title)
558 ax.set_xlabel("From Macrostate")
559 ax.set_ylabel("To Macrostate")
560 plt.tight_layout()
561 plt.savefig(out)
562 plt.close()
565### MACROSTATE FEATURES ######################################################
568def plot_macro_feature(micro_feature, out, ref=None, pop=None):
569 """
570 Plot histogram of feature distribution.
572 micro_feature (np.ndarray, NxR): N microstates, R runs, holds feature
573 values of respective macrostate
574 out (str): file to save the plot
575 ref (list[tuple]): list of
576 - macrostate_assignment
577 - macrostate_feature
578 - color
579 - label
580 of the clusterings that should be shown explicitly.
581 """
582 min_feature = micro_feature.min() * 0.95
583 max_feature = micro_feature.max() * 1.05
584 counts, bins = np.histogram(
585 micro_feature,
586 bins=np.linspace(min_feature, max_feature, 101),
587 weights=pop,
588 density=True,
589 )
590 norm_counts = counts / micro_feature.shape[1]
591 y_min = norm_counts[norm_counts > 0].min() * 0.7
593 fig, ax = plt.subplots(figsize=(8, 6))
594 ax.hist(bins[:-1], bins=bins, weights=norm_counts, label="Stochastic Clustering")
595 if ref is not None: 595 ↛ 599line 595 didn't jump to line 599 because the condition on line 595 was always true
596 # for mas, mfs, c, l, w in ref:
597 # add_ref(mas, mfs, ax, color=c, label=l, weights=w)
598 add_ref(ref.macrostate_assignment[ref.n_i], ref.macrostate_feature[ref.n_i], ax)
599 ax.set_xlabel("Fraction of Contacts")
600 ax.set_ylabel("Population")
601 ax.set_title(f"Macrostate Features, {micro_feature.shape[1]} clusterings")
602 ax.set_yscale("log")
603 ylim = ax.get_ylim()
604 ax.set_ylim((y_min, ylim[1]))
605 plt.legend(loc="lower left")
606 plt.tight_layout()
607 plt.savefig(out)
608 plt.close()
611def add_ref(
612 macrostate_assignment,
613 macrostate_feature,
614 ax,
615 color="r",
616 label="Reference",
617 weights=None,
618):
619 """
620 Add a clustering to the histogram.
622 macrostate_assignment (np.ndarray, MxN): macrostate assignement, M: number
623 of macrostates, N: number of microstates.
624 macrostate_feature (np.ndarray, M): mean feature for every macrostate.
625 """
626 b = True
627 for i, (ma, mf) in enumerate(zip(macrostate_assignment, macrostate_feature)):
628 x = [mf, mf]
629 if weights is None:
630 weights = np.array([1])
631 else:
632 weights = weights / weights.sum()
633 y = [1e-9, (ma * weights).sum() / weights.sum() * 1e-3]
634 if b:
635 ax.plot(x, y, c=color, label=label + " / 1000")
636 b = False
637 else:
638 ax.plot(x, y, c=color)
639 pplt.text(
640 mf + 0.015,
641 y[1] * 0.82,
642 f"{i + 1:.0f}",
643 c=color,
644 ax=ax,
645 contour=True,
646 size="small",
647 )
650### CONTACT REPRESENTATION ###################################################
653def contact_rep(contacts, cluster_file, state_traj, output, grid, scale=1):
654 """
655 Adapted from msmhelper.
657 Contact representation of states.
659 This script creates a contact representation of states. Were the states
660 are obtained by [MoSAIC](https://github.com/moldyn/MoSAIC) and the contact
661 representation was introduced in Nagel et al.[^1].
663 [^1]: Nagel et al., **Selecting Features for Markov Modeling: A Case Study
664 on HP35.**, *J. Chem. Theory Comput.*, submitted,
666 """
667 # setup matplotlib
668 pplt.use_style(
669 figsize=1.2 * scale,
670 colors="pastel_autumn",
671 true_black=True,
672 latex=False,
673 )
675 # load files
676 states = np.unique(state_traj)
677 clusters = load_clusters(cluster_file)
679 contact_idxs = np.hstack(clusters)
680 n_idxs = len(contact_idxs)
681 n_frames = len(contacts)
683 xtickpos = (
684 np.cumsum(
685 [
686 0,
687 *[len(clust) for clust in clusters[:-1]],
688 ]
689 )
690 - 0.5
691 )
692 nrows, ncols = grid
693 hspace, wspace = 0, 0
694 ylims = 0, np.quantile(contacts, 0.999)
696 counter = 0
697 for chunk in mh.plot._ck_test._split_array(states, nrows * ncols):
698 fig, axs = plt.subplots(
699 int(np.ceil(len(chunk) / ncols)),
700 ncols,
701 sharex=True,
702 sharey=True,
703 squeeze=False,
704 gridspec_kw={"wspace": wspace, "hspace": hspace},
705 )
707 # ignore outliers
708 for state, ax in zip(chunk, axs.flatten()):
709 contacts_state = contacts[state_traj == state]
710 pop_state = len(contacts_state) / n_frames
712 # get colormap
713 c1, c2, c3 = pplt.categorical_color(3, "C0")
715 stats = {
716 idx: boxplot_stats(contacts_state[:, idx])[0] for idx in contact_idxs
717 }
719 for color, (key_low, key_high), label in (
720 (c3, ("whislo", "whishi"), r"$Q_{1/3} \pm 1.5\mathrm{IQR}$"),
721 (c2, ("q1", "q3"), r"$\mathrm{IQR} = Q_3 - Q_1$"),
722 ):
723 ymax = [stats[idx][key_high] for idx in contact_idxs]
724 ymin = [stats[idx][key_low] for idx in contact_idxs]
725 ax.stairs(
726 ymax,
727 np.arange(n_idxs + 1) - 0.5,
728 baseline=ymin,
729 color=color,
730 lw=0,
731 fill=True,
732 label=label,
733 )
735 ax.hlines(
736 [stats[idx]["med"] for idx in contact_idxs],
737 xmin=np.arange(n_idxs) - 0.5,
738 xmax=np.arange(n_idxs) + 0.5,
739 label="median",
740 color=c1,
741 )
743 pplt.text(
744 0.5,
745 0.95,
746 rf"S{state + 1} {pop_state:.1%}",
747 ha="center",
748 va="top",
749 ax=ax,
750 transform=ax.transAxes,
751 contour=True,
752 )
754 ax.set_xlim([-0.5, n_idxs - 0.5])
755 ax.set_ylim(*ylims)
756 ax.set_xticks(xtickpos)
757 ax.set_xticklabels(np.arange(len(xtickpos)) + 1)
759 ax.grid(False)
760 for pos in xtickpos:
761 ax.axvline(pos, color="pplt:grid", lw=1.0)
763 pplt.hide_empty_axes()
764 pplt.legend(
765 ax=axs[0, 0],
766 outside="top",
767 bbox_to_anchor=(
768 0,
769 1.0,
770 axs.shape[1] + wspace * (axs.shape[1] - 1),
771 0.01,
772 ),
773 frameon=False,
774 ncol=2,
775 )
776 pplt.subplot_labels(
777 xlabel="contact clusters",
778 ylabel="distances [nm]",
779 )
781 # save figure and continue
782 if output is None: 782 ↛ 783line 782 didn't jump to line 783 because the condition on line 782 was never true
783 plt.show()
784 # output = f"{state_file}.contactRep.pdf"
785 # insert state_str between pathname and extension
786 path, ext = splitext(output)
787 if counter == 0: 787 ↛ 791line 787 didn't jump to line 791 because the condition on line 787 was always true
788 pplt.savefig(output)
789 plt.close()
790 else:
791 pplt.savefig(f"{path}.state{chunk[0]:.0f}-{chunk[-1]:.0f}{ext}")
792 plt.close()
793 counter += 1
796### SANKEY ###################################################################
799def plot_sankey(cl, ref, out, ax=None, scale=1):
800 features = []
801 for macrostate in cl.tree[cl.n_i].macrostates:
802 features.append(macrostate.feature)
803 ma_order = np.argsort(features)[::-1]
804 colorDict = {}
805 for i, o in enumerate(ma_order):
806 colorDict[str(i + 1)] = cl.tree[cl.n_i].macrostates[o].color
807 if ax is None: 807 ↛ 809line 807 didn't jump to line 809 because the condition on line 807 was always true
808 pplt.use_style(figsize=(1.7 * scale, 3.6 * scale), true_black=True)
809 sankey(
810 left=(cl.macrostates_map[cl.n_i] + 1).astype(str),
811 right=(ref.macrostates_map[0] + 1).astype(str),
812 leftWeight=ref.pop,
813 rightWeight=ref.pop,
814 leftLabels=np.arange(1, cl.n_macrostates[cl.n_i] + 1).astype(str).tolist(),
815 rightLabels=np.arange(1, ref.n_macrostates[0] + 1).astype(str).tolist(),
816 colorDict=colorDict,
817 ax=ax,
818 )
819 if ax is None: 819 ↛ exitline 819 didn't return from function 'plot_sankey' because the condition on line 819 was always true
820 pplt.savefig(out)
821 plt.close()
824### RMSD LINES ###############################################################
827def plot_rmsd(rmsds, pops, helices=None, filename=None):
828 """
829 Plots a 2D NumPy array as a heatmap with a logarithmic color scale and variable row heights.
831 Parameters:
832 - vars (np.ndarray): The 2D NumPy array to plot. Values must be positive for logarithmic scaling.
833 - row_heights (np.ndarray): 1D array defining the height of each row.
834 - helices (np.ndarray): Array with start and end points for blocks to be indicated in the bottom row.
835 - filename (str, optional): If provided, saves the heatmap to this file.
836 """
837 # Ensure all values are positive for logarithmic scaling
838 if np.any(rmsds <= 0): 838 ↛ 839line 838 didn't jump to line 839 because the condition on line 838 was never true
839 raise ValueError(
840 "All values in `rmsds` must be positive for logarithmic scaling."
841 )
843 if rmsds.shape[0] != len(pops): 843 ↛ 844line 843 didn't jump to line 844 because the condition on line 843 was never true
844 raise ValueError("Length of `pops` must match the number of rows in `rmsds`.")
846 if helices is not None: 846 ↛ 849line 846 didn't jump to line 849 because the condition on line 846 was always true
847 n_plots = rmsds.shape[0] + 1
848 else:
849 n_plots = rmsds.shape[0]
851 w = 0.08 * rmsds.shape[1] + 3 # 8.6
852 h = 1 + 0.4 * n_plots # 6
853 pplt.use_style(
854 figsize=(w, h),
855 colors="pastel_autumn",
856 true_black=True,
857 latex=False,
858 )
859 fig, axs = plt.subplots(
860 n_plots,
861 3,
862 sharex="col",
863 width_ratios=[rmsds.shape[1], 8, 8],
864 gridspec_kw={"wspace": 0, "hspace": 0},
865 )
867 ylim = 0.5 * rmsds.min(), 2 * rmsds.max()
868 pops = pops / pops.sum()
869 ylim_hist = 0, 1.05 * pops.max()
871 rmsd_sums = rmsds[:, 2:-2].sum(axis=1)
872 ylim_rmsd = 0, 1.05 * rmsd_sums.max()
874 for i, ((ax, hist_ax, rmsd_ax), rmsd, pop) in enumerate(
875 zip(axs[:-1] if helices is not None else axs, rmsds, pops)
876 ):
877 rect = patches.Rectangle(
878 (0, 0.3), # Position of the block
879 pop,
880 0.4, # color='black'
881 )
882 hist_ax.add_patch(rect)
883 hist_ax.set_xlim(ylim_hist)
884 hist_ax.set_yticks([], [])
885 hist_ax.grid(False)
887 rect = patches.Rectangle(
888 (0, 0.3), # Position of the block
889 rmsd[2:-2].sum(),
890 0.4, # color='black'
891 )
892 rmsd_ax.add_patch(rect)
893 rmsd_ax.set_xlim(ylim_rmsd)
894 rmsd_ax.set_yticks([], [])
895 rmsd_ax.grid(False)
897 ax.plot(np.arange(rmsd.shape[0]) + 1, rmsd)
898 ax.fill_between(
899 np.arange(rmsd.shape[0]) + 1,
900 [ylim[0]] * rmsd.shape[0],
901 rmsd,
902 alpha=0.5,
903 # facecolor="none",
904 # hatch="/",
905 )
907 ax.set_yscale("log")
908 ax.set_ylabel(f"{i + 1}")
909 ax.set_xlim((0.5, rmsd.shape[0] + 0.5))
910 ax.set_ylim(ylim)
911 ax.grid(True)
913 if helices is not None: 913 ↛ 970line 913 didn't jump to line 970 because the condition on line 913 was always true
914 line_start = 1
915 helices_ax = axs[-1, 0]
916 # helices_ax.plot([1, rmsds.shape[1]], [0.5, 0.5]) #, c="k")
917 for start, end in helices:
918 if start > 0:
919 # Helices
920 start -= 0.3
921 end += 0.3
922 rect = patches.Rectangle(
923 (start, 0.3), # Position of the block
924 end - start,
925 0.4, # color='black'
926 fc="#264653",
927 ec="#264653",
928 lw=2,
929 )
930 else:
931 # Sheets
932 start, end = -start, -end
933 start -= 0.5
934 end += 0.5
935 rect = patches.Rectangle(
936 (start, 0.3), # Position of the block
937 end - start,
938 0.4, # color='black'
939 fc="white",
940 ec="#264653",
941 lw=2,
942 )
943 helices_ax.plot(
944 [line_start, start],
945 [0.5, 0.5],
946 solid_capstyle="butt",
947 c="#264653",
948 lw=2,
949 )
950 line_start = end
951 helices_ax.add_patch(rect)
952 helices_ax.plot(
953 [line_start, rmsds.shape[1]],
954 [0.5, 0.5],
955 solid_capstyle="butt",
956 c="#264653",
957 lw=2,
958 )
960 helices_ax.set_ylim((0, 1))
961 helices_ax.set_ylabel("H")
962 helices_ax.set_yticks([], [])
963 helices_ax.grid(False)
965 axs[-1, 1].grid(False)
966 axs[-1, 1].set_yticks([], [])
967 axs[-1, 2].grid(False)
968 axs[-1, 2].set_yticks([], [])
970 hist_ticks = axs[-1, 1].get_xticks()
971 hist_labels = axs[-1, 1].get_xticklabels()
972 hist_labels[0] = ""
973 axs[-1, 1].set_xticks(hist_ticks, hist_labels)
975 rmsd_ticks = axs[-1, 2].get_xticks()
976 rmsd_labels = axs[-1, 2].get_xticklabels()
977 rmsd_labels[0] = ""
978 axs[-1, 2].set_xticks(rmsd_ticks, rmsd_labels)
980 axs[-1, 0].xaxis.set_major_locator(MultipleLocator(5))
981 axs[-1, 0].xaxis.set_minor_locator(MultipleLocator(1))
982 axs[-1, 0].set_xlabel("Residue")
983 axs[-1, 1].set_xlabel("Population", rotation=20)
984 axs[-1, 2].set_xlabel(r"$\sum$ RMSD / nm", rotation=20)
985 fig.supylabel("Macrostate; RMSD Variance / nm")
987 # Save to file if filename is provided
988 plt.tight_layout()
989 if filename: 989 ↛ 992line 989 didn't jump to line 992 because the condition on line 989 was always true
990 plt.savefig(filename, dpi=192) # , bbox_inches="tight"
991 else:
992 plt.show()
993 plt.close()
996def plot_delta_rmsd(rmsds, pops, helices=None, filename=None):
997 """
998 Plots a 2D NumPy array as a heatmap with a logarithmic color scale and variable row heights.
1000 Parameters:
1001 - vars (np.ndarray): The 2D NumPy array to plot. Values must be positive for logarithmic scaling.
1002 - row_heights (np.ndarray): 1D array defining the height of each row.
1003 - helices (np.ndarray): Array with start and end points for blocks to be indicated in the bottom row.
1004 - filename (str, optional): If provided, saves the heatmap to this file.
1005 """
1006 # Ensure all values are positive for logarithmic scaling
1007 if np.any(rmsds <= 0): 1007 ↛ 1008line 1007 didn't jump to line 1008 because the condition on line 1007 was never true
1008 raise ValueError(
1009 "All values in `rmsds` must be positive for logarithmic scaling."
1010 )
1012 if rmsds.shape[0] != len(pops): 1012 ↛ 1013line 1012 didn't jump to line 1013 because the condition on line 1012 was never true
1013 raise ValueError("Length of `pops` must match the number of rows in `rmsds`.")
1015 if helices is not None: 1015 ↛ 1018line 1015 didn't jump to line 1018 because the condition on line 1015 was always true
1016 n_plots = rmsds.shape[0] + 1
1017 else:
1018 n_plots = rmsds.shape[0]
1020 w = 0.08 * rmsds.shape[1] + 3 # 8.6
1021 h = 1 + 0.4 * n_plots # 6
1022 pplt.use_style(
1023 figsize=(w, h),
1024 colors="pastel_autumn",
1025 true_black=True,
1026 latex=False,
1027 )
1028 fig, axs = plt.subplots(
1029 n_plots,
1030 3,
1031 sharex="col",
1032 width_ratios=[rmsds.shape[1], 8, 8],
1033 gridspec_kw={"wspace": 0, "hspace": 0},
1034 )
1036 delta_rmsd = rmsds
1037 # delta_rmsd[1:] = rmsds[1:] - rmsds[:-1]
1038 delta_rmsd[1:] = rmsds[1:] - rmsds[0]
1039 rmsds = delta_rmsd
1041 rmsd_max = rmsds.max()
1042 rmsd_min = rmsds.min()
1043 rmsd_delta = rmsd_max - rmsd_min
1045 ylim = rmsd_min - rmsd_delta * 0.1, rmsd_max + rmsd_delta * 0.15
1046 pops = pops / pops.sum()
1047 ylim_hist = 0, 1.05 * pops.max()
1049 rmsd_sums = rmsds.sum(axis=1)
1050 ylim_rmsd = 0, 1.05 * rmsd_sums.max()
1052 for i, ((ax, hist_ax, rmsd_ax), rmsd, pop) in enumerate(
1053 zip(axs[:-1] if helices is not None else axs, rmsds, pops)
1054 ):
1055 rect = patches.Rectangle(
1056 (0, 0.3), # Position of the block
1057 pop,
1058 0.4, # color='black'
1059 )
1060 hist_ax.add_patch(rect)
1061 hist_ax.set_xlim(ylim_hist)
1062 hist_ax.set_yticks([], [])
1063 hist_ax.grid(False)
1065 rect = patches.Rectangle(
1066 (0, 0.3), # Position of the block
1067 abs(rmsd).sum(),
1068 0.4, # color='black'
1069 )
1070 rmsd_ax.add_patch(rect)
1071 rmsd_ax.set_xlim(ylim_rmsd)
1072 rmsd_ax.set_yticks([], [])
1073 rmsd_ax.grid(False)
1075 ax.plot(np.arange(rmsd.shape[0]) + 1, rmsd)
1076 ax.fill_between(
1077 np.arange(rmsd.shape[0]) + 1,
1078 # [ylim[0]]*rmsd.shape[0],
1079 [0] * rmsd.shape[0],
1080 rmsd,
1081 alpha=0.5,
1082 # facecolor="none",
1083 # hatch="/",
1084 )
1086 # ax.set_yscale("log")
1087 ax.set_ylabel(f"{i + 1}")
1088 # ax.set_ylabel(f"{i+1}-{i}")
1089 # ax.set_ylabel(f"{i+1}-{i}", rotation=90, position=(-0.2, 0))
1090 ax.set_xlim((0.5, rmsd.shape[0] + 0.5))
1091 ax.set_ylim(ylim)
1092 ax.grid(True)
1094 if helices is not None: 1094 ↛ 1151line 1094 didn't jump to line 1151 because the condition on line 1094 was always true
1095 line_start = 1
1096 helices_ax = axs[-1, 0]
1097 # helices_ax.plot([1, rmsds.shape[1]], [0.5, 0.5]) #, c="k")
1098 for start, end in helices:
1099 if start > 0:
1100 # Helices
1101 start -= 0.3
1102 end += 0.3
1103 rect = patches.Rectangle(
1104 (start, 0.3), # Position of the block
1105 end - start,
1106 0.4, # color='black'
1107 fc="#264653",
1108 ec="#264653",
1109 lw=2,
1110 )
1111 else:
1112 # Sheets
1113 start, end = -start, -end
1114 start -= 0.5
1115 end += 0.5
1116 rect = patches.Rectangle(
1117 (start, 0.3), # Position of the block
1118 end - start,
1119 0.4, # color='black'
1120 fc="white",
1121 ec="#264653",
1122 lw=2,
1123 )
1124 helices_ax.plot(
1125 [line_start, start],
1126 [0.5, 0.5],
1127 solid_capstyle="butt",
1128 c="#264653",
1129 lw=2,
1130 )
1131 line_start = end
1132 helices_ax.add_patch(rect)
1133 helices_ax.plot(
1134 [line_start, rmsds.shape[1]],
1135 [0.5, 0.5],
1136 solid_capstyle="butt",
1137 c="#264653",
1138 lw=2,
1139 )
1141 helices_ax.set_ylim((0, 1))
1142 helices_ax.set_ylabel("H")
1143 helices_ax.set_yticks([], [])
1144 helices_ax.grid(False)
1146 axs[-1, 1].grid(False)
1147 axs[-1, 1].set_yticks([], [])
1148 axs[-1, 2].grid(False)
1149 axs[-1, 2].set_yticks([], [])
1151 axs[0, 0].set_ylim(0.5 * rmsds[0].min(), 2 * rmsds[0].max())
1152 axs[0, 0].set_yscale("log")
1153 # axs[0, 0].set_ylabel("1") #, rotation=90)
1155 hist_ticks = axs[-1, 1].get_xticks()
1156 hist_labels = axs[-1, 1].get_xticklabels()
1157 hist_labels[0] = ""
1158 axs[-1, 1].set_xticks(hist_ticks, hist_labels)
1160 rmsd_ticks = axs[-1, 2].get_xticks()
1161 rmsd_labels = axs[-1, 2].get_xticklabels()
1162 rmsd_labels[0] = ""
1163 axs[-1, 2].set_xticks(rmsd_ticks, rmsd_labels)
1165 axs[-1, 0].xaxis.set_major_locator(MultipleLocator(5))
1166 axs[-1, 0].xaxis.set_minor_locator(MultipleLocator(1))
1167 axs[-1, 0].set_xlabel("Residue")
1168 axs[-1, 1].set_xlabel("Population", rotation=20)
1169 axs[-1, 2].set_xlabel(r"$\sum$ |$\Delta$RMSD| / nm", rotation=20)
1170 fig.supylabel(r"$\Delta$Macrostate; $\Delta$RMSD Variance / nm wrt. 1st State")
1172 # Save to file if filename is provided
1173 plt.tight_layout()
1174 if filename: 1174 ↛ 1177line 1174 didn't jump to line 1177 because the condition on line 1174 was always true
1175 plt.savefig(filename, dpi=192) # , bbox_inches="tight"
1176 else:
1177 plt.show()
1178 plt.close()
1181### TRAJECTORY ###############################################################
1184def plot_state_trajectory(trajectory, filename, row_length=0.2, frame_length=0.2):
1185 """
1186 Plot state trajectory
1188 trajectory (np.ndarray): state trajectory
1189 filename (str): file name to save the plot to
1190 row_length (int|float):
1191 row_length > 1: number of frames in each row
1192 0 < row_length <= 1: fraction of total frames per row (1/n_rows)
1193 frame_length (float): frame length in ns
1194 """
1195 if row_length > 1: 1195 ↛ 1196line 1195 didn't jump to line 1196 because the condition on line 1195 was never true
1196 x_max = int(row_length)
1197 elif row_length > 0: 1197 ↛ 1200line 1197 didn't jump to line 1200 because the condition on line 1197 was always true
1198 x_max = int(np.ceil(trajectory.shape[0] * row_length))
1199 else:
1200 raise ValueError("row_lengthg must be > 0")
1202 frame_length /= 1000.0
1203 # Calculate unique states and their lengths
1204 unique_states, lengths = utils.find_state_lengths(trajectory)
1205 unique_states += 1
1206 lengths = lengths * frame_length
1207 n_rows = int(np.ceil(trajectory.shape[0] / x_max))
1209 x_max *= frame_length
1211 figsize = (11.7, 8.3)
1212 # # Set up figure size proportional to data
1213 # width = max(6, x_max * 0.0001) # Minimum width of 6 inches
1214 # height = max(
1215 # 2, (unique_states.max() - unique_states.min() + 1) * 0.05 * n_rows + 0.6
1216 # ) # Minimum height of 4 inches
1217 # figsize = (width * 1.5, height * 1.5)
1219 u_states = np.unique(unique_states)
1220 cmap = mpl.cm.turbo
1221 norm = mpl.colors.BoundaryNorm(np.concatenate([[0], u_states]) + 0.5, cmap.N)
1223 fig, axs = plt.subplots(
1224 n_rows,
1225 1,
1226 sharex=True,
1227 figsize=figsize,
1228 gridspec_kw={"wspace": 0, "hspace": 0},
1229 squeeze=False,
1230 layout="compressed",
1231 )
1232 axs = axs[:, 0]
1233 axi = 0
1235 # Plot each state occurrence as a line segment
1236 x_start = 0 # Initial x-coordinate for the first segment
1237 for state, length in zip(unique_states, lengths):
1238 x_end = x_start + length # Calculate end position of this segment on the x-axis
1239 color = cmap(norm(state))
1241 while x_end > x_max:
1242 axs[axi].plot(
1243 [x_start, x_max],
1244 [state, state],
1245 color=color[:3],
1246 linewidth=3,
1247 solid_capstyle="butt",
1248 )
1249 x_end -= x_max
1250 x_start = 0
1251 axi += 1
1252 if not np.isclose(x_start, x_end):
1253 axs[axi].plot(
1254 [x_start, x_end],
1255 [state, state],
1256 color=color,
1257 linewidth=3,
1258 solid_capstyle="butt",
1259 )
1261 # Move x_start to the end of the current segment for the next one
1262 x_start = x_end
1264 # Label axes and set title
1265 cbar = fig.colorbar(
1266 ScalarMappable(norm=norm, cmap=cmap),
1267 ax=axs,
1268 orientation="vertical",
1269 label="Macrostate",
1270 )
1271 cbar.set_ticks(ticks=u_states, labels=u_states, minor=False)
1272 fig.axes[-1].tick_params(length=0)
1274 fig.supylabel("State Index")
1275 axs[-1].set_xlabel(r"t / $\mu$s")
1277 for ax in axs:
1278 ax.set_ylim(unique_states.min() - 1, unique_states.max())
1280 # Set axis limits
1281 plt.xlim(0, x_max)
1283 # Save the plot to the specified file
1284 plt.savefig(filename)
1285 plt.close() # Close the plot to free memory
1288### CHAPMAN-KOLMOGOROV TEST ##################################################
1291def chapman_kolmogorov(mpt, out, frame_length=0.2):
1292 """Chapman-Kolmogorov Test. Frame length in ns"""
1293 ck = mh.msm.tests.chapman_kolmogorov_test(
1294 utils.get_multi_state_traj(mpt.macrotraj[mpt.n_i], mpt.limits),
1295 [50, 50, 50, 50, 50],
1296 4000,
1297 # int(1550*frame_length),
1298 )
1299 pplt.use_style(
1300 figsize=4.8,
1301 colors="pastel_autumn",
1302 true_black=True,
1303 latex=False,
1304 )
1306 nrows, ncols = utils.get_grid_format(mpt.n_macrostates[mpt.n_i])
1307 for chunk in mh.plot._ck_test._split_array(
1308 np.arange(mpt.n_macrostates[mpt.n_i]), nrows * ncols
1309 ):
1310 fig = plot_ck_test(
1311 ck=ck,
1312 states=chunk,
1313 frames_per_unit=1 / frame_length,
1314 unit="ns",
1315 grid=(ncols, nrows),
1316 )
1318 for ax in fig.axes:
1319 for text in ax.texts:
1320 text.set_position((0.15, 0.2))
1321 plt.savefig(out)
1322 plt.close()
1325def plot_ck_test(
1326 ck,
1327 states=None,
1328 frames_per_unit=1,
1329 unit="frames",
1330 grid=(3, 3),
1331):
1332 """Plot CK-Test results.
1334 This routine is a basic helper function to visualize the results of
1335 [msmhelper.msm.chapman_kolmogorov_test][].
1337 Parameters
1338 ----------
1339 ck : dict
1340 Dictionary holding for each lagtime the CK equation and with 'md' the
1341 reference.
1342 states : ndarray, optional
1343 List containing all states to plot the CK-test.
1344 frames_per_unit : float, optional
1345 Number of frames per given unit. This is used to scale the axis
1346 accordingly.
1347 unit : ['frames', 'fs', 'ps', 'ns', 'us'], optional
1348 Unit to use for label.
1349 grid : (int, int), optional
1350 The number of `(n_rows, n_cols)` to use for the grid layout.
1352 Returns
1353 -------
1354 fig : matplotlib.Figure
1355 Figure holding plots.
1357 Notes
1358 -----
1359 Adapted from msmhelper.
1361 """
1362 # load colors
1363 pplt.load_cmaps()
1364 pplt.load_colors()
1366 lagtimes = np.array([key for key in ck.keys() if key != "md"])
1367 if states is None: 1367 ↛ 1368line 1367 didn't jump to line 1368 because the condition on line 1367 was never true
1368 states = np.array(list(ck["md"]["ck"].keys()))
1370 nrows, ncols = grid
1371 needed_rows = int(np.ceil(len(states) / ncols))
1373 fig, axs = plt.subplots(
1374 needed_rows,
1375 ncols,
1376 sharex=True,
1377 sharey="row",
1378 gridspec_kw={"wspace": 0, "hspace": 0},
1379 )
1380 axs = np.atleast_2d(axs)
1382 max_time = np.max(ck["md"]["time"])
1383 for irow, states_row in enumerate(mh.plot._ck_test._split_array(states, ncols)):
1384 for icol, state in enumerate(states_row):
1385 ax = axs[irow, icol]
1387 pplt.plot(
1388 ck["md"]["time"] / frames_per_unit,
1389 ck["md"]["ck"][state],
1390 "--",
1391 ax=ax,
1392 color="pplt:gray",
1393 label="MD",
1394 )
1395 for lagtime in lagtimes:
1396 pplt.plot(
1397 ck[lagtime]["time"] / frames_per_unit,
1398 ck[lagtime]["ck"][state],
1399 ax=ax,
1400 label=lagtime / frames_per_unit,
1401 )
1402 pplt.text(
1403 0.5,
1404 0.9,
1405 "S{0}".format(state + 1),
1406 contour=True,
1407 va="top",
1408 transform=ax.transAxes,
1409 ax=ax,
1410 )
1412 # set scale
1413 ax.set_xscale("log")
1414 ax.set_xlim(
1415 [
1416 lagtimes[0] / frames_per_unit,
1417 max_time / frames_per_unit,
1418 ]
1419 )
1420 ax.set_ylim([0, 1])
1421 if irow < len(axs) - 1:
1422 ax.set_yticks([0.5, 1])
1423 else:
1424 ax.set_yticks([0, 0.5, 1])
1426 ax.grid(True, which="major", linestyle="--")
1427 ax.grid(True, which="minor", linestyle="dotted")
1428 ax.set_axisbelow(True)
1430 # set legend
1431 legend_kw = (
1432 {
1433 "outside": "right",
1434 "bbox_to_anchor": (2.0, (1 - nrows), 0.2, nrows),
1435 }
1436 if ncols in {1, 2}
1437 else {
1438 "outside": "top",
1439 "bbox_to_anchor": (0.0, 1.0, ncols, 0.01),
1440 }
1441 )
1442 if ncols == 3: 1442 ↛ 1443line 1442 didn't jump to line 1443 because the condition on line 1442 was never true
1443 legend_kw["ncol"] = 3
1444 pplt.legend(
1445 ax=axs[0, 0],
1446 **legend_kw,
1447 title=rf"$\tau_\mathrm{{lag}}$ [{unit}]",
1448 frameon=False,
1449 )
1451 ylabel = (
1452 (r"self-transition probability $P_{i\to i}$")
1453 if nrows >= 3
1454 else (r"$P_{i\to i}$")
1455 )
1457 pplt.hide_empty_axes()
1458 pplt.label_outer()
1459 pplt.subplot_labels(
1460 ylabel=ylabel,
1461 xlabel=r"time $t$ [{unit}]".format(unit=unit),
1462 )
1463 return fig
1466### MACROSTATE GRAPH #########################################################
1469def state_network(lumping, out):
1470 draw_knetwork(
1471 lumping.macrotraj[lumping.n_i], lumping.tlag, lumping.feature_traj, out
1472 )