Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import logging 

2 

3try: 

4 import matplotlib.pyplot as plt 

5except ImportError: 

6 plt = None 

7 

8import networkx as nx 

9 

10import torch 

11from elfragmentador import encoding_decoding, annotate, constants 

12from elfragmentador.model import PepTransformerModel 

13import pandas as pd 

14import numpy as np 

15 

16 

17class SelfAttentionExplorer(torch.no_grad): 

18 """SelfAttentionExplorer lets you explore self-attention with a context manager. 

19 

20 It is a context manager that takes a PepTransformerModel and wraps the transformer 

21 layers to save the self-attention matrices during its activity. Once it closes, the 

22 hooks are removed but the attention matrices are kept. 

23 

24 Later these matrices can be explored. Check the examples for how to get them. 

25 

26 Examples 

27 -------- 

28 

29 >>> model = PepTransformerModel() # Or load the model from a checkpoint 

30 >>> _ = model.eval() 

31 >>> with SelfAttentionExplorer(model) as sea: 

32 ... _ = model.predict_from_seq("MYPEPTIDEK", 2, 30) 

33 ... _ = model.predict_from_seq("MY[PHOSPHO]PEPTIDEK", 2, 30) 

34 >>> out = sea.get_encoder_attn(layer=0, index=0) 

35 >>> type(out) 

36 <class 'pandas.core.frame.DataFrame'> 

37 >>> list(out) 

38 ['n1', 'M2', 'Y3', 'P4', 'E5', 'P6', 'T7', 'I8', 'D9', 'E10', 'K11', 'c12'] 

39 >>> out = sea.get_decoder_attn(layer=0, index=0) 

40 >>> type(out) 

41 <class 'pandas.core.frame.DataFrame'> 

42 >>> list(out)[:5] 

43 ['z1b1', 'z1b2', 'z1b3', 'z1b4', 'z1b5'] 

44 """ 

45 

46 def __init__(self, model: PepTransformerModel): 

47 logging.info("Initializing SelfAttentionExplorer") 

48 super().__init__() 

49 

50 self.encoder_viz = {} 

51 self.decoder_viz = {} 

52 self.aa_seqs = {} 

53 self.charges = {} 

54 self.handles = [] 

55 

56 encoder = model.encoder.transformer_encoder 

57 decoder = model.decoder.trans_decoder 

58 

59 encoder_hook = self._make_hook_transformer_layer(self.encoder_viz) 

60 for layer in range(0, len(encoder.layers)): 

61 logging.info(f"Adding hook to encoder layer: {layer}") 

62 handle = encoder.layers[layer].self_attn.register_forward_hook(encoder_hook) 

63 self.handles.append(handle) 

64 

65 decoder_hook = self._make_hook_transformer_layer(self.decoder_viz) 

66 for layer in range(0, len(decoder.layers)): 

67 logging.info(f"Adding hook to decoder layer: {layer}") 

68 handle = decoder.layers[layer].self_attn.register_forward_hook(decoder_hook) 

69 self.handles.append(handle) 

70 

71 aa_hook = self._make_hook_aa_layer(self.aa_seqs) 

72 handle = model.encoder.aa_encoder.aa_encoder.register_forward_hook(aa_hook) 

73 self.handles.append(handle) 

74 

75 charge_hook = self._make_hook_charge(self.charges) 

76 handle = model.decoder.charge_encoder.register_forward_hook(charge_hook) 

77 self.handles.append(handle) 

78 

79 def __enter__(self): 

80 super().__enter__() 

81 return self 

82 

83 def __exit__(self, exc_type, exc_value, exc_traceback): 

84 logging.info("Removing Handles") 

85 super().__exit__(exc_type, exc_value, exc_traceback) 

86 for h in self.handles: 

87 h.remove() 

88 

89 # TODO consider if all self attention matrices/dataframes should 

90 # be calculated on exit. Or even the bipartite graphs 

91 

92 def __repr__(self): 

93 out = ( 

94 ">>> SelfAttentionExplorer <<<<\n\n" 

95 ">> AA sequences (aa_seqs):\n" 

96 f"{self.aa_seqs.__repr__()}\n" 

97 ">> Charges (charges):\n" 

98 f"{self.charges.__repr__()}\n" 

99 f">> Encoder vizs (encoder_viz) {list(self.encoder_viz.values())[0].shape}:\n" 

100 f"{self.encoder_viz.__repr__()}\n" 

101 f">> Decoder vizs (decoder_viz) {list(self.decoder_viz.values())[0].shape}:\n" 

102 f"{self.decoder_viz.__repr__()}" 

103 ) 

