Coverage for src/pydal2sql_core/cli_support.py: 100%

339 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-08-05 17:27 +0200

1""" 

2CLI-Agnostic support. 

3""" 

4 

5import ast 

6import contextlib 

7import io 

8import os 

9import re 

10import select 

11import string 

12import sys 

13import textwrap 

14import traceback 

15import typing 

16from datetime import datetime 

17from pathlib import Path 

18from typing import Any, Optional 

19 

20import git 

21import gitdb.exc 

22import rich 

23from black.files import find_project_root 

24from git.objects.blob import Blob 

25from git.objects.commit import Commit 

26from git.repo import Repo 

27from witchery import ( 

28 add_function_call, 

29 find_defined_variables, 

30 find_function_to_call, 

31 find_missing_variables, 

32 generate_magic_code, 

33 has_local_imports, 

34 remove_if_falsey_blocks, 

35 remove_import, 

36 remove_local_imports, 

37 remove_specific_variables, 

38) 

39 

40from .helpers import excl, flatten, uniq 

41from .types import ( 

42 _SUPPORTED_OUTPUT_FORMATS, 

43 DEFAULT_OUTPUT_FORMAT, 

44 SUPPORTED_DATABASE_TYPES_WITH_ALIASES, 

45 SUPPORTED_OUTPUT_FORMATS, 

46 DummyDAL, 

47 DummyTypeDAL, 

48) 

49 

50 

51def has_stdin_data() -> bool: # pragma: no cover 

52 """ 

53 Check if the program starts with cli data (pipe | or redirect <). 

54 

55 Returns: 

56 bool: True if the program starts with cli data, False otherwise. 

57 

58 See Also: 

59 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data 

60 """ 

61 return any( 

62 select.select( 

63 [ 

64 sys.stdin, 

65 ], 

66 [], 

67 [], 

68 0.0, 

69 )[0] 

70 ) 

71 

72 

73AnyCallable = typing.Callable[..., Any] 

74 

75 

76def print_if_interactive(*args: Any, pretty: bool = True, **kwargs: Any) -> None: # pragma: no cover 

77 """ 

78 Print the given arguments if running in an interactive session. 

79 

80 Args: 

81 *args: Variable length list of arguments to be printed. 

82 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function. 

83 **kwargs: Optional keyword arguments to be passed to the print function. 

84 

85 Returns: 

86 None 

87 """ 

88 is_interactive = not has_stdin_data() 

89 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy 

90 if is_interactive: 

91 kwargs["file"] = sys.stderr 

92 _print( 

93 *args, 

94 **kwargs, 

95 ) 

96 

97 

98def find_git_root(at: str = None) -> Optional[Path]: 

99 """ 

100 Find the root directory of the Git repository. 

101 

102 Args: 

103 at (str, optional): The directory path to start the search. Defaults to the current working directory. 

104 

105 Returns: 

106 Optional[Path]: The root directory of the Git repository if found, otherwise None. 

107 """ 

108 folder, reason = find_project_root((at or os.getcwd(),)) 

109 if reason != ".git directory": 

110 return None 

111 return folder 

112 

113 

114def find_git_repo(repo: Repo = None, at: str = None) -> Repo: 

115 """ 

116 Find the Git repository instance. 

117 

118 Args: 

119 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance. 

120 at (str, optional): The directory path to start the search. Defaults to the current working directory. 

121 

122 Returns: 

123 Repo: The Git repository instance. 

124 """ 

125 if repo: 

126 return repo 

127 

128 root = find_git_root(at) 

129 return Repo(str(root)) 

130 

131 

132def latest_commit(repo: Repo = None) -> Commit: 

133 """ 

134 Get the latest commit in the Git repository. 

135 

136 Args: 

137 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

138 

139 Returns: 

140 Commit: The latest commit in the Git repository. 

141 """ 

142 repo = find_git_repo(repo) 

143 return repo.head.commit 

144 

145 

146def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit: 

147 """ 

148 Get a specific commit in the Git repository by its hash or name. 

149 

150 Args: 

151 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name. 

152 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

153 

154 Returns: 

155 Commit: The commit object corresponding to the given commit hash. 

156 """ 

157 repo = find_git_repo(repo) 

158 return repo.commit(commit_hash) 

159 

160 

161@contextlib.contextmanager 

162def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]: 

163 """ 

164 Open a Git Blob object as a context manager, providing access to its data. 

165 

166 Args: 

167 file (Blob): The Git Blob object to open. 

168 

169 Yields: 

170 io.BytesIO: A BytesIO object providing access to the Blob data. 

171 """ 

172 yield io.BytesIO(file.data_stream.read()) 

173 

174 

175def read_blob(file: Blob) -> str: 

