Coverage for muutils\validate_type.py: 77%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-15 20:56 -0600

1"""experimental utility for validating types in python, see `validate_type`""" 

2 

3from __future__ import annotations 

4 

5import types 

6import typing 

7import functools 

8 

9# this is also for python <3.10 compatibility 

10_GenericAliasTypeNames: typing.List[str] = [ 

11 "GenericAlias", 

12 "_GenericAlias", 

13 "_UnionGenericAlias", 

14 "_BaseGenericAlias", 

15] 

16 

17_GenericAliasTypesList: list = [ 

18 getattr(typing, name, None) for name in _GenericAliasTypeNames 

19] 

20 

21GenericAliasTypes: tuple = tuple([t for t in _GenericAliasTypesList if t is not None]) 

22 

23 

24class IncorrectTypeException(TypeError): 

25 pass 

26 

27 

28class TypeHintNotImplementedError(NotImplementedError): 

29 pass 

30 

31 

32class InvalidGenericAliasError(TypeError): 

33 pass 

34 

35 

36def _return_validation_except( 

37 return_val: bool, value: typing.Any, expected_type: typing.Any 

38) -> bool: 

39 if return_val: 

40 return True 

41 else: 

42 raise IncorrectTypeException( 

43 f"Expected {expected_type = } for {value = }", 

44 f"{type(value) = }", 

45 f"{type(value).__mro__ = }", 

46 f"{typing.get_origin(expected_type) = }", 

47 f"{typing.get_args(expected_type) = }", 

48 "\ndo --tb=long in pytest to see full trace", 

49 ) 

50 return False 

51 

52 

53def _return_validation_bool(return_val: bool) -> bool: 

54 return return_val 

55 

56 

57def validate_type( 

58 value: typing.Any, expected_type: typing.Any, do_except: bool = False 

59) -> bool: 

60 """Validate that a `value` is of the `expected_type` 

61 

62 # Parameters 

63 - `value`: the value to check the type of 

64 - `expected_type`: the type to check against. Not all types are supported 

65 - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`) 

66 (default: `False`) 

67 

68 # Returns 

69 - `bool`: `True` if the value is of the expected type, `False` otherwise. 

70 

71 # Raises 

72 - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True` 

73 - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented 

74 - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid 

75 

76 use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard 

77 """ 

78 if expected_type is typing.Any: 

79 return True 

80 

81 # set up the return function depending on `do_except` 

82 _return_func: typing.Callable[[bool], bool] = ( 

83 # functools.partial doesn't hint the function signature 

84 functools.partial( # type: ignore[assignment] 

85 _return_validation_except, value=value, expected_type=expected_type 

86 ) 

87 if do_except 

88 else _return_validation_bool 

89 ) 

90 

91 # base type without args 

92 if isinstance(expected_type, type): 

93 try: 

94 # if you use args on a type like `dict[str, int]`, this will fail 

95 return _return_func(isinstance(value, expected_type)) 

96 except TypeError as e: 

97 if isinstance(e, IncorrectTypeException): 

98 raise e 

99 

100 origin: typing.Any = typing.get_origin(expected_type) 

101 args: tuple = typing.get_args(expected_type) 

102 

103 # useful for debugging 

104 # print(f"{value = }, {expected_type = }, {origin = }, {args = }") 

105 UnionType = getattr(types, "UnionType", None) 

106 

107 if (origin is typing.Union) or ( # this works in python <3.10 

108 False 

109 if UnionType is None # return False if UnionType is not available 

110 else origin is UnionType # return True if UnionType is available 

111 ): 

112 return _return_func(any(validate_type(value, arg) for arg in args)) 

113 

114 # generic alias, more complicated 

115 item_type: type 

116 if isinstance(expected_type, GenericAliasTypes): 

117 if origin is list: 

118 # no args 

119 if len(args) == 0: 

120 return _return_func(isinstance(value, list)) 

121 # incorrect number of args 

122 if len(args) != 1: 

123 raise InvalidGenericAliasError( 

124 f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }", 

125 f"{GenericAliasTypes = }", 

126 ) 

127 # check is list 

128 if not isinstance(value, list): 

129 return _return_func(False) 

130 # check all items in list are of the correct type 

131 item_type = args[0] 

132 return all(validate_type(item, item_type) for item in value) 

133 

134 if origin is dict: 

135 # no args 

136 if len(args) == 0: 

137 return _return_func(isinstance(value, dict)) 

138 # incorrect number of args 

139 if len(args) != 2: 

140 raise InvalidGenericAliasError( 

141 f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }", 

142 f"{GenericAliasTypes = }", 

143 ) 

144 # check is dict 

145 if not isinstance(value, dict): 

146 return _return_func(False) 

147 # check all items in dict are of the correct type 

148 key_type: type = args[0] 

149 value_type: type = args[1] 

150 return _return_func( 

151 all( 

152 validate_type(key, key_type) and validate_type(val, value_type) 

153 for key, val in value.items() 

154 ) 

155 ) 

156 

157 if origin is set: 

158 # no args 

159 if len(args) == 0: 

160 return _return_func(isinstance(value, set)) 

161 # incorrect number of args 

162 if len(args) != 1: 

163 raise InvalidGenericAliasError( 

164 f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }", 

165 f"{GenericAliasTypes = }", 

166 ) 

167 # check is set 

168 if not isinstance(value, set): 

169 return _return_func(False) 

170 # check all items in set are of the correct type 

171 item_type = args[0] 

172 return _return_func(all(validate_type(item, item_type) for item in value)) 

173 

174 if origin is tuple: 

175 # no args 

176 if len(args) == 0: 

177 return _return_func(isinstance(value, tuple)) 

178 # check is tuple 

179 if not isinstance(value, tuple): 

180 return _return_func(False) 

181 # check correct number of items in tuple 

182 if len(value) != len(args): 

183 return _return_func(False) 

184 # check all items in tuple are of the correct type 

185 return _return_func( 

186 all(validate_type(item, arg) for item, arg in zip(value, args)) 

187 ) 

188 

189 if origin is type: 

190 # no args 

191 if len(args) == 0: 

192 return _return_func(isinstance(value, type)) 

193 # incorrect number of args 

194 if len(args) != 1: 

195 raise InvalidGenericAliasError( 

196 f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }", 

197 f"{GenericAliasTypes = }", 

198 ) 

199 # check is type 

200 item_type = args[0] 

201 if item_type in value.__mro__: 

202 return _return_func(True) 

203 else: 

204 return _return_func(False) 

205 

206 # TODO: Callables, etc. 

207 

208 raise TypeHintNotImplementedError( 

209 f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }", 

210 f"{origin = }, {args = }", 

211 f"\n{GenericAliasTypes = }", 

212 ) 

213 

214 else: 

215 raise TypeHintNotImplementedError( 

216 f"Unsupported type hint {expected_type = } for {value = }", 

217 f"{origin = }, {args = }", 

218 f"\n{GenericAliasTypes = }", 

219 )