Coverage for src/pydal2sql_core/cli_support.py: 100%

319 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-04 18:34 +0100

1""" 

2CLI-Agnostic support. 

3""" 

4import contextlib 

5import io 

6import os 

7import re 

8import select 

9import string 

10import sys 

11import textwrap 

12import traceback 

13import typing 

14from datetime import datetime 

15from pathlib import Path 

16from typing import Any, Optional 

17 

18import git 

19import gitdb.exc 

20import rich 

21from black.files import find_project_root 

22from git.objects.blob import Blob 

23from git.objects.commit import Commit 

24from git.repo import Repo 

25from witchery import ( 

26 add_function_call, 

27 find_defined_variables, 

28 find_function_to_call, 

29 find_missing_variables, 

30 generate_magic_code, 

31 has_local_imports, 

32 remove_if_falsey_blocks, 

33 remove_import, 

34 remove_local_imports, 

35 remove_specific_variables, 

36) 

37 

38from .helpers import excl, flatten, uniq 

39from .types import ( 

40 _SUPPORTED_OUTPUT_FORMATS, 

41 DEFAULT_OUTPUT_FORMAT, 

42 SUPPORTED_DATABASE_TYPES_WITH_ALIASES, 

43 SUPPORTED_OUTPUT_FORMATS, 

44 DummyDAL, 

45 DummyTypeDAL, 

46) 

47 

48 

49def has_stdin_data() -> bool: # pragma: no cover 

50 """ 

51 Check if the program starts with cli data (pipe | or redirect <). 

52 

53 Returns: 

54 bool: True if the program starts with cli data, False otherwise. 

55 

56 See Also: 

57 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data 

58 """ 

59 return any( 

60 select.select( 

61 [ 

62 sys.stdin, 

63 ], 

64 [], 

65 [], 

66 0.0, 

67 )[0] 

68 ) 

69 

70 

71AnyCallable = typing.Callable[..., Any] 

72 

73 

74def print_if_interactive(*args: Any, pretty: bool = True, **kwargs: Any) -> None: # pragma: no cover 

75 """ 

76 Print the given arguments if running in an interactive session. 

77 

78 Args: 

79 *args: Variable length list of arguments to be printed. 

80 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function. 

81 **kwargs: Optional keyword arguments to be passed to the print function. 

82 

83 Returns: 

84 None 

85 """ 

86 is_interactive = not has_stdin_data() 

87 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy 

88 if is_interactive: 

89 kwargs["file"] = sys.stderr 

90 _print( 

91 *args, 

92 **kwargs, 

93 ) 

94 

95 

96def find_git_root(at: str = None) -> Optional[Path]: 

97 """ 

98 Find the root directory of the Git repository. 

99 

100 Args: 

101 at (str, optional): The directory path to start the search. Defaults to the current working directory. 

102 

103 Returns: 

104 Optional[Path]: The root directory of the Git repository if found, otherwise None. 

105 """ 

106 folder, reason = find_project_root((at or os.getcwd(),)) 

107 if reason != ".git directory": 

108 return None 

109 return folder 

110 

111 

112def find_git_repo(repo: Repo = None, at: str = None) -> Repo: 

113 """ 

114 Find the Git repository instance. 

115 

116 Args: 

117 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance. 

118 at (str, optional): The directory path to start the search. Defaults to the current working directory. 

119 

120 Returns: 

121 Repo: The Git repository instance. 

122 """ 

123 if repo: 

124 return repo 

125 

126 root = find_git_root(at) 

127 return Repo(str(root)) 

128 

129 

130def latest_commit(repo: Repo = None) -> Commit: 

131 """ 

132 Get the latest commit in the Git repository. 

133 

134 Args: 

135 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

136 

137 Returns: 

138 Commit: The latest commit in the Git repository. 

139 """ 

140 repo = find_git_repo(repo) 

141 return repo.head.commit 

142 

143 

144def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit: 

145 """ 

146 Get a specific commit in the Git repository by its hash or name. 

147 

148 Args: 

149 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name. 

150 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

151 

152 Returns: 

153 Commit: The commit object corresponding to the given commit hash. 

154 """ 

155 repo = find_git_repo(repo) 

156 return repo.commit(commit_hash) 

157 

158 

159@contextlib.contextmanager 

160def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]: 

161 """ 

162 Open a Git Blob object as a context manager, providing access to its data. 

163 

164 Args: 

165 file (Blob): The Git Blob object to open. 

166 

167 Yields: 

168 io.BytesIO: A BytesIO object providing access to the Blob data. 

169 """ 

