Coverage for src/pydal2sql_core/cli_support.py: 99%

382 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-04-22 13:56 +0200

1""" 

2CLI-Agnostic support. 

3""" 

4 

5import ast 

6import contextlib 

7import io 

8import os 

9import re 

10import select 

11import string 

12import sys 

13import textwrap 

14import traceback 

15import typing 

16from dataclasses import dataclass 

17from datetime import datetime 

18from pathlib import Path 

19from typing import Any, Optional 

20 

21import git 

22import gitdb.exc 

23import rich 

24from black.files import find_project_root 

25from git.objects.blob import Blob 

26from git.objects.commit import Commit 

27from git.repo import Repo 

28from witchery import ( 

29 add_function_call, 

30 find_defined_variables, 

31 find_function_to_call, 

32 find_missing_variables, 

33 generate_magic_code, 

34 has_local_imports, 

35 remove_if_falsey_blocks, 

36 remove_import, 

37 remove_local_imports, 

38 remove_specific_variables, 

39) 

40 

41from .core import generate_sql 

42from .helpers import detect_typedal, excl, flatten, uniq 

43from .state import state 

44from .types import ( 

45 _SUPPORTED_OUTPUT_FORMATS, 

46 DEFAULT_OUTPUT_FORMAT, 

47 SUPPORTED_DATABASE_TYPES_WITH_ALIASES, 

48 SUPPORTED_OUTPUT_FORMATS, 

49 DummyDAL, 

50 DummyTypeDAL, 

51) 

52 

53IMPORT_IN_STR = re.compile(r'File "<string>", line (\d+), in <module>') 

54 

55 

56def has_stdin_data() -> bool: # pragma: no cover 

57 """ 

58 Check if the program starts with cli data (pipe | or redirect <). 

59 

60 Returns: 

61 bool: True if the program starts with cli data, False otherwise. 

62 

63 See Also: 

64 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data 

65 """ 

66 return any( 

67 select.select( 

68 [ 

69 sys.stdin, 

70 ], 

71 [], 

72 [], 

73 0.0, 

74 )[0], 

75 ) 

76 

77 

78AnyCallable = typing.Callable[..., Any] 

79 

80 

81def print_if_interactive(*args: Any, pretty: bool = True, **kwargs: Any) -> None: # pragma: no cover 

82 """ 

83 Print the given arguments if running in an interactive session. 

84 

85 Args: 

86 *args: Variable length list of arguments to be printed. 

87 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function. 

88 **kwargs: Optional keyword arguments to be passed to the print function. 

89 

90 Returns: 

91 None 

92 """ 

93 is_interactive = not has_stdin_data() 

94 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy 

95 if is_interactive: 

96 kwargs["file"] = sys.stderr 

97 _print( 

98 *args, 

99 **kwargs, 

100 ) 

101 

102 

103def find_git_root(at: str = None) -> Optional[Path]: 

104 """ 

105 Find the root directory of the Git repository. 

106 

107 Args: 

108 at (str, optional): The directory path to start the search. Defaults to the current working directory. 

109 

110 Returns: 

111 Optional[Path]: The root directory of the Git repository if found, otherwise None. 

112 """ 

113 folder, reason = find_project_root((at or os.getcwd(),)) 

114 if reason != ".git directory": 

115 return None 

116 return folder 

117 

118 

119def find_git_repo(repo: Repo = None, at: str = None) -> Repo: 

120 """ 

121 Find the Git repository instance. 

122 

123 Args: 

124 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance. 

125 at (str, optional): The directory path to start the search. Defaults to the current working directory. 

126 

127 Returns: 

128 Repo: The Git repository instance. 

129 """ 

130 if repo: 

131 return repo 

132 

133 root = find_git_root(at) 

134 return Repo(str(root)) 

135 

136 

137def latest_commit(repo: Repo = None) -> Commit: 

138 """ 

139 Get the latest commit in the Git repository. 

140 

141 Args: 

142 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

143 

144 Returns: 

145 Commit: The latest commit in the Git repository. 

146 """ 

147 repo = find_git_repo(repo) 

148 return repo.head.commit 

149 

150 

151def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit: 

152 """ 

153 Get a specific commit in the Git repository by its hash or name. 

154 

155 Args: 

156 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name. 

157 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

158 

159 Returns: 

160 Commit: The commit object corresponding to the given commit hash. 

161 """ 

162 repo = find_git_repo(repo) 

163 return repo.commit(commit_hash) 

164 

165 

166@contextlib.contextmanager 

167def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]: 

168 """ 

169 Open a Git Blob object as a context manager, providing access to its data. 

170 

171 Args: 

172 file (Blob): The Git Blob object to open. 

173 

174 Yields: 

175 io.BytesIO: A BytesIO object providing access to the Blob data. 

176 """ 

177 yield io.BytesIO(file.data_stream.read()) 

178 

179 

180def read_blob(file: Blob) -> str: 

181 """ 

182 Read the contents of a Git Blob object and decode it as a string. 

183 

184 Args: 

185 file (Blob): The Git Blob object to read. 

186 

187 Returns: 

188 str: The contents of the Blob as a string. 

189 """ 

190 with open_blob(file) as f: 

191 return f.read().decode() 

192 

193 

194def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str: 

195 """ 

196 Get the contents of a file in the Git repository at a specific commit version. 

197 

198 Args: 

199 filename (str): The path of the file to retrieve. 

200 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit). 

201 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

202 

203 Returns: 

204 str: The contents of the file as a string. 

205 """ 

206 repo = find_git_repo(repo, at=filename) 

207 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo) 

208 

209 file_path = str(Path(filename).resolve()) 

