Coverage for MPP/plot.py: 90%
641 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-14 16:23 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-14 16:23 +0200
1#!/usr/bin/env python3
2"""
3plot.py
4==================
6Various plot functions used in this package.
7"""
9from os.path import splitext
11import numpy as np
12import prettypyplot as pplt
13import matplotlib as mpl
14from matplotlib import pyplot as plt
15from matplotlib.cm import ScalarMappable
16from matplotlib.colors import (
17 LinearSegmentedColormap,
18 LogNorm,
19 ListedColormap,
20)
21from matplotlib import colors
22from matplotlib.cbook import boxplot_stats
23import matplotlib.patches as patches
24from matplotlib.ticker import MultipleLocator
25import msmhelper as mh
26from msmhelper._cli.contact_rep import load_clusters
28from . import utils
29from .sankey_gap import sankey
30from .graph import draw_knetwork
32plt.rcParams["font.family"] = "sans-serif"
34### DENDROGRAM ###############################################################
37def plot_tree(root, macrostate_assignment, output_file, scale=1, offset=0):
38 """
39 Plot the dendrogram from a given state tree of BinaryTreeNode.
40 """
41 n_states = len(root.leaves)
43 # setup matplotlib
44 pplt.use_style(figsize=3.2 * scale, figratio="golden", true_black=True)
45 plt.rcParams["font.family"] = "sans-serif"
47 fig, (ax, ax_mat) = plt.subplots(
48 2,
49 1,
50 gridspec_kw={
51 "hspace": 0.05,
52 "height_ratios": [9, 1],
53 },
54 )
55 for key, spine in ax_mat.spines.items():
56 spine.set_visible(False)
58 ax = root.plot_tree(ax)
60 ax.set_ylabel(r"Metastability $T_{ii}$")
61 ax.set_xlabel("microstates")
62 ax.set_xlim(-0.005 * n_states, 1.005 * n_states)
63 ax.set_ylim(offset, 1.05)
65 # plot legend
66 cmap = plt.get_cmap("plasma_r", 10)
67 label = r"Fraction of Contacts $q$"
69 cmappable = ScalarMappable(root.feature_norm, cmap)
70 plt.sca(ax)
71 pplt.colorbar(cmappable, width="5%", label=label, position="top")
73 # bring microstates in the right order
74 macrostate_assignment = macrostate_assignment[:, [l.name for l in root.leaves]]
76 yticks = np.arange(0.5, 1.5 + macrostate_assignment.shape[0])
77 xticks = np.arange(0, n_states + 1)
78 cmap = LinearSegmentedColormap.from_list(
79 "binary",
80 [(0, 0, 0, 0), (0, 0, 0, 1)],
81 )
83 xvals = 0.5 * (xticks[:-1] + xticks[1:])
84 for idx, assignment in enumerate(macrostate_assignment):
85 xmean = np.median(xvals[assignment == 1])
87 pplt.text(
88 xmean,
89 yticks[idx] - (yticks[1] - yticks[0]),
90 f"{idx + 1:.0f}",
91 ax=ax_mat,
92 va="top",
93 contour=True,
94 size="small",
95 )
97 # Plot macrostate assignments
98 ax_mat.pcolormesh(
99 xticks,
100 yticks,
101 macrostate_assignment,
102 snap=True,
103 cmap=cmap,
104 vmin=0,
105 vmax=1,
106 )
107 # set x-labels
108 ax_mat.set_yticks(yticks)
109 ax_mat.set_yticklabels([])
110 ax_mat.grid(visible=True, axis="y", ls="-", lw=0.5)
111 ax_mat.tick_params(axis="y", length=0, width=0)
112 ax_mat.set_xlim(ax.get_xlim())
113 ax.set_xlabel("")
114 ax_mat.set_xlabel("Macrostates")
115 ax_mat.set_ylabel("")
116 fig.align_ylabels([ax, ax_mat])
118 ax_mat.set_xticks(np.arange(0.5, 0.5 + n_states))
120 # Hide microstate labels
121 for axes in (ax, ax_mat):
122 axes.set_xticks([])
123 axes.set_xticks([], minor=True)
124 axes.set_xticklabels([])
125 axes.set_xticklabels([], minor=True)
127 pplt.savefig(output_file)
128 plt.close()
131### SIMILARITY ###############################################################
134def stochastic_state_similarity(mpt1, mpt2, out):
135 """
136 Plot similarity values for a reference and a stochastic clustering.
137 """
138 ref, sto, S = mpt1 + mpt2
139 s1, s2, s3 = S
140 n_states = S.shape[1]
141 x, y = utils.get_grid_format(n_states)
142 fig, axs = plt.subplots(y, x, figsize=(2 * x, 2 * y))
143 for state, ax in enumerate(axs.flatten()[:n_states]):
144 m = 0
145 # Set left limit to minimum instead of 0
146 m = min([min(s1[state]), min(s2[state]), min(s3[state])]) - 0.02
148 ax.hist(s1[state], bins=np.linspace(m, 1, 21), color="g", alpha=0.7)
149 ax.hist(s2[state], bins=np.linspace(m, 1, 21), color="b", alpha=0.7)
150 ax.hist(s3[state], bins=np.linspace(m, 1, 21), color="r", alpha=0.7)
151 ax.set_title(f"state {state + 1}")
152 fig.supxlabel("Macrostate similarity")
153 fig.supylabel(f"Count of clusterings ({sto.n_runs} clusterings)")
154 leg = plt.figlegend(
155 ["union", "reference", "clustering"],
156 ncols=3,
157 loc="lower center",
158 bbox_to_anchor=(0.5, 0.05),
159 )
160 plt.tight_layout(rect=(0, 0.04, 1, 1))
161 plt.savefig(out)
162 plt.close()
165### IMPLIED TIMESCALES #######################################################
168def implied_timescales(
169 trajectorys,
170 lagtimes,
171 out,
172 titles="",
173 frame_length=0.2,
174 first_ref=False,
175 scale=1,
176 use_ref=True,
177 ntimescales=3,
178):
179 """
180 frame_length in ns / frame
181 """
182 if first_ref: 182 ↛ 184line 182 didn't jump to line 184 because the condition on line 182 was always true
183 ref_trajectory = trajectorys.pop(0)
184 x, y = utils.get_grid_format(len(trajectorys))
185 pplt.use_style(
186 figsize=(3.8 * scale, 3.2 * scale), latex=False, colors="pastel_autumn"
187 )
188 fig, axs = plt.subplots(y, x, sharex=True, sharey=True)
189 plt.grid(False)
190 if not isinstance(axs, np.ndarray): 190 ↛ 193line 190 didn't jump to line 193 because the condition on line 190 was always true
191 axs = np.array([axs])
193 if titles != "": 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true
194 titles = titles
195 else:
196 titles = [""] * len(trajectorys)
198 min_it = None
199 max_it = None
201 if first_ref: 201 ↛ 210line 201 didn't jump to line 210 because the condition on line 201 was always true
202 it_ref = mh.msm.implied_timescales(
203 ref_trajectory, lagtimes, ntimescales=ntimescales
204 )
205 # change from frames to ns
206 it_ref *= frame_length
207 min_it = it_ref.min()
208 max_it = it_ref.max()
210 lagtime = lagtimes[-1] / 4.5 * frame_length
211 lagtimes_ns = lagtimes * frame_length
212 for ax, traj, title in zip(axs.flatten(), trajectorys, titles):
213 ax.axvline(lagtime, color="pplt:grid")
214 it = mh.msm.implied_timescales(traj, lagtimes, ntimescales=ntimescales)
215 # change from frames to ns
216 it *= frame_length
217 if min_it is None: 217 ↛ 218line 217 didn't jump to line 218 because the condition on line 217 was never true
218 min_it = it.min()
219 else:
220 min_it = min(it.min(), min_it)
221 if max_it is None: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true
222 max_it = it.max()
223 else:
224 max_it = max(it.max(), max_it)
226 if first_ref: 226 ↛ 231line 226 didn't jump to line 231 because the condition on line 226 was always true
227 if not use_ref: 227 ↛ 230line 227 didn't jump to line 230 because the condition on line 227 was always true
228 _plot_impl_times(it_ref, lagtimes_ns, ax, ls="--")
229 else:
230 _plot_impl_times(it_ref, lagtimes_ns, ax, ls=":")
231 _plot_impl_times(it, lagtimes_ns, ax)
232 ax.set_yscale("log")
233 ax.set_title(title)
235 for ax in axs.flatten():
236 ax.set_ylim(min(min_it * 0.9, int(lagtimes_ns.shape[0] / 4)), max_it * 1.5)
238 if len(axs.shape) == 2: 238 ↛ 239line 238 didn't jump to line 239 because the condition on line 238 was never true
239 for ax in axs[-1]:
240 ax.set_xlabel(r"lag time $\tau$ / ns")
241 for axx in axs:
242 for ax in axx[1:]:
243 plt.setp(ax.get_yticklabels(), visible=False)
244 for ax in axs[:, 0]:
245 ax.set_ylabel("time scale / ns")
246 elif len(axs.shape) == 1: 246 ↛ 254line 246 didn't jump to line 254 because the condition on line 246 was always true
247 axs[0].set_ylabel("time scale / ns")
248 for ax in axs:
249 ax.set_xlabel(r"lag time $\tau$ / ns")
250 for ax in axs[1:]: 250 ↛ 251line 250 didn't jump to line 251 because the loop on line 250 never started
251 plt.setp(ax.get_yticklabels(), visible=False)
253 # Get handles and labels
254 handles, labels = plt.gca().get_legend_handles_labels()
256 # Reorder the handles and labels manually to achieve column-major ordering
257 desired_order = np.array(
258 [(i + ntimescales, i) for i in range(ntimescales)]
259 ).flatten()
260 handles = [handles[i] for i in desired_order]
261 labels = [labels[i] for i in desired_order]
263 pplt.legend(
264 handles=handles, labels=labels, outside="top", frameon=False, ncols=ntimescales
265 )
267 plt.tight_layout()
268 plt.savefig(out)
269 plt.close()
272def _plot_impl_times(impl_times, lagtimes, ax, ls="-"):
273 """Plot the implied timescales"""
274 colors = ["#264653", "#2A9D8F", "#E9C46A", "#f4a261", "#e76f51"] * 4
275 for idx, impl_time in enumerate(impl_times.T):
276 if ls == ":": 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true
277 label = f"$t_{{\\mathrm{{ref}},{idx + 1}}}$"
278 elif ls == "--":
279 label = f"$t_{{\\mathrm{{mic}},{idx + 1}}}$"
280 else:
281 label = f"$t_{idx + 1}$"
282 ax.plot(lagtimes, impl_time, label=label, color=colors[idx], ls=ls)
284 xlim = lagtimes[0], lagtimes[-1]
285 ref_low = int(lagtimes.shape[0] / 4)
286 ax.set_xlim(xlim)
287 # highlight diagonal
288 x_i = np.arange(ref_low, xlim[1])
289 ax.fill_between(x_i, x_i, color="pplt:grid")
292def relative_implied_timescales(cl, out):
293 pplt.use_style(figsize=(8, 2.5), latex=False, colors="pastel_autumn")
295 ref = cl.reference
296 its = cl.timescales / ref.timescales
298 fig = plt.figure()
299 ax1 = fig.add_subplot(1, 3, 1)
300 ax2 = fig.add_subplot(1, 3, 2, sharey=ax1)
301 ax3 = fig.add_subplot(1, 3, 3)
303 for ax in (ax1, ax2, ax3):
304 ax.grid(False)
306 ax1.hist(its[:, 0], bins=20)
307 ax1.set_title("its 1")
308 ax1.set_xlabel(
309 r"Relative Implied Timescale $\left(\frac{t_\mathrm{stoch}}{t_\mathrm{ref}}\right)$"
310 )
311 ax1.set_ylabel("Count of Clusterings")
312 ax2.hist(its.mean(axis=1), bins=20)
313 ax2.set_title(f"Mean its {1}-{3}")
314 ax2.set_xlabel(
315 r"Relative Implied Timescale $\left(\frac{t_\mathrm{stoch}}{t_\mathrm{ref}}\right)$"
316 )
318 bins = np.array(range(min(cl.n_macrostates) - 1, max(cl.n_macrostates) + 1)) + 0.5
320 ax3.hist(cl.n_macrostates, bins=bins)
321 ax3.set_title("n macrostates")
322 ax3.set_xlabel("macrostate count")
324 plt.tight_layout()
325 plt.savefig(out)
326 plt.close()
329### SIMILARITY MATRIX ########################################################
332def plot_heatmap(a, out, title=""):
333 """
334 Plot heatmap from a matrix. This is supposed for a similarity matrix as
335 returned from the multiplication of two MPT objects.
336 """
337 fig, ax = plt.subplots()
338 ax.imshow(a, norm="log")
339 ax.set_aspect("equal", "box")
341 ax.set_xticks(np.arange(a.shape[1]))
342 ax.set_yticks(np.arange(a.shape[0]))
343 if title:
344 ax.set_title(title)
345 ax.set_xlabel("Macrostate")
346 ax.set_ylabel("Macrostate")
347 plt.tight_layout()
348 plt.savefig(out)
349 plt.close()
352def transition_matrix(a, out, title="Transition Matrix", color_thr=0.01):
353 """
354 Plot heatmap from a matrix. This is supposed for a similarity matrix as
355 returned from the multiplication of two MPT objects.
356 """
357 # Scale a to percent
358 a = a * 100
360 # Define the colormap for the diagonal elements (logarithmic Reds)
361 diagonal_values = np.diag(a)
362 diag_norm = LogNorm(vmin=diagonal_values.min(), vmax=diagonal_values.max())
363 diag_cmap = plt.cm.Reds
365 # Adjust the Reds colormap to make the lower bound closer to red
366 reds_custom = diag_cmap(np.linspace(0.2, 1, 256))
367 diag_cmap_custom = ListedColormap(reds_custom)
369 # Define the colormap for the off-diagonal elements (logarithmic viridis)
370 off_diag_mask = ~np.eye(a.shape[0], dtype=bool)
371 off_diag_values = a[off_diag_mask]
373 # Threshold for light gray
374 threshold = color_thr * off_diag_values.max()
375 print(f"Threshold for probabilities: {threshold:.3f} %")
377 off_diag_norm = LogNorm(
378 vmin=threshold * (1 - color_thr), vmax=off_diag_values.max()
379 )
381 # Create a custom colormap for off-diagonal values including light gray
382 colors_list = plt.cm.viridis(np.linspace(0, 1, 256))
383 gray = np.array([0.9, 0.9, 0.9, 1.0])
384 colors_list[: int(color_thr * 256)] = gray
385 custom_off_diag_cmap = colors.ListedColormap(colors_list)
387 fig, ax = plt.subplots(figsize=(8, 8))
388 ax.set_aspect("equal", "box")
389 ax.grid(False)
391 for i in range(a.shape[0]):
392 for j in range(a.shape[1]):
393 value = a[i, j]
394 if value == 0:
395 color = (1, 1, 1, 1) # Zero probabilities are white
396 elif i == j:
397 color = diag_cmap_custom(diag_norm(value))
398 else:
399 color = (
400 gray
401 if value < threshold
402 else custom_off_diag_cmap(off_diag_norm(value))
403 )
405 ax.add_patch(patches.Rectangle((j - 0.5, i - 0.5), 1, 1, color=color))
407 # Add text with transition probabilities
408 if value != 0:
409 grayscale = np.sum(
410 np.array(color[:3]) * np.array([0.299, 0.587, 0.114])
411 )
412 text_color = "white" if grayscale < 0.5 else "black"
413 ax.text(
414 j,
415 i,
416 f"{value:.2f}%",
417 ha="center",
418 va="center",
419 color=text_color,
420 fontsize=10,
421 )
423 ax.set_xticks(np.arange(a.shape[1]))
424 ax.set_yticks(np.arange(a.shape[0]))
425 ax.set_xticklabels(np.arange(1, a.shape[1] + 1))
426 ax.set_yticklabels(np.arange(1, a.shape[0] + 1))
427 ax.set_xlim(-0.5, a.shape[1] - 0.5)
428 ax.set_ylim(-0.5, a.shape[0] - 0.5)
430 # Add a colorbar for diagonal values
431 cbar_diag = fig.colorbar(
432 plt.cm.ScalarMappable(norm=diag_norm, cmap=diag_cmap), ax=ax, shrink=0.5
433 )
434 cbar_diag.set_label("Self Transition Probabilities / \\%")
436 # Add a colorbar for off-diagonal values
437 cbar_off_diag = fig.colorbar(
438 plt.cm.ScalarMappable(norm=off_diag_norm, cmap=custom_off_diag_cmap),
439 ax=ax,
440 shrink=0.5,
441 )
442 cbar_off_diag.set_label("Transitiion Probabilities / \\%")
444 if title: 444 ↛ 447line 444 didn't jump to line 447 because the condition on line 444 was always true
445 ax.set_title(title)
447 ax.set_xlabel("From Macrostate")
448 ax.set_ylabel("To Macrostate")
449 plt.tight_layout()
450 plt.savefig(out)
451 plt.close()
454def transition_time(
455 a,
456 out,
457 lagtime=50.0,
458 frame_length=0.2,
459 title=r"Transition Times $\frac{t_\mathrm{lag}}{P}$",
460 color_thr=0.01,
461):
462 """
463 Plot heatmap from a matrix. This is supposed for a similarity matrix as
464 returned from the multiplication of two MPT objects.
465 frame_length in ns
466 """
467 with np.errstate(divide="ignore"):
468 a = lagtime / a * frame_length
470 # Define the colormap for the diagonal elements (logarithmic Reds)
471 diagonal_values = np.diag(a)
472 diag_norm = LogNorm(vmin=diagonal_values.min(), vmax=diagonal_values.max())
473 diag_cmap = plt.cm.Reds_r
475 # Adjust the Reds colormap to make the lower bound closer to red
476 reds_custom = diag_cmap(np.linspace(0, 0.8, 256))
477 diag_cmap_custom = ListedColormap(reds_custom)
479 # Define the colormap for the off-diagonal elements (logarithmic viridis)
480 off_diag_mask = ~np.eye(a.shape[0], dtype=bool)
481 off_diag_values = a[off_diag_mask]
483 # Threshold for light gray
484 threshold = off_diag_values.min() / color_thr
485 print(f"Threshold for probabilities: {threshold:.2f} ns")
487 off_diag_norm = LogNorm(
488 vmin=off_diag_values.min(), vmax=threshold / (1 - color_thr)
489 )
491 # Create a custom colormap for off-diagonal values including light gray
492 colors_list = plt.cm.viridis_r(np.linspace(0, 1, 256))
493 gray = np.array([0.9, 0.9, 0.9, 1.0])
494 colors_list[int((1 - color_thr) * 256) :] = gray
495 custom_off_diag_cmap = colors.ListedColormap(colors_list)
497 fig, ax = plt.subplots(figsize=(8, 8))
498 ax.set_aspect("equal", "box")
499 ax.grid(False)
501 for i in range(a.shape[0]):
502 for j in range(a.shape[1]):
503 value = a[i, j]
504 if value == np.inf:
505 color = (1, 1, 1, 1) # Zero probabilities are white
506 elif i == j:
507 color = diag_cmap_custom(diag_norm(value))
508 else:
509 color = (
510 gray
511 if value > threshold
512 else custom_off_diag_cmap(off_diag_norm(value))
513 )
515 ax.add_patch(patches.Rectangle((j - 0.5, i - 0.5), 1, 1, color=color))
517 # Add text with transition probabilities
518 if value != np.inf:
519 grayscale = np.sum(
520 np.array(color[:3]) * np.array([0.299, 0.587, 0.114])
521 )
522 text_color = "white" if grayscale < 0.5 else "black"
523 if value >= threshold:
524 pre_text = f"{value:.1g}"
525 text = pre_text[:2] + pre_text[-1]
526 else:
527 if value >= 100:
528 text = f"{value:.0f}"
529 else:
530 text = f"{value:#.3g}"
531 ax.text(
532 j, i, text, ha="center", va="center", color=text_color, fontsize=10
533 )
535 ax.set_xticks(np.arange(a.shape[1]))
536 ax.set_yticks(np.arange(a.shape[0]))
537 ax.set_xticklabels(np.arange(1, a.shape[1] + 1))
538 ax.set_yticklabels(np.arange(1, a.shape[0] + 1))
539 ax.set_xlim(-0.5, a.shape[1] - 0.5)
540 ax.set_ylim(-0.5, a.shape[0] - 0.5)
542 # Add a colorbar for diagonal values
543 cbar_diag = fig.colorbar(
544 plt.cm.ScalarMappable(norm=diag_norm, cmap=diag_cmap), ax=ax, shrink=0.5
545 )
546 cbar_diag.set_label("Self Transition Times / ns")
548 # Add a colorbar for off-diagonal values
549 cbar_off_diag = fig.colorbar(
550 plt.cm.ScalarMappable(norm=off_diag_norm, cmap=custom_off_diag_cmap),
551 ax=ax,
552 shrink=0.5,
553 )
554 cbar_off_diag.set_label("Transitiion Times / ns")
556 if title: 556 ↛ 559line 556 didn't jump to line 559 because the condition on line 556 was always true
557 ax.set_title(title)
559 ax.set_xlabel("From Macrostate")
560 ax.set_ylabel("To Macrostate")
561 plt.tight_layout()
562 plt.savefig(out)
563 plt.close()
566### MACROSTATE FEATURES ######################################################
569def macro_feature(micro_feature, out, ref=None, pop=None):
570 """
571 Plot histogram of feature distribution.
573 micro_feature (np.ndarray, NxR): N microstates, R runs, holds feature
574 values of respective macrostate
575 out (str): file to save the plot
576 ref (list[tuple]): list of
577 - macrostate_assignment
578 - macrostate_feature
579 - color
580 - label
581 of the clusterings that should be shown explicitly.
582 """
583 min_feature = micro_feature.min() * 0.95
584 max_feature = micro_feature.max() * 1.05
585 counts, bins = np.histogram(
586 micro_feature,
587 bins=np.linspace(min_feature, max_feature, 101),
588 weights=pop,
589 density=True,
590 )
591 norm_counts = counts / micro_feature.shape[1]
592 y_min = norm_counts[norm_counts > 0].min() * 0.7
594 fig, ax = plt.subplots(figsize=(8, 6))
595 ax.hist(bins[:-1], bins=bins, weights=norm_counts, label="Stochastic Clustering")
596 if ref is not None: 596 ↛ 600line 596 didn't jump to line 600 because the condition on line 596 was always true
597 # for mas, mfs, c, l, w in ref:
598 # add_ref(mas, mfs, ax, color=c, label=l, weights=w)
599 add_ref(ref.macrostate_assignment[ref.n_i], ref.macrostate_feature[ref.n_i], ax)
600 ax.set_xlabel("Fraction of Contacts")
601 ax.set_ylabel("Population")
602 ax.set_title(f"Macrostate Features, {micro_feature.shape[1]} clusterings")
603 ax.set_yscale("log")
604 ylim = ax.get_ylim()
605 ax.set_ylim((y_min, ylim[1]))
606 plt.legend(loc="lower left")
607 plt.tight_layout()
608 plt.savefig(out)
609 plt.close()
612def add_ref(
613 macrostate_assignment,
614 macrostate_feature,
615 ax,
616 color="r",
617 label="Reference",
618 weights=None,
619):
620 """
621 Add a clustering to the histogram.
623 macrostate_assignment (np.ndarray, MxN): macrostate assignement, M: number
624 of macrostates, N: number of microstates.
625 macrostate_feature (np.ndarray, M): mean feature for every macrostate.
626 """
627 b = True
628 for i, (ma, mf) in enumerate(zip(macrostate_assignment, macrostate_feature)):
629 x = [mf, mf]
630 if weights is None:
631 weights = np.array([1])
632 else:
633 weights = weights / weights.sum()
634 y = [1e-9, (ma * weights).sum() / weights.sum() * 1e-3]
635 if b:
636 ax.plot(x, y, c=color, label=label + " / 1000")
637 b = False
638 else:
639 ax.plot(x, y, c=color)
640 pplt.text(
641 mf + 0.015,
642 y[1] * 0.82,
643 f"{i + 1:.0f}",
644 c=color,
645 ax=ax,
646 contour=True,
647 size="small",
648 )
651### CONTACT REPRESENTATION ###################################################
654def contact_rep(contacts, cluster_file, state_trajectory, output, grid, scale=1):
655 """
656 Adapted from msmhelper.
658 Contact representation of states.
660 This script creates a contact representation of states. Were the states
661 are obtained by [MoSAIC](https://github.com/moldyn/MoSAIC) and the contact
662 representation was introduced in Nagel et al.[^1].
664 [^1]: Nagel et al., **Selecting Features for Markov Modeling: A Case Study
665 on HP35.**, *J. Chem. Theory Comput.*, submitted,
667 """
668 # setup matplotlib
669 pplt.use_style(
670 figsize=1.2 * scale,
671 colors="pastel_autumn",
672 true_black=True,
673 latex=False,
674 )
676 # load files
677 states = np.unique(state_trajectory)
678 clusters = load_clusters(cluster_file)
680 contact_idxs = np.hstack(clusters)
681 n_idxs = len(contact_idxs)
682 n_frames = len(contacts)
684 xtickpos = (
685 np.cumsum(
686 [
687 0,
688 *[len(clust) for clust in clusters[:-1]],
689 ]
690 )
691 - 0.5
692 )
693 nrows, ncols = grid
694 hspace, wspace = 0, 0
695 ylims = 0, np.quantile(contacts, 0.999)
697 counter = 0
698 for chunk in mh.plot._ck_test._split_array(states, nrows * ncols):
699 fig, axs = plt.subplots(
700 int(np.ceil(len(chunk) / ncols)),
701 ncols,
702 sharex=True,
703 sharey=True,
704 squeeze=False,
705 gridspec_kw={"wspace": wspace, "hspace": hspace},
706 )
708 # ignore outliers
709 for state, ax in zip(chunk, axs.flatten()):
710 contacts_state = contacts[state_trajectory == state]
711 pop_state = len(contacts_state) / n_frames
713 # get colormap
714 c1, c2, c3 = pplt.categorical_color(3, "C0")
716 stats = {
717 idx: boxplot_stats(contacts_state[:, idx])[0] for idx in contact_idxs
718 }
720 for color, (key_low, key_high), label in (
721 (c3, ("whislo", "whishi"), r"$Q_{1/3} \pm 1.5\mathrm{IQR}$"),
722 (c2, ("q1", "q3"), r"$\mathrm{IQR} = Q_3 - Q_1$"),
723 ):
724 ymax = [stats[idx][key_high] for idx in contact_idxs]
725 ymin = [stats[idx][key_low] for idx in contact_idxs]
726 ax.stairs(
727 ymax,
728 np.arange(n_idxs + 1) - 0.5,
729 baseline=ymin,
730 color=color,
731 lw=0,
732 fill=True,
733 label=label,
734 )
736 ax.hlines(
737 [stats[idx]["med"] for idx in contact_idxs],
738 xmin=np.arange(n_idxs) - 0.5,
739 xmax=np.arange(n_idxs) + 0.5,
740 label="median",
741 color=c1,
742 )
744 pplt.text(
745 0.5,
746 0.95,
747 rf"S{state + 1} {pop_state:.1%}",
748 ha="center",
749 va="top",
750 ax=ax,
751 transform=ax.transAxes,
752 contour=True,
753 )
755 ax.set_xlim([-0.5, n_idxs - 0.5])
756 ax.set_ylim(*ylims)
757 ax.set_xticks(xtickpos)
758 ax.set_xticklabels(np.arange(len(xtickpos)) + 1)
760 ax.grid(False)
761 for pos in xtickpos:
762 ax.axvline(pos, color="pplt:grid", lw=1.0)
764 pplt.hide_empty_axes()
765 pplt.legend(
766 ax=axs[0, 0],
767 outside="top",
768 bbox_to_anchor=(
769 0,
770 1.0,
771 axs.shape[1] + wspace * (axs.shape[1] - 1),
772 0.01,
773 ),
774 frameon=False,
775 ncol=2,
776 )
777 pplt.subplot_labels(
778 xlabel="contact clusters",
779 ylabel="distances [nm]",
780 )
782 # save figure and continue
783 if output is None: 783 ↛ 784line 783 didn't jump to line 784 because the condition on line 783 was never true
784 plt.show()
785 # output = f"{state_file}.contactRep.pdf"
786 # insert state_str between pathname and extension
787 path, ext = splitext(output)
788 if counter == 0: 788 ↛ 792line 788 didn't jump to line 792 because the condition on line 788 was always true
789 pplt.savefig(output)
790 plt.close()
791 else:
792 pplt.savefig(f"{path}.state{chunk[0]:.0f}-{chunk[-1]:.0f}{ext}")
793 plt.close()
794 counter += 1
797### SANKEY ###################################################################
800def sankey_diagram(cl, ref, out, ax=None, scale=1):
801 features = []
802 for macrostate in cl.tree[cl.n_i].macrostates:
803 features.append(macrostate.feature)
804 ma_order = np.argsort(features)[::-1]
805 colorDict = {}
806 for i, o in enumerate(ma_order):
807 colorDict[str(i + 1)] = cl.tree[cl.n_i].macrostates[o].color
808 if ax is None: 808 ↛ 810line 808 didn't jump to line 810 because the condition on line 808 was always true
809 pplt.use_style(figsize=(1.7 * scale, 3.6 * scale), true_black=True)
810 sankey(
811 left=(cl.macrostate_map[cl.n_i] + 1).astype(str),
812 right=(ref.macrostate_map[0] + 1).astype(str),
813 leftWeight=ref.pop,
814 rightWeight=ref.pop,
815 leftLabels=np.arange(1, cl.n_macrostates[cl.n_i] + 1).astype(str).tolist(),
816 rightLabels=np.arange(1, ref.n_macrostates[0] + 1).astype(str).tolist(),
817 colorDict=colorDict,
818 ax=ax,
819 )
820 if ax is None: 820 ↛ exitline 820 didn't return from function 'sankey_diagram' because the condition on line 820 was always true
821 pplt.savefig(out)
822 plt.close()
825### RMSD LINES ###############################################################
828def rmsd(rmsds, pops, helices=None, filename=None):
829 """
830 Plots a 2D NumPy array as a heatmap with a logarithmic color scale and variable row heights.
832 Parameters:
833 - vars (np.ndarray): The 2D NumPy array to plot. Values must be positive for logarithmic scaling.
834 - row_heights (np.ndarray): 1D array defining the height of each row.
835 - helices (np.ndarray): Array with start and end points for blocks to be indicated in the bottom row.
836 - filename (str, optional): If provided, saves the heatmap to this file.
837 """
838 # Ensure all values are positive for logarithmic scaling
839 if np.any(rmsds <= 0): 839 ↛ 840line 839 didn't jump to line 840 because the condition on line 839 was never true
840 raise ValueError(
841 "All values in `rmsds` must be positive for logarithmic scaling."
842 )
844 if rmsds.shape[0] != len(pops): 844 ↛ 845line 844 didn't jump to line 845 because the condition on line 844 was never true
845 raise ValueError("Length of `pops` must match the number of rows in `rmsds`.")
847 if helices is not None: 847 ↛ 850line 847 didn't jump to line 850 because the condition on line 847 was always true
848 n_plots = rmsds.shape[0] + 1
849 else:
850 n_plots = rmsds.shape[0]
852 w = 0.08 * rmsds.shape[1] + 3 # 8.6
853 h = 1 + 0.4 * n_plots # 6
854 pplt.use_style(
855 figsize=(w, h),
856 colors="pastel_autumn",
857 true_black=True,
858 latex=False,
859 )
860 fig, axs = plt.subplots(
861 n_plots,
862 3,
863 sharex="col",
864 width_ratios=[rmsds.shape[1], 8, 8],
865 gridspec_kw={"wspace": 0, "hspace": 0},
866 )
868 ylim = 0.5 * rmsds.min(), 2 * rmsds.max()
869 pops = pops / pops.sum()
870 ylim_hist = 0, 1.05 * pops.max()
872 rmsd_sums = rmsds[:, 2:-2].sum(axis=1)
873 ylim_rmsd = 0, 1.05 * rmsd_sums.max()
875 for i, ((ax, hist_ax, rmsd_ax), rmsd, pop) in enumerate(
876 zip(axs[:-1] if helices is not None else axs, rmsds, pops)
877 ):
878 rect = patches.Rectangle(
879 (0, 0.3), # Position of the block
880 pop,
881 0.4, # color='black'
882 )
883 hist_ax.add_patch(rect)
884 hist_ax.set_xlim(ylim_hist)
885 hist_ax.set_yticks([], [])
886 hist_ax.grid(False)
888 rect = patches.Rectangle(
889 (0, 0.3), # Position of the block
890 rmsd[2:-2].sum(),
891 0.4, # color='black'
892 )
893 rmsd_ax.add_patch(rect)
894 rmsd_ax.set_xlim(ylim_rmsd)
895 rmsd_ax.set_yticks([], [])
896 rmsd_ax.grid(False)
898 ax.plot(np.arange(rmsd.shape[0]) + 1, rmsd)
899 ax.fill_between(
900 np.arange(rmsd.shape[0]) + 1,
901 [ylim[0]] * rmsd.shape[0],
902 rmsd,
903 alpha=0.5,
904 # facecolor="none",
905 # hatch="/",
906 )
908 ax.set_yscale("log")
909 ax.set_ylabel(f"{i + 1}")
910 ax.set_xlim((0.5, rmsd.shape[0] + 0.5))
911 ax.set_ylim(ylim)
912 ax.grid(True)
914 if helices is not None: 914 ↛ 971line 914 didn't jump to line 971 because the condition on line 914 was always true
915 line_start = 1
916 helices_ax = axs[-1, 0]
917 # helices_ax.plot([1, rmsds.shape[1]], [0.5, 0.5]) #, c="k")
918 for start, end in helices:
919 if start > 0:
920 # Helices
921 start -= 0.3
922 end += 0.3
923 rect = patches.Rectangle(
924 (start, 0.3), # Position of the block
925 end - start,
926 0.4, # color='black'
927 fc="#264653",
928 ec="#264653",
929 lw=2,
930 )
931 else:
932 # Sheets
933 start, end = -start, -end
934 start -= 0.5
935 end += 0.5
936 rect = patches.Rectangle(
937 (start, 0.3), # Position of the block
938 end - start,
939 0.4, # color='black'
940 fc="white",
941 ec="#264653",
942 lw=2,
943 )
944 helices_ax.plot(
945 [line_start, start],
946 [0.5, 0.5],
947 solid_capstyle="butt",
948 c="#264653",
949 lw=2,
950 )
951 line_start = end
952 helices_ax.add_patch(rect)
953 helices_ax.plot(
954 [line_start, rmsds.shape[1]],
955 [0.5, 0.5],
956 solid_capstyle="butt",
957 c="#264653",
958 lw=2,
959 )
961 helices_ax.set_ylim((0, 1))
962 helices_ax.set_ylabel("H")
963 helices_ax.set_yticks([], [])
964 helices_ax.grid(False)
966 axs[-1, 1].grid(False)
967 axs[-1, 1].set_yticks([], [])
968 axs[-1, 2].grid(False)
969 axs[-1, 2].set_yticks([], [])
971 hist_ticks = axs[-1, 1].get_xticks()
972 hist_labels = axs[-1, 1].get_xticklabels()
973 hist_labels[0] = ""
974 axs[-1, 1].set_xticks(hist_ticks, hist_labels)
976 rmsd_ticks = axs[-1, 2].get_xticks()
977 rmsd_labels = axs[-1, 2].get_xticklabels()
978 rmsd_labels[0] = ""
979 axs[-1, 2].set_xticks(rmsd_ticks, rmsd_labels)
981 axs[-1, 0].xaxis.set_major_locator(MultipleLocator(5))
982 axs[-1, 0].xaxis.set_minor_locator(MultipleLocator(1))
983 axs[-1, 0].set_xlabel("Residue")
984 axs[-1, 1].set_xlabel("Population", rotation=20)
985 axs[-1, 2].set_xlabel(r"$\sum$ RMSD / nm", rotation=20)
986 fig.supylabel("Macrostate; RMSD Variance / nm")
988 # Save to file if filename is provided
989 plt.tight_layout()
990 if filename: 990 ↛ 993line 990 didn't jump to line 993 because the condition on line 990 was always true
991 plt.savefig(filename, dpi=192) # , bbox_inches="tight"
992 else:
993 plt.show()
994 plt.close()
997def delta_rmsd(rmsds, pops, helices=None, filename=None):
998 """
999 Plots a 2D NumPy array as a heatmap with a logarithmic color scale and variable row heights.
1001 Parameters:
1002 - vars (np.ndarray): The 2D NumPy array to plot. Values must be positive for logarithmic scaling.
1003 - row_heights (np.ndarray): 1D array defining the height of each row.
1004 - helices (np.ndarray): Array with start and end points for blocks to be indicated in the bottom row.
1005 - filename (str, optional): If provided, saves the heatmap to this file.
1006 """
1007 # Ensure all values are positive for logarithmic scaling
1008 if np.any(rmsds <= 0): 1008 ↛ 1009line 1008 didn't jump to line 1009 because the condition on line 1008 was never true
1009 raise ValueError(
1010 "All values in `rmsds` must be positive for logarithmic scaling."
1011 )
1013 if rmsds.shape[0] != len(pops): 1013 ↛ 1014line 1013 didn't jump to line 1014 because the condition on line 1013 was never true
1014 raise ValueError("Length of `pops` must match the number of rows in `rmsds`.")
1016 if helices is not None: 1016 ↛ 1019line 1016 didn't jump to line 1019 because the condition on line 1016 was always true
1017 n_plots = rmsds.shape[0] + 1
1018 else:
1019 n_plots = rmsds.shape[0]
1021 w = 0.08 * rmsds.shape[1] + 3 # 8.6
1022 h = 1 + 0.4 * n_plots # 6
1023 pplt.use_style(
1024 figsize=(w, h),
1025 colors="pastel_autumn",
1026 true_black=True,
1027 latex=False,
1028 )
1029 fig, axs = plt.subplots(
1030 n_plots,
1031 3,
1032 sharex="col",
1033 width_ratios=[rmsds.shape[1], 8, 8],
1034 gridspec_kw={"wspace": 0, "hspace": 0},
1035 )
1037 delta_rmsd = rmsds
1038 # delta_rmsd[1:] = rmsds[1:] - rmsds[:-1]
1039 delta_rmsd[1:] = rmsds[1:] - rmsds[0]
1040 rmsds = delta_rmsd
1042 rmsd_max = rmsds.max()
1043 rmsd_min = rmsds.min()
1044 rmsd_delta = rmsd_max - rmsd_min
1046 ylim = rmsd_min - rmsd_delta * 0.1, rmsd_max + rmsd_delta * 0.15
1047 pops = pops / pops.sum()
1048 ylim_hist = 0, 1.05 * pops.max()
1050 rmsd_sums = rmsds.sum(axis=1)
1051 ylim_rmsd = 0, 1.05 * rmsd_sums.max()
1053 for i, ((ax, hist_ax, rmsd_ax), rmsd, pop) in enumerate(
1054 zip(axs[:-1] if helices is not None else axs, rmsds, pops)
1055 ):
1056 rect = patches.Rectangle(
1057 (0, 0.3), # Position of the block
1058 pop,
1059 0.4, # color='black'
1060 )
1061 hist_ax.add_patch(rect)
1062 hist_ax.set_xlim(ylim_hist)
1063 hist_ax.set_yticks([], [])
1064 hist_ax.grid(False)
1066 rect = patches.Rectangle(
1067 (0, 0.3), # Position of the block
1068 abs(rmsd).sum(),
1069 0.4, # color='black'
1070 )
1071 rmsd_ax.add_patch(rect)
1072 rmsd_ax.set_xlim(ylim_rmsd)
1073 rmsd_ax.set_yticks([], [])
1074 rmsd_ax.grid(False)
1076 ax.plot(np.arange(rmsd.shape[0]) + 1, rmsd)
1077 ax.fill_between(
1078 np.arange(rmsd.shape[0]) + 1,
1079 # [ylim[0]]*rmsd.shape[0],
1080 [0] * rmsd.shape[0],
1081 rmsd,
1082 alpha=0.5,
1083 # facecolor="none",
1084 # hatch="/",
1085 )
1087 # ax.set_yscale("log")
1088 ax.set_ylabel(f"{i + 1}")
1089 # ax.set_ylabel(f"{i+1}-{i}")
1090 # ax.set_ylabel(f"{i+1}-{i}", rotation=90, position=(-0.2, 0))
1091 ax.set_xlim((0.5, rmsd.shape[0] + 0.5))
1092 ax.set_ylim(ylim)
1093 ax.grid(True)
1095 if helices is not None: 1095 ↛ 1152line 1095 didn't jump to line 1152 because the condition on line 1095 was always true
1096 line_start = 1
1097 helices_ax = axs[-1, 0]
1098 # helices_ax.plot([1, rmsds.shape[1]], [0.5, 0.5]) #, c="k")
1099 for start, end in helices:
1100 if start > 0:
1101 # Helices
1102 start -= 0.3
1103 end += 0.3
1104 rect = patches.Rectangle(
1105 (start, 0.3), # Position of the block
1106 end - start,
1107 0.4, # color='black'
1108 fc="#264653",
1109 ec="#264653",
1110 lw=2,
1111 )
1112 else:
1113 # Sheets
1114 start, end = -start, -end
1115 start -= 0.5
1116 end += 0.5
1117 rect = patches.Rectangle(
1118 (start, 0.3), # Position of the block
1119 end - start,
1120 0.4, # color='black'
1121 fc="white",
1122 ec="#264653",
1123 lw=2,
1124 )
1125 helices_ax.plot(
1126 [line_start, start],
1127 [0.5, 0.5],
1128 solid_capstyle="butt",
1129 c="#264653",
1130 lw=2,
1131 )
1132 line_start = end
1133 helices_ax.add_patch(rect)
1134 helices_ax.plot(
1135 [line_start, rmsds.shape[1]],
1136 [0.5, 0.5],
1137 solid_capstyle="butt",
1138 c="#264653",
1139 lw=2,
1140 )
1142 helices_ax.set_ylim((0, 1))
1143 helices_ax.set_ylabel("H")
1144 helices_ax.set_yticks([], [])
1145 helices_ax.grid(False)
1147 axs[-1, 1].grid(False)
1148 axs[-1, 1].set_yticks([], [])
1149 axs[-1, 2].grid(False)
1150 axs[-1, 2].set_yticks([], [])
1152 axs[0, 0].set_ylim(0.5 * rmsds[0].min(), 2 * rmsds[0].max())
1153 axs[0, 0].set_yscale("log")
1154 # axs[0, 0].set_ylabel("1") #, rotation=90)
1156 hist_ticks = axs[-1, 1].get_xticks()
1157 hist_labels = axs[-1, 1].get_xticklabels()
1158 hist_labels[0] = ""
1159 axs[-1, 1].set_xticks(hist_ticks, hist_labels)
1161 rmsd_ticks = axs[-1, 2].get_xticks()
1162 rmsd_labels = axs[-1, 2].get_xticklabels()
1163 rmsd_labels[0] = ""
1164 axs[-1, 2].set_xticks(rmsd_ticks, rmsd_labels)
1166 axs[-1, 0].xaxis.set_major_locator(MultipleLocator(5))
1167 axs[-1, 0].xaxis.set_minor_locator(MultipleLocator(1))
1168 axs[-1, 0].set_xlabel("Residue")
1169 axs[-1, 1].set_xlabel("Population", rotation=20)
1170 axs[-1, 2].set_xlabel(r"$\sum$ |$\Delta$RMSD| / nm", rotation=20)
1171 fig.supylabel(r"$\Delta$Macrostate; $\Delta$RMSD Variance / nm wrt. 1st State")
1173 # Save to file if filename is provided
1174 plt.tight_layout()
1175 if filename: 1175 ↛ 1178line 1175 didn't jump to line 1178 because the condition on line 1175 was always true
1176 plt.savefig(filename, dpi=192) # , bbox_inches="tight"
1177 else:
1178 plt.show()
1179 plt.close()
1182### TRAJECTORY ###############################################################
1185def state_trajectory(trajectory, filename, row_length=0.2, frame_length=0.2):
1186 """
1187 Plot state trajectory
1189 trajectory (np.ndarray): state trajectory
1190 filename (str): file name to save the plot to
1191 row_length (int|float):
1192 row_length > 1: number of frames in each row
1193 0 < row_length <= 1: fraction of total frames per row (1/n_rows)
1194 frame_length (float): frame length in ns
1195 """
1196 if row_length > 1: 1196 ↛ 1197line 1196 didn't jump to line 1197 because the condition on line 1196 was never true
1197 x_max = int(row_length)
1198 elif row_length > 0: 1198 ↛ 1201line 1198 didn't jump to line 1201 because the condition on line 1198 was always true
1199 x_max = int(np.ceil(trajectory.shape[0] * row_length))
1200 else:
1201 raise ValueError("row_lengthg must be > 0")
1203 frame_length /= 1000.0
1204 # Calculate unique states and their lengths
1205 unique_states, lengths = utils.find_state_lengths(trajectory)
1206 unique_states += 1
1207 lengths = lengths * frame_length
1208 n_rows = int(np.ceil(trajectory.shape[0] / x_max))
1210 x_max *= frame_length
1212 figsize = (11.7, 8.3)
1213 # # Set up figure size proportional to data
1214 # width = max(6, x_max * 0.0001) # Minimum width of 6 inches
1215 # height = max(
1216 # 2, (unique_states.max() - unique_states.min() + 1) * 0.05 * n_rows + 0.6
1217 # ) # Minimum height of 4 inches
1218 # figsize = (width * 1.5, height * 1.5)
1220 u_states = np.unique(unique_states)
1221 cmap = mpl.cm.turbo
1222 norm = mpl.colors.BoundaryNorm(np.concatenate([[0], u_states]) + 0.5, cmap.N)
1224 fig, axs = plt.subplots(
1225 n_rows,
1226 1,
1227 sharex=True,
1228 figsize=figsize,
1229 gridspec_kw={"wspace": 0, "hspace": 0},
1230 squeeze=False,
1231 layout="compressed",
1232 )
1233 axs = axs[:, 0]
1234 axi = 0
1236 # Plot each state occurrence as a line segment
1237 x_start = 0 # Initial x-coordinate for the first segment
1238 for state, length in zip(unique_states, lengths):
1239 x_end = x_start + length # Calculate end position of this segment on the x-axis
1240 color = cmap(norm(state))
1242 while x_end > x_max:
1243 axs[axi].plot(
1244 [x_start, x_max],
1245 [state, state],
1246 color=color[:3],
1247 linewidth=3,
1248 solid_capstyle="butt",
1249 )
1250 x_end -= x_max
1251 x_start = 0
1252 axi += 1
1253 if not np.isclose(x_start, x_end):
1254 axs[axi].plot(
1255 [x_start, x_end],
1256 [state, state],
1257 color=color,
1258 linewidth=3,
1259 solid_capstyle="butt",
1260 )
1262 # Move x_start to the end of the current segment for the next one
1263 x_start = x_end
1265 # Label axes and set title
1266 cbar = fig.colorbar(
1267 ScalarMappable(norm=norm, cmap=cmap),
1268 ax=axs,
1269 orientation="vertical",
1270 label="Macrostate",
1271 )
1272 cbar.set_ticks(ticks=u_states, labels=u_states, minor=False)
1273 fig.axes[-1].tick_params(length=0)
1275 fig.supylabel("State Index")
1276 axs[-1].set_xlabel(r"t / $\mu$s")
1278 for ax in axs:
1279 ax.set_ylim(unique_states.min() - 1, unique_states.max())
1281 # Set axis limits
1282 plt.xlim(0, x_max)
1284 # Save the plot to the specified file
1285 plt.savefig(filename)
1286 plt.close() # Close the plot to free memory
1289### CHAPMAN-KOLMOGOROV TEST ##################################################
1292def chapman_kolmogorov(mpt, out, frame_length=0.2):
1293 """Chapman-Kolmogorov Test. Frame length in ns"""
1294 ck = mh.msm.tests.chapman_kolmogorov_test(
1295 utils.get_multi_state_trajectory(
1296 mpt.macrostate_trajectory[mpt.n_i], mpt.limits
1297 ),
1298 [50, 50, 50, 50, 50],
1299 4000,
1300 # int(1550*frame_length),
1301 )
1302 pplt.use_style(
1303 figsize=4.8,
1304 colors="pastel_autumn",
1305 true_black=True,
1306 latex=False,
1307 )
1309 nrows, ncols = utils.get_grid_format(mpt.n_macrostates[mpt.n_i])
1310 for chunk in mh.plot._ck_test._split_array(
1311 np.arange(mpt.n_macrostates[mpt.n_i]), nrows * ncols
1312 ):
1313 fig = plot_ck_test(
1314 ck=ck,
1315 states=chunk,
1316 frames_per_unit=1 / frame_length,
1317 unit="ns",
1318 grid=(ncols, nrows),
1319 )
1321 for ax in fig.axes:
1322 for text in ax.texts:
1323 text.set_position((0.15, 0.2))
1324 plt.savefig(out)
1325 plt.close()
1328def plot_ck_test(
1329 ck,
1330 states=None,
1331 frames_per_unit=1,
1332 unit="frames",
1333 grid=(3, 3),
1334):
1335 """Plot CK-Test results.
1337 This routine is a basic helper function to visualize the results of
1338 [msmhelper.msm.chapman_kolmogorov_test][].
1340 Parameters
1341 ----------
1342 ck : dict
1343 Dictionary holding for each lagtime the CK equation and with 'md' the
1344 reference.
1345 states : ndarray, optional
1346 List containing all states to plot the CK-test.
1347 frames_per_unit : float, optional
1348 Number of frames per given unit. This is used to scale the axis
1349 accordingly.
1350 unit : ['frames', 'fs', 'ps', 'ns', 'us'], optional
1351 Unit to use for label.
1352 grid : (int, int), optional
1353 The number of `(n_rows, n_cols)` to use for the grid layout.
1355 Returns
1356 -------
1357 fig : matplotlib.Figure
1358 Figure holding plots.
1360 Notes
1361 -----
1362 Adapted from msmhelper.
1364 """
1365 # load colors
1366 pplt.load_cmaps()
1367 pplt.load_colors()
1369 lagtimes = np.array([key for key in ck.keys() if key != "md"])
1370 if states is None: 1370 ↛ 1371line 1370 didn't jump to line 1371 because the condition on line 1370 was never true
1371 states = np.array(list(ck["md"]["ck"].keys()))
1373 nrows, ncols = grid
1374 needed_rows = int(np.ceil(len(states) / ncols))
1376 fig, axs = plt.subplots(
1377 needed_rows,
1378 ncols,
1379 sharex=True,
1380 sharey="row",
1381 gridspec_kw={"wspace": 0, "hspace": 0},
1382 )
1383 axs = np.atleast_2d(axs)
1385 max_time = np.max(ck["md"]["time"])
1386 for irow, states_row in enumerate(mh.plot._ck_test._split_array(states, ncols)):
1387 for icol, state in enumerate(states_row):
1388 ax = axs[irow, icol]
1390 pplt.plot(
1391 ck["md"]["time"] / frames_per_unit,
1392 ck["md"]["ck"][state],
1393 "--",
1394 ax=ax,
1395 color="pplt:gray",
1396 label="MD",
1397 )
1398 for lagtime in lagtimes:
1399 pplt.plot(
1400 ck[lagtime]["time"] / frames_per_unit,
1401 ck[lagtime]["ck"][state],
1402 ax=ax,
1403 label=lagtime / frames_per_unit,
1404 )
1405 pplt.text(
1406 0.5,
1407 0.9,
1408 "S{0}".format(state + 1),
1409 contour=True,
1410 va="top",
1411 transform=ax.transAxes,
1412 ax=ax,
1413 )
1415 # set scale
1416 ax.set_xscale("log")
1417 ax.set_xlim(
1418 [
1419 lagtimes[0] / frames_per_unit,
1420 max_time / frames_per_unit,
1421 ]
1422 )
1423 ax.set_ylim([0, 1])
1424 if irow < len(axs) - 1:
1425 ax.set_yticks([0.5, 1])
1426 else:
1427 ax.set_yticks([0, 0.5, 1])
1429 ax.grid(True, which="major", linestyle="--")
1430 ax.grid(True, which="minor", linestyle="dotted")
1431 ax.set_axisbelow(True)
1433 # set legend
1434 legend_kw = (
1435 {
1436 "outside": "right",
1437 "bbox_to_anchor": (2.0, (1 - nrows), 0.2, nrows),
1438 }
1439 if ncols in {1, 2}
1440 else {
1441 "outside": "top",
1442 "bbox_to_anchor": (0.0, 1.0, ncols, 0.01),
1443 }
1444 )
1445 if ncols == 3: 1445 ↛ 1446line 1445 didn't jump to line 1446 because the condition on line 1445 was never true
1446 legend_kw["ncol"] = 3
1447 pplt.legend(
1448 ax=axs[0, 0],
1449 **legend_kw,
1450 title=rf"$\tau_\mathrm{{lag}}$ [{unit}]",
1451 frameon=False,
1452 )
1454 ylabel = (
1455 (r"self-transition probability $P_{i\to i}$")
1456 if nrows >= 3
1457 else (r"$P_{i\to i}$")
1458 )
1460 pplt.hide_empty_axes()
1461 pplt.label_outer()
1462 pplt.subplot_labels(
1463 ylabel=ylabel,
1464 xlabel=r"time $t$ [{unit}]".format(unit=unit),
1465 )
1466 return fig
1469### MACROSTATE GRAPH #########################################################
1472def state_network(lumping, out):
1473 draw_knetwork(
1474 lumping.macrostate_trajectory[lumping.n_i],
1475 lumping.lagtime,
1476 lumping.mean_feature_trajectory,
1477 out,
1478 )