Coverage for src/pydal2sql/typer_support.py: 100%
153 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 16:22 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 16:22 +0200
1"""
2Cli-specific support
3"""
4import contextlib
5import functools
6import inspect
7import operator
8import os
9import sys
10import typing
11from dataclasses import dataclass
12from enum import Enum, EnumMeta
13from pathlib import Path
14from typing import Any, Optional
16import configuraptor
17import dotenv
18import rich
19import tomli
20import typer
21from black.files import find_project_root
22from configuraptor.helpers import find_pyproject_toml
23from su6.core import (
24 EXIT_CODE_ERROR,
25 EXIT_CODE_SUCCESS,
26 T_Command,
27 T_Inner_Wrapper,
28 T_Outer_Wrapper,
29)
30from typing_extensions import Never
32from .types import SUPPORTED_DATABASE_TYPES_WITH_ALIASES
34T_Literal = typing._SpecialForm
36LiteralType = typing.TypeVar("LiteralType", str, typing.Union[str, str] | T_Literal)
39class ReprEnumMeta(EnumMeta):
40 def __repr__(cls) -> str: # sourcery skip
41 options = typing.cast(typing.Iterable[Enum], cls.__members__.values()) # for mypy
42 members_repr = ", ".join(f"{m.value!r}" for m in options)
43 return f"{cls.__name__}({members_repr})"
46class DynamicEnum(Enum, metaclass=ReprEnumMeta):
47 ...
50def create_enum_from_literal(name: str, literal_type: LiteralType) -> typing.Type[DynamicEnum]:
51 literals: list[str] = []
53 if hasattr(literal_type, "__args__"):
54 for arg in typing.get_args(literal_type):
55 if hasattr(arg, "__args__"):
56 # e.g. literal_type = typing.Union[typing.Literal['one', 'two']]
57 literals.extend(typing.get_args(arg))
58 else:
59 # e.g. literal_type = typing.Literal['one', 'two']
60 literals.append(arg)
61 else:
62 # e.g. literal_type = 'one'
63 literals.append(str(literal_type))
65 literals.sort()
67 enum_dict = {}
69 for literal in literals:
70 enum_name = literal.replace(" ", "_").upper()
71 enum_value = literal
72 enum_dict[enum_name] = enum_value
74 return DynamicEnum(name, enum_dict) # type: ignore
77class Verbosity(Enum):
78 """
79 Verbosity is used with the --verbose argument of the cli commands.
80 """
82 # typer enum can only be string
83 quiet = "1"
84 normal = "2"
85 verbose = "3"
86 debug = "4" # only for internal use
88 @staticmethod
89 def _compare(
90 self: "Verbosity",
91 other: "Verbosity_Comparable",
92 _operator: typing.Callable[["Verbosity_Comparable", "Verbosity_Comparable"], bool],
93 ) -> bool:
94 """
95 Abstraction using 'operator' to have shared functionality between <, <=, ==, >=, >.
97 This enum can be compared with integers, strings and other Verbosity instances.
99 Args:
100 self: the first Verbosity
101 other: the second Verbosity (or other thing to compare)
102 _operator: a callable operator (from 'operators') that takes two of the same types as input.
103 """
104 match other:
105 case Verbosity():
106 return _operator(self.value, other.value)
107 case int():
108 return _operator(int(self.value), other)
109 case str():
110 return _operator(int(self.value), int(other))
112 def __gt__(self, other: "Verbosity_Comparable") -> bool:
113 """
114 Magic method for self > other.
115 """
116 return self._compare(self, other, operator.gt)
118 def __ge__(self, other: "Verbosity_Comparable") -> bool:
119 """
120 Method magic for self >= other.
121 """
122 return self._compare(self, other, operator.ge)
124 def __lt__(self, other: "Verbosity_Comparable") -> bool:
125 """
126 Magic method for self < other.
127 """
128 return self._compare(self, other, operator.lt)
130 def __le__(self, other: "Verbosity_Comparable") -> bool:
131 """
132 Magic method for self <= other.
133 """
134 return self._compare(self, other, operator.le)
136 def __eq__(self, other: typing.Union["Verbosity", str, int, object]) -> bool:
137 """
138 Magic method for self == other.
140 'eq' is a special case because 'other' MUST be object according to mypy
141 """
142 if other is Ellipsis or other is inspect._empty:
143 # both instances of object; can't use Ellipsis or type(ELlipsis) = ellipsis as a type hint in mypy
144 # special cases where Typer instanciates its cli arguments,
145 # return False or it will crash
146 return False
147 if not isinstance(other, (str, int, Verbosity)):
148 raise TypeError(f"Object of type {type(other)} can not be compared with Verbosity")
149 return self._compare(self, other, operator.eq)
151 def __hash__(self) -> int:
152 """
153 Magic method for `hash(self)`, also required for Typer to work.
154 """
155 return hash(self.value)
158Verbosity_Comparable = Verbosity | str | int
160DEFAULT_VERBOSITY = Verbosity.normal
163class AbstractConfig(configuraptor.TypedConfig, configuraptor.Singleton):
164 """
165 Used by state.config and plugin configs.
166 """
168 _strict = True
171DB_Types: typing.Any = create_enum_from_literal("DBType", SUPPORTED_DATABASE_TYPES_WITH_ALIASES)
174@dataclass
175class Config(AbstractConfig):
176 """
177 Used as typed version of the [tool.pydal2sql] part of pyproject.toml.
179 Also accessible via state.config
180 """
182 # settings go here
183 db_type: typing.Optional[SUPPORTED_DATABASE_TYPES_WITH_ALIASES] = None
184 magic: bool = False
185 noop: bool = False
186 tables: Optional[list[str]] = None
187 pyproject: typing.Optional[str] = None
188 function: str = 'define_tables'
191MaybeConfig = Optional[Config]
194def _get_pydal2sql_config(overwrites: dict[str, Any], toml_path: str = None) -> MaybeConfig:
195 """
196 Parse the users pyproject.toml (found using black's logic) and extract the tool.pydal2sql part.
198 The types as entered in the toml are checked using _ensure_types,
199 to make sure there isn't a string implicitly converted to a list of characters or something.
201 Args:
202 overwrites: cli arguments can overwrite the config toml.
203 toml_path: by default, black will search for a relevant pyproject.toml.
204 If a toml_path is provided, that file will be used instead.
205 """
206 if toml_path is None:
207 toml_path = find_pyproject_toml()
209 if not toml_path:
210 return None
212 with open(toml_path, "rb") as f:
213 full_config = tomli.load(f)
215 tool_config = full_config["tool"]
217 config = configuraptor.load_into(Config, tool_config, key="pydal2sql")
219 config.update(pyproject=toml_path)
220 config.update(**overwrites)
222 return config
225def get_pydal2sql_config(toml_path: str = None, verbosity: Verbosity = DEFAULT_VERBOSITY, **overwrites: Any) -> Config:
226 """
227 Load the relevant pyproject.toml config settings.
229 Args:
230 verbosity: if something goes wrong, level 3+ will show a warning and 4+ will raise the exception.
231 toml_path: --config can be used to use a different file than ./pyproject.toml
232 overwrites (dict[str, Any): cli arguments can overwrite the config toml.
233 If a value is None, the key is not overwritten.
234 """
235 # strip out any 'overwrites' with None as value
236 overwrites = configuraptor.convert_config(overwrites)
238 try:
239 if config := _get_pydal2sql_config(overwrites, toml_path=toml_path):
240 return config
241 raise ValueError("Falsey config?")
242 except Exception as e:
243 # something went wrong parsing config, use defaults
244 if verbosity > 3:
245 # verbosity = debug
246 raise e
247 elif verbosity > 2:
248 # verbosity = verbose
249 print("Error parsing pyproject.toml, falling back to defaults.", file=sys.stderr)
250 return Config(**overwrites)
253@dataclass()
254class ApplicationState:
255 """
256 Application State - global user defined variables.
258 State contains generic variables passed BEFORE the subcommand (so --verbosity, --config, ...),
259 whereas Config contains settings from the config toml file, updated with arguments AFTER the subcommand
260 (e.g. pydal2sql subcommand <directory> --flag), directory and flag will be updated in the config and not the state.
262 To summarize: 'state' is applicable to all commands and config only to specific ones.
263 """
265 verbosity: Verbosity = DEFAULT_VERBOSITY
266 config_file: Optional[str] = None # will be filled with black's search logic
267 config: MaybeConfig = None
269 def __post_init__(self) -> None:
270 ...
272 def load_config(self, **overwrites: Any) -> Config:
273 """
274 Load the pydal2sql config from pyproject.toml (or other config_file) with optional overwriting settings.
276 Also updates attached plugin configs.
277 """
278 if "verbosity" in overwrites:
279 self.verbosity = overwrites["verbosity"]
280 if "config_file" in overwrites:
281 self.config_file = overwrites.pop("config_file")
283 self.config = get_pydal2sql_config(toml_path=self.config_file, **overwrites)
284 return self.config
286 def get_config(self) -> Config:
287 """
288 Get a filled config instance.
289 """
290 return self.config or self.load_config()
292 def update_config(self, **values: Any) -> Config:
293 """
294 Overwrite default/toml settings with cli values.
296 Example:
297 `config = state.update_config(directory='src')`
298 This will update the state's config and return the same object with the updated settings.
299 """
300 existing_config = self.get_config()
302 values = configuraptor.convert_config(values)
303 existing_config.update(**values)
304 return existing_config
307def with_exit_code(hide_tb: bool = True) -> T_Outer_Wrapper:
308 """
309 Convert the return value of an app.command (bool or int) to an typer Exit with return code, \
310 Unless the return value is Falsey, in which case the default exit happens (with exit code 0 indicating success).
312 Usage:
313 > @app.command()
314 > @with_exit_code()
315 def some_command(): ...
317 When calling a command from a different command, _suppress=True can be added to not raise an Exit exception.
319 See also:
320 github.com:trialandsuccess/su6-checker
321 """
323 def outer_wrapper(func: T_Command) -> T_Inner_Wrapper:
324 @functools.wraps(func)
325 def inner_wrapper(*args: Any, **kwargs: Any) -> Never:
326 try:
327 result = func(*args, **kwargs)
328 except Exception as e:
329 result = EXIT_CODE_ERROR
330 if hide_tb:
331 rich.print(f"[red]{e}[/red]", file=sys.stderr)
332 else: # pragma: no cover
333 raise e
335 if isinstance(result, bool):
336 if result in (None, True):
337 # assume no issue then
338 result = EXIT_CODE_SUCCESS
339 elif result is False:
340 result = EXIT_CODE_ERROR
342 raise typer.Exit(code=int(result or 0))
344 return inner_wrapper
346 return outer_wrapper
349def _is_debug() -> bool: # pragma: no cover
350 folder, _ = find_project_root((os.getcwd(),))
351 if not folder:
352 folder = Path(os.getcwd())
353 dotenv.load_dotenv(folder / ".env")
355 return os.getenv("IS_DEBUG") == "1"
358def is_debug() -> bool: # pragma: no cover
359 with contextlib.suppress(Exception):
360 return _is_debug()
361 return False
364IS_DEBUG = is_debug()