Coverage for MPT/plot.py: 90%

643 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-11 11:01 +0200

1#!/usr/bin/env python3 

2""" 

3plot.py 

4================== 

5 

6Various plot functions used in this package. 

7""" 

8 

9from os.path import splitext 

10 

11import numpy as np 

12import prettypyplot as pplt 

13import matplotlib as mpl 

14from matplotlib import pyplot as plt 

15from matplotlib.cm import ScalarMappable 

16from matplotlib.colors import ( 

17 LinearSegmentedColormap, 

18 LogNorm, 

19 ListedColormap, 

20) 

21from matplotlib import colors 

22from matplotlib.cbook import boxplot_stats 

23import matplotlib.animation as animation 

24import matplotlib.patches as patches 

25from matplotlib.ticker import MultipleLocator 

26import msmhelper as mh 

27from msmhelper._cli.contact_rep import load_clusters 

28from scipy.stats import pearsonr 

29import MPT.utils as utils 

30from MPT.sankey_gap import sankey 

31from MPT.graph import draw_knetwork 

32 

33plt.rcParams["font.family"] = "sans-serif" 

34 

35### DENDROGRAM ############################################################### 

36 

37 

38def plot_tree(root, macrostate_assignment, output_file, scale=1, offset=0): 

39 """ 

40 Plot the dendrogram from a given state tree of BinaryTreeNode. 

41 """ 

42 n_states = len(root.leaves) 

43 

44 # setup matplotlib 

45 pplt.use_style(figsize=3.2 * scale, figratio="golden", true_black=True) 

46 plt.rcParams["font.family"] = "sans-serif" 

47 

48 fig, (ax, ax_mat) = plt.subplots( 

49 2, 

50 1, 

51 gridspec_kw={ 

52 "hspace": 0.05, 

53 "height_ratios": [9, 1], 

54 }, 

55 ) 

56 for key, spine in ax_mat.spines.items(): 

57 spine.set_visible(False) 

58 

59 ax = root.plot_tree(ax) 

60 

61 ax.set_ylabel(r"Metastability $T_{ii}$") 

62 ax.set_xlabel("microstates") 

63 ax.set_xlim(-0.005 * n_states, 1.005 * n_states) 

64 ax.set_ylim(offset, 1.05) 

65 

66 # plot legend 

67 cmap = plt.get_cmap("plasma_r", 10) 

68 label = r"Fraction of Contacts $q$" 

69 

70 cmappable = ScalarMappable(root.feature_norm, cmap) 

71 plt.sca(ax) 

72 pplt.colorbar(cmappable, width="5%", label=label, position="top") 

73 

74 # bring microstates in the right order 

75 macrostate_assignment = macrostate_assignment[:, [l.name for l in root.leaves]] 

76 

77 yticks = np.arange(0.5, 1.5 + macrostate_assignment.shape[0]) 

78 xticks = np.arange(0, n_states + 1) 

79 cmap = LinearSegmentedColormap.from_list( 

80 "binary", 

81 [(0, 0, 0, 0), (0, 0, 0, 1)], 

82 ) 

83 

84 xvals = 0.5 * (xticks[:-1] + xticks[1:]) 

85 for idx, assignment in enumerate(macrostate_assignment): 

86 xmean = np.median(xvals[assignment == 1]) 

87 

88 pplt.text( 

89 xmean, 

90 yticks[idx] - (yticks[1] - yticks[0]), 

91 f"{idx + 1:.0f}", 

92 ax=ax_mat, 

93 va="top", 

94 contour=True, 

95 size="small", 

96 ) 

97 

98 # Plot macrostate assignments 

99 ax_mat.pcolormesh( 

100 xticks, 

101 yticks, 

102 macrostate_assignment, 

103 snap=True, 

104 cmap=cmap, 

105 vmin=0, 

106 vmax=1, 

107 ) 

108 # set x-labels 

109 ax_mat.set_yticks(yticks) 

110 ax_mat.set_yticklabels([]) 

111 ax_mat.grid(visible=True, axis="y", ls="-", lw=0.5) 

112 ax_mat.tick_params(axis="y", length=0, width=0) 

113 ax_mat.set_xlim(ax.get_xlim()) 

114 ax.set_xlabel("") 

115 ax_mat.set_xlabel("Macrostates") 

116 ax_mat.set_ylabel("") 

117 fig.align_ylabels([ax, ax_mat]) 

118 

119 ax_mat.set_xticks(np.arange(0.5, 0.5 + n_states)) 

120 

121 # Hide microstate labels 

122 for axes in (ax, ax_mat): 

123 axes.set_xticks([]) 

124 axes.set_xticks([], minor=True) 

125 axes.set_xticklabels([]) 

126 axes.set_xticklabels([], minor=True) 

127 

128 pplt.savefig(output_file) 

129 plt.close() 

130 

131 

132### SIMILARITY ############################################################### 

133 

134 

135def evaluate_stochastic_clustering(mpt1, mpt2, out): 

136 """ 

137 Plot similarity values for a reference and a stochastic clustering. 

138 """ 

139 ref, sto, S = mpt1 + mpt2 

140 s1, s2, s3 = S 

141 n_states = S.shape[1] 

142 x, y = utils.get_grid_format(n_states) 

143 fig, axs = plt.subplots(y, x, figsize=(2 * x, 2 * y)) 

144 for state, ax in enumerate(axs.flatten()[:n_states]): 

145 m = 0 

146 # Set left limit to minimum instead of 0 

147 m = min([min(s1[state]), min(s2[state]), min(s3[state])]) - 0.02 

148 

149 ax.hist(s1[state], bins=np.linspace(m, 1, 21), color="g", alpha=0.7) 

150 ax.hist(s2[state], bins=np.linspace(m, 1, 21), color="b", alpha=0.7) 

151 ax.hist(s3[state], bins=np.linspace(m, 1, 21), color="r", alpha=0.7) 

152 ax.set_title(f"state {state + 1}") 

153 fig.supxlabel("Macrostate similarity") 

154 fig.supylabel(f"Count of clusterings ({sto.n_runs} clusterings)") 

155 leg = plt.figlegend( 

156 ["union", "reference", "clustering"], 

157 ncols=3, 

158 loc="lower center", 

159 bbox_to_anchor=(0.5, 0.05), 

160 ) 

161 plt.tight_layout(rect=(0, 0.04, 1, 1)) 

162 plt.savefig(out) 

163 plt.close() 

164 

165 

166### IMPLIED TIMESCALES ####################################################### 

167 

168 

169def plot_implied_timescales( 

170 trajs, 

171 lagtimes, 

172 out, 

173 titles="", 

174 frame_length=0.2, 

175 first_ref=False, 

176 scale=1, 

177 use_ref=True, 

178 ntimescales=3, 

179): 

180 """ 

181 frame_length in ns / frame 

182 """ 

183 if first_ref: 183 ↛ 185line 183 didn't jump to line 185 because the condition on line 183 was always true

184 ref_traj = trajs.pop(0) 

185 x, y = utils.get_grid_format(len(trajs)) 

186 pplt.use_style( 

187 figsize=(3.8 * scale, 3.2 * scale), latex=False, colors="pastel_autumn" 

188 ) 

189 fig, axs = plt.subplots(y, x, sharex=True, sharey=True) 

190 plt.grid(False) 

191 if not isinstance(axs, np.ndarray): 191 ↛ 194line 191 didn't jump to line 194 because the condition on line 191 was always true

192 axs = np.array([axs]) 

193 

194 if titles != "": 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true

195 titles = titles 

196 else: 

