docs for pattern_lens v0.5.0

Contents

PyPI PyPI - Downloads docs Checks

Coverage GitHub commits GitHub commit activity GitHub closed pull requests code size, bytes

Inspect patterns across models, heads, prompts, etc. Inspect a single pattern

pattern-lens

visualization of LLM attention patterns and things computed about them

pattern-lens makes it easy to:

Installation

pip install pattern-lens

Usage

The pipeline is as follows:

Basic CLI

Generate attention patterns and default visualizations:

# generate activations
python -m pattern_lens.activations --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
# create visualizations
python -m pattern_lens.figures --model gpt2 --save-path attn_data

serve the web UI:

python -m pattern_lens.server --path attn_data

Web UI

pattern-lens provides two complementary web interfaces for exploring attention patterns:

View a demo of the web UI at miv.name/pattern-lens/demo.

Much of this web UI is inspired by CircuitsVis, but with a focus on just attention patterns and figures computed from them. I have also tried to make the interface a bit simpler, more flexible, and faster.

Custom Figures

Add custom visualization functions by decorating them with @register_attn_figure_func. You should still generate the activations first:

python -m pattern_lens.activations --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data

and then write+run a script/notebook that looks something like this:

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd

# these functions simplify writing a function which saves a figure
from pattern_lens.figure_util import matplotlib_figure_saver, save_matrix_wrapper
# decorator to register your function, such that it will be run by `figures_main`
from pattern_lens.attn_figure_funcs import register_attn_figure_func
# runs the actual figure generation pipeline
from pattern_lens.figures import figures_main

# define your own functions
# this one uses `matplotlib_figure_saver` -- define a function that takes matrix and `plt.Axes`, modify the axes
@register_attn_figure_func
@matplotlib_figure_saver(fmt="svgz")
def svd_spectra(attn_matrix: np.ndarray, ax: plt.Axes) -> None:
    # Perform SVD
    U, s, Vh = svd(attn_matrix)

    # Plot singular values
    ax.plot(s, "o-")
    ax.set_yscale("log")
    ax.set_xlabel("Singular Value Index")
    ax.set_ylabel("Singular Value")
    ax.set_title("Singular Value Spectrum of Attention Matrix")


# run the figures pipelne
# run the pipeline
figures_main(
    model_name="pythia-14m",
    save_path=Path("docs/demo/"),
    n_samples=5,
    force=False,
)

See demo.ipynb for a full example.

Submodules

View Source on GitHub

pattern_lens

PyPI PyPI - Downloads docs Checks

Coverage GitHub commits GitHub commit activity GitHub closed pull requests code size, bytes

Inspect patterns across models, heads, prompts, etc. Inspect a single pattern

pattern-lens

visualization of LLM attention patterns and things computed about them

pattern-lens makes it easy to:

Installation

pip install pattern-lens

Usage

The pipeline is as follows:

Basic CLI

Generate attention patterns and default visualizations:

### generate activations
python -m <a href="pattern_lens/activations.html">pattern_lens.activations</a> --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
### create visualizations
python -m <a href="pattern_lens/figures.html">pattern_lens.figures</a> --model gpt2 --save-path attn_data

serve the web UI:

python -m <a href="pattern_lens/server.html">pattern_lens.server</a> --path attn_data

Web UI

pattern-lens provides two complementary web interfaces for exploring attention patterns:

View a demo of the web UI at miv.name/pattern-lens/demo.

Much of this web UI is inspired by CircuitsVis, but with a focus on just attention patterns and figures computed from them. I have also tried to make the interface a bit simpler, more flexible, and faster.

Custom Figures

Add custom visualization functions by decorating them with @register_attn_figure_func. You should still generate the activations first:

python -m <a href="pattern_lens/activations.html">pattern_lens.activations</a> --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data

and then write+run a script/notebook that looks something like this:

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd

### these functions simplify writing a function which saves a figure
from <a href="pattern_lens/figure_util.html">pattern_lens.figure_util</a> import matplotlib_figure_saver, save_matrix_wrapper
### decorator to register your function, such that it will be run by `figures_main`
from <a href="pattern_lens/attn_figure_funcs.html">pattern_lens.attn_figure_funcs</a> import register_attn_figure_func
### runs the actual figure generation pipeline
from <a href="pattern_lens/figures.html">pattern_lens.figures</a> import figures_main

### define your own functions
### this one uses `matplotlib_figure_saver` -- define a function that takes matrix and `plt.Axes`, modify the axes
@register_attn_figure_func
@matplotlib_figure_saver(fmt="svgz")
def svd_spectra(attn_matrix: np.ndarray, ax: plt.Axes) -> None:
    # Perform SVD
    U, s, Vh = svd(attn_matrix)

    # Plot singular values
    ax.plot(s, "o-")
    ax.set_yscale("log")
    ax.set_xlabel("Singular Value Index")
    ax.set_ylabel("Singular Value")
    ax.set_title("Singular Value Spectrum of Attention Matrix")


### run the figures pipelne
### run the pipeline
figures_main(
    model_name="pythia-14m",
    save_path=Path("docs/demo/"),
    n_samples=5,
    force=False,
)

See demo.ipynb for a full example.

View Source on GitHub

docs for pattern_lens v0.5.0

Contents

computing and saving activations given a model and prompts

Usage:

from the command line:

python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>

from a script:

from pattern_lens.activations import activations_main
activations_main(
        model_name="gpt2",
        save_path="demo/"
        prompts_path="data/pile_1k.jsonl",
)

API Documentation

View Source on GitHub

pattern_lens.activations

computing and saving activations given a model and prompts

Usage:

from the command line:

python -m <a href="">pattern_lens.activations</a> --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>

from a script:

from <a href="">pattern_lens.activations</a> import activations_main
activations_main(
        model_name="gpt2",
        save_path="demo/"
        prompts_path="data/pile_1k.jsonl",
)

View Source on GitHub

def compute_activations

(
    prompt: dict,
    model: transformer_lens.HookedTransformer.HookedTransformer | None = None,
    save_path: pathlib.Path = PosixPath('attn_data'),
    names_filter: Callable[[str], bool] | re.Pattern = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern'),
    return_cache: Literal[None, 'numpy', 'torch'] = 'torch',
    stack_heads: bool = False
) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'] | jaxtyping.Float[Tensor, 'n_layers n_heads n_ctx n_ctx'] | None]

View Source on GitHub

get activations for a given model and prompt, possibly from a cache

if from a cache, prompt_meta must be passed and contain the prompt hash

Parameters:

Returns:

tuple[
        Path,
        Union[
                None,
                ActivationCacheNp, ActivationCache,
                Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
        ]
]

def get_activations

(
    prompt: dict,
    model: transformer_lens.HookedTransformer.HookedTransformer | str,
    save_path: pathlib.Path = PosixPath('attn_data'),
    allow_disk_cache: bool = True,
    return_cache: Literal[None, 'numpy', 'torch'] = 'numpy'
) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | None]

View Source on GitHub

given a prompt and a model, save or load activations

Parameters:

Returns:

def activations_main

(
    model_name: str,
    save_path: str,
    prompts_path: str,
    raw_prompts: bool,
    min_chars: int,
    max_chars: int,
    force: bool,
    n_samples: int,
    no_index_html: bool,
    shuffle: bool = False,
    stacked_heads: bool = False,
    device: str | torch.device = device(type='cuda')
) -> None

View Source on GitHub

main function for computing activations

Parameters:

def main

() -> None

View Source on GitHub

generate attention pattern activations for a model and prompts

docs for pattern_lens v0.5.0

Contents

default figure functions

note that for pattern_lens.figures to recognize your function, you need to use the register_attn_figure_func decorator which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS

