import importlib
import json
import os
import platform
import sys
import Xponge
import XpongeLib
import Xponge.forcefield.amber.gaff as gaff
from itertools import combinations, permutations
from Xponge.forcefield.base import angle_base, bond_base, dihedral_base
payload = json.loads(os.environ['SPONGE_INTERACTIVE_FRCMOD_PAYLOAD'])
for module_name in payload.get('standardForcefields') or ['Xponge.forcefield.amber.ff14sb']:
  importlib.import_module(module_name)
def type_name(atom):
  atom_type = getattr(atom, 'type', None)
  return str(getattr(atom_type, 'name', None) or atom_type or '')
def atom_key(atoms):
  return '-'.join(type_name(atom) for atom in atoms)
def all_registered(type_cls):
  names = set(getattr(type_cls, 'get_all_types')().keys())
  names.update(getattr(type_cls, '_types_different_name', {}).keys())
  return names
def term_exists(section, aliases):
  cls = {'BOND': bond_base.BondType, 'ANGLE': angle_base.AngleType, 'DIHE': dihedral_base.ProperType, 'IMPROPER': dihedral_base.ImproperType}[section]
  registered = all_registered(cls)
  return any(alias in registered for alias in aliases)
def add_boundary_term(terms, section, name, aliases, required=True):
  aliases = [alias for alias in aliases if alias]
  if not aliases:
    return
  seen = {entry.get('name') for entry in terms[section]}
  if name in seen:
    return
  terms[section].append({'name': name, 'aliases': aliases, 'baselineMissing': not term_exists(section, aliases), 'required': bool(required)})
def residue_atom(residue, atom_name):
  name2atom = getattr(residue, 'name2atom', None) or getattr(residue, 'Name2Atom', None) or getattr(residue, 'Name2atom', None)
  return name2atom(atom_name) if callable(name2atom) else getattr(residue, atom_name)
loaded_mol2_templates = set()
def residue_type_suffix(name):
  return ''.join(ch if ch.isalnum() else '_' for ch in str(name or 'mol'))
def build_molecule(name, selected_links):
  mol = Xponge.Molecule(name)
  suffix = residue_type_suffix(name)
  for idx, component in enumerate(payload.get('components') or [], start=1):
    kind = component.get('kind')
    res_name = component.get('resName')
    if kind == 'standard':
      base_type = Xponge.ResidueType.get_type(res_name)
      res_type = base_type.deepcopy(f'{res_name}_{idx}_{suffix}')
      present = {str(name).strip().upper() for name in (component.get('atomNames') or []) if str(name).strip()}
      if present:
        template_atoms = {str(getattr(atom, 'name', '') or '').strip().upper() for atom in getattr(base_type, 'atoms', []) or [] if str(getattr(atom, 'name', '') or '').strip()}
        omit_atoms = sorted(template_atoms - present)
        if omit_atoms:
          res_type.omit_atoms(omit_atoms, charge=None)
    else:
      if res_name not in loaded_mol2_templates:
        Xponge.load_mol2(component.get('gaffMol2Path'), as_template=True)
        loaded_mol2_templates.add(res_name)
      res_type = Xponge.ResidueType.get_type(res_name)
    mol.add_residue(Xponge.Residue(res_type, directly_copy=True))
  for link in selected_links or []:
    idx_a = int(link.get('componentAIndex') or 0) - 1
    idx_b = int(link.get('componentBIndex') or 0) - 1
    if idx_a < 0 or idx_b < 0:
      raise ValueError('link component index missing')
    res_a = mol.residues[idx_a]
    res_b = mol.residues[idx_b]
    atom_a = residue_atom(res_a, link.get('atomA'))
    atom_b = residue_atom(res_b, link.get('atomB'))
    mol.add_residue_link(atom_a, atom_b)
  return mol
mol = build_molecule('interactive_frcmod', payload.get('links') or [])
def add_adjacency(adjacency, atom_a, atom_b):
  if atom_a is None or atom_b is None or atom_a is atom_b:
    return
  if atom_b not in adjacency.setdefault(atom_a, []):
    adjacency[atom_a].append(atom_b)
  if atom_a not in adjacency.setdefault(atom_b, []):
    adjacency[atom_b].append(atom_a)
def build_atom_adjacency(mol):
  adjacency = {}
  for residue in getattr(mol, 'residues', []) or []:
    residue_atoms = {str(getattr(atom, 'name', '') or ''): atom for atom in (getattr(residue, 'atoms', []) or [])}
    residue_type = getattr(residue, 'type', None)
    type_connectivity = getattr(residue_type, 'connectivity', {}) or {}
    for type_atom, type_neighbors in type_connectivity.items():
      atom = residue_atoms.get(str(getattr(type_atom, 'name', '') or ''))
      if atom is None:
        continue
      for type_neighbor in type_neighbors or []:
        neighbor = residue_atoms.get(str(getattr(type_neighbor, 'name', '') or ''))
        if neighbor is not None:
          add_adjacency(adjacency, atom, neighbor)
    for atom, neighbors in (getattr(residue, 'connectivity', {}) or {}).items():
      for neighbor in neighbors or []:
        add_adjacency(adjacency, atom, neighbor)
  for reslink in getattr(mol, 'residue_links', []) or []:
    add_adjacency(adjacency, getattr(reslink, 'atom1', None), getattr(reslink, 'atom2', None))
  return adjacency
