1"""Rewrite assertion AST to produce nice error messages""" 2from __future__ import absolute_import, division, print_function 3import ast 4import errno 5import itertools 6import imp 7import marshal 8import os 9import re 10import six 11import struct 12import sys 13import types 14 15import atomicwrites 16import py 17 18from _pytest.assertion import util 19 20 21# pytest caches rewritten pycs in __pycache__. 22if hasattr(imp, "get_tag"): 23 PYTEST_TAG = imp.get_tag() + "-PYTEST" 24else: 25 if hasattr(sys, "pypy_version_info"): 26 impl = "pypy" 27 elif sys.platform == "java": 28 impl = "jython" 29 else: 30 impl = "cpython" 31 ver = sys.version_info 32 PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1]) 33 del ver, impl 34 35PYC_EXT = ".py" + (__debug__ and "c" or "o") 36PYC_TAIL = "." + PYTEST_TAG + PYC_EXT 37 38ASCII_IS_DEFAULT_ENCODING = sys.version_info[0] < 3 39 40if sys.version_info >= (3, 5): 41 ast_Call = ast.Call 42else: 43 44 def ast_Call(a, b, c): 45 return ast.Call(a, b, c, None, None) 46 47 48class AssertionRewritingHook(object): 49 """PEP302 Import hook which rewrites asserts.""" 50 51 def __init__(self, config): 52 self.config = config 53 self.fnpats = config.getini("python_files") 54 self.session = None 55 self.modules = {} 56 self._rewritten_names = set() 57 self._register_with_pkg_resources() 58 self._must_rewrite = set() 59 60 def set_session(self, session): 61 self.session = session 62 63 def find_module(self, name, path=None): 64 state = self.config._assertstate 65 state.trace("find_module called for: %s" % name) 66 names = name.rsplit(".", 1) 67 lastname = names[-1] 68 pth = None 69 if path is not None: 70 # Starting with Python 3.3, path is a _NamespacePath(), which 71 # causes problems if not converted to list. 72 path = list(path) 73 if len(path) == 1: 74 pth = path[0] 75 if pth is None: 76 try: 77 fd, fn, desc = imp.find_module(lastname, path) 78 except ImportError: 79 return None 80 if fd is not None: 81 fd.close() 82 tp = desc[2] 83 if tp == imp.PY_COMPILED: 84 if hasattr(imp, "source_from_cache"): 85 try: 86 fn = imp.source_from_cache(fn) 87 except ValueError: 88 # Python 3 doesn't like orphaned but still-importable 89 # .pyc files. 90 fn = fn[:-1] 91 else: 92 fn = fn[:-1] 93 elif tp != imp.PY_SOURCE: 94 # Don't know what this is. 95 return None 96 else: 97 fn = os.path.join(pth, name.rpartition(".")[2] + ".py") 98 99 fn_pypath = py.path.local(fn) 100 if not self._should_rewrite(name, fn_pypath, state): 101 return None 102 103 self._rewritten_names.add(name) 104 105 # The requested module looks like a test file, so rewrite it. This is 106 # the most magical part of the process: load the source, rewrite the 107 # asserts, and load the rewritten source. We also cache the rewritten 108 # module code in a special pyc. We must be aware of the possibility of 109 # concurrent pytest processes rewriting and loading pycs. To avoid 110 # tricky race conditions, we maintain the following invariant: The 111 # cached pyc is always a complete, valid pyc. Operations on it must be 112 # atomic. POSIX's atomic rename comes in handy. 113 write = not sys.dont_write_bytecode 114 cache_dir = os.path.join(fn_pypath.dirname, "__pycache__") 115 if write: 116 try: 117 os.mkdir(cache_dir) 118 except OSError: 119 e = sys.exc_info()[1].errno 120 if e == errno.EEXIST: 121 # Either the __pycache__ directory already exists (the 122 # common case) or it's blocked by a non-dir node. In the 123 # latter case, we'll ignore it in _write_pyc. 124 pass 125 elif e in [errno.ENOENT, errno.ENOTDIR]: 126 # One of the path components was not a directory, likely 127 # because we're in a zip file. 128 write = False 129 elif e in [errno.EACCES, errno.EROFS, errno.EPERM]: 130 state.trace("read only directory: %r" % fn_pypath.dirname) 131 write = False 132 else: 133 raise 134 cache_name = fn_pypath.basename[:-3] + PYC_TAIL 135 pyc = os.path.join(cache_dir, cache_name) 136 # Notice that even if we're in a read-only directory, I'm going 137 # to check for a cached pyc. This may not be optimal... 138 co = _read_pyc(fn_pypath, pyc, state.trace) 139 if co is None: 140 state.trace("rewriting %r" % (fn,)) 141 source_stat, co = _rewrite_test(self.config, fn_pypath) 142 if co is None: 143 # Probably a SyntaxError in the test. 144 return None 145 if write: 146 _write_pyc(state, co, source_stat, pyc) 147 else: 148 state.trace("found cached rewritten pyc for %r" % (fn,)) 149 self.modules[name] = co, pyc 150 return self 151 152 def _should_rewrite(self, name, fn_pypath, state): 153 # always rewrite conftest files 154 fn = str(fn_pypath) 155 if fn_pypath.basename == "conftest.py": 156 state.trace("rewriting conftest file: %r" % (fn,)) 157 return True 158 159 if self.session is not None: 160 if self.session.isinitpath(fn): 161 state.trace("matched test file (was specified on cmdline): %r" % (fn,)) 162 return True 163 164 # modules not passed explicitly on the command line are only 165 # rewritten if they match the naming convention for test files 166 for pat in self.fnpats: 167 if fn_pypath.fnmatch(pat): 168 state.trace("matched test file %r" % (fn,)) 169 return True 170 171 for marked in self._must_rewrite: 172 if name == marked or name.startswith(marked + "."): 173 state.trace("matched marked file %r (from %r)" % (name, marked)) 174 return True 175 176 return False 177 178 def mark_rewrite(self, *names): 179 """Mark import names as needing to be rewritten. 180 181 The named module or package as well as any nested modules will 182 be rewritten on import. 183 """ 184 already_imported = ( 185 set(names).intersection(sys.modules).difference(self._rewritten_names) 186 ) 187 for name in already_imported: 188 if not AssertionRewriter.is_rewrite_disabled( 189 sys.modules[name].__doc__ or "" 190 ): 191 self._warn_already_imported(name) 192 self._must_rewrite.update(names) 193 194 def _warn_already_imported(self, name): 195 self.config.warn( 196 "P1", "Module already imported so cannot be rewritten: %s" % name 197 ) 198 199 def load_module(self, name): 200 # If there is an existing module object named 'fullname' in 201 # sys.modules, the loader must use that existing module. (Otherwise, 202 # the reload() builtin will not work correctly.) 203 if name in sys.modules: 204 return sys.modules[name] 205 206 co, pyc = self.modules.pop(name) 207 # I wish I could just call imp.load_compiled here, but __file__ has to 208 # be set properly. In Python 3.2+, this all would be handled correctly 209 # by load_compiled. 210 mod = sys.modules[name] = imp.new_module(name) 211 try: 212 mod.__file__ = co.co_filename 213 # Normally, this attribute is 3.2+. 214 mod.__cached__ = pyc 215 mod.__loader__ = self 216 py.builtin.exec_(co, mod.__dict__) 217 except: # noqa 218 if name in sys.modules: 219 del sys.modules[name] 220 raise 221 return sys.modules[name] 222 223 def is_package(self, name): 224 try: 225 fd, fn, desc = imp.find_module(name) 226 except ImportError: 227 return False 228 if fd is not None: 229 fd.close() 230 tp = desc[2] 231 return tp == imp.PKG_DIRECTORY 232 233 @classmethod 234 def _register_with_pkg_resources(cls): 235 """ 236 Ensure package resources can be loaded from this loader. May be called 237 multiple times, as the operation is idempotent. 238 """ 239 try: 240 import pkg_resources 241 242 # access an attribute in case a deferred importer is present 243 pkg_resources.__name__ 244 except ImportError: 245 return 246 247 # Since pytest tests are always located in the file system, the 248 # DefaultProvider is appropriate. 249 pkg_resources.register_loader_type(cls, pkg_resources.DefaultProvider) 250 251 def get_data(self, pathname): 252 """Optional PEP302 get_data API. 253 """ 254 with open(pathname, "rb") as f: 255 return f.read() 256 257 258def _write_pyc(state, co, source_stat, pyc): 259 # Technically, we don't have to have the same pyc format as 260 # (C)Python, since these "pycs" should never be seen by builtin 261 # import. However, there's little reason deviate, and I hope 262 # sometime to be able to use imp.load_compiled to load them. (See 263 # the comment in load_module above.) 264 try: 265 with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp: 266 fp.write(imp.get_magic()) 267 mtime = int(source_stat.mtime) 268 size = source_stat.size & 0xFFFFFFFF 269 fp.write(struct.pack("<ll", mtime, size)) 270 fp.write(marshal.dumps(co)) 271 except EnvironmentError as e: 272 state.trace("error writing pyc file at %s: errno=%s" % (pyc, e.errno)) 273 # we ignore any failure to write the cache file 274 # there are many reasons, permission-denied, __pycache__ being a 275 # file etc. 276 return False 277 return True 278 279 280RN = "\r\n".encode("utf-8") 281N = "\n".encode("utf-8") 282 283cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+") 284BOM_UTF8 = "\xef\xbb\xbf" 285 286 287def _rewrite_test(config, fn): 288 """Try to read and rewrite *fn* and return the code object.""" 289 state = config._assertstate 290 try: 291 stat = fn.stat() 292 source = fn.read("rb") 293 except EnvironmentError: 294 return None, None 295 if ASCII_IS_DEFAULT_ENCODING: 296 # ASCII is the default encoding in Python 2. Without a coding 297 # declaration, Python 2 will complain about any bytes in the file 298 # outside the ASCII range. Sadly, this behavior does not extend to 299 # compile() or ast.parse(), which prefer to interpret the bytes as 300 # latin-1. (At least they properly handle explicit coding cookies.) To 301 # preserve this error behavior, we could force ast.parse() to use ASCII 302 # as the encoding by inserting a coding cookie. Unfortunately, that 303 # messes up line numbers. Thus, we have to check ourselves if anything 304 # is outside the ASCII range in the case no encoding is explicitly 305 # declared. For more context, see issue #269. Yay for Python 3 which 306 # gets this right. 307 end1 = source.find("\n") 308 end2 = source.find("\n", end1 + 1) 309 if ( 310 not source.startswith(BOM_UTF8) 311 and cookie_re.match(source[0:end1]) is None 312 and cookie_re.match(source[end1 + 1:end2]) is None 313 ): 314 if hasattr(state, "_indecode"): 315 # encodings imported us again, so don't rewrite. 316 return None, None 317 state._indecode = True 318 try: 319 try: 320 source.decode("ascii") 321 except UnicodeDecodeError: 322 # Let it fail in real import. 323 return None, None 324 finally: 325 del state._indecode 326 try: 327 tree = ast.parse(source) 328 except SyntaxError: 329 # Let this pop up again in the real import. 330 state.trace("failed to parse: %r" % (fn,)) 331 return None, None 332 rewrite_asserts(tree, fn, config) 333 try: 334 co = compile(tree, fn.strpath, "exec", dont_inherit=True) 335 except SyntaxError: 336 # It's possible that this error is from some bug in the 337 # assertion rewriting, but I don't know of a fast way to tell. 338 state.trace("failed to compile: %r" % (fn,)) 339 return None, None 340 return stat, co 341 342 343def _read_pyc(source, pyc, trace=lambda x: None): 344 """Possibly read a pytest pyc containing rewritten code. 345 346 Return rewritten code if successful or None if not. 347 """ 348 try: 349 fp = open(pyc, "rb") 350 except IOError: 351 return None 352 with fp: 353 try: 354 mtime = int(source.mtime()) 355 size = source.size() 356 data = fp.read(12) 357 except EnvironmentError as e: 358 trace("_read_pyc(%s): EnvironmentError %s" % (source, e)) 359 return None 360 # Check for invalid or out of date pyc file. 361 if ( 362 len(data) != 12 363 or data[:4] != imp.get_magic() 364 or struct.unpack("<ll", data[4:]) != (mtime, size) 365 ): 366 trace("_read_pyc(%s): invalid or out of date pyc" % source) 367 return None 368 try: 369 co = marshal.load(fp) 370 except Exception as e: 371 trace("_read_pyc(%s): marshal.load error %s" % (source, e)) 372 return None 373 if not isinstance(co, types.CodeType): 374 trace("_read_pyc(%s): not a code object" % source) 375 return None 376 return co 377 378 379def rewrite_asserts(mod, module_path=None, config=None): 380 """Rewrite the assert statements in mod.""" 381 AssertionRewriter(module_path, config).run(mod) 382 383 384def _saferepr(obj): 385 """Get a safe repr of an object for assertion error messages. 386 387 The assertion formatting (util.format_explanation()) requires 388 newlines to be escaped since they are a special character for it. 389 Normally assertion.util.format_explanation() does this but for a 390 custom repr it is possible to contain one of the special escape 391 sequences, especially '\n{' and '\n}' are likely to be present in 392 JSON reprs. 393 394 """ 395 repr = py.io.saferepr(obj) 396 if isinstance(repr, six.text_type): 397 t = six.text_type 398 else: 399 t = six.binary_type 400 return repr.replace(t("\n"), t("\\n")) 401 402 403from _pytest.assertion.util import format_explanation as _format_explanation # noqa 404 405 406def _format_assertmsg(obj): 407 """Format the custom assertion message given. 408 409 For strings this simply replaces newlines with '\n~' so that 410 util.format_explanation() will preserve them instead of escaping 411 newlines. For other objects py.io.saferepr() is used first. 412 413 """ 414 # reprlib appears to have a bug which means that if a string 415 # contains a newline it gets escaped, however if an object has a 416 # .__repr__() which contains newlines it does not get escaped. 417 # However in either case we want to preserve the newline. 418 if isinstance(obj, six.text_type) or isinstance(obj, six.binary_type): 419 s = obj 420 is_repr = False 421 else: 422 s = py.io.saferepr(obj) 423 is_repr = True 424 if isinstance(s, six.text_type): 425 t = six.text_type 426 else: 427 t = six.binary_type 428 s = s.replace(t("\n"), t("\n~")).replace(t("%"), t("%%")) 429 if is_repr: 430 s = s.replace(t("\\n"), t("\n~")) 431 return s 432 433 434def _should_repr_global_name(obj): 435 return not hasattr(obj, "__name__") and not callable(obj) 436 437 438def _format_boolop(explanations, is_or): 439 explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" 440 if isinstance(explanation, six.text_type): 441 t = six.text_type 442 else: 443 t = six.binary_type 444 return explanation.replace(t("%"), t("%%")) 445 446 447def _call_reprcompare(ops, results, expls, each_obj): 448 for i, res, expl in zip(range(len(ops)), results, expls): 449 try: 450 done = not res 451 except Exception: 452 done = True 453 if done: 454 break 455 if util._reprcompare is not None: 456 custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) 457 if custom is not None: 458 return custom 459 return expl 460 461 462unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} 463 464binop_map = { 465 ast.BitOr: "|", 466 ast.BitXor: "^", 467 ast.BitAnd: "&", 468 ast.LShift: "<<", 469 ast.RShift: ">>", 470 ast.Add: "+", 471 ast.Sub: "-", 472 ast.Mult: "*", 473 ast.Div: "/", 474 ast.FloorDiv: "//", 475 ast.Mod: "%%", # escaped for string formatting 476 ast.Eq: "==", 477 ast.NotEq: "!=", 478 ast.Lt: "<", 479 ast.LtE: "<=", 480 ast.Gt: ">", 481 ast.GtE: ">=", 482 ast.Pow: "**", 483 ast.Is: "is", 484 ast.IsNot: "is not", 485 ast.In: "in", 486 ast.NotIn: "not in", 487} 488# Python 3.5+ compatibility 489try: 490 binop_map[ast.MatMult] = "@" 491except AttributeError: 492 pass 493 494# Python 3.4+ compatibility 495if hasattr(ast, "NameConstant"): 496 _NameConstant = ast.NameConstant 497else: 498 499 def _NameConstant(c): 500 return ast.Name(str(c), ast.Load()) 501 502 503def set_location(node, lineno, col_offset): 504 """Set node location information recursively.""" 505 506 def _fix(node, lineno, col_offset): 507 if "lineno" in node._attributes: 508 node.lineno = lineno 509 if "col_offset" in node._attributes: 510 node.col_offset = col_offset 511 for child in ast.iter_child_nodes(node): 512 _fix(child, lineno, col_offset) 513 514 _fix(node, lineno, col_offset) 515 return node 516 517 518class AssertionRewriter(ast.NodeVisitor): 519 """Assertion rewriting implementation. 520 521 The main entrypoint is to call .run() with an ast.Module instance, 522 this will then find all the assert statements and rewrite them to 523 provide intermediate values and a detailed assertion error. See 524 http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html 525 for an overview of how this works. 526 527 The entry point here is .run() which will iterate over all the 528 statements in an ast.Module and for each ast.Assert statement it 529 finds call .visit() with it. Then .visit_Assert() takes over and 530 is responsible for creating new ast statements to replace the 531 original assert statement: it rewrites the test of an assertion 532 to provide intermediate values and replace it with an if statement 533 which raises an assertion error with a detailed explanation in 534 case the expression is false. 535 536 For this .visit_Assert() uses the visitor pattern to visit all the 537 AST nodes of the ast.Assert.test field, each visit call returning 538 an AST node and the corresponding explanation string. During this 539 state is kept in several instance attributes: 540 541 :statements: All the AST statements which will replace the assert 542 statement. 543 544 :variables: This is populated by .variable() with each variable 545 used by the statements so that they can all be set to None at 546 the end of the statements. 547 548 :variable_counter: Counter to create new unique variables needed 549 by statements. Variables are created using .variable() and 550 have the form of "@py_assert0". 551 552 :on_failure: The AST statements which will be executed if the 553 assertion test fails. This is the code which will construct 554 the failure message and raises the AssertionError. 555 556 :explanation_specifiers: A dict filled by .explanation_param() 557 with %-formatting placeholders and their corresponding 558 expressions to use in the building of an assertion message. 559 This is used by .pop_format_context() to build a message. 560 561 :stack: A stack of the explanation_specifiers dicts maintained by 562 .push_format_context() and .pop_format_context() which allows 563 to build another %-formatted string while already building one. 564 565 This state is reset on every new assert statement visited and used 566 by the other visitors. 567 568 """ 569 570 def __init__(self, module_path, config): 571 super(AssertionRewriter, self).__init__() 572 self.module_path = module_path 573 self.config = config 574 575 def run(self, mod): 576 """Find all assert statements in *mod* and rewrite them.""" 577 if not mod.body: 578 # Nothing to do. 579 return 580 # Insert some special imports at the top of the module but after any 581 # docstrings and __future__ imports. 582 aliases = [ 583 ast.alias(py.builtin.builtins.__name__, "@py_builtins"), 584 ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), 585 ] 586 doc = getattr(mod, "docstring", None) 587 expect_docstring = doc is None 588 if doc is not None and self.is_rewrite_disabled(doc): 589 return 590 pos = 0 591 lineno = 1 592 for item in mod.body: 593 if ( 594 expect_docstring 595 and isinstance(item, ast.Expr) 596 and isinstance(item.value, ast.Str) 597 ): 598 doc = item.value.s 599 if self.is_rewrite_disabled(doc): 600 return 601 expect_docstring = False 602 elif ( 603 not isinstance(item, ast.ImportFrom) 604 or item.level > 0 605 or item.module != "__future__" 606 ): 607 lineno = item.lineno 608 break 609 pos += 1 610 else: 611 lineno = item.lineno 612 imports = [ 613 ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases 614 ] 615 mod.body[pos:pos] = imports 616 # Collect asserts. 617 nodes = [mod] 618 while nodes: 619 node = nodes.pop() 620 for name, field in ast.iter_fields(node): 621 if isinstance(field, list): 622 new = [] 623 for i, child in enumerate(field): 624 if isinstance(child, ast.Assert): 625 # Transform assert. 626 new.extend(self.visit(child)) 627 else: 628 new.append(child) 629 if isinstance(child, ast.AST): 630 nodes.append(child) 631 setattr(node, name, new) 632 elif ( 633 isinstance(field, ast.AST) 634 and 635 # Don't recurse into expressions as they can't contain 636 # asserts. 637 not isinstance(field, ast.expr) 638 ): 639 nodes.append(field) 640 641 @staticmethod 642 def is_rewrite_disabled(docstring): 643 return "PYTEST_DONT_REWRITE" in docstring 644 645 def variable(self): 646 """Get a new variable.""" 647 # Use a character invalid in python identifiers to avoid clashing. 648 name = "@py_assert" + str(next(self.variable_counter)) 649 self.variables.append(name) 650 return name 651 652 def assign(self, expr): 653 """Give *expr* a name.""" 654 name = self.variable() 655 self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) 656 return ast.Name(name, ast.Load()) 657 658 def display(self, expr): 659 """Call py.io.saferepr on the expression.""" 660 return self.helper("saferepr", expr) 661 662 def helper(self, name, *args): 663 """Call a helper in this module.""" 664 py_name = ast.Name("@pytest_ar", ast.Load()) 665 attr = ast.Attribute(py_name, "_" + name, ast.Load()) 666 return ast_Call(attr, list(args), []) 667 668 def builtin(self, name): 669 """Return the builtin called *name*.""" 670 builtin_name = ast.Name("@py_builtins", ast.Load()) 671 return ast.Attribute(builtin_name, name, ast.Load()) 672 673 def explanation_param(self, expr): 674 """Return a new named %-formatting placeholder for expr. 675 676 This creates a %-formatting placeholder for expr in the 677 current formatting context, e.g. ``%(py0)s``. The placeholder 678 and expr are placed in the current format context so that it 679 can be used on the next call to .pop_format_context(). 680 681 """ 682 specifier = "py" + str(next(self.variable_counter)) 683 self.explanation_specifiers[specifier] = expr 684 return "%(" + specifier + ")s" 685 686 def push_format_context(self): 687 """Create a new formatting context. 688 689 The format context is used for when an explanation wants to 690 have a variable value formatted in the assertion message. In 691 this case the value required can be added using 692 .explanation_param(). Finally .pop_format_context() is used 693 to format a string of %-formatted values as added by 694 .explanation_param(). 695 696 """ 697 self.explanation_specifiers = {} 698 self.stack.append(self.explanation_specifiers) 699 700 def pop_format_context(self, expl_expr): 701 """Format the %-formatted string with current format context. 702 703 The expl_expr should be an ast.Str instance constructed from 704 the %-placeholders created by .explanation_param(). This will 705 add the required code to format said string to .on_failure and 706 return the ast.Name instance of the formatted string. 707 708 """ 709 current = self.stack.pop() 710 if self.stack: 711 self.explanation_specifiers = self.stack[-1] 712 keys = [ast.Str(key) for key in current.keys()] 713 format_dict = ast.Dict(keys, list(current.values())) 714 form = ast.BinOp(expl_expr, ast.Mod(), format_dict) 715 name = "@py_format" + str(next(self.variable_counter)) 716 self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form)) 717 return ast.Name(name, ast.Load()) 718 719 def generic_visit(self, node): 720 """Handle expressions we don't have custom code for.""" 721 assert isinstance(node, ast.expr) 722 res = self.assign(node) 723 return res, self.explanation_param(self.display(res)) 724 725 def visit_Assert(self, assert_): 726 """Return the AST statements to replace the ast.Assert instance. 727 728 This rewrites the test of an assertion to provide 729 intermediate values and replace it with an if statement which 730 raises an assertion error with a detailed explanation in case 731 the expression is false. 732 733 """ 734 if isinstance(assert_.test, ast.Tuple) and self.config is not None: 735 fslocation = (self.module_path, assert_.lineno) 736 self.config.warn( 737 "R1", 738 "assertion is always true, perhaps " "remove parentheses?", 739 fslocation=fslocation, 740 ) 741 self.statements = [] 742 self.variables = [] 743 self.variable_counter = itertools.count() 744 self.stack = [] 745 self.on_failure = [] 746 self.push_format_context() 747 # Rewrite assert into a bunch of statements. 748 top_condition, explanation = self.visit(assert_.test) 749 # Create failure message. 750 body = self.on_failure 751 negation = ast.UnaryOp(ast.Not(), top_condition) 752 self.statements.append(ast.If(negation, body, [])) 753 if assert_.msg: 754 assertmsg = self.helper("format_assertmsg", assert_.msg) 755 explanation = "\n>assert " + explanation 756 else: 757 assertmsg = ast.Str("") 758 explanation = "assert " + explanation 759 template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) 760 msg = self.pop_format_context(template) 761 fmt = self.helper("format_explanation", msg) 762 err_name = ast.Name("AssertionError", ast.Load()) 763 exc = ast_Call(err_name, [fmt], []) 764 if sys.version_info[0] >= 3: 765 raise_ = ast.Raise(exc, None) 766 else: 767 raise_ = ast.Raise(exc, None, None) 768 body.append(raise_) 769 # Clear temporary variables by setting them to None. 770 if self.variables: 771 variables = [ast.Name(name, ast.Store()) for name in self.variables] 772 clear = ast.Assign(variables, _NameConstant(None)) 773 self.statements.append(clear) 774 # Fix line numbers. 775 for stmt in self.statements: 776 set_location(stmt, assert_.lineno, assert_.col_offset) 777 return self.statements 778 779 def visit_Name(self, name): 780 # Display the repr of the name if it's a local variable or 781 # _should_repr_global_name() thinks it's acceptable. 782 locs = ast_Call(self.builtin("locals"), [], []) 783 inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs]) 784 dorepr = self.helper("should_repr_global_name", name) 785 test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) 786 expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) 787 return name, self.explanation_param(expr) 788 789 def visit_BoolOp(self, boolop): 790 res_var = self.variable() 791 expl_list = self.assign(ast.List([], ast.Load())) 792 app = ast.Attribute(expl_list, "append", ast.Load()) 793 is_or = int(isinstance(boolop.op, ast.Or)) 794 body = save = self.statements 795 fail_save = self.on_failure 796 levels = len(boolop.values) - 1 797 self.push_format_context() 798 # Process each operand, short-circuting if needed. 799 for i, v in enumerate(boolop.values): 800 if i: 801 fail_inner = [] 802 # cond is set in a prior loop iteration below 803 self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa 804 self.on_failure = fail_inner 805 self.push_format_context() 806 res, expl = self.visit(v) 807 body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) 808 expl_format = self.pop_format_context(ast.Str(expl)) 809 call = ast_Call(app, [expl_format], []) 810 self.on_failure.append(ast.Expr(call)) 811 if i < levels: 812 cond = res 813 if is_or: 814 cond = ast.UnaryOp(ast.Not(), cond) 815 inner = [] 816 self.statements.append(ast.If(cond, inner, [])) 817 self.statements = body = inner 818 self.statements = save 819 self.on_failure = fail_save 820 expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or)) 821 expl = self.pop_format_context(expl_template) 822 return ast.Name(res_var, ast.Load()), self.explanation_param(expl) 823 824 def visit_UnaryOp(self, unary): 825 pattern = unary_map[unary.op.__class__] 826 operand_res, operand_expl = self.visit(unary.operand) 827 res = self.assign(ast.UnaryOp(unary.op, operand_res)) 828 return res, pattern % (operand_expl,) 829 830 def visit_BinOp(self, binop): 831 symbol = binop_map[binop.op.__class__] 832 left_expr, left_expl = self.visit(binop.left) 833 right_expr, right_expl = self.visit(binop.right) 834 explanation = "(%s %s %s)" % (left_expl, symbol, right_expl) 835 res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) 836 return res, explanation 837 838 def visit_Call_35(self, call): 839 """ 840 visit `ast.Call` nodes on Python3.5 and after 841 """ 842 new_func, func_expl = self.visit(call.func) 843 arg_expls = [] 844 new_args = [] 845 new_kwargs = [] 846 for arg in call.args: 847 res, expl = self.visit(arg) 848 arg_expls.append(expl) 849 new_args.append(res) 850 for keyword in call.keywords: 851 res, expl = self.visit(keyword.value) 852 new_kwargs.append(ast.keyword(keyword.arg, res)) 853 if keyword.arg: 854 arg_expls.append(keyword.arg + "=" + expl) 855 else: # **args have `arg` keywords with an .arg of None 856 arg_expls.append("**" + expl) 857 858 expl = "%s(%s)" % (func_expl, ", ".join(arg_expls)) 859 new_call = ast.Call(new_func, new_args, new_kwargs) 860 res = self.assign(new_call) 861 res_expl = self.explanation_param(self.display(res)) 862 outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl) 863 return res, outer_expl 864 865 def visit_Starred(self, starred): 866 # From Python 3.5, a Starred node can appear in a function call 867 res, expl = self.visit(starred.value) 868 return starred, "*" + expl 869 870 def visit_Call_legacy(self, call): 871 """ 872 visit `ast.Call nodes on 3.4 and below` 873 """ 874 new_func, func_expl = self.visit(call.func) 875 arg_expls = [] 876 new_args = [] 877 new_kwargs = [] 878 new_star = new_kwarg = None 879 for arg in call.args: 880 res, expl = self.visit(arg) 881 new_args.append(res) 882 arg_expls.append(expl) 883 for keyword in call.keywords: 884 res, expl = self.visit(keyword.value) 885 new_kwargs.append(ast.keyword(keyword.arg, res)) 886 arg_expls.append(keyword.arg + "=" + expl) 887 if call.starargs: 888 new_star, expl = self.visit(call.starargs) 889 arg_expls.append("*" + expl) 890 if call.kwargs: 891 new_kwarg, expl = self.visit(call.kwargs) 892 arg_expls.append("**" + expl) 893 expl = "%s(%s)" % (func_expl, ", ".join(arg_expls)) 894 new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg) 895 res = self.assign(new_call) 896 res_expl = self.explanation_param(self.display(res)) 897 outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl) 898 return res, outer_expl 899 900 # ast.Call signature changed on 3.5, 901 # conditionally change which methods is named 902 # visit_Call depending on Python version 903 if sys.version_info >= (3, 5): 904 visit_Call = visit_Call_35 905 else: 906 visit_Call = visit_Call_legacy 907 908 def visit_Attribute(self, attr): 909 if not isinstance(attr.ctx, ast.Load): 910 return self.generic_visit(attr) 911 value, value_expl = self.visit(attr.value) 912 res = self.assign(ast.Attribute(value, attr.attr, ast.Load())) 913 res_expl = self.explanation_param(self.display(res)) 914 pat = "%s\n{%s = %s.%s\n}" 915 expl = pat % (res_expl, res_expl, value_expl, attr.attr) 916 return res, expl 917 918 def visit_Compare(self, comp): 919 self.push_format_context() 920 left_res, left_expl = self.visit(comp.left) 921 if isinstance(comp.left, (ast.Compare, ast.BoolOp)): 922 left_expl = "({})".format(left_expl) 923 res_variables = [self.variable() for i in range(len(comp.ops))] 924 load_names = [ast.Name(v, ast.Load()) for v in res_variables] 925 store_names = [ast.Name(v, ast.Store()) for v in res_variables] 926 it = zip(range(len(comp.ops)), comp.ops, comp.comparators) 927 expls = [] 928 syms = [] 929 results = [left_res] 930 for i, op, next_operand in it: 931 next_res, next_expl = self.visit(next_operand) 932 if isinstance(next_operand, (ast.Compare, ast.BoolOp)): 933 next_expl = "({})".format(next_expl) 934 results.append(next_res) 935 sym = binop_map[op.__class__] 936 syms.append(ast.Str(sym)) 937 expl = "%s %s %s" % (left_expl, sym, next_expl) 938 expls.append(ast.Str(expl)) 939 res_expr = ast.Compare(left_res, [op], [next_res]) 940 self.statements.append(ast.Assign([store_names[i]], res_expr)) 941 left_res, left_expl = next_res, next_expl 942 # Use pytest.assertion.util._reprcompare if that's available. 943 expl_call = self.helper( 944 "call_reprcompare", 945 ast.Tuple(syms, ast.Load()), 946 ast.Tuple(load_names, ast.Load()), 947 ast.Tuple(expls, ast.Load()), 948 ast.Tuple(results, ast.Load()), 949 ) 950 if len(comp.ops) > 1: 951 res = ast.BoolOp(ast.And(), load_names) 952 else: 953 res = load_names[0] 954 return res, self.explanation_param(self.pop_format_context(expl_call)) 955