Coverage for src / molecular_simulations / analysis / constant_pH_analysis.py: 35%

754 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-13 01:26 -0600

1""" 

2Improved constant pH analysis with UWHAM reweighting. 

3 

4This implementation adds multistate analysis capabilities to the basic 

5curve fitting approach. Uses log-space arithmetic for numerical stability. 

6""" 

7 

8from __future__ import annotations 

9 

10import ast 

11import numpy as np 

12from pathlib import Path 

13import polars as pl 

14import re 

15from scipy.optimize import curve_fit, brentq 

16from scipy.special import logsumexp 

17from typing import Optional, Dict, List, Tuple, TYPE_CHECKING 

18import warnings 

19 

20if TYPE_CHECKING: 

21 import matplotlib.pyplot as plt 

22 

23 

24class UWHAMSolver: 

25 """ 

26 Unbinned Weighted Histogram Analysis Method (UWHAM) solver. 

27  

28 NOTE: This class is NOT currently used because UWHAM/MBAR is designed 

29 for umbrella sampling and replica exchange, not independent constant pH 

30 simulations. For standard constant pH MD where each pH is an independent 

31 equilibrium simulation, simple curve fitting is the correct approach. 

32  

33 This class is retained for potential use with replica exchange constant 

34 pH (REX-cpH) simulations where samples ARE correlated across pH values. 

35  

36 Uses log-space arithmetic throughout for numerical stability with 

37 large systems (100+ titratable residues). 

38 """ 

39 

40 def __init__(self, tol: float = 1e-7, maxiter: int = 10000): 

41 self.tol = tol 

42 self.maxiter = maxiter 

43 self.f = None # Log of normalization constants (will be solved) 

44 self.log10 = np.log(10) 

45 

46 def load_data(self, df: pl.DataFrame, resid_cols: List[str]): 

47 """ 

48 Load data from polars DataFrame into UWHAM-compatible format. 

49  

50 Parameters 

51 ---------- 

52 df : pl.DataFrame 

53 DataFrame with columns: rankid, current_pH, and residue columns 

54 Residue columns should contain numeric protonation states (0 or 1) 

55 resid_cols : List[str] 

56 List of column names corresponding to residue IDs 

57 """ 

58 # Get unique pH values and count samples 

59 pH_groups = df.group_by('current_pH').agg(pl.len().alias('count')) 

60 self.pH_values = pH_groups['current_pH'].to_numpy() 

61 self.nsamples = pH_groups['count'].to_numpy().astype(int) 

62 self.nstates = len(self.pH_values) 

63 

64 # Sort by pH for consistency 

65 sort_idx = np.argsort(self.pH_values) 

66 self.pH_values = self.pH_values[sort_idx] 

67 self.nsamples = self.nsamples[sort_idx] 

68 

69 # Store state data for each pH simulation 

70 self.states = {} # resid -> list of arrays (one per pH) 

71 self.nprotons_total = [] # Total protons for each pH simulation 

72 

73 for resid_col in resid_cols: 

74 self.states[resid_col] = [] 

75 

76 # Extract data for each pH 

77 for pH in self.pH_values: 

78 pH_data = df.filter(pl.col('current_pH') == pH) 

79 

80 # Compute total protons for this pH's samples 

81 total_protons = np.zeros(len(pH_data)) 

82 

83 # For each residue, store states 

84 for resid_col in resid_cols: 

85 states = pH_data[resid_col].to_numpy().astype(float) 

86 self.states[resid_col].append(states) 

87 total_protons += states 

88 

89 self.nprotons_total.append(total_protons) 

90 

91 # Precompute reduced potentials for all state pairs 

92 # u_kl[k] is shape (nstates, n_k) - reduced potential of samples from k evaluated at all states 

93 self.u_kl = [] 

94 for k in range(self.nstates): 

95 n_k = self.nsamples[k] 

96 u_k = np.zeros((self.nstates, n_k)) 

97 for l in range(self.nstates): 

98 u_k[l, :] = self.log10 * self.pH_values[l] * self.nprotons_total[k] 

99 self.u_kl.append(u_k) 

100 

101 def solve(self, verbose: bool = False): 

102 """ 

103 Solve UWHAM self-consistent equations iteratively. 

104  

105 Uses the MBAR equation: 

106 f_k = -log(Σ_n exp(-u_k(x_n)) / Σ_l N_l exp(f_l - u_l(x_n))) 

107  

108 where the sum over n includes ALL samples from ALL states. 

109  

110 Returns 

111 ------- 

112 f : np.ndarray 

113 Free energy offsets for each pH simulation 

114 """ 

115 # Initialize free energies 

116 f = np.zeros(self.nstates) 

117 log_N = np.log(self.nsamples.astype(float)) 

118 total_samples = sum(self.nsamples) 

119 

120 # Precompute reduced potentials for all samples at all target states 

121 # u_all[target_k, sample_idx] = reduced potential at state k for sample idx 

122 # Also store source state for each sample 

123 u_all = np.zeros((self.nstates, total_samples)) 

124 sample_source = np.zeros(total_samples, dtype=int) # which state each sample came from 

125 

126 idx = 0 

127 for source_i in range(self.nstates): 

128 n_i = self.nsamples[source_i] 

129 for n in range(n_i): 

130 sample_source[idx] = source_i 

131 for target_k in range(self.nstates): 

132 # u_k(x_n) = log10 * pH_k * nprotons(x_n) 

133 u_all[target_k, idx] = self.log10 * self.pH_values[target_k] * self.nprotons_total[source_i][n] 

134 idx += 1 

135 

136 for iteration in range(self.maxiter): 

137 f_old = f.copy() 

138 

139 # Compute denominator for each sample: c_n = Σ_l N_l exp(f_l - u_l(x_n)) 

140 # log(c_n) = logsumexp(log_N + f - u_l(x_n)) 

141 log_c = np.zeros(total_samples) 

142 for n in range(total_samples): 

143 source_i = sample_source[n] 

144 # u_l(x_n) for all states l - this is stored in u_kl[source_i] 

145 log_c[n] = logsumexp(log_N + f - self.u_kl[source_i][:, n % self.nsamples[source_i]]) 

146 

147 # Wait, that indexing is wrong. Let me redo this. 

148 # Actually I need to recompute using proper indexing 

149 

150 log_c = np.zeros(total_samples) 

151 idx = 0 

152 for source_i in range(self.nstates): 

153 n_i = self.nsamples[source_i] 

154 for local_n in range(n_i): 

155 # u_l(x_n) for all states l 

156 log_c[idx] = logsumexp(log_N + f - self.u_kl[source_i][:, local_n]) 

157 idx += 1 

158 

159 # Update each free energy 

160 for target_k in range(self.nstates): 

161 # f_k = -log(Σ_n exp(-u_k(x_n)) / c_n) 

162 # = -log(Σ_n exp(-u_k(x_n) - log(c_n))) 

163 # = -logsumexp(-u_all[target_k, :] - log_c) 

164 log_weights = -u_all[target_k, :] - log_c 

165 f[target_k] = -logsumexp(log_weights) 

166 

167 # Normalize so f[0] = 0 

168 f = f - f[0] 

169 

170 # Check convergence 

171 delta = np.abs(f - f_old).max() 

172 if verbose and iteration % 100 == 0: 

173 print(f" Iteration {iteration}: max|Δf| = {delta:.2e}") 

174 

175 if delta < self.tol: 

176 if verbose: 

177 print(f" Converged after {iteration + 1} iterations") 

178 break 

179 else: 

180 warnings.warn( 

181 f"UWHAM did not converge after {self.maxiter} iterations " 

182 f"(final delta = {delta:.2e})" 

183 ) 

184 

185 self.f = f 

186 self.log_c = log_c # Store for weight computation 

187 self.u_all = u_all # Store for weight computation 

188 self.sample_source = sample_source 

189 self.total_samples = total_samples 

190 

191 return f 

192 

193 def compute_log_weights(self, target_pH: float) -> Tuple[np.ndarray, float]: 

194 """ 

195 Compute log weights for reweighting to target pH. 

196  

197 Uses MBAR formula: 

198 w_n ∝ exp(-u_target(x_n)) / Σ_l N_l exp(f_l - u_l(x_n)) 

199  

200 Returns 

201 ------- 

202 log_weights : np.ndarray 

203 Log weights for all samples (flattened) 

204 log_norm : float 

205 Log of the normalization constant 

206 """ 

207 if self.f is None: 

208 raise RuntimeError("Must call solve() before computing weights") 

209 

210 # Compute reduced potential at target pH for all samples 

211 u_target = np.zeros(self.total_samples) 

212 idx = 0 

213 for source_i in range(self.nstates): 

214 n_i = self.nsamples[source_i] 

215 for local_n in range(n_i): 

216 u_target[idx] = self.log10 * target_pH * self.nprotons_total[source_i][local_n] 

217 idx += 1 

218 

219 # log(w_n) = -u_target(x_n) - log(c_n) 

220 # where log(c_n) was precomputed in solve() 

221 log_weights = -u_target - self.log_c 

222 

223 # Normalize 

224 log_norm = logsumexp(log_weights) 

225 

226 return log_weights, log_norm 

227 

228 def compute_expectation_at_pH( 

229 self, 

230 observable_by_state: List[np.ndarray], 

231 target_pH: float 

232 ) -> float: 