API Documentation

View Source on GitHub

pattern_lens.attn_figure_funcs

default figure functions

note that for <a href="figures.html">pattern_lens.figures</a> to recognize your function, you need to use the register_attn_figure_func decorator which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS

View Source on GitHub

def get_all_figure_names

() -> list[str]

View Source on GitHub

get all figure names

def register_attn_figure_func

(
    func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]
) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]

View Source on GitHub

decorator for registering attention matrix figure function

if you want to add a new figure function, you should use this decorator

Parameters:

Returns:

Usage:

@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.matshow(attn_matrix, cmap="viridis")
        ax.set_title("My New Figure Function")
        ax.axis("off")
        plt.savefig(path / "my_new_figure_func", format="svgz")
        plt.close(fig)

def register_attn_figure_multifunc

(
    names: Sequence[str]
) -> Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]

View Source on GitHub

decorator which registers a function as a multi-figure function

def raw

(
    attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']
) -> jaxtyping.Float[ndarray, 'n m']

View Source on GitHub

raw attention matrix

docs for pattern_lens v0.5.0

Contents

implements some constants and types

API Documentation

View Source on GitHub

pattern_lens.consts

implements some constants and types

View Source on GitHub

type alias for attention matrix

type alias for a cache of activations, like a transformer_lens.ActivationCache

type alias for a cache of activations, like a transformer_lens.ActivationCache but without the extras. useful for when loading from an npz file

default directory for attention data

regex for finding attention patterns in model state dicts

default kwargs for muutils.spinner.Spinner

divider string for separating sections

divider string for separating subsections

return type for a cache of activations

docs for pattern_lens v0.5.0

Contents

implements a bunch of types, default values, and templates which are useful for figure functions

notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures

API Documentation

View Source on GitHub

pattern_lens.figure_util

implements a bunch of types, default values, and templates which are useful for figure functions

notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures

View Source on GitHub

Type alias for a function that, given an attention matrix, saves one or more figures

Type alias for a 2D matrix (plottable)

Type alias for a 2D matrix with 3 channels (RGB)

Type alias for a function that, given an attention matrix, returns a 2D matrix

format for saving matplotlib figures

Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure

default for whether to normalize the matrix to range [0, 1]

default colormap for saving matrices

default format for saving matrices

template for saving an n by m matrix as an svg/svgz

def matplotlib_figure_saver

(
    func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], None] | None = None,
    fmt: str = 'svgz'
) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None] | Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], None], str], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]

View Source on GitHub

decorator for functions which take an attention matrix and predefined ax object, making it save a figure

Parameters:

Returns:

Usage:

@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
        ax.matshow(attn_matrix, cmap="viridis")
        ax.set_title("Raw Attention Pattern")
        ax.axis("off")

def matplotlib_multifigure_saver

(
    names: Sequence[str],
    fmt: str = 'svgz'
) -> Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], dict[str, matplotlib.axes._axes.Axes]], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]

View Source on GitHub

decorate a function such that it saves multiple figures, one for each name in names

Parameters:

Returns:

def matrix_to_image_preprocess

(
    matrix: jaxtyping.Float[ndarray, 'n m'],
    normalize: bool = False,
    cmap: str | matplotlib.colors.Colormap = 'viridis',
    diverging_colormap: bool = False,
    normalize_min: float | None = None
) -> jaxtyping.UInt8[ndarray, 'n m rgb=3']

View Source on GitHub

preprocess a 2D matrix into a plottable heatmap image

Parameters:

Returns:

def matrix2drgb_to_png_bytes

(
    matrix: jaxtyping.UInt8[ndarray, 'n m rgb=3'],
    buffer: _io.BytesIO | None = None
) -> bytes | None

View Source on GitHub

Convert a Matrix2Drgb to valid PNG bytes via PIL

Parameters:

Returns:

def matrix_as_svg

