Coverage for src/hatch_ci/tools.py: 92%

201 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-06 19:18 +0000

1# see https://pypi.org/project/setuptools-github 

2# copy of setuptools_github.tools 

3from __future__ import annotations 

4 

5import ast 

6import json 

7import re 

8from pathlib import Path 

9from typing import Any 

10 

11from . import scm 

12 

13 

14class ToolsError(Exception): 

15 pass 

16 

17 

18class ValidationError(ToolsError): 

19 pass 

20 

21 

22class InvalidVersionError(ToolsError): 

23 pass 

24 

25 

26class MissingVariableError(ToolsError): 

27 pass 

28 

29 

30class AbortExecutionError(Exception): 

31 @staticmethod 

32 def _strip(txt): 

33 txt = txt or "" 

34 txt = txt[1:] if txt.startswith("\n") else txt 

35 txt = indent(txt, pre="") 

36 return txt[:-1] if txt.endswith("\n") else txt 

37 

38 def __init__( 

39 self, message: str, explain: str | None = None, hint: str | None = None 

40 ): 

41 self.message = message.strip() 

42 self._explain = explain 

43 self._hint = hint 

44 

45 @property 

46 def explain(self): 

47 return self._strip(self._explain) 

48 

49 @property 

50 def hint(self): 

51 return self._strip(self._hint) 

52 

53 def __str__(self): 

54 result = [self.message] 

55 if self.explain: 

56 result.append(indent("\n" + self.explain, pre=" " * 2)[2:]) 

57 if self.hint: 

58 result.extend(["\nhint:", indent("\n" + self.hint, pre=" " * 2)[2:]]) 

59 return "".join(result) 

60 

61 

62def urmtree(path: Path): 

63 "universal (win|*nix) rmtree" 

64 from os import name 

65 from shutil import rmtree 

66 from stat import S_IWUSR 

67 

68 if name == "nt": 68 ↛ 69line 68 didn't jump to line 69, because the condition on line 68 was never true

69 for p in path.rglob("*"): 

70 p.chmod(S_IWUSR) 

71 rmtree(path, ignore_errors=True) 

72 if path.exists(): 72 ↛ 73line 72 didn't jump to line 73, because the condition on line 72 was never true

73 raise RuntimeError(f"cannot remove {path=}") 

74 

75 

76def indent(txt: str, pre: str = " " * 2) -> str: 

77 "simple text indentation" 

78 

79 from textwrap import dedent 

80 

81 txt = dedent(txt) 

82 if txt.endswith("\n"): 

83 last_eol = "\n" 

84 txt = txt[:-1] 

85 else: 

86 last_eol = "" 

87 

88 result = pre + txt.replace("\n", "\n" + pre) + last_eol 

89 return result if result.strip() else result.strip() 

90 

91 

92def list_of_paths(paths: str | Path | list[str | Path] | None) -> list[Path]: 

93 if not paths: 

94 return [] 

95 return [Path(s) for s in ([paths] if isinstance(paths, (str, Path)) else paths)] 

96 

97 

98def lstrip(txt: str, left: str) -> str: 

99 return txt[len(left) :] if txt.startswith(left) else txt 

100 

101 

102def apply_fixers(txt: str, fixers: dict[str, str] | None = None) -> str: 

103 result = txt 

104 for src, dst in (fixers or {}).items(): 

105 if src.startswith("re:"): 

106 result = re.sub(src[3:], dst, result) 

107 else: 

108 result = result.replace(src, dst) 

109 return result 

110 

111 

112def get_module_var( 

113 path: Path | str, var: str = "__version__", abort=True 

114) -> str | None: 

115 """extract from a python module in path the module level <var> variable 

116 

117 Args: 

118 path (str,Path): python module file to parse using ast (no code-execution) 

119 var (str): module level variable name to extract 

120 abort (bool): raise MissingVariable if var is not present 

121 

122 Returns: 

123 None or str: the variable value if found or None 

124 

125 Raises: 

126 MissingVariable: if the var is not found and abort is True 

127 

128 Notes: 

129 this uses ast to parse path, so it doesn't load the module 

130 """ 

131 

132 class V(ast.NodeVisitor): 

133 def __init__(self, keys): 

134 self.keys = keys 

135 self.result = {} 

136 

137 def visit_Module(self, node): # noqa: N802 

138 # we extract the module level variables 

139 for subnode in ast.iter_child_nodes(node): 

140 if not isinstance(subnode, ast.Assign): 140 ↛ 141line 140 didn't jump to line 141, because the condition on line 140 was never true

141 continue 

