1from __future__ import absolute_import, division, generators, print_function 2 3from bisect import bisect_right 4import sys 5import inspect, tokenize 6import py 7cpy_compile = compile 8 9try: 10 import _ast 11 from _ast import PyCF_ONLY_AST as _AST_FLAG 12except ImportError: 13 _AST_FLAG = 0 14 _ast = None 15 16 17class Source(object): 18 """ a immutable object holding a source code fragment, 19 possibly deindenting it. 20 """ 21 _compilecounter = 0 22 def __init__(self, *parts, **kwargs): 23 self.lines = lines = [] 24 de = kwargs.get('deindent', True) 25 rstrip = kwargs.get('rstrip', True) 26 for part in parts: 27 if not part: 28 partlines = [] 29 if isinstance(part, Source): 30 partlines = part.lines 31 elif isinstance(part, (tuple, list)): 32 partlines = [x.rstrip("\n") for x in part] 33 elif isinstance(part, py.builtin._basestring): 34 partlines = part.split('\n') 35 if rstrip: 36 while partlines: 37 if partlines[-1].strip(): 38 break 39 partlines.pop() 40 else: 41 partlines = getsource(part, deindent=de).lines 42 if de: 43 partlines = deindent(partlines) 44 lines.extend(partlines) 45 46 def __eq__(self, other): 47 try: 48 return self.lines == other.lines 49 except AttributeError: 50 if isinstance(other, str): 51 return str(self) == other 52 return False 53 54 __hash__ = None 55 56 def __getitem__(self, key): 57 if isinstance(key, int): 58 return self.lines[key] 59 else: 60 if key.step not in (None, 1): 61 raise IndexError("cannot slice a Source with a step") 62 newsource = Source() 63 newsource.lines = self.lines[key.start:key.stop] 64 return newsource 65 66 def __len__(self): 67 return len(self.lines) 68 69 def strip(self): 70 """ return new source object with trailing 71 and leading blank lines removed. 72 """ 73 start, end = 0, len(self) 74 while start < end and not self.lines[start].strip(): 75 start += 1 76 while end > start and not self.lines[end-1].strip(): 77 end -= 1 78 source = Source() 79 source.lines[:] = self.lines[start:end] 80 return source 81 82 def putaround(self, before='', after='', indent=' ' * 4): 83 """ return a copy of the source object with 84 'before' and 'after' wrapped around it. 85 """ 86 before = Source(before) 87 after = Source(after) 88 newsource = Source() 89 lines = [ (indent + line) for line in self.lines] 90 newsource.lines = before.lines + lines + after.lines 91 return newsource 92 93 def indent(self, indent=' ' * 4): 94 """ return a copy of the source object with 95 all lines indented by the given indent-string. 96 """ 97 newsource = Source() 98 newsource.lines = [(indent+line) for line in self.lines] 99 return newsource 100 101 def getstatement(self, lineno, assertion=False): 102 """ return Source statement which contains the 103 given linenumber (counted from 0). 104 """ 105 start, end = self.getstatementrange(lineno, assertion) 106 return self[start:end] 107 108 def getstatementrange(self, lineno, assertion=False): 109 """ return (start, end) tuple which spans the minimal 110 statement region which containing the given lineno. 111 """ 112 if not (0 <= lineno < len(self)): 113 raise IndexError("lineno out of range") 114 ast, start, end = getstatementrange_ast(lineno, self) 115 return start, end 116 117 def deindent(self, offset=None): 118 """ return a new source object deindented by offset. 119 If offset is None then guess an indentation offset from 120 the first non-blank line. Subsequent lines which have a 121 lower indentation offset will be copied verbatim as 122 they are assumed to be part of multilines. 123 """ 124 # XXX maybe use the tokenizer to properly handle multiline 125 # strings etc.pp? 126 newsource = Source() 127 newsource.lines[:] = deindent(self.lines, offset) 128 return newsource 129 130 def isparseable(self, deindent=True): 131 """ return True if source is parseable, heuristically 132 deindenting it by default. 133 """ 134 try: 135 import parser 136 except ImportError: 137 syntax_checker = lambda x: compile(x, 'asd', 'exec') 138 else: 139 syntax_checker = parser.suite 140 141 if deindent: 142 source = str(self.deindent()) 143 else: 144 source = str(self) 145 try: 146 #compile(source+'\n', "x", "exec") 147 syntax_checker(source+'\n') 148 except KeyboardInterrupt: 149 raise 150 except Exception: 151 return False 152 else: 153 return True 154 155 def __str__(self): 156 return "\n".join(self.lines) 157 158 def compile(self, filename=None, mode='exec', 159 flag=generators.compiler_flag, 160 dont_inherit=0, _genframe=None): 161 """ return compiled code object. if filename is None 162 invent an artificial filename which displays 163 the source/line position of the caller frame. 164 """ 165 if not filename or py.path.local(filename).check(file=0): 166 if _genframe is None: 167 _genframe = sys._getframe(1) # the caller 168 fn,lineno = _genframe.f_code.co_filename, _genframe.f_lineno 169 base = "<%d-codegen " % self._compilecounter 170 self.__class__._compilecounter += 1 171 if not filename: 172 filename = base + '%s:%d>' % (fn, lineno) 173 else: 174 filename = base + '%r %s:%d>' % (filename, fn, lineno) 175 source = "\n".join(self.lines) + '\n' 176 try: 177 co = cpy_compile(source, filename, mode, flag) 178 except SyntaxError: 179 ex = sys.exc_info()[1] 180 # re-represent syntax errors from parsing python strings 181 msglines = self.lines[:ex.lineno] 182 if ex.offset: 183 msglines.append(" "*ex.offset + '^') 184 msglines.append("(code was compiled probably from here: %s)" % filename) 185 newex = SyntaxError('\n'.join(msglines)) 186 newex.offset = ex.offset 187 newex.lineno = ex.lineno 188 newex.text = ex.text 189 raise newex 190 else: 191 if flag & _AST_FLAG: 192 return co 193 lines = [(x + "\n") for x in self.lines] 194 py.std.linecache.cache[filename] = (1, None, lines, filename) 195 return co 196 197# 198# public API shortcut functions 199# 200 201def compile_(source, filename=None, mode='exec', flags= 202 generators.compiler_flag, dont_inherit=0): 203 """ compile the given source to a raw code object, 204 and maintain an internal cache which allows later 205 retrieval of the source code for the code object 206 and any recursively created code objects. 207 """ 208 if _ast is not None and isinstance(source, _ast.AST): 209 # XXX should Source support having AST? 210 return cpy_compile(source, filename, mode, flags, dont_inherit) 211 _genframe = sys._getframe(1) # the caller 212 s = Source(source) 213 co = s.compile(filename, mode, flags, _genframe=_genframe) 214 return co 215 216 217def getfslineno(obj): 218 """ Return source location (path, lineno) for the given object. 219 If the source cannot be determined return ("", -1) 220 """ 221 import _pytest._code 222 try: 223 code = _pytest._code.Code(obj) 224 except TypeError: 225 try: 226 fn = (py.std.inspect.getsourcefile(obj) or 227 py.std.inspect.getfile(obj)) 228 except TypeError: 229 return "", -1 230 231 fspath = fn and py.path.local(fn) or None 232 lineno = -1 233 if fspath: 234 try: 235 _, lineno = findsource(obj) 236 except IOError: 237 pass 238 else: 239 fspath = code.path 240 lineno = code.firstlineno 241 assert isinstance(lineno, int) 242 return fspath, lineno 243 244# 245# helper functions 246# 247 248def findsource(obj): 249 try: 250 sourcelines, lineno = py.std.inspect.findsource(obj) 251 except py.builtin._sysex: 252 raise 253 except: 254 return None, -1 255 source = Source() 256 source.lines = [line.rstrip() for line in sourcelines] 257 return source, lineno 258 259 260def getsource(obj, **kwargs): 261 import _pytest._code 262 obj = _pytest._code.getrawcode(obj) 263 try: 264 strsrc = inspect.getsource(obj) 265 except IndentationError: 266 strsrc = "\"Buggy python version consider upgrading, cannot get source\"" 267 assert isinstance(strsrc, str) 268 return Source(strsrc, **kwargs) 269 270 271def deindent(lines, offset=None): 272 if offset is None: 273 for line in lines: 274 line = line.expandtabs() 275 s = line.lstrip() 276 if s: 277 offset = len(line)-len(s) 278 break 279 else: 280 offset = 0 281 if offset == 0: 282 return list(lines) 283 newlines = [] 284 285 def readline_generator(lines): 286 for line in lines: 287 yield line + '\n' 288 while True: 289 yield '' 290 291 it = readline_generator(lines) 292 293 try: 294 for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens(lambda: next(it)): 295 if sline > len(lines): 296 break # End of input reached 297 if sline > len(newlines): 298 line = lines[sline - 1].expandtabs() 299 if line.lstrip() and line[:offset].isspace(): 300 line = line[offset:] # Deindent 301 newlines.append(line) 302 303 for i in range(sline, eline): 304 # Don't deindent continuing lines of 305 # multiline tokens (i.e. multiline strings) 306 newlines.append(lines[i]) 307 except (IndentationError, tokenize.TokenError): 308 pass 309 # Add any lines we didn't see. E.g. if an exception was raised. 310 newlines.extend(lines[len(newlines):]) 311 return newlines 312 313 314def get_statement_startend2(lineno, node): 315 import ast 316 # flatten all statements and except handlers into one lineno-list 317 # AST's line numbers start indexing at 1 318 l = [] 319 for x in ast.walk(node): 320 if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler): 321 l.append(x.lineno - 1) 322 for name in "finalbody", "orelse": 323 val = getattr(x, name, None) 324 if val: 325 # treat the finally/orelse part as its own statement 326 l.append(val[0].lineno - 1 - 1) 327 l.sort() 328 insert_index = bisect_right(l, lineno) 329 start = l[insert_index - 1] 330 if insert_index >= len(l): 331 end = None 332 else: 333 end = l[insert_index] 334 return start, end 335 336 337def getstatementrange_ast(lineno, source, assertion=False, astnode=None): 338 if astnode is None: 339 content = str(source) 340 if sys.version_info < (2,7): 341 content += "\n" 342 try: 343 astnode = compile(content, "source", "exec", 1024) # 1024 for AST 344 except ValueError: 345 start, end = getstatementrange_old(lineno, source, assertion) 346 return None, start, end 347 start, end = get_statement_startend2(lineno, astnode) 348 # we need to correct the end: 349 # - ast-parsing strips comments 350 # - there might be empty lines 351 # - we might have lesser indented code blocks at the end 352 if end is None: 353 end = len(source.lines) 354 355 if end > start + 1: 356 # make sure we don't span differently indented code blocks 357 # by using the BlockFinder helper used which inspect.getsource() uses itself 358 block_finder = inspect.BlockFinder() 359 # if we start with an indented line, put blockfinder to "started" mode 360 block_finder.started = source.lines[start][0].isspace() 361 it = ((x + "\n") for x in source.lines[start:end]) 362 try: 363 for tok in tokenize.generate_tokens(lambda: next(it)): 364 block_finder.tokeneater(*tok) 365 except (inspect.EndOfBlock, IndentationError): 366 end = block_finder.last + start 367 except Exception: 368 pass 369 370 # the end might still point to a comment or empty line, correct it 371 while end: 372 line = source.lines[end - 1].lstrip() 373 if line.startswith("#") or not line: 374 end -= 1 375 else: 376 break 377 return astnode, start, end 378 379 380def getstatementrange_old(lineno, source, assertion=False): 381 """ return (start, end) tuple which spans the minimal 382 statement region which containing the given lineno. 383 raise an IndexError if no such statementrange can be found. 384 """ 385 # XXX this logic is only used on python2.4 and below 386 # 1. find the start of the statement 387 from codeop import compile_command 388 for start in range(lineno, -1, -1): 389 if assertion: 390 line = source.lines[start] 391 # the following lines are not fully tested, change with care 392 if 'super' in line and 'self' in line and '__init__' in line: 393 raise IndexError("likely a subclass") 394 if "assert" not in line and "raise" not in line: 395 continue 396 trylines = source.lines[start:lineno+1] 397 # quick hack to prepare parsing an indented line with 398 # compile_command() (which errors on "return" outside defs) 399 trylines.insert(0, 'def xxx():') 400 trysource = '\n '.join(trylines) 401 # ^ space here 402 try: 403 compile_command(trysource) 404 except (SyntaxError, OverflowError, ValueError): 405 continue 406 407 # 2. find the end of the statement 408 for end in range(lineno+1, len(source)+1): 409 trysource = source[start:end] 410 if trysource.isparseable(): 411 return start, end 412 raise SyntaxError("no valid source range around line %d " % (lineno,)) 413 414 415