176 """ 

177 Read the contents of a Git Blob object and decode it as a string. 

178 

179 Args: 

180 file (Blob): The Git Blob object to read. 

181 

182 Returns: 

183 str: The contents of the Blob as a string. 

184 """ 

185 with open_blob(file) as f: 

186 return f.read().decode() 

187 

188 

189def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str: 

190 """ 

191 Get the contents of a file in the Git repository at a specific commit version. 

192 

193 Args: 

194 filename (str): The path of the file to retrieve. 

195 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit). 

196 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

197 

198 Returns: 

199 str: The contents of the file as a string. 

200 """ 

201 repo = find_git_repo(repo, at=filename) 

202 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo) 

203 

204 file_path = str(Path(filename).resolve()) 

205 # relative to the .git folder: 

206 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/") 

207 

208 file_at_commit = commit.tree / relative_file_path 

209 return read_blob(file_at_commit) 

210 

211 

212def get_file_for_version(filename: str, version: str, prompt_description: str = "", with_git: bool = True) -> str: 

213 """ 

214 Get the contents of a file based on the version specified. 

215 

216 Args: 

217 filename (str): The path of the file to retrieve. 

218 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name. 

219 prompt_description (str, optional): A description to display when asking for input from stdin. 

220 

221 Returns: 

222 str: The contents of the file as a string. 

223 """ 

224 if not with_git or version == "current": 

225 return Path(filename).read_text() 

226 elif version == "stdin": # pragma: no cover 

227 print_if_interactive( 

228 f"[blue]Please paste your define tables ({prompt_description}) code below " 

229 f"and press ctrl-D when finished.[/blue]", 

230 file=sys.stderr, 

231 ) 

232 result = sys.stdin.read() 

233 print_if_interactive("[blue]---[/blue]", file=sys.stderr) 

234 return result 

235 elif with_git: 

236 try: 

237 return get_file_for_commit(filename, version) 

238 except (git.exc.GitError, gitdb.exc.ODBError) as e: 

239 raise FileNotFoundError(f"{filename}@{version}") from e 

240 

241 

242def extract_file_version_and_path( 

243 file_path_or_git_tag: Optional[str], default_version: str = "stdin" 

244) -> tuple[str, str | None]: 

245 """ 

246 Extract the file version and path from the given input. 

247 

248 Args: 

249 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag. 

250 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin". 

251 

252 Returns: 

253 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified). 

254 

255 Examples: 

256 myfile.py (implies @current) 

257 

258 myfile.py@latest 

259 myfile.py@my-branch 

260 myfile.py@b3f24091a9 

261 

262 @latest (implies no path, e.g. in case of ALTER to copy previously defined path) 

263 """ 

264 if not file_path_or_git_tag: 

265 return default_version, "" 

266 

267 if file_path_or_git_tag == "-": 

268 return "stdin", "-" 

269 

270 if file_path_or_git_tag.startswith("@"): 

271 file_version = file_path_or_git_tag.strip("@") 

272 file_path = None 

273 elif "@" in file_path_or_git_tag: 

274 file_path, file_version = file_path_or_git_tag.split("@") 

275 else: 

276 file_version = default_version # `latest` for before; `current` for after. 

277 file_path = file_path_or_git_tag 

278 

279 return file_version, file_path 

280 

281 

282def extract_file_versions_and_paths( 

283 filename_before: Optional[str], filename_after: Optional[str] 

284) -> tuple[tuple[str, str | None], tuple[str, str | None]]: 

285 """ 

286 Extract the file versions and paths based on the before and after filenames. 

287 

288 Args: 

289 filename_before (str, optional): The path of the file before the change (or None). 

290 filename_after (str, optional): The path of the file after the change (or None). 

291 

292 Returns: 

293 tuple[tuple[str, str | None], tuple[str, str | None]]: 

294 A tuple of two tuples, each containing the version and path of the before and after files. 

295 """ 

296 version_before, filepath_before = extract_file_version_and_path( 

297 filename_before, 

298 default_version=( 

299 "current" if filename_after and filename_before and filename_after != filename_before else "latest" 

300 ), 

301 ) 

302 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current") 

303 

304 if not (filepath_before or filepath_after): 

305 raise ValueError("Please supply at least one file name.") 

306 elif not filepath_after: 

307 filepath_after = filepath_before 

308 elif not filepath_before: 

309 filepath_before = filepath_after 

310 

311 return (version_before, filepath_before), (version_after, filepath_after) 

312 

313 

314def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]: 

315 """ 

316 Get absolute path information for the file based on the version and Git root. 

317 

318 Args: 

319 filename (str, optional): The path of the file to check (or None). 

320 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name. 

321 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined. 

322 

323 Returns: 

324 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file. 

325 """ 

326 if version == "stdin": 

327 return True, "" 

328 elif filename is None: 

329 # can't deal with this, not stdin and no file should show file missing error later. 

330 return False, "" 

331 

332 if git_root is None: 

