docs for pattern_lens v0.2.0
View Source on GitHub

pattern_lens.activations

computing and saving activations given a model and prompts

Usage:

from the command line:

python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>

from a script:

from pattern_lens.activations import activations_main
activations_main(
    model_name="gpt2",
    save_path="demo/"
    prompts_path="data/pile_1k.jsonl",
)

  1"""computing and saving activations given a model and prompts
  2
  3
  4# Usage:
  5
  6from the command line:
  7
  8```bash
  9python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
 10```
 11
 12from a script:
 13
 14```python
 15from pattern_lens.activations import activations_main
 16activations_main(
 17    model_name="gpt2",
 18    save_path="demo/"
 19    prompts_path="data/pile_1k.jsonl",
 20)
 21```
 22
 23"""
 24
 25import argparse
 26import functools
 27import json
 28from dataclasses import asdict
 29from pathlib import Path
 30import re
 31from typing import Callable, Literal, overload
 32
 33import numpy as np
 34import torch
 35import tqdm
 36from muutils.spinner import SpinnerContext
 37from muutils.misc.numerical import shorten_numerical_to_str
 38from muutils.json_serialize import json_serialize
 39from transformer_lens import HookedTransformer, HookedTransformerConfig  # type: ignore[import-untyped]
 40
 41from pattern_lens.consts import (
 42    ATTN_PATTERN_REGEX,
 43    DATA_DIR,
 44    ActivationCacheNp,
 45    SPINNER_KWARGS,
 46    DIVIDER_S1,
 47    DIVIDER_S2,
 48)
 49from pattern_lens.indexes import (
 50    generate_models_jsonl,
 51    generate_prompts_jsonl,
 52    write_html_index,
 53)
 54from pattern_lens.load_activations import (
 55    ActivationsMissingError,
 56    augment_prompt_with_hash,
 57    load_activations,
 58)
 59from pattern_lens.prompts import load_text_data
 60
 61
 62def compute_activations(
 63    prompt: dict,
 64    model: HookedTransformer | None = None,
 65    save_path: Path = Path(DATA_DIR),
 66    return_cache: bool = True,
 67    names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
 68) -> tuple[Path, ActivationCacheNp | None]:
 69    """get activations for a given model and prompt, possibly from a cache
 70
 71    if from a cache, prompt_meta must be passed and contain the prompt hash
 72
 73    # Parameters:
 74     - `prompt : dict | None`
 75       (defaults to `None`)
 76     - `model : HookedTransformer`
 77     - `save_path : Path`
 78       (defaults to `Path(DATA_DIR)`)
 79     - `return_cache : bool`
 80       will return `None` as the second element if `False`
 81       (defaults to `True`)
 82     - `names_filter : Callable[[str], bool]|re.Pattern`
 83       a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
 84       (defaults to `ATTN_PATTERN_REGEX`)
 85
 86    # Returns:
 87     - `tuple[Path, ActivationCacheNp|None]`
 88    """
 89    assert model is not None, "model must be passed"
 90    assert "text" in prompt, "prompt must contain 'text' key"
 91    prompt_str: str = prompt["text"]
 92
 93    # compute or get prompt metadata
 94    prompt_tokenized: list[str] = prompt.get(
 95        "tokens",
 96        model.tokenizer.tokenize(prompt_str),
 97    )
 98    prompt.update(
 99        dict(
100            n_tokens=len(prompt_tokenized),
101            tokens=prompt_tokenized,
102        )
103    )
104
105    # save metadata
106    prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"]
107    prompt_dir.mkdir(parents=True, exist_ok=True)
108    with open(prompt_dir / "prompt.json", "w") as f:
109        json.dump(prompt, f)
110
111    # set up names filter
112    names_filter_fn: Callable[[str], bool]
113    if isinstance(names_filter, re.Pattern):
114        names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
115    else:
116        names_filter_fn = names_filter
117
118    # compute activations
119    with torch.no_grad():
120        model.eval()
121        # TODO: batching?
122        _, cache = model.run_with_cache(
123            prompt_str,
124            names_filter=names_filter_fn,
125            return_type=None,
126        )
127
128    cache_np: ActivationCacheNp = {
129        k: v.detach().cpu().numpy() for k, v in cache.items()
130    }
131
132    # save activations
133    activations_path: Path = prompt_dir / "activations.npz"
134    np.savez_compressed(
135        activations_path,
136        **cache_np,
137    )
138
139    # return path and cache
140    if return_cache:
141        return activations_path, cache_np
142    else:
143        return activations_path, None
144
145
146@overload
147def get_activations(
148    prompt: dict,
149    model: HookedTransformer | str,
150    save_path: Path = Path(DATA_DIR),
151    allow_disk_cache: bool = True,
152    return_cache: Literal[False] = False,
153) -> tuple[Path, None]: ...
154@overload
155def get_activations(
156    prompt: dict,
157    model: HookedTransformer | str,
158    save_path: Path = Path(DATA_DIR),
159    allow_disk_cache: bool = True,
160    return_cache: Literal[True] = True,
161) -> tuple[Path, ActivationCacheNp]: ...
162def get_activations(
163    prompt: dict,
164    model: HookedTransformer | str,
165    save_path: Path = Path(DATA_DIR),
166    allow_disk_cache: bool = True,
167    return_cache: bool = True,
168) -> tuple[Path, ActivationCacheNp | None]:
169    """given a prompt and a model, save or load activations
170
171    # Parameters:
172     - `prompt : dict`
173        expected to contain the 'text' key
174     - `model : HookedTransformer | str`
175        either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
176     - `save_path : Path`
177        path to save the activations to (and load from)
178       (defaults to `Path(DATA_DIR)`)
179     - `allow_disk_cache : bool`
180        whether to allow loading from disk cache
181       (defaults to `True`)
182     - `return_cache : bool`
183        whether to return the cache. if `False`, will return `None` as the second element
184       (defaults to `True`)
185
186    # Returns:
187     - `tuple[Path, ActivationCacheNp | None]`
188         the path to the activations and the cache if `return_cache` is `True`
189
190    """
191    # add hash to prompt
192    augment_prompt_with_hash(prompt)
193
194    # get the model
195    model_name: str = (
196        model.model_name if isinstance(model, HookedTransformer) else model
197    )
198
199    # from cache
200    if allow_disk_cache:
201        try:
202            path, cache = load_activations(
203                model_name=model_name,
204                prompt=prompt,
205                save_path=save_path,
206            )
207            if return_cache:
208                return path, cache
209            else:
210                return path, None
211        except ActivationsMissingError:
212            pass
213
214    # compute them
215    if isinstance(model, str):
216        model = HookedTransformer.from_pretrained(model_name)
217
218    return compute_activations(
219        prompt=prompt,
220        model=model,
221        save_path=save_path,
222        return_cache=True,
223    )
224
225
226def activations_main(
227    model_name: str,
228    save_path: str,
229    prompts_path: str,
230    raw_prompts: bool,
231    min_chars: int,
232    max_chars: int,
233    force: bool,
234    n_samples: int,
235    no_index_html: bool,
236    shuffle: bool = False,
237    device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
238) -> None:
239    """main function for computing activations
240
241    # Parameters:
242     - `model_name : str`
243        name of a model to load with `HookedTransformer.from_pretrained`
244     - `save_path : str`
245        path to save the activations to
246     - `prompts_path : str`
247        path to the prompts file
248     - `raw_prompts : bool`
249        whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
250     - `min_chars : int`
251        minimum number of characters for a prompt
252     - `max_chars : int`
253        maximum number of characters for a prompt
254     - `force : bool`
255        whether to overwrite existing files
256     - `n_samples : int`
257        maximum number of samples to process
258     - `no_index_html : bool`
259        whether to write an index.html file
260     - `shuffle : bool`
261        whether to shuffle the prompts
262       (defaults to `False`)
263     - `device : str | torch.device`
264        the device to use. if a string, will be passed to `torch.device`
265    """
266
267    # figure out the device to use
268    device_: torch.device
269    if isinstance(device, torch.device):
270        device_ = device
271    elif isinstance(device, str):
272        device_ = torch.device(device)
273    else:
274        raise ValueError(f"invalid device: {device}")
275
276    print(f"using device: {device_}")
277
278    with SpinnerContext(message="loading model", **SPINNER_KWARGS):
279        model: HookedTransformer = HookedTransformer.from_pretrained(
280            model_name, device=device_
281        )
282        model.model_name = model_name
283        model.cfg.model_name = model_name
284        n_params: int = sum(p.numel() for p in model.parameters())
285    print(
286        f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters"
287    )
288    print(f"\tmodel devices: {set(p.device for p in model.parameters())}")
289
290    save_path_p: Path = Path(save_path)
291    save_path_p.mkdir(parents=True, exist_ok=True)
292    model_path: Path = save_path_p / model_name
293    with SpinnerContext(
294        message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS
295    ):
296        model_cfg: HookedTransformerConfig
297        model_cfg = model.cfg
298        model_path.mkdir(parents=True, exist_ok=True)
299        with open(model_path / "model_cfg.json", "w") as f:
300            json.dump(json_serialize(asdict(model_cfg)), f)
301
302    # load prompts
303    with SpinnerContext(
304        message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS
305    ):
306        prompts: list[dict]
307        if raw_prompts:
308            prompts = load_text_data(
309                Path(prompts_path),
310                min_chars=min_chars,
311                max_chars=max_chars,
312                shuffle=shuffle,
313            )
314        else:
315            with open(model_path / "prompts.jsonl", "r") as f:
316                prompts = [json.loads(line) for line in f.readlines()]
317        # truncate to n_samples
318        prompts = prompts[:n_samples]
319
320    print(f"{len(prompts)} prompts loaded")
321
322    # write index.html
323    with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
324        if not no_index_html:
325            write_html_index(save_path_p)
326
327    # get activations
328    list(
329        tqdm.tqdm(
330            map(
331                functools.partial(
332                    get_activations,
333                    model=model,
334                    save_path=save_path_p,
335                    allow_disk_cache=not force,
336                    return_cache=False,
337                ),
338                prompts,
339            ),
340            total=len(prompts),
341            desc="Computing activations",
342            unit="prompt",
343        )
344    )
345
346    with SpinnerContext(
347        message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS
348    ):
349        generate_models_jsonl(save_path_p)
350        generate_prompts_jsonl(save_path_p / model_name)
351
352
353def main():
354    print(DIVIDER_S1)
355    with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
356        arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
357        # input and output
358        arg_parser.add_argument(
359            "--model",
360            "-m",
361            type=str,
362            required=True,
363            help="The model name(s) to use. comma separated with no whitespace if multiple",
364        )
365
366        arg_parser.add_argument(
367            "--prompts",
368            "-p",
369            type=str,
370            required=False,
371            help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
372            default=None,
373        )
374
375        arg_parser.add_argument(
376            "--save-path",
377            "-s",
378            type=str,
379            required=False,
380            help="The path to save the attention patterns",
381            default=DATA_DIR,
382        )
383
384        # min and max prompt lengths
385        arg_parser.add_argument(
386            "--min-chars",
387            type=int,
388            required=False,
389            help="The minimum number of characters for a prompt",
390            default=100,
391        )
392        arg_parser.add_argument(
393            "--max-chars",
394            type=int,
395            required=False,
396            help="The maximum number of characters for a prompt",
397            default=1000,
398        )
399
400        # number of samples
401        arg_parser.add_argument(
402            "--n-samples",
403            "-n",
404            type=int,
405            required=False,
406            help="The max number of samples to process, do all in the file if None",
407            default=None,
408        )
409
410        # force overwrite
411        arg_parser.add_argument(
412            "--force",
413            "-f",
414            action="store_true",
415            help="If passed, will overwrite existing files",
416        )
417
418        # no index html
419        arg_parser.add_argument(
420            "--no-index-html",
421            action="store_true",
422            help="If passed, will not write an index.html file for the model",
423        )
424
425        # raw prompts
426        arg_parser.add_argument(
427            "--raw-prompts",
428            "-r",
429            action="store_true",
430            help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
431        )
432
433        # shuffle
434        arg_parser.add_argument(
435            "--shuffle",
436            action="store_true",
437            help="If passed, will shuffle the prompts",
438        )
439
440        # device
441        arg_parser.add_argument(
442            "--device",
443            type=str,
444            required=False,
445            help="The device to use for the model",
446            default="cuda" if torch.cuda.is_available() else "cpu",
447        )
448
449        args: argparse.Namespace = arg_parser.parse_args()
450
451    print(f"args parsed: {args}")
452
453    models: list[str]
454    if "," in args.model:
455        models = args.model.split(",")
456    else:
457        models = [args.model]
458
459    n_models: int = len(models)
460    for idx, model in enumerate(models):
461        print(DIVIDER_S2)
462        print(f"processing model {idx+1} / {n_models}: {model}")
463        print(DIVIDER_S2)
464
465        activations_main(
466            model_name=model,
467            save_path=args.save_path,
468            prompts_path=args.prompts,
469            raw_prompts=args.raw_prompts,
470            min_chars=args.min_chars,
471            max_chars=args.max_chars,
472            force=args.force,
473            n_samples=args.n_samples,
474            no_index_html=args.no_index_html,
475            shuffle=args.shuffle,
476            device=args.device,
477        )
478
479    print(DIVIDER_S1)
480
481
482if __name__ == "__main__":
483    main()

