Coverage for coverage_sh/plugin.py: 100%

201 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-02-03 22:34 +0100

1# SPDX-License-Identifier: MIT 

2# Copyright (c) 2023-2024 Kilian Lackhove 

3 

4from __future__ import annotations 

5 

6import contextlib 

7import inspect 

8import os 

9import selectors 

10import stat 

11import string 

12import subprocess 

13import threading 

14from collections import defaultdict 

15from pathlib import Path 

16from random import Random 

17from socket import gethostname 

18from time import sleep 

19from typing import TYPE_CHECKING, Any, Iterable, Iterator 

20 

21import coverage 

22import magic 

23from coverage import CoveragePlugin, FileReporter, FileTracer 

24from tree_sitter_languages import get_parser 

25 

26if TYPE_CHECKING: 

27 from coverage.types import TLineNo 

28 from tree_sitter import Node 

29 

30TMP_PATH = Path(os.environ.get("XDG_RUNTIME_DIR", "/tmp")) # noqa: S108 

31TRACEFILE_PREFIX = "shelltrace" 

32EXECUTABLE_NODE_TYPES = { 

33 "subshell", 

34 "redirected_statement", 

35 "variable_assignment", 

36 "variable_assignments", 

37 "command", 

38 "declaration_command", 

39 "unset_command", 

40 "test_command", 

41 "negated_command", 

42 "for_statement", 

43 "c_style_for_statement", 

44 "while_statement", 

45 "if_statement", 

46 "case_statement", 

47 "pipeline", 

48 "list", 

49} 

50SUPPORTED_MIME_TYPES = {"text/x-shellscript"} 

51 

52parser = get_parser("bash") 

53 

54 

55class ShellFileReporter(FileReporter): 

56 def __init__(self, filename: str) -> None: 

57 super().__init__(filename) 

58 

59 self.path = Path(filename) 

60 self._content: str | None = None 

61 self._executable_lines: set[int] = set() 

62 

63 def source(self) -> str: 

64 if self._content is None: 

65 self._content = self.path.read_text() 

66 

67 return self._content 

68 

69 def _parse_ast(self, node: Node) -> None: 

70 if node.is_named and node.type in EXECUTABLE_NODE_TYPES: 

71 self._executable_lines.add(node.start_point[0] + 1) 

72 

73 for child in node.children: 

74 self._parse_ast(child) 

75 

76 def lines(self) -> set[TLineNo]: 

77 tree = parser.parse(self.source().encode("utf-8")) 

78 self._parse_ast(tree.root_node) 

79 

80 return self._executable_lines 

81 

82 

83def filename_suffix() -> str: 

84 die = Random(os.urandom(8)) 

85 letters = string.ascii_uppercase + string.ascii_lowercase 

86 rolls = "".join(die.choice(letters) for _ in range(6)) 

87 return f"{gethostname()}.{os.getpid()}.X{rolls}x" 

88 

89 

90LineData = dict[str, set[int]] 

91 

92 

93class CovLineParser: 

94 def __init__(self) -> None: 

95 self._last_line_fragment = "" 

96 self.line_data: LineData = defaultdict(set) 

97 

98 def parse(self, buf: bytes) -> None: 

99 self._report_lines(list(self._buf_to_lines(buf))) 

100 

101 def _buf_to_lines(self, buf: bytes) -> Iterator[str]: 

102 raw = self._last_line_fragment + buf.decode() 

103 self._last_line_fragment = "" 

104 

105 for line in raw.splitlines(keepends=True): 

106 if line == "\n": 

107 pass 

108 elif line.endswith("\n"): 

109 yield line[:-1] 

110 else: 

111 self._last_line_fragment = line 

112 

113 def _report_lines(self, lines: list[str]) -> None: 

114 if not lines: 

115 return 

116 

117 for line in lines: 

118 if "COV:::" not in line: 

119 continue 

120 

121 try: 

122 _, path_, lineno_, _ = line.split(":::", maxsplit=3) 

123 lineno = int(lineno_) 

