Coverage for MPP/plot.py: 90%

641 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-14 16:23 +0200

1#!/usr/bin/env python3 

2""" 

3plot.py 

4================== 

5 

6Various plot functions used in this package. 

7""" 

8 

9from os.path import splitext 

10 

11import numpy as np 

12import prettypyplot as pplt 

13import matplotlib as mpl 

14from matplotlib import pyplot as plt 

15from matplotlib.cm import ScalarMappable 

16from matplotlib.colors import ( 

17 LinearSegmentedColormap, 

18 LogNorm, 

19 ListedColormap, 

20) 

21from matplotlib import colors 

22from matplotlib.cbook import boxplot_stats 

23import matplotlib.patches as patches 

24from matplotlib.ticker import MultipleLocator 

25import msmhelper as mh 

26from msmhelper._cli.contact_rep import load_clusters 

27 

28from . import utils 

29from .sankey_gap import sankey 

30from .graph import draw_knetwork 

31 

32plt.rcParams["font.family"] = "sans-serif" 

33 

34### DENDROGRAM ############################################################### 

35 

36 

37def plot_tree(root, macrostate_assignment, output_file, scale=1, offset=0): 

38 """ 

39 Plot the dendrogram from a given state tree of BinaryTreeNode. 

40 """ 

41 n_states = len(root.leaves) 

42 

43 # setup matplotlib 

44 pplt.use_style(figsize=3.2 * scale, figratio="golden", true_black=True) 

45 plt.rcParams["font.family"] = "sans-serif" 

46 

47 fig, (ax, ax_mat) = plt.subplots( 

48 2, 

49 1, 

50 gridspec_kw={ 

51 "hspace": 0.05, 

52 "height_ratios": [9, 1], 

53 }, 

54 ) 

55 for key, spine in ax_mat.spines.items(): 

56 spine.set_visible(False) 

57 

58 ax = root.plot_tree(ax) 

59 

60 ax.set_ylabel(r"Metastability $T_{ii}$") 

61 ax.set_xlabel("microstates") 

62 ax.set_xlim(-0.005 * n_states, 1.005 * n_states) 

63 ax.set_ylim(offset, 1.05) 

64 

65 # plot legend 

66 cmap = plt.get_cmap("plasma_r", 10) 

67 label = r"Fraction of Contacts $q$" 

68 

69 cmappable = ScalarMappable(root.feature_norm, cmap) 

70 plt.sca(ax) 

71 pplt.colorbar(cmappable, width="5%", label=label, position="top") 

72 

73 # bring microstates in the right order 

74 macrostate_assignment = macrostate_assignment[:, [l.name for l in root.leaves]] 

75 

76 yticks = np.arange(0.5, 1.5 + macrostate_assignment.shape[0]) 

77 xticks = np.arange(0, n_states + 1) 

78 cmap = LinearSegmentedColormap.from_list( 

79 "binary", 

80 [(0, 0, 0, 0), (0, 0, 0, 1)], 

81 ) 

82 

83 xvals = 0.5 * (xticks[:-1] + xticks[1:]) 

84 for idx, assignment in enumerate(macrostate_assignment): 

85 xmean = np.median(xvals[assignment == 1]) 

86 

87 pplt.text( 

88 xmean, 

89 yticks[idx] - (yticks[1] - yticks[0]), 

90 f"{idx + 1:.0f}", 

91 ax=ax_mat, 

92 va="top", 

93 contour=True, 

94 size="small", 

95 ) 

96 

97 # Plot macrostate assignments 

98 ax_mat.pcolormesh( 

99 xticks, 

100 yticks, 

101 macrostate_assignment, 

102 snap=True, 

103 cmap=cmap, 

104 vmin=0, 

105 vmax=1, 

106 ) 

107 # set x-labels 

108 ax_mat.set_yticks(yticks) 

109 ax_mat.set_yticklabels([]) 

110 ax_mat.grid(visible=True, axis="y", ls="-", lw=0.5) 

111 ax_mat.tick_params(axis="y", length=0, width=0) 

112 ax_mat.set_xlim(ax.get_xlim()) 

113 ax.set_xlabel("") 

114 ax_mat.set_xlabel("Macrostates") 

115 ax_mat.set_ylabel("") 

116 fig.align_ylabels([ax, ax_mat]) 

117 

118 ax_mat.set_xticks(np.arange(0.5, 0.5 + n_states)) 

119 

120 # Hide microstate labels 

121 for axes in (ax, ax_mat): 

122 axes.set_xticks([]) 

123 axes.set_xticks([], minor=True) 

124 axes.set_xticklabels([]) 

125 axes.set_xticklabels([], minor=True) 

126 

127 pplt.savefig(output_file) 

128 plt.close() 

129 

130 

131### SIMILARITY ############################################################### 

132 

133 

134def stochastic_state_similarity(mpt1, mpt2, out): 

135 """ 

136 Plot similarity values for a reference and a stochastic clustering. 

137 """ 

138 ref, sto, S = mpt1 + mpt2 

139 s1, s2, s3 = S 

140 n_states = S.shape[1] 

141 x, y = utils.get_grid_format(n_states) 

142 fig, axs = plt.subplots(y, x, figsize=(2 * x, 2 * y)) 

143 for state, ax in enumerate(axs.flatten()[:n_states]): 

144 m = 0 

145 # Set left limit to minimum instead of 0 

146 m = min([min(s1[state]), min(s2[state]), min(s3[state])]) - 0.02 

147 

148 ax.hist(s1[state], bins=np.linspace(m, 1, 21), color="g", alpha=0.7) 

149 ax.hist(s2[state], bins=np.linspace(m, 1, 21), color="b", alpha=0.7) 

150 ax.hist(s3[state], bins=np.linspace(m, 1, 21), color="r", alpha=0.7) 

151 ax.set_title(f"state {state + 1}") 

152 fig.supxlabel("Macrostate similarity") 

153 fig.supylabel(f"Count of clusterings ({sto.n_runs} clusterings)") 

154 leg = plt.figlegend( 

155 ["union", "reference", "clustering"], 

156 ncols=3, 

157 loc="lower center", 

158 bbox_to_anchor=(0.5, 0.05), 

159 ) 

160 plt.tight_layout(rect=(0, 0.04, 1, 1)) 

161 plt.savefig(out) 

162 plt.close() 

163 

164 

165### IMPLIED TIMESCALES ####################################################### 

166 

167 

168def implied_timescales( 

169 trajectorys, 

170 lagtimes, 

171 out, 

172 titles="", 

173 frame_length=0.2, 

174 first_ref=False, 

175 scale=1, 

176 use_ref=True, 

177 ntimescales=3, 

178): 

179 """ 

180 frame_length in ns / frame 

181 """ 

182 if first_ref: 182 ↛ 184line 182 didn't jump to line 184 because the condition on line 182 was always true

183 ref_trajectory = trajectorys.pop(0) 

184 x, y = utils.get_grid_format(len(trajectorys)) 

185 pplt.use_style( 

186 figsize=(3.8 * scale, 3.2 * scale), latex=False, colors="pastel_autumn" 

187 ) 

188 fig, axs = plt.subplots(y, x, sharex=True, sharey=True) 

189 plt.grid(False) 

190 if not isinstance(axs, np.ndarray): 190 ↛ 193line 190 didn't jump to line 193 because the condition on line 190 was always true

191 axs = np.array([axs]) 

192 

193 if titles != "": 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true

194 titles = titles 

195 else: 

196 titles = [""] * len(trajectorys) 

197 

198 min_it = None 

199 max_it = None 

200 