def compute_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | None = None, save_path: pathlib.Path = WindowsPath('attn_data'), return_cache: bool = True, names_filter: Union[Callable[[str], bool], re.Pattern] = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern')) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]:
 63def compute_activations(
 64    prompt: dict,
 65    model: HookedTransformer | None = None,
 66    save_path: Path = Path(DATA_DIR),
 67    return_cache: bool = True,
 68    names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
 69) -> tuple[Path, ActivationCacheNp | None]:
 70    """get activations for a given model and prompt, possibly from a cache
 71
 72    if from a cache, prompt_meta must be passed and contain the prompt hash
 73
 74    # Parameters:
 75     - `prompt : dict | None`
 76       (defaults to `None`)
 77     - `model : HookedTransformer`
 78     - `save_path : Path`
 79       (defaults to `Path(DATA_DIR)`)
 80     - `return_cache : bool`
 81       will return `None` as the second element if `False`
 82       (defaults to `True`)
 83     - `names_filter : Callable[[str], bool]|re.Pattern`
 84       a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
 85       (defaults to `ATTN_PATTERN_REGEX`)
 86
 87    # Returns:
 88     - `tuple[Path, ActivationCacheNp|None]`
 89    """
 90    assert model is not None, "model must be passed"
 91    assert "text" in prompt, "prompt must contain 'text' key"
 92    prompt_str: str = prompt["text"]
 93
 94    # compute or get prompt metadata
 95    prompt_tokenized: list[str] = prompt.get(
 96        "tokens",
 97        model.tokenizer.tokenize(prompt_str),
 98    )
 99    prompt.update(
100        dict(
101            n_tokens=len(prompt_tokenized),
102            tokens=prompt_tokenized,
103        )
104    )
105
106    # save metadata
107    prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"]
108    prompt_dir.mkdir(parents=True, exist_ok=True)
109    with open(prompt_dir / "prompt.json", "w") as f:
110        json.dump(prompt, f)
111
112    # set up names filter
113    names_filter_fn: Callable[[str], bool]
114    if isinstance(names_filter, re.Pattern):
115        names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
116    else:
117        names_filter_fn = names_filter
118
119    # compute activations
120    with torch.no_grad():
121        model.eval()
122        # TODO: batching?
123        _, cache = model.run_with_cache(
124            prompt_str,
125            names_filter=names_filter_fn,
126            return_type=None,
127        )
128
129    cache_np: ActivationCacheNp = {
130        k: v.detach().cpu().numpy() for k, v in cache.items()
131    }
132
133    # save activations
134    activations_path: Path = prompt_dir / "activations.npz"
135    np.savez_compressed(
136        activations_path,
137        **cache_np,
138    )
139
140    # return path and cache
141    if return_cache:
142        return activations_path, cache_np
143    else:
144        return activations_path, None