170 yield io.BytesIO(file.data_stream.read()) 

171 

172 

173def read_blob(file: Blob) -> str: 

174 """ 

175 Read the contents of a Git Blob object and decode it as a string. 

176 

177 Args: 

178 file (Blob): The Git Blob object to read. 

179 

180 Returns: 

181 str: The contents of the Blob as a string. 

182 """ 

183 with open_blob(file) as f: 

184 return f.read().decode() 

185 

186 

187def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str: 

188 """ 

189 Get the contents of a file in the Git repository at a specific commit version. 

190 

191 Args: 

192 filename (str): The path of the file to retrieve. 

193 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit). 

194 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance. 

195 

196 Returns: 

197 str: The contents of the file as a string. 

198 """ 

199 repo = find_git_repo(repo, at=filename) 

200 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo) 

201 

202 file_path = str(Path(filename).resolve()) 

203 # relative to the .git folder: 

204 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/") 

205 

206 file_at_commit = commit.tree / relative_file_path 

207 return read_blob(file_at_commit) 

208 

209 

210def get_file_for_version(filename: str, version: str, prompt_description: str = "", with_git: bool = True) -> str: 

211 """ 

212 Get the contents of a file based on the version specified. 

213 

214 Args: 

215 filename (str): The path of the file to retrieve. 

216 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name. 

217 prompt_description (str, optional): A description to display when asking for input from stdin. 

218 

219 Returns: 

220 str: The contents of the file as a string. 

221 """ 

222 if not with_git or version == "current": 

223 return Path(filename).read_text() 

224 elif version == "stdin": # pragma: no cover 

225 print_if_interactive( 

226 f"[blue]Please paste your define tables ({prompt_description}) code below " 

227 f"and press ctrl-D when finished.[/blue]", 

228 file=sys.stderr, 

229 ) 

230 result = sys.stdin.read() 

231 print_if_interactive("[blue]---[/blue]", file=sys.stderr) 

232 return result 

233 elif with_git: 

234 try: 

235 return get_file_for_commit(filename, version) 

236 except (git.exc.GitError, gitdb.exc.ODBError) as e: 

237 raise FileNotFoundError(f"{filename}@{version}") from e 

238 

239 

240def extract_file_version_and_path( 

241 file_path_or_git_tag: Optional[str], default_version: str = "stdin" 

242) -> tuple[str, str | None]: 

243 """ 

244 Extract the file version and path from the given input. 

245 

246 Args: 

247 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag. 

248 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin". 

249 

250 Returns: 

251 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified). 

252 

253 Examples: 

254 myfile.py (implies @current) 

255 

256 myfile.py@latest 

257 myfile.py@my-branch 

258 myfile.py@b3f24091a9 

259 

260 @latest (implies no path, e.g. in case of ALTER to copy previously defined path) 

261 """ 

262 if not file_path_or_git_tag: 

263 return default_version, "" 

264 

265 if file_path_or_git_tag == "-": 

266 return "stdin", "-" 

267 

268 if file_path_or_git_tag.startswith("@"): 

269 file_version = file_path_or_git_tag.strip("@") 

270 file_path = None 

271 elif "@" in file_path_or_git_tag: 

272 file_path, file_version = file_path_or_git_tag.split("@") 

273 else: 

274 file_version = default_version # `latest` for before; `current` for after. 

275 file_path = file_path_or_git_tag 

276 

277 return file_version, file_path 

278 

279 

280def extract_file_versions_and_paths( 

281 filename_before: Optional[str], filename_after: Optional[str] 

282) -> tuple[tuple[str, str | None], tuple[str, str | None]]: 

283 """ 

284 Extract the file versions and paths based on the before and after filenames. 

285 

286 Args: 

287 filename_before (str, optional): The path of the file before the change (or None). 

288 filename_after (str, optional): The path of the file after the change (or None). 

289 

290 Returns: 

291 tuple[tuple[str, str | None], tuple[str, str | None]]: 

292 A tuple of two tuples, each containing the version and path of the before and after files. 

293 """ 

294 version_before, filepath_before = extract_file_version_and_path( 

295 filename_before, 

296 default_version="current" 

297 if filename_after and filename_before and filename_after != filename_before 

298 else "latest", 

299 ) 

300 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current") 

301 

302 if not (filepath_before or filepath_after): 