233 """ 

234 Compute expectation value of observable at arbitrary pH. 

235  

236 Parameters 

237 ---------- 

238 observable_by_state : List[np.ndarray] 

239 Observable values for each sample, organized by state index 

240 target_pH : float 

241 pH value at which to compute the expectation 

242  

243 Returns 

244 ------- 

245 expectation : float 

246 Reweighted expectation value at target_pH 

247 """ 

248 log_weights, log_norm = self.compute_log_weights(target_pH) 

249 

250 # Flatten observable to match log_weights ordering 

251 obs_flat = np.concatenate(observable_by_state) 

252 

253 # Compute weighted sum 

254 # <A> = Σ_n A_n * w_n / Σ_n w_n 

255 # = Σ_n A_n * exp(log_w_n - log_norm) 

256 weights = np.exp(log_weights - log_norm) 

257 

258 return np.sum(obs_flat * weights) 

259 

260 def get_occupancy_for_resid(self, resid: str) -> List[np.ndarray]: 

261 """Get occupancy arrays for a specific residue across all pH values.""" 

262 return self.states[resid] 

263 

264 

265class TitrationCurve: 

266 """ 

267 Analyze constant pH simulations with multiple fitting methods. 

268  

269 Available methods: 

270 - curvefit: Simple least squares fit of Hill equation to per-pH averages 

271 - weighted: Weighted least squares (weight by 1/variance) 

272 - bootstrap: Curve fitting with bootstrap confidence intervals 

273  

274 Note: For independent constant pH simulations (not replica exchange), 

275 simple curve fitting is the statistically correct approach. UWHAM/MBAR 

276 is only appropriate for replica exchange constant pH where samples 

277 are correlated across pH values. 

278 """ 

279 

280 def __init__( 

281 self, 

282 log_file: Path | List[Path], 

283 make_plots: bool = True, 

284 out: Path = Path('.'), 

285 method: str = 'uwham' # 'curvefit' or 'uwham' 

286 ): 

287 if isinstance(log_file, list): 

288 dfs = [] 

289 resids = None 

290 for log in log_file: 

291 df, r = self.parse_log(log) 

292 dfs.append(df) 

293 if resids is None: 

294 resids = r 

295 self.df = pl.concat(dfs, how='vertical') 

296 else: 

297 self.df, resids = self.parse_log(log_file) 

298 

299 # Store residue IDs (converted to strings to match column names) 

300 self.resid_cols = [str(r) for r in resids] 

301 

302 self.make_plots = make_plots 

303 self.out = out 

304 self.method = method 

305 

306 @staticmethod 

307 def parse_log(log: Path) -> Tuple[pl.DataFrame, List[int]]: 

308 """Parse OpenMM constant pH log file. 

309  

310 Returns 

311 ------- 

312 df : pl.DataFrame 

313 DataFrame with columns: rankid, current_pH, and one column per residue 

314 resids : List[int] 

315 List of residue IDs in order 

316 """ 

317 lines = log.read_text().splitlines() 

318 

319 resids = None 

320 # Header format: "cpH: resids 20 76 83 92 ..." 

321 header_re = re.compile(r'cpH:\s+resids\s+(.+)$') 

322 

323 # Find header with residue IDs 

324 for line in lines: 

325 m = header_re.search(line) 

326 if m: 

327 # Residue IDs are separated by whitespace (possibly multiple spaces) 

328 resids = [int(x) for x in m.group(1).split()] 

329 break 

330 

331 if resids is None: 

332 raise RuntimeError( 

333 'Could not find cpH residue ID header line in log. ' 

334 'Expected line containing "cpH: resids ..."' 

335 ) 

336 

337 # Parse state lines 

338 state_re = re.compile( 

339 r'rank=(\d+).*cpH:\s+pH\s+([0-9.]+):\s+(\[.*\])' 

340 ) 

341 

342 rows = [] 

343 for line in lines: 

344 m = state_re.search(line) 

345 if not m: 

346 continue 

347 

348 rank = int(m.group(1)) 

349 current_pH = float(m.group(2)) 

350 states_list = ast.literal_eval(m.group(3)) 

351 

352 if len(states_list) != len(resids): 

353 raise ValueError( 

354 f'Mismatch between number of residues ({len(resids)}) ' 

355 f'and number of states ({len(states_list)})' 

356 ) 

357 

358 # Build row dictionary 

359 row = { 

360 'rankid': rank, 

361 'current_pH': current_pH, 

362 } 

363 row.update({ 

364 str(resid): state 

365 for resid, state in zip(resids, states_list) 

366 }) 

367 rows.append(row) 

368 

369 return pl.DataFrame(rows), resids 

370 

371 def prepare(self) -> None: 

372 """Prepare data for analysis.""" 

373 # Melt to long format for curve fitting method 

374 self.df_long = self.df.unpivot( 

375 index=['rankid', 'current_pH'], 

376 on=self.resid_cols, 

377 variable_name='resid', 

378 value_name='state', 

379 ) 

380 

381 # Determine canonical resname for each residue ID 

382 # Look at the first state observed for each residue 

383 self.resid_to_resname = {} 

384 for resid_col in self.resid_cols: 

385 # Get the first non-null state for this residue 

386 first_state = self.df[resid_col].drop_nulls().head(1).to_list() 

387 if first_state: 

388 state = first_state[0] 

389 self.resid_to_resname[resid_col] = self.canonical_resname.get(state, state) 

390 else: 

391 self.resid_to_resname[resid_col] = 'UNK' 

392 

393 # Map states to protonation (1 or 0) 

394 self.df_long = self.df_long.with_columns( 

395 pl.col('state').map_elements( 

396 lambda x: self.protonation_mapping.get(x), 

397 return_dtype=pl.Int64 

398 ).alias('prot') 

399 ).drop_nulls('prot') 

400 

401 # Compute per-pH statistics for curve fitting 

402 self.titrations = ( 

403 self.df_long.group_by(['resid', 'current_pH']) 

404 .agg([ 

405 pl.col('prot').mean().alias('fraction_protonated'), 

406 pl.col('prot').count().alias('n_samples') 

407 ]) 

408 .sort(['resid', 'current_pH']) 

409 ) 

410 

411 def compute_titrations_curvefit(self) -> pl.DataFrame: 

412 """ 

413 Compute pKa and Hill coefficient using scipy curve_fit. 

414  

415 This is the simple approach that treats each pH independently. 

416 """ 

417 fit_rows = [] 

418 

419 for resid, subdf in self.titrations.group_by('resid', maintain_order=True): 

420 resid = resid[0] # Unpack tuple 

421 resname = self.resid_to_resname.get(resid, 'UNK') 

422 x = subdf['current_pH'].to_numpy().astype(float) 

423 y = subdf['fraction_protonated'].to_numpy().astype(float) 

424 

425 if x.size < 3: 

426 # Not enough data points 

427 fit_rows.append({ 

428 'resid': resid, 

429 'resname': resname, 

430 'pKa': np.nan, 

431 'Hill_n': np.nan, 

432 'pKa_err': np.nan, 

433 'Hill_n_err': np.nan, 

434 'n_points': int(x.size), 

435 'method': 'curvefit' 

436 }) 

437 continue 

438 

439 # Initial guess: pKa where fraction ~ 0.5 

440 idx_mid = np.argmin(np.abs(y - 0.5)) 

441 pKa0 = x[idx_mid] 

442 n0 = 1.0 

443 

444 try: 

445 popt, pcov = curve_fit( 

446 self.hill_equation, 

447 x, y, 

448 p0=[pKa0, n0], 

449 bounds=([0., 0.1], [14., 10.]), 

450 maxfev=5000 

451 ) 

452 pKa, n = popt 

453 pKa_err = np.sqrt(np.diag(pcov))[0] if pcov is not None else np.nan 

454 n_err = np.sqrt(np.diag(pcov))[1] if pcov is not None else np.nan 

455 except Exception as e: 

456 pKa, n = np.nan, np.nan 

457 pKa_err, n_err = np.nan, np.nan 

458 

459 fit_rows.append({ 

460 'resid': resid, 

461 'resname': resname, 

462 'pKa': float(pKa), 

463 'Hill_n': float(n), 

464 'pKa_err': float(pKa_err), 

465 'Hill_n_err': float(n_err), 

466 'n_points': int(x.size), 

467 'method': 'curvefit' 

468 }) 

469 

470 return pl.DataFrame(fit_rows) 

471 

472 def compute_titrations_weighted(self, verbose: bool = False) -> pl.DataFrame: 

473 """ 

474 Compute pKa and Hill coefficient using weighted least squares. 

475  

476 Weights each pH point by 1/variance, giving more influence to 

477 points with more samples and intermediate protonation fractions. 

478  

479 This is more statistically rigorous than unweighted curve fitting 

480 when sample sizes vary across pH values. 

481 """ 

482 fit_rows = [] 

483 

484 for resid, subdf in self.titrations.group_by('resid', maintain_order=True): 

485 resid = resid[0] 

486 resname = self.resid_to_resname.get(resid, 'UNK') 

487 x = subdf['current_pH'].to_numpy().astype(float) 

488 y = subdf['fraction_protonated'].to_numpy().astype(float) 

489 n = subdf['n_samples'].to_numpy().astype(float) 

