Coverage for src/pydal2sql/helpers.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-20 15:22 +0200

1""" 

2Contains helpers for core. 

3""" 

4 

5import tempfile 

6import types 

7import typing 

8from contextlib import contextmanager 

9 

10# todo: add infinite recurse? 

11T = typing.TypeVar("T") 

12Recurse = typing.Iterable[typing.Iterable[typing.Iterable[T] | T] | T] 

13 

14 

15def _flatten(xs: Recurse[T]) -> typing.Generator[T, None, None]: 

16 """ 

17 Flatten recursively. 

18 """ 

19 for x in xs: 

20 if isinstance(x, typing.Iterable) and not isinstance(x, (str, bytes)): 

21 yield from _flatten(x) 

22 else: 

23 yield typing.cast(T, x) 

24 

25 

26def flatten(it: Recurse[T]) -> list[T]: 

27 """ 

28 Turn an arbitrarily nested iterable into a flat (1d) list. 

29 

30 Example: 

31 [[[1]], 2] -> [1, 2] 

32 """ 

33 generator = _flatten(it) 

34 return list(generator) 

35 

36 

37ANY_TYPE = type | types.UnionType | typing.Type[typing.Any] | typing._SpecialForm 

38 

39 

40def _get_typing_args_recursive(some: ANY_TYPE) -> list[type]: 

41 """ 

42 Recursively extract types from parameterized types such as unions or generics. 

43 

44 Note: 

45 The return type is actually a nested list of types and strings! 

46 Please use `get_typing_args`, which calls flatten to create a 1D list of types and strings. 

47 

48 get_typing_args_recursive( 

49 typing.Union["str", typing.Literal["Joe"]] 

50 ) 

51 

52 -> [[<class 'str'>], [['Joe']]] 

53 """ 

54 if args := typing.get_args(some): 

55 return [_get_typing_args_recursive(_) for _ in args] # type: ignore # due to recursion 

56 

57 # else: no args -> it's just a type! 

58 if isinstance(some, typing.ForwardRef): 

59 return [ 

60 # ForwardRef<str> -> str 

61 some._evaluate(globals(), locals(), set()) 

62 ] 

63 

64 return [typing.cast(type, some)] 

65 

66 

67def get_typing_args(some: ANY_TYPE) -> list[type | str]: 

68 """ 

69 Extract typing.get_args for Unions, Literals etc. 

70 

71 Useful for e.g. getting the values of Literals' 

72 """ 

73 return flatten( 

74 _get_typing_args_recursive(some), 

75 ) 

76 

77 

78@contextmanager 

79def TempdirOrExistingDir(folder_path: str = None) -> typing.Generator[str, None, None]: 

80 """ 

81 Either use db_folder or create a tempdir. 

82 

83 The tempdir will be removed on exit, your original folder_path will not be modified in that regard. 

84 

85 Example: 

86 with TempdirOrExistingDir() as my_path: ... 

87 """ 

88 if folder_path is None: 

89 tmp_dir = tempfile.TemporaryDirectory() 

90 yield tmp_dir.name 

91 else: 

92 yield folder_path