303 raise ValueError("Please supply at least one file name.") 

304 elif not filepath_after: 

305 filepath_after = filepath_before 

306 elif not filepath_before: 

307 filepath_before = filepath_after 

308 

309 return (version_before, filepath_before), (version_after, filepath_after) 

310 

311 

312def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]: 

313 """ 

314 Get absolute path information for the file based on the version and Git root. 

315 

316 Args: 

317 filename (str, optional): The path of the file to check (or None). 

318 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name. 

319 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined. 

320 

321 Returns: 

322 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file. 

323 """ 

324 if version == "stdin": 

325 return True, "" 

326 elif filename is None: 

327 # can't deal with this, not stdin and no file should show file missing error later. 

328 return False, "" 

329 

330 if git_root is None: 

331 git_root = find_git_root() or Path(os.getcwd()) 

332 

333 path = Path(filename) 

334 path_via_git = git_root / filename 

335 

336 if path.exists(): 

337 exists = True 

338 absolute_path = str(path.resolve()) 

339 elif path_via_git.exists(): 

340 exists = True 

341 absolute_path = str(path_via_git.resolve()) 

342 else: 

343 exists = False 

344 absolute_path = "" 

345 

346 return exists, absolute_path 

347 

348 

349def ensure_no_migrate_on_real_db( 

350 code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: typing.Optional[bool] = False 

351) -> str: 

352 """ 

353 Ensure that the code does not contain actual migrations on a real database. 

354 

355 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`. 

356 It also removes local imports to prevent irrelevant code being executed. 

357 

358 Args: 

359 code (str): The code to check for database migrations. 

360 db_names (Iterable[str], optional): Names of variables representing the database. 

361 Defaults to ("db", "database"). 

362 fix (bool, optional): If True, removes the migration code. Defaults to False. 

363 

364 Returns: 

365 str: The modified code with migration code removed if fix=True, otherwise the original code. 

366 """ 

367 variables = find_defined_variables(code) 

368 

369 found_variables = set() 

370 

371 for db_name in db_names: 

372 if db_name in variables: 

373 if fix: 

374 code = remove_specific_variables(code, db_names) 

375 else: 

376 found_variables.add(db_name) 

377 

378 if found_variables: 

379 if len(found_variables) == 1: 

380 var = next(iter(found_variables)) 

381 message = f"Variable {var} defined in code! " 

382 else: # pragma: no cover 

383 var = ", ".join(found_variables) 

384 message = f"Variables {var} defined in code! " 

385 raise ValueError( 

386 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database." 

387 ) 

388 

389 if has_local_imports(code): 

390 if fix: 

391 code = remove_local_imports(code) 

392 else: 

393 raise ValueError("Local imports are used in this file! Please remove these or use --magic.") 

394 

395 return code 

396 

397 

398MAX_RETRIES = 30 

399 

400# todo: overload more methods 

401 

402TEMPLATE_PYDAL = """ 

403from pydal import * 

404from pydal.objects import * 

405from pydal.validators import * 

406 

407from pydal2sql_core import generate_sql 

408 

409 

410# from pydal import DAL 

411db = database = DummyDAL(None, migrate=False) 

412 

413tables = $tables 

414db_type = '$db_type' 

415 

416$extra 

417 

418$code_before 

419 

420db_old = db 

421db_new = db = database = DummyDAL(None, migrate=False) 

422 

423$extra 

424 

425$code_after 

426 

427if not tables: 

428 tables = _uniq(db_old._tables + db_new._tables) 

429 tables = _excl(tables, _special_tables) 

430 

431if not tables: 

432 raise ValueError('no-tables-found') 

433 

434for table in tables: 

435 print('-- start ', table, '--', file=_file) 

436 if table in db_old and table in db_new: 

437 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file) 

438 elif table in db_old: 

439 print(f'DROP TABLE {table};', file=_file) 

440 else: 

441 print(generate_sql(db_new[table], db_type=db_type), file=_file) 

442 print('-- END OF MIGRATION --', file=_file) 

443 """ 

444 