490 

491 if x.size < 3: 

492 fit_rows.append({ 

493 'resid': resid, 

494 'resname': resname, 

495 'pKa': np.nan, 

496 'Hill_n': np.nan, 

497 'pKa_err': np.nan, 

498 'Hill_n_err': np.nan, 

499 'n_points': int(x.size), 

500 'method': 'weighted' 

501 }) 

502 continue 

503 

504 # Compute weights: 1/variance for binomial 

505 # Var(p) = p(1-p)/n, but avoid division by zero 

506 # Add small epsilon to avoid infinite weights at p=0 or p=1 

507 eps = 0.01 

508 y_clipped = np.clip(y, eps, 1 - eps) 

509 variance = y_clipped * (1 - y_clipped) / n 

510 weights = 1.0 / variance 

511 # Normalize weights 

512 weights = weights / weights.sum() 

513 

514 # Initial guess 

515 idx_mid = np.argmin(np.abs(y - 0.5)) 

516 pKa0 = x[idx_mid] 

517 n0 = 1.0 

518 

519 try: 

520 # Weighted curve fit using sigma = 1/sqrt(weight) 

521 sigma = 1.0 / np.sqrt(weights * len(weights)) 

522 popt, pcov = curve_fit( 

523 self.hill_equation, 

524 x, y, 

525 p0=[pKa0, n0], 

526 sigma=sigma, 

527 absolute_sigma=False, 

528 bounds=([0., 0.1], [14., 10.]), 

529 maxfev=5000 

530 ) 

531 pKa, hill_n = popt 

532 pKa_err = np.sqrt(np.diag(pcov))[0] if pcov is not None else np.nan 

533 n_err = np.sqrt(np.diag(pcov))[1] if pcov is not None else np.nan 

534 except Exception: 

535 pKa, hill_n = np.nan, np.nan 

536 pKa_err, n_err = np.nan, np.nan 

537 

538 fit_rows.append({ 

539 'resid': resid, 

540 'resname': resname, 

541 'pKa': float(pKa), 

542 'Hill_n': float(hill_n), 

543 'pKa_err': float(pKa_err), 

544 'Hill_n_err': float(n_err), 

545 'n_points': int(x.size), 

546 'method': 'weighted' 

547 }) 

548 

549 return pl.DataFrame(fit_rows) 

550 

551 def compute_titrations_bootstrap( 

552 self, 

553 n_bootstrap: int = 1000, 

554 verbose: bool = False 

555 ) -> pl.DataFrame: 

556 """ 

557 Compute pKa and Hill coefficient with bootstrap confidence intervals. 

558  

559 Resamples the data at each pH to estimate uncertainty in fitted 

560 parameters. This gives robust error estimates even when the  

561 Hill equation doesn't perfectly fit the data. 

562  

563 Parameters 

564 ---------- 

565 n_bootstrap : int 

566 Number of bootstrap iterations (default 1000) 

567 verbose : bool 

568 Print progress 

569  

570 Returns 

571 ------- 

572 DataFrame with pKa, Hill_n, and 95% confidence intervals 

573 """ 

574 fit_rows = [] 

575 

576 if verbose: 

577 print(f"Running bootstrap with {n_bootstrap} iterations...") 

578 

579 for i, (resid, subdf) in enumerate(self.titrations.group_by('resid', maintain_order=True)): 

580 resid = resid[0] 

581 resname = self.resid_to_resname.get(resid, 'UNK') 

582 x = subdf['current_pH'].to_numpy().astype(float) 

583 y = subdf['fraction_protonated'].to_numpy().astype(float) 

584 n_samples = subdf['n_samples'].to_numpy().astype(int) 

585 

586 if verbose and (i + 1) % 20 == 0: 

587 print(f" {i + 1}/{len(self.resid_cols)} residues...") 

588 

589 if x.size < 3: 

590 fit_rows.append({ 

591 'resid': resid, 

592 'resname': resname, 

593 'pKa': np.nan, 

594 'pKa_lo': np.nan, 

595 'pKa_hi': np.nan, 

596 'Hill_n': np.nan, 

597 'Hill_n_lo': np.nan, 

598 'Hill_n_hi': np.nan, 

599 'n_points': int(x.size), 

600 'method': 'bootstrap' 

601 }) 

602 continue 

603 

604 # First fit to get point estimate 

605 idx_mid = np.argmin(np.abs(y - 0.5)) 

606 pKa0 = x[idx_mid] 

607 

608 try: 

609 popt, _ = curve_fit( 

610 self.hill_equation, 

611 x, y, 

612 p0=[pKa0, 1.0], 

613 bounds=([0., 0.1], [14., 10.]), 

614 maxfev=5000 

615 ) 

616 pKa_point, hill_n_point = popt 

617 except Exception: 

618 pKa_point, hill_n_point = np.nan, np.nan 

619 

620 # Bootstrap resampling 

621 pKa_boots = [] 

622 hill_n_boots = [] 

623 

624 for _ in range(n_bootstrap): 

625 # Resample: for each pH, draw n_samples from Binomial(n, p) 

626 y_boot = np.zeros(len(x)) 

627 for j in range(len(x)): 

628 # Number of protonated in bootstrap sample 

629 n_prot = np.random.binomial(n_samples[j], y[j]) 

630 y_boot[j] = n_prot / n_samples[j] 

631 

632 try: 

633 popt_boot, _ = curve_fit( 

634 self.hill_equation, 

635 x, y_boot, 

636 p0=[pKa0, 1.0], 

637 bounds=([0., 0.1], [14., 10.]), 

638 maxfev=2000 

639 ) 

640 pKa_boots.append(popt_boot[0]) 

641 hill_n_boots.append(popt_boot[1]) 

642 except Exception: 

643 pass 

644 

645 # Compute confidence intervals 

646 if len(pKa_boots) > 10: 

647 pKa_lo, pKa_hi = np.percentile(pKa_boots, [2.5, 97.5]) 

648 hill_n_lo, hill_n_hi = np.percentile(hill_n_boots, [2.5, 97.5]) 

649 else: 

650 pKa_lo, pKa_hi = np.nan, np.nan 

651 hill_n_lo, hill_n_hi = np.nan, np.nan 

652 

653 fit_rows.append({ 

654 'resid': resid, 

655 'resname': resname, 

656 'pKa': float(pKa_point), 

657 'pKa_lo': float(pKa_lo), 

658 'pKa_hi': float(pKa_hi), 

659 'Hill_n': float(hill_n_point), 

660 'Hill_n_lo': float(hill_n_lo), 

661 'Hill_n_hi': float(hill_n_hi), 

662 'n_points': int(x.size), 

663 'method': 'bootstrap' 

664 }) 

665 

666 return pl.DataFrame(fit_rows) 

667 

668 def compute_titrations(self, verbose: bool = False, n_bootstrap: int = 1000) -> None: 

669 """Compute titrations using selected method.""" 

670 if self.method == 'curvefit': 

671 self.fits = self.compute_titrations_curvefit() 

672 elif self.method == 'weighted': 

673 self.fits = self.compute_titrations_weighted(verbose=verbose) 

674 elif self.method == 'bootstrap': 

675 self.fits = self.compute_titrations_bootstrap(n_bootstrap=n_bootstrap, verbose=verbose) 

676 else: 

677 raise ValueError(f"Unknown method: {self.method}. Use 'curvefit', 'weighted', or 'bootstrap'") 

678 

679 def postprocess(self) -> None: 

680 """Generate fitted curves for plotting.""" 

681 if self.fits is None: 

682 raise RuntimeError("Must call compute_titrations() first") 

683 

684 pH_grid = np.linspace( 

685 float(self.df['current_pH'].min()), 

686 float(self.df['current_pH'].max()), 

687 200 

688 ) 

689 

690 curves = [] 

691 for row in self.fits.iter_rows(named=True): 

692 resid = row['resid'] 

693 pKa = row['pKa'] 

694 n = row['Hill_n'] 

695 

696 if np.isnan(pKa) or np.isnan(n): 

697 continue 

698 

699 y_fit = self.hill_equation(pH_grid, pKa, n) 

700 curves.append( 

701 pl.DataFrame({ 

702 'resid': [resid] * len(pH_grid), 

703 'pH': pH_grid, 

704 'fraction_protonated_fit': y_fit, 

705 }) 

706 ) 

707 

708 self.curves = pl.concat(curves) if curves else None 

709 

710 if self.make_plots: 

711 self.plot() 

712 

713 def plot(self) -> None: 

714 """Generate plots (to be implemented).""" 

715 pass 

716 

717 def diagnose_residue(self, resid: str, verbose: bool = True) -> Dict: 

718 """ 

719 Diagnose why a residue might have failed pKa determination. 

720  

721 Parameters 

722 ---------- 

723 resid : str 

724 Residue ID to diagnose 

725 verbose : bool 

726 Print diagnostic information 

727  

728 Returns 

729 ------- 

730 dict with diagnostic info including titration curve data 

731 """ 

732 # Get per-pH fraction protonated from simple averaging 

733 resid_data = self.titrations.filter(pl.col('resid') == resid) 

734 

735 pH_vals = resid_data['current_pH'].to_numpy() 

736 frac_prot = resid_data['fraction_protonated'].to_numpy() 

