Coverage for src/pydal2sql_core/cli_support.py: 100%
339 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-08-05 17:27 +0200
« prev ^ index » next coverage.py v7.4.4, created at 2024-08-05 17:27 +0200
1"""
2CLI-Agnostic support.
3"""
5import ast
6import contextlib
7import io
8import os
9import re
10import select
11import string
12import sys
13import textwrap
14import traceback
15import typing
16from datetime import datetime
17from pathlib import Path
18from typing import Any, Optional
20import git
21import gitdb.exc
22import rich
23from black.files import find_project_root
24from git.objects.blob import Blob
25from git.objects.commit import Commit
26from git.repo import Repo
27from witchery import (
28 add_function_call,
29 find_defined_variables,
30 find_function_to_call,
31 find_missing_variables,
32 generate_magic_code,
33 has_local_imports,
34 remove_if_falsey_blocks,
35 remove_import,
36 remove_local_imports,
37 remove_specific_variables,
38)
40from .helpers import excl, flatten, uniq
41from .types import (
42 _SUPPORTED_OUTPUT_FORMATS,
43 DEFAULT_OUTPUT_FORMAT,
44 SUPPORTED_DATABASE_TYPES_WITH_ALIASES,
45 SUPPORTED_OUTPUT_FORMATS,
46 DummyDAL,
47 DummyTypeDAL,
48)
51def has_stdin_data() -> bool: # pragma: no cover
52 """
53 Check if the program starts with cli data (pipe | or redirect <).
55 Returns:
56 bool: True if the program starts with cli data, False otherwise.
58 See Also:
59 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data
60 """
61 return any(
62 select.select(
63 [
64 sys.stdin,
65 ],
66 [],
67 [],
68 0.0,
69 )[0]
70 )
73AnyCallable = typing.Callable[..., Any]
76def print_if_interactive(*args: Any, pretty: bool = True, **kwargs: Any) -> None: # pragma: no cover
77 """
78 Print the given arguments if running in an interactive session.
80 Args:
81 *args: Variable length list of arguments to be printed.
82 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function.
83 **kwargs: Optional keyword arguments to be passed to the print function.
85 Returns:
86 None
87 """
88 is_interactive = not has_stdin_data()
89 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy
90 if is_interactive:
91 kwargs["file"] = sys.stderr
92 _print(
93 *args,
94 **kwargs,
95 )
98def find_git_root(at: str = None) -> Optional[Path]:
99 """
100 Find the root directory of the Git repository.
102 Args:
103 at (str, optional): The directory path to start the search. Defaults to the current working directory.
105 Returns:
106 Optional[Path]: The root directory of the Git repository if found, otherwise None.
107 """
108 folder, reason = find_project_root((at or os.getcwd(),))
109 if reason != ".git directory":
110 return None
111 return folder
114def find_git_repo(repo: Repo = None, at: str = None) -> Repo:
115 """
116 Find the Git repository instance.
118 Args:
119 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance.
120 at (str, optional): The directory path to start the search. Defaults to the current working directory.
122 Returns:
123 Repo: The Git repository instance.
124 """
125 if repo:
126 return repo
128 root = find_git_root(at)
129 return Repo(str(root))
132def latest_commit(repo: Repo = None) -> Commit:
133 """
134 Get the latest commit in the Git repository.
136 Args:
137 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
139 Returns:
140 Commit: The latest commit in the Git repository.
141 """
142 repo = find_git_repo(repo)
143 return repo.head.commit
146def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit:
147 """
148 Get a specific commit in the Git repository by its hash or name.
150 Args:
151 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name.
152 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
154 Returns:
155 Commit: The commit object corresponding to the given commit hash.
156 """
157 repo = find_git_repo(repo)
158 return repo.commit(commit_hash)
161@contextlib.contextmanager
162def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]:
163 """
164 Open a Git Blob object as a context manager, providing access to its data.
166 Args:
167 file (Blob): The Git Blob object to open.
169 Yields:
170 io.BytesIO: A BytesIO object providing access to the Blob data.
171 """
172 yield io.BytesIO(file.data_stream.read())
175def read_blob(file: Blob) -> str:
176 """
177 Read the contents of a Git Blob object and decode it as a string.
179 Args:
180 file (Blob): The Git Blob object to read.
182 Returns:
183 str: The contents of the Blob as a string.
184 """
185 with open_blob(file) as f:
186 return f.read().decode()
189def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str:
190 """
191 Get the contents of a file in the Git repository at a specific commit version.
193 Args:
194 filename (str): The path of the file to retrieve.
195 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit).
196 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
198 Returns:
199 str: The contents of the file as a string.
200 """
201 repo = find_git_repo(repo, at=filename)
202 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo)
204 file_path = str(Path(filename).resolve())
205 # relative to the .git folder:
206 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/")
208 file_at_commit = commit.tree / relative_file_path
209 return read_blob(file_at_commit)
212def get_file_for_version(filename: str, version: str, prompt_description: str = "", with_git: bool = True) -> str:
213 """
214 Get the contents of a file based on the version specified.
216 Args:
217 filename (str): The path of the file to retrieve.
218 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name.
219 prompt_description (str, optional): A description to display when asking for input from stdin.
221 Returns:
222 str: The contents of the file as a string.
223 """
224 if not with_git or version == "current":
225 return Path(filename).read_text()
226 elif version == "stdin": # pragma: no cover
227 print_if_interactive(
228 f"[blue]Please paste your define tables ({prompt_description}) code below "
229 f"and press ctrl-D when finished.[/blue]",
230 file=sys.stderr,
231 )
232 result = sys.stdin.read()
233 print_if_interactive("[blue]---[/blue]", file=sys.stderr)
234 return result
235 elif with_git:
236 try:
237 return get_file_for_commit(filename, version)
238 except (git.exc.GitError, gitdb.exc.ODBError) as e:
239 raise FileNotFoundError(f"{filename}@{version}") from e
242def extract_file_version_and_path(
243 file_path_or_git_tag: Optional[str], default_version: str = "stdin"
244) -> tuple[str, str | None]:
245 """
246 Extract the file version and path from the given input.
248 Args:
249 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag.
250 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin".
252 Returns:
253 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified).
255 Examples:
256 myfile.py (implies @current)
258 myfile.py@latest
259 myfile.py@my-branch
260 myfile.py@b3f24091a9
262 @latest (implies no path, e.g. in case of ALTER to copy previously defined path)
263 """
264 if not file_path_or_git_tag:
265 return default_version, ""
267 if file_path_or_git_tag == "-":
268 return "stdin", "-"
270 if file_path_or_git_tag.startswith("@"):
271 file_version = file_path_or_git_tag.strip("@")
272 file_path = None
273 elif "@" in file_path_or_git_tag:
274 file_path, file_version = file_path_or_git_tag.split("@")
275 else:
276 file_version = default_version # `latest` for before; `current` for after.
277 file_path = file_path_or_git_tag
279 return file_version, file_path
282def extract_file_versions_and_paths(
283 filename_before: Optional[str], filename_after: Optional[str]
284) -> tuple[tuple[str, str | None], tuple[str, str | None]]:
285 """
286 Extract the file versions and paths based on the before and after filenames.
288 Args:
289 filename_before (str, optional): The path of the file before the change (or None).
290 filename_after (str, optional): The path of the file after the change (or None).
292 Returns:
293 tuple[tuple[str, str | None], tuple[str, str | None]]:
294 A tuple of two tuples, each containing the version and path of the before and after files.
295 """
296 version_before, filepath_before = extract_file_version_and_path(
297 filename_before,
298 default_version=(
299 "current" if filename_after and filename_before and filename_after != filename_before else "latest"
300 ),
301 )
302 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current")
304 if not (filepath_before or filepath_after):
305 raise ValueError("Please supply at least one file name.")
306 elif not filepath_after:
307 filepath_after = filepath_before
308 elif not filepath_before:
309 filepath_before = filepath_after
311 return (version_before, filepath_before), (version_after, filepath_after)
314def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]:
315 """
316 Get absolute path information for the file based on the version and Git root.
318 Args:
319 filename (str, optional): The path of the file to check (or None).
320 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name.
321 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined.
323 Returns:
324 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file.
325 """
326 if version == "stdin":
327 return True, ""
328 elif filename is None:
329 # can't deal with this, not stdin and no file should show file missing error later.
330 return False, ""
332 if git_root is None:
333 git_root = find_git_root() or Path(os.getcwd())
335 path = Path(filename)
336 path_via_git = git_root / filename
338 if path.exists():
339 exists = True
340 absolute_path = str(path.resolve())
341 elif path_via_git.exists():
342 exists = True
343 absolute_path = str(path_via_git.resolve())
344 else:
345 exists = False
346 absolute_path = ""
348 return exists, absolute_path
351def check_indentation(code: str, fix: bool = False) -> str:
352 """
353 Check the indentation of the given code.
355 This function attempts to parse the code using Python's built-in `ast.parse` function.
356 If the code is correctly indented, it is returned as is. If the code is incorrectly indented,
357 an `IndentationError` is raised. If the `fix` parameter is set to `True`, the function will
358 attempt to fix the indentation using `textwrap.dedent` before returning it.
360 Args:
361 code (str): The code to check.
362 fix (bool): Whether to fix the indentation if it is incorrect. Defaults to False.
364 Returns:
365 str: The original code if it is correctly indented,
366 or the fixed code if `fix` is True and the code was incorrectly indented.
368 Raises:
369 IndentationError: If the code is incorrectly indented and `fix` is False.
370 """
371 if not code:
372 # no code? perfect indentation!
373 return code
375 try:
376 ast.parse(code)
377 return code
378 except IndentationError as e:
379 if fix:
380 return textwrap.dedent(code)
381 else:
382 raise e
385def ensure_no_migrate_on_real_db(
386 code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: bool = False
387) -> str:
388 """
389 Ensure that the code does not contain actual migrations on a real database.
391 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`.
392 It also removes local imports to prevent irrelevant code being executed.
394 Args:
395 code (str): The code to check for database migrations.
396 db_names (Iterable[str], optional): Names of variables representing the database.
397 Defaults to ("db", "database").
398 fix (bool, optional): If True, removes the migration code. Defaults to False.
400 Returns:
401 str: The modified code with migration code removed if fix=True, otherwise the original code.
402 """
403 variables = find_defined_variables(code)
405 found_variables = set()
407 for db_name in db_names:
408 if db_name in variables:
409 if fix:
410 code = remove_specific_variables(code, db_names)
411 else:
412 found_variables.add(db_name)
414 if found_variables:
415 if len(found_variables) == 1:
416 var = next(iter(found_variables))
417 message = f"Variable {var} defined in code! "
418 else: # pragma: no cover
419 var = ", ".join(found_variables)
420 message = f"Variables {var} defined in code! "
421 raise ValueError(
422 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database."
423 )
425 if has_local_imports(code):
426 if fix:
427 code = remove_local_imports(code)
428 else:
429 raise ValueError("Local imports are used in this file! Please remove these or use --magic.")
431 return code
434MAX_RETRIES = 30
436# todo: overload more methods
438TEMPLATE_PYDAL = """
439from pydal import *
440from pydal.objects import *
441from pydal.validators import *
443from pydal2sql_core import generate_sql
446# from pydal import DAL
447db = database = DummyDAL(None, migrate=False)
449tables = $tables
450db_type = '$db_type'
452$extra
454$code_before
456db_old = db
457db_new = db = database = DummyDAL(None, migrate=False)
459$extra
461$code_after
463if not tables:
464 tables = _uniq(db_old._tables + db_new._tables)
465 tables = _excl(tables, _special_tables)
467if not tables:
468 raise ValueError('no-tables-found')
470for table in tables:
471 print('-- start ', table, '--', file=_file)
472 if table in db_old and table in db_new:
473 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file)
474 elif table in db_old:
475 print(f'DROP TABLE {table};', file=_file)
476 else:
477 print(generate_sql(db_new[table], db_type=db_type), file=_file)
478 print('-- END OF MIGRATION --', file=_file)
479 """
481TEMPLATE_TYPEDAL = """
482from pydal import *
483from pydal.objects import *
484from pydal.validators import *
485from typedal import *
487from pydal2sql_core import generate_sql
490# from typedal import TypeDAL as DAL
491db = database = DummyDAL(None, migrate=False)
493tables = $tables
494db_type = '$db_type'
496$extra
498$code_before
500db_old = db
501db_new = db = database = DummyDAL(None, migrate=False)
503$extra
505$code_after
507if not tables:
508 tables = _uniq(db_old._tables + db_new._tables)
509 tables = _excl(tables, _special_tables)
511if not tables:
512 raise ValueError('no-tables-found')
515for table in tables:
516 print('-- start ', table, '--', file=_file)
517 if table in db_old and table in db_new:
518 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file)
519 elif table in db_old:
520 print(f'DROP TABLE {table};', file=_file)
521 else:
522 print(generate_sql(db_new[table], db_type=db_type), file=_file)
524 print('-- END OF MIGRATION --', file=_file)
525 """
528def sql_to_function_name(sql_statement: str, default: Optional[str] = None) -> str:
529 """
530 Extract action (CREATE, ALTER, DROP) and table name from the SQL statement.
531 """
532 match = re.findall(r"(CREATE|ALTER|DROP)\s+TABLE\s+['\"]?(\w+)['\"]?", sql_statement.lower(), re.IGNORECASE)
534 if not match:
535 # raise ValueError("Invalid SQL statement. Unable to extract action and table name.")
536 return default or "unknown_migration"
538 action, table_name = match[0]
540 # Generate a function name with the specified format
541 return f"{action}_{table_name}"
544def _setup_generic_edwh_migrate(file: Path, is_typedal: bool) -> None:
545 contents = (
546 "from edwh_migrate import migration\n"
547 + ("from typedal import TypeDAL" if is_typedal else "from pydal import DAL")
548 + "\n"
549 )
551 with file.open("w") as f:
552 f.write(textwrap.dedent(contents))
554 rich.print(f"[green] New migrate file {file} created [/green]")
557START_RE = re.compile(r"-- start\s+\w+\s--\n")
560def _build_edwh_migration(
561 contents: str, cls: str, date: str, existing: Optional[str] = None, default_migration_name: Optional[str] = None
562) -> str:
563 sql_func_name = sql_to_function_name(contents, default=default_migration_name)
564 func_name = "_placeholder_"
565 contents = START_RE.sub("", contents)
567 for n in range(1, 1000):
568 func_name = f"{sql_func_name}_{date}_{str(n).zfill(3)}"
570 if existing and f"def {func_name}" in existing:
571 if contents.replace(" ", "").replace("\n", "") in existing.replace(" ", "").replace("\n", ""):
572 rich.print(f"[yellow] migration {func_name} already exists, skipping! [/yellow]")
573 return ""
574 elif func_name.startswith(("alter", default_migration_name or "unknown")):
575 # bump number because alter migrations are different
576 continue
577 else:
578 rich.print(
579 f"[red] migration {func_name} already exists [bold]with different contents[/bold], skipping! [/red]"
580 )
581 return ""
582 else:
583 # okay function name, stop incrementing
584 break
586 contents = textwrap.indent(contents.strip(), " " * 16)
588 if not contents.strip():
589 # no real migration!
590 return ""
592 return textwrap.dedent(
593 f'''
595 @migration
596 def {func_name}(db: {cls}):
597 db.executesql("""
598{contents}
599 """)
600 db.commit()
602 return True
603 '''
604 )
607def _build_edwh_migrations(
608 contents: str, is_typedal: bool, output: Optional[Path] = None, default_migration_name: Optional[str] = None
609) -> str:
610 cls = "TypeDAL" if is_typedal else "DAL"
611 date = datetime.now().strftime("%Y%m%d") # yyyymmdd
613 existing = output.read_text() if output and output.exists() else None
615 return "".join(
616 _build_edwh_migration(migration, cls, date, existing, default_migration_name=default_migration_name)
617 for migration in contents.split("-- END OF MIGRATION --")
618 if migration.strip()
619 )
622def _handle_output(
623 file: io.StringIO,
624 output_file: Path | str | io.StringIO | None,
625 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
626 is_typedal: bool = False,
627 default_migration_name: Optional[str] = None,
628) -> None:
629 file.seek(0)
630 contents = file.read()
632 if isinstance(output_file, str):
633 # `--output-file -` will print to stdout
634 output_file = None if output_file == "-" else Path(output_file)
636 if output_format == "edwh-migrate":
637 contents = _build_edwh_migrations(
638 contents,
639 is_typedal,
640 output_file if isinstance(output_file, Path) else None,
641 default_migration_name=default_migration_name,
642 )
643 elif output_format in {"default", "sql"} or not output_format:
644 contents = "\n".join(contents.split("-- END OF MIGRATION --"))
645 else:
646 raise ValueError(
647 f"Unknown format {output_format}. " f"Please choose one of {typing.get_args(_SUPPORTED_OUTPUT_FORMATS)}"
648 )
650 if isinstance(output_file, Path):
651 if output_format == "edwh-migrate" and (not output_file.exists() or output_file.stat().st_size == 0):
652 _setup_generic_edwh_migrate(output_file, is_typedal)
654 if contents.strip():
655 with output_file.open("a") as f:
656 f.write(contents)
658 rich.print(f"[green] Written migration(s) to {output_file} [/green]")
659 else:
660 rich.print(f"[yellow] Nothing to write to {output_file} [/yellow]")
662 elif isinstance(output_file, io.StringIO):
663 output_file.write(contents)
664 else:
665 # no file, just print to stdout:
666 print(contents.strip())
669IMPORT_IN_STR = re.compile(r'File "<string>", line (\d+), in <module>')
672def _handle_import_error(code: str, error: ImportError) -> str:
673 # error is deeper in a package, find the related import in the code:
674 tb_lines = traceback.format_exc().splitlines()
676 for line in tb_lines:
677 if matches := IMPORT_IN_STR.findall(line):
678 # 'File "<string>", line 15, in <module>'
679 line_no = int(matches[0]) - 1
680 lines = code.split("\n")
681 return lines[line_no]
683 # I don't know how to trigger this case:
684 raise ValueError("Faulty import could not be automatically deleted") from error # pragma: no cover
687MISSING_RELATIONSHIP = re.compile(r"Cannot resolve reference (\w+) in \w+ definition")
690def _handle_relation_error(error: KeyError) -> tuple[str, str]:
691 if not (table := MISSING_RELATIONSHIP.findall(str(error))):
692 # other error, raise again
693 raise error
695 t = table[0]
697 return (
698 t,
699 """
700 db.define_table('%s', redefine=True)
701 """
702 % t,
703 )
706def handle_cli(
707 code_before: str,
708 code_after: str,
709 db_type: Optional[str] = None,
710 tables: Optional[list[str] | list[list[str]]] = None,
711 verbose: bool = False,
712 noop: bool = False,
713 magic: bool = False,
714 function_name: Optional[str | tuple[str, ...]] = "define_tables",
715 use_typedal: bool | typing.Literal["auto"] = "auto",
716 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
717 output_file: Optional[str | Path | io.StringIO] = None,
718 _update_path: bool = True,
719) -> bool:
720 """
721 Handle user input for generating SQL migration statements based on before and after code.
723 Args:
724 code_before (str): The code representing the state of the database before the change.
725 code_after (str, optional): The code representing the state of the database after the change.
726 db_type (str, optional): The type of the database (e.g., "postgres", "mysql", etc.). Defaults to None.
727 tables (list[str] or list[list[str]], optional): The list of tables to generate SQL for. Defaults to None.
728 verbose (bool, optional): If True, print the generated code. Defaults to False.
729 noop (bool, optional): If True, only print the generated code but do not execute it. Defaults to False.
730 magic (bool, optional): If True, automatically add missing variables for execution. Defaults to False.
731 function_name (str, optional): The name of the function where the tables are defined. Defaults: "define_tables".
732 use_typedal: replace pydal imports with TypeDAL?
733 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
734 output_file: append the output to a file instead of printing it?
735 _update_path: try adding cwd to PYTHONPATH if failing?
737 # todo: prefix (e.g. public.)
739 Returns:
740 bool: True if SQL migration statements are generated and executed successfully, False otherwise.
741 """
742 # todo: better typedal checking
743 if use_typedal == "auto":
744 use_typedal = "typedal" in code_before.lower() or "typedal" in code_after.lower()
746 if function_name:
747 define_table_functions: set[str] = set(function_name) if isinstance(function_name, tuple) else {function_name}
748 else:
749 define_table_functions = set()
751 template = TEMPLATE_TYPEDAL if use_typedal else TEMPLATE_PYDAL
753 to_execute = string.Template(textwrap.dedent(template))
755 code_before = check_indentation(code_before, fix=magic)
756 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic)
758 code_after = check_indentation(code_after, fix=magic)
759 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic)
761 extra_code = ""
763 generated_code = to_execute.substitute(
764 {
765 "tables": flatten(tables or []),
766 "db_type": db_type or "",
767 "code_before": textwrap.dedent(code_before),
768 "code_after": textwrap.dedent(code_after),
769 "extra": extra_code,
770 }
771 )
772 if verbose or noop:
773 rich.print(generated_code, file=sys.stderr)
775 if noop:
776 # done
777 return True
779 err: typing.Optional[Exception] = None
780 catch: dict[str, Any] = {}
781 retry_counter = MAX_RETRIES
783 magic_vars = {"_file", "DummyDAL", "_special_tables", "_uniq", "_excl"}
784 special_tables: set[str] = {"typedal_cache", "typedal_cache_dependency"} if use_typedal else set()
786 cwd = os.getcwd()
787 while retry_counter:
788 if err and cwd not in sys.path:
789 # add current path to PYTHONPATH to possibly resolve imports
790 sys.path.append(os.getcwd())
792 retry_counter -= 1
793 try:
794 if verbose:
795 rich.print(generated_code, file=sys.stderr)
797 # 'catch' is used to add and receive globals from the exec scope.
798 # another argument could be added for locals, but adding simply {} changes the behavior negatively.
799 # so for now, only globals is passed.
800 catch["_file"] = io.StringIO() # <- every print should go to this file, so we can handle it afterwards
801 catch["DummyDAL"] = (
802 DummyTypeDAL if use_typedal else DummyDAL
803 ) # <- use a fake DAL that doesn't actually run queries
804 catch["_special_tables"] = special_tables # <- e.g. typedal_cache, auth_user
805 # note: when adding something to 'catch', also add it to magic_vars!!!
807 catch["_uniq"] = uniq # function to make a list unique without changing order
808 catch["_excl"] = excl # function to exclude items from a list
810 exec(generated_code, catch) # nosec: B102
811 _handle_output(catch["_file"], output_file, output_format, is_typedal=use_typedal)
812 return True # success!
813 except ValueError as e:
814 if str(e) != "no-tables-found": # pragma: no cover
815 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr)
816 return False
818 if define_table_functions:
819 any_found = False
820 for function_name in define_table_functions:
821 define_tables = find_function_to_call(generated_code, function_name)
823 # if define_tables function is found, add call to it at end of code
824 if define_tables is not None:
825 generated_code = add_function_call(generated_code, function_name, multiple=True)
826 any_found = True
828 if any_found:
829 # hurray!
830 continue
832 # else: no define_tables or other method to use found.
834 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr)
835 if use_typedal:
836 print(
837 "Please use `db.define` or `database.define`, "
838 "or if you really need to use an alias like my_db.define, "
839 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
840 file=sys.stderr,
841 )
842 else:
843 print(
844 "Please use `db.define_table` or `database.define_table`, "
845 "or if you really need to use an alias like my_db.define_tables, "
846 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
847 file=sys.stderr,
848 )
849 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr)
851 return False
853 except NameError as e:
854 err = e
855 # something is missing!
856 missing_vars = find_missing_variables(generated_code) - magic_vars
857 if not magic:
858 rich.print(
859 f"Your code is missing some variables: {missing_vars}. Add these or try --magic",
860 file=sys.stderr,
861 )
862 return False
864 # postponed: this can possibly also be achieved by updating the 'catch' dict
865 # instead of injecting in the string.
866 extra_code = extra_code + "\n" + textwrap.dedent(generate_magic_code(missing_vars))
868 code_before = remove_if_falsey_blocks(code_before)
869 code_after = remove_if_falsey_blocks(code_after)
871 generated_code = to_execute.substitute(
872 {
873 "tables": flatten(tables or []),
874 "db_type": db_type or "",
875 "extra": textwrap.dedent(extra_code),
876 "code_before": textwrap.dedent(code_before),
877 "code_after": textwrap.dedent(code_after),
878 }
879 )
880 except ImportError as e:
881 # should include ModuleNotFoundError
882 err = e
883 # if we catch an ImportError, we try to remove the import and retry
884 if not e.path:
885 # code exists in code itself
886 code_before = remove_import(code_before, e.name or "")
887 code_after = remove_import(code_after, e.name or "")
888 else:
889 to_remove = _handle_import_error(generated_code, e)
890 code_before = code_before.replace(to_remove, "\n")
891 code_after = code_after.replace(to_remove, "\n")
893 generated_code = to_execute.substitute(
894 {
895 "tables": flatten(tables or []),
896 "db_type": db_type or "",
897 "extra": textwrap.dedent(extra_code),
898 "code_before": textwrap.dedent(code_before),
899 "code_after": textwrap.dedent(code_after),
900 }
901 )
903 except KeyError as e:
904 err = e
905 table_name, table_definition = _handle_relation_error(e)
906 special_tables.add(table_name)
907 extra_code = extra_code + "\n" + textwrap.dedent(table_definition)
909 generated_code = to_execute.substitute(
910 {
911 "tables": flatten(tables or []),
912 "db_type": db_type or "",
913 "extra": textwrap.dedent(extra_code),
914 "code_before": textwrap.dedent(code_before),
915 "code_after": textwrap.dedent(code_after),
916 }
917 )
918 except Exception as e:
919 err = e
920 # otherwise: give up
921 retry_counter = 0
922 finally:
923 # reset:
924 typing.TYPE_CHECKING = False
926 if retry_counter < 1: # pragma: no cover
927 rich.print(
928 f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'} ({type(err)})", file=sys.stderr
929 )
930 return False
932 # idk when this would happen, but something definitely went wrong here:
933 return False # pragma: no cover
936def core_create(
937 filename: Optional[str] = None,
938 tables: Optional[list[str]] = None,
939 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
940 magic: bool = False,
941 noop: bool = False,
942 verbose: bool = False,
943 function: Optional[str | tuple[str, ...]] = None,
944 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
945 output_file: Optional[str | Path] = None,
946 _update_path: bool = True,
947) -> bool:
948 """
949 Generates SQL migration statements for creating one or more tables, based on the code in a given source file.
951 Args:
952 filename: The filename of the source file to parse. This code represents the final state of the database.
953 tables: A list of table names to generate SQL for.
954 If None, the function will attempt to process all tables found in the code.
955 db_type: The type of the database. If None, the function will attempt to infer it from the code.
956 magic: If True, automatically add missing variables for execution.
957 noop: If True, only print the generated code but do not execute it.
958 verbose: If True, print the generated code and additional debug information.
959 function: The name of the function where the tables are defined.
960 If None, the function will use 'define_tables'.
961 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
962 output_file: append the output to a file instead of printing it?
963 _update_path: try adding cwd to PYTHONPATH if failing?
965 Returns:
966 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
967 False otherwise.
969 Raises:
970 ValueError: If the source file cannot be found or if no tables could be found in the code.
971 """
972 git_root = find_git_root() or Path(os.getcwd())
974 functions: set[str] = set()
975 if function: # pragma: no cover
976 if isinstance(function, tuple):
977 functions.update(function)
978 else:
979 functions.add(function)
981 if filename and ":" in filename:
982 # e.g. models.py:define_tables
983 filename, _function = filename.split(":", 1)
984 functions.add(_function)
986 file_version, file_path = extract_file_version_and_path(
987 filename, default_version="current" if filename else "stdin"
988 )
989 file_exists, file_absolute_path = get_absolute_path_info(file_path, file_version, git_root)
991 if not file_exists:
992 raise FileNotFoundError(f"Source file {filename} could not be found.")
994 text = get_file_for_version(file_absolute_path, file_version, prompt_description="table definition")
996 return handle_cli(
997 "",
998 text,
999 db_type=db_type,
1000 tables=tables,
1001 verbose=verbose,
1002 noop=noop,
1003 magic=magic,
1004 function_name=tuple(functions),
1005 output_format=output_format,
1006 output_file=output_file,
1007 _update_path=_update_path,
1008 )
1011def core_alter(
1012 filename_before: Optional[str] = None,
1013 filename_after: Optional[str] = None,
1014 tables: Optional[list[str]] = None,
1015 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
1016 magic: bool = False,
1017 noop: bool = False,
1018 verbose: bool = False,
1019 function: Optional[str] = None,
1020 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
1021 output_file: Optional[str | Path] = None,
1022 _update_path: bool = True,
1023) -> bool:
1024 """
1025 Generates SQL migration statements for altering the database, based on the code in two given source files.
1027 Args:
1028 filename_before: The filename of the source file before changes.
1029 This code represents the initial state of the database.
1030 filename_after: The filename of the source file after changes.
1031 This code represents the final state of the database.
1032 tables: A list of table names to generate SQL for.
1033 If None, the function will attempt to process all tables found in the code.
1034 db_type: The type of the database. If None, the function will attempt to infer it from the code.
1035 magic: If True, automatically add missing variables for execution.
1036 noop: If True, only print the generated code but do not execute it.
1037 verbose: If True, print the generated code and additional debug information.
1038 function: The name of the function where the tables are defined.
1039 If None, the function will use 'define_tables'.
1040 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
1041 output_file: append the output to a file instead of printing it?
1042 _update_path: try adding cwd to PYTHONPATH if failing?
1044 Returns:
1045 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
1046 False otherwise.
1048 Raises:
1049 ValueError: If either of the source files cannot be found, if no tables could be found in the code,
1050 or if the codes before and after are identical.
1051 """
1052 git_root = find_git_root(filename_before) or find_git_root(filename_after)
1054 functions: set[str] = set()
1055 if function: # pragma: no cover
1056 functions.add(function)
1058 if filename_before and ":" in filename_before:
1059 # e.g. models.py:define_tables
1060 filename_before, _function = filename_before.split(":", 1)
1061 functions.add(_function)
1063 if filename_after and ":" in filename_after:
1064 # e.g. models.py:define_tables
1065 filename_after, _function = filename_after.split(":", 1)
1066 functions.add(_function)
1068 before, after = extract_file_versions_and_paths(filename_before, filename_after)
1070 version_before, filename_before = before
1071 version_after, filename_after = after
1073 # either ./file exists or /file exists (seen from git root):
1075 before_exists, before_absolute_path = get_absolute_path_info(filename_before, version_before, git_root)
1076 after_exists, after_absolute_path = get_absolute_path_info(filename_after, version_after, git_root)
1078 if not (before_exists and after_exists):
1079 message = ""
1080 message += "" if before_exists else f"Path {filename_before} does not exist! "
1081 if filename_before != filename_after:
1082 message += "" if after_exists else f"Path {filename_after} does not exist!"
1083 raise FileNotFoundError(message)
1085 try:
1086 code_before = get_file_for_version(
1087 before_absolute_path,
1088 version_before,
1089 prompt_description="current table definition",
1090 with_git=git_root is not None,
1091 )
1092 code_after = get_file_for_version(
1093 after_absolute_path,
1094 version_after,
1095 prompt_description="desired table definition",
1096 with_git=git_root is not None,
1097 )
1099 if not (code_before and code_after):
1100 message = ""
1101 message += "" if code_before else "Before code is empty (Maybe try `pydal2sql create`)! "
1102 message += "" if code_after else "After code is empty! "
1103 raise ValueError(message)
1105 if code_before == code_after:
1106 raise ValueError("Both contain the same code - nothing to alter!")
1108 except ValueError as e:
1109 rich.print(f"[yellow] {e} [/yellow]", file=sys.stderr)
1110 return False
1112 return handle_cli(
1113 code_before,
1114 code_after,
1115 db_type=db_type,
1116 tables=tables,
1117 verbose=verbose,
1118 noop=noop,
1119 magic=magic,
1120 function_name=tuple(functions),
1121 output_format=output_format,
1122 output_file=output_file,
1123 )
1126def core_stub(
1127 migration_name: str,
1128 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
1129 output_file: Optional[str | Path | io.StringIO] = None,
1130 dry_run: bool = False,
1131 is_typedal: bool = False,
1132):
1133 """
1134 Generate a dummy migration.
1136 (mostly useful for usage with output_format=edwh-migrate)
1137 """
1138 if dry_run:
1139 # None will print it to stdout
1140 output_file = None
1142 _handle_output(
1143 io.StringIO(f"-- {migration_name}\n"),
1144 output_file,
1145 output_format=output_format,
1146 is_typedal=is_typedal,
1147 default_migration_name=migration_name,
1148 )