docs for pattern_lens v0.3.0
View Source on GitHub

pattern_lens.figures

code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func


  1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`"""
  2
  3import argparse
  4import functools
  5import itertools
  6import json
  7import warnings
  8from collections import defaultdict
  9from pathlib import Path
 10
 11import numpy as np
 12from jaxtyping import Float
 13
 14# custom utils
 15from muutils.json_serialize import json_serialize
 16from muutils.parallel import run_maybe_parallel
 17from muutils.spinner import SpinnerContext
 18
 19# pattern_lens
 20from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
 21from pattern_lens.consts import (
 22	DATA_DIR,
 23	DIVIDER_S1,
 24	DIVIDER_S2,
 25	SPINNER_KWARGS,
 26	ActivationCacheNp,
 27	AttentionMatrix,
 28)
 29from pattern_lens.indexes import (
 30	generate_functions_jsonl,
 31	generate_models_jsonl,
 32	generate_prompts_jsonl,
 33)
 34from pattern_lens.load_activations import load_activations
 35
 36
 37class HTConfigMock:
 38	"""Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
 39
 40	can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
 41	- `n_layers: int`
 42	- `n_heads: int`
 43	- `model_name: str`
 44
 45	we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
 46	"""
 47
 48	def __init__(self, **kwargs: dict[str, str | int]) -> None:
 49		"will pass all kwargs to `__dict__`"
 50		self.n_layers: int
 51		self.n_heads: int
 52		self.model_name: str
 53		self.__dict__.update(kwargs)
 54
 55	def serialize(self) -> dict:
 56		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
 57		# its fine, we know its a dict
 58		return json_serialize(self.__dict__)  # type: ignore[return-value]
 59
 60	@classmethod
 61	def load(cls, data: dict) -> "HTConfigMock":
 62		"try to load a config from a dict, using the `__init__` method"
 63		return cls(**data)
 64
 65
 66def process_single_head(
 67	layer_idx: int,
 68	head_idx: int,
 69	attn_pattern: AttentionMatrix,
 70	save_dir: Path,
 71	force_overwrite: bool = False,
 72) -> dict[str, bool | Exception]:
 73	"""process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern
 74
 75	> [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
 76	> it will skip all figures for that function if any are already saved
 77	> and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures
 78
 79	# Parameters:
 80	- `layer_idx : int`
 81	- `head_idx : int`
 82	- `attn_pattern : AttentionMatrix`
 83		attention pattern for the head
 84	- `save_dir : Path`
 85		directory to save the figures to
 86	- `force_overwrite : bool`
 87		whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
 88		(defaults to `False`)
 89
 90	# Returns:
 91	- `dict[str, bool | Exception]`
 92		a dictionary of the status of each function, with the function name as the key and the status as the value
 93	"""
 94	funcs_status: dict[str, bool | Exception] = dict()
 95
 96	for func in ATTENTION_MATRIX_FIGURE_FUNCS:
 97		func_name: str = func.__name__
 98		fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
 99
100		if not force_overwrite and len(fig_path) > 0:
101			funcs_status[func_name] = True
102			continue
103
104		try:
105			func(attn_pattern, save_dir)
106			funcs_status[func_name] = True
107
108		# bling catch any exception
109		except Exception as e:  # noqa: BLE001
110			error_file = save_dir / f"{func.__name__}.error.txt"
111			error_file.write_text(str(e))
112			warnings.warn(
113				f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}",
114				stacklevel=2,
115			)
116			funcs_status[func_name] = e
117
118	return funcs_status
119
120
121def compute_and_save_figures(
122	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
123	activations_path: Path,
124	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
125	save_path: Path = Path(DATA_DIR),
126	force_overwrite: bool = False,
127	track_results: bool = False,
128) -> None:
129	"""compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
130
131	# Parameters:
132	- `model_cfg : HookedTransformerConfig|HTConfigMock`
133	- `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]`
134	- `save_path : Path`
135		(defaults to `Path(DATA_DIR)`)
136	- `force_overwrite : bool`
137		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
138		(defaults to `False`)
139	- `track_results : bool`
140		whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
141		(defaults to `False`)
142	"""
143	prompt_dir: Path = activations_path.parent
144
145	if track_results:
146		results: defaultdict[
147			str,  # func name
148			dict[
149				tuple[int, int],  # layer, head
150				bool | Exception,  # success or exception
151			],
152		] = defaultdict(dict)
153
154	for layer_idx, head_idx in itertools.product(
155		range(model_cfg.n_layers),
156		range(model_cfg.n_heads),
157	):
158		attn_pattern: AttentionMatrix
159		if isinstance(cache, dict):
160			attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
161		elif isinstance(cache, np.ndarray):
162			attn_pattern = cache[layer_idx, head_idx]
163		else:
164			msg = (
165				f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
166			)
167			raise TypeError(
168				msg,
169			)
170
171		save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
172		save_dir.mkdir(parents=True, exist_ok=True)
173		head_res: dict[str, bool | Exception] = process_single_head(
174			layer_idx=layer_idx,
175			head_idx=head_idx,
176			attn_pattern=attn_pattern,
177			save_dir=save_dir,
178			force_overwrite=force_overwrite,
179		)
180
181		if track_results:
182			for func_name, status in head_res.items():
183				results[func_name][(layer_idx, head_idx)] = status
184
185	# TODO: do something with results
186
187	generate_prompts_jsonl(save_path / model_cfg.model_name)
188
189
190def process_prompt(
191	prompt: dict,
192	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
193	save_path: Path,
194	force_overwrite: bool = False,
195) -> None:
196	"""process a single prompt, loading the activations and computing and saving the figures
197
198	basically just calls `load_activations` and then `compute_and_save_figures`
199
200	# Parameters:
201	- `prompt : dict`
202	- `model_cfg : HookedTransformerConfig|HTConfigMock`
203	- `force_overwrite : bool`
204		(defaults to `False`)
205	"""
206	activations_path: Path
207	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
208	activations_path, cache = load_activations(
209		model_name=model_cfg.model_name,
210		prompt=prompt,
211		save_path=save_path,
212		return_fmt="numpy",
213	)
214
215	compute_and_save_figures(
216		model_cfg=model_cfg,
217		activations_path=activations_path,
218		cache=cache,
219		save_path=save_path,
220		force_overwrite=force_overwrite,
221	)
222
223
224def figures_main(
225	model_name: str,
226	save_path: str,
227	n_samples: int,
228	force: bool,
229	parallel: bool | int = True,
230) -> None:
231	"""main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
232
233	# Parameters:
234	- `model_name : str`
235		model name to use, used for loading the model config, prompts, activations, and saving the figures
236	- `save_path : str`
237		base path to look in
238	- `n_samples : int`
239		max number of samples to process
240	- `force : bool`
241		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
242	- `parallel : bool | int`
243		whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
244		(defaults to `True`)
245	"""
246	with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
247		# save model info or check if it exists
248		save_path_p: Path = Path(save_path)
249		model_path: Path = save_path_p / model_name
250		with open(model_path / "model_cfg.json", "r") as f:
251			model_cfg = HTConfigMock.load(json.load(f))
252
253	with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
254		# load prompts
255		with open(model_path / "prompts.jsonl", "r") as f:
256			prompts: list[dict] = [json.loads(line) for line in f.readlines()]
257		# truncate to n_samples
258		prompts = prompts[:n_samples]
259
260	print(f"{len(prompts)} prompts loaded")
261
262	print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded")
263	print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS]))
264
265	list(
266		run_maybe_parallel(
267			func=functools.partial(
268				process_prompt,
269				model_cfg=model_cfg,
270				save_path=save_path_p,
271				force_overwrite=force,
272			),
273			iterable=prompts,
274			parallel=parallel,
275			pbar="tqdm",
276			pbar_kwargs=dict(
277				desc="Making figures",
278				unit="prompt",
279			),
280		),
281	)
282
283	with SpinnerContext(
284		message="updating jsonl metadata for models and functions",
285		**SPINNER_KWARGS,
286	):
287		generate_models_jsonl(save_path_p)
288		generate_functions_jsonl(save_path_p)
289
290
291def main() -> None:
292	"generates figures from the activations using the functions decorated with `register_attn_figure_func`"
293	print(DIVIDER_S1)
294	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
295		arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
296		# input and output
297		arg_parser.add_argument(
298			"--model",
299			"-m",
300			type=str,
301			required=True,
302			help="The model name(s) to use. comma separated with no whitespace if multiple",
303		)
304		arg_parser.add_argument(
305			"--save-path",
306			"-s",
307			type=str,
308			required=False,
309			help="The path to save the attention patterns",
310			default=DATA_DIR,
311		)
312		# number of samples
313		arg_parser.add_argument(
314			"--n-samples",
315			"-n",
316			type=int,
317			required=False,
318			help="The max number of samples to process, do all in the file if None",
319			default=None,
320		)
321		# force overwrite of existing figures
322		arg_parser.add_argument(
323			"--force",
324			"-f",
325			type=bool,
326			required=False,
327			help="Force overwrite of existing figures",
328			default=False,
329		)
330
331		args: argparse.Namespace = arg_parser.parse_args()
332
333	print(f"args parsed: {args}")
334
335	models: list[str]
336	if "," in args.model:
337		models = args.model.split(",")
338	else:
339		models = [args.model]
340
341	n_models: int = len(models)
342	for idx, model in enumerate(models):
343		print(DIVIDER_S2)
344		print(f"processing model {idx + 1} / {n_models}: {model}")
345		print(DIVIDER_S2)
346		figures_main(
347			model_name=model,
348			save_path=args.save_path,
349			n_samples=args.n_samples,
350			force=args.force,
351		)
352
353	print(DIVIDER_S1)
354
355
356if __name__ == "__main__":
357	main()

class HTConfigMock:
38class HTConfigMock:
39	"""Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
40
41	can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
42	- `n_layers: int`
43	- `n_heads: int`
44	- `model_name: str`
45
46	we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
47	"""
48
49	def __init__(self, **kwargs: dict[str, str | int]) -> None:
50		"will pass all kwargs to `__dict__`"
51		self.n_layers: int
52		self.n_heads: int
53		self.model_name: str
54		self.__dict__.update(kwargs)
55
56	def serialize(self) -> dict:
57		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
58		# its fine, we know its a dict
59		return json_serialize(self.__dict__)  # type: ignore[return-value]
60
61	@classmethod
62	def load(cls, data: dict) -> "HTConfigMock":
63		"try to load a config from a dict, using the `__init__` method"
64		return cls(**data)

Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json

can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes:

  • n_layers: int
  • n_heads: int
  • model_name: str

we do this to avoid having to import torch and transformer_lens, since this would have to be done for each process in the parallelization and probably slows things down significantly

HTConfigMock(**kwargs: dict[str, str | int])
49	def __init__(self, **kwargs: dict[str, str | int]) -> None:
50		"will pass all kwargs to `__dict__`"
51		self.n_layers: int
52		self.n_heads: int
53		self.model_name: str
54		self.__dict__.update(kwargs)

will pass all kwargs to __dict__

n_layers: int
n_heads: int
model_name: str
def serialize(self) -> dict:
56	def serialize(self) -> dict:
57		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
58		# its fine, we know its a dict
59		return json_serialize(self.__dict__)  # type: ignore[return-value]

serialize the config to json. values which aren't serializable will be converted via muutils.json_serialize.json_serialize

@classmethod
def load(cls, data: dict) -> HTConfigMock:
61	@classmethod
62	def load(cls, data: dict) -> "HTConfigMock":
63		"try to load a config from a dict, using the `__init__` method"
64		return cls(**data)

try to load a config from a dict, using the __init__ method

def process_single_head( layer_idx: int, head_idx: int, attn_pattern: jaxtyping.Float[ndarray, 'n_ctx n_ctx'], save_dir: pathlib.Path, force_overwrite: bool = False) -> dict[str, bool | Exception]:
 67def process_single_head(
 68	layer_idx: int,
 69	head_idx: int,
 70	attn_pattern: AttentionMatrix,
 71	save_dir: Path,
 72	force_overwrite: bool = False,
 73) -> dict[str, bool | Exception]:
 74	"""process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern
 75
 76	> [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
 77	> it will skip all figures for that function if any are already saved
 78	> and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures
 79
 80	# Parameters:
 81	- `layer_idx : int`
 82	- `head_idx : int`
 83	- `attn_pattern : AttentionMatrix`
 84		attention pattern for the head
 85	- `save_dir : Path`
 86		directory to save the figures to
 87	- `force_overwrite : bool`
 88		whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
 89		(defaults to `False`)
 90
 91	# Returns:
 92	- `dict[str, bool | Exception]`
 93		a dictionary of the status of each function, with the function name as the key and the status as the value
 94	"""
 95	funcs_status: dict[str, bool | Exception] = dict()
 96
 97	for func in ATTENTION_MATRIX_FIGURE_FUNCS:
 98		func_name: str = func.__name__
 99		fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
100
101		if not force_overwrite and len(fig_path) > 0:
102			funcs_status[func_name] = True
103			continue
104
105		try:
106			func(attn_pattern, save_dir)
107			funcs_status[func_name] = True
108
109		# bling catch any exception
110		except Exception as e:  # noqa: BLE001
111			error_file = save_dir / f"{func.__name__}.error.txt"
112			error_file.write_text(str(e))
113			warnings.warn(
114				f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}",
115				stacklevel=2,
116			)
117			funcs_status[func_name] = e
118
119	return funcs_status

process a single head's attention pattern, running all the functions in ATTENTION_MATRIX_FIGURE_FUNCS on the attention pattern

[gotcha:] if force_overwrite is False, and we used a multi-figure function, it will skip all figures for that function if any are already saved and it assumes a format of {func_name}.{figure_name}.{fmt} for the saved figures

Parameters:

  • layer_idx : int
  • head_idx : int
  • attn_pattern : AttentionMatrix attention pattern for the head
  • save_dir : Path directory to save the figures to
  • force_overwrite : bool whether to overwrite existing figures. if False, will skip any functions which have already saved a figure (defaults to False)

Returns:

  • dict[str, bool | Exception] a dictionary of the status of each function, with the function name as the key and the status as the value
def compute_and_save_figures( model_cfg: 'HookedTransformerConfig|HTConfigMock', activations_path: pathlib.Path, cache: dict[str, numpy.ndarray] | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'], save_path: pathlib.Path = WindowsPath('attn_data'), force_overwrite: bool = False, track_results: bool = False) -> None:
122def compute_and_save_figures(
123	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
124	activations_path: Path,
125	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
126	save_path: Path = Path(DATA_DIR),
127	force_overwrite: bool = False,
128	track_results: bool = False,
129) -> None:
130	"""compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
131
132	# Parameters:
133	- `model_cfg : HookedTransformerConfig|HTConfigMock`
134	- `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]`
135	- `save_path : Path`
136		(defaults to `Path(DATA_DIR)`)
137	- `force_overwrite : bool`
138		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
139		(defaults to `False`)
140	- `track_results : bool`
141		whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
142		(defaults to `False`)
143	"""
144	prompt_dir: Path = activations_path.parent
145
146	if track_results:
147		results: defaultdict[
148			str,  # func name
149			dict[
150				tuple[int, int],  # layer, head
151				bool | Exception,  # success or exception
152			],
153		] = defaultdict(dict)
154
155	for layer_idx, head_idx in itertools.product(
156		range(model_cfg.n_layers),
157		range(model_cfg.n_heads),
158	):
159		attn_pattern: AttentionMatrix
160		if isinstance(cache, dict):
161			attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
162		elif isinstance(cache, np.ndarray):
163			attn_pattern = cache[layer_idx, head_idx]
164		else:
165			msg = (
166				f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
167			)
168			raise TypeError(
169				msg,
170			)
171
172		save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
173		save_dir.mkdir(parents=True, exist_ok=True)
174		head_res: dict[str, bool | Exception] = process_single_head(
175			layer_idx=layer_idx,
176			head_idx=head_idx,
177			attn_pattern=attn_pattern,
178			save_dir=save_dir,
179			force_overwrite=force_overwrite,
180		)
181
182		if track_results:
183			for func_name, status in head_res.items():
184				results[func_name][(layer_idx, head_idx)] = status
185
186	# TODO: do something with results
187
188	generate_prompts_jsonl(save_path / model_cfg.model_name)

compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

  • model_cfg : HookedTransformerConfig|HTConfigMock
  • cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
  • save_path : Path (defaults to Path(DATA_DIR))
  • force_overwrite : bool force overwrite of existing figures. if False, will skip any functions which have already saved a figure (defaults to False)
  • track_results : bool whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO (defaults to False)
def process_prompt( prompt: dict, model_cfg: 'HookedTransformerConfig|HTConfigMock', save_path: pathlib.Path, force_overwrite: bool = False) -> None:
191def process_prompt(
192	prompt: dict,
193	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
194	save_path: Path,
195	force_overwrite: bool = False,
196) -> None:
197	"""process a single prompt, loading the activations and computing and saving the figures
198
199	basically just calls `load_activations` and then `compute_and_save_figures`
200
201	# Parameters:
202	- `prompt : dict`
203	- `model_cfg : HookedTransformerConfig|HTConfigMock`
204	- `force_overwrite : bool`
205		(defaults to `False`)
206	"""
207	activations_path: Path
208	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
209	activations_path, cache = load_activations(
210		model_name=model_cfg.model_name,
211		prompt=prompt,
212		save_path=save_path,
213		return_fmt="numpy",
214	)
215
216	compute_and_save_figures(
217		model_cfg=model_cfg,
218		activations_path=activations_path,
219		cache=cache,
220		save_path=save_path,
221		force_overwrite=force_overwrite,
222	)

process a single prompt, loading the activations and computing and saving the figures

basically just calls load_activations and then compute_and_save_figures

Parameters:

  • prompt : dict
  • model_cfg : HookedTransformerConfig|HTConfigMock
  • force_overwrite : bool (defaults to False)
def figures_main( model_name: str, save_path: str, n_samples: int, force: bool, parallel: bool | int = True) -> None:
225def figures_main(
226	model_name: str,
227	save_path: str,
228	n_samples: int,
229	force: bool,
230	parallel: bool | int = True,
231) -> None:
232	"""main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
233
234	# Parameters:
235	- `model_name : str`
236		model name to use, used for loading the model config, prompts, activations, and saving the figures
237	- `save_path : str`
238		base path to look in
239	- `n_samples : int`
240		max number of samples to process
241	- `force : bool`
242		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
243	- `parallel : bool | int`
244		whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
245		(defaults to `True`)
246	"""
247	with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
248		# save model info or check if it exists
249		save_path_p: Path = Path(save_path)
250		model_path: Path = save_path_p / model_name
251		with open(model_path / "model_cfg.json", "r") as f:
252			model_cfg = HTConfigMock.load(json.load(f))
253
254	with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
255		# load prompts
256		with open(model_path / "prompts.jsonl", "r") as f:
257			prompts: list[dict] = [json.loads(line) for line in f.readlines()]
258		# truncate to n_samples
259		prompts = prompts[:n_samples]
260
261	print(f"{len(prompts)} prompts loaded")
262
263	print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded")
264	print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS]))
265
266	list(
267		run_maybe_parallel(
268			func=functools.partial(
269				process_prompt,
270				model_cfg=model_cfg,
271				save_path=save_path_p,
272				force_overwrite=force,
273			),
274			iterable=prompts,
275			parallel=parallel,
276			pbar="tqdm",
277			pbar_kwargs=dict(
278				desc="Making figures",
279				unit="prompt",
280			),
281		),
282	)
283
284	with SpinnerContext(
285		message="updating jsonl metadata for models and functions",
286		**SPINNER_KWARGS,
287	):
288		generate_models_jsonl(save_path_p)
289		generate_functions_jsonl(save_path_p)

main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

  • model_name : str model name to use, used for loading the model config, prompts, activations, and saving the figures
  • save_path : str base path to look in
  • n_samples : int max number of samples to process
  • force : bool force overwrite of existing figures. if False, will skip any functions which have already saved a figure
  • parallel : bool | int whether to run in parallel. if True, will use all available cores. if False, will run in serial. if an int, will try to use that many cores (defaults to True)
def main() -> None:
292def main() -> None:
293	"generates figures from the activations using the functions decorated with `register_attn_figure_func`"
294	print(DIVIDER_S1)
295	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
296		arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
297		# input and output
298		arg_parser.add_argument(
299			"--model",
300			"-m",
301			type=str,
302			required=True,
303			help="The model name(s) to use. comma separated with no whitespace if multiple",
304		)
305		arg_parser.add_argument(
306			"--save-path",
307			"-s",
308			type=str,
309			required=False,
310			help="The path to save the attention patterns",
311			default=DATA_DIR,
312		)
313		# number of samples
314		arg_parser.add_argument(
315			"--n-samples",
316			"-n",
317			type=int,
318			required=False,
319			help="The max number of samples to process, do all in the file if None",
320			default=None,
321		)
322		# force overwrite of existing figures
323		arg_parser.add_argument(
324			"--force",
325			"-f",
326			type=bool,
327			required=False,
328			help="Force overwrite of existing figures",
329			default=False,
330		)
331
332		args: argparse.Namespace = arg_parser.parse_args()
333
334	print(f"args parsed: {args}")
335
336	models: list[str]
337	if "," in args.model:
338		models = args.model.split(",")
339	else:
340		models = [args.model]
341
342	n_models: int = len(models)
343	for idx, model in enumerate(models):
344		print(DIVIDER_S2)
345		print(f"processing model {idx + 1} / {n_models}: {model}")
346		print(DIVIDER_S2)
347		figures_main(
348			model_name=model,
349			save_path=args.save_path,
350			n_samples=args.n_samples,
351			force=args.force,
352		)
353
354	print(DIVIDER_S1)

generates figures from the activations using the functions decorated with register_attn_figure_func