Coverage for pattern_lens/indexes.py: 94%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-06 15:09 -0600

1"""writes indexes to the model directory for the frontend to use or for record keeping""" 

2 

3import importlib.metadata 

4import importlib.resources 

5import inspect 

6import itertools 

7import json 

8from collections.abc import Callable 

9from pathlib import Path 

10from typing import Literal 

11 

12import pattern_lens 

13from pattern_lens.attn_figure_funcs import ( 

14 _FIGURE_NAMES_KEY, 

15 ATTENTION_MATRIX_FIGURE_FUNCS, 

16) 

17 

18 

19def generate_prompts_jsonl(model_dir: Path) -> None: 

20 """creates a `prompts.jsonl` file with all the prompts in the model directory 

21 

22 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file 

23 """ 

24 prompts: list[dict] = list() 

25 for prompt_dir in (model_dir / "prompts").iterdir(): 

26 prompt_file: Path = prompt_dir / "prompt.json" 

27 if prompt_file.exists(): 

28 with open(prompt_file, "r") as f: 

29 prompt_data: dict = json.load(f) 

30 prompts.append(prompt_data) 

31 

32 with open(model_dir / "prompts.jsonl", "w") as f: 

33 for prompt in prompts: 

34 f.write(json.dumps(prompt)) 

35 f.write("\n") 

36 

37 

38def generate_models_jsonl(path: Path) -> None: 

39 """creates a `models.jsonl` file with all the models""" 

40 models: list[dict] = list() 

41 for model_dir in (path).iterdir(): 

42 model_cfg_path: Path = model_dir / "model_cfg.json" 

43 if model_cfg_path.exists(): 

44 with open(model_cfg_path, "r") as f: 

45 model_cfg: dict = json.load(f) 

46 models.append(model_cfg) 

47 

48 with open(path / "models.jsonl", "w") as f: 

49 for model in models: 

50 f.write(json.dumps(model)) 

51 f.write("\n") 

52 

53 

54def get_func_metadata(func: Callable) -> list[dict[str, str | None]]: 

55 """get metadata for a function 

56 

57 # Parameters: 

58 - `func : Callable` which has a `_FIGURE_NAMES_KEY` (by default `_figure_names`) attribute 

59 

60 # Returns: 

61 

62 `list[dict[str, str | None]]` 

63 each dictionary is for a function, containing: 

64 

65 - `name : str` : the name of the figure 

66 - `func_name : str` 

67 the name of the function. if not a multi-figure function, this is identical to `name` 

68 if it is a multi-figure function, then `name` is `{func_name}.{figure_name}` 

69 - `doc : str` : the docstring of the function 

70 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist 

71 - `source : str | None` : the source file of the function 

72 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read 

73 

74 """ 

75 source_file: str | None = inspect.getsourcefile(func) 

76 output: dict[str, str | None] = dict( 

77 func_name=func.__name__, 

78 doc=func.__doc__, 

79 figure_save_fmt=getattr(func, "figure_save_fmt", None), 

80 source=Path(source_file).as_posix() if source_file else None, 

81 ) 

82 

83 try: 

84 output["code"] = inspect.getsource(func) 

85 except OSError: 

86 output["code"] = None 

87 

88 fig_names: list[str] | None = getattr(func, _FIGURE_NAMES_KEY, None) 

89 if fig_names: 

90 return [ 

91 { 

92 "name": func_name, 

93 **output, 

94 } 

95 for func_name in fig_names 

96 ] 

97 else: 

98 return [ 

99 { 

100 "name": func.__name__, 

101 **output, 

102 }, 

103 ] 

104 

105 

106def generate_functions_jsonl(path: Path) -> None: 

107 "unions all functions from `figures.jsonl` and `ATTENTION_MATRIX_FIGURE_FUNCS` into the file" 

108 figures_file: Path = path / "figures.jsonl" 

109 existing_figures: dict[str, dict] = dict() 

110 

111 if figures_file.exists(): 

112 with open(figures_file, "r") as f: 

113 for line in f: 

114 func_data: dict = json.loads(line) 

115 existing_figures[func_data["name"]] = func_data 

116 

117 # Add any new functions from ALL_FUNCTIONS 

118 new_functions_lst: list[dict] = list( 

119 itertools.chain.from_iterable( 

120 get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS 

121 ), 

122 ) 

123 new_functions: dict[str, dict] = {func["name"]: func for func in new_functions_lst} 

124 

125 all_functions: list[dict] = list( 

126 { 

127 **existing_figures, 

128 **new_functions, 

129 }.values(), 

130 ) 

131 

132 with open(figures_file, "w") as f: 

133 for func_meta in sorted(all_functions, key=lambda x: x["name"]): 

134 json.dump(func_meta, f) 

135 f.write("\n") 

136 

137 

138def inline_assets( 

139 html: str, 

140 assets: list[tuple[Literal["script", "style"], str]], 

141 base_path: Path, 

142) -> str: 

143 """Inline specified local CSS/JS files into an HTML document. 

144 

145 Each entry in `assets` should be a tuple like `("script", "app.js")` or `("style", "style.css")`. 

146 

147 # Parameters: 

148 - `html : str` 

149 input HTML content. 

150 - `assets : list[tuple[Literal["script", "style"], str]]` 

151 List of (tag_type, filename) tuples to inline. 

152 

153 # Returns: 

154 `str` : Modified HTML content with inlined assets. 

155 """ 

156 for tag_type, filename in assets: 

157 if tag_type not in ("style", "script"): 

158 err_msg: str = f"Unsupported tag type: {tag_type}" 

159 raise ValueError(err_msg) 

160 

161 # Dynamically create the pattern for the given tag and filename 

162 pattern: str = rf'<{tag_type} src="{filename}"></{tag_type}>' 

163 # assert it's in the text exactly once 

164 assert html.count(pattern) == 1, ( 

165 f"Pattern {pattern} should be in the html exactly once, found {html.count(pattern) = }" 

166 ) 

167 # read the content and create the replacement 

168 content: str = (base_path / filename).read_text() 

169 replacement: str = f"<{tag_type}>\n{content}\n</{tag_type}>" 

170 # perform the replacement 

171 html = html.replace(pattern, replacement) 

172 

173 return html 

174 

175 

176def write_html_index(path: Path) -> None: 

177 """writes an index.html file to the path""" 

178 # TYPING: error: Argument 1 to "Path" has incompatible type "Traversable"; expected "str | PathLike[str]" [arg-type] 

179 frontend_resources_path: Path = Path( 

180 importlib.resources.files(pattern_lens).joinpath("frontend"), # type: ignore[arg-type] 

181 ) 

182 html_index: str = (frontend_resources_path / "index.template.html").read_text( 

183 encoding="utf-8", 

184 ) 

185 # inline assets 

186 html_index = inline_assets( 

187 html_index, 

188 [ 

189 ("style", "style.css"), 

190 ("script", "util.js"), 

191 ("script", "app.js"), 

192 ], 

193 base_path=frontend_resources_path, 

194 ) 

195 

196 # add version 

197 pattern_lens_version: str = importlib.metadata.version("pattern-lens") 

198 html_index = html_index.replace("$$PATTERN_LENS_VERSION$$", pattern_lens_version) 

199 # write the index.html file 

200 with open(path / "index.html", "w", encoding="utf-8") as f: 

201 f.write(html_index)