Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Rewrite assertion AST to produce nice error messages""" 

2import ast 

3import errno 

4import functools 

5import importlib.abc 

6import importlib.machinery 

7import importlib.util 

8import io 

9import itertools 

10import marshal 

11import os 

12import struct 

13import sys 

14import tokenize 

15import types 

16from typing import Callable 

17from typing import Dict 

18from typing import IO 

19from typing import List 

20from typing import Optional 

21from typing import Sequence 

22from typing import Set 

23from typing import Tuple 

24from typing import Union 

25 

26import py 

27 

28from _pytest._io.saferepr import saferepr 

29from _pytest._version import version 

30from _pytest.assertion import util 

31from _pytest.assertion.util import ( # noqa: F401 

32 format_explanation as _format_explanation, 

33) 

34from _pytest.compat import fspath 

35from _pytest.compat import TYPE_CHECKING 

36from _pytest.config import Config 

37from _pytest.main import Session 

38from _pytest.pathlib import fnmatch_ex 

39from _pytest.pathlib import Path 

40from _pytest.pathlib import PurePath 

41from _pytest.store import StoreKey 

42 

43if TYPE_CHECKING: 

44 from _pytest.assertion import AssertionState # noqa: F401 

45 

46 

47assertstate_key = StoreKey["AssertionState"]() 

48 

49 

50# pytest caches rewritten pycs in pycache dirs 

51PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version) 

52PYC_EXT = ".py" + (__debug__ and "c" or "o") 

53PYC_TAIL = "." + PYTEST_TAG + PYC_EXT 

54 

55 

56class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): 

57 """PEP302/PEP451 import hook which rewrites asserts.""" 

58 

59 def __init__(self, config: Config) -> None: 

60 self.config = config 

61 try: 

62 self.fnpats = config.getini("python_files") 

63 except ValueError: 

64 self.fnpats = ["test_*.py", "*_test.py"] 

65 self.session = None # type: Optional[Session] 

66 self._rewritten_names = set() # type: Set[str] 

67 self._must_rewrite = set() # type: Set[str] 

68 # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, 

69 # which might result in infinite recursion (#3506) 

70 self._writing_pyc = False 

71 self._basenames_to_check_rewrite = {"conftest"} 

72 self._marked_for_rewrite_cache = {} # type: Dict[str, bool] 

73 self._session_paths_checked = False 

74 

75 def set_session(self, session: Optional[Session]) -> None: 

76 self.session = session 

77 self._session_paths_checked = False 

78 

79 # Indirection so we can mock calls to find_spec originated from the hook during testing 

80 _find_spec = importlib.machinery.PathFinder.find_spec 

81 

82 def find_spec( 

83 self, 

84 name: str, 

85 path: Optional[Sequence[Union[str, bytes]]] = None, 

86 target: Optional[types.ModuleType] = None, 

87 ) -> Optional[importlib.machinery.ModuleSpec]: 

88 if self._writing_pyc: 

89 return None 

90 state = self.config._store[assertstate_key] 

91 if self._early_rewrite_bailout(name, state): 

92 return None 

93 state.trace("find_module called for: %s" % name) 

94 

95 # Type ignored because mypy is confused about the `self` binding here. 

96 spec = self._find_spec(name, path) # type: ignore 

97 if ( 

98 # the import machinery could not find a file to import 

99 spec is None 

100 # this is a namespace package (without `__init__.py`) 

101 # there's nothing to rewrite there 

102 # python3.5 - python3.6: `namespace` 

103 # python3.7+: `None` 

104 or spec.origin == "namespace" 

105 or spec.origin is None 

106 # we can only rewrite source files 

107 or not isinstance(spec.loader, importlib.machinery.SourceFileLoader) 

108 # if the file doesn't exist, we can't rewrite it 

109 or not os.path.exists(spec.origin) 

110 ): 

111 return None 

112 else: 

113 fn = spec.origin 

114 

115 if not self._should_rewrite(name, fn, state): 

116 return None 

117 

118 return importlib.util.spec_from_file_location( 

119 name, 

120 fn, 

121 loader=self, 

122 submodule_search_locations=spec.submodule_search_locations, 

123 ) 

124 

125 def create_module( 

126 self, spec: importlib.machinery.ModuleSpec 

127 ) -> Optional[types.ModuleType]: 

128 return None # default behaviour is fine 

129 

130 def exec_module(self, module: types.ModuleType) -> None: 

131 assert module.__spec__ is not None 

132 assert module.__spec__.origin is not None 

133 fn = Path(module.__spec__.origin) 

134 state = self.config._store[assertstate_key] 

135 

136 self._rewritten_names.add(module.__name__) 

137 

138 # The requested module looks like a test file, so rewrite it. This is 

139 # the most magical part of the process: load the source, rewrite the 

140 # asserts, and load the rewritten source. We also cache the rewritten 

141 # module code in a special pyc. We must be aware of the possibility of 

142 # concurrent pytest processes rewriting and loading pycs. To avoid 

143 # tricky race conditions, we maintain the following invariant: The 

144 # cached pyc is always a complete, valid pyc. Operations on it must be 

145 # atomic. POSIX's atomic rename comes in handy. 

146 write = not sys.dont_write_bytecode 

147 cache_dir = get_cache_dir(fn) 

148 if write: 

149 ok = try_makedirs(cache_dir) 

150 if not ok: 

151 write = False 

152 state.trace("read only directory: {}".format(cache_dir)) 

153 

154 cache_name = fn.name[:-3] + PYC_TAIL 

155 pyc = cache_dir / cache_name 

156 # Notice that even if we're in a read-only directory, I'm going 

157 # to check for a cached pyc. This may not be optimal... 

158 co = _read_pyc(fn, pyc, state.trace) 

159 if co is None: 

160 state.trace("rewriting {!r}".format(fn)) 

161 source_stat, co = _rewrite_test(fn, self.config) 

162 if write: 

163 self._writing_pyc = True 

164 try: 

165 _write_pyc(state, co, source_stat, pyc) 

166 finally: 

167 self._writing_pyc = False 

168 else: 

169 state.trace("found cached rewritten pyc for {}".format(fn)) 

170 exec(co, module.__dict__) 

171 

172 def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool: 

173 """This is a fast way to get out of rewriting modules. 