737 n_samples = resid_data['n_samples'].to_numpy() 

738 

739 # Get state distribution 

740 resid_states = self.df_long.filter(pl.col('resid') == resid) 

741 state_counts = resid_states.group_by('state').agg(pl.len().alias('count')) 

742 

743 resname = self.resid_to_resname.get(resid, 'UNK') 

744 

745 result = { 

746 'resid': resid, 

747 'resname': resname, 

748 'pH': pH_vals, 

749 'fraction_protonated': frac_prot, 

750 'n_samples': n_samples, 

751 'state_distribution': state_counts.to_dict(), 

752 'frac_min': frac_prot.min() if len(frac_prot) > 0 else np.nan, 

753 'frac_max': frac_prot.max() if len(frac_prot) > 0 else np.nan, 

754 } 

755 

756 if verbose: 

757 print(f"\nDiagnostics for residue {resid} ({resname}):") 

758 print(f" State distribution:") 

759 for row in state_counts.iter_rows(named=True): 

760 print(f" {row['state']}: {row['count']}") 

761 print(f"\n Titration curve (simple average):") 

762 print(f" {'pH':>6s} {'frac':>6s} {'n':>5s}") 

763 for pH, f, n in zip(pH_vals, frac_prot, n_samples): 

764 print(f" {pH:6.2f} {f:6.3f} {n:5d}") 

765 print(f"\n Fraction range: {result['frac_min']:.3f} - {result['frac_max']:.3f}") 

766 

767 if result['frac_min'] > 0.5: 

768 print(f" → Always >50% protonated - pKa likely ABOVE pH {pH_vals.max():.1f}") 

769 elif result['frac_max'] < 0.5: 

770 print(f" → Always <50% protonated - pKa likely BELOW pH {pH_vals.min():.1f}") 

771 elif result['frac_max'] - result['frac_min'] < 0.1: 

772 print(f" → Very little titration observed - may not titrate in this pH range") 

773 

774 return result 

775 

776 @staticmethod 

777 def hill_equation(pH: float, pKa: float, n: float) -> float: 

778 """ 

779 Hill equation for acid-base equilibrium. 

780  

781 Returns fraction protonated as function of pH. 

782 """ 

783 return 1.0 / (1.0 + 10.0**(n * (pH - pKa))) 

784 

785 @property 

786 def protonation_mapping(self) -> Dict[str, int]: 

787 """Map state names to protonation numbers (1 = protonated, 0 = not).""" 

788 return { 

789 'ASH': 1, 'ASP': 0, 

790 'GLH': 1, 'GLU': 0, 

791 'LYS': 1, 'LYN': 0, 

792 'CYS': 1, 'CYX': 0, 

793 'HIP': 1, 'HIE': 0, 'HID': 0, 

794 } 

795 

796 @property 

797 def canonical_resname(self) -> Dict[str, str]: 

798 """Map any state name to canonical residue name.""" 

799 return { 

800 'ASH': 'ASP', 'ASP': 'ASP', 

801 'GLH': 'GLU', 'GLU': 'GLU', 

802 'LYS': 'LYS', 'LYN': 'LYS', 

803 'CYS': 'CYS', 'CYX': 'CYS', 

804 'HIP': 'HIS', 'HIE': 'HIS', 'HID': 'HIS', 

805 } 

806 

807 def compare_methods(self, resids: Optional[List[str]] = None) -> pl.DataFrame: 

808 """ 

809 Compare curve fit vs UWHAM results for specified residues. 

810  

811 Parameters 

812 ---------- 

813 resids : List[str], optional 

814 Residues to compare. If None, compares all. 

815  

816 Returns 

817 ------- 

818 DataFrame with both methods' results side by side 

819 """ 

820 # Run both methods 

821 fits_cf = self.compute_titrations_curvefit() 

822 fits_uw = self.compute_titrations_uwham(verbose=False) 

823 

824 # Join on resid 

825 comparison = fits_cf.join( 

826 fits_uw.select(['resid', 'pKa', 'Hill_n', 'status']), 

827 on='resid', 

828 suffix='_uwham' 

829 ) 

830 

831 # Add difference columns 

832 comparison = comparison.with_columns([ 

833 (pl.col('pKa') - pl.col('pKa_uwham')).alias('pKa_diff'), 

834 (pl.col('Hill_n') - pl.col('Hill_n_uwham')).alias('Hill_n_diff'), 

835 ]) 

836 

837 if resids is not None: 

838 comparison = comparison.filter(pl.col('resid').is_in(resids)) 

839 

840 return comparison 

841 

842 

843class TitrationAnalyzer: 

844 """ 

845 High-level analyzer for constant pH simulations. 

846  

847 Provides a streamlined API that runs both curve fitting and UWHAM analysis, 

848 generates comparisons, and creates publication-quality plots. 

849  

850 Example usage 

851 ------------- 

852 >>> analyzer = TitrationAnalyzer(['cpH.log']) 

853 >>> analyzer.run() 

854 >>> analyzer.summary() 

855 >>> analyzer.plot_residue('145') 

856 >>> analyzer.plot_all(output_dir='plots/') 

857 >>> analyzer.save_results('results/') 

858 """ 

859 

860 def __init__( 

861 self, 

862 log_files: Path | List[Path] | str | List[str], 

863 output_dir: Optional[Path | str] = None, 

864 ): 

865 """ 

866 Initialize the analyzer. 

867  

868 Parameters 

869 ---------- 

870 log_files : Path, str, or list thereof 

871 Path(s) to constant pH log file(s) 

872 output_dir : Path or str, optional 

873 Directory for output files. If None, uses current directory. 

874 """ 

875 if isinstance(log_files, (str, Path)): 

876 log_files = [log_files] 

877 self.log_files = [Path(f) for f in log_files] 

878 

879 self.output_dir = Path(output_dir) if output_dir else Path('.') 

880 self.output_dir.mkdir(parents=True, exist_ok=True) 

881 

882 # Results storage 

883 self.fits_curvefit: Optional[pl.DataFrame] = None 

884 self.fits_weighted: Optional[pl.DataFrame] = None 

885 self.fits_bootstrap: Optional[pl.DataFrame] = None 

886 self.comparison: Optional[pl.DataFrame] = None 

887 self.titration_data: Optional[pl.DataFrame] = None 

888 

889 # Internal objects 

890 self._tc: Optional[TitrationCurve] = None 

891 

892 # Metadata 

893 self.resid_to_resname: Dict[str, str] = {} 

894 self.resid_cols: List[str] = [] 

895 

896 self._analyzed = False 

897 

898 def run( 

899 self, 

900 methods: List[str] = ['curvefit', 'weighted'], 

901 verbose: bool = True, 

902 n_bootstrap: int = 1000, 

903 ) -> 'TitrationAnalyzer': 

904 """ 

905 Run the analysis with specified methods. 

906  

907 Parameters 

908 ---------- 

909 methods : list of str 

910 Methods to run: 'curvefit', 'weighted', 'bootstrap' 

911 - curvefit: Simple least squares fit of Hill equation 

912 - weighted: Weighted least squares (by 1/variance) 

913 - bootstrap: Curve fit with bootstrap confidence intervals 

914 verbose : bool 

915 Print progress information 

916 n_bootstrap : int 

917 Number of bootstrap iterations (only used if 'bootstrap' in methods) 

918  

919 Returns 

920 ------- 

921 self : for method chaining 

922 """ 

923 if verbose: 

924 print("=" * 60) 

925 print("Constant pH Titration Analysis") 

926 print("=" * 60) 

927 print(f"Log files: {[str(f) for f in self.log_files]}") 

928 

929 # Initialize and prepare 

930 self._tc = TitrationCurve(self.log_files, make_plots=False) 

931 self._tc.prepare() 

932 

933 # Store data for plotting 

934 self.titration_data = self._tc.titrations.clone() 

935 self.resid_to_resname = self._tc.resid_to_resname.copy() 

936 self.resid_cols = self._tc.resid_cols.copy() 

937 

938 if verbose: 

939 n_residues = len(self._tc.resid_cols) 

940 pH_vals = self._tc.df['current_pH'].unique().sort() 

941 print(f"Residues: {n_residues}") 

942 print(f"pH values: {pH_vals.to_list()}") 

943 print(f"Total samples: {len(self._tc.df)}") 

944 

945 # Curve fitting 

946 if 'curvefit' in methods: 

947 if verbose: 

948 print("\n" + "-" * 40) 

949 print("Running curve fitting...") 

950 self.fits_curvefit = self._tc.compute_titrations_curvefit() 

951 if verbose: 

952 n_success = self.fits_curvefit.filter(pl.col('pKa').is_not_nan()).height 

953 print(f" Success: {n_success}/{len(self.fits_curvefit)} residues") 

954 

955 # Weighted fitting 

956 if 'weighted' in methods: 

957 if verbose: 

958 print("\n" + "-" * 40) 

959 print("Running weighted curve fitting...") 

960 self.fits_weighted = self._tc.compute_titrations_weighted(verbose=verbose) 

961 if verbose: 

962 n_success = self.fits_weighted.filter(pl.col('pKa').is_not_nan()).height 

963 print(f" Success: {n_success}/{len(self.fits_weighted)} residues") 

964 

965 # Bootstrap 

966 if 'bootstrap' in methods: 

