1# ext/declarative/clsregistry.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7"""Routines to handle the string class registry used by declarative.
8
9This system allows specification of classes and expressions used in
10:func:`_orm.relationship` using strings.
11
12"""
13import weakref
14
15from . import attributes
16from . import interfaces
17from .descriptor_props import SynonymProperty
18from .properties import ColumnProperty
19from .util import class_mapper
20from .. import exc
21from .. import inspection
22from .. import util
23from ..sql.schema import _get_table_key
24
25# strong references to registries which we place in
26# the _decl_class_registry, which is usually weak referencing.
27# the internal registries here link to classes with weakrefs and remove
28# themselves when all references to contained classes are removed.
29_registries = set()
30
31
32def add_class(classname, cls, decl_class_registry):
33    """Add a class to the _decl_class_registry associated with the
34    given declarative class.
35
36    """
37    if classname in decl_class_registry:
38        # class already exists.
39        existing = decl_class_registry[classname]
40        if not isinstance(existing, _MultipleClassMarker):
41            existing = decl_class_registry[classname] = _MultipleClassMarker(
42                [cls, existing]
43            )
44    else:
45        decl_class_registry[classname] = cls
46
47    try:
48        root_module = decl_class_registry["_sa_module_registry"]
49    except KeyError:
50        decl_class_registry[
51            "_sa_module_registry"
52        ] = root_module = _ModuleMarker("_sa_module_registry", None)
53
54    tokens = cls.__module__.split(".")
55
56    # build up a tree like this:
57    # modulename:  myapp.snacks.nuts
58    #
59    # myapp->snack->nuts->(classes)
60    # snack->nuts->(classes)
61    # nuts->(classes)
62    #
63    # this allows partial token paths to be used.
64    while tokens:
65        token = tokens.pop(0)
66        module = root_module.get_module(token)
67        for token in tokens:
68            module = module.get_module(token)
69        module.add_class(classname, cls)
70
71
72def remove_class(classname, cls, decl_class_registry):
73    if classname in decl_class_registry:
74        existing = decl_class_registry[classname]
75        if isinstance(existing, _MultipleClassMarker):
76            existing.remove_item(cls)
77        else:
78            del decl_class_registry[classname]
79
80    try:
81        root_module = decl_class_registry["_sa_module_registry"]
82    except KeyError:
83        return
84
85    tokens = cls.__module__.split(".")
86
87    while tokens:
88        token = tokens.pop(0)
89        module = root_module.get_module(token)
90        for token in tokens:
91            module = module.get_module(token)
92        module.remove_class(classname, cls)
93
94
95def _key_is_empty(key, decl_class_registry, test):
96    """test if a key is empty of a certain object.
97
98    used for unit tests against the registry to see if garbage collection
99    is working.
100
101    "test" is a callable that will be passed an object should return True
102    if the given object is the one we were looking for.
103
104    We can't pass the actual object itself b.c. this is for testing garbage
105    collection; the caller will have to have removed references to the
106    object itself.
107
108    """
109    if key not in decl_class_registry:
110        return True
111
112    thing = decl_class_registry[key]
113    if isinstance(thing, _MultipleClassMarker):
114        for sub_thing in thing.contents:
115            if test(sub_thing):
116                return False
117    else:
118        return not test(thing)
119
120
121class _MultipleClassMarker(object):
122    """refers to multiple classes of the same name
123    within _decl_class_registry.
124
125    """
126
127    __slots__ = "on_remove", "contents", "__weakref__"
128
129    def __init__(self, classes, on_remove=None):
130        self.on_remove = on_remove
131        self.contents = set(
132            [weakref.ref(item, self._remove_item) for item in classes]
133        )
134        _registries.add(self)
135
136    def remove_item(self, cls):
137        self._remove_item(weakref.ref(cls))
138
139    def __iter__(self):
140        return (ref() for ref in self.contents)
141
142    def attempt_get(self, path, key):
143        if len(self.contents) > 1:
144            raise exc.InvalidRequestError(
145                'Multiple classes found for path "%s" '
146                "in the registry of this declarative "
147                "base. Please use a fully module-qualified path."
148                % (".".join(path + [key]))
149            )
150        else:
151            ref = list(self.contents)[0]
152            cls = ref()
153            if cls is None:
154                raise NameError(key)
155            return cls
156
157    def _remove_item(self, ref):
158        self.contents.discard(ref)
159        if not self.contents:
160            _registries.discard(self)
161            if self.on_remove:
162                self.on_remove()
163
164    def add_item(self, item):
165        # protect against class registration race condition against
166        # asynchronous garbage collection calling _remove_item,
167        # [ticket:3208]
168        modules = set(
169            [
170                cls.__module__
171                for cls in [ref() for ref in self.contents]
172                if cls is not None
173            ]
174        )
175        if item.__module__ in modules:
176            util.warn(
177                "This declarative base already contains a class with the "
178                "same class name and module name as %s.%s, and will "
179                "be replaced in the string-lookup table."
180                % (item.__module__, item.__name__)
181            )
182        self.contents.add(weakref.ref(item, self._remove_item))
183
184
185class _ModuleMarker(object):
186    """Refers to a module name within
187    _decl_class_registry.
188
189    """
190
191    __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
192
193    def __init__(self, name, parent):
194        self.parent = parent
195        self.name = name
196        self.contents = {}
197        self.mod_ns = _ModNS(self)
198        if self.parent:
199            self.path = self.parent.path + [self.name]
200        else:
201            self.path = []
202        _registries.add(self)
203
204    def __contains__(self, name):
205        return name in self.contents
206
207    def __getitem__(self, name):
208        return self.contents[name]
209
210    def _remove_item(self, name):
211        self.contents.pop(name, None)
212        if not self.contents and self.parent is not None:
213            self.parent._remove_item(self.name)
214            _registries.discard(self)
215
216    def resolve_attr(self, key):
217        return getattr(self.mod_ns, key)
218
219    def get_module(self, name):
220        if name not in self.contents:
221            marker = _ModuleMarker(name, self)
222            self.contents[name] = marker
223        else:
224            marker = self.contents[name]
225        return marker
226
227    def add_class(self, name, cls):
228        if name in self.contents:
229            existing = self.contents[name]
230            existing.add_item(cls)
231        else:
232            existing = self.contents[name] = _MultipleClassMarker(
233                [cls], on_remove=lambda: self._remove_item(name)
234            )
235
236    def remove_class(self, name, cls):
237        if name in self.contents:
238            existing = self.contents[name]
239            existing.remove_item(cls)
240
241
242class _ModNS(object):
243    __slots__ = ("__parent",)
244
245    def __init__(self, parent):
246        self.__parent = parent
247
248    def __getattr__(self, key):
249        try:
250            value = self.__parent.contents[key]
251        except KeyError:
252            pass
253        else:
254            if value is not None:
255                if isinstance(value, _ModuleMarker):
256                    return value.mod_ns
257                else:
258                    assert isinstance(value, _MultipleClassMarker)
259                    return value.attempt_get(self.__parent.path, key)
260        raise AttributeError(
261            "Module %r has no mapped classes "
262            "registered under the name %r" % (self.__parent.name, key)
263        )
264
265
266class _GetColumns(object):
267    __slots__ = ("cls",)
268
269    def __init__(self, cls):
270        self.cls = cls
271
272    def __getattr__(self, key):
273        mp = class_mapper(self.cls, configure=False)
274        if mp:
275            if key not in mp.all_orm_descriptors:
276                raise AttributeError(
277                    "Class %r does not have a mapped column named %r"
278                    % (self.cls, key)
279                )
280
281            desc = mp.all_orm_descriptors[key]
282            if desc.extension_type is interfaces.NOT_EXTENSION:
283                prop = desc.property
284                if isinstance(prop, SynonymProperty):
285                    key = prop.name
286                elif not isinstance(prop, ColumnProperty):
287                    raise exc.InvalidRequestError(
288                        "Property %r is not an instance of"
289                        " ColumnProperty (i.e. does not correspond"
290                        " directly to a Column)." % key
291                    )
292        return getattr(self.cls, key)
293
294
295inspection._inspects(_GetColumns)(
296    lambda target: inspection.inspect(target.cls)
297)
298
299
300class _GetTable(object):
301    __slots__ = "key", "metadata"
302
303    def __init__(self, key, metadata):
304        self.key = key
305        self.metadata = metadata
306
307    def __getattr__(self, key):
308        return self.metadata.tables[_get_table_key(key, self.key)]
309
310
311def _determine_container(key, value):
312    if isinstance(value, _MultipleClassMarker):
313        value = value.attempt_get([], key)
314    return _GetColumns(value)
315
316
317class _class_resolver(object):
318    __slots__ = (
319        "cls",
320        "prop",
321        "arg",
322        "fallback",
323        "_dict",
324        "_resolvers",
325        "favor_tables",
326    )
327
328    def __init__(self, cls, prop, fallback, arg, favor_tables=False):
329        self.cls = cls
330        self.prop = prop
331        self.arg = arg
332        self.fallback = fallback
333        self._dict = util.PopulateDict(self._access_cls)
334        self._resolvers = ()
335        self.favor_tables = favor_tables
336
337    def _access_cls(self, key):
338        cls = self.cls
339
340        manager = attributes.manager_of_class(cls)
341        decl_base = manager.registry
342        decl_class_registry = decl_base._class_registry
343        metadata = decl_base.metadata
344
345        if self.favor_tables:
346            if key in metadata.tables:
347                return metadata.tables[key]
348            elif key in metadata._schemas:
349                return _GetTable(key, cls.metadata)
350
351        if key in decl_class_registry:
352            return _determine_container(key, decl_class_registry[key])
353
354        if not self.favor_tables:
355            if key in metadata.tables:
356                return metadata.tables[key]
357            elif key in metadata._schemas:
358                return _GetTable(key, cls.metadata)
359
360        if (
361            "_sa_module_registry" in decl_class_registry
362            and key in decl_class_registry["_sa_module_registry"]
363        ):
364            registry = decl_class_registry["_sa_module_registry"]
365            return registry.resolve_attr(key)
366        elif self._resolvers:
367            for resolv in self._resolvers:
368                value = resolv(key)
369                if value is not None:
370                    return value
371
372        return self.fallback[key]
373
374    def _raise_for_name(self, name, err):
375        util.raise_(
376            exc.InvalidRequestError(
377                "When initializing mapper %s, expression %r failed to "
378                "locate a name (%r). If this is a class name, consider "
379                "adding this relationship() to the %r class after "
380                "both dependent classes have been defined."
381                % (self.prop.parent, self.arg, name, self.cls)
382            ),
383            from_=err,
384        )
385
386    def _resolve_name(self):
387        name = self.arg
388        d = self._dict
389        rval = None
390        try:
391            for token in name.split("."):
392                if rval is None:
393                    rval = d[token]
394                else:
395                    rval = getattr(rval, token)
396        except KeyError as err:
397            self._raise_for_name(name, err)
398        except NameError as n:
399            self._raise_for_name(n.args[0], n)
400        else:
401            if isinstance(rval, _GetColumns):
402                return rval.cls
403            else:
404                return rval
405
406    def __call__(self):
407        try:
408            x = eval(self.arg, globals(), self._dict)
409
410            if isinstance(x, _GetColumns):
411                return x.cls
412            else:
413                return x
414        except NameError as n:
415            self._raise_for_name(n.args[0], n)
416
417
418_fallback_dict = None
419
420
421def _resolver(cls, prop):
422
423    global _fallback_dict
424
425    if _fallback_dict is None:
426        import sqlalchemy
427        from sqlalchemy.orm import foreign, remote
428
429        _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union(
430            {"foreign": foreign, "remote": remote}
431        )
432
433    def resolve_arg(arg, favor_tables=False):
434        return _class_resolver(
435            cls, prop, _fallback_dict, arg, favor_tables=favor_tables
436        )
437
438    def resolve_name(arg):
439        return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
440
441    return resolve_name, resolve_arg
442