attention_visualiser.base
1import torch 2import matplotlib.pyplot as plt 3import seaborn as sns 4from transformers import AutoTokenizer, AutoModel 5from transformers import BatchEncoding 6from loguru import logger 7from einops import rearrange 8from abc import ABC, abstractmethod 9import numpy as np 10from typing import Optional 11 12 13class BaseAttentionVisualiser(ABC): 14 """Base abstract class for visualizing attention weights in transformer models. 15 16 This class provides the foundation for visualizing attention weights from 17 different transformer model implementations. Concrete subclasses must implement 18 methods for computing attention values and processing attention vectors. 19 20 Attributes: 21 model: A transformer model from the Hugging Face library 22 tokenizer: A tokenizer matching the model 23 config: Dictionary containing visualization configuration parameters 24 """ 25 26 def __init__( 27 self, 28 model: AutoModel, 29 tokenizer: AutoTokenizer, 30 config: Optional[dict] = None, 31 ) -> None: 32 """Initialize the attention visualizer with a model and tokenizer. 33 34 Args: 35 model: A transformer model from Hugging Face (PyTorch or Flax) 36 tokenizer: A tokenizer matching the model 37 config: Optional dictionary with visualization parameters 38 Default parameters include: 39 - figsize: Tuple specifying figure dimensions 40 - cmap: Colormap for the heatmap 41 - annot: Whether to annotate heatmap cells with values 42 - xlabel: Label for x-axis 43 - ylabel: Label for y-axis 44 """ 45 self.model = model 46 self.tokenizer = tokenizer 47 48 logger.info(f"Model config: {self.model.config}") # type: ignore 49 50 if not config: 51 self.config = { 52 "figsize": (15, 15), 53 "cmap": "viridis", 54 "annot": True, 55 "xlabel": "", 56 "ylabel": "", 57 } 58 logger.info(f"Setting default visualiser config: {self.config}") 59 else: 60 logger.info(f"Visualiser config: {config}") 61 self.config = config 62 63 # a cache for storing already computed attention vectors 64 # these need to be updated by the `compute_attentions` 65 # method 66 self.current_input = None 67 self.cache = None 68 69 def id_to_tokens(self, encoded_input: BatchEncoding) -> list[str]: 70 """Convert token IDs to readable token strings. 71 72 Args: 73 encoded_input: The encoded input from the tokenizer 74 75 Returns: 76 List of token strings corresponding to the input IDs 77 """ 78 tokens = self.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0]) # type: ignore 79 return tokens 80 81 @abstractmethod 82 def compute_attentions(self, encoded_input: BatchEncoding) -> tuple: 83 """Compute attention weights for the given input. 84 85 This method must be implemented by concrete subclasses to compute 86 attention weights specific to the model implementation. 87 88 Args: 89 encoded_input: The encoded input from the tokenizer 90 91 Returns: 92 A tuple containing attention weights 93 """ 94 pass 95 96 @abstractmethod 97 def get_attention_vector_mean( 98 self, attention: torch.Tensor, axis: int = 0 99 ) -> np.ndarray: 100 """Calculate mean of attention vectors along specified axis. 101 102 This method must be implemented by concrete subclasses to handle 103 either PyTorch or JAX tensors appropriately. 104 105 Args: 106 attention: Attention tensor from the model 107 axis: Axis along which to compute the mean (default: 0) 108 109 Returns: 110 NumPy array of mean attention values 111 """ 112 pass 113 114 def visualise_attn_layer(self, idx: int, encoded_input: BatchEncoding) -> None: 115 """Visualize attention weights for a specific layer. 116 117 Creates a heatmap visualization of the attention weights for the specified 118 layer index. 119 120 Args: 121 idx: Index of the attention layer to visualize. 122 Negative indices count from the end (-1 is the last layer). 123 encoded_input: The encoded input from the tokenizer 124 125 Raises: 126 AssertionError: If idx is outside the range of available attention layers 127 """ 128 tokens = self.id_to_tokens(encoded_input) 129 130 attentions = self.compute_attentions(encoded_input) 131 n_attns = len(attentions) 132 133 # idx must no exceed attn_heads 134 assert idx < n_attns, ( 135 f"index must be less than the number of attention outputs in the model, which is: {n_attns}" 136 ) 137 138 # setting idx = -1 will get the last attention layer activations but 139 # the plot title will also show -1 140 if idx < 0: 141 idx = n_attns + idx 142 143 # get rid of the additional dimension since single input 144 attention = rearrange(attentions[idx], "1 a b c -> a b c") 145 # take mean over dim 0 146 attention = self.get_attention_vector_mean(attention) 147 148 plt.figure(figsize=self.config.get("figsize")) 149 sns.heatmap( 150 attention, 151 cmap=self.config.get("cmap"), 152 annot=self.config.get("annot"), 153 xticklabels=tokens, 154 yticklabels=tokens, 155 ) 156 157 plt.title(f"Attention Weights for Layer idx: {idx}") 158 plt.xlabel(self.config.get("xlabel")) # type: ignore 159 plt.ylabel(self.config.get("ylabel")) # type: ignore 160 plt.show()
14class BaseAttentionVisualiser(ABC): 15 """Base abstract class for visualizing attention weights in transformer models. 16 17 This class provides the foundation for visualizing attention weights from 18 different transformer model implementations. Concrete subclasses must implement 19 methods for computing attention values and processing attention vectors. 20 21 Attributes: 22 model: A transformer model from the Hugging Face library 23 tokenizer: A tokenizer matching the model 24 config: Dictionary containing visualization configuration parameters 25 """ 26 27 def __init__( 28 self, 29 model: AutoModel, 30 tokenizer: AutoTokenizer, 31 config: Optional[dict] = None, 32 ) -> None: 33 """Initialize the attention visualizer with a model and tokenizer. 34 35 Args: 36 model: A transformer model from Hugging Face (PyTorch or Flax) 37 tokenizer: A tokenizer matching the model 38 config: Optional dictionary with visualization parameters 39 Default parameters include: 40 - figsize: Tuple specifying figure dimensions 41 - cmap: Colormap for the heatmap 42 - annot: Whether to annotate heatmap cells with values 43 - xlabel: Label for x-axis 44 - ylabel: Label for y-axis 45 """ 46 self.model = model 47 self.tokenizer = tokenizer 48 49 logger.info(f"Model config: {self.model.config}") # type: ignore 50 51 if not config: 52 self.config = { 53 "figsize": (15, 15), 54 "cmap": "viridis", 55 "annot": True, 56 "xlabel": "", 57 "ylabel": "", 58 } 59 logger.info(f"Setting default visualiser config: {self.config}") 60 else: 61 logger.info(f"Visualiser config: {config}") 62 self.config = config 63 64 # a cache for storing already computed attention vectors 65 # these need to be updated by the `compute_attentions` 66 # method 67 self.current_input = None 68 self.cache = None 69 70 def id_to_tokens(self, encoded_input: BatchEncoding) -> list[str]: 71 """Convert token IDs to readable token strings. 72 73 Args: 74 encoded_input: The encoded input from the tokenizer 75 76 Returns: 77 List of token strings corresponding to the input IDs 78 """ 79 tokens = self.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0]) # type: ignore 80 return tokens 81 82 @abstractmethod 83 def compute_attentions(self, encoded_input: BatchEncoding) -> tuple: 84 """Compute attention weights for the given input. 85 86 This method must be implemented by concrete subclasses to compute 87 attention weights specific to the model implementation. 88 89 Args: 90 encoded_input: The encoded input from the tokenizer 91 92 Returns: 93 A tuple containing attention weights 94 """ 95 pass 96 97 @abstractmethod 98 def get_attention_vector_mean( 99 self, attention: torch.Tensor, axis: int = 0 100 ) -> np.ndarray: 101 """Calculate mean of attention vectors along specified axis. 102 103 This method must be implemented by concrete subclasses to handle 104 either PyTorch or JAX tensors appropriately. 105 106 Args: 107 attention: Attention tensor from the model 108 axis: Axis along which to compute the mean (default: 0) 109 110 Returns: 111 NumPy array of mean attention values 112 """ 113 pass 114 115 def visualise_attn_layer(self, idx: int, encoded_input: BatchEncoding) -> None: 116 """Visualize attention weights for a specific layer. 117 118 Creates a heatmap visualization of the attention weights for the specified 119 layer index. 120 121 Args: 122 idx: Index of the attention layer to visualize. 123 Negative indices count from the end (-1 is the last layer). 124 encoded_input: The encoded input from the tokenizer 125 126 Raises: 127 AssertionError: If idx is outside the range of available attention layers 128 """ 129 tokens = self.id_to_tokens(encoded_input) 130 131 attentions = self.compute_attentions(encoded_input) 132 n_attns = len(attentions) 133 134 # idx must no exceed attn_heads 135 assert idx < n_attns, ( 136 f"index must be less than the number of attention outputs in the model, which is: {n_attns}" 137 ) 138 139 # setting idx = -1 will get the last attention layer activations but 140 # the plot title will also show -1 141 if idx < 0: 142 idx = n_attns + idx 143 144 # get rid of the additional dimension since single input 145 attention = rearrange(attentions[idx], "1 a b c -> a b c") 146 # take mean over dim 0 147 attention = self.get_attention_vector_mean(attention) 148 149 plt.figure(figsize=self.config.get("figsize")) 150 sns.heatmap( 151 attention, 152 cmap=self.config.get("cmap"), 153 annot=self.config.get("annot"), 154 xticklabels=tokens, 155 yticklabels=tokens, 156 ) 157 158 plt.title(f"Attention Weights for Layer idx: {idx}") 159 plt.xlabel(self.config.get("xlabel")) # type: ignore 160 plt.ylabel(self.config.get("ylabel")) # type: ignore 161 plt.show()
Base abstract class for visualizing attention weights in transformer models.
This class provides the foundation for visualizing attention weights from different transformer model implementations. Concrete subclasses must implement methods for computing attention values and processing attention vectors.
Attributes: model: A transformer model from the Hugging Face library tokenizer: A tokenizer matching the model config: Dictionary containing visualization configuration parameters
27 def __init__( 28 self, 29 model: AutoModel, 30 tokenizer: AutoTokenizer, 31 config: Optional[dict] = None, 32 ) -> None: 33 """Initialize the attention visualizer with a model and tokenizer. 34 35 Args: 36 model: A transformer model from Hugging Face (PyTorch or Flax) 37 tokenizer: A tokenizer matching the model 38 config: Optional dictionary with visualization parameters 39 Default parameters include: 40 - figsize: Tuple specifying figure dimensions 41 - cmap: Colormap for the heatmap 42 - annot: Whether to annotate heatmap cells with values 43 - xlabel: Label for x-axis 44 - ylabel: Label for y-axis 45 """ 46 self.model = model 47 self.tokenizer = tokenizer 48 49 logger.info(f"Model config: {self.model.config}") # type: ignore 50 51 if not config: 52 self.config = { 53 "figsize": (15, 15), 54 "cmap": "viridis", 55 "annot": True, 56 "xlabel": "", 57 "ylabel": "", 58 } 59 logger.info(f"Setting default visualiser config: {self.config}") 60 else: 61 logger.info(f"Visualiser config: {config}") 62 self.config = config 63 64 # a cache for storing already computed attention vectors 65 # these need to be updated by the `compute_attentions` 66 # method 67 self.current_input = None 68 self.cache = None
Initialize the attention visualizer with a model and tokenizer.
Args: model: A transformer model from Hugging Face (PyTorch or Flax) tokenizer: A tokenizer matching the model config: Optional dictionary with visualization parameters Default parameters include: - figsize: Tuple specifying figure dimensions - cmap: Colormap for the heatmap - annot: Whether to annotate heatmap cells with values - xlabel: Label for x-axis - ylabel: Label for y-axis
70 def id_to_tokens(self, encoded_input: BatchEncoding) -> list[str]: 71 """Convert token IDs to readable token strings. 72 73 Args: 74 encoded_input: The encoded input from the tokenizer 75 76 Returns: 77 List of token strings corresponding to the input IDs 78 """ 79 tokens = self.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0]) # type: ignore 80 return tokens
Convert token IDs to readable token strings.
Args: encoded_input: The encoded input from the tokenizer
Returns: List of token strings corresponding to the input IDs
82 @abstractmethod 83 def compute_attentions(self, encoded_input: BatchEncoding) -> tuple: 84 """Compute attention weights for the given input. 85 86 This method must be implemented by concrete subclasses to compute 87 attention weights specific to the model implementation. 88 89 Args: 90 encoded_input: The encoded input from the tokenizer 91 92 Returns: 93 A tuple containing attention weights 94 """ 95 pass
Compute attention weights for the given input.
This method must be implemented by concrete subclasses to compute attention weights specific to the model implementation.
Args: encoded_input: The encoded input from the tokenizer
Returns: A tuple containing attention weights
97 @abstractmethod 98 def get_attention_vector_mean( 99 self, attention: torch.Tensor, axis: int = 0 100 ) -> np.ndarray: 101 """Calculate mean of attention vectors along specified axis. 102 103 This method must be implemented by concrete subclasses to handle 104 either PyTorch or JAX tensors appropriately. 105 106 Args: 107 attention: Attention tensor from the model 108 axis: Axis along which to compute the mean (default: 0) 109 110 Returns: 111 NumPy array of mean attention values 112 """ 113 pass
Calculate mean of attention vectors along specified axis.
This method must be implemented by concrete subclasses to handle either PyTorch or JAX tensors appropriately.
Args: attention: Attention tensor from the model axis: Axis along which to compute the mean (default: 0)
Returns: NumPy array of mean attention values
115 def visualise_attn_layer(self, idx: int, encoded_input: BatchEncoding) -> None: 116 """Visualize attention weights for a specific layer. 117 118 Creates a heatmap visualization of the attention weights for the specified 119 layer index. 120 121 Args: 122 idx: Index of the attention layer to visualize. 123 Negative indices count from the end (-1 is the last layer). 124 encoded_input: The encoded input from the tokenizer 125 126 Raises: 127 AssertionError: If idx is outside the range of available attention layers 128 """ 129 tokens = self.id_to_tokens(encoded_input) 130 131 attentions = self.compute_attentions(encoded_input) 132 n_attns = len(attentions) 133 134 # idx must no exceed attn_heads 135 assert idx < n_attns, ( 136 f"index must be less than the number of attention outputs in the model, which is: {n_attns}" 137 ) 138 139 # setting idx = -1 will get the last attention layer activations but 140 # the plot title will also show -1 141 if idx < 0: 142 idx = n_attns + idx 143 144 # get rid of the additional dimension since single input 145 attention = rearrange(attentions[idx], "1 a b c -> a b c") 146 # take mean over dim 0 147 attention = self.get_attention_vector_mean(attention) 148 149 plt.figure(figsize=self.config.get("figsize")) 150 sns.heatmap( 151 attention, 152 cmap=self.config.get("cmap"), 153 annot=self.config.get("annot"), 154 xticklabels=tokens, 155 yticklabels=tokens, 156 ) 157 158 plt.title(f"Attention Weights for Layer idx: {idx}") 159 plt.xlabel(self.config.get("xlabel")) # type: ignore 160 plt.ylabel(self.config.get("ylabel")) # type: ignore 161 plt.show()
Visualize attention weights for a specific layer.
Creates a heatmap visualization of the attention weights for the specified layer index.
Args: idx: Index of the attention layer to visualize. Negative indices count from the end (-1 is the last layer). encoded_input: The encoded input from the tokenizer
Raises: AssertionError: If idx is outside the range of available attention layers