Coverage for MPT/MPT.py: 86%

402 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-11 11:20 +0200

1import os 

2import datetime 

3import warnings 

4import numpy as np 

5import msmhelper as mh 

6import matplotlib.pyplot as plt 

7import mdtraj as md 

8 

9from pathlib import Path 

10from tqdm import tqdm 

11from typing import Callable, List 

12from numpy.typing import NDArray 

13from collections.abc import Iterable 

14from sklearn.metrics import davies_bouldin_score 

15import pygpcca as gp 

16 

17from MPT import core 

18 

19# import MPT.core as core 

20import MPT.utils as utils 

21import MPT.kernel as kernel_module 

22from MPT.graph import draw_knetwork 

23 

24import MPT.plot as plot 

25 

26# TODO: 

27# - change traj and macrotraj to list - add one dimension. First, mark all places that need adaptation. 

28# - Connect with contacts, check for implications. Float contacts file: /data/PDZ3_Ali/short_ligand/reduction/trans/contacts_analysis/cluster1-7/data/dist_all 

29# - internally change trajectory to 0-based, still support 1-based, ussue warning; Marcotraj as well 

30 

31 

32class MPT(object): 

33 def __init__( 

34 self, 

35 # traj: List[NDArray[np.int_]], 

36 traj: NDArray[np.int_], 

37 tlag: int, 

38 feature_traj: NDArray[float] = None, 

39 contact_threshold=0.45, 

40 feature_type=np.float64, 

41 macrostate_thresholds: tuple = (0.005, 0.5), 

42 limits=None, 

43 quiet=False, 

44 frame_length=0.2, 

45 ): 

46 self.traj = traj 

47 self.tlag = tlag 

48 self.pop_thr, self.q_min = macrostate_thresholds 

49 self.limits = limits 

50 tmat, states = mh.msm.estimate_markov_model( 

51 utils.get_multi_state_traj(self.traj, self.limits), 

52 self.tlag, 

53 ) 

54 self.tmat = tmat.astype(np.float64) 

55 _, self.pop = np.unique(self.traj, return_counts=True) 

56 self.n_states = len(states) 

57 self.quiet = quiet 

58 if feature_traj is not None: 58 ↛ 65line 58 didn't jump to line 65 because the condition on line 58 was always true

59 self._add_feature( 

60 feature_traj, 

61 contact_threshold=contact_threshold, 

62 feature_type=feature_type, 

63 ) 

64 else: 

65 self._add_feature(np.ones((traj.shape, 1))) 

66 

67 self.Z = None 

68 self._timescales = None 

69 self._linkage = None 

70 self._macro_pop = None 

71 self._tree = None 

72 self._shannon_entropy = None 

73 self._davies_bouldin_index = None 

74 self._gmrq = None 

75 self._reference = None 

76 self._topology_file = None 

77 self._xtc_trajectory_file = None 

78 self._rmsd = None 

79 self._n_i = None 

80 self._macro_micro_feature = None 

81 self.xtc_stride = None 

82 self.frame_length = frame_length 

83 

84 def mpt( 

85 self, 

86 kernel: Callable[ 

87 [NDArray[float], NDArray[np.int_], NDArray[np.bool_]], 

88 [np.int_, np.int_, NDArray[np.bool_]], 

89 ] = kernel_module.MPTKernel(), 

90 feature_kernel=None, 

91 n: int = 1, 

92 ) -> (NDArray[float], NDArray[np.int_]): 

93 """Perform MPT""" 

94 self.n_runs = n 

95 self.kernel = kernel 

96 self.feature_kernel = feature_kernel 

97 # n: number of macrostates 

98 

99 self.Z = np.zeros((self.n_runs, self.n_states - 1, 4), dtype=np.float64) 

100 self.full_pop = np.zeros((self.n_runs, 2 * self.n_states - 1), dtype=np.uint32) 

101 if self.quiet: 101 ↛ 104line 101 didn't jump to line 104 because the condition on line 101 was always true

102 iter = range(self.n_runs) 

103 else: 

104 print("Clustering ...") 