174 

175 Profiling has shown that the call to PathFinder.find_spec (inside of 

176 the find_spec from this class) is a major slowdown, so, this method 

177 tries to filter what we're sure won't be rewritten before getting to 

178 it. 

179 """ 

180 if self.session is not None and not self._session_paths_checked: 

181 self._session_paths_checked = True 

182 for initial_path in self.session._initialpaths: 

183 # Make something as c:/projects/my_project/path.py -> 

184 # ['c:', 'projects', 'my_project', 'path.py'] 

185 parts = str(initial_path).split(os.path.sep) 

186 # add 'path' to basenames to be checked. 

187 self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0]) 

188 

189 # Note: conftest already by default in _basenames_to_check_rewrite. 

190 parts = name.split(".") 

191 if parts[-1] in self._basenames_to_check_rewrite: 

192 return False 

193 

194 # For matching the name it must be as if it was a filename. 

195 path = PurePath(os.path.sep.join(parts) + ".py") 

196 

197 for pat in self.fnpats: 

198 # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based 

199 # on the name alone because we need to match against the full path 

200 if os.path.dirname(pat): 

201 return False 

202 if fnmatch_ex(pat, path): 

203 return False 

204 

205 if self._is_marked_for_rewrite(name, state): 

206 return False 

207 

208 state.trace("early skip of rewriting module: {}".format(name)) 

209 return True 

210 

211 def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool: 

212 # always rewrite conftest files 

213 if os.path.basename(fn) == "conftest.py": 

214 state.trace("rewriting conftest file: {!r}".format(fn)) 

215 return True 

216 

217 if self.session is not None: 

218 if self.session.isinitpath(py.path.local(fn)): 

219 state.trace( 

220 "matched test file (was specified on cmdline): {!r}".format(fn) 

221 ) 

222 return True 

223 

224 # modules not passed explicitly on the command line are only 

225 # rewritten if they match the naming convention for test files 

226 fn_path = PurePath(fn) 

227 for pat in self.fnpats: 

228 if fnmatch_ex(pat, fn_path): 

229 state.trace("matched test file {!r}".format(fn)) 

230 return True 

231 

232 return self._is_marked_for_rewrite(name, state) 

233 

234 def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool: 

235 try: 

236 return self._marked_for_rewrite_cache[name] 

237 except KeyError: 

238 for marked in self._must_rewrite: 

239 if name == marked or name.startswith(marked + "."): 

240 state.trace( 

241 "matched marked file {!r} (from {!r})".format(name, marked) 

242 ) 

243 self._marked_for_rewrite_cache[name] = True 

244 return True 

245 

246 self._marked_for_rewrite_cache[name] = False 

247 return False 

248 

249 def mark_rewrite(self, *names: str) -> None: 

250 """Mark import names as needing to be rewritten. 

251 

252 The named module or package as well as any nested modules will 

253 be rewritten on import. 

