1# Copyright (c) 2006-2011, 2013-2014 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
2# Copyright (c) 2014-2020 Claudiu Popa <pcmanticore@gmail.com>
3# Copyright (c) 2014 BioGeek <jeroen.vangoey@gmail.com>
4# Copyright (c) 2014 Google, Inc.
5# Copyright (c) 2014 Eevee (Alex Munroe) <amunroe@yelp.com>
6# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
7# Copyright (c) 2016 Derek Gustafson <degustaf@gmail.com>
8# Copyright (c) 2017 Iva Miholic <ivamiho@gmail.com>
9# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
10# Copyright (c) 2018 Nick Drozd <nicholasdrozd@gmail.com>
11# Copyright (c) 2019 Raphael Gaschignard <raphael@makeleaps.com>
12# Copyright (c) 2020-2021 hippo91 <guillaume.peillex@gmail.com>
13# Copyright (c) 2020 Raphael Gaschignard <raphael@rtpg.co>
14# Copyright (c) 2020 Anubhav <35621759+anubh-v@users.noreply.github.com>
15# Copyright (c) 2020 Ashley Whetter <ashley@awhetter.co.uk>
16# Copyright (c) 2021 Daniël van Noord <13665637+DanielNoord@users.noreply.github.com>
17# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
18# Copyright (c) 2021 grayjk <grayjk@gmail.com>
19# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
20# Copyright (c) 2021 Andrew Haigh <hello@nelf.in>
21# Copyright (c) 2021 DudeNr33 <3929834+DudeNr33@users.noreply.github.com>
22# Copyright (c) 2021 pre-commit-ci[bot] <bot@noreply.github.com>
23
24# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
25# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
26
27"""astroid manager: avoid multiple astroid build of a same module when
28possible by providing a class responsible to get astroid representation
29from various source and using a cache of built modules)
30"""
31
32import os
33import types
34import zipimport
35from typing import TYPE_CHECKING, ClassVar, List
36
37from astroid.exceptions import AstroidBuildingError, AstroidImportError
38from astroid.interpreter._import import spec
39from astroid.modutils import (
40    NoSourceFile,
41    file_info_from_modpath,
42    get_source_file,
43    is_module_name_part_of_extension_package_whitelist,
44    is_python_source,
45    is_standard_module,
46    load_module_from_name,
47    modpath_from_file,
48)
49from astroid.transforms import TransformVisitor
50
51if TYPE_CHECKING:
52    from astroid import nodes
53
54ZIP_IMPORT_EXTS = (".zip", ".egg", ".whl", ".pyz", ".pyzw")
55
56
57def safe_repr(obj):
58    try:
59        return repr(obj)
60    except Exception:  # pylint: disable=broad-except
61        return "???"
62
63
64class AstroidManager:
65    """Responsible to build astroid from files or modules.
66
67    Use the Borg (singleton) pattern.
68    """
69
70    name = "astroid loader"
71    brain = {}
72    max_inferable_values: ClassVar[int] = 100
73
74    def __init__(self):
75        self.__dict__ = AstroidManager.brain
76        if not self.__dict__:
77            # NOTE: cache entries are added by the [re]builder
78            self.astroid_cache = {}
79            self._mod_file_cache = {}
80            self._failed_import_hooks = []
81            self.always_load_extensions = False
82            self.optimize_ast = False
83            self.extension_package_whitelist = set()
84            self._transform = TransformVisitor()
85
86    @property
87    def register_transform(self):
88        # This and unregister_transform below are exported for convenience
89        return self._transform.register_transform
90
91    @property
92    def unregister_transform(self):
93        return self._transform.unregister_transform
94
95    @property
96    def builtins_module(self):
97        return self.astroid_cache["builtins"]
98
99    def visit_transforms(self, node):
100        """Visit the transforms and apply them to the given *node*."""
101        return self._transform.visit(node)
102
103    def ast_from_file(self, filepath, modname=None, fallback=True, source=False):
104        """given a module name, return the astroid object"""
105        try:
106            filepath = get_source_file(filepath, include_no_ext=True)
107            source = True
108        except NoSourceFile:
109            pass
110        if modname is None:
111            try:
112                modname = ".".join(modpath_from_file(filepath))
113            except ImportError:
114                modname = filepath
115        if (
116            modname in self.astroid_cache
117            and self.astroid_cache[modname].file == filepath
118        ):
119            return self.astroid_cache[modname]
120        if source:
121            # pylint: disable=import-outside-toplevel; circular import
122            from astroid.builder import AstroidBuilder
123
124            return AstroidBuilder(self).file_build(filepath, modname)
125        if fallback and modname:
126            return self.ast_from_module_name(modname)
127        raise AstroidBuildingError("Unable to build an AST for {path}.", path=filepath)
128
129    def ast_from_string(self, data, modname="", filepath=None):
130        """Given some source code as a string, return its corresponding astroid object"""
131        # pylint: disable=import-outside-toplevel; circular import
132        from astroid.builder import AstroidBuilder
133
134        return AstroidBuilder(self).string_build(data, modname, filepath)
135
136    def _build_stub_module(self, modname):
137        # pylint: disable=import-outside-toplevel; circular import
138        from astroid.builder import AstroidBuilder
139
140        return AstroidBuilder(self).string_build("", modname)
141
142    def _build_namespace_module(self, modname: str, path: List[str]) -> "nodes.Module":
143        # pylint: disable=import-outside-toplevel; circular import
144        from astroid.builder import build_namespace_package_module
145
146        return build_namespace_package_module(modname, path)
147
148    def _can_load_extension(self, modname: str) -> bool:
149        if self.always_load_extensions:
150            return True
151        if is_standard_module(modname):
152            return True
153        return is_module_name_part_of_extension_package_whitelist(
154            modname, self.extension_package_whitelist
155        )
156
157    def ast_from_module_name(self, modname, context_file=None):
158        """given a module name, return the astroid object"""
159        if modname in self.astroid_cache:
160            return self.astroid_cache[modname]
161        if modname == "__main__":
162            return self._build_stub_module(modname)
163        if context_file:
164            old_cwd = os.getcwd()
165            os.chdir(os.path.dirname(context_file))
166        try:
167            found_spec = self.file_from_module_name(modname, context_file)
168            if found_spec.type == spec.ModuleType.PY_ZIPMODULE:
169                module = self.zip_import_data(found_spec.location)
170                if module is not None:
171                    return module
172
173            elif found_spec.type in (
174                spec.ModuleType.C_BUILTIN,
175                spec.ModuleType.C_EXTENSION,
176            ):
177                if (
178                    found_spec.type == spec.ModuleType.C_EXTENSION
179                    and not self._can_load_extension(modname)
180                ):
181                    return self._build_stub_module(modname)
182                try:
183                    module = load_module_from_name(modname)
184                except Exception as e:
185                    raise AstroidImportError(
186                        "Loading {modname} failed with:\n{error}",
187                        modname=modname,
188                        path=found_spec.location,
189                    ) from e
190                return self.ast_from_module(module, modname)
191
192            elif found_spec.type == spec.ModuleType.PY_COMPILED:
193                raise AstroidImportError(
194                    "Unable to load compiled module {modname}.",
195                    modname=modname,
196                    path=found_spec.location,
197                )
198
199            elif found_spec.type == spec.ModuleType.PY_NAMESPACE:
200                return self._build_namespace_module(
201                    modname, found_spec.submodule_search_locations
202                )
203            elif found_spec.type == spec.ModuleType.PY_FROZEN:
204                return self._build_stub_module(modname)
205
206            if found_spec.location is None:
207                raise AstroidImportError(
208                    "Can't find a file for module {modname}.", modname=modname
209                )
210
211            return self.ast_from_file(found_spec.location, modname, fallback=False)
212        except AstroidBuildingError as e:
213            for hook in self._failed_import_hooks:
214                try:
215                    return hook(modname)
216                except AstroidBuildingError:
217                    pass
218            raise e
219        finally:
220            if context_file:
221                os.chdir(old_cwd)
222
223    def zip_import_data(self, filepath):
224        if zipimport is None:
225            return None
226
227        # pylint: disable=import-outside-toplevel; circular import
228        from astroid.builder import AstroidBuilder
229
230        builder = AstroidBuilder(self)
231        for ext in ZIP_IMPORT_EXTS:
232            try:
233                eggpath, resource = filepath.rsplit(ext + os.path.sep, 1)
234            except ValueError:
235                continue
236            try:
237                importer = zipimport.zipimporter(eggpath + ext)
238                # pylint: enable=no-member
239                zmodname = resource.replace(os.path.sep, ".")
240                if importer.is_package(resource):
241                    zmodname = zmodname + ".__init__"
242                module = builder.string_build(
243                    importer.get_source(resource), zmodname, filepath
244                )
245                return module
246            except Exception:  # pylint: disable=broad-except
247                continue
248        return None
249
250    def file_from_module_name(self, modname, contextfile):
251        try:
252            value = self._mod_file_cache[(modname, contextfile)]
253        except KeyError:
254            try:
255                value = file_info_from_modpath(
256                    modname.split("."), context_file=contextfile
257                )
258            except ImportError as e:
259                value = AstroidImportError(
260                    "Failed to import module {modname} with error:\n{error}.",
261                    modname=modname,
262                    # we remove the traceback here to save on memory usage (since these exceptions are cached)
263                    error=e.with_traceback(None),
264                )
265            self._mod_file_cache[(modname, contextfile)] = value
266        if isinstance(value, AstroidBuildingError):
267            # we remove the traceback here to save on memory usage (since these exceptions are cached)
268            raise value.with_traceback(None)
269        return value
270
271    def ast_from_module(self, module: types.ModuleType, modname: str = None):
272        """given an imported module, return the astroid object"""
273        modname = modname or module.__name__
274        if modname in self.astroid_cache:
275            return self.astroid_cache[modname]
276        try:
277            # some builtin modules don't have __file__ attribute
278            filepath = module.__file__
279            if is_python_source(filepath):
280                return self.ast_from_file(filepath, modname)
281        except AttributeError:
282            pass
283
284        # pylint: disable=import-outside-toplevel; circular import
285        from astroid.builder import AstroidBuilder
286
287        return AstroidBuilder(self).module_build(module, modname)
288
289    def ast_from_class(self, klass, modname=None):
290        """get astroid for the given class"""
291        if modname is None:
292            try:
293                modname = klass.__module__
294            except AttributeError as exc:
295                raise AstroidBuildingError(
296                    "Unable to get module for class {class_name}.",
297                    cls=klass,
298                    class_repr=safe_repr(klass),
299                    modname=modname,
300                ) from exc
301        modastroid = self.ast_from_module_name(modname)
302        return modastroid.getattr(klass.__name__)[0]  # XXX
303
304    def infer_ast_from_something(self, obj, context=None):
305        """infer astroid for the given class"""
306        if hasattr(obj, "__class__") and not isinstance(obj, type):
307            klass = obj.__class__
308        else:
309            klass = obj
310        try:
311            modname = klass.__module__
312        except AttributeError as exc:
313            raise AstroidBuildingError(
314                "Unable to get module for {class_repr}.",
315                cls=klass,
316                class_repr=safe_repr(klass),
317            ) from exc
318        except Exception as exc:
319            raise AstroidImportError(
320                "Unexpected error while retrieving module for {class_repr}:\n"
321                "{error}",
322                cls=klass,
323                class_repr=safe_repr(klass),
324            ) from exc
325        try:
326            name = klass.__name__
327        except AttributeError as exc:
328            raise AstroidBuildingError(
329                "Unable to get name for {class_repr}:\n",
330                cls=klass,
331                class_repr=safe_repr(klass),
332            ) from exc
333        except Exception as exc:
334            raise AstroidImportError(
335                "Unexpected error while retrieving name for {class_repr}:\n" "{error}",
336                cls=klass,
337                class_repr=safe_repr(klass),
338            ) from exc
339        # take care, on living object __module__ is regularly wrong :(
340        modastroid = self.ast_from_module_name(modname)
341        if klass is obj:
342            for inferred in modastroid.igetattr(name, context):
343                yield inferred
344        else:
345            for inferred in modastroid.igetattr(name, context):
346                yield inferred.instantiate_class()
347
348    def register_failed_import_hook(self, hook):
349        """Registers a hook to resolve imports that cannot be found otherwise.
350
351        `hook` must be a function that accepts a single argument `modname` which
352        contains the name of the module or package that could not be imported.
353        If `hook` can resolve the import, must return a node of type `astroid.Module`,
354        otherwise, it must raise `AstroidBuildingError`.
355        """
356        self._failed_import_hooks.append(hook)
357
358    def cache_module(self, module):
359        """Cache a module if no module with the same name is known yet."""
360        self.astroid_cache.setdefault(module.name, module)
361
362    def bootstrap(self):
363        """Bootstrap the required AST modules needed for the manager to work
364
365        The bootstrap usually involves building the AST for the builtins
366        module, which is required by the rest of astroid to work correctly.
367        """
368        from astroid import raw_building  # pylint: disable=import-outside-toplevel
369
370        raw_building._astroid_bootstrapping()
371
372    def clear_cache(self):
373        """Clear the underlying cache. Also bootstraps the builtins module."""
374        self.astroid_cache.clear()
375        self.bootstrap()
376