333 git_root = find_git_root() or Path(os.getcwd()) 

334 

335 path = Path(filename) 

336 path_via_git = git_root / filename 

337 

338 if path.exists(): 

339 exists = True 

340 absolute_path = str(path.resolve()) 

341 elif path_via_git.exists(): 

342 exists = True 

343 absolute_path = str(path_via_git.resolve()) 

344 else: 

345 exists = False 

346 absolute_path = "" 

347 

348 return exists, absolute_path 

349 

350 

351def check_indentation(code: str, fix: bool = False) -> str: 

352 """ 

353 Check the indentation of the given code. 

354 

355 This function attempts to parse the code using Python's built-in `ast.parse` function. 

356 If the code is correctly indented, it is returned as is. If the code is incorrectly indented, 

357 an `IndentationError` is raised. If the `fix` parameter is set to `True`, the function will 

358 attempt to fix the indentation using `textwrap.dedent` before returning it. 

359 

360 Args: 

361 code (str): The code to check. 

362 fix (bool): Whether to fix the indentation if it is incorrect. Defaults to False. 

363 

364 Returns: 

365 str: The original code if it is correctly indented, 

366 or the fixed code if `fix` is True and the code was incorrectly indented. 

367 

368 Raises: 

369 IndentationError: If the code is incorrectly indented and `fix` is False. 

370 """ 

371 if not code: 

372 # no code? perfect indentation! 

373 return code 

374 

375 try: 

376 ast.parse(code) 

377 return code 

378 except IndentationError as e: 

379 if fix: 

380 return textwrap.dedent(code) 

381 else: 

382 raise e 

383 

384 

385def ensure_no_migrate_on_real_db( 

386 code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: bool = False 

387) -> str: 

388 """ 

389 Ensure that the code does not contain actual migrations on a real database. 

390 

391 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`. 

392 It also removes local imports to prevent irrelevant code being executed. 

393 

394 Args: 

395 code (str): The code to check for database migrations. 

396 db_names (Iterable[str], optional): Names of variables representing the database. 

397 Defaults to ("db", "database"). 

398 fix (bool, optional): If True, removes the migration code. Defaults to False. 

399 

400 Returns: 

401 str: The modified code with migration code removed if fix=True, otherwise the original code. 

402 """ 

403 variables = find_defined_variables(code) 

404 

405 found_variables = set() 

406 

407 for db_name in db_names: 

408 if db_name in variables: 

409 if fix: 

410 code = remove_specific_variables(code, db_names) 

411 else: 

412 found_variables.add(db_name) 

413 

414 if found_variables: 

415 if len(found_variables) == 1: 

416 var = next(iter(found_variables)) 

417 message = f"Variable {var} defined in code! " 

418 else: # pragma: no cover 

419 var = ", ".join(found_variables) 

420 message = f"Variables {var} defined in code! " 

421 raise ValueError( 

422 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database." 

423 ) 

424 

425 if has_local_imports(code): 

426 if fix: 

427 code = remove_local_imports(code) 

428 else: 

429 raise ValueError("Local imports are used in this file! Please remove these or use --magic.") 

430 

431 return code 

432 

433 

434MAX_RETRIES = 30 

435 

436# todo: overload more methods 

437 

438TEMPLATE_PYDAL = """ 

439from pydal import * 

440from pydal.objects import * 

441from pydal.validators import * 

442 

443from pydal2sql_core import generate_sql 

444 

445 

446# from pydal import DAL 

447db = database = DummyDAL(None, migrate=False) 

448 

449tables = $tables 

450db_type = '$db_type' 

451 

452$extra 

453 

454$code_before 

455 

456db_old = db 

457db_new = db = database = DummyDAL(None, migrate=False) 

458 

459$extra 

460 

461$code_after 

462 

463if not tables: 

464 tables = _uniq(db_old._tables + db_new._tables) 

465 tables = _excl(tables, _special_tables) 

466 

467if not tables: 

468 raise ValueError('no-tables-found') 

469 

470for table in tables: 

471 print('-- start ', table, '--', file=_file) 

472 if table in db_old and table in db_new: 

473 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file) 

474 elif table in db_old: 

475 print(f'DROP TABLE {table};', file=_file) 

476 else: 

477 print(generate_sql(db_new[table], db_type=db_type), file=_file) 

478 print('-- END OF MIGRATION --', file=_file) 

479 """ 

480 

