Coverage for MPT/utils.py: 93%
213 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 11:01 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 11:01 +0200
1"""
2utils.py
3========
5Utilities for MPT.
6"""
8import os
9import numpy as np
10from numba import njit
11from itertools import combinations
12from typing import List
13from numpy.typing import NDArray
14import scipy as scy
15import mdtraj as md
16from tqdm import tqdm
19def translate_traj(traj: NDArray[np.int_], map: NDArray[np.int_]) -> NDArray[np.int_]:
20 """
21 Transform trajectory to other state names.
23 traj (NDArray[np.int_]): original state trajectory
24 map (NDArray[np.int_]): index is original state, value at that position is
25 new value
27 returns translated trajectory
28 """
29 macrostates = np.unique(map)
30 if map.max() < 2**8: 30 ↛ 32line 30 didn't jump to line 32 because the condition on line 30 was always true
31 macrotraj_type = np.uint8
32 elif map.max() < 2**16:
33 macrotraj_type = np.uint16
34 else:
35 macrotraj_type = np.uint32
37 macrotraj = np.zeros(traj.shape, dtype=macrotraj_type)
38 for macrostate in macrostates:
39 macrotraj[np.isin(traj, np.where(map == macrostate)[0])] = macrostate
40 return macrotraj
43def macro_tmat(tmat, macrostate_assignment, pop):
44 """
45 transform a transition matrix from microstates to macrostates
46 """
47 n_macrostates = macrostate_assignment.shape[0]
48 m_tmat = np.zeros((n_macrostates, n_macrostates), dtype=tmat.dtype.type)
49 for i, ms in enumerate(macrostate_assignment):
50 for j, other_ms in enumerate(macrostate_assignment):
51 m_tmat[i, j] = (tmat[ms][:, other_ms] * np.expand_dims(pop[ms], -1)).sum()
52 return m_tmat / m_tmat.sum(axis=0)
55def get_grid_format(n):
56 sqrt = np.sqrt(n)
57 y = int(sqrt)
58 x = y
59 if x < sqrt:
60 x += 2
61 if (x - 1) * y >= n:
62 x -= 1
63 return x, y
66def gmrq(tmat):
67 # Generalized matrix Rayleigh quotient
68 q = np.zeros(len(tmat))
69 for i, t in enumerate(tmat):
70 val, vec = np.linalg.eig(t)
71 q[i] = val[:3].sum()
72 return q
75def Z_to_linkage(Z):
76 l = Z[:, :3].copy()
77 for i, row in enumerate(l):
78 mask = np.where(l[:, :2] == i + Z.shape[0] + 1)
79 l[:, :2][mask] = row[1]
80 l[:, :2] += 1
81 return l
84def linkage_to_Z(linkage, pop):
85 linkage = np.array(linkage)
86 n_states = linkage.shape[0] + 1
87 Z = np.zeros((linkage.shape[0], 4))
88 Z[:, :3] = linkage[:, :3]
89 Z[:, :2] -= 1
91 full_pop = np.zeros(2 * n_states - 1, dtype=pop.dtype.type)
92 full_pop[:n_states] = pop
93 for i, l in enumerate(linkage[:-1]):
94 new_state = n_states + i
95 old_state = Z[i, 1]
96 full_pop[new_state] = full_pop[[Z[i, 0].astype(int), int(old_state)]].sum()
97 Z[i + 1 :, :2][np.where(Z[i + 1 :, :2] == old_state)] = new_state
98 full_pop[-1] = full_pop[Z[i + 1, :2].astype(int)].sum()
99 Z[:, 3] = full_pop[n_states:]
100 return Z, full_pop
103def merge_states(tmat, states, new_state, full_pop, reset_states=True):
104 full_pop[new_state] = full_pop[states].sum()
106 tmat[new_state] = (tmat[states] * full_pop[states, np.newaxis]).sum(
107 axis=0
108 ) / full_pop[new_state]
110 tmat[:, new_state] = tmat[:, states].sum(axis=1)
111 if reset_states:
112 tmat[:, states] = 0
113 tmat[states, :] = 0
114 return tmat, full_pop
117def calc_full_tmat(tmat, pop, Z):
118 """Calculate full tmat for a give Z matrix"""
119 # Ensure that Z is 3D
120 if Z.ndim == 2: 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true
121 Z = Z.reshape((1, *Z.shape))
123 # Initialize full_tmat and full_pop
124 n_states = tmat.shape[0]
125 full_dim = 2 * n_states - 1
126 n_runs = Z.shape[0]
127 full_tmat = np.empty((n_runs, full_dim, full_dim))
128 full_pop = np.empty((n_runs, full_dim), dtype=np.uint32)
130 full_tmat[:, :n_states, :n_states] = tmat
131 full_pop[:, :n_states] = pop
133 for run, z in enumerate(Z):
134 for i, (origin, target) in enumerate(z[:, :2].astype(int)):
135 full_tmat[run], full_pop[run] = merge_states(
136 full_tmat[run],
137 [origin, target],
138 n_states + i,
139 full_pop[run],
140 reset_states=False,
141 )
142 return full_tmat, full_pop
145def Z_to_mask(Z):
146 """
147 Calculate the mask for each lumping step.
148 Z (Nx4): Z matrix
149 """
150 n1 = Z.shape[0]
151 n = n1 + 1
152 m = np.zeros((n1, 2 * n - 1), dtype=bool)
153 m[0, :n] = True
154 for k, (i, j) in enumerate(Z[:-1, :2].astype(int)):
155 m[k + 1] = m[k]
156 m[k + 1, [i, j]] = False
157 m[k + 1, k + n] = True
158 return m
161def get_macrostate_assignment_from_tree(tree):
162 macrostate_order = [l.assigned_macrostate.name for l in tree.leaves]
163 macrostates = {l.assigned_macrostate for l in tree.leaves}
164 q_ma = np.array([(m.name, m.feature) for m in macrostates])
165 ma_order = np.argsort(q_ma[:, 1])[::-1]
166 # Dict to translate from n+i numbering to actual macrostate numbers.
167 full2real = {f: r for r, f in enumerate(q_ma[ma_order, 0])}
168 macrostate_assignment = np.full((len(macrostates), len(macrostate_order)), False)
169 macrostate_assignment[
170 [full2real[m] for m in macrostate_order], np.arange(len(macrostate_order))
171 ] = True
172 reorder_microstates = np.zeros(len(macrostate_order), dtype=int)
173 reorder_microstates[[l.name for l in tree.leaves]] = np.arange(
174 len(macrostate_order)
175 )
176 return macrostate_assignment[:, reorder_microstates]
179def similarity(ref, sto):
180 """Return similarity of two clusterings"""
181 # Similarity matrix
182 S = np.zeros((3, ref.n_macrostates[0], sto.n_runs))
184 for n_i in range(sto.n_runs):
185 ref_ma = ref.macrostate_assignment[0].astype(bool)
186 sto_ma = sto.macrostate_assignment[n_i].astype(bool)
187 for i in range(ref.n_macrostates[0]):
188 for j in range(sto.n_macrostates[n_i]):
189 intersect = (
190 np.logical_and(ref_ma[i], sto_ma[j])
191 * ref.full_pop[0, : ref.n_states]
192 ).sum()
193 union = (
194 np.logical_or(ref_ma[i], sto_ma[j])
195 * ref.full_pop[0, : ref.n_states]
196 ).sum()
197 # union
198 S[0, i, n_i] = max(S[0, i, n_i], intersect / union)
199 # reference
200 S[1, i, n_i] = max(
201 S[1, i, n_i],
202 intersect / (ref_ma[i] * ref.full_pop[0, : ref.n_states]).sum(),
203 )
204 # clustering
205 S[2, i, n_i] = max(
206 S[2, i, n_i],
207 intersect / (sto_ma[j] * ref.full_pop[0, : ref.n_states]).sum(),
208 )
209 return S
212def shannon_entropy(p):
213 p = p / sum(p)
214 return -(p * np.log(p)).sum() / np.log(p.shape[0])
217def weighting_function(dq):
218 if dq.shape[0] == 1:
219 return np.exp(-dq)
220 # sigma = np.sqrt(np.var(dq))
221 sigma2 = np.var(dq)
222 return np.exp(-(dq**2) / (2 * sigma2))
225### RMSD #####################################################################
228def load_traj(topfile, trajfile, atom_selection="all", frames=None, stride=None):
229 print("Loading trajectory...")
230 top = md.load_topology(topfile)
231 if frames is None: 231 ↛ 236line 231 didn't jump to line 236 because the condition on line 231 was always true
232 return md.load_xtc(
233 trajfile, top=top, atom_indices=top.select(atom_selection), stride=stride
234 )
235 else:
236 return md.join(
237 [
238 md.load_xtc(
239 trajfile,
240 top=top,
241 atom_indices=top.select(atom_selection),
242 frame=frame,
243 )
244 for frame in frames
245 ]
246 )
249def load_mean_frames(topfile, trajfile, mean_frames, dt=0.1):
250 top = md.load_topology(topfile)
251 idxs = [int(frame.time[0]) / dt for frame in mean_frames]
252 traj = md.join([md.load_xtc(trajfile, top=top, frame=frame) for frame in idxs])
253 return traj
256def find_mean_frame(traj):
257 mean_rmsd = np.array([estimate_rmsd(frame, traj) for frame in traj])
258 mean_frame = traj[np.argmin(mean_rmsd)]
259 return mean_frame
262def estimate_rmsd(frame, traj):
263 rmsd = md.rmsd(
264 traj,
265 frame,
266 )
267 return np.mean(rmsd)
270def align_trajectory_to_reference(trajectory, reference):
271 """
272 Aligns each frame in the trajectory array to the reference frame using the Kabsch algorithm.
274 Parameters:
275 - trajectory: numpy array of shape (N, 35, 3) where N is the number of frames.
276 - reference: numpy array of shape (1, 35, 3) representing the reference points.
278 Returns:
279 - aligned_trajectory: numpy array of shape (N, 35, 3) where each frame is aligned to the reference.
280 """
282 # Extract the reference frame (since reference is of shape (1, 35, 3), we need to squeeze it to (35, 3))
283 reference_frame = reference.squeeze()
285 # Compute the centroid (mean) of the reference points
286 reference_centroid = np.mean(reference_frame, axis=0)
288 # Center the reference points by subtracting the centroid
289 centered_reference = reference_frame - reference_centroid
291 # Initialize the array to store the aligned trajectory
292 aligned_trajectory = np.zeros_like(trajectory)
294 # Iterate through each frame in the trajectory
295 for i in range(trajectory.shape[0]):
296 # Extract the current frame
297 current_frame = trajectory[i]
299 # Compute the centroid of the current frame
300 frame_centroid = np.mean(current_frame, axis=0)
302 # Center the current frame by subtracting the centroid
303 centered_frame = current_frame - frame_centroid
305 # Compute the covariance matrix
306 H = np.dot(centered_frame.T, centered_reference)
308 # Compute the Singular Value Decomposition (SVD)
309 U, S, Vt = np.linalg.svd(H)
311 # Compute the optimal rotation matrix
312 R = np.dot(Vt.T, U.T)
314 # Handle special reflection case where the determinant of R is -1
315 if np.linalg.det(R) < 0: 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true
316 Vt[2, :] *= -1
317 R = np.dot(Vt.T, U.T)
319 # Apply the rotation to the centered frame
320 rotated_frame = np.dot(centered_frame, R)
322 # Re-add the reference centroid to align the trajectory in the reference coordinate system
323 aligned_trajectory[i] = rotated_frame + reference_centroid
325 return aligned_trajectory
328def calc_var(ref, traj):
329 """Calculate RMSD"""
330 aligned_trajectory = align_trajectory_to_reference(traj, ref)
331 d = ((aligned_trajectory - ref) ** 2).sum(axis=2)
332 return d.mean(axis=0)
335def opt_num_batches(n):
336 return int(np.cbrt(n**2 / 2))
339def calc_rmsd(mpt, quiet=False):
340 t = load_traj(
341 mpt.topology_file,
342 mpt.xtc_trajectory_file,
343 atom_selection="name CA",
344 stride=mpt.xtc_stride,
345 )
346 mean_frames = []
347 rmsd = np.empty([mpt.n_macrostates[mpt.n_i], t.n_atoms])
348 for j in range(mpt.n_macrostates[mpt.n_i]):
349 if not quiet: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true
350 print(f"Process macrostate {j}")
351 m = mpt.macrotraj[mpt.n_i] == j
352 tm = t[m]
353 m_frames = []
354 n_batches = opt_num_batches(mpt.macro_pop[mpt.n_i][j])
355 for i in tqdm(range(n_batches)) if not quiet else range(n_batches):
356 m_frames.append(find_mean_frame(tm[i::n_batches]))
357 mean_frames.append(find_mean_frame(md.join(m_frames)))
358 rmsd[j] = calc_var(mean_frames[j].xyz, tm.xyz)
359 return rmsd, mean_frames
362def find_state_lengths(arr):
363 # Lists to store unique states and their consecutive counts
364 unique_states = []
365 lengths = []
367 # Initialize the first state and its count
368 current_state = arr[0]
369 count = 1
371 # Iterate over the array from the second element onward
372 for value in arr[1:]:
373 if value == current_state:
374 # Increment count if the state is the same
375 count += 1
376 else:
377 # Append the state and its count when a new state is encountered
378 unique_states.append(current_state)
379 lengths.append(count)
380 # Update the current state and reset count
381 current_state = value
382 count = 1
384 # Append the last state and its count
385 unique_states.append(current_state)
386 lengths.append(count)
388 return np.array(unique_states), np.array(lengths)
391def get_multi_state_traj(trajs: np.ndarray, limits: np.ndarray):
392 """Load trajectory containing several concatenated trajectories"""
393 if limits is None:
394 return trajs
395 trajectories = []
396 current_position = 0
397 for limit in limits:
398 trajectories.append(trajs[current_position : int(current_position + limit)])
399 current_position += limit
400 return trajectories
403def fnc_from_multi_feature_traj(multi_feature_traj):
404 return (multi_feature_traj <= 0.45).mean(axis=1)