Coverage for src/pydal2sql_core/cli_support.py: 100%
318 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-11-20 12:23 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2023-11-20 12:23 +0100
1"""
2CLI-Agnostic support.
3"""
4import contextlib
5import io
6import os
7import re
8import select
9import string
10import sys
11import textwrap
12import traceback
13import typing
14from datetime import datetime
15from pathlib import Path
16from typing import Any, Optional
18import git
19import gitdb.exc
20import rich
21from black.files import find_project_root
22from git.objects.blob import Blob
23from git.objects.commit import Commit
24from git.repo import Repo
25from witchery import (
26 add_function_call,
27 find_defined_variables,
28 find_function_to_call,
29 find_missing_variables,
30 generate_magic_code,
31 has_local_imports,
32 remove_if_falsey_blocks,
33 remove_import,
34 remove_local_imports,
35 remove_specific_variables,
36)
38from .helpers import flatten
39from .types import (
40 _SUPPORTED_OUTPUT_FORMATS,
41 DEFAULT_OUTPUT_FORMAT,
42 SUPPORTED_DATABASE_TYPES_WITH_ALIASES,
43 SUPPORTED_OUTPUT_FORMATS,
44 DummyDAL,
45 DummyTypeDAL,
46)
49def has_stdin_data() -> bool: # pragma: no cover
50 """
51 Check if the program starts with cli data (pipe | or redirect <).
53 Returns:
54 bool: True if the program starts with cli data, False otherwise.
56 See Also:
57 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data
58 """
59 return any(
60 select.select(
61 [
62 sys.stdin,
63 ],
64 [],
65 [],
66 0.0,
67 )[0]
68 )
71AnyCallable = typing.Callable[..., Any]
74def print_if_interactive(*args: Any, pretty: bool = True, **kwargs: Any) -> None: # pragma: no cover
75 """
76 Print the given arguments if running in an interactive session.
78 Args:
79 *args: Variable length list of arguments to be printed.
80 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function.
81 **kwargs: Optional keyword arguments to be passed to the print function.
83 Returns:
84 None
85 """
86 is_interactive = not has_stdin_data()
87 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy
88 if is_interactive:
89 kwargs["file"] = sys.stderr
90 _print(
91 *args,
92 **kwargs,
93 )
96def find_git_root(at: str = None) -> Optional[Path]:
97 """
98 Find the root directory of the Git repository.
100 Args:
101 at (str, optional): The directory path to start the search. Defaults to the current working directory.
103 Returns:
104 Optional[Path]: The root directory of the Git repository if found, otherwise None.
105 """
106 folder, reason = find_project_root((at or os.getcwd(),))
107 if reason != ".git directory":
108 return None
109 return folder
112def find_git_repo(repo: Repo = None, at: str = None) -> Repo:
113 """
114 Find the Git repository instance.
116 Args:
117 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance.
118 at (str, optional): The directory path to start the search. Defaults to the current working directory.
120 Returns:
121 Repo: The Git repository instance.
122 """
123 if repo:
124 return repo
126 root = find_git_root(at)
127 return Repo(str(root))
130def latest_commit(repo: Repo = None) -> Commit:
131 """
132 Get the latest commit in the Git repository.
134 Args:
135 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
137 Returns:
138 Commit: The latest commit in the Git repository.
139 """
140 repo = find_git_repo(repo)
141 return repo.head.commit
144def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit:
145 """
146 Get a specific commit in the Git repository by its hash or name.
148 Args:
149 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name.
150 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
152 Returns:
153 Commit: The commit object corresponding to the given commit hash.
154 """
155 repo = find_git_repo(repo)
156 return repo.commit(commit_hash)
159@contextlib.contextmanager
160def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]:
161 """
162 Open a Git Blob object as a context manager, providing access to its data.
164 Args:
165 file (Blob): The Git Blob object to open.
167 Yields:
168 io.BytesIO: A BytesIO object providing access to the Blob data.
169 """
170 yield io.BytesIO(file.data_stream.read())
173def read_blob(file: Blob) -> str:
174 """
175 Read the contents of a Git Blob object and decode it as a string.
177 Args:
178 file (Blob): The Git Blob object to read.
180 Returns:
181 str: The contents of the Blob as a string.
182 """
183 with open_blob(file) as f:
184 return f.read().decode()
187def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str:
188 """
189 Get the contents of a file in the Git repository at a specific commit version.
191 Args:
192 filename (str): The path of the file to retrieve.
193 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit).
194 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
196 Returns:
197 str: The contents of the file as a string.
198 """
199 repo = find_git_repo(repo, at=filename)
200 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo)
202 file_path = str(Path(filename).resolve())
203 # relative to the .git folder:
204 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/")
206 file_at_commit = commit.tree / relative_file_path
207 return read_blob(file_at_commit)
210def get_file_for_version(filename: str, version: str, prompt_description: str = "", with_git: bool = True) -> str:
211 """
212 Get the contents of a file based on the version specified.
214 Args:
215 filename (str): The path of the file to retrieve.
216 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name.
217 prompt_description (str, optional): A description to display when asking for input from stdin.
219 Returns:
220 str: The contents of the file as a string.
221 """
222 if not with_git or version == "current":
223 return Path(filename).read_text()
224 elif version == "stdin": # pragma: no cover
225 print_if_interactive(
226 f"[blue]Please paste your define tables ({prompt_description}) code below "
227 f"and press ctrl-D when finished.[/blue]",
228 file=sys.stderr,
229 )
230 result = sys.stdin.read()
231 print_if_interactive("[blue]---[/blue]", file=sys.stderr)
232 return result
233 elif with_git:
234 try:
235 return get_file_for_commit(filename, version)
236 except (git.exc.GitError, gitdb.exc.ODBError) as e:
237 raise FileNotFoundError(f"{filename}@{version}") from e
240def extract_file_version_and_path(
241 file_path_or_git_tag: Optional[str], default_version: str = "stdin"
242) -> tuple[str, str | None]:
243 """
244 Extract the file version and path from the given input.
246 Args:
247 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag.
248 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin".
250 Returns:
251 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified).
253 Examples:
254 myfile.py (implies @current)
256 myfile.py@latest
257 myfile.py@my-branch
258 myfile.py@b3f24091a9
260 @latest (implies no path, e.g. in case of ALTER to copy previously defined path)
261 """
262 if not file_path_or_git_tag:
263 return default_version, ""
265 if file_path_or_git_tag == "-":
266 return "stdin", "-"
268 if file_path_or_git_tag.startswith("@"):
269 file_version = file_path_or_git_tag.strip("@")
270 file_path = None
271 elif "@" in file_path_or_git_tag:
272 file_path, file_version = file_path_or_git_tag.split("@")
273 else:
274 file_version = default_version # `latest` for before; `current` for after.
275 file_path = file_path_or_git_tag
277 return file_version, file_path
280def extract_file_versions_and_paths(
281 filename_before: Optional[str], filename_after: Optional[str]
282) -> tuple[tuple[str, str | None], tuple[str, str | None]]:
283 """
284 Extract the file versions and paths based on the before and after filenames.
286 Args:
287 filename_before (str, optional): The path of the file before the change (or None).
288 filename_after (str, optional): The path of the file after the change (or None).
290 Returns:
291 tuple[tuple[str, str | None], tuple[str, str | None]]:
292 A tuple of two tuples, each containing the version and path of the before and after files.
293 """
294 version_before, filepath_before = extract_file_version_and_path(
295 filename_before,
296 default_version="current"
297 if filename_after and filename_before and filename_after != filename_before
298 else "latest",
299 )
300 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current")
302 if not (filepath_before or filepath_after):
303 raise ValueError("Please supply at least one file name.")
304 elif not filepath_after:
305 filepath_after = filepath_before
306 elif not filepath_before:
307 filepath_before = filepath_after
309 return (version_before, filepath_before), (version_after, filepath_after)
312def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]:
313 """
314 Get absolute path information for the file based on the version and Git root.
316 Args:
317 filename (str, optional): The path of the file to check (or None).
318 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name.
319 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined.
321 Returns:
322 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file.
323 """
324 if version == "stdin":
325 return True, ""
326 elif filename is None:
327 # can't deal with this, not stdin and no file should show file missing error later.
328 return False, ""
330 if git_root is None:
331 git_root = find_git_root() or Path(os.getcwd())
333 path = Path(filename)
334 path_via_git = git_root / filename
336 if path.exists():
337 exists = True
338 absolute_path = str(path.resolve())
339 elif path_via_git.exists():
340 exists = True
341 absolute_path = str(path_via_git.resolve())
342 else:
343 exists = False
344 absolute_path = ""
346 return exists, absolute_path
349def ensure_no_migrate_on_real_db(
350 code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: typing.Optional[bool] = False
351) -> str:
352 """
353 Ensure that the code does not contain actual migrations on a real database.
355 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`.
356 It also removes local imports to prevent irrelevant code being executed.
358 Args:
359 code (str): The code to check for database migrations.
360 db_names (Iterable[str], optional): Names of variables representing the database.
361 Defaults to ("db", "database").
362 fix (bool, optional): If True, removes the migration code. Defaults to False.
364 Returns:
365 str: The modified code with migration code removed if fix=True, otherwise the original code.
366 """
367 variables = find_defined_variables(code)
369 found_variables = set()
371 for db_name in db_names:
372 if db_name in variables:
373 if fix:
374 code = remove_specific_variables(code, db_names)
375 else:
376 found_variables.add(db_name)
378 if found_variables:
379 if len(found_variables) == 1:
380 var = next(iter(found_variables))
381 message = f"Variable {var} defined in code! "
382 else: # pragma: no cover
383 var = ", ".join(found_variables)
384 message = f"Variables {var} defined in code! "
385 raise ValueError(
386 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database."
387 )
389 if has_local_imports(code):
390 if fix:
391 code = remove_local_imports(code)
392 else:
393 raise ValueError("Local imports are used in this file! Please remove these or use --magic.")
395 return code
398MAX_RETRIES = 30
400# todo: overload more methods
402TEMPLATE_PYDAL = """
403from pydal import *
404from pydal.objects import *
405from pydal.validators import *
407from pydal2sql_core import generate_sql
410# from pydal import DAL
411db = database = DummyDAL(None, migrate=False)
413tables = $tables
414db_type = '$db_type'
416$extra
418$code_before
420db_old = db
421db_new = db = database = DummyDAL(None, migrate=False)
423$extra
425$code_after
427if not tables:
428 tables = set(db_old._tables + db_new._tables) - _special_tables
430if not tables:
431 raise ValueError('no-tables-found')
433for table in tables:
434 print('-- start ', table, '--', file=_file)
435 if table in db_old and table in db_new:
436 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file)
437 elif table in db_old:
438 print(f'DROP TABLE {table};', file=_file)
439 else:
440 print(generate_sql(db_new[table], db_type=db_type), file=_file)
441 print('-- END OF MIGRATION --', file=_file)
442 """
444TEMPLATE_TYPEDAL = """
445from pydal import *
446from pydal.objects import *
447from pydal.validators import *
448from typedal import *
450from pydal2sql_core import generate_sql
453# from typedal import TypeDAL as DAL
454db = database = DummyDAL(None, migrate=False)
456tables = $tables
457db_type = '$db_type'
459$extra
461$code_before
463db_old = db
464db_new = db = database = DummyDAL(None, migrate=False)
466$extra
468$code_after
470if not tables:
471 tables = set(db_old._tables + db_new._tables) - _special_tables
473if not tables:
474 raise ValueError('no-tables-found')
477for table in tables:
478 print('-- start ', table, '--', file=_file)
479 if table in db_old and table in db_new:
480 print(generate_sql(db_old[table], db_new[table], db_type=db_type), file=_file)
481 elif table in db_old:
482 print(f'DROP TABLE {table};', file=_file)
483 else:
484 print(generate_sql(db_new[table], db_type=db_type), file=_file)
486 print('-- END OF MIGRATION --', file=_file)
487 """
490def sql_to_function_name(sql_statement: str) -> str:
491 """
492 Extract action (CREATE, ALTER, DROP) and table name from the SQL statement.
493 """
494 match = re.findall(r"(CREATE|ALTER|DROP)\s+TABLE\s+['\"]?(\w+)['\"]?", sql_statement.lower(), re.IGNORECASE)
496 if not match:
497 # raise ValueError("Invalid SQL statement. Unable to extract action and table name.")
498 return "unknown_migration"
500 action, table_name = match[0]
502 # Generate a function name with the specified format
503 return f"{action}_{table_name}"
506def _setup_generic_edwh_migrate(file: Path, is_typedal: bool) -> None:
507 contents = (
508 "from edwh_migrate import migration\n"
509 + ("from typedal import TypeDAL" if is_typedal else "from pydal import DAL")
510 + "\n"
511 )
513 with file.open("w") as f:
514 f.write(textwrap.dedent(contents))
516 rich.print(f"[green] New migrate file {file} created [/green]")
519START_RE = re.compile(r"-- start\s+\w+\s--\n")
522def _build_edwh_migration(contents: str, cls: str, date: str, existing: Optional[str] = None) -> str:
523 sql_func_name = sql_to_function_name(contents)
524 func_name = "_placeholder_"
525 contents = START_RE.sub("", contents)
527 for n in range(1, 1000):
528 func_name = f"{sql_func_name}_{date}_{str(n).zfill(3)}"
530 if existing and f"def {func_name}" in existing:
531 if contents.replace(" ", "").replace("\n", "") in existing.replace(" ", "").replace("\n", ""):
532 rich.print(f"[yellow] migration {func_name} already exists, skipping! [/yellow]")
533 return ""
534 elif func_name.startswith("alter"):
535 # bump number because alter migrations are different
536 continue
537 else:
538 rich.print(
539 f"[red] migration {func_name} already exists [bold]with different contents[/bold], skipping! [/red]"
540 )
541 return ""
542 else:
543 # okay function name, stop incrementing
544 break
546 contents = textwrap.indent(contents.strip(), " " * 16)
548 if not contents.strip():
549 # no real migration!
550 return ""
552 return textwrap.dedent(
553 f'''
555 @migration
556 def {func_name}(db: {cls}):
557 db.executesql("""
558{contents}
559 """)
560 db.commit()
562 return True
563 '''
564 )
567def _build_edwh_migrations(contents: str, is_typedal: bool, output: Optional[Path] = None) -> str:
568 cls = "TypeDAL" if is_typedal else "DAL"
569 date = datetime.now().strftime("%Y%m%d") # yyyymmdd
571 existing = output.read_text() if output and output.exists() else None
573 return "".join(
574 _build_edwh_migration(migration, cls, date, existing)
575 for migration in contents.split("-- END OF MIGRATION --")
576 if migration.strip()
577 )
580def _handle_output(
581 file: io.StringIO,
582 output_file: Path | str | io.StringIO | None,
583 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
584 is_typedal: bool = False,
585) -> None:
586 file.seek(0)
587 contents = file.read()
589 if isinstance(output_file, str):
590 output_file = Path(output_file)
592 if output_format == "edwh-migrate":
593 contents = _build_edwh_migrations(contents, is_typedal, output_file if isinstance(output_file, Path) else None)
594 elif output_format in {"default", "sql"} or not output_format:
595 contents = "\n".join(contents.split("-- END OF MIGRATION --"))
596 else:
597 raise ValueError(
598 f"Unknown format {output_format}. " f"Please choose one of {typing.get_args(_SUPPORTED_OUTPUT_FORMATS)}"
599 )
601 if isinstance(output_file, Path):
602 if output_format == "edwh-migrate" and (not output_file.exists() or output_file.stat().st_size == 0):
603 _setup_generic_edwh_migrate(output_file, is_typedal)
605 if contents.strip():
606 with output_file.open("a") as f:
607 f.write(contents)
609 rich.print(f"[green] Written migration(s) to {output_file} [/green]")
610 else:
611 rich.print(f"[yellow] Nothing to write to {output_file} [/yellow]")
613 elif isinstance(output_file, io.StringIO):
614 output_file.write(contents)
615 else:
616 # no file, just print to stdout:
617 print(contents.strip())
620IMPORT_IN_STR = re.compile(r'File "<string>", line (\d+), in <module>')
623def _handle_import_error(code: str, error: ImportError) -> str:
624 # error is deeper in a package, find the related import in the code:
625 tb_lines = traceback.format_exc().splitlines()
627 for line in tb_lines:
628 if matches := IMPORT_IN_STR.findall(line):
629 # 'File "<string>", line 15, in <module>'
630 line_no = int(matches[0]) - 1
631 lines = code.split("\n")
632 return lines[line_no]
634 # I don't know how to trigger this case:
635 raise ValueError("Faulty import could not be automatically deleted") from error # pragma: no cover
638MISSING_RELATIONSHIP = re.compile(r"Cannot resolve reference (\w+) in \w+ definition")
641def _handle_relation_error(error: KeyError) -> tuple[str, str]:
642 if not (table := MISSING_RELATIONSHIP.findall(str(error))):
643 # other error, raise again
644 raise error
646 t = table[0]
648 return (
649 t,
650 """
651 db.define_table('%s', redefine=True)
652 """
653 % t,
654 )
657def handle_cli(
658 code_before: str,
659 code_after: str,
660 db_type: Optional[str] = None,
661 tables: Optional[list[str] | list[list[str]]] = None,
662 verbose: Optional[bool] = False,
663 noop: Optional[bool] = False,
664 magic: Optional[bool] = False,
665 function_name: Optional[str | tuple[str, ...]] = "define_tables",
666 use_typedal: bool | typing.Literal["auto"] = "auto",
667 output_format: SUPPORTED_OUTPUT_FORMATS = DEFAULT_OUTPUT_FORMAT,
668 output_file: Optional[str | Path | io.StringIO] = None,
669) -> bool:
670 """
671 Handle user input for generating SQL migration statements based on before and after code.
673 Args:
674 code_before (str): The code representing the state of the database before the change.
675 code_after (str, optional): The code representing the state of the database after the change.
676 db_type (str, optional): The type of the database (e.g., "postgres", "mysql", etc.). Defaults to None.
677 tables (list[str] or list[list[str]], optional): The list of tables to generate SQL for. Defaults to None.
678 verbose (bool, optional): If True, print the generated code. Defaults to False.
679 noop (bool, optional): If True, only print the generated code but do not execute it. Defaults to False.
680 magic (bool, optional): If True, automatically add missing variables for execution. Defaults to False.
681 function_name (str, optional): The name of the function where the tables are defined. Defaults: "define_tables".
682 use_typedal: replace pydal imports with TypeDAL?
683 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
684 output_file: append the output to a file instead of printing it?
686 # todo: prefix (e.g. public.)
688 Returns:
689 bool: True if SQL migration statements are generated and executed successfully, False otherwise.
690 """
691 # todo: better typedal checking
692 if use_typedal == "auto":
693 use_typedal = "typedal" in code_before.lower() or "typedal" in code_after.lower()
695 if function_name:
696 define_table_functions: set[str] = set(function_name) if isinstance(function_name, tuple) else {function_name}
697 else:
698 define_table_functions = set()
700 template = TEMPLATE_TYPEDAL if use_typedal else TEMPLATE_PYDAL
702 to_execute = string.Template(textwrap.dedent(template))
704 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic)
705 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic)
706 extra_code = ""
708 generated_code = to_execute.substitute(
709 {
710 "tables": flatten(tables or []),
711 "db_type": db_type or "",
712 "code_before": textwrap.dedent(code_before),
713 "code_after": textwrap.dedent(code_after),
714 "extra": extra_code,
715 }
716 )
717 if verbose or noop:
718 rich.print(generated_code, file=sys.stderr)
720 if noop:
721 # done
722 return True
724 err: typing.Optional[Exception] = None
725 catch: dict[str, Any] = {}
726 retry_counter = MAX_RETRIES
728 magic_vars = {"_file", "DummyDAL", "_special_tables"}
729 special_tables: set[str] = {"typedal_cache", "typedal_cache_dependency"} if use_typedal else set()
731 while retry_counter:
732 retry_counter -= 1
733 try:
734 if verbose:
735 rich.print(generated_code, file=sys.stderr)
737 # 'catch' is used to add and receive globals from the exec scope.
738 # another argument could be added for locals, but adding simply {} changes the behavior negatively.
739 # so for now, only globals is passed.
740 catch["_file"] = io.StringIO() # <- every print should go to this file, so we can handle it afterwards
741 catch["DummyDAL"] = (
742 DummyTypeDAL if use_typedal else DummyDAL
743 ) # <- use a fake DAL that doesn't actually run queries
744 catch["_special_tables"] = special_tables # <- e.g. typedal_cache, auth_user
745 # note: when adding something to 'catch', also add it to magic_vars!!!
747 exec(generated_code, catch) # nosec: B102
748 _handle_output(catch["_file"], output_file, output_format, is_typedal=use_typedal)
749 return True # success!
750 except ValueError as e:
751 if str(e) != "no-tables-found": # pragma: no cover
752 rich.print(f"[yellow]{e}[/yellow]", file=sys.stderr)
753 return False
755 if define_table_functions:
756 any_found = False
757 for function_name in define_table_functions:
758 define_tables = find_function_to_call(generated_code, function_name)
760 # if define_tables function is found, add call to it at end of code
761 if define_tables is not None:
762 generated_code = add_function_call(generated_code, function_name, multiple=True)
763 any_found = True
765 if any_found:
766 # hurray!
767 continue
769 # else: no define_tables or other method to use found.
771 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr)
772 if use_typedal:
773 print(
774 "Please use `db.define` or `database.define`, "
775 "or if you really need to use an alias like my_db.define, "
776 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
777 file=sys.stderr,
778 )
779 else:
780 print(
781 "Please use `db.define_table` or `database.define_table`, "
782 "or if you really need to use an alias like my_db.define_tables, "
783 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
784 file=sys.stderr,
785 )
786 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr)
788 return False
790 except NameError as e:
791 err = e
792 # something is missing!
793 missing_vars = find_missing_variables(generated_code) - magic_vars
794 if not magic:
795 rich.print(
796 f"Your code is missing some variables: {missing_vars}. Add these or try --magic",
797 file=sys.stderr,
798 )
799 return False
801 # postponed: this can possibly also be achieved by updating the 'catch' dict
802 # instead of injecting in the string.
803 extra_code = extra_code + "\n" + textwrap.dedent(generate_magic_code(missing_vars))
805 code_before = remove_if_falsey_blocks(code_before)
806 code_after = remove_if_falsey_blocks(code_after)
808 generated_code = to_execute.substitute(
809 {
810 "tables": flatten(tables or []),
811 "db_type": db_type or "",
812 "extra": textwrap.dedent(extra_code),
813 "code_before": textwrap.dedent(code_before),
814 "code_after": textwrap.dedent(code_after),
815 }
816 )
817 except ImportError as e:
818 # should include ModuleNotFoundError
819 err = e
820 # if we catch an ImportError, we try to remove the import and retry
821 if not e.path:
822 # code exists in code itself
823 code_before = remove_import(code_before, e.name or "")
824 code_after = remove_import(code_after, e.name or "")
825 else:
826 to_remove = _handle_import_error(generated_code, e)
827 code_before = code_before.replace(to_remove, "\n")
828 code_after = code_after.replace(to_remove, "\n")
830 generated_code = to_execute.substitute(
831 {
832 "tables": flatten(tables or []),
833 "db_type": db_type or "",
834 "extra": textwrap.dedent(extra_code),
835 "code_before": textwrap.dedent(code_before),
836 "code_after": textwrap.dedent(code_after),
837 }
838 )
840 except KeyError as e:
841 err = e
842 table_name, table_definition = _handle_relation_error(e)
843 special_tables.add(table_name)
844 extra_code = extra_code + "\n" + textwrap.dedent(table_definition)
846 generated_code = to_execute.substitute(
847 {
848 "tables": flatten(tables or []),
849 "db_type": db_type or "",
850 "extra": textwrap.dedent(extra_code),
851 "code_before": textwrap.dedent(code_before),
852 "code_after": textwrap.dedent(code_after),
853 }
854 )
855 except Exception as e:
856 err = e
857 # otherwise: give up
858 retry_counter = 0
859 finally:
860 # reset:
861 typing.TYPE_CHECKING = False
863 if retry_counter < 1: # pragma: no cover
864 rich.print(f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'}", file=sys.stderr)
865 return False
867 # idk when this would happen, but something definitely went wrong here:
868 return False # pragma: no cover
871def core_create(
872 filename: Optional[str] = None,
873 tables: Optional[list[str]] = None,
874 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
875 magic: bool = False,
876 noop: bool = False,
877 verbose: bool = False,
878 function: Optional[str | tuple[str, ...]] = None,
879 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
880 output_file: Optional[str | Path] = None,
881) -> bool:
882 """
883 Generates SQL migration statements for creating one or more tables, based on the code in a given source file.
885 Args:
886 filename: The filename of the source file to parse. This code represents the final state of the database.
887 tables: A list of table names to generate SQL for.
888 If None, the function will attempt to process all tables found in the code.
889 db_type: The type of the database. If None, the function will attempt to infer it from the code.
890 magic: If True, automatically add missing variables for execution.
891 noop: If True, only print the generated code but do not execute it.
892 verbose: If True, print the generated code and additional debug information.
893 function: The name of the function where the tables are defined.
894 If None, the function will use 'define_tables'.
895 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
896 output_file: append the output to a file instead of printing it?
898 Returns:
899 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
900 False otherwise.
902 Raises:
903 ValueError: If the source file cannot be found or if no tables could be found in the code.
904 """
905 git_root = find_git_root() or Path(os.getcwd())
907 functions: set[str] = set()
908 if function: # pragma: no cover
909 if isinstance(function, tuple):
910 functions.update(function)
911 else:
912 functions.add(function)
914 if filename and ":" in filename:
915 # e.g. models.py:define_tables
916 filename, _function = filename.split(":", 1)
917 functions.add(_function)
919 file_version, file_path = extract_file_version_and_path(
920 filename, default_version="current" if filename else "stdin"
921 )
922 file_exists, file_absolute_path = get_absolute_path_info(file_path, file_version, git_root)
924 if not file_exists:
925 raise FileNotFoundError(f"Source file {filename} could not be found.")
927 text = get_file_for_version(file_absolute_path, file_version, prompt_description="table definition")
929 return handle_cli(
930 "",
931 text,
932 db_type=db_type,
933 tables=tables,
934 verbose=verbose,
935 noop=noop,
936 magic=magic,
937 function_name=tuple(functions),
938 output_format=output_format,
939 output_file=output_file,
940 )
943def core_alter(
944 filename_before: Optional[str] = None,
945 filename_after: Optional[str] = None,
946 tables: Optional[list[str]] = None,
947 db_type: Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None,
948 magic: bool = False,
949 noop: bool = False,
950 verbose: bool = False,
951 function: Optional[str] = None,
952 output_format: Optional[SUPPORTED_OUTPUT_FORMATS] = DEFAULT_OUTPUT_FORMAT,
953 output_file: Optional[str | Path] = None,
954) -> bool:
955 """
956 Generates SQL migration statements for altering the database, based on the code in two given source files.
958 Args:
959 filename_before: The filename of the source file before changes.
960 This code represents the initial state of the database.
961 filename_after: The filename of the source file after changes.
962 This code represents the final state of the database.
963 tables: A list of table names to generate SQL for.
964 If None, the function will attempt to process all tables found in the code.
965 db_type: The type of the database. If None, the function will attempt to infer it from the code.
966 magic: If True, automatically add missing variables for execution.
967 noop: If True, only print the generated code but do not execute it.
968 verbose: If True, print the generated code and additional debug information.
969 function: The name of the function where the tables are defined.
970 If None, the function will use 'define_tables'.
971 output_format: defaults to just SQL, edwh-migrate migration syntax also supported
972 output_file: append the output to a file instead of printing it?
974 Returns:
975 bool: True if SQL migration statements are generated and (if not in noop mode) executed successfully,
976 False otherwise.
978 Raises:
979 ValueError: If either of the source files cannot be found, if no tables could be found in the code,
980 or if the codes before and after are identical.
981 """
982 git_root = find_git_root(filename_before) or find_git_root(filename_after)
984 functions: set[str] = set()
985 if function: # pragma: no cover
986 functions.add(function)
988 if filename_before and ":" in filename_before:
989 # e.g. models.py:define_tables
990 filename_before, _function = filename_before.split(":", 1)
991 functions.add(_function)
993 if filename_after and ":" in filename_after:
994 # e.g. models.py:define_tables
995 filename_after, _function = filename_after.split(":", 1)
996 functions.add(_function)
998 before, after = extract_file_versions_and_paths(filename_before, filename_after)
1000 version_before, filename_before = before
1001 version_after, filename_after = after
1003 # either ./file exists or /file exists (seen from git root):
1005 before_exists, before_absolute_path = get_absolute_path_info(filename_before, version_before, git_root)
1006 after_exists, after_absolute_path = get_absolute_path_info(filename_after, version_after, git_root)
1008 if not (before_exists and after_exists):
1009 message = ""
1010 message += "" if before_exists else f"Path {filename_before} does not exist! "
1011 if filename_before != filename_after:
1012 message += "" if after_exists else f"Path {filename_after} does not exist!"
1013 raise FileNotFoundError(message)
1015 try:
1016 code_before = get_file_for_version(
1017 before_absolute_path,
1018 version_before,
1019 prompt_description="current table definition",
1020 with_git=git_root is not None,
1021 )
1022 code_after = get_file_for_version(
1023 after_absolute_path,
1024 version_after,
1025 prompt_description="desired table definition",
1026 with_git=git_root is not None,
1027 )
1029 if not (code_before and code_after):
1030 message = ""
1031 message += "" if code_before else "Before code is empty (Maybe try `pydal2sql create`)! "
1032 message += "" if code_after else "After code is empty! "
1033 raise ValueError(message)
1035 if code_before == code_after:
1036 raise ValueError("Both contain the same code!")
1038 except ValueError as e:
1039 rich.print(f"[yellow] alter failed ({e}), trying create! [/yellow]", file=sys.stderr)
1040 try:
1041 return core_create(
1042 filename_after or filename_before,
1043 tables,
1044 db_type,
1045 magic,
1046 noop,
1047 verbose,
1048 tuple(functions),
1049 output_format,
1050 output_file,
1051 )
1052 except Exception: # pragma: no cover
1053 return False
1055 return handle_cli(
1056 code_before,
1057 code_after,
1058 db_type=db_type,
1059 tables=tables,
1060 verbose=verbose,
1061 noop=noop,
1062 magic=magic,
1063 function_name=tuple(functions),
1064 output_format=output_format,
1065 output_file=output_file,
1066 )