210 # relative to the .git folder: 

211 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/") 

212 

213 file_at_commit = commit.tree / relative_file_path 

214 return read_blob(file_at_commit) 

215 

216 

217def get_file_for_version(filename: str, version: str, prompt_description: str = "", with_git: bool = True) -> str: 

218 """ 

219 Get the contents of a file based on the version specified. 

220 

221 Args: 

222 filename (str): The path of the file to retrieve. 

223 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name. 

224 prompt_description (str, optional): A description to display when asking for input from stdin. 

225 

226 Returns: 

227 str: The contents of the file as a string. 

228 """ 

229 if not with_git or version == "current": 

230 return Path(filename).read_text() 

231 elif version == "stdin": # pragma: no cover 

232 print_if_interactive( 

233 f"[blue]Please paste your define tables ({prompt_description}) code below " 

234 f"and press ctrl-D when finished.[/blue]", 

235 file=sys.stderr, 

236 ) 

237 result = sys.stdin.read() 

238 print_if_interactive("[blue]---[/blue]", file=sys.stderr) 

239 return result 

240 elif with_git: 

241 try: 

242 return get_file_for_commit(filename, version) 

243 except (git.exc.GitError, gitdb.exc.ODBError) as e: 

244 raise FileNotFoundError(f"{filename}@{version}") from e 

245 

246 

247def extract_file_version_and_path( 

248 file_path_or_git_tag: Optional[str], 

249 default_version: str = "stdin", 

250) -> tuple[str, str | None]: 

251 """ 

252 Extract the file version and path from the given input. 

253 

254 Args: 

255 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag. 

256 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin". 

257 

258 Returns: 

259 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified). 

260 

261 Examples: 

262 myfile.py (implies @current) 

263 

264 myfile.py@latest 

265 myfile.py@my-branch 

266 myfile.py@b3f24091a9 

267 

268 @latest (implies no path, e.g. in case of ALTER to copy previously defined path) 

269 """ 

270 if not file_path_or_git_tag: 

271 return default_version, "" 

272 

273 if file_path_or_git_tag == "-": 

274 return "stdin", "-" 

275 

276 if file_path_or_git_tag.startswith("@"): 

277 file_version = file_path_or_git_tag.strip("@") 

278 file_path = None 

279 elif "@" in file_path_or_git_tag: 

280 file_path, file_version = file_path_or_git_tag.split("@") 

281 else: 

282 file_version = default_version # `latest` for before; `current` for after. 

283 file_path = file_path_or_git_tag 

284 

285 return file_version, file_path 

286 

287 

288def extract_file_versions_and_paths( 

289 filename_before: Optional[str], 

290 filename_after: Optional[str], 

291) -> tuple[tuple[str, str | None], tuple[str, str | None]]: 

292 """ 

293 Extract the file versions and paths based on the before and after filenames. 

294 

295 Args: 

296 filename_before (str, optional): The path of the file before the change (or None). 

297 filename_after (str, optional): The path of the file after the change (or None). 

298 

299 Returns: 

300 tuple[tuple[str, str | None], tuple[str, str | None]]: 

301 A tuple of two tuples, each containing the version and path of the before and after files. 

302 """ 

303 version_before, filepath_before = extract_file_version_and_path( 

304 filename_before, 

305 default_version=( 

306 "current" if filename_after and filename_before and filename_after != filename_before else "latest" 

307 ), 

308 ) 

309 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current") 

310 

311 if not (filepath_before or filepath_after): 

312 raise ValueError("Please supply at least one file name.") 

313 elif not filepath_after: 

314 filepath_after = filepath_before 

315 elif not filepath_before: 

316 filepath_before = filepath_after 

317 

318 return (version_before, filepath_before), (version_after, filepath_after) 

319 

320 

321def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]: 

322 """ 

323 Get absolute path information for the file based on the version and Git root. 

324 

325 Args: 

326 filename (str, optional): The path of the file to check (or None). 

327 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name. 

328 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined. 

329 

330 Returns: 

331 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file. 

332 """ 

333 if version == "stdin": 

334 return True, "" 

335 elif filename is None: 

336 # can't deal with this, not stdin and no file should show file missing error later. 

337 return False, "" 

338 

339 if git_root is None: 

340 git_root = find_git_root() or Path(os.getcwd()) 

341 

342 path = Path(filename) 

343 path_via_git = git_root / filename 

344 

345 if path.exists(): 

346 exists = True 

347 absolute_path = str(path.resolve()) 

348 elif path_via_git.exists(): 

349 exists = True 

350 absolute_path = str(path_via_git.resolve()) 

351 else: 

352 exists = False 

353 absolute_path = "" 

354 

355 return exists, absolute_path 

356 

357 

358def check_indentation(code: str, fix: bool = False) -> str: 

359 """ 

360 Check the indentation of the given code. 

361 

362 This function attempts to parse the code using Python's built-in `ast.parse` function. 

363 If the code is correctly indented, it is returned as is. If the code is incorrectly indented, 

364 an `IndentationError` is raised. If the `fix` parameter is set to `True`, the function will 

365 attempt to fix the indentation using `textwrap.dedent` before returning it. 

366 

367 Args: 

368 code (str): The code to check. 

369 fix (bool): Whether to fix the indentation if it is incorrect. Defaults to False. 

370 

371 Returns: 

372 str: The original code if it is correctly indented, 

373 or the fixed code if `fix` is True and the code was incorrectly indented. 

374 

375 Raises: 

376 IndentationError: If the code is incorrectly indented and `fix` is False. 

377 """ 

378 if not code: 

379 # no code? perfect indentation! 

380 return code 

381 

