Coverage for src / molecular_simulations / analysis / constant_pH_analysis.py: 35%
754 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-13 01:26 -0600
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-13 01:26 -0600
1"""
2Improved constant pH analysis with UWHAM reweighting.
4This implementation adds multistate analysis capabilities to the basic
5curve fitting approach. Uses log-space arithmetic for numerical stability.
6"""
8from __future__ import annotations
10import ast
11import numpy as np
12from pathlib import Path
13import polars as pl
14import re
15from scipy.optimize import curve_fit, brentq
16from scipy.special import logsumexp
17from typing import Optional, Dict, List, Tuple, TYPE_CHECKING
18import warnings
20if TYPE_CHECKING:
21 import matplotlib.pyplot as plt
24class UWHAMSolver:
25 """
26 Unbinned Weighted Histogram Analysis Method (UWHAM) solver.
28 NOTE: This class is NOT currently used because UWHAM/MBAR is designed
29 for umbrella sampling and replica exchange, not independent constant pH
30 simulations. For standard constant pH MD where each pH is an independent
31 equilibrium simulation, simple curve fitting is the correct approach.
33 This class is retained for potential use with replica exchange constant
34 pH (REX-cpH) simulations where samples ARE correlated across pH values.
36 Uses log-space arithmetic throughout for numerical stability with
37 large systems (100+ titratable residues).
38 """
40 def __init__(self, tol: float = 1e-7, maxiter: int = 10000):
41 self.tol = tol
42 self.maxiter = maxiter
43 self.f = None # Log of normalization constants (will be solved)
44 self.log10 = np.log(10)
46 def load_data(self, df: pl.DataFrame, resid_cols: List[str]):
47 """
48 Load data from polars DataFrame into UWHAM-compatible format.
50 Parameters
51 ----------
52 df : pl.DataFrame
53 DataFrame with columns: rankid, current_pH, and residue columns
54 Residue columns should contain numeric protonation states (0 or 1)
55 resid_cols : List[str]
56 List of column names corresponding to residue IDs
57 """
58 # Get unique pH values and count samples
59 pH_groups = df.group_by('current_pH').agg(pl.len().alias('count'))
60 self.pH_values = pH_groups['current_pH'].to_numpy()
61 self.nsamples = pH_groups['count'].to_numpy().astype(int)
62 self.nstates = len(self.pH_values)
64 # Sort by pH for consistency
65 sort_idx = np.argsort(self.pH_values)
66 self.pH_values = self.pH_values[sort_idx]
67 self.nsamples = self.nsamples[sort_idx]
69 # Store state data for each pH simulation
70 self.states = {} # resid -> list of arrays (one per pH)
71 self.nprotons_total = [] # Total protons for each pH simulation
73 for resid_col in resid_cols:
74 self.states[resid_col] = []
76 # Extract data for each pH
77 for pH in self.pH_values:
78 pH_data = df.filter(pl.col('current_pH') == pH)
80 # Compute total protons for this pH's samples
81 total_protons = np.zeros(len(pH_data))
83 # For each residue, store states
84 for resid_col in resid_cols:
85 states = pH_data[resid_col].to_numpy().astype(float)
86 self.states[resid_col].append(states)
87 total_protons += states
89 self.nprotons_total.append(total_protons)
91 # Precompute reduced potentials for all state pairs
92 # u_kl[k] is shape (nstates, n_k) - reduced potential of samples from k evaluated at all states
93 self.u_kl = []
94 for k in range(self.nstates):
95 n_k = self.nsamples[k]
96 u_k = np.zeros((self.nstates, n_k))
97 for l in range(self.nstates):
98 u_k[l, :] = self.log10 * self.pH_values[l] * self.nprotons_total[k]
99 self.u_kl.append(u_k)
101 def solve(self, verbose: bool = False):
102 """
103 Solve UWHAM self-consistent equations iteratively.
105 Uses the MBAR equation:
106 f_k = -log(Σ_n exp(-u_k(x_n)) / Σ_l N_l exp(f_l - u_l(x_n)))
108 where the sum over n includes ALL samples from ALL states.
110 Returns
111 -------
112 f : np.ndarray
113 Free energy offsets for each pH simulation
114 """
115 # Initialize free energies
116 f = np.zeros(self.nstates)
117 log_N = np.log(self.nsamples.astype(float))
118 total_samples = sum(self.nsamples)
120 # Precompute reduced potentials for all samples at all target states
121 # u_all[target_k, sample_idx] = reduced potential at state k for sample idx
122 # Also store source state for each sample
123 u_all = np.zeros((self.nstates, total_samples))
124 sample_source = np.zeros(total_samples, dtype=int) # which state each sample came from
126 idx = 0
127 for source_i in range(self.nstates):
128 n_i = self.nsamples[source_i]
129 for n in range(n_i):
130 sample_source[idx] = source_i
131 for target_k in range(self.nstates):
132 # u_k(x_n) = log10 * pH_k * nprotons(x_n)
133 u_all[target_k, idx] = self.log10 * self.pH_values[target_k] * self.nprotons_total[source_i][n]
134 idx += 1
136 for iteration in range(self.maxiter):
137 f_old = f.copy()
139 # Compute denominator for each sample: c_n = Σ_l N_l exp(f_l - u_l(x_n))
140 # log(c_n) = logsumexp(log_N + f - u_l(x_n))
141 log_c = np.zeros(total_samples)
142 for n in range(total_samples):
143 source_i = sample_source[n]
144 # u_l(x_n) for all states l - this is stored in u_kl[source_i]
145 log_c[n] = logsumexp(log_N + f - self.u_kl[source_i][:, n % self.nsamples[source_i]])
147 # Wait, that indexing is wrong. Let me redo this.
148 # Actually I need to recompute using proper indexing
150 log_c = np.zeros(total_samples)
151 idx = 0
152 for source_i in range(self.nstates):
153 n_i = self.nsamples[source_i]
154 for local_n in range(n_i):
155 # u_l(x_n) for all states l
156 log_c[idx] = logsumexp(log_N + f - self.u_kl[source_i][:, local_n])
157 idx += 1
159 # Update each free energy
160 for target_k in range(self.nstates):
161 # f_k = -log(Σ_n exp(-u_k(x_n)) / c_n)
162 # = -log(Σ_n exp(-u_k(x_n) - log(c_n)))
163 # = -logsumexp(-u_all[target_k, :] - log_c)
164 log_weights = -u_all[target_k, :] - log_c
165 f[target_k] = -logsumexp(log_weights)
167 # Normalize so f[0] = 0
168 f = f - f[0]
170 # Check convergence
171 delta = np.abs(f - f_old).max()
172 if verbose and iteration % 100 == 0:
173 print(f" Iteration {iteration}: max|Δf| = {delta:.2e}")
175 if delta < self.tol:
176 if verbose:
177 print(f" Converged after {iteration + 1} iterations")
178 break
179 else:
180 warnings.warn(
181 f"UWHAM did not converge after {self.maxiter} iterations "
182 f"(final delta = {delta:.2e})"
183 )
185 self.f = f
186 self.log_c = log_c # Store for weight computation
187 self.u_all = u_all # Store for weight computation
188 self.sample_source = sample_source
189 self.total_samples = total_samples
191 return f
193 def compute_log_weights(self, target_pH: float) -> Tuple[np.ndarray, float]:
194 """
195 Compute log weights for reweighting to target pH.
197 Uses MBAR formula:
198 w_n ∝ exp(-u_target(x_n)) / Σ_l N_l exp(f_l - u_l(x_n))
200 Returns
201 -------
202 log_weights : np.ndarray
203 Log weights for all samples (flattened)
204 log_norm : float
205 Log of the normalization constant
206 """
207 if self.f is None:
208 raise RuntimeError("Must call solve() before computing weights")
210 # Compute reduced potential at target pH for all samples
211 u_target = np.zeros(self.total_samples)
212 idx = 0
213 for source_i in range(self.nstates):
214 n_i = self.nsamples[source_i]
215 for local_n in range(n_i):
216 u_target[idx] = self.log10 * target_pH * self.nprotons_total[source_i][local_n]
217 idx += 1
219 # log(w_n) = -u_target(x_n) - log(c_n)
220 # where log(c_n) was precomputed in solve()
221 log_weights = -u_target - self.log_c
223 # Normalize
224 log_norm = logsumexp(log_weights)
226 return log_weights, log_norm
228 def compute_expectation_at_pH(
229 self,
230 observable_by_state: List[np.ndarray],
231 target_pH: float
232 ) -> float:
233 """
234 Compute expectation value of observable at arbitrary pH.
236 Parameters
237 ----------
238 observable_by_state : List[np.ndarray]
239 Observable values for each sample, organized by state index
240 target_pH : float
241 pH value at which to compute the expectation
243 Returns
244 -------
245 expectation : float
246 Reweighted expectation value at target_pH
247 """
248 log_weights, log_norm = self.compute_log_weights(target_pH)
250 # Flatten observable to match log_weights ordering
251 obs_flat = np.concatenate(observable_by_state)
253 # Compute weighted sum
254 # <A> = Σ_n A_n * w_n / Σ_n w_n
255 # = Σ_n A_n * exp(log_w_n - log_norm)
256 weights = np.exp(log_weights - log_norm)
258 return np.sum(obs_flat * weights)
260 def get_occupancy_for_resid(self, resid: str) -> List[np.ndarray]:
261 """Get occupancy arrays for a specific residue across all pH values."""
262 return self.states[resid]
265class TitrationCurve:
266 """
267 Analyze constant pH simulations with multiple fitting methods.
269 Available methods:
270 - curvefit: Simple least squares fit of Hill equation to per-pH averages
271 - weighted: Weighted least squares (weight by 1/variance)
272 - bootstrap: Curve fitting with bootstrap confidence intervals
274 Note: For independent constant pH simulations (not replica exchange),
275 simple curve fitting is the statistically correct approach. UWHAM/MBAR
276 is only appropriate for replica exchange constant pH where samples
277 are correlated across pH values.
278 """
280 def __init__(
281 self,
282 log_file: Path | List[Path],
283 make_plots: bool = True,
284 out: Path = Path('.'),
285 method: str = 'uwham' # 'curvefit' or 'uwham'
286 ):
287 if isinstance(log_file, list):
288 dfs = []
289 resids = None
290 for log in log_file:
291 df, r = self.parse_log(log)
292 dfs.append(df)
293 if resids is None:
294 resids = r
295 self.df = pl.concat(dfs, how='vertical')
296 else:
297 self.df, resids = self.parse_log(log_file)
299 # Store residue IDs (converted to strings to match column names)
300 self.resid_cols = [str(r) for r in resids]
302 self.make_plots = make_plots
303 self.out = out
304 self.method = method
306 @staticmethod
307 def parse_log(log: Path) -> Tuple[pl.DataFrame, List[int]]:
308 """Parse OpenMM constant pH log file.
310 Returns
311 -------
312 df : pl.DataFrame
313 DataFrame with columns: rankid, current_pH, and one column per residue
314 resids : List[int]
315 List of residue IDs in order
316 """
317 lines = log.read_text().splitlines()
319 resids = None
320 # Header format: "cpH: resids 20 76 83 92 ..."
321 header_re = re.compile(r'cpH:\s+resids\s+(.+)$')
323 # Find header with residue IDs
324 for line in lines:
325 m = header_re.search(line)
326 if m:
327 # Residue IDs are separated by whitespace (possibly multiple spaces)
328 resids = [int(x) for x in m.group(1).split()]
329 break
331 if resids is None:
332 raise RuntimeError(
333 'Could not find cpH residue ID header line in log. '
334 'Expected line containing "cpH: resids ..."'
335 )
337 # Parse state lines
338 state_re = re.compile(
339 r'rank=(\d+).*cpH:\s+pH\s+([0-9.]+):\s+(\[.*\])'
340 )
342 rows = []
343 for line in lines:
344 m = state_re.search(line)
345 if not m:
346 continue
348 rank = int(m.group(1))
349 current_pH = float(m.group(2))
350 states_list = ast.literal_eval(m.group(3))
352 if len(states_list) != len(resids):
353 raise ValueError(
354 f'Mismatch between number of residues ({len(resids)}) '
355 f'and number of states ({len(states_list)})'
356 )
358 # Build row dictionary
359 row = {
360 'rankid': rank,
361 'current_pH': current_pH,
362 }
363 row.update({
364 str(resid): state
365 for resid, state in zip(resids, states_list)
366 })
367 rows.append(row)
369 return pl.DataFrame(rows), resids
371 def prepare(self) -> None:
372 """Prepare data for analysis."""
373 # Melt to long format for curve fitting method
374 self.df_long = self.df.unpivot(
375 index=['rankid', 'current_pH'],
376 on=self.resid_cols,
377 variable_name='resid',
378 value_name='state',
379 )
381 # Determine canonical resname for each residue ID
382 # Look at the first state observed for each residue
383 self.resid_to_resname = {}
384 for resid_col in self.resid_cols:
385 # Get the first non-null state for this residue
386 first_state = self.df[resid_col].drop_nulls().head(1).to_list()
387 if first_state:
388 state = first_state[0]
389 self.resid_to_resname[resid_col] = self.canonical_resname.get(state, state)
390 else:
391 self.resid_to_resname[resid_col] = 'UNK'
393 # Map states to protonation (1 or 0)
394 self.df_long = self.df_long.with_columns(
395 pl.col('state').map_elements(
396 lambda x: self.protonation_mapping.get(x),
397 return_dtype=pl.Int64
398 ).alias('prot')
399 ).drop_nulls('prot')
401 # Compute per-pH statistics for curve fitting
402 self.titrations = (
403 self.df_long.group_by(['resid', 'current_pH'])
404 .agg([
405 pl.col('prot').mean().alias('fraction_protonated'),
406 pl.col('prot').count().alias('n_samples')
407 ])
408 .sort(['resid', 'current_pH'])
409 )
411 def compute_titrations_curvefit(self) -> pl.DataFrame:
412 """
413 Compute pKa and Hill coefficient using scipy curve_fit.
415 This is the simple approach that treats each pH independently.
416 """
417 fit_rows = []
419 for resid, subdf in self.titrations.group_by('resid', maintain_order=True):
420 resid = resid[0] # Unpack tuple
421 resname = self.resid_to_resname.get(resid, 'UNK')
422 x = subdf['current_pH'].to_numpy().astype(float)
423 y = subdf['fraction_protonated'].to_numpy().astype(float)
425 if x.size < 3:
426 # Not enough data points
427 fit_rows.append({
428 'resid': resid,
429 'resname': resname,
430 'pKa': np.nan,
431 'Hill_n': np.nan,
432 'pKa_err': np.nan,
433 'Hill_n_err': np.nan,
434 'n_points': int(x.size),
435 'method': 'curvefit'
436 })
437 continue
439 # Initial guess: pKa where fraction ~ 0.5
440 idx_mid = np.argmin(np.abs(y - 0.5))
441 pKa0 = x[idx_mid]
442 n0 = 1.0
444 try:
445 popt, pcov = curve_fit(
446 self.hill_equation,
447 x, y,
448 p0=[pKa0, n0],
449 bounds=([0., 0.1], [14., 10.]),
450 maxfev=5000
451 )
452 pKa, n = popt
453 pKa_err = np.sqrt(np.diag(pcov))[0] if pcov is not None else np.nan
454 n_err = np.sqrt(np.diag(pcov))[1] if pcov is not None else np.nan
455 except Exception as e:
456 pKa, n = np.nan, np.nan
457 pKa_err, n_err = np.nan, np.nan
459 fit_rows.append({
460 'resid': resid,
461 'resname': resname,
462 'pKa': float(pKa),
463 'Hill_n': float(n),
464 'pKa_err': float(pKa_err),
465 'Hill_n_err': float(n_err),
466 'n_points': int(x.size),
467 'method': 'curvefit'
468 })
470 return pl.DataFrame(fit_rows)
472 def compute_titrations_weighted(self, verbose: bool = False) -> pl.DataFrame:
473 """
474 Compute pKa and Hill coefficient using weighted least squares.
476 Weights each pH point by 1/variance, giving more influence to
477 points with more samples and intermediate protonation fractions.
479 This is more statistically rigorous than unweighted curve fitting
480 when sample sizes vary across pH values.
481 """
482 fit_rows = []
484 for resid, subdf in self.titrations.group_by('resid', maintain_order=True):
485 resid = resid[0]
486 resname = self.resid_to_resname.get(resid, 'UNK')
487 x = subdf['current_pH'].to_numpy().astype(float)
488 y = subdf['fraction_protonated'].to_numpy().astype(float)
489 n = subdf['n_samples'].to_numpy().astype(float)
491 if x.size < 3:
492 fit_rows.append({
493 'resid': resid,
494 'resname': resname,
495 'pKa': np.nan,
496 'Hill_n': np.nan,
497 'pKa_err': np.nan,
498 'Hill_n_err': np.nan,
499 'n_points': int(x.size),
500 'method': 'weighted'
501 })
502 continue
504 # Compute weights: 1/variance for binomial
505 # Var(p) = p(1-p)/n, but avoid division by zero
506 # Add small epsilon to avoid infinite weights at p=0 or p=1
507 eps = 0.01
508 y_clipped = np.clip(y, eps, 1 - eps)
509 variance = y_clipped * (1 - y_clipped) / n
510 weights = 1.0 / variance
511 # Normalize weights
512 weights = weights / weights.sum()
514 # Initial guess
515 idx_mid = np.argmin(np.abs(y - 0.5))
516 pKa0 = x[idx_mid]
517 n0 = 1.0
519 try:
520 # Weighted curve fit using sigma = 1/sqrt(weight)
521 sigma = 1.0 / np.sqrt(weights * len(weights))
522 popt, pcov = curve_fit(
523 self.hill_equation,
524 x, y,
525 p0=[pKa0, n0],
526 sigma=sigma,
527 absolute_sigma=False,
528 bounds=([0., 0.1], [14., 10.]),
529 maxfev=5000
530 )
531 pKa, hill_n = popt
532 pKa_err = np.sqrt(np.diag(pcov))[0] if pcov is not None else np.nan
533 n_err = np.sqrt(np.diag(pcov))[1] if pcov is not None else np.nan
534 except Exception:
535 pKa, hill_n = np.nan, np.nan
536 pKa_err, n_err = np.nan, np.nan
538 fit_rows.append({
539 'resid': resid,
540 'resname': resname,
541 'pKa': float(pKa),
542 'Hill_n': float(hill_n),
543 'pKa_err': float(pKa_err),
544 'Hill_n_err': float(n_err),
545 'n_points': int(x.size),
546 'method': 'weighted'
547 })
549 return pl.DataFrame(fit_rows)
551 def compute_titrations_bootstrap(
552 self,
553 n_bootstrap: int = 1000,
554 verbose: bool = False
555 ) -> pl.DataFrame:
556 """
557 Compute pKa and Hill coefficient with bootstrap confidence intervals.
559 Resamples the data at each pH to estimate uncertainty in fitted
560 parameters. This gives robust error estimates even when the
561 Hill equation doesn't perfectly fit the data.
563 Parameters
564 ----------
565 n_bootstrap : int
566 Number of bootstrap iterations (default 1000)
567 verbose : bool
568 Print progress
570 Returns
571 -------
572 DataFrame with pKa, Hill_n, and 95% confidence intervals
573 """
574 fit_rows = []
576 if verbose:
577 print(f"Running bootstrap with {n_bootstrap} iterations...")
579 for i, (resid, subdf) in enumerate(self.titrations.group_by('resid', maintain_order=True)):
580 resid = resid[0]
581 resname = self.resid_to_resname.get(resid, 'UNK')
582 x = subdf['current_pH'].to_numpy().astype(float)
583 y = subdf['fraction_protonated'].to_numpy().astype(float)
584 n_samples = subdf['n_samples'].to_numpy().astype(int)
586 if verbose and (i + 1) % 20 == 0:
587 print(f" {i + 1}/{len(self.resid_cols)} residues...")
589 if x.size < 3:
590 fit_rows.append({
591 'resid': resid,
592 'resname': resname,
593 'pKa': np.nan,
594 'pKa_lo': np.nan,
595 'pKa_hi': np.nan,
596 'Hill_n': np.nan,
597 'Hill_n_lo': np.nan,
598 'Hill_n_hi': np.nan,
599 'n_points': int(x.size),
600 'method': 'bootstrap'
601 })
602 continue
604 # First fit to get point estimate
605 idx_mid = np.argmin(np.abs(y - 0.5))
606 pKa0 = x[idx_mid]
608 try:
609 popt, _ = curve_fit(
610 self.hill_equation,
611 x, y,
612 p0=[pKa0, 1.0],
613 bounds=([0., 0.1], [14., 10.]),
614 maxfev=5000
615 )
616 pKa_point, hill_n_point = popt
617 except Exception:
618 pKa_point, hill_n_point = np.nan, np.nan
620 # Bootstrap resampling
621 pKa_boots = []
622 hill_n_boots = []
624 for _ in range(n_bootstrap):
625 # Resample: for each pH, draw n_samples from Binomial(n, p)
626 y_boot = np.zeros(len(x))
627 for j in range(len(x)):
628 # Number of protonated in bootstrap sample
629 n_prot = np.random.binomial(n_samples[j], y[j])
630 y_boot[j] = n_prot / n_samples[j]
632 try:
633 popt_boot, _ = curve_fit(
634 self.hill_equation,
635 x, y_boot,
636 p0=[pKa0, 1.0],
637 bounds=([0., 0.1], [14., 10.]),
638 maxfev=2000
639 )
640 pKa_boots.append(popt_boot[0])
641 hill_n_boots.append(popt_boot[1])
642 except Exception:
643 pass
645 # Compute confidence intervals
646 if len(pKa_boots) > 10:
647 pKa_lo, pKa_hi = np.percentile(pKa_boots, [2.5, 97.5])
648 hill_n_lo, hill_n_hi = np.percentile(hill_n_boots, [2.5, 97.5])
649 else:
650 pKa_lo, pKa_hi = np.nan, np.nan
651 hill_n_lo, hill_n_hi = np.nan, np.nan
653 fit_rows.append({
654 'resid': resid,
655 'resname': resname,
656 'pKa': float(pKa_point),
657 'pKa_lo': float(pKa_lo),
658 'pKa_hi': float(pKa_hi),
659 'Hill_n': float(hill_n_point),
660 'Hill_n_lo': float(hill_n_lo),
661 'Hill_n_hi': float(hill_n_hi),
662 'n_points': int(x.size),
663 'method': 'bootstrap'
664 })
666 return pl.DataFrame(fit_rows)
668 def compute_titrations(self, verbose: bool = False, n_bootstrap: int = 1000) -> None:
669 """Compute titrations using selected method."""
670 if self.method == 'curvefit':
671 self.fits = self.compute_titrations_curvefit()
672 elif self.method == 'weighted':
673 self.fits = self.compute_titrations_weighted(verbose=verbose)
674 elif self.method == 'bootstrap':
675 self.fits = self.compute_titrations_bootstrap(n_bootstrap=n_bootstrap, verbose=verbose)
676 else:
677 raise ValueError(f"Unknown method: {self.method}. Use 'curvefit', 'weighted', or 'bootstrap'")
679 def postprocess(self) -> None:
680 """Generate fitted curves for plotting."""
681 if self.fits is None:
682 raise RuntimeError("Must call compute_titrations() first")
684 pH_grid = np.linspace(
685 float(self.df['current_pH'].min()),
686 float(self.df['current_pH'].max()),
687 200
688 )
690 curves = []
691 for row in self.fits.iter_rows(named=True):
692 resid = row['resid']
693 pKa = row['pKa']
694 n = row['Hill_n']
696 if np.isnan(pKa) or np.isnan(n):
697 continue
699 y_fit = self.hill_equation(pH_grid, pKa, n)
700 curves.append(
701 pl.DataFrame({
702 'resid': [resid] * len(pH_grid),
703 'pH': pH_grid,
704 'fraction_protonated_fit': y_fit,
705 })
706 )
708 self.curves = pl.concat(curves) if curves else None
710 if self.make_plots:
711 self.plot()
713 def plot(self) -> None:
714 """Generate plots (to be implemented)."""
715 pass
717 def diagnose_residue(self, resid: str, verbose: bool = True) -> Dict:
718 """
719 Diagnose why a residue might have failed pKa determination.
721 Parameters
722 ----------
723 resid : str
724 Residue ID to diagnose
725 verbose : bool
726 Print diagnostic information
728 Returns
729 -------
730 dict with diagnostic info including titration curve data
731 """
732 # Get per-pH fraction protonated from simple averaging
733 resid_data = self.titrations.filter(pl.col('resid') == resid)
735 pH_vals = resid_data['current_pH'].to_numpy()
736 frac_prot = resid_data['fraction_protonated'].to_numpy()
737 n_samples = resid_data['n_samples'].to_numpy()
739 # Get state distribution
740 resid_states = self.df_long.filter(pl.col('resid') == resid)
741 state_counts = resid_states.group_by('state').agg(pl.len().alias('count'))
743 resname = self.resid_to_resname.get(resid, 'UNK')
745 result = {
746 'resid': resid,
747 'resname': resname,
748 'pH': pH_vals,
749 'fraction_protonated': frac_prot,
750 'n_samples': n_samples,
751 'state_distribution': state_counts.to_dict(),
752 'frac_min': frac_prot.min() if len(frac_prot) > 0 else np.nan,
753 'frac_max': frac_prot.max() if len(frac_prot) > 0 else np.nan,
754 }
756 if verbose:
757 print(f"\nDiagnostics for residue {resid} ({resname}):")
758 print(f" State distribution:")
759 for row in state_counts.iter_rows(named=True):
760 print(f" {row['state']}: {row['count']}")
761 print(f"\n Titration curve (simple average):")
762 print(f" {'pH':>6s} {'frac':>6s} {'n':>5s}")
763 for pH, f, n in zip(pH_vals, frac_prot, n_samples):
764 print(f" {pH:6.2f} {f:6.3f} {n:5d}")
765 print(f"\n Fraction range: {result['frac_min']:.3f} - {result['frac_max']:.3f}")
767 if result['frac_min'] > 0.5:
768 print(f" → Always >50% protonated - pKa likely ABOVE pH {pH_vals.max():.1f}")
769 elif result['frac_max'] < 0.5:
770 print(f" → Always <50% protonated - pKa likely BELOW pH {pH_vals.min():.1f}")
771 elif result['frac_max'] - result['frac_min'] < 0.1:
772 print(f" → Very little titration observed - may not titrate in this pH range")
774 return result
776 @staticmethod
777 def hill_equation(pH: float, pKa: float, n: float) -> float:
778 """
779 Hill equation for acid-base equilibrium.
781 Returns fraction protonated as function of pH.
782 """
783 return 1.0 / (1.0 + 10.0**(n * (pH - pKa)))
785 @property
786 def protonation_mapping(self) -> Dict[str, int]:
787 """Map state names to protonation numbers (1 = protonated, 0 = not)."""
788 return {
789 'ASH': 1, 'ASP': 0,
790 'GLH': 1, 'GLU': 0,
791 'LYS': 1, 'LYN': 0,
792 'CYS': 1, 'CYX': 0,
793 'HIP': 1, 'HIE': 0, 'HID': 0,
794 }
796 @property
797 def canonical_resname(self) -> Dict[str, str]:
798 """Map any state name to canonical residue name."""
799 return {
800 'ASH': 'ASP', 'ASP': 'ASP',
801 'GLH': 'GLU', 'GLU': 'GLU',
802 'LYS': 'LYS', 'LYN': 'LYS',
803 'CYS': 'CYS', 'CYX': 'CYS',
804 'HIP': 'HIS', 'HIE': 'HIS', 'HID': 'HIS',
805 }
807 def compare_methods(self, resids: Optional[List[str]] = None) -> pl.DataFrame:
808 """
809 Compare curve fit vs UWHAM results for specified residues.
811 Parameters
812 ----------
813 resids : List[str], optional
814 Residues to compare. If None, compares all.
816 Returns
817 -------
818 DataFrame with both methods' results side by side
819 """
820 # Run both methods
821 fits_cf = self.compute_titrations_curvefit()
822 fits_uw = self.compute_titrations_uwham(verbose=False)
824 # Join on resid
825 comparison = fits_cf.join(
826 fits_uw.select(['resid', 'pKa', 'Hill_n', 'status']),
827 on='resid',
828 suffix='_uwham'
829 )
831 # Add difference columns
832 comparison = comparison.with_columns([
833 (pl.col('pKa') - pl.col('pKa_uwham')).alias('pKa_diff'),
834 (pl.col('Hill_n') - pl.col('Hill_n_uwham')).alias('Hill_n_diff'),
835 ])
837 if resids is not None:
838 comparison = comparison.filter(pl.col('resid').is_in(resids))
840 return comparison
843class TitrationAnalyzer:
844 """
845 High-level analyzer for constant pH simulations.
847 Provides a streamlined API that runs both curve fitting and UWHAM analysis,
848 generates comparisons, and creates publication-quality plots.
850 Example usage
851 -------------
852 >>> analyzer = TitrationAnalyzer(['cpH.log'])
853 >>> analyzer.run()
854 >>> analyzer.summary()
855 >>> analyzer.plot_residue('145')
856 >>> analyzer.plot_all(output_dir='plots/')
857 >>> analyzer.save_results('results/')
858 """
860 def __init__(
861 self,
862 log_files: Path | List[Path] | str | List[str],
863 output_dir: Optional[Path | str] = None,
864 ):
865 """
866 Initialize the analyzer.
868 Parameters
869 ----------
870 log_files : Path, str, or list thereof
871 Path(s) to constant pH log file(s)
872 output_dir : Path or str, optional
873 Directory for output files. If None, uses current directory.
874 """
875 if isinstance(log_files, (str, Path)):
876 log_files = [log_files]
877 self.log_files = [Path(f) for f in log_files]
879 self.output_dir = Path(output_dir) if output_dir else Path('.')
880 self.output_dir.mkdir(parents=True, exist_ok=True)
882 # Results storage
883 self.fits_curvefit: Optional[pl.DataFrame] = None
884 self.fits_weighted: Optional[pl.DataFrame] = None
885 self.fits_bootstrap: Optional[pl.DataFrame] = None
886 self.comparison: Optional[pl.DataFrame] = None
887 self.titration_data: Optional[pl.DataFrame] = None
889 # Internal objects
890 self._tc: Optional[TitrationCurve] = None
892 # Metadata
893 self.resid_to_resname: Dict[str, str] = {}
894 self.resid_cols: List[str] = []
896 self._analyzed = False
898 def run(
899 self,
900 methods: List[str] = ['curvefit', 'weighted'],
901 verbose: bool = True,
902 n_bootstrap: int = 1000,
903 ) -> 'TitrationAnalyzer':
904 """
905 Run the analysis with specified methods.
907 Parameters
908 ----------
909 methods : list of str
910 Methods to run: 'curvefit', 'weighted', 'bootstrap'
911 - curvefit: Simple least squares fit of Hill equation
912 - weighted: Weighted least squares (by 1/variance)
913 - bootstrap: Curve fit with bootstrap confidence intervals
914 verbose : bool
915 Print progress information
916 n_bootstrap : int
917 Number of bootstrap iterations (only used if 'bootstrap' in methods)
919 Returns
920 -------
921 self : for method chaining
922 """
923 if verbose:
924 print("=" * 60)
925 print("Constant pH Titration Analysis")
926 print("=" * 60)
927 print(f"Log files: {[str(f) for f in self.log_files]}")
929 # Initialize and prepare
930 self._tc = TitrationCurve(self.log_files, make_plots=False)
931 self._tc.prepare()
933 # Store data for plotting
934 self.titration_data = self._tc.titrations.clone()
935 self.resid_to_resname = self._tc.resid_to_resname.copy()
936 self.resid_cols = self._tc.resid_cols.copy()
938 if verbose:
939 n_residues = len(self._tc.resid_cols)
940 pH_vals = self._tc.df['current_pH'].unique().sort()
941 print(f"Residues: {n_residues}")
942 print(f"pH values: {pH_vals.to_list()}")
943 print(f"Total samples: {len(self._tc.df)}")
945 # Curve fitting
946 if 'curvefit' in methods:
947 if verbose:
948 print("\n" + "-" * 40)
949 print("Running curve fitting...")
950 self.fits_curvefit = self._tc.compute_titrations_curvefit()
951 if verbose:
952 n_success = self.fits_curvefit.filter(pl.col('pKa').is_not_nan()).height
953 print(f" Success: {n_success}/{len(self.fits_curvefit)} residues")
955 # Weighted fitting
956 if 'weighted' in methods:
957 if verbose:
958 print("\n" + "-" * 40)
959 print("Running weighted curve fitting...")
960 self.fits_weighted = self._tc.compute_titrations_weighted(verbose=verbose)
961 if verbose:
962 n_success = self.fits_weighted.filter(pl.col('pKa').is_not_nan()).height
963 print(f" Success: {n_success}/{len(self.fits_weighted)} residues")
965 # Bootstrap
966 if 'bootstrap' in methods:
967 if verbose:
968 print("\n" + "-" * 40)
969 print(f"Running bootstrap ({n_bootstrap} iterations)...")
970 self.fits_bootstrap = self._tc.compute_titrations_bootstrap(
971 n_bootstrap=n_bootstrap, verbose=verbose
972 )
973 if verbose:
974 n_success = self.fits_bootstrap.filter(pl.col('pKa').is_not_nan()).height
975 print(f" Success: {n_success}/{len(self.fits_bootstrap)} residues")
977 # Generate comparison if multiple methods ran
978 if self.fits_curvefit is not None and self.fits_weighted is not None:
979 self._generate_comparison()
981 self._analyzed = True
983 if verbose:
984 print("\n" + "=" * 60)
985 print("Analysis complete!")
986 print("=" * 60)
988 return self
990 def _generate_comparison(self) -> None:
991 """Generate comparison DataFrame between curvefit and weighted methods."""
992 self.comparison = self.fits_curvefit.join(
993 self.fits_weighted.select(['resid', 'pKa', 'Hill_n']),
994 on='resid',
995 suffix='_weighted'
996 ).with_columns([
997 (pl.col('pKa') - pl.col('pKa_weighted')).alias('pKa_diff'),
998 (pl.col('Hill_n') - pl.col('Hill_n_weighted')).alias('Hill_n_diff'),
999 ])
1001 def summary(self, show_all: bool = False) -> pl.DataFrame:
1002 """
1003 Print and return summary of results.
1005 Parameters
1006 ----------
1007 show_all : bool
1008 If True, show all residues. Otherwise show first 20.
1010 Returns
1011 -------
1012 DataFrame with comparison results
1013 """
1014 if not self._analyzed:
1015 raise RuntimeError("Must call run() before summary()")
1017 if self.comparison is not None:
1018 successful = self.comparison.filter(
1019 pl.col('pKa').is_not_nan() & pl.col('pKa_weighted').is_not_nan()
1020 )
1022 print(f"\nComparison Summary ({len(successful)} residues with both methods successful):")
1023 print("-" * 60)
1025 if len(successful) > 0:
1026 delta = successful['pKa_diff'].to_numpy()
1027 print(f"ΔpKa (curvefit - weighted):")
1028 print(f" Mean: {np.mean(delta):+.3f}")
1029 print(f" Std: {np.std(delta):.3f}")
1030 print(f" Median: {np.median(delta):+.3f}")
1031 print(f" Range: [{np.min(delta):.3f}, {np.max(delta):.3f}]")
1033 display_df = successful.select([
1034 'resid', 'resname', 'pKa', 'pKa_weighted', 'pKa_diff',
1035 'Hill_n', 'Hill_n_weighted'
1036 ])
1038 if not show_all and len(display_df) > 20:
1039 print(f"\nShowing first 20 of {len(display_df)} residues (use show_all=True for all):")
1040 print(display_df.head(20))
1041 else:
1042 print(display_df)
1044 return self.comparison
1046 elif self.fits_curvefit is not None:
1047 print("\nCurve Fitting Results:")
1048 print(self.fits_curvefit if show_all else self.fits_curvefit.head(20))
1049 return self.fits_curvefit
1051 elif self.fits_weighted is not None:
1052 print("\nWeighted Fitting Results:")
1053 print(self.fits_weighted if show_all else self.fits_weighted.head(20))
1054 return self.fits_weighted
1056 elif self.fits_bootstrap is not None:
1057 print("\nBootstrap Results:")
1058 print(self.fits_bootstrap if show_all else self.fits_bootstrap.head(20))
1059 return self.fits_bootstrap
1061 return None
1063 def get_results(self, method: str = 'curvefit') -> pl.DataFrame:
1064 """
1065 Get results DataFrame for specified method.
1067 Parameters
1068 ----------
1069 method : str
1070 'curvefit', 'weighted', 'bootstrap', or 'comparison'
1071 """
1072 if method == 'curvefit':
1073 return self.fits_curvefit
1074 elif method == 'weighted':
1075 return self.fits_weighted
1076 elif method == 'bootstrap':
1077 return self.fits_bootstrap
1078 elif method == 'comparison':
1079 return self.comparison
1080 else:
1081 raise ValueError(f"Unknown method: {method}")
1083 def plot_residue(
1084 self,
1085 resid: str,
1086 ax: Optional['plt.Axes'] = None,
1087 show_curvefit: bool = True,
1088 show_weighted: bool = True,
1089 show_data: bool = True,
1090 figsize: Tuple[float, float] = (8, 6),
1091 save: Optional[str | Path] = None,
1092 ) -> 'plt.Figure':
1093 """
1094 Plot titration curve for a single residue.
1096 Parameters
1097 ----------
1098 resid : str
1099 Residue ID to plot
1100 ax : matplotlib Axes, optional
1101 Axes to plot on. If None, creates new figure.
1102 show_curvefit : bool
1103 Show curve fitting result
1104 show_weighted : bool
1105 Show weighted fit result
1106 show_data : bool
1107 Show raw data points
1108 figsize : tuple
1109 Figure size if creating new figure
1110 save : str or Path, optional
1111 Path to save figure
1113 Returns
1114 -------
1115 matplotlib Figure
1116 """
1117 try:
1118 import matplotlib.pyplot as plt
1119 except ImportError:
1120 raise ImportError("matplotlib required for plotting: pip install matplotlib")
1122 if not self._analyzed:
1123 raise RuntimeError("Must call run() before plotting")
1125 if ax is None:
1126 fig, ax = plt.subplots(figsize=figsize)
1127 else:
1128 fig = ax.get_figure()
1130 resname = self.resid_to_resname.get(resid, 'UNK')
1132 # Raw data
1133 resid_data = self.titration_data.filter(pl.col('resid') == resid)
1134 pH_data = resid_data['current_pH'].to_numpy()
1135 frac_data = resid_data['fraction_protonated'].to_numpy()
1136 n_samples = resid_data['n_samples'].to_numpy()
1138 # Standard error for binomial
1139 se = np.sqrt(frac_data * (1 - frac_data) / np.maximum(n_samples, 1))
1141 # Plot data points
1142 if show_data:
1143 ax.errorbar(
1144 pH_data, frac_data, yerr=se,
1145 fmt='o', color='black', markersize=8,
1146 capsize=3, capthick=1, elinewidth=1,
1147 label='Data', zorder=10
1148 )
1150 # pH grid for curves
1151 pH_grid = np.linspace(
1152 min(pH_data) - 0.5,
1153 max(pH_data) + 0.5,
1154 200
1155 )
1157 # Curve fit line (unweighted)
1158 if show_curvefit and self.fits_curvefit is not None:
1159 cf_row = self.fits_curvefit.filter(pl.col('resid') == resid)
1160 if len(cf_row) > 0:
1161 pKa_cf = cf_row['pKa'][0]
1162 n_cf = cf_row['Hill_n'][0]
1163 if not np.isnan(pKa_cf) and not np.isnan(n_cf):
1164 y_cf = TitrationCurve.hill_equation(pH_grid, pKa_cf, n_cf)
1165 ax.plot(
1166 pH_grid, y_cf, '-', color='blue', linewidth=2,
1167 label=f'Curve fit (pKa={pKa_cf:.2f}, n={n_cf:.2f})'
1168 )
1169 ax.axvline(pKa_cf, color='blue', linestyle=':', alpha=0.5)
1171 # Weighted fit line
1172 if show_weighted and self.fits_weighted is not None:
1173 wt_row = self.fits_weighted.filter(pl.col('resid') == resid)
1174 if len(wt_row) > 0:
1175 pKa_wt = wt_row['pKa'][0]
1176 n_wt = wt_row['Hill_n'][0]
1177 if not np.isnan(pKa_wt) and not np.isnan(n_wt):
1178 y_wt = TitrationCurve.hill_equation(pH_grid, pKa_wt, n_wt)
1179 ax.plot(
1180 pH_grid, y_wt, '--', color='red', linewidth=2,
1181 label=f'Weighted (pKa={pKa_wt:.2f}, n={n_wt:.2f})'
1182 )
1183 ax.axvline(pKa_wt, color='red', linestyle=':', alpha=0.5)
1185 # Formatting
1186 ax.set_xlabel('pH', fontsize=12)
1187 ax.set_ylabel('Fraction Protonated', fontsize=12)
1188 ax.set_title(f'Residue {resid} ({resname})', fontsize=14)
1189 ax.set_ylim(-0.05, 1.05)
1190 ax.axhline(0.5, color='gray', linestyle='--', alpha=0.3)
1191 ax.legend(loc='best', fontsize=10)
1192 ax.grid(True, alpha=0.3)
1194 plt.tight_layout()
1196 if save:
1197 fig.savefig(save, dpi=150, bbox_inches='tight')
1199 return fig
1201 def plot_all(
1202 self,
1203 output_dir: Optional[str | Path] = None,
1204 format: str = 'png',
1205 show_curvefit: bool = True,
1206 show_weighted: bool = True,
1207 residues: Optional[List[str]] = None,
1208 verbose: bool = True,
1209 ) -> None:
1210 """
1211 Generate plots for all (or selected) residues.
1213 Parameters
1214 ----------
1215 output_dir : str or Path, optional
1216 Directory for plots. Uses self.output_dir / 'plots' if None.
1217 format : str
1218 Image format ('png', 'pdf', 'svg')
1219 show_curvefit : bool
1220 Include curve fitting results
1221 show_weighted : bool
1222 Include weighted fit results
1223 residues : list of str, optional
1224 Specific residues to plot. If None, plots all.
1225 verbose : bool
1226 Print progress
1227 """
1228 try:
1229 import matplotlib.pyplot as plt
1230 except ImportError:
1231 raise ImportError("matplotlib required for plotting")
1233 if not self._analyzed:
1234 raise RuntimeError("Must call run() before plotting")
1236 plot_dir = Path(output_dir) if output_dir else self.output_dir / 'plots'
1237 plot_dir.mkdir(parents=True, exist_ok=True)
1239 if residues is None:
1240 residues = self.resid_cols
1242 if verbose:
1243 print(f"Generating {len(residues)} plots in {plot_dir}/")
1245 for i, resid in enumerate(residues):
1246 resname = self.resid_to_resname.get(resid, 'UNK')
1247 filename = plot_dir / f"{resname}_{resid}.{format}"
1249 fig = self.plot_residue(
1250 resid,
1251 show_curvefit=show_curvefit,
1252 show_weighted=show_weighted,
1253 save=filename
1254 )
1255 plt.close(fig)
1257 if verbose and (i + 1) % 20 == 0:
1258 print(f" {i + 1}/{len(residues)} plots generated...")
1260 if verbose:
1261 print(f" All {len(residues)} plots saved to {plot_dir}/")
1263 def plot_summary(
1264 self,
1265 figsize: Tuple[float, float] = (12, 5),
1266 save: Optional[str | Path] = None,
1267 ) -> 'plt.Figure':
1268 """
1269 Generate summary plot comparing methods.
1271 Creates a 2-panel figure:
1272 - Left: pKa comparison scatter plot
1273 - Right: Distribution of pKa differences
1274 """
1275 try:
1276 import matplotlib.pyplot as plt
1277 except ImportError:
1278 raise ImportError("matplotlib required for plotting")
1280 if self.comparison is None:
1281 raise RuntimeError("Need both curvefit and weighted methods for summary plot")
1283 successful = self.comparison.filter(
1284 pl.col('pKa').is_not_nan() & pl.col('pKa_weighted').is_not_nan()
1285 )
1287 if len(successful) == 0:
1288 raise ValueError("No residues with both methods successful")
1290 fig, axes = plt.subplots(1, 2, figsize=figsize)
1292 pKa_cf = successful['pKa'].to_numpy()
1293 pKa_wt = successful['pKa_weighted'].to_numpy()
1294 diff = successful['pKa_diff'].to_numpy()
1296 # Scatter plot
1297 ax = axes[0]
1298 ax.scatter(pKa_cf, pKa_wt, alpha=0.6, edgecolor='black', linewidth=0.5)
1299 lims = [
1300 min(min(pKa_cf), min(pKa_wt)) - 0.5,
1301 max(max(pKa_cf), max(pKa_wt)) + 0.5
1302 ]
1303 ax.plot(lims, lims, 'k--', alpha=0.5)
1304 ax.set_xlim(lims)
1305 ax.set_ylim(lims)
1306 ax.set_xlabel('pKa (Curve Fit)', fontsize=12)
1307 ax.set_ylabel('pKa (Weighted)', fontsize=12)
1308 ax.set_title('Method Comparison', fontsize=14)
1309 ax.set_aspect('equal')
1310 ax.grid(True, alpha=0.3)
1312 corr = np.corrcoef(pKa_cf, pKa_wt)[0, 1]
1313 ax.text(
1314 0.05, 0.95, f'r = {corr:.3f}',
1315 transform=ax.transAxes, fontsize=11,
1316 verticalalignment='top',
1317 bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
1318 )
1320 # Histogram
1321 ax = axes[1]
1322 ax.hist(diff, bins=20, edgecolor='black', alpha=0.7)
1323 ax.axvline(0, color='red', linestyle='--', linewidth=2)
1324 ax.axvline(np.mean(diff), color='blue', linestyle='-', linewidth=2,
1325 label=f'Mean = {np.mean(diff):.3f}')
1326 ax.set_xlabel('ΔpKa (Curve Fit - Weighted)', fontsize=12)
1327 ax.set_ylabel('Count', fontsize=12)
1328 ax.set_title('pKa Difference Distribution', fontsize=14)
1329 ax.legend(fontsize=10)
1330 ax.grid(True, alpha=0.3)
1332 plt.tight_layout()
1334 if save:
1335 fig.savefig(save, dpi=150, bbox_inches='tight')
1337 return fig
1339 def save_results(
1340 self,
1341 output_dir: Optional[str | Path] = None,
1342 prefix: str = '',
1343 formats: List[str] = ['csv'],
1344 ) -> None:
1345 """
1346 Save all results to files.
1348 Parameters
1349 ----------
1350 output_dir : str or Path, optional
1351 Output directory. Uses self.output_dir if None.
1352 prefix : str
1353 Prefix for filenames
1354 formats : list of str
1355 Output formats: 'csv', 'parquet', 'json'
1356 """
1357 out_dir = Path(output_dir) if output_dir else self.output_dir
1358 out_dir.mkdir(parents=True, exist_ok=True)
1360 prefix = f"{prefix}_" if prefix else ""
1362 def save_df(df: pl.DataFrame, name: str):
1363 for fmt in formats:
1364 filepath = out_dir / f"{prefix}{name}.{fmt}"
1365 if fmt == 'csv':
1366 df.write_csv(filepath)
1367 elif fmt == 'parquet':
1368 df.write_parquet(filepath)
1369 elif fmt == 'json':
1370 df.write_json(filepath)
1371 print(f" Saved {filepath}")
1373 print(f"Saving results to {out_dir}/")
1375 if self.fits_curvefit is not None:
1376 save_df(self.fits_curvefit, 'pKa_curvefit')
1377 if self.fits_weighted is not None:
1378 save_df(self.fits_weighted, 'pKa_weighted')
1379 if self.fits_bootstrap is not None:
1380 save_df(self.fits_bootstrap, 'pKa_bootstrap')
1381 if self.comparison is not None:
1382 save_df(self.comparison, 'pKa_comparison')
1383 if self.titration_data is not None:
1384 save_df(self.titration_data, 'titration_data')
1386 def diagnose(self, resid: str) -> Dict:
1387 """Get diagnostic information for a residue."""
1388 if self._tc is None:
1389 raise RuntimeError("Must call run() first")
1390 return self._tc.diagnose_residue(resid, verbose=True)
1392 def recommend_protonation(
1393 self,
1394 target_pH: float,
1395 confidence_threshold: float = 0.7,
1396 use_bootstrap: bool = False,
1397 verbose: bool = True,
1398 ) -> pl.DataFrame:
1399 """
1400 Recommend protonation states for a target pH.
1402 Uses the fitted titration curves to predict which residues are
1403 protonated vs deprotonated at the specified pH, with confidence
1404 estimates based on distance from pKa.
1406 Parameters
1407 ----------
1408 target_pH : float
1409 pH value to make predictions for (e.g., 3.0, 7.4)
1410 confidence_threshold : float
1411 Probability threshold for "confident" predictions (default 0.7)
1412 Residues with P(protonated) between (1-threshold) and threshold
1413 are marked as "uncertain"
1414 use_bootstrap : bool
1415 If True and bootstrap results available, use bootstrap CI for
1416 uncertainty estimation
1417 verbose : bool
1418 Print summary of recommendations
1420 Returns
1421 -------
1422 DataFrame with columns:
1423 - resid: residue ID
1424 - resname: canonical residue name (ASP, GLU, HIS, LYS, CYS)
1425 - pKa: fitted pKa value
1426 - prob_protonated: probability of being protonated at target pH
1427 - recommendation: 'protonated', 'deprotonated', or 'uncertain'
1428 - confidence: 'high', 'medium', or 'low'
1429 - state_name: recommended state name (e.g., 'ASH' or 'ASP')
1430 """
1431 if not self._analyzed:
1432 raise RuntimeError("Must call run() before recommend_protonation()")
1434 # Use curvefit results (or weighted if available)
1435 fits = self.fits_curvefit
1436 if fits is None:
1437 fits = self.fits_weighted
1438 if fits is None:
1439 raise RuntimeError("No fitting results available")
1441 # State name mappings
1442 protonated_state = {
1443 'ASP': 'ASH', 'GLU': 'GLH', 'HIS': 'HIP',
1444 'LYS': 'LYS', 'CYS': 'CYS'
1445 }
1446 deprotonated_state = {
1447 'ASP': 'ASP', 'GLU': 'GLU', 'HIS': 'HIE',
1448 'LYS': 'LYN', 'CYS': 'CYX'
1449 }
1451 # Reference pKa values for sanity checking
1452 reference_pKa = {
1453 'ASP': 3.9, 'GLU': 4.3, 'HIS': 6.0,
1454 'LYS': 10.5, 'CYS': 8.3
1455 }
1457 recommendations = []
1459 for row in fits.iter_rows(named=True):
1460 resid = row['resid']
1461 resname = row['resname']
1462 pKa = row['pKa']
1463 hill_n = row['Hill_n']
1465 # Compute probability of being protonated at target pH
1466 if np.isnan(pKa) or np.isnan(hill_n):
1467 # No fit available - use reference pKa
1468 ref_pKa = reference_pKa.get(resname, 7.0)
1469 prob_prot = 1.0 / (1.0 + 10**(target_pH - ref_pKa))
1470 pKa_used = ref_pKa
1471 fit_quality = 'reference'
1472 else:
1473 # Use fitted Hill equation
1474 prob_prot = TitrationCurve.hill_equation(target_pH, pKa, hill_n)
1475 pKa_used = pKa
1476 fit_quality = 'fitted'
1478 # Determine recommendation
1479 if prob_prot >= confidence_threshold:
1480 recommendation = 'protonated'
1481 state_name = protonated_state.get(resname, resname)
1482 elif prob_prot <= (1 - confidence_threshold):
1483 recommendation = 'deprotonated'
1484 state_name = deprotonated_state.get(resname, resname)
1485 else:
1486 recommendation = 'uncertain'
1487 # For uncertain cases, go with majority
1488 if prob_prot >= 0.5:
1489 state_name = protonated_state.get(resname, resname)
1490 else:
1491 state_name = deprotonated_state.get(resname, resname)
1493 # Confidence based on distance from 0.5
1494 prob_distance = abs(prob_prot - 0.5)
1495 if prob_distance > 0.4: # >90% or <10%
1496 confidence = 'high'
1497 elif prob_distance > 0.2: # >70% or <30%
1498 confidence = 'medium'
1499 else:
1500 confidence = 'low'
1502 recommendations.append({
1503 'resid': resid,
1504 'resname': resname,
1505 'pKa': pKa_used,
1506 'pKa_source': fit_quality,
1507 'prob_protonated': prob_prot,
1508 'recommendation': recommendation,
1509 'confidence': confidence,
1510 'state_name': state_name,
1511 })
1513 result = pl.DataFrame(recommendations)
1515 if verbose:
1516 print(f"\n{'='*60}")
1517 print(f"Protonation Recommendations at pH {target_pH}")
1518 print(f"{'='*60}")
1520 # Summary counts
1521 n_prot = result.filter(pl.col('recommendation') == 'protonated').height
1522 n_deprot = result.filter(pl.col('recommendation') == 'deprotonated').height
1523 n_uncertain = result.filter(pl.col('recommendation') == 'uncertain').height
1525 print(f"\nSummary:")
1526 print(f" Protonated: {n_prot:3d} residues")
1527 print(f" Deprotonated: {n_deprot:3d} residues")
1528 print(f" Uncertain: {n_uncertain:3d} residues")
1530 # Group by residue type
1531 print(f"\nBy residue type:")
1532 for restype in ['ASP', 'GLU', 'HIS', 'LYS', 'CYS']:
1533 subset = result.filter(pl.col('resname') == restype)
1534 if len(subset) > 0:
1535 n_p = subset.filter(pl.col('recommendation') == 'protonated').height
1536 n_d = subset.filter(pl.col('recommendation') == 'deprotonated').height
1537 n_u = subset.filter(pl.col('recommendation') == 'uncertain').height
1538 ref = reference_pKa.get(restype, '?')
1539 print(f" {restype} (ref pKa={ref}): {n_p} prot, {n_d} deprot, {n_u} uncertain")
1541 # Show uncertain residues (most important to check)
1542 uncertain = result.filter(pl.col('recommendation') == 'uncertain')
1543 if len(uncertain) > 0:
1544 print(f"\n⚠️ Uncertain residues (prob between {1-confidence_threshold:.0%}-{confidence_threshold:.0%}):")
1545 for row in uncertain.sort('prob_protonated', descending=True).iter_rows(named=True):
1546 print(f" {row['resname']} {row['resid']}: "
1547 f"P(prot)={row['prob_protonated']:.1%}, "
1548 f"pKa={row['pKa']:.1f} → {row['state_name']}")
1550 # Show residues with pKa near target pH
1551 near_pKa = result.filter(
1552 (pl.col('pKa') > target_pH - 1.5) &
1553 (pl.col('pKa') < target_pH + 1.5) &
1554 (pl.col('pKa_source') == 'fitted')
1555 )
1556 if len(near_pKa) > 0:
1557 print(f"\n📍 Residues with pKa near pH {target_pH} (±1.5 units):")
1558 for row in near_pKa.sort('pKa').iter_rows(named=True):
1559 print(f" {row['resname']} {row['resid']}: "
1560 f"pKa={row['pKa']:.2f}, "
1561 f"P(prot)={row['prob_protonated']:.1%} → {row['state_name']}")
1563 return result
1565 def get_protonation_string(
1566 self,
1567 target_pH: float,
1568 confidence_threshold: float = 0.7,
1569 ) -> str:
1570 """
1571 Get a simple string of recommended protonation states.
1573 Useful for setting up simulations.
1575 Parameters
1576 ----------
1577 target_pH : float
1578 pH value to make predictions for
1579 confidence_threshold : float
1580 Probability threshold for confident predictions
1582 Returns
1583 -------
1584 String with format: "resid:state,resid:state,..."
1585 """
1586 recs = self.recommend_protonation(
1587 target_pH,
1588 confidence_threshold=confidence_threshold,
1589 verbose=False
1590 )
1592 parts = []
1593 for row in recs.iter_rows(named=True):
1594 parts.append(f"{row['resid']}:{row['state_name']}")
1596 return ','.join(parts)
1598 def export_protonation_states(
1599 self,
1600 target_pH: float,
1601 output_file: Optional[str | Path] = None,
1602 format: str = 'csv',
1603 confidence_threshold: float = 0.7,
1604 ) -> pl.DataFrame:
1605 """
1606 Export protonation state recommendations to file.
1608 Parameters
1609 ----------
1610 target_pH : float
1611 pH value to make predictions for
1612 output_file : str or Path, optional
1613 Output file path. If None, uses output_dir/protonation_pH{pH}.{format}
1614 format : str
1615 Output format: 'csv', 'json', or 'txt'
1616 confidence_threshold : float
1617 Probability threshold for confident predictions
1619 Returns
1620 -------
1621 DataFrame with recommendations
1622 """
1623 recs = self.recommend_protonation(
1624 target_pH,
1625 confidence_threshold=confidence_threshold,
1626 verbose=False
1627 )
1629 if output_file is None:
1630 output_file = self.output_dir / f"protonation_pH{target_pH:.1f}.{format}"
1631 else:
1632 output_file = Path(output_file)
1634 if format == 'csv':
1635 recs.write_csv(output_file)
1636 elif format == 'json':
1637 recs.write_json(output_file)
1638 elif format == 'txt':
1639 # Simple text format for easy reading
1640 with open(output_file, 'w') as f:
1641 f.write(f"# Protonation states at pH {target_pH}\n")
1642 f.write(f"# confidence_threshold = {confidence_threshold}\n")
1643 f.write("#\n")
1644 f.write("# resid resname state prob_prot confidence\n")
1645 for row in recs.iter_rows(named=True):
1646 f.write(f"{row['resid']:>6s} {row['resname']:>7s} "
1647 f"{row['state_name']:>5s} {row['prob_protonated']:>9.3f} "
1648 f"{row['confidence']}\n")
1650 print(f"Saved protonation recommendations to {output_file}")
1651 return recs
1653 def plot_protonation_summary(
1654 self,
1655 target_pH: float,
1656 figsize: Tuple[float, float] = (12, 6),
1657 save: Optional[str | Path] = None,
1658 ) -> 'plt.Figure':
1659 """
1660 Visualize protonation probabilities at target pH.
1662 Creates a bar plot showing P(protonated) for each residue,
1663 colored by residue type.
1665 Parameters
1666 ----------
1667 target_pH : float
1668 pH value to visualize
1669 figsize : tuple
1670 Figure size
1671 save : str or Path, optional
1672 Path to save figure
1674 Returns
1675 -------
1676 matplotlib Figure
1677 """
1678 try:
1679 import matplotlib.pyplot as plt
1680 from matplotlib.patches import Patch
1681 except ImportError:
1682 raise ImportError("matplotlib required for plotting")
1684 recs = self.recommend_protonation(target_pH, verbose=False)
1686 # Sort by probability
1687 recs_sorted = recs.sort('prob_protonated', descending=True)
1689 fig, ax = plt.subplots(figsize=figsize)
1691 # Colors for each residue type
1692 colors = {
1693 'ASP': '#e41a1c', # red
1694 'GLU': '#ff7f00', # orange
1695 'HIS': '#4daf4a', # green
1696 'LYS': '#377eb8', # blue
1697 'CYS': '#984ea3', # purple
1698 }
1700 x = np.arange(len(recs_sorted))
1701 probs = recs_sorted['prob_protonated'].to_numpy()
1702 resnames = recs_sorted['resname'].to_list()
1703 resids = recs_sorted['resid'].to_list()
1705 bar_colors = [colors.get(rn, 'gray') for rn in resnames]
1707 bars = ax.bar(x, probs, color=bar_colors, edgecolor='black', linewidth=0.5)
1709 # Add 0.5 line
1710 ax.axhline(0.5, color='black', linestyle='--', linewidth=2, alpha=0.7)
1711 ax.axhline(0.7, color='gray', linestyle=':', linewidth=1, alpha=0.5)
1712 ax.axhline(0.3, color='gray', linestyle=':', linewidth=1, alpha=0.5)
1714 # Labels
1715 ax.set_xlabel('Residue', fontsize=12)
1716 ax.set_ylabel('P(protonated)', fontsize=12)
1717 ax.set_title(f'Protonation Probabilities at pH {target_pH}', fontsize=14)
1718 ax.set_ylim(0, 1.05)
1720 # X-axis labels (show every Nth label if too many)
1721 n_residues = len(x)
1722 if n_residues > 50:
1723 # Show fewer labels
1724 step = n_residues // 20
1725 ax.set_xticks(x[::step])
1726 labels = [f"{resnames[i]}{resids[i]}" for i in range(0, n_residues, step)]
1727 ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
1728 else:
1729 ax.set_xticks(x)
1730 labels = [f"{rn}{ri}" for rn, ri in zip(resnames, resids)]
1731 ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
1733 # Legend
1734 legend_elements = [Patch(facecolor=c, edgecolor='black', label=n)
1735 for n, c in colors.items() if n in resnames]
1736 ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
1738 # Add text annotations for counts
1739 n_prot = sum(1 for p in probs if p >= 0.5)
1740 n_deprot = sum(1 for p in probs if p < 0.5)
1741 ax.text(0.02, 0.98, f'Protonated: {n_prot}\nDeprotonated: {n_deprot}',
1742 transform=ax.transAxes, fontsize=10, verticalalignment='top',
1743 bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
1745 plt.tight_layout()
1747 if save:
1748 fig.savefig(save, dpi=150, bbox_inches='tight')
1750 return fig
1752 def __repr__(self) -> str:
1753 status = "analyzed" if self._analyzed else "not analyzed"
1754 return f"TitrationAnalyzer({len(self.log_files)} log files, {status})"
1757def analyze_cph(
1758 log_files: Path | List[Path] | str | List[str],
1759 output_dir: Optional[str | Path] = None,
1760 methods: List[str] = ['curvefit', 'weighted'],
1761 plot: bool = True,
1762 verbose: bool = True,
1763) -> TitrationAnalyzer:
1764 """
1765 Convenience function to run complete constant pH analysis.
1767 Parameters
1768 ----------
1769 log_files : path(s) to log files
1770 output_dir : output directory
1771 methods : list of methods to run ('curvefit', 'weighted', 'bootstrap')
1772 plot : whether to generate plots
1773 verbose : print progress
1775 Returns
1776 -------
1777 TitrationAnalyzer with results
1779 Example
1780 -------
1781 >>> results = analyze_cph('cpH.log', output_dir='analysis/')
1782 >>> results.summary()
1783 >>> results.plot_residue('145')
1784 """
1785 analyzer = TitrationAnalyzer(log_files, output_dir=output_dir)
1786 analyzer.run(methods=methods, verbose=verbose)
1788 if plot:
1789 try:
1790 analyzer.plot_all(verbose=verbose)
1791 analyzer.plot_summary(save=analyzer.output_dir / 'summary.png')
1792 except ImportError:
1793 if verbose:
1794 print("matplotlib not available, skipping plots")
1795 except RuntimeError:
1796 # plot_summary requires both methods
1797 pass
1799 analyzer.save_results()
1801 return analyzer
1804if __name__ == '__main__':
1805 import sys
1807 # Get log files from command line or use default
1808 log_paths = [Path('cpH.log')]
1809 if len(sys.argv) > 1:
1810 log_paths = [Path(p) for p in sys.argv[1:]]
1812 # =========================================================================
1813 # STREAMLINED API - TitrationAnalyzer
1814 # =========================================================================
1815 #
1816 # Available methods:
1817 # - curvefit: Simple least squares fit (default)
1818 # - weighted: Weighted least squares (by 1/variance)
1819 # - bootstrap: Curve fit with bootstrap confidence intervals
1820 #
1821 # Basic usage:
1822 # analyzer = TitrationAnalyzer(log_paths)
1823 # analyzer.run()
1824 # analyzer.summary()
1825 #
1826 # Protonation recommendations:
1827 # recs = analyzer.recommend_protonation(target_pH=3.0)
1828 # analyzer.plot_protonation_summary(target_pH=3.0)
1829 #
1830 # =========================================================================
1832 # Create analyzer
1833 analyzer = TitrationAnalyzer(log_paths, output_dir='cph_analysis')
1835 # Run curve fitting and weighted fitting
1836 analyzer.run(methods=['curvefit', 'weighted'], verbose=True)
1838 # Print summary
1839 analyzer.summary()
1841 # Generate all plots (if matplotlib available)
1842 try:
1843 analyzer.plot_all(verbose=True)
1844 analyzer.plot_summary(save='cph_analysis/summary.png')
1845 print("\nPlots saved to cph_analysis/plots/")
1846 except ImportError:
1847 print("\nSkipping plots (matplotlib not installed)")
1849 # Save results
1850 analyzer.save_results()
1852 # =========================================================================
1853 # PROTONATION RECOMMENDATIONS
1854 # =========================================================================
1856 # Get recommendations for pH 3.0
1857 print("\n")
1858 recs = analyzer.recommend_protonation(target_pH=3.0)
1860 # Export to file
1861 analyzer.export_protonation_states(target_pH=3.0, format='csv')
1863 # Visualize
1864 try:
1865 analyzer.plot_protonation_summary(
1866 target_pH=3.0,
1867 save='cph_analysis/protonation_pH3.0.png'
1868 )
1869 except ImportError:
1870 pass
1872 # Can also get recommendations for physiological pH
1873 # analyzer.recommend_protonation(target_pH=7.4)