1# Copyright (c) 2006-2011, 2013-2014 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
2# Copyright (c) 2013 Phil Schaf <flying-sheep@web.de>
3# Copyright (c) 2014-2020 Claudiu Popa <pcmanticore@gmail.com>
4# Copyright (c) 2014-2015 Google, Inc.
5# Copyright (c) 2014 Alexander Presnyakov <flagist0@gmail.com>
6# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
7# Copyright (c) 2016 Derek Gustafson <degustaf@gmail.com>
8# Copyright (c) 2017 Łukasz Rogalski <rogalski.91@gmail.com>
9# Copyright (c) 2018 Anthony Sottile <asottile@umich.edu>
10# Copyright (c) 2020-2021 hippo91 <guillaume.peillex@gmail.com>
11# Copyright (c) 2021 Daniël van Noord <13665637+DanielNoord@users.noreply.github.com>
12# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
13# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
14# Copyright (c) 2021 Andrew Haigh <hello@nelf.in>
15
16# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
17# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
18
19"""The AstroidBuilder makes astroid from living object and / or from _ast
20
21The builder is not thread safe and can't be used to parse different sources
22at the same time.
23"""
24import os
25import textwrap
26import types
27from tokenize import detect_encoding
28from typing import List, Union
29
30from astroid import bases, modutils, nodes, raw_building, rebuilder, util
31from astroid._ast import get_parser_module
32from astroid.exceptions import AstroidBuildingError, AstroidSyntaxError, InferenceError
33from astroid.manager import AstroidManager
34from astroid.nodes.node_classes import NodeNG
35
36objects = util.lazy_import("objects")
37
38# The name of the transient function that is used to
39# wrap expressions to be extracted when calling
40# extract_node.
41_TRANSIENT_FUNCTION = "__"
42
43# The comment used to select a statement to be extracted
44# when calling extract_node.
45_STATEMENT_SELECTOR = "#@"
46MISPLACED_TYPE_ANNOTATION_ERROR = "misplaced type annotation"
47
48
49def open_source_file(filename):
50    # pylint: disable=consider-using-with
51    with open(filename, "rb") as byte_stream:
52        encoding = detect_encoding(byte_stream.readline)[0]
53    stream = open(filename, newline=None, encoding=encoding)
54    data = stream.read()
55    return stream, encoding, data
56
57
58def _can_assign_attr(node, attrname):
59    try:
60        slots = node.slots()
61    except NotImplementedError:
62        pass
63    else:
64        if slots and attrname not in {slot.value for slot in slots}:
65            return False
66    return node.qname() != "builtins.object"
67
68
69class AstroidBuilder(raw_building.InspectBuilder):
70    """Class for building an astroid tree from source code or from a live module.
71
72    The param *manager* specifies the manager class which should be used.
73    If no manager is given, then the default one will be used. The
74    param *apply_transforms* determines if the transforms should be
75    applied after the tree was built from source or from a live object,
76    by default being True.
77    """
78
79    # pylint: disable=redefined-outer-name
80    def __init__(self, manager=None, apply_transforms=True):
81        super().__init__(manager)
82        self._apply_transforms = apply_transforms
83
84    def module_build(
85        self, module: types.ModuleType, modname: str = None
86    ) -> nodes.Module:
87        """Build an astroid from a living module instance."""
88        node = None
89        path = getattr(module, "__file__", None)
90        if path is not None:
91            path_, ext = os.path.splitext(modutils._path_from_filename(path))
92            if ext in {".py", ".pyc", ".pyo"} and os.path.exists(path_ + ".py"):
93                node = self.file_build(path_ + ".py", modname)
94        if node is None:
95            # this is a built-in module
96            # get a partial representation by introspection
97            node = self.inspect_build(module, modname=modname, path=path)
98            if self._apply_transforms:
99                # We have to handle transformation by ourselves since the
100                # rebuilder isn't called for builtin nodes
101                node = self._manager.visit_transforms(node)
102        return node
103
104    def file_build(self, path, modname=None):
105        """Build astroid from a source code file (i.e. from an ast)
106
107        *path* is expected to be a python source file
108        """
109        try:
110            stream, encoding, data = open_source_file(path)
111        except OSError as exc:
112            raise AstroidBuildingError(
113                "Unable to load file {path}:\n{error}",
114                modname=modname,
115                path=path,
116                error=exc,
117            ) from exc
118        except (SyntaxError, LookupError) as exc:
119            raise AstroidSyntaxError(
120                "Python 3 encoding specification error or unknown encoding:\n"
121                "{error}",
122                modname=modname,
123                path=path,
124                error=exc,
125            ) from exc
126        except UnicodeError as exc:  # wrong encoding
127            # detect_encoding returns utf-8 if no encoding specified
128            raise AstroidBuildingError(
129                "Wrong or no encoding specified for {filename}.", filename=path
130            ) from exc
131        with stream:
132            # get module name if necessary
133            if modname is None:
134                try:
135                    modname = ".".join(modutils.modpath_from_file(path))
136                except ImportError:
137                    modname = os.path.splitext(os.path.basename(path))[0]
138            # build astroid representation
139            module = self._data_build(data, modname, path)
140            return self._post_build(module, encoding)
141
142    def string_build(self, data, modname="", path=None):
143        """Build astroid from source code string."""
144        module = self._data_build(data, modname, path)
145        module.file_bytes = data.encode("utf-8")
146        return self._post_build(module, "utf-8")
147
148    def _post_build(self, module, encoding):
149        """Handles encoding and delayed nodes after a module has been built"""
150        module.file_encoding = encoding
151        self._manager.cache_module(module)
152        # post tree building steps after we stored the module in the cache:
153        for from_node in module._import_from_nodes:
154            if from_node.modname == "__future__":
155                for symbol, _ in from_node.names:
156                    module.future_imports.add(symbol)
157            self.add_from_names_to_locals(from_node)
158        # handle delayed assattr nodes
159        for delayed in module._delayed_assattr:
160            self.delayed_assattr(delayed)
161
162        # Visit the transforms
163        if self._apply_transforms:
164            module = self._manager.visit_transforms(module)
165        return module
166
167    def _data_build(self, data, modname, path):
168        """Build tree node from data and add some informations"""
169        try:
170            node, parser_module = _parse_string(data, type_comments=True)
171        except (TypeError, ValueError, SyntaxError) as exc:
172            raise AstroidSyntaxError(
173                "Parsing Python code failed:\n{error}",
174                source=data,
175                modname=modname,
176                path=path,
177                error=exc,
178            ) from exc
179
180        if path is not None:
181            node_file = os.path.abspath(path)
182        else:
183            node_file = "<?>"
184        if modname.endswith(".__init__"):
185            modname = modname[:-9]
186            package = True
187        else:
188            package = (
189                path is not None
190                and os.path.splitext(os.path.basename(path))[0] == "__init__"
191            )
192        builder = rebuilder.TreeRebuilder(self._manager, parser_module)
193        module = builder.visit_module(node, modname, node_file, package)
194        module._import_from_nodes = builder._import_from_nodes
195        module._delayed_assattr = builder._delayed_assattr
196        return module
197
198    def add_from_names_to_locals(self, node):
199        """Store imported names to the locals
200
201        Resort the locals if coming from a delayed node
202        """
203
204        def _key_func(node):
205            return node.fromlineno
206
207        def sort_locals(my_list):
208            my_list.sort(key=_key_func)
209
210        for (name, asname) in node.names:
211            if name == "*":
212                try:
213                    imported = node.do_import_module()
214                except AstroidBuildingError:
215                    continue
216                for name in imported.public_names():
217                    node.parent.set_local(name, node)
218                    sort_locals(node.parent.scope().locals[name])
219            else:
220                node.parent.set_local(asname or name, node)
221                sort_locals(node.parent.scope().locals[asname or name])
222
223    def delayed_assattr(self, node):
224        """Visit a AssAttr node
225
226        This adds name to locals and handle members definition.
227        """
228        try:
229            frame = node.frame()
230            for inferred in node.expr.infer():
231                if inferred is util.Uninferable:
232                    continue
233                try:
234                    cls = inferred.__class__
235                    if cls is bases.Instance or cls is objects.ExceptionInstance:
236                        inferred = inferred._proxied
237                        iattrs = inferred.instance_attrs
238                        if not _can_assign_attr(inferred, node.attrname):
239                            continue
240                    elif isinstance(inferred, bases.Instance):
241                        # Const, Tuple or other containers that inherit from
242                        # `Instance`
243                        continue
244                    elif inferred.is_function:
245                        iattrs = inferred.instance_attrs
246                    else:
247                        iattrs = inferred.locals
248                except AttributeError:
249                    # XXX log error
250                    continue
251                values = iattrs.setdefault(node.attrname, [])
252                if node in values:
253                    continue
254                # get assign in __init__ first XXX useful ?
255                if (
256                    frame.name == "__init__"
257                    and values
258                    and values[0].frame().name != "__init__"
259                ):
260                    values.insert(0, node)
261                else:
262                    values.append(node)
263        except InferenceError:
264            pass
265
266
267def build_namespace_package_module(name: str, path: List[str]) -> nodes.Module:
268    return nodes.Module(name, doc="", path=path, package=True)
269
270
271def parse(code, module_name="", path=None, apply_transforms=True):
272    """Parses a source string in order to obtain an astroid AST from it
273
274    :param str code: The code for the module.
275    :param str module_name: The name for the module, if any
276    :param str path: The path for the module
277    :param bool apply_transforms:
278        Apply the transforms for the give code. Use it if you
279        don't want the default transforms to be applied.
280    """
281    code = textwrap.dedent(code)
282    builder = AstroidBuilder(
283        manager=AstroidManager(), apply_transforms=apply_transforms
284    )
285    return builder.string_build(code, modname=module_name, path=path)
286
287
288def _extract_expressions(node):
289    """Find expressions in a call to _TRANSIENT_FUNCTION and extract them.
290
291    The function walks the AST recursively to search for expressions that
292    are wrapped into a call to _TRANSIENT_FUNCTION. If it finds such an
293    expression, it completely removes the function call node from the tree,
294    replacing it by the wrapped expression inside the parent.
295
296    :param node: An astroid node.
297    :type node:  astroid.bases.NodeNG
298    :yields: The sequence of wrapped expressions on the modified tree
299    expression can be found.
300    """
301    if (
302        isinstance(node, nodes.Call)
303        and isinstance(node.func, nodes.Name)
304        and node.func.name == _TRANSIENT_FUNCTION
305    ):
306        real_expr = node.args[0]
307        real_expr.parent = node.parent
308        # Search for node in all _astng_fields (the fields checked when
309        # get_children is called) of its parent. Some of those fields may
310        # be lists or tuples, in which case the elements need to be checked.
311        # When we find it, replace it by real_expr, so that the AST looks
312        # like no call to _TRANSIENT_FUNCTION ever took place.
313        for name in node.parent._astroid_fields:
314            child = getattr(node.parent, name)
315            if isinstance(child, (list, tuple)):
316                for idx, compound_child in enumerate(child):
317                    if compound_child is node:
318                        child[idx] = real_expr
319            elif child is node:
320                setattr(node.parent, name, real_expr)
321        yield real_expr
322    else:
323        for child in node.get_children():
324            yield from _extract_expressions(child)
325
326
327def _find_statement_by_line(node, line):
328    """Extracts the statement on a specific line from an AST.
329
330    If the line number of node matches line, it will be returned;
331    otherwise its children are iterated and the function is called
332    recursively.
333
334    :param node: An astroid node.
335    :type node: astroid.bases.NodeNG
336    :param line: The line number of the statement to extract.
337    :type line: int
338    :returns: The statement on the line, or None if no statement for the line
339      can be found.
340    :rtype:  astroid.bases.NodeNG or None
341    """
342    if isinstance(node, (nodes.ClassDef, nodes.FunctionDef, nodes.MatchCase)):
343        # This is an inaccuracy in the AST: the nodes that can be
344        # decorated do not carry explicit information on which line
345        # the actual definition (class/def), but .fromline seems to
346        # be close enough.
347        node_line = node.fromlineno
348    else:
349        node_line = node.lineno
350
351    if node_line == line:
352        return node
353
354    for child in node.get_children():
355        result = _find_statement_by_line(child, line)
356        if result:
357            return result
358
359    return None
360
361
362def extract_node(code: str, module_name: str = "") -> Union[NodeNG, List[NodeNG]]:
363    """Parses some Python code as a module and extracts a designated AST node.
364
365    Statements:
366     To extract one or more statement nodes, append #@ to the end of the line
367
368     Examples:
369       >>> def x():
370       >>>   def y():
371       >>>     return 1 #@
372
373       The return statement will be extracted.
374
375       >>> class X(object):
376       >>>   def meth(self): #@
377       >>>     pass
378
379      The function object 'meth' will be extracted.
380
381    Expressions:
382     To extract arbitrary expressions, surround them with the fake
383     function call __(...). After parsing, the surrounded expression
384     will be returned and the whole AST (accessible via the returned
385     node's parent attribute) will look like the function call was
386     never there in the first place.
387
388     Examples:
389       >>> a = __(1)
390
391       The const node will be extracted.
392
393       >>> def x(d=__(foo.bar)): pass
394
395       The node containing the default argument will be extracted.
396
397       >>> def foo(a, b):
398       >>>   return 0 < __(len(a)) < b
399
400       The node containing the function call 'len' will be extracted.
401
402    If no statements or expressions are selected, the last toplevel
403    statement will be returned.
404
405    If the selected statement is a discard statement, (i.e. an expression
406    turned into a statement), the wrapped expression is returned instead.
407
408    For convenience, singleton lists are unpacked.
409
410    :param str code: A piece of Python code that is parsed as
411    a module. Will be passed through textwrap.dedent first.
412    :param str module_name: The name of the module.
413    :returns: The designated node from the parse tree, or a list of nodes.
414    """
415
416    def _extract(node):
417        if isinstance(node, nodes.Expr):
418            return node.value
419
420        return node
421
422    requested_lines = []
423    for idx, line in enumerate(code.splitlines()):
424        if line.strip().endswith(_STATEMENT_SELECTOR):
425            requested_lines.append(idx + 1)
426
427    tree = parse(code, module_name=module_name)
428    if not tree.body:
429        raise ValueError("Empty tree, cannot extract from it")
430
431    extracted = []
432    if requested_lines:
433        extracted = [_find_statement_by_line(tree, line) for line in requested_lines]
434
435    # Modifies the tree.
436    extracted.extend(_extract_expressions(tree))
437
438    if not extracted:
439        extracted.append(tree.body[-1])
440
441    extracted = [_extract(node) for node in extracted]
442    if len(extracted) == 1:
443        return extracted[0]
444    return extracted
445
446
447def _parse_string(data, type_comments=True):
448    parser_module = get_parser_module(type_comments=type_comments)
449    try:
450        parsed = parser_module.parse(data + "\n", type_comments=type_comments)
451    except SyntaxError as exc:
452        # If the type annotations are misplaced for some reason, we do not want
453        # to fail the entire parsing of the file, so we need to retry the parsing without
454        # type comment support.
455        if exc.args[0] != MISPLACED_TYPE_ANNOTATION_ERROR or not type_comments:
456            raise
457
458        parser_module = get_parser_module(type_comments=False)
459        parsed = parser_module.parse(data + "\n", type_comments=False)
460    return parsed, parser_module
461