197 titles = [""] * len(trajs) 

198 

199 min_it = None 

200 max_it = None 

201 

202 if first_ref: 202 ↛ 209line 202 didn't jump to line 209 because the condition on line 202 was always true

203 it_ref = mh.msm.implied_timescales(ref_traj, lagtimes, ntimescales=ntimescales) 

204 # change from frames to ns 

205 it_ref *= frame_length 

206 min_it = it_ref.min() 

207 max_it = it_ref.max() 

208 

209 tlag = lagtimes[-1] / 4.5 * frame_length 

210 lagtimes_ns = lagtimes * frame_length 

211 for ax, traj, title in zip(axs.flatten(), trajs, titles): 

212 ax.axvline(tlag, color="pplt:grid") 

213 it = mh.msm.implied_timescales(traj, lagtimes, ntimescales=ntimescales) 

214 # change from frames to ns 

215 it *= frame_length 

216 if min_it is None: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true

217 min_it = it.min() 

218 else: 

219 min_it = min(it.min(), min_it) 

220 if max_it is None: 220 ↛ 221line 220 didn't jump to line 221 because the condition on line 220 was never true

221 max_it = it.max() 

222 else: 

223 max_it = max(it.max(), max_it) 

224 

225 if first_ref: 225 ↛ 230line 225 didn't jump to line 230 because the condition on line 225 was always true

226 if not use_ref: 226 ↛ 229line 226 didn't jump to line 229 because the condition on line 226 was always true

227 _plot_impl_times(it_ref, lagtimes_ns, ax, ls="--") 

228 else: 

229 _plot_impl_times(it_ref, lagtimes_ns, ax, ls=":") 

230 _plot_impl_times(it, lagtimes_ns, ax) 

231 ax.set_yscale("log") 

232 ax.set_title(title) 

233 

234 for ax in axs.flatten(): 

235 ax.set_ylim(min(min_it * 0.9, int(lagtimes_ns.shape[0] / 4)), max_it * 1.5) 

236 

237 if len(axs.shape) == 2: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true

238 for ax in axs[-1]: 

239 ax.set_xlabel(r"lag time $\tau$ / ns") 

240 for axx in axs: 

241 for ax in axx[1:]: 

242 plt.setp(ax.get_yticklabels(), visible=False) 

243 for ax in axs[:, 0]: 

244 ax.set_ylabel("time scale / ns") 

245 elif len(axs.shape) == 1: 245 ↛ 253line 245 didn't jump to line 253 because the condition on line 245 was always true

246 axs[0].set_ylabel("time scale / ns") 

247 for ax in axs: 

248 ax.set_xlabel(r"lag time $\tau$ / ns") 

249 for ax in axs[1:]: 249 ↛ 250line 249 didn't jump to line 250 because the loop on line 249 never started

250 plt.setp(ax.get_yticklabels(), visible=False) 

251 

252 # Get handles and labels 

253 handles, labels = plt.gca().get_legend_handles_labels() 

254 

255 # Reorder the handles and labels manually to achieve column-major ordering 

256 desired_order = np.array( 

257 [(i + ntimescales, i) for i in range(ntimescales)] 

258 ).flatten() 

259 handles = [handles[i] for i in desired_order] 

260 labels = [labels[i] for i in desired_order] 

261 

262 pplt.legend( 

263 handles=handles, labels=labels, outside="top", frameon=False, ncols=ntimescales 

264 ) 

265 

266 plt.tight_layout() 

267 plt.savefig(out) 

268 plt.close() 

269 

270 

271def _plot_impl_times(impl_times, lagtimes, ax, ls="-"): 

272 """Plot the implied timescales""" 

273 colors = ["#264653", "#2A9D8F", "#E9C46A", "#f4a261", "#e76f51"] * 4 

274 for idx, impl_time in enumerate(impl_times.T): 

275 if ls == ":": 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true

276 label = f"$t_{{\\mathrm{{ref}},{idx + 1}}}$" 

277 elif ls == "--": 

278 label = f"$t_{{\\mathrm{{mic}},{idx + 1}}}$" 

279 else: 

280 label = f"$t_{idx + 1}$" 

281 ax.plot(lagtimes, impl_time, label=label, color=colors[idx], ls=ls) 

282 

283 xlim = lagtimes[0], lagtimes[-1] 

284 ref_low = int(lagtimes.shape[0] / 4) 

285 ax.set_xlim(xlim) 

286 # highlight diagonal 

287 x_i = np.arange(ref_low, xlim[1]) 

288 ax.fill_between(x_i, x_i, color="pplt:grid") 

289 

290 

291def relative_implied_timescales(cl, out): 

292 pplt.use_style(figsize=(8, 2.5), latex=False, colors="pastel_autumn") 

293 

294 ref = cl.reference 

295 its = cl.timescales / ref.timescales 

296 

297 fig = plt.figure() 

298 ax1 = fig.add_subplot(1, 3, 1) 

299 ax2 = fig.add_subplot(1, 3, 2, sharey=ax1) 

300 ax3 = fig.add_subplot(1, 3, 3) 

301 

302 for ax in (ax1, ax2, ax3): 

303 ax.grid(False) 

304 

305 ax1.hist(its[:, 0], bins=20) 

306 ax1.set_title("its 1") 

307 ax1.set_xlabel( 

308 r"Relative Implied Timescale $\left(\frac{t_\mathrm{stoch}}{t_\mathrm{ref}}\right)$" 

309 ) 

310 ax1.set_ylabel("Count of Clusterings") 

311 ax2.hist(its.mean(axis=1), bins=20) 

312 ax2.set_title(f"Mean its {1}-{3}") 

313 ax2.set_xlabel( 

314 r"Relative Implied Timescale $\left(\frac{t_\mathrm{stoch}}{t_\mathrm{ref}}\right)$" 

315 ) 

316 

317 bins = np.array(range(min(cl.n_macrostates) - 1, max(cl.n_macrostates) + 1)) + 0.5 

318 

319 ax3.hist(cl.n_macrostates, bins=bins) 

320 ax3.set_title("n macrostates") 

321 ax3.set_xlabel("macrostate count") 

322 

323 plt.tight_layout() 

324 plt.savefig(out) 

325 plt.close() 

326 

327 

328### SIMILARITY MATRIX ######################################################## 

329 

330 

331def plot_heatmap(a, out, title=""): 

332 """ 

333 Plot heatmap from a matrix. This is supposed for a similarity matrix as 

334 returned from the multiplication of two MPT objects. 

335 """ 

336 fig, ax = plt.subplots() 

337 ax.imshow(a, norm="log") 

338 ax.set_aspect("equal", "box") 

339 

340 ax.set_xticks(np.arange(a.shape[1])) 

341 ax.set_yticks(np.arange(a.shape[0])) 

342 if title: 

343 ax.set_title(title) 

344 ax.set_xlabel("Macrostate") 

345 ax.set_ylabel("Macrostate") 

346 plt.tight_layout() 

347 plt.savefig(out) 

348 plt.close() 

349 

350 

351def transition_matrix(a, out, title="Transition Matrix", color_thr=0.01): 

352 """ 

353 Plot heatmap from a matrix. This is supposed for a similarity matrix as 

354 returned from the multiplication of two MPT objects. 

355 """ 

356 # Scale a to percent 

357 a = a * 100 

358 

359 # Define the colormap for the diagonal elements (logarithmic Reds) 

360 diagonal_values = np.diag(a) 

361 diag_norm = LogNorm(vmin=diagonal_values.min(), vmax=diagonal_values.max()) 

362 diag_cmap = plt.cm.Reds 

363 

