Coverage for src / molecular_simulations / analysis / ipSAE.py: 93%

234 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-12 10:07 -0600

1from itertools import permutations 

2import numpy as np 

3from numpy import vectorize 

4from pathlib import Path 

5import polars as pl 

6from typing import Any, Union 

7 

8PathLike = Union[Path, str] 

9OptPath = Union[Path, str, None] 

10 

11class ipSAE: 

12 """ 

13 Compute the interaction prediction Score from Aligned Errors for a model. 

14 Adapted from https://doi.org/10.1101/2025.02.10.637595. Currently supports 

15 only outputs which provide plddt and pae data which limits us to Boltz and 

16 AlphaFold. 

17 

18 Arguments: 

19 structure_file (PathLike): Path to PDB/CIF model. 

20 plddt_file (PathLike): Path to plddt npy file. 

21 pae_file (PathLike): Path to pae npy file. 

22 out_path (PathLike | None): Defaults to None. Path for outputs, or if None, 

23 will use the parent path from the plddt file. 

24 """ 

25 def __init__(self, 

26 structure_file: PathLike, 

27 plddt_file: PathLike, 

28 pae_file: PathLike, 

29 out_path: OptPath=None): 

30 self.parser = ModelParser(structure_file) 

31 self.plddt_file = Path(plddt_file) 

32 self.pae_file = Path(pae_file) 

33 

34 self.path = Path(out_path) if out_path is not None else self.plddt_file.parent 

35 self.path.mkdir(exist_ok=True) 

36 

37 def parse_structure_file(self) -> None: 

38 """ 

39 Runs parser to read in structure file and extract relevant details. 

40 

41 Returns: 

42 None 

43 """ 

44 self.parser.parse_structure_file() 

45 self.parser.classify_chains() 

46 self.coordinates = np.vstack([res['coor'] for res in self.parser.residues]) 

47 self.token_array = np.array(self.parser.token_mask, dtype=bool) 

48 

49 def prepare_scorer(self) -> None: 

50 """ 

51 Prepares scorer for computing various scores. 

52 

53 Returns: 

54 None 

55 """ 

56 chains = np.array(self.parser.chains) 

57 chain_types = self.parser.chain_types 

58 residue_types = np.array([res['res'] for res in self.parser.residues]) 

59 

60 self.scorer = ScoreCalculator(chains=chains, 

61 chain_pair_type=chain_types, 

62 n_residues=residue_types) 

63 

64 def run(self) -> None: 

65 """ 

66 Main logic of class. Parses structure file, computes distogram, unpacks 

67 pLDDT and PAE, feeds data to scorer and saves out scores. 

68 

69 Returns: 

70 None 

71 """ 

72 self.parse_structure_file() 

73 

74 distances = self.coordinates[:, np.newaxis, :] - self.coordinates[np.newaxis, :, :] 

75 distances = np.sqrt((distances ** 2).sum(axis=2)) 

76 pLDDT = self.load_pLDDT_file() 

77 PAE = self.load_PAE_file() 

78 

79 self.prepare_scorer() 

80 self.scorer.compute_scores(distances, pLDDT, PAE) 

81 

82 self.scores = self.scorer.scores 

83 self.save_scores() 

84 

85 def save_scores(self) -> None: 

86 """ 

87 Saves scores dataframe to a parquet file. 

88 

89 Returns: 

90 None 

91 """ 

92 self.scores.write_parquet(self.path / 'ipSAE_scores.parquet') 

93 

94 def load_pLDDT_file(self) -> np.ndarray: 

95 """ 

96 Loads pLDDT file and scales data by 100. 

97 

98 Returns: 

99 (np.ndarray): Scaled pLDDT array. 

100 """ 

101 data = np.load(str(self.plddt_file)) 

102 pLDDT_arr = np.array(data['plddt'] * 100.) 

103 

104 return pLDDT_arr 

105 

106 def load_PAE_file(self) -> np.ndarray: 

107 """ 

108 Loads PAE file and returns data. 

109 

110 Returns: 

111 (np.ndarray): Array of PAE values. 

112 """ 

113 data = np.load(str(self.pae_file))['pae'] 

114 return data 

115 

116class ScoreCalculator: 

