1"""
2Find intermediate evalutation results in assert statements through builtin AST.
3This should replace _assertionold.py eventually.
4"""
5
6import sys
7import ast
8
9import py
10from py._code.assertion import _format_explanation, BuiltinAssertionError
11
12
13def _is_ast_expr(node):
14    return isinstance(node, ast.expr)
15def _is_ast_stmt(node):
16    return isinstance(node, ast.stmt)
17
18
19class Failure(Exception):
20    """Error found while interpreting AST."""
21
22    def __init__(self, explanation=""):
23        self.cause = sys.exc_info()
24        self.explanation = explanation
25
26
27def interpret(source, frame, should_fail=False):
28    mod = ast.parse(source)
29    visitor = DebugInterpreter(frame)
30    try:
31        visitor.visit(mod)
32    except Failure:
33        failure = sys.exc_info()[1]
34        return getfailure(failure)
35    if should_fail:
36        return ("(assertion failed, but when it was re-run for "
37                "printing intermediate values, it did not fail.  Suggestions: "
38                "compute assert expression before the assert or use --no-assert)")
39
40def run(offending_line, frame=None):
41    if frame is None:
42        frame = py.code.Frame(sys._getframe(1))
43    return interpret(offending_line, frame)
44
45def getfailure(failure):
46    explanation = _format_explanation(failure.explanation)
47    value = failure.cause[1]
48    if str(value):
49        lines = explanation.splitlines()
50        if not lines:
51            lines.append("")
52        lines[0] += " << %s" % (value,)
53        explanation = "\n".join(lines)
54    text = "%s: %s" % (failure.cause[0].__name__, explanation)
55    if text.startswith("AssertionError: assert "):
56        text = text[16:]
57    return text
58
59
60operator_map = {
61    ast.BitOr : "|",
62    ast.BitXor : "^",
63    ast.BitAnd : "&",
64    ast.LShift : "<<",
65    ast.RShift : ">>",
66    ast.Add : "+",
67    ast.Sub : "-",
68    ast.Mult : "*",
69    ast.Div : "/",
70    ast.FloorDiv : "//",
71    ast.Mod : "%",
72    ast.Eq : "==",
73    ast.NotEq : "!=",
74    ast.Lt : "<",
75    ast.LtE : "<=",
76    ast.Gt : ">",
77    ast.GtE : ">=",
78    ast.Pow : "**",
79    ast.Is : "is",
80    ast.IsNot : "is not",
81    ast.In : "in",
82    ast.NotIn : "not in"
83}
84
85unary_map = {
86    ast.Not : "not %s",
87    ast.Invert : "~%s",
88    ast.USub : "-%s",
89    ast.UAdd : "+%s"
90}
91
92
93class DebugInterpreter(ast.NodeVisitor):
94    """Interpret AST nodes to gleam useful debugging information. """
95
96    def __init__(self, frame):
97        self.frame = frame
98
99    def generic_visit(self, node):
100        # Fallback when we don't have a special implementation.
101        if _is_ast_expr(node):
102            mod = ast.Expression(node)
103            co = self._compile(mod)
104            try:
105                result = self.frame.eval(co)
106            except Exception:
107                raise Failure()
108            explanation = self.frame.repr(result)
109            return explanation, result
110        elif _is_ast_stmt(node):
111            mod = ast.Module([node])
112            co = self._compile(mod, "exec")
113            try:
114                self.frame.exec_(co)
115            except Exception:
116                raise Failure()
117            return None, None
118        else:
119            raise AssertionError("can't handle %s" %(node,))
120
121    def _compile(self, source, mode="eval"):
122        return compile(source, "<assertion interpretation>", mode)
123
124    def visit_Expr(self, expr):
125        return self.visit(expr.value)
126
127    def visit_Module(self, mod):
128        for stmt in mod.body:
129            self.visit(stmt)
130
131    def visit_Name(self, name):
132        explanation, result = self.generic_visit(name)
133        # See if the name is local.
134        source = "%r in locals() is not globals()" % (name.id,)
135        co = self._compile(source)
136        try:
137            local = self.frame.eval(co)
138        except Exception:
139            # have to assume it isn't
140            local = False
141        if not local:
142            return name.id, result
143        return explanation, result
144
145    def visit_Compare(self, comp):
146        left = comp.left
147        left_explanation, left_result = self.visit(left)
148        for op, next_op in zip(comp.ops, comp.comparators):
149            next_explanation, next_result = self.visit(next_op)
150            op_symbol = operator_map[op.__class__]
151            explanation = "%s %s %s" % (left_explanation, op_symbol,
152                                        next_explanation)
153            source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,)
154            co = self._compile(source)
155            try:
156                result = self.frame.eval(co, __exprinfo_left=left_result,
157                                         __exprinfo_right=next_result)
158            except Exception:
159                raise Failure(explanation)
160            try:
161                if not result:
162                    break
163            except KeyboardInterrupt:
164                raise
165            except:
166                break
167            left_explanation, left_result = next_explanation, next_result
168
169        rcomp = py.code._reprcompare
170        if rcomp:
171            res = rcomp(op_symbol, left_result, next_result)
172            if res:
173                explanation = res
174        return explanation, result
175
176    def visit_BoolOp(self, boolop):
177        is_or = isinstance(boolop.op, ast.Or)
178        explanations = []
179        for operand in boolop.values:
180            explanation, result = self.visit(operand)
181            explanations.append(explanation)
182            if result == is_or:
183                break
184        name = is_or and " or " or " and "
185        explanation = "(" + name.join(explanations) + ")"
186        return explanation, result
187
188    def visit_UnaryOp(self, unary):
189        pattern = unary_map[unary.op.__class__]
190        operand_explanation, operand_result = self.visit(unary.operand)
191        explanation = pattern % (operand_explanation,)
192        co = self._compile(pattern % ("__exprinfo_expr",))
193        try:
194            result = self.frame.eval(co, __exprinfo_expr=operand_result)
195        except Exception:
196            raise Failure(explanation)
197        return explanation, result
198
199    def visit_BinOp(self, binop):
200        left_explanation, left_result = self.visit(binop.left)
201        right_explanation, right_result = self.visit(binop.right)
202        symbol = operator_map[binop.op.__class__]
203        explanation = "(%s %s %s)" % (left_explanation, symbol,
204                                      right_explanation)
205        source = "__exprinfo_left %s __exprinfo_right" % (symbol,)
206        co = self._compile(source)
207        try:
208            result = self.frame.eval(co, __exprinfo_left=left_result,
209                                     __exprinfo_right=right_result)
210        except Exception:
211            raise Failure(explanation)
212        return explanation, result
213
214    def visit_Call(self, call):
215        func_explanation, func = self.visit(call.func)
216        arg_explanations = []
217        ns = {"__exprinfo_func" : func}
218        arguments = []
219        for arg in call.args:
220            arg_explanation, arg_result = self.visit(arg)
221            arg_name = "__exprinfo_%s" % (len(ns),)
222            ns[arg_name] = arg_result
223            arguments.append(arg_name)
224            arg_explanations.append(arg_explanation)
225        for keyword in call.keywords:
226            arg_explanation, arg_result = self.visit(keyword.value)
227            arg_name = "__exprinfo_%s" % (len(ns),)
228            ns[arg_name] = arg_result
229            keyword_source = "%s=%%s" % (keyword.arg)
230            arguments.append(keyword_source % (arg_name,))
231            arg_explanations.append(keyword_source % (arg_explanation,))
232        if call.starargs:
233            arg_explanation, arg_result = self.visit(call.starargs)
234            arg_name = "__exprinfo_star"
235            ns[arg_name] = arg_result
236            arguments.append("*%s" % (arg_name,))
237            arg_explanations.append("*%s" % (arg_explanation,))
238        if call.kwargs:
239            arg_explanation, arg_result = self.visit(call.kwargs)
240            arg_name = "__exprinfo_kwds"
241            ns[arg_name] = arg_result
242            arguments.append("**%s" % (arg_name,))
243            arg_explanations.append("**%s" % (arg_explanation,))
244        args_explained = ", ".join(arg_explanations)
245        explanation = "%s(%s)" % (func_explanation, args_explained)
246        args = ", ".join(arguments)
247        source = "__exprinfo_func(%s)" % (args,)
248        co = self._compile(source)
249        try:
250            result = self.frame.eval(co, **ns)
251        except Exception:
252            raise Failure(explanation)
253        pattern = "%s\n{%s = %s\n}"
254        rep = self.frame.repr(result)
255        explanation = pattern % (rep, rep, explanation)
256        return explanation, result
257
258    def _is_builtin_name(self, name):
259        pattern = "%r not in globals() and %r not in locals()"
260        source = pattern % (name.id, name.id)
261        co = self._compile(source)
262        try:
263            return self.frame.eval(co)
264        except Exception:
265            return False
266
267    def visit_Attribute(self, attr):
268        if not isinstance(attr.ctx, ast.Load):
269            return self.generic_visit(attr)
270        source_explanation, source_result = self.visit(attr.value)
271        explanation = "%s.%s" % (source_explanation, attr.attr)
272        source = "__exprinfo_expr.%s" % (attr.attr,)
273        co = self._compile(source)
274        try:
275            result = self.frame.eval(co, __exprinfo_expr=source_result)
276        except Exception:
277            raise Failure(explanation)
278        explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result),
279                                              self.frame.repr(result),
280                                              source_explanation, attr.attr)
281        # Check if the attr is from an instance.
282        source = "%r in getattr(__exprinfo_expr, '__dict__', {})"
283        source = source % (attr.attr,)
284        co = self._compile(source)
285        try:
286            from_instance = self.frame.eval(co, __exprinfo_expr=source_result)
287        except Exception:
288            from_instance = True
289        if from_instance:
290            rep = self.frame.repr(result)
291            pattern = "%s\n{%s = %s\n}"
292            explanation = pattern % (rep, rep, explanation)
293        return explanation, result
294
295    def visit_Assert(self, assrt):
296        test_explanation, test_result = self.visit(assrt.test)
297        if test_explanation.startswith("False\n{False =") and \
298                test_explanation.endswith("\n"):
299            test_explanation = test_explanation[15:-2]
300        explanation = "assert %s" % (test_explanation,)
301        if not test_result:
302            try:
303                raise BuiltinAssertionError
304            except Exception:
305                raise Failure(explanation)
306        return explanation, test_result
307
308    def visit_Assign(self, assign):
309        value_explanation, value_result = self.visit(assign.value)
310        explanation = "... = %s" % (value_explanation,)
311        name = ast.Name("__exprinfo_expr", ast.Load(),
312                        lineno=assign.value.lineno,
313                        col_offset=assign.value.col_offset)
314        new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno,
315                                col_offset=assign.col_offset)
316        mod = ast.Module([new_assign])
317        co = self._compile(mod, "exec")
318        try:
319            self.frame.exec_(co, __exprinfo_expr=value_result)
320        except Exception:
321            raise Failure(explanation)
322        return explanation, value_result
323