142 for target in subnode.targets: 

143 if target.id not in self.keys: 

144 continue 

145 if not isinstance(subnode.value, (ast.Num, ast.Str, ast.Constant)): 

146 raise ValidationError( 

147 f"cannot extract non Constant variable " 

148 f"{target.id} ({type(subnode.value)})" 

149 ) 

150 if isinstance(subnode.value, ast.Str): 

151 value = subnode.value.s 

152 elif isinstance(subnode.value, ast.Num): 152 ↛ 155line 152 didn't jump to line 155, because the condition on line 152 was never false

153 value = subnode.value.n 

154 else: 

155 value = subnode.value.value 

156 if target.id in self.result: 

157 raise ValidationError( 

158 f"found multiple repeated variables {target.id}" 

159 ) 

160 self.result[target.id] = value 

161 return self.generic_visit(node) 

162 

163 v = V({var}) 

164 path = Path(path) 

165 if path.exists(): 

166 tree = ast.parse(Path(path).read_text()) 

167 v.visit(tree) 

168 if var not in v.result and abort: 

169 raise MissingVariableError(f"cannot find {var} in {path}", path, var) 

170 return v.result.get(var, None) 

171 

172 

173def set_module_var( 

174 path: str | Path, var: str, value: Any, create: bool = True 

175) -> tuple[Any, str]: 

176 """replace var in path with value 

177 

178 Args: 

179 path (str,Path): python module file to parse 

180 var (str): module level variable name to extract 

181 value (None or Any): if not None replace var in version_file 

182 create (bool): create path if not present 

183 

184 Returns: 

185 (str, str) the (<previous-var-value|None>, <the new text>) 

186 """ 

187 

188 # validate the var 

189 get_module_var(path, var, abort=False) 

190 

191 # module level var 

192 expr = re.compile(f"^{var}\\s*=\\s*['\\\"](?P<value>[^\\\"']*)['\\\"]") 

193 fixed = None 

194 lines = [] 

195 

196 src = Path(path) 

197 if not src.exists() and create: 

198 src.parent.mkdir(parents=True, exist_ok=True) 

199 src.touch() 

200 

201 input_lines = src.read_text().split("\n") 

202 for line in input_lines: 

203 if fixed is not None: 

204 lines.append(line) 

205 continue 

206 match = expr.search(line) 

207 if match: 

208 fixed = match.group("value") 

209 if value is not None: 209 ↛ 212line 209 didn't jump to line 212, because the condition on line 209 was never false

210 x, y = match.span(1) 

211 line = line[:x] + value + line[y:] 

212 lines.append(line) 

213 txt = "\n".join(lines) 

214 if (fixed is None) and create: 

215 if txt and txt[-1] != "\n": 

216 txt += "\n" 

217 txt += f'{var} = "{value}"' 

218 

219 with Path(path).open("w") as fp: 

220 fp.write(txt) 

221 return fixed, txt 

222 

223 

224def bump_version(version: str, mode: str) -> str: 

225 """given a version str will bump it according to mode 

226 

227 Arguments: 

228 version: text in the N.M.O form 

229 mode: major, minor or micro 

230 

231 Returns: 

232 increased text 

233 

234 >>> bump_version("1.0.3", "micro") 

235 "1.0.4" 

236 >>> bump_version("1.0.3", "minor") 

237 "1.1.0" 

238 """ 

239 newver = [int(n) for n in version.split(".")] 

240 if mode == "major": 

241 newver[-3] += 1 

242 newver[-2] = 0 

243 newver[-1] = 0 

244 elif mode == "minor": 

245 newver[-2] += 1 

246 newver[-1] = 0 

247 elif mode == "micro": 

248 newver[-1] += 1 

249 return ".".join(str(v) for v in newver) 

250 

251 

252def get_data( 

253 version_file: str | Path, github_dump: str | None = None, abort: bool = True 

254) -> dict[str, str | None]: 

255 """extracts version information from github_dump and updates version_file in-place 

256 

257 Args: 

258 version_file (str, Path): path to a file with a __version__ variable 

259 github_dump (str): the os.getenv("GITHUB_DUMP") value 

260 

261 Returns: 

262 dict[str,str|None]: a dict with the current config 

263 """ 

264 result = { 

265 "version": get_module_var(version_file, "__version__"), 

266 "current": get_module_var(version_file, "__version__"), 

267 "branch": None, 

268 "hash": None, 

269 "build": None, 

270 "runid": None, 

271 "workflow": None, 

272 } 

273 

274 path = Path(version_file) 

275 repo = scm.lookup(path) 

276 

