import importlib
import json
import os
from collections import Counter
import Xponge
import Xponge.forcefield.amber as amber
forcefields = ['Xponge.forcefield.amber.ff14sb', 'Xponge.forcefield.amber.gaff', 'Xponge.forcefield.amber.tip3p']
forcefield_modules = []
for mod in forcefields:
  forcefield_modules.append(importlib.import_module(mod))
def _find_type(name):
  try:
    return Xponge.ResidueType.get_type(name)
  except Exception:
    pass
  if hasattr(Xponge, name):
    return getattr(Xponge, name)
  for mod in forcefield_modules:
    if hasattr(mod, name):
      return getattr(mod, name)
  return None
WAT = _find_type('WAT')
if WAT is None:
  raise RuntimeError('未找到 WAT 残基')
frcmod_paths = ['/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/frcmod/interactive.frcmod']
for frcmod_path in frcmod_paths:
  amber.load_parameters_from_frcmod(frcmod_path, prefix=False)
nonstandard_residues = [{'mol2_path': '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/edit_struct/CCS_3.gaff.mol2', 'resname': 'CCS', 'head': None, 'tail': None}]
for entry in nonstandard_residues:
  mol2_path = entry['mol2_path']
  resname = entry.get('resname')
  Xponge.load_mol2(mol2_path, as_template=True)
  if not resname:
    resname = os.path.splitext(os.path.basename(mol2_path))[0].upper()
  if resname:
    res_type = Xponge.ResidueType.get_type(resname)
    if entry.get('head'):
      res_type.head = entry['head']
    if entry.get('tail'):
      res_type.tail = entry['tail']
structures = [('/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/8RYK_pdbfixer_H_ed.pdb', 'pdb')]
chain_terminals = {'/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/8RYK_pdbfixer_H_ed.pdb': [{'chain_id': 'B', 'start_index': 0, 'end_index': 152, 'start_res_seq': '1', 'end_res_seq': '153', 'start_ins_code': None, 'end_ins_code': None, 'n_terminal': True, 'c_terminal': True}, {'chain_id': 'D', 'start_index': 153, 'end_index': 168, 'start_res_seq': '1', 'end_res_seq': '16', 'start_ins_code': None, 'end_ins_code': None, 'n_terminal': False, 'c_terminal': True}]}
residue_links = [{'atom_a': {'residue_seq': 1, 'residue_name': 'PHE', 'atom_name': 'N', 'chain_id': 'D', 'source_path': '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/8RYK_pdbfixer_H_ed.pdb', 'residue_index': 153}, 'atom_b': {'residue_seq': 14, 'residue_name': 'CCS', 'atom_name': 'CE', 'chain_id': 'D', 'source_path': '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/8RYK_pdbfixer_H_ed.pdb', 'residue_index': 166}}, {'atom_a': {'residue_seq': 13, 'residue_name': 'ASP', 'atom_name': 'C', 'chain_id': 'D', 'source_path': '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/8RYK_pdbfixer_H_ed.pdb', 'residue_index': 165}, 'atom_b': {'residue_seq': 14, 'residue_name': 'CCS', 'atom_name': 'N', 'chain_id': 'D', 'source_path': '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/8RYK_pdbfixer_H_ed.pdb', 'residue_index': 166}}, {'atom_a': {'residue_seq': 14, 'residue_name': 'CCS', 'atom_name': 'C', 'chain_id': 'D', 'source_path': '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/8RYK_pdbfixer_H_ed.pdb', 'residue_index': 166}, 'atom_b': {'residue_seq': 15, 'residue_name': 'GLY', 'atom_name': 'N', 'chain_id': 'D', 'source_path': '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/8RYK_pdbfixer_H_ed.pdb', 'residue_index': 167}}]
only_pdb = False
keep_conect = True
box_padding = 0.5
add_solvent = True
list_box = 10.0
tolerance = 2.5
n_solvent = None
add_ions = False
def _unterminal(mol, idx):
  try:
    if idx is None:
      return
    if idx < 0:
      idx = len(mol.residues) + idx
    if idx < 0 or idx >= len(mol.residues):
      return
    res = mol.residues[idx]
  except Exception:
    return
  for name in ('unterminal', 'Unterminal', 'UnTerminal'):
    fn = getattr(res, name, None)
    if callable(fn):
      fn()
      return
load_pdb = getattr(Xponge, 'load_pdb', None) or getattr(Xponge, 'Load_Pdb', None)
load_mol2 = getattr(Xponge, 'load_mol2', None) or getattr(Xponge, 'Load_Mol2', None)
if load_pdb is None or load_mol2 is None:
  raise RuntimeError('Xponge 读取函数不可用')
