Coverage for src/pydal2sql/helpers.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-21 11:14 +0200

1""" 

2Contains helpers for core. 

3""" 

4 

5import tempfile 

6import types 

7import typing 

8from contextlib import contextmanager 

9 

10T = typing.TypeVar("T", bound=typing.Any) 

11Recurse = typing.Union[T, typing.Iterable["Recurse[T]"]] 

12 

13 

14def _flatten(xs: typing.Iterable[T | Recurse[T]]) -> typing.Generator[T, None, None]: 

15 """ 

16 Flatten recursively. 

17 """ 

18 for x in xs: 

19 if isinstance(x, typing.Iterable) and not isinstance(x, (str, bytes)): 

20 yield from _flatten(x) 

21 else: 

22 yield typing.cast(T, x) 

23 

24 

25def flatten(it: Recurse[T]) -> list[T]: 

26 """ 

27 Turn an arbitrarily nested iterable into a flat (1d) list. 

28 

29 Example: 

30 [[[1]], 2] -> [1, 2] 

31 """ 

32 generator = _flatten(it) 

33 return list(generator) 

34 

35 

36ANY_TYPE = type | types.UnionType | typing.Type[typing.Any] | typing._SpecialForm 

37 

38 

39def _get_typing_args_recursive(some: ANY_TYPE) -> list[type]: 

40 """ 

41 Recursively extract types from parameterized types such as unions or generics. 

42 

43 Note: 

44 The return type is actually a nested list of types and strings! 

45 Please use `get_typing_args`, which calls flatten to create a 1D list of types and strings. 

46 

47 get_typing_args_recursive( 

48 typing.Union["str", typing.Literal["Joe"]] 

49 ) 

50 

51 -> [[<class 'str'>], [['Joe']]] 

52 """ 

53 if args := typing.get_args(some): 

54 return [_get_typing_args_recursive(_) for _ in args] # type: ignore # due to recursion 

55 

56 # else: no args -> it's just a type! 

57 if isinstance(some, typing.ForwardRef): 

58 return [ 

59 # ForwardRef<str> -> str 

60 some._evaluate(globals(), locals(), set()) 

61 ] 

62 

63 return [typing.cast(type, some)] 

64 

65 

66def get_typing_args(some: ANY_TYPE) -> list[type | str]: 

67 """ 

68 Extract typing.get_args for Unions, Literals etc. 

69 

70 Useful for e.g. getting the values of Literals' 

71 """ 

72 return flatten( 

73 _get_typing_args_recursive(some), 

74 ) 

75 

76 

77@contextmanager 

78def TempdirOrExistingDir(folder_path: str = None) -> typing.Generator[str, None, None]: 

79 """ 

80 Either use db_folder or create a tempdir. 

81 

82 The tempdir will be removed on exit, your original folder_path will not be modified in that regard. 

83 

84 Example: 

85 with TempdirOrExistingDir() as my_path: ... 

86 """ 

87 if folder_path is None: 

88 tmp_dir = tempfile.TemporaryDirectory() 

89 yield tmp_dir.name 

90 else: 

91 yield folder_path