Coverage for src/hatch_ci/tools.py: 91%

198 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-03 19:49 +0000

1# see https://pypi.org/project/setuptools-github 

2# copy of setuptools_github.tools 

3from __future__ import annotations 

4 

5import ast 

6import json 

7import re 

8from pathlib import Path 

9from typing import Any 

10 

11from . import scm 

12 

13 

14class ToolsError(Exception): 

15 pass 

16 

17 

18class ValidationError(ToolsError): 

19 pass 

20 

21 

22class InvalidVersionError(ToolsError): 

23 pass 

24 

25 

26class MissingVariableError(ToolsError): 

27 pass 

28 

29 

30class AbortExecutionError(Exception): 

31 @staticmethod 

32 def _strip(txt): 

33 txt = txt or "" 

34 txt = txt[1:] if txt.startswith("\n") else txt 

35 txt = indent(txt, pre="") 

36 return txt[:-1] if txt.endswith("\n") else txt 

37 

38 def __init__( 

39 self, message: str, explain: str | None = None, hint: str | None = None 

40 ): 

41 self.message = message.strip() 

42 self._explain = explain 

43 self._hint = hint 

44 

45 @property 

46 def explain(self): 

47 return self._strip(self._explain) 

48 

49 @property 

50 def hint(self): 

51 return self._strip(self._hint) 

52 

53 def __str__(self): 

54 result = [self.message] 

55 if self.explain: 

56 result.append(indent("\n" + self.explain, pre=" " * 2)[2:]) 

57 if self.hint: 

58 result.extend(["\nhint:", indent("\n" + self.hint, pre=" " * 2)[2:]]) 

59 return "".join(result) 

60 

61 

62def urmtree(path: Path): 

63 "universal (win|*nix) rmtree" 

64 from os import name 

65 from shutil import rmtree 

66 from stat import S_IWUSR 

67 

68 if name == "nt": 68 ↛ 69line 68 didn't jump to line 69, because the condition on line 68 was never true

69 for p in path.rglob("*"): 

70 p.chmod(S_IWUSR) 

71 rmtree(path, ignore_errors=True) 

72 if path.exists(): 72 ↛ 73line 72 didn't jump to line 73, because the condition on line 72 was never true

73 raise RuntimeError(f"cannot remove {path=}") 

74 

75 

76def indent(txt: str, pre: str = " " * 2) -> str: 

77 "simple text indentation" 

78 

79 from textwrap import dedent 

80 

81 txt = dedent(txt) 

82 if txt.endswith("\n"): 

83 last_eol = "\n" 

84 txt = txt[:-1] 

85 else: 

86 last_eol = "" 

87 

88 result = pre + txt.replace("\n", "\n" + pre) + last_eol 

89 return result if result.strip() else result.strip() 

90 

91 

92def list_of_paths(paths: str | Path | list[str | Path] | None) -> list[Path]: 

93 if not paths: 

94 return [] 

95 return [Path(s) for s in ([paths] if isinstance(paths, (str, Path)) else paths)] 

96 

97 

98def lstrip(txt: str, left: str) -> str: 

99 return txt[len(left) :] if txt.startswith(left) else txt 

100 

101 

102def get_module_var( 

103 path: Path | str, var: str = "__version__", abort=True 

104) -> str | None: 

105 """extract from a python module in path the module level <var> variable 

106 

107 Args: 

108 path (str,Path): python module file to parse using ast (no code-execution) 

109 var (str): module level variable name to extract 

110 abort (bool): raise MissingVariable if var is not present 

111 

112 Returns: 

113 None or str: the variable value if found or None 

114 

115 Raises: 

116 MissingVariable: if the var is not found and abort is True 

117 

118 Notes: 

119 this uses ast to parse path, so it doesn't load the module 

120 """ 

121 

122 class V(ast.NodeVisitor): 

123 def __init__(self, keys): 

124 self.keys = keys 

125 self.result = {} 

126 

127 def visit_Module(self, node): # noqa: N802 

128 # we extract the module level variables 

129 for subnode in ast.iter_child_nodes(node): 

130 if not isinstance(subnode, ast.Assign): 130 ↛ 131line 130 didn't jump to line 131, because the condition on line 130 was never true

131 continue 

132 for target in subnode.targets: 

133 if target.id not in self.keys: 

134 continue 

135 if not isinstance(subnode.value, (ast.Num, ast.Str, ast.Constant)): 

136 raise ValidationError( 

137 f"cannot extract non Constant variable " 

138 f"{target.id} ({type(subnode.value)})" 

139 ) 

