Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2python version compatibility code 

3""" 

4import enum 

5import functools 

6import inspect 

7import os 

8import re 

9import sys 

10from contextlib import contextmanager 

11from inspect import Parameter 

12from inspect import signature 

13from typing import Any 

14from typing import Callable 

15from typing import Generic 

16from typing import Optional 

17from typing import overload 

18from typing import Tuple 

19from typing import TypeVar 

20from typing import Union 

21 

22import attr 

23import py 

24 

25from _pytest._io.saferepr import saferepr 

26from _pytest.outcomes import fail 

27from _pytest.outcomes import TEST_OUTCOME 

28 

29if sys.version_info < (3, 5, 2): 

30 TYPE_CHECKING = False # type: bool 

31else: 

32 from typing import TYPE_CHECKING 

33 

34 

35if TYPE_CHECKING: 

36 from typing import NoReturn 

37 from typing import Type 

38 from typing_extensions import Final 

39 

40 

41_T = TypeVar("_T") 

42_S = TypeVar("_S") 

43 

44 

45# fmt: off 

46# Singleton type for NOTSET, as described in: 

47# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions 

48class NotSetType(enum.Enum): 

49 token = 0 

50NOTSET = NotSetType.token # type: Final # noqa: E305 

51# fmt: on 

52 

53MODULE_NOT_FOUND_ERROR = ( 

54 "ModuleNotFoundError" if sys.version_info[:2] >= (3, 6) else "ImportError" 

55) 

56 

57 

58if sys.version_info >= (3, 8): 

59 from importlib import metadata as importlib_metadata 

60else: 

61 import importlib_metadata # noqa: F401 

62 

63 

64def _format_args(func: Callable[..., Any]) -> str: 

65 return str(signature(func)) 

66 

67 

68# The type of re.compile objects is not exposed in Python. 

69REGEX_TYPE = type(re.compile("")) 

70 

71 

72if sys.version_info < (3, 6): 

73 

74 def fspath(p): 

75 """os.fspath replacement, useful to point out when we should replace it by the 

76 real function once we drop py35. 

77 """ 

78 return str(p) 

79 

80 

81else: 

82 fspath = os.fspath 

83 

84 

85def is_generator(func: object) -> bool: 

86 genfunc = inspect.isgeneratorfunction(func) 

87 return genfunc and not iscoroutinefunction(func) 

88 

89 

90def iscoroutinefunction(func: object) -> bool: 

91 """ 