117 """ 

118 Computes various model quality scores including: pDockQ, pDockQ2, LIS, ipTM and 

119 the ipSAE score. 

120 

121 Arguments: 

122 chains (np.ndarray): Array of chainIDs. 

123 chain_pair_type (dict[str, str]): Dictionary mapping of chainID to chain type. 

124 n_residues (int): Number of residues total in structure. 

125 pdockq_cutoff (float): Defaults to 8.0 Å. 

126 pae_cutoff (float): Defaults to 12.0 Å. 

127 dist_cutoff (float): Defaults to 10.0 Å. 

128 """ 

129 def __init__(self, 

130 chains: np.ndarray, 

131 chain_pair_type: dict[str, str], 

132 n_residues: int, 

133 pdockq_cutoff: float=8., 

134 pae_cutoff: float=12., 

135 dist_cutoff: float=10.): 

136 self.chains = chains 

137 self.unique_chains = np.unique(chains) 

138 self.chain_pair_type = chain_pair_type 

139 self.n_res = n_residues 

140 self.pDockQ_cutoff = pdockq_cutoff 

141 self.PAE_cutoff = pae_cutoff 

142 self.dist_cutoff = dist_cutoff 

143 

144 self.permute_chains() 

145 

146 def compute_scores(self, 

147 distances: np.ndarray, 

148 pLDDT: np.ndarray, 

149 PAE: np.ndarray) -> None: 

150 """ 

151 Based on the input distance, pLDDT and PAE matrices, compute the pairwise pDockQ, pDockQ2, 

152 LIS, ipTM and ipSAE scores. 

153 

154 Returns: 

155 None 

156 """ 

157 self.distances = distances 

158 self.pLDDT = pLDDT 

159 self.PAE = PAE 

160 

161 results = [] 

162 for chain1, chain2 in self.permuted: 

163 pDockQ, pDockQ2 = self.compute_pDockQ_scores(chain1, chain2) 

164 LIS = self.compute_LIS(chain1, chain2) 

165 ipTM, ipSAE = self.compute_ipTM_ipSAE(chain1, chain2) 

166 

167 results.append([chain1, chain2, pDockQ, pDockQ2, LIS, ipTM, ipSAE]) 

168 

169 self.df = pl.DataFrame(np.array(results), schema={'chain1': str, 

170 'chain2': str, 

171 'pDockQ': float, 

172 'pDockQ2': float, 

173 'LIS': float, 

174 'ipTM': float, 

175 'ipSAE': float}) 

176 self.get_max_values() 

177 

178 def compute_pDockQ_scores(self, 

179 chain1: str, 

180 chain2: str) -> tuple[float, float]: 

181 """ 

182 Computes both the pDockQ and pDockQ2 scores for the interface between two chains. 

183 pDockQ is dependent solely on the pLDDT matrix while pDockQ2 is dependent on both 

184 pLDDT and the PAE matrix. 

185 

186 Arguments: 

187 chain1 (str): The string name of the first chain. 

188 chain2 (str): The string name of the first chain. 

189 

190 Returns: 

191 (tuple[float, float]): A tuple of the pDockQ and pDockQ2 scores respectively. 

192 """ 

193 n_pairs = 0 

194 _sum = 0. 

195 residues = set() 

196 for i in range(self.n_res): 

197 if self.chains[i] == chain1: 

198 continue 

199 

200 valid_pairs = (self.chains == chain2) & (self.distances[i] <= self.pDockQ_cutoff) 

201 n_pairs += np.sum(valid_pairs) 

202 if valid_pairs.any(): 

203 residues.add(i) 

204 chain2_residues = np.where(valid_pairs)[0] 

205 pae_list = self.PAE[i][valid_pairs] 

206 pae_list_ptm = self.compute_pTM(pae_list, 10.) 

207 _sum += pae_list_ptm.sum() 

208 

209 for residue in chain2_residues: 

210 residues.add(residue) 

211 

212 if n_pairs > 0: 

213 residues = list(residues) 

214 n_res = len(residues) 

215 mean_pLDDT = self.pLDDT[residues].mean() 

216 x = mean_pLDDT * np.log10(n_pairs) 

217 pDockQ = self.pDockQ_score(x) 

218 

219 mean_pTM = _sum / n_pairs 

220 x = mean_pLDDT * mean_pTM 

221 pDockQ2 = self.pDockQ2_score(x) 

222 

223 return pDockQ, pDockQ2 

224 

225 def compute_LIS(self, 

226 chain1: str, 

227 chain2: str) -> float: 