254 """ 

255 already_imported = ( 

256 set(names).intersection(sys.modules).difference(self._rewritten_names) 

257 ) 

258 for name in already_imported: 

259 mod = sys.modules[name] 

260 if not AssertionRewriter.is_rewrite_disabled( 

261 mod.__doc__ or "" 

262 ) and not isinstance(mod.__loader__, type(self)): 

263 self._warn_already_imported(name) 

264 self._must_rewrite.update(names) 

265 self._marked_for_rewrite_cache.clear() 

266 

267 def _warn_already_imported(self, name: str) -> None: 

268 from _pytest.warning_types import PytestAssertRewriteWarning 

269 from _pytest.warnings import _issue_warning_captured 

270 

271 _issue_warning_captured( 

272 PytestAssertRewriteWarning( 

273 "Module already imported so cannot be rewritten: %s" % name 

274 ), 

275 self.config.hook, 

276 stacklevel=5, 

277 ) 

278 

279 def get_data(self, pathname: Union[str, bytes]) -> bytes: 

280 """Optional PEP302 get_data API.""" 

281 with open(pathname, "rb") as f: 

282 return f.read() 

283 

284 

285def _write_pyc_fp( 

286 fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType 

287) -> None: 

288 # Technically, we don't have to have the same pyc format as 

289 # (C)Python, since these "pycs" should never be seen by builtin 

290 # import. However, there's little reason deviate. 

291 fp.write(importlib.util.MAGIC_NUMBER) 

292 # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) 

293 mtime = int(source_stat.st_mtime) & 0xFFFFFFFF 

294 size = source_stat.st_size & 0xFFFFFFFF 

295 # "<LL" stands for 2 unsigned longs, little-ending 

296 fp.write(struct.pack("<LL", mtime, size)) 

297 fp.write(marshal.dumps(co)) 

298 

299 

300if sys.platform == "win32": 

301 from atomicwrites import atomic_write 

302 

303 def _write_pyc( 

304 state: "AssertionState", 

305 co: types.CodeType, 

306 source_stat: os.stat_result, 

307 pyc: Path, 

308 ) -> bool: 

309 try: 

310 with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp: 

311 _write_pyc_fp(fp, source_stat, co) 

312 except OSError as e: 

313 state.trace("error writing pyc file at {}: {}".format(pyc, e)) 

314 # we ignore any failure to write the cache file 

315 # there are many reasons, permission-denied, pycache dir being a 

316 # file etc. 

317 return False 

318 return True 

319 

320 

321else: 

322 

323 def _write_pyc( 

324 state: "AssertionState", 

325 co: types.CodeType, 

326 source_stat: os.stat_result, 

327 pyc: Path, 

328 ) -> bool: 

329 proc_pyc = "{}.{}".format(pyc, os.getpid()) 

330 try: 

331 fp = open(proc_pyc, "wb") 

332 except OSError as e: 

333 state.trace( 

334 "error writing pyc file at {}: errno={}".format(proc_pyc, e.errno) 

335 ) 

336 return False 

337 

338 try: 

339 _write_pyc_fp(fp, source_stat, co) 

340 os.rename(proc_pyc, fspath(pyc)) 

341 except OSError as e: 

342 state.trace("error writing pyc file at {}: {}".format(pyc, e)) 

343 # we ignore any failure to write the cache file 

344 # there are many reasons, permission-denied, pycache dir being a 

345 # file etc. 

346 return False 

347 finally: 

348 fp.close() 

349 return True 

350 

351 

352def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: 

353 """read and rewrite *fn* and return the code object.""" 

354 fn_ = fspath(fn) 

355 stat = os.stat(fn_) 

356 with open(fn_, "rb") as f: 

357 source = f.read() 

358 tree = ast.parse(source, filename=fn_) 

359 rewrite_asserts(tree, source, fn_, config) 

360 co = compile(tree, fn_, "exec", dont_inherit=True) 

361 return stat, co 

362 

363 

364def _read_pyc( 

365 source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None 

366) -> Optional[types.CodeType]: 

367 """Possibly read a pytest pyc containing rewritten code. 

368 

369 Return rewritten code if successful or None if not. 

370 """ 

371 try: 

372 fp = open(fspath(pyc), "rb") 

373 except OSError: 

374 return None 

375 with fp: 

376 try: 

377 stat_result = os.stat(fspath(source)) 

378 mtime = int(stat_result.st_mtime) 

379 size = stat_result.st_size 

380 data = fp.read(12) 

381 except OSError as e: 

382 trace("_read_pyc({}): OSError {}".format(source, e)) 

383 return None 

384 # Check for invalid or out of date pyc file. 

385 if ( 

386 len(data) != 12 

387 or data[:4] != importlib.util.MAGIC_NUMBER 

388 or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF) 

389 ): 

390 trace("_read_pyc(%s): invalid or out of date pyc" % source) 

391 return None 

392 try: 

393 co = marshal.load(fp) 

394 except Exception as e: 

395 trace("_read_pyc({}): marshal.load error {}".format(source, e)) 

396 return None 

397 if not isinstance(co, types.CodeType): 

398 trace("_read_pyc(%s): not a code object" % source) 

399 return None 

400 return co 

401 

402 

403def rewrite_asserts( 

404 mod: ast.Module, 

405 source: bytes, 

406 module_path: Optional[str] = None, 

407 config: Optional[Config] = None, 

408) -> None: 

409 """Rewrite the assert statements in mod.""" 

410 AssertionRewriter(module_path, config, source).run(mod) 

411 

412 

413def _saferepr(obj: object) -> str: 

414 """Get a safe repr of an object for assertion error messages. 

415 

416 The assertion formatting (util.format_explanation()) requires 

417 newlines to be escaped since they are a special character for it. 

418 Normally assertion.util.format_explanation() does this but for a 

419 custom repr it is possible to contain one of the special escape 

420 sequences, especially '\n{' and '\n}' are likely to be present in 

421 JSON reprs. 

422 

423 """ 

424 return saferepr(obj).replace("\n", "\\n") 

425 

426 

427def _format_assertmsg(obj: object) -> str: 

428 """Format the custom assertion message given. 

429 

430 For strings this simply replaces newlines with '\n~' so that 

431 util.format_explanation() will preserve them instead of escaping 

432 newlines. For other objects saferepr() is used first. 

433 

