Coverage for pattern_lens\indexes.py: 97%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 20:39 -0700

1"""writes indexes to the model directory for the frontend to use or for record keeping""" 

2 

3import inspect 

4import json 

5from pathlib import Path 

6import importlib.resources 

7import importlib.metadata 

8from typing import Callable 

9 

10import pattern_lens 

11from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 

12 

13 

14def generate_prompts_jsonl(model_dir: Path): 

15 """creates a `prompts.jsonl` file with all the prompts in the model directory 

16 

17 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file 

18 """ 

19 prompts: list[dict] = list() 

20 for prompt_dir in (model_dir / "prompts").iterdir(): 

21 prompt_file: Path = prompt_dir / "prompt.json" 

22 if prompt_file.exists(): 

23 with open(prompt_file, "r") as f: 

24 prompt_data: dict = json.load(f) 

25 prompts.append(prompt_data) 

26 

27 with open(model_dir / "prompts.jsonl", "w") as f: 

28 for prompt in prompts: 

29 f.write(json.dumps(prompt)) 

30 f.write("\n") 

31 

32 

33def generate_models_jsonl(path: Path): 

34 """creates a `models.jsonl` file with all the models""" 

35 models: list[dict] = list() 

36 for model_dir in (path).iterdir(): 

37 model_cfg_path: Path = model_dir / "model_cfg.json" 

38 if model_cfg_path.exists(): 

39 with open(model_cfg_path, "r") as f: 

40 model_cfg: dict = json.load(f) 

41 models.append(model_cfg) 

42 

43 with open(path / "models.jsonl", "w") as f: 

44 for model in models: 

45 f.write(json.dumps(model)) 

46 f.write("\n") 

47 

48 

49def get_func_metadata(func: Callable) -> dict[str, str | None]: 

50 """get metadata for a function 

51 

52 # Parameters: 

53 - `func : Callable` 

54 

55 # Returns: 

56 

57 `dict[str, str | None]` 

58 dictionary: 

59 

60 - `name : str` : the name of the function 

61 - `doc : str` : the docstring of the function 

62 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist 

63 - `source : str | None` : the source file of the function 

64 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read 

65 

66 """ 

67 source_file: str | None = inspect.getsourcefile(func) 

68 output: dict[str, str | None] = dict( 

69 name=func.__name__, 

70 doc=func.__doc__, 

71 figure_save_fmt=getattr(func, "figure_save_fmt", None), 

72 source=Path(source_file).as_posix() if source_file else None, 

73 ) 

74 

75 try: 

76 output["code"] = inspect.getsource(func) 

77 except OSError: 

78 output["code"] = None 

79 

80 return output 

81 

82 

83def generate_functions_jsonl(path: Path): 

84 "unions all functions from file and current `ATTENTION_MATRIX_FIGURE_FUNCS` into a `functions.jsonl` file" 

85 functions_file: Path = path / "functions.jsonl" 

86 existing_functions: dict[str, dict] = dict() 

87 

88 if functions_file.exists(): 

89 with open(functions_file, "r") as f: 

90 for line in f: 

91 func_data: dict = json.loads(line) 

92 existing_functions[func_data["name"]] = func_data 

93 

94 # Add any new functions from ALL_FUNCTIONS 

95 new_functions: dict[str, dict] = { 

96 func.__name__: get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS 

97 } 

98 

99 all_functions: list[dict] = list( 

100 { 

101 **existing_functions, 

102 **new_functions, 

103 }.values() 

104 ) 

105 

106 with open(functions_file, "w") as f: 

107 for func_meta in sorted(all_functions, key=lambda x: x["name"]): 

108 json.dump(func_meta, f) 

109 f.write("\n") 

110 

111 

112def write_html_index(path: Path): 

113 """writes an index.html file to the path""" 

114 html_index: str = ( 

115 importlib.resources.files(pattern_lens) 

116 .joinpath("frontend/index.html") 

117 .read_text(encoding="utf-8") 

118 ) 

119 pattern_lens_version: str = importlib.metadata.version("pattern-lens") 

120 html_index = html_index.replace("$$PATTERN_LENS_VERSION$$", pattern_lens_version) 

121 with open(path / "index.html", "w", encoding="utf-8") as f: 

122 f.write(html_index)