Coverage for src/pydal2sql_core/cli_support.py: 100%

152 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-31 22:49 +0200

1""" 

2CLI-Agnostic support. 

3""" 

4import contextlib 

5import io 

6import os 

7import select 

8import string 

9import sys 

10import textwrap 

11import typing 

12from pathlib import Path 

13from typing import Optional 

14 

15import rich 

16from black.files import find_project_root 

17from git.objects.blob import Blob 

18from git.objects.commit import Commit 

19from git.repo import Repo 

20from witchery import ( 

21 add_function_call, 

22 find_defined_variables, 

23 find_function_to_call, 

24 find_missing_variables, 

25 generate_magic_code, 

26 has_local_imports, 

27 remove_import, 

28 remove_local_imports, 

29 remove_specific_variables, 

30) 

31 

32from .helpers import flatten 

33 

34 

35def has_stdin_data() -> bool: # pragma: no cover 

36 """ 

37 Check if the program starts with cli data (pipe | or redirect <). 

38 

39 Returns: 

40 bool: True if the program starts with cli data, False otherwise. 

41 

42 See Also: 

43 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data 

44 """ 

45 return any( 

46 select.select( 

47 [ 

48 sys.stdin, 

49 ], 

50 [], 

51 [], 

52 0.0, 

53 )[0] 

54 ) 

55 

56 

57AnyCallable = typing.Callable[..., typing.Any] 

58 

59 

60def print_if_interactive(*args: typing.Any, pretty: bool = True, **kwargs: typing.Any) -> None: # pragma: no cover 

61 """ 

62 Print the given arguments if running in an interactive session. 

63 

64 Args: 

65 *args: Variable length list of arguments to be printed. 

66 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function. 

67 **kwargs: Optional keyword arguments to be passed to the print function. 

68 

69 Returns: 

70 None 

71 """ 

72 is_interactive = not has_stdin_data() 

73 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy 

74 if is_interactive: 

75 kwargs["file"] = sys.stderr 

76 _print( 

77 *args, 

78 **kwargs, 

79 ) 

80 

81 

82def find_git_root(at: str = None) -> Optional[Path]: 

83 """ 

84 Find the root directory of the Git repository. 

85 

86 Args: 

87 at (str, optional): The directory path to start the search. Defaults to the current working directory. 

88 

89 Returns: 

90 Optional[Path]: The root directory of the Git repository if found, otherwise None. 

91 """ 

92 folder, reason = find_project_root((at or os.getcwd(),)) 

93 if reason != ".git directory": 

94 return None 

95 return folder 

96 

97 

98def find_git_repo(repo: Repo = None, at: str = None) -> Repo: 

99 """ 

100 Find the Git repository instance. 

101 

102 Args: 

103 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance. 

104 at (str, optional): The directory path to start the search. Defaults to the current working directory. 

105 

106 Returns: 

107 Repo: The Git repository instance. 

108 """ 

109 if repo: 

110 return repo 

111 

112 root = find_git_root(at) 

113 return Repo(str(root)) 

114 

115 

116def latest_commit(repo: Repo = None) -> Commit: 

117 """ 

118 Get the latest commit in the Git repository. 

119 

120 Args: 

121 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

122 

123 Returns: 

124 Commit: The latest commit in the Git repository. 

125 """ 

126 repo = find_git_repo(repo) 

127 return repo.head.commit 

128 

129 

130def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit: 

131 """ 

132 Get a specific commit in the Git repository by its hash or name. 

133 

134 Args: 

135 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name. 

136 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

137 

138 Returns: 

139 Commit: The commit object corresponding to the given commit hash. 

140 """ 

141 repo = find_git_repo(repo) 

142 return repo.commit(commit_hash) 

143 

144 

145@contextlib.contextmanager 

146def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]: 

147 """ 

148 Open a Git Blob object as a context manager, providing access to its data. 

149 

150 Args: 

151 file (Blob): The Git Blob object to open. 

152 

153 Yields: 

154 io.BytesIO: A BytesIO object providing access to the Blob data. 

155 """ 

156 yield io.BytesIO(file.data_stream.read()) 

157 

