Coverage for src/pydal2sql_core/cli_support.py: 100%
301 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-11-17 17:33 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2023-11-17 17:33 +0100
1"""
2CLI-Agnostic support.
3"""
4import contextlib
5import io
6import os
7import re
8import select
9import string
10import sys
11import textwrap
12import traceback
13import typing
14import warnings
15from datetime import datetime
16from pathlib import Path
17from typing import Any, Optional
19import git
20import gitdb.exc
21import rich
22from black.files import find_project_root
23from git.objects.blob import Blob
24from git.objects.commit import Commit
25from git.repo import Repo
26from witchery import (
27 add_function_call,
28 find_defined_variables,
29 find_function_to_call,
30 find_missing_variables,
31 generate_magic_code,
32 has_local_imports,
33 remove_if_falsey_blocks,
34 remove_import,
35 remove_local_imports,
36 remove_specific_variables,
37)
39from .helpers import flatten
40from .types import (
41 _SUPPORTED_OUTPUT_FORMATS,
42 DEFAULT_OUTPUT_FORMAT,
43 SUPPORTED_DATABASE_TYPES_WITH_ALIASES,
44 SUPPORTED_OUTPUT_FORMATS,
45 DummyDAL,
46 DummyTypeDAL,
47)
50def has_stdin_data() -> bool: # pragma: no cover
51 """
52 Check if the program starts with cli data (pipe | or redirect <).
54 Returns:
55 bool: True if the program starts with cli data, False otherwise.
57 See Also:
58 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data
59 """
60 return any(
61 select.select(
62 [
63 sys.stdin,
64 ],
65 [],
66 [],
67 0.0,
68 )[0]
69 )
72AnyCallable = typing.Callable[..., Any]
75def print_if_interactive(*args: Any, pretty: bool = True, **kwargs: Any) -> None: # pragma: no cover
76 """
77 Print the given arguments if running in an interactive session.
79 Args:
80 *args: Variable length list of arguments to be printed.
81 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function.
82 **kwargs: Optional keyword arguments to be passed to the print function.
84 Returns:
85 None
86 """
87 is_interactive = not has_stdin_data()
88 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy
89 if is_interactive:
90 kwargs["file"] = sys.stderr
91 _print(
92 *args,
93 **kwargs,
94 )
97def find_git_root(at: str = None) -> Optional[Path]:
98 """
99 Find the root directory of the Git repository.
101 Args:
102 at (str, optional): The directory path to start the search. Defaults to the current working directory.
104 Returns:
105 Optional[Path]: The root directory of the Git repository if found, otherwise None.
106 """
107 folder, reason = find_project_root((at or os.getcwd(),))
108 if reason != ".git directory":
109 return None
110 return folder
113def find_git_repo(repo: Repo = None, at: str = None) -> Repo:
114 """
115 Find the Git repository instance.
117 Args:
118 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance.
119 at (str, optional): The directory path to start the search. Defaults to the current working directory.
121 Returns:
122 Repo: The Git repository instance.
123 """
124 if repo:
125 return repo
127 root = find_git_root(at)
128 return Repo(str(root))
131def latest_commit(repo: Repo = None) -> Commit:
132 """
133 Get the latest commit in the Git repository.
135 Args:
136 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
138 Returns:
139 Commit: The latest commit in the Git repository.
140 """
141 repo = find_git_repo(repo)
142 return repo.head.commit
145def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit:
146 """
147 Get a specific commit in the Git repository by its hash or name.
149 Args:
150 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name.
151 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
153 Returns:
154 Commit: The commit object corresponding to the given commit hash.
155 """
156 repo = find_git_repo(repo)
157 return repo.commit(commit_hash)
160@contextlib.contextmanager
161def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]:
162 """
163 Open a Git Blob object as a context manager, providing access to its data.
165 Args:
166 file (Blob): The Git Blob object to open.
168 Yields:
169 io.BytesIO: A BytesIO object providing access to the Blob data.
170 """
171 yield io.BytesIO(file.data_stream.read())
174def read_blob(file: Blob) -> str:
175 """
176 Read the contents of a Git Blob object and decode it as a string.
178 Args:
179 file (Blob): The Git Blob object to read.
181 Returns:
182 str: The contents of the Blob as a string.
183 """
184 with open_blob(file) as f:
185 return f.read().decode()
188def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str:
189 """
190 Get the contents of a file in the Git repository at a specific commit version.
192 Args:
193 filename (str): The path of the file to retrieve.
194 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit).
195 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
197 Returns:
198 str: The contents of the file as a string.
199 """
200 repo = find_git_repo(repo, at=filename)
201 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo)
203 file_path = str(Path(filename).resolve())
204 # relative to the .git folder:
205 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/")
207 file_at_commit = commit.tree / relative_file_path
208 return read_blob(file_at_commit)
211def get_file_for_version(filename: str, version: str, prompt_description: str = "", with_git: bool = True) -> str:
212 """
213 Get the contents of a file based on the version specified.
215 Args:
216 filename (str): The path of the file to retrieve.
217 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name.
218 prompt_description (str, optional): A description to display when asking for input from stdin.
220 Returns:
221 str: The contents of the file as a string.
222 """
223 if not with_git or version == "current":
224 return Path(filename).read_text()
225 elif version == "stdin": # pragma: no cover
226 print_if_interactive(
227 f"[blue]Please paste your define tables ({prompt_description}) code below "
228 f"and press ctrl-D when finished.[/blue]",
229 file=sys.stderr,
230 )
231 result = sys.stdin.read()
232 print_if_interactive("[blue]---[/blue]", file=sys.stderr)
233 return result
234 elif with_git:
235 try:
236 return get_file_for_commit(filename, version)
237 except (git.exc.GitError, gitdb.exc.ODBError) as e:
238 raise FileNotFoundError(f"{filename}@{version}") from e
241def extract_file_version_and_path(
242 file_path_or_git_tag: Optional[str], default_version: str = "stdin"
243) -> tuple[str, str | None]:
244 """
245 Extract the file version and path from the given input.
247 Args:
248 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag.
249 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin".
251 Returns:
252 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified).
254 Examples:
255 myfile.py (implies @current)
257 myfile.py@latest
258 myfile.py@my-branch
259 myfile.py@b3f24091a9
261 @latest (implies no path, e.g. in case of ALTER to copy previously defined path)
262 """
263 if not file_path_or_git_tag:
264 return default_version, ""
266 if file_path_or_git_tag == "-":
267 return "stdin", "-"
269 if file_path_or_git_tag.startswith("@"):
270 file_version = file_path_or_git_tag.strip("@")
271 file_path = None
272 elif "@" in file_path_or_git_tag:
273 file_path, file_version = file_path_or_git_tag.split("@")
274 else:
275 file_version = default_version # `latest` for before; `current` for after.
276 file_path = file_path_or_git_tag
278 return file_version, file_path
281def extract_file_versions_and_paths(
282 filename_before: Optional[str], filename_after: Optional[str]
283) -> tuple[tuple[str, str | None], tuple[str, str | None]]:
284 """
285 Extract the file versions and paths based on the before and after filenames.
287 Args:
288 filename_before (str, optional): The path of the file before the change (or None).
289 filename_after (str, optional): The path of the file after the change (or None).
291 Returns:
292 tuple[tuple[str, str | None], tuple[str, str | None]]:
293 A tuple of two tuples, each containing the version and path of the before and after files.
294 """
295 version_before, filepath_before = extract_file_version_and_path(
296 filename_before,
297 default_version="current"
298 if filename_after and filename_before and filename_after != filename_before
299 else "latest",
300 )
301 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current")
303 if not (filepath_before or filepath_after):
304 raise ValueError("Please supply at least one file name.")
305 elif not filepath_after:
306 filepath_after = filepath_before
307 elif not filepath_before:
308 filepath_before = filepath_after
310 return (version_before, filepath_before), (version_after, filepath_after)
313def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]:
314 """
315 Get absolute path information for the file based on the version and Git root.
317 Args:
318 filename (str, optional): The path of the file to check (or None).
319 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name.
320 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined.
322 Returns:
323 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file.
324 """
325 if version == "stdin":
326 return True, ""
327 elif filename is None:
328 # can't deal with this, not stdin and no file should show file missing error later.
329 return False, ""
331 if git_root is None:
332 git_root = find_git_root() or Path(os.getcwd())
334 path = Path(filename)
335 path_via_git = git_root / filename
337 if path.exists():
338 exists = True
339 absolute_path = str(path.resolve())
340 elif path_via_git.exists():
341 exists = True
342 absolute_path = str(path_via_git.resolve())
343 else:
344 exists = False
345 absolute_path = ""
347 return exists, absolute_path
350def ensure_no_migrate_on_real_db(
351 code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: typing.Optional[bool] = False
352) -> str:
353 """
354 Ensure that the code does not contain actual migrations on a real database.
356 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`.
357 It also removes local imports to prevent irrelevant code being executed.
359 Args:
360 code (str): The code to check for database migrations.
361 db_names (Iterable[str], optional): Names of variables representing the database.
362 Defaults to ("db", "database").
363 fix (bool, optional): If True, removes the migration code. Defaults to False.
365 Returns:
366 str: The modified code with migration code removed if fix=True, otherwise the original code.
367 """
368 variables = find_defined_variables(code)
370 found_variables = set()
372 for db_name in db_names:
373 if db_name in variables:
374 if fix:
375 code = remove_specific_variables(code, db_names)
376 else:
377 found_variables.add(db_name)
379 if found_variables:
380 if len(found_variables) == 1:
381 var = next(iter(found_variables))
382 message = f"Variable {var} defined in code! "
383 else: # pragma: no cover
384 var = ", ".join(found_variables)
385 message = f"Variables {var} defined in code! "
386 raise ValueError(
387 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database."
388 )
390 if has_local_imports(code):
391 if fix:
392 code = remove_local_imports(code)
393 else:
394 raise ValueError("Local imports are used in this file! Please remove these or use --magic.")
396 return code
399MAX_RETRIES = 30
401# todo: overload more methods
403TEMPLATE_PYDAL = """
404from pydal import *
405from pydal.objects import *
406from pydal.validators import *
408from pydal2sql_core import generate_sql
411# from pydal import DAL
412db = database = DummyDAL(None, migrate=False)
414tables = $tables
415db_type = '$db_type'
417$extra
419$code_before
421db_old = db
422db_new = db = database = DummyDAL(None, migrate=False)
424$extra
426$code_after
428if not tables:
429 tables = set(db_old._tables + db_new._tables) - _special_tables
431if not tables:
432 raise ValueError('no-tables-found')
434for table in tables:
435 print('-- start ', table, '--', file=_file)
436 if table in db_old and table in db_new:
437 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file)
438 elif table in db_old:
439 print(f'DROP TABLE {table};', file=_file)
440 else:
441 print(generate_sql(db_new[table], db_type=db_type), file=_file)
442 print('-- END OF MIGRATION --', file=_file)
443 """
445TEMPLATE_TYPEDAL = """
446from pydal import *
447from pydal.objects import *
448from pydal.validators import *
449from typedal import *
451from pydal2sql_core import generate_sql
454# from typedal import TypeDAL as DAL
455db = database = DummyDAL(None, migrate=False)
457tables = $tables
458db_type = '$db_type'
460$extra
462$code_before
464db_old = db
465db_new = db = database = DummyDAL(None, migrate=False)
467$extra
469$code_after
471if not tables:
472 tables = set(db_old._tables + db_new._tables) - _special_tables
474if not tables:
475 raise ValueError('no-tables-found')
478for table in tables:
479 print('-- start ', table, '--', file=_file)
480 if table in db_old and table in db_new:
481 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file)
482 elif table in db_old:
483 print(f'DROP TABLE {table};', file=_file)
484 else:
485 print(generate_sql(db_new[table], db_type=db_type), file=_file)
487 print('-- END OF MIGRATION --', file=_file)
488 """
491def sql_to_function_name(sql_statement: str) -> str:
492 """
493 Extract action (CREATE, ALTER, DROP) and table name from the SQL statement.
494 """
495 match = re.findall(r"(CREATE|ALTER|DROP)\s+TABLE\s+['\"]?(\w+)['\"]?", sql_statement.lower(), re.IGNORECASE)
497 if not match:
498 # raise ValueError("Invalid SQL statement. Unable to extract action and table name.")
499 return "unknown_migration"
501 action, table_name = match[0]
503 # Generate a function name with the specified format
504 return f"{action}_{table_name}"
507def _setup_generic_edwh_migrate(file: Path, is_typedal: bool) -> None:
508 contents = (
509 "from edwh_migrate import migration\n"
510 + ("from typedal import TypeDAL" if is_typedal else "from pydal import DAL")
511 + "\n"
512 )
514 with file.open("w") as f:
515 f.write(textwrap.dedent(contents))
517 rich.print(f"[green] New migrate file {file} created [/green]")
520def _build_edwh_migration(contents: str, cls: str, date: str) -> str:
521 func_name = sql_to_function_name(contents)
522 func_name = f"{func_name}_{date}_001"
524 # fixme: on `create` and `drop`, if `func_name` already exists - stop
525 # on `alter`, set a new order number (-> 002)
526 contents = textwrap.indent(contents.strip(), " " * 16)
527 return textwrap.dedent(
528 f'''
530 @migration
531 def {func_name}(db: {cls}):
532 db.executesql("""
533{contents}
534 );
535 """)
536 db.commit()
538 return True
539 '''
540 )
543def _build_edwh_migrations(contents: str, is_typedal: bool) -> str:
544 cls = "TypeDAL" if is_typedal else "DAL"
545 date = datetime.now().strftime("%Y%m%d") # yyyymmdd
547 return "".join(
548 _build_edwh_migration(migration, cls, date)
549 for migration in contents.split("-- END OF MIGRATION --")
550 if migration.strip()
551 )
554def _handle_output(
555 file: io.StringIO,
556 output_file: Path | str | io.StringIO | None,
557 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
558 is_typedal: bool = False,
559) -> None:
560 file.seek(0)
561 contents = file.read()
563 if output_format == "edwh-migrate":
564 contents = _build_edwh_migrations(contents, is_typedal)
565 elif output_format in {"default", "sql"} or not output_format:
566 contents = "\n".join(contents.split("-- END OF MIGRATION --"))
567 else:
568 raise ValueError(
569 f"Unknown format {output_format}. " f"Please choose one of {typing.get_args(_SUPPORTED_OUTPUT_FORMATS)}"
570 )
572 if isinstance(output_file, str):
573 output_file = Path(output_file)
575 if isinstance(output_file, Path):
576 if output_format == "edwh-migrate" and (not output_file.exists() or output_file.stat().st_size == 0):
577 _setup_generic_edwh_migrate(output_file, is_typedal)
579 with output_file.open("a") as f:
580 f.write(contents)
582 rich.print(f"[green] Written migration(s) to {output_file} [/green]")
584 elif isinstance(output_file, io.StringIO):
585 output_file.write(contents)
586 else:
587 # no file, just print to stdout:
588 print(contents)
591IMPORT_IN_STR = re.compile(r'File "<string>", line (\d+), in <module>')
594def _handle_import_error(code: str, error: ImportError) -> str:
595 # error is deeper in a package, find the related import in the code:
596 tb_lines = traceback.format_exc().splitlines()
598 for line in tb_lines:
599 if matches := IMPORT_IN_STR.findall(line):
600 # 'File "<string>", line 15, in <module>'
601 line_no = int(matches[0]) - 1
602 lines = code.split("\n")
603 return lines[line_no]
605 # I don't know how to trigger this case:
606 raise ValueError("Faulty import could not be automatically deleted") from error # pragma: no cover
609MISSING_RELATIONSHIP = re.compile(r"Cannot resolve reference (\w+) in \w+ definition")
612def _handle_relation_error(error: KeyError) -> tuple[str, str]:
613 if not (table := MISSING_RELATIONSHIP.findall(str(error))):
614 # other error, raise again
615 raise error
617 t = table[0]
619 return (
620 t,
621 """
622 db.define_table('%s', redefine=True)
623 """
624 % t,
625 )
628def handle_cli(
629 code_before: str,
630 code_after: str,
631 db_type: Optional[str] = None,
632 tables: Optional[list[str] | list[list[str]]] = None,
633 verbose: Optional[bool] = False,
634 noop: Optional[bool] = False,
635 magic: Optional[bool] = False,
636 function_name: Optional[str | tuple[str, ...]] = "define_tables",
637 use_typedal: bool | typing.Literal["auto"] = "auto",
638 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
639 output_file: Optional[str | Path | io.StringIO] = None,
640) -> bool:
641 """
642 Handle user input for generating SQL migration statements based on before and after code.
644 Args:
645 code_before (str): The code representing the state of the database before the change.
646 code_after (str, optional): The code representing the state of the database after the change.
647 db_type (str, optional): The type of the database (e.g., "postgres", "mysql", etc.). Defaults to None.
648 tables (list[str] or list[list[str]], optional): The list of tables to generate SQL for. Defaults to None.
649 verbose (bool, optional): If True, print the generated code. Defaults to False.
650 noop (bool, optional): If True, only print the generated code but do not execute it. Defaults to False.
651 magic (bool, optional): If True, automatically add missing variables for execution. Defaults to False.
652 function_name (str, optional): The name of the function where the tables are defined. Defaults: "define_tables".
653 use_typedal: replace pydal imports with TypeDAL?
654 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
655 output_file: append the output to a file instead of printing it?
657 # todo: prefix (e.g. public.)
659 Returns:
660 bool: True if SQL migration statements are generated and executed successfully, False otherwise.
661 """
662 # todo: better typedal checking
663 if use_typedal == "auto":
664 use_typedal = "typedal" in code_before.lower() or "typedal" in code_after.lower()
666 if function_name:
667 define_table_functions: set[str] = set(function_name) if isinstance(function_name, tuple) else {function_name}
668 else:
669 define_table_functions = set()
671 template = TEMPLATE_TYPEDAL if use_typedal else TEMPLATE_PYDAL
673 to_execute = string.Template(textwrap.dedent(template))
675 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic)
676 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic)
677 extra_code = ""
679 generated_code = to_execute.substitute(
680 {
681 "tables": flatten(tables or []),
682 "db_type": db_type or "",
683 "code_before": textwrap.dedent(code_before),
684 "code_after": textwrap.dedent(code_after),
685 "extra": extra_code,
686 }
687 )
688 if verbose or noop:
689 rich.print(generated_code, file=sys.stderr)
691 if noop:
692 # done
693 return True
695 err: typing.Optional[Exception] = None
696 catch: dict[str, Any] = {}
697 retry_counter = MAX_RETRIES
699 magic_vars = {"_file", "DummyDAL", "_special_tables"}
700 special_tables: set[str] = {"typedal_cache", "typedal_cache_dependency"} if use_typedal else set()
702 while retry_counter:
703 retry_counter -= 1
704 try:
705 if verbose:
706 rich.print(generated_code, file=sys.stderr)
708 # 'catch' is used to add and receive globals from the exec scope.
709 # another argument could be added for locals, but adding simply {} changes the behavior negatively.
710 # so for now, only globals is passed.
711 catch["_file"] = io.StringIO() # <- every print should go to this file, so we can handle it afterwards
712 catch["DummyDAL"] = (
713 DummyTypeDAL if use_typedal else DummyDAL
714 ) # <- use a fake DAL that doesn't actually run queries
715 catch["_special_tables"] = special_tables # <- e.g. typedal_cache, auth_user
716 # note: when adding something to 'catch', also add it to magic_vars!!!
718 exec(generated_code, catch) # nosec: B102
719 _handle_output(catch["_file"], output_file, output_format, is_typedal=use_typedal)
720 return True # success!
721 except ValueError as e:
722 if str(e) != "no-tables-found": # pragma: no cover
723 warnings.warn(str(e), source=e)
724 return False
726 if define_table_functions:
727 any_found = False
728 for function_name in define_table_functions:
729 define_tables = find_function_to_call(generated_code, function_name)
731 # if define_tables function is found, add call to it at end of code
732 if define_tables is not None:
733 generated_code = add_function_call(generated_code, function_name, multiple=True)
734 any_found = True
736 if any_found:
737 # hurray!
738 continue
740 # else: no define_tables or other method to use found.
742 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr)
743 if use_typedal:
744 print(
745 "Please use `db.define` or `database.define`, "
746 "or if you really need to use an alias like my_db.define, "
747 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
748 file=sys.stderr,
749 )
750 else:
751 print(
752 "Please use `db.define_table` or `database.define_table`, "
753 "or if you really need to use an alias like my_db.define_tables, "
754 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
755 file=sys.stderr,
756 )
757 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr)
759 return False
761 except NameError as e:
762 err = e
763 # something is missing!
764 missing_vars = find_missing_variables(generated_code) - magic_vars
765 if not magic:
766 rich.print(
767 f"Your code is missing some variables: {missing_vars}. Add these or try --magic",
768 file=sys.stderr,
769 )
770 return False
772 # postponed: this can possibly also be achieved by updating the 'catch' dict
773 # instead of injecting in the string.
774 extra_code = extra_code + "\n" + textwrap.dedent(generate_magic_code(missing_vars))
776 code_before = remove_if_falsey_blocks(code_before)
777 code_after = remove_if_falsey_blocks(code_after)
779 generated_code = to_execute.substitute(
780 {
781 "tables": flatten(tables or []),
782 "db_type": db_type or "",
783 "extra": textwrap.dedent(extra_code),
784 "code_before": textwrap.dedent(code_before),
785 "code_after": textwrap.dedent(code_after),
786 }
787 )
788 except ImportError as e:
789 # should include ModuleNotFoundError
790 err = e
791 # if we catch an ImportError, we try to remove the import and retry
792 if not e.path:
793 # code exists in code itself
794 code_before = remove_import(code_before, e.name or "")
795 code_after = remove_import(code_after, e.name or "")
796 else:
797 to_remove = _handle_import_error(generated_code, e)
798 code_before = code_before.replace(to_remove, "\n")
799 code_after = code_after.replace(to_remove, "\n")
801 generated_code = to_execute.substitute(
802 {
803 "tables": flatten(tables or []),
804 "db_type": db_type or "",
805 "extra": textwrap.dedent(extra_code),
806 "code_before": textwrap.dedent(code_before),
807 "code_after": textwrap.dedent(code_after),
808 }
809 )
811 except KeyError as e:
812 err = e
813 table_name, table_definition = _handle_relation_error(e)
814 special_tables.add(table_name)
815 extra_code = extra_code + "\n" + textwrap.dedent(table_definition)
817 generated_code = to_execute.substitute(
818 {
819 "tables": flatten(tables or []),
820 "db_type": db_type or "",
821 "extra": textwrap.dedent(extra_code),
822 "code_before": textwrap.dedent(code_before),
823 "code_after": textwrap.dedent(code_after),
824 }
825 )
826 except Exception as e:
827 err = e
828 # otherwise: give up
829 retry_counter = 0
830 finally:
831 # reset:
832 typing.TYPE_CHECKING = False
834 if retry_counter < 1: # pragma: no cover
835 rich.print(f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'}", file=sys.stderr)
836 return False
838 # idk when this would happen, but something definitely went wrong here:
839 return False # pragma: no cover
842def core_create(
843 filename: Optional[str] = None,
844 tables: Optional[list[str]] = None,
845 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
846 magic: bool = False,
847 noop: bool = False,
848 verbose: bool = False,
849 function: Optional[str | tuple[str, ...]] = None,
850 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
851 output_file: Optional[str | Path] = None,
852) -> bool:
853 """
854 Generates SQL migration statements for creating one or more tables, based on the code in a given source file.
856 Args:
857 filename: The filename of the source file to parse. This code represents the final state of the database.
858 tables: A list of table names to generate SQL for.
859 If None, the function will attempt to process all tables found in the code.
860 db_type: The type of the database. If None, the function will attempt to infer it from the code.
861 magic: If True, automatically add missing variables for execution.
862 noop: If True, only print the generated code but do not execute it.
863 verbose: If True, print the generated code and additional debug information.
864 function: The name of the function where the tables are defined.
865 If None, the function will use 'define_tables'.
866 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
867 output_file: append the output to a file instead of printing it?
869 Returns:
870 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
871 False otherwise.
873 Raises:
874 ValueError: If the source file cannot be found or if no tables could be found in the code.
875 """
876 git_root = find_git_root() or Path(os.getcwd())
878 functions: set[str] = set()
879 if function: # pragma: no cover
880 if isinstance(function, tuple):
881 functions.update(function)
882 else:
883 functions.add(function)
885 if filename and ":" in filename:
886 # e.g. models.py:define_tables
887 filename, _function = filename.split(":", 1)
888 functions.add(_function)
890 file_version, file_path = extract_file_version_and_path(
891 filename, default_version="current" if filename else "stdin"
892 )
893 file_exists, file_absolute_path = get_absolute_path_info(file_path, file_version, git_root)
895 if not file_exists:
896 raise FileNotFoundError(f"Source file {filename} could not be found.")
898 text = get_file_for_version(file_absolute_path, file_version, prompt_description="table definition")
900 return handle_cli(
901 "",
902 text,
903 db_type=db_type,
904 tables=tables,
905 verbose=verbose,
906 noop=noop,
907 magic=magic,
908 function_name=tuple(functions),
909 output_format=output_format,
910 output_file=output_file,
911 )
914def core_alter(
915 filename_before: Optional[str] = None,
916 filename_after: Optional[str] = None,
917 tables: Optional[list[str]] = None,
918 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
919 magic: bool = False,
920 noop: bool = False,
921 verbose: bool = False,
922 function: Optional[str] = None,
923 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
924 output_file: Optional[str | Path] = None,
925) -> bool:
926 """
927 Generates SQL migration statements for altering the database, based on the code in two given source files.
929 Args:
930 filename_before: The filename of the source file before changes.
931 This code represents the initial state of the database.
932 filename_after: The filename of the source file after changes.
933 This code represents the final state of the database.
934 tables: A list of table names to generate SQL for.
935 If None, the function will attempt to process all tables found in the code.
936 db_type: The type of the database. If None, the function will attempt to infer it from the code.
937 magic: If True, automatically add missing variables for execution.
938 noop: If True, only print the generated code but do not execute it.
939 verbose: If True, print the generated code and additional debug information.
940 function: The name of the function where the tables are defined.
941 If None, the function will use 'define_tables'.
942 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
943 output_file: append the output to a file instead of printing it?
945 Returns:
946 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
947 False otherwise.
949 Raises:
950 ValueError: If either of the source files cannot be found, if no tables could be found in the code,
951 or if the codes before and after are identical.
952 """
953 git_root = find_git_root(filename_before) or find_git_root(filename_after)
955 functions: set[str] = set()
956 if function: # pragma: no cover
957 functions.add(function)
959 if filename_before and ":" in filename_before:
960 # e.g. models.py:define_tables
961 filename_before, _function = filename_before.split(":", 1)
962 functions.add(_function)
964 if filename_after and ":" in filename_after:
965 # e.g. models.py:define_tables
966 filename_after, _function = filename_after.split(":", 1)
967 functions.add(_function)
969 before, after = extract_file_versions_and_paths(filename_before, filename_after)
971 version_before, filename_before = before
972 version_after, filename_after = after
974 # either ./file exists or /file exists (seen from git root):
976 before_exists, before_absolute_path = get_absolute_path_info(filename_before, version_before, git_root)
977 after_exists, after_absolute_path = get_absolute_path_info(filename_after, version_after, git_root)
979 if not (before_exists and after_exists):
980 message = ""
981 message += "" if before_exists else f"Path {filename_before} does not exist! "
982 if filename_before != filename_after:
983 message += "" if after_exists else f"Path {filename_after} does not exist!"
984 raise FileNotFoundError(message)
986 try:
987 code_before = get_file_for_version(
988 before_absolute_path,
989 version_before,
990 prompt_description="current table definition",
991 with_git=git_root is not None,
992 )
993 code_after = get_file_for_version(
994 after_absolute_path,
995 version_after,
996 prompt_description="desired table definition",
997 with_git=git_root is not None,
998 )
1000 if not (code_before and code_after):
1001 message = ""
1002 message += "" if code_before else "Before code is empty (Maybe try `pydal2sql create`)! "
1003 message += "" if code_after else "After code is empty! "
1004 raise ValueError(message)
1006 if code_before == code_after:
1007 raise ValueError("Both contain the same code!")
1009 except ValueError as e:
1010 rich.print(f"[yellow] alter failed ({e}), trying create! [/yellow]", file=sys.stderr)
1011 try:
1012 return core_create(
1013 filename_after or filename_before,
1014 tables,
1015 db_type,
1016 magic,
1017 noop,
1018 verbose,
1019 tuple(functions),
1020 output_format,
1021 output_file,
1022 )
1023 except Exception: # pragma: no cover
1024 return False
1026 return handle_cli(
1027 code_before,
1028 code_after,
1029 db_type=db_type,
1030 tables=tables,
1031 verbose=verbose,
1032 noop=noop,
1033 magic=magic,
1034 function_name=tuple(functions),
1035 output_format=output_format,
1036 output_file=output_file,
1037 )