382 try: 

383 ast.parse(code) 

384 return code 

385 except IndentationError as e: 

386 if fix: 

387 return textwrap.dedent(code) 

388 else: 

389 raise e 

390 

391 

392def ensure_no_migrate_on_real_db( 

393 code: str, 

394 db_names: typing.Iterable[str] = ("db", "database"), 

395 fix: bool = False, 

396) -> str: 

397 """ 

398 Ensure that the code does not contain actual migrations on a real database. 

399 

400 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`. 

401 It also removes local imports to prevent irrelevant code being executed. 

402 

403 Args: 

404 code (str): The code to check for database migrations. 

405 db_names (Iterable[str], optional): Names of variables representing the database. 

406 Defaults to ("db", "database"). 

407 fix (bool, optional): If True, removes the migration code. Defaults to False. 

408 

409 Returns: 

410 str: The modified code with migration code removed if fix=True, otherwise the original code. 

411 """ 

412 variables = find_defined_variables(code) 

413 

414 found_variables = set() 

415 

416 for db_name in db_names: 

417 if db_name in variables: 

418 if fix: 

419 code = remove_specific_variables(code, db_names) 

420 else: 

421 found_variables.add(db_name) 

422 

423 if found_variables: 

424 if len(found_variables) == 1: 

425 var = next(iter(found_variables)) 

426 message = f"Variable {var} defined in code! " 

427 else: # pragma: no cover 

428 var = ", ".join(found_variables) 

429 message = f"Variables {var} defined in code! " 

430 raise ValueError( 

431 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database.", 

432 ) 

433 

434 if has_local_imports(code): 

435 if fix: 

436 code = remove_local_imports(code) 

437 else: 

438 raise ValueError("Local imports are used in this file! Please remove these or use --magic.") 

439 

440 return code 

441 

442 

443MAX_RETRIES = 30 

444 

445 

446# todo: overload more methods 

447 

448 

449@dataclass 

450class RenderContext: 

451 """ 

452 Context passed to migration renderers. 

453 """ 

454 

455 db_old: DummyDAL 

456 db_new: DummyDAL 

457 tables: list[str] 

458 db_type: Optional[str] 

459 use_typedal: bool 

460 is_create: bool 

461 is_alter: bool 

462 

463 

464Renderer = typing.Callable[[RenderContext], str] 

465 

466TEMPLATE_EXEC_PYDAL = """ 

467from pydal import * 

468from pydal.objects import * 

469from pydal.validators import * 

470 

471db = database = DummyDAL(None, migrate=False) 

472 

473tables = $tables 

474 

475$extra 

476 

477$code_before 

478 

479db_old = db 

480db_new = db = database = DummyDAL(None, migrate=False) 

481 

482$extra 

483 

484$code_after 

485 

486if not tables: 

487 tables = _uniq(db_old._tables + db_new._tables) 

488 tables = _excl(tables, _special_tables) 

489 

490if not tables: 

491 raise ValueError('no-tables-found') 

492_tables = tables 

493 """ 

494 

495TEMPLATE_EXEC_TYPEDAL = """ 

496from pydal import * 

497from pydal.objects import * 

498from pydal.validators import * 

499from typedal import * 

500 

501db = database = DummyDAL(None, migrate=False) 

502 

503tables = $tables 

504 

505$extra 

506 

507$code_before 

508 

509db_old = db 

510db_new = db = database = DummyDAL(None, migrate=False) 

511 

512$extra 

513 

514$code_after 

515 

516if not tables: 

517 tables = _uniq(db_old._tables + db_new._tables) 

518 tables = _excl(tables, _special_tables) 

519 

520if not tables: 

521 raise ValueError('no-tables-found') 

522_tables = tables 

523 """ 

524 

525 

526def default_sql_renderer(context: RenderContext) -> str: 

527 """ 

528 Default SQL renderer used by handle_cli. 

529 """ 

530 output = io.StringIO() 

531 

532 for table in context.tables: 

533 print("-- start ", table, "--", file=output) 

534 if table in context.db_old and table in context.db_new: 

535 print(generate_sql(context.db_old[table], context.db_new[table], db_type=context.db_type), file=output) 

536 elif table in context.db_old: 

537 print(f"DROP TABLE {table};", file=output) 

538 else: 

539 print(generate_sql(context.db_new[table], db_type=context.db_type), file=output) 

540 print("-- END OF MIGRATION --", file=output) 

541 

542 output.seek(0) 

543 return output.read() 

544 

545 

546def sql_to_function_name(sql_statement: str, default: Optional[str] = None) -> str: 

547 """ 

548 Extract action (CREATE, ALTER, DROP) and table name from the SQL statement. 

549 """ 

550 match = re.findall(r"(CREATE|ALTER|DROP)\s+TABLE\s+['\"]?(\w+)['\"]?", sql_statement.lower(), re.IGNORECASE) 

551 

552 if not match: 

553 # raise ValueError("Invalid SQL statement. Unable to extract action and table name.") 

554 return default or "unknown_migration" 

555 

556 action, table_name = match[0] 

557 

558 # Generate a function name with the specified format 

559 return f"{action}_{table_name}" 

560 

561 

562def _setup_generic_edwh_migrate(file: Path, is_typedal: bool) -> None: 

563 contents = ( 

564 "from edwh_migrate import migration\n" 

565 + ("from typedal import TypeDAL" if is_typedal else "from pydal import DAL") 

566 + "\n" 

567 ) 

568 

569 with file.open("w") as f: 

570 f.write(textwrap.dedent(contents)) 

571 

572 rich.print(f"[green] New migrate file {file} created [/green]") 

573 

574 