get activations for a given model and prompt, possibly from a cache

if from a cache, prompt_meta must be passed and contain the prompt hash

Parameters:

  • prompt : dict | None (defaults to None)
  • model : HookedTransformer
  • save_path : Path (defaults to Path(DATA_DIR))
  • return_cache : bool will return None as the second element if False (defaults to True)
  • names_filter : Callable[[str], bool]|re.Pattern a filter for the names of the activations to return. if an re.Pattern, will use lambda key: names_filter.match(key) is not None (defaults to ATTN_PATTERN_REGEX)

Returns:

  • tuple[Path, ActivationCacheNp|None]
def get_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | str, save_path: pathlib.Path = WindowsPath('attn_data'), allow_disk_cache: bool = True, return_cache: bool = True) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]:
163def get_activations(
164    prompt: dict,
165    model: HookedTransformer | str,
166    save_path: Path = Path(DATA_DIR),
167    allow_disk_cache: bool = True,
168    return_cache: bool = True,
169) -> tuple[Path, ActivationCacheNp | None]:
170    """given a prompt and a model, save or load activations
171
172    # Parameters:
173     - `prompt : dict`
174        expected to contain the 'text' key
175     - `model : HookedTransformer | str`
176        either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
177     - `save_path : Path`
178        path to save the activations to (and load from)
179       (defaults to `Path(DATA_DIR)`)
180     - `allow_disk_cache : bool`
181        whether to allow loading from disk cache
182       (defaults to `True`)
183     - `return_cache : bool`
184        whether to return the cache. if `False`, will return `None` as the second element
185       (defaults to `True`)
186
187    # Returns:
188     - `tuple[Path, ActivationCacheNp | None]`
189         the path to the activations and the cache if `return_cache` is `True`
190
191    """
192    # add hash to prompt
193    augment_prompt_with_hash(prompt)
194
195    # get the model
196    model_name: str = (
197        model.model_name if isinstance(model, HookedTransformer) else model
198    )
199
200    # from cache
201    if allow_disk_cache:
202        try:
203            path, cache = load_activations(
204                model_name=model_name,
205                prompt=prompt,
206                save_path=save_path,
207            )
208            if return_cache:
209                return path, cache
210            else:
211                return path, None
212        except ActivationsMissingError:
213            pass
214
215    # compute them
216    if isinstance(model, str):
217        model = HookedTransformer.from_pretrained(model_name)
218
219    return compute_activations(
220        prompt=prompt,
221        model=model,
222        save_path=save_path,
223        return_cache=True,
224    )