434 """ 

435 # reprlib appears to have a bug which means that if a string 

436 # contains a newline it gets escaped, however if an object has a 

437 # .__repr__() which contains newlines it does not get escaped. 

438 # However in either case we want to preserve the newline. 

439 replaces = [("\n", "\n~"), ("%", "%%")] 

440 if not isinstance(obj, str): 

441 obj = saferepr(obj) 

442 replaces.append(("\\n", "\n~")) 

443 

444 for r1, r2 in replaces: 

445 obj = obj.replace(r1, r2) 

446 

447 return obj 

448 

449 

450def _should_repr_global_name(obj: object) -> bool: 

451 if callable(obj): 

452 return False 

453 

454 try: 

455 return not hasattr(obj, "__name__") 

456 except Exception: 

457 return True 

458 

459 

460def _format_boolop(explanations, is_or: bool): 

461 explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" 

462 if isinstance(explanation, str): 

463 return explanation.replace("%", "%%") 

464 else: 

465 return explanation.replace(b"%", b"%%") 

466 

467 

468def _call_reprcompare( 

469 ops: Sequence[str], 

470 results: Sequence[bool], 

471 expls: Sequence[str], 

472 each_obj: Sequence[object], 

473) -> str: 

474 for i, res, expl in zip(range(len(ops)), results, expls): 

475 try: 

476 done = not res 

477 except Exception: 

478 done = True 

479 if done: 

480 break 

481 if util._reprcompare is not None: 

482 custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) 

483 if custom is not None: 

484 return custom 

485 return expl 

486 

487 

488def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: 

489 if util._assertion_pass is not None: 

490 util._assertion_pass(lineno, orig, expl) 

491 

492 

493def _check_if_assertion_pass_impl() -> bool: 

494 """Checks if any plugins implement the pytest_assertion_pass hook 

495 in order not to generate explanation unecessarily (might be expensive)""" 

496 return True if util._assertion_pass else False 

497 

498 

499UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} 

500 

501BINOP_MAP = { 

502 ast.BitOr: "|", 

503 ast.BitXor: "^", 

504 ast.BitAnd: "&", 

505 ast.LShift: "<<", 

506 ast.RShift: ">>", 

507 ast.Add: "+", 

508 ast.Sub: "-", 

509 ast.Mult: "*", 

510 ast.Div: "/", 

511 ast.FloorDiv: "//", 

512 ast.Mod: "%%", # escaped for string formatting 

513 ast.Eq: "==", 

514 ast.NotEq: "!=", 

515 ast.Lt: "<", 

516 ast.LtE: "<=", 

517 ast.Gt: ">", 

518 ast.GtE: ">=", 

519 ast.Pow: "**", 

520 ast.Is: "is", 

521 ast.IsNot: "is not", 

522 ast.In: "in", 

523 ast.NotIn: "not in", 

524 ast.MatMult: "@", 

525} 

526 

527 

528def set_location(node, lineno, col_offset): 

529 """Set node location information recursively.""" 

530 

531 def _fix(node, lineno, col_offset): 

532 if "lineno" in node._attributes: 

533 node.lineno = lineno 

534 if "col_offset" in node._attributes: 

535 node.col_offset = col_offset 

536 for child in ast.iter_child_nodes(node): 

537 _fix(child, lineno, col_offset) 

538 

539 _fix(node, lineno, col_offset) 

540 return node 

541 

542 

543def _get_assertion_exprs(src: bytes) -> Dict[int, str]: 

544 """Returns a mapping from {lineno: "assertion test expression"}""" 

545 ret = {} # type: Dict[int, str] 

546 

547 depth = 0 

548 lines = [] # type: List[str] 

549 assert_lineno = None # type: Optional[int] 

550 seen_lines = set() # type: Set[int] 

551 

552 def _write_and_reset() -> None: 

553 nonlocal depth, lines, assert_lineno, seen_lines 

554 assert assert_lineno is not None 

555 ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\") 

556 depth = 0 

557 lines = [] 

558 assert_lineno = None 

559 seen_lines = set() 

560 

561 tokens = tokenize.tokenize(io.BytesIO(src).readline) 

562 for tp, source, (lineno, offset), _, line in tokens: 

563 if tp == tokenize.NAME and source == "assert": 

564 assert_lineno = lineno 

565 elif assert_lineno is not None: 

566 # keep track of depth for the assert-message `,` lookup 

567 if tp == tokenize.OP and source in "([{": 

568 depth += 1 

569 elif tp == tokenize.OP and source in ")]}": 

570 depth -= 1 

571 

572 if not lines: 

573 lines.append(line[offset:]) 

574 seen_lines.add(lineno) 

575 # a non-nested comma separates the expression from the message 

576 elif depth == 0 and tp == tokenize.OP and source == ",": 

577 # one line assert with message 

578 if lineno in seen_lines and len(lines) == 1: 

579 offset_in_trimmed = offset + len(lines[-1]) - len(line) 

580 lines[-1] = lines[-1][:offset_in_trimmed] 

581 # multi-line assert with message 

582 elif lineno in seen_lines: 

583 lines[-1] = lines[-1][:offset] 

584 # multi line assert with escapd newline before message 

585 else: 

586 lines.append(line[:offset]) 

587 _write_and_reset() 

588 elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}: 

589 _write_and_reset() 

590 elif lines and lineno not in seen_lines: 

591 lines.append(line) 

592 seen_lines.add(lineno) 

593 

594 return ret 

595 

596 

597class AssertionRewriter(ast.NodeVisitor): 

598 """Assertion rewriting implementation. 