445TEMPLATE_TYPEDAL = """ 

446from pydal import * 

447from pydal.objects import * 

448from pydal.validators import * 

449from typedal import * 

450 

451from pydal2sql_core import generate_sql 

452 

453 

454# from typedal import TypeDAL as DAL 

455db = database = DummyDAL(None, migrate=False) 

456 

457tables = $tables 

458db_type = '$db_type' 

459 

460$extra 

461 

462$code_before 

463 

464db_old = db 

465db_new = db = database = DummyDAL(None, migrate=False) 

466 

467$extra 

468 

469$code_after 

470 

471if not tables: 

472 tables = _uniq(db_old._tables + db_new._tables) 

473 tables = _excl(tables, _special_tables) 

474 

475if not tables: 

476 raise ValueError('no-tables-found') 

477 

478 

479for table in tables: 

480 print('-- start ', table, '--', file=_file) 

481 if table in db_old and table in db_new: 

482 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file) 

483 elif table in db_old: 

484 print(f'DROP TABLE {table};', file=_file) 

485 else: 

486 print(generate_sql(db_new[table], db_type=db_type), file=_file) 

487 

488 print('-- END OF MIGRATION --', file=_file) 

489 """ 

490 

491 

492def sql_to_function_name(sql_statement: str) -> str: 

493 """ 

494 Extract action (CREATE, ALTER, DROP) and table name from the SQL statement. 

495 """ 

496 match = re.findall(r"(CREATE|ALTER|DROP)\s+TABLE\s+['\"]?(\w+)['\"]?", sql_statement.lower(), re.IGNORECASE) 

497 

498 if not match: 

499 # raise ValueError("Invalid SQL statement. Unable to extract action and table name.") 

500 return "unknown_migration" 

501 

502 action, table_name = match[0] 

503 

504 # Generate a function name with the specified format 

505 return f"{action}_{table_name}" 

506 

507 

508def _setup_generic_edwh_migrate(file: Path, is_typedal: bool) -> None: 

509 contents = ( 

510 "from edwh_migrate import migration\n" 

511 + ("from typedal import TypeDAL" if is_typedal else "from pydal import DAL") 

512 + "\n" 

513 ) 

514 

515 with file.open("w") as f: 

516 f.write(textwrap.dedent(contents)) 

517 

518 rich.print(f"[green] New migrate file {file} created [/green]") 

519 

520 

521START_RE = re.compile(r"-- start\s+\w+\s--\n") 

522 

523 

524def _build_edwh_migration(contents: str, cls: str, date: str, existing: Optional[str] = None) -> str: 

525 sql_func_name = sql_to_function_name(contents) 

526 func_name = "_placeholder_" 

527 contents = START_RE.sub("", contents) 

528 

529 for n in range(1, 1000): 

530 func_name = f"{sql_func_name}_{date}_{str(n).zfill(3)}" 

531 

532 if existing and f"def {func_name}" in existing: 

533 if contents.replace(" ", "").replace("\n", "") in existing.replace(" ", "").replace("\n", ""): 

534 rich.print(f"[yellow] migration {func_name} already exists, skipping! [/yellow]") 

535 return "" 

536 elif func_name.startswith("alter"): 

537 # bump number because alter migrations are different 

538 continue 

539 else: 

540 rich.print( 

541 f"[red] migration {func_name} already exists [bold]with different contents[/bold], skipping! [/red]" 

542 ) 

543 return "" 

544 else: 

545 # okay function name, stop incrementing 

546 break 

547 

548 contents = textwrap.indent(contents.strip(), " " * 16) 

549 

550 if not contents.strip(): 

551 # no real migration! 

552 return "" 

553 

554 return textwrap.dedent( 

555 f''' 

556 

557 @migration 

558 def {func_name}(db: {cls}): 

559 db.executesql(""" 

560{contents} 

561 """) 

562 db.commit() 

563 

564 return True 

565 ''' 

566 ) 

567 

568 

569def _build_edwh_migrations(contents: str, is_typedal: bool, output: Optional[Path] = None) -> str: 

570 cls = "TypeDAL" if is_typedal else "DAL" 

571 date = datetime.now().strftime("%Y%m%d") # yyyymmdd 

572 

573 existing = output.read_text() if output and output.exists() else None 

574 

575 return "".join( 

576 _build_edwh_migration(migration, cls, date, existing) 

577 for migration in contents.split("-- END OF MIGRATION --") 

578 if migration.strip() 

579 ) 

580 

581 

582def _handle_output( 

583 file: io.StringIO, 

584 output_file: Path | str | io.StringIO | None, 

585 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT, 

586 is_typedal: bool = False, 

587) -> None: 

588 file.seek(0) 

589 contents = file.read() 

590 

591 if isinstance(output_file, str): 

592 # `--output-file -` will print to stdout 

593 output_file = None if output_file == "-" else Path(output_file) 

