1import ast 2import inspect 3import textwrap 4import tokenize 5import warnings 6from bisect import bisect_right 7from typing import Iterable 8from typing import Iterator 9from typing import List 10from typing import Optional 11from typing import Tuple 12from typing import Union 13 14from _pytest.compat import overload 15 16 17class Source: 18 """An immutable object holding a source code fragment. 19 20 When using Source(...), the source lines are deindented. 21 """ 22 23 def __init__(self, obj: object = None) -> None: 24 if not obj: 25 self.lines = [] # type: List[str] 26 elif isinstance(obj, Source): 27 self.lines = obj.lines 28 elif isinstance(obj, (tuple, list)): 29 self.lines = deindent(x.rstrip("\n") for x in obj) 30 elif isinstance(obj, str): 31 self.lines = deindent(obj.split("\n")) 32 else: 33 rawcode = getrawcode(obj) 34 src = inspect.getsource(rawcode) 35 self.lines = deindent(src.split("\n")) 36 37 def __eq__(self, other: object) -> bool: 38 if not isinstance(other, Source): 39 return NotImplemented 40 return self.lines == other.lines 41 42 # Ignore type because of https://github.com/python/mypy/issues/4266. 43 __hash__ = None # type: ignore 44 45 @overload 46 def __getitem__(self, key: int) -> str: 47 ... 48 49 @overload # noqa: F811 50 def __getitem__(self, key: slice) -> "Source": # noqa: F811 51 ... 52 53 def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811 54 if isinstance(key, int): 55 return self.lines[key] 56 else: 57 if key.step not in (None, 1): 58 raise IndexError("cannot slice a Source with a step") 59 newsource = Source() 60 newsource.lines = self.lines[key.start : key.stop] 61 return newsource 62 63 def __iter__(self) -> Iterator[str]: 64 return iter(self.lines) 65 66 def __len__(self) -> int: 67 return len(self.lines) 68 69 def strip(self) -> "Source": 70 """Return new Source object with trailing and leading blank lines removed.""" 71 start, end = 0, len(self) 72 while start < end and not self.lines[start].strip(): 73 start += 1 74 while end > start and not self.lines[end - 1].strip(): 75 end -= 1 76 source = Source() 77 source.lines[:] = self.lines[start:end] 78 return source 79 80 def indent(self, indent: str = " " * 4) -> "Source": 81 """Return a copy of the source object with all lines indented by the 82 given indent-string.""" 83 newsource = Source() 84 newsource.lines = [(indent + line) for line in self.lines] 85 return newsource 86 87 def getstatement(self, lineno: int) -> "Source": 88 """Return Source statement which contains the given linenumber 89 (counted from 0).""" 90 start, end = self.getstatementrange(lineno) 91 return self[start:end] 92 93 def getstatementrange(self, lineno: int) -> Tuple[int, int]: 94 """Return (start, end) tuple which spans the minimal statement region 95 which containing the given lineno.""" 96 if not (0 <= lineno < len(self)): 97 raise IndexError("lineno out of range") 98 ast, start, end = getstatementrange_ast(lineno, self) 99 return start, end 100 101 def deindent(self) -> "Source": 102 """Return a new Source object deindented.""" 103 newsource = Source() 104 newsource.lines[:] = deindent(self.lines) 105 return newsource 106 107 def __str__(self) -> str: 108 return "\n".join(self.lines) 109 110 111# 112# helper functions 113# 114 115 116def findsource(obj) -> Tuple[Optional[Source], int]: 117 try: 118 sourcelines, lineno = inspect.findsource(obj) 119 except Exception: 120 return None, -1 121 source = Source() 122 source.lines = [line.rstrip() for line in sourcelines] 123 return source, lineno 124 125 126def getrawcode(obj, trycall: bool = True): 127 """Return code object for given function.""" 128 try: 129 return obj.__code__ 130 except AttributeError: 131 obj = getattr(obj, "f_code", obj) 132 obj = getattr(obj, "__code__", obj) 133 if trycall and not hasattr(obj, "co_firstlineno"): 134 if hasattr(obj, "__call__") and not inspect.isclass(obj): 135 x = getrawcode(obj.__call__, trycall=False) 136 if hasattr(x, "co_firstlineno"): 137 return x 138 return obj 139 140 141def deindent(lines: Iterable[str]) -> List[str]: 142 return textwrap.dedent("\n".join(lines)).splitlines() 143 144 145def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: 146 # Flatten all statements and except handlers into one lineno-list. 147 # AST's line numbers start indexing at 1. 148 values = [] # type: List[int] 149 for x in ast.walk(node): 150 if isinstance(x, (ast.stmt, ast.ExceptHandler)): 151 values.append(x.lineno - 1) 152 for name in ("finalbody", "orelse"): 153 val = getattr(x, name, None) # type: Optional[List[ast.stmt]] 154 if val: 155 # Treat the finally/orelse part as its own statement. 156 values.append(val[0].lineno - 1 - 1) 157 values.sort() 158 insert_index = bisect_right(values, lineno) 159 start = values[insert_index - 1] 160 if insert_index >= len(values): 161 end = None 162 else: 163 end = values[insert_index] 164 return start, end 165 166 167def getstatementrange_ast( 168 lineno: int, 169 source: Source, 170 assertion: bool = False, 171 astnode: Optional[ast.AST] = None, 172) -> Tuple[ast.AST, int, int]: 173 if astnode is None: 174 content = str(source) 175 # See #4260: 176 # Don't produce duplicate warnings when compiling source to find AST. 177 with warnings.catch_warnings(): 178 warnings.simplefilter("ignore") 179 astnode = ast.parse(content, "source", "exec") 180 181 start, end = get_statement_startend2(lineno, astnode) 182 # We need to correct the end: 183 # - ast-parsing strips comments 184 # - there might be empty lines 185 # - we might have lesser indented code blocks at the end 186 if end is None: 187 end = len(source.lines) 188 189 if end > start + 1: 190 # Make sure we don't span differently indented code blocks 191 # by using the BlockFinder helper used which inspect.getsource() uses itself. 192 block_finder = inspect.BlockFinder() 193 # If we start with an indented line, put blockfinder to "started" mode. 194 block_finder.started = source.lines[start][0].isspace() 195 it = ((x + "\n") for x in source.lines[start:end]) 196 try: 197 for tok in tokenize.generate_tokens(lambda: next(it)): 198 block_finder.tokeneater(*tok) 199 except (inspect.EndOfBlock, IndentationError): 200 end = block_finder.last + start 201 except Exception: 202 pass 203 204 # The end might still point to a comment or empty line, correct it. 205 while end: 206 line = source.lines[end - 1].lstrip() 207 if line.startswith("#") or not line: 208 end -= 1 209 else: 210 break 211 return astnode, start, end 212