967 if verbose: 

968 print("\n" + "-" * 40) 

969 print(f"Running bootstrap ({n_bootstrap} iterations)...") 

970 self.fits_bootstrap = self._tc.compute_titrations_bootstrap( 

971 n_bootstrap=n_bootstrap, verbose=verbose 

972 ) 

973 if verbose: 

974 n_success = self.fits_bootstrap.filter(pl.col('pKa').is_not_nan()).height 

975 print(f" Success: {n_success}/{len(self.fits_bootstrap)} residues") 

976 

977 # Generate comparison if multiple methods ran 

978 if self.fits_curvefit is not None and self.fits_weighted is not None: 

979 self._generate_comparison() 

980 

981 self._analyzed = True 

982 

983 if verbose: 

984 print("\n" + "=" * 60) 

985 print("Analysis complete!") 

986 print("=" * 60) 

987 

988 return self 

989 

990 def _generate_comparison(self) -> None: 

991 """Generate comparison DataFrame between curvefit and weighted methods.""" 

992 self.comparison = self.fits_curvefit.join( 

993 self.fits_weighted.select(['resid', 'pKa', 'Hill_n']), 

994 on='resid', 

995 suffix='_weighted' 

996 ).with_columns([ 

997 (pl.col('pKa') - pl.col('pKa_weighted')).alias('pKa_diff'), 

998 (pl.col('Hill_n') - pl.col('Hill_n_weighted')).alias('Hill_n_diff'), 

999 ]) 

1000 

1001 def summary(self, show_all: bool = False) -> pl.DataFrame: 

1002 """ 

1003 Print and return summary of results. 

1004  

1005 Parameters 

1006 ---------- 

1007 show_all : bool 

1008 If True, show all residues. Otherwise show first 20. 

1009  

1010 Returns 

1011 ------- 

1012 DataFrame with comparison results 

1013 """ 

1014 if not self._analyzed: 

1015 raise RuntimeError("Must call run() before summary()") 

1016 

1017 if self.comparison is not None: 

1018 successful = self.comparison.filter( 

1019 pl.col('pKa').is_not_nan() & pl.col('pKa_weighted').is_not_nan() 

1020 ) 

1021 

1022 print(f"\nComparison Summary ({len(successful)} residues with both methods successful):") 

1023 print("-" * 60) 

1024 

1025 if len(successful) > 0: 

1026 delta = successful['pKa_diff'].to_numpy() 

1027 print(f"ΔpKa (curvefit - weighted):") 

1028 print(f" Mean: {np.mean(delta):+.3f}") 

1029 print(f" Std: {np.std(delta):.3f}") 

1030 print(f" Median: {np.median(delta):+.3f}") 

1031 print(f" Range: [{np.min(delta):.3f}, {np.max(delta):.3f}]") 

1032 

1033 display_df = successful.select([ 

1034 'resid', 'resname', 'pKa', 'pKa_weighted', 'pKa_diff', 

1035 'Hill_n', 'Hill_n_weighted' 

1036 ]) 

1037 

1038 if not show_all and len(display_df) > 20: 

1039 print(f"\nShowing first 20 of {len(display_df)} residues (use show_all=True for all):") 

1040 print(display_df.head(20)) 

1041 else: 

1042 print(display_df) 

1043 

1044 return self.comparison 

1045 

1046 elif self.fits_curvefit is not None: 

1047 print("\nCurve Fitting Results:") 

1048 print(self.fits_curvefit if show_all else self.fits_curvefit.head(20)) 

1049 return self.fits_curvefit 

1050 

1051 elif self.fits_weighted is not None: 

1052 print("\nWeighted Fitting Results:") 

1053 print(self.fits_weighted if show_all else self.fits_weighted.head(20)) 

1054 return self.fits_weighted 

1055 

1056 elif self.fits_bootstrap is not None: 

1057 print("\nBootstrap Results:") 

1058 print(self.fits_bootstrap if show_all else self.fits_bootstrap.head(20)) 

1059 return self.fits_bootstrap 

1060 

1061 return None 

1062 

1063 def get_results(self, method: str = 'curvefit') -> pl.DataFrame: 

1064 """ 

1065 Get results DataFrame for specified method. 

1066  

1067 Parameters 

1068 ---------- 

1069 method : str 

1070 'curvefit', 'weighted', 'bootstrap', or 'comparison' 

1071 """ 

1072 if method == 'curvefit': 

1073 return self.fits_curvefit 

1074 elif method == 'weighted': 

1075 return self.fits_weighted 

1076 elif method == 'bootstrap': 

1077 return self.fits_bootstrap 

1078 elif method == 'comparison': 

1079 return self.comparison 

1080 else: 

1081 raise ValueError(f"Unknown method: {method}") 

1082 

1083 def plot_residue( 

1084 self, 

1085 resid: str, 

1086 ax: Optional['plt.Axes'] = None, 

1087 show_curvefit: bool = True, 

1088 show_weighted: bool = True, 

1089 show_data: bool = True, 

1090 figsize: Tuple[float, float] = (8, 6), 

1091 save: Optional[str | Path] = None, 

1092 ) -> 'plt.Figure': 

1093 """ 

1094 Plot titration curve for a single residue. 

1095  

1096 Parameters 

1097 ---------- 

1098 resid : str 

1099 Residue ID to plot 

1100 ax : matplotlib Axes, optional 

1101 Axes to plot on. If None, creates new figure. 

1102 show_curvefit : bool 

1103 Show curve fitting result 

1104 show_weighted : bool 

1105 Show weighted fit result 

1106 show_data : bool 

1107 Show raw data points 

1108 figsize : tuple 

1109 Figure size if creating new figure 

1110 save : str or Path, optional 

1111 Path to save figure 

1112  

1113 Returns 

1114 ------- 

1115 matplotlib Figure 

1116 """ 

1117 try: 

1118 import matplotlib.pyplot as plt 

1119 except ImportError: 

1120 raise ImportError("matplotlib required for plotting: pip install matplotlib") 

1121 

1122 if not self._analyzed: 

1123 raise RuntimeError("Must call run() before plotting") 

1124 

1125 if ax is None: 

1126 fig, ax = plt.subplots(figsize=figsize) 

1127 else: 

1128 fig = ax.get_figure() 

1129 

1130 resname = self.resid_to_resname.get(resid, 'UNK') 

1131 

1132 # Raw data 

1133 resid_data = self.titration_data.filter(pl.col('resid') == resid) 

1134 pH_data = resid_data['current_pH'].to_numpy() 

1135 frac_data = resid_data['fraction_protonated'].to_numpy() 

1136 n_samples = resid_data['n_samples'].to_numpy() 

1137 

1138 # Standard error for binomial 

1139 se = np.sqrt(frac_data * (1 - frac_data) / np.maximum(n_samples, 1)) 

1140 

1141 # Plot data points 

1142 if show_data: 

1143 ax.errorbar( 

1144 pH_data, frac_data, yerr=se, 

1145 fmt='o', color='black', markersize=8, 

1146 capsize=3, capthick=1, elinewidth=1, 

1147 label='Data', zorder=10 

1148 ) 

1149 

1150 # pH grid for curves 

1151 pH_grid = np.linspace( 

1152 min(pH_data) - 0.5, 

1153 max(pH_data) + 0.5, 

1154 200 

1155 ) 

1156 

1157 # Curve fit line (unweighted) 

1158 if show_curvefit and self.fits_curvefit is not None: 

1159 cf_row = self.fits_curvefit.filter(pl.col('resid') == resid) 

1160 if len(cf_row) > 0: 

1161 pKa_cf = cf_row['pKa'][0] 

1162 n_cf = cf_row['Hill_n'][0] 

1163 if not np.isnan(pKa_cf) and not np.isnan(n_cf): 

1164 y_cf = TitrationCurve.hill_equation(pH_grid, pKa_cf, n_cf) 

1165 ax.plot( 

1166 pH_grid, y_cf, '-', color='blue', linewidth=2, 

1167 label=f'Curve fit (pKa={pKa_cf:.2f}, n={n_cf:.2f})' 

1168 ) 

1169 ax.axvline(pKa_cf, color='blue', linestyle=':', alpha=0.5) 

1170 

1171 # Weighted fit line 

1172 if show_weighted and self.fits_weighted is not None: 

1173 wt_row = self.fits_weighted.filter(pl.col('resid') == resid) 

1174 if len(wt_row) > 0: 

1175 pKa_wt = wt_row['pKa'][0] 

1176 n_wt = wt_row['Hill_n'][0] 

1177 if not np.isnan(pKa_wt) and not np.isnan(n_wt): 

1178 y_wt = TitrationCurve.hill_equation(pH_grid, pKa_wt, n_wt) 

1179 ax.plot( 

1180 pH_grid, y_wt, '--', color='red', linewidth=2, 

1181 label=f'Weighted (pKa={pKa_wt:.2f}, n={n_wt:.2f})' 

1182 ) 

1183 ax.axvline(pKa_wt, color='red', linestyle=':', alpha=0.5) 

1184 

1185 # Formatting 

1186 ax.set_xlabel('pH', fontsize=12) 

1187 ax.set_ylabel('Fraction Protonated', fontsize=12) 

1188 ax.set_title(f'Residue {resid} ({resname})', fontsize=14) 

