Coverage for MPT/utils.py: 93%

213 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-11 11:01 +0200

1""" 

2utils.py 

3======== 

4 

5Utilities for MPT. 

6""" 

7 

8import os 

9import numpy as np 

10from numba import njit 

11from itertools import combinations 

12from typing import List 

13from numpy.typing import NDArray 

14import scipy as scy 

15import mdtraj as md 

16from tqdm import tqdm 

17 

18 

19def translate_traj(traj: NDArray[np.int_], map: NDArray[np.int_]) -> NDArray[np.int_]: 

20 """ 

21 Transform trajectory to other state names. 

22 

23 traj (NDArray[np.int_]): original state trajectory 

24 map (NDArray[np.int_]): index is original state, value at that position is 

25 new value 

26 

27 returns translated trajectory 

28 """ 

29 macrostates = np.unique(map) 

30 if map.max() < 2**8: 30 ↛ 32line 30 didn't jump to line 32 because the condition on line 30 was always true

31 macrotraj_type = np.uint8 

32 elif map.max() < 2**16: 

33 macrotraj_type = np.uint16 

34 else: 

35 macrotraj_type = np.uint32 

36 

37 macrotraj = np.zeros(traj.shape, dtype=macrotraj_type) 

38 for macrostate in macrostates: 

39 macrotraj[np.isin(traj, np.where(map == macrostate)[0])] = macrostate 

40 return macrotraj 

41 

42 

43def macro_tmat(tmat, macrostate_assignment, pop): 

44 """ 

45 transform a transition matrix from microstates to macrostates 

46 """ 

47 n_macrostates = macrostate_assignment.shape[0] 

48 m_tmat = np.zeros((n_macrostates, n_macrostates), dtype=tmat.dtype.type) 

49 for i, ms in enumerate(macrostate_assignment): 

50 for j, other_ms in enumerate(macrostate_assignment): 

51 m_tmat[i, j] = (tmat[ms][:, other_ms] * np.expand_dims(pop[ms], -1)).sum() 

52 return m_tmat / m_tmat.sum(axis=0) 

53 

54 

55def get_grid_format(n): 

56 sqrt = np.sqrt(n) 

57 y = int(sqrt) 

58 x = y 

59 if x < sqrt: 

60 x += 2 

61 if (x - 1) * y >= n: 

62 x -= 1 

63 return x, y 

64 

65 

66def gmrq(tmat): 

67 # Generalized matrix Rayleigh quotient 

68 q = np.zeros(len(tmat)) 

69 for i, t in enumerate(tmat): 

70 val, vec = np.linalg.eig(t) 

71 q[i] = val[:3].sum() 

72 return q 

73 

74 

75def Z_to_linkage(Z): 

76 l = Z[:, :3].copy() 

77 for i, row in enumerate(l): 

78 mask = np.where(l[:, :2] == i + Z.shape[0] + 1) 

79 l[:, :2][mask] = row[1] 

80 l[:, :2] += 1 

81 return l 

82 

83 

84def linkage_to_Z(linkage, pop): 

85 linkage = np.array(linkage) 

86 n_states = linkage.shape[0] + 1 

87 Z = np.zeros((linkage.shape[0], 4)) 

88 Z[:, :3] = linkage[:, :3] 

89 Z[:, :2] -= 1 

90 

91 full_pop = np.zeros(2 * n_states - 1, dtype=pop.dtype.type) 

92 full_pop[:n_states] = pop 

93 for i, l in enumerate(linkage[:-1]): 

94 new_state = n_states + i 

95 old_state = Z[i, 1] 

96 full_pop[new_state] = full_pop[[Z[i, 0].astype(int), int(old_state)]].sum() 

97 Z[i + 1 :, :2][np.where(Z[i + 1 :, :2] == old_state)] = new_state 

98 full_pop[-1] = full_pop[Z[i + 1, :2].astype(int)].sum() 

99 Z[:, 3] = full_pop[n_states:] 

100 return Z, full_pop 

101 

102 