599 

600 The main entrypoint is to call .run() with an ast.Module instance, 

601 this will then find all the assert statements and rewrite them to 

602 provide intermediate values and a detailed assertion error. See 

603 http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html 

604 for an overview of how this works. 

605 

606 The entry point here is .run() which will iterate over all the 

607 statements in an ast.Module and for each ast.Assert statement it 

608 finds call .visit() with it. Then .visit_Assert() takes over and 

609 is responsible for creating new ast statements to replace the 

610 original assert statement: it rewrites the test of an assertion 

611 to provide intermediate values and replace it with an if statement 

612 which raises an assertion error with a detailed explanation in 

613 case the expression is false and calls pytest_assertion_pass hook 

614 if expression is true. 

615 

616 For this .visit_Assert() uses the visitor pattern to visit all the 

617 AST nodes of the ast.Assert.test field, each visit call returning 

618 an AST node and the corresponding explanation string. During this 

619 state is kept in several instance attributes: 

620 

621 :statements: All the AST statements which will replace the assert 

622 statement. 

623 

624 :variables: This is populated by .variable() with each variable 

625 used by the statements so that they can all be set to None at 

626 the end of the statements. 

627 

628 :variable_counter: Counter to create new unique variables needed 

629 by statements. Variables are created using .variable() and 

630 have the form of "@py_assert0". 

631 

632 :expl_stmts: The AST statements which will be executed to get 

633 data from the assertion. This is the code which will construct 

634 the detailed assertion message that is used in the AssertionError 

635 or for the pytest_assertion_pass hook. 

636 

637 :explanation_specifiers: A dict filled by .explanation_param() 

638 with %-formatting placeholders and their corresponding 

639 expressions to use in the building of an assertion message. 

640 This is used by .pop_format_context() to build a message. 

641 

642 :stack: A stack of the explanation_specifiers dicts maintained by 

643 .push_format_context() and .pop_format_context() which allows 

644 to build another %-formatted string while already building one. 

645 

646 This state is reset on every new assert statement visited and used 

647 by the other visitors. 

648 

649 """ 

650 

651 def __init__( 

652 self, module_path: Optional[str], config: Optional[Config], source: bytes 

653 ) -> None: 

654 super().__init__() 

655 self.module_path = module_path 

656 self.config = config 

657 if config is not None: 

658 self.enable_assertion_pass_hook = config.getini( 

659 "enable_assertion_pass_hook" 

660 ) 

661 else: 

662 self.enable_assertion_pass_hook = False 

663 self.source = source 

664 

665 @functools.lru_cache(maxsize=1) 

666 def _assert_expr_to_lineno(self) -> Dict[int, str]: 

667 return _get_assertion_exprs(self.source) 

668 

669 def run(self, mod: ast.Module) -> None: 

670 """Find all assert statements in *mod* and rewrite them.""" 

671 if not mod.body: 

672 # Nothing to do. 

673 return 

674 # Insert some special imports at the top of the module but after any 

675 # docstrings and __future__ imports. 

676 aliases = [ 

677 ast.alias("builtins", "@py_builtins"), 

678 ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), 

679 ] 

680 doc = getattr(mod, "docstring", None) 

681 expect_docstring = doc is None 

682 if doc is not None and self.is_rewrite_disabled(doc): 

683 return 

684 pos = 0 

685 lineno = 1 

686 for item in mod.body: 

687 if ( 

688 expect_docstring 

689 and isinstance(item, ast.Expr) 

690 and isinstance(item.value, ast.Str) 

691 ): 

692 doc = item.value.s 

693 if self.is_rewrite_disabled(doc): 

694 return 

695 expect_docstring = False 

696 elif ( 

697 not isinstance(item, ast.ImportFrom) 

698 or item.level > 0 

699 or item.module != "__future__" 

700 ): 

701 lineno = item.lineno 

702 break 

703 pos += 1 

704 else: 

705 lineno = item.lineno 

706 imports = [ 

707 ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases 

708 ] 

709 mod.body[pos:pos] = imports 

710 # Collect asserts. 

711 nodes = [mod] # type: List[ast.AST] 

712 while nodes: 

713 node = nodes.pop() 

714 for name, field in ast.iter_fields(node): 

715 if isinstance(field, list): 

716 new = [] # type: List 

717 for i, child in enumerate(field): 

718 if isinstance(child, ast.Assert): 

719 # Transform assert. 

720 new.extend(self.visit(child)) 

721 else: 

722 new.append(child) 

723 if isinstance(child, ast.AST): 

724 nodes.append(child) 

725 setattr(node, name, new) 

726 elif ( 

727 isinstance(field, ast.AST) 

728 # Don't recurse into expressions as they can't contain 

729 # asserts. 

730 and not isinstance(field, ast.expr) 

731 ): 

732 nodes.append(field) 

733 

734 @staticmethod 

735 def is_rewrite_disabled(docstring: str) -> bool: 

736 return "PYTEST_DONT_REWRITE" in docstring 

737 

738 def variable(self) -> str: 

739 """Get a new variable.""" 

740 # Use a character invalid in python identifiers to avoid clashing. 

741 name = "@py_assert" + str(next(self.variable_counter)) 

742 self.variables.append(name) 

743 return name 

744 

745 def assign(self, expr: ast.expr) -> ast.Name: 

746 """Give *expr* a name.""" 

747 name = self.variable() 

748 self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) 

749 return ast.Name(name, ast.Load()) 

750 

751 def display(self, expr: ast.expr) -> ast.expr: 

752 """Call saferepr on the expression.""" 

753 return self.helper("_saferepr", expr) 

754 

755 def helper(self, name: str, *args: ast.expr) -> ast.expr: 

756 """Call a helper in this module.""" 

757 py_name = ast.Name("@pytest_ar", ast.Load()) 

758 attr = ast.Attribute(py_name, name, ast.Load()) 

759 return ast.Call(attr, list(args), []) 

760 

761 def builtin(self, name: str) -> ast.Attribute: 

762 """Return the builtin called *name*.""" 

763 builtin_name = ast.Name("@py_builtins", ast.Load()) 

764 return ast.Attribute(builtin_name, name, ast.Load()) 

765 

766 def explanation_param(self, expr: ast.expr) -> str: 

767 """Return a new named %-formatting placeholder for expr. 