364 # Adjust the Reds colormap to make the lower bound closer to red 

365 reds_custom = diag_cmap(np.linspace(0.2, 1, 256)) 

366 diag_cmap_custom = ListedColormap(reds_custom) 

367 

368 # Define the colormap for the off-diagonal elements (logarithmic viridis) 

369 off_diag_mask = ~np.eye(a.shape[0], dtype=bool) 

370 off_diag_values = a[off_diag_mask] 

371 

372 # Threshold for light gray 

373 threshold = color_thr * off_diag_values.max() 

374 print(f"Threshold for probabilities: {threshold:.3f} %") 

375 

376 off_diag_norm = LogNorm( 

377 vmin=threshold * (1 - color_thr), vmax=off_diag_values.max() 

378 ) 

379 

380 # Create a custom colormap for off-diagonal values including light gray 

381 colors_list = plt.cm.viridis(np.linspace(0, 1, 256)) 

382 gray = np.array([0.9, 0.9, 0.9, 1.0]) 

383 colors_list[: int(color_thr * 256)] = gray 

384 custom_off_diag_cmap = colors.ListedColormap(colors_list) 

385 

386 fig, ax = plt.subplots(figsize=(8, 8)) 

387 ax.set_aspect("equal", "box") 

388 ax.grid(False) 

389 

390 for i in range(a.shape[0]): 

391 for j in range(a.shape[1]): 

392 value = a[i, j] 

393 if value == 0: 

394 color = (1, 1, 1, 1) # Zero probabilities are white 

395 elif i == j: 

396 color = diag_cmap_custom(diag_norm(value)) 

397 else: 

398 color = ( 

399 gray 

400 if value < threshold 

401 else custom_off_diag_cmap(off_diag_norm(value)) 

402 ) 

403 

404 ax.add_patch(patches.Rectangle((j - 0.5, i - 0.5), 1, 1, color=color)) 

405 

406 # Add text with transition probabilities 

407 if value != 0: 

408 grayscale = np.sum( 

409 np.array(color[:3]) * np.array([0.299, 0.587, 0.114]) 

410 ) 

411 text_color = "white" if grayscale < 0.5 else "black" 

412 ax.text( 

413 j, 

414 i, 

415 f"{value:.2f}%", 

416 ha="center", 

417 va="center", 

418 color=text_color, 

419 fontsize=10, 

420 ) 

421 

422 ax.set_xticks(np.arange(a.shape[1])) 

423 ax.set_yticks(np.arange(a.shape[0])) 

424 ax.set_xticklabels(np.arange(1, a.shape[1] + 1)) 

425 ax.set_yticklabels(np.arange(1, a.shape[0] + 1)) 

426 ax.set_xlim(-0.5, a.shape[1] - 0.5) 

427 ax.set_ylim(-0.5, a.shape[0] - 0.5) 

428 

429 # Add a colorbar for diagonal values 

430 cbar_diag = fig.colorbar( 

431 plt.cm.ScalarMappable(norm=diag_norm, cmap=diag_cmap), ax=ax, shrink=0.5 

432 ) 

433 cbar_diag.set_label("Self Transition Probabilities / \\%") 

434 

435 # Add a colorbar for off-diagonal values 

436 cbar_off_diag = fig.colorbar( 

437 plt.cm.ScalarMappable(norm=off_diag_norm, cmap=custom_off_diag_cmap), 

438 ax=ax, 

439 shrink=0.5, 

440 ) 

441 cbar_off_diag.set_label("Transitiion Probabilities / \\%") 

442 

443 if title: 443 ↛ 446line 443 didn't jump to line 446 because the condition on line 443 was always true

444 ax.set_title(title) 

445 

446 ax.set_xlabel("From Macrostate") 

447 ax.set_ylabel("To Macrostate") 

448 plt.tight_layout() 

449 plt.savefig(out) 

450 plt.close() 

451 

452 

453def transition_time( 

454 a, 

455 out, 

456 tlag=50.0, 

457 frame_length=0.2, 

458 title=r"Transition Times $\frac{t_\mathrm{lag}}{P}$", 

459 color_thr=0.01, 

460): 

461 """ 

462 Plot heatmap from a matrix. This is supposed for a similarity matrix as 

463 returned from the multiplication of two MPT objects. 

464 frame_length in ns 

465 """ 

466 with np.errstate(divide="ignore"): 

467 a = tlag / a * frame_length 

468 

469 # Define the colormap for the diagonal elements (logarithmic Reds) 

470 diagonal_values = np.diag(a) 

471 diag_norm = LogNorm(vmin=diagonal_values.min(), vmax=diagonal_values.max()) 

472 diag_cmap = plt.cm.Reds_r 

473 

474 # Adjust the Reds colormap to make the lower bound closer to red 

475 reds_custom = diag_cmap(np.linspace(0, 0.8, 256)) 

476 diag_cmap_custom = ListedColormap(reds_custom) 

477 

478 # Define the colormap for the off-diagonal elements (logarithmic viridis) 

479 off_diag_mask = ~np.eye(a.shape[0], dtype=bool) 

480 off_diag_values = a[off_diag_mask] 

481 

482 # Threshold for light gray 

483 threshold = off_diag_values.min() / color_thr 

484 print(f"Threshold for probabilities: {threshold:.2f} ns") 

485 

486 off_diag_norm = LogNorm( 

487 vmin=off_diag_values.min(), vmax=threshold / (1 - color_thr) 

488 ) 

489 

490 # Create a custom colormap for off-diagonal values including light gray 

491 colors_list = plt.cm.viridis_r(np.linspace(0, 1, 256)) 

492 gray = np.array([0.9, 0.9, 0.9, 1.0]) 

493 colors_list[int((1 - color_thr) * 256) :] = gray 

494 custom_off_diag_cmap = colors.ListedColormap(colors_list) 

495 

496 fig, ax = plt.subplots(figsize=(8, 8)) 

497 ax.set_aspect("equal", "box") 

498 ax.grid(False) 

499 

500 for i in range(a.shape[0]): 

501 for j in range(a.shape[1]): 

502 value = a[i, j] 

503 if value == np.inf: 

504 color = (1, 1, 1, 1) # Zero probabilities are white 

505 elif i == j: 

506 color = diag_cmap_custom(diag_norm(value)) 

507 else: 

508 color = ( 

509 gray 

510 if value > threshold 

511 else custom_off_diag_cmap(off_diag_norm(value)) 

512 ) 

513 

514 ax.add_patch(patches.Rectangle((j - 0.5, i - 0.5), 1, 1, color=color)) 

515 

516 # Add text with transition probabilities 

517 if value != np.inf: 

518 grayscale = np.sum( 

519 np.array(color[:3]) * np.array([0.299, 0.587, 0.114]) 

520 ) 

521 text_color = "white" if grayscale < 0.5 else "black" 

522 if value >= threshold: 

523 pre_text = f"{value:.1g}" 

524 text = pre_text[:2] + pre_text[-1] 

525 else: 

526 if value >= 100: 

527 text = f"{value:.0f}" 

528 else: 

529 text = f"{value:#.3g}" 

530 ax.text( 

531 j, i, text, ha="center", va="center", color=text_color, fontsize=10 

532 ) 

533 

534 ax.set_xticks(np.arange(a.shape[1])) 

535 ax.set_yticks(np.arange(a.shape[0])) 

536 ax.set_xticklabels(np.arange(1, a.shape[1] + 1)) 

537 ax.set_yticklabels(np.arange(1, a.shape[0] + 1)) 

538 ax.set_xlim(-0.5, a.shape[1] - 0.5) 