575START_RE = re.compile(r"-- start\s+\w+\s--\n") 

576 

577 

578def _build_edwh_migration( 

579 contents: str, 

580 cls: str, 

581 date: str, 

582 existing: Optional[str] = None, 

583 default_migration_name: Optional[str] = None, 

584) -> str: 

585 sql_func_name = sql_to_function_name(contents, default=default_migration_name) 

586 func_name = "_placeholder_" 

587 contents = START_RE.sub("", contents) 

588 

589 for n in range(1, 1000): 

590 func_name = f"{sql_func_name}_{date}_{str(n).zfill(3)}" 

591 

592 if existing and f"def {func_name}" in existing: 

593 if contents.replace(" ", "").replace("\n", "") in existing.replace(" ", "").replace("\n", ""): 

594 rich.print(f"[yellow] migration {func_name} already exists, skipping! [/yellow]") 

595 return "" 

596 elif func_name.startswith(("alter", default_migration_name or "unknown")): 

597 # bump number because alter migrations are different 

598 continue 

599 else: 

600 rich.print( 

601 f"[red] migration {func_name} already exists " 

602 f"[bold]with different contents[/bold], skipping! [/red]", 

603 ) 

604 return "" 

605 else: 

606 # okay function name, stop incrementing 

607 break 

608 

609 contents = textwrap.indent(contents.strip(), " " * 16) 

610 

611 if not contents.strip(): 

612 # no real migration! 

613 return "" 

614 

615 return textwrap.dedent( 

616 f''' 

617 

618 @migration 

619 def {func_name}(db: {cls}): 

620 db.executesql(""" 

621{contents} 

622 """) 

623 db.commit() 

624 

625 return True 

626 ''', 

627 ) 

628 

629 

630def _build_edwh_migrations( 

631 contents: str, 

632 is_typedal: bool, 

633 output: Optional[Path] = None, 

634 default_migration_name: Optional[str] = None, 

635) -> str: 

636 cls = "TypeDAL" if is_typedal else "DAL" 

637 date = datetime.now().strftime("%Y%m%d") # yyyymmdd 

638 

639 existing = output.read_text() if output and output.exists() else None 

640 

641 return "".join( 

642 _build_edwh_migration(migration, cls, date, existing, default_migration_name=default_migration_name) 

643 for migration in contents.split("-- END OF MIGRATION --") 

644 if migration.strip() 

645 ) 

646 

647 

648def _format_and_write_sql_output( 

649 file: io.StringIO, 

650 output_file: Path | str | io.StringIO | None, 

651 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT, 

652 is_typedal: bool = False, 

653 default_migration_name: Optional[str] = None, 

654) -> bool: 

655 """ 

656 Handle generated migration code (e.g. core_create, core_alter or core_stub). 

657 

658 Args: 

659 file (io.StringIO): The file-like object containing the migration content. 

660 output_file (Union[Path, str, io.StringIO, None]): The file to which the migration content should be written. 

661 If None, the content will be printed to stdout. If the string is '-', it will also be printed to stdout. 

662 output_format (edwh-migrate, default): The format in which to output the migration. 

663 Defaults to 'default'. 

664 is_typedal (bool): If True, the migration content will be formatted for Typedal. Defaults to False (= pydal). 

665 Only relevant for edwh-migrate format. 

666 default_migration_name (Optional[str]): The default migration name used in the output content. Defaults to None. 

667 

668 Raises: 

669 ValueError: If the provided output_format is unknown. 

670 

671 This function reads the migration content from the provided file-like object, formats it according to the specified 

672 output format, and writes it to the specified output file or stdout. 

673 """ 

674 file.seek(0) 

675 contents = file.read() 

676 

677 if isinstance(output_file, str): 

678 # `--output-file -` will print to stdout 

679 output_file = None if output_file == "-" else Path(output_file) 

680 

681 if output_format == "edwh-migrate": 

682 contents = _build_edwh_migrations( 

683 contents, 

684 is_typedal, 

685 output_file if isinstance(output_file, Path) else None, 

686 default_migration_name=default_migration_name, 

687 ) 

688 elif output_format in {"default", "sql"} or not output_format: 

689 contents = "\n".join(contents.split("-- END OF MIGRATION --")) 

690 else: 

691 raise ValueError( 

692 f"Unknown format {output_format}. Please choose one of {typing.get_args(_SUPPORTED_OUTPUT_FORMATS)}", 

693 ) 

694 

695 if ( 

696 isinstance(output_file, Path) 

697 and output_format == "edwh-migrate" 

698 and (not output_file.exists() or output_file.stat().st_size == 0) 

699 ): 

700 _setup_generic_edwh_migrate(output_file, is_typedal) 

701 

702 return _write_output(contents, output_file, output_description="migration(s)") 

703 

704 

705def try_format_and_write_sql_output( 

706 file: io.StringIO, 

707 output_file: Path | str | io.StringIO | None, 

708 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT, 

709 is_typedal: bool = False, 

710 default_migration_name: Optional[str] = None, 

711) -> bool: 

712 """ 

713 Safe SQL formatter/writer that returns False instead of raising on invalid format. 

714 """ 

715 try: 

716 return _format_and_write_sql_output( 

717 file, 

718 output_file, 

719 output_format=output_format, 

720 is_typedal=is_typedal, 

721 default_migration_name=default_migration_name, 

722 ) 

723 except ValueError as e: 

724 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr) 

725 return False 

726 

727 

728def _write_output( 

729 contents: str, 

730 output_file: Path | str | io.StringIO | None, 

731 output_description: str = "output", 

732) -> bool: 

733 """ 

734 Write already-rendered output to destination. 

735 """ 

