Coverage for pattern_lens/indexes.py: 94%
79 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
1"""writes indexes to the model directory for the frontend to use or for record keeping"""
3import importlib.metadata
4import importlib.resources
5import inspect
6import itertools
7import json
8from collections.abc import Callable
9from pathlib import Path
10from typing import Literal
12import pattern_lens
13from pattern_lens.attn_figure_funcs import (
14 _FIGURE_NAMES_KEY,
15 ATTENTION_MATRIX_FIGURE_FUNCS,
16)
19def generate_prompts_jsonl(model_dir: Path) -> None:
20 """creates a `prompts.jsonl` file with all the prompts in the model directory
22 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file
23 """
24 prompts: list[dict] = list()
25 for prompt_dir in (model_dir / "prompts").iterdir():
26 prompt_file: Path = prompt_dir / "prompt.json"
27 if prompt_file.exists():
28 with open(prompt_file, "r") as f:
29 prompt_data: dict = json.load(f)
30 prompts.append(prompt_data)
32 with open(model_dir / "prompts.jsonl", "w") as f:
33 for prompt in prompts:
34 f.write(json.dumps(prompt))
35 f.write("\n")
38def generate_models_jsonl(path: Path) -> None:
39 """creates a `models.jsonl` file with all the models"""
40 models: list[dict] = list()
41 for model_dir in (path).iterdir():
42 model_cfg_path: Path = model_dir / "model_cfg.json"
43 if model_cfg_path.exists():
44 with open(model_cfg_path, "r") as f:
45 model_cfg: dict = json.load(f)
46 models.append(model_cfg)
48 with open(path / "models.jsonl", "w") as f:
49 for model in models:
50 f.write(json.dumps(model))
51 f.write("\n")
54def get_func_metadata(func: Callable) -> list[dict[str, str | None]]:
55 """get metadata for a function
57 # Parameters:
58 - `func : Callable` which has a `_FIGURE_NAMES_KEY` (by default `_figure_names`) attribute
60 # Returns:
62 `list[dict[str, str | None]]`
63 each dictionary is for a function, containing:
65 - `name : str` : the name of the figure
66 - `func_name : str`
67 the name of the function. if not a multi-figure function, this is identical to `name`
68 if it is a multi-figure function, then `name` is `{func_name}.{figure_name}`
69 - `doc : str` : the docstring of the function
70 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist
71 - `source : str | None` : the source file of the function
72 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read
74 """
75 source_file: str | None = inspect.getsourcefile(func)
76 output: dict[str, str | None] = dict(
77 func_name=func.__name__,
78 doc=func.__doc__,
79 figure_save_fmt=getattr(func, "figure_save_fmt", None),
80 source=Path(source_file).as_posix() if source_file else None,
81 )
83 try:
84 output["code"] = inspect.getsource(func)
85 except OSError:
86 output["code"] = None
88 fig_names: list[str] | None = getattr(func, _FIGURE_NAMES_KEY, None)
89 if fig_names:
90 return [
91 {
92 "name": func_name,
93 **output,
94 }
95 for func_name in fig_names
96 ]
97 else:
98 return [
99 {
100 "name": func.__name__,
101 **output,
102 },
103 ]
106def generate_functions_jsonl(path: Path) -> None:
107 "unions all functions from `figures.jsonl` and `ATTENTION_MATRIX_FIGURE_FUNCS` into the file"
108 figures_file: Path = path / "figures.jsonl"
109 existing_figures: dict[str, dict] = dict()
111 if figures_file.exists():
112 with open(figures_file, "r") as f:
113 for line in f:
114 func_data: dict = json.loads(line)
115 existing_figures[func_data["name"]] = func_data
117 # Add any new functions from ALL_FUNCTIONS
118 new_functions_lst: list[dict] = list(
119 itertools.chain.from_iterable(
120 get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS
121 ),
122 )
123 new_functions: dict[str, dict] = {func["name"]: func for func in new_functions_lst}
125 all_functions: list[dict] = list(
126 {
127 **existing_figures,
128 **new_functions,
129 }.values(),
130 )
132 with open(figures_file, "w") as f:
133 for func_meta in sorted(all_functions, key=lambda x: x["name"]):
134 json.dump(func_meta, f)
135 f.write("\n")
138def inline_assets(
139 html: str,
140 assets: list[tuple[Literal["script", "style"], str]],
141 base_path: Path,
142) -> str:
143 """Inline specified local CSS/JS files into an HTML document.
145 Each entry in `assets` should be a tuple like `("script", "app.js")` or `("style", "style.css")`.
147 # Parameters:
148 - `html : str`
149 input HTML content.
150 - `assets : list[tuple[Literal["script", "style"], str]]`
151 List of (tag_type, filename) tuples to inline.
153 # Returns:
154 `str` : Modified HTML content with inlined assets.
155 """
156 for tag_type, filename in assets:
157 if tag_type not in ("style", "script"):
158 err_msg: str = f"Unsupported tag type: {tag_type}"
159 raise ValueError(err_msg)
161 # Dynamically create the pattern for the given tag and filename
162 pattern: str = rf'<{tag_type} src="{filename}"></{tag_type}>'
163 # assert it's in the text exactly once
164 assert html.count(pattern) == 1, (
165 f"Pattern {pattern} should be in the html exactly once, found {html.count(pattern) = }"
166 )
167 # read the content and create the replacement
168 content: str = (base_path / filename).read_text()
169 replacement: str = f"<{tag_type}>\n{content}\n</{tag_type}>"
170 # perform the replacement
171 html = html.replace(pattern, replacement)
173 return html
176def write_html_index(path: Path) -> None:
177 """writes an index.html file to the path"""
178 # TYPING: error: Argument 1 to "Path" has incompatible type "Traversable"; expected "str | PathLike[str]" [arg-type]
179 frontend_resources_path: Path = Path(
180 importlib.resources.files(pattern_lens).joinpath("frontend"), # type: ignore[arg-type]
181 )
182 html_index: str = (frontend_resources_path / "index.template.html").read_text(
183 encoding="utf-8",
184 )
185 # inline assets
186 html_index = inline_assets(
187 html_index,
188 [
189 ("style", "style.css"),
190 ("script", "util.js"),
191 ("script", "app.js"),
192 ],
193 base_path=frontend_resources_path,
194 )
196 # add version
197 pattern_lens_version: str = importlib.metadata.version("pattern-lens")
198 html_index = html_index.replace("$$PATTERN_LENS_VERSION$$", pattern_lens_version)
199 # write the index.html file
200 with open(path / "index.html", "w", encoding="utf-8") as f:
201 f.write(html_index)