228 """ 

229 Computes Local Interaction Score (LIS) which is based on a subset of the  

230 predicted aligned error using a cutoff of 12. Values range in the interval  

231 (0, 1] and can be interpreted as how accurate a fold is within the error  

232 cutoff where a mean error of 0 yields a LIS value of 1 and a mean error 

233 that approaches 12 has a LIS value that approaches 0. 

234 Adapted from: https://doi.org/10.1101/2024.02.19.580970. 

235 

236 Arguments: 

237 chain1 (str): The string name of the first chain. 

238 chain2 (str): The string name of the second chain. 

239 Returns: 

240 (float): The LIS value for both chains. 

241 """ 

242 mask = (self.chains[:, None] == chain1) & (self.chains[None, :] == chain2) 

243 selected_pae = self.PAE[mask] 

244 

245 LIS = 0. 

246 if selected_pae.size: 

247 valid_pae = selected_pae[selected_pae < 12] 

248 if valid_pae.size: 

249 scores = (12 - valid_pae) / 12 

250 avg_score = np.mean(scores) 

251 LIS = avg_score 

252 

253 return LIS 

254 

255 def compute_ipTM_ipSAE(self, 

256 chain1: str, 

257 chain2: str) -> tuple[float, float]: 

258 """ 

259 Computes the ipTM and ipSAE scores for a given pair of chains. These operations 

260 are combined since they rely on very similar processing of the data. 

261 

262 Arguments: 

263 chain1 (str): The first chain to compare. 

264 chain2 (str): The second chain to compare. 

265 

266 Returns: 

267 (tuple[float]): A tuple containing the ipTM and ipSAE scores respectively. 

268 """ 

269 pair_type = 'protein' 

270 if 'nucleic' in [self.chain_pair_type[chain1], self.chain_pair_type[chain2]]: 

271 pair_type = 'nucleic' 

272 

273 L = np.sum(self.chains == chain1) + np.sum(self.chains == chain2) 

274 d0_chain = self.compute_d0(L, pair_type) 

275 

276 pTM_matrix_chain = self.compute_pTM(self.PAE, d0_chain) 

277 ipTM_byres = np.zeros((pTM_matrix_chain.shape[0])) 

278 

279 valid_pairs_ipTM = (self.chains == chain2) 

280 ipTM_byres = np.array([0.]) 

281 if valid_pairs_ipTM.any(): 

282 ipTM_byres = np.mean(pTM_matrix_chain[:, valid_pairs_ipTM], axis=0) 

283 

284 valid_pairs_matrix = (self.chains == chain2) & (self.PAE < self.PAE_cutoff) 

285 valid_pairs_ipSAE = valid_pairs_matrix 

286 

287 ipSAE_byres = np.array([0.]) 

288 if valid_pairs_ipSAE.any(): 

289 ipSAE_byres = np.mean(pTM_matrix_chain[valid_pairs_ipSAE], axis=0) 

290 

291 ipTM = np.max(ipTM_byres) 

292 ipSAE = np.max(ipSAE_byres) 

293 

294 return ipTM, ipSAE 

295 

296 def get_max_values(self) -> None: 

297 """ 

298 Because some scores like ipSAE are not symmetric, meaning A->B != B->A, we 

299 take the maximal score for either direction to be the undirected score. 

300 Here we scrape through the internal dataframe and keeps only the rows with 

301 the maximal values. 

302 

303 Returns: 

304 None 

305 """ 

306 rows = [] 

307 processed = set() 

308 for chain1, chain2 in self.permuted: 

309 if not all([chain in processed for chain in (chain1, chain2)]): 

310 filtered = self.df.filter( 

311 ((pl.col('chain1') == chain1) & (pl.col('chain2') == chain2)) | 

312 ((pl.col('chain1') == chain2) & (pl.col('chain2') == chain1)) 

313 ) 

314 max_ipsae = filtered.select('ipSAE').max().item() 

315 max_row = filtered.filter(pl.col('ipSAE') == max_ipsae) 

316 rows.append(max_row) 

317 

318 processed.add(chain1) 

319 processed.add(chain2) 

320 

321 self.scores = pl.concat(rows) 

322 

323 def permute_chains(self) -> None: 

324 """ 

325 Helper function that gives all permutations of chainID except 

326 the pair (self, self) for each chainID. This also ensures that 

327 if we have (A, B) we do not also store (B, A). 

328 

329 Returns: 

330 None 

331 """ 

332 permuted = set() 

333 for c1, c2 in permutations(self.unique_chains, 2): 

334 if c1 != c2: 

335 permuted.add((c1, c2)) 

336 permuted.add((c2, c1)) 