539 ax.set_ylim(-0.5, a.shape[0] - 0.5) 

540 

541 # Add a colorbar for diagonal values 

542 cbar_diag = fig.colorbar( 

543 plt.cm.ScalarMappable(norm=diag_norm, cmap=diag_cmap), ax=ax, shrink=0.5 

544 ) 

545 cbar_diag.set_label("Self Transition Times / ns") 

546 

547 # Add a colorbar for off-diagonal values 

548 cbar_off_diag = fig.colorbar( 

549 plt.cm.ScalarMappable(norm=off_diag_norm, cmap=custom_off_diag_cmap), 

550 ax=ax, 

551 shrink=0.5, 

552 ) 

553 cbar_off_diag.set_label("Transitiion Times / ns") 

554 

555 if title: 555 ↛ 558line 555 didn't jump to line 558 because the condition on line 555 was always true

556 ax.set_title(title) 

557 

558 ax.set_xlabel("From Macrostate") 

559 ax.set_ylabel("To Macrostate") 

560 plt.tight_layout() 

561 plt.savefig(out) 

562 plt.close() 

563 

564 

565### MACROSTATE FEATURES ###################################################### 

566 

567 

568def plot_macro_feature(micro_feature, out, ref=None, pop=None): 

569 """ 

570 Plot histogram of feature distribution. 

571 

572 micro_feature (np.ndarray, NxR): N microstates, R runs, holds feature 

573 values of respective macrostate 

574 out (str): file to save the plot 

575 ref (list[tuple]): list of 

576 - macrostate_assignment 

577 - macrostate_feature 

578 - color 

579 - label 

580 of the clusterings that should be shown explicitly. 

581 """ 

582 min_feature = micro_feature.min() * 0.95 

583 max_feature = micro_feature.max() * 1.05 

584 counts, bins = np.histogram( 

585 micro_feature, 

586 bins=np.linspace(min_feature, max_feature, 101), 

587 weights=pop, 

588 density=True, 

589 ) 

590 norm_counts = counts / micro_feature.shape[1] 

591 y_min = norm_counts[norm_counts > 0].min() * 0.7 

592 

593 fig, ax = plt.subplots(figsize=(8, 6)) 

594 ax.hist(bins[:-1], bins=bins, weights=norm_counts, label="Stochastic Clustering") 

595 if ref is not None: 595 ↛ 599line 595 didn't jump to line 599 because the condition on line 595 was always true

596 # for mas, mfs, c, l, w in ref: 

597 # add_ref(mas, mfs, ax, color=c, label=l, weights=w) 

598 add_ref(ref.macrostate_assignment[ref.n_i], ref.macrostate_feature[ref.n_i], ax) 

599 ax.set_xlabel("Fraction of Contacts") 

600 ax.set_ylabel("Population") 

601 ax.set_title(f"Macrostate Features, {micro_feature.shape[1]} clusterings") 

602 ax.set_yscale("log") 

603 ylim = ax.get_ylim() 

604 ax.set_ylim((y_min, ylim[1])) 

605 plt.legend(loc="lower left") 

606 plt.tight_layout() 

607 plt.savefig(out) 

608 plt.close() 

609 

610 

611def add_ref( 

612 macrostate_assignment, 

613 macrostate_feature, 

614 ax, 

615 color="r", 

616 label="Reference", 

617 weights=None, 

618): 

619 """ 

620 Add a clustering to the histogram. 

621 

622 macrostate_assignment (np.ndarray, MxN): macrostate assignement, M: number 

623 of macrostates, N: number of microstates. 

624 macrostate_feature (np.ndarray, M): mean feature for every macrostate. 

625 """ 

626 b = True 

627 for i, (ma, mf) in enumerate(zip(macrostate_assignment, macrostate_feature)): 

628 x = [mf, mf] 

629 if weights is None: 

630 weights = np.array([1]) 

631 else: 

632 weights = weights / weights.sum() 

633 y = [1e-9, (ma * weights).sum() / weights.sum() * 1e-3] 

634 if b: 

635 ax.plot(x, y, c=color, label=label + " / 1000") 

636 b = False 

637 else: 

638 ax.plot(x, y, c=color) 

639 pplt.text( 

640 mf + 0.015, 

641 y[1] * 0.82, 

642 f"{i + 1:.0f}", 

643 c=color, 

644 ax=ax, 

645 contour=True, 

646 size="small", 

647 ) 

648 

649 

650### CONTACT REPRESENTATION ################################################### 

651 

652 

653def contact_rep(contacts, cluster_file, state_traj, output, grid, scale=1): 

654 """ 

655 Adapted from msmhelper. 

656 

657 Contact representation of states. 

658 

659 This script creates a contact representation of states. Were the states 

660 are obtained by [MoSAIC](https://github.com/moldyn/MoSAIC) and the contact 

661 representation was introduced in Nagel et al.[^1]. 

662 

663 [^1]: Nagel et al., **Selecting Features for Markov Modeling: A Case Study 

664 on HP35.**, *J. Chem. Theory Comput.*, submitted, 

665 

666 """ 

667 # setup matplotlib 

668 pplt.use_style( 

669 figsize=1.2 * scale, 

670 colors="pastel_autumn", 

671 true_black=True, 

672 latex=False, 

673 ) 

674 

675 # load files 

676 states = np.unique(state_traj) 

677 clusters = load_clusters(cluster_file) 

678 

679 contact_idxs = np.hstack(clusters) 

680 n_idxs = len(contact_idxs) 

681 n_frames = len(contacts) 

682 

683 xtickpos = ( 

684 np.cumsum( 

685 [ 

686 0, 

687 *[len(clust) for clust in clusters[:-1]], 

688 ] 

689 ) 

690 - 0.5 

691 ) 

692 nrows, ncols = grid 

693 hspace, wspace = 0, 0 

694 ylims = 0, np.quantile(contacts, 0.999) 

695 

696 counter = 0 

697 for chunk in mh.plot._ck_test._split_array(states, nrows * ncols): 

698 fig, axs = plt.subplots( 

699 int(np.ceil(len(chunk) / ncols)), 

700 ncols, 

701 sharex=True, 

702 sharey=True, 

703 squeeze=False, 

704 gridspec_kw={"wspace": wspace, "hspace": hspace}, 

705 ) 

706 

707 # ignore outliers 

708 for state, ax in zip(chunk, axs.flatten()): 

709 contacts_state = contacts[state_traj == state] 

710 pop_state = len(contacts_state) / n_frames 

711 

712 # get colormap 

713 c1, c2, c3 = pplt.categorical_color(3, "C0") 

714 

715 stats = { 

716 idx: boxplot_stats(contacts_state[:, idx])[0] for idx in contact_idxs 

717 } 

718 

719 for color, (key_low, key_high), label in ( 

720 (c3, ("whislo", "whishi"), r"$Q_{1/3} \pm 1.5\mathrm{IQR}$"), 

721 (c2, ("q1", "q3"), r"$\mathrm{IQR} = Q_3 - Q_1$"), 

722 ): 

723 ymax = [stats[idx][key_high] for idx in contact_idxs] 

724 ymin = [stats[idx][key_low] for idx in contact_idxs] 

725 ax.stairs( 

726 ymax, 

727 np.arange(n_idxs + 1) - 0.5, 

728 baseline=ymin, 

729 color=color, 

730 lw=0, 

731 fill=True, 

732 label=label, 

733 ) 

734 

735 ax.hlines( 

736 [stats[idx]["med"] for idx in contact_idxs], 

737 xmin=np.arange(n_idxs) - 0.5, 

738 xmax=np.arange(n_idxs) + 0.5, 

739 label="median", 

740 color=c1, 

741 ) 

