attention_visualiser.visualiser
1import torch 2from transformers import AutoTokenizer, AutoModel 3from transformers import BatchEncoding 4from attention_visualiser.base import BaseAttentionVisualiser 5import numpy as np 6from typing import Optional 7 8 9class AttentionVisualiser(BaseAttentionVisualiser): 10 """Attention visualizer for PyTorch-based transformer models. 11 12 This class implements the abstract methods from BaseAttentionVisualiser 13 specifically for models implemented in PyTorch. It handles the extraction 14 and processing of attention weights from PyTorch transformer models. 15 16 Attributes: 17 model: A PyTorch-based transformer model from Hugging Face 18 tokenizer: A tokenizer matching the model 19 config: Dictionary containing visualization configuration parameters 20 """ 21 22 def __init__( 23 self, model: AutoModel, tokenizer: AutoTokenizer, config: Optional[dict] = None 24 ) -> None: 25 """Initialize the PyTorch-specific attention visualizer. 26 27 Args: 28 model: A PyTorch-based transformer model from Hugging Face 29 tokenizer: A tokenizer matching the model 30 config: Optional dictionary with visualization parameters 31 """ 32 super().__init__(model, tokenizer, config) 33 34 def compute_attentions(self, encoded_input: BatchEncoding) -> tuple: 35 """Compute attention weights for the given input using a PyTorch model. 36 37 Runs the PyTorch model in inference mode with output_attentions flag set to True 38 and extracts the attention weights from the model output. 39 40 Args: 41 encoded_input: The encoded input from the tokenizer 42 43 Returns: 44 A tuple containing attention weights from all layers of the model 45 """ 46 if encoded_input == self.current_input: 47 # return from cache 48 return self.cache 49 50 # else recompute 51 with torch.no_grad(): 52 output = self.model(**encoded_input, output_attentions=True) # type: ignore 53 54 attentions = output.attentions 55 56 # update cache and current input 57 self.current_input = encoded_input 58 self.cache = attentions 59 60 return attentions 61 62 def get_attention_vector_mean( 63 self, attention: torch.Tensor, axis: int = 0 64 ) -> np.ndarray: 65 """Calculate mean of PyTorch attention vectors along specified axis. 66 67 Computes the mean of the attention tensor and converts it to a NumPy array. 68 69 Args: 70 attention: PyTorch tensor containing attention weights 71 axis: Axis along which to compute the mean (default: 0) 72 73 Returns: 74 NumPy array of mean attention values 75 """ 76 return torch.mean(attention, dim=axis).detach().cpu().numpy()
10class AttentionVisualiser(BaseAttentionVisualiser): 11 """Attention visualizer for PyTorch-based transformer models. 12 13 This class implements the abstract methods from BaseAttentionVisualiser 14 specifically for models implemented in PyTorch. It handles the extraction 15 and processing of attention weights from PyTorch transformer models. 16 17 Attributes: 18 model: A PyTorch-based transformer model from Hugging Face 19 tokenizer: A tokenizer matching the model 20 config: Dictionary containing visualization configuration parameters 21 """ 22 23 def __init__( 24 self, model: AutoModel, tokenizer: AutoTokenizer, config: Optional[dict] = None 25 ) -> None: 26 """Initialize the PyTorch-specific attention visualizer. 27 28 Args: 29 model: A PyTorch-based transformer model from Hugging Face 30 tokenizer: A tokenizer matching the model 31 config: Optional dictionary with visualization parameters 32 """ 33 super().__init__(model, tokenizer, config) 34 35 def compute_attentions(self, encoded_input: BatchEncoding) -> tuple: 36 """Compute attention weights for the given input using a PyTorch model. 37 38 Runs the PyTorch model in inference mode with output_attentions flag set to True 39 and extracts the attention weights from the model output. 40 41 Args: 42 encoded_input: The encoded input from the tokenizer 43 44 Returns: 45 A tuple containing attention weights from all layers of the model 46 """ 47 if encoded_input == self.current_input: 48 # return from cache 49 return self.cache 50 51 # else recompute 52 with torch.no_grad(): 53 output = self.model(**encoded_input, output_attentions=True) # type: ignore 54 55 attentions = output.attentions 56 57 # update cache and current input 58 self.current_input = encoded_input 59 self.cache = attentions 60 61 return attentions 62 63 def get_attention_vector_mean( 64 self, attention: torch.Tensor, axis: int = 0 65 ) -> np.ndarray: 66 """Calculate mean of PyTorch attention vectors along specified axis. 67 68 Computes the mean of the attention tensor and converts it to a NumPy array. 69 70 Args: 71 attention: PyTorch tensor containing attention weights 72 axis: Axis along which to compute the mean (default: 0) 73 74 Returns: 75 NumPy array of mean attention values 76 """ 77 return torch.mean(attention, dim=axis).detach().cpu().numpy()
Attention visualizer for PyTorch-based transformer models.
This class implements the abstract methods from BaseAttentionVisualiser specifically for models implemented in PyTorch. It handles the extraction and processing of attention weights from PyTorch transformer models.
Attributes: model: A PyTorch-based transformer model from Hugging Face tokenizer: A tokenizer matching the model config: Dictionary containing visualization configuration parameters
23 def __init__( 24 self, model: AutoModel, tokenizer: AutoTokenizer, config: Optional[dict] = None 25 ) -> None: 26 """Initialize the PyTorch-specific attention visualizer. 27 28 Args: 29 model: A PyTorch-based transformer model from Hugging Face 30 tokenizer: A tokenizer matching the model 31 config: Optional dictionary with visualization parameters 32 """ 33 super().__init__(model, tokenizer, config)
Initialize the PyTorch-specific attention visualizer.
Args: model: A PyTorch-based transformer model from Hugging Face tokenizer: A tokenizer matching the model config: Optional dictionary with visualization parameters
35 def compute_attentions(self, encoded_input: BatchEncoding) -> tuple: 36 """Compute attention weights for the given input using a PyTorch model. 37 38 Runs the PyTorch model in inference mode with output_attentions flag set to True 39 and extracts the attention weights from the model output. 40 41 Args: 42 encoded_input: The encoded input from the tokenizer 43 44 Returns: 45 A tuple containing attention weights from all layers of the model 46 """ 47 if encoded_input == self.current_input: 48 # return from cache 49 return self.cache 50 51 # else recompute 52 with torch.no_grad(): 53 output = self.model(**encoded_input, output_attentions=True) # type: ignore 54 55 attentions = output.attentions 56 57 # update cache and current input 58 self.current_input = encoded_input 59 self.cache = attentions 60 61 return attentions
Compute attention weights for the given input using a PyTorch model.
Runs the PyTorch model in inference mode with output_attentions flag set to True and extracts the attention weights from the model output.
Args: encoded_input: The encoded input from the tokenizer
Returns: A tuple containing attention weights from all layers of the model
63 def get_attention_vector_mean( 64 self, attention: torch.Tensor, axis: int = 0 65 ) -> np.ndarray: 66 """Calculate mean of PyTorch attention vectors along specified axis. 67 68 Computes the mean of the attention tensor and converts it to a NumPy array. 69 70 Args: 71 attention: PyTorch tensor containing attention weights 72 axis: Axis along which to compute the mean (default: 0) 73 74 Returns: 75 NumPy array of mean attention values 76 """ 77 return torch.mean(attention, dim=axis).detach().cpu().numpy()
Calculate mean of PyTorch attention vectors along specified axis.
Computes the mean of the attention tensor and converts it to a NumPy array.
Args: attention: PyTorch tensor containing attention weights axis: Axis along which to compute the mean (default: 0)
Returns: NumPy array of mean attention values