1import sqlalchemy as sa
2from sqlalchemy import event
3from sqlalchemy import util
4from sqlalchemy.ext import instrumentation
5from sqlalchemy.orm import attributes
6from sqlalchemy.orm import class_mapper
7from sqlalchemy.orm import clear_mappers
8from sqlalchemy.orm import events
9from sqlalchemy.orm.attributes import del_attribute
10from sqlalchemy.orm.attributes import get_attribute
11from sqlalchemy.orm.attributes import set_attribute
12from sqlalchemy.orm.instrumentation import is_instrumented
13from sqlalchemy.orm.instrumentation import manager_of_class
14from sqlalchemy.orm.instrumentation import register_class
15from sqlalchemy.testing import assert_raises
16from sqlalchemy.testing import assert_raises_message
17from sqlalchemy.testing import eq_
18from sqlalchemy.testing import fixtures
19from sqlalchemy.testing import ne_
20from sqlalchemy.testing.util import decorator
21
22
23@decorator
24def modifies_instrumentation_finders(fn, *args, **kw):
25    pristine = instrumentation.instrumentation_finders[:]
26    try:
27        fn(*args, **kw)
28    finally:
29        del instrumentation.instrumentation_finders[:]
30        instrumentation.instrumentation_finders.extend(pristine)
31
32
33class _ExtBase(object):
34    @classmethod
35    def teardown_class(cls):
36        instrumentation._reinstall_default_lookups()
37
38
39class MyTypesManager(instrumentation.InstrumentationManager):
40    def instrument_attribute(self, class_, key, attr):
41        pass
42
43    def install_descriptor(self, class_, key, attr):
44        pass
45
46    def uninstall_descriptor(self, class_, key):
47        pass
48
49    def instrument_collection_class(self, class_, key, collection_class):
50        return MyListLike
51
52    def get_instance_dict(self, class_, instance):
53        return instance._goofy_dict
54
55    def initialize_instance_dict(self, class_, instance):
56        instance.__dict__["_goofy_dict"] = {}
57
58    def install_state(self, class_, instance, state):
59        instance.__dict__["_my_state"] = state
60
61    def state_getter(self, class_):
62        return lambda instance: instance.__dict__["_my_state"]
63
64
65class MyListLike(list):
66    # add @appender, @remover decorators as needed
67    _sa_iterator = list.__iter__
68    _sa_linker = None
69    _sa_converter = None
70
71    def _sa_appender(self, item, _sa_initiator=None):
72        if _sa_initiator is not False:
73            self._sa_adapter.fire_append_event(item, _sa_initiator)
74        list.append(self, item)
75
76    append = _sa_appender
77
78    def _sa_remover(self, item, _sa_initiator=None):
79        self._sa_adapter.fire_pre_remove_event(_sa_initiator)
80        if _sa_initiator is not False:
81            self._sa_adapter.fire_remove_event(item, _sa_initiator)
82        list.remove(self, item)
83
84    remove = _sa_remover
85
86
87MyBaseClass, MyClass = None, None
88
89
90class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
91    @classmethod
92    def setup_class(cls):
93        global MyBaseClass, MyClass
94
95        class MyBaseClass(object):
96            __sa_instrumentation_manager__ = (
97                instrumentation.InstrumentationManager
98            )
99
100        class MyClass(object):
101
102            # This proves that a staticmethod will work here; don't
103            # flatten this back to a class assignment!
104            def __sa_instrumentation_manager__(cls):
105                return MyTypesManager(cls)
106
107            __sa_instrumentation_manager__ = staticmethod(
108                __sa_instrumentation_manager__
109            )
110
111            # This proves SA can handle a class with non-string dict keys
112            if not util.pypy and not util.jython:
113                locals()[42] = 99  # Don't remove this line!
114
115            def __init__(self, **kwargs):
116                for k in kwargs:
117                    setattr(self, k, kwargs[k])
118
119            def __getattr__(self, key):
120                if is_instrumented(self, key):
121                    return get_attribute(self, key)
122                else:
123                    try:
124                        return self._goofy_dict[key]
125                    except KeyError:
126                        raise AttributeError(key)
127
128            def __setattr__(self, key, value):
129                if is_instrumented(self, key):
130                    set_attribute(self, key, value)
131                else:
132                    self._goofy_dict[key] = value
133
134            def __hasattr__(self, key):
135                if is_instrumented(self, key):
136                    return True
137                else:
138                    return key in self._goofy_dict
139
140            def __delattr__(self, key):
141                if is_instrumented(self, key):
142                    del_attribute(self, key)
143                else:
144                    del self._goofy_dict[key]
145
146    def teardown(self):
147        clear_mappers()
148
149    def test_instance_dict(self):
150        class User(MyClass):
151            pass
152
153        register_class(User)
154        attributes.register_attribute(
155            User, "user_id", uselist=False, useobject=False
156        )
157        attributes.register_attribute(
158            User, "user_name", uselist=False, useobject=False
159        )
160        attributes.register_attribute(
161            User, "email_address", uselist=False, useobject=False
162        )
163
164        u = User()
165        u.user_id = 7
166        u.user_name = "john"
167        u.email_address = "lala@123.com"
168        eq_(
169            u.__dict__,
170            {
171                "_my_state": u._my_state,
172                "_goofy_dict": {
173                    "user_id": 7,
174                    "user_name": "john",
175                    "email_address": "lala@123.com",
176                },
177            },
178        )
179
180    def test_basic(self):
181        for base in (object, MyBaseClass, MyClass):
182
183            class User(base):
184                pass
185
186            register_class(User)
187            attributes.register_attribute(
188                User, "user_id", uselist=False, useobject=False
189            )
190            attributes.register_attribute(
191                User, "user_name", uselist=False, useobject=False
192            )
193            attributes.register_attribute(
194                User, "email_address", uselist=False, useobject=False
195            )
196
197            u = User()
198            u.user_id = 7
199            u.user_name = "john"
200            u.email_address = "lala@123.com"
201
202            eq_(u.user_id, 7)
203            eq_(u.user_name, "john")
204            eq_(u.email_address, "lala@123.com")
205            attributes.instance_state(u)._commit_all(
206                attributes.instance_dict(u)
207            )
208            eq_(u.user_id, 7)
209            eq_(u.user_name, "john")
210            eq_(u.email_address, "lala@123.com")
211
212            u.user_name = "heythere"
213            u.email_address = "foo@bar.com"
214            eq_(u.user_id, 7)
215            eq_(u.user_name, "heythere")
216            eq_(u.email_address, "foo@bar.com")
217
218    def test_deferred(self):
219        for base in (object, MyBaseClass, MyClass):
220
221            class Foo(base):
222                pass
223
224            data = {"a": "this is a", "b": 12}
225
226            def loader(state, keys):
227                for k in keys:
228                    state.dict[k] = data[k]
229                return attributes.ATTR_WAS_SET
230
231            manager = register_class(Foo)
232            manager.deferred_scalar_loader = loader
233            attributes.register_attribute(
234                Foo, "a", uselist=False, useobject=False
235            )
236            attributes.register_attribute(
237                Foo, "b", uselist=False, useobject=False
238            )
239
240            if base is object:
241                assert Foo not in (
242                    instrumentation._instrumentation_factory._state_finders
243                )
244            else:
245                assert Foo in (
246                    instrumentation._instrumentation_factory._state_finders
247                )
248
249            f = Foo()
250            attributes.instance_state(f)._expire(
251                attributes.instance_dict(f), set()
252            )
253            eq_(f.a, "this is a")
254            eq_(f.b, 12)
255
256            f.a = "this is some new a"
257            attributes.instance_state(f)._expire(
258                attributes.instance_dict(f), set()
259            )
260            eq_(f.a, "this is a")
261            eq_(f.b, 12)
262
263            attributes.instance_state(f)._expire(
264                attributes.instance_dict(f), set()
265            )
266            f.a = "this is another new a"
267            eq_(f.a, "this is another new a")
268            eq_(f.b, 12)
269
270            attributes.instance_state(f)._expire(
271                attributes.instance_dict(f), set()
272            )
273            eq_(f.a, "this is a")
274            eq_(f.b, 12)
275
276            del f.a
277            eq_(f.a, None)
278            eq_(f.b, 12)
279
280            attributes.instance_state(f)._commit_all(
281                attributes.instance_dict(f)
282            )
283            eq_(f.a, None)
284            eq_(f.b, 12)
285
286    def test_inheritance(self):
287        """tests that attributes are polymorphic"""
288
289        for base in (object, MyBaseClass, MyClass):
290
291            class Foo(base):
292                pass
293
294            class Bar(Foo):
295                pass
296
297            register_class(Foo)
298            register_class(Bar)
299
300            def func1(state, passive):
301                return "this is the foo attr"
302
303            def func2(state, passive):
304                return "this is the bar attr"
305
306            def func3(state, passive):
307                return "this is the shared attr"
308
309            attributes.register_attribute(
310                Foo, "element", uselist=False, callable_=func1, useobject=True
311            )
312            attributes.register_attribute(
313                Foo, "element2", uselist=False, callable_=func3, useobject=True
314            )
315            attributes.register_attribute(
316                Bar, "element", uselist=False, callable_=func2, useobject=True
317            )
318
319            x = Foo()
320            y = Bar()
321            assert x.element == "this is the foo attr"
322            assert y.element == "this is the bar attr", y.element
323            assert x.element2 == "this is the shared attr"
324            assert y.element2 == "this is the shared attr"
325
326    def test_collection_with_backref(self):
327        for base in (object, MyBaseClass, MyClass):
328
329            class Post(base):
330                pass
331
332            class Blog(base):
333                pass
334
335            register_class(Post)
336            register_class(Blog)
337            attributes.register_attribute(
338                Post,
339                "blog",
340                uselist=False,
341                backref="posts",
342                trackparent=True,
343                useobject=True,
344            )
345            attributes.register_attribute(
346                Blog,
347                "posts",
348                uselist=True,
349                backref="blog",
350                trackparent=True,
351                useobject=True,
352            )
353            b = Blog()
354            (p1, p2, p3) = (Post(), Post(), Post())
355            b.posts.append(p1)
356            b.posts.append(p2)
357            b.posts.append(p3)
358            self.assert_(b.posts == [p1, p2, p3])
359            self.assert_(p2.blog is b)
360
361            p3.blog = None
362            self.assert_(b.posts == [p1, p2])
363            p4 = Post()
364            p4.blog = b
365            self.assert_(b.posts == [p1, p2, p4])
366
367            p4.blog = b
368            p4.blog = b
369            self.assert_(b.posts == [p1, p2, p4])
370
371            # assert no failure removing None
372            p5 = Post()
373            p5.blog = None
374            del p5.blog
375
376    def test_history(self):
377        for base in (object, MyBaseClass, MyClass):
378
379            class Foo(base):
380                pass
381
382            class Bar(base):
383                pass
384
385            register_class(Foo)
386            register_class(Bar)
387            attributes.register_attribute(
388                Foo, "name", uselist=False, useobject=False
389            )
390            attributes.register_attribute(
391                Foo, "bars", uselist=True, trackparent=True, useobject=True
392            )
393            attributes.register_attribute(
394                Bar, "name", uselist=False, useobject=False
395            )
396
397            f1 = Foo()
398            f1.name = "f1"
399
400            eq_(
401                attributes.get_state_history(
402                    attributes.instance_state(f1), "name"
403                ),
404                (["f1"], (), ()),
405            )
406
407            b1 = Bar()
408            b1.name = "b1"
409            f1.bars.append(b1)
410            eq_(
411                attributes.get_state_history(
412                    attributes.instance_state(f1), "bars"
413                ),
414                ([b1], [], []),
415            )
416
417            attributes.instance_state(f1)._commit_all(
418                attributes.instance_dict(f1)
419            )
420            attributes.instance_state(b1)._commit_all(
421                attributes.instance_dict(b1)
422            )
423
424            eq_(
425                attributes.get_state_history(
426                    attributes.instance_state(f1), "name"
427                ),
428                ((), ["f1"], ()),
429            )
430            eq_(
431                attributes.get_state_history(
432                    attributes.instance_state(f1), "bars"
433                ),
434                ((), [b1], ()),
435            )
436
437            f1.name = "f1mod"
438            b2 = Bar()
439            b2.name = "b2"
440            f1.bars.append(b2)
441            eq_(
442                attributes.get_state_history(
443                    attributes.instance_state(f1), "name"
444                ),
445                (["f1mod"], (), ["f1"]),
446            )
447            eq_(
448                attributes.get_state_history(
449                    attributes.instance_state(f1), "bars"
450                ),
451                ([b2], [b1], []),
452            )
453            f1.bars.remove(b1)
454            eq_(
455                attributes.get_state_history(
456                    attributes.instance_state(f1), "bars"
457                ),
458                ([b2], [], [b1]),
459            )
460
461    def test_null_instrumentation(self):
462        class Foo(MyBaseClass):
463            pass
464
465        register_class(Foo)
466        attributes.register_attribute(
467            Foo, "name", uselist=False, useobject=False
468        )
469        attributes.register_attribute(
470            Foo, "bars", uselist=True, trackparent=True, useobject=True
471        )
472
473        assert Foo.name == attributes.manager_of_class(Foo)["name"]
474        assert Foo.bars == attributes.manager_of_class(Foo)["bars"]
475
476    def test_alternate_finders(self):
477        """Ensure the generic finder front-end deals with edge cases."""
478
479        class Unknown(object):
480            pass
481
482        class Known(MyBaseClass):
483            pass
484
485        register_class(Known)
486        k, u = Known(), Unknown()
487
488        assert instrumentation.manager_of_class(Unknown) is None
489        assert instrumentation.manager_of_class(Known) is not None
490        assert instrumentation.manager_of_class(None) is None
491
492        assert attributes.instance_state(k) is not None
493        assert_raises((AttributeError, KeyError), attributes.instance_state, u)
494        assert_raises(
495            (AttributeError, KeyError), attributes.instance_state, None
496        )
497
498    def test_unmapped_not_type_error(self):
499        """extension version of the same test in test_mapper.
500
501        fixes #3408
502        """
503        assert_raises_message(
504            sa.exc.ArgumentError,
505            "Class object expected, got '5'.",
506            class_mapper,
507            5,
508        )
509
510    def test_unmapped_not_type_error_iter_ok(self):
511        """extension version of the same test in test_mapper.
512
513        fixes #3408
514        """
515        assert_raises_message(
516            sa.exc.ArgumentError,
517            r"Class object expected, got '\(5, 6\)'.",
518            class_mapper,
519            (5, 6),
520        )
521
522
523class FinderTest(_ExtBase, fixtures.ORMTest):
524    def test_standard(self):
525        class A(object):
526            pass
527
528        register_class(A)
529
530        eq_(type(manager_of_class(A)), instrumentation.ClassManager)
531
532    def test_nativeext_interfaceexact(self):
533        class A(object):
534            __sa_instrumentation_manager__ = (
535                instrumentation.InstrumentationManager
536            )
537
538        register_class(A)
539        ne_(type(manager_of_class(A)), instrumentation.ClassManager)
540
541    def test_nativeext_submanager(self):
542        class Mine(instrumentation.ClassManager):
543            pass
544
545        class A(object):
546            __sa_instrumentation_manager__ = Mine
547
548        register_class(A)
549        eq_(type(manager_of_class(A)), Mine)
550
551    @modifies_instrumentation_finders
552    def test_customfinder_greedy(self):
553        class Mine(instrumentation.ClassManager):
554            pass
555
556        class A(object):
557            pass
558
559        def find(cls):
560            return Mine
561
562        instrumentation.instrumentation_finders.insert(0, find)
563        register_class(A)
564        eq_(type(manager_of_class(A)), Mine)
565
566    @modifies_instrumentation_finders
567    def test_customfinder_pass(self):
568        class A(object):
569            pass
570
571        def find(cls):
572            return None
573
574        instrumentation.instrumentation_finders.insert(0, find)
575        register_class(A)
576
577        eq_(type(manager_of_class(A)), instrumentation.ClassManager)
578
579
580class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest):
581    def test_none(self):
582        class A(object):
583            pass
584
585        register_class(A)
586
587        def mgr_factory(cls):
588            return instrumentation.ClassManager(cls)
589
590        class B(object):
591            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
592
593        register_class(B)
594
595        class C(object):
596            __sa_instrumentation_manager__ = instrumentation.ClassManager
597
598        register_class(C)
599
600    def test_single_down(self):
601        class A(object):
602            pass
603
604        register_class(A)
605
606        def mgr_factory(cls):
607            return instrumentation.ClassManager(cls)
608
609        class B(A):
610            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
611
612        assert_raises_message(
613            TypeError,
614            "multiple instrumentation implementations",
615            register_class,
616            B,
617        )
618
619    def test_single_up(self):
620        class A(object):
621            pass
622
623        # delay registration
624
625        def mgr_factory(cls):
626            return instrumentation.ClassManager(cls)
627
628        class B(A):
629            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
630
631        register_class(B)
632
633        assert_raises_message(
634            TypeError,
635            "multiple instrumentation implementations",
636            register_class,
637            A,
638        )
639
640    def test_diamond_b1(self):
641        def mgr_factory(cls):
642            return instrumentation.ClassManager(cls)
643
644        class A(object):
645            pass
646
647        class B1(A):
648            pass
649
650        class B2(A):
651            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
652
653        class C(object):
654            pass
655
656        assert_raises_message(
657            TypeError,
658            "multiple instrumentation implementations",
659            register_class,
660            B1,
661        )
662
663    def test_diamond_b2(self):
664        def mgr_factory(cls):
665            return instrumentation.ClassManager(cls)
666
667        class A(object):
668            pass
669
670        class B1(A):
671            pass
672
673        class B2(A):
674            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
675
676        class C(object):
677            pass
678
679        register_class(B2)
680        assert_raises_message(
681            TypeError,
682            "multiple instrumentation implementations",
683            register_class,
684            B1,
685        )
686
687    def test_diamond_c_b(self):
688        def mgr_factory(cls):
689            return instrumentation.ClassManager(cls)
690
691        class A(object):
692            pass
693
694        class B1(A):
695            pass
696
697        class B2(A):
698            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
699
700        class C(object):
701            pass
702
703        register_class(C)
704
705        assert_raises_message(
706            TypeError,
707            "multiple instrumentation implementations",
708            register_class,
709            B1,
710        )
711
712
713class ExtendedEventsTest(_ExtBase, fixtures.ORMTest):
714
715    """Allow custom Events implementations."""
716
717    @modifies_instrumentation_finders
718    def test_subclassed(self):
719        class MyEvents(events.InstanceEvents):
720            pass
721
722        class MyClassManager(instrumentation.ClassManager):
723            dispatch = event.dispatcher(MyEvents)
724
725        instrumentation.instrumentation_finders.insert(
726            0, lambda cls: MyClassManager
727        )
728
729        class A(object):
730            pass
731
732        register_class(A)
733        manager = instrumentation.manager_of_class(A)
734        assert issubclass(manager.dispatch._events, MyEvents)
735