481TEMPLATE_TYPEDAL = """ 

482from pydal import * 

483from pydal.objects import * 

484from pydal.validators import * 

485from typedal import * 

486 

487from pydal2sql_core import generate_sql 

488 

489 

490# from typedal import TypeDAL as DAL 

491db = database = DummyDAL(None, migrate=False) 

492 

493tables = $tables 

494db_type = '$db_type' 

495 

496$extra 

497 

498$code_before 

499 

500db_old = db 

501db_new = db = database = DummyDAL(None, migrate=False) 

502 

503$extra 

504 

505$code_after 

506 

507if not tables: 

508 tables = _uniq(db_old._tables + db_new._tables) 

509 tables = _excl(tables, _special_tables) 

510 

511if not tables: 

512 raise ValueError('no-tables-found') 

513 

514 

515for table in tables: 

516 print('-- start ', table, '--', file=_file) 

517 if table in db_old and table in db_new: 

518 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file) 

519 elif table in db_old: 

520 print(f'DROP TABLE {table};', file=_file) 

521 else: 

522 print(generate_sql(db_new[table], db_type=db_type), file=_file) 

523 

524 print('-- END OF MIGRATION --', file=_file) 

525 """ 

526 

527 

528def sql_to_function_name(sql_statement: str, default: Optional[str] = None) -> str: 

529 """ 

530 Extract action (CREATE, ALTER, DROP) and table name from the SQL statement. 

531 """ 

532 match = re.findall(r"(CREATE|ALTER|DROP)\s+TABLE\s+['\"]?(\w+)['\"]?", sql_statement.lower(), re.IGNORECASE) 

533 

534 if not match: 

535 # raise ValueError("Invalid SQL statement. Unable to extract action and table name.") 

536 return default or "unknown_migration" 

537 

538 action, table_name = match[0] 

539 

540 # Generate a function name with the specified format 

541 return f"{action}_{table_name}" 

542 

543 

544def _setup_generic_edwh_migrate(file: Path, is_typedal: bool) -> None: 

545 contents = ( 

546 "from edwh_migrate import migration\n" 

547 + ("from typedal import TypeDAL" if is_typedal else "from pydal import DAL") 

548 + "\n" 

549 ) 

550 

551 with file.open("w") as f: 

552 f.write(textwrap.dedent(contents)) 

553 

554 rich.print(f"[green] New migrate file {file} created [/green]") 

555 

556 

557START_RE = re.compile(r"-- start\s+\w+\s--\n") 

558 

559 

560def _build_edwh_migration( 

561 contents: str, cls: str, date: str, existing: Optional[str] = None, default_migration_name: Optional[str] = None 

562) -> str: 

563 sql_func_name = sql_to_function_name(contents, default=default_migration_name) 

564 func_name = "_placeholder_" 

565 contents = START_RE.sub("", contents) 

566 

567 for n in range(1, 1000): 

568 func_name = f"{sql_func_name}_{date}_{str(n).zfill(3)}" 

569 

570 if existing and f"def {func_name}" in existing: 

571 if contents.replace(" ", "").replace("\n", "") in existing.replace(" ", "").replace("\n", ""): 

572 rich.print(f"[yellow] migration {func_name} already exists, skipping! [/yellow]") 

573 return "" 

574 elif func_name.startswith(("alter", default_migration_name or "unknown")): 

575 # bump number because alter migrations are different 

576 continue 

577 else: 

578 rich.print( 

579 f"[red] migration {func_name} already exists [bold]with different contents[/bold], skipping! [/red]" 

580 ) 

581 return "" 

582 else: 

583 # okay function name, stop incrementing 

584 break 

585 

586 contents = textwrap.indent(contents.strip(), " " * 16) 

587 

588 if not contents.strip(): 

589 # no real migration! 

590 return "" 

591 

592 return textwrap.dedent( 

593 f''' 

594 

595 @migration 

596 def {func_name}(db: {cls}): 

597 db.executesql(""" 

598{contents} 

599 """) 

600 db.commit() 

601 

602 return True 

603 ''' 

604 ) 

605 

606 

607def _build_edwh_migrations( 

608 contents: str, is_typedal: bool, output: Optional[Path] = None, default_migration_name: Optional[str] = None 

609) -> str: 

610 cls = "TypeDAL" if is_typedal else "DAL" 

611 date = datetime.now().strftime("%Y%m%d") # yyyymmdd 

612 

613 existing = output.read_text() if output and output.exists() else None 

614 

615 return "".join( 

616 _build_edwh_migration(migration, cls, date, existing, default_migration_name=default_migration_name) 

617 for migration in contents.split("-- END OF MIGRATION --") 

618 if migration.strip() 

619 ) 

620 

621 

622def _handle_output( 

623 file: io.StringIO, 

624 output_file: Path | str | io.StringIO | None, 

625 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT, 

626 is_typedal: bool = False, 

627 default_migration_name: Optional[str] = None, 

628) -> None: 

629 file.seek(0) 

630 contents = file.read() 

631 

632 if isinstance(output_file, str): 

633 # `--output-file -` will print to stdout 

634 output_file = None if output_file == "-" else Path(output_file) 

635 

636 if output_format == "edwh-migrate": 

