docs for pattern_lens v0.2.0
View Source on GitHub

pattern_lens.figures

code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func


  1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`"""
  2
  3import argparse
  4from collections import defaultdict
  5import functools
  6import itertools
  7import json
  8import warnings
  9from pathlib import Path
 10
 11from muutils.json_serialize import json_serialize
 12from muutils.spinner import SpinnerContext
 13from muutils.parallel import run_maybe_parallel
 14
 15from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
 16from pattern_lens.consts import (
 17    DATA_DIR,
 18    AttentionMatrix,
 19    SPINNER_KWARGS,
 20    ActivationCacheNp,
 21    DIVIDER_S1,
 22    DIVIDER_S2,
 23)
 24from pattern_lens.indexes import (
 25    generate_functions_jsonl,
 26    generate_models_jsonl,
 27    generate_prompts_jsonl,
 28)
 29from pattern_lens.load_activations import load_activations
 30
 31
 32class HTConfigMock:
 33    """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
 34
 35    can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
 36    - `n_layers: int`
 37    - `n_heads: int`
 38    - `model_name: str`
 39    """
 40
 41    def __init__(self, **kwargs):
 42        self.n_layers: int
 43        self.n_heads: int
 44        self.model_name: str
 45        self.__dict__.update(kwargs)
 46
 47    def serialize(self):
 48        """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
 49        return json_serialize(self.__dict__)
 50
 51    @classmethod
 52    def load(cls, data: dict):
 53        "try to load a config from a dict, using the `__init__` method"
 54        return cls(**data)
 55
 56
 57def process_single_head(
 58    layer_idx: int,
 59    head_idx: int,
 60    attn_pattern: AttentionMatrix,
 61    save_dir: Path,
 62    force_overwrite: bool = False,
 63) -> dict[str, bool | Exception]:
 64    """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern
 65
 66    # Parameters:
 67     - `layer_idx : int`
 68     - `head_idx : int`
 69     - `attn_pattern : AttentionMatrix`
 70        attention pattern for the head
 71     - `save_dir : Path`
 72        directory to save the figures to
 73     - `force_overwrite : bool`
 74        whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
 75       (defaults to `False`)
 76
 77    # Returns:
 78     - `dict[str, bool | Exception]`
 79        a dictionary of the status of each function, with the function name as the key and the status as the value
 80    """
 81    funcs_status: dict[str, bool | Exception] = dict()
 82
 83    for func in ATTENTION_MATRIX_FIGURE_FUNCS:
 84        func_name: str = func.__name__
 85        fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
 86
 87        if not force_overwrite and len(fig_path) > 0:
 88            funcs_status[func_name] = True
 89            continue
 90
 91        try:
 92            func(attn_pattern, save_dir)
 93            funcs_status[func_name] = True
 94
 95        except Exception as e:
 96            error_file = save_dir / f"{func.__name__}.error.txt"
 97            error_file.write_text(str(e))
 98            warnings.warn(
 99                f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {str(e)}"
100            )
101            funcs_status[func_name] = e
102
103    return funcs_status
104
105
106def compute_and_save_figures(
107    model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
108    activations_path: Path,
109    cache: ActivationCacheNp,
110    save_path: Path = Path(DATA_DIR),
111    force_overwrite: bool = False,
112    track_results: bool = False,
113) -> None:
114    """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
115
116    # Parameters:
117     - `model_cfg : HookedTransformerConfig|HTConfigMock`
118     - `cache : ActivationCacheNp`
119     - `save_path : Path`
120       (defaults to `Path(DATA_DIR)`)
121     - `force_overwrite : bool`
122        force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
123       (defaults to `False`)
124     - `track_results : bool`
125        whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
126       (defaults to `False`)
127    """
128    prompt_dir: Path = activations_path.parent
129
130    if track_results:
131        results: defaultdict[
132            str,  # func name
133            dict[
134                tuple[int, int],  # layer, head
135                bool | Exception,  # success or exception
136            ],
137        ] = defaultdict(dict)
138
139    for layer_idx, head_idx in itertools.product(
140        range(model_cfg.n_layers),
141        range(model_cfg.n_heads),
142    ):
143        attn_pattern: AttentionMatrix = cache[f"blocks.{layer_idx}.attn.hook_pattern"][
144            0, head_idx
145        ]
146        save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
147        save_dir.mkdir(parents=True, exist_ok=True)
148        head_res: dict[str, bool | Exception] = process_single_head(
149            layer_idx=layer_idx,
150            head_idx=head_idx,
151            attn_pattern=attn_pattern,
152            save_dir=save_dir,
153            force_overwrite=force_overwrite,
154        )
155
156        if track_results:
157            for func_name, status in head_res.items():
158                results[func_name][(layer_idx, head_idx)] = status
159
160    # TODO: do something with results
161
162    generate_prompts_jsonl(save_path / model_cfg.model_name)
163
164
165def process_prompt(
166    prompt: dict,
167    model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
168    save_path: Path,
169    force_overwrite: bool = False,
170) -> None:
171    """process a single prompt, loading the activations and computing and saving the figures
172
173    basically just calls `load_activations` and then `compute_and_save_figures`
174
175    # Parameters:
176     - `prompt : dict`
177     - `model_cfg : HookedTransformerConfig|HTConfigMock`
178     - `force_overwrite : bool`
179       (defaults to `False`)
180    """
181    activations_path: Path
182    cache: ActivationCacheNp
183    activations_path, cache = load_activations(
184        model_name=model_cfg.model_name,
185        prompt=prompt,
186        save_path=save_path,
187        return_fmt="numpy",
188    )
189
190    compute_and_save_figures(
191        model_cfg=model_cfg,
192        activations_path=activations_path,
193        cache=cache,
194        save_path=save_path,
195        force_overwrite=force_overwrite,
196    )
197
198
199def figures_main(
200    model_name: str,
201    save_path: str,
202    n_samples: int,
203    force: bool,
204    parallel: bool | int = True,
205) -> None:
206    """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
207
208    # Parameters:
209     - `model_name : str`
210        model name to use, used for loading the model config, prompts, activations, and saving the figures
211     - `save_path : str`
212        base path to look in
213     - `n_samples : int`
214        max number of samples to process
215     - `force : bool`
216        force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
217     - `parallel : bool | int`
218        whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
219       (defaults to `True`)
220    """
221    with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
222        # save model info or check if it exists
223        save_path_p: Path = Path(save_path)
224        model_path: Path = save_path_p / model_name
225        with open(model_path / "model_cfg.json", "r") as f:
226            model_cfg = HTConfigMock.load(json.load(f))
227
228    with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
229        # load prompts
230        with open(model_path / "prompts.jsonl", "r") as f:
231            prompts: list[dict] = [json.loads(line) for line in f.readlines()]
232        # truncate to n_samples
233        prompts = prompts[:n_samples]
234
235    print(f"{len(prompts)} prompts loaded")
236
237    print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded")
238    print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS]))
239
240    list(
241        run_maybe_parallel(
242            func=functools.partial(
243                process_prompt,
244                model_cfg=model_cfg,
245                save_path=save_path_p,
246                force_overwrite=force,
247            ),
248            iterable=prompts,
249            parallel=parallel,
250            pbar="tqdm",
251            pbar_kwargs=dict(
252                desc="Making figures",
253                unit="prompt",
254            ),
255        )
256    )
257
258    with SpinnerContext(
259        message="updating jsonl metadata for models and functions", **SPINNER_KWARGS
260    ):
261        generate_models_jsonl(save_path_p)
262        generate_functions_jsonl(save_path_p)
263
264
265def main():
266    print(DIVIDER_S1)
267    with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
268        arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
269        # input and output
270        arg_parser.add_argument(
271            "--model",
272            "-m",
273            type=str,
274            required=True,
275            help="The model name(s) to use. comma separated with no whitespace if multiple",
276        )
277        arg_parser.add_argument(
278            "--save-path",
279            "-s",
280            type=str,
281            required=False,
282            help="The path to save the attention patterns",
283            default=DATA_DIR,
284        )
285        # number of samples
286        arg_parser.add_argument(
287            "--n-samples",
288            "-n",
289            type=int,
290            required=False,
291            help="The max number of samples to process, do all in the file if None",
292            default=None,
293        )
294        # force overwrite of existing figures
295        arg_parser.add_argument(
296            "--force",
297            "-f",
298            type=bool,
299            required=False,
300            help="Force overwrite of existing figures",
301            default=False,
302        )
303
304        args: argparse.Namespace = arg_parser.parse_args()
305
306    print(f"args parsed: {args}")
307
308    models: list[str]
309    if "," in args.model:
310        models = args.model.split(",")
311    else:
312        models = [args.model]
313
314    n_models: int = len(models)
315    for idx, model in enumerate(models):
316        print(DIVIDER_S2)
317        print(f"processing model {idx+1} / {n_models}: {model}")
318        print(DIVIDER_S2)
319        figures_main(
320            model_name=model,
321            save_path=args.save_path,
322            n_samples=args.n_samples,
323            force=args.force,
324        )
325
326    print(DIVIDER_S1)
327
328
329if __name__ == "__main__":
330    main()

class HTConfigMock:
33class HTConfigMock:
34    """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
35
36    can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
37    - `n_layers: int`
38    - `n_heads: int`
39    - `model_name: str`
40    """
41
42    def __init__(self, **kwargs):
43        self.n_layers: int
44        self.n_heads: int
45        self.model_name: str
46        self.__dict__.update(kwargs)
47
48    def serialize(self):
49        """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
50        return json_serialize(self.__dict__)
51
52    @classmethod
53    def load(cls, data: dict):
54        "try to load a config from a dict, using the `__init__` method"
55        return cls(**data)

Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json

can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes:

  • n_layers: int
  • n_heads: int
  • model_name: str
HTConfigMock(**kwargs)
42    def __init__(self, **kwargs):
43        self.n_layers: int
44        self.n_heads: int
45        self.model_name: str
46        self.__dict__.update(kwargs)
n_layers: int
n_heads: int
model_name: str
def serialize(self):
48    def serialize(self):
49        """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
50        return json_serialize(self.__dict__)

serialize the config to json. values which aren't serializable will be converted via muutils.json_serialize.json_serialize

@classmethod
def load(cls, data: dict):
52    @classmethod
53    def load(cls, data: dict):
54        "try to load a config from a dict, using the `__init__` method"
55        return cls(**data)

try to load a config from a dict, using the __init__ method

def process_single_head( layer_idx: int, head_idx: int, attn_pattern: jaxtyping.Float[ndarray, 'n_ctx n_ctx'], save_dir: pathlib.Path, force_overwrite: bool = False) -> dict[str, bool | Exception]:
 58def process_single_head(
 59    layer_idx: int,
 60    head_idx: int,
 61    attn_pattern: AttentionMatrix,
 62    save_dir: Path,
 63    force_overwrite: bool = False,
 64) -> dict[str, bool | Exception]:
 65    """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern
 66
 67    # Parameters:
 68     - `layer_idx : int`
 69     - `head_idx : int`
 70     - `attn_pattern : AttentionMatrix`
 71        attention pattern for the head
 72     - `save_dir : Path`
 73        directory to save the figures to
 74     - `force_overwrite : bool`
 75        whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
 76       (defaults to `False`)
 77
 78    # Returns:
 79     - `dict[str, bool | Exception]`
 80        a dictionary of the status of each function, with the function name as the key and the status as the value
 81    """
 82    funcs_status: dict[str, bool | Exception] = dict()
 83
 84    for func in ATTENTION_MATRIX_FIGURE_FUNCS:
 85        func_name: str = func.__name__
 86        fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
 87
 88        if not force_overwrite and len(fig_path) > 0:
 89            funcs_status[func_name] = True
 90            continue
 91
 92        try:
 93            func(attn_pattern, save_dir)
 94            funcs_status[func_name] = True
 95
 96        except Exception as e:
 97            error_file = save_dir / f"{func.__name__}.error.txt"
 98            error_file.write_text(str(e))
 99            warnings.warn(
100                f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {str(e)}"
101            )
102            funcs_status[func_name] = e
103
104    return funcs_status

process a single head's attention pattern, running all the functions in ATTENTION_MATRIX_FIGURE_FUNCS on the attention pattern

Parameters:

  • layer_idx : int
  • head_idx : int
  • attn_pattern : AttentionMatrix attention pattern for the head
  • save_dir : Path directory to save the figures to
  • force_overwrite : bool whether to overwrite existing figures. if False, will skip any functions which have already saved a figure (defaults to False)

Returns:

  • dict[str, bool | Exception] a dictionary of the status of each function, with the function name as the key and the status as the value
def compute_and_save_figures( model_cfg: 'HookedTransformerConfig|HTConfigMock', activations_path: pathlib.Path, cache: dict[str, numpy.ndarray], save_path: pathlib.Path = WindowsPath('attn_data'), force_overwrite: bool = False, track_results: bool = False) -> None:
107def compute_and_save_figures(
108    model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
109    activations_path: Path,
110    cache: ActivationCacheNp,
111    save_path: Path = Path(DATA_DIR),
112    force_overwrite: bool = False,
113    track_results: bool = False,
114) -> None:
115    """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
116
117    # Parameters:
118     - `model_cfg : HookedTransformerConfig|HTConfigMock`
119     - `cache : ActivationCacheNp`
120     - `save_path : Path`
121       (defaults to `Path(DATA_DIR)`)
122     - `force_overwrite : bool`
123        force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
124       (defaults to `False`)
125     - `track_results : bool`
126        whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
127       (defaults to `False`)
128    """
129    prompt_dir: Path = activations_path.parent
130
131    if track_results:
132        results: defaultdict[
133            str,  # func name
134            dict[
135                tuple[int, int],  # layer, head
136                bool | Exception,  # success or exception
137            ],
138        ] = defaultdict(dict)
139
140    for layer_idx, head_idx in itertools.product(
141        range(model_cfg.n_layers),
142        range(model_cfg.n_heads),
143    ):
144        attn_pattern: AttentionMatrix = cache[f"blocks.{layer_idx}.attn.hook_pattern"][
145            0, head_idx
146        ]
147        save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
148        save_dir.mkdir(parents=True, exist_ok=True)
149        head_res: dict[str, bool | Exception] = process_single_head(
150            layer_idx=layer_idx,
151            head_idx=head_idx,
152            attn_pattern=attn_pattern,
153            save_dir=save_dir,
154            force_overwrite=force_overwrite,
155        )
156
157        if track_results:
158            for func_name, status in head_res.items():
159                results[func_name][(layer_idx, head_idx)] = status
160
161    # TODO: do something with results
162
163    generate_prompts_jsonl(save_path / model_cfg.model_name)

compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

  • model_cfg : HookedTransformerConfig|HTConfigMock
  • cache : ActivationCacheNp
  • save_path : Path (defaults to Path(DATA_DIR))
  • force_overwrite : bool force overwrite of existing figures. if False, will skip any functions which have already saved a figure (defaults to False)
  • track_results : bool whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO (defaults to False)
def process_prompt( prompt: dict, model_cfg: 'HookedTransformerConfig|HTConfigMock', save_path: pathlib.Path, force_overwrite: bool = False) -> None:
166def process_prompt(
167    prompt: dict,
168    model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
169    save_path: Path,
170    force_overwrite: bool = False,
171) -> None:
172    """process a single prompt, loading the activations and computing and saving the figures
173
174    basically just calls `load_activations` and then `compute_and_save_figures`
175
176    # Parameters:
177     - `prompt : dict`
178     - `model_cfg : HookedTransformerConfig|HTConfigMock`
179     - `force_overwrite : bool`
180       (defaults to `False`)
181    """
182    activations_path: Path
183    cache: ActivationCacheNp
184    activations_path, cache = load_activations(
185        model_name=model_cfg.model_name,
186        prompt=prompt,
187        save_path=save_path,
188        return_fmt="numpy",
189    )
190
191    compute_and_save_figures(
192        model_cfg=model_cfg,
193        activations_path=activations_path,
194        cache=cache,
195        save_path=save_path,
196        force_overwrite=force_overwrite,
197    )

process a single prompt, loading the activations and computing and saving the figures

basically just calls load_activations and then compute_and_save_figures

Parameters:

  • prompt : dict
  • model_cfg : HookedTransformerConfig|HTConfigMock
  • force_overwrite : bool (defaults to False)
def figures_main( model_name: str, save_path: str, n_samples: int, force: bool, parallel: bool | int = True) -> None:
200def figures_main(
201    model_name: str,
202    save_path: str,
203    n_samples: int,
204    force: bool,
205    parallel: bool | int = True,
206) -> None:
207    """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
208
209    # Parameters:
210     - `model_name : str`
211        model name to use, used for loading the model config, prompts, activations, and saving the figures
212     - `save_path : str`
213        base path to look in
214     - `n_samples : int`
215        max number of samples to process
216     - `force : bool`
217        force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
218     - `parallel : bool | int`
219        whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
220       (defaults to `True`)
221    """
222    with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
223        # save model info or check if it exists
224        save_path_p: Path = Path(save_path)
225        model_path: Path = save_path_p / model_name
226        with open(model_path / "model_cfg.json", "r") as f:
227            model_cfg = HTConfigMock.load(json.load(f))
228
229    with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
230        # load prompts
231        with open(model_path / "prompts.jsonl", "r") as f:
232            prompts: list[dict] = [json.loads(line) for line in f.readlines()]
233        # truncate to n_samples
234        prompts = prompts[:n_samples]
235
236    print(f"{len(prompts)} prompts loaded")
237
238    print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded")
239    print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS]))
240
241    list(
242        run_maybe_parallel(
243            func=functools.partial(
244                process_prompt,
245                model_cfg=model_cfg,
246                save_path=save_path_p,
247                force_overwrite=force,
248            ),
249            iterable=prompts,
250            parallel=parallel,
251            pbar="tqdm",
252            pbar_kwargs=dict(
253                desc="Making figures",
254                unit="prompt",
255            ),
256        )
257    )
258
259    with SpinnerContext(
260        message="updating jsonl metadata for models and functions", **SPINNER_KWARGS
261    ):
262        generate_models_jsonl(save_path_p)
263        generate_functions_jsonl(save_path_p)

main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

  • model_name : str model name to use, used for loading the model config, prompts, activations, and saving the figures
  • save_path : str base path to look in
  • n_samples : int max number of samples to process
  • force : bool force overwrite of existing figures. if False, will skip any functions which have already saved a figure
  • parallel : bool | int whether to run in parallel. if True, will use all available cores. if False, will run in serial. if an int, will try to use that many cores (defaults to True)
def main():
266def main():
267    print(DIVIDER_S1)
268    with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
269        arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
270        # input and output
271        arg_parser.add_argument(
272            "--model",
273            "-m",
274            type=str,
275            required=True,
276            help="The model name(s) to use. comma separated with no whitespace if multiple",
277        )
278        arg_parser.add_argument(
279            "--save-path",
280            "-s",
281            type=str,
282            required=False,
283            help="The path to save the attention patterns",
284            default=DATA_DIR,
285        )
286        # number of samples
287        arg_parser.add_argument(
288            "--n-samples",
289            "-n",
290            type=int,
291            required=False,
292            help="The max number of samples to process, do all in the file if None",
293            default=None,
294        )
295        # force overwrite of existing figures
296        arg_parser.add_argument(
297            "--force",
298            "-f",
299            type=bool,
300            required=False,
301            help="Force overwrite of existing figures",
302            default=False,
303        )
304
305        args: argparse.Namespace = arg_parser.parse_args()
306
307    print(f"args parsed: {args}")
308
309    models: list[str]
310    if "," in args.model:
311        models = args.model.split(",")
312    else:
313        models = [args.model]
314
315    n_models: int = len(models)
316    for idx, model in enumerate(models):
317        print(DIVIDER_S2)
318        print(f"processing model {idx+1} / {n_models}: {model}")
319        print(DIVIDER_S2)
320        figures_main(
321            model_name=model,
322            save_path=args.save_path,
323            n_samples=args.n_samples,
324            force=args.force,
325        )
326
327    print(DIVIDER_S1)