104 

105 return out 

106 

107 @staticmethod 

108 def _make_hook_transformer_layer(target): 

109 def hook_fn(m, i, o): 

110 if target.get(m, None) is None: 

111 target[m] = o[1] 

112 else: 

113 target[m] = torch.cat([target[m], o[1]]) 

114 

115 return hook_fn 

116 

117 @staticmethod 

118 def _make_hook_aa_layer(target): 

119 def hook_fn(m, i, o): 

120 if target.get(m, None) is None: 

121 target[m] = [] 

122 target[m].append( 

123 encoding_decoding.decode_mod_seq( 

124 [int(x) for x in i[0]], clip_explicit_term=False 

125 ) 

126 ) 

127 

128 return hook_fn 

129 

130 @staticmethod 

131 def _make_hook_charge(target): 

132 def hook_fn(m, i, o): 

133 if target.get(m, None) is None: 

134 target[m] = [] 

135 target[m].extend([int(x) for x in i[1]]) 

136 

137 return hook_fn 

138 

139 @staticmethod 

140 def _norm(attn): 

141 attn = (attn - np.mean(attn)) / np.std(attn) 

142 return attn 

143 

144 def get_encoder_attn(self, layer: int, index: int = 0, norm=False) -> pd.DataFrame: 

145 seq = list(self.aa_seqs.values())[0][index] 

146 attn = list(self.encoder_viz.values())[layer][index][: len(seq), : len(seq)] 

147 if norm: 

148 attn = self.norm(attn) 

149 

150 names = [x + str(i + 1) for i, x in enumerate(seq)] 

151 attn = pd.DataFrame(attn.clone().detach().numpy(), index=names, columns=names) 

152 return attn[::-1] 

153 

154 def get_decoder_attn(self, layer: int, index: int = 0, norm=True) -> pd.DataFrame: 

155 attn = list(self.decoder_viz.values())[layer][index].clone().detach().numpy() 

156 seq = list(self.aa_seqs.values())[0][index] 

157 charge = list(self.charges.values())[0][index] 

158 theo_ions = list(annotate.get_peptide_ions(seq).keys()) 

159 theo_ions = [x for x in theo_ions if int(x[1]) <= charge] 

160 if norm: 

161 attn = self._norm(attn) 

162 

163 names = constants.FRAG_EMBEDING_LABELS 

164 attn = pd.DataFrame(attn, index=names, columns=names) 

165 

166 # TODO add charge filtering 

167 return attn[theo_ions].loc[theo_ions][::-1] 

168 

169 

170def make_bipartite(x): 

171 """Makes a bipartite graph from a data frame whose column and row indices are the same""" 

172 B = nx.Graph() 

173 B.add_nodes_from( 

174 [x for x in x.index], bipartite=0 

175 ) # Add the node attribute "bipartite" 

176 B.add_nodes_from([x + "_" for x in x.index], bipartite=1) 

177 

178 for index1 in x.index: 

179 for index2 in x.index: 

180 B.add_edges_from( 

181 [ 

182 (index1, index2 + "_"), 

183 ], 

184 weight=x[index1].loc[index2], 

185 ) 

186 

187 return B 

188 

189 

190def plot_bipartite_seq(B): 

191 """Plots a bipartite graph from a sequence self-attention 

192 

193 expects names to be in the form of X[index] and X[index]_ 

194 """ 

195 if plt is None: 

196 raise ImportError( 

197 "Matplotlib is not installed, please install and re-load elfragmentador" 

198 ) 

199 

200 # Separate by group 

201 l, r = nx.bipartite.sets(B) 

202 pos = {} 

203 # Update position for node from each group 

204 pos.update((node, (int(node[1:]), 1)) for node in l) 

205 pos.update((node, (int(node[1:-1]), 2)) for node in r) 

206 weights = list(nx.get_edge_attributes(B, "weight").values()) 

207 mean_weight = np.array(weights).mean() 

208 min_weight = np.array(weights).min() 

209 weights = [x if x > mean_weight else min_weight for x in weights] 

210 

211 nx.draw( 

212 B, 

213 pos=pos, 

214 edge_color=weights, 

215 edge_cmap=plt.cm.Blues, 

216 with_labels=True, 

217 width=3, 

218 alpha=0.5, 

219 ) 

220 plt.show()