201 if first_ref: 201 ↛ 210line 201 didn't jump to line 210 because the condition on line 201 was always true

202 it_ref = mh.msm.implied_timescales( 

203 ref_trajectory, lagtimes, ntimescales=ntimescales 

204 ) 

205 # change from frames to ns 

206 it_ref *= frame_length 

207 min_it = it_ref.min() 

208 max_it = it_ref.max() 

209 

210 lagtime = lagtimes[-1] / 4.5 * frame_length 

211 lagtimes_ns = lagtimes * frame_length 

212 for ax, traj, title in zip(axs.flatten(), trajectorys, titles): 

213 ax.axvline(lagtime, color="pplt:grid") 

214 it = mh.msm.implied_timescales(traj, lagtimes, ntimescales=ntimescales) 

215 # change from frames to ns 

216 it *= frame_length 

217 if min_it is None: 217 ↛ 218line 217 didn't jump to line 218 because the condition on line 217 was never true

218 min_it = it.min() 

219 else: 

220 min_it = min(it.min(), min_it) 

221 if max_it is None: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true

222 max_it = it.max() 

223 else: 

224 max_it = max(it.max(), max_it) 

225 

226 if first_ref: 226 ↛ 231line 226 didn't jump to line 231 because the condition on line 226 was always true

227 if not use_ref: 227 ↛ 230line 227 didn't jump to line 230 because the condition on line 227 was always true

228 _plot_impl_times(it_ref, lagtimes_ns, ax, ls="--") 

229 else: 

230 _plot_impl_times(it_ref, lagtimes_ns, ax, ls=":") 

231 _plot_impl_times(it, lagtimes_ns, ax) 

232 ax.set_yscale("log") 

233 ax.set_title(title) 

234 

235 for ax in axs.flatten(): 

236 ax.set_ylim(min(min_it * 0.9, int(lagtimes_ns.shape[0] / 4)), max_it * 1.5) 

237 

238 if len(axs.shape) == 2: 238 ↛ 239line 238 didn't jump to line 239 because the condition on line 238 was never true

239 for ax in axs[-1]: 

240 ax.set_xlabel(r"lag time $\tau$ / ns") 

241 for axx in axs: 

242 for ax in axx[1:]: 

243 plt.setp(ax.get_yticklabels(), visible=False) 

244 for ax in axs[:, 0]: 

245 ax.set_ylabel("time scale / ns") 

246 elif len(axs.shape) == 1: 246 ↛ 254line 246 didn't jump to line 254 because the condition on line 246 was always true

247 axs[0].set_ylabel("time scale / ns") 

248 for ax in axs: 

249 ax.set_xlabel(r"lag time $\tau$ / ns") 

250 for ax in axs[1:]: 250 ↛ 251line 250 didn't jump to line 251 because the loop on line 250 never started

251 plt.setp(ax.get_yticklabels(), visible=False) 

252 

253 # Get handles and labels 

254 handles, labels = plt.gca().get_legend_handles_labels() 

255 

256 # Reorder the handles and labels manually to achieve column-major ordering 

257 desired_order = np.array( 

258 [(i + ntimescales, i) for i in range(ntimescales)] 

259 ).flatten() 

260 handles = [handles[i] for i in desired_order] 

261 labels = [labels[i] for i in desired_order] 

262 

263 pplt.legend( 

264 handles=handles, labels=labels, outside="top", frameon=False, ncols=ntimescales 

265 ) 

266 

267 plt.tight_layout() 

268 plt.savefig(out) 

269 plt.close() 

270 

271 

272def _plot_impl_times(impl_times, lagtimes, ax, ls="-"): 

273 """Plot the implied timescales""" 

274 colors = ["#264653", "#2A9D8F", "#E9C46A", "#f4a261", "#e76f51"] * 4 

275 for idx, impl_time in enumerate(impl_times.T): 

276 if ls == ":": 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true

277 label = f"$t_{{\\mathrm{{ref}},{idx + 1}}}$" 

278 elif ls == "--": 

279 label = f"$t_{{\\mathrm{{mic}},{idx + 1}}}$" 

280 else: 

281 label = f"$t_{idx + 1}$" 

282 ax.plot(lagtimes, impl_time, label=label, color=colors[idx], ls=ls) 

283 

284 xlim = lagtimes[0], lagtimes[-1] 

285 ref_low = int(lagtimes.shape[0] / 4) 

286 ax.set_xlim(xlim) 

287 # highlight diagonal 

288 x_i = np.arange(ref_low, xlim[1]) 

289 ax.fill_between(x_i, x_i, color="pplt:grid") 

290 

291 

292def relative_implied_timescales(cl, out): 

293 pplt.use_style(figsize=(8, 2.5), latex=False, colors="pastel_autumn") 

294 

295 ref = cl.reference 

296 its = cl.timescales / ref.timescales 

297 

298 fig = plt.figure() 

299 ax1 = fig.add_subplot(1, 3, 1) 

300 ax2 = fig.add_subplot(1, 3, 2, sharey=ax1) 

301 ax3 = fig.add_subplot(1, 3, 3) 

302 

303 for ax in (ax1, ax2, ax3): 

304 ax.grid(False) 

305 

306 ax1.hist(its[:, 0], bins=20) 

307 ax1.set_title("its 1") 

308 ax1.set_xlabel( 

309 r"Relative Implied Timescale $\left(\frac{t_\mathrm{stoch}}{t_\mathrm{ref}}\right)$" 

310 ) 

311 ax1.set_ylabel("Count of Clusterings") 

312 ax2.hist(its.mean(axis=1), bins=20) 

313 ax2.set_title(f"Mean its {1}-{3}") 

314 ax2.set_xlabel( 

315 r"Relative Implied Timescale $\left(\frac{t_\mathrm{stoch}}{t_\mathrm{ref}}\right)$" 

316 ) 

317 

318 bins = np.array(range(min(cl.n_macrostates) - 1, max(cl.n_macrostates) + 1)) + 0.5 

319 

320 ax3.hist(cl.n_macrostates, bins=bins) 

321 ax3.set_title("n macrostates") 

322 ax3.set_xlabel("macrostate count") 

323 

324 plt.tight_layout() 

325 plt.savefig(out) 

326 plt.close() 

327 

328 

329### SIMILARITY MATRIX ######################################################## 

330 

331 

332def plot_heatmap(a, out, title=""): 

333 """ 

334 Plot heatmap from a matrix. This is supposed for a similarity matrix as 

335 returned from the multiplication of two MPT objects. 

336 """ 

337 fig, ax = plt.subplots() 

338 ax.imshow(a, norm="log") 

339 ax.set_aspect("equal", "box") 

340 

341 ax.set_xticks(np.arange(a.shape[1])) 

342 ax.set_yticks(np.arange(a.shape[0])) 

343 if title: 

344 ax.set_title(title) 

345 ax.set_xlabel("Macrostate") 

346 ax.set_ylabel("Macrostate") 

347 plt.tight_layout() 

348 plt.savefig(out) 

349 plt.close() 

350 

351 

352def transition_matrix(a, out, title="Transition Matrix", color_thr=0.01): 

353 """ 

354 Plot heatmap from a matrix. This is supposed for a similarity matrix as 

355 returned from the multiplication of two MPT objects. 

356 """ 

357 # Scale a to percent 

358 a = a * 100 

359 

360 # Define the colormap for the diagonal elements (logarithmic Reds) 

361 diagonal_values = np.diag(a) 

362 diag_norm = LogNorm(vmin=diagonal_values.min(), vmax=diagonal_values.max()) 

363 diag_cmap = plt.cm.Reds 

364 