637 contents = _build_edwh_migrations( 

638 contents, 

639 is_typedal, 

640 output_file if isinstance(output_file, Path) else None, 

641 default_migration_name=default_migration_name, 

642 ) 

643 elif output_format in {"default", "sql"} or not output_format: 

644 contents = "\n".join(contents.split("-- END OF MIGRATION --")) 

645 else: 

646 raise ValueError( 

647 f"Unknown format {output_format}. " f"Please choose one of {typing.get_args(_SUPPORTED_OUTPUT_FORMATS)}" 

648 ) 

649 

650 if isinstance(output_file, Path): 

651 if output_format == "edwh-migrate" and (not output_file.exists() or output_file.stat().st_size == 0): 

652 _setup_generic_edwh_migrate(output_file, is_typedal) 

653 

654 if contents.strip(): 

655 with output_file.open("a") as f: 

656 f.write(contents) 

657 

658 rich.print(f"[green] Written migration(s) to {output_file} [/green]") 

659 else: 

660 rich.print(f"[yellow] Nothing to write to {output_file} [/yellow]") 

661 

662 elif isinstance(output_file, io.StringIO): 

663 output_file.write(contents) 

664 else: 

665 # no file, just print to stdout: 

666 print(contents.strip()) 

667 

668 

669IMPORT_IN_STR = re.compile(r'File "<string>", line (\d+), in <module>') 

670 

671 

672def _handle_import_error(code: str, error: ImportError) -> str: 

673 # error is deeper in a package, find the related import in the code: 

674 tb_lines = traceback.format_exc().splitlines() 

675 

676 for line in tb_lines: 

677 if matches := IMPORT_IN_STR.findall(line): 

678 # 'File "<string>", line 15, in <module>' 

679 line_no = int(matches[0]) - 1 

680 lines = code.split("\n") 

681 return lines[line_no] 

682 

683 # I don't know how to trigger this case: 

684 raise ValueError("Faulty import could not be automatically deleted") from error # pragma: no cover 

685 

686 

687MISSING_RELATIONSHIP = re.compile(r"Cannot resolve reference (\w+) in \w+ definition") 

688 

689 

690def _handle_relation_error(error: KeyError) -> tuple[str, str]: 

691 if not (table := MISSING_RELATIONSHIP.findall(str(error))): 

692 # other error, raise again 

693 raise error 

694 

695 t = table[0] 

696 

697 return ( 

698 t, 

699 """ 

700 db.define_table('%s', redefine=True) 

701 """ 

702 % t, 

703 ) 

704 

705 

706def handle_cli( 

707 code_before: str, 

708 code_after: str, 

709 db_type: Optional[str] = None, 

710 tables: Optional[list[str] | list[list[str]]] = None, 

711 verbose: bool = False, 

712 noop: bool = False, 

713 magic: bool = False, 

714 function_name: Optional[str | tuple[str, ...]] = "define_tables", 

715 use_typedal: bool | typing.Literal["auto"] = "auto", 

716 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT, 

717 output_file: Optional[str | Path | io.StringIO] = None, 

718 _update_path: bool = True, 

719) -> bool: 

720 """ 

721 Handle user input for generating SQL migration statements based on before and after code. 

722 

723 Args: 

724 code_before (str): The code representing the state of the database before the change. 

725 code_after (str, optional): The code representing the state of the database after the change. 

726 db_type (str, optional): The type of the database (e.g., "postgres", "mysql", etc.). Defaults to None. 

727 tables (list[str] or list[list[str]], optional): The list of tables to generate SQL for. Defaults to None. 

728 verbose (bool, optional): If True, print the generated code. Defaults to False. 

729 noop (bool, optional): If True, only print the generated code but do not execute it. Defaults to False. 

730 magic (bool, optional): If True, automatically add missing variables for execution. Defaults to False. 

731 function_name (str, optional): The name of the function where the tables are defined. Defaults: "define_tables". 

732 use_typedal: replace pydal imports with TypeDAL? 

733 output_format: defaults to just SQL, edwh-migrate migration syntax also supported 

734 output_file: append the output to a file instead of printing it? 

735 _update_path: try adding cwd to PYTHONPATH if failing? 

736 

737 # todo: prefix (e.g. public.) 

738 

739 Returns: 

740 bool: True if SQL migration statements are generated and executed successfully, False otherwise. 

741 """ 

742 # todo: better typedal checking 

743 if use_typedal == "auto": 

744 use_typedal = "typedal" in code_before.lower() or "typedal" in code_after.lower() 

745 

746 if function_name: 

747 define_table_functions: set[str] = set(function_name) if isinstance(function_name, tuple) else {function_name} 

748 else: 

749 define_table_functions = set() 

750 

751 template = TEMPLATE_TYPEDAL if use_typedal else TEMPLATE_PYDAL 

752 

753 to_execute = string.Template(textwrap.dedent(template)) 