given a prompt and a model, save or load activations

Parameters:

  • prompt : dict expected to contain the 'text' key
  • model : HookedTransformer | str either a HookedTransformer or a string model name, to be loaded with HookedTransformer.from_pretrained
  • save_path : Path path to save the activations to (and load from) (defaults to Path(DATA_DIR))
  • allow_disk_cache : bool whether to allow loading from disk cache (defaults to True)
  • return_cache : bool whether to return the cache. if False, will return None as the second element (defaults to True)

Returns:

  • tuple[Path, ActivationCacheNp | None] the path to the activations and the cache if return_cache is True
def activations_main( model_name: str, save_path: str, prompts_path: str, raw_prompts: bool, min_chars: int, max_chars: int, force: bool, n_samples: int, no_index_html: bool, shuffle: bool = False, device: str | torch.device = 'cuda') -> None:
227def activations_main(
228    model_name: str,
229    save_path: str,
230    prompts_path: str,
231    raw_prompts: bool,
232    min_chars: int,
233    max_chars: int,
234    force: bool,
235    n_samples: int,
236    no_index_html: bool,
237    shuffle: bool = False,
238    device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
239) -> None:
240    """main function for computing activations
241
242    # Parameters:
243     - `model_name : str`
244        name of a model to load with `HookedTransformer.from_pretrained`
245     - `save_path : str`
246        path to save the activations to
247     - `prompts_path : str`
248        path to the prompts file
249     - `raw_prompts : bool`
250        whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
251     - `min_chars : int`
252        minimum number of characters for a prompt
253     - `max_chars : int`
254        maximum number of characters for a prompt
255     - `force : bool`
256        whether to overwrite existing files
257     - `n_samples : int`
258        maximum number of samples to process
259     - `no_index_html : bool`
260        whether to write an index.html file
261     - `shuffle : bool`
262        whether to shuffle the prompts
263       (defaults to `False`)
264     - `device : str | torch.device`
265        the device to use. if a string, will be passed to `torch.device`
266    """
267
268    # figure out the device to use
269    device_: torch.device
270    if isinstance(device, torch.device):
271        device_ = device
272    elif isinstance(device, str):
273        device_ = torch.device(device)
274    else:
275        raise ValueError(f"invalid device: {device}")
276
277    print(f"using device: {device_}")
278
279    with SpinnerContext(message="loading model", **SPINNER_KWARGS):
280        model: HookedTransformer = HookedTransformer.from_pretrained(
281            model_name, device=device_
282        )
283        model.model_name = model_name
284        model.cfg.model_name = model_name
285        n_params: int = sum(p.numel() for p in model.parameters())
286    print(
287        f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters"
288    )
289    print(f"\tmodel devices: {set(p.device for p in model.parameters())}")
290
291    save_path_p: Path = Path(save_path)
292    save_path_p.mkdir(parents=True, exist_ok=True)
293    model_path: Path = save_path_p / model_name
294    with SpinnerContext(
295        message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS
296    ):
297        model_cfg: HookedTransformerConfig
298        model_cfg = model.cfg
299        model_path.mkdir(parents=True, exist_ok=True)
300        with open(model_path / "model_cfg.json", "w") as f:
301            json.dump(json_serialize(asdict(model_cfg)), f)
302
303    # load prompts
304    with SpinnerContext(
305        message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS
306    ):
307        prompts: list[dict]
308        if raw_prompts:
309            prompts = load_text_data(
310                Path(prompts_path),
311                min_chars=min_chars,
312                max_chars=max_chars,
313                shuffle=shuffle,
314            )
315        else:
316            with open(model_path / "prompts.jsonl", "r") as f:
317                prompts = [json.loads(line) for line in f.readlines()]
318        # truncate to n_samples
319        prompts = prompts[:n_samples]
320
321    print(f"{len(prompts)} prompts loaded")
322
323    # write index.html
324    with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
325        if not no_index_html:
326            write_html_index(save_path_p)
327
328    # get activations
329    list(
330        tqdm.tqdm(
331            map(
332                functools.partial(
333                    get_activations,
334                    model=model,
335                    save_path=save_path_p,
336                    allow_disk_cache=not force,
337                    return_cache=False,
338                ),
339                prompts,
340            ),
341            total=len(prompts),
342            desc="Computing activations",
343            unit="prompt",
344        )
345    )
346
347    with SpinnerContext(
348        message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS
349    ):
350        generate_models_jsonl(save_path_p)
351        generate_prompts_jsonl(save_path_p / model_name)

