1from __future__ import generators 2 3from bisect import bisect_right 4import sys 5import inspect, tokenize 6import py 7from types import ModuleType 8cpy_compile = compile 9 10try: 11 import _ast 12 from _ast import PyCF_ONLY_AST as _AST_FLAG 13except ImportError: 14 _AST_FLAG = 0 15 _ast = None 16 17 18class Source(object): 19 """ a immutable object holding a source code fragment, 20 possibly deindenting it. 21 """ 22 _compilecounter = 0 23 def __init__(self, *parts, **kwargs): 24 self.lines = lines = [] 25 de = kwargs.get('deindent', True) 26 rstrip = kwargs.get('rstrip', True) 27 for part in parts: 28 if not part: 29 partlines = [] 30 if isinstance(part, Source): 31 partlines = part.lines 32 elif isinstance(part, (tuple, list)): 33 partlines = [x.rstrip("\n") for x in part] 34 elif isinstance(part, py.builtin._basestring): 35 partlines = part.split('\n') 36 if rstrip: 37 while partlines: 38 if partlines[-1].strip(): 39 break 40 partlines.pop() 41 else: 42 partlines = getsource(part, deindent=de).lines 43 if de: 44 partlines = deindent(partlines) 45 lines.extend(partlines) 46 47 def __eq__(self, other): 48 try: 49 return self.lines == other.lines 50 except AttributeError: 51 if isinstance(other, str): 52 return str(self) == other 53 return False 54 55 def __getitem__(self, key): 56 if isinstance(key, int): 57 return self.lines[key] 58 else: 59 if key.step not in (None, 1): 60 raise IndexError("cannot slice a Source with a step") 61 return self.__getslice__(key.start, key.stop) 62 63 def __len__(self): 64 return len(self.lines) 65 66 def __getslice__(self, start, end): 67 newsource = Source() 68 newsource.lines = self.lines[start:end] 69 return newsource 70 71 def strip(self): 72 """ return new source object with trailing 73 and leading blank lines removed. 74 """ 75 start, end = 0, len(self) 76 while start < end and not self.lines[start].strip(): 77 start += 1 78 while end > start and not self.lines[end-1].strip(): 79 end -= 1 80 source = Source() 81 source.lines[:] = self.lines[start:end] 82 return source 83 84 def putaround(self, before='', after='', indent=' ' * 4): 85 """ return a copy of the source object with 86 'before' and 'after' wrapped around it. 87 """ 88 before = Source(before) 89 after = Source(after) 90 newsource = Source() 91 lines = [ (indent + line) for line in self.lines] 92 newsource.lines = before.lines + lines + after.lines 93 return newsource 94 95 def indent(self, indent=' ' * 4): 96 """ return a copy of the source object with 97 all lines indented by the given indent-string. 98 """ 99 newsource = Source() 100 newsource.lines = [(indent+line) for line in self.lines] 101 return newsource 102 103 def getstatement(self, lineno, assertion=False): 104 """ return Source statement which contains the 105 given linenumber (counted from 0). 106 """ 107 start, end = self.getstatementrange(lineno, assertion) 108 return self[start:end] 109 110 def getstatementrange(self, lineno, assertion=False): 111 """ return (start, end) tuple which spans the minimal 112 statement region which containing the given lineno. 113 """ 114 if not (0 <= lineno < len(self)): 115 raise IndexError("lineno out of range") 116 ast, start, end = getstatementrange_ast(lineno, self) 117 return start, end 118 119 def deindent(self, offset=None): 120 """ return a new source object deindented by offset. 121 If offset is None then guess an indentation offset from 122 the first non-blank line. Subsequent lines which have a 123 lower indentation offset will be copied verbatim as 124 they are assumed to be part of multilines. 125 """ 126 # XXX maybe use the tokenizer to properly handle multiline 127 # strings etc.pp? 128 newsource = Source() 129 newsource.lines[:] = deindent(self.lines, offset) 130 return newsource 131 132 def isparseable(self, deindent=True): 133 """ return True if source is parseable, heuristically 134 deindenting it by default. 135 """ 136 try: 137 import parser 138 except ImportError: 139 syntax_checker = lambda x: compile(x, 'asd', 'exec') 140 else: 141 syntax_checker = parser.suite 142 143 if deindent: 144 source = str(self.deindent()) 145 else: 146 source = str(self) 147 try: 148 #compile(source+'\n', "x", "exec") 149 syntax_checker(source+'\n') 150 except KeyboardInterrupt: 151 raise 152 except Exception: 153 return False 154 else: 155 return True 156 157 def __str__(self): 158 return "\n".join(self.lines) 159 160 def compile(self, filename=None, mode='exec', 161 flag=generators.compiler_flag, 162 dont_inherit=0, _genframe=None): 163 """ return compiled code object. if filename is None 164 invent an artificial filename which displays 165 the source/line position of the caller frame. 166 """ 167 if not filename or py.path.local(filename).check(file=0): 168 if _genframe is None: 169 _genframe = sys._getframe(1) # the caller 170 fn,lineno = _genframe.f_code.co_filename, _genframe.f_lineno 171 base = "<%d-codegen " % self._compilecounter 172 self.__class__._compilecounter += 1 173 if not filename: 174 filename = base + '%s:%d>' % (fn, lineno) 175 else: 176 filename = base + '%r %s:%d>' % (filename, fn, lineno) 177 source = "\n".join(self.lines) + '\n' 178 try: 179 co = cpy_compile(source, filename, mode, flag) 180 except SyntaxError: 181 ex = sys.exc_info()[1] 182 # re-represent syntax errors from parsing python strings 183 msglines = self.lines[:ex.lineno] 184 if ex.offset: 185 msglines.append(" "*ex.offset + '^') 186 msglines.append("(code was compiled probably from here: %s)" % filename) 187 newex = SyntaxError('\n'.join(msglines)) 188 newex.offset = ex.offset 189 newex.lineno = ex.lineno 190 newex.text = ex.text 191 raise newex 192 else: 193 if flag & _AST_FLAG: 194 return co 195 lines = [(x + "\n") for x in self.lines] 196 if sys.version_info[0] >= 3: 197 # XXX py3's inspect.getsourcefile() checks for a module 198 # and a pep302 __loader__ ... we don't have a module 199 # at code compile-time so we need to fake it here 200 m = ModuleType("_pycodecompile_pseudo_module") 201 py.std.inspect.modulesbyfile[filename] = None 202 py.std.sys.modules[None] = m 203 m.__loader__ = 1 204 py.std.linecache.cache[filename] = (1, None, lines, filename) 205 return co 206 207# 208# public API shortcut functions 209# 210 211def compile_(source, filename=None, mode='exec', flags= 212 generators.compiler_flag, dont_inherit=0): 213 """ compile the given source to a raw code object, 214 and maintain an internal cache which allows later 215 retrieval of the source code for the code object 216 and any recursively created code objects. 217 """ 218 if _ast is not None and isinstance(source, _ast.AST): 219 # XXX should Source support having AST? 220 return cpy_compile(source, filename, mode, flags, dont_inherit) 221 _genframe = sys._getframe(1) # the caller 222 s = Source(source) 223 co = s.compile(filename, mode, flags, _genframe=_genframe) 224 return co 225 226 227def getfslineno(obj): 228 """ Return source location (path, lineno) for the given object. 229 If the source cannot be determined return ("", -1) 230 """ 231 try: 232 code = py.code.Code(obj) 233 except TypeError: 234 try: 235 fn = (py.std.inspect.getsourcefile(obj) or 236 py.std.inspect.getfile(obj)) 237 except TypeError: 238 return "", -1 239 240 fspath = fn and py.path.local(fn) or None 241 lineno = -1 242 if fspath: 243 try: 244 _, lineno = findsource(obj) 245 except IOError: 246 pass 247 else: 248 fspath = code.path 249 lineno = code.firstlineno 250 assert isinstance(lineno, int) 251 return fspath, lineno 252 253# 254# helper functions 255# 256 257def findsource(obj): 258 try: 259 sourcelines, lineno = py.std.inspect.findsource(obj) 260 except py.builtin._sysex: 261 raise 262 except: 263 return None, -1 264 source = Source() 265 source.lines = [line.rstrip() for line in sourcelines] 266 return source, lineno 267 268def getsource(obj, **kwargs): 269 obj = py.code.getrawcode(obj) 270 try: 271 strsrc = inspect.getsource(obj) 272 except IndentationError: 273 strsrc = "\"Buggy python version consider upgrading, cannot get source\"" 274 assert isinstance(strsrc, str) 275 return Source(strsrc, **kwargs) 276 277def deindent(lines, offset=None): 278 if offset is None: 279 for line in lines: 280 line = line.expandtabs() 281 s = line.lstrip() 282 if s: 283 offset = len(line)-len(s) 284 break 285 else: 286 offset = 0 287 if offset == 0: 288 return list(lines) 289 newlines = [] 290 def readline_generator(lines): 291 for line in lines: 292 yield line + '\n' 293 while True: 294 yield '' 295 296 it = readline_generator(lines) 297 298 try: 299 for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens(lambda: next(it)): 300 if sline > len(lines): 301 break # End of input reached 302 if sline > len(newlines): 303 line = lines[sline - 1].expandtabs() 304 if line.lstrip() and line[:offset].isspace(): 305 line = line[offset:] # Deindent 306 newlines.append(line) 307 308 for i in range(sline, eline): 309 # Don't deindent continuing lines of 310 # multiline tokens (i.e. multiline strings) 311 newlines.append(lines[i]) 312 except (IndentationError, tokenize.TokenError): 313 pass 314 # Add any lines we didn't see. E.g. if an exception was raised. 315 newlines.extend(lines[len(newlines):]) 316 return newlines 317 318 319def get_statement_startend2(lineno, node): 320 import ast 321 # flatten all statements and except handlers into one lineno-list 322 # AST's line numbers start indexing at 1 323 l = [] 324 for x in ast.walk(node): 325 if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler): 326 l.append(x.lineno - 1) 327 for name in "finalbody", "orelse": 328 val = getattr(x, name, None) 329 if val: 330 # treat the finally/orelse part as its own statement 331 l.append(val[0].lineno - 1 - 1) 332 l.sort() 333 insert_index = bisect_right(l, lineno) 334 start = l[insert_index - 1] 335 if insert_index >= len(l): 336 end = None 337 else: 338 end = l[insert_index] 339 return start, end 340 341 342def getstatementrange_ast(lineno, source, assertion=False, astnode=None): 343 if astnode is None: 344 content = str(source) 345 if sys.version_info < (2,7): 346 content += "\n" 347 try: 348 astnode = compile(content, "source", "exec", 1024) # 1024 for AST 349 except ValueError: 350 start, end = getstatementrange_old(lineno, source, assertion) 351 return None, start, end 352 start, end = get_statement_startend2(lineno, astnode) 353 # we need to correct the end: 354 # - ast-parsing strips comments 355 # - there might be empty lines 356 # - we might have lesser indented code blocks at the end 357 if end is None: 358 end = len(source.lines) 359 360 if end > start + 1: 361 # make sure we don't span differently indented code blocks 362 # by using the BlockFinder helper used which inspect.getsource() uses itself 363 block_finder = inspect.BlockFinder() 364 # if we start with an indented line, put blockfinder to "started" mode 365 block_finder.started = source.lines[start][0].isspace() 366 it = ((x + "\n") for x in source.lines[start:end]) 367 try: 368 for tok in tokenize.generate_tokens(lambda: next(it)): 369 block_finder.tokeneater(*tok) 370 except (inspect.EndOfBlock, IndentationError): 371 end = block_finder.last + start 372 except Exception: 373 pass 374 375 # the end might still point to a comment or empty line, correct it 376 while end: 377 line = source.lines[end - 1].lstrip() 378 if line.startswith("#") or not line: 379 end -= 1 380 else: 381 break 382 return astnode, start, end 383 384 385def getstatementrange_old(lineno, source, assertion=False): 386 """ return (start, end) tuple which spans the minimal 387 statement region which containing the given lineno. 388 raise an IndexError if no such statementrange can be found. 389 """ 390 # XXX this logic is only used on python2.4 and below 391 # 1. find the start of the statement 392 from codeop import compile_command 393 for start in range(lineno, -1, -1): 394 if assertion: 395 line = source.lines[start] 396 # the following lines are not fully tested, change with care 397 if 'super' in line and 'self' in line and '__init__' in line: 398 raise IndexError("likely a subclass") 399 if "assert" not in line and "raise" not in line: 400 continue 401 trylines = source.lines[start:lineno+1] 402 # quick hack to prepare parsing an indented line with 403 # compile_command() (which errors on "return" outside defs) 404 trylines.insert(0, 'def xxx():') 405 trysource = '\n '.join(trylines) 406 # ^ space here 407 try: 408 compile_command(trysource) 409 except (SyntaxError, OverflowError, ValueError): 410 continue 411 412 # 2. find the end of the statement 413 for end in range(lineno+1, len(source)+1): 414 trysource = source[start:end] 415 if trysource.isparseable(): 416 return start, end 417 raise SyntaxError("no valid source range around line %d " % (lineno,)) 418 419 420