1"""Rewrite assertion AST to produce nice error messages"""
2from __future__ import absolute_import, division, print_function
3import ast
4import errno
5import itertools
6import imp
7import marshal
8import os
9import re
10import six
11import struct
12import sys
13import types
14
15import atomicwrites
16import py
17
18from _pytest.assertion import util
19
20
21# pytest caches rewritten pycs in __pycache__.
22if hasattr(imp, "get_tag"):
23    PYTEST_TAG = imp.get_tag() + "-PYTEST"
24else:
25    if hasattr(sys, "pypy_version_info"):
26        impl = "pypy"
27    elif sys.platform == "java":
28        impl = "jython"
29    else:
30        impl = "cpython"
31    ver = sys.version_info
32    PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1])
33    del ver, impl
34
35PYC_EXT = ".py" + (__debug__ and "c" or "o")
36PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
37
38ASCII_IS_DEFAULT_ENCODING = sys.version_info[0] < 3
39
40if sys.version_info >= (3, 5):
41    ast_Call = ast.Call
42else:
43
44    def ast_Call(a, b, c):
45        return ast.Call(a, b, c, None, None)
46
47
48class AssertionRewritingHook(object):
49    """PEP302 Import hook which rewrites asserts."""
50
51    def __init__(self, config):
52        self.config = config
53        self.fnpats = config.getini("python_files")
54        self.session = None
55        self.modules = {}
56        self._rewritten_names = set()
57        self._register_with_pkg_resources()
58        self._must_rewrite = set()
59
60    def set_session(self, session):
61        self.session = session
62
63    def find_module(self, name, path=None):
64        state = self.config._assertstate
65        state.trace("find_module called for: %s" % name)
66        names = name.rsplit(".", 1)
67        lastname = names[-1]
68        pth = None
69        if path is not None:
70            # Starting with Python 3.3, path is a _NamespacePath(), which
71            # causes problems if not converted to list.
72            path = list(path)
73            if len(path) == 1:
74                pth = path[0]
75        if pth is None:
76            try:
77                fd, fn, desc = imp.find_module(lastname, path)
78            except ImportError:
79                return None
80            if fd is not None:
81                fd.close()
82            tp = desc[2]
83            if tp == imp.PY_COMPILED:
84                if hasattr(imp, "source_from_cache"):
85                    try:
86                        fn = imp.source_from_cache(fn)
87                    except ValueError:
88                        # Python 3 doesn't like orphaned but still-importable
89                        # .pyc files.
90                        fn = fn[:-1]
91                else:
92                    fn = fn[:-1]
93            elif tp != imp.PY_SOURCE:
94                # Don't know what this is.
95                return None
96        else:
97            fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
98
99        fn_pypath = py.path.local(fn)
100        if not self._should_rewrite(name, fn_pypath, state):
101            return None
102
103        self._rewritten_names.add(name)
104
105        # The requested module looks like a test file, so rewrite it. This is
106        # the most magical part of the process: load the source, rewrite the
107        # asserts, and load the rewritten source. We also cache the rewritten
108        # module code in a special pyc. We must be aware of the possibility of
109        # concurrent pytest processes rewriting and loading pycs. To avoid
110        # tricky race conditions, we maintain the following invariant: The
111        # cached pyc is always a complete, valid pyc. Operations on it must be
112        # atomic. POSIX's atomic rename comes in handy.
113        write = not sys.dont_write_bytecode
114        cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
115        if write:
116            try:
117                os.mkdir(cache_dir)
118            except OSError:
119                e = sys.exc_info()[1].errno
120                if e == errno.EEXIST:
121                    # Either the __pycache__ directory already exists (the
122                    # common case) or it's blocked by a non-dir node. In the
123                    # latter case, we'll ignore it in _write_pyc.
124                    pass
125                elif e in [errno.ENOENT, errno.ENOTDIR]:
126                    # One of the path components was not a directory, likely
127                    # because we're in a zip file.
128                    write = False
129                elif e in [errno.EACCES, errno.EROFS, errno.EPERM]:
130                    state.trace("read only directory: %r" % fn_pypath.dirname)
131                    write = False
132                else:
133                    raise
134        cache_name = fn_pypath.basename[:-3] + PYC_TAIL
135        pyc = os.path.join(cache_dir, cache_name)
136        # Notice that even if we're in a read-only directory, I'm going
137        # to check for a cached pyc. This may not be optimal...
138        co = _read_pyc(fn_pypath, pyc, state.trace)
139        if co is None:
140            state.trace("rewriting %r" % (fn,))
141            source_stat, co = _rewrite_test(self.config, fn_pypath)
142            if co is None:
143                # Probably a SyntaxError in the test.
144                return None
145            if write:
146                _write_pyc(state, co, source_stat, pyc)
147        else:
148            state.trace("found cached rewritten pyc for %r" % (fn,))
149        self.modules[name] = co, pyc
150        return self
151
152    def _should_rewrite(self, name, fn_pypath, state):
153        # always rewrite conftest files
154        fn = str(fn_pypath)
155        if fn_pypath.basename == "conftest.py":
156            state.trace("rewriting conftest file: %r" % (fn,))
157            return True
158
159        if self.session is not None:
160            if self.session.isinitpath(fn):
161                state.trace("matched test file (was specified on cmdline): %r" % (fn,))
162                return True
163
164        # modules not passed explicitly on the command line are only
165        # rewritten if they match the naming convention for test files
166        for pat in self.fnpats:
167            if fn_pypath.fnmatch(pat):
168                state.trace("matched test file %r" % (fn,))
169                return True
170
171        for marked in self._must_rewrite:
172            if name == marked or name.startswith(marked + "."):
173                state.trace("matched marked file %r (from %r)" % (name, marked))
174                return True
175
176        return False
177
178    def mark_rewrite(self, *names):
179        """Mark import names as needing to be rewritten.
180
181        The named module or package as well as any nested modules will
182        be rewritten on import.
183        """
184        already_imported = (
185            set(names).intersection(sys.modules).difference(self._rewritten_names)
186        )
187        for name in already_imported:
188            if not AssertionRewriter.is_rewrite_disabled(
189                sys.modules[name].__doc__ or ""
190            ):
191                self._warn_already_imported(name)
192        self._must_rewrite.update(names)
193
194    def _warn_already_imported(self, name):
195        self.config.warn(
196            "P1", "Module already imported so cannot be rewritten: %s" % name
197        )
198
199    def load_module(self, name):
200        # If there is an existing module object named 'fullname' in
201        # sys.modules, the loader must use that existing module. (Otherwise,
202        # the reload() builtin will not work correctly.)
203        if name in sys.modules:
204            return sys.modules[name]
205
206        co, pyc = self.modules.pop(name)
207        # I wish I could just call imp.load_compiled here, but __file__ has to
208        # be set properly. In Python 3.2+, this all would be handled correctly
209        # by load_compiled.
210        mod = sys.modules[name] = imp.new_module(name)
211        try:
212            mod.__file__ = co.co_filename
213            # Normally, this attribute is 3.2+.
214            mod.__cached__ = pyc
215            mod.__loader__ = self
216            py.builtin.exec_(co, mod.__dict__)
217        except:  # noqa
218            if name in sys.modules:
219                del sys.modules[name]
220            raise
221        return sys.modules[name]
222
223    def is_package(self, name):
224        try:
225            fd, fn, desc = imp.find_module(name)
226        except ImportError:
227            return False
228        if fd is not None:
229            fd.close()
230        tp = desc[2]
231        return tp == imp.PKG_DIRECTORY
232
233    @classmethod
234    def _register_with_pkg_resources(cls):
235        """
236        Ensure package resources can be loaded from this loader. May be called
237        multiple times, as the operation is idempotent.
238        """
239        try:
240            import pkg_resources
241
242            # access an attribute in case a deferred importer is present
243            pkg_resources.__name__
244        except ImportError:
245            return
246
247        # Since pytest tests are always located in the file system, the
248        #  DefaultProvider is appropriate.
249        pkg_resources.register_loader_type(cls, pkg_resources.DefaultProvider)
250
251    def get_data(self, pathname):
252        """Optional PEP302 get_data API.
253        """
254        with open(pathname, "rb") as f:
255            return f.read()
256
257
258def _write_pyc(state, co, source_stat, pyc):
259    # Technically, we don't have to have the same pyc format as
260    # (C)Python, since these "pycs" should never be seen by builtin
261    # import. However, there's little reason deviate, and I hope
262    # sometime to be able to use imp.load_compiled to load them. (See
263    # the comment in load_module above.)
264    try:
265        with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp:
266            fp.write(imp.get_magic())
267            mtime = int(source_stat.mtime)
268            size = source_stat.size & 0xFFFFFFFF
269            fp.write(struct.pack("<ll", mtime, size))
270            fp.write(marshal.dumps(co))
271    except EnvironmentError as e:
272        state.trace("error writing pyc file at %s: errno=%s" % (pyc, e.errno))
273        # we ignore any failure to write the cache file
274        # there are many reasons, permission-denied, __pycache__ being a
275        # file etc.
276        return False
277    return True
278
279
280RN = "\r\n".encode("utf-8")
281N = "\n".encode("utf-8")
282
283cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+")
284BOM_UTF8 = "\xef\xbb\xbf"
285
286
287def _rewrite_test(config, fn):
288    """Try to read and rewrite *fn* and return the code object."""
289    state = config._assertstate
290    try:
291        stat = fn.stat()
292        source = fn.read("rb")
293    except EnvironmentError:
294        return None, None
295    if ASCII_IS_DEFAULT_ENCODING:
296        # ASCII is the default encoding in Python 2. Without a coding
297        # declaration, Python 2 will complain about any bytes in the file
298        # outside the ASCII range. Sadly, this behavior does not extend to
299        # compile() or ast.parse(), which prefer to interpret the bytes as
300        # latin-1. (At least they properly handle explicit coding cookies.) To
301        # preserve this error behavior, we could force ast.parse() to use ASCII
302        # as the encoding by inserting a coding cookie. Unfortunately, that
303        # messes up line numbers. Thus, we have to check ourselves if anything
304        # is outside the ASCII range in the case no encoding is explicitly
305        # declared. For more context, see issue #269. Yay for Python 3 which
306        # gets this right.
307        end1 = source.find("\n")
308        end2 = source.find("\n", end1 + 1)
309        if (
310            not source.startswith(BOM_UTF8)
311            and cookie_re.match(source[0:end1]) is None
312            and cookie_re.match(source[end1 + 1:end2]) is None
313        ):
314            if hasattr(state, "_indecode"):
315                # encodings imported us again, so don't rewrite.
316                return None, None
317            state._indecode = True
318            try:
319                try:
320                    source.decode("ascii")
321                except UnicodeDecodeError:
322                    # Let it fail in real import.
323                    return None, None
324            finally:
325                del state._indecode
326    try:
327        tree = ast.parse(source)
328    except SyntaxError:
329        # Let this pop up again in the real import.
330        state.trace("failed to parse: %r" % (fn,))
331        return None, None
332    rewrite_asserts(tree, fn, config)
333    try:
334        co = compile(tree, fn.strpath, "exec", dont_inherit=True)
335    except SyntaxError:
336        # It's possible that this error is from some bug in the
337        # assertion rewriting, but I don't know of a fast way to tell.
338        state.trace("failed to compile: %r" % (fn,))
339        return None, None
340    return stat, co
341
342
343def _read_pyc(source, pyc, trace=lambda x: None):
344    """Possibly read a pytest pyc containing rewritten code.
345
346    Return rewritten code if successful or None if not.
347    """
348    try:
349        fp = open(pyc, "rb")
350    except IOError:
351        return None
352    with fp:
353        try:
354            mtime = int(source.mtime())
355            size = source.size()
356            data = fp.read(12)
357        except EnvironmentError as e:
358            trace("_read_pyc(%s): EnvironmentError %s" % (source, e))
359            return None
360        # Check for invalid or out of date pyc file.
361        if (
362            len(data) != 12
363            or data[:4] != imp.get_magic()
364            or struct.unpack("<ll", data[4:]) != (mtime, size)
365        ):
366            trace("_read_pyc(%s): invalid or out of date pyc" % source)
367            return None
368        try:
369            co = marshal.load(fp)
370        except Exception as e:
371            trace("_read_pyc(%s): marshal.load error %s" % (source, e))
372            return None
373        if not isinstance(co, types.CodeType):
374            trace("_read_pyc(%s): not a code object" % source)
375            return None
376        return co
377
378
379def rewrite_asserts(mod, module_path=None, config=None):
380    """Rewrite the assert statements in mod."""
381    AssertionRewriter(module_path, config).run(mod)
382
383
384def _saferepr(obj):
385    """Get a safe repr of an object for assertion error messages.
386
387    The assertion formatting (util.format_explanation()) requires
388    newlines to be escaped since they are a special character for it.
389    Normally assertion.util.format_explanation() does this but for a
390    custom repr it is possible to contain one of the special escape
391    sequences, especially '\n{' and '\n}' are likely to be present in
392    JSON reprs.
393
394    """
395    repr = py.io.saferepr(obj)
396    if isinstance(repr, six.text_type):
397        t = six.text_type
398    else:
399        t = six.binary_type
400    return repr.replace(t("\n"), t("\\n"))
401
402
403from _pytest.assertion.util import format_explanation as _format_explanation  # noqa
404
405
406def _format_assertmsg(obj):
407    """Format the custom assertion message given.
408
409    For strings this simply replaces newlines with '\n~' so that
410    util.format_explanation() will preserve them instead of escaping
411    newlines.  For other objects py.io.saferepr() is used first.
412
413    """
414    # reprlib appears to have a bug which means that if a string
415    # contains a newline it gets escaped, however if an object has a
416    # .__repr__() which contains newlines it does not get escaped.
417    # However in either case we want to preserve the newline.
418    if isinstance(obj, six.text_type) or isinstance(obj, six.binary_type):
419        s = obj
420        is_repr = False
421    else:
422        s = py.io.saferepr(obj)
423        is_repr = True
424    if isinstance(s, six.text_type):
425        t = six.text_type
426    else:
427        t = six.binary_type
428    s = s.replace(t("\n"), t("\n~")).replace(t("%"), t("%%"))
429    if is_repr:
430        s = s.replace(t("\\n"), t("\n~"))
431    return s
432
433
434def _should_repr_global_name(obj):
435    return not hasattr(obj, "__name__") and not callable(obj)
436
437
438def _format_boolop(explanations, is_or):
439    explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
440    if isinstance(explanation, six.text_type):
441        t = six.text_type
442    else:
443        t = six.binary_type
444    return explanation.replace(t("%"), t("%%"))
445
446
447def _call_reprcompare(ops, results, expls, each_obj):
448    for i, res, expl in zip(range(len(ops)), results, expls):
449        try:
450            done = not res
451        except Exception:
452            done = True
453        if done:
454            break
455    if util._reprcompare is not None:
456        custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
457        if custom is not None:
458            return custom
459    return expl
460
461
462unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
463
464binop_map = {
465    ast.BitOr: "|",
466    ast.BitXor: "^",
467    ast.BitAnd: "&",
468    ast.LShift: "<<",
469    ast.RShift: ">>",
470    ast.Add: "+",
471    ast.Sub: "-",
472    ast.Mult: "*",
473    ast.Div: "/",
474    ast.FloorDiv: "//",
475    ast.Mod: "%%",  # escaped for string formatting
476    ast.Eq: "==",
477    ast.NotEq: "!=",
478    ast.Lt: "<",
479    ast.LtE: "<=",
480    ast.Gt: ">",
481    ast.GtE: ">=",
482    ast.Pow: "**",
483    ast.Is: "is",
484    ast.IsNot: "is not",
485    ast.In: "in",
486    ast.NotIn: "not in",
487}
488# Python 3.5+ compatibility
489try:
490    binop_map[ast.MatMult] = "@"
491except AttributeError:
492    pass
493
494# Python 3.4+ compatibility
495if hasattr(ast, "NameConstant"):
496    _NameConstant = ast.NameConstant
497else:
498
499    def _NameConstant(c):
500        return ast.Name(str(c), ast.Load())
501
502
503def set_location(node, lineno, col_offset):
504    """Set node location information recursively."""
505
506    def _fix(node, lineno, col_offset):
507        if "lineno" in node._attributes:
508            node.lineno = lineno
509        if "col_offset" in node._attributes:
510            node.col_offset = col_offset
511        for child in ast.iter_child_nodes(node):
512            _fix(child, lineno, col_offset)
513
514    _fix(node, lineno, col_offset)
515    return node
516
517
518class AssertionRewriter(ast.NodeVisitor):
519    """Assertion rewriting implementation.
520
521    The main entrypoint is to call .run() with an ast.Module instance,
522    this will then find all the assert statements and rewrite them to
523    provide intermediate values and a detailed assertion error.  See
524    http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
525    for an overview of how this works.
526
527    The entry point here is .run() which will iterate over all the
528    statements in an ast.Module and for each ast.Assert statement it
529    finds call .visit() with it.  Then .visit_Assert() takes over and
530    is responsible for creating new ast statements to replace the
531    original assert statement: it rewrites the test of an assertion
532    to provide intermediate values and replace it with an if statement
533    which raises an assertion error with a detailed explanation in
534    case the expression is false.
535
536    For this .visit_Assert() uses the visitor pattern to visit all the
537    AST nodes of the ast.Assert.test field, each visit call returning
538    an AST node and the corresponding explanation string.  During this
539    state is kept in several instance attributes:
540
541    :statements: All the AST statements which will replace the assert
542       statement.
543
544    :variables: This is populated by .variable() with each variable
545       used by the statements so that they can all be set to None at
546       the end of the statements.
547
548    :variable_counter: Counter to create new unique variables needed
549       by statements.  Variables are created using .variable() and
550       have the form of "@py_assert0".
551
552    :on_failure: The AST statements which will be executed if the
553       assertion test fails.  This is the code which will construct
554       the failure message and raises the AssertionError.
555
556    :explanation_specifiers: A dict filled by .explanation_param()
557       with %-formatting placeholders and their corresponding
558       expressions to use in the building of an assertion message.
559       This is used by .pop_format_context() to build a message.
560
561    :stack: A stack of the explanation_specifiers dicts maintained by
562       .push_format_context() and .pop_format_context() which allows
563       to build another %-formatted string while already building one.
564
565    This state is reset on every new assert statement visited and used
566    by the other visitors.
567
568    """
569
570    def __init__(self, module_path, config):
571        super(AssertionRewriter, self).__init__()
572        self.module_path = module_path
573        self.config = config
574
575    def run(self, mod):
576        """Find all assert statements in *mod* and rewrite them."""
577        if not mod.body:
578            # Nothing to do.
579            return
580        # Insert some special imports at the top of the module but after any
581        # docstrings and __future__ imports.
582        aliases = [
583            ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
584            ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
585        ]
586        doc = getattr(mod, "docstring", None)
587        expect_docstring = doc is None
588        if doc is not None and self.is_rewrite_disabled(doc):
589            return
590        pos = 0
591        lineno = 1
592        for item in mod.body:
593            if (
594                expect_docstring
595                and isinstance(item, ast.Expr)
596                and isinstance(item.value, ast.Str)
597            ):
598                doc = item.value.s
599                if self.is_rewrite_disabled(doc):
600                    return
601                expect_docstring = False
602            elif (
603                not isinstance(item, ast.ImportFrom)
604                or item.level > 0
605                or item.module != "__future__"
606            ):
607                lineno = item.lineno
608                break
609            pos += 1
610        else:
611            lineno = item.lineno
612        imports = [
613            ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
614        ]
615        mod.body[pos:pos] = imports
616        # Collect asserts.
617        nodes = [mod]
618        while nodes:
619            node = nodes.pop()
620            for name, field in ast.iter_fields(node):
621                if isinstance(field, list):
622                    new = []
623                    for i, child in enumerate(field):
624                        if isinstance(child, ast.Assert):
625                            # Transform assert.
626                            new.extend(self.visit(child))
627                        else:
628                            new.append(child)
629                            if isinstance(child, ast.AST):
630                                nodes.append(child)
631                    setattr(node, name, new)
632                elif (
633                    isinstance(field, ast.AST)
634                    and
635                    # Don't recurse into expressions as they can't contain
636                    # asserts.
637                    not isinstance(field, ast.expr)
638                ):
639                    nodes.append(field)
640
641    @staticmethod
642    def is_rewrite_disabled(docstring):
643        return "PYTEST_DONT_REWRITE" in docstring
644
645    def variable(self):
646        """Get a new variable."""
647        # Use a character invalid in python identifiers to avoid clashing.
648        name = "@py_assert" + str(next(self.variable_counter))
649        self.variables.append(name)
650        return name
651
652    def assign(self, expr):
653        """Give *expr* a name."""
654        name = self.variable()
655        self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
656        return ast.Name(name, ast.Load())
657
658    def display(self, expr):
659        """Call py.io.saferepr on the expression."""
660        return self.helper("saferepr", expr)
661
662    def helper(self, name, *args):
663        """Call a helper in this module."""
664        py_name = ast.Name("@pytest_ar", ast.Load())
665        attr = ast.Attribute(py_name, "_" + name, ast.Load())
666        return ast_Call(attr, list(args), [])
667
668    def builtin(self, name):
669        """Return the builtin called *name*."""
670        builtin_name = ast.Name("@py_builtins", ast.Load())
671        return ast.Attribute(builtin_name, name, ast.Load())
672
673    def explanation_param(self, expr):
674        """Return a new named %-formatting placeholder for expr.
675
676        This creates a %-formatting placeholder for expr in the
677        current formatting context, e.g. ``%(py0)s``.  The placeholder
678        and expr are placed in the current format context so that it
679        can be used on the next call to .pop_format_context().
680
681        """
682        specifier = "py" + str(next(self.variable_counter))
683        self.explanation_specifiers[specifier] = expr
684        return "%(" + specifier + ")s"
685
686    def push_format_context(self):
687        """Create a new formatting context.
688
689        The format context is used for when an explanation wants to
690        have a variable value formatted in the assertion message.  In
691        this case the value required can be added using
692        .explanation_param().  Finally .pop_format_context() is used
693        to format a string of %-formatted values as added by
694        .explanation_param().
695
696        """
697        self.explanation_specifiers = {}
698        self.stack.append(self.explanation_specifiers)
699
700    def pop_format_context(self, expl_expr):
701        """Format the %-formatted string with current format context.
702
703        The expl_expr should be an ast.Str instance constructed from
704        the %-placeholders created by .explanation_param().  This will
705        add the required code to format said string to .on_failure and
706        return the ast.Name instance of the formatted string.
707
708        """
709        current = self.stack.pop()
710        if self.stack:
711            self.explanation_specifiers = self.stack[-1]
712        keys = [ast.Str(key) for key in current.keys()]
713        format_dict = ast.Dict(keys, list(current.values()))
714        form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
715        name = "@py_format" + str(next(self.variable_counter))
716        self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
717        return ast.Name(name, ast.Load())
718
719    def generic_visit(self, node):
720        """Handle expressions we don't have custom code for."""
721        assert isinstance(node, ast.expr)
722        res = self.assign(node)
723        return res, self.explanation_param(self.display(res))
724
725    def visit_Assert(self, assert_):
726        """Return the AST statements to replace the ast.Assert instance.
727
728        This rewrites the test of an assertion to provide
729        intermediate values and replace it with an if statement which
730        raises an assertion error with a detailed explanation in case
731        the expression is false.
732
733        """
734        if isinstance(assert_.test, ast.Tuple) and self.config is not None:
735            fslocation = (self.module_path, assert_.lineno)
736            self.config.warn(
737                "R1",
738                "assertion is always true, perhaps " "remove parentheses?",
739                fslocation=fslocation,
740            )
741        self.statements = []
742        self.variables = []
743        self.variable_counter = itertools.count()
744        self.stack = []
745        self.on_failure = []
746        self.push_format_context()
747        # Rewrite assert into a bunch of statements.
748        top_condition, explanation = self.visit(assert_.test)
749        # Create failure message.
750        body = self.on_failure
751        negation = ast.UnaryOp(ast.Not(), top_condition)
752        self.statements.append(ast.If(negation, body, []))
753        if assert_.msg:
754            assertmsg = self.helper("format_assertmsg", assert_.msg)
755            explanation = "\n>assert " + explanation
756        else:
757            assertmsg = ast.Str("")
758            explanation = "assert " + explanation
759        template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
760        msg = self.pop_format_context(template)
761        fmt = self.helper("format_explanation", msg)
762        err_name = ast.Name("AssertionError", ast.Load())
763        exc = ast_Call(err_name, [fmt], [])
764        if sys.version_info[0] >= 3:
765            raise_ = ast.Raise(exc, None)
766        else:
767            raise_ = ast.Raise(exc, None, None)
768        body.append(raise_)
769        # Clear temporary variables by setting them to None.
770        if self.variables:
771            variables = [ast.Name(name, ast.Store()) for name in self.variables]
772            clear = ast.Assign(variables, _NameConstant(None))
773            self.statements.append(clear)
774        # Fix line numbers.
775        for stmt in self.statements:
776            set_location(stmt, assert_.lineno, assert_.col_offset)
777        return self.statements
778
779    def visit_Name(self, name):
780        # Display the repr of the name if it's a local variable or
781        # _should_repr_global_name() thinks it's acceptable.
782        locs = ast_Call(self.builtin("locals"), [], [])
783        inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
784        dorepr = self.helper("should_repr_global_name", name)
785        test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
786        expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
787        return name, self.explanation_param(expr)
788
789    def visit_BoolOp(self, boolop):
790        res_var = self.variable()
791        expl_list = self.assign(ast.List([], ast.Load()))
792        app = ast.Attribute(expl_list, "append", ast.Load())
793        is_or = int(isinstance(boolop.op, ast.Or))
794        body = save = self.statements
795        fail_save = self.on_failure
796        levels = len(boolop.values) - 1
797        self.push_format_context()
798        # Process each operand, short-circuting if needed.
799        for i, v in enumerate(boolop.values):
800            if i:
801                fail_inner = []
802                # cond is set in a prior loop iteration below
803                self.on_failure.append(ast.If(cond, fail_inner, []))  # noqa
804                self.on_failure = fail_inner
805            self.push_format_context()
806            res, expl = self.visit(v)
807            body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
808            expl_format = self.pop_format_context(ast.Str(expl))
809            call = ast_Call(app, [expl_format], [])
810            self.on_failure.append(ast.Expr(call))
811            if i < levels:
812                cond = res
813                if is_or:
814                    cond = ast.UnaryOp(ast.Not(), cond)
815                inner = []
816                self.statements.append(ast.If(cond, inner, []))
817                self.statements = body = inner
818        self.statements = save
819        self.on_failure = fail_save
820        expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
821        expl = self.pop_format_context(expl_template)
822        return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
823
824    def visit_UnaryOp(self, unary):
825        pattern = unary_map[unary.op.__class__]
826        operand_res, operand_expl = self.visit(unary.operand)
827        res = self.assign(ast.UnaryOp(unary.op, operand_res))
828        return res, pattern % (operand_expl,)
829
830    def visit_BinOp(self, binop):
831        symbol = binop_map[binop.op.__class__]
832        left_expr, left_expl = self.visit(binop.left)
833        right_expr, right_expl = self.visit(binop.right)
834        explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
835        res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
836        return res, explanation
837
838    def visit_Call_35(self, call):
839        """
840        visit `ast.Call` nodes on Python3.5 and after
841        """
842        new_func, func_expl = self.visit(call.func)
843        arg_expls = []
844        new_args = []
845        new_kwargs = []
846        for arg in call.args:
847            res, expl = self.visit(arg)
848            arg_expls.append(expl)
849            new_args.append(res)
850        for keyword in call.keywords:
851            res, expl = self.visit(keyword.value)
852            new_kwargs.append(ast.keyword(keyword.arg, res))
853            if keyword.arg:
854                arg_expls.append(keyword.arg + "=" + expl)
855            else:  # **args have `arg` keywords with an .arg of None
856                arg_expls.append("**" + expl)
857
858        expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
859        new_call = ast.Call(new_func, new_args, new_kwargs)
860        res = self.assign(new_call)
861        res_expl = self.explanation_param(self.display(res))
862        outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
863        return res, outer_expl
864
865    def visit_Starred(self, starred):
866        # From Python 3.5, a Starred node can appear in a function call
867        res, expl = self.visit(starred.value)
868        return starred, "*" + expl
869
870    def visit_Call_legacy(self, call):
871        """
872        visit `ast.Call nodes on 3.4 and below`
873        """
874        new_func, func_expl = self.visit(call.func)
875        arg_expls = []
876        new_args = []
877        new_kwargs = []
878        new_star = new_kwarg = None
879        for arg in call.args:
880            res, expl = self.visit(arg)
881            new_args.append(res)
882            arg_expls.append(expl)
883        for keyword in call.keywords:
884            res, expl = self.visit(keyword.value)
885            new_kwargs.append(ast.keyword(keyword.arg, res))
886            arg_expls.append(keyword.arg + "=" + expl)
887        if call.starargs:
888            new_star, expl = self.visit(call.starargs)
889            arg_expls.append("*" + expl)
890        if call.kwargs:
891            new_kwarg, expl = self.visit(call.kwargs)
892            arg_expls.append("**" + expl)
893        expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
894        new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg)
895        res = self.assign(new_call)
896        res_expl = self.explanation_param(self.display(res))
897        outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
898        return res, outer_expl
899
900    # ast.Call signature changed on 3.5,
901    # conditionally change  which methods is named
902    # visit_Call depending on Python version
903    if sys.version_info >= (3, 5):
904        visit_Call = visit_Call_35
905    else:
906        visit_Call = visit_Call_legacy
907
908    def visit_Attribute(self, attr):
909        if not isinstance(attr.ctx, ast.Load):
910            return self.generic_visit(attr)
911        value, value_expl = self.visit(attr.value)
912        res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
913        res_expl = self.explanation_param(self.display(res))
914        pat = "%s\n{%s = %s.%s\n}"
915        expl = pat % (res_expl, res_expl, value_expl, attr.attr)
916        return res, expl
917
918    def visit_Compare(self, comp):
919        self.push_format_context()
920        left_res, left_expl = self.visit(comp.left)
921        if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
922            left_expl = "({})".format(left_expl)
923        res_variables = [self.variable() for i in range(len(comp.ops))]
924        load_names = [ast.Name(v, ast.Load()) for v in res_variables]
925        store_names = [ast.Name(v, ast.Store()) for v in res_variables]
926        it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
927        expls = []
928        syms = []
929        results = [left_res]
930        for i, op, next_operand in it:
931            next_res, next_expl = self.visit(next_operand)
932            if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
933                next_expl = "({})".format(next_expl)
934            results.append(next_res)
935            sym = binop_map[op.__class__]
936            syms.append(ast.Str(sym))
937            expl = "%s %s %s" % (left_expl, sym, next_expl)
938            expls.append(ast.Str(expl))
939            res_expr = ast.Compare(left_res, [op], [next_res])
940            self.statements.append(ast.Assign([store_names[i]], res_expr))
941            left_res, left_expl = next_res, next_expl
942        # Use pytest.assertion.util._reprcompare if that's available.
943        expl_call = self.helper(
944            "call_reprcompare",
945            ast.Tuple(syms, ast.Load()),
946            ast.Tuple(load_names, ast.Load()),
947            ast.Tuple(expls, ast.Load()),
948            ast.Tuple(results, ast.Load()),
949        )
950        if len(comp.ops) > 1:
951            res = ast.BoolOp(ast.And(), load_names)
952        else:
953            res = load_names[0]
954        return res, self.explanation_param(self.pop_format_context(expl_call))
955