attention_visualiser.base

  1import torch
  2import matplotlib.pyplot as plt
  3import seaborn as sns
  4from transformers import AutoTokenizer, AutoModel
  5from transformers import BatchEncoding
  6from loguru import logger
  7from einops import rearrange
  8from abc import ABC, abstractmethod
  9import numpy as np
 10from typing import Optional
 11
 12
 13class BaseAttentionVisualiser(ABC):
 14    """Base abstract class for visualizing attention weights in transformer models.
 15
 16    This class provides the foundation for visualizing attention weights from
 17    different transformer model implementations. Concrete subclasses must implement
 18    methods for computing attention values and processing attention vectors.
 19
 20    Attributes:
 21        model: A transformer model from the Hugging Face library
 22        tokenizer: A tokenizer matching the model
 23        config: Dictionary containing visualization configuration parameters
 24    """
 25
 26    def __init__(
 27        self,
 28        model: AutoModel,
 29        tokenizer: AutoTokenizer,
 30        config: Optional[dict] = None,
 31    ) -> None:
 32        """Initialize the attention visualizer with a model and tokenizer.
 33
 34        Args:
 35            model: A transformer model from Hugging Face (PyTorch or Flax)
 36            tokenizer: A tokenizer matching the model
 37            config: Optional dictionary with visualization parameters
 38                   Default parameters include:
 39                   - figsize: Tuple specifying figure dimensions
 40                   - cmap: Colormap for the heatmap
 41                   - annot: Whether to annotate heatmap cells with values
 42                   - xlabel: Label for x-axis
 43                   - ylabel: Label for y-axis
 44        """
 45        self.model = model
 46        self.tokenizer = tokenizer
 47
 48        logger.info(f"Model config: {self.model.config}")  # type: ignore
 49
 50        if not config:
 51            self.config = {
 52                "figsize": (15, 15),
 53                "cmap": "viridis",
 54                "annot": True,
 55                "xlabel": "",
 56                "ylabel": "",
 57            }
 58            logger.info(f"Setting default visualiser config: {self.config}")
 59        else:
 60            logger.info(f"Visualiser config: {config}")
 61            self.config = config
 62
 63        # a cache for storing already computed attention vectors
 64        # these need to be updated by the `compute_attentions`
 65        # method
 66        self.current_input = None
 67        self.cache = None
 68
 69    def id_to_tokens(self, encoded_input: BatchEncoding) -> list[str]:
 70        """Convert token IDs to readable token strings.
 71
 72        Args:
 73            encoded_input: The encoded input from the tokenizer
 74
 75        Returns:
 76            List of token strings corresponding to the input IDs
 77        """
 78        tokens = self.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0])  # type: ignore
 79        return tokens
 80
 81    @abstractmethod
 82    def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
 83        """Compute attention weights for the given input.
 84
 85        This method must be implemented by concrete subclasses to compute
 86        attention weights specific to the model implementation.
 87
 88        Args:
 89            encoded_input: The encoded input from the tokenizer
 90
 91        Returns:
 92            A tuple containing attention weights
 93        """
 94        pass
 95
 96    @abstractmethod
 97    def get_attention_vector_mean(
 98        self, attention: torch.Tensor, axis: int = 0
 99    ) -> np.ndarray:
