1from __future__ import generators
2
3from bisect import bisect_right
4import sys
5import inspect, tokenize
6import py
7from types import ModuleType
8cpy_compile = compile
9
10try:
11    import _ast
12    from _ast import PyCF_ONLY_AST as _AST_FLAG
13except ImportError:
14    _AST_FLAG = 0
15    _ast = None
16
17
18class Source(object):
19    """ a immutable object holding a source code fragment,
20        possibly deindenting it.
21    """
22    _compilecounter = 0
23    def __init__(self, *parts, **kwargs):
24        self.lines = lines = []
25        de = kwargs.get('deindent', True)
26        rstrip = kwargs.get('rstrip', True)
27        for part in parts:
28            if not part:
29                partlines = []
30            if isinstance(part, Source):
31                partlines = part.lines
32            elif isinstance(part, (tuple, list)):
33                partlines = [x.rstrip("\n") for x in part]
34            elif isinstance(part, py.builtin._basestring):
35                partlines = part.split('\n')
36                if rstrip:
37                    while partlines:
38                        if partlines[-1].strip():
39                            break
40                        partlines.pop()
41            else:
42                partlines = getsource(part, deindent=de).lines
43            if de:
44                partlines = deindent(partlines)
45            lines.extend(partlines)
46
47    def __eq__(self, other):
48        try:
49            return self.lines == other.lines
50        except AttributeError:
51            if isinstance(other, str):
52                return str(self) == other
53            return False
54
55    def __getitem__(self, key):
56        if isinstance(key, int):
57            return self.lines[key]
58        else:
59            if key.step not in (None, 1):
60                raise IndexError("cannot slice a Source with a step")
61            return self.__getslice__(key.start, key.stop)
62
63    def __len__(self):
64        return len(self.lines)
65
66    def __getslice__(self, start, end):
67        newsource = Source()
68        newsource.lines = self.lines[start:end]
69        return newsource
70
71    def strip(self):
72        """ return new source object with trailing
73            and leading blank lines removed.
74        """
75        start, end = 0, len(self)
76        while start < end and not self.lines[start].strip():
77            start += 1
78        while end > start and not self.lines[end-1].strip():
79            end -= 1
80        source = Source()
81        source.lines[:] = self.lines[start:end]
82        return source
83
84    def putaround(self, before='', after='', indent=' ' * 4):
85        """ return a copy of the source object with
86            'before' and 'after' wrapped around it.
87        """
88        before = Source(before)
89        after = Source(after)
90        newsource = Source()
91        lines = [ (indent + line) for line in self.lines]
92        newsource.lines = before.lines + lines +  after.lines
93        return newsource
94
95    def indent(self, indent=' ' * 4):
96        """ return a copy of the source object with
97            all lines indented by the given indent-string.
98        """
99        newsource = Source()
100        newsource.lines = [(indent+line) for line in self.lines]
101        return newsource
102
103    def getstatement(self, lineno, assertion=False):
104        """ return Source statement which contains the
105            given linenumber (counted from 0).
106        """
107        start, end = self.getstatementrange(lineno, assertion)
108        return self[start:end]
109
110    def getstatementrange(self, lineno, assertion=False):
111        """ return (start, end) tuple which spans the minimal
112            statement region which containing the given lineno.
113        """
114        if not (0 <= lineno < len(self)):
115            raise IndexError("lineno out of range")
116        ast, start, end = getstatementrange_ast(lineno, self)
117        return start, end
118
119    def deindent(self, offset=None):
120        """ return a new source object deindented by offset.
121            If offset is None then guess an indentation offset from
122            the first non-blank line.  Subsequent lines which have a
123            lower indentation offset will be copied verbatim as
124            they are assumed to be part of multilines.
125        """
126        # XXX maybe use the tokenizer to properly handle multiline
127        #     strings etc.pp?
128        newsource = Source()
129        newsource.lines[:] = deindent(self.lines, offset)
130        return newsource
131
132    def isparseable(self, deindent=True):
133        """ return True if source is parseable, heuristically
134            deindenting it by default.
135        """
136        try:
137            import parser
138        except ImportError:
139            syntax_checker = lambda x: compile(x, 'asd', 'exec')
140        else:
141            syntax_checker = parser.suite
142
143        if deindent:
144            source = str(self.deindent())
145        else:
146            source = str(self)
147        try:
148            #compile(source+'\n', "x", "exec")
149            syntax_checker(source+'\n')
150        except KeyboardInterrupt:
151            raise
152        except Exception:
153            return False
154        else:
155            return True
156
157    def __str__(self):
158        return "\n".join(self.lines)
159
160    def compile(self, filename=None, mode='exec',
161                flag=generators.compiler_flag,
162                dont_inherit=0, _genframe=None):
163        """ return compiled code object. if filename is None
164            invent an artificial filename which displays
165            the source/line position of the caller frame.
166        """
167        if not filename or py.path.local(filename).check(file=0):
168            if _genframe is None:
169                _genframe = sys._getframe(1) # the caller
170            fn,lineno = _genframe.f_code.co_filename, _genframe.f_lineno
171            base = "<%d-codegen " % self._compilecounter
172            self.__class__._compilecounter += 1
173            if not filename:
174                filename = base + '%s:%d>' % (fn, lineno)
175            else:
176                filename = base + '%r %s:%d>' % (filename, fn, lineno)
177        source = "\n".join(self.lines) + '\n'
178        try:
179            co = cpy_compile(source, filename, mode, flag)
180        except SyntaxError:
181            ex = sys.exc_info()[1]
182            # re-represent syntax errors from parsing python strings
183            msglines = self.lines[:ex.lineno]
184            if ex.offset:
185                msglines.append(" "*ex.offset + '^')
186            msglines.append("(code was compiled probably from here: %s)" % filename)
187            newex = SyntaxError('\n'.join(msglines))
188            newex.offset = ex.offset
189            newex.lineno = ex.lineno
190            newex.text = ex.text
191            raise newex
192        else:
193            if flag & _AST_FLAG:
194                return co
195            lines = [(x + "\n") for x in self.lines]
196            import linecache
197            linecache.cache[filename] = (1, None, lines, filename)
198            return co
199
200#
201# public API shortcut functions
202#
203
204def compile_(source, filename=None, mode='exec', flags=
205            generators.compiler_flag, dont_inherit=0):
206    """ compile the given source to a raw code object,
207        and maintain an internal cache which allows later
208        retrieval of the source code for the code object
209        and any recursively created code objects.
210    """
211    if _ast is not None and isinstance(source, _ast.AST):
212        # XXX should Source support having AST?
213        return cpy_compile(source, filename, mode, flags, dont_inherit)
214    _genframe = sys._getframe(1) # the caller
215    s = Source(source)
216    co = s.compile(filename, mode, flags, _genframe=_genframe)
217    return co
218
219
220def getfslineno(obj):
221    """ Return source location (path, lineno) for the given object.
222    If the source cannot be determined return ("", -1)
223    """
224    try:
225        code = py.code.Code(obj)
226    except TypeError:
227        try:
228            fn = (inspect.getsourcefile(obj) or
229                  inspect.getfile(obj))
230        except TypeError:
231            return "", -1
232
233        fspath = fn and py.path.local(fn) or None
234        lineno = -1
235        if fspath:
236            try:
237                _, lineno = findsource(obj)
238            except IOError:
239                pass
240    else:
241        fspath = code.path
242        lineno = code.firstlineno
243    assert isinstance(lineno, int)
244    return fspath, lineno
245
246#
247# helper functions
248#
249
250def findsource(obj):
251    try:
252        sourcelines, lineno = inspect.findsource(obj)
253    except py.builtin._sysex:
254        raise
255    except:
256        return None, -1
257    source = Source()
258    source.lines = [line.rstrip() for line in sourcelines]
259    return source, lineno
260
261def getsource(obj, **kwargs):
262    obj = py.code.getrawcode(obj)
263    try:
264        strsrc = inspect.getsource(obj)
265    except IndentationError:
266        strsrc = "\"Buggy python version consider upgrading, cannot get source\""
267    assert isinstance(strsrc, str)
268    return Source(strsrc, **kwargs)
269
270def deindent(lines, offset=None):
271    if offset is None:
272        for line in lines:
273            line = line.expandtabs()
274            s = line.lstrip()
275            if s:
276                offset = len(line)-len(s)
277                break
278        else:
279            offset = 0
280    if offset == 0:
281        return list(lines)
282    newlines = []
283    def readline_generator(lines):
284        for line in lines:
285            yield line + '\n'
286        while True:
287            yield ''
288
289    it = readline_generator(lines)
290
291    try:
292        for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens(lambda: next(it)):
293            if sline > len(lines):
294                break # End of input reached
295            if sline > len(newlines):
296                line = lines[sline - 1].expandtabs()
297                if line.lstrip() and line[:offset].isspace():
298                    line = line[offset:] # Deindent
299                newlines.append(line)
300
301            for i in range(sline, eline):
302                # Don't deindent continuing lines of
303                # multiline tokens (i.e. multiline strings)
304                newlines.append(lines[i])
305    except (IndentationError, tokenize.TokenError):
306        pass
307    # Add any lines we didn't see. E.g. if an exception was raised.
308    newlines.extend(lines[len(newlines):])
309    return newlines
310
311
312def get_statement_startend2(lineno, node):
313    import ast
314    # flatten all statements and except handlers into one lineno-list
315    # AST's line numbers start indexing at 1
316    l = []
317    for x in ast.walk(node):
318        if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
319            l.append(x.lineno - 1)
320            for name in "finalbody", "orelse":
321                val = getattr(x, name, None)
322                if val:
323                    # treat the finally/orelse part as its own statement
324                    l.append(val[0].lineno - 1 - 1)
325    l.sort()
326    insert_index = bisect_right(l, lineno)
327    start = l[insert_index - 1]
328    if insert_index >= len(l):
329        end = None
330    else:
331        end = l[insert_index]
332    return start, end
333
334
335def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
336    if astnode is None:
337        content = str(source)
338        try:
339            astnode = compile(content, "source", "exec", 1024)  # 1024 for AST
340        except ValueError:
341            start, end = getstatementrange_old(lineno, source, assertion)
342            return None, start, end
343    start, end = get_statement_startend2(lineno, astnode)
344    # we need to correct the end:
345    # - ast-parsing strips comments
346    # - there might be empty lines
347    # - we might have lesser indented code blocks at the end
348    if end is None:
349        end = len(source.lines)
350
351    if end > start + 1:
352        # make sure we don't span differently indented code blocks
353        # by using the BlockFinder helper used which inspect.getsource() uses itself
354        block_finder = inspect.BlockFinder()
355        # if we start with an indented line, put blockfinder to "started" mode
356        block_finder.started = source.lines[start][0].isspace()
357        it = ((x + "\n") for x in source.lines[start:end])
358        try:
359            for tok in tokenize.generate_tokens(lambda: next(it)):
360                block_finder.tokeneater(*tok)
361        except (inspect.EndOfBlock, IndentationError):
362            end = block_finder.last + start
363        except Exception:
364            pass
365
366    # the end might still point to a comment or empty line, correct it
367    while end:
368        line = source.lines[end - 1].lstrip()
369        if line.startswith("#") or not line:
370            end -= 1
371        else:
372            break
373    return astnode, start, end
374
375
376def getstatementrange_old(lineno, source, assertion=False):
377    """ return (start, end) tuple which spans the minimal
378        statement region which containing the given lineno.
379        raise an IndexError if no such statementrange can be found.
380    """
381    # XXX this logic is only used on python2.4 and below
382    # 1. find the start of the statement
383    from codeop import compile_command
384    for start in range(lineno, -1, -1):
385        if assertion:
386            line = source.lines[start]
387            # the following lines are not fully tested, change with care
388            if 'super' in line and 'self' in line and '__init__' in line:
389                raise IndexError("likely a subclass")
390            if "assert" not in line and "raise" not in line:
391                continue
392        trylines = source.lines[start:lineno+1]
393        # quick hack to prepare parsing an indented line with
394        # compile_command() (which errors on "return" outside defs)
395        trylines.insert(0, 'def xxx():')
396        trysource = '\n '.join(trylines)
397        #              ^ space here
398        try:
399            compile_command(trysource)
400        except (SyntaxError, OverflowError, ValueError):
401            continue
402
403        # 2. find the end of the statement
404        for end in range(lineno+1, len(source)+1):
405            trysource = source[start:end]
406            if trysource.isparseable():
407                return start, end
408    raise SyntaxError("no valid source range around line %d " % (lineno,))
409
410
411