594 

595 if output_format == "edwh-migrate": 

596 contents = _build_edwh_migrations(contents, is_typedal, output_file if isinstance(output_file, Path) else None) 

597 elif output_format in {"default", "sql"} or not output_format: 

598 contents = "\n".join(contents.split("-- END OF MIGRATION --")) 

599 else: 

600 raise ValueError( 

601 f"Unknown format {output_format}. " f"Please choose one of {typing.get_args(_SUPPORTED_OUTPUT_FORMATS)}" 

602 ) 

603 

604 if isinstance(output_file, Path): 

605 if output_format == "edwh-migrate" and (not output_file.exists() or output_file.stat().st_size == 0): 

606 _setup_generic_edwh_migrate(output_file, is_typedal) 

607 

608 if contents.strip(): 

609 with output_file.open("a") as f: 

610 f.write(contents) 

611 

612 rich.print(f"[green] Written migration(s) to {output_file} [/green]") 

613 else: 

614 rich.print(f"[yellow] Nothing to write to {output_file} [/yellow]") 

615 

616 elif isinstance(output_file, io.StringIO): 

617 output_file.write(contents) 

618 else: 

619 # no file, just print to stdout: 

620 print(contents.strip()) 

621 

622 

623IMPORT_IN_STR = re.compile(r'File "<string>", line (\d+), in <module>') 

624 

625 

626def _handle_import_error(code: str, error: ImportError) -> str: 

627 # error is deeper in a package, find the related import in the code: 

628 tb_lines = traceback.format_exc().splitlines() 

629 

630 for line in tb_lines: 

631 if matches := IMPORT_IN_STR.findall(line): 

632 # 'File "<string>", line 15, in <module>' 

633 line_no = int(matches[0]) - 1 

634 lines = code.split("\n") 

635 return lines[line_no] 

636 

637 # I don't know how to trigger this case: 

638 raise ValueError("Faulty import could not be automatically deleted") from error # pragma: no cover 

639 

640 

641MISSING_RELATIONSHIP = re.compile(r"Cannot resolve reference (\w+) in \w+ definition") 

642 

643 

644def _handle_relation_error(error: KeyError) -> tuple[str, str]: 

645 if not (table := MISSING_RELATIONSHIP.findall(str(error))): 

646 # other error, raise again 

647 raise error 

648 

649 t = table[0] 

650 

651 return ( 

652 t, 

653 """ 

654 db.define_table('%s', redefine=True) 

655 """ 

656 % t, 

657 ) 

658 

659 

660def handle_cli( 

661 code_before: str, 

662 code_after: str, 

663 db_type: Optional[str] = None, 

664 tables: Optional[list[str] | list[list[str]]] = None, 

665 verbose: Optional[bool] = False, 

666 noop: Optional[bool] = False, 

667 magic: Optional[bool] = False, 

668 function_name: Optional[str | tuple[str, ...]] = "define_tables", 

669 use_typedal: bool | typing.Literal["auto"] = "auto", 

670 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT, 

671 output_file: Optional[str | Path | io.StringIO] = None, 

672) -> bool: 

673 """ 

674 Handle user input for generating SQL migration statements based on before and after code. 

675 

676 Args: 

677 code_before (str): The code representing the state of the database before the change. 

678 code_after (str, optional): The code representing the state of the database after the change. 

679 db_type (str, optional): The type of the database (e.g., "postgres", "mysql", etc.). Defaults to None. 

680 tables (list[str] or list[list[str]], optional): The list of tables to generate SQL for. Defaults to None. 

681 verbose (bool, optional): If True, print the generated code. Defaults to False. 

682 noop (bool, optional): If True, only print the generated code but do not execute it. Defaults to False. 

683 magic (bool, optional): If True, automatically add missing variables for execution. Defaults to False. 

684 function_name (str, optional): The name of the function where the tables are defined. Defaults: "define_tables". 

685 use_typedal: replace pydal imports with TypeDAL? 

686 output_format: defaults to just SQL, edwh-migrate migration syntax also supported 

687 output_file: append the output to a file instead of printing it? 

688 

689 # todo: prefix (e.g. public.) 

690 

691 Returns: 

692 bool: True if SQL migration statements are generated and executed successfully, False otherwise. 

693 """ 

694 # todo: better typedal checking 

695 if use_typedal == "auto": 

696 use_typedal = "typedal" in code_before.lower() or "typedal" in code_after.lower() 

