Coverage for pattern_lens / indexes.py: 90%

73 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1"""writes indexes to the model directory for the frontend to use or for record keeping""" 

2 

3import importlib.resources 

4import inspect 

5import itertools 

6import json 

7from collections.abc import Callable 

8from pathlib import Path 

9 

10import pattern_lens 

11from pattern_lens.attn_figure_funcs import ( 

12 _FIGURE_NAMES_KEY, 

13 ATTENTION_MATRIX_FIGURE_FUNCS, 

14) 

15 

16 

17def generate_prompts_jsonl(model_dir: Path) -> None: 

18 """creates a `prompts.jsonl` file with all the prompts in the model directory 

19 

20 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file 

21 """ 

22 prompts: list[dict] = list() 

23 for prompt_dir in (model_dir / "prompts").iterdir(): 

24 prompt_file: Path = prompt_dir / "prompt.json" 

25 if prompt_file.exists(): 

26 with open(prompt_file, "r") as f: 

27 prompt_data: dict = json.load(f) 

28 prompts.append(prompt_data) 

29 

30 with open(model_dir / "prompts.jsonl", "w") as f: 

31 for prompt in prompts: 

32 f.write(json.dumps(prompt)) 

33 f.write("\n") 

34 

35 

36def generate_models_jsonl(path: Path) -> None: 

37 """creates a `models.jsonl` file with all the models""" 

38 models: list[dict] = list() 

39 for model_dir in (path).iterdir(): 

40 model_cfg_path: Path = model_dir / "model_cfg.json" 

41 if model_cfg_path.exists(): 

42 with open(model_cfg_path, "r") as f: 

43 model_cfg: dict = json.load(f) 

44 models.append(model_cfg) 

45 

46 with open(path / "models.jsonl", "w") as f: 

47 for model in models: 

48 f.write(json.dumps(model)) 

49 f.write("\n") 

50 

51 

52def get_func_metadata(func: Callable) -> list[dict[str, str | None]]: 

53 """get metadata for a function 

54 

55 # Parameters: 

56 - `func : Callable` which has a `_FIGURE_NAMES_KEY` (by default `_figure_names`) attribute 

57 

58 # Returns: 

59 

60 `list[dict[str, str | None]]` 

61 each dictionary is for a function, containing: 

62 

63 - `name : str` : the name of the figure 

64 - `func_name : str` 

65 the name of the function. if not a multi-figure function, this is identical to `name` 

66 if it is a multi-figure function, then `name` is `{func_name}.{figure_name}` 

67 - `doc : str` : the docstring of the function 

68 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist 

69 - `source : str | None` : the source file of the function 

70 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read 

71 

72 """ 

73 source_file: str | None = inspect.getsourcefile(func) 

74 func_name: str = getattr(func, "__name__", "<unknown>") 

75 output: dict[str, str | None] = dict( 

76 func_name=func_name, 

77 doc=getattr(func, "__doc__", None), 

78 figure_save_fmt=getattr(func, "figure_save_fmt", None), 

79 source=Path(source_file).as_posix() if source_file else None, 

80 ) 

81 

82 try: 

83 output["code"] = inspect.getsource(func) 

84 except OSError: 

85 output["code"] = None 

86 

87 fig_names: list[str] | None = getattr(func, _FIGURE_NAMES_KEY, None) 

88 if fig_names: 

89 return [ 

90 { 

91 "name": fig_name, 

92 **output, 

93 } 

94 for fig_name in fig_names 

95 ] 

96 else: 

97 return [ 

98 { 

99 "name": func_name, 

100 **output, 

101 }, 

102 ] 

103 

104 

105def generate_functions_jsonl(path: Path) -> None: 

106 "unions all functions from `figures.jsonl` and `ATTENTION_MATRIX_FIGURE_FUNCS` into the file" 

107 figures_file: Path = path / "figures.jsonl" 

108 existing_figures: dict[str, dict] = dict() 

109 

110 if figures_file.exists(): 

111 with open(figures_file, "r") as f: 

112 for line in f: 

113 func_data: dict = json.loads(line) 

114 existing_figures[func_data["name"]] = func_data 

115 

116 # Add any new functions from ALL_FUNCTIONS 

117 new_functions_lst: list[dict] = list( 

118 itertools.chain.from_iterable( 

119 get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS 

120 ), 

121 ) 

122 new_functions: dict[str, dict] = {func["name"]: func for func in new_functions_lst} 

123 

124 all_functions: list[dict] = list( 

125 { 

126 **existing_figures, 

127 **new_functions, 

128 }.values(), 

129 ) 

130 

131 with open(figures_file, "w") as f: 

132 for func_meta in sorted(all_functions, key=lambda x: x["name"]): 

133 json.dump(func_meta, f) 

134 f.write("\n") 

135 

136 

137def write_html_index( 

138 path: Path, 

139 cfg_single: dict | None = None, 

140 cfg_patternlens: dict | None = None, 

141) -> None: 

142 """writes index.html and single.html files to the path""" 

143 # TYPING: error: Argument 1 to "Path" has incompatible type "Traversable"; expected "str | PathLike[str]" [arg-type] 

144 frontend_resources_path: Path = Path( 

145 importlib.resources.files(pattern_lens).joinpath("frontend"), # type: ignore[arg-type] 

146 ) 

147 

148 pl_index_html: str = (frontend_resources_path / "patternlens.html").read_text() 

149 sg_html: str = (frontend_resources_path / "single.html").read_text() 

150 

151 # Write both html files 

152 with open(path / "index.html", "w", encoding="utf-8") as f: 

153 f.write(pl_index_html) 

154 

155 with open(path / "single.html", "w", encoding="utf-8") as f: 

156 f.write(sg_html) 

157 

158 # write the config files if they are provided 

159 if cfg_single is not None: 

160 with open(path / "sg_cfg.json", "w", encoding="utf-8") as f: 

161 json.dump(cfg_single, f, indent="\t") 

162 

163 if cfg_patternlens is not None: 

164 with open(path / "pl_cfg.json", "w", encoding="utf-8") as f: 

165 json.dump(cfg_patternlens, f, indent="\t")