92 Return True if func is a coroutine function (a function defined with async 

93 def syntax, and doesn't contain yield), or a function decorated with 

94 @asyncio.coroutine. 

95 

96 Note: copied and modified from Python 3.5's builtin couroutines.py to avoid 

97 importing asyncio directly, which in turns also initializes the "logging" 

98 module as a side-effect (see issue #8). 

99 """ 

100 return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False) 

101 

102 

103def is_async_function(func: object) -> bool: 

104 """Return True if the given function seems to be an async function or async generator""" 

105 return iscoroutinefunction(func) or ( 

106 sys.version_info >= (3, 6) and inspect.isasyncgenfunction(func) 

107 ) 

108 

109 

110def getlocation(function, curdir=None) -> str: 

111 function = get_real_func(function) 

112 fn = py.path.local(inspect.getfile(function)) 

113 lineno = function.__code__.co_firstlineno 

114 if curdir is not None: 

115 relfn = fn.relto(curdir) 

116 if relfn: 

117 return "%s:%d" % (relfn, lineno + 1) 

118 return "%s:%d" % (fn, lineno + 1) 

119 

120 

121def num_mock_patch_args(function) -> int: 

122 """ return number of arguments used up by mock arguments (if any) """ 

123 patchings = getattr(function, "patchings", None) 

124 if not patchings: 

125 return 0 

126 

127 mock_sentinel = getattr(sys.modules.get("mock"), "DEFAULT", object()) 

128 ut_mock_sentinel = getattr(sys.modules.get("unittest.mock"), "DEFAULT", object()) 

129 

130 return len( 

131 [ 

132 p 

133 for p in patchings 

134 if not p.attribute_name 

135 and (p.new is mock_sentinel or p.new is ut_mock_sentinel) 

136 ] 

137 ) 

138 

139 

140def getfuncargnames( 

141 function: Callable[..., Any], 

142 *, 

143 name: str = "", 

144 is_method: bool = False, 

145 cls: Optional[type] = None 

146) -> Tuple[str, ...]: 

147 """Returns the names of a function's mandatory arguments. 

148 

149 This should return the names of all function arguments that: 

150 * Aren't bound to an instance or type as in instance or class methods. 

151 * Don't have default values. 

152 * Aren't bound with functools.partial. 

153 * Aren't replaced with mocks. 

154 

155 The is_method and cls arguments indicate that the function should 

156 be treated as a bound method even though it's not unless, only in 

157 the case of cls, the function is a static method. 

158 

159 The name parameter should be the original name in which the function was collected. 

160 """ 

161 # TODO(RonnyPfannschmidt): This function should be refactored when we 

162 # revisit fixtures. The fixture mechanism should ask the node for 

163 # the fixture names, and not try to obtain directly from the 

164 # function object well after collection has occurred. 

165 

166 # The parameters attribute of a Signature object contains an 

167 # ordered mapping of parameter names to Parameter instances. This 

168 # creates a tuple of the names of the parameters that don't have 

169 # defaults. 

170 try: 

171 parameters = signature(function).parameters 

172 except (ValueError, TypeError) as e: 

173 fail( 

174 "Could not determine arguments of {!r}: {}".format(function, e), 

175 pytrace=False, 

176 ) 

177 

178 arg_names = tuple( 

179 p.name 

180 for p in parameters.values() 

181 if ( 

182 p.kind is Parameter.POSITIONAL_OR_KEYWORD 

183 or p.kind is Parameter.KEYWORD_ONLY 

184 ) 

185 and p.default is Parameter.empty 

186 ) 

187 if not name: 

188 name = function.__name__ 

189 

190 # If this function should be treated as a bound method even though 

191 # it's passed as an unbound method or function, remove the first 

192 # parameter name. 

193 if is_method or ( 

194 cls and not isinstance(cls.__dict__.get(name, None), staticmethod) 

195 ): 

196 arg_names = arg_names[1:] 

197 # Remove any names that will be replaced with mocks. 

198 if hasattr(function, "__wrapped__"): 

199 arg_names = arg_names[num_mock_patch_args(function) :] 

200 return arg_names 

201 

202 

203if sys.version_info < (3, 7): 

204 

205 @contextmanager 

206 def nullcontext(): 

207 yield 

208 

209 

210else: 

211 from contextlib import nullcontext # noqa 

212 

213 

214def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]: 

215 # Note: this code intentionally mirrors the code at the beginning of getfuncargnames, 

216 # to get the arguments which were excluded from its result because they had default values 

217 return tuple( 

218 p.name 

219 for p in signature(function).parameters.values() 

220 if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY) 

221 and p.default is not Parameter.empty 

222 ) 

223 

224 

225_non_printable_ascii_translate_table = { 

226 i: "\\x{:02x}".format(i) for i in range(128) if i not in range(32, 127) 

227} 

228_non_printable_ascii_translate_table.update( 

229 {ord("\t"): "\\t", ord("\r"): "\\r", ord("\n"): "\\n"} 

230) 

231 

232 

233def _translate_non_printable(s: str) -> str: 

234 return s.translate(_non_printable_ascii_translate_table) 

235 

236 

237STRING_TYPES = bytes, str 

238 

239 

240def _bytes_to_ascii(val: bytes) -> str: 

241 return val.decode("ascii", "backslashreplace") 

242 

243 

244def ascii_escaped(val: Union[bytes, str]) -> str: 

245 """If val is pure ascii, returns it as a str(). Otherwise, escapes 

246 bytes objects into a sequence of escaped bytes: 

247 

248 b'\xc3\xb4\xc5\xd6' -> '\\xc3\\xb4\\xc5\\xd6' 

249 

250 and escapes unicode objects into a sequence of escaped unicode 

251 ids, e.g.: 

252 

253 '4\\nV\\U00043efa\\x0eMXWB\\x1e\\u3028\\u15fd\\xcd\\U0007d944' 

254 

255 note: 

256 the obvious "v.decode('unicode-escape')" will return 

257 valid utf-8 unicode if it finds them in bytes, but we 

258 want to return escaped bytes for any byte, even if they match 

259 a utf-8 string. 

260 

261 """ 

262 if isinstance(val, bytes): 

263 ret = _bytes_to_ascii(val) 

264 else: 

265 ret = val.encode("unicode_escape").decode("ascii") 

266 return _translate_non_printable(ret) 

267 

268 

269@attr.s 

270class _PytestWrapper: 

271 """Dummy wrapper around a function object for internal use only. 

272 

273 Used to correctly unwrap the underlying function object 

274 when we are creating fixtures, because we wrap the function object ourselves with a decorator 

275 to issue warnings when the fixture function is called directly. 

276 """ 

277 

278 obj = attr.ib() 

279 

280 

281def get_real_func(obj): 

282 """ gets the real function object of the (possibly) wrapped object by 

283 functools.wraps or functools.partial. 

284 """ 

285 start_obj = obj 

286 for i in range(100): 

287 # __pytest_wrapped__ is set by @pytest.fixture when wrapping the fixture function 

288 # to trigger a warning if it gets called directly instead of by pytest: we don't 

289 # want to unwrap further than this otherwise we lose useful wrappings like @mock.patch (#3774) 

290 new_obj = getattr(obj, "__pytest_wrapped__", None) 

291 if isinstance(new_obj, _PytestWrapper): 

292 obj = new_obj.obj 

293 break 

294 new_obj = getattr(obj, "__wrapped__", None) 

295 if new_obj is None: 

296 break 

297 obj = new_obj 

298 else: 

299 raise ValueError( 

300 ("could not find real function of {start}\nstopped at {current}").format( 

301 start=saferepr(start_obj), current=saferepr(obj) 

302 ) 

303 ) 

304 if isinstance(obj, functools.partial): 

305 obj = obj.func 

306 return obj 

307 

308 

309def get_real_method(obj, holder): 

310 """ 

311 Attempts to obtain the real function object that might be wrapping ``obj``, while at the same time 

312 returning a bound method to ``holder`` if the original object was a bound method. 

313 """ 

314 try: 

315 is_method = hasattr(obj, "__func__") 

316 obj = get_real_func(obj) 

317 except Exception: # pragma: no cover 

318 return obj 

319 if is_method and hasattr(obj, "__get__") and callable(obj.__get__): 

320 obj = obj.__get__(holder) 

321 return obj 

322 

323 

324def getimfunc(func): 

325 try: 

326 return func.__func__ 

327 except AttributeError: 

328 return func 

329 

330 

331def safe_getattr(object: Any, name: str, default: Any) -> Any: 

332 """ Like getattr but return default upon any Exception or any OutcomeException. 

333 

334 Attribute access can potentially fail for 'evil' Python objects. 

335 See issue #214. 

336 It catches OutcomeException because of #2490 (issue #580), new outcomes are derived from BaseException 

337 instead of Exception (for more details check #2707) 

338 """ 

339 try: 

340 return getattr(object, name, default) 

341 except TEST_OUTCOME: 

342 return default 

343 

344 

345def safe_isclass(obj: object) -> bool: 

346 """Ignore any exception via isinstance on Python 3.""" 

347 try: 

348 return inspect.isclass(obj) 

349 except Exception: 

350 return False 

351 

352 

353if sys.version_info < (3, 5, 2): 

354 

355 def overload(f): # noqa: F811 

356 return f 

357 

358 

359if getattr(attr, "__version_info__", ()) >= (19, 2): 

360 ATTRS_EQ_FIELD = "eq" 

361else: 

362 ATTRS_EQ_FIELD = "cmp" 

363 

364 

365if sys.version_info >= (3, 8): 

366 from functools import cached_property 

367else: 

368 

369 class cached_property(Generic[_S, _T]): 

370 __slots__ = ("func", "__doc__") 

371 

372 def __init__(self, func: Callable[[_S], _T]) -> None: 

373 self.func = func 

374 self.__doc__ = func.__doc__ 

375 

376 @overload 

377 def __get__( 

378 self, instance: None, owner: Optional["Type[_S]"] = ... 

379 ) -> "cached_property[_S, _T]": 

380 raise NotImplementedError() 

381 

382 @overload # noqa: F811 

383 def __get__( # noqa: F811 

384 self, instance: _S, owner: Optional["Type[_S]"] = ... 

385 ) -> _T: 

386 raise NotImplementedError() 

387 

388 def __get__(self, instance, owner=None): # noqa: F811 

389 if instance is None: 

390 return self 

391 value = instance.__dict__[self.func.__name__] = self.func(instance) 

392 return value 

393 

394 

395# Sometimes an algorithm needs a dict which yields items in the order in which 

396# they were inserted when iterated. Since Python 3.7, `dict` preserves 

397# insertion order. Since `dict` is faster and uses less memory than 

398# `OrderedDict`, prefer to use it if possible. 

399if sys.version_info >= (3, 7): 

400 order_preserving_dict = dict 

401else: 

402 from collections import OrderedDict 

403 

404 order_preserving_dict = OrderedDict 

405 

406 

407# Perform exhaustiveness checking. 

408# 

409# Consider this example: 

410# 

411# MyUnion = Union[int, str] 

412# 

413# def handle(x: MyUnion) -> int { 

414# if isinstance(x, int): 

415# return 1 

416# elif isinstance(x, str): 

417# return 2 

418# else: 

419# raise Exception('unreachable') 

420# 

421# Now suppose we add a new variant: 

422# 

423# MyUnion = Union[int, str, bytes] 

424# 

425# After doing this, we must remember ourselves to go and update the handle 

426# function to handle the new variant. 

427# 

428# With `assert_never` we can do better: 

429# 

430# // throw new Error('unreachable'); 

431# return assert_never(x) 

432# 

433# Now, if we forget to handle the new variant, the type-checker will emit a 

434# compile-time error, instead of the runtime error we would have gotten 

435# previously. 

436# 

437# This also work for Enums (if you use `is` to compare) and Literals. 

438def assert_never(value: "NoReturn") -> "NoReturn": 

439 assert False, "Unhandled value: {} ({})".format(value, type(value).__name__)