337 

338 self.permuted = list(permuted) 

339 

340 @staticmethod 

341 def pDockQ_score(x) -> float: 

342 """ 

343 Computes pDockQ score per the following equation. 

344 $pDockQ = \frac{0.724}{(1 + e^{-0.052 * (x - 152.611)}) + 0.018}$ 

345 

346 Details on the pDockQ score at: https://doi.org/10.1038/s41467-022-28865-w 

347 

348 Arguments: 

349 x (float): Mean pLDDT score scaled by the log10 number of residue pairs  

350 that meet pLDDT and distance cutoffs. 

351 

352 Returns: 

353 (float): pDockQ score 

354 """ 

355 return 0.724 / (1 + np.exp(-0.052 * (x - 152.611))) + 0.018 

356 

357 @staticmethod 

358 def pDockQ2_score(x) -> float: 

359 """ 

360 Computes pDockQ2 score per the following equation. 

361 $pDockQ = \frac{1.31}{(1 + e^{-0.075 * (x - 84.733)}) + 0.005}$ 

362 

363 Details on the pDockQ2 score at: https://doi.org/10.1093/bioinformatics/btad424 

364 

365 Arguments: 

366 x (float): Mean pLDDT score scaled by mean PAE score. 

367 

368 Returns: 

369 (float): pDockQ2 score 

370 """ 

371 return 1.31 / (1 + np.exp(-0.075 * (x - 84.733))) + 0.005 

372 

373 @staticmethod 

374 @vectorize 

375 def compute_pTM(x: float, 

376 d0: float) -> float: 

377 """ 

378 Computes pTM score per the following equation. 

379 $pTM = \frac{1.0}{(1 + (x / d0)^2)}$ 

380 

381 Arguments: 

382 x (float): pLDDT score 

383 d0 (float): d0 parameter 

384 

385 Returns: 

386 (float): pTM score 

387 """ 

388 return 1. / (1 + (x / d0) ** 2) 

389 

390 @staticmethod 

391 def compute_d0(L: int, 

392 pair_type: str) -> float: 

393 """ 

394 Computes d0 term per the following equation.  

395 $d0 = min(1.0, 1.24 * (L - 15)^}(\frac{1}{3})} - 1.8)$ 

396 

397 Arguments: 

398 L (int): Length of sequence up to 27 residues. 

399 pair_type (str): Whether or not chain is a nucleic acid. 

400 

401 Returns: 

402 (float): d0 

403 """ 

404 L = max(27, L) 

405 

406 min_value = 1. 

407 if pair_type == 'nucleic_acid': 

408 min_value = 2. 

409 

410 return max(min_value, 1.24 * (L - 15) ** (1/3) - 1.8) 

411 

412 

413class ModelParser: 

414 """ 

415 Helper class to read in and process a structure file for downstream 

416 scoring tasks. Capable of reading both PDB and CIF formats. 

417 

418 Arguments: 

419 structure (PathLike): Path to PDB or CIF file. 

420 """ 

421 def __init__(self, 

422 structure: PathLike): 

423 self.structure =Path(structure) 

424 

425 self.token_mask = [] 

426 self.residues = [] 

427 self.cb_residues = [] 

428 self.chains = [] 

429 

430 def parse_structure_file(self) -> None: 

431 """ 

432 Identify filetype, and parses line by line, storing relevant data 

433 for all C-alpha, C-beta and C1, C3 atoms for proteins and nucleic 

434 acids alike. 

435 

436 Returns: 

437 None 

438 """ 

439 if self.structure.suffix == '.pdb': 

440 line_parser = self.parse_pdb_line 

441 else: 

442 line_parser = self.parse_cif_line 

443 

444 field_num = 0 

445 lines = open(self.structure).readlines() 

446 fields = dict() 

447 for line in lines: 

448 if line.startswith('_atom_site.'): 

449 _, field_name = line.strip().split('.') 

450 fields[field_name] = field_num 

451 field_num += 1 

452 

453 if any([line.startswith(atom) for atom in ['ATOM', 'HETATM']]): 

454 atom = line_parser(line, fields) 

455 

456 name = atom['atom_name'] 

457 if name == 'CA': 

458 self.token_mask.append(1) 

459 self.residues.append(atom) 

460 self.chains.append(atom['chain_id']) 

461 if atom['res'] == 'GLY': 

462 self.cb_residues.append(atom) 

463 

464 elif 'C1' in name: 

465 self.token_mask.append(1) 

