attention_visualiser.visualiser

 1import torch
 2from transformers import AutoTokenizer, AutoModel
 3from transformers import BatchEncoding
 4from attention_visualiser.base import BaseAttentionVisualiser
 5import numpy as np
 6from typing import Optional
 7
 8
 9class AttentionVisualiser(BaseAttentionVisualiser):
10    """Attention visualizer for PyTorch-based transformer models.
11
12    This class implements the abstract methods from BaseAttentionVisualiser
13    specifically for models implemented in PyTorch. It handles the extraction
14    and processing of attention weights from PyTorch transformer models.
15
16    Attributes:
17        model: A PyTorch-based transformer model from Hugging Face
18        tokenizer: A tokenizer matching the model
19        config: Dictionary containing visualization configuration parameters
20    """
21
22    def __init__(
23        self, model: AutoModel, tokenizer: AutoTokenizer, config: Optional[dict] = None
24    ) -> None:
25        """Initialize the PyTorch-specific attention visualizer.
26
27        Args:
28            model: A PyTorch-based transformer model from Hugging Face
29            tokenizer: A tokenizer matching the model
30            config: Optional dictionary with visualization parameters
31        """
32        super().__init__(model, tokenizer, config)
33
34    def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
35        """Compute attention weights for the given input using a PyTorch model.
36
37        Runs the PyTorch model in inference mode with output_attentions flag set to True
38        and extracts the attention weights from the model output.
39
40        Args:
41            encoded_input: The encoded input from the tokenizer
42
43        Returns:
44            A tuple containing attention weights from all layers of the model
45        """
46        if encoded_input == self.current_input:
47            # return from cache
48            return self.cache
49
50        # else recompute
51        with torch.no_grad():
52            output = self.model(**encoded_input, output_attentions=True)  # type: ignore
53
54        attentions = output.attentions
55
56        # update cache and current input
57        self.current_input = encoded_input
58        self.cache = attentions
59
60        return attentions
61
62    def get_attention_vector_mean(
63        self, attention: torch.Tensor, axis: int = 0
64    ) -> np.ndarray:
65        """Calculate mean of PyTorch attention vectors along specified axis.
66
67        Computes the mean of the attention tensor and converts it to a NumPy array.
68
69        Args:
70            attention: PyTorch tensor containing attention weights
71            axis: Axis along which to compute the mean (default: 0)
72
73        Returns:
74            NumPy array of mean attention values
75        """
76        return torch.mean(attention, dim=axis).detach().cpu().numpy()
class AttentionVisualiser(attention_visualiser.base.BaseAttentionVisualiser):
10class AttentionVisualiser(BaseAttentionVisualiser):
11    """Attention visualizer for PyTorch-based transformer models.
12
13    This class implements the abstract methods from BaseAttentionVisualiser
14    specifically for models implemented in PyTorch. It handles the extraction
15    and processing of attention weights from PyTorch transformer models.
16
17    Attributes:
18        model: A PyTorch-based transformer model from Hugging Face
19        tokenizer: A tokenizer matching the model
20        config: Dictionary containing visualization configuration parameters
21    """
22
23    def __init__(
24        self, model: AutoModel, tokenizer: AutoTokenizer, config: Optional[dict] = None
25    ) -> None:
26        """Initialize the PyTorch-specific attention visualizer.
27
28        Args:
29            model: A PyTorch-based transformer model from Hugging Face
30            tokenizer: A tokenizer matching the model
31            config: Optional dictionary with visualization parameters
32        """
33        super().__init__(model, tokenizer, config)
34
35    def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
36        """Compute attention weights for the given input using a PyTorch model.
37
38        Runs the PyTorch model in inference mode with output_attentions flag set to True
39        and extracts the attention weights from the model output.
40
41        Args:
42            encoded_input: The encoded input from the tokenizer
43
44        Returns:
45            A tuple containing attention weights from all layers of the model
46        """
47        if encoded_input == self.current_input:
48            # return from cache
49            return self.cache
50
51        # else recompute
52        with torch.no_grad():
53            output = self.model(**encoded_input, output_attentions=True)  # type: ignore
54
55        attentions = output.attentions
56
57        # update cache and current input
58        self.current_input = encoded_input
59        self.cache = attentions
60
61        return attentions
62
63    def get_attention_vector_mean(
64        self, attention: torch.Tensor, axis: int = 0
65    ) -> np.ndarray:
66        """Calculate mean of PyTorch attention vectors along specified axis.
67
68        Computes the mean of the attention tensor and converts it to a NumPy array.
69
70        Args:
71            attention: PyTorch tensor containing attention weights
72            axis: Axis along which to compute the mean (default: 0)
73
74        Returns:
75            NumPy array of mean attention values
76        """
77        return torch.mean(attention, dim=axis).detach().cpu().numpy()