1189 ax.set_ylim(-0.05, 1.05) 

1190 ax.axhline(0.5, color='gray', linestyle='--', alpha=0.3) 

1191 ax.legend(loc='best', fontsize=10) 

1192 ax.grid(True, alpha=0.3) 

1193 

1194 plt.tight_layout() 

1195 

1196 if save: 

1197 fig.savefig(save, dpi=150, bbox_inches='tight') 

1198 

1199 return fig 

1200 

1201 def plot_all( 

1202 self, 

1203 output_dir: Optional[str | Path] = None, 

1204 format: str = 'png', 

1205 show_curvefit: bool = True, 

1206 show_weighted: bool = True, 

1207 residues: Optional[List[str]] = None, 

1208 verbose: bool = True, 

1209 ) -> None: 

1210 """ 

1211 Generate plots for all (or selected) residues. 

1212  

1213 Parameters 

1214 ---------- 

1215 output_dir : str or Path, optional 

1216 Directory for plots. Uses self.output_dir / 'plots' if None. 

1217 format : str 

1218 Image format ('png', 'pdf', 'svg') 

1219 show_curvefit : bool 

1220 Include curve fitting results 

1221 show_weighted : bool 

1222 Include weighted fit results 

1223 residues : list of str, optional 

1224 Specific residues to plot. If None, plots all. 

1225 verbose : bool 

1226 Print progress 

1227 """ 

1228 try: 

1229 import matplotlib.pyplot as plt 

1230 except ImportError: 

1231 raise ImportError("matplotlib required for plotting") 

1232 

1233 if not self._analyzed: 

1234 raise RuntimeError("Must call run() before plotting") 

1235 

1236 plot_dir = Path(output_dir) if output_dir else self.output_dir / 'plots' 

1237 plot_dir.mkdir(parents=True, exist_ok=True) 

1238 

1239 if residues is None: 

1240 residues = self.resid_cols 

1241 

1242 if verbose: 

1243 print(f"Generating {len(residues)} plots in {plot_dir}/") 

1244 

1245 for i, resid in enumerate(residues): 

1246 resname = self.resid_to_resname.get(resid, 'UNK') 

1247 filename = plot_dir / f"{resname}_{resid}.{format}" 

1248 

1249 fig = self.plot_residue( 

1250 resid, 

1251 show_curvefit=show_curvefit, 

1252 show_weighted=show_weighted, 

1253 save=filename 

1254 ) 

1255 plt.close(fig) 

1256 

1257 if verbose and (i + 1) % 20 == 0: 

1258 print(f" {i + 1}/{len(residues)} plots generated...") 

1259 

1260 if verbose: 

1261 print(f" All {len(residues)} plots saved to {plot_dir}/") 

1262 

1263 def plot_summary( 

1264 self, 

1265 figsize: Tuple[float, float] = (12, 5), 

1266 save: Optional[str | Path] = None, 

1267 ) -> 'plt.Figure': 

1268 """ 

1269 Generate summary plot comparing methods. 

1270  

1271 Creates a 2-panel figure: 

1272 - Left: pKa comparison scatter plot 

1273 - Right: Distribution of pKa differences 

1274 """ 

1275 try: 

1276 import matplotlib.pyplot as plt 

1277 except ImportError: 

1278 raise ImportError("matplotlib required for plotting") 

1279 

1280 if self.comparison is None: 

1281 raise RuntimeError("Need both curvefit and weighted methods for summary plot") 

1282 

1283 successful = self.comparison.filter( 

1284 pl.col('pKa').is_not_nan() & pl.col('pKa_weighted').is_not_nan() 

1285 ) 

1286 

1287 if len(successful) == 0: 

1288 raise ValueError("No residues with both methods successful") 

1289 

1290 fig, axes = plt.subplots(1, 2, figsize=figsize) 

1291 

1292 pKa_cf = successful['pKa'].to_numpy() 

1293 pKa_wt = successful['pKa_weighted'].to_numpy() 

1294 diff = successful['pKa_diff'].to_numpy() 

1295 

1296 # Scatter plot 

1297 ax = axes[0] 

1298 ax.scatter(pKa_cf, pKa_wt, alpha=0.6, edgecolor='black', linewidth=0.5) 

1299 lims = [ 

1300 min(min(pKa_cf), min(pKa_wt)) - 0.5, 

1301 max(max(pKa_cf), max(pKa_wt)) + 0.5 

1302 ] 

1303 ax.plot(lims, lims, 'k--', alpha=0.5) 

1304 ax.set_xlim(lims) 

1305 ax.set_ylim(lims) 

1306 ax.set_xlabel('pKa (Curve Fit)', fontsize=12) 

1307 ax.set_ylabel('pKa (Weighted)', fontsize=12) 

1308 ax.set_title('Method Comparison', fontsize=14) 

1309 ax.set_aspect('equal') 

1310 ax.grid(True, alpha=0.3) 

1311 

1312 corr = np.corrcoef(pKa_cf, pKa_wt)[0, 1] 

1313 ax.text( 

1314 0.05, 0.95, f'r = {corr:.3f}', 

1315 transform=ax.transAxes, fontsize=11, 

1316 verticalalignment='top', 

1317 bbox=dict(boxstyle='round', facecolor='white', alpha=0.8) 

1318 ) 

1319 

1320 # Histogram 

1321 ax = axes[1] 

1322 ax.hist(diff, bins=20, edgecolor='black', alpha=0.7) 

1323 ax.axvline(0, color='red', linestyle='--', linewidth=2) 

1324 ax.axvline(np.mean(diff), color='blue', linestyle='-', linewidth=2, 

1325 label=f'Mean = {np.mean(diff):.3f}') 

1326 ax.set_xlabel('ΔpKa (Curve Fit - Weighted)', fontsize=12) 

1327 ax.set_ylabel('Count', fontsize=12) 

1328 ax.set_title('pKa Difference Distribution', fontsize=14) 

1329 ax.legend(fontsize=10) 

1330 ax.grid(True, alpha=0.3) 

1331 

1332 plt.tight_layout() 

1333 

1334 if save: 

1335 fig.savefig(save, dpi=150, bbox_inches='tight') 

1336 

1337 return fig 

1338 

1339 def save_results( 

1340 self, 

1341 output_dir: Optional[str | Path] = None, 

1342 prefix: str = '', 

1343 formats: List[str] = ['csv'], 

1344 ) -> None: 

1345 """ 

1346 Save all results to files. 

1347  

1348 Parameters 

1349 ---------- 

1350 output_dir : str or Path, optional 

1351 Output directory. Uses self.output_dir if None. 

1352 prefix : str 

1353 Prefix for filenames 

1354 formats : list of str 

1355 Output formats: 'csv', 'parquet', 'json' 

1356 """ 

1357 out_dir = Path(output_dir) if output_dir else self.output_dir 

1358 out_dir.mkdir(parents=True, exist_ok=True) 

1359 

1360 prefix = f"{prefix}_" if prefix else "" 

1361 

1362 def save_df(df: pl.DataFrame, name: str): 

1363 for fmt in formats: 

1364 filepath = out_dir / f"{prefix}{name}.{fmt}" 

1365 if fmt == 'csv': 

1366 df.write_csv(filepath) 

1367 elif fmt == 'parquet': 

1368 df.write_parquet(filepath) 

1369 elif fmt == 'json': 

1370 df.write_json(filepath) 

1371 print(f" Saved {filepath}") 

1372 

1373 print(f"Saving results to {out_dir}/") 

1374 

1375 if self.fits_curvefit is not None: 

1376 save_df(self.fits_curvefit, 'pKa_curvefit') 

1377 if self.fits_weighted is not None: 

1378 save_df(self.fits_weighted, 'pKa_weighted') 

1379 if self.fits_bootstrap is not None: 

1380 save_df(self.fits_bootstrap, 'pKa_bootstrap') 

1381 if self.comparison is not None: 

1382 save_df(self.comparison, 'pKa_comparison') 

1383 if self.titration_data is not None: 

1384 save_df(self.titration_data, 'titration_data') 

1385 

1386 def diagnose(self, resid: str) -> Dict: 

1387 """Get diagnostic information for a residue.""" 

1388 if self._tc is None: 

1389 raise RuntimeError("Must call run() first") 

1390 return self._tc.diagnose_residue(resid, verbose=True) 

1391 

1392 def recommend_protonation( 

1393 self, 

1394 target_pH: float, 

1395 confidence_threshold: float = 0.7, 

1396 use_bootstrap: bool = False, 

1397 verbose: bool = True, 

1398 ) -> pl.DataFrame: 

1399 """ 

1400 Recommend protonation states for a target pH. 

1401  

1402 Uses the fitted titration curves to predict which residues are 

1403 protonated vs deprotonated at the specified pH, with confidence 

1404 estimates based on distance from pKa. 

1405  

1406 Parameters 

1407 ---------- 

1408 target_pH : float 

1409 pH value to make predictions for (e.g., 3.0, 7.4) 

1410 confidence_threshold : float 

1411 Probability threshold for "confident" predictions (default 0.7) 

1412 Residues with P(protonated) between (1-threshold) and threshold 

1413 are marked as "uncertain" 

1414 use_bootstrap : bool 

1415 If True and bootstrap results available, use bootstrap CI for 

1416 uncertainty estimation 

1417 verbose : bool 

1418 Print summary of recommendations 

1419  

1420 Returns 

1421 ------- 

1422 DataFrame with columns: 

1423 - resid: residue ID 

1424 - resname: canonical residue name (ASP, GLU, HIS, LYS, CYS) 

1425 - pKa: fitted pKa value 

1426 - prob_protonated: probability of being protonated at target pH 

1427 - recommendation: 'protonated', 'deprotonated', or 'uncertain' 

1428 - confidence: 'high', 'medium', or 'low' 

1429 - state_name: recommended state name (e.g., 'ASH' or 'ASP') 

1430 """ 