742 

743 pplt.text( 

744 0.5, 

745 0.95, 

746 rf"S{state + 1} {pop_state:.1%}", 

747 ha="center", 

748 va="top", 

749 ax=ax, 

750 transform=ax.transAxes, 

751 contour=True, 

752 ) 

753 

754 ax.set_xlim([-0.5, n_idxs - 0.5]) 

755 ax.set_ylim(*ylims) 

756 ax.set_xticks(xtickpos) 

757 ax.set_xticklabels(np.arange(len(xtickpos)) + 1) 

758 

759 ax.grid(False) 

760 for pos in xtickpos: 

761 ax.axvline(pos, color="pplt:grid", lw=1.0) 

762 

763 pplt.hide_empty_axes() 

764 pplt.legend( 

765 ax=axs[0, 0], 

766 outside="top", 

767 bbox_to_anchor=( 

768 0, 

769 1.0, 

770 axs.shape[1] + wspace * (axs.shape[1] - 1), 

771 0.01, 

772 ), 

773 frameon=False, 

774 ncol=2, 

775 ) 

776 pplt.subplot_labels( 

777 xlabel="contact clusters", 

778 ylabel="distances [nm]", 

779 ) 

780 

781 # save figure and continue 

782 if output is None: 782 ↛ 783line 782 didn't jump to line 783 because the condition on line 782 was never true

783 plt.show() 

784 # output = f"{state_file}.contactRep.pdf" 

785 # insert state_str between pathname and extension 

786 path, ext = splitext(output) 

787 if counter == 0: 787 ↛ 791line 787 didn't jump to line 791 because the condition on line 787 was always true

788 pplt.savefig(output) 

789 plt.close() 

790 else: 

791 pplt.savefig(f"{path}.state{chunk[0]:.0f}-{chunk[-1]:.0f}{ext}") 

792 plt.close() 

793 counter += 1 

794 

795 

796### SANKEY ################################################################### 

797 

798 

799def plot_sankey(cl, ref, out, ax=None, scale=1): 

800 features = [] 

801 for macrostate in cl.tree[cl.n_i].macrostates: 

802 features.append(macrostate.feature) 

803 ma_order = np.argsort(features)[::-1] 

804 colorDict = {} 

805 for i, o in enumerate(ma_order): 

806 colorDict[str(i + 1)] = cl.tree[cl.n_i].macrostates[o].color 

807 if ax is None: 807 ↛ 809line 807 didn't jump to line 809 because the condition on line 807 was always true

808 pplt.use_style(figsize=(1.7 * scale, 3.6 * scale), true_black=True) 

809 sankey( 

810 left=(cl.macrostates_map[cl.n_i] + 1).astype(str), 

811 right=(ref.macrostates_map[0] + 1).astype(str), 

812 leftWeight=ref.pop, 

813 rightWeight=ref.pop, 

814 leftLabels=np.arange(1, cl.n_macrostates[cl.n_i] + 1).astype(str).tolist(), 

815 rightLabels=np.arange(1, ref.n_macrostates[0] + 1).astype(str).tolist(), 

816 colorDict=colorDict, 

817 ax=ax, 

818 ) 

819 if ax is None: 819 ↛ exitline 819 didn't return from function 'plot_sankey' because the condition on line 819 was always true

820 pplt.savefig(out) 

821 plt.close() 

822 

823 

824### RMSD LINES ############################################################### 

825 

826 

827def plot_rmsd(rmsds, pops, helices=None, filename=None): 

828 """ 

829 Plots a 2D NumPy array as a heatmap with a logarithmic color scale and variable row heights. 

830 

831 Parameters: 

832 - vars (np.ndarray): The 2D NumPy array to plot. Values must be positive for logarithmic scaling. 

833 - row_heights (np.ndarray): 1D array defining the height of each row. 

834 - helices (np.ndarray): Array with start and end points for blocks to be indicated in the bottom row. 

835 - filename (str, optional): If provided, saves the heatmap to this file. 

836 """ 

837 # Ensure all values are positive for logarithmic scaling 

838 if np.any(rmsds <= 0): 838 ↛ 839line 838 didn't jump to line 839 because the condition on line 838 was never true

839 raise ValueError( 

840 "All values in `rmsds` must be positive for logarithmic scaling." 

841 ) 

842 

843 if rmsds.shape[0] != len(pops): 843 ↛ 844line 843 didn't jump to line 844 because the condition on line 843 was never true

844 raise ValueError("Length of `pops` must match the number of rows in `rmsds`.") 

845 

846 if helices is not None: 846 ↛ 849line 846 didn't jump to line 849 because the condition on line 846 was always true

847 n_plots = rmsds.shape[0] + 1 

848 else: 

849 n_plots = rmsds.shape[0] 

850 

851 w = 0.08 * rmsds.shape[1] + 3 # 8.6 

852 h = 1 + 0.4 * n_plots # 6 

853 pplt.use_style( 

854 figsize=(w, h), 

855 colors="pastel_autumn", 

856 true_black=True, 

857 latex=False, 

858 ) 

859 fig, axs = plt.subplots( 

860 n_plots, 

861 3, 

862 sharex="col", 

863 width_ratios=[rmsds.shape[1], 8, 8], 

864 gridspec_kw={"wspace": 0, "hspace": 0}, 

865 ) 

866 

867 ylim = 0.5 * rmsds.min(), 2 * rmsds.max() 

868 pops = pops / pops.sum() 

869 ylim_hist = 0, 1.05 * pops.max() 

870 

871 rmsd_sums = rmsds[:, 2:-2].sum(axis=1) 

872 ylim_rmsd = 0, 1.05 * rmsd_sums.max() 

873 

874 for i, ((ax, hist_ax, rmsd_ax), rmsd, pop) in enumerate( 

875 zip(axs[:-1] if helices is not None else axs, rmsds, pops) 

876 ): 

877 rect = patches.Rectangle( 

878 (0, 0.3), # Position of the block 

879 pop, 

880 0.4, # color='black' 

881 ) 

882 hist_ax.add_patch(rect) 

883 hist_ax.set_xlim(ylim_hist) 

884 hist_ax.set_yticks([], []) 

885 hist_ax.grid(False) 

886 

887 rect = patches.Rectangle( 

888 (0, 0.3), # Position of the block 

889 rmsd[2:-2].sum(), 

890 0.4, # color='black' 

891 ) 

892 rmsd_ax.add_patch(rect) 

893 rmsd_ax.set_xlim(ylim_rmsd) 

894 rmsd_ax.set_yticks([], []) 

895 rmsd_ax.grid(False) 

896 

897 ax.plot(np.arange(rmsd.shape[0]) + 1, rmsd) 

898 ax.fill_between( 

899 np.arange(rmsd.shape[0]) + 1, 

900 [ylim[0]] * rmsd.shape[0], 

901 rmsd, 

902 alpha=0.5, 

903 # facecolor="none", 

904 # hatch="/", 

905 ) 

906 

907 ax.set_yscale("log") 

908 ax.set_ylabel(f"{i + 1}") 

909 ax.set_xlim((0.5, rmsd.shape[0] + 0.5)) 

910 ax.set_ylim(ylim) 

911 ax.grid(True) 

912 

913 if helices is not None: 913 ↛ 970line 913 didn't jump to line 970 because the condition on line 913 was always true

914 line_start = 1 

915 helices_ax = axs[-1, 0] 

916 # helices_ax.plot([1, rmsds.shape[1]], [0.5, 0.5]) #, c="k") 