105 iter = tqdm(range(self.n_runs)) 

106 for i in iter: 

107 self.Z[i], self.full_pop[i] = core.cluster( 

108 self.tmat, 

109 self.pop, 

110 kernel=self.kernel, 

111 feature_kernel=self.feature_kernel, 

112 ) 

113 self.assign_macrostates() 

114 

115 def _add_feature( 

116 self, 

117 feature_traj: NDArray[float], 

118 contact_threshold=0.45, 

119 feature_type=np.float64, 

120 ): 

121 """ 

122 Add feature data to instance 

123 

124 feature_traj (NDArray(float)): frames x features 

125 """ 

126 if feature_traj.shape[0] != self.traj.shape[0]: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true

127 raise ValueError( 

128 "feature_traj must have the same length as the microstate trajectory (mpp.traj)" 

129 ) 

130 if feature_traj.ndim == 2: 130 ↛ 133line 130 didn't jump to line 133 because the condition on line 130 was always true

131 self.multi_feature_traj = feature_traj.astype(feature_type) 

132 else: 

133 raise ValueError("feature_traj must be 2 D") 

134 

135 self.contact_threshold = contact_threshold 

136 self.multi_feature_traj_bool = self.multi_feature_traj < self.contact_threshold 

137 self.feature_traj = self.multi_feature_traj_bool.mean(axis=1) 

138 self.feature = np.zeros(self.n_states, dtype=feature_type) 

139 for i in range(self.n_states): 

140 self.feature[i] = self.feature_traj[self.traj == i].mean() 

141 

142 def assign_macrostates(self, macrotraj_type=np.uint8): 

143 """Assign microstates to macrostates and collect associate data""" 

144 self.macrostate_feature = [] 

145 self.macrostate_multi_feature = [] 

146 self.macrostate_assignment = [] 

147 self.macrostates_map = [] 

148 self.macro_tmat = [] 

149 self.macrotraj = np.zeros( 

150 (self.n_runs, self.traj.shape[0]), dtype=macrotraj_type 

151 ) 

152 self.n_macrostates = [] 

153 

154 if self.quiet: 154 ↛ 157line 154 didn't jump to line 157 because the condition on line 154 was always true

155 iter = range(self.n_runs) 

156 else: 

157 print("Assigning macrostates ...") 

158 iter = tqdm(range(self.n_runs)) 

159 for n_i in iter: 

160 self.macrostate_assignment.append( 

161 utils.get_macrostate_assignment_from_tree(self.tree[n_i]) 

162 ) 

163 

164 # Calculate other macrostate related values 

165 self.macrostates_map.append( 

166 np.zeros(self.n_states, dtype=self.traj.dtype.type) 

167 ) 

168 mas, mis = np.where(self.macrostate_assignment[-1] == 1) 

169 self.macrostates_map[-1][mis] = mas 

170 self.macro_tmat.append( 

171 utils.macro_tmat(self.tmat, self.macrostate_assignment[-1], self.pop) 

172 ) 

173 self.macrotraj[n_i] = utils.translate_traj( 

174 self.traj, self.macrostates_map[-1] 

175 ) 

176 self.n_macrostates.append(self.macrostate_assignment[-1].shape[0]) 

177 self.macrostate_feature.append( 

178 [ 

179 self.feature_traj[np.where(self.macrotraj[n_i] == i)].mean() 

180 for i in np.arange(self.n_macrostates[-1]) 

181 ] 

182 ) 

183 self.macrostate_multi_feature.append( 

184 [ 

185 self.multi_feature_traj_bool[ 

186 np.where(self.macrotraj[n_i] == i) 

187 ].mean(axis=0) 

188 for i in np.arange(self.n_macrostates[-1], dtype=int) 

189 ] 

190 ) 

191 

192 def gpcca(self, n_macrostates, macrotraj_type=np.uint8): 

193 self.gpcca = gp.GPCCA(self.tmat, method="krylov") 

194 self.gpcca.optimize(n_macrostates) 

195 

196 self.n_runs = 1 

197 self.n_macrostates = [n_macrostates] 

