docs for
pattern_lensv0.5.0
| Inspect patterns across models, heads, prompts, etc. | Inspect a single pattern |
![]() |
![]() |
visualization of LLM attention patterns and things computed about them
pattern-lens makes it easy to:
pip install pattern-lensThe pipeline is as follows:
pattern_lens.activations.acitvations_main(), saving them in
npz filespattern_lens.figures.figures_main() – read the
npz files, pass each attention pattern to each
visualization function, and save the resulting figurespattern_lens.server – web
interface reads metadata in json/jsonl files, then lets the user select
figures to showGenerate attention patterns and default visualizations:
# generate activations
python -m pattern_lens.activations --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
# create visualizations
python -m pattern_lens.figures --model gpt2 --save-path attn_dataserve the web UI:
python -m pattern_lens.server --path attn_datapattern-lens provides two complementary web interfaces for exploring attention patterns:
View a demo of the web UI at miv.name/pattern-lens/demo.
Much of this web UI is inspired by CircuitsVis,
but with a focus on just attention patterns and figures computed from
them. I have also tried to make the interface a bit simpler, more
flexible, and faster.
Add custom visualization functions by decorating them with
@register_attn_figure_func. You should still generate the
activations first:
python -m pattern_lens.activations --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
and then write+run a script/notebook that looks something like this:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd
# these functions simplify writing a function which saves a figure
from pattern_lens.figure_util import matplotlib_figure_saver, save_matrix_wrapper
# decorator to register your function, such that it will be run by `figures_main`
from pattern_lens.attn_figure_funcs import register_attn_figure_func
# runs the actual figure generation pipeline
from pattern_lens.figures import figures_main
# define your own functions
# this one uses `matplotlib_figure_saver` -- define a function that takes matrix and `plt.Axes`, modify the axes
@register_attn_figure_func
@matplotlib_figure_saver(fmt="svgz")
def svd_spectra(attn_matrix: np.ndarray, ax: plt.Axes) -> None:
# Perform SVD
U, s, Vh = svd(attn_matrix)
# Plot singular values
ax.plot(s, "o-")
ax.set_yscale("log")
ax.set_xlabel("Singular Value Index")
ax.set_ylabel("Singular Value")
ax.set_title("Singular Value Spectrum of Attention Matrix")
# run the figures pipelne
# run the pipeline
figures_main(
model_name="pythia-14m",
save_path=Path("docs/demo/"),
n_samples=5,
force=False,
)See demo.ipynb for a full example.
pattern_lens| Inspect patterns across models, heads, prompts, etc. | Inspect a single pattern |
![]() |
![]() |
visualization of LLM attention patterns and things computed about them
pattern-lens makes it easy to:
pip install pattern-lensThe pipeline is as follows:
pattern_lens.activations.acitvations_main(), saving them in
npz files<a href="pattern_lens/figures.html#figures_main">pattern_lens.figures.figures_main()</a>
– read the npz files, pass each attention pattern to each
visualization function, and save the resulting figures<a href="pattern_lens/server.html">pattern_lens.server</a>
– web interface reads metadata in json/jsonl files, then lets the user
select figures to showGenerate attention patterns and default visualizations:
### generate activations
python -m <a href="pattern_lens/activations.html">pattern_lens.activations</a> --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
### create visualizations
python -m <a href="pattern_lens/figures.html">pattern_lens.figures</a> --model gpt2 --save-path attn_dataserve the web UI:
python -m <a href="pattern_lens/server.html">pattern_lens.server</a> --path attn_datapattern-lens provides two complementary web interfaces for exploring attention patterns:
View a demo of the web UI at miv.name/pattern-lens/demo.
Much of this web UI is inspired by CircuitsVis,
but with a focus on just attention patterns and figures computed from
them. I have also tried to make the interface a bit simpler, more
flexible, and faster.
Add custom visualization functions by decorating them with
@register_attn_figure_func. You should still generate the
activations first:
python -m <a href="pattern_lens/activations.html">pattern_lens.activations</a> --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
and then write+run a script/notebook that looks something like this:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd
### these functions simplify writing a function which saves a figure
from <a href="pattern_lens/figure_util.html">pattern_lens.figure_util</a> import matplotlib_figure_saver, save_matrix_wrapper
### decorator to register your function, such that it will be run by `figures_main`
from <a href="pattern_lens/attn_figure_funcs.html">pattern_lens.attn_figure_funcs</a> import register_attn_figure_func
### runs the actual figure generation pipeline
from <a href="pattern_lens/figures.html">pattern_lens.figures</a> import figures_main
### define your own functions
### this one uses `matplotlib_figure_saver` -- define a function that takes matrix and `plt.Axes`, modify the axes
@register_attn_figure_func
@matplotlib_figure_saver(fmt="svgz")
def svd_spectra(attn_matrix: np.ndarray, ax: plt.Axes) -> None:
# Perform SVD
U, s, Vh = svd(attn_matrix)
# Plot singular values
ax.plot(s, "o-")
ax.set_yscale("log")
ax.set_xlabel("Singular Value Index")
ax.set_ylabel("Singular Value")
ax.set_title("Singular Value Spectrum of Attention Matrix")
### run the figures pipelne
### run the pipeline
figures_main(
model_name="pythia-14m",
save_path=Path("docs/demo/"),
n_samples=5,
force=False,
)See demo.ipynb for a full example.
docs for
pattern_lensv0.5.0
computing and saving activations given a model and prompts
from the command line:
python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>from a script:
from pattern_lens.activations import activations_main
activations_main(
model_name="gpt2",
save_path="demo/"
prompts_path="data/pile_1k.jsonl",
)pattern_lens.activationscomputing and saving activations given a model and prompts
from the command line:
python -m <a href="">pattern_lens.activations</a> --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>from a script:
from <a href="">pattern_lens.activations</a> import activations_main
activations_main(
model_name="gpt2",
save_path="demo/"
prompts_path="data/pile_1k.jsonl",
)def compute_activations(
prompt: dict,
model: transformer_lens.HookedTransformer.HookedTransformer | None = None,
save_path: pathlib.Path = PosixPath('attn_data'),
names_filter: Callable[[str], bool] | re.Pattern = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern'),
return_cache: Literal[None, 'numpy', 'torch'] = 'torch',
stack_heads: bool = False
) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'] | jaxtyping.Float[Tensor, 'n_layers n_heads n_ctx n_ctx'] | None]get activations for a given model and prompt, possibly from a cache
if from a cache, prompt_meta must be passed and contain the prompt hash
prompt : dict | None (defaults to
None)model : HookedTransformersave_path : Path (defaults to
Path(DATA_DIR))names_filter : Callable[[str], bool]|re.Pattern a
filter for the names of the activations to return. if an
re.Pattern, will use
lambda key: names_filter.match(key) is not None (defaults
to ATTN_PATTERN_REGEX)return_cache : Literal[None, "numpy", "torch"] will
return None as the second element if None,
otherwise will return the cache in the specified tensor format.
stack_heads still affects whether it will be a dict (False)
or a single tensor (True) (defaults to None)stack_heads : bool whether the heads should be stacked
in the output. this causes a number of changes:npy file with a single
(n_layers, n_heads, n_ctx, n_ctx) tensor saved for each
prompt instead of npz file with dict by layercache will be a single
(n_layers, n_heads, n_ctx, n_ctx) tensor instead of a dict
by layer if return_cache is True will assert
that everything in the activation cache is only attention patterns, and
is all of the attention patterns. raises an exception if not.tuple[
Path,
Union[
None,
ActivationCacheNp, ActivationCache,
Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
]
]
def get_activations(
prompt: dict,
model: transformer_lens.HookedTransformer.HookedTransformer | str,
save_path: pathlib.Path = PosixPath('attn_data'),
allow_disk_cache: bool = True,
return_cache: Literal[None, 'numpy', 'torch'] = 'numpy'
) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | None]given a prompt and a model, save or load activations
prompt : dict expected to contain the ‘text’ keymodel : HookedTransformer | str either a
HookedTransformer or a string model name, to be loaded with
HookedTransformer.from_pretrainedsave_path : Path path to save the activations to (and
load from) (defaults to Path(DATA_DIR))allow_disk_cache : bool whether to allow loading from
disk cache (defaults to True)return_cache : Literal[None, "numpy", "torch"] whether
to return the cache, and in what format (defaults to
"numpy")tuple[Path, ActivationCacheNp | ActivationCache | None]
the path to the activations and the cache if
return_cache is not None
DEFAULT_DEVICE: torch.device = device(type='cuda')
def activations_main(
model_name: str,
save_path: str,
prompts_path: str,
raw_prompts: bool,
min_chars: int,
max_chars: int,
force: bool,
n_samples: int,
no_index_html: bool,
shuffle: bool = False,
stacked_heads: bool = False,
device: str | torch.device = device(type='cuda')
) -> Nonemain function for computing activations
model_name : str name of a model to load with
HookedTransformer.from_pretrainedsave_path : str path to save the activations toprompts_path : str path to the prompts fileraw_prompts : bool whether the prompts are raw, not
filtered by length. load_text_data will be called if
True, otherwise just load the “text” field from each line
in prompts_pathmin_chars : int minimum number of characters for a
promptmax_chars : int maximum number of characters for a
promptforce : bool whether to overwrite existing filesn_samples : int maximum number of samples to
processno_index_html : bool whether to write an index.html
fileshuffle : bool whether to shuffle the prompts (defaults
to False)stacked_heads : bool whether to stack the heads in the
output tensor. will save as .npy instead of
.npz if True (defaults to
False)device : str | torch.device the device to use. if a
string, will be passed to torch.devicedef main() -> Nonegenerate attention pattern activations for a model and prompts
docs for
pattern_lensv0.5.0
default figure functions
note that for pattern_lens.figures to recognize your
function, you need to use the register_attn_figure_func
decorator which adds your function to
ATTENTION_MATRIX_FIGURE_FUNCS
ATTENTION_MATRIX_FIGURE_FUNCSget_all_figure_namesregister_attn_figure_funcregister_attn_figure_multifuncrawpattern_lens.attn_figure_funcsdefault figure functions
note that for
<a href="figures.html">pattern_lens.figures</a>
to recognize your function, you need to use the
register_attn_figure_func decorator which adds your
function to ATTENTION_MATRIX_FIGURE_FUNCS
ATTENTION_MATRIX_FIGURE_FUNCS: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]] = [<function raw>]def get_all_figure_names() -> list[str]get all figure names
def register_attn_figure_func(
func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]
) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]decorator for registering attention matrix figure function
if you want to add a new figure function, you should use this decorator
func : AttentionMatrixFigureFunc your function, which
should take an attention matrix and pathAttentionMatrixFigureFunc your function, after we add
it to ATTENTION_MATRIX_FIGURE_FUNCS@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
fig, ax = plt.subplots(figsize=(10, 10))
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("My New Figure Function")
ax.axis("off")
plt.savefig(path / "my_new_figure_func", format="svgz")
plt.close(fig)def register_attn_figure_multifunc(
names: Sequence[str]
) -> Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]decorator which registers a function as a multi-figure function
def raw(
attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']
) -> jaxtyping.Float[ndarray, 'n m']raw attention matrix
docs for
pattern_lensv0.5.0
implements some constants and types
AttentionMatrixActivationCacheNpActivationCacheTorchDATA_DIRATTN_PATTERN_REGEXSPINNER_KWARGSDIVIDER_S1DIVIDER_S2ReturnCachepattern_lens.constsimplements some constants and types
AttentionMatrix = <class 'jaxtyping.Float[ndarray, 'n_ctx n_ctx']'>type alias for attention matrix
ActivationCacheNp = dict[str, numpy.ndarray]type alias for a cache of activations, like a transformer_lens.ActivationCache
ActivationCacheTorch = dict[str, torch.Tensor]type alias for a cache of activations, like a transformer_lens.ActivationCache but without the extras. useful for when loading from an npz file
DATA_DIR: str = 'attn_data'default directory for attention data
ATTN_PATTERN_REGEX: re.Pattern = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern')regex for finding attention patterns in model state dicts
SPINNER_KWARGS: dict = {'config': {'success': '✔️ '}}default kwargs for muutils.spinner.Spinner
DIVIDER_S1: str = '======================================================================'divider string for separating sections
DIVIDER_S2: str = '--------------------------------------------------'divider string for separating subsections
ReturnCache = typing.Literal[None, 'numpy', 'torch']return type for a cache of activations
docs for
pattern_lensv0.5.0
implements a bunch of types, default values, and templates which are useful for figure functions
notably, you can use the decorators
matplotlib_figure_saver, save_matrix_wrapper
to make your functions save figures
AttentionMatrixFigureFuncMatrix2DMatrix2DrgbAttentionMatrixToMatrixFuncMATPLOTLIB_FIGURE_FMTMatrixSaveFormatMATRIX_SAVE_NORMALIZEMATRIX_SAVE_CMAPMATRIX_SAVE_FMTMATRIX_SAVE_SVG_TEMPLATEmatplotlib_figure_savermatplotlib_multifigure_savermatrix_to_image_preprocessmatrix2drgb_to_png_bytesmatrix_as_svgsave_matrix_wrapperpattern_lens.figure_utilimplements a bunch of types, default values, and templates which are useful for figure functions
notably, you can use the decorators
matplotlib_figure_saver, save_matrix_wrapper
to make your functions save figures
AttentionMatrixFigureFunc = collections.abc.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]Type alias for a function that, given an attention matrix, saves one or more figures
Matrix2D = <class 'jaxtyping.Float[ndarray, 'n m']'>Type alias for a 2D matrix (plottable)
Matrix2Drgb = <class 'jaxtyping.UInt8[ndarray, 'n m rgb=3']'>Type alias for a 2D matrix with 3 channels (RGB)
AttentionMatrixToMatrixFunc = collections.abc.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]Type alias for a function that, given an attention matrix, returns a 2D matrix
MATPLOTLIB_FIGURE_FMT: str = 'svgz'format for saving matplotlib figures
MatrixSaveFormat = typing.Literal['png', 'svg', 'svgz']Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure
MATRIX_SAVE_NORMALIZE: bool = Falsedefault for whether to normalize the matrix to range [0, 1]
MATRIX_SAVE_CMAP: str = 'viridis'default colormap for saving matrices
MATRIX_SAVE_FMT: Literal['png', 'svg', 'svgz'] = 'svgz'default format for saving matrices
MATRIX_SAVE_SVG_TEMPLATE: str = '<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>'template for saving an n by m matrix as an
svg/svgz
def matplotlib_figure_saver(
func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], None] | None = None,
fmt: str = 'svgz'
) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None] | Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], None], str], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]decorator for functions which take an attention matrix and predefined
ax object, making it save a figure
func : Callable[[AttentionMatrix, plt.Axes], None] your
function, which should take an attention matrix and predefined
ax objectfmt : str format for saving matplotlib figures
(defaults to MATPLOTLIB_FIGURE_FMT)AttentionMatrixFigureFunc your function, after we wrap
it to save a figure@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("Raw Attention Pattern")
ax.axis("off")def matplotlib_multifigure_saver(
names: Sequence[str],
fmt: str = 'svgz'
) -> Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], dict[str, matplotlib.axes._axes.Axes]], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]decorate a function such that it saves multiple figures, one for each
name in names
names : Sequence[str] the names of the figures to
savefmt : str format for saving matplotlib figures
(defaults to MATPLOTLIB_FIGURE_FMT)Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]
the decorator, which will then be applied to the function we expect the
decorated function to take an attention pattern, and a dict of axes
corresponding to the namesdef matrix_to_image_preprocess(
matrix: jaxtyping.Float[ndarray, 'n m'],
normalize: bool = False,
cmap: str | matplotlib.colors.Colormap = 'viridis',
diverging_colormap: bool = False,
normalize_min: float | None = None
) -> jaxtyping.UInt8[ndarray, 'n m rgb=3']preprocess a 2D matrix into a plottable heatmap image
matrix : Matrix2D input matrixnormalize : bool whether to normalize the matrix to
range [0, 1] (defaults to MATRIX_SAVE_NORMALIZE)cmap : str|Colormap the colormap to use for the matrix
(defaults to MATRIX_SAVE_CMAP)diverging_colormap : bool if True and using a diverging
colormap, ensures 0 values map to the center of the colormap (defaults
to False)normalize_min : float|None if a float, then for
normalize=True and diverging_colormap=False,
the minimum value to normalize to (generally set this to zero?). if
None, then the minimum value of the matrix is used. if
diverging_colormap=True OR normalize=False,
this must be None. (defaults to
None)Matrix2Drgbdef matrix2drgb_to_png_bytes(
matrix: jaxtyping.UInt8[ndarray, 'n m rgb=3'],
buffer: _io.BytesIO | None = None
) -> bytes | NoneConvert a Matrix2Drgb to valid PNG bytes via PIL
buffer is provided, it will write the PNG bytes to
the buffer and return Nonebuffer is not provided, it will return the PNG
bytesmatrix : Matrix2Drgbbuffer : io.BytesIO | None (defaults to
None, in which case it will return the PNG bytes)bytes|None bytes if buffer is
None, otherwise Nonedef matrix_as_svg(
matrix: jaxtyping.Float[ndarray, 'n m'],
normalize: bool = False,
cmap: str | matplotlib.colors.Colormap = 'viridis',
diverging_colormap: bool = False,
normalize_min: float | None = None
) -> strquickly convert a 2D matrix to an SVG image, without matplotlib
matrix : Float[np.ndarray, 'n m'] a 2D matrix to
convert to an SVG imagenormalize : bool whether to normalize the matrix to
range [0, 1]. if it’s not in the range [0, 1], this must be
True or it will raise an AssertionError
(defaults to False)cmap : str the colormap to use for the matrix – will
look up in matplotlib.colormaps if it’s a string (defaults
to "viridis")diverging_colormap : bool if True and using a diverging
colormap, ensures 0 values map to the center of the colormap (defaults
to False)normalize_min : float|None if a float, then for
normalize=True and diverging_colormap=False,
the minimum value to normalize to (generally set this to zero?) if
None, then the minimum value of the matrix is used if
diverging_colormap=True OR normalize=False,
this must be None (defaults to
None)str the SVG content for the matrixdef save_matrix_wrapper(
func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']] | None = None,
*args,
fmt: Literal['png', 'svg', 'svgz'] = 'svgz',
normalize: bool = False,
cmap: str | matplotlib.colors.Colormap = 'viridis',
diverging_colormap: bool = False,
normalize_min: float | None = None
) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None] | Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]Decorator for functions that process an attention matrix and save it as an SVGZ image.
Can handle both argumentless usage and with arguments.
func : AttentionMatrixToMatrixFunc|None Either the
function to decorate (in the no-arguments case) or None
when used with arguments.fmt : MatrixSaveFormat, keyword-only The format to save
the matrix as. Defaults to MATRIX_SAVE_FMT.normalize : bool, keyword-only Whether to normalize the
matrix to range [0, 1]. Defaults to False.cmap : str, keyword-only The colormap to use for the
matrix. Defaults to MATRIX_SVG_CMAP.diverging_colormap : bool if True and using a diverging
colormap, ensures 0 values map to the center of the colormap (defaults
to False)normalize_min : float|None if a float, then for
normalize=True and diverging_colormap=False,
the minimum value to normalize to (generally set this to zero?) if
None, then the minimum value of the matrix is used if
diverging_colormap=True OR normalize=False,
this must be None (defaults to
None)AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
AttentionMatrixFigureFunc if func is
AttentionMatrixToMatrixFunc (no arguments case)Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
if func is None – returns the decorator which
will then be applied to the (with arguments case)@save_matrix_wrapper
def identity_matrix(matrix):
return matrix
@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
return matrix * 2
@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
return matrix * 2docs for
pattern_lensv0.5.0
code for generating figures from attention patterns, using the
functions decorated with register_attn_figure_func
HTConfigMockprocess_single_headcompute_and_save_figuresprocess_promptselect_attn_figure_funcsfigures_mainmainpattern_lens.figurescode for generating figures from attention patterns, using the
functions decorated with register_attn_figure_func
class HTConfigMock:Mock of transformer_lens.HookedTransformerConfig for
type hinting and loading config json
can be initialized with any kwargs, and will update its
__dict__ with them. does, however, require the following
attributes: - n_layers: int - n_heads: int -
model_name: str
we do this to avoid having to import torch and
transformer_lens, since this would have to be done for each
process in the parallelization and probably slows things down
significantly
HTConfigMock(**kwargs: dict[str, str | int])will pass all kwargs to __dict__
n_layers: int
n_heads: int
model_name: str
def serialize(self) -> dictserialize the config to json. values which aren’t serializable will
be converted via muutils.json_serialize.json_serialize
def load(cls, data: dict) -> pattern_lens.figures.HTConfigMocktry to load a config from a dict, using the __init__
method
def process_single_head(
layer_idx: int,
head_idx: int,
attn_pattern: jaxtyping.Float[ndarray, 'n_ctx n_ctx'],
save_dir: pathlib.Path,
figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]],
force_overwrite: bool = False
) -> dict[str, bool | Exception]process a single head’s attention pattern, running all the functions
in figure_funcs on the attention pattern
[gotcha:] if
force_overwriteisFalse, and we used a multi-figure function, it will skip all figures for that function if any are already saved and it assumes a format of{func_name}.{figure_name}.{fmt}for the saved figures
layer_idx : inthead_idx : intattn_pattern : AttentionMatrix attention pattern for
the headsave_dir : Path directory to save the figures toforce_overwrite : bool whether to overwrite existing
figures. if False, will skip any functions which have
already saved a figure (defaults to False)dict[str, bool | Exception] a dictionary of the status
of each function, with the function name as the key and the status as
the valuedef compute_and_save_figures(
model_cfg: 'HookedTransformerConfig|HTConfigMock',
activations_path: pathlib.Path,
cache: dict[str, numpy.ndarray] | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'],
figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]],
save_path: pathlib.Path = PosixPath('attn_data'),
force_overwrite: bool = False,
track_results: bool = False
) -> Nonecompute and save figures for all heads in the model, using the
functions in ATTENTION_MATRIX_FIGURE_FUNCS
model_cfg : HookedTransformerConfig|HTConfigMock
configuration of the model, used for loading the activationscache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
activation cache containing actual patterns for the prompt we are
processingfigure_funcs : list[AttentionMatrixFigureFunc] list of
functions to runsave_path : Path directory to save the figures to
(defaults to Path(DATA_DIR))force_overwrite : bool force overwrite of existing
figures. if False, will skip any functions which have
already saved a figure (defaults to False)track_results : bool whether to track the results of
each function for each head. Isn’t used for anything yet, but this is a
TODO (defaults to False)def process_prompt(
prompt: dict,
model_cfg: 'HookedTransformerConfig|HTConfigMock',
save_path: pathlib.Path,
figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]],
force_overwrite: bool = False
) -> Noneprocess a single prompt, loading the activations and computing and saving the figures
basically just calls load_activations and then
compute_and_save_figures
prompt : dict prompt to process, should be a dict with
the following keys: - "text": the prompt string -
"hash": the hash of the promptmodel_cfg : HookedTransformerConfig|HTConfigMock
configuration of the model, used for figuring out where to savesave_path : Path directory to save the figures tofigure_funcs : list[AttentionMatrixFigureFunc] list of
functions to runforce_overwrite : bool (defaults to
False)def select_attn_figure_funcs(
figure_funcs_select: set[str] | str | None = None
) -> list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]given a selector, figure out which functions from
ATTENTION_MATRIX_FIGURE_FUNCS to use
None, will use all functionsdef figures_main(
model_name: str,
save_path: str,
n_samples: int,
force: bool,
figure_funcs_select: set[str] | str | None = None,
parallel: bool | int = True
) -> Nonemain function for generating figures from attention patterns, using
the functions in ATTENTION_MATRIX_FIGURE_FUNCS
model_name : str model name to use, used for loading
the model config, prompts, activations, and saving the figuressave_path : str base path to look inn_samples : int max number of samples to processforce : bool force overwrite of existing figures. if
False, will skip any functions which have already saved a
figurefigure_funcs_select : set[str]|str|None figure
functions to use. if None, will use all functions. if a
string, will use the function names which match the string. if a set,
will use the function names in the set (defaults to
None)parallel : bool | int whether to run in parallel. if
True, will use all available cores. if False,
will run in serial. if an int, will try to use that many cores (defaults
to True)def main() -> Nonegenerates figures from the activations using the functions decorated
with register_attn_figure_func
docs for
pattern_lensv0.5.0
writes indexes to the model directory for the frontend to use or for record keeping
generate_prompts_jsonlgenerate_models_jsonlget_func_metadatagenerate_functions_jsonlwrite_html_indexpattern_lens.indexeswrites indexes to the model directory for the frontend to use or for record keeping
def generate_prompts_jsonl(model_dir: pathlib.Path) -> Nonecreates a prompts.jsonl file with all the prompts in the
model directory
looks in all directories in {model_dir}/prompts for a
prompt.json file
def generate_models_jsonl(path: pathlib.Path) -> Nonecreates a models.jsonl file with all the models
def get_func_metadata(func: Callable) -> list[dict[str, str | None]]get metadata for a function
func : Callable which has a
_FIGURE_NAMES_KEY (by default _figure_names)
attributelist[dict[str, str | None]] each dictionary is for a
function, containing:
name : str : the name of the figurefunc_name : str the name of the function. if not a
multi-figure function, this is identical to name if it is a
multi-figure function, then name is
{func_name}.{figure_name}doc : str : the docstring of the functionfigure_save_fmt : str | None : the format of the figure
that the function saves, using the figure_save_fmt
attribute of the function. None if the attribute does not
existsource : str | None : the source file of the
functioncode : str | None : the source code of the function,
split by line. None if the source file cannot be readdef generate_functions_jsonl(path: pathlib.Path) -> Noneunions all functions from figures.jsonl and
ATTENTION_MATRIX_FIGURE_FUNCS into the file
def write_html_index(path: pathlib.Path) -> Nonewrites index.html and single.html files to the path (version replacement handled by makefile)
docs for
pattern_lensv0.5.0
loading activations from .npz on disk. implements some custom Exception classes
GetActivationsErrorActivationsMissingErrorActivationsMismatchErrorInvalidPromptErrorcompare_prompt_to_loadedaugment_prompt_with_hashload_activationspattern_lens.load_activationsloading activations from .npz on disk. implements some custom Exception classes
class GetActivationsError(builtins.ValueError):base class for errors in getting activations
class ActivationsMissingError(GetActivationsError, builtins.FileNotFoundError):error for missing activations – can’t find the activations file
class ActivationsMismatchError(GetActivationsError):error for mismatched activations – the prompt text or hash do not match
raised by compare_prompt_to_loaded
class InvalidPromptError(GetActivationsError):error for invalid prompt – the prompt does not have fields “hash” or “text”
raised by augment_prompt_with_hash
def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> Nonecompare a prompt to a loaded prompt, raise an error if they do not match
prompt : dictprompt_loaded : dictNoneActivationsMismatchError : if the prompt text or hash
do not matchdef augment_prompt_with_hash(prompt: dict) -> dictif a prompt does not have a hash, add one
not having a “text” field is allowed, but only if “hash” is present
prompt : dictdictthe input prompt dictionary, if it does not have a
"hash" key
def load_activations(
model_name: str,
prompt: dict,
save_path: pathlib.Path,
return_fmt: Literal[None, 'numpy', 'torch'] = 'torch'
) -> tuple[pathlib.Path, dict[str, torch.Tensor] | dict[str, numpy.ndarray]]load activations for a prompt and model, from an npz file
model_name : strprompt : dictsave_path : Pathreturn_fmt : Literal["torch", "numpy"] (defaults to
"torch")tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]
the path to the activations file and the activations as a dictionary of
numpy arrays or torch tensors, depending on return_fmtActivationsMissingError : if the activations file is
missingValueError : if return_fmt is not
"torch" or "numpy"docs for
pattern_lensv0.5.0
implements load_text_data for loading prompts
pattern_lens.promptsimplements load_text_data for loading prompts
def load_text_data(
fname: pathlib.Path,
min_chars: int | None = None,
max_chars: int | None = None,
shuffle: bool = False
) -> list[dict]given fname, the path to a jsonl file, split prompts up
into more reasonable sizes
fname : Path jsonl file with prompts. Expects a list of
dicts with a “text” keymin_chars : int | None (defaults to
None)max_chars : int | None (defaults to
None)shuffle : bool (defaults to False)list[dict] processed list of prompts. Each prompt has a
“text” key w/ a string value and some metadata. this is not guaranteed
to be the same length as the input list!docs for
pattern_lensv0.5.0
cli for starting the server to show the web ui.
can also run with –rewrite-index to update the index.html file. this is useful for working on the ui.
pattern_lens.servercli for starting the server to show the web ui.
can also run with –rewrite-index to update the index.html file. this is useful for working on the ui.
def main(path: str | None = None, port: int = 8000) -> Nonemove to the given path and start the server