917 for start, end in helices: 

918 if start > 0: 

919 # Helices 

920 start -= 0.3 

921 end += 0.3 

922 rect = patches.Rectangle( 

923 (start, 0.3), # Position of the block 

924 end - start, 

925 0.4, # color='black' 

926 fc="#264653", 

927 ec="#264653", 

928 lw=2, 

929 ) 

930 else: 

931 # Sheets 

932 start, end = -start, -end 

933 start -= 0.5 

934 end += 0.5 

935 rect = patches.Rectangle( 

936 (start, 0.3), # Position of the block 

937 end - start, 

938 0.4, # color='black' 

939 fc="white", 

940 ec="#264653", 

941 lw=2, 

942 ) 

943 helices_ax.plot( 

944 [line_start, start], 

945 [0.5, 0.5], 

946 solid_capstyle="butt", 

947 c="#264653", 

948 lw=2, 

949 ) 

950 line_start = end 

951 helices_ax.add_patch(rect) 

952 helices_ax.plot( 

953 [line_start, rmsds.shape[1]], 

954 [0.5, 0.5], 

955 solid_capstyle="butt", 

956 c="#264653", 

957 lw=2, 

958 ) 

959 

960 helices_ax.set_ylim((0, 1)) 

961 helices_ax.set_ylabel("H") 

962 helices_ax.set_yticks([], []) 

963 helices_ax.grid(False) 

964 

965 axs[-1, 1].grid(False) 

966 axs[-1, 1].set_yticks([], []) 

967 axs[-1, 2].grid(False) 

968 axs[-1, 2].set_yticks([], []) 

969 

970 hist_ticks = axs[-1, 1].get_xticks() 

971 hist_labels = axs[-1, 1].get_xticklabels() 

972 hist_labels[0] = "" 

973 axs[-1, 1].set_xticks(hist_ticks, hist_labels) 

974 

975 rmsd_ticks = axs[-1, 2].get_xticks() 

976 rmsd_labels = axs[-1, 2].get_xticklabels() 

977 rmsd_labels[0] = "" 

978 axs[-1, 2].set_xticks(rmsd_ticks, rmsd_labels) 

979 

980 axs[-1, 0].xaxis.set_major_locator(MultipleLocator(5)) 

981 axs[-1, 0].xaxis.set_minor_locator(MultipleLocator(1)) 

982 axs[-1, 0].set_xlabel("Residue") 

983 axs[-1, 1].set_xlabel("Population", rotation=20) 

984 axs[-1, 2].set_xlabel(r"$\sum$ RMSD / nm", rotation=20) 

985 fig.supylabel("Macrostate; RMSD Variance / nm") 

986 

987 # Save to file if filename is provided 

988 plt.tight_layout() 

989 if filename: 989 ↛ 992line 989 didn't jump to line 992 because the condition on line 989 was always true

990 plt.savefig(filename, dpi=192) # , bbox_inches="tight" 

991 else: 

992 plt.show() 

993 plt.close() 

994 

995 

996def plot_delta_rmsd(rmsds, pops, helices=None, filename=None): 

997 """ 

998 Plots a 2D NumPy array as a heatmap with a logarithmic color scale and variable row heights. 

999 

1000 Parameters: 

1001 - vars (np.ndarray): The 2D NumPy array to plot. Values must be positive for logarithmic scaling. 

1002 - row_heights (np.ndarray): 1D array defining the height of each row. 

1003 - helices (np.ndarray): Array with start and end points for blocks to be indicated in the bottom row. 

1004 - filename (str, optional): If provided, saves the heatmap to this file. 

1005 """ 

1006 # Ensure all values are positive for logarithmic scaling 

1007 if np.any(rmsds <= 0): 1007 ↛ 1008line 1007 didn't jump to line 1008 because the condition on line 1007 was never true

1008 raise ValueError( 

1009 "All values in `rmsds` must be positive for logarithmic scaling." 

1010 ) 

1011 

1012 if rmsds.shape[0] != len(pops): 1012 ↛ 1013line 1012 didn't jump to line 1013 because the condition on line 1012 was never true

1013 raise ValueError("Length of `pops` must match the number of rows in `rmsds`.") 

1014 

1015 if helices is not None: 1015 ↛ 1018line 1015 didn't jump to line 1018 because the condition on line 1015 was always true

1016 n_plots = rmsds.shape[0] + 1 

1017 else: 

1018 n_plots = rmsds.shape[0] 

1019 

1020 w = 0.08 * rmsds.shape[1] + 3 # 8.6 

1021 h = 1 + 0.4 * n_plots # 6 

1022 pplt.use_style( 

1023 figsize=(w, h), 

1024 colors="pastel_autumn", 

1025 true_black=True, 

1026 latex=False, 

1027 ) 

1028 fig, axs = plt.subplots( 

1029 n_plots, 

1030 3, 

1031 sharex="col", 

1032 width_ratios=[rmsds.shape[1], 8, 8], 

1033 gridspec_kw={"wspace": 0, "hspace": 0}, 

1034 ) 

1035 

1036 delta_rmsd = rmsds 

1037 # delta_rmsd[1:] = rmsds[1:] - rmsds[:-1] 

1038 delta_rmsd[1:] = rmsds[1:] - rmsds[0] 

1039 rmsds = delta_rmsd 

1040 

1041 rmsd_max = rmsds.max() 

1042 rmsd_min = rmsds.min() 

1043 rmsd_delta = rmsd_max - rmsd_min 

1044 

1045 ylim = rmsd_min - rmsd_delta * 0.1, rmsd_max + rmsd_delta * 0.15 

1046 pops = pops / pops.sum() 

1047 ylim_hist = 0, 1.05 * pops.max() 

1048 

1049 rmsd_sums = rmsds.sum(axis=1) 

1050 ylim_rmsd = 0, 1.05 * rmsd_sums.max() 

1051 

1052 for i, ((ax, hist_ax, rmsd_ax), rmsd, pop) in enumerate( 

1053 zip(axs[:-1] if helices is not None else axs, rmsds, pops) 

1054 ): 

1055 rect = patches.Rectangle( 

1056 (0, 0.3), # Position of the block 

1057 pop, 

1058 0.4, # color='black' 

1059 ) 

1060 hist_ax.add_patch(rect) 

1061 hist_ax.set_xlim(ylim_hist) 

1062 hist_ax.set_yticks([], []) 

1063 hist_ax.grid(False) 

1064 

1065 rect = patches.Rectangle( 

1066 (0, 0.3), # Position of the block 

1067 abs(rmsd).sum(), 

1068 0.4, # color='black' 

1069 ) 

1070 rmsd_ax.add_patch(rect) 

1071 rmsd_ax.set_xlim(ylim_rmsd) 

1072 rmsd_ax.set_yticks([], []) 

1073 rmsd_ax.grid(False) 

1074 

1075 ax.plot(np.arange(rmsd.shape[0]) + 1, rmsd) 

1076 ax.fill_between( 

1077 np.arange(rmsd.shape[0]) + 1, 

1078 # [ylim[0]]*rmsd.shape[0], 

1079 [0] * rmsd.shape[0], 

1080 rmsd, 

1081 alpha=0.5, 

1082 # facecolor="none", 

1083 # hatch="/", 

1084 ) 

1085 

1086 # ax.set_yscale("log") 

1087 ax.set_ylabel(f"{i + 1}") 

1088 # ax.set_ylabel(f"{i+1}-{i}") 

1089 # ax.set_ylabel(f"{i+1}-{i}", rotation=90, position=(-0.2, 0)) 

