Coverage for src/pydal2sql/cli_support.py: 97%
154 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 16:22 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 16:22 +0200
1"""
2CLI-Agnostic support.
3"""
4import contextlib
5import io
6import os
7import select
8import string
9import sys
10import textwrap
11import typing
12from pathlib import Path
13from typing import Optional
15import rich
16from black.files import find_project_root
17from git.objects.blob import Blob
18from git.objects.commit import Commit
19from git.repo import Repo
21from .helpers import flatten
22from .magic import (
23 find_defined_variables,
24 find_local_imports,
25 find_missing_variables,
26 generate_magic_code,
27 remove_import,
28 remove_local_imports,
29 remove_specific_variables, find_function_to_call, add_function_call,
30)
33def has_stdin_data() -> bool: # pragma: no cover
34 """
35 Check if the program starts with cli data (pipe | or redirect ><).
37 See Also:
38 https://stackoverflow.com/questions/3762881/how-do-i-check-if-stdin-has-some-data
39 """
40 return any(
41 select.select(
42 [
43 sys.stdin,
44 ],
45 [],
46 [],
47 0.0,
48 )[0]
49 )
52AnyCallable = typing.Callable[..., typing.Any]
55def print_if_interactive(*args: typing.Any, pretty: bool = True, **kwargs: typing.Any) -> None: # pragma: no cover
56 is_interactive = not has_stdin_data()
57 _print = typing.cast(AnyCallable, rich.print if pretty else print) # make mypy happy
58 if is_interactive:
59 kwargs["file"] = sys.stderr
60 _print(
61 *args,
62 **kwargs,
63 )
66def find_git_root(at: str = None) -> Optional[Path]:
67 folder, reason = find_project_root((at or os.getcwd(),))
68 if reason != ".git directory":
69 return None
70 return folder
73def find_git_repo(repo: Repo = None, at: str = None) -> Repo:
74 if repo:
75 return repo
77 root = find_git_root(at)
78 return Repo(str(root))
81def latest_commit(repo: Repo = None) -> Commit:
82 repo = find_git_repo(repo)
83 return repo.head.commit
86def commit_by_id(commit_hash: str, repo: Repo = None) -> Commit:
87 repo = find_git_repo(repo)
88 return repo.commit(commit_hash)
91@contextlib.contextmanager
92def open_blob(file: Blob) -> typing.Generator[io.BytesIO, None, None]:
93 yield io.BytesIO(file.data_stream.read())
96def read_blob(file: Blob) -> str:
97 with open_blob(file) as f:
98 return f.read().decode()
101def get_file_for_commit(filename: str, commit_version: str = "latest", repo: Repo = None) -> str:
102 repo = find_git_repo(repo, at=filename)
103 commit = latest_commit(repo) if commit_version == "latest" else commit_by_id(commit_version, repo)
105 file_path = str(Path(filename).resolve())
106 # relative to the .git folder:
107 relative_file_path = file_path.removeprefix(f"{repo.working_dir}/")
109 file_at_commit = commit.tree / relative_file_path
110 return read_blob(file_at_commit)
113def get_file_for_version(filename: str, version: str, prompt_description: str = "") -> str:
114 if version == "current":
115 return Path(filename).read_text()
116 elif version == "stdin": # pragma: no cover
117 print_if_interactive(
118 f"[blue]Please paste your define tables ({prompt_description}) code below "
119 f"and press ctrl-D when finished.[/blue]",
120 file=sys.stderr,
121 )
122 result = sys.stdin.read()
123 print_if_interactive("[blue]---[/blue]", file=sys.stderr)
124 return result
125 else:
126 return get_file_for_commit(filename, version)
129def extract_file_version_and_path(
130 file_path_or_git_tag: Optional[str], default_version: str = "stdin"
131) -> tuple[str, str | None]:
132 """
134 Examples:
135 myfile.py (implies @current)
137 myfile.py@latest
138 myfile.py@my-branch
139 myfile.py@b3f24091a9
141 @latest (implies no path, e.g. in case of ALTER to copy previously defined path)
142 """
143 if not file_path_or_git_tag:
144 return default_version, ""
146 if file_path_or_git_tag == "-":
147 return "stdin", "-"
149 if file_path_or_git_tag.startswith("@"):
150 file_version = file_path_or_git_tag.strip("@")
151 file_path = None
152 elif "@" in file_path_or_git_tag:
153 file_path, file_version = file_path_or_git_tag.split("@")
154 else:
155 file_version = default_version # `latest` for before; `current` for after.
156 file_path = file_path_or_git_tag
158 return file_version, file_path
161def extract_file_versions_and_paths(
162 filename_before: Optional[str], filename_after: Optional[str]
163) -> tuple[tuple[str, str | None], tuple[str, str | None]]:
164 version_before, filepath_before = extract_file_version_and_path(
165 filename_before,
166 default_version="current"
167 if filename_after and filename_before and filename_after != filename_before
168 else "latest",
169 )
170 version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current")
172 if not (filepath_before or filepath_after):
173 raise ValueError("Please supply at least one file name.")
174 elif not filepath_after:
175 filepath_after = filepath_before
176 elif not filepath_before:
177 filepath_before = filepath_after
179 return (version_before, filepath_before), (version_after, filepath_after)
182def get_absolute_path_info(filename: Optional[str], version: str, git_root: Optional[Path] = None) -> tuple[bool, str]:
183 if version == "stdin":
184 return True, ""
185 elif filename is None:
186 # can't deal with this, not stdin and no file should show file missing error later.
187 return False, ""
189 if git_root is None:
190 git_root = find_git_root() or Path(os.getcwd())
192 path = Path(filename)
193 path_via_git = git_root / filename
195 if path.exists():
196 exists = True
197 absolute_path = str(path.resolve())
198 elif path_via_git.exists():
199 exists = True
200 absolute_path = str(path_via_git.resolve())
201 else:
202 exists = False
203 absolute_path = ""
205 return exists, absolute_path
208def ensure_no_migrate_on_real_db(
209 code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: typing.Optional[bool] = False
210) -> str:
211 variables = find_defined_variables(code)
213 found_variables = set()
215 for db_name in db_names:
216 if db_name in variables:
217 if fix:
218 code = remove_specific_variables(code, db_names)
219 else:
220 found_variables.add(db_name)
222 if found_variables:
223 if len(found_variables) == 1:
224 var = next(iter(found_variables))
225 message = f"Variable {var} defined in code! "
226 else: # pragma: no cover
227 var = ", ".join(found_variables)
228 message = f"Variables {var} defined in code! "
229 raise ValueError(
230 f"{message} Please remove this or use --magic to prevent performing actual migrations on your database."
231 )
233 if find_local_imports(code):
234 if fix:
235 code = remove_local_imports(code)
236 else:
237 raise ValueError("Local imports are used in this file! Please remove these or use --magic.")
239 return code
242MAX_RETRIES = 20
245def handle_cli(
246 code_before: str,
247 code_after: str,
248 db_type: Optional[str] = None,
249 tables: Optional[list[str] | list[list[str]]] = None,
250 verbose: Optional[bool] = False,
251 noop: Optional[bool] = False,
252 magic: Optional[bool] = False,
253 function_name: Optional[str] = 'define_tables',
254) -> bool:
255 """
256 Handle user input.
257 """
258 # todo: prefix (e.g. public.)
260 to_execute = string.Template(
261 textwrap.dedent(
262 """
263 from pydal import *
264 from pydal.objects import *
265 from pydal.validators import *
267 from pydal2sql import generate_sql
269 db = database = DAL(None, migrate=False)
271 tables = $tables
272 db_type = '$db_type'
274 $extra
276 $code_before
278 db_old = db
279 db_new = db = database = DAL(None, migrate=False)
281 $code_after
283 if not tables:
284 tables = set(db_old._tables + db_new._tables)
286 if not tables:
287 raise ValueError('no-tables-found')
290 for table in tables:
291 print('--', table)
292 if table in db_old and table in db_new:
293 print(generate_sql(db_old[table], db_new[table], db_type=db_type))
294 elif table in db_old:
295 print(f'DROP TABLE {table};')
296 else:
297 print(generate_sql(db_new[table], db_type=db_type))
298 """
299 )
300 )
302 code_before = ensure_no_migrate_on_real_db(code_before, fix=magic)
303 code_after = ensure_no_migrate_on_real_db(code_after, fix=magic)
305 generated_code = to_execute.substitute(
306 {
307 "tables": flatten(tables or []),
308 "db_type": db_type or "",
309 "code_before": textwrap.dedent(code_before),
310 "code_after": textwrap.dedent(code_after),
311 "extra": "",
312 }
313 )
314 if verbose or noop:
315 rich.print(generated_code, file=sys.stderr)
317 if not noop:
318 err: typing.Optional[Exception] = None
319 retry_counter = MAX_RETRIES
320 while retry_counter > 0:
321 retry_counter -= 1
322 try:
323 exec(generated_code) # nosec: B102
324 return True # success!
325 except ValueError as e:
326 err = e
328 if str(e) != 'no-tables-found':
329 raise e
331 define_tables = find_function_to_call(generated_code, function_name)
333 # if define_tables function is found, add call to it at end of code
334 if define_tables is not None:
335 print(define_tables)
336 generated_code = add_function_call(generated_code, function_name)
337 print(generated_code, function_name)
338 continue
340 # else: no define_tables or other method to use found.
342 print(f"No tables found in the top-level or {function_name} function!", file=sys.stderr)
343 print("Please use `db.define_table` or `database.define_table`, "
344 "or if you really need to use an alias like my_db.define_tables, "
345 "add `my_db = db` at the top of the file or pass `--db-name mydb`.", file=sys.stderr)
346 print(f"You can also specify a --function to use something else than {function_name}.",
347 file=sys.stderr)
349 return False
351 except NameError as e:
352 err = e
353 # something is missing!
354 missing_vars = find_missing_variables(generated_code)
355 if not magic:
356 rich.print(
357 f"Your code is missing some variables: {missing_vars}. Add these or try --magic",
358 file=sys.stderr,
359 )
360 return False
362 extra_code = generate_magic_code(missing_vars)
364 generated_code = to_execute.substitute(
365 {
366 "tables": flatten(tables or []),
367 "db_type": db_type or "",
368 "extra": extra_code,
369 "code_before": textwrap.dedent(code_before),
370 "code_after": textwrap.dedent(code_after),
371 }
372 )
374 if verbose:
375 rich.print(generated_code, file=sys.stderr)
376 except ImportError as e:
377 err = e
378 # if we catch an ImportError, we try to remove the import and retry
379 generated_code = remove_import(generated_code, e.name)
381 if retry_counter < 1: # pragma: no cover
382 rich.print(f"[red]Code could not be fixed automagically![/red]. Error: {err or '?'}", file=sys.stderr)
383 return False
385 return True