Coverage for MPP/utils.py: 93%

206 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-22 11:26 +0200

1""" 

2utils.py 

3======== 

4 

5Utilities for MPP. 

6""" 

7 

8import numpy as np 

9from numpy.typing import NDArray 

10import mdtraj as md 

11from tqdm import tqdm 

12 

13 

14def translate_trajectory( 

15 trajectory: NDArray[np.int_], map: NDArray[np.int_] 

16) -> NDArray[np.int_]: 

17 """ 

18 Transform trajectory to other state names. 

19 

20 trajectory (NDArray[np.int_]): original state trajectory 

21 map (NDArray[np.int_]): index is original state, value at that position is 

22 new value 

23 

24 returns translated trajectory 

25 """ 

26 macrostates = np.unique(map) 

27 if map.max() < 2**8: 27 ↛ 29line 27 didn't jump to line 29 because the condition on line 27 was always true

28 macrostate_trajectory_type = np.uint8 

29 elif map.max() < 2**16: 

30 macrostate_trajectory_type = np.uint16 

31 else: 

32 macrostate_trajectory_type = np.uint32 

33 

34 macrostate_trajectory = np.zeros(trajectory.shape, dtype=macrostate_trajectory_type) 

35 for macrostate in macrostates: 

36 macrostate_trajectory[np.isin(trajectory, np.where(map == macrostate)[0])] = ( 

37 macrostate 

38 ) 

39 return macrostate_trajectory 

40 

41 

42def macrostate_tmat(tmat, macrostate_assignment, pop): 

43 """ 

44 transform a transition matrix from microstates to macrostates 

45 """ 

46 n_macrostates = macrostate_assignment.shape[0] 

47 m_tmat = np.zeros((n_macrostates, n_macrostates), dtype=tmat.dtype.type) 

48 for i, ms in enumerate(macrostate_assignment): 

49 for j, other_ms in enumerate(macrostate_assignment): 

50 m_tmat[i, j] = (tmat[ms][:, other_ms] * np.expand_dims(pop[ms], -1)).sum() 

51 return m_tmat / m_tmat.sum(axis=0) 

52 

53 

54def get_grid_format(n): 

55 sqrt = np.sqrt(n) 

56 y = int(sqrt) 

57 x = y 

58 if x < sqrt: 

59 x += 2 

60 if (x - 1) * y >= n: 

61 x -= 1 

62 return x, y 

63 

64 

65def gmrq(tmat): 

66 # Generalized matrix Rayleigh quotient 

67 q = np.zeros(len(tmat)) 

68 for i, t in enumerate(tmat): 

69 val, vec = np.linalg.eig(t) 

70 q[i] = val[:3].sum() 

71 return q 

72 

73 

74def Z_to_linkage(Z): 

75 linkage = Z[:, :3].copy() 

76 for i, row in enumerate(linkage): 

77 mask = np.where(linkage[:, :2] == i + Z.shape[0] + 1) 

78 linkage[:, :2][mask] = row[1] 

79 linkage[:, :2] += 1 

80 return linkage 

81 

82 

83def linkage_to_Z(linkage, pop): 

84 linkage = np.array(linkage) 

85 n_states = linkage.shape[0] + 1 

86 Z = np.zeros((linkage.shape[0], 4)) 

87 Z[:, :3] = linkage[:, :3] 

88 Z[:, :2] -= 1 

89 

90 full_pop = np.zeros(2 * n_states - 1, dtype=pop.dtype.type) 

91 full_pop[:n_states] = pop 

92 for i in range(len(linkage[:-1])): 

93 new_state = n_states + i 

94 old_state = Z[i, 1] 

95 full_pop[new_state] = full_pop[[Z[i, 0].astype(int), int(old_state)]].sum() 

96 Z[i + 1 :, :2][np.where(Z[i + 1 :, :2] == old_state)] = new_state 

97 full_pop[-1] = full_pop[Z[i + 1, :2].astype(int)].sum() 

98 Z[:, 3] = full_pop[n_states:] 

99 return Z, full_pop 

100 

101 