736 if isinstance(output_file, str): 

737 output_file = None if output_file == "-" else Path(output_file) 

738 

739 if isinstance(output_file, Path): 

740 if contents.strip(): 

741 with output_file.open("a") as f: 

742 f.write(contents) 

743 

744 rich.print(f"[green] Written {output_description} to {output_file} [/green]") 

745 else: 

746 rich.print(f"[yellow] Nothing to write to {output_file} [/yellow]") 

747 elif isinstance(output_file, io.StringIO): 

748 output_file.write(contents) 

749 else: 

750 print(contents.strip()) 

751 

752 return True 

753 

754 

755def _handle_import_error(code: str, error: ImportError) -> str: 

756 # error is deeper in a package, find the related import in the code: 

757 tb_lines = traceback.format_exc().splitlines() 

758 

759 for line in tb_lines: 

760 if matches := IMPORT_IN_STR.findall(line): 

761 # 'File "<string>", line 15, in <module>' 

762 line_no = int(matches[0]) - 1 

763 lines = code.split("\n") 

764 return lines[line_no] 

765 

766 # I don't know how to trigger this case: 

767 raise ValueError("Faulty import could not be automatically deleted") from error # pragma: no cover 

768 

769 

770MISSING_RELATIONSHIP = re.compile(r"Cannot resolve reference (\w+) in \w+ definition") 

771 

772 

773def _handle_relation_error(error: KeyError) -> tuple[str, str]: 

774 if not (table := MISSING_RELATIONSHIP.findall(str(error))): 

775 # other error, raise again 

776 raise error 

777 

778 t = table[0] 

779 

780 return ( 

781 t, 

782 """ 

783 db.define_table('%s', redefine=True) 

784 """ 

785 % t, 

786 ) 

787 

788 

789def render_schema_from_code( 

790 code_after: str, 

791 output_file: Optional[str | Path | io.StringIO], 

792 renderer: Renderer, 

793 db_type: Optional[str] = None, 

794 tables: Optional[list[str] | list[list[str]]] = None, 

795 verbose: bool = False, 

796 noop: bool = False, 

797 magic: bool = False, 

798 function_name: Optional[str | tuple[str, ...]] = "define_tables", 

799 use_typedal: bool = False, 

800 code_before: str = "", 

801 _update_path: bool = True, 

802) -> bool: 

803 """ 

804 Execute migration code and render output with a custom renderer callback. 

805 

806 Args: 

807 code_before: Source code representing the initial table definitions. 

808 code_after: Source code representing the desired table definitions. 

809 renderer: Callback receiving RenderContext and returning final output text. 

810 db_type: Optional database type hint. 

811 tables: Explicit table selection. 

812 verbose: Print generated execution code to stderr. 

813 noop: Only print generated execution code, skip execution. 

814 magic: Automatically inject missing names/import fallbacks. 

815 function_name: Optional function(s) to call when no top-level tables are found. 

816 use_typedal: Use TypeDAL execution template (`True`), pydal (`False`), or auto-detect. 

817 output_file: Output destination (path, StringIO, stdout). 

818 _update_path: Reserved for compatibility. 

819 

820 Returns: 

821 bool: True on success, False on failure. 

822 """ 

823 if function_name: 

824 define_table_functions: set[str] = set(function_name) if isinstance(function_name, tuple) else {function_name} 

825 else: 

826 define_table_functions = set() 

827 

828 template = TEMPLATE_EXEC_TYPEDAL if use_typedal else TEMPLATE_EXEC_PYDAL 

829 

830 to_execute = string.Template(textwrap.dedent(template)) 

831 

832 code_before = check_indentation(code_before, fix=magic) 

833 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic) 

834 

835 code_after = check_indentation(code_after, fix=magic) 

836 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic) 

837 

838 extra_code = "" 

839 

840 def _render_exec_code(_code_before: str, _code_after: str, _extra: str) -> str: 

841 generated = to_execute.substitute( 

842 { 

843 "tables": flatten(tables or []), 

844 "code_before": textwrap.dedent(_code_before), 

845 "code_after": textwrap.dedent(_code_after), 

846 "extra": textwrap.dedent(_extra), 

847 }, 

848 ) 

849 

850 # Function-defined tables (e.g. define_tables/define_td_tables) should be executed 

851 # regardless of whether an explicit --tables filter is provided. 

852 for _function_name in define_table_functions: 

853 if find_function_to_call(generated, _function_name) is not None: 

854 generated = add_function_call(generated, _function_name, multiple=True) 

855 

856 return generated 

857 

858 generated_code = _render_exec_code(code_before, code_after, extra_code) 

859 if verbose or noop: 

860 rich.print(generated_code, file=sys.stderr) 

861 

862 if noop: 

863 # done 

864 return True 

865 

866 err: typing.Optional[Exception] = None 

867 catch: dict[str, Any] = {} 

868 retry_counter = MAX_RETRIES 

869 

870 magic_vars = {"DummyDAL", "_special_tables", "_uniq", "_excl"} 

871 special_tables: set[str] = {"typedal_cache", "typedal_cache_dependency"} if use_typedal else set() 

872 

873 cwd = os.getcwd() 

874 while retry_counter: 

875 if err and cwd not in sys.path: 

876 # add current path to PYTHONPATH to possibly resolve imports 

877 sys.path.append(os.getcwd()) 

878 

879 retry_counter -= 1 

880 try: 

881 if verbose: 

882 rich.print(generated_code, file=sys.stderr) 

883 

884 # 'catch' is used to add and receive globals from the exec scope. 

885 # another argument could be added for locals, but adding simply {} changes the behavior negatively. 

