pattern_lens.activations
computing and saving activations given a model and prompts
Usage:
from the command line:
python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
from a script:
from pattern_lens.activations import activations_main
activations_main(
model_name="gpt2",
save_path="demo/"
prompts_path="data/pile_1k.jsonl",
)
1"""computing and saving activations given a model and prompts 2 3 4# Usage: 5 6from the command line: 7 8```bash 9python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples> 10``` 11 12from a script: 13 14```python 15from pattern_lens.activations import activations_main 16activations_main( 17 model_name="gpt2", 18 save_path="demo/" 19 prompts_path="data/pile_1k.jsonl", 20) 21``` 22 23""" 24 25import argparse 26import functools 27import json 28from dataclasses import asdict 29from pathlib import Path 30import re 31from typing import Callable, Literal, overload 32 33import numpy as np 34import torch 35import tqdm 36from muutils.spinner import SpinnerContext 37from muutils.misc.numerical import shorten_numerical_to_str 38from muutils.json_serialize import json_serialize 39from transformer_lens import HookedTransformer, HookedTransformerConfig # type: ignore[import-untyped] 40 41from pattern_lens.consts import ( 42 ATTN_PATTERN_REGEX, 43 DATA_DIR, 44 ActivationCacheNp, 45 SPINNER_KWARGS, 46 DIVIDER_S1, 47 DIVIDER_S2, 48) 49from pattern_lens.indexes import ( 50 generate_models_jsonl, 51 generate_prompts_jsonl, 52 write_html_index, 53) 54from pattern_lens.load_activations import ( 55 ActivationsMissingError, 56 augment_prompt_with_hash, 57 load_activations, 58) 59from pattern_lens.prompts import load_text_data 60 61 62def compute_activations( 63 prompt: dict, 64 model: HookedTransformer | None = None, 65 save_path: Path = Path(DATA_DIR), 66 return_cache: bool = True, 67 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 68) -> tuple[Path, ActivationCacheNp | None]: 69 """get activations for a given model and prompt, possibly from a cache 70 71 if from a cache, prompt_meta must be passed and contain the prompt hash 72 73 # Parameters: 74 - `prompt : dict | None` 75 (defaults to `None`) 76 - `model : HookedTransformer` 77 - `save_path : Path` 78 (defaults to `Path(DATA_DIR)`) 79 - `return_cache : bool` 80 will return `None` as the second element if `False` 81 (defaults to `True`) 82 - `names_filter : Callable[[str], bool]|re.Pattern` 83 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 84 (defaults to `ATTN_PATTERN_REGEX`) 85 86 # Returns: 87 - `tuple[Path, ActivationCacheNp|None]` 88 """ 89 assert model is not None, "model must be passed" 90 assert "text" in prompt, "prompt must contain 'text' key" 91 prompt_str: str = prompt["text"] 92 93 # compute or get prompt metadata 94 prompt_tokenized: list[str] = prompt.get( 95 "tokens", 96 model.tokenizer.tokenize(prompt_str), 97 ) 98 prompt.update( 99 dict( 100 n_tokens=len(prompt_tokenized), 101 tokens=prompt_tokenized, 102 ) 103 ) 104 105 # save metadata 106 prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"] 107 prompt_dir.mkdir(parents=True, exist_ok=True) 108 with open(prompt_dir / "prompt.json", "w") as f: 109 json.dump(prompt, f) 110 111 # set up names filter 112 names_filter_fn: Callable[[str], bool] 113 if isinstance(names_filter, re.Pattern): 114 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 115 else: 116 names_filter_fn = names_filter 117 118 # compute activations 119 with torch.no_grad(): 120 model.eval() 121 # TODO: batching? 122 _, cache = model.run_with_cache( 123 prompt_str, 124 names_filter=names_filter_fn, 125 return_type=None, 126 ) 127 128 cache_np: ActivationCacheNp = { 129 k: v.detach().cpu().numpy() for k, v in cache.items() 130 } 131 132 # save activations 133 activations_path: Path = prompt_dir / "activations.npz" 134 np.savez_compressed( 135 activations_path, 136 **cache_np, 137 ) 138 139 # return path and cache 140 if return_cache: 141 return activations_path, cache_np 142 else: 143 return activations_path, None 144 145 146@overload 147def get_activations( 148 prompt: dict, 149 model: HookedTransformer | str, 150 save_path: Path = Path(DATA_DIR), 151 allow_disk_cache: bool = True, 152 return_cache: Literal[False] = False, 153) -> tuple[Path, None]: ... 154@overload 155def get_activations( 156 prompt: dict, 157 model: HookedTransformer | str, 158 save_path: Path = Path(DATA_DIR), 159 allow_disk_cache: bool = True, 160 return_cache: Literal[True] = True, 161) -> tuple[Path, ActivationCacheNp]: ... 162def get_activations( 163 prompt: dict, 164 model: HookedTransformer | str, 165 save_path: Path = Path(DATA_DIR), 166 allow_disk_cache: bool = True, 167 return_cache: bool = True, 168) -> tuple[Path, ActivationCacheNp | None]: 169 """given a prompt and a model, save or load activations 170 171 # Parameters: 172 - `prompt : dict` 173 expected to contain the 'text' key 174 - `model : HookedTransformer | str` 175 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 176 - `save_path : Path` 177 path to save the activations to (and load from) 178 (defaults to `Path(DATA_DIR)`) 179 - `allow_disk_cache : bool` 180 whether to allow loading from disk cache 181 (defaults to `True`) 182 - `return_cache : bool` 183 whether to return the cache. if `False`, will return `None` as the second element 184 (defaults to `True`) 185 186 # Returns: 187 - `tuple[Path, ActivationCacheNp | None]` 188 the path to the activations and the cache if `return_cache` is `True` 189 190 """ 191 # add hash to prompt 192 augment_prompt_with_hash(prompt) 193 194 # get the model 195 model_name: str = ( 196 model.model_name if isinstance(model, HookedTransformer) else model 197 ) 198 199 # from cache 200 if allow_disk_cache: 201 try: 202 path, cache = load_activations( 203 model_name=model_name, 204 prompt=prompt, 205 save_path=save_path, 206 ) 207 if return_cache: 208 return path, cache 209 else: 210 return path, None 211 except ActivationsMissingError: 212 pass 213 214 # compute them 215 if isinstance(model, str): 216 model = HookedTransformer.from_pretrained(model_name) 217 218 return compute_activations( 219 prompt=prompt, 220 model=model, 221 save_path=save_path, 222 return_cache=True, 223 ) 224 225 226def activations_main( 227 model_name: str, 228 save_path: str, 229 prompts_path: str, 230 raw_prompts: bool, 231 min_chars: int, 232 max_chars: int, 233 force: bool, 234 n_samples: int, 235 no_index_html: bool, 236 shuffle: bool = False, 237 device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu", 238) -> None: 239 """main function for computing activations 240 241 # Parameters: 242 - `model_name : str` 243 name of a model to load with `HookedTransformer.from_pretrained` 244 - `save_path : str` 245 path to save the activations to 246 - `prompts_path : str` 247 path to the prompts file 248 - `raw_prompts : bool` 249 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 250 - `min_chars : int` 251 minimum number of characters for a prompt 252 - `max_chars : int` 253 maximum number of characters for a prompt 254 - `force : bool` 255 whether to overwrite existing files 256 - `n_samples : int` 257 maximum number of samples to process 258 - `no_index_html : bool` 259 whether to write an index.html file 260 - `shuffle : bool` 261 whether to shuffle the prompts 262 (defaults to `False`) 263 - `device : str | torch.device` 264 the device to use. if a string, will be passed to `torch.device` 265 """ 266 267 # figure out the device to use 268 device_: torch.device 269 if isinstance(device, torch.device): 270 device_ = device 271 elif isinstance(device, str): 272 device_ = torch.device(device) 273 else: 274 raise ValueError(f"invalid device: {device}") 275 276 print(f"using device: {device_}") 277 278 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 279 model: HookedTransformer = HookedTransformer.from_pretrained( 280 model_name, device=device_ 281 ) 282 model.model_name = model_name 283 model.cfg.model_name = model_name 284 n_params: int = sum(p.numel() for p in model.parameters()) 285 print( 286 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters" 287 ) 288 print(f"\tmodel devices: {set(p.device for p in model.parameters())}") 289 290 save_path_p: Path = Path(save_path) 291 save_path_p.mkdir(parents=True, exist_ok=True) 292 model_path: Path = save_path_p / model_name 293 with SpinnerContext( 294 message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS 295 ): 296 model_cfg: HookedTransformerConfig 297 model_cfg = model.cfg 298 model_path.mkdir(parents=True, exist_ok=True) 299 with open(model_path / "model_cfg.json", "w") as f: 300 json.dump(json_serialize(asdict(model_cfg)), f) 301 302 # load prompts 303 with SpinnerContext( 304 message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS 305 ): 306 prompts: list[dict] 307 if raw_prompts: 308 prompts = load_text_data( 309 Path(prompts_path), 310 min_chars=min_chars, 311 max_chars=max_chars, 312 shuffle=shuffle, 313 ) 314 else: 315 with open(model_path / "prompts.jsonl", "r") as f: 316 prompts = [json.loads(line) for line in f.readlines()] 317 # truncate to n_samples 318 prompts = prompts[:n_samples] 319 320 print(f"{len(prompts)} prompts loaded") 321 322 # write index.html 323 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS): 324 if not no_index_html: 325 write_html_index(save_path_p) 326 327 # get activations 328 list( 329 tqdm.tqdm( 330 map( 331 functools.partial( 332 get_activations, 333 model=model, 334 save_path=save_path_p, 335 allow_disk_cache=not force, 336 return_cache=False, 337 ), 338 prompts, 339 ), 340 total=len(prompts), 341 desc="Computing activations", 342 unit="prompt", 343 ) 344 ) 345 346 with SpinnerContext( 347 message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS 348 ): 349 generate_models_jsonl(save_path_p) 350 generate_prompts_jsonl(save_path_p / model_name) 351 352 353def main(): 354 print(DIVIDER_S1) 355 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 356 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 357 # input and output 358 arg_parser.add_argument( 359 "--model", 360 "-m", 361 type=str, 362 required=True, 363 help="The model name(s) to use. comma separated with no whitespace if multiple", 364 ) 365 366 arg_parser.add_argument( 367 "--prompts", 368 "-p", 369 type=str, 370 required=False, 371 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 372 default=None, 373 ) 374 375 arg_parser.add_argument( 376 "--save-path", 377 "-s", 378 type=str, 379 required=False, 380 help="The path to save the attention patterns", 381 default=DATA_DIR, 382 ) 383 384 # min and max prompt lengths 385 arg_parser.add_argument( 386 "--min-chars", 387 type=int, 388 required=False, 389 help="The minimum number of characters for a prompt", 390 default=100, 391 ) 392 arg_parser.add_argument( 393 "--max-chars", 394 type=int, 395 required=False, 396 help="The maximum number of characters for a prompt", 397 default=1000, 398 ) 399 400 # number of samples 401 arg_parser.add_argument( 402 "--n-samples", 403 "-n", 404 type=int, 405 required=False, 406 help="The max number of samples to process, do all in the file if None", 407 default=None, 408 ) 409 410 # force overwrite 411 arg_parser.add_argument( 412 "--force", 413 "-f", 414 action="store_true", 415 help="If passed, will overwrite existing files", 416 ) 417 418 # no index html 419 arg_parser.add_argument( 420 "--no-index-html", 421 action="store_true", 422 help="If passed, will not write an index.html file for the model", 423 ) 424 425 # raw prompts 426 arg_parser.add_argument( 427 "--raw-prompts", 428 "-r", 429 action="store_true", 430 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 431 ) 432 433 # shuffle 434 arg_parser.add_argument( 435 "--shuffle", 436 action="store_true", 437 help="If passed, will shuffle the prompts", 438 ) 439 440 # device 441 arg_parser.add_argument( 442 "--device", 443 type=str, 444 required=False, 445 help="The device to use for the model", 446 default="cuda" if torch.cuda.is_available() else "cpu", 447 ) 448 449 args: argparse.Namespace = arg_parser.parse_args() 450 451 print(f"args parsed: {args}") 452 453 models: list[str] 454 if "," in args.model: 455 models = args.model.split(",") 456 else: 457 models = [args.model] 458 459 n_models: int = len(models) 460 for idx, model in enumerate(models): 461 print(DIVIDER_S2) 462 print(f"processing model {idx+1} / {n_models}: {model}") 463 print(DIVIDER_S2) 464 465 activations_main( 466 model_name=model, 467 save_path=args.save_path, 468 prompts_path=args.prompts, 469 raw_prompts=args.raw_prompts, 470 min_chars=args.min_chars, 471 max_chars=args.max_chars, 472 force=args.force, 473 n_samples=args.n_samples, 474 no_index_html=args.no_index_html, 475 shuffle=args.shuffle, 476 device=args.device, 477 ) 478 479 print(DIVIDER_S1) 480 481 482if __name__ == "__main__": 483 main()
def
compute_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | None = None, save_path: pathlib.Path = WindowsPath('attn_data'), return_cache: bool = True, names_filter: Union[Callable[[str], bool], re.Pattern] = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern')) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]:
63def compute_activations( 64 prompt: dict, 65 model: HookedTransformer | None = None, 66 save_path: Path = Path(DATA_DIR), 67 return_cache: bool = True, 68 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 69) -> tuple[Path, ActivationCacheNp | None]: 70 """get activations for a given model and prompt, possibly from a cache 71 72 if from a cache, prompt_meta must be passed and contain the prompt hash 73 74 # Parameters: 75 - `prompt : dict | None` 76 (defaults to `None`) 77 - `model : HookedTransformer` 78 - `save_path : Path` 79 (defaults to `Path(DATA_DIR)`) 80 - `return_cache : bool` 81 will return `None` as the second element if `False` 82 (defaults to `True`) 83 - `names_filter : Callable[[str], bool]|re.Pattern` 84 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 85 (defaults to `ATTN_PATTERN_REGEX`) 86 87 # Returns: 88 - `tuple[Path, ActivationCacheNp|None]` 89 """ 90 assert model is not None, "model must be passed" 91 assert "text" in prompt, "prompt must contain 'text' key" 92 prompt_str: str = prompt["text"] 93 94 # compute or get prompt metadata 95 prompt_tokenized: list[str] = prompt.get( 96 "tokens", 97 model.tokenizer.tokenize(prompt_str), 98 ) 99 prompt.update( 100 dict( 101 n_tokens=len(prompt_tokenized), 102 tokens=prompt_tokenized, 103 ) 104 ) 105 106 # save metadata 107 prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"] 108 prompt_dir.mkdir(parents=True, exist_ok=True) 109 with open(prompt_dir / "prompt.json", "w") as f: 110 json.dump(prompt, f) 111 112 # set up names filter 113 names_filter_fn: Callable[[str], bool] 114 if isinstance(names_filter, re.Pattern): 115 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 116 else: 117 names_filter_fn = names_filter 118 119 # compute activations 120 with torch.no_grad(): 121 model.eval() 122 # TODO: batching? 123 _, cache = model.run_with_cache( 124 prompt_str, 125 names_filter=names_filter_fn, 126 return_type=None, 127 ) 128 129 cache_np: ActivationCacheNp = { 130 k: v.detach().cpu().numpy() for k, v in cache.items() 131 } 132 133 # save activations 134 activations_path: Path = prompt_dir / "activations.npz" 135 np.savez_compressed( 136 activations_path, 137 **cache_np, 138 ) 139 140 # return path and cache 141 if return_cache: 142 return activations_path, cache_np 143 else: 144 return activations_path, None
get activations for a given model and prompt, possibly from a cache
if from a cache, prompt_meta must be passed and contain the prompt hash
Parameters:
prompt : dict | None(defaults toNone)model : HookedTransformersave_path : Path(defaults toPath(DATA_DIR))return_cache : boolwill returnNoneas the second element ifFalse(defaults toTrue)names_filter : Callable[[str], bool]|re.Patterna filter for the names of the activations to return. if anre.Pattern, will uselambda key: names_filter.match(key) is not None(defaults toATTN_PATTERN_REGEX)
Returns:
tuple[Path, ActivationCacheNp|None]
def
get_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | str, save_path: pathlib.Path = WindowsPath('attn_data'), allow_disk_cache: bool = True, return_cache: bool = True) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]:
163def get_activations( 164 prompt: dict, 165 model: HookedTransformer | str, 166 save_path: Path = Path(DATA_DIR), 167 allow_disk_cache: bool = True, 168 return_cache: bool = True, 169) -> tuple[Path, ActivationCacheNp | None]: 170 """given a prompt and a model, save or load activations 171 172 # Parameters: 173 - `prompt : dict` 174 expected to contain the 'text' key 175 - `model : HookedTransformer | str` 176 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 177 - `save_path : Path` 178 path to save the activations to (and load from) 179 (defaults to `Path(DATA_DIR)`) 180 - `allow_disk_cache : bool` 181 whether to allow loading from disk cache 182 (defaults to `True`) 183 - `return_cache : bool` 184 whether to return the cache. if `False`, will return `None` as the second element 185 (defaults to `True`) 186 187 # Returns: 188 - `tuple[Path, ActivationCacheNp | None]` 189 the path to the activations and the cache if `return_cache` is `True` 190 191 """ 192 # add hash to prompt 193 augment_prompt_with_hash(prompt) 194 195 # get the model 196 model_name: str = ( 197 model.model_name if isinstance(model, HookedTransformer) else model 198 ) 199 200 # from cache 201 if allow_disk_cache: 202 try: 203 path, cache = load_activations( 204 model_name=model_name, 205 prompt=prompt, 206 save_path=save_path, 207 ) 208 if return_cache: 209 return path, cache 210 else: 211 return path, None 212 except ActivationsMissingError: 213 pass 214 215 # compute them 216 if isinstance(model, str): 217 model = HookedTransformer.from_pretrained(model_name) 218 219 return compute_activations( 220 prompt=prompt, 221 model=model, 222 save_path=save_path, 223 return_cache=True, 224 )
given a prompt and a model, save or load activations
Parameters:
prompt : dictexpected to contain the 'text' keymodel : HookedTransformer | streither aHookedTransformeror a string model name, to be loaded withHookedTransformer.from_pretrainedsave_path : Pathpath to save the activations to (and load from) (defaults toPath(DATA_DIR))allow_disk_cache : boolwhether to allow loading from disk cache (defaults toTrue)return_cache : boolwhether to return the cache. ifFalse, will returnNoneas the second element (defaults toTrue)
Returns:
tuple[Path, ActivationCacheNp | None]the path to the activations and the cache ifreturn_cacheisTrue
def
activations_main( model_name: str, save_path: str, prompts_path: str, raw_prompts: bool, min_chars: int, max_chars: int, force: bool, n_samples: int, no_index_html: bool, shuffle: bool = False, device: str | torch.device = 'cuda') -> None:
227def activations_main( 228 model_name: str, 229 save_path: str, 230 prompts_path: str, 231 raw_prompts: bool, 232 min_chars: int, 233 max_chars: int, 234 force: bool, 235 n_samples: int, 236 no_index_html: bool, 237 shuffle: bool = False, 238 device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu", 239) -> None: 240 """main function for computing activations 241 242 # Parameters: 243 - `model_name : str` 244 name of a model to load with `HookedTransformer.from_pretrained` 245 - `save_path : str` 246 path to save the activations to 247 - `prompts_path : str` 248 path to the prompts file 249 - `raw_prompts : bool` 250 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 251 - `min_chars : int` 252 minimum number of characters for a prompt 253 - `max_chars : int` 254 maximum number of characters for a prompt 255 - `force : bool` 256 whether to overwrite existing files 257 - `n_samples : int` 258 maximum number of samples to process 259 - `no_index_html : bool` 260 whether to write an index.html file 261 - `shuffle : bool` 262 whether to shuffle the prompts 263 (defaults to `False`) 264 - `device : str | torch.device` 265 the device to use. if a string, will be passed to `torch.device` 266 """ 267 268 # figure out the device to use 269 device_: torch.device 270 if isinstance(device, torch.device): 271 device_ = device 272 elif isinstance(device, str): 273 device_ = torch.device(device) 274 else: 275 raise ValueError(f"invalid device: {device}") 276 277 print(f"using device: {device_}") 278 279 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 280 model: HookedTransformer = HookedTransformer.from_pretrained( 281 model_name, device=device_ 282 ) 283 model.model_name = model_name 284 model.cfg.model_name = model_name 285 n_params: int = sum(p.numel() for p in model.parameters()) 286 print( 287 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters" 288 ) 289 print(f"\tmodel devices: {set(p.device for p in model.parameters())}") 290 291 save_path_p: Path = Path(save_path) 292 save_path_p.mkdir(parents=True, exist_ok=True) 293 model_path: Path = save_path_p / model_name 294 with SpinnerContext( 295 message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS 296 ): 297 model_cfg: HookedTransformerConfig 298 model_cfg = model.cfg 299 model_path.mkdir(parents=True, exist_ok=True) 300 with open(model_path / "model_cfg.json", "w") as f: 301 json.dump(json_serialize(asdict(model_cfg)), f) 302 303 # load prompts 304 with SpinnerContext( 305 message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS 306 ): 307 prompts: list[dict] 308 if raw_prompts: 309 prompts = load_text_data( 310 Path(prompts_path), 311 min_chars=min_chars, 312 max_chars=max_chars, 313 shuffle=shuffle, 314 ) 315 else: 316 with open(model_path / "prompts.jsonl", "r") as f: 317 prompts = [json.loads(line) for line in f.readlines()] 318 # truncate to n_samples 319 prompts = prompts[:n_samples] 320 321 print(f"{len(prompts)} prompts loaded") 322 323 # write index.html 324 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS): 325 if not no_index_html: 326 write_html_index(save_path_p) 327 328 # get activations 329 list( 330 tqdm.tqdm( 331 map( 332 functools.partial( 333 get_activations, 334 model=model, 335 save_path=save_path_p, 336 allow_disk_cache=not force, 337 return_cache=False, 338 ), 339 prompts, 340 ), 341 total=len(prompts), 342 desc="Computing activations", 343 unit="prompt", 344 ) 345 ) 346 347 with SpinnerContext( 348 message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS 349 ): 350 generate_models_jsonl(save_path_p) 351 generate_prompts_jsonl(save_path_p / model_name)
main function for computing activations
Parameters:
model_name : strname of a model to load withHookedTransformer.from_pretrainedsave_path : strpath to save the activations toprompts_path : strpath to the prompts fileraw_prompts : boolwhether the prompts are raw, not filtered by length.load_text_datawill be called ifTrue, otherwise just load the "text" field from each line inprompts_pathmin_chars : intminimum number of characters for a promptmax_chars : intmaximum number of characters for a promptforce : boolwhether to overwrite existing filesn_samples : intmaximum number of samples to processno_index_html : boolwhether to write an index.html fileshuffle : boolwhether to shuffle the prompts (defaults toFalse)device : str | torch.devicethe device to use. if a string, will be passed totorch.device
def
main():
354def main(): 355 print(DIVIDER_S1) 356 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 357 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 358 # input and output 359 arg_parser.add_argument( 360 "--model", 361 "-m", 362 type=str, 363 required=True, 364 help="The model name(s) to use. comma separated with no whitespace if multiple", 365 ) 366 367 arg_parser.add_argument( 368 "--prompts", 369 "-p", 370 type=str, 371 required=False, 372 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 373 default=None, 374 ) 375 376 arg_parser.add_argument( 377 "--save-path", 378 "-s", 379 type=str, 380 required=False, 381 help="The path to save the attention patterns", 382 default=DATA_DIR, 383 ) 384 385 # min and max prompt lengths 386 arg_parser.add_argument( 387 "--min-chars", 388 type=int, 389 required=False, 390 help="The minimum number of characters for a prompt", 391 default=100, 392 ) 393 arg_parser.add_argument( 394 "--max-chars", 395 type=int, 396 required=False, 397 help="The maximum number of characters for a prompt", 398 default=1000, 399 ) 400 401 # number of samples 402 arg_parser.add_argument( 403 "--n-samples", 404 "-n", 405 type=int, 406 required=False, 407 help="The max number of samples to process, do all in the file if None", 408 default=None, 409 ) 410 411 # force overwrite 412 arg_parser.add_argument( 413 "--force", 414 "-f", 415 action="store_true", 416 help="If passed, will overwrite existing files", 417 ) 418 419 # no index html 420 arg_parser.add_argument( 421 "--no-index-html", 422 action="store_true", 423 help="If passed, will not write an index.html file for the model", 424 ) 425 426 # raw prompts 427 arg_parser.add_argument( 428 "--raw-prompts", 429 "-r", 430 action="store_true", 431 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 432 ) 433 434 # shuffle 435 arg_parser.add_argument( 436 "--shuffle", 437 action="store_true", 438 help="If passed, will shuffle the prompts", 439 ) 440 441 # device 442 arg_parser.add_argument( 443 "--device", 444 type=str, 445 required=False, 446 help="The device to use for the model", 447 default="cuda" if torch.cuda.is_available() else "cpu", 448 ) 449 450 args: argparse.Namespace = arg_parser.parse_args() 451 452 print(f"args parsed: {args}") 453 454 models: list[str] 455 if "," in args.model: 456 models = args.model.split(",") 457 else: 458 models = [args.model] 459 460 n_models: int = len(models) 461 for idx, model in enumerate(models): 462 print(DIVIDER_S2) 463 print(f"processing model {idx+1} / {n_models}: {model}") 464 print(DIVIDER_S2) 465 466 activations_main( 467 model_name=model, 468 save_path=args.save_path, 469 prompts_path=args.prompts, 470 raw_prompts=args.raw_prompts, 471 min_chars=args.min_chars, 472 max_chars=args.max_chars, 473 force=args.force, 474 n_samples=args.n_samples, 475 no_index_html=args.no_index_html, 476 shuffle=args.shuffle, 477 device=args.device, 478 ) 479 480 print(DIVIDER_S1)