365 # Adjust the Reds colormap to make the lower bound closer to red 

366 reds_custom = diag_cmap(np.linspace(0.2, 1, 256)) 

367 diag_cmap_custom = ListedColormap(reds_custom) 

368 

369 # Define the colormap for the off-diagonal elements (logarithmic viridis) 

370 off_diag_mask = ~np.eye(a.shape[0], dtype=bool) 

371 off_diag_values = a[off_diag_mask] 

372 

373 # Threshold for light gray 

374 threshold = color_thr * off_diag_values.max() 

375 print(f"Threshold for probabilities: {threshold:.3f} %") 

376 

377 off_diag_norm = LogNorm( 

378 vmin=threshold * (1 - color_thr), vmax=off_diag_values.max() 

379 ) 

380 

381 # Create a custom colormap for off-diagonal values including light gray 

382 colors_list = plt.cm.viridis(np.linspace(0, 1, 256)) 

383 gray = np.array([0.9, 0.9, 0.9, 1.0]) 

384 colors_list[: int(color_thr * 256)] = gray 

385 custom_off_diag_cmap = colors.ListedColormap(colors_list) 

386 

387 fig, ax = plt.subplots(figsize=(8, 8)) 

388 ax.set_aspect("equal", "box") 

389 ax.grid(False) 

390 

391 for i in range(a.shape[0]): 

392 for j in range(a.shape[1]): 

393 value = a[i, j] 

394 if value == 0: 

395 color = (1, 1, 1, 1) # Zero probabilities are white 

396 elif i == j: 

397 color = diag_cmap_custom(diag_norm(value)) 

398 else: 

399 color = ( 

400 gray 

401 if value < threshold 

402 else custom_off_diag_cmap(off_diag_norm(value)) 

403 ) 

404 

405 ax.add_patch(patches.Rectangle((j - 0.5, i - 0.5), 1, 1, color=color)) 

406 

407 # Add text with transition probabilities 

408 if value != 0: 

409 grayscale = np.sum( 

410 np.array(color[:3]) * np.array([0.299, 0.587, 0.114]) 

411 ) 

412 text_color = "white" if grayscale < 0.5 else "black" 

413 ax.text( 

414 j, 

415 i, 

416 f"{value:.2f}%", 

417 ha="center", 

418 va="center", 

419 color=text_color, 

420 fontsize=10, 

421 ) 

422 

423 ax.set_xticks(np.arange(a.shape[1])) 

424 ax.set_yticks(np.arange(a.shape[0])) 

425 ax.set_xticklabels(np.arange(1, a.shape[1] + 1)) 

426 ax.set_yticklabels(np.arange(1, a.shape[0] + 1)) 

427 ax.set_xlim(-0.5, a.shape[1] - 0.5) 

428 ax.set_ylim(-0.5, a.shape[0] - 0.5) 

429 

430 # Add a colorbar for diagonal values 

431 cbar_diag = fig.colorbar( 

432 plt.cm.ScalarMappable(norm=diag_norm, cmap=diag_cmap), ax=ax, shrink=0.5 

433 ) 

434 cbar_diag.set_label("Self Transition Probabilities / \\%") 

435 

436 # Add a colorbar for off-diagonal values 

437 cbar_off_diag = fig.colorbar( 

438 plt.cm.ScalarMappable(norm=off_diag_norm, cmap=custom_off_diag_cmap), 

439 ax=ax, 

440 shrink=0.5, 

441 ) 

442 cbar_off_diag.set_label("Transitiion Probabilities / \\%") 

443 

444 if title: 444 ↛ 447line 444 didn't jump to line 447 because the condition on line 444 was always true

445 ax.set_title(title) 

446 

447 ax.set_xlabel("From Macrostate") 

448 ax.set_ylabel("To Macrostate") 

449 plt.tight_layout() 

450 plt.savefig(out) 

451 plt.close() 

452 

453 

454def transition_time( 

455 a, 

456 out, 

457 lagtime=50.0, 

458 frame_length=0.2, 

459 title=r"Transition Times $\frac{t_\mathrm{lag}}{P}$", 

460 color_thr=0.01, 

461): 

462 """ 

463 Plot heatmap from a matrix. This is supposed for a similarity matrix as 

464 returned from the multiplication of two MPT objects. 

465 frame_length in ns 

466 """ 

467 with np.errstate(divide="ignore"): 

468 a = lagtime / a * frame_length 

469 

470 # Define the colormap for the diagonal elements (logarithmic Reds) 

471 diagonal_values = np.diag(a) 

472 diag_norm = LogNorm(vmin=diagonal_values.min(), vmax=diagonal_values.max()) 

473 diag_cmap = plt.cm.Reds_r 

474 

475 # Adjust the Reds colormap to make the lower bound closer to red 

476 reds_custom = diag_cmap(np.linspace(0, 0.8, 256)) 

477 diag_cmap_custom = ListedColormap(reds_custom) 

478 

479 # Define the colormap for the off-diagonal elements (logarithmic viridis) 

480 off_diag_mask = ~np.eye(a.shape[0], dtype=bool) 

481 off_diag_values = a[off_diag_mask] 

482 

483 # Threshold for light gray 

484 threshold = off_diag_values.min() / color_thr 

485 print(f"Threshold for probabilities: {threshold:.2f} ns") 

486 

487 off_diag_norm = LogNorm( 

488 vmin=off_diag_values.min(), vmax=threshold / (1 - color_thr) 

489 ) 

490 

491 # Create a custom colormap for off-diagonal values including light gray 

492 colors_list = plt.cm.viridis_r(np.linspace(0, 1, 256)) 

493 gray = np.array([0.9, 0.9, 0.9, 1.0]) 

494 colors_list[int((1 - color_thr) * 256) :] = gray 

495 custom_off_diag_cmap = colors.ListedColormap(colors_list) 

496 

497 fig, ax = plt.subplots(figsize=(8, 8)) 

498 ax.set_aspect("equal", "box") 

499 ax.grid(False) 

500 

501 for i in range(a.shape[0]): 

502 for j in range(a.shape[1]): 

503 value = a[i, j] 

504 if value == np.inf: 

505 color = (1, 1, 1, 1) # Zero probabilities are white 

506 elif i == j: 

507 color = diag_cmap_custom(diag_norm(value)) 

508 else: 

509 color = ( 

510 gray 

511 if value > threshold 

512 else custom_off_diag_cmap(off_diag_norm(value)) 

513 ) 

514 

515 ax.add_patch(patches.Rectangle((j - 0.5, i - 0.5), 1, 1, color=color)) 

516 

517 # Add text with transition probabilities 

518 if value != np.inf: 

519 grayscale = np.sum( 

520 np.array(color[:3]) * np.array([0.299, 0.587, 0.114]) 

521 ) 

522 text_color = "white" if grayscale < 0.5 else "black" 

523 if value >= threshold: 

524 pre_text = f"{value:.1g}" 

525 text = pre_text[:2] + pre_text[-1] 

526 else: 

527 if value >= 100: 

528 text = f"{value:.0f}" 

529 else: 

530 text = f"{value:#.3g}" 

531 ax.text( 

532 j, i, text, ha="center", va="center", color=text_color, fontsize=10 

533 ) 

534 

535 ax.set_xticks(np.arange(a.shape[1])) 

536 ax.set_yticks(np.arange(a.shape[0])) 

537 ax.set_xticklabels(np.arange(1, a.shape[1] + 1)) 

538 ax.set_yticklabels(np.arange(1, a.shape[0] + 1)) 

539 ax.set_xlim(-0.5, a.shape[1] - 0.5) 

