Coverage for coverage_sh/plugin.py: 64%

122 statements  

« prev     ^ index     » next       coverage.py v7.3.3, created at 2024-01-02 20:32 +0100

1# SPDX-License-Identifier: MIT 

2# Copyright (c) 2023-2024 Kilian Lackhove 

3 

4from __future__ import annotations 

5 

6import atexit 

7import inspect 

8import os 

9import subprocess 

10from collections import defaultdict 

11from pathlib import Path 

12from shutil import which 

13from socket import gethostname 

14from typing import TYPE_CHECKING, Any, Iterable 

15 

16import coverage 

17import magic 

18from coverage import CoveragePlugin, FileReporter, FileTracer 

19from coverage.sqldata import filename_suffix 

20from tree_sitter_languages import get_parser 

21 

22if TYPE_CHECKING: 

23 from coverage.types import TLineNo 

24 from tree_sitter import Node 

25 

26SUPPORTED_SHELLS = { 

27 "sh", 

28 "/bin/sh", 

29 "/usr/bin/sh", 

30 which("sh"), 

31 "bash", 

32 "/bin/bash", 

33 "/usr/bin/bash", 

34 which("bash"), 

35} 

36EXECUTABLE_NODE_TYPES = { 

37 "subshell", 

38 "redirected_statement", 

39 "variable_assignment", 

40 "variable_assignments", 

41 "command", 

42 "declaration_command", 

43 "unset_command", 

44 "test_command", 

45 "negated_command", 

46 "for_statement", 

47 "c_style_for_statement", 

48 "while_statement", 

49 "if_statement", 

50 "case_statement", 

51 "pipeline", 

52 "list", 

53} 

54SUPPORTED_MIME_TYPES = ("text/x-shellscript",) 

55 

56parser = get_parser("bash") 

57 

58 

59class ShellFileReporter(FileReporter): 

60 def __init__(self, filename: str) -> None: 

61 super().__init__(filename) 

62 

63 self.path = Path(filename) 

64 self._content = None 

65 self._executable_lines = set() 

66 

67 def source(self) -> str: 

68 if self._content is None: 

69 self._content = self.path.read_text() 

70 

71 return self._content 

72 

73 def _parse_ast(self, node: Node) -> None: 

74 if node.is_named and node.type in EXECUTABLE_NODE_TYPES: 

75 self._executable_lines.add(node.start_point[0] + 1) 

76 

77 for child in node.children: 

78 self._parse_ast(child) 

79 

80 def lines(self) -> set[TLineNo]: 

81 tree = parser.parse(self.source().encode("utf-8")) 

82 self._parse_ast(tree.root_node) 

83 

84 return self._executable_lines 

85 

86 

87OriginalPopen = subprocess.Popen 

88 

89 

90class PatchedPopen(OriginalPopen): 

91 tracefiles_dir_path: None | Path = None 

92 

93 def __init__(self, *args, **kwargs): 

94 self._tracefile_path = None 

95 self._tracefile_fd = None 

96 

97 # convert args into kwargs 

98 sig = inspect.signature(subprocess.Popen) 

99 kwargs.update(dict(zip(sig.parameters.keys(), args))) 

100 

101 executable = kwargs.get("executable") 

102 args: list[str] = kwargs.get("args") 

103 

104 if coverage.Coverage.current() is not None and ( 

105 args[0] in SUPPORTED_SHELLS or executable in SUPPORTED_SHELLS 

106 ): 

107 self._init_trace(kwargs) 

108 else: 

109 super().__init__(**kwargs) 

110 

111 def _init_trace(self, kwargs: dict[str, Any]) -> None: 

112 self._tracefile_path = ( 

113 self.tracefiles_dir_path / f"shelltrace.{filename_suffix(suffix=True)}" 

114 ) 

115 self._tracefile_path.parent.mkdir(parents=True, exist_ok=True) 

116 self._tracefile_path.touch() 

