Coverage for src/pydal2sql_core/cli_support.py: 99%
382 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-22 13:56 +0200
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-22 13:56 +0200
1"""
2CLI-Agnostic support.
3"""
5import ast
6import contextlib
7import io
8import os
9import re
10import select
11import string
12import sys
13import textwrap
14import traceback
15import typing
16from dataclasses import dataclass
17from datetime import datetime
18from pathlib import Path
19from typing import Any, Optional
21import git
22import gitdb.exc
23import rich
24from black.files import find_project_root
25from git.objects.blob import Blob
26from git.objects.commit import Commit
27from git.repo import Repo
28from witchery import (
29 add_function_call,
30 find_defined_variables,
31 find_function_to_call,
32 find_missing_variables,
33 generate_magic_code,
34 has_local_imports,
35 remove_if_falsey_blocks,
36 remove_import,
37 remove_local_imports,
38 remove_specific_variables,
39)
41from .core import generate_sql
42from .helpers import detect_typedal, excl, flatten, uniq
43from .state import state
44from .types import (
45 _SUPPORTED_OUTPUT_FORMATS,
46 DEFAULT_OUTPUT_FORMAT,
47 SUPPORTED_DATABASE_TYPES_WITH_ALIASES,
48 SUPPORTED_OUTPUT_FORMATS,
49 DummyDAL,
50 DummyTypeDAL,
51)
53IMPORT_IN_STR = re.compile(r'File "<string>", line (\d+), in <module>')
56def has_stdin_data() -> bool: # pragma: no cover
57 """
58 Check if the program starts with cli data (pipe | or redirect <).
60 Returns:
61 bool: True if the program starts with cli data, False otherwise.
63 See Also:
64 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data
65 """
66 return any(
67 select.select(
68 [
69 sys.stdin,
70 ],
71 [],
72 [],
73 0.0,
74 )[0],
75 )
78AnyCallable = typing.Callable[..., Any]
81def print_if_interactive(*args: Any, pretty: bool = True, **kwargs: Any) -> None: # pragma: no cover
82 """
83 Print the given arguments if running in an interactive session.
85 Args:
86 *args: Variable length list of arguments to be printed.
87 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function.
88 **kwargs: Optional keyword arguments to be passed to the print function.
90 Returns:
91 None
92 """
93 is_interactive = not has_stdin_data()
94 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy
95 if is_interactive:
96 kwargs["file"] = sys.stderr
97 _print(
98 *args,
99 **kwargs,
100 )
103def find_git_root(at: str = None) -> Optional[Path]:
104 """
105 Find the root directory of the Git repository.
107 Args:
108 at (str, optional): The directory path to start the search. Defaults to the current working directory.
110 Returns:
111 Optional[Path]: The root directory of the Git repository if found, otherwise None.
112 """
113 folder, reason = find_project_root((at or os.getcwd(),))
114 if reason != ".git directory":
115 return None
116 return folder
119def find_git_repo(repo: Repo = None, at: str = None) -> Repo:
120 """
121 Find the Git repository instance.
123 Args:
124 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance.
125 at (str, optional): The directory path to start the search. Defaults to the current working directory.
127 Returns:
128 Repo: The Git repository instance.
129 """
130 if repo:
131 return repo
133 root = find_git_root(at)
134 return Repo(str(root))
137def latest_commit(repo: Repo = None) -> Commit:
138 """
139 Get the latest commit in the Git repository.
141 Args:
142 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
144 Returns:
145 Commit: The latest commit in the Git repository.
146 """
147 repo = find_git_repo(repo)
148 return repo.head.commit
151def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit:
152 """
153 Get a specific commit in the Git repository by its hash or name.
155 Args:
156 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name.
157 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
159 Returns:
160 Commit: The commit object corresponding to the given commit hash.
161 """
162 repo = find_git_repo(repo)
163 return repo.commit(commit_hash)
166@contextlib.contextmanager
167def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]:
168 """
169 Open a Git Blob object as a context manager, providing access to its data.
171 Args:
172 file (Blob): The Git Blob object to open.
174 Yields:
175 io.BytesIO: A BytesIO object providing access to the Blob data.
176 """
177 yield io.BytesIO(file.data_stream.read())
180def read_blob(file: Blob) -> str:
181 """
182 Read the contents of a Git Blob object and decode it as a string.
184 Args:
185 file (Blob): The Git Blob object to read.
187 Returns:
188 str: The contents of the Blob as a string.
189 """
190 with open_blob(file) as f:
191 return f.read().decode()
194def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str:
195 """
196 Get the contents of a file in the Git repository at a specific commit version.
198 Args:
199 filename (str): The path of the file to retrieve.
200 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit).
201 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
203 Returns:
204 str: The contents of the file as a string.
205 """
206 repo = find_git_repo(repo, at=filename)
207 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo)
209 file_path = str(Path(filename).resolve())
210 # relative to the .git folder:
211 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/")
213 file_at_commit = commit.tree / relative_file_path
214 return read_blob(file_at_commit)
217def get_file_for_version(filename: str, version: str, prompt_description: str = "", with_git: bool = True) -> str:
218 """
219 Get the contents of a file based on the version specified.
221 Args:
222 filename (str): The path of the file to retrieve.
223 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name.
224 prompt_description (str, optional): A description to display when asking for input from stdin.
226 Returns:
227 str: The contents of the file as a string.
228 """
229 if not with_git or version == "current":
230 return Path(filename).read_text()
231 elif version == "stdin": # pragma: no cover
232 print_if_interactive(
233 f"[blue]Please paste your define tables ({prompt_description}) code below "
234 f"and press ctrl-D when finished.[/blue]",
235 file=sys.stderr,
236 )
237 result = sys.stdin.read()
238 print_if_interactive("[blue]---[/blue]", file=sys.stderr)
239 return result
240 elif with_git:
241 try:
242 return get_file_for_commit(filename, version)
243 except (git.exc.GitError, gitdb.exc.ODBError) as e:
244 raise FileNotFoundError(f"{filename}@{version}") from e
247def extract_file_version_and_path(
248 file_path_or_git_tag: Optional[str],
249 default_version: str = "stdin",
250) -> tuple[str, str | None]:
251 """
252 Extract the file version and path from the given input.
254 Args:
255 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag.
256 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin".
258 Returns:
259 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified).
261 Examples:
262 myfile.py (implies @current)
264 myfile.py@latest
265 myfile.py@my-branch
266 myfile.py@b3f24091a9
268 @latest (implies no path, e.g. in case of ALTER to copy previously defined path)
269 """
270 if not file_path_or_git_tag:
271 return default_version, ""
273 if file_path_or_git_tag == "-":
274 return "stdin", "-"
276 if file_path_or_git_tag.startswith("@"):
277 file_version = file_path_or_git_tag.strip("@")
278 file_path = None
279 elif "@" in file_path_or_git_tag:
280 file_path, file_version = file_path_or_git_tag.split("@")
281 else:
282 file_version = default_version # `latest` for before; `current` for after.
283 file_path = file_path_or_git_tag
285 return file_version, file_path
288def extract_file_versions_and_paths(
289 filename_before: Optional[str],
290 filename_after: Optional[str],
291) -> tuple[tuple[str, str | None], tuple[str, str | None]]:
292 """
293 Extract the file versions and paths based on the before and after filenames.
295 Args:
296 filename_before (str, optional): The path of the file before the change (or None).
297 filename_after (str, optional): The path of the file after the change (or None).
299 Returns:
300 tuple[tuple[str, str | None], tuple[str, str | None]]:
301 A tuple of two tuples, each containing the version and path of the before and after files.
302 """
303 version_before, filepath_before = extract_file_version_and_path(
304 filename_before,
305 default_version=(
306 "current" if filename_after and filename_before and filename_after != filename_before else "latest"
307 ),
308 )
309 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current")
311 if not (filepath_before or filepath_after):
312 raise ValueError("Please supply at least one file name.")
313 elif not filepath_after:
314 filepath_after = filepath_before
315 elif not filepath_before:
316 filepath_before = filepath_after
318 return (version_before, filepath_before), (version_after, filepath_after)
321def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]:
322 """
323 Get absolute path information for the file based on the version and Git root.
325 Args:
326 filename (str, optional): The path of the file to check (or None).
327 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name.
328 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined.
330 Returns:
331 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file.
332 """
333 if version == "stdin":
334 return True, ""
335 elif filename is None:
336 # can't deal with this, not stdin and no file should show file missing error later.
337 return False, ""
339 if git_root is None:
340 git_root = find_git_root() or Path(os.getcwd())
342 path = Path(filename)
343 path_via_git = git_root / filename
345 if path.exists():
346 exists = True
347 absolute_path = str(path.resolve())
348 elif path_via_git.exists():
349 exists = True
350 absolute_path = str(path_via_git.resolve())
351 else:
352 exists = False
353 absolute_path = ""
355 return exists, absolute_path
358def check_indentation(code: str, fix: bool = False) -> str:
359 """
360 Check the indentation of the given code.
362 This function attempts to parse the code using Python's built-in `ast.parse` function.
363 If the code is correctly indented, it is returned as is. If the code is incorrectly indented,
364 an `IndentationError` is raised. If the `fix` parameter is set to `True`, the function will
365 attempt to fix the indentation using `textwrap.dedent` before returning it.
367 Args:
368 code (str): The code to check.
369 fix (bool): Whether to fix the indentation if it is incorrect. Defaults to False.
371 Returns:
372 str: The original code if it is correctly indented,
373 or the fixed code if `fix` is True and the code was incorrectly indented.
375 Raises:
376 IndentationError: If the code is incorrectly indented and `fix` is False.
377 """
378 if not code:
379 # no code? perfect indentation!
380 return code
382 try:
383 ast.parse(code)
384 return code
385 except IndentationError as e:
386 if fix:
387 return textwrap.dedent(code)
388 else:
389 raise e
392def ensure_no_migrate_on_real_db(
393 code: str,
394 db_names: typing.Iterable[str] = ("db", "database"),
395 fix: bool = False,
396) -> str:
397 """
398 Ensure that the code does not contain actual migrations on a real database.
400 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`.
401 It also removes local imports to prevent irrelevant code being executed.
403 Args:
404 code (str): The code to check for database migrations.
405 db_names (Iterable[str], optional): Names of variables representing the database.
406 Defaults to ("db", "database").
407 fix (bool, optional): If True, removes the migration code. Defaults to False.
409 Returns:
410 str: The modified code with migration code removed if fix=True, otherwise the original code.
411 """
412 variables = find_defined_variables(code)
414 found_variables = set()
416 for db_name in db_names:
417 if db_name in variables:
418 if fix:
419 code = remove_specific_variables(code, db_names)
420 else:
421 found_variables.add(db_name)
423 if found_variables:
424 if len(found_variables) == 1:
425 var = next(iter(found_variables))
426 message = f"Variable {var} defined in code! "
427 else: # pragma: no cover
428 var = ", ".join(found_variables)
429 message = f"Variables {var} defined in code! "
430 raise ValueError(
431 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database.",
432 )
434 if has_local_imports(code):
435 if fix:
436 code = remove_local_imports(code)
437 else:
438 raise ValueError("Local imports are used in this file! Please remove these or use --magic.")
440 return code
443MAX_RETRIES = 30
446# todo: overload more methods
449@dataclass
450class RenderContext:
451 """
452 Context passed to migration renderers.
453 """
455 db_old: DummyDAL
456 db_new: DummyDAL
457 tables: list[str]
458 db_type: Optional[str]
459 use_typedal: bool
460 is_create: bool
461 is_alter: bool
464Renderer = typing.Callable[[RenderContext], str]
466TEMPLATE_EXEC_PYDAL = """
467from pydal import *
468from pydal.objects import *
469from pydal.validators import *
471db = database = DummyDAL(None, migrate=False)
473tables = $tables
475$extra
477$code_before
479db_old = db
480db_new = db = database = DummyDAL(None, migrate=False)
482$extra
484$code_after
486if not tables:
487 tables = _uniq(db_old._tables + db_new._tables)
488 tables = _excl(tables, _special_tables)
490if not tables:
491 raise ValueError('no-tables-found')
492_tables = tables
493 """
495TEMPLATE_EXEC_TYPEDAL = """
496from pydal import *
497from pydal.objects import *
498from pydal.validators import *
499from typedal import *
501db = database = DummyDAL(None, migrate=False)
503tables = $tables
505$extra
507$code_before
509db_old = db
510db_new = db = database = DummyDAL(None, migrate=False)
512$extra
514$code_after
516if not tables:
517 tables = _uniq(db_old._tables + db_new._tables)
518 tables = _excl(tables, _special_tables)
520if not tables:
521 raise ValueError('no-tables-found')
522_tables = tables
523 """
526def default_sql_renderer(context: RenderContext) -> str:
527 """
528 Default SQL renderer used by handle_cli.
529 """
530 output = io.StringIO()
532 for table in context.tables:
533 print("-- start ", table, "--", file=output)
534 if table in context.db_old and table in context.db_new:
535 print(generate_sql(context.db_old[table], context.db_new[table], db_type=context.db_type), file=output)
536 elif table in context.db_old:
537 print(f"DROP TABLE {table};", file=output)
538 else:
539 print(generate_sql(context.db_new[table], db_type=context.db_type), file=output)
540 print("-- END OF MIGRATION --", file=output)
542 output.seek(0)
543 return output.read()
546def sql_to_function_name(sql_statement: str, default: Optional[str] = None) -> str:
547 """
548 Extract action (CREATE, ALTER, DROP) and table name from the SQL statement.
549 """
550 match = re.findall(r"(CREATE|ALTER|DROP)\s+TABLE\s+['\"]?(\w+)['\"]?", sql_statement.lower(), re.IGNORECASE)
552 if not match:
553 # raise ValueError("Invalid SQL statement. Unable to extract action and table name.")
554 return default or "unknown_migration"
556 action, table_name = match[0]
558 # Generate a function name with the specified format
559 return f"{action}_{table_name}"
562def _setup_generic_edwh_migrate(file: Path, is_typedal: bool) -> None:
563 contents = (
564 "from edwh_migrate import migration\n"
565 + ("from typedal import TypeDAL" if is_typedal else "from pydal import DAL")
566 + "\n"
567 )
569 with file.open("w") as f:
570 f.write(textwrap.dedent(contents))
572 rich.print(f"[green] New migrate file {file} created [/green]")
575START_RE = re.compile(r"-- start\s+\w+\s--\n")
578def _build_edwh_migration(
579 contents: str,
580 cls: str,
581 date: str,
582 existing: Optional[str] = None,
583 default_migration_name: Optional[str] = None,
584) -> str:
585 sql_func_name = sql_to_function_name(contents, default=default_migration_name)
586 func_name = "_placeholder_"
587 contents = START_RE.sub("", contents)
589 for n in range(1, 1000):
590 func_name = f"{sql_func_name}_{date}_{str(n).zfill(3)}"
592 if existing and f"def {func_name}" in existing:
593 if contents.replace(" ", "").replace("\n", "") in existing.replace(" ", "").replace("\n", ""):
594 rich.print(f"[yellow] migration {func_name} already exists, skipping! [/yellow]")
595 return ""
596 elif func_name.startswith(("alter", default_migration_name or "unknown")):
597 # bump number because alter migrations are different
598 continue
599 else:
600 rich.print(
601 f"[red] migration {func_name} already exists "
602 f"[bold]with different contents[/bold], skipping! [/red]",
603 )
604 return ""
605 else:
606 # okay function name, stop incrementing
607 break
609 contents = textwrap.indent(contents.strip(), " " * 16)
611 if not contents.strip():
612 # no real migration!
613 return ""
615 return textwrap.dedent(
616 f'''
618 @migration
619 def {func_name}(db: {cls}):
620 db.executesql("""
621{contents}
622 """)
623 db.commit()
625 return True
626 ''',
627 )
630def _build_edwh_migrations(
631 contents: str,
632 is_typedal: bool,
633 output: Optional[Path] = None,
634 default_migration_name: Optional[str] = None,
635) -> str:
636 cls = "TypeDAL" if is_typedal else "DAL"
637 date = datetime.now().strftime("%Y%m%d") # yyyymmdd
639 existing = output.read_text() if output and output.exists() else None
641 return "".join(
642 _build_edwh_migration(migration, cls, date, existing, default_migration_name=default_migration_name)
643 for migration in contents.split("-- END OF MIGRATION --")
644 if migration.strip()
645 )
648def _format_and_write_sql_output(
649 file: io.StringIO,
650 output_file: Path | str | io.StringIO | None,
651 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
652 is_typedal: bool = False,
653 default_migration_name: Optional[str] = None,
654) -> bool:
655 """
656 Handle generated migration code (e.g. core_create, core_alter or core_stub).
658 Args:
659 file (io.StringIO): The file-like object containing the migration content.
660 output_file (Union[Path, str, io.StringIO, None]): The file to which the migration content should be written.
661 If None, the content will be printed to stdout. If the string is '-', it will also be printed to stdout.
662 output_format (edwh-migrate, default): The format in which to output the migration.
663 Defaults to 'default'.
664 is_typedal (bool): If True, the migration content will be formatted for Typedal. Defaults to False (= pydal).
665 Only relevant for edwh-migrate format.
666 default_migration_name (Optional[str]): The default migration name used in the output content. Defaults to None.
668 Raises:
669 ValueError: If the provided output_format is unknown.
671 This function reads the migration content from the provided file-like object, formats it according to the specified
672 output format, and writes it to the specified output file or stdout.
673 """
674 file.seek(0)
675 contents = file.read()
677 if isinstance(output_file, str):
678 # `--output-file -` will print to stdout
679 output_file = None if output_file == "-" else Path(output_file)
681 if output_format == "edwh-migrate":
682 contents = _build_edwh_migrations(
683 contents,
684 is_typedal,
685 output_file if isinstance(output_file, Path) else None,
686 default_migration_name=default_migration_name,
687 )
688 elif output_format in {"default", "sql"} or not output_format:
689 contents = "\n".join(contents.split("-- END OF MIGRATION --"))
690 else:
691 raise ValueError(
692 f"Unknown format {output_format}. Please choose one of {typing.get_args(_SUPPORTED_OUTPUT_FORMATS)}",
693 )
695 if (
696 isinstance(output_file, Path)
697 and output_format == "edwh-migrate"
698 and (not output_file.exists() or output_file.stat().st_size == 0)
699 ):
700 _setup_generic_edwh_migrate(output_file, is_typedal)
702 return _write_output(contents, output_file, output_description="migration(s)")
705def try_format_and_write_sql_output(
706 file: io.StringIO,
707 output_file: Path | str | io.StringIO | None,
708 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
709 is_typedal: bool = False,
710 default_migration_name: Optional[str] = None,
711) -> bool:
712 """
713 Safe SQL formatter/writer that returns False instead of raising on invalid format.
714 """
715 try:
716 return _format_and_write_sql_output(
717 file,
718 output_file,
719 output_format=output_format,
720 is_typedal=is_typedal,
721 default_migration_name=default_migration_name,
722 )
723 except ValueError as e:
724 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr)
725 return False
728def _write_output(
729 contents: str,
730 output_file: Path | str | io.StringIO | None,
731 output_description: str = "output",
732) -> bool:
733 """
734 Write already-rendered output to destination.
735 """
736 if isinstance(output_file, str):
737 output_file = None if output_file == "-" else Path(output_file)
739 if isinstance(output_file, Path):
740 if contents.strip():
741 with output_file.open("a") as f:
742 f.write(contents)
744 rich.print(f"[green] Written {output_description} to {output_file} [/green]")
745 else:
746 rich.print(f"[yellow] Nothing to write to {output_file} [/yellow]")
747 elif isinstance(output_file, io.StringIO):
748 output_file.write(contents)
749 else:
750 print(contents.strip())
752 return True
755def _handle_import_error(code: str, error: ImportError) -> str:
756 # error is deeper in a package, find the related import in the code:
757 tb_lines = traceback.format_exc().splitlines()
759 for line in tb_lines:
760 if matches := IMPORT_IN_STR.findall(line):
761 # 'File "<string>", line 15, in <module>'
762 line_no = int(matches[0]) - 1
763 lines = code.split("\n")
764 return lines[line_no]
766 # I don't know how to trigger this case:
767 raise ValueError("Faulty import could not be automatically deleted") from error # pragma: no cover
770MISSING_RELATIONSHIP = re.compile(r"Cannot resolve reference (\w+) in \w+ definition")
773def _handle_relation_error(error: KeyError) -> tuple[str, str]:
774 if not (table := MISSING_RELATIONSHIP.findall(str(error))):
775 # other error, raise again
776 raise error
778 t = table[0]
780 return (
781 t,
782 """
783 db.define_table('%s', redefine=True)
784 """
785 % t,
786 )
789def render_schema_from_code(
790 code_after: str,
791 output_file: Optional[str | Path | io.StringIO],
792 renderer: Renderer,
793 db_type: Optional[str] = None,
794 tables: Optional[list[str] | list[list[str]]] = None,
795 verbose: bool = False,
796 noop: bool = False,
797 magic: bool = False,
798 function_name: Optional[str | tuple[str, ...]] = "define_tables",
799 use_typedal: bool = False,
800 code_before: str = "",
801 _update_path: bool = True,
802) -> bool:
803 """
804 Execute migration code and render output with a custom renderer callback.
806 Args:
807 code_before: Source code representing the initial table definitions.
808 code_after: Source code representing the desired table definitions.
809 renderer: Callback receiving RenderContext and returning final output text.
810 db_type: Optional database type hint.
811 tables: Explicit table selection.
812 verbose: Print generated execution code to stderr.
813 noop: Only print generated execution code, skip execution.
814 magic: Automatically inject missing names/import fallbacks.
815 function_name: Optional function(s) to call when no top-level tables are found.
816 use_typedal: Use TypeDAL execution template (`True`), pydal (`False`), or auto-detect.
817 output_file: Output destination (path, StringIO, stdout).
818 _update_path: Reserved for compatibility.
820 Returns:
821 bool: True on success, False on failure.
822 """
823 if function_name:
824 define_table_functions: set[str] = set(function_name) if isinstance(function_name, tuple) else {function_name}
825 else:
826 define_table_functions = set()
828 template = TEMPLATE_EXEC_TYPEDAL if use_typedal else TEMPLATE_EXEC_PYDAL
830 to_execute = string.Template(textwrap.dedent(template))
832 code_before = check_indentation(code_before, fix=magic)
833 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic)
835 code_after = check_indentation(code_after, fix=magic)
836 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic)
838 extra_code = ""
840 def _render_exec_code(_code_before: str, _code_after: str, _extra: str) -> str:
841 generated = to_execute.substitute(
842 {
843 "tables": flatten(tables or []),
844 "code_before": textwrap.dedent(_code_before),
845 "code_after": textwrap.dedent(_code_after),
846 "extra": textwrap.dedent(_extra),
847 },
848 )
850 # Function-defined tables (e.g. define_tables/define_td_tables) should be executed
851 # regardless of whether an explicit --tables filter is provided.
852 for _function_name in define_table_functions:
853 if find_function_to_call(generated, _function_name) is not None:
854 generated = add_function_call(generated, _function_name, multiple=True)
856 return generated
858 generated_code = _render_exec_code(code_before, code_after, extra_code)
859 if verbose or noop:
860 rich.print(generated_code, file=sys.stderr)
862 if noop:
863 # done
864 return True
866 err: typing.Optional[Exception] = None
867 catch: dict[str, Any] = {}
868 retry_counter = MAX_RETRIES
870 magic_vars = {"DummyDAL", "_special_tables", "_uniq", "_excl"}
871 special_tables: set[str] = {"typedal_cache", "typedal_cache_dependency"} if use_typedal else set()
873 cwd = os.getcwd()
874 while retry_counter:
875 if err and cwd not in sys.path:
876 # add current path to PYTHONPATH to possibly resolve imports
877 sys.path.append(os.getcwd())
879 retry_counter -= 1
880 try:
881 if verbose:
882 rich.print(generated_code, file=sys.stderr)
884 # 'catch' is used to add and receive globals from the exec scope.
885 # another argument could be added for locals, but adding simply {} changes the behavior negatively.
886 # so for now, only globals is passed.
887 catch["DummyDAL"] = (
888 DummyTypeDAL if use_typedal else DummyDAL
889 ) # <- use a fake DAL that doesn't actually run queries
890 catch["_special_tables"] = special_tables # <- e.g. typedal_cache, auth_user
891 catch["db_type"] = db_type or "" # allow source code to override db_type during exec
892 # note: when adding something to 'catch', also add it to magic_vars!!!
894 catch["_uniq"] = uniq # function to make a list unique without changing order
895 catch["_excl"] = excl # function to exclude items from a list
897 exec(generated_code, catch) # nosec: B102
899 context = RenderContext(
900 db_old=typing.cast(DummyDAL, catch["db_old"]),
901 db_new=typing.cast(DummyDAL, catch["db_new"]),
902 tables=list(catch["_tables"]),
903 db_type=catch.get("db_type", db_type),
904 use_typedal=use_typedal,
905 is_create=not bool(code_before.strip()),
906 is_alter=bool(code_before.strip()),
907 )
908 rendered = renderer(context)
909 _write_output(rendered, output_file)
910 return True # success!
911 except ValueError as e:
912 if str(e) != "no-tables-found": # pragma: no cover
913 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr)
914 return False
916 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr)
917 if use_typedal:
918 print(
919 "Please use `db.define` or `database.define`, "
920 "or if you really need to use an alias like my_db.define, "
921 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
922 file=sys.stderr,
923 )
924 else:
925 print(
926 "Please use `db.define_table` or `database.define_table`, "
927 "or if you really need to use an alias like my_db.define_tables, "
928 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
929 file=sys.stderr,
930 )
931 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr)
933 return False
935 except NameError as e:
936 err = e
937 # something is missing!
938 missing_vars = find_missing_variables(generated_code) - magic_vars
939 if not magic:
940 rich.print(
941 f"Your code is missing some variables: {missing_vars}. Add these or try --magic",
942 file=sys.stderr,
943 )
944 return False
946 # postponed: this can possibly also be achieved by updating the 'catch' dict
947 # instead of injecting in the string.
948 extra_code = extra_code + "\n" + textwrap.dedent(generate_magic_code(missing_vars))
950 code_before = remove_if_falsey_blocks(code_before)
951 code_after = remove_if_falsey_blocks(code_after)
953 generated_code = _render_exec_code(code_before, code_after, extra_code)
954 except ImportError as e:
955 # should include ModuleNotFoundError
956 err = e
957 # if we catch an ImportError, we try to remove the import and retry
958 if not e.path:
959 # code exists in code itself
960 code_before = remove_import(code_before, e.name or "")
961 code_after = remove_import(code_after, e.name or "")
962 else:
963 to_remove = _handle_import_error(generated_code, e)
964 code_before = code_before.replace(to_remove, "\n")
965 code_after = code_after.replace(to_remove, "\n")
967 generated_code = _render_exec_code(code_before, code_after, extra_code)
969 except KeyError as e:
970 err = e
971 table_name, table_definition = _handle_relation_error(e)
972 special_tables.add(table_name)
973 extra_code = extra_code + "\n" + textwrap.dedent(table_definition)
975 generated_code = _render_exec_code(code_before, code_after, extra_code)
976 except Exception as e:
977 err = e
978 # otherwise: give up
979 retry_counter = 0
980 finally:
981 # reset:
982 typing.TYPE_CHECKING = False
984 if retry_counter < 1: # pragma: no cover
985 rich.print(
986 f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'} ({type(err)})",
987 file=sys.stderr,
988 )
989 if state.verbosity > 2:
990 traceback.print_tb(err.__traceback__)
991 return False
993 # idk when this would happen, but something definitely went wrong here:
994 return False # pragma: no cover
997def handle_cli(
998 code_before: str,
999 code_after: str,
1000 db_type: Optional[str] = None,
1001 tables: Optional[list[str] | list[list[str]]] = None,
1002 verbose: bool = False,
1003 noop: bool = False,
1004 magic: bool = False,
1005 function_name: Optional[str | tuple[str, ...]] = "define_tables",
1006 use_typedal: bool | typing.Literal["auto"] = "auto",
1007 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
1008 output_file: Optional[str | Path | io.StringIO] = None,
1009 _update_path: bool = True,
1010) -> bool:
1011 """
1012 SQL-specific wrapper around render_schema_from_code.
1013 """
1014 is_typedal = (detect_typedal(code_before) or detect_typedal(code_after)) if use_typedal == "auto" else use_typedal
1016 raw_output = io.StringIO()
1017 success = render_schema_from_code(
1018 code_after,
1019 code_before=code_before,
1020 output_file=raw_output,
1021 renderer=default_sql_renderer,
1022 db_type=db_type,
1023 tables=tables,
1024 verbose=verbose,
1025 noop=noop,
1026 magic=magic,
1027 function_name=function_name,
1028 use_typedal=is_typedal,
1029 _update_path=_update_path,
1030 )
1031 if not success:
1032 return False
1033 if noop:
1034 return True
1036 return try_format_and_write_sql_output(
1037 raw_output,
1038 output_file,
1039 output_format=output_format,
1040 is_typedal=is_typedal,
1041 )
1044def find_file_contents(
1045 filename: Optional[str] = None,
1046 function: Optional[str | tuple[str, ...]] = None,
1047 prompt_description="table definition",
1048 found_functions: list[str] | None = None,
1049 default_version: str = "stdin",
1050 file_version: Optional[str] = None,
1051 git_root: Optional[Path] = None,
1052 with_git: bool = True,
1053) -> str:
1054 """
1055 Resolve a source file spec and return its contents.
1057 Supports:
1058 - `<path>`
1059 - `<path>:<function>`
1060 - `<path>@<git-ref>`
1061 - `-` / stdin
1063 Args:
1064 filename: File path (optionally including `:<function>` or `@<git-ref>`).
1065 function: Optional function name(s) to merge with any `:<function>` from `filename`.
1066 prompt_description: Description shown when reading from stdin.
1067 found_functions: Optional mutable list that receives discovered function names.
1068 default_version: Default version when `filename` has no `@<git-ref>`.
1069 file_version: Explicit file version. When provided, `filename` is treated as a plain path.
1070 git_root: Optional git root for path resolution.
1071 with_git: Whether `@<git-ref>` lookups should use git history.
1073 Returns:
1074 The loaded source code.
1076 Raises:
1077 FileNotFoundError: If the resolved source file does not exist.
1078 """
1079 if git_root is None:
1080 git_root = find_git_root(filename) or find_git_root() or Path(os.getcwd())
1082 functions: set[str] = set()
1083 if function: # pragma: no cover
1084 if isinstance(function, tuple):
1085 functions.update(function)
1086 else:
1087 functions.add(function)
1089 if filename and ":" in filename:
1090 # e.g. models.py:define_tables
1091 filename, _function = filename.split(":", 1)
1092 functions.add(_function)
1094 if file_version is None:
1095 file_version, file_path = extract_file_version_and_path(
1096 filename,
1097 default_version=default_version,
1098 )
1099 else:
1100 file_path = filename
1101 file_exists, file_absolute_path = get_absolute_path_info(file_path, file_version, git_root)
1103 if not file_exists:
1104 raise FileNotFoundError(f"Source file {filename} could not be found.")
1106 if found_functions is not None:
1107 found_functions.extend(functions)
1109 return get_file_for_version(
1110 file_absolute_path,
1111 file_version,
1112 prompt_description=prompt_description,
1113 with_git=with_git,
1114 )
1117def core_create(
1118 filename: Optional[str] = None,
1119 tables: Optional[list[str]] = None,
1120 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
1121 magic: bool = False,
1122 noop: bool = False,
1123 verbose: bool = False,
1124 function: Optional[str | tuple[str, ...]] = None,
1125 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
1126 output_file: Optional[str | Path] = None,
1127 _update_path: bool = True,
1128) -> bool:
1129 """
1130 Generates SQL migration statements for creating one or more tables, based on the code in a given source file.
1132 Args:
1133 filename: The filename of the source file to parse. This code represents the final state of the database.
1134 tables: A list of table names to generate SQL for.
1135 If None, the function will attempt to process all tables found in the code.
1136 db_type: The type of the database. If None, the function will attempt to infer it from the code.
1137 magic: If True, automatically add missing variables for execution.
1138 noop: If True, only print the generated code but do not execute it.
1139 verbose: If True, print the generated code and additional debug information.
1140 function: The name of the function where the tables are defined.
1141 If None, the function will use 'define_tables'.
1142 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
1143 output_file: append the output to a file instead of printing it?
1144 _update_path: try adding cwd to PYTHONPATH if failing?
1146 Returns:
1147 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
1148 False otherwise.
1150 Raises:
1151 ValueError: If the source file cannot be found or if no tables could be found in the code.
1152 """
1153 mut_functions = []
1154 text = find_file_contents(
1155 filename,
1156 function,
1157 found_functions=mut_functions,
1158 default_version="current" if filename else "stdin",
1159 )
1161 return handle_cli(
1162 "",
1163 text,
1164 db_type=db_type,
1165 tables=tables,
1166 verbose=verbose,
1167 noop=noop,
1168 magic=magic,
1169 function_name=tuple(mut_functions),
1170 output_format=output_format,
1171 output_file=output_file,
1172 _update_path=_update_path,
1173 )
1176def core_alter(
1177 filename_before: Optional[str] = None,
1178 filename_after: Optional[str] = None,
1179 tables: Optional[list[str]] = None,
1180 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
1181 magic: bool = False,
1182 noop: bool = False,
1183 verbose: bool = False,
1184 function: Optional[str] = None,
1185 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
1186 output_file: Optional[str | Path] = None,
1187 _update_path: bool = True,
1188) -> bool:
1189 """
1190 Generates SQL migration statements for altering the database, based on the code in two given source files.
1192 Args:
1193 filename_before: The filename of the source file before changes.
1194 This code represents the initial state of the database.
1195 filename_after: The filename of the source file after changes.
1196 This code represents the final state of the database.
1197 tables: A list of table names to generate SQL for.
1198 If None, the function will attempt to process all tables found in the code.
1199 db_type: The type of the database. If None, the function will attempt to infer it from the code.
1200 magic: If True, automatically add missing variables for execution.
1201 noop: If True, only print the generated code but do not execute it.
1202 verbose: If True, print the generated code and additional debug information.
1203 function: The name of the function where the tables are defined.
1204 If None, the function will use 'define_tables'.
1205 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
1206 output_file: append the output to a file instead of printing it?
1207 _update_path: try adding cwd to PYTHONPATH if failing?
1209 Returns:
1210 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
1211 False otherwise.
1213 Raises:
1214 ValueError: If either of the source files cannot be found, if no tables could be found in the code,
1215 or if the codes before and after are identical.
1216 """
1217 git_root = find_git_root(filename_before) or find_git_root(filename_after)
1219 functions: set[str] = set()
1220 if function: # pragma: no cover
1221 functions.add(function)
1223 if filename_before and ":" in filename_before:
1224 # e.g. models.py:define_tables
1225 filename_before, _function = filename_before.split(":", 1)
1226 functions.add(_function)
1228 if filename_after and ":" in filename_after:
1229 # e.g. models.py:define_tables
1230 filename_after, _function = filename_after.split(":", 1)
1231 functions.add(_function)
1233 before, after = extract_file_versions_and_paths(filename_before, filename_after)
1235 version_before, filename_before = before
1236 version_after, filename_after = after
1238 # either ./file exists or /file exists (seen from git root):
1240 before_exists, _ = get_absolute_path_info(filename_before, version_before, git_root)
1241 after_exists, _ = get_absolute_path_info(filename_after, version_after, git_root)
1243 if not (before_exists and after_exists):
1244 message = ""
1245 message += "" if before_exists else f"Path {filename_before} does not exist! "
1246 if filename_before != filename_after:
1247 message += "" if after_exists else f"Path {filename_after} does not exist!"
1248 raise FileNotFoundError(message)
1250 try:
1251 code_before = find_file_contents(
1252 filename_before,
1253 prompt_description="current table definition",
1254 file_version=version_before,
1255 git_root=git_root,
1256 with_git=git_root is not None,
1257 )
1258 code_after = find_file_contents(
1259 filename_after,
1260 prompt_description="desired table definition",
1261 file_version=version_after,
1262 git_root=git_root,
1263 with_git=git_root is not None,
1264 )
1266 if not (code_before and code_after):
1267 message = ""
1268 message += "" if code_before else "Before code is empty (Maybe try `pydal2sql create`)! "
1269 message += "" if code_after else "After code is empty! "
1270 raise ValueError(message)
1272 if code_before == code_after:
1273 raise ValueError("Both contain the same code - nothing to alter!")
1275 except ValueError as e:
1276 rich.print(f"[yellow] {e} [/yellow]", file=sys.stderr)
1277 return False
1279 return handle_cli(
1280 code_before,
1281 code_after,
1282 db_type=db_type,
1283 tables=tables,
1284 verbose=verbose,
1285 noop=noop,
1286 magic=magic,
1287 function_name=tuple(functions),
1288 output_format=output_format,
1289 output_file=output_file,
1290 )
1293def core_stub(
1294 migration_name: str,
1295 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
1296 output_file: Optional[str | Path | io.StringIO] = None,
1297 dry_run: bool = False,
1298 is_typedal: bool = False,
1299) -> bool:
1300 """
1301 Generate a dummy migration.
1303 Args:
1304 migration_name (str): The name of the migration.
1305 output_format: The format in which to output the migration (edwh-migrate, default).
1306 Defaults to "default".
1307 output_file: The file to which the migration should be output.
1308 If None, the migration will be printed to stdout. Defaults to None.
1309 dry_run (bool): If True, the migration will not be executed and will only be printed to stdout.
1310 Defaults to False.
1311 is_typedal (bool): If True, the migration will be generated for Typedal.
1312 Defaults to False, which will generate the migration for pydal.
1313 Only relevant if output_format=edwh-migrate
1315 Returns:
1316 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
1317 False otherwise.
1319 Example:
1320 >>> core_stub("some_change", output_format="sql", output_file="migration.sql", dry_run=False, is_typedal=False)
1321 True
1323 This function is mostly useful for usage with output_format='edwh-migrate'.
1324 If `dry_run` is True, the migration will only be printed to stdout and not executed.
1325 """
1326 if dry_run:
1327 # None will print it to stdout
1328 output_file = None
1330 return _format_and_write_sql_output(
1331 io.StringIO(f"-- {migration_name}\n"),
1332 output_file,
1333 output_format=output_format,
1334 is_typedal=is_typedal,
1335 default_migration_name=migration_name,
1336 )