198 

199 gma = self.gpcca.macrostate_assignment 

200 gmt = np.zeros(self.traj.shape, dtype=self.traj.dtype) 

201 gmf = np.empty(self.n_macrostates[0]) 

202 for i in range(self.n_macrostates[0]): 

203 gmt[np.where(np.isin(self.traj, np.where(gma == i)[0]))[0]] = i + 1 

204 gmf[i] = self.feature_traj[gmt == i + 1].mean() 

205 

206 order = np.argsort(gmf)[::-1] 

207 new_states = np.empty(self.n_macrostates[0], dtype=macrotraj_type) 

208 new_states[order] = np.arange(self.n_macrostates[0], dtype=macrotraj_type) 

209 self.macrostates_map = [np.empty(gma.shape, dtype=macrotraj_type)] 

210 for i in range(self.n_macrostates[0]): 

211 self.macrostates_map[0][np.where(gma == i)] = new_states[i] 

212 

213 self.macrostate_assignment = [ 

214 np.full((self.n_macrostates[0], self.macrostates_map[0].shape[0]), False) 

215 ] 

216 self.macrostate_assignment[0][ 

217 self.macrostates_map[0], 

218 np.arange(self.macrostates_map[0].shape[0], dtype=int), 

219 ] = True 

220 self.macrostate_feature = [gmf[order]] 

221 self.macrotraj = np.empty( 

222 (self.n_runs, self.traj.shape[0]), dtype=macrotraj_type 

223 ) 

224 self.macrotraj[0] = utils.translate_traj(self.traj, self.macrostates_map[0]) 

225 self.macro_tmat = [ 

226 utils.macro_tmat(self.tmat, self.macrostate_assignment[0], self.pop) 

227 ] 

228 

229 # Create mock Z and mock full_pop for Sankey plot 

230 # After implementation remove mock Z.npy file in run.py 

231 self.Z = np.zeros((self.n_runs, self.n_states - 1, 4), dtype=np.float64) 

232 self.full_pop = np.zeros((self.n_runs, 2 * self.n_states - 1), dtype=np.uint32) 

233 self.full_pop[0, : self.n_states] = self.pop 

234 

235 last_merged = self.n_states 

236 merge = 0 

237 for macrostate in range(self.n_macrostates[0]): 

238 microstates = np.where(self.macrostates_map[0] == macrostate)[0] 

239 origin = microstates[0] 

240 if microstates.shape[0] > 1: 

241 for target in microstates[1:]: 

242 intermediate_state = self.n_states + merge 

243 self.full_pop[0, intermediate_state] = self.full_pop[ 

244 0, [origin, target] 

245 ].sum() 

246 self.Z[0, merge] = ( 

247 origin, 

248 target, 

249 0.2, 

250 self.full_pop[0, intermediate_state], 

251 ) 

252 origin = intermediate_state 

253 merge += 1 

254 

255 if macrostate > 0: 

256 intermediate_state = self.n_states + merge 

257 target = last_merged 

258 self.full_pop[0, intermediate_state] = self.full_pop[ 

259 0, [origin, target] 

260 ].sum() 

261 self.Z[0, merge] = ( 

262 origin, 

263 target, 

264 0.9, 

265 self.full_pop[0, intermediate_state], 

266 ) 

267 last_merged = intermediate_state 

268 merge += 1 

269 else: 

270 last_merged = origin 

271 

272 self.tree 

273 self.pop_thr = 0 

274 self.q_min = 0.5 

275 

276 def __add__(self, other): 

277 """'+' operator is used to calculate similarity""" 

278 if self.n_runs == 1 and other.n_runs >= 1: 278 ↛ 280line 278 didn't jump to line 280 because the condition on line 278 was never true

279 # reference 

280 ref = self 

281 # stochastic lumping 

282 sto = other 

283 elif other.n_runs == 1 and self.n_runs >= 1: 283 ↛ 287line 283 didn't jump to line 287 because the condition on line 283 was always true

284 ref = other 

285 sto = self 

286 else: 

287 raise ValueError("The reference lumping must have exactly one run.") 

