Coverage for src / molecular_simulations / analysis / ipSAE.py: 93%
234 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 10:07 -0600
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 10:07 -0600
1from itertools import permutations
2import numpy as np
3from numpy import vectorize
4from pathlib import Path
5import polars as pl
6from typing import Any, Union
8PathLike = Union[Path, str]
9OptPath = Union[Path, str, None]
11class ipSAE:
12 """
13 Compute the interaction prediction Score from Aligned Errors for a model.
14 Adapted from https://doi.org/10.1101/2025.02.10.637595. Currently supports
15 only outputs which provide plddt and pae data which limits us to Boltz and
16 AlphaFold.
18 Arguments:
19 structure_file (PathLike): Path to PDB/CIF model.
20 plddt_file (PathLike): Path to plddt npy file.
21 pae_file (PathLike): Path to pae npy file.
22 out_path (PathLike | None): Defaults to None. Path for outputs, or if None,
23 will use the parent path from the plddt file.
24 """
25 def __init__(self,
26 structure_file: PathLike,
27 plddt_file: PathLike,
28 pae_file: PathLike,
29 out_path: OptPath=None):
30 self.parser = ModelParser(structure_file)
31 self.plddt_file = Path(plddt_file)
32 self.pae_file = Path(pae_file)
34 self.path = Path(out_path) if out_path is not None else self.plddt_file.parent
35 self.path.mkdir(exist_ok=True)
37 def parse_structure_file(self) -> None:
38 """
39 Runs parser to read in structure file and extract relevant details.
41 Returns:
42 None
43 """
44 self.parser.parse_structure_file()
45 self.parser.classify_chains()
46 self.coordinates = np.vstack([res['coor'] for res in self.parser.residues])
47 self.token_array = np.array(self.parser.token_mask, dtype=bool)
49 def prepare_scorer(self) -> None:
50 """
51 Prepares scorer for computing various scores.
53 Returns:
54 None
55 """
56 chains = np.array(self.parser.chains)
57 chain_types = self.parser.chain_types
58 residue_types = np.array([res['res'] for res in self.parser.residues])
60 self.scorer = ScoreCalculator(chains=chains,
61 chain_pair_type=chain_types,
62 n_residues=residue_types)
64 def run(self) -> None:
65 """
66 Main logic of class. Parses structure file, computes distogram, unpacks
67 pLDDT and PAE, feeds data to scorer and saves out scores.
69 Returns:
70 None
71 """
72 self.parse_structure_file()
74 distances = self.coordinates[:, np.newaxis, :] - self.coordinates[np.newaxis, :, :]
75 distances = np.sqrt((distances ** 2).sum(axis=2))
76 pLDDT = self.load_pLDDT_file()
77 PAE = self.load_PAE_file()
79 self.prepare_scorer()
80 self.scorer.compute_scores(distances, pLDDT, PAE)
82 self.scores = self.scorer.scores
83 self.save_scores()
85 def save_scores(self) -> None:
86 """
87 Saves scores dataframe to a parquet file.
89 Returns:
90 None
91 """
92 self.scores.write_parquet(self.path / 'ipSAE_scores.parquet')
94 def load_pLDDT_file(self) -> np.ndarray:
95 """
96 Loads pLDDT file and scales data by 100.
98 Returns:
99 (np.ndarray): Scaled pLDDT array.
100 """
101 data = np.load(str(self.plddt_file))
102 pLDDT_arr = np.array(data['plddt'] * 100.)
104 return pLDDT_arr
106 def load_PAE_file(self) -> np.ndarray:
107 """
108 Loads PAE file and returns data.
110 Returns:
111 (np.ndarray): Array of PAE values.
112 """
113 data = np.load(str(self.pae_file))['pae']
114 return data
116class ScoreCalculator:
117 """
118 Computes various model quality scores including: pDockQ, pDockQ2, LIS, ipTM and
119 the ipSAE score.
121 Arguments:
122 chains (np.ndarray): Array of chainIDs.
123 chain_pair_type (dict[str, str]): Dictionary mapping of chainID to chain type.
124 n_residues (int): Number of residues total in structure.
125 pdockq_cutoff (float): Defaults to 8.0 Å.
126 pae_cutoff (float): Defaults to 12.0 Å.
127 dist_cutoff (float): Defaults to 10.0 Å.
128 """
129 def __init__(self,
130 chains: np.ndarray,
131 chain_pair_type: dict[str, str],
132 n_residues: int,
133 pdockq_cutoff: float=8.,
134 pae_cutoff: float=12.,
135 dist_cutoff: float=10.):
136 self.chains = chains
137 self.unique_chains = np.unique(chains)
138 self.chain_pair_type = chain_pair_type
139 self.n_res = n_residues
140 self.pDockQ_cutoff = pdockq_cutoff
141 self.PAE_cutoff = pae_cutoff
142 self.dist_cutoff = dist_cutoff
144 self.permute_chains()
146 def compute_scores(self,
147 distances: np.ndarray,
148 pLDDT: np.ndarray,
149 PAE: np.ndarray) -> None:
150 """
151 Based on the input distance, pLDDT and PAE matrices, compute the pairwise pDockQ, pDockQ2,
152 LIS, ipTM and ipSAE scores.
154 Returns:
155 None
156 """
157 self.distances = distances
158 self.pLDDT = pLDDT
159 self.PAE = PAE
161 results = []
162 for chain1, chain2 in self.permuted:
163 pDockQ, pDockQ2 = self.compute_pDockQ_scores(chain1, chain2)
164 LIS = self.compute_LIS(chain1, chain2)
165 ipTM, ipSAE = self.compute_ipTM_ipSAE(chain1, chain2)
167 results.append([chain1, chain2, pDockQ, pDockQ2, LIS, ipTM, ipSAE])
169 self.df = pl.DataFrame(np.array(results), schema={'chain1': str,
170 'chain2': str,
171 'pDockQ': float,
172 'pDockQ2': float,
173 'LIS': float,
174 'ipTM': float,
175 'ipSAE': float})
176 self.get_max_values()
178 def compute_pDockQ_scores(self,
179 chain1: str,
180 chain2: str) -> tuple[float, float]:
181 """
182 Computes both the pDockQ and pDockQ2 scores for the interface between two chains.
183 pDockQ is dependent solely on the pLDDT matrix while pDockQ2 is dependent on both
184 pLDDT and the PAE matrix.
186 Arguments:
187 chain1 (str): The string name of the first chain.
188 chain2 (str): The string name of the first chain.
190 Returns:
191 (tuple[float, float]): A tuple of the pDockQ and pDockQ2 scores respectively.
192 """
193 n_pairs = 0
194 _sum = 0.
195 residues = set()
196 for i in range(self.n_res):
197 if self.chains[i] == chain1:
198 continue
200 valid_pairs = (self.chains == chain2) & (self.distances[i] <= self.pDockQ_cutoff)
201 n_pairs += np.sum(valid_pairs)
202 if valid_pairs.any():
203 residues.add(i)
204 chain2_residues = np.where(valid_pairs)[0]
205 pae_list = self.PAE[i][valid_pairs]
206 pae_list_ptm = self.compute_pTM(pae_list, 10.)
207 _sum += pae_list_ptm.sum()
209 for residue in chain2_residues:
210 residues.add(residue)
212 if n_pairs > 0:
213 residues = list(residues)
214 n_res = len(residues)
215 mean_pLDDT = self.pLDDT[residues].mean()
216 x = mean_pLDDT * np.log10(n_pairs)
217 pDockQ = self.pDockQ_score(x)
219 mean_pTM = _sum / n_pairs
220 x = mean_pLDDT * mean_pTM
221 pDockQ2 = self.pDockQ2_score(x)
223 return pDockQ, pDockQ2
225 def compute_LIS(self,
226 chain1: str,
227 chain2: str) -> float:
228 """
229 Computes Local Interaction Score (LIS) which is based on a subset of the
230 predicted aligned error using a cutoff of 12. Values range in the interval
231 (0, 1] and can be interpreted as how accurate a fold is within the error
232 cutoff where a mean error of 0 yields a LIS value of 1 and a mean error
233 that approaches 12 has a LIS value that approaches 0.
234 Adapted from: https://doi.org/10.1101/2024.02.19.580970.
236 Arguments:
237 chain1 (str): The string name of the first chain.
238 chain2 (str): The string name of the second chain.
239 Returns:
240 (float): The LIS value for both chains.
241 """
242 mask = (self.chains[:, None] == chain1) & (self.chains[None, :] == chain2)
243 selected_pae = self.PAE[mask]
245 LIS = 0.
246 if selected_pae.size:
247 valid_pae = selected_pae[selected_pae < 12]
248 if valid_pae.size:
249 scores = (12 - valid_pae) / 12
250 avg_score = np.mean(scores)
251 LIS = avg_score
253 return LIS
255 def compute_ipTM_ipSAE(self,
256 chain1: str,
257 chain2: str) -> tuple[float, float]:
258 """
259 Computes the ipTM and ipSAE scores for a given pair of chains. These operations
260 are combined since they rely on very similar processing of the data.
262 Arguments:
263 chain1 (str): The first chain to compare.
264 chain2 (str): The second chain to compare.
266 Returns:
267 (tuple[float]): A tuple containing the ipTM and ipSAE scores respectively.
268 """
269 pair_type = 'protein'
270 if 'nucleic' in [self.chain_pair_type[chain1], self.chain_pair_type[chain2]]:
271 pair_type = 'nucleic'
273 L = np.sum(self.chains == chain1) + np.sum(self.chains == chain2)
274 d0_chain = self.compute_d0(L, pair_type)
276 pTM_matrix_chain = self.compute_pTM(self.PAE, d0_chain)
277 ipTM_byres = np.zeros((pTM_matrix_chain.shape[0]))
279 valid_pairs_ipTM = (self.chains == chain2)
280 ipTM_byres = np.array([0.])
281 if valid_pairs_ipTM.any():
282 ipTM_byres = np.mean(pTM_matrix_chain[:, valid_pairs_ipTM], axis=0)
284 valid_pairs_matrix = (self.chains == chain2) & (self.PAE < self.PAE_cutoff)
285 valid_pairs_ipSAE = valid_pairs_matrix
287 ipSAE_byres = np.array([0.])
288 if valid_pairs_ipSAE.any():
289 ipSAE_byres = np.mean(pTM_matrix_chain[valid_pairs_ipSAE], axis=0)
291 ipTM = np.max(ipTM_byres)
292 ipSAE = np.max(ipSAE_byres)
294 return ipTM, ipSAE
296 def get_max_values(self) -> None:
297 """
298 Because some scores like ipSAE are not symmetric, meaning A->B != B->A, we
299 take the maximal score for either direction to be the undirected score.
300 Here we scrape through the internal dataframe and keeps only the rows with
301 the maximal values.
303 Returns:
304 None
305 """
306 rows = []
307 processed = set()
308 for chain1, chain2 in self.permuted:
309 if not all([chain in processed for chain in (chain1, chain2)]):
310 filtered = self.df.filter(
311 ((pl.col('chain1') == chain1) & (pl.col('chain2') == chain2)) |
312 ((pl.col('chain1') == chain2) & (pl.col('chain2') == chain1))
313 )
314 max_ipsae = filtered.select('ipSAE').max().item()
315 max_row = filtered.filter(pl.col('ipSAE') == max_ipsae)
316 rows.append(max_row)
318 processed.add(chain1)
319 processed.add(chain2)
321 self.scores = pl.concat(rows)
323 def permute_chains(self) -> None:
324 """
325 Helper function that gives all permutations of chainID except
326 the pair (self, self) for each chainID. This also ensures that
327 if we have (A, B) we do not also store (B, A).
329 Returns:
330 None
331 """
332 permuted = set()
333 for c1, c2 in permutations(self.unique_chains, 2):
334 if c1 != c2:
335 permuted.add((c1, c2))
336 permuted.add((c2, c1))
338 self.permuted = list(permuted)
340 @staticmethod
341 def pDockQ_score(x) -> float:
342 """
343 Computes pDockQ score per the following equation.
344 $pDockQ = \frac{0.724}{(1 + e^{-0.052 * (x - 152.611)}) + 0.018}$
346 Details on the pDockQ score at: https://doi.org/10.1038/s41467-022-28865-w
348 Arguments:
349 x (float): Mean pLDDT score scaled by the log10 number of residue pairs
350 that meet pLDDT and distance cutoffs.
352 Returns:
353 (float): pDockQ score
354 """
355 return 0.724 / (1 + np.exp(-0.052 * (x - 152.611))) + 0.018
357 @staticmethod
358 def pDockQ2_score(x) -> float:
359 """
360 Computes pDockQ2 score per the following equation.
361 $pDockQ = \frac{1.31}{(1 + e^{-0.075 * (x - 84.733)}) + 0.005}$
363 Details on the pDockQ2 score at: https://doi.org/10.1093/bioinformatics/btad424
365 Arguments:
366 x (float): Mean pLDDT score scaled by mean PAE score.
368 Returns:
369 (float): pDockQ2 score
370 """
371 return 1.31 / (1 + np.exp(-0.075 * (x - 84.733))) + 0.005
373 @staticmethod
374 @vectorize
375 def compute_pTM(x: float,
376 d0: float) -> float:
377 """
378 Computes pTM score per the following equation.
379 $pTM = \frac{1.0}{(1 + (x / d0)^2)}$
381 Arguments:
382 x (float): pLDDT score
383 d0 (float): d0 parameter
385 Returns:
386 (float): pTM score
387 """
388 return 1. / (1 + (x / d0) ** 2)
390 @staticmethod
391 def compute_d0(L: int,
392 pair_type: str) -> float:
393 """
394 Computes d0 term per the following equation.
395 $d0 = min(1.0, 1.24 * (L - 15)^}(\frac{1}{3})} - 1.8)$
397 Arguments:
398 L (int): Length of sequence up to 27 residues.
399 pair_type (str): Whether or not chain is a nucleic acid.
401 Returns:
402 (float): d0
403 """
404 L = max(27, L)
406 min_value = 1.
407 if pair_type == 'nucleic_acid':
408 min_value = 2.
410 return max(min_value, 1.24 * (L - 15) ** (1/3) - 1.8)
413class ModelParser:
414 """
415 Helper class to read in and process a structure file for downstream
416 scoring tasks. Capable of reading both PDB and CIF formats.
418 Arguments:
419 structure (PathLike): Path to PDB or CIF file.
420 """
421 def __init__(self,
422 structure: PathLike):
423 self.structure =Path(structure)
425 self.token_mask = []
426 self.residues = []
427 self.cb_residues = []
428 self.chains = []
430 def parse_structure_file(self) -> None:
431 """
432 Identify filetype, and parses line by line, storing relevant data
433 for all C-alpha, C-beta and C1, C3 atoms for proteins and nucleic
434 acids alike.
436 Returns:
437 None
438 """
439 if self.structure.suffix == '.pdb':
440 line_parser = self.parse_pdb_line
441 else:
442 line_parser = self.parse_cif_line
444 field_num = 0
445 lines = open(self.structure).readlines()
446 fields = dict()
447 for line in lines:
448 if line.startswith('_atom_site.'):
449 _, field_name = line.strip().split('.')
450 fields[field_name] = field_num
451 field_num += 1
453 if any([line.startswith(atom) for atom in ['ATOM', 'HETATM']]):
454 atom = line_parser(line, fields)
456 name = atom['atom_name']
457 if name == 'CA':
458 self.token_mask.append(1)
459 self.residues.append(atom)
460 self.chains.append(atom['chain_id'])
461 if atom['res'] == 'GLY':
462 self.cb_residues.append(atom)
464 elif 'C1' in name:
465 self.token_mask.append(1)
466 self.residues.append(atom)
467 self.chains.append(atom['chain_id'])
469 elif name == 'CB' or 'C3' in name:
470 self.cb_residues.append(atom)
472 def classify_chains(self) -> None:
473 """
474 Reads through residue data to assign the identity of each chain as
475 either protein (by default) or nucleic acid if an NA residue is detected.
477 Returns:
478 None
479 """
480 self.residue_types = np.array([res['res'] for res in self.residues])
481 chains = np.unique(self.chains)
482 self.chain_types = {chain: 'protein' for chain in chains}
483 for chain in chains:
484 indices = np.where(chains == chain)[0]
485 chain_residues = self.residue_types[indices]
486 if any([r in chain_residues for r in self.nucleic_acids]):
487 self.chain_types[chain] = 'nucleic_acid'
489 @property
490 def nucleic_acids(self) -> list[str]:
491 """
492 Stores the canonical resnames for RNA and DNA residues.
494 Returns:
495 (list[str]): List of nucleic acid resnames.
496 """
497 return ['DA', 'DC', 'DT', 'DG', 'A', 'C', 'U', 'G']
499 @staticmethod
500 def parse_pdb_line(line: str,
501 *args) -> dict[str, Any]:
502 """
503 Parses a single line of a PDB file, extracting atom and residue information.
504 Processes this into a dictionary and returns the dict.
506 Arguments:
507 line (str): Actual line from PDB file.
508 *args: Just here so we can use the same API for PDB and CIF.
510 Returns:
511 (dict[str, Any]): Dictionary representation of data.
512 """
513 atom_num = line[6:11].strip()
514 atom_name = line[12:16].strip()
515 residue_name = line[17:20].strip()
516 chain_id = line[21]
517 residue_id = line[22:26].strip()
518 x = line[30:38].strip()
519 y = line[38:46].strip()
520 z = line[46:54].strip()
522 return ModelParser.package_line(atom_num, atom_name, residue_name, chain_id, residue_id, x, y, z)
524 @staticmethod
525 def parse_cif_line(line: str,
526 fields: dict[str, int]) -> dict[str, Any]:
527 """
528 Parses a single line of a CIF file, extracting atom and residue information.
529 Processes this into a dictionary and returns the dict.
531 Arguments:
532 line (str): Actual line from CIF file.
533 fields (dict[str, int]): Definition of where each field is found.
535 Returns:
536 (dict[str, Any]): Dictionary representation of data.
537 """
538 _split = line.split()
539 atom_num = _split[fields['id']]
540 atom_name = _split[fields['label_atom_id']]
541 residue_name = _split[fields['label_comp_id']]
542 chain_id = _split[fields['label_asym_id']]
543 residue_id = _split[fields['label_seq_id']]
544 x = _split[fields['Cartn_x']]
545 y = _split[fields['Cartn_y']]
546 z = _split[fields['Cartn_z']]
548 if residue_id == '.':
549 return None
551 return ModelParser.package_line(atom_num, atom_name, residue_name, chain_id, residue_id, x, y, z)
553 @staticmethod
554 def package_line(atom_num: str,
555 atom_name: str,
556 residue_name: str,
557 chain_id: str,
558 residue_id: str,
559 x: str,
560 y: str,
561 z: str) -> dict[str, Any]:
562 """
563 Packs various information from a single line of a structure file into
564 a dictionary to maintain consistency.
566 Arguments:
567 atom_num (str): Atom index.
568 atom_name (str): Atom name.
569 residue_name (str): Resname.
570 chain_id (str): ChainID.
571 residue_id (str): ResID.
572 x (str): X coordinate.
573 y (str): Y coordinate.
574 z (str): Z coordinate.
576 Returns:
577 (dict[str, Any]): Dictionary representation of data.
578 """
579 return {
580 'atom_num': int(atom_num),
581 'atom_name': atom_name,
582 'coor': np.array([float(i) for i in [x, y, z]]),
583 'res': residue_name,
584 'chain_id': chain_id,
585 'resid': int(residue_id),
586 }