Attention visualizer for PyTorch-based transformer models.

This class implements the abstract methods from BaseAttentionVisualiser specifically for models implemented in PyTorch. It handles the extraction and processing of attention weights from PyTorch transformer models.

Attributes: model: A PyTorch-based transformer model from Hugging Face tokenizer: A tokenizer matching the model config: Dictionary containing visualization configuration parameters

AttentionVisualiser( model: transformers.models.auto.modeling_auto.AutoModel, tokenizer: transformers.models.auto.tokenization_auto.AutoTokenizer, config: Optional[dict] = None)
23    def __init__(
24        self, model: AutoModel, tokenizer: AutoTokenizer, config: Optional[dict] = None
25    ) -> None:
26        """Initialize the PyTorch-specific attention visualizer.
27
28        Args:
29            model: A PyTorch-based transformer model from Hugging Face
30            tokenizer: A tokenizer matching the model
31            config: Optional dictionary with visualization parameters
32        """
33        super().__init__(model, tokenizer, config)

Initialize the PyTorch-specific attention visualizer.

Args: model: A PyTorch-based transformer model from Hugging Face tokenizer: A tokenizer matching the model config: Optional dictionary with visualization parameters

def compute_attentions( self, encoded_input: transformers.tokenization_utils_base.BatchEncoding) -> tuple:
35    def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
36        """Compute attention weights for the given input using a PyTorch model.
37
38        Runs the PyTorch model in inference mode with output_attentions flag set to True
39        and extracts the attention weights from the model output.
40
41        Args:
42            encoded_input: The encoded input from the tokenizer
43
44        Returns:
45            A tuple containing attention weights from all layers of the model
46        """
47        if encoded_input == self.current_input:
48            # return from cache
49            return self.cache
50
51        # else recompute
52        with torch.no_grad():
53            output = self.model(**encoded_input, output_attentions=True)  # type: ignore
54
55        attentions = output.attentions
56
57        # update cache and current input
58        self.current_input = encoded_input
59        self.cache = attentions
60
61        return attentions

Compute attention weights for the given input using a PyTorch model.

Runs the PyTorch model in inference mode with output_attentions flag set to True and extracts the attention weights from the model output.

Args: encoded_input: The encoded input from the tokenizer

Returns: A tuple containing attention weights from all layers of the model

def get_attention_vector_mean(self, attention: torch.Tensor, axis: int = 0) -> numpy.ndarray:
63    def get_attention_vector_mean(
64        self, attention: torch.Tensor, axis: int = 0
65    ) -> np.ndarray:
66        """Calculate mean of PyTorch attention vectors along specified axis.
67
68        Computes the mean of the attention tensor and converts it to a NumPy array.
69
70        Args:
71            attention: PyTorch tensor containing attention weights
72            axis: Axis along which to compute the mean (default: 0)
73
74        Returns:
75            NumPy array of mean attention values
76        """
77        return torch.mean(attention, dim=axis).detach().cpu().numpy()

Calculate mean of PyTorch attention vectors along specified axis.

Computes the mean of the attention tensor and converts it to a NumPy array.

Args: attention: PyTorch tensor containing attention weights axis: Axis along which to compute the mean (default: 0)

Returns: NumPy array of mean attention values