288 return ref, sto, utils.similarity(ref, sto) 

289 

290 @property 

291 def n_i(self): 

292 """Sets self.n_i to the lumping with longest first implied timescale.""" 

293 if self._n_i is None: 

294 if self.n_runs > 1: 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true

295 self._n_i = np.argmax(self.timescales[:, 0]) 

296 else: 

297 self._n_i = 0 

298 return self._n_i 

299 

300 @property 

301 def timescales(self): 

302 """The timescales property.""" 

303 if self._timescales is None: 303 ↛ 305line 303 didn't jump to line 305 because the condition on line 303 was always true

304 self.calc_timescales() 

305 return self._timescales 

306 

307 def calc_timescales(self, ntimescales=3, dtype=np.float32): 

308 """Calculate implied timescales""" 

309 self._timescales = np.zeros((self.n_runs, ntimescales), dtype=dtype) 

310 for i, traj in enumerate(self.macrotraj): 

311 self._timescales[i, :] = mh.msm.implied_timescales( 

312 utils.get_multi_state_traj(traj, self.limits), 

313 [self.tlag], 

314 ntimescales=ntimescales, 

315 )[0] 

316 

317 def save_macrotraj(self, out): 

318 header = ( 

319 f"# Created by MPT class\n" 

320 f"# Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n" 

321 f"# Trajectory contains {self.n_macrostates[self.n_i]} states and {self.macrotraj.shape[1]} frames.\n" 

322 f"# Trajectory index: {self.n_i}\n" 

323 ) 

324 np.savetxt(out, self.macrotraj[self.n_i], fmt="%.0f", header=header) 

325 

326 def save_Z(self, out, n_i="all"): 

327 """Save Z matrix""" 

328 if not out.endswith(".npy"): 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true

329 out += ".npy" 

330 

331 if n_i == "all": 331 ↛ 333line 331 didn't jump to line 333 because the condition on line 331 was always true

332 np.save(out, self.Z) 

333 elif isinstance(n_i, Iterable): 

334 np.save(out, self.Z[n_i]) 

335 elif isinstance(n_i, int): 

336 np.save(out, self.Z[n_i : n_i + 1]) 

337 else: 

338 raise ValueError("n_i must be 'all', Iterable or int.") 

339 

340 def from_Z(self, Z): 

341 """Load Z matrix""" 

342 if isinstance(Z, np.ndarray): 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true

343 self.Z = Z 

344 elif os.path.exists(Z): 344 ↛ 347line 344 didn't jump to line 347 because the condition on line 344 was always true

345 self.Z = np.load(Z) 

346 else: 

347 raise ValueError("Z must be a numpy array or a .npy file.") 

348 

349 self.n_runs = self.Z.shape[0] 

350 # n: number of macrostates 

351 tmat, states = mh.msm.estimate_markov_model( 

352 utils.get_multi_state_traj(self.traj, self.limits), 

353 self.tlag, 

354 ) 

355 self.tmat = tmat.astype(float) 

356 _, self.pop = np.unique(self.traj, return_counts=True) 

357 self.n_states = len(states) 

358 self.full_pop = np.zeros((self.n_runs, 2 * self.n_states - 1), dtype=np.uint32) 

359 self.full_pop[:, : self.n_states] = self.pop 

360 self.full_pop[:, self.n_states :] = self.Z[:, :, 3] 

361 

362 self.assign_macrostates() 

363 

364 @property 

365 def linkage(self): 

366 """The linkage property.""" 

367 if self._linkage is None: 

368 self._linkage = utils.Z_to_linkage(self.Z[self.n_i]) 

369 return self._linkage 

370 

371 @property 

372 def macro_pop(self): 

373 """The macro_pop property.""" 

374 if self._macro_pop is None: 

375 self._macro_pop = [] 

376 for j, ma in enumerate(self.macrostate_assignment): 

377 self._macro_pop.append( 

378 np.zeros(ma.shape[0], dtype=self.full_pop.dtype.type) 

379 ) 

380 for i, m in enumerate(ma): 