158 

159def read_blob(file: Blob) -> str: 

160 """ 

161 Read the contents of a Git Blob object and decode it as a string. 

162 

163 Args: 

164 file (Blob): The Git Blob object to read. 

165 

166 Returns: 

167 str: The contents of the Blob as a string. 

168 """ 

169 with open_blob(file) as f: 

170 return f.read().decode() 

171 

172 

173def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str: 

174 """ 

175 Get the contents of a file in the Git repository at a specific commit version. 

176 

177 Args: 

178 filename (str): The path of the file to retrieve. 

179 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit). 

180 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

181 

182 Returns: 

183 str: The contents of the file as a string. 

184 """ 

185 repo = find_git_repo(repo, at=filename) 

186 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo) 

187 

188 file_path = str(Path(filename).resolve()) 

189 # relative to the .git folder: 

190 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/") 

191 

192 file_at_commit = commit.tree / relative_file_path 

193 return read_blob(file_at_commit) 

194 

195 

196def get_file_for_version(filename: str, version: str, prompt_description: str = "") -> str: 

197 """ 

198 Get the contents of a file based on the version specified. 

199 

200 Args: 

201 filename (str): The path of the file to retrieve. 

202 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name. 

203 prompt_description (str, optional): A description to display when asking for input from stdin. 

204 

205 Returns: 

206 str: The contents of the file as a string. 

207 """ 

208 if version == "current": 

209 return Path(filename).read_text() 

210 elif version == "stdin": # pragma: no cover 

211 print_if_interactive( 

212 f"[blue]Please paste your define tables ({prompt_description}) code below " 

213 f"and press ctrl-D when finished.[/blue]", 

214 file=sys.stderr, 

215 ) 

216 result = sys.stdin.read() 

217 print_if_interactive("[blue]---[/blue]", file=sys.stderr) 

218 return result 

219 else: 

220 return get_file_for_commit(filename, version) 

221 

222 

223def extract_file_version_and_path( 

224 file_path_or_git_tag: Optional[str], default_version: str = "stdin" 

225) -> tuple[str, str | None]: 

226 """ 

227 Extract the file version and path from the given input. 

228 

229 Args: 

230 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag. 

231 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin". 

232 

233 Returns: 

234 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified). 

235 

236 Examples: 

237 myfile.py (implies @current) 

238 

239 myfile.py@latest 

240 myfile.py@my-branch 

241 myfile.py@b3f24091a9 

242 

243 @latest (implies no path, e.g. in case of ALTER to copy previously defined path) 

244 """ 

245 if not file_path_or_git_tag: 

246 return default_version, "" 

247 

248 if file_path_or_git_tag == "-": 

249 return "stdin", "-" 

250 

251 if file_path_or_git_tag.startswith("@"): 

252 file_version = file_path_or_git_tag.strip("@") 

253 file_path = None 

254 elif "@" in file_path_or_git_tag: 

255 file_path, file_version = file_path_or_git_tag.split("@") 

256 else: 

257 file_version = default_version # `latest` for before; `current` for after. 

258 file_path = file_path_or_git_tag 

259 

260 return file_version, file_path 

261 

262 

263def extract_file_versions_and_paths( 

264 filename_before: Optional[str], filename_after: Optional[str] 

265) -> tuple[tuple[str, str | None], tuple[str, str | None]]: 

266 """ 

267 Extract the file versions and paths based on the before and after filenames. 

268 

269 Args: 

270 filename_before (str, optional): The path of the file before the change (or None). 

271 filename_after (str, optional): The path of the file after the change (or None). 

272 

273 Returns: 

274 tuple[tuple[str, str | None], tuple[str, str | None]]: 

275 A tuple of two tuples, each containing the version and path of the before and after files. 

276 """ 

277 version_before, filepath_before = extract_file_version_and_path( 

278 filename_before, 

279 default_version="current" 

280 if filename_after and filename_before and filename_after != filename_before 

281 else "latest", 

282 ) 

283 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current") 

284 

285 if not (filepath_before or filepath_after): 

286 raise ValueError("Please supply at least one file name.") 

287 elif not filepath_after: 