754 

755 code_before = check_indentation(code_before, fix=magic) 

756 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic) 

757 

758 code_after = check_indentation(code_after, fix=magic) 

759 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic) 

760 

761 extra_code = "" 

762 

763 generated_code = to_execute.substitute( 

764 { 

765 "tables": flatten(tables or []), 

766 "db_type": db_type or "", 

767 "code_before": textwrap.dedent(code_before), 

768 "code_after": textwrap.dedent(code_after), 

769 "extra": extra_code, 

770 } 

771 ) 

772 if verbose or noop: 

773 rich.print(generated_code, file=sys.stderr) 

774 

775 if noop: 

776 # done 

777 return True 

778 

779 err: typing.Optional[Exception] = None 

780 catch: dict[str, Any] = {} 

781 retry_counter = MAX_RETRIES 

782 

783 magic_vars = {"_file", "DummyDAL", "_special_tables", "_uniq", "_excl"} 

784 special_tables: set[str] = {"typedal_cache", "typedal_cache_dependency"} if use_typedal else set() 

785 

786 cwd = os.getcwd() 

787 while retry_counter: 

788 if err and cwd not in sys.path: 

789 # add current path to PYTHONPATH to possibly resolve imports 

790 sys.path.append(os.getcwd()) 

791 

792 retry_counter -= 1 

793 try: 

794 if verbose: 

795 rich.print(generated_code, file=sys.stderr) 

796 

797 # 'catch' is used to add and receive globals from the exec scope. 

798 # another argument could be added for locals, but adding simply {} changes the behavior negatively. 

799 # so for now, only globals is passed. 

800 catch["_file"] = io.StringIO() # <- every print should go to this file, so we can handle it afterwards 

801 catch["DummyDAL"] = ( 

802 DummyTypeDAL if use_typedal else DummyDAL 

803 ) # <- use a fake DAL that doesn't actually run queries 

804 catch["_special_tables"] = special_tables # <- e.g. typedal_cache, auth_user 

805 # note: when adding something to 'catch', also add it to magic_vars!!! 

806 

807 catch["_uniq"] = uniq # function to make a list unique without changing order 

808 catch["_excl"] = excl # function to exclude items from a list 

809 

810 exec(generated_code, catch) # nosec: B102 

811 _handle_output(catch["_file"], output_file, output_format, is_typedal=use_typedal) 

812 return True # success! 

813 except ValueError as e: 

814 if str(e) != "no-tables-found": # pragma: no cover 

815 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr) 

816 return False 

817 

818 if define_table_functions: 

819 any_found = False 

820 for function_name in define_table_functions: 

821 define_tables = find_function_to_call(generated_code, function_name) 

822 

823 # if define_tables function is found, add call to it at end of code 

824 if define_tables is not None: 

825 generated_code = add_function_call(generated_code, function_name, multiple=True) 

826 any_found = True 

827 

828 if any_found: 

829 # hurray! 

830 continue 

831 

832 # else: no define_tables or other method to use found. 

833 

834 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr) 

835 if use_typedal: 

836 print( 

837 "Please use `db.define` or `database.define`, " 

838 "or if you really need to use an alias like my_db.define, " 

839 "add `my_db = db` at the top of the file or pass `--db-name mydb`.", 

840 file=sys.stderr, 

841 ) 

842 else: 

843 print( 

844 "Please use `db.define_table` or `database.define_table`, " 

845 "or if you really need to use an alias like my_db.define_tables, " 

846 "add `my_db = db` at the top of the file or pass `--db-name mydb`.", 

847 file=sys.stderr, 

848 ) 

849 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr) 

850 

851 return False 

852 

853 except NameError as e: 

854 err = e 

855 # something is missing! 

856 missing_vars = find_missing_variables(generated_code) - magic_vars 

857 if not magic: 

858 rich.print( 

859 f"Your code is missing some variables: {missing_vars}. Add these or try --magic", 

860 file=sys.stderr, 

861 ) 

862 return False 

863 

864 # postponed: this can possibly also be achieved by updating the 'catch' dict 

865 # instead of injecting in the string. 

866 extra_code = extra_code + "\n" + textwrap.dedent(generate_magic_code(missing_vars)) 

867 

868 code_before = remove_if_falsey_blocks(code_before) 

869 code_after = remove_if_falsey_blocks(code_after) 

870 

871 generated_code = to_execute.substitute( 

872 { 

873 "tables": flatten(tables or []), 

874 "db_type": db_type or "", 

875 "extra": textwrap.dedent(extra_code), 

876 "code_before": textwrap.dedent(code_before), 

877 "code_after": textwrap.dedent(code_after), 

878 } 

879 ) 

880 except ImportError as e: 

881 # should include ModuleNotFoundError 

882 err = e 

883 # if we catch an ImportError, we try to remove the import and retry 