697 

698 if function_name: 

699 define_table_functions: set[str] = set(function_name) if isinstance(function_name, tuple) else {function_name} 

700 else: 

701 define_table_functions = set() 

702 

703 template = TEMPLATE_TYPEDAL if use_typedal else TEMPLATE_PYDAL 

704 

705 to_execute = string.Template(textwrap.dedent(template)) 

706 

707 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic) 

708 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic) 

709 extra_code = "" 

710 

711 generated_code = to_execute.substitute( 

712 { 

713 "tables": flatten(tables or []), 

714 "db_type": db_type or "", 

715 "code_before": textwrap.dedent(code_before), 

716 "code_after": textwrap.dedent(code_after), 

717 "extra": extra_code, 

718 } 

719 ) 

720 if verbose or noop: 

721 rich.print(generated_code, file=sys.stderr) 

722 

723 if noop: 

724 # done 

725 return True 

726 

727 err: typing.Optional[Exception] = None 

728 catch: dict[str, Any] = {} 

729 retry_counter = MAX_RETRIES 

730 

731 magic_vars = {"_file", "DummyDAL", "_special_tables", "_uniq", "_excl"} 

732 special_tables: set[str] = {"typedal_cache", "typedal_cache_dependency"} if use_typedal else set() 

733 

734 while retry_counter: 

735 retry_counter -= 1 

736 try: 

737 if verbose: 

738 rich.print(generated_code, file=sys.stderr) 

739 

740 # 'catch' is used to add and receive globals from the exec scope. 

741 # another argument could be added for locals, but adding simply {} changes the behavior negatively. 

742 # so for now, only globals is passed. 

743 catch["_file"] = io.StringIO() # <- every print should go to this file, so we can handle it afterwards 

744 catch["DummyDAL"] = ( 

745 DummyTypeDAL if use_typedal else DummyDAL 

746 ) # <- use a fake DAL that doesn't actually run queries 

747 catch["_special_tables"] = special_tables # <- e.g. typedal_cache, auth_user 

748 # note: when adding something to 'catch', also add it to magic_vars!!! 

749 

750 catch["_uniq"] = uniq # function to make a list unique without changing order 

751 catch["_excl"] = excl # function to exclude items from a list 

752 

753 exec(generated_code, catch) # nosec: B102 

754 _handle_output(catch["_file"], output_file, output_format, is_typedal=use_typedal) 

755 return True # success! 

756 except ValueError as e: 

757 if str(e) != "no-tables-found": # pragma: no cover 

758 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr) 

759 return False 

760 

761 if define_table_functions: 

762 any_found = False 

763 for function_name in define_table_functions: 

764 define_tables = find_function_to_call(generated_code, function_name) 

765 

766 # if define_tables function is found, add call to it at end of code 

767 if define_tables is not None: 

768 generated_code = add_function_call(generated_code, function_name, multiple=True) 

769 any_found = True 

770 

771 if any_found: 

772 # hurray! 

773 continue 

774 

775 # else: no define_tables or other method to use found. 

776 

777 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr) 

778 if use_typedal: 

779 print( 

780 "Please use `db.define` or `database.define`, " 

781 "or if you really need to use an alias like my_db.define, " 

782 "add `my_db = db` at the top of the file or pass `--db-name mydb`.", 

783 file=sys.stderr, 

784 ) 

785 else: 

786 print( 

787 "Please use `db.define_table` or `database.define_table`, " 

788 "or if you really need to use an alias like my_db.define_tables, " 

789 "add `my_db = db` at the top of the file or pass `--db-name mydb`.", 

790 file=sys.stderr, 

791 ) 

792 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr) 

793 

794 return False 

795 

796 except NameError as e: 

797 err = e 

798 # something is missing! 

799 missing_vars = find_missing_variables(generated_code) - magic_vars 

800 if not magic: 

801 rich.print( 

802 f"Your code is missing some variables: {missing_vars}. Add these or try --magic", 

803 file=sys.stderr, 

804 ) 

805 return False 

806 

807 # postponed: this can possibly also be achieved by updating the 'catch' dict 

808 # instead of injecting in the string. 

809 extra_code = extra_code + "\n" + textwrap.dedent(generate_magic_code(missing_vars)) 

810 

811 code_before = remove_if_falsey_blocks(code_before) 

812 code_after = remove_if_falsey_blocks(code_after) 

813 

