Coverage for src/pydal2sql/cli_support.py: 97%

154 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-31 16:22 +0200

1""" 

2CLI-Agnostic support. 

3""" 

4import contextlib 

5import io 

6import os 

7import select 

8import string 

9import sys 

10import textwrap 

11import typing 

12from pathlib import Path 

13from typing import Optional 

14 

15import rich 

16from black.files import find_project_root 

17from git.objects.blob import Blob 

18from git.objects.commit import Commit 

19from git.repo import Repo 

20 

21from .helpers import flatten 

22from .magic import ( 

23 find_defined_variables, 

24 find_local_imports, 

25 find_missing_variables, 

26 generate_magic_code, 

27 remove_import, 

28 remove_local_imports, 

29 remove_specific_variables, find_function_to_call, add_function_call, 

30) 

31 

32 

33def has_stdin_data() -> bool: # pragma: no cover 

34 """ 

35 Check if the program starts with cli data (pipe | or redirect ><). 

36 

37 See Also: 

38 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data 

39 """ 

40 return any( 

41 select.select( 

42 [ 

43 sys.stdin, 

44 ], 

45 [], 

46 [], 

47 0.0, 

48 )[0] 

49 ) 

50 

51 

52AnyCallable = typing.Callable[..., typing.Any] 

53 

54 

55def print_if_interactive(*args: typing.Any, pretty: bool = True, **kwargs: typing.Any) -> None: # pragma: no cover 

56 is_interactive = not has_stdin_data() 

57 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy 

58 if is_interactive: 

59 kwargs["file"] = sys.stderr 

60 _print( 

61 *args, 

62 **kwargs, 

63 ) 

64 

65 

66def find_git_root(at: str = None) -> Optional[Path]: 

67 folder, reason = find_project_root((at or os.getcwd(),)) 

68 if reason != ".git directory": 

69 return None 

70 return folder 

71 

72 

73def find_git_repo(repo: Repo = None, at: str = None) -> Repo: 

74 if repo: 

75 return repo 

76 

77 root = find_git_root(at) 

78 return Repo(str(root)) 

79 

80 

81def latest_commit(repo: Repo = None) -> Commit: 

82 repo = find_git_repo(repo) 

83 return repo.head.commit 

84 

85 

86def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit: 

87 repo = find_git_repo(repo) 

88 return repo.commit(commit_hash) 

89 

90 

91@contextlib.contextmanager 

92def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]: 

93 yield io.BytesIO(file.data_stream.read()) 

94 

95 

96def read_blob(file: Blob) -> str: 

97 with open_blob(file) as f: 

98 return f.read().decode() 

99 

100 

101def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str: 

102 repo = find_git_repo(repo, at=filename) 

103 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo) 

104 

105 file_path = str(Path(filename).resolve()) 

106 # relative to the .git folder: 

107 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/") 

108 

109 file_at_commit = commit.tree / relative_file_path 

110 return read_blob(file_at_commit) 

111 

112 

113def get_file_for_version(filename: str, version: str, prompt_description: str = "") -> str: 

114 if version == "current": 

115 return Path(filename).read_text() 

116 elif version == "stdin": # pragma: no cover 

117 print_if_interactive( 

118 f"[blue]Please paste your define tables ({prompt_description}) code below " 

119 f"and press ctrl-D when finished.[/blue]", 

120 file=sys.stderr, 

121 ) 

122 result = sys.stdin.read() 

123 print_if_interactive("[blue]---[/blue]", file=sys.stderr) 

124 return result 

125 else: 

126 return get_file_for_commit(filename, version) 

127 

128 

129def extract_file_version_and_path( 

130 file_path_or_git_tag: Optional[str], default_version: str = "stdin" 

131) -> tuple[str, str | None]: 

132 """ 

133 

134 Examples: 

135 myfile.py (implies @current) 

136 

137 myfile.py@latest 

138 myfile.py@my-branch 

139 myfile.py@b3f24091a9 

140 

141 @latest (implies no path, e.g. in case of ALTER to copy previously defined path) 

142 """ 

143 if not file_path_or_git_tag: 

144 return default_version, "" 

145 

146 if file_path_or_git_tag == "-": 

147 return "stdin", "-" 

148 

149 if file_path_or_git_tag.startswith("@"): 

150 file_version = file_path_or_git_tag.strip("@") 

151 file_path = None 

152 elif "@" in file_path_or_git_tag: 

153 file_path, file_version = file_path_or_git_tag.split("@") 

154 else: 

155 file_version = default_version # `latest` for before; `current` for after. 

156 file_path = file_path_or_git_tag 

157 

158 return file_version, file_path 

159 

160 

161def extract_file_versions_and_paths( 

162 filename_before: Optional[str], filename_after: Optional[str] 

163) -> tuple[tuple[str, str | None], tuple[str, str | None]]: 

164 version_before, filepath_before = extract_file_version_and_path( 

165 filename_before, 

166 default_version="current" 

167 if filename_after and filename_before and filename_after != filename_before 

168 else "latest", 

169 ) 

170 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current") 

171 

172 if not (filepath_before or filepath_after): 

173 raise ValueError("Please supply at least one file name.") 

174 elif not filepath_after: 

175 filepath_after = filepath_before 

176 elif not filepath_before: 

177 filepath_before = filepath_after 

178 

179 return (version_before, filepath_before), (version_after, filepath_after) 

180 

181 

182def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]: 

183 if version == "stdin": 

184 return True, "" 

185 elif filename is None: 

186 # can't deal with this, not stdin and no file should show file missing error later. 

187 return False, "" 

188 

189 if git_root is None: 

190 git_root = find_git_root() or Path(os.getcwd()) 

191 

192 path = Path(filename) 