def build_boundary_terms():
  terms = {'BOND': [], 'ANGLE': [], 'DIHE': [], 'IMPROPER': []}
  components = payload.get('components') or []
  adjacency = build_atom_adjacency(mol)
  for link in payload.get('links') or []:
    idx_a = int(link.get('componentAIndex') or 0) - 1
    idx_b = int(link.get('componentBIndex') or 0) - 1
    comp_a = components[idx_a]
    comp_b = components[idx_b]
    if comp_a.get('kind') == comp_b.get('kind'):
      continue
    atom_a = residue_atom(mol.residues[idx_a], link.get('atomA'))
    atom_b = residue_atom(mol.residues[idx_b], link.get('atomB'))
    neighbors_a = [neighbor for neighbor in adjacency.get(atom_a, []) if neighbor is not atom_b]
    neighbors_b = [neighbor for neighbor in adjacency.get(atom_b, []) if neighbor is not atom_a]
    add_boundary_term(terms, 'BOND', atom_key([atom_a, atom_b]), [atom_key([atom_a, atom_b]), atom_key([atom_b, atom_a])])
    for neighbor in neighbors_a:
      add_boundary_term(terms, 'ANGLE', atom_key([neighbor, atom_a, atom_b]), [atom_key([neighbor, atom_a, atom_b]), atom_key([atom_b, atom_a, neighbor])], required=False)
    for neighbor in neighbors_b:
      add_boundary_term(terms, 'ANGLE', atom_key([atom_a, atom_b, neighbor]), [atom_key([atom_a, atom_b, neighbor]), atom_key([neighbor, atom_b, atom_a])], required=False)
    for left in neighbors_a:
      for right in neighbors_b:
        add_boundary_term(terms, 'DIHE', atom_key([left, atom_a, atom_b, right]), [atom_key([left, atom_a, atom_b, right]), atom_key([right, atom_b, atom_a, left])], required=False)
    for center, linked, neighbors in ((atom_a, atom_b, neighbors_a), (atom_b, atom_a, neighbors_b)):
      for outer in neighbors:
        for outer2 in adjacency.get(outer, []):
          if outer2 is center:
            continue
          add_boundary_term(terms, 'DIHE', atom_key([outer2, outer, center, linked]), [atom_key([outer2, outer, center, linked]), atom_key([linked, center, outer, outer2])], required=False)
      if len(neighbors) >= 2:
        for first, second in combinations(neighbors, 2):
          improper_atoms = [first, second, center, linked]
          aliases = []
          for perm in permutations([first, second, linked]):
            aliases.append(atom_key([perm[0], perm[1], center, perm[2]]))
          add_boundary_term(terms, 'IMPROPER', atom_key(improper_atoms), aliases, required=False)
  return terms
metadata = {'boundaryTerms': build_boundary_terms(), 'environment': {'platform': platform.platform(), 'python': sys.version, 'xponge': getattr(Xponge, '__version__', None), 'xpongeLibPath': getattr(XpongeLib, '__file__', '')}}
with open(payload.get('boundaryMetadataPath'), 'w', encoding='utf-8') as handle:
  json.dump(metadata, handle, indent=2, sort_keys=True)
def is_boundary_link(link):
  components = payload.get('components') or []
  idx_a = int(link.get('componentAIndex') or 0) - 1
  idx_b = int(link.get('componentBIndex') or 0) - 1
  if idx_a < 0 or idx_b < 0 or idx_a >= len(components) or idx_b >= len(components):
    return False
  return components[idx_a].get('kind') != components[idx_b].get('kind')
boundary_links = [link for link in (payload.get('links') or []) if is_boundary_link(link)]
Xponge.Save_Mol2(mol, payload.get('rawMol2Path'))
if len(boundary_links) <= 1:
  parmchk_links = boundary_links or (payload.get('links') or [])
  parmchk_mol = build_molecule('interactive_frcmod_parmchk', parmchk_links)
  Xponge.Save_Mol2(parmchk_mol, payload.get('rawMol2Path'))
  gaff.parmchk2_gaff(payload.get('rawMol2Path'), payload.get('rawOutputFrcmod'), direct_load=False, keep=True)
else:
  raw_mol2_root, raw_mol2_ext = os.path.splitext(payload.get('rawMol2Path'))
  raw_frcmod_root, raw_frcmod_ext = os.path.splitext(payload.get('rawOutputFrcmod'))
  raw_mol2_ext = raw_mol2_ext or '.mol2'
  raw_frcmod_ext = raw_frcmod_ext or '.frcmod'
  raw_parts = []
  for link_index, boundary_link in enumerate(boundary_links, start=1):
    link_mol = build_molecule(f'interactive_frcmod_link_{link_index}', [boundary_link])
    link_mol2 = f'{raw_mol2_root}.link{link_index}{raw_mol2_ext}'
    link_frcmod = f'{raw_frcmod_root}.link{link_index}{raw_frcmod_ext}'
    Xponge.Save_Mol2(link_mol, link_mol2)
    gaff.parmchk2_gaff(link_mol2, link_frcmod, direct_load=False, keep=True)
    raw_parts.append(link_frcmod)
  with open(payload.get('rawOutputFrcmod'), 'w', encoding='utf-8') as output_handle:
    output_handle.write('Remark generated by Mokda from per-boundary parmchk2 runs\n')
    for raw_part in raw_parts:
      output_handle.write('\n')
      with open(raw_part, 'r', encoding='utf-8', errors='ignore') as input_handle:
        output_handle.write(input_handle.read())
print(payload.get('rawOutputFrcmod'))