103def merge_states(tmat, states, new_state, full_pop, reset_states=True): 

104 full_pop[new_state] = full_pop[states].sum() 

105 

106 tmat[new_state] = (tmat[states] * full_pop[states, np.newaxis]).sum( 

107 axis=0 

108 ) / full_pop[new_state] 

109 

110 tmat[:, new_state] = tmat[:, states].sum(axis=1) 

111 if reset_states: 

112 tmat[:, states] = 0 

113 tmat[states, :] = 0 

114 return tmat, full_pop 

115 

116 

117def calc_full_tmat(tmat, pop, Z): 

118 """Calculate full tmat for a give Z matrix""" 

119 # Ensure that Z is 3D 

120 if Z.ndim == 2: 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true

121 Z = Z.reshape((1, *Z.shape)) 

122 

123 # Initialize full_tmat and full_pop 

124 n_states = tmat.shape[0] 

125 full_dim = 2 * n_states - 1 

126 n_runs = Z.shape[0] 

127 full_tmat = np.empty((n_runs, full_dim, full_dim)) 

128 full_pop = np.empty((n_runs, full_dim), dtype=np.uint32) 

129 

130 full_tmat[:, :n_states, :n_states] = tmat 

131 full_pop[:, :n_states] = pop 

132 

133 for run, z in enumerate(Z): 

134 for i, (origin, target) in enumerate(z[:, :2].astype(int)): 

135 full_tmat[run], full_pop[run] = merge_states( 

136 full_tmat[run], 

137 [origin, target], 

138 n_states + i, 

139 full_pop[run], 

140 reset_states=False, 

141 ) 

142 return full_tmat, full_pop 

143 

144 

145def Z_to_mask(Z): 

146 """ 

147 Calculate the mask for each lumping step. 

148 Z (Nx4): Z matrix 

149 """ 

150 n1 = Z.shape[0] 

151 n = n1 + 1 

152 m = np.zeros((n1, 2 * n - 1), dtype=bool) 

153 m[0, :n] = True 

154 for k, (i, j) in enumerate(Z[:-1, :2].astype(int)): 

155 m[k + 1] = m[k] 

156 m[k + 1, [i, j]] = False 

157 m[k + 1, k + n] = True 

158 return m 

159 

160 

161def get_macrostate_assignment_from_tree(tree): 

162 macrostate_order = [l.assigned_macrostate.name for l in tree.leaves] 

163 macrostates = {l.assigned_macrostate for l in tree.leaves} 

164 q_ma = np.array([(m.name, m.feature) for m in macrostates]) 

165 ma_order = np.argsort(q_ma[:, 1])[::-1] 

166 # Dict to translate from n+i numbering to actual macrostate numbers. 

167 full2real = {f: r for r, f in enumerate(q_ma[ma_order, 0])} 

168 macrostate_assignment = np.full((len(macrostates), len(macrostate_order)), False) 

169 macrostate_assignment[ 

170 [full2real[m] for m in macrostate_order], np.arange(len(macrostate_order)) 

171 ] = True 

172 reorder_microstates = np.zeros(len(macrostate_order), dtype=int) 

173 reorder_microstates[[l.name for l in tree.leaves]] = np.arange( 

174 len(macrostate_order) 

175 ) 

176 return macrostate_assignment[:, reorder_microstates] 

177 

178 

179def similarity(ref, sto): 

180 """Return similarity of two clusterings""" 

181 # Similarity matrix 

182 S = np.zeros((3, ref.n_macrostates[0], sto.n_runs)) 

183 

184 for n_i in range(sto.n_runs): 

185 ref_ma = ref.macrostate_assignment[0].astype(bool) 

186 sto_ma = sto.macrostate_assignment[n_i].astype(bool) 

187 for i in range(ref.n_macrostates[0]): 

188 for j in range(sto.n_macrostates[n_i]): 

189 intersect = ( 

190 np.logical_and(ref_ma[i], sto_ma[j]) 

191 * ref.full_pop[0, : ref.n_states] 

192 ).sum() 

