1import ast
2import inspect
3import textwrap
4import tokenize
5import warnings
6from bisect import bisect_right
7from typing import Iterable
8from typing import Iterator
9from typing import List
10from typing import Optional
11from typing import Tuple
12from typing import Union
13
14from _pytest.compat import overload
15
16
17class Source:
18    """An immutable object holding a source code fragment.
19
20    When using Source(...), the source lines are deindented.
21    """
22
23    def __init__(self, obj: object = None) -> None:
24        if not obj:
25            self.lines = []  # type: List[str]
26        elif isinstance(obj, Source):
27            self.lines = obj.lines
28        elif isinstance(obj, (tuple, list)):
29            self.lines = deindent(x.rstrip("\n") for x in obj)
30        elif isinstance(obj, str):
31            self.lines = deindent(obj.split("\n"))
32        else:
33            rawcode = getrawcode(obj)
34            src = inspect.getsource(rawcode)
35            self.lines = deindent(src.split("\n"))
36
37    def __eq__(self, other: object) -> bool:
38        if not isinstance(other, Source):
39            return NotImplemented
40        return self.lines == other.lines
41
42    # Ignore type because of https://github.com/python/mypy/issues/4266.
43    __hash__ = None  # type: ignore
44
45    @overload
46    def __getitem__(self, key: int) -> str:
47        ...
48
49    @overload  # noqa: F811
50    def __getitem__(self, key: slice) -> "Source":  # noqa: F811
51        ...
52
53    def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]:  # noqa: F811
54        if isinstance(key, int):
55            return self.lines[key]
56        else:
57            if key.step not in (None, 1):
58                raise IndexError("cannot slice a Source with a step")
59            newsource = Source()
60            newsource.lines = self.lines[key.start : key.stop]
61            return newsource
62
63    def __iter__(self) -> Iterator[str]:
64        return iter(self.lines)
65
66    def __len__(self) -> int:
67        return len(self.lines)
68
69    def strip(self) -> "Source":
70        """Return new Source object with trailing and leading blank lines removed."""
71        start, end = 0, len(self)
72        while start < end and not self.lines[start].strip():
73            start += 1
74        while end > start and not self.lines[end - 1].strip():
75            end -= 1
76        source = Source()
77        source.lines[:] = self.lines[start:end]
78        return source
79
80    def indent(self, indent: str = " " * 4) -> "Source":
81        """Return a copy of the source object with all lines indented by the
82        given indent-string."""
83        newsource = Source()
84        newsource.lines = [(indent + line) for line in self.lines]
85        return newsource
86
87    def getstatement(self, lineno: int) -> "Source":
88        """Return Source statement which contains the given linenumber
89        (counted from 0)."""
90        start, end = self.getstatementrange(lineno)
91        return self[start:end]
92
93    def getstatementrange(self, lineno: int) -> Tuple[int, int]:
94        """Return (start, end) tuple which spans the minimal statement region
95        which containing the given lineno."""
96        if not (0 <= lineno < len(self)):
97            raise IndexError("lineno out of range")
98        ast, start, end = getstatementrange_ast(lineno, self)
99        return start, end
100
101    def deindent(self) -> "Source":
102        """Return a new Source object deindented."""
103        newsource = Source()
104        newsource.lines[:] = deindent(self.lines)
105        return newsource
106
107    def __str__(self) -> str:
108        return "\n".join(self.lines)
109
110
111#
112# helper functions
113#
114
115
116def findsource(obj) -> Tuple[Optional[Source], int]:
117    try:
118        sourcelines, lineno = inspect.findsource(obj)
119    except Exception:
120        return None, -1
121    source = Source()
122    source.lines = [line.rstrip() for line in sourcelines]
123    return source, lineno
124
125
126def getrawcode(obj, trycall: bool = True):
127    """Return code object for given function."""
128    try:
129        return obj.__code__
130    except AttributeError:
131        obj = getattr(obj, "f_code", obj)
132        obj = getattr(obj, "__code__", obj)
133        if trycall and not hasattr(obj, "co_firstlineno"):
134            if hasattr(obj, "__call__") and not inspect.isclass(obj):
135                x = getrawcode(obj.__call__, trycall=False)
136                if hasattr(x, "co_firstlineno"):
137                    return x
138        return obj
139
140
141def deindent(lines: Iterable[str]) -> List[str]:
142    return textwrap.dedent("\n".join(lines)).splitlines()
143
144
145def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
146    # Flatten all statements and except handlers into one lineno-list.
147    # AST's line numbers start indexing at 1.
148    values = []  # type: List[int]
149    for x in ast.walk(node):
150        if isinstance(x, (ast.stmt, ast.ExceptHandler)):
151            values.append(x.lineno - 1)
152            for name in ("finalbody", "orelse"):
153                val = getattr(x, name, None)  # type: Optional[List[ast.stmt]]
154                if val:
155                    # Treat the finally/orelse part as its own statement.
156                    values.append(val[0].lineno - 1 - 1)
157    values.sort()
158    insert_index = bisect_right(values, lineno)
159    start = values[insert_index - 1]
160    if insert_index >= len(values):
161        end = None
162    else:
163        end = values[insert_index]
164    return start, end
165
166
167def getstatementrange_ast(
168    lineno: int,
169    source: Source,
170    assertion: bool = False,
171    astnode: Optional[ast.AST] = None,
172) -> Tuple[ast.AST, int, int]:
173    if astnode is None:
174        content = str(source)
175        # See #4260:
176        # Don't produce duplicate warnings when compiling source to find AST.
177        with warnings.catch_warnings():
178            warnings.simplefilter("ignore")
179            astnode = ast.parse(content, "source", "exec")
180
181    start, end = get_statement_startend2(lineno, astnode)
182    # We need to correct the end:
183    # - ast-parsing strips comments
184    # - there might be empty lines
185    # - we might have lesser indented code blocks at the end
186    if end is None:
187        end = len(source.lines)
188
189    if end > start + 1:
190        # Make sure we don't span differently indented code blocks
191        # by using the BlockFinder helper used which inspect.getsource() uses itself.
192        block_finder = inspect.BlockFinder()
193        # If we start with an indented line, put blockfinder to "started" mode.
194        block_finder.started = source.lines[start][0].isspace()
195        it = ((x + "\n") for x in source.lines[start:end])
196        try:
197            for tok in tokenize.generate_tokens(lambda: next(it)):
198                block_finder.tokeneater(*tok)
199        except (inspect.EndOfBlock, IndentationError):
200            end = block_finder.last + start
201        except Exception:
202            pass
203
204    # The end might still point to a comment or empty line, correct it.
205    while end:
206        line = source.lines[end - 1].lstrip()
207        if line.startswith("#") or not line:
208            end -= 1
209        else:
210            break
211    return astnode, start, end
212