1# ext/declarative/clsregistry.py 2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors 3# <see AUTHORS file> 4# 5# This module is part of SQLAlchemy and is released under 6# the MIT License: https://www.opensource.org/licenses/mit-license.php 7"""Routines to handle the string class registry used by declarative. 8 9This system allows specification of classes and expressions used in 10:func:`_orm.relationship` using strings. 11 12""" 13import weakref 14 15from . import attributes 16from . import interfaces 17from .descriptor_props import SynonymProperty 18from .properties import ColumnProperty 19from .util import class_mapper 20from .. import exc 21from .. import inspection 22from .. import util 23from ..sql.schema import _get_table_key 24 25# strong references to registries which we place in 26# the _decl_class_registry, which is usually weak referencing. 27# the internal registries here link to classes with weakrefs and remove 28# themselves when all references to contained classes are removed. 29_registries = set() 30 31 32def add_class(classname, cls, decl_class_registry): 33 """Add a class to the _decl_class_registry associated with the 34 given declarative class. 35 36 """ 37 if classname in decl_class_registry: 38 # class already exists. 39 existing = decl_class_registry[classname] 40 if not isinstance(existing, _MultipleClassMarker): 41 existing = decl_class_registry[classname] = _MultipleClassMarker( 42 [cls, existing] 43 ) 44 else: 45 decl_class_registry[classname] = cls 46 47 try: 48 root_module = decl_class_registry["_sa_module_registry"] 49 except KeyError: 50 decl_class_registry[ 51 "_sa_module_registry" 52 ] = root_module = _ModuleMarker("_sa_module_registry", None) 53 54 tokens = cls.__module__.split(".") 55 56 # build up a tree like this: 57 # modulename: myapp.snacks.nuts 58 # 59 # myapp->snack->nuts->(classes) 60 # snack->nuts->(classes) 61 # nuts->(classes) 62 # 63 # this allows partial token paths to be used. 64 while tokens: 65 token = tokens.pop(0) 66 module = root_module.get_module(token) 67 for token in tokens: 68 module = module.get_module(token) 69 module.add_class(classname, cls) 70 71 72def remove_class(classname, cls, decl_class_registry): 73 if classname in decl_class_registry: 74 existing = decl_class_registry[classname] 75 if isinstance(existing, _MultipleClassMarker): 76 existing.remove_item(cls) 77 else: 78 del decl_class_registry[classname] 79 80 try: 81 root_module = decl_class_registry["_sa_module_registry"] 82 except KeyError: 83 return 84 85 tokens = cls.__module__.split(".") 86 87 while tokens: 88 token = tokens.pop(0) 89 module = root_module.get_module(token) 90 for token in tokens: 91 module = module.get_module(token) 92 module.remove_class(classname, cls) 93 94 95def _key_is_empty(key, decl_class_registry, test): 96 """test if a key is empty of a certain object. 97 98 used for unit tests against the registry to see if garbage collection 99 is working. 100 101 "test" is a callable that will be passed an object should return True 102 if the given object is the one we were looking for. 103 104 We can't pass the actual object itself b.c. this is for testing garbage 105 collection; the caller will have to have removed references to the 106 object itself. 107 108 """ 109 if key not in decl_class_registry: 110 return True 111 112 thing = decl_class_registry[key] 113 if isinstance(thing, _MultipleClassMarker): 114 for sub_thing in thing.contents: 115 if test(sub_thing): 116 return False 117 else: 118 return not test(thing) 119 120 121class _MultipleClassMarker(object): 122 """refers to multiple classes of the same name 123 within _decl_class_registry. 124 125 """ 126 127 __slots__ = "on_remove", "contents", "__weakref__" 128 129 def __init__(self, classes, on_remove=None): 130 self.on_remove = on_remove 131 self.contents = set( 132 [weakref.ref(item, self._remove_item) for item in classes] 133 ) 134 _registries.add(self) 135 136 def remove_item(self, cls): 137 self._remove_item(weakref.ref(cls)) 138 139 def __iter__(self): 140 return (ref() for ref in self.contents) 141 142 def attempt_get(self, path, key): 143 if len(self.contents) > 1: 144 raise exc.InvalidRequestError( 145 'Multiple classes found for path "%s" ' 146 "in the registry of this declarative " 147 "base. Please use a fully module-qualified path." 148 % (".".join(path + [key])) 149 ) 150 else: 151 ref = list(self.contents)[0] 152 cls = ref() 153 if cls is None: 154 raise NameError(key) 155 return cls 156 157 def _remove_item(self, ref): 158 self.contents.discard(ref) 159 if not self.contents: 160 _registries.discard(self) 161 if self.on_remove: 162 self.on_remove() 163 164 def add_item(self, item): 165 # protect against class registration race condition against 166 # asynchronous garbage collection calling _remove_item, 167 # [ticket:3208] 168 modules = set( 169 [ 170 cls.__module__ 171 for cls in [ref() for ref in self.contents] 172 if cls is not None 173 ] 174 ) 175 if item.__module__ in modules: 176 util.warn( 177 "This declarative base already contains a class with the " 178 "same class name and module name as %s.%s, and will " 179 "be replaced in the string-lookup table." 180 % (item.__module__, item.__name__) 181 ) 182 self.contents.add(weakref.ref(item, self._remove_item)) 183 184 185class _ModuleMarker(object): 186 """Refers to a module name within 187 _decl_class_registry. 188 189 """ 190 191 __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" 192 193 def __init__(self, name, parent): 194 self.parent = parent 195 self.name = name 196 self.contents = {} 197 self.mod_ns = _ModNS(self) 198 if self.parent: 199 self.path = self.parent.path + [self.name] 200 else: 201 self.path = [] 202 _registries.add(self) 203 204 def __contains__(self, name): 205 return name in self.contents 206 207 def __getitem__(self, name): 208 return self.contents[name] 209 210 def _remove_item(self, name): 211 self.contents.pop(name, None) 212 if not self.contents and self.parent is not None: 213 self.parent._remove_item(self.name) 214 _registries.discard(self) 215 216 def resolve_attr(self, key): 217 return getattr(self.mod_ns, key) 218 219 def get_module(self, name): 220 if name not in self.contents: 221 marker = _ModuleMarker(name, self) 222 self.contents[name] = marker 223 else: 224 marker = self.contents[name] 225 return marker 226 227 def add_class(self, name, cls): 228 if name in self.contents: 229 existing = self.contents[name] 230 existing.add_item(cls) 231 else: 232 existing = self.contents[name] = _MultipleClassMarker( 233 [cls], on_remove=lambda: self._remove_item(name) 234 ) 235 236 def remove_class(self, name, cls): 237 if name in self.contents: 238 existing = self.contents[name] 239 existing.remove_item(cls) 240 241 242class _ModNS(object): 243 __slots__ = ("__parent",) 244 245 def __init__(self, parent): 246 self.__parent = parent 247 248 def __getattr__(self, key): 249 try: 250 value = self.__parent.contents[key] 251 except KeyError: 252 pass 253 else: 254 if value is not None: 255 if isinstance(value, _ModuleMarker): 256 return value.mod_ns 257 else: 258 assert isinstance(value, _MultipleClassMarker) 259 return value.attempt_get(self.__parent.path, key) 260 raise AttributeError( 261 "Module %r has no mapped classes " 262 "registered under the name %r" % (self.__parent.name, key) 263 ) 264 265 266class _GetColumns(object): 267 __slots__ = ("cls",) 268 269 def __init__(self, cls): 270 self.cls = cls 271 272 def __getattr__(self, key): 273 mp = class_mapper(self.cls, configure=False) 274 if mp: 275 if key not in mp.all_orm_descriptors: 276 raise AttributeError( 277 "Class %r does not have a mapped column named %r" 278 % (self.cls, key) 279 ) 280 281 desc = mp.all_orm_descriptors[key] 282 if desc.extension_type is interfaces.NOT_EXTENSION: 283 prop = desc.property 284 if isinstance(prop, SynonymProperty): 285 key = prop.name 286 elif not isinstance(prop, ColumnProperty): 287 raise exc.InvalidRequestError( 288 "Property %r is not an instance of" 289 " ColumnProperty (i.e. does not correspond" 290 " directly to a Column)." % key 291 ) 292 return getattr(self.cls, key) 293 294 295inspection._inspects(_GetColumns)( 296 lambda target: inspection.inspect(target.cls) 297) 298 299 300class _GetTable(object): 301 __slots__ = "key", "metadata" 302 303 def __init__(self, key, metadata): 304 self.key = key 305 self.metadata = metadata 306 307 def __getattr__(self, key): 308 return self.metadata.tables[_get_table_key(key, self.key)] 309 310 311def _determine_container(key, value): 312 if isinstance(value, _MultipleClassMarker): 313 value = value.attempt_get([], key) 314 return _GetColumns(value) 315 316 317class _class_resolver(object): 318 __slots__ = ( 319 "cls", 320 "prop", 321 "arg", 322 "fallback", 323 "_dict", 324 "_resolvers", 325 "favor_tables", 326 ) 327 328 def __init__(self, cls, prop, fallback, arg, favor_tables=False): 329 self.cls = cls 330 self.prop = prop 331 self.arg = arg 332 self.fallback = fallback 333 self._dict = util.PopulateDict(self._access_cls) 334 self._resolvers = () 335 self.favor_tables = favor_tables 336 337 def _access_cls(self, key): 338 cls = self.cls 339 340 manager = attributes.manager_of_class(cls) 341 decl_base = manager.registry 342 decl_class_registry = decl_base._class_registry 343 metadata = decl_base.metadata 344 345 if self.favor_tables: 346 if key in metadata.tables: 347 return metadata.tables[key] 348 elif key in metadata._schemas: 349 return _GetTable(key, cls.metadata) 350 351 if key in decl_class_registry: 352 return _determine_container(key, decl_class_registry[key]) 353 354 if not self.favor_tables: 355 if key in metadata.tables: 356 return metadata.tables[key] 357 elif key in metadata._schemas: 358 return _GetTable(key, cls.metadata) 359 360 if ( 361 "_sa_module_registry" in decl_class_registry 362 and key in decl_class_registry["_sa_module_registry"] 363 ): 364 registry = decl_class_registry["_sa_module_registry"] 365 return registry.resolve_attr(key) 366 elif self._resolvers: 367 for resolv in self._resolvers: 368 value = resolv(key) 369 if value is not None: 370 return value 371 372 return self.fallback[key] 373 374 def _raise_for_name(self, name, err): 375 util.raise_( 376 exc.InvalidRequestError( 377 "When initializing mapper %s, expression %r failed to " 378 "locate a name (%r). If this is a class name, consider " 379 "adding this relationship() to the %r class after " 380 "both dependent classes have been defined." 381 % (self.prop.parent, self.arg, name, self.cls) 382 ), 383 from_=err, 384 ) 385 386 def _resolve_name(self): 387 name = self.arg 388 d = self._dict 389 rval = None 390 try: 391 for token in name.split("."): 392 if rval is None: 393 rval = d[token] 394 else: 395 rval = getattr(rval, token) 396 except KeyError as err: 397 self._raise_for_name(name, err) 398 except NameError as n: 399 self._raise_for_name(n.args[0], n) 400 else: 401 if isinstance(rval, _GetColumns): 402 return rval.cls 403 else: 404 return rval 405 406 def __call__(self): 407 try: 408 x = eval(self.arg, globals(), self._dict) 409 410 if isinstance(x, _GetColumns): 411 return x.cls 412 else: 413 return x 414 except NameError as n: 415 self._raise_for_name(n.args[0], n) 416 417 418_fallback_dict = None 419 420 421def _resolver(cls, prop): 422 423 global _fallback_dict 424 425 if _fallback_dict is None: 426 import sqlalchemy 427 from sqlalchemy.orm import foreign, remote 428 429 _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union( 430 {"foreign": foreign, "remote": remote} 431 ) 432 433 def resolve_arg(arg, favor_tables=False): 434 return _class_resolver( 435 cls, prop, _fallback_dict, arg, favor_tables=favor_tables 436 ) 437 438 def resolve_name(arg): 439 return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name 440 441 return resolve_name, resolve_arg 442