124 path = Path(path_).absolute() 

125 except ValueError as e: 

126 raise ValueError(f"could not parse line {line}") from e 

127 

128 self.line_data[str(path)].add(lineno) 

129 

130 def flush(self) -> None: 

131 self.parse(b"\n") 

132 

133 

134class CoverageWriter: 

135 def __init__(self, coverage_data_path: Path): 

136 self._coverage_data_path = coverage_data_path 

137 

138 def write(self, line_data: LineData) -> None: 

139 suffix_ = "sh." + filename_suffix() 

140 coverage_data = coverage.CoverageData( 

141 # TODO: This probably wont work with pytest-cov 

142 basename=self._coverage_data_path, 

143 suffix=suffix_, 

144 # TODO: set warn, debug and no_disk 

145 ) 

146 

147 coverage_data.add_file_tracers( 

148 {f: "coverage_sh.ShellPlugin" for f in line_data} 

149 ) 

150 coverage_data.add_lines(line_data) 

151 coverage_data.write() 

152 

153 

154class CoverageParserThread(threading.Thread): 

155 def __init__( 

156 self, 

157 coverage_writer: CoverageWriter, 

158 name: str | None = None, 

159 parser: CovLineParser | None = None, 

160 ) -> None: 

161 super().__init__(name=name) 

162 self._keep_running = True 

163 self._listening = False 

164 self._parser = parser or CovLineParser() 

165 self._coverage_writer = coverage_writer 

166 

167 self.fifo_path = TMP_PATH / f"coverage-sh.{filename_suffix()}.pipe" 

168 with contextlib.suppress(FileNotFoundError): 

169 self.fifo_path.unlink() 

170 os.mkfifo(self.fifo_path, mode=stat.S_IRUSR | stat.S_IWUSR) 

171 

172 def start(self) -> None: 

173 super().start() 

174 while not self._listening: 

175 sleep(0.0001) 

176 

177 def stop(self) -> None: 

178 self._keep_running = False 

179 

180 def run(self) -> None: 

181 sel = selectors.DefaultSelector() 

182 while self._keep_running: 

183 # we need to keep reopening the fifo as long as the subprocess is running because multiple bash processes 

184 # might write EOFs to it 

185 fifo = os.open(self.fifo_path, flags=os.O_RDONLY | os.O_NONBLOCK) 

186 sel.register(fifo, selectors.EVENT_READ) 

187 self._listening = True 

188 

189 eof = False 

190 data_incoming = True 

191 while not eof and (data_incoming or self._keep_running): 

192 events = sel.select(timeout=1) 

193 data_incoming = len(events) > 0 

194 for key, _ in events: 

195 buf = os.read(key.fd, 2**10) 

196 if not buf: 

197 eof = True 

198 break 

199 self._parser.parse(buf) 

200 

201 self._parser.flush() 

202 

203 sel.unregister(fifo) 

204 os.close(fifo) 

205 

206 self._coverage_writer.write(self._parser.line_data) 

207 with contextlib.suppress(FileNotFoundError): 

208 self.fifo_path.unlink() 

209 

210 

211OriginalPopen = subprocess.Popen 

212 

213 

214def init_helper(fifo_path: Path) -> Path: 

215 helper_path = Path(TMP_PATH, f"coverage-sh.{filename_suffix()}.sh") 

216 helper_path.write_text( 

217 rf"""#!/bin/sh 

218PS4="COV:::\${{BASH_SOURCE}}:::\${{LINENO}}:::" 

219exec {{BASH_XTRACEFD}}>>"{fifo_path!s}" 

220export BASH_XTRACEFD 

221set -x 

222""" 

223 ) 

224 helper_path.chmod(mode=stat.S_IRUSR | stat.S_IWUSR) 

225 return helper_path 

226 

227 

228class PatchedPopen(OriginalPopen): 

229 data_file_path: Path = Path.cwd() 

230 

231 def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] 

232 if coverage.Coverage.current() is None: 

233 # we are not recording coverage, so just act like the original Popen 