molecules = []
for path, ext in structures:
  if ext == 'mol2':
    molecules.append(load_mol2(path))
    continue
  chain_settings = chain_terminals.get(path) or []
  unterminal_residues = []
  for setting in chain_settings:
    chain_id = (setting.get('chain_id') or '').strip()[:1]
    start_res_seq = str(setting.get('start_res_seq') or '').strip()
    end_res_seq = str(setting.get('end_res_seq') or '').strip()
    start_ins_code = str(setting.get('start_ins_code') or '').strip()
    end_ins_code = str(setting.get('end_ins_code') or '').strip()
    if setting.get('n_terminal') is False and start_res_seq:
      selector = f"{chain_id}:{start_res_seq}" if chain_id else start_res_seq
      if start_ins_code:
        selector += start_ins_code[0]
      unterminal_residues.append(selector)
    if setting.get('c_terminal') is False and end_res_seq:
      selector = f"{chain_id}:{end_res_seq}" if chain_id else end_res_seq
      if end_ins_code:
        selector += end_ins_code[0]
      unterminal_residues.append(selector)
  try:
    if unterminal_residues:
      mol = load_pdb(path, ignore_conect=(not keep_conect), read_cryst1=False, unterminal_residues=unterminal_residues)
    else:
      mol = load_pdb(path, ignore_conect=(not keep_conect), read_cryst1=False)
  except TypeError as exc:
    raise RuntimeError('当前 Xponge 版本不支持 unterminal_residues 参数，请升级 Xponge 到最新版本') from exc
  for setting in chain_settings:
    if setting.get('n_terminal') is False:
      _unterminal(mol, setting.get('start_index'))
    if setting.get('c_terminal') is False:
      _unterminal(mol, setting.get('end_index'))
  molecules.append(mol)
if not molecules:
  raise RuntimeError('未加载到结构')
mol_out = molecules[0]
for mol in molecules[1:]:
  mol_out = mol_out + mol
source_ranges = []
offset = 0
for (path, ext), mol in zip(structures, molecules):
  try:
    count = len(mol.residues)
  except Exception:
    count = 0
  source_ranges.append({'path': path, 'start': offset, 'end': offset + count - 1})
  offset += count
if residue_links:
  residues = getattr(mol_out, 'residues', None)
  if residues is None:
    raise RuntimeError('Xponge residues 不可用')
  add_link = getattr(mol_out, 'add_residue_link', None) or getattr(mol_out, 'Add_Residue_Link', None)
  if add_link is None:
    raise RuntimeError('Xponge add_residue_link 不可用')
  def _get_res_info(res):
    name = getattr(res, 'name', None)
    if not name:
      res_type = getattr(res, 'type', None)
      name = getattr(res_type, 'name', None) if res_type is not None else None
    chain_id = getattr(res, 'chain_id', None)
    if chain_id is None:
      chain_obj = getattr(res, 'chain', None)
      if chain_obj is not None:
        chain_id = getattr(chain_obj, 'id', None) or getattr(chain_obj, 'name', None)
    if chain_id is not None and str(chain_id).strip() == '':
      chain_id = None
    seq = None
    for key in ('residue_seq', 'resseq', 'seq', 'id', 'index'):
      if hasattr(res, key):
        try:
          seq = int(getattr(res, key))
          break
        except Exception:
          pass
    return name, chain_id, seq
  def _match_residue(atom):
    target_seq = atom.get('residue_seq')
    target_chain = (atom.get('chain_id') or '').strip()
    target_name = (atom.get('residue_name') or '').strip().upper()
    target_path = (atom.get('source_path') or '').strip()
    target_index = atom.get('residue_index')
    start = 0
    end = len(residues) - 1
    if target_path:
      for rng in source_ranges:
        if rng.get('path') == target_path:
          start = rng.get('start', 0)
          end = rng.get('end', end)
          break
    if target_index is not None and target_path:
      try:
        idx = start + int(target_index)
      except Exception:
        idx = None
      if idx is not None and 0 <= idx < len(residues):
        return idx
    matches = []
    for idx in range(start, end + 1):
      if idx < 0 or idx >= len(residues):
        continue
      res = residues[idx]
      name, chain_id, seq = _get_res_info(res)
      if target_seq is not None and seq is None:
        continue
      if target_seq is not None and seq is not None and int(seq) != int(target_seq):
        continue
      if target_chain and chain_id is not None and str(chain_id).strip() != target_chain:
        continue
      if target_name and name and str(name).strip().upper() != target_name:
        continue
      matches.append(idx)
    if not matches:
      raise RuntimeError(f"未找到匹配残基: {atom}")
    if len(matches) > 1:
      raise RuntimeError(f"残基匹配不唯一: {atom}")
    return matches[0]
  for link in residue_links:
    atom_a = link.get('atom_a') or {}
    atom_b = link.get('atom_b') or {}
    idx_a = _match_residue(atom_a)
    idx_b = _match_residue(atom_b)
    res_a = residues[idx_a]
    res_b = residues[idx_b]
    name2atom = getattr(res_a, 'name2atom', None) or getattr(res_a, 'Name2Atom', None)
    atom_a_obj = name2atom(atom_a.get('atom_name')) if callable(name2atom) else getattr(res_a, atom_a.get('atom_name'))
    name2atom = getattr(res_b, 'name2atom', None) or getattr(res_b, 'Name2Atom', None)
    atom_b_obj = name2atom(atom_b.get('atom_name')) if callable(name2atom) else getattr(res_b, atom_b.get('atom_name'))
    add_link(atom_a_obj, atom_b_obj)
