Coverage for coverage_sh/plugin.py: 64%
122 statements
« prev ^ index » next coverage.py v7.3.3, created at 2024-01-02 20:32 +0100
« prev ^ index » next coverage.py v7.3.3, created at 2024-01-02 20:32 +0100
1# SPDX-License-Identifier: MIT
2# Copyright (c) 2023-2024 Kilian Lackhove
4from __future__ import annotations
6import atexit
7import inspect
8import os
9import subprocess
10from collections import defaultdict
11from pathlib import Path
12from shutil import which
13from socket import gethostname
14from typing import TYPE_CHECKING, Any, Iterable
16import coverage
17import magic
18from coverage import CoveragePlugin, FileReporter, FileTracer
19from coverage.sqldata import filename_suffix
20from tree_sitter_languages import get_parser
22if TYPE_CHECKING:
23 from coverage.types import TLineNo
24 from tree_sitter import Node
26SUPPORTED_SHELLS = {
27 "sh",
28 "/bin/sh",
29 "/usr/bin/sh",
30 which("sh"),
31 "bash",
32 "/bin/bash",
33 "/usr/bin/bash",
34 which("bash"),
35}
36EXECUTABLE_NODE_TYPES = {
37 "subshell",
38 "redirected_statement",
39 "variable_assignment",
40 "variable_assignments",
41 "command",
42 "declaration_command",
43 "unset_command",
44 "test_command",
45 "negated_command",
46 "for_statement",
47 "c_style_for_statement",
48 "while_statement",
49 "if_statement",
50 "case_statement",
51 "pipeline",
52 "list",
53}
54SUPPORTED_MIME_TYPES = ("text/x-shellscript",)
56parser = get_parser("bash")
59class ShellFileReporter(FileReporter):
60 def __init__(self, filename: str) -> None:
61 super().__init__(filename)
63 self.path = Path(filename)
64 self._content = None
65 self._executable_lines = set()
67 def source(self) -> str:
68 if self._content is None:
69 self._content = self.path.read_text()
71 return self._content
73 def _parse_ast(self, node: Node) -> None:
74 if node.is_named and node.type in EXECUTABLE_NODE_TYPES:
75 self._executable_lines.add(node.start_point[0] + 1)
77 for child in node.children:
78 self._parse_ast(child)
80 def lines(self) -> set[TLineNo]:
81 tree = parser.parse(self.source().encode("utf-8"))
82 self._parse_ast(tree.root_node)
84 return self._executable_lines
87OriginalPopen = subprocess.Popen
90class PatchedPopen(OriginalPopen):
91 tracefiles_dir_path: None | Path = None
93 def __init__(self, *args, **kwargs):
94 self._tracefile_path = None
95 self._tracefile_fd = None
97 # convert args into kwargs
98 sig = inspect.signature(subprocess.Popen)
99 kwargs.update(dict(zip(sig.parameters.keys(), args)))
101 executable = kwargs.get("executable")
102 args: list[str] = kwargs.get("args")
104 if coverage.Coverage.current() is not None and (
105 args[0] in SUPPORTED_SHELLS or executable in SUPPORTED_SHELLS
106 ):
107 self._init_trace(kwargs)
108 else:
109 super().__init__(**kwargs)
111 def _init_trace(self, kwargs: dict[str, Any]) -> None:
112 self._tracefile_path = (
113 self.tracefiles_dir_path / f"shelltrace.{filename_suffix(suffix=True)}"
114 )
115 self._tracefile_path.parent.mkdir(parents=True, exist_ok=True)
116 self._tracefile_path.touch()
117 self._tracefile_fd = os.open(self._tracefile_path, flags=os.O_RDWR | os.O_CREAT)
119 env = kwargs.get("env", os.environ.copy())
120 env["BASH_XTRACEFD"] = str(self._tracefile_fd)
121 env["PS4"] = "COV:::$BASH_SOURCE:::$LINENO:::"
122 kwargs["env"] = env
124 args = list(kwargs.get("args", ()))
125 args.insert(1, "-x")
126 kwargs["args"] = args
128 pass_fds = list(kwargs.get("pass_fds", ()))
129 pass_fds.append(self._tracefile_fd)
130 kwargs["pass_fds"] = pass_fds
132 super().__init__(**kwargs)
134 def __del__(self):
135 if self._tracefile_fd is not None:
136 os.close(self._tracefile_fd)
137 super().__del__()
140class ShellPlugin(CoveragePlugin):
141 def __init__(self, options: dict[str, str]):
142 self.options = options
143 self._data = None
144 self.tracefiles_dir_path = Path.cwd() / ".coverage-sh"
145 self.tracefiles_dir_path.mkdir(parents=True, exist_ok=True)
147 atexit.register(self._convert_traces)
148 # TODO: Does this work with multithreading? Isnt there an easier way of finding the base path?
149 PatchedPopen.tracefiles_dir_path = self.tracefiles_dir_path
150 subprocess.Popen = PatchedPopen
152 def _init_data(self) -> None:
153 if self._data is None:
154 self._data = coverage.CoverageData(
155 # TODO: make basename configurable
156 basename=self.tracefiles_dir_path.parent / ".coverage",
157 suffix="sh." + filename_suffix(suffix=True),
158 # TODO: set warn, debug and no_disk
159 )
161 def _convert_traces(self) -> None:
162 self._init_data()
164 for tracefile_path in self.tracefiles_dir_path.glob(
165 f"shelltrace.{gethostname()}.{os.getpid()}.*"
166 ):
167 line_data = self._parse_tracefile(tracefile_path)
168 self._write_trace(line_data)
170 tracefile_path.unlink(missing_ok=True)
172 if len(list(self.tracefiles_dir_path.glob("*"))) == 0:
173 self.tracefiles_dir_path.rmdir()
175 @staticmethod
176 def _parse_tracefile(tracefile_path: Path) -> dict[str, set[int]]:
177 if not tracefile_path.exists():
178 return {}
180 line_data = defaultdict(set)
181 with tracefile_path.open("r") as fd:
182 for line in fd:
183 _, path, lineno, _ = line.split(":::")
184 path = Path(path).absolute()
185 line_data[str(path)].add(int(lineno))
187 return line_data
189 def _write_trace(self, line_data: dict[str, set[int]]) -> None:
190 self._data.add_file_tracers({f: "coverage_sh.ShellPlugin" for f in line_data})
191 self._data.add_lines(line_data)
192 self._data.write()
194 @staticmethod
195 def _is_relevant(path: Path) -> bool:
196 return magic.from_file(path, mime=True) in SUPPORTED_MIME_TYPES
198 def file_tracer(self, filename: str) -> FileTracer | None: # noqa: ARG002
199 return None
201 def file_reporter(
202 self,
203 filename: str,
204 ) -> ShellFileReporter | str:
205 return ShellFileReporter(filename)
207 def find_executable_files(
208 self,
209 src_dir: str,
210 ) -> Iterable[str]:
211 for f in Path(src_dir).rglob("*"):
212 if not f.is_file() or any(p.startswith(".") for p in f.parts):
213 continue
215 if self._is_relevant(f):
216 yield str(f)