Coverage for pattern_lens\indexes.py: 95%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-03 19:30 -0700

1"""writes indexes to the model directory for the frontend to use or for record keeping""" 

2 

3import importlib.metadata 

4import importlib.resources 

5import inspect 

6import itertools 

7import json 

8from collections.abc import Callable 

9from pathlib import Path 

10 

11import pattern_lens 

12from pattern_lens.attn_figure_funcs import ( 

13 _FIGURE_NAMES_KEY, 

14 ATTENTION_MATRIX_FIGURE_FUNCS, 

15) 

16 

17 

18def generate_prompts_jsonl(model_dir: Path) -> None: 

19 """creates a `prompts.jsonl` file with all the prompts in the model directory 

20 

21 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file 

22 """ 

23 prompts: list[dict] = list() 

24 for prompt_dir in (model_dir / "prompts").iterdir(): 

25 prompt_file: Path = prompt_dir / "prompt.json" 

26 if prompt_file.exists(): 

27 with open(prompt_file, "r") as f: 

28 prompt_data: dict = json.load(f) 

29 prompts.append(prompt_data) 

30 

31 with open(model_dir / "prompts.jsonl", "w") as f: 

32 for prompt in prompts: 

33 f.write(json.dumps(prompt)) 

34 f.write("\n") 

35 

36 

37def generate_models_jsonl(path: Path) -> None: 

38 """creates a `models.jsonl` file with all the models""" 

39 models: list[dict] = list() 

40 for model_dir in (path).iterdir(): 

41 model_cfg_path: Path = model_dir / "model_cfg.json" 

42 if model_cfg_path.exists(): 

43 with open(model_cfg_path, "r") as f: 

44 model_cfg: dict = json.load(f) 

45 models.append(model_cfg) 

46 

47 with open(path / "models.jsonl", "w") as f: 

48 for model in models: 

49 f.write(json.dumps(model)) 

50 f.write("\n") 

51 

52 

53def get_func_metadata(func: Callable) -> list[dict[str, str | None]]: 

54 """get metadata for a function 

55 

56 # Parameters: 

57 - `func : Callable` which has a `_FIGURE_NAMES_KEY` (by default `_figure_names`) attribute 

58 

59 # Returns: 

60 

61 `list[dict[str, str | None]]` 

62 each dictionary is for a function, containing: 

63 

64 - `name : str` : the name of the figure 

65 - `func_name : str` 

66 the name of the function. if not a multi-figure function, this is identical to `name` 

67 if it is a multi-figure function, then `name` is `{func_name}.{figure_name}` 

68 - `doc : str` : the docstring of the function 

69 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist 

70 - `source : str | None` : the source file of the function 

71 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read 

72 

73 """ 

74 source_file: str | None = inspect.getsourcefile(func) 

75 output: dict[str, str | None] = dict( 

76 func_name=func.__name__, 

77 doc=func.__doc__, 

78 figure_save_fmt=getattr(func, "figure_save_fmt", None), 

79 source=Path(source_file).as_posix() if source_file else None, 

80 ) 

81 

82 try: 

83 output["code"] = inspect.getsource(func) 

84 except OSError: 

85 output["code"] = None 

86 

87 fig_names: list[str] | None = getattr(func, _FIGURE_NAMES_KEY, None) 

88 if fig_names: 

89 return [ 

90 { 

91 "name": func_name, 

92 **output, 

93 } 

94 for func_name in fig_names 

95 ] 

96 else: 

97 return [ 

98 { 

99 "name": func.__name__, 

100 **output, 

101 }, 

102 ] 

103 

104 

105def generate_functions_jsonl(path: Path) -> None: 

106 "unions all functions from `figures.jsonl` and `ATTENTION_MATRIX_FIGURE_FUNCS` into the file" 

107 figures_file: Path = path / "figures.jsonl" 

108 existing_figures: dict[str, dict] = dict() 

109 

110 if figures_file.exists(): 

111 with open(figures_file, "r") as f: 

112 for line in f: 

113 func_data: dict = json.loads(line) 

114 existing_figures[func_data["name"]] = func_data 

115 

116 # Add any new functions from ALL_FUNCTIONS 

117 new_functions_lst: list[dict] = list( 

118 itertools.chain.from_iterable( 

119 get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS 

120 ), 

121 ) 

122 new_functions: dict[str, dict] = {func["name"]: func for func in new_functions_lst} 

123 

124 all_functions: list[dict] = list( 

125 { 

126 **existing_figures, 

127 **new_functions, 

128 }.values(), 

129 ) 

130 

131 with open(figures_file, "w") as f: 

132 for func_meta in sorted(all_functions, key=lambda x: x["name"]): 

133 json.dump(func_meta, f) 

134 f.write("\n") 

135 

136 

137def write_html_index(path: Path) -> None: 

138 """writes an index.html file to the path""" 

139 html_index: str = ( 

140 importlib.resources.files(pattern_lens) 

141 .joinpath("frontend/index.html") 

142 .read_text(encoding="utf-8") 

143 ) 

144 pattern_lens_version: str = importlib.metadata.version("pattern-lens") 

145 html_index = html_index.replace("$$PATTERN_LENS_VERSION$$", pattern_lens_version) 

146 with open(path / "index.html", "w", encoding="utf-8") as f: 

147 f.write(html_index)