Coverage for src/pydal2sql_core/helpers.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-08-05 17:25 +0200

1""" 

2Contains helpers for core. 

3""" 

4 

5import tempfile 

6import types 

7import typing 

8from contextlib import contextmanager 

9from pathlib import Path 

10 

11T = typing.TypeVar("T", bound=typing.Any) 

12Recurse = typing.Union[T, typing.Iterable["Recurse[T]"]] 

13 

14 

15def _flatten(xs: typing.Iterable[T | Recurse[T]]) -> typing.Generator[T, None, None]: 

16 """ 

17 Flatten recursively. 

18 """ 

19 for x in xs: 

20 if isinstance(x, typing.Iterable) and not isinstance(x, (str, bytes)): 

21 yield from _flatten(x) 

22 else: 

23 yield typing.cast(T, x) 

24 

25 

26def flatten(it: Recurse[T]) -> list[T]: 

27 """ 

28 Turn an arbitrarily nested iterable into a flat (1d) list. 

29 

30 Example: 

31 [[[1]], 2] -> [1, 2] 

32 """ 

33 generator = _flatten(it) 

34 return list(generator) 

35 

36 

37ANY_TYPE = type | types.UnionType | typing.Type[typing.Any] | typing._SpecialForm 

38 

39 

40def _get_typing_args_recursive(some: ANY_TYPE) -> list[type]: 

41 """ 

42 Recursively extract types from parameterized types such as unions or generics. 

43 

44 Note: 

45 The return type is actually a nested list of types and strings! 

46 Please use `get_typing_args`, which calls flatten to create a 1D list of types and strings. 

47 

48 get_typing_args_recursive( 

49 typing.Union["str", typing.Literal["Joe"]] 

50 ) 

51 

52 -> [[<class 'str'>], [['Joe']]] 

53 """ 

54 if args := typing.get_args(some): 

55 return [_get_typing_args_recursive(_) for _ in args] # type: ignore # due to recursion 

56 

57 # else: no args -> it's just a type! 

58 if isinstance(some, typing.ForwardRef): 

59 return [ 

60 # ForwardRef<str> -> str 

61 some._evaluate(globals(), locals(), set()) 

62 ] 

63 

64 return [typing.cast(type, some)] 

65 

66 

67def uniq(some_list: list[T]) -> list[T]: 

68 """ 

69 Remove duplicates from some_list while preserving the initial order. 

70 """ 

71 return list(dict.fromkeys(some_list)) 

72 

73 

74def excl(some_list: list[T], without: typing.Iterable[T] | typing.Any) -> list[T]: 

75 """ 

76 Remove 'without' from 'some list'. 

77 

78 Without can be an iterable of items to remove or a single value. 

79 """ 

80 if isinstance(without, typing.Iterable): 

81 return [_ for _ in some_list if _ not in without] 

82 else: 

83 return [_ for _ in some_list if _ != without] 

84 

85 

86def get_typing_args(some: ANY_TYPE) -> list[type | str]: 

87 """ 

88 Extract typing.get_args for Unions, Literals etc. 

89 

90 Useful for e.g. getting the values of Literals' 

91 """ 

92 return flatten( 

93 _get_typing_args_recursive(some), 

94 ) 

95 

96 

97@contextmanager 

98def TempdirOrExistingDir(folder_path: typing.Optional[str | Path] = None) -> typing.Generator[str, None, None]: 

99 """ 

100 Either use db_folder or create a tempdir. 

101 

102 The tempdir will be removed on exit, your original folder_path will not be modified in that regard. 

103 

104 Example: 

105 with TempdirOrExistingDir() as my_path: ... 

106 """ 

107 if folder_path is None: 

108 tmp_dir = tempfile.TemporaryDirectory() 

109 yield tmp_dir.name 

110 elif isinstance(folder_path, Path): 

111 yield str(folder_path) 

112 else: 

113 yield folder_path