add_missing_atoms = getattr(mol_out, 'add_missing_atoms', None) or getattr(mol_out, 'Add_Missing_Atoms', None)
if add_missing_atoms is None:
  raise RuntimeError('Xponge add_missing_atoms 不可用')
add_missing_atoms()
if add_solvent:
  add_solvent_box = getattr(Xponge, 'add_solvent_box', None) or getattr(Xponge, 'Add_Solvent_Box', None)
  if add_solvent_box is None:
    raise RuntimeError('Xponge add_solvent_box 不可用')
  if n_solvent is None:
    add_solvent_box(mol_out, WAT, list_box, tolerance)
  else:
    add_solvent_box(mol_out, WAT, list_box, tolerance, n_solvent)
if add_ions:
  cation_type = _find_type(cation_name)
  anion_type = _find_type(anion_name)
  if cation_type is None or anion_type is None:
    raise RuntimeError('未找到离子残基类型')
  counts = Counter(res.type.name for res in mol_out.residues)
  num_wat = counts.get('WAT', 0)
  tot_charge = int(round(mol_out.charge))
  n_salt = float(salt_concentration) / 1000.0 * num_wat * 0.0180687
  n_cation = int(round(n_salt))
  if anion_charge == 0:
    raise RuntimeError('阴离子电荷不能为 0')
  n_anion = int(round((tot_charge + n_cation * cation_charge) / abs(anion_charge)))
  if n_anion < 0:
    raise RuntimeError('离子浓度不足以配平总电荷')
  solvent_replace = getattr(Xponge, 'Solvent_Replace', None) or getattr(Xponge, 'solvent_replace', None)
  if solvent_replace is None:
    raise RuntimeError('Xponge Solvent_Replace 不可用')
  solvent_replace(mol_out, WAT, {cation_type: n_cation, anion_type: n_anion})
  meta = {
    'total_charge': tot_charge,
    'num_wat': num_wat,
    'n_cation': n_cation,
    'n_anion': n_anion
  }
  print('__SPONGE_META__' + json.dumps(meta))
save_pdb = getattr(Xponge, 'save_pdb', None) or getattr(Xponge, 'Save_Pdb', None)
save_input = getattr(Xponge, 'Save_Sponge_Input', None) or getattr(Xponge, 'save_sponge_input', None)
if save_pdb is None or (save_input is None and not only_pdb):
  raise RuntimeError('Xponge 保存函数不可用')
if box_padding is not None:
  set_box_padding = getattr(mol_out, 'set_box_padding', None) or getattr(mol_out, 'Set_Box_Padding', None)
  if set_box_padding is None:
    raise RuntimeError('Xponge set_box_padding 不可用')
  set_box_padding(padding=box_padding, center=True)
saved_links = None
if hasattr(mol_out, 'residue_links'):
  try:
    saved_links = list(mol_out.residue_links)
    mol_out.residue_links = []
  except Exception:
    saved_links = None
save_pdb(mol_out, '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/sponge/input.pdb')
if saved_links:
  try:
    atom_index = getattr(mol_out, 'atom_index', {})
    connect_map = {}
    for link in saved_links:
      a = getattr(link, 'atom1', None)
      b = getattr(link, 'atom2', None)
      if a is None or b is None:
        continue
      idx_a = atom_index.get(a)
      idx_b = atom_index.get(b)
      if idx_a is None or idx_b is None:
        continue
      connect_map.setdefault(idx_a, set()).add(idx_b)
      connect_map.setdefault(idx_b, set()).add(idx_a)
    if connect_map:
      pdb_path = '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/sponge/input.pdb'
      with open(pdb_path, 'r', encoding='utf-8', errors='ignore') as handle:
        lines = handle.read().splitlines()
      end_line = None
      if lines and lines[-1].startswith('END'):
        end_line = lines.pop()
      conect_lines = []
      for idx in sorted(connect_map.keys()):
        targets = sorted(connect_map[idx])
        for chunk_start in range(0, len(targets), 4):
          chunk = targets[chunk_start:chunk_start + 4]
          conect_lines.append('CONECT' + f"{idx + 1:5d}" + ''.join([f"{j + 1:5d}" for j in chunk]))
      lines.extend(conect_lines)
      if end_line:
        lines.append(end_line)
      with open(pdb_path, 'w', encoding='utf-8') as handle:
        handle.write('\n'.join(lines) + '\n')
  except Exception:
    pass
if not only_pdb:
  if saved_links is not None:
    try:
      mol_out.residue_links = saved_links
    except Exception:
      pass
  save_input(mol_out, '/media/yuh/BCDC9249DC91FDB8/Data/Mokda-FEP/8RYK/sponge/input')