1from __future__ import generators
2
3from bisect import bisect_right
4import sys
5import inspect, tokenize
6import py
7from types import ModuleType
8cpy_compile = compile
9
10try:
11    import _ast
12    from _ast import PyCF_ONLY_AST as _AST_FLAG
13except ImportError:
14    _AST_FLAG = 0
15    _ast = None
16
17
18class Source(object):
19    """ a immutable object holding a source code fragment,
20        possibly deindenting it.
21    """
22    _compilecounter = 0
23    def __init__(self, *parts, **kwargs):
24        self.lines = lines = []
25        de = kwargs.get('deindent', True)
26        rstrip = kwargs.get('rstrip', True)
27        for part in parts:
28            if not part:
29                partlines = []
30            if isinstance(part, Source):
31                partlines = part.lines
32            elif isinstance(part, (tuple, list)):
33                partlines = [x.rstrip("\n") for x in part]
34            elif isinstance(part, py.builtin._basestring):
35                partlines = part.split('\n')
36                if rstrip:
37                    while partlines:
38                        if partlines[-1].strip():
39                            break
40                        partlines.pop()
41            else:
42                partlines = getsource(part, deindent=de).lines
43            if de:
44                partlines = deindent(partlines)
45            lines.extend(partlines)
46
47    def __eq__(self, other):
48        try:
49            return self.lines == other.lines
50        except AttributeError:
51            if isinstance(other, str):
52                return str(self) == other
53            return False
54
55    def __getitem__(self, key):
56        if isinstance(key, int):
57            return self.lines[key]
58        else:
59            if key.step not in (None, 1):
60                raise IndexError("cannot slice a Source with a step")
61            return self.__getslice__(key.start, key.stop)
62
63    def __len__(self):
64        return len(self.lines)
65
66    def __getslice__(self, start, end):
67        newsource = Source()
68        newsource.lines = self.lines[start:end]
69        return newsource
70
71    def strip(self):
72        """ return new source object with trailing
73            and leading blank lines removed.
74        """
75        start, end = 0, len(self)
76        while start < end and not self.lines[start].strip():
77            start += 1
78        while end > start and not self.lines[end-1].strip():
79            end -= 1
80        source = Source()
81        source.lines[:] = self.lines[start:end]
82        return source
83
84    def putaround(self, before='', after='', indent=' ' * 4):
85        """ return a copy of the source object with
86            'before' and 'after' wrapped around it.
87        """
88        before = Source(before)
89        after = Source(after)
90        newsource = Source()
91        lines = [ (indent + line) for line in self.lines]
92        newsource.lines = before.lines + lines +  after.lines
93        return newsource
94
95    def indent(self, indent=' ' * 4):
96        """ return a copy of the source object with
97            all lines indented by the given indent-string.
98        """
99        newsource = Source()
100        newsource.lines = [(indent+line) for line in self.lines]
101        return newsource
102
103    def getstatement(self, lineno, assertion=False):
104        """ return Source statement which contains the
105            given linenumber (counted from 0).
106        """
107        start, end = self.getstatementrange(lineno, assertion)
108        return self[start:end]
109
110    def getstatementrange(self, lineno, assertion=False):
111        """ return (start, end) tuple which spans the minimal
112            statement region which containing the given lineno.
113        """
114        if not (0 <= lineno < len(self)):
115            raise IndexError("lineno out of range")
116        ast, start, end = getstatementrange_ast(lineno, self)
117        return start, end
118
119    def deindent(self, offset=None):
120        """ return a new source object deindented by offset.
121            If offset is None then guess an indentation offset from
122            the first non-blank line.  Subsequent lines which have a
123            lower indentation offset will be copied verbatim as
124            they are assumed to be part of multilines.
125        """
126        # XXX maybe use the tokenizer to properly handle multiline
127        #     strings etc.pp?
128        newsource = Source()
129        newsource.lines[:] = deindent(self.lines, offset)
130        return newsource
131
132    def isparseable(self, deindent=True):
133        """ return True if source is parseable, heuristically
134            deindenting it by default.
135        """
136        try:
137            import parser
138        except ImportError:
139            syntax_checker = lambda x: compile(x, 'asd', 'exec')
140        else:
141            syntax_checker = parser.suite
142
143        if deindent:
144            source = str(self.deindent())
145        else:
146            source = str(self)
147        try:
148            #compile(source+'\n', "x", "exec")
149            syntax_checker(source+'\n')
150        except KeyboardInterrupt:
151            raise
152        except Exception:
153            return False
154        else:
155            return True
156
157    def __str__(self):
158        return "\n".join(self.lines)
159
160    def compile(self, filename=None, mode='exec',
161                flag=generators.compiler_flag,
162                dont_inherit=0, _genframe=None):
163        """ return compiled code object. if filename is None
164            invent an artificial filename which displays
165            the source/line position of the caller frame.
166        """
167        if not filename or py.path.local(filename).check(file=0):
168            if _genframe is None:
169                _genframe = sys._getframe(1) # the caller
170            fn,lineno = _genframe.f_code.co_filename, _genframe.f_lineno
171            base = "<%d-codegen " % self._compilecounter
172            self.__class__._compilecounter += 1
173            if not filename:
174                filename = base + '%s:%d>' % (fn, lineno)
175            else:
176                filename = base + '%r %s:%d>' % (filename, fn, lineno)
177        source = "\n".join(self.lines) + '\n'
178        try:
179            co = cpy_compile(source, filename, mode, flag)
180        except SyntaxError:
181            ex = sys.exc_info()[1]
182            # re-represent syntax errors from parsing python strings
183            msglines = self.lines[:ex.lineno]
184            if ex.offset:
185                msglines.append(" "*ex.offset + '^')
186            msglines.append("(code was compiled probably from here: %s)" % filename)
187            newex = SyntaxError('\n'.join(msglines))
188            newex.offset = ex.offset
189            newex.lineno = ex.lineno
190            newex.text = ex.text
191            raise newex
192        else:
193            if flag & _AST_FLAG:
194                return co
195            lines = [(x + "\n") for x in self.lines]
196            if sys.version_info[0] >= 3:
197                # XXX py3's inspect.getsourcefile() checks for a module
198                # and a pep302 __loader__ ... we don't have a module
199                # at code compile-time so we need to fake it here
200                m = ModuleType("_pycodecompile_pseudo_module")
201                py.std.inspect.modulesbyfile[filename] = None
202                py.std.sys.modules[None] = m
203                m.__loader__ = 1
204            py.std.linecache.cache[filename] = (1, None, lines, filename)
205            return co
206
207#
208# public API shortcut functions
209#
210
211def compile_(source, filename=None, mode='exec', flags=
212            generators.compiler_flag, dont_inherit=0):
213    """ compile the given source to a raw code object,
214        and maintain an internal cache which allows later
215        retrieval of the source code for the code object
216        and any recursively created code objects.
217    """
218    if _ast is not None and isinstance(source, _ast.AST):
219        # XXX should Source support having AST?
220        return cpy_compile(source, filename, mode, flags, dont_inherit)
221    _genframe = sys._getframe(1) # the caller
222    s = Source(source)
223    co = s.compile(filename, mode, flags, _genframe=_genframe)
224    return co
225
226
227def getfslineno(obj):
228    """ Return source location (path, lineno) for the given object.
229    If the source cannot be determined return ("", -1)
230    """
231    try:
232        code = py.code.Code(obj)
233    except TypeError:
234        try:
235            fn = (py.std.inspect.getsourcefile(obj) or
236                  py.std.inspect.getfile(obj))
237        except TypeError:
238            return "", -1
239
240        fspath = fn and py.path.local(fn) or None
241        lineno = -1
242        if fspath:
243            try:
244                _, lineno = findsource(obj)
245            except IOError:
246                pass
247    else:
248        fspath = code.path
249        lineno = code.firstlineno
250    assert isinstance(lineno, int)
251    return fspath, lineno
252
253#
254# helper functions
255#
256
257def findsource(obj):
258    try:
259        sourcelines, lineno = py.std.inspect.findsource(obj)
260    except py.builtin._sysex:
261        raise
262    except:
263        return None, -1
264    source = Source()
265    source.lines = [line.rstrip() for line in sourcelines]
266    return source, lineno
267
268def getsource(obj, **kwargs):
269    obj = py.code.getrawcode(obj)
270    try:
271        strsrc = inspect.getsource(obj)
272    except IndentationError:
273        strsrc = "\"Buggy python version consider upgrading, cannot get source\""
274    assert isinstance(strsrc, str)
275    return Source(strsrc, **kwargs)
276
277def deindent(lines, offset=None):
278    if offset is None:
279        for line in lines:
280            line = line.expandtabs()
281            s = line.lstrip()
282            if s:
283                offset = len(line)-len(s)
284                break
285        else:
286            offset = 0
287    if offset == 0:
288        return list(lines)
289    newlines = []
290    def readline_generator(lines):
291        for line in lines:
292            yield line + '\n'
293        while True:
294            yield ''
295
296    it = readline_generator(lines)
297
298    try:
299        for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens(lambda: next(it)):
300            if sline > len(lines):
301                break # End of input reached
302            if sline > len(newlines):
303                line = lines[sline - 1].expandtabs()
304                if line.lstrip() and line[:offset].isspace():
305                    line = line[offset:] # Deindent
306                newlines.append(line)
307
308            for i in range(sline, eline):
309                # Don't deindent continuing lines of
310                # multiline tokens (i.e. multiline strings)
311                newlines.append(lines[i])
312    except (IndentationError, tokenize.TokenError):
313        pass
314    # Add any lines we didn't see. E.g. if an exception was raised.
315    newlines.extend(lines[len(newlines):])
316    return newlines
317
318
319def get_statement_startend2(lineno, node):
320    import ast
321    # flatten all statements and except handlers into one lineno-list
322    # AST's line numbers start indexing at 1
323    l = []
324    for x in ast.walk(node):
325        if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
326            l.append(x.lineno - 1)
327            for name in "finalbody", "orelse":
328                val = getattr(x, name, None)
329                if val:
330                    # treat the finally/orelse part as its own statement
331                    l.append(val[0].lineno - 1 - 1)
332    l.sort()
333    insert_index = bisect_right(l, lineno)
334    start = l[insert_index - 1]
335    if insert_index >= len(l):
336        end = None
337    else:
338        end = l[insert_index]
339    return start, end
340
341
342def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
343    if astnode is None:
344        content = str(source)
345        if sys.version_info < (2,7):
346            content += "\n"
347        try:
348            astnode = compile(content, "source", "exec", 1024)  # 1024 for AST
349        except ValueError:
350            start, end = getstatementrange_old(lineno, source, assertion)
351            return None, start, end
352    start, end = get_statement_startend2(lineno, astnode)
353    # we need to correct the end:
354    # - ast-parsing strips comments
355    # - there might be empty lines
356    # - we might have lesser indented code blocks at the end
357    if end is None:
358        end = len(source.lines)
359
360    if end > start + 1:
361        # make sure we don't span differently indented code blocks
362        # by using the BlockFinder helper used which inspect.getsource() uses itself
363        block_finder = inspect.BlockFinder()
364        # if we start with an indented line, put blockfinder to "started" mode
365        block_finder.started = source.lines[start][0].isspace()
366        it = ((x + "\n") for x in source.lines[start:end])
367        try:
368            for tok in tokenize.generate_tokens(lambda: next(it)):
369                block_finder.tokeneater(*tok)
370        except (inspect.EndOfBlock, IndentationError):
371            end = block_finder.last + start
372        except Exception:
373            pass
374
375    # the end might still point to a comment or empty line, correct it
376    while end:
377        line = source.lines[end - 1].lstrip()
378        if line.startswith("#") or not line:
379            end -= 1
380        else:
381            break
382    return astnode, start, end
383
384
385def getstatementrange_old(lineno, source, assertion=False):
386    """ return (start, end) tuple which spans the minimal
387        statement region which containing the given lineno.
388        raise an IndexError if no such statementrange can be found.
389    """
390    # XXX this logic is only used on python2.4 and below
391    # 1. find the start of the statement
392    from codeop import compile_command
393    for start in range(lineno, -1, -1):
394        if assertion:
395            line = source.lines[start]
396            # the following lines are not fully tested, change with care
397            if 'super' in line and 'self' in line and '__init__' in line:
398                raise IndexError("likely a subclass")
399            if "assert" not in line and "raise" not in line:
400                continue
401        trylines = source.lines[start:lineno+1]
402        # quick hack to prepare parsing an indented line with
403        # compile_command() (which errors on "return" outside defs)
404        trylines.insert(0, 'def xxx():')
405        trysource = '\n '.join(trylines)
406        #              ^ space here
407        try:
408            compile_command(trysource)
409        except (SyntaxError, OverflowError, ValueError):
410            continue
411
412        # 2. find the end of the statement
413        for end in range(lineno+1, len(source)+1):
414            trysource = source[start:end]
415            if trysource.isparseable():
416                return start, end
417    raise SyntaxError("no valid source range around line %d " % (lineno,))
418
419
420