1431 if not self._analyzed: 

1432 raise RuntimeError("Must call run() before recommend_protonation()") 

1433 

1434 # Use curvefit results (or weighted if available) 

1435 fits = self.fits_curvefit 

1436 if fits is None: 

1437 fits = self.fits_weighted 

1438 if fits is None: 

1439 raise RuntimeError("No fitting results available") 

1440 

1441 # State name mappings 

1442 protonated_state = { 

1443 'ASP': 'ASH', 'GLU': 'GLH', 'HIS': 'HIP', 

1444 'LYS': 'LYS', 'CYS': 'CYS' 

1445 } 

1446 deprotonated_state = { 

1447 'ASP': 'ASP', 'GLU': 'GLU', 'HIS': 'HIE', 

1448 'LYS': 'LYN', 'CYS': 'CYX' 

1449 } 

1450 

1451 # Reference pKa values for sanity checking 

1452 reference_pKa = { 

1453 'ASP': 3.9, 'GLU': 4.3, 'HIS': 6.0, 

1454 'LYS': 10.5, 'CYS': 8.3 

1455 } 

1456 

1457 recommendations = [] 

1458 

1459 for row in fits.iter_rows(named=True): 

1460 resid = row['resid'] 

1461 resname = row['resname'] 

1462 pKa = row['pKa'] 

1463 hill_n = row['Hill_n'] 

1464 

1465 # Compute probability of being protonated at target pH 

1466 if np.isnan(pKa) or np.isnan(hill_n): 

1467 # No fit available - use reference pKa 

1468 ref_pKa = reference_pKa.get(resname, 7.0) 

1469 prob_prot = 1.0 / (1.0 + 10**(target_pH - ref_pKa)) 

1470 pKa_used = ref_pKa 

1471 fit_quality = 'reference' 

1472 else: 

1473 # Use fitted Hill equation 

1474 prob_prot = TitrationCurve.hill_equation(target_pH, pKa, hill_n) 

1475 pKa_used = pKa 

1476 fit_quality = 'fitted' 

1477 

1478 # Determine recommendation 

1479 if prob_prot >= confidence_threshold: 

1480 recommendation = 'protonated' 

1481 state_name = protonated_state.get(resname, resname) 

1482 elif prob_prot <= (1 - confidence_threshold): 

1483 recommendation = 'deprotonated' 

1484 state_name = deprotonated_state.get(resname, resname) 

1485 else: 

1486 recommendation = 'uncertain' 

1487 # For uncertain cases, go with majority 

1488 if prob_prot >= 0.5: 

1489 state_name = protonated_state.get(resname, resname) 

1490 else: 

1491 state_name = deprotonated_state.get(resname, resname) 

1492 

1493 # Confidence based on distance from 0.5 

1494 prob_distance = abs(prob_prot - 0.5) 

1495 if prob_distance > 0.4: # >90% or <10% 

1496 confidence = 'high' 

1497 elif prob_distance > 0.2: # >70% or <30% 

1498 confidence = 'medium' 

1499 else: 

1500 confidence = 'low' 

1501 

1502 recommendations.append({ 

1503 'resid': resid, 

1504 'resname': resname, 

1505 'pKa': pKa_used, 

1506 'pKa_source': fit_quality, 

1507 'prob_protonated': prob_prot, 

1508 'recommendation': recommendation, 

1509 'confidence': confidence, 

1510 'state_name': state_name, 

1511 }) 

1512 

1513 result = pl.DataFrame(recommendations) 

1514 

1515 if verbose: 

1516 print(f"\n{'='*60}") 

1517 print(f"Protonation Recommendations at pH {target_pH}") 

1518 print(f"{'='*60}") 

1519 

1520 # Summary counts 

1521 n_prot = result.filter(pl.col('recommendation') == 'protonated').height 

1522 n_deprot = result.filter(pl.col('recommendation') == 'deprotonated').height 

1523 n_uncertain = result.filter(pl.col('recommendation') == 'uncertain').height 

1524 

1525 print(f"\nSummary:") 

1526 print(f" Protonated: {n_prot:3d} residues") 

1527 print(f" Deprotonated: {n_deprot:3d} residues") 

1528 print(f" Uncertain: {n_uncertain:3d} residues") 

1529 

1530 # Group by residue type 

1531 print(f"\nBy residue type:") 

1532 for restype in ['ASP', 'GLU', 'HIS', 'LYS', 'CYS']: 

1533 subset = result.filter(pl.col('resname') == restype) 

1534 if len(subset) > 0: 

1535 n_p = subset.filter(pl.col('recommendation') == 'protonated').height 

1536 n_d = subset.filter(pl.col('recommendation') == 'deprotonated').height 

1537 n_u = subset.filter(pl.col('recommendation') == 'uncertain').height 

1538 ref = reference_pKa.get(restype, '?') 

1539 print(f" {restype} (ref pKa={ref}): {n_p} prot, {n_d} deprot, {n_u} uncertain") 

1540 

1541 # Show uncertain residues (most important to check) 

1542 uncertain = result.filter(pl.col('recommendation') == 'uncertain') 

1543 if len(uncertain) > 0: 

1544 print(f"\n⚠️ Uncertain residues (prob between {1-confidence_threshold:.0%}-{confidence_threshold:.0%}):") 

1545 for row in uncertain.sort('prob_protonated', descending=True).iter_rows(named=True): 

1546 print(f" {row['resname']} {row['resid']}: " 

1547 f"P(prot)={row['prob_protonated']:.1%}, " 

1548 f"pKa={row['pKa']:.1f}{row['state_name']}") 

1549 

1550 # Show residues with pKa near target pH 

1551 near_pKa = result.filter( 

1552 (pl.col('pKa') > target_pH - 1.5) & 

1553 (pl.col('pKa') < target_pH + 1.5) & 

1554 (pl.col('pKa_source') == 'fitted') 

1555 ) 

1556 if len(near_pKa) > 0: 

1557 print(f"\n📍 Residues with pKa near pH {target_pH} (±1.5 units):") 

1558 for row in near_pKa.sort('pKa').iter_rows(named=True): 

1559 print(f" {row['resname']} {row['resid']}: " 

1560 f"pKa={row['pKa']:.2f}, " 

1561 f"P(prot)={row['prob_protonated']:.1%}{row['state_name']}") 

1562 

1563 return result 

1564 

1565 def get_protonation_string( 

1566 self, 

1567 target_pH: float, 

1568 confidence_threshold: float = 0.7, 

1569 ) -> str: 

1570 """ 

1571 Get a simple string of recommended protonation states. 

1572  

1573 Useful for setting up simulations. 

1574  

1575 Parameters 

1576 ---------- 

1577 target_pH : float 

1578 pH value to make predictions for 

1579 confidence_threshold : float 

1580 Probability threshold for confident predictions 

1581  

1582 Returns 

1583 ------- 

1584 String with format: "resid:state,resid:state,..." 

1585 """ 

1586 recs = self.recommend_protonation( 

1587 target_pH, 

1588 confidence_threshold=confidence_threshold, 

1589 verbose=False 

1590 ) 

1591 

1592 parts = [] 

1593 for row in recs.iter_rows(named=True): 

1594 parts.append(f"{row['resid']}:{row['state_name']}") 

1595 

1596 return ','.join(parts) 

1597 

1598 def export_protonation_states( 

1599 self, 

1600 target_pH: float, 

1601 output_file: Optional[str | Path] = None, 

1602 format: str = 'csv', 

1603 confidence_threshold: float = 0.7, 

1604 ) -> pl.DataFrame: 

1605 """ 

1606 Export protonation state recommendations to file. 

1607  

1608 Parameters 

1609 ---------- 

1610 target_pH : float 

1611 pH value to make predictions for 

1612 output_file : str or Path, optional 

1613 Output file path. If None, uses output_dir/protonation_pH{pH}.{format} 

1614 format : str 

1615 Output format: 'csv', 'json', or 'txt' 

1616 confidence_threshold : float 

1617 Probability threshold for confident predictions 

1618  

1619 Returns 

1620 ------- 

1621 DataFrame with recommendations 

1622 """ 

1623 recs = self.recommend_protonation( 

1624 target_pH, 

1625 confidence_threshold=confidence_threshold, 

1626 verbose=False 

1627 ) 

1628 

1629 if output_file is None: 

1630 output_file = self.output_dir / f"protonation_pH{target_pH:.1f}.{format}" 

1631 else: 

1632 output_file = Path(output_file) 

1633 

1634 if format == 'csv': 

1635 recs.write_csv(output_file) 

1636 elif format == 'json': 

1637 recs.write_json(output_file) 

1638 elif format == 'txt': 