540 ax.set_ylim(-0.5, a.shape[0] - 0.5) 

541 

542 # Add a colorbar for diagonal values 

543 cbar_diag = fig.colorbar( 

544 plt.cm.ScalarMappable(norm=diag_norm, cmap=diag_cmap), ax=ax, shrink=0.5 

545 ) 

546 cbar_diag.set_label("Self Transition Times / ns") 

547 

548 # Add a colorbar for off-diagonal values 

549 cbar_off_diag = fig.colorbar( 

550 plt.cm.ScalarMappable(norm=off_diag_norm, cmap=custom_off_diag_cmap), 

551 ax=ax, 

552 shrink=0.5, 

553 ) 

554 cbar_off_diag.set_label("Transitiion Times / ns") 

555 

556 if title: 556 ↛ 559line 556 didn't jump to line 559 because the condition on line 556 was always true

557 ax.set_title(title) 

558 

559 ax.set_xlabel("From Macrostate") 

560 ax.set_ylabel("To Macrostate") 

561 plt.tight_layout() 

562 plt.savefig(out) 

563 plt.close() 

564 

565 

566### MACROSTATE FEATURES ###################################################### 

567 

568 

569def macro_feature(micro_feature, out, ref=None, pop=None): 

570 """ 

571 Plot histogram of feature distribution. 

572 

573 micro_feature (np.ndarray, NxR): N microstates, R runs, holds feature 

574 values of respective macrostate 

575 out (str): file to save the plot 

576 ref (list[tuple]): list of 

577 - macrostate_assignment 

578 - macrostate_feature 

579 - color 

580 - label 

581 of the clusterings that should be shown explicitly. 

582 """ 

583 min_feature = micro_feature.min() * 0.95 

584 max_feature = micro_feature.max() * 1.05 

585 counts, bins = np.histogram( 

586 micro_feature, 

587 bins=np.linspace(min_feature, max_feature, 101), 

588 weights=pop, 

589 density=True, 

590 ) 

591 norm_counts = counts / micro_feature.shape[1] 

592 y_min = norm_counts[norm_counts > 0].min() * 0.7 

593 

594 fig, ax = plt.subplots(figsize=(8, 6)) 

595 ax.hist(bins[:-1], bins=bins, weights=norm_counts, label="Stochastic Clustering") 

596 if ref is not None: 596 ↛ 600line 596 didn't jump to line 600 because the condition on line 596 was always true

597 # for mas, mfs, c, l, w in ref: 

598 # add_ref(mas, mfs, ax, color=c, label=l, weights=w) 

599 add_ref(ref.macrostate_assignment[ref.n_i], ref.macrostate_feature[ref.n_i], ax) 

600 ax.set_xlabel("Fraction of Contacts") 

601 ax.set_ylabel("Population") 

602 ax.set_title(f"Macrostate Features, {micro_feature.shape[1]} clusterings") 

603 ax.set_yscale("log") 

604 ylim = ax.get_ylim() 

605 ax.set_ylim((y_min, ylim[1])) 

606 plt.legend(loc="lower left") 

607 plt.tight_layout() 

608 plt.savefig(out) 

609 plt.close() 

610 

611 

612def add_ref( 

613 macrostate_assignment, 

614 macrostate_feature, 

615 ax, 

616 color="r", 

617 label="Reference", 

618 weights=None, 

619): 

620 """ 

621 Add a clustering to the histogram. 

622 

623 macrostate_assignment (np.ndarray, MxN): macrostate assignement, M: number 

624 of macrostates, N: number of microstates. 

625 macrostate_feature (np.ndarray, M): mean feature for every macrostate. 

626 """ 

627 b = True 

628 for i, (ma, mf) in enumerate(zip(macrostate_assignment, macrostate_feature)): 

629 x = [mf, mf] 

630 if weights is None: 

631 weights = np.array([1]) 

632 else: 

633 weights = weights / weights.sum() 

634 y = [1e-9, (ma * weights).sum() / weights.sum() * 1e-3] 

635 if b: 

636 ax.plot(x, y, c=color, label=label + " / 1000") 

637 b = False 

638 else: 

639 ax.plot(x, y, c=color) 

640 pplt.text( 

641 mf + 0.015, 

642 y[1] * 0.82, 

643 f"{i + 1:.0f}", 

644 c=color, 

645 ax=ax, 

646 contour=True, 

647 size="small", 

648 ) 

649 

650 

651### CONTACT REPRESENTATION ################################################### 

652 

653 

654def contact_rep(contacts, cluster_file, state_trajectory, output, grid, scale=1): 

655 """ 

656 Adapted from msmhelper. 

657 

658 Contact representation of states. 

659 

660 This script creates a contact representation of states. Were the states 

661 are obtained by [MoSAIC](https://github.com/moldyn/MoSAIC) and the contact 

662 representation was introduced in Nagel et al.[^1]. 

663 

664 [^1]: Nagel et al., **Selecting Features for Markov Modeling: A Case Study 

665 on HP35.**, *J. Chem. Theory Comput.*, submitted, 

666 

667 """ 

668 # setup matplotlib 

669 pplt.use_style( 

670 figsize=1.2 * scale, 

671 colors="pastel_autumn", 

672 true_black=True, 

673 latex=False, 

674 ) 

675 

676 # load files 

677 states = np.unique(state_trajectory) 

678 clusters = load_clusters(cluster_file) 

679 

680 contact_idxs = np.hstack(clusters) 

681 n_idxs = len(contact_idxs) 

682 n_frames = len(contacts) 

683 

684 xtickpos = ( 

685 np.cumsum( 

686 [ 

687 0, 

688 *[len(clust) for clust in clusters[:-1]], 

689 ] 

690 ) 

691 - 0.5 

692 ) 

693 nrows, ncols = grid 

694 hspace, wspace = 0, 0 

695 ylims = 0, np.quantile(contacts, 0.999) 

696 

697 counter = 0 

698 for chunk in mh.plot._ck_test._split_array(states, nrows * ncols): 

699 fig, axs = plt.subplots( 

700 int(np.ceil(len(chunk) / ncols)), 

701 ncols, 

702 sharex=True, 

703 sharey=True, 

704 squeeze=False, 

705 gridspec_kw={"wspace": wspace, "hspace": hspace}, 

706 ) 

707 

708 # ignore outliers 

709 for state, ax in zip(chunk, axs.flatten()): 

710 contacts_state = contacts[state_trajectory == state] 

711 pop_state = len(contacts_state) / n_frames 

712 

713 # get colormap 

714 c1, c2, c3 = pplt.categorical_color(3, "C0") 

715 

716 stats = { 

717 idx: boxplot_stats(contacts_state[:, idx])[0] for idx in contact_idxs 

718 } 

719 

720 for color, (key_low, key_high), label in ( 

721 (c3, ("whislo", "whishi"), r"$Q_{1/3} \pm 1.5\mathrm{IQR}$"), 

722 (c2, ("q1", "q3"), r"$\mathrm{IQR} = Q_3 - Q_1$"), 

723 ): 

724 ymax = [stats[idx][key_high] for idx in contact_idxs] 

725 ymin = [stats[idx][key_low] for idx in contact_idxs] 

726 ax.stairs( 

727 ymax, 

728 np.arange(n_idxs + 1) - 0.5, 

729 baseline=ymin, 

730 color=color, 

731 lw=0, 

732 fill=True, 

733 label=label, 

734 ) 

735 

736 ax.hlines( 

737 [stats[idx]["med"] for idx in contact_idxs], 

738 xmin=np.arange(n_idxs) - 0.5, 

739 xmax=np.arange(n_idxs) + 0.5, 

740 label="median", 

741 color=c1, 

742 ) 