381 self._macro_pop[-1][i] = self.full_pop[j, : self.n_states][ 

382 m.astype(bool) 

383 ].sum() 

384 return self._macro_pop 

385 

386 @property 

387 def tree(self): 

388 """The tree property.""" 

389 if self._tree is None: 

390 self._tree = [] 

391 for z, pop in zip(self.Z, self.full_pop): 

392 self._tree.append(self.build_tree(z, pop)) 

393 return self._tree 

394 

395 def build_tree(self, Z, full_pop): 

396 """Build tree using BinaryTreeNode and return root""" 

397 macrostate_thresholds = (self.pop_thr, self.q_min) 

398 n = Z.shape[0] + 1 

399 nodes = {} 

400 for i, (state, target_state, q, pop) in enumerate(Z): 

401 state = int(state) 

402 target_state = int(target_state) 

403 if state not in nodes: 

404 nodes[state] = core.BinaryTreeNode( 

405 state, 

406 self.tmat, 

407 population=full_pop[state], 

408 q=q, 

409 macrostate_thresholds=macrostate_thresholds, 

410 ) 

411 if target_state not in nodes: 

412 nodes[target_state] = core.BinaryTreeNode( 

413 target_state, 

414 self.tmat, 

415 population=full_pop[target_state], 

416 q=q, 

417 macrostate_thresholds=macrostate_thresholds, 

418 ) 

419 nodes[n + i] = core.BinaryTreeNode( 

420 n + i, self.tmat, q=q, macrostate_thresholds=macrostate_thresholds 

421 ) 

422 nodes[n + i].left = nodes[state] 

423 nodes[n + i].right = nodes[target_state] 

424 for node in nodes[n + i].leaves: 

425 node.feature = self.feature[node.name] 

426 return nodes[n + i] 

427 

428 @property 

429 def shannon_entropy(self): 

430 """The shannon_entropy property.""" 

431 if self._shannon_entropy is None: 431 ↛ 435line 431 didn't jump to line 435 because the condition on line 431 was always true

432 self._shannon_entropy = np.zeros(self.n_runs) 

433 for i, pop in enumerate(self.macro_pop): 

434 self._shannon_entropy[i] = utils.shannon_entropy(pop) 

435 return self._shannon_entropy 

436 

437 @property 

438 def davies_bouldin_index(self): 

439 """The davies_bouldin_index property.""" 

440 if self._davies_bouldin_index is None: 440 ↛ 446line 440 didn't jump to line 446 because the condition on line 440 was always true

441 self._davies_bouldin_index = np.zeros(self.n_runs) 

442 for i in range(self.n_runs): 

443 self._davies_bouldin_index[i] = davies_bouldin_score( 

444 self.multi_feature_traj, self.macrotraj[i] 

445 ) 

446 return self._davies_bouldin_index 

447 

448 @property 

449 def gmrq(self): 

450 """The gmrq property.""" 

451 if self._gmrq is None: 451 ↛ 453line 451 didn't jump to line 453 because the condition on line 451 was always true

452 self._gmrq = utils.gmrq(self.macro_tmat) 

453 return self._gmrq 

454 

455 @property 

456 def reference(self): 

457 """The reference property.""" 

458 if self._reference is None: 458 ↛ 470line 458 didn't jump to line 470 because the condition on line 458 was always true

459 k = kernel_module.MPTKernel() 

460 self._reference = MPT( 

461 self.traj, 

462 self.tlag, 

463 self.multi_feature_traj, 

464 contact_threshold=self.contact_threshold, 

465 macrostate_thresholds=(self.pop_thr, self.q_min), 

466 limits=self.limits, 

467 quiet=True, 

468 ) 

469 self._reference.mpt(k) 

470 return self._reference 

471 

472 @property 

473 def traj(self): 

474 """The microstate trajectory - 0-based.""" 

475 return self._traj 

476 

477 @traj.setter 

478 def traj(self, value): 

479 if value.min() == 1: 

480 value -= 1 

481 warnings.warn("1-based trajectory was shifted to 0-based.") 

