Coverage for src/pydal2sql_core/helpers.py: 100%

42 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-04-22 11:38 +0200

1""" 

2Contains helpers for core. 

3""" 

4 

5import tempfile 

6import types 

7import typing 

8from contextlib import contextmanager 

9from pathlib import Path 

10 

11import witchery 

12from typedal.core import evaluate_forward_reference 

13 

14T = typing.TypeVar("T", bound=typing.Any) 

15Recurse = typing.Union[T, typing.Iterable["Recurse[T]"]] 

16 

17 

18def _flatten(xs: typing.Iterable[T | Recurse[T]]) -> typing.Generator[T, None, None]: 

19 """ 

20 Flatten recursively. 

21 """ 

22 for x in xs: 

23 if isinstance(x, typing.Iterable) and not isinstance(x, (str, bytes)): 

24 yield from _flatten(x) 

25 else: 

26 yield typing.cast(T, x) 

27 

28 

29def flatten(it: Recurse[T]) -> list[T]: 

30 """ 

31 Turn an arbitrarily nested iterable into a flat (1d) list. 

32 

33 Example: 

34 [[[1]], 2] -> [1, 2] 

35 """ 

36 generator = _flatten(it) 

37 return list(generator) 

38 

39 

40ANY_TYPE = type | types.UnionType | typing.Type[typing.Any] | typing._SpecialForm 

41 

42 

43def _get_typing_args_recursive(some: ANY_TYPE) -> list[type]: 

44 """ 

45 Recursively extract types from parameterized types such as unions or generics. 

46 

47 Note: 

48 The return type is actually a nested list of types and strings! 

49 Please use `get_typing_args`, which calls flatten to create a 1D list of types and strings. 

50 

51 get_typing_args_recursive( 

52 typing.Union["str", typing.Literal["Joe"]] 

53 ) 

54 

55 -> [[<class 'str'>], [['Joe']]] 

56 """ 

57 if args := typing.get_args(some): 

58 return [_get_typing_args_recursive(_) for _ in args] # type: ignore # due to recursion 

59 

60 # else: no args -> it's just a type! 

61 if isinstance(some, typing.ForwardRef): 

62 return [ 

63 # ForwardRef<str> -> str 

64 evaluate_forward_reference(some) 

65 ] 

66 

67 return [typing.cast(type, some)] 

68 

69 

70def uniq(some_list: list[T]) -> list[T]: 

71 """ 

72 Remove duplicates from some_list while preserving the initial order. 

73 """ 

74 return list(dict.fromkeys(some_list)) 

75 

76 

77def excl(some_list: list[T], without: typing.Iterable[T] | typing.Any) -> list[T]: 

78 """ 

79 Remove 'without' from 'some list'. 

80 

81 Without can be an iterable of items to remove or a single value. 

82 """ 

83 if isinstance(without, typing.Iterable): 

84 return [_ for _ in some_list if _ not in without] 

85 else: 

86 return [_ for _ in some_list if _ != without] 

87 

88 

89def get_typing_args(some: ANY_TYPE) -> list[type | str]: 

90 """ 

91 Extract typing.get_args for Unions, Literals etc. 

92 

93 Useful for e.g. getting the values of Literals' 

94 """ 

95 return flatten( 

96 _get_typing_args_recursive(some), 

97 ) 

98 

99 

100@contextmanager 

101def TempdirOrExistingDir(folder_path: typing.Optional[str | Path] = None) -> typing.Generator[str, None, None]: 

102 """ 

103 Either use db_folder or create a tempdir. 

104 

105 The tempdir will be removed on exit, your original folder_path will not be modified in that regard. 

106 

107 Example: 

108 with TempdirOrExistingDir() as my_path: ... 

109 """ 

110 if folder_path is None: 

111 tmp_dir = tempfile.TemporaryDirectory() 

112 yield tmp_dir.name 

113 elif isinstance(folder_path, Path): 

114 yield str(folder_path) 

115 else: 

116 yield folder_path 

117 

118 

119def detect_typedal(code: str): 

120 """ 

121 Check if TypeDAL is imported somewhere in the code. 

122 """ 

123 return "typedal" in witchery.find_imported_modules(code)