814 generated_code = to_execute.substitute( 

815 { 

816 "tables": flatten(tables or []), 

817 "db_type": db_type or "", 

818 "extra": textwrap.dedent(extra_code), 

819 "code_before": textwrap.dedent(code_before), 

820 "code_after": textwrap.dedent(code_after), 

821 } 

822 ) 

823 except ImportError as e: 

824 # should include ModuleNotFoundError 

825 err = e 

826 # if we catch an ImportError, we try to remove the import and retry 

827 if not e.path: 

828 # code exists in code itself 

829 code_before = remove_import(code_before, e.name or "") 

830 code_after = remove_import(code_after, e.name or "") 

831 else: 

832 to_remove = _handle_import_error(generated_code, e) 

833 code_before = code_before.replace(to_remove, "\n") 

834 code_after = code_after.replace(to_remove, "\n") 

835 

836 generated_code = to_execute.substitute( 

837 { 

838 "tables": flatten(tables or []), 

839 "db_type": db_type or "", 

840 "extra": textwrap.dedent(extra_code), 

841 "code_before": textwrap.dedent(code_before), 

842 "code_after": textwrap.dedent(code_after), 

843 } 

844 ) 

845 

846 except KeyError as e: 

847 err = e 

848 table_name, table_definition = _handle_relation_error(e) 

849 special_tables.add(table_name) 

850 extra_code = extra_code + "\n" + textwrap.dedent(table_definition) 

851 

852 generated_code = to_execute.substitute( 

853 { 

854 "tables": flatten(tables or []), 

855 "db_type": db_type or "", 

856 "extra": textwrap.dedent(extra_code), 

857 "code_before": textwrap.dedent(code_before), 

858 "code_after": textwrap.dedent(code_after), 

859 } 

860 ) 

861 except Exception as e: 

862 err = e 

863 # otherwise: give up 

864 retry_counter = 0 

865 finally: 

866 # reset: 

867 typing.TYPE_CHECKING = False 

868 

869 if retry_counter < 1: # pragma: no cover 

