1import sqlalchemy as sa 2from sqlalchemy import event 3from sqlalchemy import util 4from sqlalchemy.ext import instrumentation 5from sqlalchemy.orm import attributes 6from sqlalchemy.orm import class_mapper 7from sqlalchemy.orm import clear_mappers 8from sqlalchemy.orm import events 9from sqlalchemy.orm.attributes import del_attribute 10from sqlalchemy.orm.attributes import get_attribute 11from sqlalchemy.orm.attributes import set_attribute 12from sqlalchemy.orm.instrumentation import is_instrumented 13from sqlalchemy.orm.instrumentation import manager_of_class 14from sqlalchemy.orm.instrumentation import register_class 15from sqlalchemy.testing import assert_raises 16from sqlalchemy.testing import assert_raises_message 17from sqlalchemy.testing import eq_ 18from sqlalchemy.testing import fixtures 19from sqlalchemy.testing import ne_ 20from sqlalchemy.testing.util import decorator 21 22 23@decorator 24def modifies_instrumentation_finders(fn, *args, **kw): 25 pristine = instrumentation.instrumentation_finders[:] 26 try: 27 fn(*args, **kw) 28 finally: 29 del instrumentation.instrumentation_finders[:] 30 instrumentation.instrumentation_finders.extend(pristine) 31 32 33class _ExtBase(object): 34 @classmethod 35 def teardown_class(cls): 36 instrumentation._reinstall_default_lookups() 37 38 39class MyTypesManager(instrumentation.InstrumentationManager): 40 def instrument_attribute(self, class_, key, attr): 41 pass 42 43 def install_descriptor(self, class_, key, attr): 44 pass 45 46 def uninstall_descriptor(self, class_, key): 47 pass 48 49 def instrument_collection_class(self, class_, key, collection_class): 50 return MyListLike 51 52 def get_instance_dict(self, class_, instance): 53 return instance._goofy_dict 54 55 def initialize_instance_dict(self, class_, instance): 56 instance.__dict__["_goofy_dict"] = {} 57 58 def install_state(self, class_, instance, state): 59 instance.__dict__["_my_state"] = state 60 61 def state_getter(self, class_): 62 return lambda instance: instance.__dict__["_my_state"] 63 64 65class MyListLike(list): 66 # add @appender, @remover decorators as needed 67 _sa_iterator = list.__iter__ 68 _sa_linker = None 69 _sa_converter = None 70 71 def _sa_appender(self, item, _sa_initiator=None): 72 if _sa_initiator is not False: 73 self._sa_adapter.fire_append_event(item, _sa_initiator) 74 list.append(self, item) 75 76 append = _sa_appender 77 78 def _sa_remover(self, item, _sa_initiator=None): 79 self._sa_adapter.fire_pre_remove_event(_sa_initiator) 80 if _sa_initiator is not False: 81 self._sa_adapter.fire_remove_event(item, _sa_initiator) 82 list.remove(self, item) 83 84 remove = _sa_remover 85 86 87MyBaseClass, MyClass = None, None 88 89 90class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): 91 @classmethod 92 def setup_class(cls): 93 global MyBaseClass, MyClass 94 95 class MyBaseClass(object): 96 __sa_instrumentation_manager__ = ( 97 instrumentation.InstrumentationManager 98 ) 99 100 class MyClass(object): 101 102 # This proves that a staticmethod will work here; don't 103 # flatten this back to a class assignment! 104 def __sa_instrumentation_manager__(cls): 105 return MyTypesManager(cls) 106 107 __sa_instrumentation_manager__ = staticmethod( 108 __sa_instrumentation_manager__ 109 ) 110 111 # This proves SA can handle a class with non-string dict keys 112 if not util.pypy and not util.jython: 113 locals()[42] = 99 # Don't remove this line! 114 115 def __init__(self, **kwargs): 116 for k in kwargs: 117 setattr(self, k, kwargs[k]) 118 119 def __getattr__(self, key): 120 if is_instrumented(self, key): 121 return get_attribute(self, key) 122 else: 123 try: 124 return self._goofy_dict[key] 125 except KeyError: 126 raise AttributeError(key) 127 128 def __setattr__(self, key, value): 129 if is_instrumented(self, key): 130 set_attribute(self, key, value) 131 else: 132 self._goofy_dict[key] = value 133 134 def __hasattr__(self, key): 135 if is_instrumented(self, key): 136 return True 137 else: 138 return key in self._goofy_dict 139 140 def __delattr__(self, key): 141 if is_instrumented(self, key): 142 del_attribute(self, key) 143 else: 144 del self._goofy_dict[key] 145 146 def teardown(self): 147 clear_mappers() 148 149 def test_instance_dict(self): 150 class User(MyClass): 151 pass 152 153 register_class(User) 154 attributes.register_attribute( 155 User, "user_id", uselist=False, useobject=False 156 ) 157 attributes.register_attribute( 158 User, "user_name", uselist=False, useobject=False 159 ) 160 attributes.register_attribute( 161 User, "email_address", uselist=False, useobject=False 162 ) 163 164 u = User() 165 u.user_id = 7 166 u.user_name = "john" 167 u.email_address = "lala@123.com" 168 eq_( 169 u.__dict__, 170 { 171 "_my_state": u._my_state, 172 "_goofy_dict": { 173 "user_id": 7, 174 "user_name": "john", 175 "email_address": "lala@123.com", 176 }, 177 }, 178 ) 179 180 def test_basic(self): 181 for base in (object, MyBaseClass, MyClass): 182 183 class User(base): 184 pass 185 186 register_class(User) 187 attributes.register_attribute( 188 User, "user_id", uselist=False, useobject=False 189 ) 190 attributes.register_attribute( 191 User, "user_name", uselist=False, useobject=False 192 ) 193 attributes.register_attribute( 194 User, "email_address", uselist=False, useobject=False 195 ) 196 197 u = User() 198 u.user_id = 7 199 u.user_name = "john" 200 u.email_address = "lala@123.com" 201 202 eq_(u.user_id, 7) 203 eq_(u.user_name, "john") 204 eq_(u.email_address, "lala@123.com") 205 attributes.instance_state(u)._commit_all( 206 attributes.instance_dict(u) 207 ) 208 eq_(u.user_id, 7) 209 eq_(u.user_name, "john") 210 eq_(u.email_address, "lala@123.com") 211 212 u.user_name = "heythere" 213 u.email_address = "foo@bar.com" 214 eq_(u.user_id, 7) 215 eq_(u.user_name, "heythere") 216 eq_(u.email_address, "foo@bar.com") 217 218 def test_deferred(self): 219 for base in (object, MyBaseClass, MyClass): 220 221 class Foo(base): 222 pass 223 224 data = {"a": "this is a", "b": 12} 225 226 def loader(state, keys): 227 for k in keys: 228 state.dict[k] = data[k] 229 return attributes.ATTR_WAS_SET 230 231 manager = register_class(Foo) 232 manager.deferred_scalar_loader = loader 233 attributes.register_attribute( 234 Foo, "a", uselist=False, useobject=False 235 ) 236 attributes.register_attribute( 237 Foo, "b", uselist=False, useobject=False 238 ) 239 240 if base is object: 241 assert Foo not in ( 242 instrumentation._instrumentation_factory._state_finders 243 ) 244 else: 245 assert Foo in ( 246 instrumentation._instrumentation_factory._state_finders 247 ) 248 249 f = Foo() 250 attributes.instance_state(f)._expire( 251 attributes.instance_dict(f), set() 252 ) 253 eq_(f.a, "this is a") 254 eq_(f.b, 12) 255 256 f.a = "this is some new a" 257 attributes.instance_state(f)._expire( 258 attributes.instance_dict(f), set() 259 ) 260 eq_(f.a, "this is a") 261 eq_(f.b, 12) 262 263 attributes.instance_state(f)._expire( 264 attributes.instance_dict(f), set() 265 ) 266 f.a = "this is another new a" 267 eq_(f.a, "this is another new a") 268 eq_(f.b, 12) 269 270 attributes.instance_state(f)._expire( 271 attributes.instance_dict(f), set() 272 ) 273 eq_(f.a, "this is a") 274 eq_(f.b, 12) 275 276 del f.a 277 eq_(f.a, None) 278 eq_(f.b, 12) 279 280 attributes.instance_state(f)._commit_all( 281 attributes.instance_dict(f) 282 ) 283 eq_(f.a, None) 284 eq_(f.b, 12) 285 286 def test_inheritance(self): 287 """tests that attributes are polymorphic""" 288 289 for base in (object, MyBaseClass, MyClass): 290 291 class Foo(base): 292 pass 293 294 class Bar(Foo): 295 pass 296 297 register_class(Foo) 298 register_class(Bar) 299 300 def func1(state, passive): 301 return "this is the foo attr" 302 303 def func2(state, passive): 304 return "this is the bar attr" 305 306 def func3(state, passive): 307 return "this is the shared attr" 308 309 attributes.register_attribute( 310 Foo, "element", uselist=False, callable_=func1, useobject=True 311 ) 312 attributes.register_attribute( 313 Foo, "element2", uselist=False, callable_=func3, useobject=True 314 ) 315 attributes.register_attribute( 316 Bar, "element", uselist=False, callable_=func2, useobject=True 317 ) 318 319 x = Foo() 320 y = Bar() 321 assert x.element == "this is the foo attr" 322 assert y.element == "this is the bar attr", y.element 323 assert x.element2 == "this is the shared attr" 324 assert y.element2 == "this is the shared attr" 325 326 def test_collection_with_backref(self): 327 for base in (object, MyBaseClass, MyClass): 328 329 class Post(base): 330 pass 331 332 class Blog(base): 333 pass 334 335 register_class(Post) 336 register_class(Blog) 337 attributes.register_attribute( 338 Post, 339 "blog", 340 uselist=False, 341 backref="posts", 342 trackparent=True, 343 useobject=True, 344 ) 345 attributes.register_attribute( 346 Blog, 347 "posts", 348 uselist=True, 349 backref="blog", 350 trackparent=True, 351 useobject=True, 352 ) 353 b = Blog() 354 (p1, p2, p3) = (Post(), Post(), Post()) 355 b.posts.append(p1) 356 b.posts.append(p2) 357 b.posts.append(p3) 358 self.assert_(b.posts == [p1, p2, p3]) 359 self.assert_(p2.blog is b) 360 361 p3.blog = None 362 self.assert_(b.posts == [p1, p2]) 363 p4 = Post() 364 p4.blog = b 365 self.assert_(b.posts == [p1, p2, p4]) 366 367 p4.blog = b 368 p4.blog = b 369 self.assert_(b.posts == [p1, p2, p4]) 370 371 # assert no failure removing None 372 p5 = Post() 373 p5.blog = None 374 del p5.blog 375 376 def test_history(self): 377 for base in (object, MyBaseClass, MyClass): 378 379 class Foo(base): 380 pass 381 382 class Bar(base): 383 pass 384 385 register_class(Foo) 386 register_class(Bar) 387 attributes.register_attribute( 388 Foo, "name", uselist=False, useobject=False 389 ) 390 attributes.register_attribute( 391 Foo, "bars", uselist=True, trackparent=True, useobject=True 392 ) 393 attributes.register_attribute( 394 Bar, "name", uselist=False, useobject=False 395 ) 396 397 f1 = Foo() 398 f1.name = "f1" 399 400 eq_( 401 attributes.get_state_history( 402 attributes.instance_state(f1), "name" 403 ), 404 (["f1"], (), ()), 405 ) 406 407 b1 = Bar() 408 b1.name = "b1" 409 f1.bars.append(b1) 410 eq_( 411 attributes.get_state_history( 412 attributes.instance_state(f1), "bars" 413 ), 414 ([b1], [], []), 415 ) 416 417 attributes.instance_state(f1)._commit_all( 418 attributes.instance_dict(f1) 419 ) 420 attributes.instance_state(b1)._commit_all( 421 attributes.instance_dict(b1) 422 ) 423 424 eq_( 425 attributes.get_state_history( 426 attributes.instance_state(f1), "name" 427 ), 428 ((), ["f1"], ()), 429 ) 430 eq_( 431 attributes.get_state_history( 432 attributes.instance_state(f1), "bars" 433 ), 434 ((), [b1], ()), 435 ) 436 437 f1.name = "f1mod" 438 b2 = Bar() 439 b2.name = "b2" 440 f1.bars.append(b2) 441 eq_( 442 attributes.get_state_history( 443 attributes.instance_state(f1), "name" 444 ), 445 (["f1mod"], (), ["f1"]), 446 ) 447 eq_( 448 attributes.get_state_history( 449 attributes.instance_state(f1), "bars" 450 ), 451 ([b2], [b1], []), 452 ) 453 f1.bars.remove(b1) 454 eq_( 455 attributes.get_state_history( 456 attributes.instance_state(f1), "bars" 457 ), 458 ([b2], [], [b1]), 459 ) 460 461 def test_null_instrumentation(self): 462 class Foo(MyBaseClass): 463 pass 464 465 register_class(Foo) 466 attributes.register_attribute( 467 Foo, "name", uselist=False, useobject=False 468 ) 469 attributes.register_attribute( 470 Foo, "bars", uselist=True, trackparent=True, useobject=True 471 ) 472 473 assert Foo.name == attributes.manager_of_class(Foo)["name"] 474 assert Foo.bars == attributes.manager_of_class(Foo)["bars"] 475 476 def test_alternate_finders(self): 477 """Ensure the generic finder front-end deals with edge cases.""" 478 479 class Unknown(object): 480 pass 481 482 class Known(MyBaseClass): 483 pass 484 485 register_class(Known) 486 k, u = Known(), Unknown() 487 488 assert instrumentation.manager_of_class(Unknown) is None 489 assert instrumentation.manager_of_class(Known) is not None 490 assert instrumentation.manager_of_class(None) is None 491 492 assert attributes.instance_state(k) is not None 493 assert_raises((AttributeError, KeyError), attributes.instance_state, u) 494 assert_raises( 495 (AttributeError, KeyError), attributes.instance_state, None 496 ) 497 498 def test_unmapped_not_type_error(self): 499 """extension version of the same test in test_mapper. 500 501 fixes #3408 502 """ 503 assert_raises_message( 504 sa.exc.ArgumentError, 505 "Class object expected, got '5'.", 506 class_mapper, 507 5, 508 ) 509 510 def test_unmapped_not_type_error_iter_ok(self): 511 """extension version of the same test in test_mapper. 512 513 fixes #3408 514 """ 515 assert_raises_message( 516 sa.exc.ArgumentError, 517 r"Class object expected, got '\(5, 6\)'.", 518 class_mapper, 519 (5, 6), 520 ) 521 522 523class FinderTest(_ExtBase, fixtures.ORMTest): 524 def test_standard(self): 525 class A(object): 526 pass 527 528 register_class(A) 529 530 eq_(type(manager_of_class(A)), instrumentation.ClassManager) 531 532 def test_nativeext_interfaceexact(self): 533 class A(object): 534 __sa_instrumentation_manager__ = ( 535 instrumentation.InstrumentationManager 536 ) 537 538 register_class(A) 539 ne_(type(manager_of_class(A)), instrumentation.ClassManager) 540 541 def test_nativeext_submanager(self): 542 class Mine(instrumentation.ClassManager): 543 pass 544 545 class A(object): 546 __sa_instrumentation_manager__ = Mine 547 548 register_class(A) 549 eq_(type(manager_of_class(A)), Mine) 550 551 @modifies_instrumentation_finders 552 def test_customfinder_greedy(self): 553 class Mine(instrumentation.ClassManager): 554 pass 555 556 class A(object): 557 pass 558 559 def find(cls): 560 return Mine 561 562 instrumentation.instrumentation_finders.insert(0, find) 563 register_class(A) 564 eq_(type(manager_of_class(A)), Mine) 565 566 @modifies_instrumentation_finders 567 def test_customfinder_pass(self): 568 class A(object): 569 pass 570 571 def find(cls): 572 return None 573 574 instrumentation.instrumentation_finders.insert(0, find) 575 register_class(A) 576 577 eq_(type(manager_of_class(A)), instrumentation.ClassManager) 578 579 580class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest): 581 def test_none(self): 582 class A(object): 583 pass 584 585 register_class(A) 586 587 def mgr_factory(cls): 588 return instrumentation.ClassManager(cls) 589 590 class B(object): 591 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 592 593 register_class(B) 594 595 class C(object): 596 __sa_instrumentation_manager__ = instrumentation.ClassManager 597 598 register_class(C) 599 600 def test_single_down(self): 601 class A(object): 602 pass 603 604 register_class(A) 605 606 def mgr_factory(cls): 607 return instrumentation.ClassManager(cls) 608 609 class B(A): 610 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 611 612 assert_raises_message( 613 TypeError, 614 "multiple instrumentation implementations", 615 register_class, 616 B, 617 ) 618 619 def test_single_up(self): 620 class A(object): 621 pass 622 623 # delay registration 624 625 def mgr_factory(cls): 626 return instrumentation.ClassManager(cls) 627 628 class B(A): 629 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 630 631 register_class(B) 632 633 assert_raises_message( 634 TypeError, 635 "multiple instrumentation implementations", 636 register_class, 637 A, 638 ) 639 640 def test_diamond_b1(self): 641 def mgr_factory(cls): 642 return instrumentation.ClassManager(cls) 643 644 class A(object): 645 pass 646 647 class B1(A): 648 pass 649 650 class B2(A): 651 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 652 653 class C(object): 654 pass 655 656 assert_raises_message( 657 TypeError, 658 "multiple instrumentation implementations", 659 register_class, 660 B1, 661 ) 662 663 def test_diamond_b2(self): 664 def mgr_factory(cls): 665 return instrumentation.ClassManager(cls) 666 667 class A(object): 668 pass 669 670 class B1(A): 671 pass 672 673 class B2(A): 674 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 675 676 class C(object): 677 pass 678 679 register_class(B2) 680 assert_raises_message( 681 TypeError, 682 "multiple instrumentation implementations", 683 register_class, 684 B1, 685 ) 686 687 def test_diamond_c_b(self): 688 def mgr_factory(cls): 689 return instrumentation.ClassManager(cls) 690 691 class A(object): 692 pass 693 694 class B1(A): 695 pass 696 697 class B2(A): 698 __sa_instrumentation_manager__ = staticmethod(mgr_factory) 699 700 class C(object): 701 pass 702 703 register_class(C) 704 705 assert_raises_message( 706 TypeError, 707 "multiple instrumentation implementations", 708 register_class, 709 B1, 710 ) 711 712 713class ExtendedEventsTest(_ExtBase, fixtures.ORMTest): 714 715 """Allow custom Events implementations.""" 716 717 @modifies_instrumentation_finders 718 def test_subclassed(self): 719 class MyEvents(events.InstanceEvents): 720 pass 721 722 class MyClassManager(instrumentation.ClassManager): 723 dispatch = event.dispatcher(MyEvents) 724 725 instrumentation.instrumentation_finders.insert( 726 0, lambda cls: MyClassManager 727 ) 728 729 class A(object): 730 pass 731 732 register_class(A) 733 manager = instrumentation.manager_of_class(A) 734 assert issubclass(manager.dispatch._events, MyEvents) 735