743 

744 pplt.text( 

745 0.5, 

746 0.95, 

747 rf"S{state + 1} {pop_state:.1%}", 

748 ha="center", 

749 va="top", 

750 ax=ax, 

751 transform=ax.transAxes, 

752 contour=True, 

753 ) 

754 

755 ax.set_xlim([-0.5, n_idxs - 0.5]) 

756 ax.set_ylim(*ylims) 

757 ax.set_xticks(xtickpos) 

758 ax.set_xticklabels(np.arange(len(xtickpos)) + 1) 

759 

760 ax.grid(False) 

761 for pos in xtickpos: 

762 ax.axvline(pos, color="pplt:grid", lw=1.0) 

763 

764 pplt.hide_empty_axes() 

765 pplt.legend( 

766 ax=axs[0, 0], 

767 outside="top", 

768 bbox_to_anchor=( 

769 0, 

770 1.0, 

771 axs.shape[1] + wspace * (axs.shape[1] - 1), 

772 0.01, 

773 ), 

774 frameon=False, 

775 ncol=2, 

776 ) 

777 pplt.subplot_labels( 

778 xlabel="contact clusters", 

779 ylabel="distances [nm]", 

780 ) 

781 

782 # save figure and continue 

783 if output is None: 783 ↛ 784line 783 didn't jump to line 784 because the condition on line 783 was never true

784 plt.show() 

785 # output = f"{state_file}.contactRep.pdf" 

786 # insert state_str between pathname and extension 

787 path, ext = splitext(output) 

788 if counter == 0: 788 ↛ 792line 788 didn't jump to line 792 because the condition on line 788 was always true

789 pplt.savefig(output) 

790 plt.close() 

791 else: 

792 pplt.savefig(f"{path}.state{chunk[0]:.0f}-{chunk[-1]:.0f}{ext}") 

793 plt.close() 

794 counter += 1 

795 

796 

797### SANKEY ################################################################### 

798 

799 

800def sankey_diagram(cl, ref, out, ax=None, scale=1): 

801 features = [] 

802 for macrostate in cl.tree[cl.n_i].macrostates: 

803 features.append(macrostate.feature) 

804 ma_order = np.argsort(features)[::-1] 

805 colorDict = {} 

806 for i, o in enumerate(ma_order): 

807 colorDict[str(i + 1)] = cl.tree[cl.n_i].macrostates[o].color 

808 if ax is None: 808 ↛ 810line 808 didn't jump to line 810 because the condition on line 808 was always true

809 pplt.use_style(figsize=(1.7 * scale, 3.6 * scale), true_black=True) 

810 sankey( 

811 left=(cl.macrostate_map[cl.n_i] + 1).astype(str), 

812 right=(ref.macrostate_map[0] + 1).astype(str), 

813 leftWeight=ref.pop, 

814 rightWeight=ref.pop, 

815 leftLabels=np.arange(1, cl.n_macrostates[cl.n_i] + 1).astype(str).tolist(), 

816 rightLabels=np.arange(1, ref.n_macrostates[0] + 1).astype(str).tolist(), 

817 colorDict=colorDict, 

818 ax=ax, 

819 ) 

820 if ax is None: 820 ↛ exitline 820 didn't return from function 'sankey_diagram' because the condition on line 820 was always true

821 pplt.savefig(out) 

822 plt.close() 

823 

824 

825### RMSD LINES ############################################################### 

826 

827 

828def rmsd(rmsds, pops, helices=None, filename=None): 

829 """ 

830 Plots a 2D NumPy array as a heatmap with a logarithmic color scale and variable row heights. 

831 

832 Parameters: 

833 - vars (np.ndarray): The 2D NumPy array to plot. Values must be positive for logarithmic scaling. 

834 - row_heights (np.ndarray): 1D array defining the height of each row. 

835 - helices (np.ndarray): Array with start and end points for blocks to be indicated in the bottom row. 

836 - filename (str, optional): If provided, saves the heatmap to this file. 

837 """ 

838 # Ensure all values are positive for logarithmic scaling 

839 if np.any(rmsds <= 0): 839 ↛ 840line 839 didn't jump to line 840 because the condition on line 839 was never true

840 raise ValueError( 

841 "All values in `rmsds` must be positive for logarithmic scaling." 

842 ) 

843 

844 if rmsds.shape[0] != len(pops): 844 ↛ 845line 844 didn't jump to line 845 because the condition on line 844 was never true

845 raise ValueError("Length of `pops` must match the number of rows in `rmsds`.") 

846 

847 if helices is not None: 847 ↛ 850line 847 didn't jump to line 850 because the condition on line 847 was always true

848 n_plots = rmsds.shape[0] + 1 

849 else: 

850 n_plots = rmsds.shape[0] 

851 

852 w = 0.08 * rmsds.shape[1] + 3 # 8.6 

853 h = 1 + 0.4 * n_plots # 6 

854 pplt.use_style( 

855 figsize=(w, h), 

856 colors="pastel_autumn", 

857 true_black=True, 

858 latex=False, 

859 ) 

860 fig, axs = plt.subplots( 

861 n_plots, 

862 3, 

863 sharex="col", 

864 width_ratios=[rmsds.shape[1], 8, 8], 

865 gridspec_kw={"wspace": 0, "hspace": 0}, 

866 ) 

867 

868 ylim = 0.5 * rmsds.min(), 2 * rmsds.max() 

869 pops = pops / pops.sum() 

870 ylim_hist = 0, 1.05 * pops.max() 

871 

872 rmsd_sums = rmsds[:, 2:-2].sum(axis=1) 

873 ylim_rmsd = 0, 1.05 * rmsd_sums.max() 

874 

875 for i, ((ax, hist_ax, rmsd_ax), rmsd, pop) in enumerate( 

876 zip(axs[:-1] if helices is not None else axs, rmsds, pops) 

877 ): 

878 rect = patches.Rectangle( 

879 (0, 0.3), # Position of the block 

880 pop, 

881 0.4, # color='black' 

882 ) 

883 hist_ax.add_patch(rect) 

884 hist_ax.set_xlim(ylim_hist) 

885 hist_ax.set_yticks([], []) 

886 hist_ax.grid(False) 

887 

888 rect = patches.Rectangle( 

889 (0, 0.3), # Position of the block 

890 rmsd[2:-2].sum(), 

891 0.4, # color='black' 

892 ) 

893 rmsd_ax.add_patch(rect) 

894 rmsd_ax.set_xlim(ylim_rmsd) 

895 rmsd_ax.set_yticks([], []) 

896 rmsd_ax.grid(False) 

897 

898 ax.plot(np.arange(rmsd.shape[0]) + 1, rmsd) 

899 ax.fill_between( 

900 np.arange(rmsd.shape[0]) + 1, 

901 [ylim[0]] * rmsd.shape[0], 

902 rmsd, 

903 alpha=0.5, 

904 # facecolor="none", 

905 # hatch="/", 

906 ) 

907 

908 ax.set_yscale("log") 

909 ax.set_ylabel(f"{i + 1}") 

910 ax.set_xlim((0.5, rmsd.shape[0] + 0.5)) 

911 ax.set_ylim(ylim) 

912 ax.grid(True) 

913 

914 if helices is not None: 914 ↛ 971line 914 didn't jump to line 971 because the condition on line 914 was always true

915 line_start = 1 

916 helices_ax = axs[-1, 0] 

917 # helices_ax.plot([1, rmsds.shape[1]], [0.5, 0.5]) #, c="k") 

