Coverage for pattern_lens\indexes.py: 95%
65 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-03 19:30 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-03 19:30 -0700
1"""writes indexes to the model directory for the frontend to use or for record keeping"""
3import importlib.metadata
4import importlib.resources
5import inspect
6import itertools
7import json
8from collections.abc import Callable
9from pathlib import Path
11import pattern_lens
12from pattern_lens.attn_figure_funcs import (
13 _FIGURE_NAMES_KEY,
14 ATTENTION_MATRIX_FIGURE_FUNCS,
15)
18def generate_prompts_jsonl(model_dir: Path) -> None:
19 """creates a `prompts.jsonl` file with all the prompts in the model directory
21 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file
22 """
23 prompts: list[dict] = list()
24 for prompt_dir in (model_dir / "prompts").iterdir():
25 prompt_file: Path = prompt_dir / "prompt.json"
26 if prompt_file.exists():
27 with open(prompt_file, "r") as f:
28 prompt_data: dict = json.load(f)
29 prompts.append(prompt_data)
31 with open(model_dir / "prompts.jsonl", "w") as f:
32 for prompt in prompts:
33 f.write(json.dumps(prompt))
34 f.write("\n")
37def generate_models_jsonl(path: Path) -> None:
38 """creates a `models.jsonl` file with all the models"""
39 models: list[dict] = list()
40 for model_dir in (path).iterdir():
41 model_cfg_path: Path = model_dir / "model_cfg.json"
42 if model_cfg_path.exists():
43 with open(model_cfg_path, "r") as f:
44 model_cfg: dict = json.load(f)
45 models.append(model_cfg)
47 with open(path / "models.jsonl", "w") as f:
48 for model in models:
49 f.write(json.dumps(model))
50 f.write("\n")
53def get_func_metadata(func: Callable) -> list[dict[str, str | None]]:
54 """get metadata for a function
56 # Parameters:
57 - `func : Callable` which has a `_FIGURE_NAMES_KEY` (by default `_figure_names`) attribute
59 # Returns:
61 `list[dict[str, str | None]]`
62 each dictionary is for a function, containing:
64 - `name : str` : the name of the figure
65 - `func_name : str`
66 the name of the function. if not a multi-figure function, this is identical to `name`
67 if it is a multi-figure function, then `name` is `{func_name}.{figure_name}`
68 - `doc : str` : the docstring of the function
69 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist
70 - `source : str | None` : the source file of the function
71 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read
73 """
74 source_file: str | None = inspect.getsourcefile(func)
75 output: dict[str, str | None] = dict(
76 func_name=func.__name__,
77 doc=func.__doc__,
78 figure_save_fmt=getattr(func, "figure_save_fmt", None),
79 source=Path(source_file).as_posix() if source_file else None,
80 )
82 try:
83 output["code"] = inspect.getsource(func)
84 except OSError:
85 output["code"] = None
87 fig_names: list[str] | None = getattr(func, _FIGURE_NAMES_KEY, None)
88 if fig_names:
89 return [
90 {
91 "name": func_name,
92 **output,
93 }
94 for func_name in fig_names
95 ]
96 else:
97 return [
98 {
99 "name": func.__name__,
100 **output,
101 },
102 ]
105def generate_functions_jsonl(path: Path) -> None:
106 "unions all functions from `figures.jsonl` and `ATTENTION_MATRIX_FIGURE_FUNCS` into the file"
107 figures_file: Path = path / "figures.jsonl"
108 existing_figures: dict[str, dict] = dict()
110 if figures_file.exists():
111 with open(figures_file, "r") as f:
112 for line in f:
113 func_data: dict = json.loads(line)
114 existing_figures[func_data["name"]] = func_data
116 # Add any new functions from ALL_FUNCTIONS
117 new_functions_lst: list[dict] = list(
118 itertools.chain.from_iterable(
119 get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS
120 ),
121 )
122 new_functions: dict[str, dict] = {func["name"]: func for func in new_functions_lst}
124 all_functions: list[dict] = list(
125 {
126 **existing_figures,
127 **new_functions,
128 }.values(),
129 )
131 with open(figures_file, "w") as f:
132 for func_meta in sorted(all_functions, key=lambda x: x["name"]):
133 json.dump(func_meta, f)
134 f.write("\n")
137def write_html_index(path: Path) -> None:
138 """writes an index.html file to the path"""
139 html_index: str = (
140 importlib.resources.files(pattern_lens)
141 .joinpath("frontend/index.html")
142 .read_text(encoding="utf-8")
143 )
144 pattern_lens_version: str = importlib.metadata.version("pattern-lens")
145 html_index = html_index.replace("$$PATTERN_LENS_VERSION$$", pattern_lens_version)
146 with open(path / "index.html", "w", encoding="utf-8") as f:
147 f.write(html_index)