100        """Calculate mean of attention vectors along specified axis.
101
102        This method must be implemented by concrete subclasses to handle
103        either PyTorch or JAX tensors appropriately.
104
105        Args:
106            attention: Attention tensor from the model
107            axis: Axis along which to compute the mean (default: 0)
108
109        Returns:
110            NumPy array of mean attention values
111        """
112        pass
113
114    def visualise_attn_layer(self, idx: int, encoded_input: BatchEncoding) -> None:
115        """Visualize attention weights for a specific layer.
116
117        Creates a heatmap visualization of the attention weights for the specified
118        layer index.
119
120        Args:
121            idx: Index of the attention layer to visualize.
122                 Negative indices count from the end (-1 is the last layer).
123            encoded_input: The encoded input from the tokenizer
124
125        Raises:
126            AssertionError: If idx is outside the range of available attention layers
127        """
128        tokens = self.id_to_tokens(encoded_input)
129
130        attentions = self.compute_attentions(encoded_input)
131        n_attns = len(attentions)
132
133        # idx must no exceed attn_heads
134        assert idx < n_attns, (
135            f"index must be less than the number of attention outputs in the model, which is: {n_attns}"
136        )
137
138        # setting idx = -1 will get the last attention layer activations but
139        # the plot title will also show -1
140        if idx < 0:
141            idx = n_attns + idx
142
143        # get rid of the additional dimension since single input
144        attention = rearrange(attentions[idx], "1 a b c -> a b c")
145        # take mean over dim 0
146        attention = self.get_attention_vector_mean(attention)
147
148        plt.figure(figsize=self.config.get("figsize"))
149        sns.heatmap(
150            attention,
151            cmap=self.config.get("cmap"),
152            annot=self.config.get("annot"),
153            xticklabels=tokens,
154            yticklabels=tokens,
155        )
156
157        plt.title(f"Attention Weights for Layer idx: {idx}")
158        plt.xlabel(self.config.get("xlabel"))  # type: ignore
159        plt.ylabel(self.config.get("ylabel"))  # type: ignore
160        plt.show()
class BaseAttentionVisualiser(abc.ABC):
 14class BaseAttentionVisualiser(ABC):
 15    """Base abstract class for visualizing attention weights in transformer models.
 16
 17    This class provides the foundation for visualizing attention weights from
 18    different transformer model implementations. Concrete subclasses must implement
 19    methods for computing attention values and processing attention vectors.
 20
 21    Attributes:
 22        model: A transformer model from the Hugging Face library
 23        tokenizer: A tokenizer matching the model
 24        config: Dictionary containing visualization configuration parameters
 25    """
 26
 27    def __init__(
 28        self,
 29        model: AutoModel,
 30        tokenizer: AutoTokenizer,
 31        config: Optional[dict] = None,
 32    ) -> None:
 33        """Initialize the attention visualizer with a model and tokenizer.
 34
 35        Args:
 36            model: A transformer model from Hugging Face (PyTorch or Flax)
 37            tokenizer: A tokenizer matching the model
 38            config: Optional dictionary with visualization parameters
 39                   Default parameters include:
 40                   - figsize: Tuple specifying figure dimensions
 41                   - cmap: Colormap for the heatmap
 42                   - annot: Whether to annotate heatmap cells with values
 43                   - xlabel: Label for x-axis
 44                   - ylabel: Label for y-axis
 45        """
 46        self.model = model
 47        self.tokenizer = tokenizer
 48
 49        logger.info(f"Model config: {self.model.config}")  # type: ignore
 50
 51        if not config:
 52            self.config = {
 53                "figsize": (15, 15),
 54                "cmap": "viridis",
 55                "annot": True,
 56                "xlabel": "",
 57                "ylabel": "",
 58            }
 59            logger.info(f"Setting default visualiser config: {self.config}")
 60        else:
 61            logger.info(f"Visualiser config: {config}")
 62            self.config = config
 63
 64        # a cache for storing already computed attention vectors
 65        # these need to be updated by the `compute_attentions`
 66        # method
 67        self.current_input = None
 68        self.cache = None
 69
 70    def id_to_tokens(self, encoded_input: BatchEncoding) -> list[str]:
 71        """Convert token IDs to readable token strings.
 72
 73        Args:
 74            encoded_input: The encoded input from the tokenizer
 75
 76        Returns:
 77            List of token strings corresponding to the input IDs
 78        """
 79        tokens = self.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0])  # type: ignore
 80        return tokens
 81
 82    @abstractmethod
 83    def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
 84        """Compute attention weights for the given input.
 85
 86        This method must be implemented by concrete subclasses to compute
 87        attention weights specific to the model implementation.
 88
 89        Args:
 90            encoded_input: The encoded input from the tokenizer
 91
 92        Returns:
 93            A tuple containing attention weights
 94        """
 95        pass
 96
 97    @abstractmethod
 98    def get_attention_vector_mean(
 99        self, attention: torch.Tensor, axis: int = 0
100    ) -> np.ndarray:
101        """Calculate mean of attention vectors along specified axis.
102
103        This method must be implemented by concrete subclasses to handle
104        either PyTorch or JAX tensors appropriately.
105
106        Args:
107            attention: Attention tensor from the model
108            axis: Axis along which to compute the mean (default: 0)
109
110        Returns:
111            NumPy array of mean attention values
112        """
113        pass
114
115    def visualise_attn_layer(self, idx: int, encoded_input: BatchEncoding) -> None:
116        """Visualize attention weights for a specific layer.
117
118        Creates a heatmap visualization of the attention weights for the specified
119        layer index.
120
121        Args:
122            idx: Index of the attention layer to visualize.
123                 Negative indices count from the end (-1 is the last layer).
124            encoded_input: The encoded input from the tokenizer
125
126        Raises:
127            AssertionError: If idx is outside the range of available attention layers
128        """
129        tokens = self.id_to_tokens(encoded_input)
130
131        attentions = self.compute_attentions(encoded_input)
132        n_attns = len(attentions)
133
134        # idx must no exceed attn_heads
135        assert idx < n_attns, (
136            f"index must be less than the number of attention outputs in the model, which is: {n_attns}"
137        )
138
139        # setting idx = -1 will get the last attention layer activations but
140        # the plot title will also show -1
141        if idx < 0:
142            idx = n_attns + idx
143
144        # get rid of the additional dimension since single input
145        attention = rearrange(attentions[idx], "1 a b c -> a b c")
146        # take mean over dim 0
147        attention = self.get_attention_vector_mean(attention)
148
149        plt.figure(figsize=self.config.get("figsize"))
150        sns.heatmap(
151            attention,
152            cmap=self.config.get("cmap"),
153            annot=self.config.get("annot"),
154            xticklabels=tokens,
155            yticklabels=tokens,
156        )
157
158        plt.title(f"Attention Weights for Layer idx: {idx}")
159        plt.xlabel(self.config.get("xlabel"))  # type: ignore
160        plt.ylabel(self.config.get("ylabel"))  # type: ignore
161        plt.show()

