1""" 2Find intermediate evalutation results in assert statements through builtin AST. 3This should replace _assertionold.py eventually. 4""" 5 6import sys 7import ast 8 9import py 10from py._code.assertion import _format_explanation, BuiltinAssertionError 11 12 13def _is_ast_expr(node): 14 return isinstance(node, ast.expr) 15def _is_ast_stmt(node): 16 return isinstance(node, ast.stmt) 17 18 19class Failure(Exception): 20 """Error found while interpreting AST.""" 21 22 def __init__(self, explanation=""): 23 self.cause = sys.exc_info() 24 self.explanation = explanation 25 26 27def interpret(source, frame, should_fail=False): 28 mod = ast.parse(source) 29 visitor = DebugInterpreter(frame) 30 try: 31 visitor.visit(mod) 32 except Failure: 33 failure = sys.exc_info()[1] 34 return getfailure(failure) 35 if should_fail: 36 return ("(assertion failed, but when it was re-run for " 37 "printing intermediate values, it did not fail. Suggestions: " 38 "compute assert expression before the assert or use --no-assert)") 39 40def run(offending_line, frame=None): 41 if frame is None: 42 frame = py.code.Frame(sys._getframe(1)) 43 return interpret(offending_line, frame) 44 45def getfailure(failure): 46 explanation = _format_explanation(failure.explanation) 47 value = failure.cause[1] 48 if str(value): 49 lines = explanation.splitlines() 50 if not lines: 51 lines.append("") 52 lines[0] += " << %s" % (value,) 53 explanation = "\n".join(lines) 54 text = "%s: %s" % (failure.cause[0].__name__, explanation) 55 if text.startswith("AssertionError: assert "): 56 text = text[16:] 57 return text 58 59 60operator_map = { 61 ast.BitOr : "|", 62 ast.BitXor : "^", 63 ast.BitAnd : "&", 64 ast.LShift : "<<", 65 ast.RShift : ">>", 66 ast.Add : "+", 67 ast.Sub : "-", 68 ast.Mult : "*", 69 ast.Div : "/", 70 ast.FloorDiv : "//", 71 ast.Mod : "%", 72 ast.Eq : "==", 73 ast.NotEq : "!=", 74 ast.Lt : "<", 75 ast.LtE : "<=", 76 ast.Gt : ">", 77 ast.GtE : ">=", 78 ast.Pow : "**", 79 ast.Is : "is", 80 ast.IsNot : "is not", 81 ast.In : "in", 82 ast.NotIn : "not in" 83} 84 85unary_map = { 86 ast.Not : "not %s", 87 ast.Invert : "~%s", 88 ast.USub : "-%s", 89 ast.UAdd : "+%s" 90} 91 92 93class DebugInterpreter(ast.NodeVisitor): 94 """Interpret AST nodes to gleam useful debugging information. """ 95 96 def __init__(self, frame): 97 self.frame = frame 98 99 def generic_visit(self, node): 100 # Fallback when we don't have a special implementation. 101 if _is_ast_expr(node): 102 mod = ast.Expression(node) 103 co = self._compile(mod) 104 try: 105 result = self.frame.eval(co) 106 except Exception: 107 raise Failure() 108 explanation = self.frame.repr(result) 109 return explanation, result 110 elif _is_ast_stmt(node): 111 mod = ast.Module([node]) 112 co = self._compile(mod, "exec") 113 try: 114 self.frame.exec_(co) 115 except Exception: 116 raise Failure() 117 return None, None 118 else: 119 raise AssertionError("can't handle %s" %(node,)) 120 121 def _compile(self, source, mode="eval"): 122 return compile(source, "<assertion interpretation>", mode) 123 124 def visit_Expr(self, expr): 125 return self.visit(expr.value) 126 127 def visit_Module(self, mod): 128 for stmt in mod.body: 129 self.visit(stmt) 130 131 def visit_Name(self, name): 132 explanation, result = self.generic_visit(name) 133 # See if the name is local. 134 source = "%r in locals() is not globals()" % (name.id,) 135 co = self._compile(source) 136 try: 137 local = self.frame.eval(co) 138 except Exception: 139 # have to assume it isn't 140 local = False 141 if not local: 142 return name.id, result 143 return explanation, result 144 145 def visit_Compare(self, comp): 146 left = comp.left 147 left_explanation, left_result = self.visit(left) 148 for op, next_op in zip(comp.ops, comp.comparators): 149 next_explanation, next_result = self.visit(next_op) 150 op_symbol = operator_map[op.__class__] 151 explanation = "%s %s %s" % (left_explanation, op_symbol, 152 next_explanation) 153 source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,) 154 co = self._compile(source) 155 try: 156 result = self.frame.eval(co, __exprinfo_left=left_result, 157 __exprinfo_right=next_result) 158 except Exception: 159 raise Failure(explanation) 160 try: 161 if not result: 162 break 163 except KeyboardInterrupt: 164 raise 165 except: 166 break 167 left_explanation, left_result = next_explanation, next_result 168 169 rcomp = py.code._reprcompare 170 if rcomp: 171 res = rcomp(op_symbol, left_result, next_result) 172 if res: 173 explanation = res 174 return explanation, result 175 176 def visit_BoolOp(self, boolop): 177 is_or = isinstance(boolop.op, ast.Or) 178 explanations = [] 179 for operand in boolop.values: 180 explanation, result = self.visit(operand) 181 explanations.append(explanation) 182 if result == is_or: 183 break 184 name = is_or and " or " or " and " 185 explanation = "(" + name.join(explanations) + ")" 186 return explanation, result 187 188 def visit_UnaryOp(self, unary): 189 pattern = unary_map[unary.op.__class__] 190 operand_explanation, operand_result = self.visit(unary.operand) 191 explanation = pattern % (operand_explanation,) 192 co = self._compile(pattern % ("__exprinfo_expr",)) 193 try: 194 result = self.frame.eval(co, __exprinfo_expr=operand_result) 195 except Exception: 196 raise Failure(explanation) 197 return explanation, result 198 199 def visit_BinOp(self, binop): 200 left_explanation, left_result = self.visit(binop.left) 201 right_explanation, right_result = self.visit(binop.right) 202 symbol = operator_map[binop.op.__class__] 203 explanation = "(%s %s %s)" % (left_explanation, symbol, 204 right_explanation) 205 source = "__exprinfo_left %s __exprinfo_right" % (symbol,) 206 co = self._compile(source) 207 try: 208 result = self.frame.eval(co, __exprinfo_left=left_result, 209 __exprinfo_right=right_result) 210 except Exception: 211 raise Failure(explanation) 212 return explanation, result 213 214 def visit_Call(self, call): 215 func_explanation, func = self.visit(call.func) 216 arg_explanations = [] 217 ns = {"__exprinfo_func" : func} 218 arguments = [] 219 for arg in call.args: 220 arg_explanation, arg_result = self.visit(arg) 221 arg_name = "__exprinfo_%s" % (len(ns),) 222 ns[arg_name] = arg_result 223 arguments.append(arg_name) 224 arg_explanations.append(arg_explanation) 225 for keyword in call.keywords: 226 arg_explanation, arg_result = self.visit(keyword.value) 227 arg_name = "__exprinfo_%s" % (len(ns),) 228 ns[arg_name] = arg_result 229 keyword_source = "%s=%%s" % (keyword.arg) 230 arguments.append(keyword_source % (arg_name,)) 231 arg_explanations.append(keyword_source % (arg_explanation,)) 232 if call.starargs: 233 arg_explanation, arg_result = self.visit(call.starargs) 234 arg_name = "__exprinfo_star" 235 ns[arg_name] = arg_result 236 arguments.append("*%s" % (arg_name,)) 237 arg_explanations.append("*%s" % (arg_explanation,)) 238 if call.kwargs: 239 arg_explanation, arg_result = self.visit(call.kwargs) 240 arg_name = "__exprinfo_kwds" 241 ns[arg_name] = arg_result 242 arguments.append("**%s" % (arg_name,)) 243 arg_explanations.append("**%s" % (arg_explanation,)) 244 args_explained = ", ".join(arg_explanations) 245 explanation = "%s(%s)" % (func_explanation, args_explained) 246 args = ", ".join(arguments) 247 source = "__exprinfo_func(%s)" % (args,) 248 co = self._compile(source) 249 try: 250 result = self.frame.eval(co, **ns) 251 except Exception: 252 raise Failure(explanation) 253 pattern = "%s\n{%s = %s\n}" 254 rep = self.frame.repr(result) 255 explanation = pattern % (rep, rep, explanation) 256 return explanation, result 257 258 def _is_builtin_name(self, name): 259 pattern = "%r not in globals() and %r not in locals()" 260 source = pattern % (name.id, name.id) 261 co = self._compile(source) 262 try: 263 return self.frame.eval(co) 264 except Exception: 265 return False 266 267 def visit_Attribute(self, attr): 268 if not isinstance(attr.ctx, ast.Load): 269 return self.generic_visit(attr) 270 source_explanation, source_result = self.visit(attr.value) 271 explanation = "%s.%s" % (source_explanation, attr.attr) 272 source = "__exprinfo_expr.%s" % (attr.attr,) 273 co = self._compile(source) 274 try: 275 result = self.frame.eval(co, __exprinfo_expr=source_result) 276 except Exception: 277 raise Failure(explanation) 278 explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result), 279 self.frame.repr(result), 280 source_explanation, attr.attr) 281 # Check if the attr is from an instance. 282 source = "%r in getattr(__exprinfo_expr, '__dict__', {})" 283 source = source % (attr.attr,) 284 co = self._compile(source) 285 try: 286 from_instance = self.frame.eval(co, __exprinfo_expr=source_result) 287 except Exception: 288 from_instance = True 289 if from_instance: 290 rep = self.frame.repr(result) 291 pattern = "%s\n{%s = %s\n}" 292 explanation = pattern % (rep, rep, explanation) 293 return explanation, result 294 295 def visit_Assert(self, assrt): 296 test_explanation, test_result = self.visit(assrt.test) 297 if test_explanation.startswith("False\n{False =") and \ 298 test_explanation.endswith("\n"): 299 test_explanation = test_explanation[15:-2] 300 explanation = "assert %s" % (test_explanation,) 301 if not test_result: 302 try: 303 raise BuiltinAssertionError 304 except Exception: 305 raise Failure(explanation) 306 return explanation, test_result 307 308 def visit_Assign(self, assign): 309 value_explanation, value_result = self.visit(assign.value) 310 explanation = "... = %s" % (value_explanation,) 311 name = ast.Name("__exprinfo_expr", ast.Load(), 312 lineno=assign.value.lineno, 313 col_offset=assign.value.col_offset) 314 new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno, 315 col_offset=assign.col_offset) 316 mod = ast.Module([new_assign]) 317 co = self._compile(mod, "exec") 318 try: 319 self.frame.exec_(co, __exprinfo_expr=value_result) 320 except Exception: 321 raise Failure(explanation) 322 return explanation, value_result 323