918 for start, end in helices: 

919 if start > 0: 

920 # Helices 

921 start -= 0.3 

922 end += 0.3 

923 rect = patches.Rectangle( 

924 (start, 0.3), # Position of the block 

925 end - start, 

926 0.4, # color='black' 

927 fc="#264653", 

928 ec="#264653", 

929 lw=2, 

930 ) 

931 else: 

932 # Sheets 

933 start, end = -start, -end 

934 start -= 0.5 

935 end += 0.5 

936 rect = patches.Rectangle( 

937 (start, 0.3), # Position of the block 

938 end - start, 

939 0.4, # color='black' 

940 fc="white", 

941 ec="#264653", 

942 lw=2, 

943 ) 

944 helices_ax.plot( 

945 [line_start, start], 

946 [0.5, 0.5], 

947 solid_capstyle="butt", 

948 c="#264653", 

949 lw=2, 

950 ) 

951 line_start = end 

952 helices_ax.add_patch(rect) 

953 helices_ax.plot( 

954 [line_start, rmsds.shape[1]], 

955 [0.5, 0.5], 

956 solid_capstyle="butt", 

957 c="#264653", 

958 lw=2, 

959 ) 

960 

961 helices_ax.set_ylim((0, 1)) 

962 helices_ax.set_ylabel("H") 

963 helices_ax.set_yticks([], []) 

964 helices_ax.grid(False) 

965 

966 axs[-1, 1].grid(False) 

967 axs[-1, 1].set_yticks([], []) 

968 axs[-1, 2].grid(False) 

969 axs[-1, 2].set_yticks([], []) 

970 

971 hist_ticks = axs[-1, 1].get_xticks() 

972 hist_labels = axs[-1, 1].get_xticklabels() 

973 hist_labels[0] = "" 

974 axs[-1, 1].set_xticks(hist_ticks, hist_labels) 

975 

976 rmsd_ticks = axs[-1, 2].get_xticks() 

977 rmsd_labels = axs[-1, 2].get_xticklabels() 

978 rmsd_labels[0] = "" 

979 axs[-1, 2].set_xticks(rmsd_ticks, rmsd_labels) 

980 

981 axs[-1, 0].xaxis.set_major_locator(MultipleLocator(5)) 

982 axs[-1, 0].xaxis.set_minor_locator(MultipleLocator(1)) 

983 axs[-1, 0].set_xlabel("Residue") 

984 axs[-1, 1].set_xlabel("Population", rotation=20) 

985 axs[-1, 2].set_xlabel(r"$\sum$ RMSD / nm", rotation=20) 

986 fig.supylabel("Macrostate; RMSD Variance / nm") 

987 

988 # Save to file if filename is provided 

989 plt.tight_layout() 

990 if filename: 990 ↛ 993line 990 didn't jump to line 993 because the condition on line 990 was always true

991 plt.savefig(filename, dpi=192) # , bbox_inches="tight" 

992 else: 

993 plt.show() 

994 plt.close() 

995 

996 

997def delta_rmsd(rmsds, pops, helices=None, filename=None): 

998 """ 

999 Plots a 2D NumPy array as a heatmap with a logarithmic color scale and variable row heights. 

1000 

1001 Parameters: 

1002 - vars (np.ndarray): The 2D NumPy array to plot. Values must be positive for logarithmic scaling. 

1003 - row_heights (np.ndarray): 1D array defining the height of each row. 

1004 - helices (np.ndarray): Array with start and end points for blocks to be indicated in the bottom row. 

1005 - filename (str, optional): If provided, saves the heatmap to this file. 

1006 """ 

1007 # Ensure all values are positive for logarithmic scaling 

1008 if np.any(rmsds <= 0): 1008 ↛ 1009line 1008 didn't jump to line 1009 because the condition on line 1008 was never true

1009 raise ValueError( 

1010 "All values in `rmsds` must be positive for logarithmic scaling." 

1011 ) 

1012 

1013 if rmsds.shape[0] != len(pops): 1013 ↛ 1014line 1013 didn't jump to line 1014 because the condition on line 1013 was never true

1014 raise ValueError("Length of `pops` must match the number of rows in `rmsds`.") 

1015 

1016 if helices is not None: 1016 ↛ 1019line 1016 didn't jump to line 1019 because the condition on line 1016 was always true

1017 n_plots = rmsds.shape[0] + 1 

1018 else: 

1019 n_plots = rmsds.shape[0] 

1020 

1021 w = 0.08 * rmsds.shape[1] + 3 # 8.6 

1022 h = 1 + 0.4 * n_plots # 6 

1023 pplt.use_style( 

1024 figsize=(w, h), 

1025 colors="pastel_autumn", 

1026 true_black=True, 

1027 latex=False, 

1028 ) 

1029 fig, axs = plt.subplots( 

1030 n_plots, 

1031 3, 

1032 sharex="col", 

1033 width_ratios=[rmsds.shape[1], 8, 8], 

1034 gridspec_kw={"wspace": 0, "hspace": 0}, 

1035 ) 

1036 

1037 delta_rmsd = rmsds 

1038 # delta_rmsd[1:] = rmsds[1:] - rmsds[:-1] 

1039 delta_rmsd[1:] = rmsds[1:] - rmsds[0] 

1040 rmsds = delta_rmsd 

1041 

1042 rmsd_max = rmsds.max() 

1043 rmsd_min = rmsds.min() 

1044 rmsd_delta = rmsd_max - rmsd_min 

1045 

1046 ylim = rmsd_min - rmsd_delta * 0.1, rmsd_max + rmsd_delta * 0.15 

1047 pops = pops / pops.sum() 

1048 ylim_hist = 0, 1.05 * pops.max() 

1049 

1050 rmsd_sums = rmsds.sum(axis=1) 

1051 ylim_rmsd = 0, 1.05 * rmsd_sums.max() 

1052 

1053 for i, ((ax, hist_ax, rmsd_ax), rmsd, pop) in enumerate( 

1054 zip(axs[:-1] if helices is not None else axs, rmsds, pops) 

1055 ): 

1056 rect = patches.Rectangle( 

1057 (0, 0.3), # Position of the block 

1058 pop, 

1059 0.4, # color='black' 

1060 ) 

1061 hist_ax.add_patch(rect) 

1062 hist_ax.set_xlim(ylim_hist) 

1063 hist_ax.set_yticks([], []) 

1064 hist_ax.grid(False) 

1065 

1066 rect = patches.Rectangle( 

1067 (0, 0.3), # Position of the block 

1068 abs(rmsd).sum(), 

1069 0.4, # color='black' 

1070 ) 

1071 rmsd_ax.add_patch(rect) 

1072 rmsd_ax.set_xlim(ylim_rmsd) 

1073 rmsd_ax.set_yticks([], []) 

1074 rmsd_ax.grid(False) 

1075 

1076 ax.plot(np.arange(rmsd.shape[0]) + 1, rmsd) 

1077 ax.fill_between( 

1078 np.arange(rmsd.shape[0]) + 1, 

1079 # [ylim[0]]*rmsd.shape[0], 

1080 [0] * rmsd.shape[0], 

1081 rmsd, 

1082 alpha=0.5, 

1083 # facecolor="none", 

1084 # hatch="/", 

1085 ) 

1086 

1087 # ax.set_yscale("log") 

1088 ax.set_ylabel(f"{i + 1}") 

1089 # ax.set_ylabel(f"{i+1}-{i}") 

1090 # ax.set_ylabel(f"{i+1}-{i}", rotation=90, position=(-0.2, 0)) 

1091 ax.set_xlim((0.5, rmsd.shape[0] + 0.5)) 