482 if np.unique(value).shape[0] > value.max() + 1: 482 ↛ 483line 482 didn't jump to line 483 because the condition on line 482 was never true

483 raise ValueError("The state numbering in the trajectory is not continuous") 

484 if value.max() < 2**7: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true

485 traj_type = np.uint8 

486 elif value.max() < 2**15: 486 ↛ 489line 486 didn't jump to line 489 because the condition on line 486 was always true

487 traj_type = np.uint16 

488 else: 

489 traj_type = np.uint32 

490 self._traj = value.astype(traj_type) 

491 

492 def print_rel(self): 

493 for l, i in [ 

494 ( 

495 "Implied Timescale: ", 

496 self.timescales[0, 0] / self.reference.timescales[0, 0], 

497 ), 

498 ("GMRQ: ", self.gmrq[0] / self.reference.gmrq[0]), 

499 ( 

500 "DBI: ", 

501 self.davies_bouldin_index()[0] 

502 / self.reference.davies_bouldin_index()[0], 

503 ), 

504 ("H: ", self.shannon_entropy[0] / self.reference.shannon_entropy[0]), 

505 ]: 

506 print(l + f"{i:.2f}") 

507 

508 @property 

509 def topology_file(self): 

510 """The topology_file property.""" 

511 if self._topology_file is None: 511 ↛ 512line 511 didn't jump to line 512 because the condition on line 511 was never true

512 raise ValueError("No topology file set.") 

513 return self._topology_file 

514 

515 @topology_file.setter 

516 def topology_file(self, value): 

517 if os.path.isfile(value): 517 ↛ 520line 517 didn't jump to line 520 because the condition on line 517 was always true

518 self._topology_file = value 

519 else: 

520 raise FileNotFoundError(f"No such file: {value}") 

521 

522 @property 

523 def xtc_trajectory_file(self): 

524 """The xtc_trajectory_file property.""" 

525 if self._xtc_trajectory_file is None: 525 ↛ 526line 525 didn't jump to line 526 because the condition on line 525 was never true

526 raise ValueError("No xtc trajectory file set.") 

527 return self._xtc_trajectory_file 

528 

529 @xtc_trajectory_file.setter 

530 def xtc_trajectory_file(self, value): 

531 if os.path.isfile(value): 531 ↛ 534line 531 didn't jump to line 534 because the condition on line 531 was always true

532 self._xtc_trajectory_file = value 

533 else: 

534 raise FileNotFoundError(f"No such file: {value}") 

535 

536 @property 

537 def rmsd(self): 

538 """The rmsd property.""" 

539 if self._rmsd is None: 

540 self._rmsd, self.mean_frames = utils.calc_rmsd(self, quiet=self.quiet) 

541 return self._rmsd 

542 

543 @property 

544 def frame_length(self): 

545 """The frame_length property. Frame length in ns.""" 

546 return self._frame_length 

547 

548 @frame_length.setter 

549 def frame_length(self, value): 

550 self._frame_length = value 

551 

552 @property 

553 def macro_micro_feature(self): 

554 """Assign macrostate feature values to corresponding microstates""" 

555 if self._macro_micro_feature is None: 555 ↛ 564line 555 didn't jump to line 564 because the condition on line 555 was always true

556 self._macro_micro_feature = np.zeros( 

557 (self.n_states, self.n_runs), dtype=self.feature_traj.dtype.type 

558 ) 

559 for i, (ma, mf) in enumerate( 

560 zip(self.macrostate_assignment, self.macrostate_feature) 

561 ): 

562 for j, mb in enumerate(ma.astype(bool)): 

563 self._macro_micro_feature[mb, i] = mf[j] 

564 return self._macro_micro_feature 

565 

566 def save_rmsd(self, out): 

567 np.save(out, self.rmsd) 

568 

569 def load_rmsd(self, f_name): 

570 self._rmsd = np.load(f_name) 

571 

572 def write_pdbs(self, out): 

573 utils.write_pdbs( 

574 out, 

575 np.log(self.rmsd), 

576 self.topology_file, 

577 self.xtc_trajectory_file, 

578 self.mean_frames, 

579 ) 