1090 ax.set_xlim((0.5, rmsd.shape[0] + 0.5)) 

1091 ax.set_ylim(ylim) 

1092 ax.grid(True) 

1093 

1094 if helices is not None: 1094 ↛ 1151line 1094 didn't jump to line 1151 because the condition on line 1094 was always true

1095 line_start = 1 

1096 helices_ax = axs[-1, 0] 

1097 # helices_ax.plot([1, rmsds.shape[1]], [0.5, 0.5]) #, c="k") 

1098 for start, end in helices: 

1099 if start > 0: 

1100 # Helices 

1101 start -= 0.3 

1102 end += 0.3 

1103 rect = patches.Rectangle( 

1104 (start, 0.3), # Position of the block 

1105 end - start, 

1106 0.4, # color='black' 

1107 fc="#264653", 

1108 ec="#264653", 

1109 lw=2, 

1110 ) 

1111 else: 

1112 # Sheets 

1113 start, end = -start, -end 

1114 start -= 0.5 

1115 end += 0.5 

1116 rect = patches.Rectangle( 

1117 (start, 0.3), # Position of the block 

1118 end - start, 

1119 0.4, # color='black' 

1120 fc="white", 

1121 ec="#264653", 

1122 lw=2, 

1123 ) 

1124 helices_ax.plot( 

1125 [line_start, start], 

1126 [0.5, 0.5], 

1127 solid_capstyle="butt", 

1128 c="#264653", 

1129 lw=2, 

1130 ) 

1131 line_start = end 

1132 helices_ax.add_patch(rect) 

1133 helices_ax.plot( 

1134 [line_start, rmsds.shape[1]], 

1135 [0.5, 0.5], 

1136 solid_capstyle="butt", 

1137 c="#264653", 

1138 lw=2, 

1139 ) 

1140 

1141 helices_ax.set_ylim((0, 1)) 

1142 helices_ax.set_ylabel("H") 

1143 helices_ax.set_yticks([], []) 

1144 helices_ax.grid(False) 

1145 

1146 axs[-1, 1].grid(False) 

1147 axs[-1, 1].set_yticks([], []) 

1148 axs[-1, 2].grid(False) 

1149 axs[-1, 2].set_yticks([], []) 

1150 

1151 axs[0, 0].set_ylim(0.5 * rmsds[0].min(), 2 * rmsds[0].max()) 

1152 axs[0, 0].set_yscale("log") 

1153 # axs[0, 0].set_ylabel("1") #, rotation=90) 

1154 

1155 hist_ticks = axs[-1, 1].get_xticks() 

1156 hist_labels = axs[-1, 1].get_xticklabels() 

1157 hist_labels[0] = "" 

1158 axs[-1, 1].set_xticks(hist_ticks, hist_labels) 

1159 

1160 rmsd_ticks = axs[-1, 2].get_xticks() 

1161 rmsd_labels = axs[-1, 2].get_xticklabels() 

1162 rmsd_labels[0] = "" 

1163 axs[-1, 2].set_xticks(rmsd_ticks, rmsd_labels) 

1164 

1165 axs[-1, 0].xaxis.set_major_locator(MultipleLocator(5)) 

1166 axs[-1, 0].xaxis.set_minor_locator(MultipleLocator(1)) 

1167 axs[-1, 0].set_xlabel("Residue") 

1168 axs[-1, 1].set_xlabel("Population", rotation=20) 

1169 axs[-1, 2].set_xlabel(r"$\sum$ |$\Delta$RMSD| / nm", rotation=20) 

1170 fig.supylabel(r"$\Delta$Macrostate; $\Delta$RMSD Variance / nm wrt. 1st State") 

1171 

1172 # Save to file if filename is provided 

1173 plt.tight_layout() 

1174 if filename: 1174 ↛ 1177line 1174 didn't jump to line 1177 because the condition on line 1174 was always true

1175 plt.savefig(filename, dpi=192) # , bbox_inches="tight" 

1176 else: 

1177 plt.show() 

1178 plt.close() 

1179 

1180 

1181### TRAJECTORY ############################################################### 

1182 

1183 

1184def plot_state_trajectory(trajectory, filename, row_length=0.2, frame_length=0.2): 

1185 """ 

1186 Plot state trajectory 

1187 

1188 trajectory (np.ndarray): state trajectory 

1189 filename (str): file name to save the plot to 

1190 row_length (int|float): 

1191 row_length > 1: number of frames in each row 

1192 0 < row_length <= 1: fraction of total frames per row (1/n_rows) 

1193 frame_length (float): frame length in ns 

1194 """ 

1195 if row_length > 1: 1195 ↛ 1196line 1195 didn't jump to line 1196 because the condition on line 1195 was never true

1196 x_max = int(row_length) 

1197 elif row_length > 0: 1197 ↛ 1200line 1197 didn't jump to line 1200 because the condition on line 1197 was always true

1198 x_max = int(np.ceil(trajectory.shape[0] * row_length)) 

1199 else: 

1200 raise ValueError("row_lengthg must be > 0") 

1201 

1202 frame_length /= 1000.0 

1203 # Calculate unique states and their lengths 

1204 unique_states, lengths = utils.find_state_lengths(trajectory) 

1205 unique_states += 1 

1206 lengths = lengths * frame_length 

1207 n_rows = int(np.ceil(trajectory.shape[0] / x_max)) 

1208 

1209 x_max *= frame_length 

1210 

1211 figsize = (11.7, 8.3) 

1212 # # Set up figure size proportional to data 

1213 # width = max(6, x_max * 0.0001) # Minimum width of 6 inches 

1214 # height = max( 

1215 # 2, (unique_states.max() - unique_states.min() + 1) * 0.05 * n_rows + 0.6 

1216 # ) # Minimum height of 4 inches 

1217 # figsize = (width * 1.5, height * 1.5) 

1218 

1219 u_states = np.unique(unique_states) 

1220 cmap = mpl.cm.turbo 

1221 norm = mpl.colors.BoundaryNorm(np.concatenate([[0], u_states]) + 0.5, cmap.N) 

1222 

1223 fig, axs = plt.subplots( 

1224 n_rows, 

1225 1, 

1226 sharex=True, 

1227 figsize=figsize, 

1228 gridspec_kw={"wspace": 0, "hspace": 0}, 

1229 squeeze=False, 

1230 layout="compressed", 

1231 ) 

1232 axs = axs[:, 0] 

1233 axi = 0 

1234 

1235 # Plot each state occurrence as a line segment 

1236 x_start = 0 # Initial x-coordinate for the first segment 

1237 for state, length in zip(unique_states, lengths): 

1238 x_end = x_start + length # Calculate end position of this segment on the x-axis 

1239 color = cmap(norm(state)) 

1240 

1241 while x_end > x_max: 

1242 axs[axi].plot( 

1243 [x_start, x_max], 

1244 [state, state], 

1245 color=color[:3], 

1246 linewidth=3, 

1247 solid_capstyle="butt", 

1248 ) 

1249 x_end -= x_max 

1250 x_start = 0 

1251 axi += 1 

1252 if not np.isclose(x_start, x_end): 

1253 axs[axi].plot( 

1254 [x_start, x_end], 

1255 [state, state], 

1256 color=color, 

1257 linewidth=3, 

1258 solid_capstyle="butt", 

1259 ) 

1260 

1261 # Move x_start to the end of the current segment for the next one 

1262 x_start = x_end 

1263 

1264 # Label axes and set title 