1092 ax.set_ylim(ylim) 

1093 ax.grid(True) 

1094 

1095 if helices is not None: 1095 ↛ 1152line 1095 didn't jump to line 1152 because the condition on line 1095 was always true

1096 line_start = 1 

1097 helices_ax = axs[-1, 0] 

1098 # helices_ax.plot([1, rmsds.shape[1]], [0.5, 0.5]) #, c="k") 

1099 for start, end in helices: 

1100 if start > 0: 

1101 # Helices 

1102 start -= 0.3 

1103 end += 0.3 

1104 rect = patches.Rectangle( 

1105 (start, 0.3), # Position of the block 

1106 end - start, 

1107 0.4, # color='black' 

1108 fc="#264653", 

1109 ec="#264653", 

1110 lw=2, 

1111 ) 

1112 else: 

1113 # Sheets 

1114 start, end = -start, -end 

1115 start -= 0.5 

1116 end += 0.5 

1117 rect = patches.Rectangle( 

1118 (start, 0.3), # Position of the block 

1119 end - start, 

1120 0.4, # color='black' 

1121 fc="white", 

1122 ec="#264653", 

1123 lw=2, 

1124 ) 

1125 helices_ax.plot( 

1126 [line_start, start], 

1127 [0.5, 0.5], 

1128 solid_capstyle="butt", 

1129 c="#264653", 

1130 lw=2, 

1131 ) 

1132 line_start = end 

1133 helices_ax.add_patch(rect) 

1134 helices_ax.plot( 

1135 [line_start, rmsds.shape[1]], 

1136 [0.5, 0.5], 

1137 solid_capstyle="butt", 

1138 c="#264653", 

1139 lw=2, 

1140 ) 

1141 

1142 helices_ax.set_ylim((0, 1)) 

1143 helices_ax.set_ylabel("H") 

1144 helices_ax.set_yticks([], []) 

1145 helices_ax.grid(False) 

1146 

1147 axs[-1, 1].grid(False) 

1148 axs[-1, 1].set_yticks([], []) 

1149 axs[-1, 2].grid(False) 

1150 axs[-1, 2].set_yticks([], []) 

1151 

1152 axs[0, 0].set_ylim(0.5 * rmsds[0].min(), 2 * rmsds[0].max()) 

1153 axs[0, 0].set_yscale("log") 

1154 # axs[0, 0].set_ylabel("1") #, rotation=90) 

1155 

1156 hist_ticks = axs[-1, 1].get_xticks() 

1157 hist_labels = axs[-1, 1].get_xticklabels() 

1158 hist_labels[0] = "" 

1159 axs[-1, 1].set_xticks(hist_ticks, hist_labels) 

1160 

1161 rmsd_ticks = axs[-1, 2].get_xticks() 

1162 rmsd_labels = axs[-1, 2].get_xticklabels() 

1163 rmsd_labels[0] = "" 

1164 axs[-1, 2].set_xticks(rmsd_ticks, rmsd_labels) 

1165 

1166 axs[-1, 0].xaxis.set_major_locator(MultipleLocator(5)) 

1167 axs[-1, 0].xaxis.set_minor_locator(MultipleLocator(1)) 

1168 axs[-1, 0].set_xlabel("Residue") 

1169 axs[-1, 1].set_xlabel("Population", rotation=20) 

1170 axs[-1, 2].set_xlabel(r"$\sum$ |$\Delta$RMSD| / nm", rotation=20) 

1171 fig.supylabel(r"$\Delta$Macrostate; $\Delta$RMSD Variance / nm wrt. 1st State") 

1172 

1173 # Save to file if filename is provided 

1174 plt.tight_layout() 

1175 if filename: 1175 ↛ 1178line 1175 didn't jump to line 1178 because the condition on line 1175 was always true

1176 plt.savefig(filename, dpi=192) # , bbox_inches="tight" 

1177 else: 

1178 plt.show() 

1179 plt.close() 

1180 

1181 

1182### TRAJECTORY ############################################################### 

1183 

1184 

1185def state_trajectory(trajectory, filename, row_length=0.2, frame_length=0.2): 

1186 """ 

1187 Plot state trajectory 

1188 

1189 trajectory (np.ndarray): state trajectory 

1190 filename (str): file name to save the plot to 

1191 row_length (int|float): 

1192 row_length > 1: number of frames in each row 

1193 0 < row_length <= 1: fraction of total frames per row (1/n_rows) 

1194 frame_length (float): frame length in ns 

1195 """ 

1196 if row_length > 1: 1196 ↛ 1197line 1196 didn't jump to line 1197 because the condition on line 1196 was never true

1197 x_max = int(row_length) 

1198 elif row_length > 0: 1198 ↛ 1201line 1198 didn't jump to line 1201 because the condition on line 1198 was always true

1199 x_max = int(np.ceil(trajectory.shape[0] * row_length)) 

1200 else: 

1201 raise ValueError("row_lengthg must be > 0") 

1202 

1203 frame_length /= 1000.0 

1204 # Calculate unique states and their lengths 

1205 unique_states, lengths = utils.find_state_lengths(trajectory) 

1206 unique_states += 1 

1207 lengths = lengths * frame_length 

1208 n_rows = int(np.ceil(trajectory.shape[0] / x_max)) 

1209 

1210 x_max *= frame_length 

1211 

1212 figsize = (11.7, 8.3) 

1213 # # Set up figure size proportional to data 

1214 # width = max(6, x_max * 0.0001) # Minimum width of 6 inches 

1215 # height = max( 

1216 # 2, (unique_states.max() - unique_states.min() + 1) * 0.05 * n_rows + 0.6 

1217 # ) # Minimum height of 4 inches 

1218 # figsize = (width * 1.5, height * 1.5) 

1219 

1220 u_states = np.unique(unique_states) 

1221 cmap = mpl.cm.turbo 

1222 norm = mpl.colors.BoundaryNorm(np.concatenate([[0], u_states]) + 0.5, cmap.N) 

1223 

1224 fig, axs = plt.subplots( 

1225 n_rows, 

1226 1, 

1227 sharex=True, 

1228 figsize=figsize, 

1229 gridspec_kw={"wspace": 0, "hspace": 0}, 

1230 squeeze=False, 

1231 layout="compressed", 

1232 ) 

1233 axs = axs[:, 0] 

1234 axi = 0 

1235 

1236 # Plot each state occurrence as a line segment 

1237 x_start = 0 # Initial x-coordinate for the first segment 

1238 for state, length in zip(unique_states, lengths): 

1239 x_end = x_start + length # Calculate end position of this segment on the x-axis 

1240 color = cmap(norm(state)) 

1241 

1242 while x_end > x_max: 

1243 axs[axi].plot( 

1244 [x_start, x_max], 

1245 [state, state], 

1246 color=color[:3], 

1247 linewidth=3, 

1248 solid_capstyle="butt", 

1249 ) 

1250 x_end -= x_max 

1251 x_start = 0 

1252 axi += 1 

1253 if not np.isclose(x_start, x_end): 

1254 axs[axi].plot( 

1255 [x_start, x_end], 

1256 [state, state], 

1257 color=color, 

1258 linewidth=3, 

1259 solid_capstyle="butt", 

1260 ) 

1261 

1262 # Move x_start to the end of the current segment for the next one 

1263 x_start = x_end 

1264 

1265 # Label axes and set title 

1266 cbar = fig.colorbar( 

1267 ScalarMappable(norm=norm, cmap=cmap), 

1268 ax=axs, 

1269 orientation="vertical", 

1270 label="Macrostate", 

1271 ) 