580 

581 def rmsd_sharpness(self): 

582 return ( 

583 self.rmsd.mean(axis=1) * self.macro_pop[self.n_i] 

584 ).sum() / self.macro_pop[self.n_i].sum() 

585 

586 def draw_random_frames_indices(self, out=None, n=20): 

587 """ 

588 Draw n random frames for each macrostate 

589 

590 out (str): Path to directory where to save the .random[n] files 

591 n (int): number of frames to draw randomly 

592 """ 

593 drawn_frames = np.empty((self.n_macrostates[self.n_i], n), dtype=int) 

594 for state in np.arange(self.n_macrostates[self.n_i]): 

595 frames_in_state = np.where(self.macrotraj[self.n_i] == state)[0] 

596 drawn_frames[state] = np.random.choice( 

597 frames_in_state, size=n, replace=False 

598 ) 

599 if self.xtc_stride is not None: 599 ↛ 600line 599 didn't jump to line 600 because the condition on line 599 was never true

600 drawn_frames *= self.xtc_stride 

601 if out: 

602 Path(os.path.join(out)).mkdir(parents=True, exist_ok=True) 

603 for s, i in enumerate(drawn_frames): 

604 # Path(os.path.join(out, f"{s+1:02d}")).mkdir(parents=True, exist_ok=True) 

605 # np.savetxt(os.path.join(out, f"{s+1:02d}", f".frames.ndx"), i, fmt="%.0f", header="[frames]") 

606 np.savetxt( 

607 os.path.join(out, f"{s + 1:02d}.ndx"), 

608 i, 

609 fmt="%.0f", 

610 header="[frames]", 

611 ) 

612 else: 

613 return drawn_frames 

614 

615 def draw_random_frames(self, out, n=20): 

616 """ 

617 Draw n random frames for each macrostate 

618 

619 out (str): Path to directory where to save the pdb files 

620 n (int): number of frames to draw randomly 

621 """ 

622 for state in np.arange(self.n_macrostates[self.n_i]): 

623 frames_in_state = np.where(self.macrotraj[self.n_i] == state)[0] 

624 drawn_frames = np.random.choice(frames_in_state, size=n, replace=False) 

625 for i, frame in enumerate(drawn_frames): 

626 f = md.load_xtc( 

627 self.xtc_trajectory_file, 

628 top=self.topology_file, 

629 frame=frame, 

630 ) 

631 f.save_pdb(os.path.join(out, f"S{state}_{i:02d}.pdb")) 

632 

633 def get_best_defined_contacts(self, n=3): 

634 """Calculate the variance for each contact in each macrostate.""" 

635 contacts = np.zeros((self.n_macrostates[self.n_i], n), dtype=int) 

636 for i in range(self.n_macrostates[self.n_i]): 

637 contacts[i] = np.argsort( 

638 np.var( 

639 self.multi_feature_traj[self.macrotraj[self.n_i] == i], 

640 axis=0, 

641 ) 

642 )[:n] 

643 return contacts 

644 

645 def get_least_moving_residues(self, contact_index_file, n=3): 

646 contact_indices = np.loadtxt(contact_index_file, dtype=int) 

647 contacts = self.get_best_defined_contacts(n) 

648 least_moving_residues = [] 

649 for c in contacts: 

650 least_moving_residues.append(np.unique(contact_indices[c].flatten())) 

651 return least_moving_residues 

652 

653 def write_least_moving_residues(self, contact_index_file, out, n=3): 

654 if contact_index_file != "none": 654 ↛ 662line 654 didn't jump to line 662 because the condition on line 654 was always true

655 least_moving_residues = self.get_least_moving_residues( 

656 contact_index_file, n=n 

657 ) 

658 with open(out, "w") as f: 

659 for residues in least_moving_residues: 

660 f.write(f"{' '.join(residues.astype(str))}\n") 

661 else: 

662 with open(out, "w") as f: 

663 f.write("") 

664 

665 ### PLOT METHODS ######################################################### 

666 

667 def plot(self, out: str, scale=1, offset=0): 