(
    matrix: jaxtyping.Float[ndarray, 'n m'],
    normalize: bool = False,
    cmap: str | matplotlib.colors.Colormap = 'viridis',
    diverging_colormap: bool = False,
    normalize_min: float | None = None
) -> str

View Source on GitHub

quickly convert a 2D matrix to an SVG image, without matplotlib

Parameters:

Returns:

def save_matrix_wrapper

(
    func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']] | None = None,
    *args,
    fmt: Literal['png', 'svg', 'svgz'] = 'svgz',
    normalize: bool = False,
    cmap: str | matplotlib.colors.Colormap = 'viridis',
    diverging_colormap: bool = False,
    normalize_min: float | None = None
) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None] | Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]

View Source on GitHub

Decorator for functions that process an attention matrix and save it as an SVGZ image.

Can handle both argumentless usage and with arguments.

Parameters:

Returns:

AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]

Usage:

@save_matrix_wrapper
def identity_matrix(matrix):
        return matrix

@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
        return matrix * 2

@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
        return matrix * 2

docs for pattern_lens v0.5.0

Contents

code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func

API Documentation

View Source on GitHub

pattern_lens.figures

code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func

View Source on GitHub

class HTConfigMock:

View Source on GitHub

Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json

can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes: - n_layers: int - n_heads: int - model_name: str

we do this to avoid having to import torch and transformer_lens, since this would have to be done for each process in the parallelization and probably slows things down significantly

HTConfigMock

(**kwargs: dict[str, str | int])

View Source on GitHub

will pass all kwargs to __dict__

def serialize

(self) -> dict

View Source on GitHub

serialize the config to json. values which aren’t serializable will be converted via muutils.json_serialize.json_serialize

def load

(cls, data: dict) -> pattern_lens.figures.HTConfigMock

View Source on GitHub

try to load a config from a dict, using the __init__ method

def process_single_head

(
    layer_idx: int,
    head_idx: int,
    attn_pattern: jaxtyping.Float[ndarray, 'n_ctx n_ctx'],
    save_dir: pathlib.Path,
    figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]],
    force_overwrite: bool = False
) -> dict[str, bool | Exception]

View Source on GitHub

process a single head’s attention pattern, running all the functions in figure_funcs on the attention pattern

[gotcha:] if force_overwrite is False, and we used a multi-figure function, it will skip all figures for that function if any are already saved and it assumes a format of {func_name}.{figure_name}.{fmt} for the saved figures

Parameters:

Returns:

def compute_and_save_figures

(
    model_cfg: 'HookedTransformerConfig|HTConfigMock',
    activations_path: pathlib.Path,
    cache: dict[str, numpy.ndarray] | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'],
    figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]],
    save_path: pathlib.Path = PosixPath('attn_data'),
    force_overwrite: bool = False,
    track_results: bool = False
) -> None

View Source on GitHub

compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

def process_prompt

(
    prompt: dict,
    model_cfg: 'HookedTransformerConfig|HTConfigMock',
    save_path: pathlib.Path,
    figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]],
    force_overwrite: bool = False
) -> None

View Source on GitHub

process a single prompt, loading the activations and computing and saving the figures

basically just calls load_activations and then compute_and_save_figures

Parameters:

def select_attn_figure_funcs

(
    figure_funcs_select: set[str] | str | None = None
) -> list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]

View Source on GitHub

given a selector, figure out which functions from ATTENTION_MATRIX_FIGURE_FUNCS to use

def figures_main

(
    model_name: str,
    save_path: str,
    n_samples: int,
    force: bool,
    figure_funcs_select: set[str] | str | None = None,
    parallel: bool | int = True
) -> None

View Source on GitHub

main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

def main

() -> None

View Source on GitHub

generates figures from the activations using the functions decorated with register_attn_figure_func

docs for pattern_lens v0.5.0

Contents

writes indexes to the model directory for the frontend to use or for record keeping

