Coverage for src/pydal2sql_core/cli_support.py: 100%
152 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 22:49 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 22:49 +0200
1"""
2CLI-Agnostic support.
3"""
4import contextlib
5import io
6import os
7import select
8import string
9import sys
10import textwrap
11import typing
12from pathlib import Path
13from typing import Optional
15import rich
16from black.files import find_project_root
17from git.objects.blob import Blob
18from git.objects.commit import Commit
19from git.repo import Repo
20from witchery import (
21 add_function_call,
22 find_defined_variables,
23 find_function_to_call,
24 find_missing_variables,
25 generate_magic_code,
26 has_local_imports,
27 remove_import,
28 remove_local_imports,
29 remove_specific_variables,
30)
32from .helpers import flatten
35def has_stdin_data() -> bool: # pragma: no cover
36 """
37 Check if the program starts with cli data (pipe | or redirect <).
39 Returns:
40 bool: True if the program starts with cli data, False otherwise.
42 See Also:
43 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data
44 """
45 return any(
46 select.select(
47 [
48 sys.stdin,
49 ],
50 [],
51 [],
52 0.0,
53 )[0]
54 )
57AnyCallable = typing.Callable[..., typing.Any]
60def print_if_interactive(*args: typing.Any, pretty: bool = True, **kwargs: typing.Any) -> None: # pragma: no cover
61 """
62 Print the given arguments if running in an interactive session.
64 Args:
65 *args: Variable length list of arguments to be printed.
66 pretty (bool): If True, print using rich library's rich.print, otherwise use the built-in print function.
67 **kwargs: Optional keyword arguments to be passed to the print function.
69 Returns:
70 None
71 """
72 is_interactive = not has_stdin_data()
73 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy
74 if is_interactive:
75 kwargs["file"] = sys.stderr
76 _print(
77 *args,
78 **kwargs,
79 )
82def find_git_root(at: str = None) -> Optional[Path]:
83 """
84 Find the root directory of the Git repository.
86 Args:
87 at (str, optional): The directory path to start the search. Defaults to the current working directory.
89 Returns:
90 Optional[Path]: The root directory of the Git repository if found, otherwise None.
91 """
92 folder, reason = find_project_root((at or os.getcwd(),))
93 if reason != ".git directory":
94 return None
95 return folder
98def find_git_repo(repo: Repo = None, at: str = None) -> Repo:
99 """
100 Find the Git repository instance.
102 Args:
103 repo (Repo, optional): An existing Git repository instance. If provided, returns the same instance.
104 at (str, optional): The directory path to start the search. Defaults to the current working directory.
106 Returns:
107 Repo: The Git repository instance.
108 """
109 if repo:
110 return repo
112 root = find_git_root(at)
113 return Repo(str(root))
116def latest_commit(repo: Repo = None) -> Commit:
117 """
118 Get the latest commit in the Git repository.
120 Args:
121 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
123 Returns:
124 Commit: The latest commit in the Git repository.
125 """
126 repo = find_git_repo(repo)
127 return repo.head.commit
130def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit:
131 """
132 Get a specific commit in the Git repository by its hash or name.
134 Args:
135 commit_hash (str): The hash of the commit to retrieve. Can also be e.g. a branch name.
136 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
138 Returns:
139 Commit: The commit object corresponding to the given commit hash.
140 """
141 repo = find_git_repo(repo)
142 return repo.commit(commit_hash)
145@contextlib.contextmanager
146def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]:
147 """
148 Open a Git Blob object as a context manager, providing access to its data.
150 Args:
151 file (Blob): The Git Blob object to open.
153 Yields:
154 io.BytesIO: A BytesIO object providing access to the Blob data.
155 """
156 yield io.BytesIO(file.data_stream.read())
159def read_blob(file: Blob) -> str:
160 """
161 Read the contents of a Git Blob object and decode it as a string.
163 Args:
164 file (Blob): The Git Blob object to read.
166 Returns:
167 str: The contents of the Blob as a string.
168 """
169 with open_blob(file) as f:
170 return f.read().decode()
173def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str:
174 """
175 Get the contents of a file in the Git repository at a specific commit version.
177 Args:
178 filename (str): The path of the file to retrieve.
179 commit_version (str, optional): The commit hash or branch name. Defaults to "latest" (latest commit).
180 repo (Repo, optional): An existing Git repository instance. If provided, uses the given instance.
182 Returns:
183 str: The contents of the file as a string.
184 """
185 repo = find_git_repo(repo, at=filename)
186 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo)
188 file_path = str(Path(filename).resolve())
189 # relative to the .git folder:
190 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/")
192 file_at_commit = commit.tree / relative_file_path
193 return read_blob(file_at_commit)
196def get_file_for_version(filename: str, version: str, prompt_description: str = "") -> str:
197 """
198 Get the contents of a file based on the version specified.
200 Args:
201 filename (str): The path of the file to retrieve.
202 version (str): The version specifier, which can be "current", "stdin", or a commit hash/branch name.
203 prompt_description (str, optional): A description to display when asking for input from stdin.
205 Returns:
206 str: The contents of the file as a string.
207 """
208 if version == "current":
209 return Path(filename).read_text()
210 elif version == "stdin": # pragma: no cover
211 print_if_interactive(
212 f"[blue]Please paste your define tables ({prompt_description}) code below "
213 f"and press ctrl-D when finished.[/blue]",
214 file=sys.stderr,
215 )
216 result = sys.stdin.read()
217 print_if_interactive("[blue]---[/blue]", file=sys.stderr)
218 return result
219 else:
220 return get_file_for_commit(filename, version)
223def extract_file_version_and_path(
224 file_path_or_git_tag: Optional[str], default_version: str = "stdin"
225) -> tuple[str, str | None]:
226 """
227 Extract the file version and path from the given input.
229 Args:
230 file_path_or_git_tag (str, optional): The input string containing the file path and/or Git tag.
231 default_version (str, optional): The default version to use if no version is specified. Defaults to "stdin".
233 Returns:
234 tuple[str, str | None]: A tuple containing the extracted version and file path (or None if not specified).
236 Examples:
237 myfile.py (implies @current)
239 myfile.py@latest
240 myfile.py@my-branch
241 myfile.py@b3f24091a9
243 @latest (implies no path, e.g. in case of ALTER to copy previously defined path)
244 """
245 if not file_path_or_git_tag:
246 return default_version, ""
248 if file_path_or_git_tag == "-":
249 return "stdin", "-"
251 if file_path_or_git_tag.startswith("@"):
252 file_version = file_path_or_git_tag.strip("@")
253 file_path = None
254 elif "@" in file_path_or_git_tag:
255 file_path, file_version = file_path_or_git_tag.split("@")
256 else:
257 file_version = default_version # `latest` for before; `current` for after.
258 file_path = file_path_or_git_tag
260 return file_version, file_path
263def extract_file_versions_and_paths(
264 filename_before: Optional[str], filename_after: Optional[str]
265) -> tuple[tuple[str, str | None], tuple[str, str | None]]:
266 """
267 Extract the file versions and paths based on the before and after filenames.
269 Args:
270 filename_before (str, optional): The path of the file before the change (or None).
271 filename_after (str, optional): The path of the file after the change (or None).
273 Returns:
274 tuple[tuple[str, str | None], tuple[str, str | None]]:
275 A tuple of two tuples, each containing the version and path of the before and after files.
276 """
277 version_before, filepath_before = extract_file_version_and_path(
278 filename_before,
279 default_version="current"
280 if filename_after and filename_before and filename_after != filename_before
281 else "latest",
282 )
283 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current")
285 if not (filepath_before or filepath_after):
286 raise ValueError("Please supply at least one file name.")
287 elif not filepath_after:
288 filepath_after = filepath_before
289 elif not filepath_before:
290 filepath_before = filepath_after
292 return (version_before, filepath_before), (version_after, filepath_after)
295def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]:
296 """
297 Get absolute path information for the file based on the version and Git root.
299 Args:
300 filename (str, optional): The path of the file to check (or None).
301 version (str): The version specifier, which can be "stdin", "current", or a commit hash/branch name.
302 git_root (Path, optional): The root directory of the Git repository. If None, it will be determined.
304 Returns:
305 tuple[bool, str]: A tuple containing a boolean indicating if the file exists and the absolute path to the file.
306 """
307 if version == "stdin":
308 return True, ""
309 elif filename is None:
310 # can't deal with this, not stdin and no file should show file missing error later.
311 return False, ""
313 if git_root is None:
314 git_root = find_git_root() or Path(os.getcwd())
316 path = Path(filename)
317 path_via_git = git_root / filename
319 if path.exists():
320 exists = True
321 absolute_path = str(path.resolve())
322 elif path_via_git.exists():
323 exists = True
324 absolute_path = str(path_via_git.resolve())
325 else:
326 exists = False
327 absolute_path = ""
329 return exists, absolute_path
332def ensure_no_migrate_on_real_db(
333 code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: typing.Optional[bool] = False
334) -> str:
335 """
336 Ensure that the code does not contain actual migrations on a real database.
338 It does this by removing definitions of 'db' and database. This can be changed by customizing `db_names`.
339 It also removes local imports to prevent irrelevant code being executed.
341 Args:
342 code (str): The code to check for database migrations.
343 db_names (Iterable[str], optional): Names of variables representing the database.
344 Defaults to ("db", "database").
345 fix (bool, optional): If True, removes the migration code. Defaults to False.
347 Returns:
348 str: The modified code with migration code removed if fix=True, otherwise the original code.
349 """
350 variables = find_defined_variables(code)
352 found_variables = set()
354 for db_name in db_names:
355 if db_name in variables:
356 if fix:
357 code = remove_specific_variables(code, db_names)
358 else:
359 found_variables.add(db_name)
361 if found_variables:
362 if len(found_variables) == 1:
363 var = next(iter(found_variables))
364 message = f"Variable {var} defined in code! "
365 else: # pragma: no cover
366 var = ", ".join(found_variables)
367 message = f"Variables {var} defined in code! "
368 raise ValueError(
369 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database."
370 )
372 if has_local_imports(code):
373 if fix:
374 code = remove_local_imports(code)
375 else:
376 raise ValueError("Local imports are used in this file! Please remove these or use --magic.")
378 return code
381MAX_RETRIES = 20
385# todo: overload more methods
388def handle_cli(
389 code_before: str,
390 code_after: str,
391 db_type: Optional[str] = None,
392 tables: Optional[list[str] | list[list[str]]] = None,
393 verbose: Optional[bool] = False,
394 noop: Optional[bool] = False,
395 magic: Optional[bool] = False,
396 function_name: Optional[str] = "define_tables",
397) -> bool:
398 """
399 Handle user input for generating SQL migration statements based on before and after code.
401 Args:
402 code_before (str): The code representing the state of the database before the change.
403 code_after (str, optional): The code representing the state of the database after the change.
404 db_type (str, optional): The type of the database (e.g., "postgres", "mysql", etc.). Defaults to None.
405 tables (list[str] or list[list[str]], optional): The list of tables to generate SQL for. Defaults to None.
406 verbose (bool, optional): If True, print the generated code. Defaults to False.
407 noop (bool, optional): If True, only print the generated code but do not execute it. Defaults to False.
408 magic (bool, optional): If True, automatically add missing variables for execution. Defaults to False.
409 function_name (str, optional): The name of the function where the tables are defined. Defaults: "define_tables".
411 Returns:
412 bool: True if SQL migration statements are generated and executed successfully, False otherwise.
413 """
414 # todo: prefix (e.g. public.)
416 to_execute = string.Template(
417 textwrap.dedent(
418 """
419from pydal import *
420from pydal.objects import *
421from pydal.validators import *
423from pydal2sql import generate_sql
426from pydal import DAL
427db = database = DAL(None, migrate=False)
429tables = $tables
430db_type = '$db_type'
432$extra
434$code_before
436db_old = db
437db_new = db = database = DAL(None, migrate=False)
439$code_after
441if not tables:
442 tables = set(db_old._tables + db_new._tables)
444if not tables:
445 raise ValueError('no-tables-found')
448for table in tables:
449 print('--', table)
450 if table in db_old and table in db_new:
451 print(generate_sql(db_old[table], db_new[table], db_type=db_type))
452 elif table in db_old:
453 print(f'DROP TABLE {table};')
454 else:
455 print(generate_sql(db_new[table], db_type=db_type))
456 """
457 )
458 )
460 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic)
461 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic)
463 generated_code = to_execute.substitute(
464 {
465 "tables": flatten(tables or []),
466 "db_type": db_type or "",
467 "code_before": textwrap.dedent(code_before),
468 "code_after": textwrap.dedent(code_after),
469 "extra": "",
470 }
471 )
472 if verbose or noop:
473 rich.print(generated_code, file=sys.stderr)
475 if not noop:
476 err: typing.Optional[Exception] = None
477 catch = {}
478 retry_counter = MAX_RETRIES
479 while retry_counter > 0:
480 retry_counter -= 1
481 try:
482 if verbose:
483 rich.print(generated_code, file=sys.stderr)
485 # 'catch' is used to add and receive globals from the exec scope.
486 # another argument could be added for locals, but adding simply {} changes the behavior negatively.
487 # so for now, only globals is passed.
488 exec(generated_code, catch) # nosec: B102
489 return True # success!
490 except ValueError as e:
491 err = e
493 if str(e) != "no-tables-found": # pragma: no cover
494 raise e
496 if function_name:
497 define_tables = find_function_to_call(generated_code, function_name)
499 # if define_tables function is found, add call to it at end of code
500 if define_tables is not None:
501 generated_code = add_function_call(generated_code, function_name, multiple=True)
502 continue
504 # else: no define_tables or other method to use found.
506 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr)
507 print(
508 "Please use `db.define_table` or `database.define_table`, "
509 "or if you really need to use an alias like my_db.define_tables, "
510 "add `my_db = db` at the top of the file or pass `--db-name mydb`.",
511 file=sys.stderr,
512 )
513 print(f"You can also specify a --function to use something else than {function_name}.", file=sys.stderr)
515 return False
517 except NameError as e:
518 err = e
519 # something is missing!
520 missing_vars = find_missing_variables(generated_code)
521 if not magic:
522 rich.print(
523 f"Your code is missing some variables: {missing_vars}. Add these or try --magic",
524 file=sys.stderr,
525 )
526 return False
528 # postponed: this can possibly also be achieved by updating 'catch'.
529 extra_code = generate_magic_code(missing_vars)
531 generated_code = to_execute.substitute(
532 {
533 "tables": flatten(tables or []),
534 "db_type": db_type or "",
535 "extra": textwrap.dedent(extra_code),
536 "code_before": textwrap.dedent(code_before),
537 "code_after": textwrap.dedent(code_after),
538 }
539 )
540 except ImportError as e:
541 err = e
542 # if we catch an ImportError, we try to remove the import and retry
543 generated_code = remove_import(generated_code, e.name)
545 if retry_counter < 1: # pragma: no cover
546 rich.print(f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'}", file=sys.stderr)
547 return False
549 return True