886 # so for now, only globals is passed. 

887 catch["DummyDAL"] = ( 

888 DummyTypeDAL if use_typedal else DummyDAL 

889 ) # <- use a fake DAL that doesn't actually run queries 

890 catch["_special_tables"] = special_tables # <- e.g. typedal_cache, auth_user 

891 catch["db_type"] = db_type or "" # allow source code to override db_type during exec 

892 # note: when adding something to 'catch', also add it to magic_vars!!! 

893 

894 catch["_uniq"] = uniq # function to make a list unique without changing order 

895 catch["_excl"] = excl # function to exclude items from a list 

896 

897 exec(generated_code, catch) # nosec: B102 

898 

899 context = RenderContext( 

900 db_old=typing.cast(DummyDAL, catch["db_old"]), 

901 db_new=typing.cast(DummyDAL, catch["db_new"]), 

902 tables=list(catch["_tables"]), 

903 db_type=catch.get("db_type", db_type), 

904 use_typedal=use_typedal, 

905 is_create=not bool(code_before.strip()), 

906 is_alter=bool(code_before.strip()), 

907 ) 

908 rendered = renderer(context) 

909 _write_output(rendered, output_file) 

910 return True # success! 

911 except ValueError as e: 

912 if str(e) != "no-tables-found": # pragma: no cover 

913 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr) 

914 return False 

915 

916 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr) 

917 if use_typedal: 

918 print( 

919 "Please use `db.define` or `database.define`, " 

920 "or if you really need to use an alias like my_db.define, " 

921 "add `my_db = db` at the top of the file or pass `--db-name mydb`.", 

922 file=sys.stderr, 

923 ) 

924 else: 

925 print( 

926 "Please use `db.define_table` or `database.define_table`, " 

927 "or if you really need to use an alias like my_db.define_tables, " 

928 "add `my_db = db` at the top of the file or pass `--db-name mydb`.", 

929 file=sys.stderr, 

930 ) 

931 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr) 

932 

933 return False 

934 

935 except NameError as e: 

936 err = e 

937 # something is missing! 

938 missing_vars = find_missing_variables(generated_code) - magic_vars 

939 if not magic: 

940 rich.print( 

941 f"Your code is missing some variables: {missing_vars}. Add these or try --magic", 

942 file=sys.stderr, 

943 ) 

944 return False 

945 

946 # postponed: this can possibly also be achieved by updating the 'catch' dict 

947 # instead of injecting in the string. 

948 extra_code = extra_code + "\n" + textwrap.dedent(generate_magic_code(missing_vars)) 

949 

950 code_before = remove_if_falsey_blocks(code_before) 

951 code_after = remove_if_falsey_blocks(code_after) 

952 

953 generated_code = _render_exec_code(code_before, code_after, extra_code) 

954 except ImportError as e: 

955 # should include ModuleNotFoundError 

956 err = e 

957 # if we catch an ImportError, we try to remove the import and retry 

958 if not e.path: 

959 # code exists in code itself 

960 code_before = remove_import(code_before, e.name or "") 

961 code_after = remove_import(code_after, e.name or "") 

962 else: 

963 to_remove = _handle_import_error(generated_code, e) 

964 code_before = code_before.replace(to_remove, "\n") 

965 code_after = code_after.replace(to_remove, "\n") 

966 

967 generated_code = _render_exec_code(code_before, code_after, extra_code) 

968 

969 except KeyError as e: 

970 err = e 

971 table_name, table_definition = _handle_relation_error(e) 

972 special_tables.add(table_name) 

973 extra_code = extra_code + "\n" + textwrap.dedent(table_definition) 

974 

975 generated_code = _render_exec_code(code_before, code_after, extra_code) 

976 except Exception as e: 

977 err = e 

978 # otherwise: give up 

979 retry_counter = 0 

980 finally: 

981 # reset: 

982 typing.TYPE_CHECKING = False 

983 

984 if retry_counter < 1: # pragma: no cover 

985 rich.print( 

986 f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'} ({type(err)})", 

987 file=sys.stderr, 

988 ) 

989 if state.verbosity > 2: 

990 traceback.print_tb(err.__traceback__) 

991 return False 

992 

993 # idk when this would happen, but something definitely went wrong here: 

994 return False # pragma: no cover 

995 

996 

997def handle_cli( 

998 code_before: str, 

999 code_after: str, 

1000 db_type: Optional[str] = None, 

1001 tables: Optional[list[str] | list[list[str]]] = None, 

1002 verbose: bool = False, 

1003 noop: bool = False, 

1004 magic: bool = False, 

1005 function_name: Optional[str | tuple[str, ...]] = "define_tables", 

1006 use_typedal: bool | typing.Literal["auto"] = "auto", 

1007 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT, 

1008 output_file: Optional[str | Path | io.StringIO] = None, 

1009 _update_path: bool = True, 

1010) -> bool: 

1011 """ 

1012 SQL-specific wrapper around render_schema_from_code. 

1013 """ 

1014 is_typedal = (detect_typedal(code_before) or detect_typedal(code_after)) if use_typedal == "auto" else use_typedal 

1015 

1016 raw_output = io.StringIO() 

1017 success = render_schema_from_code( 

1018 code_after, 

1019 code_before=code_before, 

1020 output_file=raw_output, 

1021 renderer=default_sql_renderer, 

1022 db_type=db_type, 

1023 tables=tables, 

1024 verbose=verbose, 

1025 noop=noop, 

1026 magic=magic, 

1027 function_name=function_name, 

1028 use_typedal=is_typedal, 

1029 _update_path=_update_path, 

1030 ) 

1031 if not success: 

