Coverage for src/hatch_ci/tools.py: 92%
201 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-07 07:23 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-07 07:23 +0000
1# see https://pypi.org/project/setuptools-github
2# copy of setuptools_github.tools
3from __future__ import annotations
5import ast
6import json
7import re
8from pathlib import Path
9from typing import Any
11from . import scm
14class ToolsError(Exception):
15 pass
18class ValidationError(ToolsError):
19 pass
22class InvalidVersionError(ToolsError):
23 pass
26class MissingVariableError(ToolsError):
27 pass
30class AbortExecutionError(Exception):
31 @staticmethod
32 def _strip(txt):
33 txt = txt or ""
34 txt = txt[1:] if txt.startswith("\n") else txt
35 txt = indent(txt, pre="")
36 return txt[:-1] if txt.endswith("\n") else txt
38 def __init__(
39 self, message: str, explain: str | None = None, hint: str | None = None
40 ):
41 self.message = message.strip()
42 self._explain = explain
43 self._hint = hint
45 @property
46 def explain(self):
47 return self._strip(self._explain)
49 @property
50 def hint(self):
51 return self._strip(self._hint)
53 def __str__(self):
54 result = [self.message]
55 if self.explain:
56 result.append(indent("\n" + self.explain, pre=" " * 2)[2:])
57 if self.hint:
58 result.extend(["\nhint:", indent("\n" + self.hint, pre=" " * 2)[2:]])
59 return "".join(result)
62def urmtree(path: Path):
63 "universal (win|*nix) rmtree"
64 from os import name
65 from shutil import rmtree
66 from stat import S_IWUSR
68 if name == "nt": 68 ↛ 69line 68 didn't jump to line 69, because the condition on line 68 was never true
69 for p in path.rglob("*"):
70 p.chmod(S_IWUSR)
71 rmtree(path, ignore_errors=True)
72 if path.exists(): 72 ↛ 73line 72 didn't jump to line 73, because the condition on line 72 was never true
73 raise RuntimeError(f"cannot remove {path=}")
76def indent(txt: str, pre: str = " " * 2) -> str:
77 "simple text indentation"
79 from textwrap import dedent
81 txt = dedent(txt)
82 if txt.endswith("\n"):
83 last_eol = "\n"
84 txt = txt[:-1]
85 else:
86 last_eol = ""
88 result = pre + txt.replace("\n", "\n" + pre) + last_eol
89 return result if result.strip() else result.strip()
92def list_of_paths(paths: str | Path | list[str | Path] | None) -> list[Path]:
93 if not paths:
94 return []
95 return [Path(s) for s in ([paths] if isinstance(paths, (str, Path)) else paths)]
98def lstrip(txt: str, left: str) -> str:
99 return txt[len(left) :] if txt.startswith(left) else txt
102def apply_fixers(txt: str, fixers: dict[str, str] | None = None) -> str:
103 result = txt
104 for src, dst in (fixers or {}).items():
105 if src.startswith("re:"):
106 result = re.sub(src[3:], dst, result)
107 else:
108 result = result.replace(src, dst)
109 return result
112def get_module_var(
113 path: Path | str, var: str = "__version__", abort=True
114) -> str | None:
115 """extract from a python module in path the module level <var> variable
117 Args:
118 path (str,Path): python module file to parse using ast (no code-execution)
119 var (str): module level variable name to extract
120 abort (bool): raise MissingVariable if var is not present
122 Returns:
123 None or str: the variable value if found or None
125 Raises:
126 MissingVariable: if the var is not found and abort is True
128 Notes:
129 this uses ast to parse path, so it doesn't load the module
130 """
132 class V(ast.NodeVisitor):
133 def __init__(self, keys):
134 self.keys = keys
135 self.result = {}
137 def visit_Module(self, node): # noqa: N802
138 # we extract the module level variables
139 for subnode in ast.iter_child_nodes(node):
140 if not isinstance(subnode, ast.Assign): 140 ↛ 141line 140 didn't jump to line 141, because the condition on line 140 was never true
141 continue
142 for target in subnode.targets:
143 if target.id not in self.keys:
144 continue
145 if not isinstance(subnode.value, (ast.Num, ast.Str, ast.Constant)):
146 raise ValidationError(
147 f"cannot extract non Constant variable "
148 f"{target.id} ({type(subnode.value)})"
149 )
150 if isinstance(subnode.value, ast.Str):
151 value = subnode.value.s
152 elif isinstance(subnode.value, ast.Num): 152 ↛ 155line 152 didn't jump to line 155, because the condition on line 152 was never false
153 value = subnode.value.n
154 else:
155 value = subnode.value.value
156 if target.id in self.result:
157 raise ValidationError(
158 f"found multiple repeated variables {target.id}"
159 )
160 self.result[target.id] = value
161 return self.generic_visit(node)
163 v = V({var})
164 path = Path(path)
165 if path.exists():
166 tree = ast.parse(Path(path).read_text())
167 v.visit(tree)
168 if var not in v.result and abort:
169 raise MissingVariableError(f"cannot find {var} in {path}", path, var)
170 return v.result.get(var, None)
173def set_module_var(
174 path: str | Path, var: str, value: Any, create: bool = True
175) -> tuple[Any, str]:
176 """replace var in path with value
178 Args:
179 path (str,Path): python module file to parse
180 var (str): module level variable name to extract
181 value (None or Any): if not None replace var in version_file
182 create (bool): create path if not present
184 Returns:
185 (str, str) the (<previous-var-value|None>, <the new text>)
186 """
188 # validate the var
189 get_module_var(path, var, abort=False)
191 # module level var
192 expr = re.compile(f"^{var}\\s*=\\s*['\\\"](?P<value>[^\\\"']*)['\\\"]")
193 fixed = None
194 lines = []
196 src = Path(path)
197 if not src.exists() and create:
198 src.parent.mkdir(parents=True, exist_ok=True)
199 src.touch()
201 input_lines = src.read_text().split("\n")
202 for line in input_lines:
203 if fixed is not None:
204 lines.append(line)
205 continue
206 match = expr.search(line)
207 if match:
208 fixed = match.group("value")
209 if value is not None: 209 ↛ 212line 209 didn't jump to line 212, because the condition on line 209 was never false
210 x, y = match.span(1)
211 line = line[:x] + value + line[y:]
212 lines.append(line)
213 txt = "\n".join(lines)
214 if (fixed is None) and create:
215 if txt and txt[-1] != "\n":
216 txt += "\n"
217 txt += f'{var} = "{value}"'
219 with Path(path).open("w") as fp:
220 fp.write(txt)
221 return fixed, txt
224def bump_version(version: str, mode: str) -> str:
225 """given a version str will bump it according to mode
227 Arguments:
228 version: text in the N.M.O form
229 mode: major, minor or micro
231 Returns:
232 increased text
234 >>> bump_version("1.0.3", "micro")
235 "1.0.4"
236 >>> bump_version("1.0.3", "minor")
237 "1.1.0"
238 """
239 newver = [int(n) for n in version.split(".")]
240 if mode == "major":
241 newver[-3] += 1
242 newver[-2] = 0
243 newver[-1] = 0
244 elif mode == "minor":
245 newver[-2] += 1
246 newver[-1] = 0
247 elif mode == "micro":
248 newver[-1] += 1
249 return ".".join(str(v) for v in newver)
252def get_data(
253 version_file: str | Path, github_dump: str | None = None, abort: bool = True
254) -> dict[str, str | None]:
255 """extracts version information from github_dump and updates version_file in-place
257 Args:
258 version_file (str, Path): path to a file with a __version__ variable
259 github_dump (str): the os.getenv("GITHUB_DUMP") value
261 Returns:
262 dict[str,str|None]: a dict with the current config
263 """
264 result = {
265 "version": get_module_var(version_file, "__version__"),
266 "current": get_module_var(version_file, "__version__"),
267 "branch": None,
268 "hash": None,
269 "build": None,
270 "runid": None,
271 "workflow": None,
272 }
274 path = Path(version_file)
275 repo = scm.lookup(path)
277 if not (repo or github_dump): 277 ↛ 278line 277 didn't jump to line 278, because the condition on line 277 was never true
278 if abort:
279 raise scm.InvalidGitRepoError(f"cannot find a valid git repo for {path}")
280 return result
282 if not github_dump and repo:
283 gdata = {
284 "ref": repo.head.name,
285 "sha": repo.head.target.hex[:7],
286 "run_number": 0,
287 "run_id": 0,
288 }
289 dirty = repo.dirty()
290 else:
291 gdata = json.loads(github_dump) if isinstance(github_dump, str) else github_dump
292 dirty = False
294 expr = re.compile(r"/(?P<what>beta|release)/(?P<version>\d+([.]\d+)*)$")
295 expr1 = re.compile(r"(?P<version>\d+([.]\d+)*)(?P<num>b\d+)?$")
297 result["branch"] = lstrip(gdata["ref"], "refs/heads/")
298 result["hash"] = gdata["sha"] + ("*" if dirty else "")
299 result["build"] = gdata["run_number"]
300 result["runid"] = gdata["run_id"]
301 result["workflow"] = result["branch"]
303 current = result["current"]
304 if match := expr.search(gdata["ref"]):
305 # setuptools double calls the update_version,
306 # this fixes the issue
307 match1 = expr1.search(current or "")
308 if not match1: 308 ↛ 309line 308 didn't jump to line 309, because the condition on line 308 was never true
309 raise InvalidVersionError(f"cannot parse current version '{current}'")
310 if match1.group("version") != match.group("version"):
311 raise InvalidVersionError(
312 f"building package for {current} from '{gdata['ref']}' "
313 f"branch ({match.groupdict()} mismatch {match1.groupdict()})"
314 )
315 if match.group("what") == "beta":
316 result["version"] = f"{match1.group('version')}b{gdata['run_number']}"
317 result["workflow"] = "beta"
318 else:
319 result["workflow"] = "tags"
320 return result
323def update_version(
324 version_file: str | Path, github_dump: str | None = None, abort: bool = True
325) -> str | None:
326 """extracts version information from github_dump and updates version_file in-place
328 Args:
329 version_file (str, Path): path to a file with a __version__ variable
330 github_dump (str): the os.getenv("GITHUB_DUMP") value
332 Returns:
333 str: the new version for the package
334 """
336 data = get_data(version_file, github_dump, abort)
337 set_module_var(version_file, "__version__", data["version"])
338 set_module_var(version_file, "__hash__", data["hash"])
339 return data["version"]
342def process(
343 version_file: str | Path,
344 github_dump: str | None = None,
345 paths: str | Path | list[str | Path] | None = None,
346 fixers: dict[str, str] | None = None,
347 abort: bool = True,
348) -> dict[str, str | None]:
349 """get version from github_dump and updates version_file/paths
351 Args:
352 paths (str, Path): path(s) to files jinja2 processeable
353 version_file (str, Path): path to a file with __version__ variable
354 github_dump (str): the os.getenv("GITHUB_DUMP") value
356 Returns:
357 str: the new version for the package
359 Example:
360 {'branch': 'beta/0.3.1',
361 'build': 0,
362 'current': '0.3.1',
363 'hash': 'c9e484a*',
364 'version': '0.3.1b0',
365 'runid': 0
366 }
367 """
368 from argparse import Namespace
369 from functools import partial
370 from urllib.parse import quote
372 from jinja2 import Environment
374 class Context(Namespace):
375 def items(self):
376 for name, value in self.__dict__.items():
377 if name.startswith("_"): 377 ↛ 378line 377 didn't jump to line 378, because the condition on line 377 was never true
378 continue
379 yield (name, value)
381 data = get_data(version_file, github_dump, abort)
382 set_module_var(version_file, "__version__", data["version"])
383 set_module_var(version_file, "__hash__", data["hash"])
385 env = Environment(autoescape=True)
386 env.filters["urlquote"] = partial(quote, safe="")
387 for path in list_of_paths(paths):
388 txt = apply_fixers(path.read_text(), fixers)
389 tmpl = env.from_string(txt)
390 path.write_text(tmpl.render(ctx=Context(**data)))
391 return data