277 if not (repo or github_dump): 277 ↛ 278line 277 didn't jump to line 278, because the condition on line 277 was never true

278 if abort: 

279 raise scm.InvalidGitRepoError(f"cannot find a valid git repo for {path}") 

280 return result 

281 

282 if not github_dump and repo: 

283 gdata = { 

284 "ref": repo.head.name, 

285 "sha": repo.head.target.hex[:7], 

286 "run_number": 0, 

287 "run_id": 0, 

288 } 

289 dirty = repo.dirty() 

290 else: 

291 gdata = json.loads(github_dump) if isinstance(github_dump, str) else github_dump 

292 dirty = False 

293 

294 expr = re.compile(r"/(?P<what>beta|release)/(?P<version>\d+([.]\d+)*)$") 

295 expr1 = re.compile(r"(?P<version>\d+([.]\d+)*)(?P<num>b\d+)?$") 

296 

297 result["branch"] = lstrip(gdata["ref"], "refs/heads/") 

298 result["hash"] = gdata["sha"] + ("*" if dirty else "") 

299 result["build"] = gdata["run_number"] 

300 result["runid"] = gdata["run_id"] 

301 result["workflow"] = result["branch"] 

302 

303 current = result["current"] 

304 if match := expr.search(gdata["ref"]): 

305 # setuptools double calls the update_version, 

306 # this fixes the issue 

307 match1 = expr1.search(current or "") 

308 if not match1: 308 ↛ 309line 308 didn't jump to line 309, because the condition on line 308 was never true

309 raise InvalidVersionError(f"cannot parse current version '{current}'") 

310 if match1.group("version") != match.group("version"): 

311 raise InvalidVersionError( 

312 f"building package for {current} from '{gdata['ref']}' " 

313 f"branch ({match.groupdict()} mismatch {match1.groupdict()})" 

314 ) 

315 if match.group("what") == "beta": 

316 result["version"] = f"{match1.group('version')}b{gdata['run_number']}" 

317 result["workflow"] = "beta" 

318 else: 

319 result["workflow"] = "tags" 

320 return result 

321 

322 

323def update_version( 

324 version_file: str | Path, github_dump: str | None = None, abort: bool = True 

325) -> str | None: 

326 """extracts version information from github_dump and updates version_file in-place 

327 

328 Args: 

329 version_file (str, Path): path to a file with a __version__ variable 

330 github_dump (str): the os.getenv("GITHUB_DUMP") value 

331 

332 Returns: 

333 str: the new version for the package 

334 """ 

335 

336 data = get_data(version_file, github_dump, abort) 

337 set_module_var(version_file, "__version__", data["version"]) 

338 set_module_var(version_file, "__hash__", data["hash"]) 

339 return data["version"] 

340 

341 

342def process( 

343 version_file: str | Path, 

344 github_dump: str | None = None, 

345 paths: str | Path | list[str | Path] | None = None, 

346 fixers: dict[str, str] | None = None, 

347 abort: bool = True, 

348) -> dict[str, str | None]: 

349 """get version from github_dump and updates version_file/paths 

350 

351 Args: 

352 paths (str, Path): path(s) to files jinja2 processeable 

353 version_file (str, Path): path to a file with __version__ variable 

354 github_dump (str): the os.getenv("GITHUB_DUMP") value 

355 

356 Returns: 

357 str: the new version for the package 

358 

359 Example: 

360 {'branch': 'beta/0.3.1', 

361 'build': 0, 

362 'current': '0.3.1', 

363 'hash': 'c9e484a*', 

364 'version': '0.3.1b0', 

365 'runid': 0 

366 } 

367 """ 

368 from argparse import Namespace 

369 from functools import partial 

370 from urllib.parse import quote 

371 

372 from jinja2 import Environment 

373 

374 class Context(Namespace): 

375 def items(self): 

376 for name, value in self.__dict__.items(): 

377 if name.startswith("_"): 377 ↛ 378line 377 didn't jump to line 378, because the condition on line 377 was never true

378 continue 

379 yield (name, value) 

380 

381 data = get_data(version_file, github_dump, abort) 

382 set_module_var(version_file, "__version__", data["version"]) 

383 set_module_var(version_file, "__hash__", data["hash"]) 

384 

385 env = Environment(autoescape=True) 

386 env.filters["urlquote"] = partial(quote, safe="") 

387 for path in list_of_paths(paths): 

388 txt = apply_fixers(path.read_text(), fixers) 

389 tmpl = env.from_string(txt) 

390 path.write_text(tmpl.render(ctx=Context(**data))) 

391 return data