193 union = ( 

194 np.logical_or(ref_ma[i], sto_ma[j]) 

195 * ref.full_pop[0, : ref.n_states] 

196 ).sum() 

197 # union 

198 S[0, i, n_i] = max(S[0, i, n_i], intersect / union) 

199 # reference 

200 S[1, i, n_i] = max( 

201 S[1, i, n_i], 

202 intersect / (ref_ma[i] * ref.full_pop[0, : ref.n_states]).sum(), 

203 ) 

204 # clustering 

205 S[2, i, n_i] = max( 

206 S[2, i, n_i], 

207 intersect / (sto_ma[j] * ref.full_pop[0, : ref.n_states]).sum(), 

208 ) 

209 return S 

210 

211 

212def shannon_entropy(p): 

213 p = p / sum(p) 

214 return -(p * np.log(p)).sum() / np.log(p.shape[0]) 

215 

216 

217def weighting_function(dq): 

218 if dq.shape[0] == 1: 

219 return np.exp(-dq) 

220 # sigma = np.sqrt(np.var(dq)) 

221 sigma2 = np.var(dq) 

222 return np.exp(-(dq**2) / (2 * sigma2)) 

223 

224 

225### RMSD ##################################################################### 

226 

227 

228def load_traj(topfile, trajfile, atom_selection="all", frames=None, stride=None): 

229 print("Loading trajectory...") 

230 top = md.load_topology(topfile) 

231 if frames is None: 231 ↛ 236line 231 didn't jump to line 236 because the condition on line 231 was always true

232 return md.load_xtc( 

233 trajfile, top=top, atom_indices=top.select(atom_selection), stride=stride 

234 ) 

235 else: 

236 return md.join( 

237 [ 

238 md.load_xtc( 

239 trajfile, 

240 top=top, 

241 atom_indices=top.select(atom_selection), 

242 frame=frame, 

243 ) 

244 for frame in frames 

245 ] 

246 ) 

247 

248 

249def load_mean_frames(topfile, trajfile, mean_frames, dt=0.1): 

250 top = md.load_topology(topfile) 

251 idxs = [int(frame.time[0]) / dt for frame in mean_frames] 

252 traj = md.join([md.load_xtc(trajfile, top=top, frame=frame) for frame in idxs]) 

253 return traj 

254 

255 

256def find_mean_frame(traj): 

257 mean_rmsd = np.array([estimate_rmsd(frame, traj) for frame in traj]) 

258 mean_frame = traj[np.argmin(mean_rmsd)] 

259 return mean_frame 

260 

261 

262def estimate_rmsd(frame, traj): 

263 rmsd = md.rmsd( 

264 traj, 

265 frame, 

266 ) 

267 return np.mean(rmsd) 

268 

269 

270def align_trajectory_to_reference(trajectory, reference): 

271 """ 

272 Aligns each frame in the trajectory array to the reference frame using the Kabsch algorithm. 

273 

274 Parameters: 

275 - trajectory: numpy array of shape (N, 35, 3) where N is the number of frames. 

276 - reference: numpy array of shape (1, 35, 3) representing the reference points. 

277 

278 Returns: 

279 - aligned_trajectory: numpy array of shape (N, 35, 3) where each frame is aligned to the reference. 

280 """ 

281 

282 # Extract the reference frame (since reference is of shape (1, 35, 3), we need to squeeze it to (35, 3)) 

283 reference_frame = reference.squeeze() 

284 

285 # Compute the centroid (mean) of the reference points 

286 reference_centroid = np.mean(reference_frame, axis=0) 

287 

288 # Center the reference points by subtracting the centroid 

289 centered_reference = reference_frame - reference_centroid 

290 

291 # Initialize the array to store the aligned trajectory 

292 aligned_trajectory = np.zeros_like(trajectory) 

293 

294 # Iterate through each frame in the trajectory 

295 for i in range(trajectory.shape[0]): 

296 # Extract the current frame 

297 current_frame = trajectory[i] 

298 