140 if isinstance(subnode.value, ast.Str): 

141 value = subnode.value.s 

142 elif isinstance(subnode.value, ast.Num): 142 ↛ 145line 142 didn't jump to line 145, because the condition on line 142 was never false

143 value = subnode.value.n 

144 else: 

145 value = subnode.value.value 

146 if target.id in self.result: 

147 print(f">>> {path=}") # noqa: T201 

148 print(path.read_text()) # noqa: T201 

149 raise ValidationError( 

150 f"found multiple repeated variables {target.id}" 

151 ) 

152 self.result[target.id] = value 

153 return self.generic_visit(node) 

154 

155 v = V({var}) 

156 path = Path(path) 

157 if path.exists(): 

158 tree = ast.parse(Path(path).read_text()) 

159 v.visit(tree) 

160 if var not in v.result and abort: 

161 raise MissingVariableError(f"cannot find {var} in {path}", path, var) 

162 return v.result.get(var, None) 

163 

164 

165def set_module_var( 

166 path: str | Path, var: str, value: Any, create: bool = True 

167) -> tuple[Any, str]: 

168 """replace var in path with value 

169 

170 Args: 

171 path (str,Path): python module file to parse 

172 var (str): module level variable name to extract 

173 value (None or Any): if not None replace var in initfile 

174 create (bool): create path if not present 

175 

176 Returns: 

177 (str, str) the (<previous-var-value|None>, <the new text>) 

178 """ 

179 

180 # validate the var 

181 get_module_var(path, var, abort=False) 

182 

183 # module level var 

184 expr = re.compile(f"^{var}\\s*=\\s*['\\\"](?P<value>[^\\\"']*)['\\\"]") 

185 fixed = None 

186 lines = [] 

187 

188 src = Path(path) 

189 if not src.exists() and create: 

190 src.parent.mkdir(parents=True, exist_ok=True) 

191 src.touch() 

192 

193 input_lines = src.read_text().split("\n") 

194 for line in input_lines: 

195 if fixed is not None: 

196 lines.append(line) 

197 continue 

198 match = expr.search(line) 

199 if match: 

200 fixed = match.group("value") 

201 if value is not None: 201 ↛ 204line 201 didn't jump to line 204, because the condition on line 201 was never false

202 x, y = match.span(1) 

203 line = line[:x] + value + line[y:] 

204 lines.append(line) 

205 txt = "\n".join(lines) 

206 if (fixed is None) and create: 

207 if txt and txt[-1] != "\n": 

208 txt += "\n" 

209 txt += f'{var} = "{value}"' 

210 

211 with Path(path).open("w") as fp: 

212 fp.write(txt) 

213 return fixed, txt 

214 

215 

216def bump_version(version: str, mode: str) -> str: 

217 """given a version str will bump it according to mode 

218 

219 Arguments: 

220 version: text in the N.M.O form 

221 mode: major, minor or micro 

222 

223 Returns: 

224 increased text 

225 

226 >>> bump_version("1.0.3", "micro") 

227 "1.0.4" 

228 >>> bump_version("1.0.3", "minor") 

229 "1.1.0" 

230 """ 

231 newver = [int(n) for n in version.split(".")] 

232 if mode == "major": 

233 newver[-3] += 1 

234 newver[-2] = 0 

235 newver[-1] = 0 

236 elif mode == "minor": 

237 newver[-2] += 1 

238 newver[-1] = 0 

239 elif mode == "micro": 

240 newver[-1] += 1 

241 return ".".join(str(v) for v in newver) 

242 

243 

244def get_data( 

245 initfile: str | Path, github_dump: str | None = None, abort: bool = True 

246) -> dict[str, str | None]: 

247 """extracts version information from github_dump and updates initfile in-place 

248 

249 Args: 

250 initfile (str, Path): path to the __init__.py file with a __version__ variable 

251 github_dump (str): the os.getenv("GITHUB_DUMP") value 

252 

253 Returns: 

254 dict[str,str|None]: a dict with the current config 

255 """ 

256 result = { 

257 "version": get_module_var(initfile, "__version__"), 

258 "current": get_module_var(initfile, "__version__"), 

259 "branch": None, 

260 "hash": None, 

261 "build": None, 

262 "runid": None, 

263 "workflow": None, 

264 } 

265 

266 path = Path(initfile) 

267 repo = scm.lookup(path) 

268 

269 if not (repo or github_dump): 269 ↛ 270line 269 didn't jump to line 270, because the condition on line 269 was never true

270 if abort: 

271 raise scm.InvalidGitRepoError(f"cannot find a valid git repo for {path}") 