1032 return False 

1033 if noop: 

1034 return True 

1035 

1036 return try_format_and_write_sql_output( 

1037 raw_output, 

1038 output_file, 

1039 output_format=output_format, 

1040 is_typedal=is_typedal, 

1041 ) 

1042 

1043 

1044def find_file_contents( 

1045 filename: Optional[str] = None, 

1046 function: Optional[str | tuple[str, ...]] = None, 

1047 prompt_description="table definition", 

1048 found_functions: list[str] | None = None, 

1049 default_version: str = "stdin", 

1050 file_version: Optional[str] = None, 

1051 git_root: Optional[Path] = None, 

1052 with_git: bool = True, 

1053) -> str: 

1054 """ 

1055 Resolve a source file spec and return its contents. 

1056 

1057 Supports: 

1058 - `<path>` 

1059 - `<path>:<function>` 

1060 - `<path>@<git-ref>` 

1061 - `-` / stdin 

1062 

1063 Args: 

1064 filename: File path (optionally including `:<function>` or `@<git-ref>`). 

1065 function: Optional function name(s) to merge with any `:<function>` from `filename`. 

1066 prompt_description: Description shown when reading from stdin. 

1067 found_functions: Optional mutable list that receives discovered function names. 

1068 default_version: Default version when `filename` has no `@<git-ref>`. 

1069 file_version: Explicit file version. When provided, `filename` is treated as a plain path. 

1070 git_root: Optional git root for path resolution. 

1071 with_git: Whether `@<git-ref>` lookups should use git history. 

1072 

1073 Returns: 

1074 The loaded source code. 

1075 

1076 Raises: 

1077 FileNotFoundError: If the resolved source file does not exist. 

1078 """ 

1079 if git_root is None: 

1080 git_root = find_git_root(filename) or find_git_root() or Path(os.getcwd()) 

1081 

1082 functions: set[str] = set() 

1083 if function: # pragma: no cover 

1084 if isinstance(function, tuple): 

1085 functions.update(function) 

1086 else: 

1087 functions.add(function) 

1088 

1089 if filename and ":" in filename: 

1090 # e.g. models.py:define_tables 

1091 filename, _function = filename.split(":", 1) 

1092 functions.add(_function) 

1093 

1094 if file_version is None: 

1095 file_version, file_path = extract_file_version_and_path( 

1096 filename, 

1097 default_version=default_version, 

1098 ) 

1099 else: 

1100 file_path = filename 

1101 file_exists, file_absolute_path = get_absolute_path_info(file_path, file_version, git_root) 

1102 

1103 if not file_exists: 

1104 raise FileNotFoundError(f"Source file {filename} could not be found.") 

1105 

1106 if found_functions is not None: 

1107 found_functions.extend(functions) 

1108 

1109 return get_file_for_version( 

1110 file_absolute_path, 

1111 file_version, 

1112 prompt_description=prompt_description, 

1113 with_git=with_git, 

1114 ) 

1115 

1116 

1117def core_create( 

1118 filename: Optional[str] = None, 

1119 tables: Optional[list[str]] = None, 

1120 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None, 

1121 magic: bool = False, 

1122 noop: bool = False, 

1123 verbose: bool = False, 

1124 function: Optional[str | tuple[str, ...]] = None, 

1125 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT, 

1126 output_file: Optional[str | Path] = None, 

1127 _update_path: bool = True, 

1128) -> bool: 

1129 """ 

1130 Generates SQL migration statements for creating one or more tables, based on the code in a given source file. 

1131 

1132 Args: 

1133 filename: The filename of the source file to parse. This code represents the final state of the database. 

1134 tables: A list of table names to generate SQL for. 

1135 If None, the function will attempt to process all tables found in the code. 

1136 db_type: The type of the database. If None, the function will attempt to infer it from the code. 

1137 magic: If True, automatically add missing variables for execution. 

1138 noop: If True, only print the generated code but do not execute it. 

1139 verbose: If True, print the generated code and additional debug information. 

1140 function: The name of the function where the tables are defined. 

1141 If None, the function will use 'define_tables'. 

1142 output_format: defaults to just SQL, edwh-migrate migration syntax also supported 

1143 output_file: append the output to a file instead of printing it? 

1144 _update_path: try adding cwd to PYTHONPATH if failing? 

1145 

1146 Returns: 

1147 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully, 

1148 False otherwise. 

1149 

1150 Raises: 

1151 ValueError: If the source file cannot be found or if no tables could be found in the code. 

1152 """ 

1153 mut_functions = [] 

1154 text = find_file_contents( 

1155 filename, 

1156 function, 

1157 found_functions=mut_functions, 

1158 default_version="current" if filename else "stdin", 

1159 ) 

1160 

1161 return handle_cli( 

1162 "", 

1163 text, 

1164 db_type=db_type, 

1165 tables=tables, 

1166 verbose=verbose, 

1167 noop=noop, 

1168 magic=magic, 

1169 function_name=tuple(mut_functions), 

1170 output_format=output_format, 

1171 output_file=output_file, 

1172 _update_path=_update_path, 

1173 ) 

1174 

1175 

1176def core_alter( 

1177 filename_before: Optional[str] = None, 

1178 filename_after: Optional[str] = None, 

1179 tables: Optional[list[str]] = None, 

1180 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None, 

1181 magic: bool = False, 

1182 noop: bool = False, 

1183 verbose: bool = False, 

1184 function: Optional[str] = None, 

1185 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT, 

1186 output_file: Optional[str | Path] = None, 

1187 _update_path: bool = True, 

1188) -> bool: 

