Coverage for pattern_lens\indexes.py: 97%
60 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
1"""writes indexes to the model directory for the frontend to use or for record keeping"""
3import inspect
4import json
5from pathlib import Path
6import importlib.resources
7import importlib.metadata
8from typing import Callable
10import pattern_lens
11from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
14def generate_prompts_jsonl(model_dir: Path):
15 """creates a `prompts.jsonl` file with all the prompts in the model directory
17 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file
18 """
19 prompts: list[dict] = list()
20 for prompt_dir in (model_dir / "prompts").iterdir():
21 prompt_file: Path = prompt_dir / "prompt.json"
22 if prompt_file.exists():
23 with open(prompt_file, "r") as f:
24 prompt_data: dict = json.load(f)
25 prompts.append(prompt_data)
27 with open(model_dir / "prompts.jsonl", "w") as f:
28 for prompt in prompts:
29 f.write(json.dumps(prompt))
30 f.write("\n")
33def generate_models_jsonl(path: Path):
34 """creates a `models.jsonl` file with all the models"""
35 models: list[dict] = list()
36 for model_dir in (path).iterdir():
37 model_cfg_path: Path = model_dir / "model_cfg.json"
38 if model_cfg_path.exists():
39 with open(model_cfg_path, "r") as f:
40 model_cfg: dict = json.load(f)
41 models.append(model_cfg)
43 with open(path / "models.jsonl", "w") as f:
44 for model in models:
45 f.write(json.dumps(model))
46 f.write("\n")
49def get_func_metadata(func: Callable) -> dict[str, str | None]:
50 """get metadata for a function
52 # Parameters:
53 - `func : Callable`
55 # Returns:
57 `dict[str, str | None]`
58 dictionary:
60 - `name : str` : the name of the function
61 - `doc : str` : the docstring of the function
62 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist
63 - `source : str | None` : the source file of the function
64 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read
66 """
67 source_file: str | None = inspect.getsourcefile(func)
68 output: dict[str, str | None] = dict(
69 name=func.__name__,
70 doc=func.__doc__,
71 figure_save_fmt=getattr(func, "figure_save_fmt", None),
72 source=Path(source_file).as_posix() if source_file else None,
73 )
75 try:
76 output["code"] = inspect.getsource(func)
77 except OSError:
78 output["code"] = None
80 return output
83def generate_functions_jsonl(path: Path):
84 "unions all functions from file and current `ATTENTION_MATRIX_FIGURE_FUNCS` into a `functions.jsonl` file"
85 functions_file: Path = path / "functions.jsonl"
86 existing_functions: dict[str, dict] = dict()
88 if functions_file.exists():
89 with open(functions_file, "r") as f:
90 for line in f:
91 func_data: dict = json.loads(line)
92 existing_functions[func_data["name"]] = func_data
94 # Add any new functions from ALL_FUNCTIONS
95 new_functions: dict[str, dict] = {
96 func.__name__: get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS
97 }
99 all_functions: list[dict] = list(
100 {
101 **existing_functions,
102 **new_functions,
103 }.values()
104 )
106 with open(functions_file, "w") as f:
107 for func_meta in sorted(all_functions, key=lambda x: x["name"]):
108 json.dump(func_meta, f)
109 f.write("\n")
112def write_html_index(path: Path):
113 """writes an index.html file to the path"""
114 html_index: str = (
115 importlib.resources.files(pattern_lens)
116 .joinpath("frontend/index.html")
117 .read_text(encoding="utf-8")
118 )
119 pattern_lens_version: str = importlib.metadata.version("pattern-lens")
120 html_index = html_index.replace("$$PATTERN_LENS_VERSION$$", pattern_lens_version)
121 with open(path / "index.html", "w", encoding="utf-8") as f:
122 f.write(html_index)