884 if not e.path: 

885 # code exists in code itself 

886 code_before = remove_import(code_before, e.name or "") 

887 code_after = remove_import(code_after, e.name or "") 

888 else: 

889 to_remove = _handle_import_error(generated_code, e) 

890 code_before = code_before.replace(to_remove, "\n") 

891 code_after = code_after.replace(to_remove, "\n") 

892 

893 generated_code = to_execute.substitute( 

894 { 

895 "tables": flatten(tables or []), 

896 "db_type": db_type or "", 

897 "extra": textwrap.dedent(extra_code), 

898 "code_before": textwrap.dedent(code_before), 

899 "code_after": textwrap.dedent(code_after), 

900 } 

901 ) 

902 

903 except KeyError as e: 

904 err = e 

905 table_name, table_definition = _handle_relation_error(e) 

906 special_tables.add(table_name) 

907 extra_code = extra_code + "\n" + textwrap.dedent(table_definition) 

908 

909 generated_code = to_execute.substitute( 

910 { 

911 "tables": flatten(tables or []), 

912 "db_type": db_type or "", 

913 "extra": textwrap.dedent(extra_code), 

914 "code_before": textwrap.dedent(code_before), 

915 "code_after": textwrap.dedent(code_after), 

916 } 

917 ) 

918 except Exception as e: 

919 err = e 

920 # otherwise: give up 

921 retry_counter = 0 

922 finally: 

923 # reset: 

924 typing.TYPE_CHECKING = False 

925 

926 if retry_counter < 1: # pragma: no cover 

927 rich.print( 

928 f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'} ({type(err)})", file=sys.stderr 

929 ) 

930 return False 

931 

932 # idk when this would happen, but something definitely went wrong here: 

933 return False # pragma: no cover 

934 

935 

936def core_create( 

937 filename: Optional[str] = None, 

938 tables: Optional[list[str]] = None, 

939 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None, 

940 magic: bool = False, 

941 noop: bool = False, 

942 verbose: bool = False, 

943 function: Optional[str | tuple[str, ...]] = None, 

944 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT, 

945 output_file: Optional[str | Path] = None, 

946 _update_path: bool = True, 

947) -> bool: 

948 """ 

949 Generates SQL migration statements for creating one or more tables, based on the code in a given source file. 

950 

951 Args: 

952 filename: The filename of the source file to parse. This code represents the final state of the database. 

953 tables: A list of table names to generate SQL for. 

954 If None, the function will attempt to process all tables found in the code. 

955 db_type: The type of the database. If None, the function will attempt to infer it from the code. 

956 magic: If True, automatically add missing variables for execution. 

957 noop: If True, only print the generated code but do not execute it. 

958 verbose: If True, print the generated code and additional debug information. 

959 function: The name of the function where the tables are defined. 

960 If None, the function will use 'define_tables'. 

961 output_format: defaults to just SQL, edwh-migrate migration syntax also supported 

962 output_file: append the output to a file instead of printing it? 

963 _update_path: try adding cwd to PYTHONPATH if failing? 

964 

965 Returns: 

966 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully, 

967 False otherwise. 

968 

969 Raises: 

970 ValueError: If the source file cannot be found or if no tables could be found in the code. 

971 """ 

972 git_root = find_git_root() or Path(os.getcwd()) 

973 

974 functions: set[str] = set() 

975 if function: # pragma: no cover 

976 if isinstance(function, tuple): 

977 functions.update(function) 

978 else: 

979 functions.add(function) 

980 

981 if filename and ":" in filename: 

982 # e.g. models.py:define_tables 

983 filename, _function = filename.split(":", 1) 

984 functions.add(_function) 

985 

986 file_version, file_path = extract_file_version_and_path( 

987 filename, default_version="current" if filename else "stdin" 

988 ) 

989 file_exists, file_absolute_path = get_absolute_path_info(file_path, file_version, git_root) 

990 

991 if not file_exists: 

992 raise FileNotFoundError(f"Source file {filename} could not be found.") 

993 

994 text = get_file_for_version(file_absolute_path, file_version, prompt_description="table definition") 

995 

996 return handle_cli( 

997 "", 

998 text, 

999 db_type=db_type, 

1000 tables=tables, 

1001 verbose=verbose, 

1002 noop=noop, 

1003 magic=magic, 

1004 function_name=tuple(functions), 

1005 output_format=output_format, 

1006 output_file=output_file, 

1007 _update_path=_update_path, 

1008 ) 

1009 

1010 

1011def core_alter( 

1012 filename_before: Optional[str] = None, 

1013 filename_after: Optional[str] = None, 

1014 tables: Optional[list[str]] = None, 

1015 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None, 

1016 magic: bool = False, 

1017 noop: bool = False, 

1018 verbose: bool = False, 

1019 function: Optional[str] = None, 

1020 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT, 

1021 output_file: Optional[str | Path] = None, 

1022 _update_path: bool = True, 

1023) -> bool: 