768 

769 This creates a %-formatting placeholder for expr in the 

770 current formatting context, e.g. ``%(py0)s``. The placeholder 

771 and expr are placed in the current format context so that it 

772 can be used on the next call to .pop_format_context(). 

773 

774 """ 

775 specifier = "py" + str(next(self.variable_counter)) 

776 self.explanation_specifiers[specifier] = expr 

777 return "%(" + specifier + ")s" 

778 

779 def push_format_context(self) -> None: 

780 """Create a new formatting context. 

781 

782 The format context is used for when an explanation wants to 

783 have a variable value formatted in the assertion message. In 

784 this case the value required can be added using 

785 .explanation_param(). Finally .pop_format_context() is used 

786 to format a string of %-formatted values as added by 

787 .explanation_param(). 

788 

789 """ 

790 self.explanation_specifiers = {} # type: Dict[str, ast.expr] 

791 self.stack.append(self.explanation_specifiers) 

792 

793 def pop_format_context(self, expl_expr: ast.expr) -> ast.Name: 

794 """Format the %-formatted string with current format context. 

795 

796 The expl_expr should be an str ast.expr instance constructed from 

797 the %-placeholders created by .explanation_param(). This will 

798 add the required code to format said string to .expl_stmts and 

799 return the ast.Name instance of the formatted string. 

800 

801 """ 

802 current = self.stack.pop() 

803 if self.stack: 

804 self.explanation_specifiers = self.stack[-1] 

805 keys = [ast.Str(key) for key in current.keys()] 

806 format_dict = ast.Dict(keys, list(current.values())) 

807 form = ast.BinOp(expl_expr, ast.Mod(), format_dict) 

808 name = "@py_format" + str(next(self.variable_counter)) 

809 if self.enable_assertion_pass_hook: 

810 self.format_variables.append(name) 

811 self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) 

812 return ast.Name(name, ast.Load()) 

813 

814 def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]: 

815 """Handle expressions we don't have custom code for.""" 

816 assert isinstance(node, ast.expr) 

817 res = self.assign(node) 

818 return res, self.explanation_param(self.display(res)) 

819 

820 def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: 

821 """Return the AST statements to replace the ast.Assert instance. 

822 

823 This rewrites the test of an assertion to provide 

824 intermediate values and replace it with an if statement which 

825 raises an assertion error with a detailed explanation in case 

826 the expression is false. 

827 

828 """ 

829 if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: 

830 from _pytest.warning_types import PytestAssertRewriteWarning 

831 import warnings 

832 

833 # TODO: This assert should not be needed. 

834 assert self.module_path is not None 

835 warnings.warn_explicit( 

836 PytestAssertRewriteWarning( 

837 "assertion is always true, perhaps remove parentheses?" 

838 ), 

839 category=None, 

840 filename=fspath(self.module_path), 

841 lineno=assert_.lineno, 

842 ) 

843 

844 self.statements = [] # type: List[ast.stmt] 

845 self.variables = [] # type: List[str] 

846 self.variable_counter = itertools.count() 

847 

848 if self.enable_assertion_pass_hook: 

849 self.format_variables = [] # type: List[str] 

850 

851 self.stack = [] # type: List[Dict[str, ast.expr]] 

852 self.expl_stmts = [] # type: List[ast.stmt] 

853 self.push_format_context() 

854 # Rewrite assert into a bunch of statements. 

855 top_condition, explanation = self.visit(assert_.test) 

856 

857 negation = ast.UnaryOp(ast.Not(), top_condition) 

858 

859 if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook 

860 msg = self.pop_format_context(ast.Str(explanation)) 

861 

862 # Failed 

863 if assert_.msg: 

864 assertmsg = self.helper("_format_assertmsg", assert_.msg) 

865 gluestr = "\n>assert " 

866 else: 

867 assertmsg = ast.Str("") 

868 gluestr = "assert " 

869 err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) 