Base abstract class for visualizing attention weights in transformer models.

This class provides the foundation for visualizing attention weights from different transformer model implementations. Concrete subclasses must implement methods for computing attention values and processing attention vectors.

Attributes: model: A transformer model from the Hugging Face library tokenizer: A tokenizer matching the model config: Dictionary containing visualization configuration parameters

BaseAttentionVisualiser( model: transformers.models.auto.modeling_auto.AutoModel, tokenizer: transformers.models.auto.tokenization_auto.AutoTokenizer, config: Optional[dict] = None)
27    def __init__(
28        self,
29        model: AutoModel,
30        tokenizer: AutoTokenizer,
31        config: Optional[dict] = None,
32    ) -> None:
33        """Initialize the attention visualizer with a model and tokenizer.
34
35        Args:
36            model: A transformer model from Hugging Face (PyTorch or Flax)
37            tokenizer: A tokenizer matching the model
38            config: Optional dictionary with visualization parameters
39                   Default parameters include:
40                   - figsize: Tuple specifying figure dimensions
41                   - cmap: Colormap for the heatmap
42                   - annot: Whether to annotate heatmap cells with values
43                   - xlabel: Label for x-axis
44                   - ylabel: Label for y-axis
45        """
46        self.model = model
47        self.tokenizer = tokenizer
48
49        logger.info(f"Model config: {self.model.config}")  # type: ignore
50
51        if not config:
52            self.config = {
53                "figsize": (15, 15),
54                "cmap": "viridis",
55                "annot": True,
56                "xlabel": "",
57                "ylabel": "",
58            }
59            logger.info(f"Setting default visualiser config: {self.config}")
60        else:
61            logger.info(f"Visualiser config: {config}")
62            self.config = config
63
64        # a cache for storing already computed attention vectors
65        # these need to be updated by the `compute_attentions`
66        # method
67        self.current_input = None
68        self.cache = None

Initialize the attention visualizer with a model and tokenizer.

Args: model: A transformer model from Hugging Face (PyTorch or Flax) tokenizer: A tokenizer matching the model config: Optional dictionary with visualization parameters Default parameters include: - figsize: Tuple specifying figure dimensions - cmap: Colormap for the heatmap - annot: Whether to annotate heatmap cells with values - xlabel: Label for x-axis - ylabel: Label for y-axis

