Filament
In [ ]:
Copied!
import numpy as np
import mdtraj as md
import matplotlib.pyplot as plt
try:
import nglview as nv
except ImportError:
nv = None
try:
import networkx as nx
except ImportError:
nx = None
from pathlib import Path
import numpy as np
import mdtraj as md
import matplotlib.pyplot as plt
try:
import nglview as nv
except ImportError:
nv = None
try:
import networkx as nx
except ImportError:
nx = None
from pathlib import Path
In [ ]:
Copied!
# https://biopython.org/docs/1.74/api/Bio.SVDSuperimposer.html
# conda install conda-forge::biopython
from Bio.SVDSuperimposer import SVDSuperimposer
# https://biopython.org/docs/1.74/api/Bio.SVDSuperimposer.html
# conda install conda-forge::biopython
from Bio.SVDSuperimposer import SVDSuperimposer
from numpy import dot
In [ ]:
Copied!
In [ ]:
Copied!
class SiteMapper:
def __init__(self, s1s1, s2s2, segments=dict, k=100):
self.segments = segments
self.s1s1 = s1s1[::k]
self.s2s2 = s2s2[::k]
#self.get_site_map()
def get_site_map(self):
# get structures of the different sites
s1 = self.get_site_structures(self.s1s1,site='s1')
s2_from_s1 = self.get_site_structures(self.s1s1, site='s2')
s2 = self.get_site_structures(self.s2s2,site='s2')
h3_s1s1 = self.get_segment_structures(self.s1s1,site='h3')
h3_s2s2 = self.get_segment_structures(self.s2s2,site='h3')
h3 = md.join([h3_s1s1,h3_s2s2])
l2_s1s1 = self.get_segment_structures(self.s1s1,site='l2')
l2_s2s2 = self.get_segment_structures(self.s2s2,site='l2')
l2 = md.join([l2_s1s1,l2_s2s2])
dbd_s1s1 = self.get_segment_structures(self.s1s1,site='dbd')
dbd_s2s2 = self.get_segment_structures(self.s2s2,site='dbd')
dbd = md.join([dbd_s1s1,dbd_s2s2])
# map of the different sites
site_map = {'s1':s1,
'h3':h3,
's2':s2,
's2_end' : s2_from_s1,
'l2':l2,
'dbd':dbd}
# fix resSeq numbering for second chain of s1 and s2
for c in site_map['s1'].top.chains:
if c.index == 1:
for res in c.residues:
res.resSeq = res.resSeq - 137
for c in site_map['s2'].top.chains:
if c.index == 1:
for res in c.residues:
res.resSeq = res.resSeq - 137
return site_map
def check_selection(self,top,selection):
if selection == 'CA':
indices = top.select('name CA')
elif selection == 'backbone':
indices = top.select('backbone')
elif selection == 'sidechain':
indices = top.select('sidechain')
else:
indices = top.select('all')
return indices
def get_monomer_domain_indices(self,top,domain,chain=0,selection=None):
residues = np.array(top._chains[chain]._residues)
indices = self.check_selection(top,selection)
return [at.index for res in residues[domain] for at in res.atoms if at.index in indices]
def get_segment_structures(self,traj,site='dbd'):
chain_a = self.get_monomer_domain_indices(top=traj.top, domain=self.segments[site], chain=0, selection=None)
chain_b = self.get_monomer_domain_indices(top=traj.top, domain=self.segments[site], chain=1, selection=None)
A = traj.atom_slice(chain_a)
B = traj.atom_slice(chain_b)
return md.join([A,B])
def get_site_structures(self, traj,site='s1'):
chain_a = self.get_monomer_domain_indices(top=traj.top, domain=self.segments[site], chain=0, selection=None)
chain_b = self.get_monomer_domain_indices(top=traj.top, domain=self.segments[site], chain=1, selection=None)
return traj.atom_slice(np.sort(chain_a+chain_b))
def show_domain(self, system, domains, domain):
""""Not working yet, need to fix the selection of the atoms in the domain."""
if nv is None:
raise ImportError('nglview is required for show_domain but is not installed.')
# shows first frame
top = system.top
view = nv.show_mdtraj(system[0])
view.clear()
indices = self.get_monomer_domain_indices(top, domains[domain], chain=0)
view.add_representation('cartoon',selection=[i for i in top.select('all') if i not in indices],color='cornflowerblue')
top = system.topology
chain_id = 0
indices = self.get_monomer_domain_indices(top, domains[domain], chain=chain_id)
view.add_representation('cartoon',selection=indices,color='gold')
top = system.topology
chain_id = 1
indices = self.get_monomer_domain_indices(top, domains[domain], chain=chain_id)
view.add_representation('cartoon',selection=indices,color='red')
return view
class SiteMapper:
def __init__(self, s1s1, s2s2, segments=dict, k=100):
self.segments = segments
self.s1s1 = s1s1[::k]
self.s2s2 = s2s2[::k]
#self.get_site_map()
def get_site_map(self):
# get structures of the different sites
s1 = self.get_site_structures(self.s1s1,site='s1')
s2_from_s1 = self.get_site_structures(self.s1s1, site='s2')
s2 = self.get_site_structures(self.s2s2,site='s2')
h3_s1s1 = self.get_segment_structures(self.s1s1,site='h3')
h3_s2s2 = self.get_segment_structures(self.s2s2,site='h3')
h3 = md.join([h3_s1s1,h3_s2s2])
l2_s1s1 = self.get_segment_structures(self.s1s1,site='l2')
l2_s2s2 = self.get_segment_structures(self.s2s2,site='l2')
l2 = md.join([l2_s1s1,l2_s2s2])
dbd_s1s1 = self.get_segment_structures(self.s1s1,site='dbd')
dbd_s2s2 = self.get_segment_structures(self.s2s2,site='dbd')
dbd = md.join([dbd_s1s1,dbd_s2s2])
# map of the different sites
site_map = {'s1':s1,
'h3':h3,
's2':s2,
's2_end' : s2_from_s1,
'l2':l2,
'dbd':dbd}
# fix resSeq numbering for second chain of s1 and s2
for c in site_map['s1'].top.chains:
if c.index == 1:
for res in c.residues:
res.resSeq = res.resSeq - 137
for c in site_map['s2'].top.chains:
if c.index == 1:
for res in c.residues:
res.resSeq = res.resSeq - 137
return site_map
def check_selection(self,top,selection):
if selection == 'CA':
indices = top.select('name CA')
elif selection == 'backbone':
indices = top.select('backbone')
elif selection == 'sidechain':
indices = top.select('sidechain')
else:
indices = top.select('all')
return indices
def get_monomer_domain_indices(self,top,domain,chain=0,selection=None):
residues = np.array(top._chains[chain]._residues)
indices = self.check_selection(top,selection)
return [at.index for res in residues[domain] for at in res.atoms if at.index in indices]
def get_segment_structures(self,traj,site='dbd'):
chain_a = self.get_monomer_domain_indices(top=traj.top, domain=self.segments[site], chain=0, selection=None)
chain_b = self.get_monomer_domain_indices(top=traj.top, domain=self.segments[site], chain=1, selection=None)
A = traj.atom_slice(chain_a)
B = traj.atom_slice(chain_b)
return md.join([A,B])
def get_site_structures(self, traj,site='s1'):
chain_a = self.get_monomer_domain_indices(top=traj.top, domain=self.segments[site], chain=0, selection=None)
chain_b = self.get_monomer_domain_indices(top=traj.top, domain=self.segments[site], chain=1, selection=None)
return traj.atom_slice(np.sort(chain_a+chain_b))
def show_domain(self, system, domains, domain):
""""Not working yet, need to fix the selection of the atoms in the domain."""
if nv is None:
raise ImportError('nglview is required for show_domain but is not installed.')
# shows first frame
top = system.top
view = nv.show_mdtraj(system[0])
view.clear()
indices = self.get_monomer_domain_indices(top, domains[domain], chain=0)
view.add_representation('cartoon',selection=[i for i in top.select('all') if i not in indices],color='cornflowerblue')
top = system.topology
chain_id = 0
indices = self.get_monomer_domain_indices(top, domains[domain], chain=chain_id)
view.add_representation('cartoon',selection=indices,color='gold')
top = system.topology
chain_id = 1
indices = self.get_monomer_domain_indices(top, domains[domain], chain=chain_id)
view.add_representation('cartoon',selection=indices,color='red')
return view
In [ ]:
Copied!
class Superimposer:
def __init__(self, A, B, overlap_A, overlap_B):
self.A = A
self.B = B
self.overlap_A = overlap_A
self.overlap_B = overlap_B
def get_rot_and_trans(self, subtraj_A,subtraj_B):
""" fit only works now on a single frame (mdtraj returns xyz with shape (n_frames, atoms, xyz)
even for single frame trajs so hence the xyz[0]"""
# load super imposer
sup = SVDSuperimposer()
# Set the coords, y will be rotated and translated on x
x = subtraj_A.xyz[0]
y = subtraj_B.xyz[0]
sup.set(x, y)
# Do the leastsquared fit
sup.run()
# Get the rms
rms = sup.get_rms()
# Get rotation (right multiplying!) and the translation
rot, tran = sup.get_rotran()
# now we have the instructions to rotate B on A
return rot,tran,rms
def apply_superimposition(self, traj, rot, tran):
# get xyz coordinates
xyz = traj.xyz[0]
# rotate subject on target
new_xyz = np.dot(xyz, rot) + tran
# replace coordinates of traj
traj.xyz = new_xyz
return traj
def fit_B_on_A(self):
# create trajs containing only the selections
subtraj_A = self.A.atom_slice(self.overlap_A)
subtraj_B = self.B.atom_slice(self.overlap_B)
# obtain instructions to rotate and translate B on A based on substraj structures
rot, tran, _ = self.get_rot_and_trans(subtraj_A,subtraj_B)
# do the superimposition of B on A and subsitute old with new xyz of B
return self.apply_superimposition(self.B, rot, tran)
class Superimposer:
def __init__(self, A, B, overlap_A, overlap_B):
self.A = A
self.B = B
self.overlap_A = overlap_A
self.overlap_B = overlap_B
def get_rot_and_trans(self, subtraj_A,subtraj_B):
""" fit only works now on a single frame (mdtraj returns xyz with shape (n_frames, atoms, xyz)
even for single frame trajs so hence the xyz[0]"""
# load super imposer
sup = SVDSuperimposer()
# Set the coords, y will be rotated and translated on x
x = subtraj_A.xyz[0]
y = subtraj_B.xyz[0]
sup.set(x, y)
# Do the leastsquared fit
sup.run()
# Get the rms
rms = sup.get_rms()
# Get rotation (right multiplying!) and the translation
rot, tran = sup.get_rotran()
# now we have the instructions to rotate B on A
return rot,tran,rms
def apply_superimposition(self, traj, rot, tran):
# get xyz coordinates
xyz = traj.xyz[0]
# rotate subject on target
new_xyz = np.dot(xyz, rot) + tran
# replace coordinates of traj
traj.xyz = new_xyz
return traj
def fit_B_on_A(self):
# create trajs containing only the selections
subtraj_A = self.A.atom_slice(self.overlap_A)
subtraj_B = self.B.atom_slice(self.overlap_B)
# obtain instructions to rotate and translate B on A based on substraj structures
rot, tran, _ = self.get_rot_and_trans(subtraj_A,subtraj_B)
# do the superimposition of B on A and subsitute old with new xyz of B
return self.apply_superimposition(self.B, rot, tran)
In [ ]:
Copied!
class Helper:
@staticmethod
def check_if_dimerization(site):
if 's' in site:
return True
else:
return False
@staticmethod
def get_termini(site_x,site_y):
chain_order = np.array(['s1','h3','s2','l2','dbd'])
x = np.argwhere(chain_order==site_x)
y = np.argwhere(chain_order==site_y)
if x < y:
return ['N_terminus','C_terminus']
elif x > y:
return ['C_terminus','N_terminus']
@staticmethod
def get_overlap_indices(top,n,chain=0,terminus=None):
residues = np.array(top._chains[chain]._residues)
if terminus == 'N_terminus': # get residues at end of chain
s = residues[len(residues)-n*2:len(residues)]
return [at.index for res in s for at in res.atoms]
elif terminus == 'C_terminus': # get residues at beginning of chain
s = residues[:n*2]
return [at.index for res in s for at in res.atoms]
else:
print('No terminus')
@staticmethod
def check_overlaps(overlap_A,overlap_B):
if len(overlap_A) != len(overlap_B):
print(len(overlap_A),len(overlap_B))
print('Something went wrong with finding the overlaps')
else:
False
@staticmethod
def remove_overlap_old(traj, overlap):
"""Works fine but is slow for large number of atoms... because top and traj get reinitialized"""
return traj.atom_slice([at.index for at in traj.top.atoms if at.index not in overlap])
@staticmethod
def remove_overlap(traj, overlap):
indices_to_keep = [at.index for at in traj.top.atoms if at.index not in overlap]
xyz = np.array(traj.xyz[:, indices_to_keep], order='C')
for _,index in enumerate(overlap):
traj.top.delete_atom_by_index(index-_)
traj.xyz = xyz
return traj
@staticmethod
def split_chain_topology(traj, leading_chain):
# split part of A in chain that is being extended and that is not
traj_active = traj.atom_slice(traj.top.select(f'chainid {leading_chain}'))
traj_passive = traj.atom_slice(traj.top.select(f'not chainid {leading_chain}'))
return traj_active, traj_passive
@staticmethod
def merge_chain_topology(A, B, keep_resSeq=True):
C = A.stack(B,keep_resSeq=keep_resSeq)
top = C.top
# Merge two tops (with two chains or more) to a top of one chain
out = md.Topology()
c = out.add_chain()
for chain in top.chains:
for residue in chain.residues:
r = out.add_residue(residue.name, c, residue.resSeq, residue.segment_id)
for atom in residue.atoms:
out.add_atom(atom.name, atom.element, r, serial=atom.serial)
# for bond in top.bonds:
# a1, a2 = bond
# out.add_bond(a1, a2, type=bond.type, order=bond.order)
out.create_standard_bonds() #rare manier om bonds te maken, maar werkt
C.top = out
return C
class Helper:
@staticmethod
def check_if_dimerization(site):
if 's' in site:
return True
else:
return False
@staticmethod
def get_termini(site_x,site_y):
chain_order = np.array(['s1','h3','s2','l2','dbd'])
x = np.argwhere(chain_order==site_x)
y = np.argwhere(chain_order==site_y)
if x < y:
return ['N_terminus','C_terminus']
elif x > y:
return ['C_terminus','N_terminus']
@staticmethod
def get_overlap_indices(top,n,chain=0,terminus=None):
residues = np.array(top._chains[chain]._residues)
if terminus == 'N_terminus': # get residues at end of chain
s = residues[len(residues)-n*2:len(residues)]
return [at.index for res in s for at in res.atoms]
elif terminus == 'C_terminus': # get residues at beginning of chain
s = residues[:n*2]
return [at.index for res in s for at in res.atoms]
else:
print('No terminus')
@staticmethod
def check_overlaps(overlap_A,overlap_B):
if len(overlap_A) != len(overlap_B):
print(len(overlap_A),len(overlap_B))
print('Something went wrong with finding the overlaps')
else:
False
@staticmethod
def remove_overlap_old(traj, overlap):
"""Works fine but is slow for large number of atoms... because top and traj get reinitialized"""
return traj.atom_slice([at.index for at in traj.top.atoms if at.index not in overlap])
@staticmethod
def remove_overlap(traj, overlap):
indices_to_keep = [at.index for at in traj.top.atoms if at.index not in overlap]
xyz = np.array(traj.xyz[:, indices_to_keep], order='C')
for _,index in enumerate(overlap):
traj.top.delete_atom_by_index(index-_)
traj.xyz = xyz
return traj
@staticmethod
def split_chain_topology(traj, leading_chain):
# split part of A in chain that is being extended and that is not
traj_active = traj.atom_slice(traj.top.select(f'chainid {leading_chain}'))
traj_passive = traj.atom_slice(traj.top.select(f'not chainid {leading_chain}'))
return traj_active, traj_passive
@staticmethod
def merge_chain_topology(A, B, keep_resSeq=True):
C = A.stack(B,keep_resSeq=keep_resSeq)
top = C.top
# Merge two tops (with two chains or more) to a top of one chain
out = md.Topology()
c = out.add_chain()
for chain in top.chains:
for residue in chain.residues:
r = out.add_residue(residue.name, c, residue.resSeq, residue.segment_id)
for atom in residue.atoms:
out.add_atom(atom.name, atom.element, r, serial=atom.serial)
# for bond in top.bonds:
# a1, a2 = bond
# out.add_bond(a1, a2, type=bond.type, order=bond.order)
out.create_standard_bonds() #rare manier om bonds te maken, maar werkt
C.top = out
return C
In [ ]:
Copied!
class Fixer:
def __init__(self, traj):
if nx is None:
raise ImportError('networkx is required for Fixer but is not installed.')
segments = {'s1':np.arange(0,41),
's2':np.arange(53,82)}
# Figure out which chains are connected
G = self.compute_interaction_graph(traj, segments)
# Traverse over graph for new chain assignements
chain_mapping = self.traverse_from_endpoint(G)
# Update chain order in topology
new_topology, atom_mapping = self.update_chain_topology(traj, chain_mapping)
# Update xyz coordinates
new_xyz = traj.xyz[:,atom_mapping]
# Create new trajectory with corrected order and adjust xyz as well
new_traj = md.Trajectory(new_xyz,new_topology)
self.new_traj = new_traj
def get_updated_traj(self):
return self.new_traj
def traverse_from_endpoint(self, G):
# Find all nodes with degree 1 (endpoints)
endpoints = [node for node, degree in G.degree() if degree == 1]
# Choose the first endpoint as the start node
start_node = endpoints[0] if endpoints else None
# Initialize a dictionary to store the number of steps to each node
# chain_mapping = {node: float('inf') for node in G.nodes()}
D = []
for node in G.nodes:
d = nx.shortest_path_length(G, source=start_node, target=node)
D.append(d)
chain_mapping = {i:j for i,j in zip(G.nodes, np.argsort(D))}
return chain_mapping
def compute_chain_centers(self, traj, domain):
top = traj.top
coms = []
lens = []
ids = []
for c in top.chains:
try:
selection = top.select(f'chainid {c.index} and resSeq {domain[0]} to {domain[-1]}')
com = md.compute_center_of_mass(traj.atom_slice(selection))
coms.append(com)
lens.append(len(selection))
ids.append(c.index)
except:
pass
coms = np.squeeze(np.array(coms))
lens = np.array(lens)
ids = np.array(ids)
return coms, lens, ids
def compute_interaction_graph(self, traj, segments):
# Compute COMS of each chain domain and get chain labels
s1_centers = self.compute_chain_centers(traj, segments['s1'])
s2_centers = self.compute_chain_centers(traj, segments['s2'])
# Initialize graph
G = nx.Graph()
labels = {}
for c in traj.top.chains:
G.add_node(c.index,label=c.index)
labels[c.index] = c.index
# Loop over centers
for idx,center in enumerate([s1_centers, s2_centers]):
coms = center[0]
ids = center[2]
# Computer distance between coms
D = np.zeros((len(coms),len(coms)))
for i,ci in enumerate(coms):
for j,cj in enumerate(coms):
d = np.linalg.norm(ci-cj)
D[i,j] = d
# Use closest pairs to collect edges
edges = []
for _,d in enumerate(D):
pair = np.sort([ids[_],ids[np.argsort(d)[1]]])
edges.append(pair)
# Filter pairs for redudancy
edges = np.unique(edges,axis=0)
for pair in edges:
G.add_edge(pair[0],pair[1])
return G
def update_chain_topology(self, traj, chain_mapping):
# Initialize empty top
new_top = md.Topology()
# Collect current chains
chains = list(traj.top.chains)
atom_mapping = []
# Loop over chain mapping
for new, current in chain_mapping.items():
new_chain = new_top.add_chain() # add empty chain
chain = chains[current]
for res in chain.residues: # fill chain with resdues
new_res = new_top.add_residue(res.name, new_chain, res.resSeq,res.segment_id)
for atom in res.atoms: # fill residue with atoms
new_top.add_atom(atom.name, atom.element, new_res, serial=atom.serial)
atom_mapping.append(atom.index) # keep track of new index order
# Return mapping and top
return new_top, atom_mapping
class Fixer:
def __init__(self, traj):
if nx is None:
raise ImportError('networkx is required for Fixer but is not installed.')
segments = {'s1':np.arange(0,41),
's2':np.arange(53,82)}
# Figure out which chains are connected
G = self.compute_interaction_graph(traj, segments)
# Traverse over graph for new chain assignements
chain_mapping = self.traverse_from_endpoint(G)
# Update chain order in topology
new_topology, atom_mapping = self.update_chain_topology(traj, chain_mapping)
# Update xyz coordinates
new_xyz = traj.xyz[:,atom_mapping]
# Create new trajectory with corrected order and adjust xyz as well
new_traj = md.Trajectory(new_xyz,new_topology)
self.new_traj = new_traj
def get_updated_traj(self):
return self.new_traj
def traverse_from_endpoint(self, G):
# Find all nodes with degree 1 (endpoints)
endpoints = [node for node, degree in G.degree() if degree == 1]
# Choose the first endpoint as the start node
start_node = endpoints[0] if endpoints else None
# Initialize a dictionary to store the number of steps to each node
# chain_mapping = {node: float('inf') for node in G.nodes()}
D = []
for node in G.nodes:
d = nx.shortest_path_length(G, source=start_node, target=node)
D.append(d)
chain_mapping = {i:j for i,j in zip(G.nodes, np.argsort(D))}
return chain_mapping
def compute_chain_centers(self, traj, domain):
top = traj.top
coms = []
lens = []
ids = []
for c in top.chains:
try:
selection = top.select(f'chainid {c.index} and resSeq {domain[0]} to {domain[-1]}')
com = md.compute_center_of_mass(traj.atom_slice(selection))
coms.append(com)
lens.append(len(selection))
ids.append(c.index)
except:
pass
coms = np.squeeze(np.array(coms))
lens = np.array(lens)
ids = np.array(ids)
return coms, lens, ids
def compute_interaction_graph(self, traj, segments):
# Compute COMS of each chain domain and get chain labels
s1_centers = self.compute_chain_centers(traj, segments['s1'])
s2_centers = self.compute_chain_centers(traj, segments['s2'])
# Initialize graph
G = nx.Graph()
labels = {}
for c in traj.top.chains:
G.add_node(c.index,label=c.index)
labels[c.index] = c.index
# Loop over centers
for idx,center in enumerate([s1_centers, s2_centers]):
coms = center[0]
ids = center[2]
# Computer distance between coms
D = np.zeros((len(coms),len(coms)))
for i,ci in enumerate(coms):
for j,cj in enumerate(coms):
d = np.linalg.norm(ci-cj)
D[i,j] = d
# Use closest pairs to collect edges
edges = []
for _,d in enumerate(D):
pair = np.sort([ids[_],ids[np.argsort(d)[1]]])
edges.append(pair)
# Filter pairs for redudancy
edges = np.unique(edges,axis=0)
for pair in edges:
G.add_edge(pair[0],pair[1])
return G
def update_chain_topology(self, traj, chain_mapping):
# Initialize empty top
new_top = md.Topology()
# Collect current chains
chains = list(traj.top.chains)
atom_mapping = []
# Loop over chain mapping
for new, current in chain_mapping.items():
new_chain = new_top.add_chain() # add empty chain
chain = chains[current]
for res in chain.residues: # fill chain with resdues
new_res = new_top.add_residue(res.name, new_chain, res.resSeq,res.segment_id)
for atom in res.atoms: # fill residue with atoms
new_top.add_atom(atom.name, atom.element, new_res, serial=atom.serial)
atom_mapping.append(atom.index) # keep track of new index order
# Return mapping and top
return new_top, atom_mapping
In [ ]:
Copied!
class Assembler:
def __init__(self, site_map, n_overlap : int = 2):
self.traj = None
self.chain_id = 0
self.site_map = site_map
self.n = n_overlap
self.n_dimers = 0
self.n_dna = 0
self.first = True
self.s1_pairs = [['s1','h3'],['h3','s2'],['s2','l2'],['l2','dbd']]
self.s2_pairs = [['s2','h3'],['h3','s1'],['s2','l2'],['l2','dbd']]
self.traj_history = []
self.cleaned = False
@staticmethod
def default_segments(n_overlap=2):
return {
's1': np.arange(0, 41 + n_overlap),
'h3': np.arange(41 - n_overlap, 53 + n_overlap),
's2': np.arange(53 - n_overlap, 82 + n_overlap),
'l2': np.arange(82 - n_overlap, 95 + n_overlap),
'dbd': np.arange(95 - n_overlap, 137),
}
@staticmethod
def _pick_site_frame(site_traj, site_name, source_name):
if site_name in {'s1', 's2', 's2_end'}:
return site_traj[0]
if source_name == 's1s1':
return site_traj[0]
if source_name == 's2s2':
return site_traj[-1]
raise ValueError(f'Unknown source name {source_name} for site {site_name}.')
@staticmethod
def load_minimal_site_map(minimal_dir, segments=None, n_overlap=2):
"""
Load minimal source PDBs and build start/extend segment maps on demand.
Expected files in ``minimal_dir``:
- ``s1s1_start.pdb``
- ``s2s2_start.pdb``
- ``s1s1_extend.pdb``
- ``s2s2_extend.pdb``
- ``complex_frame_<idx>.pdb``
"""
base = Path(minimal_dir)
if segments is None:
segments = Assembler.default_segments(n_overlap=n_overlap)
required_files = [
's1s1_start.pdb',
's2s2_start.pdb',
's1s1_extend.pdb',
's2s2_extend.pdb',
]
missing = [filename for filename in required_files if not (base / filename).exists()]
if missing:
raise FileNotFoundError(f'Missing required minimal source files: {missing}')
manifest = {}
manifest_path = base / 'manifest.json'
if manifest_path.exists():
import json
with open(manifest_path, 'r', encoding='utf-8') as handle:
manifest = json.load(handle)
start_site_sources = manifest.get('start_site_sources', {'s2': 's2s2', 'h3': 's1s1'})
extend_site_sources = manifest.get(
'extend_site_sources',
{'s1': 's1s1', 'h3': 's1s1', 's2': 's2s2', 'l2': 's1s1', 'dbd': 's1s1'},
)
s1s1_start = md.load(str(base / 's1s1_start.pdb'))
s2s2_start = md.load(str(base / 's2s2_start.pdb'))
s1s1_extend = md.load(str(base / 's1s1_extend.pdb'))
s2s2_extend = md.load(str(base / 's2s2_extend.pdb'))
start_mapper = SiteMapper(s1s1_start, s2s2_start, segments=segments, k=1)
extend_mapper = SiteMapper(s1s1_extend, s2s2_extend, segments=segments, k=1)
start_site_map = start_mapper.get_site_map()
extend_site_map = extend_mapper.get_site_map()
minimal_start = {}
for site_name, source_name in start_site_sources.items():
if site_name not in start_site_map:
raise KeyError(f'Start site {site_name} missing from generated site map.')
minimal_start[site_name] = Assembler._pick_site_frame(start_site_map[site_name], site_name, source_name)
minimal_extend = {}
for site_name, source_name in extend_site_sources.items():
if site_name not in extend_site_map:
raise KeyError(f'Extend site {site_name} missing from generated site map.')
minimal_extend[site_name] = Assembler._pick_site_frame(extend_site_map[site_name], site_name, source_name)
site_map = {'minimal_start': minimal_start, 'minimal_extend': minimal_extend}
dna_frame_idx = manifest.get('dna_frame_index', 1)
dna_candidate = base / f'complex_frame_{dna_frame_idx}.pdb'
if not dna_candidate.exists():
candidates = sorted(base.glob('complex_frame_*.pdb'))
if candidates:
dna_candidate = candidates[0]
if dna_candidate.exists():
site_map['complex'] = md.load(str(dna_candidate))
return site_map
@staticmethod
def load_fixed_minimal_site_map(minimal_dir):
return Assembler.load_minimal_site_map(minimal_dir=minimal_dir)
def get_traj(self):
if self.cleaned:
return self.traj
else:
self.clean_traj()
return self.traj
# if self.n_dna != 0:
# subtraj_dna = self.traj.atom_slice(self.traj.top.select('resname DG DC DA DT'))
# subtraj_protein = self.traj.atom_slice(self.traj.top.select('protein'))
# traj = subtraj_protein.atom_slice(subtraj_protein.top.select(f'chainid 0 to {(self.n_dimers*2)-1}'))
# return traj.stack(subtraj_dna)
# else:
# return self.traj.atom_slice(self.traj.top.select(f'chainid 0 to {(self.n_dimers*2)-1}'))
def add_dimer(self, verbose: bool = False, segment: str = 'random'):
"""
Adds a dimer to the trajectory by iterating over s2_pairs first and then s1_pairs, with specific increments to chain_id.
"""
self.n_dimers += 1
# print('Start processing s2 pairs')
# print(self.s2_pairs)
self.traj = self.process_pairs(self.s2_pairs, self.chain_id, verbose, segment)
# print('Start processing s1 pairs')
# print(self.s1_pairs)
self.traj = self.process_pairs(self.s1_pairs, self.chain_id + 2,verbose, segment)
self.chain_id += 2
return self.traj
def process_pairs(self, pairs, chain_id, verbose, segment):
"""
Processes a sequence of pairs, adding each to the trajectory.
"""
for idx, pair in enumerate(pairs):
leading_chain = chain_id if idx == 0 else 0
self.traj = self.add_pair(pair, leading_chain=leading_chain, verbose=verbose, segment=segment)
return self.traj
# Get segments based on 'fixed' or 'random' segment criteria
def get_segments(self, site_a, site_b, segment):
if segment == 'minimal' and 'minimal_start' in self.site_map and 'minimal_extend' in self.site_map:
if not self.traj:
start_map = self.site_map['minimal_start']
fallback_map = self.site_map['minimal_extend']
A = start_map.get(site_a, fallback_map.get(site_a))
B = start_map.get(site_b, fallback_map.get(site_b))
if A is None or B is None:
raise KeyError(f'Missing minimal-start entries for {site_a} and/or {site_b}.')
else:
extension_map = self.site_map['minimal_extend']
if site_b not in extension_map:
raise KeyError(f'Missing minimal-extension entry for site {site_b}.')
A = self.traj
B = extension_map[site_b]
return A, B
if not self.traj:
if segment == 'fixed':
x, y = 40, 90
elif segment == 'random':
k = len(self.site_map[site_a])
l = len(self.site_map[site_b])
x, y = np.random.randint(0, k), np.random.randint(0, l)
A = self.site_map[site_a][x]
B = self.site_map[site_b][y]
else:
if segment == 'fixed':
z = 20
elif segment == 'random':
k = len(self.site_map[site_b])
z = np.random.randint(0, k)
A = self.traj
B = self.site_map[site_b][z]
return A, B
# Check for dimerization and print if verbose
def check_dimerization(self, site_a, site_b):
dimer_a = Helper.check_if_dimerization(site_a)
dimer_b = Helper.check_if_dimerization(site_b)
return dimer_a, dimer_b
# Determine growth direction based on terminus
def determine_terminus_direction(self, site_a, site_b):
terminus_a, terminus_b = Helper.get_termini(site_a, site_b)
reverse = terminus_a == 'C_terminus'
return reverse, terminus_a, terminus_b
# Manage overlaps and perform superimposition
def manage_overlaps(self, A, B, leading_chain, adding_chain, terminus_a, terminus_b):
overlap_A = Helper.get_overlap_indices(A.top, self.n, chain=leading_chain, terminus=terminus_a)
overlap_B = Helper.get_overlap_indices(B.top, self.n, chain=adding_chain, terminus=terminus_b)
check = Helper.check_overlaps(overlap_A, overlap_B)
if check:
return check
superimposer = Superimposer(A, B, overlap_A, overlap_B)
new_B = superimposer.fit_B_on_A()
# Instead of atom slice I should use pop/delete to remove the atoms?
new_A = Helper.remove_overlap(A, overlap_A)
return new_A, new_B
# Manipulate topologies: split, merge, and stack components
def manipulate_topology(self, A, B, leading_chain, adding_chain, reverse, keep_resSeq, dimer_b):
A_active, A_passive = Helper.split_chain_topology(A, leading_chain)
if dimer_b:
B_active, B_passive = Helper.split_chain_topology(B, adding_chain)
temp = Helper.merge_chain_topology(B_active if reverse else A_active, A_active if reverse else B_active, keep_resSeq)
C_temp = temp.stack(A_passive, keep_resSeq)
C = C_temp.stack(B_passive, keep_resSeq)
else:
temp = Helper.merge_chain_topology(B if reverse else A_active, A_active if reverse else B, keep_resSeq)
C = temp.stack(A_passive, keep_resSeq)
return C
# Refactored method to add a pair of sites
def add_pair(self, pair, leading_chain=0, adding_chain=0, verbose=False, reverse=False, segment='fixed'):
if verbose:
#print(f'Adding pair: {pair} of dimer {self.n_dimers}')
pass
site_a, site_b = pair
# Get segments based on 'fixed' or 'random' segment criteria
A, B = self.get_segments(site_a, site_b, segment)
# Check for dimerization (aka if site is s1 or s2)
dimer_a, dimer_b = self.check_dimerization(site_a, site_b)
# Determine growth direction based on terminus
reverse, terminus_a, terminus_b = self.determine_terminus_direction(site_a, site_b)
# Manage overlaps and perform superimposition
new_A, new_B = self.manage_overlaps(A, B, leading_chain, adding_chain, terminus_a, terminus_b)
# Manipulate topologies: split, merge, and stack components, and return new trajectory
C = self.manipulate_topology(A=new_A, B=new_B, leading_chain=leading_chain, adding_chain=adding_chain, reverse=reverse, keep_resSeq=True, dimer_b=dimer_b)
self.traj_history.append(C)
return C
def clean_traj(self):
self.raw_traj = self.traj
fixer = Fixer(self.traj)
self.traj = fixer.get_updated_traj()
# Remove leftover s2 segment domains at the ends of the filament
self.traj = self.traj.atom_slice(self.traj.top.select(f'chainid 1 to {(self.n_dimers*2)}'))
self.cleaned = True
def add_dna(self, chainid=0, frame_idx=1):
if not self.cleaned:
self.clean_traj()
# Select frame of DNA - DBD complex
if 'complex' not in self.site_map:
raise KeyError('No DNA complex found in site_map. Provide site_map["complex"] before calling add_dna.')
complex_traj = self.site_map['complex']
dna_complex = complex_traj[0] if len(complex_traj) == 1 else complex_traj[frame_idx]
# Get selection of dbd residues for fit of only backbone
indices_dbd_complex = dna_complex.top.select(f'resSeq 95 to 137 and backbone') # SALMONELA
indices_dbd_traj = self.traj.top.select(f'(chainid {chainid} and resSeq 95 to 137) and backbone') # Ecoli
# Fit the dbd with DNA to the loction of the dbd in the filament at chainid
imposer = Superimposer(self.traj,dna_complex,indices_dbd_traj,indices_dbd_complex)
new_dbd_complex = imposer.fit_B_on_A()
# Remove indices of DBD from complex
dna = new_dbd_complex.atom_slice(new_dbd_complex.top.select('not protein'))
# Add DNA to the filament traj
self.traj = self.traj.stack(dna)
self.n_dna += 1
class Assembler:
def __init__(self, site_map, n_overlap : int = 2):
self.traj = None
self.chain_id = 0
self.site_map = site_map
self.n = n_overlap
self.n_dimers = 0
self.n_dna = 0
self.first = True
self.s1_pairs = [['s1','h3'],['h3','s2'],['s2','l2'],['l2','dbd']]
self.s2_pairs = [['s2','h3'],['h3','s1'],['s2','l2'],['l2','dbd']]
self.traj_history = []
self.cleaned = False
@staticmethod
def default_segments(n_overlap=2):
return {
's1': np.arange(0, 41 + n_overlap),
'h3': np.arange(41 - n_overlap, 53 + n_overlap),
's2': np.arange(53 - n_overlap, 82 + n_overlap),
'l2': np.arange(82 - n_overlap, 95 + n_overlap),
'dbd': np.arange(95 - n_overlap, 137),
}
@staticmethod
def _pick_site_frame(site_traj, site_name, source_name):
if site_name in {'s1', 's2', 's2_end'}:
return site_traj[0]
if source_name == 's1s1':
return site_traj[0]
if source_name == 's2s2':
return site_traj[-1]
raise ValueError(f'Unknown source name {source_name} for site {site_name}.')
@staticmethod
def load_minimal_site_map(minimal_dir, segments=None, n_overlap=2):
"""
Load minimal source PDBs and build start/extend segment maps on demand.
Expected files in ``minimal_dir``:
- ``s1s1_start.pdb``
- ``s2s2_start.pdb``
- ``s1s1_extend.pdb``
- ``s2s2_extend.pdb``
- ``complex_frame_<idx>.pdb``
"""
base = Path(minimal_dir)
if segments is None:
segments = Assembler.default_segments(n_overlap=n_overlap)
required_files = [
's1s1_start.pdb',
's2s2_start.pdb',
's1s1_extend.pdb',
's2s2_extend.pdb',
]
missing = [filename for filename in required_files if not (base / filename).exists()]
if missing:
raise FileNotFoundError(f'Missing required minimal source files: {missing}')
manifest = {}
manifest_path = base / 'manifest.json'
if manifest_path.exists():
import json
with open(manifest_path, 'r', encoding='utf-8') as handle:
manifest = json.load(handle)
start_site_sources = manifest.get('start_site_sources', {'s2': 's2s2', 'h3': 's1s1'})
extend_site_sources = manifest.get(
'extend_site_sources',
{'s1': 's1s1', 'h3': 's1s1', 's2': 's2s2', 'l2': 's1s1', 'dbd': 's1s1'},
)
s1s1_start = md.load(str(base / 's1s1_start.pdb'))
s2s2_start = md.load(str(base / 's2s2_start.pdb'))
s1s1_extend = md.load(str(base / 's1s1_extend.pdb'))
s2s2_extend = md.load(str(base / 's2s2_extend.pdb'))
start_mapper = SiteMapper(s1s1_start, s2s2_start, segments=segments, k=1)
extend_mapper = SiteMapper(s1s1_extend, s2s2_extend, segments=segments, k=1)
start_site_map = start_mapper.get_site_map()
extend_site_map = extend_mapper.get_site_map()
minimal_start = {}
for site_name, source_name in start_site_sources.items():
if site_name not in start_site_map:
raise KeyError(f'Start site {site_name} missing from generated site map.')
minimal_start[site_name] = Assembler._pick_site_frame(start_site_map[site_name], site_name, source_name)
minimal_extend = {}
for site_name, source_name in extend_site_sources.items():
if site_name not in extend_site_map:
raise KeyError(f'Extend site {site_name} missing from generated site map.')
minimal_extend[site_name] = Assembler._pick_site_frame(extend_site_map[site_name], site_name, source_name)
site_map = {'minimal_start': minimal_start, 'minimal_extend': minimal_extend}
dna_frame_idx = manifest.get('dna_frame_index', 1)
dna_candidate = base / f'complex_frame_{dna_frame_idx}.pdb'
if not dna_candidate.exists():
candidates = sorted(base.glob('complex_frame_*.pdb'))
if candidates:
dna_candidate = candidates[0]
if dna_candidate.exists():
site_map['complex'] = md.load(str(dna_candidate))
return site_map
@staticmethod
def load_fixed_minimal_site_map(minimal_dir):
return Assembler.load_minimal_site_map(minimal_dir=minimal_dir)
def get_traj(self):
if self.cleaned:
return self.traj
else:
self.clean_traj()
return self.traj
# if self.n_dna != 0:
# subtraj_dna = self.traj.atom_slice(self.traj.top.select('resname DG DC DA DT'))
# subtraj_protein = self.traj.atom_slice(self.traj.top.select('protein'))
# traj = subtraj_protein.atom_slice(subtraj_protein.top.select(f'chainid 0 to {(self.n_dimers*2)-1}'))
# return traj.stack(subtraj_dna)
# else:
# return self.traj.atom_slice(self.traj.top.select(f'chainid 0 to {(self.n_dimers*2)-1}'))
def add_dimer(self, verbose: bool = False, segment: str = 'random'):
"""
Adds a dimer to the trajectory by iterating over s2_pairs first and then s1_pairs, with specific increments to chain_id.
"""
self.n_dimers += 1
# print('Start processing s2 pairs')
# print(self.s2_pairs)
self.traj = self.process_pairs(self.s2_pairs, self.chain_id, verbose, segment)
# print('Start processing s1 pairs')
# print(self.s1_pairs)
self.traj = self.process_pairs(self.s1_pairs, self.chain_id + 2,verbose, segment)
self.chain_id += 2
return self.traj
def process_pairs(self, pairs, chain_id, verbose, segment):
"""
Processes a sequence of pairs, adding each to the trajectory.
"""
for idx, pair in enumerate(pairs):
leading_chain = chain_id if idx == 0 else 0
self.traj = self.add_pair(pair, leading_chain=leading_chain, verbose=verbose, segment=segment)
return self.traj
# Get segments based on 'fixed' or 'random' segment criteria
def get_segments(self, site_a, site_b, segment):
if segment == 'minimal' and 'minimal_start' in self.site_map and 'minimal_extend' in self.site_map:
if not self.traj:
start_map = self.site_map['minimal_start']
fallback_map = self.site_map['minimal_extend']
A = start_map.get(site_a, fallback_map.get(site_a))
B = start_map.get(site_b, fallback_map.get(site_b))
if A is None or B is None:
raise KeyError(f'Missing minimal-start entries for {site_a} and/or {site_b}.')
else:
extension_map = self.site_map['minimal_extend']
if site_b not in extension_map:
raise KeyError(f'Missing minimal-extension entry for site {site_b}.')
A = self.traj
B = extension_map[site_b]
return A, B
if not self.traj:
if segment == 'fixed':
x, y = 40, 90
elif segment == 'random':
k = len(self.site_map[site_a])
l = len(self.site_map[site_b])
x, y = np.random.randint(0, k), np.random.randint(0, l)
A = self.site_map[site_a][x]
B = self.site_map[site_b][y]
else:
if segment == 'fixed':
z = 20
elif segment == 'random':
k = len(self.site_map[site_b])
z = np.random.randint(0, k)
A = self.traj
B = self.site_map[site_b][z]
return A, B
# Check for dimerization and print if verbose
def check_dimerization(self, site_a, site_b):
dimer_a = Helper.check_if_dimerization(site_a)
dimer_b = Helper.check_if_dimerization(site_b)
return dimer_a, dimer_b
# Determine growth direction based on terminus
def determine_terminus_direction(self, site_a, site_b):
terminus_a, terminus_b = Helper.get_termini(site_a, site_b)
reverse = terminus_a == 'C_terminus'
return reverse, terminus_a, terminus_b
# Manage overlaps and perform superimposition
def manage_overlaps(self, A, B, leading_chain, adding_chain, terminus_a, terminus_b):
overlap_A = Helper.get_overlap_indices(A.top, self.n, chain=leading_chain, terminus=terminus_a)
overlap_B = Helper.get_overlap_indices(B.top, self.n, chain=adding_chain, terminus=terminus_b)
check = Helper.check_overlaps(overlap_A, overlap_B)
if check:
return check
superimposer = Superimposer(A, B, overlap_A, overlap_B)
new_B = superimposer.fit_B_on_A()
# Instead of atom slice I should use pop/delete to remove the atoms?
new_A = Helper.remove_overlap(A, overlap_A)
return new_A, new_B
# Manipulate topologies: split, merge, and stack components
def manipulate_topology(self, A, B, leading_chain, adding_chain, reverse, keep_resSeq, dimer_b):
A_active, A_passive = Helper.split_chain_topology(A, leading_chain)
if dimer_b:
B_active, B_passive = Helper.split_chain_topology(B, adding_chain)
temp = Helper.merge_chain_topology(B_active if reverse else A_active, A_active if reverse else B_active, keep_resSeq)
C_temp = temp.stack(A_passive, keep_resSeq)
C = C_temp.stack(B_passive, keep_resSeq)
else:
temp = Helper.merge_chain_topology(B if reverse else A_active, A_active if reverse else B, keep_resSeq)
C = temp.stack(A_passive, keep_resSeq)
return C
# Refactored method to add a pair of sites
def add_pair(self, pair, leading_chain=0, adding_chain=0, verbose=False, reverse=False, segment='fixed'):
if verbose:
#print(f'Adding pair: {pair} of dimer {self.n_dimers}')
pass
site_a, site_b = pair
# Get segments based on 'fixed' or 'random' segment criteria
A, B = self.get_segments(site_a, site_b, segment)
# Check for dimerization (aka if site is s1 or s2)
dimer_a, dimer_b = self.check_dimerization(site_a, site_b)
# Determine growth direction based on terminus
reverse, terminus_a, terminus_b = self.determine_terminus_direction(site_a, site_b)
# Manage overlaps and perform superimposition
new_A, new_B = self.manage_overlaps(A, B, leading_chain, adding_chain, terminus_a, terminus_b)
# Manipulate topologies: split, merge, and stack components, and return new trajectory
C = self.manipulate_topology(A=new_A, B=new_B, leading_chain=leading_chain, adding_chain=adding_chain, reverse=reverse, keep_resSeq=True, dimer_b=dimer_b)
self.traj_history.append(C)
return C
def clean_traj(self):
self.raw_traj = self.traj
fixer = Fixer(self.traj)
self.traj = fixer.get_updated_traj()
# Remove leftover s2 segment domains at the ends of the filament
self.traj = self.traj.atom_slice(self.traj.top.select(f'chainid 1 to {(self.n_dimers*2)}'))
self.cleaned = True
def add_dna(self, chainid=0, frame_idx=1):
if not self.cleaned:
self.clean_traj()
# Select frame of DNA - DBD complex
if 'complex' not in self.site_map:
raise KeyError('No DNA complex found in site_map. Provide site_map["complex"] before calling add_dna.')
complex_traj = self.site_map['complex']
dna_complex = complex_traj[0] if len(complex_traj) == 1 else complex_traj[frame_idx]
# Get selection of dbd residues for fit of only backbone
indices_dbd_complex = dna_complex.top.select(f'resSeq 95 to 137 and backbone') # SALMONELA
indices_dbd_traj = self.traj.top.select(f'(chainid {chainid} and resSeq 95 to 137) and backbone') # Ecoli
# Fit the dbd with DNA to the loction of the dbd in the filament at chainid
imposer = Superimposer(self.traj,dna_complex,indices_dbd_traj,indices_dbd_complex)
new_dbd_complex = imposer.fit_B_on_A()
# Remove indices of DBD from complex
dna = new_dbd_complex.atom_slice(new_dbd_complex.top.select('not protein'))
# Add DNA to the filament traj
self.traj = self.traj.stack(dna)
self.n_dna += 1