870 err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) 

871 err_name = ast.Name("AssertionError", ast.Load()) 

872 fmt = self.helper("_format_explanation", err_msg) 

873 exc = ast.Call(err_name, [fmt], []) 

874 raise_ = ast.Raise(exc, None) 

875 statements_fail = [] 

876 statements_fail.extend(self.expl_stmts) 

877 statements_fail.append(raise_) 

878 

879 # Passed 

880 fmt_pass = self.helper("_format_explanation", msg) 

881 orig = self._assert_expr_to_lineno()[assert_.lineno] 

882 hook_call_pass = ast.Expr( 

883 self.helper( 

884 "_call_assertion_pass", 

885 ast.Num(assert_.lineno), 

886 ast.Str(orig), 

887 fmt_pass, 

888 ) 

889 ) 

890 # If any hooks implement assert_pass hook 

891 hook_impl_test = ast.If( 

892 self.helper("_check_if_assertion_pass_impl"), 

893 self.expl_stmts + [hook_call_pass], 

894 [], 

895 ) 

896 statements_pass = [hook_impl_test] 

897 

898 # Test for assertion condition 

899 main_test = ast.If(negation, statements_fail, statements_pass) 

900 self.statements.append(main_test) 

901 if self.format_variables: 

902 variables = [ 

903 ast.Name(name, ast.Store()) for name in self.format_variables 

904 ] 

905 clear_format = ast.Assign(variables, ast.NameConstant(None)) 

906 self.statements.append(clear_format) 

907 

908 else: # Original assertion rewriting 

909 # Create failure message. 

910 body = self.expl_stmts 

911 self.statements.append(ast.If(negation, body, [])) 

912 if assert_.msg: 

913 assertmsg = self.helper("_format_assertmsg", assert_.msg) 

914 explanation = "\n>assert " + explanation 

915 else: 

916 assertmsg = ast.Str("") 

917 explanation = "assert " + explanation 

918 template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) 

919 msg = self.pop_format_context(template) 

920 fmt = self.helper("_format_explanation", msg) 

921 err_name = ast.Name("AssertionError", ast.Load()) 

922 exc = ast.Call(err_name, [fmt], []) 

923 raise_ = ast.Raise(exc, None) 

924 

925 body.append(raise_) 

926 

927 # Clear temporary variables by setting them to None. 

928 if self.variables: 

929 variables = [ast.Name(name, ast.Store()) for name in self.variables] 

930 clear = ast.Assign(variables, ast.NameConstant(None)) 

931 self.statements.append(clear) 

932 # Fix line numbers. 

933 for stmt in self.statements: 

934 set_location(stmt, assert_.lineno, assert_.col_offset) 

935 return self.statements 

936 

937 def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: 

938 # Display the repr of the name if it's a local variable or 

939 # _should_repr_global_name() thinks it's acceptable. 

940 locs = ast.Call(self.builtin("locals"), [], []) 

941 inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs]) 

942 dorepr = self.helper("_should_repr_global_name", name) 

943 test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) 

944 expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) 

945 return name, self.explanation_param(expr) 

946 

947 def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: 

948 res_var = self.variable() 

949 expl_list = self.assign(ast.List([], ast.Load())) 

950 app = ast.Attribute(expl_list, "append", ast.Load()) 

951 is_or = int(isinstance(boolop.op, ast.Or)) 

952 body = save = self.statements 

953 fail_save = self.expl_stmts 

954 levels = len(boolop.values) - 1 

955 self.push_format_context() 

956 # Process each operand, short-circuiting if needed. 

957 for i, v in enumerate(boolop.values): 

958 if i: 

959 fail_inner = [] # type: List[ast.stmt] 

960 # cond is set in a prior loop iteration below 

961 self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa 

962 self.expl_stmts = fail_inner 

963 self.push_format_context() 

964 res, expl = self.visit(v) 

965 body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) 

966 expl_format = self.pop_format_context(ast.Str(expl)) 

967 call = ast.Call(app, [expl_format], []) 

968 self.expl_stmts.append(ast.Expr(call)) 

969 if i < levels: 

970 cond = res # type: ast.expr 

971 if is_or: 

972 cond = ast.UnaryOp(ast.Not(), cond) 

973 inner = [] # type: List[ast.stmt] 

974 self.statements.append(ast.If(cond, inner, [])) 

975 self.statements = body = inner 

976 self.statements = save 

977 self.expl_stmts = fail_save 

978 expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) 

979 expl = self.pop_format_context(expl_template) 

980 return ast.Name(res_var, ast.Load()), self.explanation_param(expl) 

981 

982 def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]: 

983 pattern = UNARY_MAP[unary.op.__class__] 

984 operand_res, operand_expl = self.visit(unary.operand) 

985 res = self.assign(ast.UnaryOp(unary.op, operand_res)) 

986 return res, pattern % (operand_expl,) 

987 

988 def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]: 

989 symbol = BINOP_MAP[binop.op.__class__] 

990 left_expr, left_expl = self.visit(binop.left) 

991 right_expr, right_expl = self.visit(binop.right) 

992 explanation = "({} {} {})".format(left_expl, symbol, right_expl) 

993 res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) 

994 return res, explanation 

995 