main function for computing activations

Parameters:

  • model_name : str name of a model to load with HookedTransformer.from_pretrained
  • save_path : str path to save the activations to
  • prompts_path : str path to the prompts file
  • raw_prompts : bool whether the prompts are raw, not filtered by length. load_text_data will be called if True, otherwise just load the "text" field from each line in prompts_path
  • min_chars : int minimum number of characters for a prompt
  • max_chars : int maximum number of characters for a prompt
  • force : bool whether to overwrite existing files
  • n_samples : int maximum number of samples to process
  • no_index_html : bool whether to write an index.html file
  • shuffle : bool whether to shuffle the prompts (defaults to False)
  • device : str | torch.device the device to use. if a string, will be passed to torch.device
def main():
354def main():
355    print(DIVIDER_S1)
356    with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
357        arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
358        # input and output
359        arg_parser.add_argument(
360            "--model",
361            "-m",
362            type=str,
363            required=True,
364            help="The model name(s) to use. comma separated with no whitespace if multiple",
365        )
366
367        arg_parser.add_argument(
368            "--prompts",
369            "-p",
370            type=str,
371            required=False,
372            help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
373            default=None,
374        )
375
376        arg_parser.add_argument(
377            "--save-path",
378            "-s",
379            type=str,
380            required=False,
381            help="The path to save the attention patterns",
382            default=DATA_DIR,
383        )
384
385        # min and max prompt lengths
386        arg_parser.add_argument(
387            "--min-chars",
388            type=int,
389            required=False,
390            help="The minimum number of characters for a prompt",
391            default=100,
392        )
393        arg_parser.add_argument(
394            "--max-chars",
395            type=int,
396            required=False,
397            help="The maximum number of characters for a prompt",
398            default=1000,
399        )
400
401        # number of samples
402        arg_parser.add_argument(
403            "--n-samples",
404            "-n",
405            type=int,
406            required=False,
407            help="The max number of samples to process, do all in the file if None",
408            default=None,
409        )
410
411        # force overwrite
412        arg_parser.add_argument(
413            "--force",
414            "-f",
415            action="store_true",
416            help="If passed, will overwrite existing files",
417        )
418
419        # no index html
420        arg_parser.add_argument(
421            "--no-index-html",
422            action="store_true",
423            help="If passed, will not write an index.html file for the model",
424        )
425
426        # raw prompts
427        arg_parser.add_argument(
428            "--raw-prompts",
429            "-r",
430            action="store_true",
431            help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
432        )
433
434        # shuffle
435        arg_parser.add_argument(
436            "--shuffle",
437            action="store_true",
438            help="If passed, will shuffle the prompts",
439        )
440
441        # device
442        arg_parser.add_argument(
443            "--device",
444            type=str,
445            required=False,
446            help="The device to use for the model",
447            default="cuda" if torch.cuda.is_available() else "cpu",
448        )
449
450        args: argparse.Namespace = arg_parser.parse_args()
451
452    print(f"args parsed: {args}")
453
454    models: list[str]
455    if "," in args.model:
456        models = args.model.split(",")
457    else:
458        models = [args.model]
459
460    n_models: int = len(models)
461    for idx, model in enumerate(models):
462        print(DIVIDER_S2)
463        print(f"processing model {idx+1} / {n_models}: {model}")
464        print(DIVIDER_S2)
465
466        activations_main(
467            model_name=model,
468            save_path=args.save_path,
469            prompts_path=args.prompts,
470            raw_prompts=args.raw_prompts,
471            min_chars=args.min_chars,
472            max_chars=args.max_chars,
473            force=args.force,
474            n_samples=args.n_samples,
475            no_index_html=args.no_index_html,
476            shuffle=args.shuffle,
477            device=args.device,
478        )
479
480    print(DIVIDER_S1)