1from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, ne_ 2from sqlalchemy import util 3import sqlalchemy as sa 4from sqlalchemy.orm import class_mapper 5from sqlalchemy.orm import attributes 6from sqlalchemy.orm.attributes import set_attribute, \ 7 get_attribute, del_attribute 8from sqlalchemy.orm.instrumentation import is_instrumented 9from sqlalchemy.orm import clear_mappers 10from sqlalchemy.testing import fixtures 11from sqlalchemy.ext import instrumentation 12from sqlalchemy.orm.instrumentation import register_class, manager_of_class 13from sqlalchemy.testing.util import decorator 14from sqlalchemy.orm import events 15from sqlalchemy import event 16 17 18@decorator 19def modifies_instrumentation_finders(fn, *args, **kw): 20 pristine = instrumentation.instrumentation_finders[:] 21 try: 22 fn(*args, **kw) 23 finally: 24 del instrumentation.instrumentation_finders[:] 25 instrumentation.instrumentation_finders.extend(pristine) 26 27 28class _ExtBase(object): 29 @classmethod 30 def teardown_class(cls): 31 instrumentation._reinstall_default_lookups() 32 33 34class MyTypesManager(instrumentation.InstrumentationManager): 35 36 def instrument_attribute(self, class_, key, attr): 37 pass 38 39 def install_descriptor(self, class_, key, attr): 40 pass 41 42 def uninstall_descriptor(self, class_, key): 43 pass 44 45 def instrument_collection_class(self, class_, key, collection_class): 46 return MyListLike 47 48 def get_instance_dict(self, class_, instance): 49 return instance._goofy_dict 50 51 def initialize_instance_dict(self, class_, instance): 52 instance.__dict__['_goofy_dict'] = {} 53 54 def install_state(self, class_, instance, state): 55 instance.__dict__['_my_state'] = state 56 57 def state_getter(self, class_): 58 return lambda instance: instance.__dict__['_my_state'] 59 60 61class MyListLike(list): 62 # add @appender, @remover decorators as needed 63 _sa_iterator = list.__iter__ 64 _sa_linker = None 65 _sa_converter = None 66 67 def _sa_appender(self, item, _sa_initiator=None): 68 if _sa_initiator is not False: 69 self._sa_adapter.fire_append_event(item, _sa_initiator) 70 list.append(self, item) 71 append = _sa_appender 72 73 def _sa_remover(self, item, _sa_initiator=None): 74 self._sa_adapter.fire_pre_remove_event(_sa_initiator) 75 if _sa_initiator is not False: 76 self._sa_adapter.fire_remove_event(item, _sa_initiator) 77 list.remove(self, item) 78 remove = _sa_remover 79 80 81MyBaseClass, MyClass = None, None 82 83 84class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): 85 86 @classmethod 87 def setup_class(cls): 88 global MyBaseClass, MyClass 89 90 class MyBaseClass(object): 91 __sa_instrumentation_manager__ = \ 92 instrumentation.InstrumentationManager 93 94 class MyClass(object): 95 96 # This proves that a staticmethod will work here; don't 97 # flatten this back to a class assignment! 98 def __sa_instrumentation_manager__(cls): 99 return MyTypesManager(cls) 100 101 __sa_instrumentation_manager__ = staticmethod( 102 __sa_instrumentation_manager__) 103 104 # This proves SA can handle a class with non-string dict keys 105 if not util.pypy and not util.jython: 106 locals()[42] = 99 # Don't remove this line! 107 108 def __init__(self, **kwargs): 109 for k in kwargs: 110 setattr(self, k, kwargs[k]) 111 112 def __getattr__(self, key): 113 if is_instrumented(self, key): 114 return get_attribute(self, key) 115 else: 116 try: 117 return self._goofy_dict[key] 118 except KeyError: 119 raise AttributeError(key) 120 121 def __setattr__(self, key, value): 122 if is_instrumented(self, key): 123 set_attribute(self, key, value) 124 else: 125 self._goofy_dict[key] = value 126 127 def __hasattr__(self, key): 128 if is_instrumented(self, key): 129 return True 130 else: 131 return key in self._goofy_dict 132 133 def __delattr__(self, key): 134 if is_instrumented(self, key): 135 del_attribute(self, key) 136 else: 137 del self._goofy_dict[key] 138 139 def teardown(self): 140 clear_mappers() 141 142 def test_instance_dict(self): 143 class User(MyClass): 144 pass 145 146 register_class(User) 147 attributes.register_attribute( 148 User, 'user_id', uselist=False, useobject=False) 149 attributes.register_attribute( 150 User, 'user_name', uselist=False, useobject=False) 151 attributes.register_attribute( 152 User, 'email_address', uselist=False, useobject=False) 153 154 u = User() 155 u.user_id = 7 156 u.user_name = 'john' 157 u.email_address = 'lala@123.com' 158 eq_( 159 u.__dict__, 160 { 161 '_my_state': u._my_state, 162 '_goofy_dict': { 163 'user_id': 7, 'user_name': 'john', 164 'email_address': 'lala@123.com'}} 165 ) 166 167 def test_basic(self): 168 for base in (object, MyBaseClass, MyClass): 169 class User(base): 170 pass 171 172 register_class(User) 173 attributes.register_attribute( 174 User, 'user_id', uselist=False, useobject=False) 175 attributes.register_attribute( 176 User, 'user_name', uselist=False, useobject=False) 177 attributes.register_attribute( 178 User, 'email_address', uselist=False, useobject=False) 179 180 u = User() 181 u.user_id = 7 182 u.user_name = 'john' 183 u.email_address = 'lala@123.com' 184 185 eq_(u.user_id, 7) 186 eq_(u.user_name, "john") 187 eq_(u.email_address, "lala@123.com") 188 attributes.instance_state(u)._commit_all( 189 attributes.instance_dict(u)) 190 eq_(u.user_id, 7) 191 eq_(u.user_name, "john") 192 eq_(u.email_address, "lala@123.com") 193 194 u.user_name = 'heythere' 195 u.email_address = 'foo@bar.com' 196 eq_(u.user_id, 7) 197 eq_(u.user_name, "heythere") 198 eq_(u.email_address, "foo@bar.com") 199 200 def test_deferred(self): 201 for base in (object, MyBaseClass, MyClass): 202 class Foo(base): 203 pass 204 205 data = {'a': 'this is a', 'b': 12} 206 207 def loader(state, keys): 208 for k in keys: 209 state.dict[k] = data[k] 210 return attributes.ATTR_WAS_SET 211 212 manager = register_class(Foo) 213 manager.deferred_scalar_loader = loader 214 attributes.register_attribute( 215 Foo, 'a', uselist=False, useobject=False) 216 attributes.register_attribute( 217 Foo, 'b', uselist=False, useobject=False) 218 219 if base is object: 220 assert Foo not in \ 221 instrumentation._instrumentation_factory._state_finders 222 else: 223 assert Foo in \ 224 instrumentation._instrumentation_factory._state_finders 225 226 f = Foo() 227 attributes.instance_state(f)._expire( 228 attributes.instance_dict(f), set()) 229 eq_(f.a, "this is a") 230 eq_(f.b, 12) 231 232 f.a = "this is some new a" 233 attributes.instance_state(f)._expire( 234 attributes.instance_dict(f), set()) 235 eq_(f.a, "this is a") 236 eq_(f.b, 12) 237 238 attributes.instance_state(f)._expire( 239 attributes.instance_dict(f), set()) 240 f.a = "this is another new a" 241 eq_(f.a, "this is another new a") 242 eq_(f.b, 12) 243 244 attributes.instance_state(f)._expire( 245 attributes.instance_dict(f), set()) 246 eq_(f.a, "this is a") 247 eq_(f.b, 12) 248 249 del f.a 250 eq_(f.a, None) 251 eq_(f.b, 12) 252 253 attributes.instance_state(f)._commit_all( 254 attributes.instance_dict(f)) 255 eq_(f.a, None) 256 eq_(f.b, 12) 257 258 def test_inheritance(self): 259 """tests that attributes are polymorphic""" 260 261 for base in (object, MyBaseClass, MyClass): 262 class Foo(base): 263 pass 264 265 class Bar(Foo): 266 pass 267 268 register_class(Foo) 269 register_class(Bar) 270 271 def func1(state, passive): 272 return "this is the foo attr" 273 274 def func2(state, passive): 275 return "this is the bar attr" 276 277 def func3(state, passive): 278 return "this is the shared attr" 279 attributes.register_attribute(Foo, 'element', 280 uselist=False, callable_=func1, 281 useobject=True) 282 attributes.register_attribute(Foo, 'element2', 283 uselist=False, callable_=func3, 284 useobject=True) 285 attributes.register_attribute(Bar, 'element', 286 uselist=False, callable_=func2, 287 useobject=True) 288 289 x = Foo() 290 y = Bar() 291 assert x.element == 'this is the foo attr' 292 assert y.element == 'this is the bar attr', y.element 293 assert x.element2 == 'this is the shared attr' 294 assert y.element2 == 'this is the shared attr' 295 296 def test_collection_with_backref(self): 297 for base in (object, MyBaseClass, MyClass): 298 class Post(base): 299 pass 300 301 class Blog(base): 302 pass 303 304 register_class(Post) 305 register_class(Blog) 306 attributes.register_attribute( 307 Post, 'blog', uselist=False, 308 backref='posts', trackparent=True, useobject=True) 309 attributes.register_attribute( 310 Blog, 'posts', uselist=True, 311 backref='blog', trackparent=True, useobject=True) 312 b = Blog() 313 (p1, p2, p3) = (Post(), Post(), Post()) 314 b.posts.append(p1) 315 b.posts.append(p2) 316 b.posts.append(p3) 317 self.assert_(b.posts == [p1, p2, p3]) 318 self.assert_(p2.blog is b) 319 320 p3.blog = None 321 self.assert_(b.posts == [p1, p2]) 322 p4 = Post() 323 p4.blog = b 324 self.assert_(b.posts == [p1, p2, p4]) 325 326 p4.blog = b 327 p4.blog = b 328 self.assert_(b.posts == [p1, p2, p4]) 329 330 # assert no failure removing None 331 p5 = Post() 332 p5.blog = None 333 del p5.blog 334 335 def test_history(self): 336 for base in (object, MyBaseClass, MyClass): 337 class Foo(base): 338 pass 339 340 class Bar(base): 341 pass 342 343 register_class(Foo) 344 register_class(Bar) 345 attributes.register_attribute( 346 Foo, "name", uselist=False, useobject=False) 347 attributes.register_attribute( 348 Foo, "bars", uselist=True, trackparent=True, useobject=True) 349 attributes.register_attribute( 350 Bar, "name", uselist=False, useobject=False) 351 352 f1 = Foo() 353 f1.name = 'f1' 354 355 eq_( 356 attributes.get_state_history( 357 attributes.instance_state(f1), 'name'), 358 (['f1'], (), ())) 359 360 b1 = Bar() 361 b1.name = 'b1' 362 f1.bars.append(b1) 363 eq_( 364 attributes.get_state_history( 365 attributes.instance_state(f1), 'bars'), 366 ([b1], [], [])) 367 368 attributes.instance_state(f1)._commit_all( 369 attributes.instance_dict(f1)) 370 attributes.instance_state(b1)._commit_all( 371 attributes.instance_dict(b1)) 372 373 eq_( 374 attributes.get_state_history( 375 attributes.instance_state(f1), 376 'name'), 377 ((), ['f1'], ())) 378 eq_( 379 attributes.get_state_history( 380 attributes.instance_state(f1), 381 'bars'), 382 ((), [b1], ())) 383 384 f1.name = 'f1mod' 385 b2 = Bar() 386 b2.name = 'b2' 387 f1.bars.append(b2) 388 eq_( 389 attributes.get_state_history( 390 attributes.instance_state(f1), 'name'), 391 (['f1mod'], (), ['f1'])) 392 eq_( 393 attributes.get_state_history( 394 attributes.instance_state(f1), 'bars'), 395 ([b2], [b1], [])) 396 f1.bars.remove(b1) 397 eq_( 398 attributes.get_state_history( 399 attributes.instance_state(f1), 'bars'), 400 ([b2], [], [b1])) 401 402 def test_null_instrumentation(self): 403 class Foo(MyBaseClass): 404 pass 405 register_class(Foo) 406 attributes.register_attribute( 407 Foo, "name", uselist=False, useobject=False) 408 attributes.register_attribute( 409 Foo, "bars", uselist=True, trackparent=True, useobject=True) 410 411 assert Foo.name == attributes.manager_of_class(Foo)['name'] 412 assert Foo.bars == attributes.manager_of_class(Foo)['bars'] 413 414 def test_alternate_finders(self): 415 """Ensure the generic finder front-end deals with edge cases.""" 416 417 class Unknown(object): 418 pass 419 420 class Known(MyBaseClass): 421 pass 422 423 register_class(Known) 424 k, u = Known(), Unknown() 425 426 assert instrumentation.manager_of_class(Unknown) is None 427 assert instrumentation.manager_of_class(Known) is not None 428 assert instrumentation.manager_of_class(None) is None 429 430 assert attributes.instance_state(k) is not None 431 assert_raises((AttributeError, KeyError), 432 attributes.instance_state, u) 433 assert_raises((AttributeError, KeyError), 434 attributes.instance_state, None) 435 436 def test_unmapped_not_type_error(self): 437 """extension version of the same test in test_mapper. 438 439 fixes #3408 440 """ 441 assert_raises_message( 442 sa.exc.ArgumentError, 443 "Class object expected, got '5'.", 444 class_mapper, 5 445 ) 446 447 def test_unmapped_not_type_error_iter_ok(self): 448 """extension version of the same test in test_mapper. 449 450 fixes #3408 451 """ 452 assert_raises_message( 453 sa.exc.ArgumentError, 454 r"Class object expected, got '\(5, 6\)'.", 455 class_mapper, (5, 6) 456 ) 457 458 459class FinderTest(_ExtBase, fixtures.ORMTest): 460 461 def test_standard(self): 462 class A(object): 463 pass 464 465 register_class(A) 466 467 eq_( 468 type(manager_of_class(A)), 469 instrumentation.ClassManager) 470 471 def test_nativeext_interfaceexact(self): 472 class A(object): 473 __sa_instrumentation_manager__ = \ 474 instrumentation.InstrumentationManager 475 476 register_class(A) 477 ne_( 478 type(manager_of_class(A)), 479 instrumentation.ClassManager) 480 481 def test_nativeext_submanager(self): 482 class Mine(instrumentation.ClassManager): 483 pass 484 485 class A(object): 486 __sa_instrumentation_manager__ = Mine 487 488 register_class(A) 489 eq_(type(manager_of_class(A)), Mine) 490 491 @modifies_instrumentation_finders 492 def test_customfinder_greedy(self): 493 class Mine(instrumentation.ClassManager): 494 pass 495 496 class A(object): 497 pass 498 499 def find(cls): 500 return Mine 501 502 instrumentation.instrumentation_finders.insert(0, find) 503 register_class(A) 504 eq_(type(manager_of_class(A)), Mine) 505 506 @modifies_instrumentation_finders 507 def test_customfinder_pass(self): 508 class A(object): 509 pass 510 511 def find(cls): 512 return None 513 514 instrumentation.instrumentation_finders.insert(0, find) 515 register_class(A) 516 517 eq_( 518 type(manager_of_class(A)), 519 instrumentation.ClassManager) 520 521 522class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest): 523 524 def test_none(self): 525 class A(object): 526 pass 527 register_class(A) 528 529 mgr_factory = lambda cls: instrumentation.ClassManager(cls) 530 531 class B(object): 532 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 533 register_class(B) 534 535 class C(object): 536 __sa_instrumentation_manager__ = instrumentation.ClassManager 537 register_class(C) 538 539 def test_single_down(self): 540 class A(object): 541 pass 542 register_class(A) 543 544 mgr_factory = lambda cls: instrumentation.ClassManager(cls) 545 546 class B(A): 547 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 548 549 assert_raises_message( 550 TypeError, "multiple instrumentation implementations", 551 register_class, B) 552 553 def test_single_up(self): 554 555 class A(object): 556 pass 557 # delay registration 558 559 mgr_factory = lambda cls: instrumentation.ClassManager(cls) 560 561 class B(A): 562 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 563 register_class(B) 564 565 assert_raises_message( 566 TypeError, "multiple instrumentation implementations", 567 register_class, A) 568 569 def test_diamond_b1(self): 570 mgr_factory = lambda cls: instrumentation.ClassManager(cls) 571 572 class A(object): 573 pass 574 575 class B1(A): 576 pass 577 578 class B2(A): 579 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 580 581 class C(object): 582 pass 583 584 assert_raises_message( 585 TypeError, "multiple instrumentation implementations", 586 register_class, B1) 587 588 def test_diamond_b2(self): 589 mgr_factory = lambda cls: instrumentation.ClassManager(cls) 590 591 class A(object): 592 pass 593 594 class B1(A): 595 pass 596 597 class B2(A): 598 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 599 600 class C(object): 601 pass 602 603 register_class(B2) 604 assert_raises_message( 605 TypeError, "multiple instrumentation implementations", 606 register_class, B1) 607 608 def test_diamond_c_b(self): 609 mgr_factory = lambda cls: instrumentation.ClassManager(cls) 610 611 class A(object): 612 pass 613 614 class B1(A): 615 pass 616 617 class B2(A): 618 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 619 620 class C(object): 621 pass 622 623 register_class(C) 624 625 assert_raises_message( 626 TypeError, "multiple instrumentation implementations", 627 register_class, B1) 628 629 630class ExtendedEventsTest(_ExtBase, fixtures.ORMTest): 631 632 """Allow custom Events implementations.""" 633 634 @modifies_instrumentation_finders 635 def test_subclassed(self): 636 class MyEvents(events.InstanceEvents): 637 pass 638 639 class MyClassManager(instrumentation.ClassManager): 640 dispatch = event.dispatcher(MyEvents) 641 642 instrumentation.instrumentation_finders.insert( 643 0, lambda cls: MyClassManager) 644 645 class A(object): 646 pass 647 648 register_class(A) 649 manager = instrumentation.manager_of_class(A) 650 assert issubclass(manager.dispatch._events, MyEvents) 651