234 self._parser_thread = None 

235 super().__init__(*args, **kwargs) 

236 return 

237 

238 # convert args into kwargs 

239 sig = inspect.signature(subprocess.Popen) 

240 kwargs.update(dict(zip(sig.parameters.keys(), args))) 

241 

242 self._parser_thread = CoverageParserThread( 

243 coverage_writer=CoverageWriter(coverage_data_path=self.data_file_path), 

244 name="CoverageParserThread(None)", 

245 ) 

246 self._parser_thread.start() 

247 

248 self._helper_path = init_helper(self._parser_thread.fifo_path) 

249 

250 env = kwargs.get("env", os.environ.copy()) 

251 env["BASH_ENV"] = str(self._helper_path) 

252 env["ENV"] = str(self._helper_path) 

253 kwargs["env"] = env 

254 

255 super().__init__(**kwargs) 

256 

257 def wait(self, timeout: float | None = None) -> int: 

258 retval = super().wait(timeout) 

259 if self._parser_thread is None: 

260 # no coverage recording was active during __init__ 

261 return retval 

262 

263 self._parser_thread.stop() 

264 self._parser_thread.join() 

265 with contextlib.suppress(FileNotFoundError): 

266 self._helper_path.unlink() 

267 return retval 

268 

269 

270class MonitorThread(threading.Thread): 

271 def __init__( 

272 self, 

273 parser_thread: CoverageParserThread, 

274 main_thread: threading.Thread | None = None, 

275 name: str | None = None, 

276 ) -> None: 

277 super().__init__(name=name) 

278 self._main_thread = main_thread or threading.main_thread() 

279 self.parser_thread = parser_thread 

280 

281 def run(self) -> None: 

282 self._main_thread.join() 

283 self.parser_thread.stop() 

284 self.parser_thread.join() 

285 

286 

287class ShellPlugin(CoveragePlugin): 

288 def __init__(self, options: dict[str, Any]): 

289 self.options = options 

290 self._helper_path = None 

291 

292 coverage_data_path = Path(coverage.Coverage().config.data_file).absolute() 

293 if self.options.get("cover_always", False): 

294 parser_thread = CoverageParserThread( 

295 coverage_writer=CoverageWriter(coverage_data_path), 

296 name=f"CoverageParserThread({coverage_data_path!s})", 

297 ) 

298 parser_thread.start() 

299 

300 monitor_thread = MonitorThread( 

301 parser_thread=parser_thread, name="MonitorThread" 

302 ) 

303 monitor_thread.start() 

304 

305 self._helper_path = init_helper(parser_thread.fifo_path) 

306 os.environ["BASH_ENV"] = str(self._helper_path) 

307 os.environ["ENV"] = str(self._helper_path) 

308 else: 

309 PatchedPopen.data_file_path = coverage_data_path 

310 subprocess.Popen = PatchedPopen 

311 

312 def __del__(self) -> None: 

313 if self._helper_path is not None: 

314 with contextlib.suppress(FileNotFoundError): 

315 self._helper_path.unlink() 

316 

317 @staticmethod 

318 def _is_relevant(path: Path) -> bool: 

319 return magic.from_file(path.resolve(), mime=True) in SUPPORTED_MIME_TYPES 

320 

321 def file_tracer(self, filename: str) -> FileTracer | None: # noqa: ARG002 

322 return None 

323 

324 def file_reporter( 

325 self, 

326 filename: str, 

327 ) -> ShellFileReporter | str: 

328 return ShellFileReporter(filename) 

329 

330 def find_executable_files( 

331 self, 

332 src_dir: str, 

333 ) -> Iterable[str]: 

334 for f in Path(src_dir).rglob("*"): 

335 # TODO: Use coverage's logic for figuring out if a file should be excluded 

336 if not (f.is_file() or f.is_symlink()) or any( 

337 p.startswith(".") for p in f.parts 

338 ): 

339 continue 

340 

341 if self._is_relevant(f): 

342 yield str(f)