102def merge_states(tmat, states, new_state, full_pop, reset_states=True): 

103 full_pop[new_state] = full_pop[states].sum() 

104 

105 tmat[new_state] = (tmat[states] * full_pop[states, np.newaxis]).sum( 

106 axis=0 

107 ) / full_pop[new_state] 

108 

109 tmat[:, new_state] = tmat[:, states].sum(axis=1) 

110 if reset_states: 

111 tmat[:, states] = 0 

112 tmat[states, :] = 0 

113 return tmat, full_pop 

114 

115 

116def calc_full_tmat(tmat, pop, Z): 

117 """Calculate full tmat for a give Z matrix""" 

118 # Ensure that Z is 3D 

119 if Z.ndim == 2: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 Z = Z.reshape((1, *Z.shape)) 

121 

122 # Initialize full_tmat and full_pop 

123 n_states = tmat.shape[0] 

124 full_dim = 2 * n_states - 1 

125 n_runs = Z.shape[0] 

126 full_tmat = np.empty((n_runs, full_dim, full_dim)) 

127 full_pop = np.empty((n_runs, full_dim), dtype=np.uint32) 

128 

129 full_tmat[:, :n_states, :n_states] = tmat 

130 full_pop[:, :n_states] = pop 

131 

132 for run, z in enumerate(Z): 

133 for i, (origin, target) in enumerate(z[:, :2].astype(int)): 

134 full_tmat[run], full_pop[run] = merge_states( 

135 full_tmat[run], 

136 [origin, target], 

137 n_states + i, 

138 full_pop[run], 

139 reset_states=False, 

140 ) 

141 return full_tmat, full_pop 

142 

143 

144def Z_to_mask(Z): 

145 """ 

146 Calculate the mask for each lumping step. 

147 Z (Nx4): Z matrix 

148 """ 

149 n1 = Z.shape[0] 

150 n = n1 + 1 

151 m = np.zeros((n1, 2 * n - 1), dtype=bool) 

152 m[0, :n] = True 

153 for k, (i, j) in enumerate(Z[:-1, :2].astype(int)): 

154 m[k + 1] = m[k] 

155 m[k + 1, [i, j]] = False 

156 m[k + 1, k + n] = True 

157 return m 

158 

159 

160def get_macrostate_assignment_from_tree(tree): 

161 macrostate_order = [l.assigned_macrostate.name for l in tree.leaves] 

162 macrostates = {l.assigned_macrostate for l in tree.leaves} 

163 q_ma = np.array([(m.name, m.feature) for m in macrostates]) 

164 ma_order = np.argsort(q_ma[:, 1])[::-1] 

165 # Dict to translate from n+i numbering to actual macrostate numbers. 

166 full2real = {f: r for r, f in enumerate(q_ma[ma_order, 0])} 

167 macrostate_assignment = np.full((len(macrostates), len(macrostate_order)), False) 

168 macrostate_assignment[ 

169 [full2real[m] for m in macrostate_order], np.arange(len(macrostate_order)) 

170 ] = True 

171 reorder_microstates = np.zeros(len(macrostate_order), dtype=int) 

172 reorder_microstates[[l.name for l in tree.leaves]] = np.arange( 

173 len(macrostate_order) 

174 ) 

175 return macrostate_assignment[:, reorder_microstates] 

176 

177 

178def similarity(ref, sto): 

179 """Return similarity of two clusterings""" 

180 # Similarity matrix 

181 S = np.zeros((3, ref.n_macrostates[0], sto.n_runs)) 

182 

183 for n_i in range(sto.n_runs): 

184 ref_ma = ref.macrostate_assignment[0].astype(bool) 

185 sto_ma = sto.macrostate_assignment[n_i].astype(bool) 

186 for i in range(ref.n_macrostates[0]): 

187 for j in range(sto.n_macrostates[n_i]): 

188 intersect = ( 

189 np.logical_and(ref_ma[i], sto_ma[j]) 

190 * ref.full_pop[0, : ref.n_states] 

191 ).sum() 