299 # Compute the centroid of the current frame 

300 frame_centroid = np.mean(current_frame, axis=0) 

301 

302 # Center the current frame by subtracting the centroid 

303 centered_frame = current_frame - frame_centroid 

304 

305 # Compute the covariance matrix 

306 H = np.dot(centered_frame.T, centered_reference) 

307 

308 # Compute the Singular Value Decomposition (SVD) 

309 U, S, Vt = np.linalg.svd(H) 

310 

311 # Compute the optimal rotation matrix 

312 R = np.dot(Vt.T, U.T) 

313 

314 # Handle special reflection case where the determinant of R is -1 

315 if np.linalg.det(R) < 0: 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true

316 Vt[2, :] *= -1 

317 R = np.dot(Vt.T, U.T) 

318 

319 # Apply the rotation to the centered frame 

320 rotated_frame = np.dot(centered_frame, R) 

321 

322 # Re-add the reference centroid to align the trajectory in the reference coordinate system 

323 aligned_trajectory[i] = rotated_frame + reference_centroid 

324 

325 return aligned_trajectory 

326 

327 

328def calc_var(ref, traj): 

329 """Calculate RMSD""" 

330 aligned_trajectory = align_trajectory_to_reference(traj, ref) 

331 d = ((aligned_trajectory - ref) ** 2).sum(axis=2) 

332 return d.mean(axis=0) 

333 

334 

335def opt_num_batches(n): 

336 return int(np.cbrt(n**2 / 2)) 

337 

338 

339def calc_rmsd(mpt, quiet=False): 

340 t = load_traj( 

341 mpt.topology_file, 

342 mpt.xtc_trajectory_file, 

343 atom_selection="name CA", 

344 stride=mpt.xtc_stride, 

345 ) 

346 mean_frames = [] 

347 rmsd = np.empty([mpt.n_macrostates[mpt.n_i], t.n_atoms]) 

348 for j in range(mpt.n_macrostates[mpt.n_i]): 

349 if not quiet: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true

350 print(f"Process macrostate {j}") 

351 m = mpt.macrotraj[mpt.n_i] == j 

352 tm = t[m] 

353 m_frames = [] 

354 n_batches = opt_num_batches(mpt.macro_pop[mpt.n_i][j]) 

355 for i in tqdm(range(n_batches)) if not quiet else range(n_batches): 

356 m_frames.append(find_mean_frame(tm[i::n_batches])) 

357 mean_frames.append(find_mean_frame(md.join(m_frames))) 

358 rmsd[j] = calc_var(mean_frames[j].xyz, tm.xyz) 

359 return rmsd, mean_frames 

360 

361 

362def find_state_lengths(arr): 

363 # Lists to store unique states and their consecutive counts 

364 unique_states = [] 

365 lengths = [] 

366 

367 # Initialize the first state and its count 

368 current_state = arr[0] 

369 count = 1 

370 

371 # Iterate over the array from the second element onward 

372 for value in arr[1:]: 

373 if value == current_state: 

374 # Increment count if the state is the same 

375 count += 1 

376 else: 

377 # Append the state and its count when a new state is encountered 

378 unique_states.append(current_state) 

379 lengths.append(count) 

380 # Update the current state and reset count 

381 current_state = value 

382 count = 1 

383 

384 # Append the last state and its count 

385 unique_states.append(current_state) 

386 lengths.append(count) 

387 

388 return np.array(unique_states), np.array(lengths) 

389 

390 

391def get_multi_state_traj(trajs: np.ndarray, limits: np.ndarray): 

392 """Load trajectory containing several concatenated trajectories""" 

393 if limits is None: 

394 return trajs 

395 trajectories = [] 

396 current_position = 0 

397 for limit in limits: 

398 trajectories.append(trajs[current_position : int(current_position + limit)]) 

399 current_position += limit 

400 return trajectories 

401 

402 

403def fnc_from_multi_feature_traj(multi_feature_traj): 

404 return (multi_feature_traj <= 0.45).mean(axis=1)