1189 """ 

1190 Generates SQL migration statements for altering the database, based on the code in two given source files. 

1191 

1192 Args: 

1193 filename_before: The filename of the source file before changes. 

1194 This code represents the initial state of the database. 

1195 filename_after: The filename of the source file after changes. 

1196 This code represents the final state of the database. 

1197 tables: A list of table names to generate SQL for. 

1198 If None, the function will attempt to process all tables found in the code. 

1199 db_type: The type of the database. If None, the function will attempt to infer it from the code. 

1200 magic: If True, automatically add missing variables for execution. 

1201 noop: If True, only print the generated code but do not execute it. 

1202 verbose: If True, print the generated code and additional debug information. 

1203 function: The name of the function where the tables are defined. 

1204 If None, the function will use 'define_tables'. 

1205 output_format: defaults to just SQL, edwh-migrate migration syntax also supported 

1206 output_file: append the output to a file instead of printing it? 

1207 _update_path: try adding cwd to PYTHONPATH if failing? 

1208 

1209 Returns: 

1210 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully, 

1211 False otherwise. 

1212 

1213 Raises: 

1214 ValueError: If either of the source files cannot be found, if no tables could be found in the code, 

1215 or if the codes before and after are identical. 

1216 """ 

1217 git_root = find_git_root(filename_before) or find_git_root(filename_after) 

1218 

1219 functions: set[str] = set() 

1220 if function: # pragma: no cover 

1221 functions.add(function) 

1222 

1223 if filename_before and ":" in filename_before: 

1224 # e.g. models.py:define_tables 

1225 filename_before, _function = filename_before.split(":", 1) 

1226 functions.add(_function) 

1227 

1228 if filename_after and ":" in filename_after: 

1229 # e.g. models.py:define_tables 

1230 filename_after, _function = filename_after.split(":", 1) 

1231 functions.add(_function) 

1232 

1233 before, after = extract_file_versions_and_paths(filename_before, filename_after) 

1234 

1235 version_before, filename_before = before 

1236 version_after, filename_after = after 

1237 

1238 # either ./file exists or /file exists (seen from git root): 

1239 

1240 before_exists, _ = get_absolute_path_info(filename_before, version_before, git_root) 

1241 after_exists, _ = get_absolute_path_info(filename_after, version_after, git_root) 

1242 

1243 if not (before_exists and after_exists): 

1244 message = "" 

1245 message += "" if before_exists else f"Path {filename_before} does not exist! " 

1246 if filename_before != filename_after: 

1247 message += "" if after_exists else f"Path {filename_after} does not exist!" 

1248 raise FileNotFoundError(message) 

1249 

1250 try: 

1251 code_before = find_file_contents( 

1252 filename_before, 

1253 prompt_description="current table definition", 

1254 file_version=version_before, 

1255 git_root=git_root, 

1256 with_git=git_root is not None, 

1257 ) 

1258 code_after = find_file_contents( 

1259 filename_after, 

1260 prompt_description="desired table definition", 

1261 file_version=version_after, 

1262 git_root=git_root, 

1263 with_git=git_root is not None, 

1264 ) 

1265 

1266 if not (code_before and code_after): 

1267 message = "" 

1268 message += "" if code_before else "Before code is empty (Maybe try `pydal2sql create`)! " 

1269 message += "" if code_after else "After code is empty! " 

1270 raise ValueError(message) 

1271 

1272 if code_before == code_after: 

1273 raise ValueError("Both contain the same code - nothing to alter!") 

1274 

1275 except ValueError as e: 

1276 rich.print(f"[yellow] {e} [/yellow]", file=sys.stderr) 

1277 return False 

1278 

1279 return handle_cli( 

1280 code_before, 

1281 code_after, 

1282 db_type=db_type, 

1283 tables=tables, 

1284 verbose=verbose, 

1285 noop=noop, 

1286 magic=magic, 

1287 function_name=tuple(functions), 

1288 output_format=output_format, 

1289 output_file=output_file, 

1290 ) 

1291 

1292 

1293def core_stub( 

1294 migration_name: str, 

1295 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT, 

1296 output_file: Optional[str | Path | io.StringIO] = None, 

1297 dry_run: bool = False, 

1298 is_typedal: bool = False, 

1299) -> bool: 

1300 """ 

1301 Generate a dummy migration. 

1302 

1303 Args: 

1304 migration_name (str): The name of the migration. 

1305 output_format: The format in which to output the migration (edwh-migrate, default). 

1306 Defaults to "default". 

1307 output_file: The file to which the migration should be output. 

1308 If None, the migration will be printed to stdout. Defaults to None. 

1309 dry_run (bool): If True, the migration will not be executed and will only be printed to stdout. 

1310 Defaults to False. 

1311 is_typedal (bool): If True, the migration will be generated for Typedal. 

1312 Defaults to False, which will generate the migration for pydal. 

1313 Only relevant if output_format=edwh-migrate 

1314 

1315 Returns: 

1316 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully, 

1317 False otherwise. 

1318 

1319 Example: 

1320 >>> core_stub("some_change", output_format="sql", output_file="migration.sql", dry_run=False, is_typedal=False) 

1321 True 

1322 

1323 This function is mostly useful for usage with output_format='edwh-migrate'. 

1324 If `dry_run` is True, the migration will only be printed to stdout and not executed. 

1325 """ 

1326 if dry_run: 

1327 # None will print it to stdout 

1328 output_file = None 

1329 

1330 return _format_and_write_sql_output( 

1331 io.StringIO(f"-- {migration_name}\n"), 

1332 output_file, 

1333 output_format=output_format, 

1334 is_typedal=is_typedal, 

1335 default_migration_name=migration_name, 

1336 )