192 union = ( 

193 np.logical_or(ref_ma[i], sto_ma[j]) 

194 * ref.full_pop[0, : ref.n_states] 

195 ).sum() 

196 # union 

197 S[0, i, n_i] = max(S[0, i, n_i], intersect / union) 

198 # reference 

199 S[1, i, n_i] = max( 

200 S[1, i, n_i], 

201 intersect / (ref_ma[i] * ref.full_pop[0, : ref.n_states]).sum(), 

202 ) 

203 # lumping 

204 S[2, i, n_i] = max( 

205 S[2, i, n_i], 

206 intersect / (sto_ma[j] * ref.full_pop[0, : ref.n_states]).sum(), 

207 ) 

208 return S 

209 

210 

211def shannon_entropy(p): 

212 p = p / sum(p) 

213 return -(p * np.log(p)).sum() / np.log(p.shape[0]) 

214 

215 

216def weighting_function(dq): 

217 if dq.shape[0] == 1: 

218 return np.exp(-dq) 

219 # sigma = np.sqrt(np.var(dq)) 

220 sigma2 = np.var(dq) 

221 return np.exp(-(dq**2) / (2 * sigma2)) 

222 

223 

224### RMSD ##################################################################### 

225 

226 

227def load_trajectory( 

228 topfile, trajectoryfile, atom_selection="all", frames=None, stride=None 

229): 

230 print("Loading trajectory...") 

231 top = md.load_topology(topfile) 

232 if frames is None: 232 ↛ 240line 232 didn't jump to line 240 because the condition on line 232 was always true

233 return md.load_xtc( 

234 trajectoryfile, 

235 top=top, 

236 atom_indices=top.select(atom_selection), 

237 stride=stride, 

238 ) 

239 else: 

240 return md.join( 

241 [ 

242 md.load_xtc( 

243 trajectoryfile, 

244 top=top, 

245 atom_indices=top.select(atom_selection), 

246 frame=frame, 

247 ) 

248 for frame in frames 

249 ] 

250 ) 

251 

252 

253def load_mean_frames(topfile, trajectoryfile, mean_frames, dt=0.1): 

254 top = md.load_topology(topfile) 

255 idxs = [int(frame.time[0]) / dt for frame in mean_frames] 

256 trajectory = md.join( 

257 [md.load_xtc(trajectoryfile, top=top, frame=frame) for frame in idxs] 

258 ) 

259 return trajectory 

260 

261 

262def find_mean_frame(trajectory): 

263 mean_rmsd = np.array([estimate_rmsd(frame, trajectory) for frame in trajectory]) 

264 mean_frame = trajectory[np.argmin(mean_rmsd)] 

265 return mean_frame 

266 

267 

268def estimate_rmsd(frame, trajectory): 

269 rmsd = md.rmsd( 

270 trajectory, 

271 frame, 

272 ) 

273 return np.mean(rmsd) 

274 

275 

276def align_trajectory_to_reference(trajectory, reference): 

277 """ 

278 Aligns each frame in the trajectory array to the reference frame using the Kabsch algorithm. 

279 

280 Parameters: 

281 - trajectory: numpy array of shape (N, 35, 3) where N is the number of frames. 

282 - reference: numpy array of shape (1, 35, 3) representing the reference points. 

283 

284 Returns: 

285 - aligned_trajectory: numpy array of shape (N, 35, 3) where each frame is aligned to the reference. 

286 """ 

287 

288 # Extract the reference frame (since reference is of shape (1, 35, 3), we need to squeeze it to (35, 3)) 

289 reference_frame = reference.squeeze() 

290 

291 # Compute the centroid (mean) of the reference points 

292 reference_centroid = np.mean(reference_frame, axis=0) 

293 

294 # Center the reference points by subtracting the centroid 

295 centered_reference = reference_frame - reference_centroid 

296 

297 # Initialize the array to store the aligned trajectory 

298 aligned_trajectory = np.zeros_like(trajectory) 

299 

300 # Iterate through each frame in the trajectory 

301 for i in range(trajectory.shape[0]): 

302 # Extract the current frame 

303 current_frame = trajectory[i] 

304 