1265 cbar = fig.colorbar( 

1266 ScalarMappable(norm=norm, cmap=cmap), 

1267 ax=axs, 

1268 orientation="vertical", 

1269 label="Macrostate", 

1270 ) 

1271 cbar.set_ticks(ticks=u_states, labels=u_states, minor=False) 

1272 fig.axes[-1].tick_params(length=0) 

1273 

1274 fig.supylabel("State Index") 

1275 axs[-1].set_xlabel(r"t / $\mu$s") 

1276 

1277 for ax in axs: 

1278 ax.set_ylim(unique_states.min() - 1, unique_states.max()) 

1279 

1280 # Set axis limits 

1281 plt.xlim(0, x_max) 

1282 

1283 # Save the plot to the specified file 

1284 plt.savefig(filename) 

1285 plt.close() # Close the plot to free memory 

1286 

1287 

1288### CHAPMAN-KOLMOGOROV TEST ################################################## 

1289 

1290 

1291def chapman_kolmogorov(mpt, out, frame_length=0.2): 

1292 """Chapman-Kolmogorov Test. Frame length in ns""" 

1293 ck = mh.msm.tests.chapman_kolmogorov_test( 

1294 utils.get_multi_state_traj(mpt.macrotraj[mpt.n_i], mpt.limits), 

1295 [50, 50, 50, 50, 50], 

1296 4000, 

1297 # int(1550*frame_length), 

1298 ) 

1299 pplt.use_style( 

1300 figsize=4.8, 

1301 colors="pastel_autumn", 

1302 true_black=True, 

1303 latex=False, 

1304 ) 

1305 

1306 nrows, ncols = utils.get_grid_format(mpt.n_macrostates[mpt.n_i]) 

1307 for chunk in mh.plot._ck_test._split_array( 

1308 np.arange(mpt.n_macrostates[mpt.n_i]), nrows * ncols 

1309 ): 

1310 fig = plot_ck_test( 

1311 ck=ck, 

1312 states=chunk, 

1313 frames_per_unit=1 / frame_length, 

1314 unit="ns", 

1315 grid=(ncols, nrows), 

1316 ) 

1317 

1318 for ax in fig.axes: 

1319 for text in ax.texts: 

1320 text.set_position((0.15, 0.2)) 

1321 plt.savefig(out) 

1322 plt.close() 

1323 

1324 

1325def plot_ck_test( 

1326 ck, 

1327 states=None, 

1328 frames_per_unit=1, 

1329 unit="frames", 

1330 grid=(3, 3), 

1331): 

1332 """Plot CK-Test results. 

1333 

1334 This routine is a basic helper function to visualize the results of 

1335 [msmhelper.msm.chapman_kolmogorov_test][]. 

1336 

1337 Parameters 

1338 ---------- 

1339 ck : dict 

1340 Dictionary holding for each lagtime the CK equation and with 'md' the 

1341 reference. 

1342 states : ndarray, optional 

1343 List containing all states to plot the CK-test. 

1344 frames_per_unit : float, optional 

1345 Number of frames per given unit. This is used to scale the axis 

1346 accordingly. 

1347 unit : ['frames', 'fs', 'ps', 'ns', 'us'], optional 

1348 Unit to use for label. 

1349 grid : (int, int), optional 

1350 The number of `(n_rows, n_cols)` to use for the grid layout. 

1351 

1352 Returns 

1353 ------- 

1354 fig : matplotlib.Figure 

1355 Figure holding plots. 

1356 

1357 Notes 

1358 ----- 

1359 Adapted from msmhelper. 

1360 

1361 """ 

1362 # load colors 

1363 pplt.load_cmaps() 

1364 pplt.load_colors() 

1365 

1366 lagtimes = np.array([key for key in ck.keys() if key != "md"]) 

1367 if states is None: 1367 ↛ 1368line 1367 didn't jump to line 1368 because the condition on line 1367 was never true

1368 states = np.array(list(ck["md"]["ck"].keys())) 

1369 

1370 nrows, ncols = grid 

1371 needed_rows = int(np.ceil(len(states) / ncols)) 

1372 

1373 fig, axs = plt.subplots( 

1374 needed_rows, 

1375 ncols, 

1376 sharex=True, 

1377 sharey="row", 

1378 gridspec_kw={"wspace": 0, "hspace": 0}, 

1379 ) 

1380 axs = np.atleast_2d(axs) 

1381 

1382 max_time = np.max(ck["md"]["time"]) 

1383 for irow, states_row in enumerate(mh.plot._ck_test._split_array(states, ncols)): 

1384 for icol, state in enumerate(states_row): 

1385 ax = axs[irow, icol] 

1386 

1387 pplt.plot( 

1388 ck["md"]["time"] / frames_per_unit, 

1389 ck["md"]["ck"][state], 

1390 "--", 

1391 ax=ax, 

1392 color="pplt:gray", 

1393 label="MD", 

1394 ) 

1395 for lagtime in lagtimes: 

1396 pplt.plot( 

1397 ck[lagtime]["time"] / frames_per_unit, 

1398 ck[lagtime]["ck"][state], 

1399 ax=ax, 

1400 label=lagtime / frames_per_unit, 

1401 ) 

1402 pplt.text( 

1403 0.5, 

1404 0.9, 

1405 "S{0}".format(state + 1), 

1406 contour=True, 

1407 va="top", 

1408 transform=ax.transAxes, 

1409 ax=ax, 

1410 ) 

1411 

1412 # set scale 

1413 ax.set_xscale("log") 

1414 ax.set_xlim( 

1415 [ 

1416 lagtimes[0] / frames_per_unit, 

1417 max_time / frames_per_unit, 

1418 ] 

1419 ) 

1420 ax.set_ylim([0, 1]) 

1421 if irow < len(axs) - 1: 

1422 ax.set_yticks([0.5, 1]) 

1423 else: 

1424 ax.set_yticks([0, 0.5, 1]) 

1425 

1426 ax.grid(True, which="major", linestyle="--") 

1427 ax.grid(True, which="minor", linestyle="dotted") 

1428 ax.set_axisbelow(True) 

1429 

1430 # set legend 

1431 legend_kw = ( 

1432 { 

1433 "outside": "right", 

1434 "bbox_to_anchor": (2.0, (1 - nrows), 0.2, nrows), 

1435 } 

1436 if ncols in {1, 2} 

1437 else { 

1438 "outside": "top", 

1439 "bbox_to_anchor": (0.0, 1.0, ncols, 0.01), 

1440 } 

1441 ) 

1442 if ncols == 3: 1442 ↛ 1443line 1442 didn't jump to line 1443 because the condition on line 1442 was never true

1443 legend_kw["ncol"] = 3 

1444 pplt.legend( 

1445 ax=axs[0, 0], 

1446 **legend_kw, 

1447 title=rf"$\tau_\mathrm{{lag}}$ [{unit}]", 

1448 frameon=False, 

1449 ) 

1450 

1451 ylabel = ( 

1452 (r"self-transition probability $P_{i\to i}$") 

1453 if nrows >= 3 

1454 else (r"$P_{i\to i}$") 

1455 ) 

1456 

1457 pplt.hide_empty_axes() 

1458 pplt.label_outer() 

1459 pplt.subplot_labels( 

1460 ylabel=ylabel, 

1461 xlabel=r"time $t$ [{unit}]".format(unit=unit), 

1462 ) 

1463 return fig 

1464 

1465 

1466### MACROSTATE GRAPH ######################################################### 

1467 

1468 

1469def state_network(lumping, out): 

1470 draw_knetwork( 

1471 lumping.macrotraj[lumping.n_i], lumping.tlag, lumping.feature_traj, out 

1472 )