Coverage for pattern_lens / indexes.py: 90%
73 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1"""writes indexes to the model directory for the frontend to use or for record keeping"""
3import importlib.resources
4import inspect
5import itertools
6import json
7from collections.abc import Callable
8from pathlib import Path
10import pattern_lens
11from pattern_lens.attn_figure_funcs import (
12 _FIGURE_NAMES_KEY,
13 ATTENTION_MATRIX_FIGURE_FUNCS,
14)
17def generate_prompts_jsonl(model_dir: Path) -> None:
18 """creates a `prompts.jsonl` file with all the prompts in the model directory
20 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file
21 """
22 prompts: list[dict] = list()
23 for prompt_dir in (model_dir / "prompts").iterdir():
24 prompt_file: Path = prompt_dir / "prompt.json"
25 if prompt_file.exists():
26 with open(prompt_file, "r") as f:
27 prompt_data: dict = json.load(f)
28 prompts.append(prompt_data)
30 with open(model_dir / "prompts.jsonl", "w") as f:
31 for prompt in prompts:
32 f.write(json.dumps(prompt))
33 f.write("\n")
36def generate_models_jsonl(path: Path) -> None:
37 """creates a `models.jsonl` file with all the models"""
38 models: list[dict] = list()
39 for model_dir in (path).iterdir():
40 model_cfg_path: Path = model_dir / "model_cfg.json"
41 if model_cfg_path.exists():
42 with open(model_cfg_path, "r") as f:
43 model_cfg: dict = json.load(f)
44 models.append(model_cfg)
46 with open(path / "models.jsonl", "w") as f:
47 for model in models:
48 f.write(json.dumps(model))
49 f.write("\n")
52def get_func_metadata(func: Callable) -> list[dict[str, str | None]]:
53 """get metadata for a function
55 # Parameters:
56 - `func : Callable` which has a `_FIGURE_NAMES_KEY` (by default `_figure_names`) attribute
58 # Returns:
60 `list[dict[str, str | None]]`
61 each dictionary is for a function, containing:
63 - `name : str` : the name of the figure
64 - `func_name : str`
65 the name of the function. if not a multi-figure function, this is identical to `name`
66 if it is a multi-figure function, then `name` is `{func_name}.{figure_name}`
67 - `doc : str` : the docstring of the function
68 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist
69 - `source : str | None` : the source file of the function
70 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read
72 """
73 source_file: str | None = inspect.getsourcefile(func)
74 func_name: str = getattr(func, "__name__", "<unknown>")
75 output: dict[str, str | None] = dict(
76 func_name=func_name,
77 doc=getattr(func, "__doc__", None),
78 figure_save_fmt=getattr(func, "figure_save_fmt", None),
79 source=Path(source_file).as_posix() if source_file else None,
80 )
82 try:
83 output["code"] = inspect.getsource(func)
84 except OSError:
85 output["code"] = None
87 fig_names: list[str] | None = getattr(func, _FIGURE_NAMES_KEY, None)
88 if fig_names:
89 return [
90 {
91 "name": fig_name,
92 **output,
93 }
94 for fig_name in fig_names
95 ]
96 else:
97 return [
98 {
99 "name": func_name,
100 **output,
101 },
102 ]
105def generate_functions_jsonl(path: Path) -> None:
106 "unions all functions from `figures.jsonl` and `ATTENTION_MATRIX_FIGURE_FUNCS` into the file"
107 figures_file: Path = path / "figures.jsonl"
108 existing_figures: dict[str, dict] = dict()
110 if figures_file.exists():
111 with open(figures_file, "r") as f:
112 for line in f:
113 func_data: dict = json.loads(line)
114 existing_figures[func_data["name"]] = func_data
116 # Add any new functions from ALL_FUNCTIONS
117 new_functions_lst: list[dict] = list(
118 itertools.chain.from_iterable(
119 get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS
120 ),
121 )
122 new_functions: dict[str, dict] = {func["name"]: func for func in new_functions_lst}
124 all_functions: list[dict] = list(
125 {
126 **existing_figures,
127 **new_functions,
128 }.values(),
129 )
131 with open(figures_file, "w") as f:
132 for func_meta in sorted(all_functions, key=lambda x: x["name"]):
133 json.dump(func_meta, f)
134 f.write("\n")
137def write_html_index(
138 path: Path,
139 cfg_single: dict | None = None,
140 cfg_patternlens: dict | None = None,
141) -> None:
142 """writes index.html and single.html files to the path"""
143 # TYPING: error: Argument 1 to "Path" has incompatible type "Traversable"; expected "str | PathLike[str]" [arg-type]
144 frontend_resources_path: Path = Path(
145 importlib.resources.files(pattern_lens).joinpath("frontend"), # type: ignore[arg-type]
146 )
148 pl_index_html: str = (frontend_resources_path / "patternlens.html").read_text()
149 sg_html: str = (frontend_resources_path / "single.html").read_text()
151 # Write both html files
152 with open(path / "index.html", "w", encoding="utf-8") as f:
153 f.write(pl_index_html)
155 with open(path / "single.html", "w", encoding="utf-8") as f:
156 f.write(sg_html)
158 # write the config files if they are provided
159 if cfg_single is not None:
160 with open(path / "sg_cfg.json", "w", encoding="utf-8") as f:
161 json.dump(cfg_single, f, indent="\t")
163 if cfg_patternlens is not None:
164 with open(path / "pl_cfg.json", "w", encoding="utf-8") as f:
165 json.dump(cfg_patternlens, f, indent="\t")