Coverage for src/pydal2sql_core/cli_support.py: 99%
377 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-22 12:18 +0200
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-22 12:18 +0200
1"""
2CLI-Agnostic support.
3"""
5import ast
6import contextlib
7import io
8import os
9import re
10import select
11import string
12import sys
13import textwrap
14import traceback
15import typing
16from dataclasses import dataclass
17from datetime import datetime
18from pathlib import Path
19from typing import Any, Optional
21import git
22import gitdb.exc
23import rich
24from black.files import find_project_root
25from git.objects.blob import Blob
26from git.objects.commit import Commit
27from git.repo import Repo
28from witchery import (
29 add_function_call,
30 find_defined_variables,
31 find_function_to_call,
32 find_missing_variables,
33 generate_magic_code,
34 has_local_imports,
35 remove_if_falsey_blocks,
36 remove_import,
37 remove_local_imports,
38 remove_specific_variables,
39)
41from .core import generate_sql
42from .helpers import detect_typedal, excl, flatten, uniq
43from .state import state
44from .types import (
45 _SUPPORTED_OUTPUT_FORMATS,
46 DEFAULT_OUTPUT_FORMAT,
47 SUPPORTED_DATABASE_TYPES_WITH_ALIASES,
48 SUPPORTED_OUTPUT_FORMATS,
49 DummyDAL,
50 DummyTypeDAL,
51)
53IMPORT_IN_STR = re.compile(r'File "<string>", line (\d+), in <module>')
56def has_stdin_data() -> bool: # pragma: no cover
57 """
58 Check if the program starts with cli data (pipe | or redirect <).
60 Returns:
61 bool: True if the program starts with cli data, False otherwise.
63 See Also:
64 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data
65 """
66 return any(
67 select.select(
68 [
69 sys.stdin,
70 ],
71 [],
72 [],
73 0.0,
74 )[0],
75 )
78AnyCallable = typing.Callable[..., Any]
81def print_if_interactive(*args: Any, pretty: bool = True, **kwargs: Any) -> None: # pragma: no cover
82 """
83 Print the given arguments if running in an interactive session.
85 Args:
86 *args: Variable length list of arguments to be printed.
87 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function.
88 **kwargs: Optional keyword arguments to be passed to the print function.
90 Returns:
91 None
92 """
93 is_interactive = not has_stdin_data()
94 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy
95 if is_interactive:
96 kwargs["file"] = sys.stderr
97 _print(
98 *args,
99 **kwargs,
100 )
103def find_git_root(at: str = None) -> Optional[Path]:
104 """
105 Find the root directory of the Git repository.
107 Args:
108 at (str, optional): The directory path to start the search. Defaults to the current working directory.
110 Returns:
111 Optional[Path]: The root directory of the Git repository if found, otherwise None.
112 """
113 folder, reason = find_project_root((at or os.getcwd(),))
114 if reason != ".git directory":
115 return None
116 return folder
119def find_git_repo(repo: Repo = None, at: str = None) -> Repo:
120 """
121 Find the Git repository instance.
123 Args:
124 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance.
125 at (str, optional): The directory path to start the search. Defaults to the current working directory.
127 Returns:
128 Repo: The Git repository instance.
129 """
130 if repo:
131 return repo
133 root = find_git_root(at)
134 return Repo(str(root))
137def latest_commit(repo: Repo = None) -> Commit:
138 """
139 Get the latest commit in the Git repository.
141 Args:
142 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
144 Returns:
145 Commit: The latest commit in the Git repository.
146 """
147 repo = find_git_repo(repo)
148 return repo.head.commit
151def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit:
152 """
153 Get a specific commit in the Git repository by its hash or name.
155 Args:
156 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name.
157 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
159 Returns:
160 Commit: The commit object corresponding to the given commit hash.
161 """
162 repo = find_git_repo(repo)
163 return repo.commit(commit_hash)
166@contextlib.contextmanager
167def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]:
168 """
169 Open a Git Blob object as a context manager, providing access to its data.
171 Args:
172 file (Blob): The Git Blob object to open.
174 Yields:
175 io.BytesIO: A BytesIO object providing access to the Blob data.
176 """
177 yield io.BytesIO(file.data_stream.read())
180def read_blob(file: Blob) -> str:
181 """
182 Read the contents of a Git Blob object and decode it as a string.
184 Args:
185 file (Blob): The Git Blob object to read.
187 Returns:
188 str: The contents of the Blob as a string.
189 """
190 with open_blob(file) as f:
191 return f.read().decode()
194def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str:
195 """
196 Get the contents of a file in the Git repository at a specific commit version.
198 Args:
199 filename (str): The path of the file to retrieve.
200 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit).
201 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
203 Returns:
204 str: The contents of the file as a string.
205 """
206 repo = find_git_repo(repo, at=filename)
207 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo)
209 file_path = str(Path(filename).resolve())
210 # relative to the .git folder:
211 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/")
213 file_at_commit = commit.tree / relative_file_path
214 return read_blob(file_at_commit)
217def get_file_for_version(filename: str, version: str, prompt_description: str = "", with_git: bool = True) -> str:
218 """
219 Get the contents of a file based on the version specified.
221 Args:
222 filename (str): The path of the file to retrieve.
223 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name.
224 prompt_description (str, optional): A description to display when asking for input from stdin.
226 Returns:
227 str: The contents of the file as a string.
228 """
229 if not with_git or version == "current":
230 return Path(filename).read_text()
231 elif version == "stdin": # pragma: no cover
232 print_if_interactive(
233 f"[blue]Please paste your define tables ({prompt_description}) code below "
234 f"and press ctrl-D when finished.[/blue]",
235 file=sys.stderr,
236 )
237 result = sys.stdin.read()
238 print_if_interactive("[blue]---[/blue]", file=sys.stderr)
239 return result
240 elif with_git:
241 try:
242 return get_file_for_commit(filename, version)
243 except (git.exc.GitError, gitdb.exc.ODBError) as e:
244 raise FileNotFoundError(f"{filename}@{version}") from e
247def extract_file_version_and_path(
248 file_path_or_git_tag: Optional[str],
249 default_version: str = "stdin",
250) -> tuple[str, str | None]:
251 """
252 Extract the file version and path from the given input.
254 Args:
255 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag.
256 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin".
258 Returns:
259 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified).
261 Examples:
262 myfile.py (implies @current)
264 myfile.py@latest
265 myfile.py@my-branch
266 myfile.py@b3f24091a9
268 @latest (implies no path, e.g. in case of ALTER to copy previously defined path)
269 """
270 if not file_path_or_git_tag:
271 return default_version, ""
273 if file_path_or_git_tag == "-":
274 return "stdin", "-"
276 if file_path_or_git_tag.startswith("@"):
277 file_version = file_path_or_git_tag.strip("@")
278 file_path = None
279 elif "@" in file_path_or_git_tag:
280 file_path, file_version = file_path_or_git_tag.split("@")
281 else:
282 file_version = default_version # `latest` for before; `current` for after.
283 file_path = file_path_or_git_tag
285 return file_version, file_path
288def extract_file_versions_and_paths(
289 filename_before: Optional[str],
290 filename_after: Optional[str],
291) -> tuple[tuple[str, str | None], tuple[str, str | None]]:
292 """
293 Extract the file versions and paths based on the before and after filenames.
295 Args:
296 filename_before (str, optional): The path of the file before the change (or None).
297 filename_after (str, optional): The path of the file after the change (or None).
299 Returns:
300 tuple[tuple[str, str | None], tuple[str, str | None]]:
301 A tuple of two tuples, each containing the version and path of the before and after files.
302 """
303 version_before, filepath_before = extract_file_version_and_path(
304 filename_before,
305 default_version=(
306 "current" if filename_after and filename_before and filename_after != filename_before else "latest"
307 ),
308 )
309 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current")
311 if not (filepath_before or filepath_after):
312 raise ValueError("Please supply at least one file name.")
313 elif not filepath_after:
314 filepath_after = filepath_before
315 elif not filepath_before:
316 filepath_before = filepath_after
318 return (version_before, filepath_before), (version_after, filepath_after)
321def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]:
322 """
323 Get absolute path information for the file based on the version and Git root.
325 Args:
326 filename (str, optional): The path of the file to check (or None).
327 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name.
328 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined.
330 Returns:
331 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file.
332 """
333 if version == "stdin":
334 return True, ""
335 elif filename is None:
336 # can't deal with this, not stdin and no file should show file missing error later.
337 return False, ""
339 if git_root is None:
340 git_root = find_git_root() or Path(os.getcwd())
342 path = Path(filename)
343 path_via_git = git_root / filename
345 if path.exists():
346 exists = True
347 absolute_path = str(path.resolve())
348 elif path_via_git.exists():
349 exists = True
350 absolute_path = str(path_via_git.resolve())
351 else:
352 exists = False
353 absolute_path = ""
355 return exists, absolute_path
358def check_indentation(code: str, fix: bool = False) -> str:
359 """
360 Check the indentation of the given code.
362 This function attempts to parse the code using Python's built-in `ast.parse` function.
363 If the code is correctly indented, it is returned as is. If the code is incorrectly indented,
364 an `IndentationError` is raised. If the `fix` parameter is set to `True`, the function will
365 attempt to fix the indentation using `textwrap.dedent` before returning it.
367 Args:
368 code (str): The code to check.
369 fix (bool): Whether to fix the indentation if it is incorrect. Defaults to False.
371 Returns:
372 str: The original code if it is correctly indented,
373 or the fixed code if `fix` is True and the code was incorrectly indented.
375 Raises:
376 IndentationError: If the code is incorrectly indented and `fix` is False.
377 """
378 if not code:
379 # no code? perfect indentation!
380 return code
382 try:
383 ast.parse(code)
384 return code
385 except IndentationError as e:
386 if fix:
387 return textwrap.dedent(code)
388 else:
389 raise e
392def ensure_no_migrate_on_real_db(
393 code: str,
394 db_names: typing.Iterable[str] = ("db", "database"),
395 fix: bool = False,
396) -> str:
397 """
398 Ensure that the code does not contain actual migrations on a real database.
400 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`.
401 It also removes local imports to prevent irrelevant code being executed.
403 Args:
404 code (str): The code to check for database migrations.
405 db_names (Iterable[str], optional): Names of variables representing the database.
406 Defaults to ("db", "database").
407 fix (bool, optional): If True, removes the migration code. Defaults to False.
409 Returns:
410 str: The modified code with migration code removed if fix=True, otherwise the original code.
411 """
412 variables = find_defined_variables(code)
414 found_variables = set()
416 for db_name in db_names:
417 if db_name in variables:
418 if fix:
419 code = remove_specific_variables(code, db_names)
420 else:
421 found_variables.add(db_name)
423 if found_variables:
424 if len(found_variables) == 1:
425 var = next(iter(found_variables))
426 message = f"Variable {var} defined in code! "
427 else: # pragma: no cover
428 var = ", ".join(found_variables)
429 message = f"Variables {var} defined in code! "
430 raise ValueError(
431 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database.",
432 )
434 if has_local_imports(code):
435 if fix:
436 code = remove_local_imports(code)
437 else:
438 raise ValueError("Local imports are used in this file! Please remove these or use --magic.")
440 return code
443MAX_RETRIES = 30
446# todo: overload more methods
449@dataclass
450class RenderContext:
451 """
452 Context passed to migration renderers.
453 """
455 db_old: DummyDAL
456 db_new: DummyDAL
457 tables: list[str]
458 db_type: Optional[str]
459 use_typedal: bool
460 is_create: bool
461 is_alter: bool
464Renderer = typing.Callable[[RenderContext], str]
466TEMPLATE_EXEC_PYDAL = """
467from pydal import *
468from pydal.objects import *
469from pydal.validators import *
471db = database = DummyDAL(None, migrate=False)
473tables = $tables
475$extra
477$code_before
479db_old = db
480db_new = db = database = DummyDAL(None, migrate=False)
482$extra
484$code_after
486if not tables:
487 tables = _uniq(db_old._tables + db_new._tables)
488 tables = _excl(tables, _special_tables)
490if not tables:
491 raise ValueError('no-tables-found')
492_tables = tables
493 """
495TEMPLATE_EXEC_TYPEDAL = """
496from pydal import *
497from pydal.objects import *
498from pydal.validators import *
499from typedal import *
501db = database = DummyDAL(None, migrate=False)
503tables = $tables
505$extra
507$code_before
509db_old = db
510db_new = db = database = DummyDAL(None, migrate=False)
512$extra
514$code_after
516if not tables:
517 tables = _uniq(db_old._tables + db_new._tables)
518 tables = _excl(tables, _special_tables)
520if not tables:
521 raise ValueError('no-tables-found')
522_tables = tables
523 """
526def default_sql_renderer(context: RenderContext) -> str:
527 """
528 Default SQL renderer used by handle_cli.
529 """
530 output = io.StringIO()
532 for table in context.tables:
533 print("-- start ", table, "--", file=output)
534 if table in context.db_old and table in context.db_new:
535 print(generate_sql(context.db_old[table], context.db_new[table], db_type=context.db_type), file=output)
536 elif table in context.db_old:
537 print(f"DROP TABLE {table};", file=output)
538 else:
539 print(generate_sql(context.db_new[table], db_type=context.db_type), file=output)
540 print("-- END OF MIGRATION --", file=output)
542 output.seek(0)
543 return output.read()
546def sql_to_function_name(sql_statement: str, default: Optional[str] = None) -> str:
547 """
548 Extract action (CREATE, ALTER, DROP) and table name from the SQL statement.
549 """
550 match = re.findall(r"(CREATE|ALTER|DROP)\s+TABLE\s+['\"]?(\w+)['\"]?", sql_statement.lower(), re.IGNORECASE)
552 if not match:
553 # raise ValueError("Invalid SQL statement. Unable to extract action and table name.")
554 return default or "unknown_migration"
556 action, table_name = match[0]
558 # Generate a function name with the specified format
559 return f"{action}_{table_name}"
562def _setup_generic_edwh_migrate(file: Path, is_typedal: bool) -> None:
563 contents = (
564 "from edwh_migrate import migration\n"
565 + ("from typedal import TypeDAL" if is_typedal else "from pydal import DAL")
566 + "\n"
567 )
569 with file.open("w") as f:
570 f.write(textwrap.dedent(contents))
572 rich.print(f"[green] New migrate file {file} created [/green]")
575START_RE = re.compile(r"-- start\s+\w+\s--\n")
578def _build_edwh_migration(
579 contents: str,
580 cls: str,
581 date: str,
582 existing: Optional[str] = None,
583 default_migration_name: Optional[str] = None,
584) -> str:
585 sql_func_name = sql_to_function_name(contents, default=default_migration_name)
586 func_name = "_placeholder_"
587 contents = START_RE.sub("", contents)
589 for n in range(1, 1000):
590 func_name = f"{sql_func_name}_{date}_{str(n).zfill(3)}"
592 if existing and f"def {func_name}" in existing:
593 if contents.replace(" ", "").replace("\n", "") in existing.replace(" ", "").replace("\n", ""):
594 rich.print(f"[yellow] migration {func_name} already exists, skipping! [/yellow]")
595 return ""
596 elif func_name.startswith(("alter", default_migration_name or "unknown")):
597 # bump number because alter migrations are different
598 continue
599 else:
600 rich.print(
601 f"[red] migration {func_name} already exists "
602 f"[bold]with different contents[/bold], skipping! [/red]",
603 )
604 return ""
605 else:
606 # okay function name, stop incrementing
607 break
609 contents = textwrap.indent(contents.strip(), " " * 16)
611 if not contents.strip():
612 # no real migration!
613 return ""
615 return textwrap.dedent(
616 f'''
618 @migration
619 def {func_name}(db: {cls}):
620 db.executesql("""
621{contents}
622 """)
623 db.commit()
625 return True
626 ''',
627 )
630def _build_edwh_migrations(
631 contents: str,
632 is_typedal: bool,
633 output: Optional[Path] = None,
634 default_migration_name: Optional[str] = None,
635) -> str:
636 cls = "TypeDAL" if is_typedal else "DAL"
637 date = datetime.now().strftime("%Y%m%d") # yyyymmdd
639 existing = output.read_text() if output and output.exists() else None
641 return "".join(
642 _build_edwh_migration(migration, cls, date, existing, default_migration_name=default_migration_name)
643 for migration in contents.split("-- END OF MIGRATION --")
644 if migration.strip()
645 )
648def _format_and_write_sql_output(
649 file: io.StringIO,
650 output_file: Path | str | io.StringIO | None,
651 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
652 is_typedal: bool = False,
653 default_migration_name: Optional[str] = None,
654) -> bool:
655 """
656 Handle generated migration code (e.g. core_create, core_alter or core_stub).
658 Args:
659 file (io.StringIO): The file-like object containing the migration content.
660 output_file (Union[Path, str, io.StringIO, None]): The file to which the migration content should be written.
661 If None, the content will be printed to stdout. If the string is '-', it will also be printed to stdout.
662 output_format (edwh-migrate, default): The format in which to output the migration.
663 Defaults to 'default'.
664 is_typedal (bool): If True, the migration content will be formatted for Typedal. Defaults to False (= pydal).
665 Only relevant for edwh-migrate format.
666 default_migration_name (Optional[str]): The default migration name used in the output content. Defaults to None.
668 Raises:
669 ValueError: If the provided output_format is unknown.
671 This function reads the migration content from the provided file-like object, formats it according to the specified
672 output format, and writes it to the specified output file or stdout.
673 """
674 file.seek(0)
675 contents = file.read()
677 if isinstance(output_file, str):
678 # `--output-file -` will print to stdout
679 output_file = None if output_file == "-" else Path(output_file)
681 if output_format == "edwh-migrate":
682 contents = _build_edwh_migrations(
683 contents,
684 is_typedal,
685 output_file if isinstance(output_file, Path) else None,
686 default_migration_name=default_migration_name,
687 )
688 elif output_format in {"default", "sql"} or not output_format:
689 contents = "\n".join(contents.split("-- END OF MIGRATION --"))
690 else:
691 raise ValueError(
692 f"Unknown format {output_format}. Please choose one of {typing.get_args(_SUPPORTED_OUTPUT_FORMATS)}",
693 )
695 if (
696 isinstance(output_file, Path)
697 and output_format == "edwh-migrate"
698 and (not output_file.exists() or output_file.stat().st_size == 0)
699 ):
700 _setup_generic_edwh_migrate(output_file, is_typedal)
702 return _write_output(contents, output_file, output_description="migration(s)")
705def try_format_and_write_sql_output(
706 file: io.StringIO,
707 output_file: Path | str | io.StringIO | None,
708 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
709 is_typedal: bool = False,
710 default_migration_name: Optional[str] = None,
711) -> bool:
712 """
713 Safe SQL formatter/writer that returns False instead of raising on invalid format.
714 """
715 try:
716 return _format_and_write_sql_output(
717 file,
718 output_file,
719 output_format=output_format,
720 is_typedal=is_typedal,
721 default_migration_name=default_migration_name,
722 )
723 except ValueError as e:
724 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr)
725 return False
728def _write_output(
729 contents: str,
730 output_file: Path | str | io.StringIO | None,
731 output_description: str = "output",
732) -> bool:
733 """
734 Write already-rendered output to destination.
735 """
736 if isinstance(output_file, str):
737 output_file = None if output_file == "-" else Path(output_file)
739 if isinstance(output_file, Path):
740 if contents.strip():
741 with output_file.open("a") as f:
742 f.write(contents)
744 rich.print(f"[green] Written {output_description} to {output_file} [/green]")
745 else:
746 rich.print(f"[yellow] Nothing to write to {output_file} [/yellow]")
747 elif isinstance(output_file, io.StringIO):
748 output_file.write(contents)
749 else:
750 print(contents.strip())
752 return True
755def _handle_import_error(code: str, error: ImportError) -> str:
756 # error is deeper in a package, find the related import in the code:
757 tb_lines = traceback.format_exc().splitlines()
759 for line in tb_lines:
760 if matches := IMPORT_IN_STR.findall(line):
761 # 'File "<string>", line 15, in <module>'
762 line_no = int(matches[0]) - 1
763 lines = code.split("\n")
764 return lines[line_no]
766 # I don't know how to trigger this case:
767 raise ValueError("Faulty import could not be automatically deleted") from error # pragma: no cover
770MISSING_RELATIONSHIP = re.compile(r"Cannot resolve reference (\w+) in \w+ definition")
773def _handle_relation_error(error: KeyError) -> tuple[str, str]:
774 if not (table := MISSING_RELATIONSHIP.findall(str(error))):
775 # other error, raise again
776 raise error
778 t = table[0]
780 return (
781 t,
782 """
783 db.define_table('%s', redefine=True)
784 """
785 % t,
786 )
789def render_schema_from_code(
790 code_after: str,
791 output_file: Optional[str | Path | io.StringIO],
792 renderer: Renderer,
793 db_type: Optional[str] = None,
794 tables: Optional[list[str] | list[list[str]]] = None,
795 verbose: bool = False,
796 noop: bool = False,
797 magic: bool = False,
798 function_name: Optional[str | tuple[str, ...]] = "define_tables",
799 use_typedal: bool = False,
800 code_before: str = "",
801 _update_path: bool = True,
802) -> bool:
803 """
804 Execute migration code and render output with a custom renderer callback.
806 Args:
807 code_before: Source code representing the initial table definitions.
808 code_after: Source code representing the desired table definitions.
809 renderer: Callback receiving RenderContext and returning final output text.
810 db_type: Optional database type hint.
811 tables: Explicit table selection.
812 verbose: Print generated execution code to stderr.
813 noop: Only print generated execution code, skip execution.
814 magic: Automatically inject missing names/import fallbacks.
815 function_name: Optional function(s) to call when no top-level tables are found.
816 use_typedal: Use TypeDAL execution template (`True`), pydal (`False`), or auto-detect.
817 output_file: Output destination (path, StringIO, stdout).
818 _update_path: Reserved for compatibility.
820 Returns:
821 bool: True on success, False on failure.
822 """
823 if function_name:
824 define_table_functions: set[str] = set(function_name) if isinstance(function_name, tuple) else {function_name}
825 else:
826 define_table_functions = set()
828 template = TEMPLATE_EXEC_TYPEDAL if use_typedal else TEMPLATE_EXEC_PYDAL
830 to_execute = string.Template(textwrap.dedent(template))
832 code_before = check_indentation(code_before, fix=magic)
833 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic)
835 code_after = check_indentation(code_after, fix=magic)
836 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic)
838 extra_code = ""
840 generated_code = to_execute.substitute(
841 {
842 "tables": flatten(tables or []),
843 "code_before": textwrap.dedent(code_before),
844 "code_after": textwrap.dedent(code_after),
845 "extra": extra_code,
846 },
847 )
848 if verbose or noop:
849 rich.print(generated_code, file=sys.stderr)
851 if noop:
852 # done
853 return True
855 err: typing.Optional[Exception] = None
856 catch: dict[str, Any] = {}
857 retry_counter = MAX_RETRIES
859 magic_vars = {"DummyDAL", "_special_tables", "_uniq", "_excl"}
860 special_tables: set[str] = {"typedal_cache", "typedal_cache_dependency"} if use_typedal else set()
862 cwd = os.getcwd()
863 while retry_counter:
864 if err and cwd not in sys.path:
865 # add current path to PYTHONPATH to possibly resolve imports
866 sys.path.append(os.getcwd())
868 retry_counter -= 1
869 try:
870 if verbose:
871 rich.print(generated_code, file=sys.stderr)
873 # 'catch' is used to add and receive globals from the exec scope.
874 # another argument could be added for locals, but adding simply {} changes the behavior negatively.
875 # so for now, only globals is passed.
876 catch["DummyDAL"] = (
877 DummyTypeDAL if use_typedal else DummyDAL
878 ) # <- use a fake DAL that doesn't actually run queries
879 catch["_special_tables"] = special_tables # <- e.g. typedal_cache, auth_user
880 catch["db_type"] = db_type or "" # allow source code to override db_type during exec
881 # note: when adding something to 'catch', also add it to magic_vars!!!
883 catch["_uniq"] = uniq # function to make a list unique without changing order
884 catch["_excl"] = excl # function to exclude items from a list
886 exec(generated_code, catch) # nosec: B102
888 context = RenderContext(
889 db_old=typing.cast(DummyDAL, catch["db_old"]),
890 db_new=typing.cast(DummyDAL, catch["db_new"]),
891 tables=list(catch["_tables"]),
892 db_type=catch.get("db_type", None),
893 use_typedal=use_typedal,
894 is_create=not bool(code_before.strip()),
895 is_alter=bool(code_before.strip()),
896 )
897 rendered = renderer(context)
898 _write_output(rendered, output_file)
899 return True # success!
900 except ValueError as e:
901 if str(e) != "no-tables-found": # pragma: no cover
902 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr)
903 return False
905 if define_table_functions:
906 any_found = False
907 for function_name in define_table_functions:
908 define_tables = find_function_to_call(generated_code, function_name)
910 # if define_tables function is found, add call to it at end of code
911 if define_tables is not None:
912 generated_code = add_function_call(generated_code, function_name, multiple=True)
913 any_found = True
915 if any_found:
916 # hurray!
917 continue
919 # else: no define_tables or other method to use found.
921 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr)
922 if use_typedal:
923 print(
924 "Please use `db.define` or `database.define`, "
925 "or if you really need to use an alias like my_db.define, "
926 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
927 file=sys.stderr,
928 )
929 else:
930 print(
931 "Please use `db.define_table` or `database.define_table`, "
932 "or if you really need to use an alias like my_db.define_tables, "
933 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
934 file=sys.stderr,
935 )
936 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr)
938 return False
940 except NameError as e:
941 err = e
942 # something is missing!
943 missing_vars = find_missing_variables(generated_code) - magic_vars
944 if not magic:
945 rich.print(
946 f"Your code is missing some variables: {missing_vars}. Add these or try --magic",
947 file=sys.stderr,
948 )
949 return False
951 # postponed: this can possibly also be achieved by updating the 'catch' dict
952 # instead of injecting in the string.
953 extra_code = extra_code + "\n" + textwrap.dedent(generate_magic_code(missing_vars))
955 code_before = remove_if_falsey_blocks(code_before)
956 code_after = remove_if_falsey_blocks(code_after)
958 generated_code = to_execute.substitute(
959 {
960 "tables": flatten(tables or []),
961 "extra": textwrap.dedent(extra_code),
962 "code_before": textwrap.dedent(code_before),
963 "code_after": textwrap.dedent(code_after),
964 },
965 )
966 except ImportError as e:
967 # should include ModuleNotFoundError
968 err = e
969 # if we catch an ImportError, we try to remove the import and retry
970 if not e.path:
971 # code exists in code itself
972 code_before = remove_import(code_before, e.name or "")
973 code_after = remove_import(code_after, e.name or "")
974 else:
975 to_remove = _handle_import_error(generated_code, e)
976 code_before = code_before.replace(to_remove, "\n")
977 code_after = code_after.replace(to_remove, "\n")
979 generated_code = to_execute.substitute(
980 {
981 "tables": flatten(tables or []),
982 "extra": textwrap.dedent(extra_code),
983 "code_before": textwrap.dedent(code_before),
984 "code_after": textwrap.dedent(code_after),
985 },
986 )
988 except KeyError as e:
989 err = e
990 table_name, table_definition = _handle_relation_error(e)
991 special_tables.add(table_name)
992 extra_code = extra_code + "\n" + textwrap.dedent(table_definition)
994 generated_code = to_execute.substitute(
995 {
996 "tables": flatten(tables or []),
997 "extra": textwrap.dedent(extra_code),
998 "code_before": textwrap.dedent(code_before),
999 "code_after": textwrap.dedent(code_after),
1000 },
1001 )
1002 except Exception as e:
1003 err = e
1004 # otherwise: give up
1005 retry_counter = 0
1006 finally:
1007 # reset:
1008 typing.TYPE_CHECKING = False
1010 if retry_counter < 1: # pragma: no cover
1011 rich.print(
1012 f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'} ({type(err)})",
1013 file=sys.stderr,
1014 )
1015 if state.verbosity > 2:
1016 traceback.print_tb(err.__traceback__)
1017 return False
1019 # idk when this would happen, but something definitely went wrong here:
1020 return False # pragma: no cover
1023def handle_cli(
1024 code_before: str,
1025 code_after: str,
1026 db_type: Optional[str] = None,
1027 tables: Optional[list[str] | list[list[str]]] = None,
1028 verbose: bool = False,
1029 noop: bool = False,
1030 magic: bool = False,
1031 function_name: Optional[str | tuple[str, ...]] = "define_tables",
1032 use_typedal: bool | typing.Literal["auto"] = "auto",
1033 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
1034 output_file: Optional[str | Path | io.StringIO] = None,
1035 _update_path: bool = True,
1036) -> bool:
1037 """
1038 SQL-specific wrapper around render_schema_from_code.
1039 """
1040 is_typedal = (detect_typedal(code_before) or detect_typedal(code_after)) if use_typedal == "auto" else use_typedal
1042 raw_output = io.StringIO()
1043 success = render_schema_from_code(
1044 code_after,
1045 code_before=code_before,
1046 output_file=raw_output,
1047 renderer=default_sql_renderer,
1048 db_type=db_type,
1049 tables=tables,
1050 verbose=verbose,
1051 noop=noop,
1052 magic=magic,
1053 function_name=function_name,
1054 use_typedal=is_typedal,
1055 _update_path=_update_path,
1056 )
1057 if not success:
1058 return False
1059 if noop:
1060 return True
1062 return try_format_and_write_sql_output(
1063 raw_output,
1064 output_file,
1065 output_format=output_format,
1066 is_typedal=is_typedal,
1067 )
1070def core_create(
1071 filename: Optional[str] = None,
1072 tables: Optional[list[str]] = None,
1073 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
1074 magic: bool = False,
1075 noop: bool = False,
1076 verbose: bool = False,
1077 function: Optional[str | tuple[str, ...]] = None,
1078 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
1079 output_file: Optional[str | Path] = None,
1080 _update_path: bool = True,
1081) -> bool:
1082 """
1083 Generates SQL migration statements for creating one or more tables, based on the code in a given source file.
1085 Args:
1086 filename: The filename of the source file to parse. This code represents the final state of the database.
1087 tables: A list of table names to generate SQL for.
1088 If None, the function will attempt to process all tables found in the code.
1089 db_type: The type of the database. If None, the function will attempt to infer it from the code.
1090 magic: If True, automatically add missing variables for execution.
1091 noop: If True, only print the generated code but do not execute it.
1092 verbose: If True, print the generated code and additional debug information.
1093 function: The name of the function where the tables are defined.
1094 If None, the function will use 'define_tables'.
1095 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
1096 output_file: append the output to a file instead of printing it?
1097 _update_path: try adding cwd to PYTHONPATH if failing?
1099 Returns:
1100 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
1101 False otherwise.
1103 Raises:
1104 ValueError: If the source file cannot be found or if no tables could be found in the code.
1105 """
1106 git_root = find_git_root() or Path(os.getcwd())
1108 functions: set[str] = set()
1109 if function: # pragma: no cover
1110 if isinstance(function, tuple):
1111 functions.update(function)
1112 else:
1113 functions.add(function)
1115 if filename and ":" in filename:
1116 # e.g. models.py:define_tables
1117 filename, _function = filename.split(":", 1)
1118 functions.add(_function)
1120 file_version, file_path = extract_file_version_and_path(
1121 filename,
1122 default_version="current" if filename else "stdin",
1123 )
1124 file_exists, file_absolute_path = get_absolute_path_info(file_path, file_version, git_root)
1126 if not file_exists:
1127 raise FileNotFoundError(f"Source file {filename} could not be found.")
1129 text = get_file_for_version(file_absolute_path, file_version, prompt_description="table definition")
1131 return handle_cli(
1132 "",
1133 text,
1134 db_type=db_type,
1135 tables=tables,
1136 verbose=verbose,
1137 noop=noop,
1138 magic=magic,
1139 function_name=tuple(functions),
1140 output_format=output_format,
1141 output_file=output_file,
1142 _update_path=_update_path,
1143 )
1146def core_alter(
1147 filename_before: Optional[str] = None,
1148 filename_after: Optional[str] = None,
1149 tables: Optional[list[str]] = None,
1150 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
1151 magic: bool = False,
1152 noop: bool = False,
1153 verbose: bool = False,
1154 function: Optional[str] = None,
1155 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
1156 output_file: Optional[str | Path] = None,
1157 _update_path: bool = True,
1158) -> bool:
1159 """
1160 Generates SQL migration statements for altering the database, based on the code in two given source files.
1162 Args:
1163 filename_before: The filename of the source file before changes.
1164 This code represents the initial state of the database.
1165 filename_after: The filename of the source file after changes.
1166 This code represents the final state of the database.
1167 tables: A list of table names to generate SQL for.
1168 If None, the function will attempt to process all tables found in the code.
1169 db_type: The type of the database. If None, the function will attempt to infer it from the code.
1170 magic: If True, automatically add missing variables for execution.
1171 noop: If True, only print the generated code but do not execute it.
1172 verbose: If True, print the generated code and additional debug information.
1173 function: The name of the function where the tables are defined.
1174 If None, the function will use 'define_tables'.
1175 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
1176 output_file: append the output to a file instead of printing it?
1177 _update_path: try adding cwd to PYTHONPATH if failing?
1179 Returns:
1180 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
1181 False otherwise.
1183 Raises:
1184 ValueError: If either of the source files cannot be found, if no tables could be found in the code,
1185 or if the codes before and after are identical.
1186 """
1187 git_root = find_git_root(filename_before) or find_git_root(filename_after)
1189 functions: set[str] = set()
1190 if function: # pragma: no cover
1191 functions.add(function)
1193 if filename_before and ":" in filename_before:
1194 # e.g. models.py:define_tables
1195 filename_before, _function = filename_before.split(":", 1)
1196 functions.add(_function)
1198 if filename_after and ":" in filename_after:
1199 # e.g. models.py:define_tables
1200 filename_after, _function = filename_after.split(":", 1)
1201 functions.add(_function)
1203 before, after = extract_file_versions_and_paths(filename_before, filename_after)
1205 version_before, filename_before = before
1206 version_after, filename_after = after
1208 # either ./file exists or /file exists (seen from git root):
1210 before_exists, before_absolute_path = get_absolute_path_info(filename_before, version_before, git_root)
1211 after_exists, after_absolute_path = get_absolute_path_info(filename_after, version_after, git_root)
1213 if not (before_exists and after_exists):
1214 message = ""
1215 message += "" if before_exists else f"Path {filename_before} does not exist! "
1216 if filename_before != filename_after:
1217 message += "" if after_exists else f"Path {filename_after} does not exist!"
1218 raise FileNotFoundError(message)
1220 try:
1221 code_before = get_file_for_version(
1222 before_absolute_path,
1223 version_before,
1224 prompt_description="current table definition",
1225 with_git=git_root is not None,
1226 )
1227 code_after = get_file_for_version(
1228 after_absolute_path,
1229 version_after,
1230 prompt_description="desired table definition",
1231 with_git=git_root is not None,
1232 )
1234 if not (code_before and code_after):
1235 message = ""
1236 message += "" if code_before else "Before code is empty (Maybe try `pydal2sql create`)! "
1237 message += "" if code_after else "After code is empty! "
1238 raise ValueError(message)
1240 if code_before == code_after:
1241 raise ValueError("Both contain the same code - nothing to alter!")
1243 except ValueError as e:
1244 rich.print(f"[yellow] {e} [/yellow]", file=sys.stderr)
1245 return False
1247 return handle_cli(
1248 code_before,
1249 code_after,
1250 db_type=db_type,
1251 tables=tables,
1252 verbose=verbose,
1253 noop=noop,
1254 magic=magic,
1255 function_name=tuple(functions),
1256 output_format=output_format,
1257 output_file=output_file,
1258 )
1261def core_stub(
1262 migration_name: str,
1263 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
1264 output_file: Optional[str | Path | io.StringIO] = None,
1265 dry_run: bool = False,
1266 is_typedal: bool = False,
1267) -> bool:
1268 """
1269 Generate a dummy migration.
1271 Args:
1272 migration_name (str): The name of the migration.
1273 output_format: The format in which to output the migration (edwh-migrate, default).
1274 Defaults to "default".
1275 output_file: The file to which the migration should be output.
1276 If None, the migration will be printed to stdout. Defaults to None.
1277 dry_run (bool): If True, the migration will not be executed and will only be printed to stdout.
1278 Defaults to False.
1279 is_typedal (bool): If True, the migration will be generated for Typedal.
1280 Defaults to False, which will generate the migration for pydal.
1281 Only relevant if output_format=edwh-migrate
1283 Returns:
1284 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
1285 False otherwise.
1287 Example:
1288 >>> core_stub("some_change", output_format="sql", output_file="migration.sql", dry_run=False, is_typedal=False)
1289 True
1291 This function is mostly useful for usage with output_format='edwh-migrate'.
1292 If `dry_run` is True, the migration will only be printed to stdout and not executed.
1293 """
1294 if dry_run:
1295 # None will print it to stdout
1296 output_file = None
1298 return _format_and_write_sql_output(
1299 io.StringIO(f"-- {migration_name}\n"),
1300 output_file,
1301 output_format=output_format,
1302 is_typedal=is_typedal,
1303 default_migration_name=migration_name,
1304 )