996 def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: 

997 """ 

998 visit `ast.Call` nodes 

999 """ 

1000 new_func, func_expl = self.visit(call.func) 

1001 arg_expls = [] 

1002 new_args = [] 

1003 new_kwargs = [] 

1004 for arg in call.args: 

1005 res, expl = self.visit(arg) 

1006 arg_expls.append(expl) 

1007 new_args.append(res) 

1008 for keyword in call.keywords: 

1009 res, expl = self.visit(keyword.value) 

1010 new_kwargs.append(ast.keyword(keyword.arg, res)) 

1011 if keyword.arg: 

1012 arg_expls.append(keyword.arg + "=" + expl) 

1013 else: # **args have `arg` keywords with an .arg of None 

1014 arg_expls.append("**" + expl) 

1015 

1016 expl = "{}({})".format(func_expl, ", ".join(arg_expls)) 

1017 new_call = ast.Call(new_func, new_args, new_kwargs) 

1018 res = self.assign(new_call) 

1019 res_expl = self.explanation_param(self.display(res)) 

1020 outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl) 

1021 return res, outer_expl 

1022 

1023 def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]: 

1024 # From Python 3.5, a Starred node can appear in a function call 

1025 res, expl = self.visit(starred.value) 

1026 new_starred = ast.Starred(res, starred.ctx) 

1027 return new_starred, "*" + expl 

1028 

1029 def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: 

1030 if not isinstance(attr.ctx, ast.Load): 

1031 return self.generic_visit(attr) 

1032 value, value_expl = self.visit(attr.value) 

1033 res = self.assign(ast.Attribute(value, attr.attr, ast.Load())) 

1034 res_expl = self.explanation_param(self.display(res)) 

1035 pat = "%s\n{%s = %s.%s\n}" 

1036 expl = pat % (res_expl, res_expl, value_expl, attr.attr) 

1037 return res, expl 

1038 

1039 def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: 

1040 self.push_format_context() 

1041 left_res, left_expl = self.visit(comp.left) 

1042 if isinstance(comp.left, (ast.Compare, ast.BoolOp)): 

1043 left_expl = "({})".format(left_expl) 

1044 res_variables = [self.variable() for i in range(len(comp.ops))] 

1045 load_names = [ast.Name(v, ast.Load()) for v in res_variables] 

1046 store_names = [ast.Name(v, ast.Store()) for v in res_variables] 

1047 it = zip(range(len(comp.ops)), comp.ops, comp.comparators) 

1048 expls = [] 

1049 syms = [] 

1050 results = [left_res] 

1051 for i, op, next_operand in it: 

1052 next_res, next_expl = self.visit(next_operand) 

1053 if isinstance(next_operand, (ast.Compare, ast.BoolOp)): 

1054 next_expl = "({})".format(next_expl) 

1055 results.append(next_res) 

1056 sym = BINOP_MAP[op.__class__] 

1057 syms.append(ast.Str(sym)) 

1058 expl = "{} {} {}".format(left_expl, sym, next_expl) 

1059 expls.append(ast.Str(expl)) 

1060 res_expr = ast.Compare(left_res, [op], [next_res]) 

1061 self.statements.append(ast.Assign([store_names[i]], res_expr)) 

1062 left_res, left_expl = next_res, next_expl 

1063 # Use pytest.assertion.util._reprcompare if that's available. 

1064 expl_call = self.helper( 

1065 "_call_reprcompare", 

1066 ast.Tuple(syms, ast.Load()), 

1067 ast.Tuple(load_names, ast.Load()), 

1068 ast.Tuple(expls, ast.Load()), 

1069 ast.Tuple(results, ast.Load()), 

1070 ) 

1071 if len(comp.ops) > 1: 

1072 res = ast.BoolOp(ast.And(), load_names) # type: ast.expr 

1073 else: 

1074 res = load_names[0] 

1075 return res, self.explanation_param(self.pop_format_context(expl_call)) 

1076 

1077 

1078def try_makedirs(cache_dir: Path) -> bool: 

1079 """Attempts to create the given directory and sub-directories exist, returns True if 

1080 successful or it already exists""" 

1081 try: 

1082 os.makedirs(fspath(cache_dir), exist_ok=True) 

1083 except (FileNotFoundError, NotADirectoryError, FileExistsError): 

1084 # One of the path components was not a directory: 

1085 # - we're in a zip file 

1086 # - it is a file 

1087 return False 

1088 except PermissionError: 

1089 return False 

1090 except OSError as e: 

1091 # as of now, EROFS doesn't have an equivalent OSError-subclass 

1092 if e.errno == errno.EROFS: 

1093 return False 

1094 raise 

1095 return True 

1096 

1097 

1098def get_cache_dir(file_path: Path) -> Path: 

1099 """Returns the cache directory to write .pyc files for the given .py file path""" 

1100 if sys.version_info >= (3, 8) and sys.pycache_prefix: 

1101 # given: 

1102 # prefix = '/tmp/pycs' 

1103 # path = '/home/user/proj/test_app.py' 

1104 # we want: 

1105 # '/tmp/pycs/home/user/proj' 

1106 return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) 

1107 else: 

1108 # classic pycache directory 

1109 return file_path.parent / "__pycache__"