1from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, ne_
2from sqlalchemy import util
3import sqlalchemy as sa
4from sqlalchemy.orm import class_mapper
5from sqlalchemy.orm import attributes
6from sqlalchemy.orm.attributes import set_attribute, \
7    get_attribute, del_attribute
8from sqlalchemy.orm.instrumentation import is_instrumented
9from sqlalchemy.orm import clear_mappers
10from sqlalchemy.testing import fixtures
11from sqlalchemy.ext import instrumentation
12from sqlalchemy.orm.instrumentation import register_class, manager_of_class
13from sqlalchemy.testing.util import decorator
14from sqlalchemy.orm import events
15from sqlalchemy import event
16
17
18@decorator
19def modifies_instrumentation_finders(fn, *args, **kw):
20    pristine = instrumentation.instrumentation_finders[:]
21    try:
22        fn(*args, **kw)
23    finally:
24        del instrumentation.instrumentation_finders[:]
25        instrumentation.instrumentation_finders.extend(pristine)
26
27
28class _ExtBase(object):
29    @classmethod
30    def teardown_class(cls):
31        instrumentation._reinstall_default_lookups()
32
33
34class MyTypesManager(instrumentation.InstrumentationManager):
35
36    def instrument_attribute(self, class_, key, attr):
37        pass
38
39    def install_descriptor(self, class_, key, attr):
40        pass
41
42    def uninstall_descriptor(self, class_, key):
43        pass
44
45    def instrument_collection_class(self, class_, key, collection_class):
46        return MyListLike
47
48    def get_instance_dict(self, class_, instance):
49        return instance._goofy_dict
50
51    def initialize_instance_dict(self, class_, instance):
52        instance.__dict__['_goofy_dict'] = {}
53
54    def install_state(self, class_, instance, state):
55        instance.__dict__['_my_state'] = state
56
57    def state_getter(self, class_):
58        return lambda instance: instance.__dict__['_my_state']
59
60
61class MyListLike(list):
62    # add @appender, @remover decorators as needed
63    _sa_iterator = list.__iter__
64    _sa_linker = None
65    _sa_converter = None
66
67    def _sa_appender(self, item, _sa_initiator=None):
68        if _sa_initiator is not False:
69            self._sa_adapter.fire_append_event(item, _sa_initiator)
70        list.append(self, item)
71    append = _sa_appender
72
73    def _sa_remover(self, item, _sa_initiator=None):
74        self._sa_adapter.fire_pre_remove_event(_sa_initiator)
75        if _sa_initiator is not False:
76            self._sa_adapter.fire_remove_event(item, _sa_initiator)
77        list.remove(self, item)
78    remove = _sa_remover
79
80
81MyBaseClass, MyClass = None, None
82
83
84class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
85
86    @classmethod
87    def setup_class(cls):
88        global MyBaseClass, MyClass
89
90        class MyBaseClass(object):
91            __sa_instrumentation_manager__ = \
92                instrumentation.InstrumentationManager
93
94        class MyClass(object):
95
96            # This proves that a staticmethod will work here; don't
97            # flatten this back to a class assignment!
98            def __sa_instrumentation_manager__(cls):
99                return MyTypesManager(cls)
100
101            __sa_instrumentation_manager__ = staticmethod(
102                __sa_instrumentation_manager__)
103
104            # This proves SA can handle a class with non-string dict keys
105            if not util.pypy and not util.jython:
106                locals()[42] = 99   # Don't remove this line!
107
108            def __init__(self, **kwargs):
109                for k in kwargs:
110                    setattr(self, k, kwargs[k])
111
112            def __getattr__(self, key):
113                if is_instrumented(self, key):
114                    return get_attribute(self, key)
115                else:
116                    try:
117                        return self._goofy_dict[key]
118                    except KeyError:
119                        raise AttributeError(key)
120
121            def __setattr__(self, key, value):
122                if is_instrumented(self, key):
123                    set_attribute(self, key, value)
124                else:
125                    self._goofy_dict[key] = value
126
127            def __hasattr__(self, key):
128                if is_instrumented(self, key):
129                    return True
130                else:
131                    return key in self._goofy_dict
132
133            def __delattr__(self, key):
134                if is_instrumented(self, key):
135                    del_attribute(self, key)
136                else:
137                    del self._goofy_dict[key]
138
139    def teardown(self):
140        clear_mappers()
141
142    def test_instance_dict(self):
143        class User(MyClass):
144            pass
145
146        register_class(User)
147        attributes.register_attribute(
148            User, 'user_id', uselist=False, useobject=False)
149        attributes.register_attribute(
150            User, 'user_name', uselist=False, useobject=False)
151        attributes.register_attribute(
152            User, 'email_address', uselist=False, useobject=False)
153
154        u = User()
155        u.user_id = 7
156        u.user_name = 'john'
157        u.email_address = 'lala@123.com'
158        eq_(
159            u.__dict__,
160            {
161                '_my_state': u._my_state,
162                '_goofy_dict': {
163                    'user_id': 7, 'user_name': 'john',
164                    'email_address': 'lala@123.com'}}
165        )
166
167    def test_basic(self):
168        for base in (object, MyBaseClass, MyClass):
169            class User(base):
170                pass
171
172            register_class(User)
173            attributes.register_attribute(
174                User, 'user_id', uselist=False, useobject=False)
175            attributes.register_attribute(
176                User, 'user_name', uselist=False, useobject=False)
177            attributes.register_attribute(
178                User, 'email_address', uselist=False, useobject=False)
179
180            u = User()
181            u.user_id = 7
182            u.user_name = 'john'
183            u.email_address = 'lala@123.com'
184
185            eq_(u.user_id, 7)
186            eq_(u.user_name, "john")
187            eq_(u.email_address, "lala@123.com")
188            attributes.instance_state(u)._commit_all(
189                attributes.instance_dict(u))
190            eq_(u.user_id, 7)
191            eq_(u.user_name, "john")
192            eq_(u.email_address, "lala@123.com")
193
194            u.user_name = 'heythere'
195            u.email_address = 'foo@bar.com'
196            eq_(u.user_id, 7)
197            eq_(u.user_name, "heythere")
198            eq_(u.email_address, "foo@bar.com")
199
200    def test_deferred(self):
201        for base in (object, MyBaseClass, MyClass):
202            class Foo(base):
203                pass
204
205            data = {'a': 'this is a', 'b': 12}
206
207            def loader(state, keys):
208                for k in keys:
209                    state.dict[k] = data[k]
210                return attributes.ATTR_WAS_SET
211
212            manager = register_class(Foo)
213            manager.deferred_scalar_loader = loader
214            attributes.register_attribute(
215                Foo, 'a', uselist=False, useobject=False)
216            attributes.register_attribute(
217                Foo, 'b', uselist=False, useobject=False)
218
219            if base is object:
220                assert Foo not in \
221                    instrumentation._instrumentation_factory._state_finders
222            else:
223                assert Foo in \
224                    instrumentation._instrumentation_factory._state_finders
225
226            f = Foo()
227            attributes.instance_state(f)._expire(
228                attributes.instance_dict(f), set())
229            eq_(f.a, "this is a")
230            eq_(f.b, 12)
231
232            f.a = "this is some new a"
233            attributes.instance_state(f)._expire(
234                attributes.instance_dict(f), set())
235            eq_(f.a, "this is a")
236            eq_(f.b, 12)
237
238            attributes.instance_state(f)._expire(
239                attributes.instance_dict(f), set())
240            f.a = "this is another new a"
241            eq_(f.a, "this is another new a")
242            eq_(f.b, 12)
243
244            attributes.instance_state(f)._expire(
245                attributes.instance_dict(f), set())
246            eq_(f.a, "this is a")
247            eq_(f.b, 12)
248
249            del f.a
250            eq_(f.a, None)
251            eq_(f.b, 12)
252
253            attributes.instance_state(f)._commit_all(
254                attributes.instance_dict(f))
255            eq_(f.a, None)
256            eq_(f.b, 12)
257
258    def test_inheritance(self):
259        """tests that attributes are polymorphic"""
260
261        for base in (object, MyBaseClass, MyClass):
262            class Foo(base):
263                pass
264
265            class Bar(Foo):
266                pass
267
268            register_class(Foo)
269            register_class(Bar)
270
271            def func1(state, passive):
272                return "this is the foo attr"
273
274            def func2(state, passive):
275                return "this is the bar attr"
276
277            def func3(state, passive):
278                return "this is the shared attr"
279            attributes.register_attribute(Foo, 'element',
280                                          uselist=False, callable_=func1,
281                                          useobject=True)
282            attributes.register_attribute(Foo, 'element2',
283                                          uselist=False, callable_=func3,
284                                          useobject=True)
285            attributes.register_attribute(Bar, 'element',
286                                          uselist=False, callable_=func2,
287                                          useobject=True)
288
289            x = Foo()
290            y = Bar()
291            assert x.element == 'this is the foo attr'
292            assert y.element == 'this is the bar attr', y.element
293            assert x.element2 == 'this is the shared attr'
294            assert y.element2 == 'this is the shared attr'
295
296    def test_collection_with_backref(self):
297        for base in (object, MyBaseClass, MyClass):
298            class Post(base):
299                pass
300
301            class Blog(base):
302                pass
303
304            register_class(Post)
305            register_class(Blog)
306            attributes.register_attribute(
307                Post, 'blog', uselist=False,
308                backref='posts', trackparent=True, useobject=True)
309            attributes.register_attribute(
310                Blog, 'posts', uselist=True,
311                backref='blog', trackparent=True, useobject=True)
312            b = Blog()
313            (p1, p2, p3) = (Post(), Post(), Post())
314            b.posts.append(p1)
315            b.posts.append(p2)
316            b.posts.append(p3)
317            self.assert_(b.posts == [p1, p2, p3])
318            self.assert_(p2.blog is b)
319
320            p3.blog = None
321            self.assert_(b.posts == [p1, p2])
322            p4 = Post()
323            p4.blog = b
324            self.assert_(b.posts == [p1, p2, p4])
325
326            p4.blog = b
327            p4.blog = b
328            self.assert_(b.posts == [p1, p2, p4])
329
330            # assert no failure removing None
331            p5 = Post()
332            p5.blog = None
333            del p5.blog
334
335    def test_history(self):
336        for base in (object, MyBaseClass, MyClass):
337            class Foo(base):
338                pass
339
340            class Bar(base):
341                pass
342
343            register_class(Foo)
344            register_class(Bar)
345            attributes.register_attribute(
346                Foo, "name", uselist=False, useobject=False)
347            attributes.register_attribute(
348                Foo, "bars", uselist=True, trackparent=True, useobject=True)
349            attributes.register_attribute(
350                Bar, "name", uselist=False, useobject=False)
351
352            f1 = Foo()
353            f1.name = 'f1'
354
355            eq_(
356                attributes.get_state_history(
357                    attributes.instance_state(f1), 'name'),
358                (['f1'], (), ()))
359
360            b1 = Bar()
361            b1.name = 'b1'
362            f1.bars.append(b1)
363            eq_(
364                attributes.get_state_history(
365                    attributes.instance_state(f1), 'bars'),
366                ([b1], [], []))
367
368            attributes.instance_state(f1)._commit_all(
369                attributes.instance_dict(f1))
370            attributes.instance_state(b1)._commit_all(
371                attributes.instance_dict(b1))
372
373            eq_(
374                attributes.get_state_history(
375                    attributes.instance_state(f1),
376                    'name'),
377                ((), ['f1'], ()))
378            eq_(
379                attributes.get_state_history(
380                    attributes.instance_state(f1),
381                    'bars'),
382                ((), [b1], ()))
383
384            f1.name = 'f1mod'
385            b2 = Bar()
386            b2.name = 'b2'
387            f1.bars.append(b2)
388            eq_(
389                attributes.get_state_history(
390                    attributes.instance_state(f1), 'name'),
391                (['f1mod'], (), ['f1']))
392            eq_(
393                attributes.get_state_history(
394                    attributes.instance_state(f1), 'bars'),
395                ([b2], [b1], []))
396            f1.bars.remove(b1)
397            eq_(
398                attributes.get_state_history(
399                    attributes.instance_state(f1), 'bars'),
400                ([b2], [], [b1]))
401
402    def test_null_instrumentation(self):
403        class Foo(MyBaseClass):
404            pass
405        register_class(Foo)
406        attributes.register_attribute(
407            Foo, "name", uselist=False, useobject=False)
408        attributes.register_attribute(
409            Foo, "bars", uselist=True, trackparent=True, useobject=True)
410
411        assert Foo.name == attributes.manager_of_class(Foo)['name']
412        assert Foo.bars == attributes.manager_of_class(Foo)['bars']
413
414    def test_alternate_finders(self):
415        """Ensure the generic finder front-end deals with edge cases."""
416
417        class Unknown(object):
418            pass
419
420        class Known(MyBaseClass):
421            pass
422
423        register_class(Known)
424        k, u = Known(), Unknown()
425
426        assert instrumentation.manager_of_class(Unknown) is None
427        assert instrumentation.manager_of_class(Known) is not None
428        assert instrumentation.manager_of_class(None) is None
429
430        assert attributes.instance_state(k) is not None
431        assert_raises((AttributeError, KeyError),
432                      attributes.instance_state, u)
433        assert_raises((AttributeError, KeyError),
434                      attributes.instance_state, None)
435
436    def test_unmapped_not_type_error(self):
437        """extension version of the same test in test_mapper.
438
439        fixes #3408
440        """
441        assert_raises_message(
442            sa.exc.ArgumentError,
443            "Class object expected, got '5'.",
444            class_mapper, 5
445        )
446
447    def test_unmapped_not_type_error_iter_ok(self):
448        """extension version of the same test in test_mapper.
449
450        fixes #3408
451        """
452        assert_raises_message(
453            sa.exc.ArgumentError,
454            r"Class object expected, got '\(5, 6\)'.",
455            class_mapper, (5, 6)
456        )
457
458
459class FinderTest(_ExtBase, fixtures.ORMTest):
460
461    def test_standard(self):
462        class A(object):
463            pass
464
465        register_class(A)
466
467        eq_(
468            type(manager_of_class(A)),
469            instrumentation.ClassManager)
470
471    def test_nativeext_interfaceexact(self):
472        class A(object):
473            __sa_instrumentation_manager__ = \
474                instrumentation.InstrumentationManager
475
476        register_class(A)
477        ne_(
478            type(manager_of_class(A)),
479            instrumentation.ClassManager)
480
481    def test_nativeext_submanager(self):
482        class Mine(instrumentation.ClassManager):
483            pass
484
485        class A(object):
486            __sa_instrumentation_manager__ = Mine
487
488        register_class(A)
489        eq_(type(manager_of_class(A)), Mine)
490
491    @modifies_instrumentation_finders
492    def test_customfinder_greedy(self):
493        class Mine(instrumentation.ClassManager):
494            pass
495
496        class A(object):
497            pass
498
499        def find(cls):
500            return Mine
501
502        instrumentation.instrumentation_finders.insert(0, find)
503        register_class(A)
504        eq_(type(manager_of_class(A)), Mine)
505
506    @modifies_instrumentation_finders
507    def test_customfinder_pass(self):
508        class A(object):
509            pass
510
511        def find(cls):
512            return None
513
514        instrumentation.instrumentation_finders.insert(0, find)
515        register_class(A)
516
517        eq_(
518            type(manager_of_class(A)),
519            instrumentation.ClassManager)
520
521
522class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest):
523
524    def test_none(self):
525        class A(object):
526            pass
527        register_class(A)
528
529        mgr_factory = lambda cls: instrumentation.ClassManager(cls)
530
531        class B(object):
532            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
533        register_class(B)
534
535        class C(object):
536            __sa_instrumentation_manager__ = instrumentation.ClassManager
537        register_class(C)
538
539    def test_single_down(self):
540        class A(object):
541            pass
542        register_class(A)
543
544        mgr_factory = lambda cls: instrumentation.ClassManager(cls)
545
546        class B(A):
547            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
548
549        assert_raises_message(
550            TypeError, "multiple instrumentation implementations",
551            register_class, B)
552
553    def test_single_up(self):
554
555        class A(object):
556            pass
557        # delay registration
558
559        mgr_factory = lambda cls: instrumentation.ClassManager(cls)
560
561        class B(A):
562            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
563        register_class(B)
564
565        assert_raises_message(
566            TypeError, "multiple instrumentation implementations",
567            register_class, A)
568
569    def test_diamond_b1(self):
570        mgr_factory = lambda cls: instrumentation.ClassManager(cls)
571
572        class A(object):
573            pass
574
575        class B1(A):
576            pass
577
578        class B2(A):
579            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
580
581        class C(object):
582            pass
583
584        assert_raises_message(
585            TypeError, "multiple instrumentation implementations",
586            register_class, B1)
587
588    def test_diamond_b2(self):
589        mgr_factory = lambda cls: instrumentation.ClassManager(cls)
590
591        class A(object):
592            pass
593
594        class B1(A):
595            pass
596
597        class B2(A):
598            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
599
600        class C(object):
601            pass
602
603        register_class(B2)
604        assert_raises_message(
605            TypeError, "multiple instrumentation implementations",
606            register_class, B1)
607
608    def test_diamond_c_b(self):
609        mgr_factory = lambda cls: instrumentation.ClassManager(cls)
610
611        class A(object):
612            pass
613
614        class B1(A):
615            pass
616
617        class B2(A):
618            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
619
620        class C(object):
621            pass
622
623        register_class(C)
624
625        assert_raises_message(
626            TypeError, "multiple instrumentation implementations",
627            register_class, B1)
628
629
630class ExtendedEventsTest(_ExtBase, fixtures.ORMTest):
631
632    """Allow custom Events implementations."""
633
634    @modifies_instrumentation_finders
635    def test_subclassed(self):
636        class MyEvents(events.InstanceEvents):
637            pass
638
639        class MyClassManager(instrumentation.ClassManager):
640            dispatch = event.dispatcher(MyEvents)
641
642        instrumentation.instrumentation_finders.insert(
643            0, lambda cls: MyClassManager)
644
645        class A(object):
646            pass
647
648        register_class(A)
649        manager = instrumentation.manager_of_class(A)
650        assert issubclass(manager.dispatch._events, MyEvents)
651