272 return result 

273 

274 if not github_dump and repo: 

275 gdata = { 

276 "ref": repo.head.name, 

277 "sha": repo.head.target.hex[:7], 

278 "run_number": 0, 

279 "run_id": 0, 

280 } 

281 dirty = repo.dirty() 

282 else: 

283 gdata = json.loads(github_dump) if isinstance(github_dump, str) else github_dump 

284 dirty = False 

285 

286 expr = re.compile(r"/(?P<what>beta|release)/(?P<version>\d+([.]\d+)*)$") 

287 expr1 = re.compile(r"(?P<version>\d+([.]\d+)*)(?P<num>b\d+)?$") 

288 

289 result["branch"] = lstrip(gdata["ref"], "refs/heads/") 

290 result["hash"] = gdata["sha"] + ("*" if dirty else "") 

291 result["build"] = gdata["run_number"] 

292 result["runid"] = gdata["run_id"] 

293 result["workflow"] = result["branch"] 

294 

295 current = result["current"] 

296 if match := expr.search(gdata["ref"]): 

297 # setuptools double calls the update_version, 

298 # this fixes the issue 

299 match1 = expr1.search(current or "") 

300 if not match1: 300 ↛ 301line 300 didn't jump to line 301, because the condition on line 300 was never true

301 raise InvalidVersionError(f"cannot parse current version '{current}'") 

302 if match1.group("version") != match.group("version"): 

303 raise InvalidVersionError( 

304 f"building package for {current} from '{gdata['ref']}' " 

305 f"branch ({match.groupdict()} mismatch {match1.groupdict()})" 

306 ) 

307 if match.group("what") == "beta": 

308 result["version"] = f"{match1.group('version')}b{gdata['run_number']}" 

309 result["workflow"] = "beta" 

310 else: 

311 result["workflow"] = "tags" 

312 return result 

313 

314 

315def update_version( 

316 initfile: str | Path, github_dump: str | None = None, abort: bool = True 

317) -> str | None: 

318 """extracts version information from github_dump and updates initfile in-place 

319 

320 Args: 

321 initfile (str, Path): path to the __init__.py file with a __version__ variable 

322 github_dump (str): the os.getenv("GITHUB_DUMP") value 

323 

324 Returns: 

325 str: the new version for the package 

326 """ 

327 

328 data = get_data(initfile, github_dump, abort) 

329 set_module_var(initfile, "__version__", data["version"]) 

330 set_module_var(initfile, "__hash__", data["hash"]) 

331 return data["version"] 

332 

333 

334def process( 

335 initfile: str | Path, 

336 github_dump: str | None = None, 

337 paths: str | Path | list[str | Path] | None = None, 

338 fixers: dict[str, str] | None = None, 

339 abort: bool = True, 

340) -> dict[str, str | None]: 

341 """get version from github_dump and updates initfile/paths 

342 

343 Args: 

344 paths (str, Path): path(s) to files jinja2 processeable 

345 initfile (str, Path): path to the __init__.py file with a __version__ variable 

346 github_dump (str): the os.getenv("GITHUB_DUMP") value 

347 

348 Returns: 

349 str: the new version for the package 

350 

351 Example: 

352 {'branch': 'beta/0.3.1', 

353 'build': 0, 

354 'current': '0.3.1', 

355 'hash': 'c9e484a*', 

356 'version': '0.3.1b0', 

357 'runid': 0 

358 } 

359 """ 

360 from argparse import Namespace 

361 from functools import partial 

362 from urllib.parse import quote 

363 

364 from jinja2 import Environment 

365 

366 class Context(Namespace): 

367 def items(self): 

368 for name, value in self.__dict__.items(): 

369 if name.startswith("_"): 369 ↛ 370line 369 didn't jump to line 370, because the condition on line 369 was never true

370 continue 

371 yield (name, value) 

372 

373 data = get_data(initfile, github_dump, abort) 

374 set_module_var(initfile, "__version__", data["version"]) 

375 set_module_var(initfile, "__hash__", data["hash"]) 

376 

377 env = Environment(autoescape=True) 

378 env.filters["urlquote"] = partial(quote, safe="") 

379 for path in list_of_paths(paths): 

380 txt = path.read_text() 

381 for old, new in (fixers or {}).items(): 381 ↛ 382line 381 didn't jump to line 382, because the loop on line 381 never started

382 txt = txt.replace(old, new, 1) 

383 tmpl = env.from_string(txt) 

384 path.write_text(tmpl.render(ctx=Context(**data))) 

385 return data