288 filepath_after = filepath_before 

289 elif not filepath_before: 

290 filepath_before = filepath_after 

291 

292 return (version_before, filepath_before), (version_after, filepath_after) 

293 

294 

295def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]: 

296 """ 

297 Get absolute path information for the file based on the version and Git root. 

298 

299 Args: 

300 filename (str, optional): The path of the file to check (or None). 

301 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name. 

302 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined. 

303 

304 Returns: 

305 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file. 

306 """ 

307 if version == "stdin": 

308 return True, "" 

309 elif filename is None: 

310 # can't deal with this, not stdin and no file should show file missing error later. 

311 return False, "" 

312 

313 if git_root is None: 

314 git_root = find_git_root() or Path(os.getcwd()) 

315 

316 path = Path(filename) 

317 path_via_git = git_root / filename 

318 

319 if path.exists(): 

320 exists = True 

321 absolute_path = str(path.resolve()) 

322 elif path_via_git.exists(): 

323 exists = True 

324 absolute_path = str(path_via_git.resolve()) 

325 else: 

326 exists = False 

327 absolute_path = "" 

328 

329 return exists, absolute_path 

330 

331 

332def ensure_no_migrate_on_real_db( 

333 code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: typing.Optional[bool] = False 

334) -> str: 

335 """ 

336 Ensure that the code does not contain actual migrations on a real database. 

337 

338 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`. 

339 It also removes local imports to prevent irrelevant code being executed. 

340 

341 Args: 

342 code (str): The code to check for database migrations. 

343 db_names (Iterable[str], optional): Names of variables representing the database. 

344 Defaults to ("db", "database"). 

345 fix (bool, optional): If True, removes the migration code. Defaults to False. 

346 

347 Returns: 

348 str: The modified code with migration code removed if fix=True, otherwise the original code. 

349 """ 

350 variables = find_defined_variables(code) 

351 

352 found_variables = set() 

353 

354 for db_name in db_names: 

355 if db_name in variables: 

356 if fix: 

357 code = remove_specific_variables(code, db_names) 

358 else: 

359 found_variables.add(db_name) 

360 

361 if found_variables: 

362 if len(found_variables) == 1: 

363 var = next(iter(found_variables)) 

364 message = f"Variable {var} defined in code! " 

365 else: # pragma: no cover 

366 var = ", ".join(found_variables) 

367 message = f"Variables {var} defined in code! " 

368 raise ValueError( 

369 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database." 

370 ) 

371 

372 if has_local_imports(code): 

373 if fix: 

374 code = remove_local_imports(code) 

375 else: 

376 raise ValueError("Local imports are used in this file! Please remove these or use --magic.") 

377 

378 return code 

379 

380 

381MAX_RETRIES = 20 

382 

383 

384 

385# todo: overload more methods 

386 

387 

388def handle_cli( 

389 code_before: str, 

390 code_after: str, 

391 db_type: Optional[str] = None, 

392 tables: Optional[list[str] | list[list[str]]] = None, 

393 verbose: Optional[bool] = False, 

394 noop: Optional[bool] = False, 

395 magic: Optional[bool] = False, 

396 function_name: Optional[str] = "define_tables", 

397) -> bool: 

398 """ 

399 Handle user input for generating SQL migration statements based on before and after code. 

400 

401 Args: 

402 code_before (str): The code representing the state of the database before the change. 

403 code_after (str, optional): The code representing the state of the database after the change. 

404 db_type (str, optional): The type of the database (e.g., "postgres", "mysql", etc.). Defaults to None. 

405 tables (list[str] or list[list[str]], optional): The list of tables to generate SQL for. Defaults to None. 

406 verbose (bool, optional): If True, print the generated code. Defaults to False. 

407 noop (bool, optional): If True, only print the generated code but do not execute it. Defaults to False. 

408 magic (bool, optional): If True, automatically add missing variables for execution. Defaults to False. 

409 function_name (str, optional): The name of the function where the tables are defined. Defaults: "define_tables". 

410 

411 Returns: 

412 bool: True if SQL migration statements are generated and executed successfully, False otherwise. 

413 """ 