193 path_via_git = git_root / filename 

194 

195 if path.exists(): 

196 exists = True 

197 absolute_path = str(path.resolve()) 

198 elif path_via_git.exists(): 

199 exists = True 

200 absolute_path = str(path_via_git.resolve()) 

201 else: 

202 exists = False 

203 absolute_path = "" 

204 

205 return exists, absolute_path 

206 

207 

208def ensure_no_migrate_on_real_db( 

209 code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: typing.Optional[bool] = False 

210) -> str: 

211 variables = find_defined_variables(code) 

212 

213 found_variables = set() 

214 

215 for db_name in db_names: 

216 if db_name in variables: 

217 if fix: 

218 code = remove_specific_variables(code, db_names) 

219 else: 

220 found_variables.add(db_name) 

221 

222 if found_variables: 

223 if len(found_variables) == 1: 

224 var = next(iter(found_variables)) 

225 message = f"Variable {var} defined in code! " 

226 else: # pragma: no cover 

227 var = ", ".join(found_variables) 

228 message = f"Variables {var} defined in code! " 

229 raise ValueError( 

230 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database." 

231 ) 

232 

233 if find_local_imports(code): 

234 if fix: 

235 code = remove_local_imports(code) 

236 else: 

237 raise ValueError("Local imports are used in this file! Please remove these or use --magic.") 

238 

239 return code 

240 

241 

242MAX_RETRIES = 20 

243 

244 

245def handle_cli( 

246 code_before: str, 

247 code_after: str, 

248 db_type: Optional[str] = None, 

249 tables: Optional[list[str] | list[list[str]]] = None, 

250 verbose: Optional[bool] = False, 

251 noop: Optional[bool] = False, 

252 magic: Optional[bool] = False, 

253 function_name: Optional[str] = 'define_tables', 

254) -> bool: 

255 """ 

256 Handle user input. 

257 """ 

258 # todo: prefix (e.g. public.) 

259 

260 to_execute = string.Template( 

261 textwrap.dedent( 

262 """ 

263 from pydal import * 

264 from pydal.objects import * 

265 from pydal.validators import * 

266 

267 from pydal2sql import generate_sql 

268 

269 db = database = DAL(None, migrate=False) 

270 

271 tables = $tables 

272 db_type = '$db_type' 

273 

274 $extra 

275 

276 $code_before 

277 

278 db_old = db 

279 db_new = db = database = DAL(None, migrate=False) 

280 

281 $code_after 

282 

283 if not tables: 

284 tables = set(db_old._tables + db_new._tables) 

285 

286 if not tables: 

287 raise ValueError('no-tables-found') 

288 

289 

290 for table in tables: 

291 print('--', table) 

292 if table in db_old and table in db_new: 

293 print(generate_sql(db_old[table], db_new[table], db_type=db_type)) 

294 elif table in db_old: 

295 print(f'DROP TABLE {table};') 

296 else: 

297 print(generate_sql(db_new[table], db_type=db_type)) 

298 """ 

299 ) 

300 ) 

301 

302 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic) 

303 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic) 

304 

305 generated_code = to_execute.substitute( 

306 { 

307 "tables": flatten(tables or []), 

308 "db_type": db_type or "", 

309 "code_before": textwrap.dedent(code_before), 

310 "code_after": textwrap.dedent(code_after), 

311 "extra": "", 

312 } 

313 ) 

314 if verbose or noop: 

315 rich.print(generated_code, file=sys.stderr) 

316 

317 if not noop: 

318 err: typing.Optional[Exception] = None 

319 retry_counter = MAX_RETRIES 

320 while retry_counter > 0: 

321 retry_counter -= 1 

322 try: 

323 exec(generated_code) # nosec: B102 

324 return True # success! 

325 except ValueError as e: 

326 err = e 

327 

328 if str(e) != 'no-tables-found': 

329 raise e 

330 

331 define_tables = find_function_to_call(generated_code, function_name) 

332 

333 # if define_tables function is found, add call to it at end of code 

334 if define_tables is not None: 

335 print(define_tables) 

336 generated_code = add_function_call(generated_code, function_name) 

337 print(generated_code, function_name) 

338 continue 

339 

340 # else: no define_tables or other method to use found. 

341 

342 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr) 

343 print("Please use `db.define_table` or `database.define_table`, " 

344 "or if you really need to use an alias like my_db.define_tables, " 

345 "add `my_db = db` at the top of the file or pass `--db-name mydb`.", file=sys.stderr) 

346 print(f"You can also specify a --function to use something else than {function_name}.", 

347 file=sys.stderr) 

348 

349 return False 

350 

351 except NameError as e: 

352 err = e 

353 # something is missing! 

354 missing_vars = find_missing_variables(generated_code) 

355 if not magic: 

356 rich.print( 

357 f"Your code is missing some variables: {missing_vars}. Add these or try --magic", 

358 file=sys.stderr, 

359 ) 

360 return False 

361 

362 extra_code = generate_magic_code(missing_vars) 

363 

364 generated_code = to_execute.substitute( 

365 { 

366 "tables": flatten(tables or []), 

367 "db_type": db_type or "", 

368 "extra": extra_code, 

369 "code_before": textwrap.dedent(code_before), 

370 "code_after": textwrap.dedent(code_after), 

371 } 

372 ) 

373 

374 if verbose: 

375 rich.print(generated_code, file=sys.stderr) 

376 except ImportError as e: 

377 err = e 

378 # if we catch an ImportError, we try to remove the import and retry 

379 generated_code = remove_import(generated_code, e.name) 

380 

381 if retry_counter < 1: # pragma: no cover 

382 rich.print(f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'}", file=sys.stderr) 

383 return False 

384 

385 return True