Coverage for visualization.py : 79%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
3try:
4 import matplotlib.pyplot as plt
5except ImportError:
6 plt = None
8import networkx as nx
10import torch
11from elfragmentador import encoding_decoding, annotate, constants
12from elfragmentador.model import PepTransformerModel
13import pandas as pd
14import numpy as np
17class SelfAttentionExplorer(torch.no_grad):
18 """SelfAttentionExplorer lets you explore self-attention with a context manager.
20 It is a context manager that takes a PepTransformerModel and wraps the transformer
21 layers to save the self-attention matrices during its activity. Once it closes, the
22 hooks are removed but the attention matrices are kept.
24 Later these matrices can be explored. Check the examples for how to get them.
26 Examples
27 --------
29 >>> model = PepTransformerModel() # Or load the model from a checkpoint
30 >>> _ = model.eval()
31 >>> with SelfAttentionExplorer(model) as sea:
32 ... _ = model.predict_from_seq("MYPEPTIDEK", 2, 30)
33 ... _ = model.predict_from_seq("MY[PHOSPHO]PEPTIDEK", 2, 30)
34 >>> out = sea.get_encoder_attn(layer=0, index=0)
35 >>> type(out)
36 <class 'pandas.core.frame.DataFrame'>
37 >>> list(out)
38 ['n1', 'M2', 'Y3', 'P4', 'E5', 'P6', 'T7', 'I8', 'D9', 'E10', 'K11', 'c12']
39 >>> out = sea.get_decoder_attn(layer=0, index=0)
40 >>> type(out)
41 <class 'pandas.core.frame.DataFrame'>
42 >>> list(out)[:5]
43 ['z1b1', 'z1b2', 'z1b3', 'z1b4', 'z1b5']
44 """
46 def __init__(self, model: PepTransformerModel):
47 logging.info("Initializing SelfAttentionExplorer")
48 super().__init__()
50 self.encoder_viz = {}
51 self.decoder_viz = {}
52 self.aa_seqs = {}
53 self.charges = {}
54 self.handles = []
56 encoder = model.encoder.transformer_encoder
57 decoder = model.decoder.trans_decoder
59 encoder_hook = self._make_hook_transformer_layer(self.encoder_viz)
60 for layer in range(0, len(encoder.layers)):
61 logging.info(f"Adding hook to encoder layer: {layer}")
62 handle = encoder.layers[layer].self_attn.register_forward_hook(encoder_hook)
63 self.handles.append(handle)
65 decoder_hook = self._make_hook_transformer_layer(self.decoder_viz)
66 for layer in range(0, len(decoder.layers)):
67 logging.info(f"Adding hook to decoder layer: {layer}")
68 handle = decoder.layers[layer].self_attn.register_forward_hook(decoder_hook)
69 self.handles.append(handle)
71 aa_hook = self._make_hook_aa_layer(self.aa_seqs)
72 handle = model.encoder.aa_encoder.aa_encoder.register_forward_hook(aa_hook)
73 self.handles.append(handle)
75 charge_hook = self._make_hook_charge(self.charges)
76 handle = model.decoder.charge_encoder.register_forward_hook(charge_hook)
77 self.handles.append(handle)
79 def __enter__(self):
80 super().__enter__()
81 return self
83 def __exit__(self, exc_type, exc_value, exc_traceback):
84 logging.info("Removing Handles")
85 super().__exit__(exc_type, exc_value, exc_traceback)
86 for h in self.handles:
87 h.remove()
89 # TODO consider if all self attention matrices/dataframes should
90 # be calculated on exit. Or even the bipartite graphs
92 def __repr__(self):
93 out = (
94 ">>> SelfAttentionExplorer <<<<\n\n"
95 ">> AA sequences (aa_seqs):\n"
96 f"{self.aa_seqs.__repr__()}\n"
97 ">> Charges (charges):\n"
98 f"{self.charges.__repr__()}\n"
99 f">> Encoder vizs (encoder_viz) {list(self.encoder_viz.values())[0].shape}:\n"
100 f"{self.encoder_viz.__repr__()}\n"
101 f">> Decoder vizs (decoder_viz) {list(self.decoder_viz.values())[0].shape}:\n"
102 f"{self.decoder_viz.__repr__()}"
103 )
105 return out
107 @staticmethod
108 def _make_hook_transformer_layer(target):
109 def hook_fn(m, i, o):
110 if target.get(m, None) is None:
111 target[m] = o[1]
112 else:
113 target[m] = torch.cat([target[m], o[1]])
115 return hook_fn
117 @staticmethod
118 def _make_hook_aa_layer(target):
119 def hook_fn(m, i, o):
120 if target.get(m, None) is None:
121 target[m] = []
122 target[m].append(
123 encoding_decoding.decode_mod_seq(
124 [int(x) for x in i[0]], clip_explicit_term=False
125 )
126 )
128 return hook_fn
130 @staticmethod
131 def _make_hook_charge(target):
132 def hook_fn(m, i, o):
133 if target.get(m, None) is None:
134 target[m] = []
135 target[m].extend([int(x) for x in i[1]])
137 return hook_fn
139 @staticmethod
140 def _norm(attn):
141 attn = (attn - np.mean(attn)) / np.std(attn)
142 return attn
144 def get_encoder_attn(self, layer: int, index: int = 0, norm=False) -> pd.DataFrame:
145 seq = list(self.aa_seqs.values())[0][index]
146 attn = list(self.encoder_viz.values())[layer][index][: len(seq), : len(seq)]
147 if norm:
148 attn = self.norm(attn)
150 names = [x + str(i + 1) for i, x in enumerate(seq)]
151 attn = pd.DataFrame(attn.clone().detach().numpy(), index=names, columns=names)
152 return attn[::-1]
154 def get_decoder_attn(self, layer: int, index: int = 0, norm=True) -> pd.DataFrame:
155 attn = list(self.decoder_viz.values())[layer][index].clone().detach().numpy()
156 seq = list(self.aa_seqs.values())[0][index]
157 charge = list(self.charges.values())[0][index]
158 theo_ions = list(annotate.get_peptide_ions(seq).keys())
159 theo_ions = [x for x in theo_ions if int(x[1]) <= charge]
160 if norm:
161 attn = self._norm(attn)
163 names = constants.FRAG_EMBEDING_LABELS
164 attn = pd.DataFrame(attn, index=names, columns=names)
166 # TODO add charge filtering
167 return attn[theo_ions].loc[theo_ions][::-1]
170def make_bipartite(x):
171 """Makes a bipartite graph from a data frame whose column and row indices are the same"""
172 B = nx.Graph()
173 B.add_nodes_from(
174 [x for x in x.index], bipartite=0
175 ) # Add the node attribute "bipartite"
176 B.add_nodes_from([x + "_" for x in x.index], bipartite=1)
178 for index1 in x.index:
179 for index2 in x.index:
180 B.add_edges_from(
181 [
182 (index1, index2 + "_"),
183 ],
184 weight=x[index1].loc[index2],
185 )
187 return B
190def plot_bipartite_seq(B):
191 """Plots a bipartite graph from a sequence self-attention
193 expects names to be in the form of X[index] and X[index]_
194 """
195 if plt is None:
196 raise ImportError(
197 "Matplotlib is not installed, please install and re-load elfragmentador"
198 )
200 # Separate by group
201 l, r = nx.bipartite.sets(B)
202 pos = {}
203 # Update position for node from each group
204 pos.update((node, (int(node[1:]), 1)) for node in l)
205 pos.update((node, (int(node[1:-1]), 2)) for node in r)
206 weights = list(nx.get_edge_attributes(B, "weight").values())
207 mean_weight = np.array(weights).mean()
208 min_weight = np.array(weights).min()
209 weights = [x if x > mean_weight else min_weight for x in weights]
211 nx.draw(
212 B,
213 pos=pos,
214 edge_color=weights,
215 edge_cmap=plt.cm.Blues,
216 with_labels=True,
217 width=3,
218 alpha=0.5,
219 )
220 plt.show()