1from __future__ import absolute_import, division, generators, print_function
2
3from bisect import bisect_right
4import sys
5import inspect, tokenize
6import py
7cpy_compile = compile
8
9try:
10    import _ast
11    from _ast import PyCF_ONLY_AST as _AST_FLAG
12except ImportError:
13    _AST_FLAG = 0
14    _ast = None
15
16
17class Source(object):
18    """ a immutable object holding a source code fragment,
19        possibly deindenting it.
20    """
21    _compilecounter = 0
22    def __init__(self, *parts, **kwargs):
23        self.lines = lines = []
24        de = kwargs.get('deindent', True)
25        rstrip = kwargs.get('rstrip', True)
26        for part in parts:
27            if not part:
28                partlines = []
29            if isinstance(part, Source):
30                partlines = part.lines
31            elif isinstance(part, (tuple, list)):
32                partlines = [x.rstrip("\n") for x in part]
33            elif isinstance(part, py.builtin._basestring):
34                partlines = part.split('\n')
35                if rstrip:
36                    while partlines:
37                        if partlines[-1].strip():
38                            break
39                        partlines.pop()
40            else:
41                partlines = getsource(part, deindent=de).lines
42            if de:
43                partlines = deindent(partlines)
44            lines.extend(partlines)
45
46    def __eq__(self, other):
47        try:
48            return self.lines == other.lines
49        except AttributeError:
50            if isinstance(other, str):
51                return str(self) == other
52            return False
53
54    __hash__ = None
55
56    def __getitem__(self, key):
57        if isinstance(key, int):
58            return self.lines[key]
59        else:
60            if key.step not in (None, 1):
61                raise IndexError("cannot slice a Source with a step")
62            newsource = Source()
63            newsource.lines = self.lines[key.start:key.stop]
64            return newsource
65
66    def __len__(self):
67        return len(self.lines)
68
69    def strip(self):
70        """ return new source object with trailing
71            and leading blank lines removed.
72        """
73        start, end = 0, len(self)
74        while start < end and not self.lines[start].strip():
75            start += 1
76        while end > start and not self.lines[end-1].strip():
77            end -= 1
78        source = Source()
79        source.lines[:] = self.lines[start:end]
80        return source
81
82    def putaround(self, before='', after='', indent=' ' * 4):
83        """ return a copy of the source object with
84            'before' and 'after' wrapped around it.
85        """
86        before = Source(before)
87        after = Source(after)
88        newsource = Source()
89        lines = [ (indent + line) for line in self.lines]
90        newsource.lines = before.lines + lines +  after.lines
91        return newsource
92
93    def indent(self, indent=' ' * 4):
94        """ return a copy of the source object with
95            all lines indented by the given indent-string.
96        """
97        newsource = Source()
98        newsource.lines = [(indent+line) for line in self.lines]
99        return newsource
100
101    def getstatement(self, lineno, assertion=False):
102        """ return Source statement which contains the
103            given linenumber (counted from 0).
104        """
105        start, end = self.getstatementrange(lineno, assertion)
106        return self[start:end]
107
108    def getstatementrange(self, lineno, assertion=False):
109        """ return (start, end) tuple which spans the minimal
110            statement region which containing the given lineno.
111        """
112        if not (0 <= lineno < len(self)):
113            raise IndexError("lineno out of range")
114        ast, start, end = getstatementrange_ast(lineno, self)
115        return start, end
116
117    def deindent(self, offset=None):
118        """ return a new source object deindented by offset.
119            If offset is None then guess an indentation offset from
120            the first non-blank line.  Subsequent lines which have a
121            lower indentation offset will be copied verbatim as
122            they are assumed to be part of multilines.
123        """
124        # XXX maybe use the tokenizer to properly handle multiline
125        #     strings etc.pp?
126        newsource = Source()
127        newsource.lines[:] = deindent(self.lines, offset)
128        return newsource
129
130    def isparseable(self, deindent=True):
131        """ return True if source is parseable, heuristically
132            deindenting it by default.
133        """
134        try:
135            import parser
136        except ImportError:
137            syntax_checker = lambda x: compile(x, 'asd', 'exec')
138        else:
139            syntax_checker = parser.suite
140
141        if deindent:
142            source = str(self.deindent())
143        else:
144            source = str(self)
145        try:
146            #compile(source+'\n', "x", "exec")
147            syntax_checker(source+'\n')
148        except KeyboardInterrupt:
149            raise
150        except Exception:
151            return False
152        else:
153            return True
154
155    def __str__(self):
156        return "\n".join(self.lines)
157
158    def compile(self, filename=None, mode='exec',
159                flag=generators.compiler_flag,
160                dont_inherit=0, _genframe=None):
161        """ return compiled code object. if filename is None
162            invent an artificial filename which displays
163            the source/line position of the caller frame.
164        """
165        if not filename or py.path.local(filename).check(file=0):
166            if _genframe is None:
167                _genframe = sys._getframe(1) # the caller
168            fn,lineno = _genframe.f_code.co_filename, _genframe.f_lineno
169            base = "<%d-codegen " % self._compilecounter
170            self.__class__._compilecounter += 1
171            if not filename:
172                filename = base + '%s:%d>' % (fn, lineno)
173            else:
174                filename = base + '%r %s:%d>' % (filename, fn, lineno)
175        source = "\n".join(self.lines) + '\n'
176        try:
177            co = cpy_compile(source, filename, mode, flag)
178        except SyntaxError:
179            ex = sys.exc_info()[1]
180            # re-represent syntax errors from parsing python strings
181            msglines = self.lines[:ex.lineno]
182            if ex.offset:
183                msglines.append(" "*ex.offset + '^')
184            msglines.append("(code was compiled probably from here: %s)" % filename)
185            newex = SyntaxError('\n'.join(msglines))
186            newex.offset = ex.offset
187            newex.lineno = ex.lineno
188            newex.text = ex.text
189            raise newex
190        else:
191            if flag & _AST_FLAG:
192                return co
193            lines = [(x + "\n") for x in self.lines]
194            py.std.linecache.cache[filename] = (1, None, lines, filename)
195            return co
196
197#
198# public API shortcut functions
199#
200
201def compile_(source, filename=None, mode='exec', flags=
202            generators.compiler_flag, dont_inherit=0):
203    """ compile the given source to a raw code object,
204        and maintain an internal cache which allows later
205        retrieval of the source code for the code object
206        and any recursively created code objects.
207    """
208    if _ast is not None and isinstance(source, _ast.AST):
209        # XXX should Source support having AST?
210        return cpy_compile(source, filename, mode, flags, dont_inherit)
211    _genframe = sys._getframe(1) # the caller
212    s = Source(source)
213    co = s.compile(filename, mode, flags, _genframe=_genframe)
214    return co
215
216
217def getfslineno(obj):
218    """ Return source location (path, lineno) for the given object.
219    If the source cannot be determined return ("", -1)
220    """
221    import _pytest._code
222    try:
223        code = _pytest._code.Code(obj)
224    except TypeError:
225        try:
226            fn = (py.std.inspect.getsourcefile(obj) or
227                  py.std.inspect.getfile(obj))
228        except TypeError:
229            return "", -1
230
231        fspath = fn and py.path.local(fn) or None
232        lineno = -1
233        if fspath:
234            try:
235                _, lineno = findsource(obj)
236            except IOError:
237                pass
238    else:
239        fspath = code.path
240        lineno = code.firstlineno
241    assert isinstance(lineno, int)
242    return fspath, lineno
243
244#
245# helper functions
246#
247
248def findsource(obj):
249    try:
250        sourcelines, lineno = py.std.inspect.findsource(obj)
251    except py.builtin._sysex:
252        raise
253    except:
254        return None, -1
255    source = Source()
256    source.lines = [line.rstrip() for line in sourcelines]
257    return source, lineno
258
259
260def getsource(obj, **kwargs):
261    import _pytest._code
262    obj = _pytest._code.getrawcode(obj)
263    try:
264        strsrc = inspect.getsource(obj)
265    except IndentationError:
266        strsrc = "\"Buggy python version consider upgrading, cannot get source\""
267    assert isinstance(strsrc, str)
268    return Source(strsrc, **kwargs)
269
270
271def deindent(lines, offset=None):
272    if offset is None:
273        for line in lines:
274            line = line.expandtabs()
275            s = line.lstrip()
276            if s:
277                offset = len(line)-len(s)
278                break
279        else:
280            offset = 0
281    if offset == 0:
282        return list(lines)
283    newlines = []
284
285    def readline_generator(lines):
286        for line in lines:
287            yield line + '\n'
288        while True:
289            yield ''
290
291    it = readline_generator(lines)
292
293    try:
294        for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens(lambda: next(it)):
295            if sline > len(lines):
296                break # End of input reached
297            if sline > len(newlines):
298                line = lines[sline - 1].expandtabs()
299                if line.lstrip() and line[:offset].isspace():
300                    line = line[offset:] # Deindent
301                newlines.append(line)
302
303            for i in range(sline, eline):
304                # Don't deindent continuing lines of
305                # multiline tokens (i.e. multiline strings)
306                newlines.append(lines[i])
307    except (IndentationError, tokenize.TokenError):
308        pass
309    # Add any lines we didn't see. E.g. if an exception was raised.
310    newlines.extend(lines[len(newlines):])
311    return newlines
312
313
314def get_statement_startend2(lineno, node):
315    import ast
316    # flatten all statements and except handlers into one lineno-list
317    # AST's line numbers start indexing at 1
318    l = []
319    for x in ast.walk(node):
320        if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
321            l.append(x.lineno - 1)
322            for name in "finalbody", "orelse":
323                val = getattr(x, name, None)
324                if val:
325                    # treat the finally/orelse part as its own statement
326                    l.append(val[0].lineno - 1 - 1)
327    l.sort()
328    insert_index = bisect_right(l, lineno)
329    start = l[insert_index - 1]
330    if insert_index >= len(l):
331        end = None
332    else:
333        end = l[insert_index]
334    return start, end
335
336
337def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
338    if astnode is None:
339        content = str(source)
340        if sys.version_info < (2,7):
341            content += "\n"
342        try:
343            astnode = compile(content, "source", "exec", 1024)  # 1024 for AST
344        except ValueError:
345            start, end = getstatementrange_old(lineno, source, assertion)
346            return None, start, end
347    start, end = get_statement_startend2(lineno, astnode)
348    # we need to correct the end:
349    # - ast-parsing strips comments
350    # - there might be empty lines
351    # - we might have lesser indented code blocks at the end
352    if end is None:
353        end = len(source.lines)
354
355    if end > start + 1:
356        # make sure we don't span differently indented code blocks
357        # by using the BlockFinder helper used which inspect.getsource() uses itself
358        block_finder = inspect.BlockFinder()
359        # if we start with an indented line, put blockfinder to "started" mode
360        block_finder.started = source.lines[start][0].isspace()
361        it = ((x + "\n") for x in source.lines[start:end])
362        try:
363            for tok in tokenize.generate_tokens(lambda: next(it)):
364                block_finder.tokeneater(*tok)
365        except (inspect.EndOfBlock, IndentationError):
366            end = block_finder.last + start
367        except Exception:
368            pass
369
370    # the end might still point to a comment or empty line, correct it
371    while end:
372        line = source.lines[end - 1].lstrip()
373        if line.startswith("#") or not line:
374            end -= 1
375        else:
376            break
377    return astnode, start, end
378
379
380def getstatementrange_old(lineno, source, assertion=False):
381    """ return (start, end) tuple which spans the minimal
382        statement region which containing the given lineno.
383        raise an IndexError if no such statementrange can be found.
384    """
385    # XXX this logic is only used on python2.4 and below
386    # 1. find the start of the statement
387    from codeop import compile_command
388    for start in range(lineno, -1, -1):
389        if assertion:
390            line = source.lines[start]
391            # the following lines are not fully tested, change with care
392            if 'super' in line and 'self' in line and '__init__' in line:
393                raise IndexError("likely a subclass")
394            if "assert" not in line and "raise" not in line:
395                continue
396        trylines = source.lines[start:lineno+1]
397        # quick hack to prepare parsing an indented line with
398        # compile_command() (which errors on "return" outside defs)
399        trylines.insert(0, 'def xxx():')
400        trysource = '\n '.join(trylines)
401        #              ^ space here
402        try:
403            compile_command(trysource)
404        except (SyntaxError, OverflowError, ValueError):
405            continue
406
407        # 2. find the end of the statement
408        for end in range(lineno+1, len(source)+1):
409            trysource = source[start:end]
410            if trysource.isparseable():
411                return start, end
412    raise SyntaxError("no valid source range around line %d " % (lineno,))
413
414
415