414 # todo: prefix (e.g. public.) 

415 

416 to_execute = string.Template( 

417 textwrap.dedent( 

418 """ 

419from pydal import * 

420from pydal.objects import * 

421from pydal.validators import * 

422 

423from pydal2sql import generate_sql 

424 

425 

426from pydal import DAL 

427db = database = DAL(None, migrate=False) 

428 

429tables = $tables 

430db_type = '$db_type' 

431 

432$extra 

433 

434$code_before 

435 

436db_old = db 

437db_new = db = database = DAL(None, migrate=False) 

438 

439$code_after 

440 

441if not tables: 

442 tables = set(db_old._tables + db_new._tables) 

443 

444if not tables: 

445 raise ValueError('no-tables-found') 

446 

447 

448for table in tables: 

449 print('--', table) 

450 if table in db_old and table in db_new: 

451 print(generate_sql(db_old[table], db_new[table], db_type=db_type)) 

452 elif table in db_old: 

453 print(f'DROP TABLE {table};') 

454 else: 

455 print(generate_sql(db_new[table], db_type=db_type)) 

456 """ 

457 ) 

458 ) 

459 

460 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic) 

461 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic) 

462 

463 generated_code = to_execute.substitute( 

464 { 

465 "tables": flatten(tables or []), 

466 "db_type": db_type or "", 

467 "code_before": textwrap.dedent(code_before), 

468 "code_after": textwrap.dedent(code_after), 

469 "extra": "", 

470 } 

471 ) 

472 if verbose or noop: 

473 rich.print(generated_code, file=sys.stderr) 

474 

475 if not noop: 

476 err: typing.Optional[Exception] = None 

477 catch = {} 

478 retry_counter = MAX_RETRIES 

479 while retry_counter > 0: 

480 retry_counter -= 1 

481 try: 

482 if verbose: 

483 rich.print(generated_code, file=sys.stderr) 

484 

485 # 'catch' is used to add and receive globals from the exec scope. 

486 # another argument could be added for locals, but adding simply {} changes the behavior negatively. 

487 # so for now, only globals is passed. 

488 exec(generated_code, catch) # nosec: B102 

489 return True # success! 

490 except ValueError as e: 

491 err = e 

492 

493 if str(e) != "no-tables-found": # pragma: no cover 

494 raise e 

495 

496 if function_name: 

497 define_tables = find_function_to_call(generated_code, function_name) 

498 

499 # if define_tables function is found, add call to it at end of code 

500 if define_tables is not None: 

501 generated_code = add_function_call(generated_code, function_name, multiple=True) 

502 continue 

503 

504 # else: no define_tables or other method to use found. 

505 

506 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr) 

507 print( 

508 "Please use `db.define_table` or `database.define_table`, " 

509 "or if you really need to use an alias like my_db.define_tables, " 

510 "add `my_db = db` at the top of the file or pass `--db-name mydb`.", 

511 file=sys.stderr, 

512 ) 

513 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr) 

514 

515 return False 

516 

517 except NameError as e: 

518 err = e 

519 # something is missing! 

520 missing_vars = find_missing_variables(generated_code) 

521 if not magic: 

522 rich.print( 

523 f"Your code is missing some variables: {missing_vars}. Add these or try --magic", 

524 file=sys.stderr, 

525 ) 

526 return False 

527 

528 # postponed: this can possibly also be achieved by updating 'catch'. 

529 extra_code = generate_magic_code(missing_vars) 

530 

531 generated_code = to_execute.substitute( 

532 { 

533 "tables": flatten(tables or []), 

534 "db_type": db_type or "", 

535 "extra": textwrap.dedent(extra_code), 

536 "code_before": textwrap.dedent(code_before), 

537 "code_after": textwrap.dedent(code_after), 

538 } 

539 ) 

540 except ImportError as e: 

541 err = e 

542 # if we catch an ImportError, we try to remove the import and retry 

543 generated_code = remove_import(generated_code, e.name) 

544 

545 if retry_counter < 1: # pragma: no cover 

546 rich.print(f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'}", file=sys.stderr) 

547 return False 

548 

549 return True