668 """Plot dendrogram""" 

669 plot.plot_tree( 

670 self.tree[self.n_i], 

671 self.macrostate_assignment[self.n_i], 

672 out, 

673 scale=scale, 

674 offset=offset, 

675 ) 

676 

677 def plot_implied_timescales(self, out, use_ref=True, scale=1): 

678 """ 

679 out: File to write plot 

680 use_ref: If it for reference trajectory should be plotted 

681 scale: scaling factor for plot 

682 """ 

683 if use_ref: 683 ↛ 684line 683 didn't jump to line 684 because the condition on line 683 was never true

684 ref_traj = self.reference.macrotraj[0] 

685 else: 

686 ref_traj = self.traj 

687 

688 macrotraj = utils.get_multi_state_traj(self.macrotraj[self.n_i], self.limits) 

689 

690 dtlag = max(1, int(1 / self.frame_length)) 

691 plot.plot_implied_timescales( 

692 [ref_traj, macrotraj], 

693 np.arange(1, 4.5 * self.tlag + dtlag, dtlag, dtype=int), 

694 out, 

695 frame_length=self.frame_length, 

696 first_ref=True, 

697 scale=scale, 

698 use_ref=use_ref, 

699 ntimescales=self.timescales.shape[1], 

700 ) 

701 

702 def plot_macro_feature(self, out, ref=None): 

703 """ 

704 Plot histogram of feature distribution. 

705 

706 micro_feature (np.ndarray, NxR): N microstates, R runs, holds feature 

707 values of respective macrostate 

708 out (str): file to save the plot 

709 ref (list[tuple]): list of 

710 - macrostate_assignment 

711 - macrostate_feature 

712 - color 

713 - label 

714 of the clusterings that should be shown explicitly. 

715 """ 

716 plot.plot_macro_feature( 

717 self.macro_micro_feature, out, self.reference if ref is None else ref 

718 ) 

719 

720 def plot_rmsd(self, out, helices=None): 

721 plot.plot_rmsd(self.rmsd, self.macro_pop[self.n_i], helices, out) 

722 

723 def plot_delta_rmsd(self, out, helices=None): 

724 plot.plot_delta_rmsd(self.rmsd, self.macro_pop[self.n_i], helices, out) 

725 

726 def plot_contact_rep(self, cluster_file, out, scale=1): 

727 plot.contact_rep( 

728 self.multi_feature_traj, 

729 cluster_file, 

730 self.macrotraj[self.n_i], 

731 out, 

732 utils.get_grid_format(self.n_macrostates[self.n_i]), 

733 scale=scale, 

734 ) 

735 

736 def plot_relative_implied_timescales(self, out): 

737 plot.relative_implied_timescales(self, out) 

738 

739 def plot_ck_test(self, out): 

740 plot.chapman_kolmogorov(self, out, self.frame_length) 

741 

742 def plot_state_network(self, out): 

743 plot.state_network(self, out) 

744 

745 def plot_stochastic_state_similarity(self, out): 

746 plot.evaluate_stochastic_clustering(self, self.reference, out) 

747 

748 def plot_transition_matrix(self, out): 

749 plot.transition_matrix(self.macro_tmat[self.n_i], out) 

750 

751 def plot_transition_time(self, out): 

752 plot.transition_time( 

753 self.macro_tmat[self.n_i], 

754 out, 

755 tlag=self.tlag, 

756 frame_length=self.frame_length, 

757 ) 

758 

759 def plot_graph(self, out, u=0, f=0): 

760 draw_knetwork( 

761 self.macrotraj[self.n_i], self.tlag, self.feature_traj, out, u=u, f=f 

762 ) 

763 

764 def plot_sankey(self, out, ax=None, scale=1): 

765 plot.plot_sankey(self, self.reference, out, ax=ax, scale=scale) 

766 

767 def plot_macrotraj(self, out, row_length=0.2): 

768 plot.plot_state_trajectory( 

769 self.macrotraj[self.n_i], 

770 out, 

771 row_length=row_length, 

772 frame_length=self.frame_length, 

773 )