1272 cbar.set_ticks(ticks=u_states, labels=u_states, minor=False) 

1273 fig.axes[-1].tick_params(length=0) 

1274 

1275 fig.supylabel("State Index") 

1276 axs[-1].set_xlabel(r"t / $\mu$s") 

1277 

1278 for ax in axs: 

1279 ax.set_ylim(unique_states.min() - 1, unique_states.max()) 

1280 

1281 # Set axis limits 

1282 plt.xlim(0, x_max) 

1283 

1284 # Save the plot to the specified file 

1285 plt.savefig(filename) 

1286 plt.close() # Close the plot to free memory 

1287 

1288 

1289### CHAPMAN-KOLMOGOROV TEST ################################################## 

1290 

1291 

1292def chapman_kolmogorov(mpt, out, frame_length=0.2): 

1293 """Chapman-Kolmogorov Test. Frame length in ns""" 

1294 ck = mh.msm.tests.chapman_kolmogorov_test( 

1295 utils.get_multi_state_trajectory( 

1296 mpt.macrostate_trajectory[mpt.n_i], mpt.limits 

1297 ), 

1298 [50, 50, 50, 50, 50], 

1299 4000, 

1300 # int(1550*frame_length), 

1301 ) 

1302 pplt.use_style( 

1303 figsize=4.8, 

1304 colors="pastel_autumn", 

1305 true_black=True, 

1306 latex=False, 

1307 ) 

1308 

1309 nrows, ncols = utils.get_grid_format(mpt.n_macrostates[mpt.n_i]) 

1310 for chunk in mh.plot._ck_test._split_array( 

1311 np.arange(mpt.n_macrostates[mpt.n_i]), nrows * ncols 

1312 ): 

1313 fig = plot_ck_test( 

1314 ck=ck, 

1315 states=chunk, 

1316 frames_per_unit=1 / frame_length, 

1317 unit="ns", 

1318 grid=(ncols, nrows), 

1319 ) 

1320 

1321 for ax in fig.axes: 

1322 for text in ax.texts: 

1323 text.set_position((0.15, 0.2)) 

1324 plt.savefig(out) 

1325 plt.close() 

1326 

1327 

1328def plot_ck_test( 

1329 ck, 

1330 states=None, 

1331 frames_per_unit=1, 

1332 unit="frames", 

1333 grid=(3, 3), 

1334): 

1335 """Plot CK-Test results. 

1336 

1337 This routine is a basic helper function to visualize the results of 

1338 [msmhelper.msm.chapman_kolmogorov_test][]. 

1339 

1340 Parameters 

1341 ---------- 

1342 ck : dict 

1343 Dictionary holding for each lagtime the CK equation and with 'md' the 

1344 reference. 

1345 states : ndarray, optional 

1346 List containing all states to plot the CK-test. 

1347 frames_per_unit : float, optional 

1348 Number of frames per given unit. This is used to scale the axis 

1349 accordingly. 

1350 unit : ['frames', 'fs', 'ps', 'ns', 'us'], optional 

1351 Unit to use for label. 

1352 grid : (int, int), optional 

1353 The number of `(n_rows, n_cols)` to use for the grid layout. 

1354 

1355 Returns 

1356 ------- 

1357 fig : matplotlib.Figure 

1358 Figure holding plots. 

1359 

1360 Notes 

1361 ----- 

1362 Adapted from msmhelper. 

1363 

1364 """ 

1365 # load colors 

1366 pplt.load_cmaps() 

1367 pplt.load_colors() 

1368 

1369 lagtimes = np.array([key for key in ck.keys() if key != "md"]) 

1370 if states is None: 1370 ↛ 1371line 1370 didn't jump to line 1371 because the condition on line 1370 was never true

1371 states = np.array(list(ck["md"]["ck"].keys())) 

1372 

1373 nrows, ncols = grid 

1374 needed_rows = int(np.ceil(len(states) / ncols)) 

1375 

1376 fig, axs = plt.subplots( 

1377 needed_rows, 

1378 ncols, 

1379 sharex=True, 

1380 sharey="row", 

1381 gridspec_kw={"wspace": 0, "hspace": 0}, 

1382 ) 

1383 axs = np.atleast_2d(axs) 

1384 

1385 max_time = np.max(ck["md"]["time"]) 

1386 for irow, states_row in enumerate(mh.plot._ck_test._split_array(states, ncols)): 

1387 for icol, state in enumerate(states_row): 

1388 ax = axs[irow, icol] 

1389 

1390 pplt.plot( 

1391 ck["md"]["time"] / frames_per_unit, 

1392 ck["md"]["ck"][state], 

1393 "--", 

1394 ax=ax, 

1395 color="pplt:gray", 

1396 label="MD", 

1397 ) 

1398 for lagtime in lagtimes: 

1399 pplt.plot( 

1400 ck[lagtime]["time"] / frames_per_unit, 

1401 ck[lagtime]["ck"][state], 

1402 ax=ax, 

1403 label=lagtime / frames_per_unit, 

1404 ) 

1405 pplt.text( 

1406 0.5, 

1407 0.9, 

1408 "S{0}".format(state + 1), 

1409 contour=True, 

1410 va="top", 

1411 transform=ax.transAxes, 

1412 ax=ax, 

1413 ) 

1414 

1415 # set scale 

1416 ax.set_xscale("log") 

1417 ax.set_xlim( 

1418 [ 

1419 lagtimes[0] / frames_per_unit, 

1420 max_time / frames_per_unit, 

1421 ] 

1422 ) 

1423 ax.set_ylim([0, 1]) 

1424 if irow < len(axs) - 1: 

1425 ax.set_yticks([0.5, 1]) 

1426 else: 

1427 ax.set_yticks([0, 0.5, 1]) 

1428 

1429 ax.grid(True, which="major", linestyle="--") 

1430 ax.grid(True, which="minor", linestyle="dotted") 

1431 ax.set_axisbelow(True) 

1432 

1433 # set legend 

1434 legend_kw = ( 

1435 { 

1436 "outside": "right", 

1437 "bbox_to_anchor": (2.0, (1 - nrows), 0.2, nrows), 

1438 } 

1439 if ncols in {1, 2} 

1440 else { 

1441 "outside": "top", 

1442 "bbox_to_anchor": (0.0, 1.0, ncols, 0.01), 

1443 } 

1444 ) 

1445 if ncols == 3: 1445 ↛ 1446line 1445 didn't jump to line 1446 because the condition on line 1445 was never true

1446 legend_kw["ncol"] = 3 

1447 pplt.legend( 

1448 ax=axs[0, 0], 

1449 **legend_kw, 

1450 title=rf"$\tau_\mathrm{{lag}}$ [{unit}]", 

1451 frameon=False, 

1452 ) 

1453 

1454 ylabel = ( 

1455 (r"self-transition probability $P_{i\to i}$") 

1456 if nrows >= 3 

1457 else (r"$P_{i\to i}$") 

1458 ) 

1459 

1460 pplt.hide_empty_axes() 

1461 pplt.label_outer() 

1462 pplt.subplot_labels( 

1463 ylabel=ylabel, 

1464 xlabel=r"time $t$ [{unit}]".format(unit=unit), 

1465 ) 

1466 return fig 

1467 

1468 

1469### MACROSTATE GRAPH ######################################################### 

1470 

1471 

1472def state_network(lumping, out): 

1473 draw_knetwork( 

1474 lumping.macrostate_trajectory[lumping.n_i], 

1475 lumping.lagtime, 

1476 lumping.mean_feature_trajectory, 

1477 out, 

1478 )