1024 """ 

1025 Generates SQL migration statements for altering the database, based on the code in two given source files. 

1026 

1027 Args: 

1028 filename_before: The filename of the source file before changes. 

1029 This code represents the initial state of the database. 

1030 filename_after: The filename of the source file after changes. 

1031 This code represents the final state of the database. 

1032 tables: A list of table names to generate SQL for. 

1033 If None, the function will attempt to process all tables found in the code. 

1034 db_type: The type of the database. If None, the function will attempt to infer it from the code. 

1035 magic: If True, automatically add missing variables for execution. 

1036 noop: If True, only print the generated code but do not execute it. 

1037 verbose: If True, print the generated code and additional debug information. 

1038 function: The name of the function where the tables are defined. 

1039 If None, the function will use 'define_tables'. 

1040 output_format: defaults to just SQL, edwh-migrate migration syntax also supported 

1041 output_file: append the output to a file instead of printing it? 

1042 _update_path: try adding cwd to PYTHONPATH if failing? 

1043 

1044 Returns: 

1045 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully, 

1046 False otherwise. 

1047 

1048 Raises: 

1049 ValueError: If either of the source files cannot be found, if no tables could be found in the code, 

1050 or if the codes before and after are identical. 

1051 """ 

1052 git_root = find_git_root(filename_before) or find_git_root(filename_after) 

1053 

1054 functions: set[str] = set() 

1055 if function: # pragma: no cover 

1056 functions.add(function) 

1057 

1058 if filename_before and ":" in filename_before: 

1059 # e.g. models.py:define_tables 

1060 filename_before, _function = filename_before.split(":", 1) 

1061 functions.add(_function) 

1062 

1063 if filename_after and ":" in filename_after: 

1064 # e.g. models.py:define_tables 

1065 filename_after, _function = filename_after.split(":", 1) 

1066 functions.add(_function) 

1067 

1068 before, after = extract_file_versions_and_paths(filename_before, filename_after) 

1069 

1070 version_before, filename_before = before 

1071 version_after, filename_after = after 

1072 

1073 # either ./file exists or /file exists (seen from git root): 

1074 

1075 before_exists, before_absolute_path = get_absolute_path_info(filename_before, version_before, git_root) 

1076 after_exists, after_absolute_path = get_absolute_path_info(filename_after, version_after, git_root) 

1077 

1078 if not (before_exists and after_exists): 

1079 message = "" 

1080 message += "" if before_exists else f"Path {filename_before} does not exist! " 

1081 if filename_before != filename_after: 

1082 message += "" if after_exists else f"Path {filename_after} does not exist!" 

1083 raise FileNotFoundError(message) 

1084 

1085 try: 

1086 code_before = get_file_for_version( 

1087 before_absolute_path, 

1088 version_before, 

1089 prompt_description="current table definition", 

1090 with_git=git_root is not None, 

1091 ) 

1092 code_after = get_file_for_version( 

1093 after_absolute_path, 

1094 version_after, 

1095 prompt_description="desired table definition", 

1096 with_git=git_root is not None, 

1097 ) 

1098 

1099 if not (code_before and code_after): 

1100 message = "" 

1101 message += "" if code_before else "Before code is empty (Maybe try `pydal2sql create`)! " 

1102 message += "" if code_after else "After code is empty! " 

1103 raise ValueError(message) 

1104 

1105 if code_before == code_after: 

1106 raise ValueError("Both contain the same code - nothing to alter!") 

1107 

1108 except ValueError as e: 

1109 rich.print(f"[yellow] {e} [/yellow]", file=sys.stderr) 

1110 return False 

1111 

1112 return handle_cli( 

1113 code_before, 

1114 code_after, 

1115 db_type=db_type, 

1116 tables=tables, 

1117 verbose=verbose, 

1118 noop=noop, 

1119 magic=magic, 

1120 function_name=tuple(functions), 

1121 output_format=output_format, 

1122 output_file=output_file, 

1123 ) 

1124 

1125 

1126def core_stub( 

1127 migration_name: str, 

1128 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT, 

1129 output_file: Optional[str | Path | io.StringIO] = None, 

1130 dry_run: bool = False, 

1131 is_typedal: bool = False, 

1132): 

1133 """ 

1134 Generate a dummy migration. 

1135 

1136 (mostly useful for usage with output_format=edwh-migrate) 

1137 """ 

1138 if dry_run: 

1139 # None will print it to stdout 

1140 output_file = None 

1141 

1142 _handle_output( 

1143 io.StringIO(f"-- {migration_name}\n"), 

1144 output_file, 

1145 output_format=output_format, 

1146 is_typedal=is_typedal, 

1147 default_migration_name=migration_name, 

1148 )