305 # Compute the centroid of the current frame 

306 frame_centroid = np.mean(current_frame, axis=0) 

307 

308 # Center the current frame by subtracting the centroid 

309 centered_frame = current_frame - frame_centroid 

310 

311 # Compute the covariance matrix 

312 H = np.dot(centered_frame.T, centered_reference) 

313 

314 # Compute the Singular Value Decomposition (SVD) 

315 U, S, Vt = np.linalg.svd(H) 

316 

317 # Compute the optimal rotation matrix 

318 R = np.dot(Vt.T, U.T) 

319 

320 # Handle special reflection case where the determinant of R is -1 

321 if np.linalg.det(R) < 0: 321 ↛ 322line 321 didn't jump to line 322 because the condition on line 321 was never true

322 Vt[2, :] *= -1 

323 R = np.dot(Vt.T, U.T) 

324 

325 # Apply the rotation to the centered frame 

326 rotated_frame = np.dot(centered_frame, R) 

327 

328 # Re-add the reference centroid to align the trajectory in the reference coordinate system 

329 aligned_trajectory[i] = rotated_frame + reference_centroid 

330 

331 return aligned_trajectory 

332 

333 

334def calc_var(ref, trajectory): 

335 """Calculate RMSD""" 

336 aligned_trajectory = align_trajectory_to_reference(trajectory, ref) 

337 d = ((aligned_trajectory - ref) ** 2).sum(axis=2) 

338 return d.mean(axis=0) 

339 

340 

341def opt_num_batches(n): 

342 return int(np.cbrt(n**2 / 2)) 

343 

344 

345def calc_rmsd(lumping, quiet=False): 

346 t = load_trajectory( 

347 lumping.topology_file, 

348 lumping.xtc_trajectory_file, 

349 atom_selection="name CA", 

350 stride=lumping.xtc_stride, 

351 ) 

352 mean_frames = [] 

353 rmsd = np.empty([lumping.n_macrostates[lumping.n_i], t.n_atoms]) 

354 for j in range(lumping.n_macrostates[lumping.n_i]): 

355 if not quiet: 355 ↛ 356line 355 didn't jump to line 356 because the condition on line 355 was never true

356 print(f"Process macrostate {j}") 

357 m = lumping.macrostate_trajectory[lumping.n_i] == j 

358 tm = t[m] 

359 m_frames = [] 

360 n_batches = opt_num_batches(lumping.macrostate_population[lumping.n_i][j]) 

361 for i in tqdm(range(n_batches)) if not quiet else range(n_batches): 

362 m_frames.append(find_mean_frame(tm[i::n_batches])) 

363 mean_frames.append(find_mean_frame(md.join(m_frames))) 

364 rmsd[j] = calc_var(mean_frames[j].xyz, tm.xyz) 

365 return rmsd, mean_frames 

366 

367 

368def find_state_lengths(arr): 

369 # Lists to store unique states and their consecutive counts 

370 unique_states = [] 

371 lengths = [] 

372 

373 # Initialize the first state and its count 

374 current_state = arr[0] 

375 count = 1 

376 

377 # Iterate over the array from the second element onward 

378 for value in arr[1:]: 

379 if value == current_state: 

380 # Increment count if the state is the same 

381 count += 1 

382 else: 

383 # Append the state and its count when a new state is encountered 

384 unique_states.append(current_state) 

385 lengths.append(count) 

386 # Update the current state and reset count 

387 current_state = value 

388 count = 1 

389 

390 # Append the last state and its count 

391 unique_states.append(current_state) 

392 lengths.append(count) 

393 

394 return np.array(unique_states), np.array(lengths) 

395 

396 

397def get_multi_state_trajectory(trajectories: np.ndarray, limits: np.ndarray): 

398 """Load trajectory containing several concatenated trajectories""" 

399 if limits is None: 

400 return trajectories 

401 trajectory_collection = [] 

402 current_position = 0 

403 for limit in limits: 

404 trajectory_collection.append( 

405 trajectories[current_position : int(current_position + limit)] 

406 ) 

407 current_position += limit 

408 return trajectory_collection