870 rich.print(f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'}", file=sys.stderr) 

871 return False 

872 

873 # idk when this would happen, but something definitely went wrong here: 

874 return False # pragma: no cover 

875 

876 

877def core_create( 

878 filename: Optional[str] = None, 

879 tables: Optional[list[str]] = None, 

880 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None, 

881 magic: bool = False, 

882 noop: bool = False, 

883 verbose: bool = False, 

884 function: Optional[str | tuple[str, ...]] = None, 

885 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT, 

886 output_file: Optional[str | Path] = None, 

887) -> bool: 

888 """ 

889 Generates SQL migration statements for creating one or more tables, based on the code in a given source file. 

890 

891 Args: 

892 filename: The filename of the source file to parse. This code represents the final state of the database. 

893 tables: A list of table names to generate SQL for. 

894 If None, the function will attempt to process all tables found in the code. 

895 db_type: The type of the database. If None, the function will attempt to infer it from the code. 

896 magic: If True, automatically add missing variables for execution. 

897 noop: If True, only print the generated code but do not execute it. 

898 verbose: If True, print the generated code and additional debug information. 

899 function: The name of the function where the tables are defined. 

900 If None, the function will use 'define_tables'. 

901 output_format: defaults to just SQL, edwh-migrate migration syntax also supported 

902 output_file: append the output to a file instead of printing it? 

903 

904 Returns: 

905 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully, 

906 False otherwise. 

907 

908 Raises: 

909 ValueError: If the source file cannot be found or if no tables could be found in the code. 

910 """ 

911 git_root = find_git_root() or Path(os.getcwd()) 

912 

913 functions: set[str] = set() 

914 if function: # pragma: no cover 

915 if isinstance(function, tuple): 

916 functions.update(function) 

917 else: 

918 functions.add(function) 

919 

920 if filename and ":" in filename: 

921 # e.g. models.py:define_tables 

922 filename, _function = filename.split(":", 1) 

923 functions.add(_function) 

924 

925 file_version, file_path = extract_file_version_and_path( 

926 filename, default_version="current" if filename else "stdin" 

927 ) 

928 file_exists, file_absolute_path = get_absolute_path_info(file_path, file_version, git_root) 

929 

930 if not file_exists: 

931 raise FileNotFoundError(f"Source file {filename} could not be found.") 

932 

933 text = get_file_for_version(file_absolute_path, file_version, prompt_description="table definition") 

934 

935 return handle_cli( 

936 "", 

937 text, 

938 db_type=db_type, 

939 tables=tables, 

940 verbose=verbose, 

941 noop=noop, 

942 magic=magic, 

943 function_name=tuple(functions), 

944 output_format=output_format, 

945 output_file=output_file, 

946 ) 

947 

948 

949def core_alter( 

950 filename_before: Optional[str] = None, 

951 filename_after: Optional[str] = None, 

952 tables: Optional[list[str]] = None, 

953 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None, 

954 magic: bool = False, 

955 noop: bool = False, 

956 verbose: bool = False, 

957 function: Optional[str] = None, 

958 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT, 

959 output_file: Optional[str | Path] = None, 

960) -> bool: 

961 """ 

962 Generates SQL migration statements for altering the database, based on the code in two given source files. 

963 

964 Args: 

965 filename_before: The filename of the source file before changes. 

966 This code represents the initial state of the database. 

967 filename_after: The filename of the source file after changes. 

968 This code represents the final state of the database. 

969 tables: A list of table names to generate SQL for. 

970 If None, the function will attempt to process all tables found in the code. 

971 db_type: The type of the database. If None, the function will attempt to infer it from the code. 

972 magic: If True, automatically add missing variables for execution. 

973 noop: If True, only print the generated code but do not execute it. 

974 verbose: If True, print the generated code and additional debug information. 

975 function: The name of the function where the tables are defined. 

976 If None, the function will use 'define_tables'. 

977 output_format: defaults to just SQL, edwh-migrate migration syntax also supported 

978 output_file: append the output to a file instead of printing it? 

979 

980 Returns: 

981 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully, 

982 False otherwise. 

983 

984 Raises: 

985 ValueError: If either of the source files cannot be found, if no tables could be found in the code, 

986 or if the codes before and after are identical. 

987 """ 

988 git_root = find_git_root(filename_before) or find_git_root(filename_after) 

989 

990 functions: set[str] = set() 

991 if function: # pragma: no cover 

992 functions.add(function) 

993 

994 if filename_before and ":" in filename_before: 

995 # e.g. models.py:define_tables 

996 filename_before, _function = filename_before.split(":", 1) 

997 functions.add(_function) 

998 

999 if filename_after and ":" in filename_after: 

1000 # e.g. models.py:define_tables 

1001 filename_after, _function = filename_after.split(":", 1) 

1002 functions.add(_function) 

1003 

1004 before, after = extract_file_versions_and_paths(filename_before, filename_after) 

1005 

1006 version_before, filename_before = before 

1007 version_after, filename_after = after 

1008 

1009 # either ./file exists or /file exists (seen from git root): 

1010 

1011 before_exists, before_absolute_path = get_absolute_path_info(filename_before, version_before, git_root) 

1012 after_exists, after_absolute_path = get_absolute_path_info(filename_after, version_after, git_root) 

1013 

1014 if not (before_exists and after_exists): 

1015 message = "" 

1016 message += "" if before_exists else f"Path {filename_before} does not exist! " 

1017 if filename_before != filename_after: 

1018 message += "" if after_exists else f"Path {filename_after} does not exist!" 

1019 raise FileNotFoundError(message) 

1020 

1021 try: 

1022 code_before = get_file_for_version( 

1023 before_absolute_path, 

1024 version_before, 

1025 prompt_description="current table definition", 

1026 with_git=git_root is not None, 

1027 ) 

1028 code_after = get_file_for_version( 

1029 after_absolute_path, 

1030 version_after, 

1031 prompt_description="desired table definition", 

1032 with_git=git_root is not None, 

1033 ) 

1034 

1035 if not (code_before and code_after): 

1036 message = "" 

1037 message += "" if code_before else "Before code is empty (Maybe try `pydal2sql create`)! " 

1038 message += "" if code_after else "After code is empty! " 

1039 raise ValueError(message) 

1040 

1041 if code_before == code_after: 

1042 raise ValueError("Both contain the same code - nothing to alter!") 

1043 

1044 except ValueError as e: 

1045 rich.print(f"[yellow] {e} [/yellow]", file=sys.stderr) 

1046 return False 

1047 

1048 return handle_cli( 

1049 code_before, 

1050 code_after, 

1051 db_type=db_type, 

1052 tables=tables, 

1053 verbose=verbose, 

1054 noop=noop, 

1055 magic=magic, 

1056 function_name=tuple(functions), 

1057 output_format=output_format, 

1058 output_file=output_file, 

1059 )