117 self._tracefile_fd = os.open(self._tracefile_path, flags=os.O_RDWR | os.O_CREAT) 

118 

119 env = kwargs.get("env", os.environ.copy()) 

120 env["BASH_XTRACEFD"] = str(self._tracefile_fd) 

121 env["PS4"] = "COV:::$BASH_SOURCE:::$LINENO:::" 

122 kwargs["env"] = env 

123 

124 args = list(kwargs.get("args", ())) 

125 args.insert(1, "-x") 

126 kwargs["args"] = args 

127 

128 pass_fds = list(kwargs.get("pass_fds", ())) 

129 pass_fds.append(self._tracefile_fd) 

130 kwargs["pass_fds"] = pass_fds 

131 

132 super().__init__(**kwargs) 

133 

134 def __del__(self): 

135 if self._tracefile_fd is not None: 

136 os.close(self._tracefile_fd) 

137 super().__del__() 

138 

139 

140class ShellPlugin(CoveragePlugin): 

141 def __init__(self, options: dict[str, str]): 

142 self.options = options 

143 self._data = None 

144 self.tracefiles_dir_path = Path.cwd() / ".coverage-sh" 

145 self.tracefiles_dir_path.mkdir(parents=True, exist_ok=True) 

146 

147 atexit.register(self._convert_traces) 

148 # TODO: Does this work with multithreading? Isnt there an easier way of finding the base path? 

149 PatchedPopen.tracefiles_dir_path = self.tracefiles_dir_path 

150 subprocess.Popen = PatchedPopen 

151 

152 def _init_data(self) -> None: 

153 if self._data is None: 

154 self._data = coverage.CoverageData( 

155 # TODO: make basename configurable 

156 basename=self.tracefiles_dir_path.parent / ".coverage", 

157 suffix="sh." + filename_suffix(suffix=True), 

158 # TODO: set warn, debug and no_disk 

159 ) 

160 

161 def _convert_traces(self) -> None: 

162 self._init_data() 

163 

164 for tracefile_path in self.tracefiles_dir_path.glob( 

165 f"shelltrace.{gethostname()}.{os.getpid()}.*" 

166 ): 

167 line_data = self._parse_tracefile(tracefile_path) 

168 self._write_trace(line_data) 

169 

170 tracefile_path.unlink(missing_ok=True) 

171 

172 if len(list(self.tracefiles_dir_path.glob("*"))) == 0: 

173 self.tracefiles_dir_path.rmdir() 

174 

175 @staticmethod 

176 def _parse_tracefile(tracefile_path: Path) -> dict[str, set[int]]: 

177 if not tracefile_path.exists(): 

178 return {} 

179 

180 line_data = defaultdict(set) 

181 with tracefile_path.open("r") as fd: 

182 for line in fd: 

183 _, path, lineno, _ = line.split(":::") 

184 path = Path(path).absolute() 

185 line_data[str(path)].add(int(lineno)) 

186 

187 return line_data 

188 

189 def _write_trace(self, line_data: dict[str, set[int]]) -> None: 

190 self._data.add_file_tracers({f: "coverage_sh.ShellPlugin" for f in line_data}) 

191 self._data.add_lines(line_data) 

192 self._data.write() 

193 

194 @staticmethod 

195 def _is_relevant(path: Path) -> bool: 

196 return magic.from_file(path, mime=True) in SUPPORTED_MIME_TYPES 

197 

198 def file_tracer(self, filename: str) -> FileTracer | None: # noqa: ARG002 

199 return None 

200 

201 def file_reporter( 

202 self, 

203 filename: str, 

204 ) -> ShellFileReporter | str: 

205 return ShellFileReporter(filename) 

206 

207 def find_executable_files( 

208 self, 

209 src_dir: str, 

210 ) -> Iterable[str]: 

211 for f in Path(src_dir).rglob("*"): 

212 if not f.is_file() or any(p.startswith(".") for p in f.parts): 

213 continue 

214 

215 if self._is_relevant(f): 

216 yield str(f)