466 self.residues.append(atom) 

467 self.chains.append(atom['chain_id']) 

468 

469 elif name == 'CB' or 'C3' in name: 

470 self.cb_residues.append(atom) 

471 

472 def classify_chains(self) -> None: 

473 """ 

474 Reads through residue data to assign the identity of each chain as 

475 either protein (by default) or nucleic acid if an NA residue is detected. 

476 

477 Returns: 

478 None 

479 """ 

480 self.residue_types = np.array([res['res'] for res in self.residues]) 

481 chains = np.unique(self.chains) 

482 self.chain_types = {chain: 'protein' for chain in chains} 

483 for chain in chains: 

484 indices = np.where(chains == chain)[0] 

485 chain_residues = self.residue_types[indices] 

486 if any([r in chain_residues for r in self.nucleic_acids]): 

487 self.chain_types[chain] = 'nucleic_acid' 

488 

489 @property 

490 def nucleic_acids(self) -> list[str]: 

491 """ 

492 Stores the canonical resnames for RNA and DNA residues. 

493 

494 Returns: 

495 (list[str]): List of nucleic acid resnames. 

496 """ 

497 return ['DA', 'DC', 'DT', 'DG', 'A', 'C', 'U', 'G'] 

498 

499 @staticmethod 

500 def parse_pdb_line(line: str, 

501 *args) -> dict[str, Any]: 

502 """ 

503 Parses a single line of a PDB file, extracting atom and residue information. 

504 Processes this into a dictionary and returns the dict. 

505 

506 Arguments: 

507 line (str): Actual line from PDB file. 

508 *args: Just here so we can use the same API for PDB and CIF. 

509 

510 Returns: 

511 (dict[str, Any]): Dictionary representation of data. 

512 """ 

513 atom_num = line[6:11].strip() 

514 atom_name = line[12:16].strip() 

515 residue_name = line[17:20].strip() 

516 chain_id = line[21] 

517 residue_id = line[22:26].strip() 

518 x = line[30:38].strip() 

519 y = line[38:46].strip() 

520 z = line[46:54].strip() 

521 

522 return ModelParser.package_line(atom_num, atom_name, residue_name, chain_id, residue_id, x, y, z) 

523 

524 @staticmethod 

525 def parse_cif_line(line: str, 

526 fields: dict[str, int]) -> dict[str, Any]: 

527 """ 

528 Parses a single line of a CIF file, extracting atom and residue information. 

529 Processes this into a dictionary and returns the dict. 

530 

531 Arguments: 

532 line (str): Actual line from CIF file. 

533 fields (dict[str, int]): Definition of where each field is found. 

534 

535 Returns: 

536 (dict[str, Any]): Dictionary representation of data. 

537 """ 

538 _split = line.split() 

539 atom_num = _split[fields['id']] 

540 atom_name = _split[fields['label_atom_id']] 

541 residue_name = _split[fields['label_comp_id']] 

542 chain_id = _split[fields['label_asym_id']] 

543 residue_id = _split[fields['label_seq_id']] 

544 x = _split[fields['Cartn_x']] 

545 y = _split[fields['Cartn_y']] 

546 z = _split[fields['Cartn_z']] 

547 

548 if residue_id == '.': 

549 return None 

550 

551 return ModelParser.package_line(atom_num, atom_name, residue_name, chain_id, residue_id, x, y, z) 

552 

553 @staticmethod 

554 def package_line(atom_num: str, 

555 atom_name: str, 

556 residue_name: str, 

557 chain_id: str, 

558 residue_id: str, 

559 x: str, 

560 y: str, 

561 z: str) -> dict[str, Any]: 

562 """ 

563 Packs various information from a single line of a structure file into 

564 a dictionary to maintain consistency. 

565 

566 Arguments: 

567 atom_num (str): Atom index. 

568 atom_name (str): Atom name. 

569 residue_name (str): Resname. 

570 chain_id (str): ChainID. 

571 residue_id (str): ResID. 

572 x (str): X coordinate. 

573 y (str): Y coordinate. 

574 z (str): Z coordinate. 

575 

576 Returns: 

577 (dict[str, Any]): Dictionary representation of data. 

578 """ 

579 return { 

580 'atom_num': int(atom_num), 

581 'atom_name': atom_name, 

582 'coor': np.array([float(i) for i in [x, y, z]]), 

583 'res': residue_name, 

584 'chain_id': chain_id, 

585 'resid': int(residue_id), 

586 }