1639 # Simple text format for easy reading 

1640 with open(output_file, 'w') as f: 

1641 f.write(f"# Protonation states at pH {target_pH}\n") 

1642 f.write(f"# confidence_threshold = {confidence_threshold}\n") 

1643 f.write("#\n") 

1644 f.write("# resid resname state prob_prot confidence\n") 

1645 for row in recs.iter_rows(named=True): 

1646 f.write(f"{row['resid']:>6s} {row['resname']:>7s} " 

1647 f"{row['state_name']:>5s} {row['prob_protonated']:>9.3f} " 

1648 f"{row['confidence']}\n") 

1649 

1650 print(f"Saved protonation recommendations to {output_file}") 

1651 return recs 

1652 

1653 def plot_protonation_summary( 

1654 self, 

1655 target_pH: float, 

1656 figsize: Tuple[float, float] = (12, 6), 

1657 save: Optional[str | Path] = None, 

1658 ) -> 'plt.Figure': 

1659 """ 

1660 Visualize protonation probabilities at target pH. 

1661  

1662 Creates a bar plot showing P(protonated) for each residue, 

1663 colored by residue type. 

1664  

1665 Parameters 

1666 ---------- 

1667 target_pH : float 

1668 pH value to visualize 

1669 figsize : tuple 

1670 Figure size 

1671 save : str or Path, optional 

1672 Path to save figure 

1673  

1674 Returns 

1675 ------- 

1676 matplotlib Figure 

1677 """ 

1678 try: 

1679 import matplotlib.pyplot as plt 

1680 from matplotlib.patches import Patch 

1681 except ImportError: 

1682 raise ImportError("matplotlib required for plotting") 

1683 

1684 recs = self.recommend_protonation(target_pH, verbose=False) 

1685 

1686 # Sort by probability 

1687 recs_sorted = recs.sort('prob_protonated', descending=True) 

1688 

1689 fig, ax = plt.subplots(figsize=figsize) 

1690 

1691 # Colors for each residue type 

1692 colors = { 

1693 'ASP': '#e41a1c', # red 

1694 'GLU': '#ff7f00', # orange 

1695 'HIS': '#4daf4a', # green 

1696 'LYS': '#377eb8', # blue 

1697 'CYS': '#984ea3', # purple 

1698 } 

1699 

1700 x = np.arange(len(recs_sorted)) 

1701 probs = recs_sorted['prob_protonated'].to_numpy() 

1702 resnames = recs_sorted['resname'].to_list() 

1703 resids = recs_sorted['resid'].to_list() 

1704 

1705 bar_colors = [colors.get(rn, 'gray') for rn in resnames] 

1706 

1707 bars = ax.bar(x, probs, color=bar_colors, edgecolor='black', linewidth=0.5) 

1708 

1709 # Add 0.5 line 

1710 ax.axhline(0.5, color='black', linestyle='--', linewidth=2, alpha=0.7) 

1711 ax.axhline(0.7, color='gray', linestyle=':', linewidth=1, alpha=0.5) 

1712 ax.axhline(0.3, color='gray', linestyle=':', linewidth=1, alpha=0.5) 

1713 

1714 # Labels 

1715 ax.set_xlabel('Residue', fontsize=12) 

1716 ax.set_ylabel('P(protonated)', fontsize=12) 

1717 ax.set_title(f'Protonation Probabilities at pH {target_pH}', fontsize=14) 

1718 ax.set_ylim(0, 1.05) 

1719 

1720 # X-axis labels (show every Nth label if too many) 

1721 n_residues = len(x) 

1722 if n_residues > 50: 

1723 # Show fewer labels 

1724 step = n_residues // 20 

1725 ax.set_xticks(x[::step]) 

1726 labels = [f"{resnames[i]}{resids[i]}" for i in range(0, n_residues, step)] 

1727 ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8) 

1728 else: 

1729 ax.set_xticks(x) 

1730 labels = [f"{rn}{ri}" for rn, ri in zip(resnames, resids)] 

1731 ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8) 

1732 

1733 # Legend 

1734 legend_elements = [Patch(facecolor=c, edgecolor='black', label=n) 

1735 for n, c in colors.items() if n in resnames] 

1736 ax.legend(handles=legend_elements, loc='upper right', fontsize=10) 

1737 

1738 # Add text annotations for counts 

1739 n_prot = sum(1 for p in probs if p >= 0.5) 

1740 n_deprot = sum(1 for p in probs if p < 0.5) 

1741 ax.text(0.02, 0.98, f'Protonated: {n_prot}\nDeprotonated: {n_deprot}', 

1742 transform=ax.transAxes, fontsize=10, verticalalignment='top', 

1743 bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) 

1744 

1745 plt.tight_layout() 

1746 

1747 if save: 

1748 fig.savefig(save, dpi=150, bbox_inches='tight') 

1749 

1750 return fig 

1751 

1752 def __repr__(self) -> str: 

1753 status = "analyzed" if self._analyzed else "not analyzed" 

1754 return f"TitrationAnalyzer({len(self.log_files)} log files, {status})" 

1755 

1756 

1757def analyze_cph( 

1758 log_files: Path | List[Path] | str | List[str], 

1759 output_dir: Optional[str | Path] = None, 

1760 methods: List[str] = ['curvefit', 'weighted'], 

1761 plot: bool = True, 

1762 verbose: bool = True, 

1763) -> TitrationAnalyzer: 

1764 """ 

1765 Convenience function to run complete constant pH analysis. 

1766  

1767 Parameters 

1768 ---------- 

1769 log_files : path(s) to log files 

1770 output_dir : output directory 

1771 methods : list of methods to run ('curvefit', 'weighted', 'bootstrap') 

1772 plot : whether to generate plots 

1773 verbose : print progress 

1774  

1775 Returns 

1776 ------- 

1777 TitrationAnalyzer with results 

1778  

1779 Example 

1780 ------- 

1781 >>> results = analyze_cph('cpH.log', output_dir='analysis/') 

1782 >>> results.summary() 

1783 >>> results.plot_residue('145') 

1784 """ 

1785 analyzer = TitrationAnalyzer(log_files, output_dir=output_dir) 

1786 analyzer.run(methods=methods, verbose=verbose) 

1787 

1788 if plot: 

1789 try: 

1790 analyzer.plot_all(verbose=verbose) 

1791 analyzer.plot_summary(save=analyzer.output_dir / 'summary.png') 

1792 except ImportError: 

1793 if verbose: 

1794 print("matplotlib not available, skipping plots") 

1795 except RuntimeError: 

1796 # plot_summary requires both methods 

1797 pass 

1798 

1799 analyzer.save_results() 

1800 

1801 return analyzer 

1802 

1803 

1804if __name__ == '__main__': 

1805 import sys 

1806 

1807 # Get log files from command line or use default 

1808 log_paths = [Path('cpH.log')] 

1809 if len(sys.argv) > 1: 

1810 log_paths = [Path(p) for p in sys.argv[1:]] 

1811 

1812 # ========================================================================= 

1813 # STREAMLINED API - TitrationAnalyzer 

1814 # ========================================================================= 

1815 #  

1816 # Available methods: 

1817 # - curvefit: Simple least squares fit (default) 

1818 # - weighted: Weighted least squares (by 1/variance)  

1819 # - bootstrap: Curve fit with bootstrap confidence intervals 

1820 # 

1821 # Basic usage: 

1822 # analyzer = TitrationAnalyzer(log_paths) 

1823 # analyzer.run() 

1824 # analyzer.summary() 

1825 # 

1826 # Protonation recommendations: 

1827 # recs = analyzer.recommend_protonation(target_pH=3.0) 

1828 # analyzer.plot_protonation_summary(target_pH=3.0) 

1829 # 

1830 # ========================================================================= 

1831 

1832 # Create analyzer 

1833 analyzer = TitrationAnalyzer(log_paths, output_dir='cph_analysis') 

1834 

1835 # Run curve fitting and weighted fitting 

1836 analyzer.run(methods=['curvefit', 'weighted'], verbose=True) 

1837 

1838 # Print summary 

1839 analyzer.summary() 

1840 

1841 # Generate all plots (if matplotlib available) 

1842 try: 

1843 analyzer.plot_all(verbose=True) 

1844 analyzer.plot_summary(save='cph_analysis/summary.png') 

1845 print("\nPlots saved to cph_analysis/plots/") 

1846 except ImportError: 

1847 print("\nSkipping plots (matplotlib not installed)") 

1848 

1849 # Save results 

1850 analyzer.save_results() 

1851 

1852 # ========================================================================= 

1853 # PROTONATION RECOMMENDATIONS 

1854 # ========================================================================= 

1855 

1856 # Get recommendations for pH 3.0 

1857 print("\n") 

1858 recs = analyzer.recommend_protonation(target_pH=3.0) 

1859 

1860 # Export to file 

1861 analyzer.export_protonation_states(target_pH=3.0, format='csv') 

1862 

1863 # Visualize 

1864 try: 

1865 analyzer.plot_protonation_summary( 

1866 target_pH=3.0, 

1867 save='cph_analysis/protonation_pH3.0.png' 

1868 ) 

1869 except ImportError: 

1870 pass 

1871 

1872 # Can also get recommendations for physiological pH 

1873 # analyzer.recommend_protonation(target_pH=7.4)