API Documentation

View Source on GitHub

pattern_lens.indexes

writes indexes to the model directory for the frontend to use or for record keeping

View Source on GitHub

def generate_prompts_jsonl

(model_dir: pathlib.Path) -> None

View Source on GitHub

creates a prompts.jsonl file with all the prompts in the model directory

looks in all directories in {model_dir}/prompts for a prompt.json file

def generate_models_jsonl

(path: pathlib.Path) -> None

View Source on GitHub

creates a models.jsonl file with all the models

def get_func_metadata

(func: Callable) -> list[dict[str, str | None]]

View Source on GitHub

get metadata for a function

Parameters:

Returns:

list[dict[str, str | None]] each dictionary is for a function, containing:

def generate_functions_jsonl

(path: pathlib.Path) -> None

View Source on GitHub

unions all functions from figures.jsonl and ATTENTION_MATRIX_FIGURE_FUNCS into the file

def write_html_index

(path: pathlib.Path) -> None

View Source on GitHub

writes index.html and single.html files to the path (version replacement handled by makefile)

docs for pattern_lens v0.5.0

Contents

loading activations from .npz on disk. implements some custom Exception classes

API Documentation

View Source on GitHub

pattern_lens.load_activations

loading activations from .npz on disk. implements some custom Exception classes

View Source on GitHub

class GetActivationsError(builtins.ValueError):

View Source on GitHub

base class for errors in getting activations

Inherited Members

class ActivationsMissingError(GetActivationsError, builtins.FileNotFoundError):

View Source on GitHub

error for missing activations – can’t find the activations file

Inherited Members

class ActivationsMismatchError(GetActivationsError):

View Source on GitHub

error for mismatched activations – the prompt text or hash do not match

raised by compare_prompt_to_loaded

Inherited Members

class InvalidPromptError(GetActivationsError):

View Source on GitHub

error for invalid prompt – the prompt does not have fields “hash” or “text”

raised by augment_prompt_with_hash

Inherited Members

def compare_prompt_to_loaded

(prompt: dict, prompt_loaded: dict) -> None

View Source on GitHub

compare a prompt to a loaded prompt, raise an error if they do not match

Parameters:

Returns:

Raises:

def augment_prompt_with_hash

(prompt: dict) -> dict

View Source on GitHub

if a prompt does not have a hash, add one

not having a “text” field is allowed, but only if “hash” is present

Parameters:

Returns:

Modifies:

the input prompt dictionary, if it does not have a "hash" key

def load_activations

(
    model_name: str,
    prompt: dict,
    save_path: pathlib.Path,
    return_fmt: Literal[None, 'numpy', 'torch'] = 'torch'
) -> tuple[pathlib.Path, dict[str, torch.Tensor] | dict[str, numpy.ndarray]]

View Source on GitHub

load activations for a prompt and model, from an npz file

Parameters:

Returns:

Raises:

docs for pattern_lens v0.5.0

Contents

implements load_text_data for loading prompts

API Documentation

View Source on GitHub

pattern_lens.prompts

implements load_text_data for loading prompts

View Source on GitHub

def load_text_data

(
    fname: pathlib.Path,
    min_chars: int | None = None,
    max_chars: int | None = None,
    shuffle: bool = False
) -> list[dict]

View Source on GitHub

given fname, the path to a jsonl file, split prompts up into more reasonable sizes

Parameters:

Returns:

docs for pattern_lens v0.5.0

Contents

cli for starting the server to show the web ui.

can also run with –rewrite-index to update the index.html file. this is useful for working on the ui.

API Documentation

View Source on GitHub

pattern_lens.server

cli for starting the server to show the web ui.

can also run with –rewrite-index to update the index.html file. this is useful for working on the ui.

View Source on GitHub

def main

(path: str | None = None, port: int = 8000) -> None

View Source on GitHub

move to the given path and start the server