model
tokenizer
current_input
cache
def id_to_tokens( self, encoded_input: transformers.tokenization_utils_base.BatchEncoding) -> list[str]:
70    def id_to_tokens(self, encoded_input: BatchEncoding) -> list[str]:
71        """Convert token IDs to readable token strings.
72
73        Args:
74            encoded_input: The encoded input from the tokenizer
75
76        Returns:
77            List of token strings corresponding to the input IDs
78        """
79        tokens = self.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0])  # type: ignore
80        return tokens

Convert token IDs to readable token strings.

Args: encoded_input: The encoded input from the tokenizer

Returns: List of token strings corresponding to the input IDs

@abstractmethod
def compute_attentions( self, encoded_input: transformers.tokenization_utils_base.BatchEncoding) -> tuple:
82    @abstractmethod
83    def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
84        """Compute attention weights for the given input.
85
86        This method must be implemented by concrete subclasses to compute
87        attention weights specific to the model implementation.
88
89        Args:
90            encoded_input: The encoded input from the tokenizer
91
92        Returns:
93            A tuple containing attention weights
94        """
95        pass

Compute attention weights for the given input.

This method must be implemented by concrete subclasses to compute attention weights specific to the model implementation.

Args: encoded_input: The encoded input from the tokenizer

Returns: A tuple containing attention weights

@abstractmethod
def get_attention_vector_mean(self, attention: torch.Tensor, axis: int = 0) -> numpy.ndarray:
 97    @abstractmethod
 98    def get_attention_vector_mean(
 99        self, attention: torch.Tensor, axis: int = 0
100    ) -> np.ndarray:
101        """Calculate mean of attention vectors along specified axis.
102
103        This method must be implemented by concrete subclasses to handle
104        either PyTorch or JAX tensors appropriately.
105
106        Args:
107            attention: Attention tensor from the model
108            axis: Axis along which to compute the mean (default: 0)
109
110        Returns:
111            NumPy array of mean attention values
112        """
113        pass

Calculate mean of attention vectors along specified axis.

This method must be implemented by concrete subclasses to handle either PyTorch or JAX tensors appropriately.

Args: attention: Attention tensor from the model axis: Axis along which to compute the mean (default: 0)

Returns: NumPy array of mean attention values

def visualise_attn_layer( self, idx: int, encoded_input: transformers.tokenization_utils_base.BatchEncoding) -> None:
115    def visualise_attn_layer(self, idx: int, encoded_input: BatchEncoding) -> None:
116        """Visualize attention weights for a specific layer.
117
118        Creates a heatmap visualization of the attention weights for the specified
119        layer index.
120
121        Args:
122            idx: Index of the attention layer to visualize.
123                 Negative indices count from the end (-1 is the last layer).
124            encoded_input: The encoded input from the tokenizer
125
126        Raises:
127            AssertionError: If idx is outside the range of available attention layers
128        """
129        tokens = self.id_to_tokens(encoded_input)
130
131        attentions = self.compute_attentions(encoded_input)
132        n_attns = len(attentions)
133
134        # idx must no exceed attn_heads
135        assert idx < n_attns, (
136            f"index must be less than the number of attention outputs in the model, which is: {n_attns}"
137        )
138
139        # setting idx = -1 will get the last attention layer activations but
140        # the plot title will also show -1
141        if idx < 0:
142            idx = n_attns + idx
143
144        # get rid of the additional dimension since single input
145        attention = rearrange(attentions[idx], "1 a b c -> a b c")
146        # take mean over dim 0
147        attention = self.get_attention_vector_mean(attention)
148
149        plt.figure(figsize=self.config.get("figsize"))
150        sns.heatmap(
151            attention,
152            cmap=self.config.get("cmap"),
153            annot=self.config.get("annot"),
154            xticklabels=tokens,
155            yticklabels=tokens,
156        )
157
158        plt.title(f"Attention Weights for Layer idx: {idx}")
159        plt.xlabel(self.config.get("xlabel"))  # type: ignore
160        plt.ylabel(self.config.get("ylabel"))  # type: ignore
161        plt.show()

Visualize attention weights for a specific layer.

Creates a heatmap visualization of the attention weights for the specified layer index.

Args: idx: Index of the attention layer to visualize. Negative indices count from the end (-1 is the last layer). encoded_input: The encoded input from the tokenizer

Raises: AssertionError: If idx is outside the range of available attention layers