1import sqlalchemy as sa
2from sqlalchemy import event
3from sqlalchemy import ForeignKey
4from sqlalchemy import Integer
5from sqlalchemy import MetaData
6from sqlalchemy import util
7from sqlalchemy.orm import attributes
8from sqlalchemy.orm import class_mapper
9from sqlalchemy.orm import create_session
10from sqlalchemy.orm import instrumentation
11from sqlalchemy.orm import mapper
12from sqlalchemy.orm import relationship
13from sqlalchemy.testing import assert_raises
14from sqlalchemy.testing import assert_raises_message
15from sqlalchemy.testing import eq_
16from sqlalchemy.testing import fixtures
17from sqlalchemy.testing import ne_
18from sqlalchemy.testing.schema import Column
19from sqlalchemy.testing.schema import Table
20
21
22class InitTest(fixtures.ORMTest):
23    def fixture(self):
24        return Table(
25            "t",
26            MetaData(),
27            Column("id", Integer, primary_key=True),
28            Column("type", Integer),
29            Column("x", Integer),
30            Column("y", Integer),
31        )
32
33    def register(self, cls, canary):
34        original_init = cls.__init__
35        instrumentation.register_class(cls)
36        ne_(cls.__init__, original_init)
37        manager = instrumentation.manager_of_class(cls)
38
39        def init(state, args, kwargs):
40            canary.append((cls, "init", state.class_))
41
42        event.listen(manager, "init", init, raw=True)
43
44    def test_ai(self):
45        inits = []
46
47        class A(object):
48            def __init__(self):
49                inits.append((A, "__init__"))
50
51        A()
52        eq_(inits, [(A, "__init__")])
53
54    def test_A(self):
55        inits = []
56
57        class A(object):
58            pass
59
60        self.register(A, inits)
61
62        A()
63        eq_(inits, [(A, "init", A)])
64
65    def test_Ai(self):
66        inits = []
67
68        class A(object):
69            def __init__(self):
70                inits.append((A, "__init__"))
71
72        self.register(A, inits)
73
74        A()
75        eq_(inits, [(A, "init", A), (A, "__init__")])
76
77    def test_ai_B(self):
78        inits = []
79
80        class A(object):
81            def __init__(self):
82                inits.append((A, "__init__"))
83
84        class B(A):
85            pass
86
87        self.register(B, inits)
88
89        A()
90        eq_(inits, [(A, "__init__")])
91
92        del inits[:]
93
94        B()
95        eq_(inits, [(B, "init", B), (A, "__init__")])
96
97    def test_ai_Bi(self):
98        inits = []
99
100        class A(object):
101            def __init__(self):
102                inits.append((A, "__init__"))
103
104        class B(A):
105            def __init__(self):
106                inits.append((B, "__init__"))
107                super(B, self).__init__()
108
109        self.register(B, inits)
110
111        A()
112        eq_(inits, [(A, "__init__")])
113
114        del inits[:]
115
116        B()
117        eq_(inits, [(B, "init", B), (B, "__init__"), (A, "__init__")])
118
119    def test_Ai_bi(self):
120        inits = []
121
122        class A(object):
123            def __init__(self):
124                inits.append((A, "__init__"))
125
126        self.register(A, inits)
127
128        class B(A):
129            def __init__(self):
130                inits.append((B, "__init__"))
131                super(B, self).__init__()
132
133        A()
134        eq_(inits, [(A, "init", A), (A, "__init__")])
135
136        del inits[:]
137
138        B()
139        eq_(inits, [(B, "__init__"), (A, "init", B), (A, "__init__")])
140
141    def test_Ai_Bi(self):
142        inits = []
143
144        class A(object):
145            def __init__(self):
146                inits.append((A, "__init__"))
147
148        self.register(A, inits)
149
150        class B(A):
151            def __init__(self):
152                inits.append((B, "__init__"))
153                super(B, self).__init__()
154
155        self.register(B, inits)
156
157        A()
158        eq_(inits, [(A, "init", A), (A, "__init__")])
159
160        del inits[:]
161
162        B()
163        eq_(inits, [(B, "init", B), (B, "__init__"), (A, "__init__")])
164
165    def test_Ai_B(self):
166        inits = []
167
168        class A(object):
169            def __init__(self):
170                inits.append((A, "__init__"))
171
172        self.register(A, inits)
173
174        class B(A):
175            pass
176
177        self.register(B, inits)
178
179        A()
180        eq_(inits, [(A, "init", A), (A, "__init__")])
181
182        del inits[:]
183
184        B()
185        eq_(inits, [(B, "init", B), (A, "__init__")])
186
187    def test_Ai_Bi_Ci(self):
188        inits = []
189
190        class A(object):
191            def __init__(self):
192                inits.append((A, "__init__"))
193
194        self.register(A, inits)
195
196        class B(A):
197            def __init__(self):
198                inits.append((B, "__init__"))
199                super(B, self).__init__()
200
201        self.register(B, inits)
202
203        class C(B):
204            def __init__(self):
205                inits.append((C, "__init__"))
206                super(C, self).__init__()
207
208        self.register(C, inits)
209
210        A()
211        eq_(inits, [(A, "init", A), (A, "__init__")])
212
213        del inits[:]
214
215        B()
216        eq_(inits, [(B, "init", B), (B, "__init__"), (A, "__init__")])
217
218        del inits[:]
219        C()
220        eq_(
221            inits,
222            [
223                (C, "init", C),
224                (C, "__init__"),
225                (B, "__init__"),
226                (A, "__init__"),
227            ],
228        )
229
230    def test_Ai_bi_Ci(self):
231        inits = []
232
233        class A(object):
234            def __init__(self):
235                inits.append((A, "__init__"))
236
237        self.register(A, inits)
238
239        class B(A):
240            def __init__(self):
241                inits.append((B, "__init__"))
242                super(B, self).__init__()
243
244        class C(B):
245            def __init__(self):
246                inits.append((C, "__init__"))
247                super(C, self).__init__()
248
249        self.register(C, inits)
250
251        A()
252        eq_(inits, [(A, "init", A), (A, "__init__")])
253
254        del inits[:]
255
256        B()
257        eq_(inits, [(B, "__init__"), (A, "init", B), (A, "__init__")])
258
259        del inits[:]
260        C()
261        eq_(
262            inits,
263            [
264                (C, "init", C),
265                (C, "__init__"),
266                (B, "__init__"),
267                (A, "__init__"),
268            ],
269        )
270
271    def test_Ai_b_Ci(self):
272        inits = []
273
274        class A(object):
275            def __init__(self):
276                inits.append((A, "__init__"))
277
278        self.register(A, inits)
279
280        class B(A):
281            pass
282
283        class C(B):
284            def __init__(self):
285                inits.append((C, "__init__"))
286                super(C, self).__init__()
287
288        self.register(C, inits)
289
290        A()
291        eq_(inits, [(A, "init", A), (A, "__init__")])
292
293        del inits[:]
294
295        B()
296        eq_(inits, [(A, "init", B), (A, "__init__")])
297
298        del inits[:]
299        C()
300        eq_(inits, [(C, "init", C), (C, "__init__"), (A, "__init__")])
301
302    def test_Ai_B_Ci(self):
303        inits = []
304
305        class A(object):
306            def __init__(self):
307                inits.append((A, "__init__"))
308
309        self.register(A, inits)
310
311        class B(A):
312            pass
313
314        self.register(B, inits)
315
316        class C(B):
317            def __init__(self):
318                inits.append((C, "__init__"))
319                super(C, self).__init__()
320
321        self.register(C, inits)
322
323        A()
324        eq_(inits, [(A, "init", A), (A, "__init__")])
325
326        del inits[:]
327
328        B()
329        eq_(inits, [(B, "init", B), (A, "__init__")])
330
331        del inits[:]
332        C()
333        eq_(inits, [(C, "init", C), (C, "__init__"), (A, "__init__")])
334
335    def test_Ai_B_C(self):
336        inits = []
337
338        class A(object):
339            def __init__(self):
340                inits.append((A, "__init__"))
341
342        self.register(A, inits)
343
344        class B(A):
345            pass
346
347        self.register(B, inits)
348
349        class C(B):
350            pass
351
352        self.register(C, inits)
353
354        A()
355        eq_(inits, [(A, "init", A), (A, "__init__")])
356
357        del inits[:]
358
359        B()
360        eq_(inits, [(B, "init", B), (A, "__init__")])
361
362        del inits[:]
363        C()
364        eq_(inits, [(C, "init", C), (A, "__init__")])
365
366    def test_A_Bi_C(self):
367        inits = []
368
369        class A(object):
370            pass
371
372        self.register(A, inits)
373
374        class B(A):
375            def __init__(self):
376                inits.append((B, "__init__"))
377
378        self.register(B, inits)
379
380        class C(B):
381            pass
382
383        self.register(C, inits)
384
385        A()
386        eq_(inits, [(A, "init", A)])
387
388        del inits[:]
389
390        B()
391        eq_(inits, [(B, "init", B), (B, "__init__")])
392
393        del inits[:]
394        C()
395        eq_(inits, [(C, "init", C), (B, "__init__")])
396
397    def test_A_B_Ci(self):
398        inits = []
399
400        class A(object):
401            pass
402
403        self.register(A, inits)
404
405        class B(A):
406            pass
407
408        self.register(B, inits)
409
410        class C(B):
411            def __init__(self):
412                inits.append((C, "__init__"))
413
414        self.register(C, inits)
415
416        A()
417        eq_(inits, [(A, "init", A)])
418
419        del inits[:]
420
421        B()
422        eq_(inits, [(B, "init", B)])
423
424        del inits[:]
425        C()
426        eq_(inits, [(C, "init", C), (C, "__init__")])
427
428    def test_A_B_C(self):
429        inits = []
430
431        class A(object):
432            pass
433
434        self.register(A, inits)
435
436        class B(A):
437            pass
438
439        self.register(B, inits)
440
441        class C(B):
442            pass
443
444        self.register(C, inits)
445
446        A()
447        eq_(inits, [(A, "init", A)])
448
449        del inits[:]
450
451        B()
452        eq_(inits, [(B, "init", B)])
453
454        del inits[:]
455        C()
456        eq_(inits, [(C, "init", C)])
457
458    def test_defaulted_init(self):
459        class X(object):
460            def __init__(self_, a, b=123, c="abc"):
461                self_.a = a
462                self_.b = b
463                self_.c = c
464
465        instrumentation.register_class(X)
466
467        o = X("foo")
468        eq_(o.a, "foo")
469        eq_(o.b, 123)
470        eq_(o.c, "abc")
471
472        class Y(object):
473            unique = object()
474
475            class OutOfScopeForEval(object):
476                def __repr__(self_):
477                    # misleading repr
478                    return "123"
479
480            outofscope = OutOfScopeForEval()
481
482            def __init__(self_, u=unique, o=outofscope):
483                self_.u = u
484                self_.o = o
485
486        instrumentation.register_class(Y)
487
488        o = Y()
489        assert o.u is Y.unique
490        assert o.o is Y.outofscope
491
492
493class MapperInitTest(fixtures.ORMTest):
494    def fixture(self):
495        return Table(
496            "t",
497            MetaData(),
498            Column("id", Integer, primary_key=True),
499            Column("type", Integer),
500            Column("x", Integer),
501            Column("y", Integer),
502        )
503
504    def test_partially_mapped_inheritance(self):
505        class A(object):
506            pass
507
508        class B(A):
509            pass
510
511        class C(B):
512            def __init__(self, x):
513                pass
514
515        mapper(A, self.fixture())
516
517        # B is not mapped in the current implementation
518        assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, B)
519
520        # C is not mapped in the current implementation
521        assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, C)
522
523    def test_del_warning(self):
524        class A(object):
525            def __del__(self):
526                pass
527
528        assert_raises_message(
529            sa.exc.SAWarning,
530            r"__del__\(\) method on class "
531            r"<class '.*\.A'> will cause "
532            r"unreachable cycles and memory leaks, as SQLAlchemy "
533            r"instrumentation often creates reference cycles.  "
534            r"Please remove this method.",
535            mapper,
536            A,
537            self.fixture(),
538        )
539
540
541class OnLoadTest(fixtures.ORMTest):
542    """Check that Events.load is not hit in regular attributes operations."""
543
544    def test_basic(self):
545        import pickle
546
547        global A
548
549        class A(object):
550            pass
551
552        def canary(instance):
553            assert False
554
555        try:
556            instrumentation.register_class(A)
557            manager = instrumentation.manager_of_class(A)
558            event.listen(manager, "load", canary)
559
560            a = A()
561            p_a = pickle.dumps(a)
562            pickle.loads(p_a)
563        finally:
564            del A
565
566
567class NativeInstrumentationTest(fixtures.ORMTest):
568    def test_register_reserved_attribute(self):
569        class T(object):
570            pass
571
572        instrumentation.register_class(T)
573        manager = instrumentation.manager_of_class(T)
574
575        sa = instrumentation.ClassManager.STATE_ATTR
576        ma = instrumentation.ClassManager.MANAGER_ATTR
577
578        def fails(method, attr):
579            return assert_raises(
580                KeyError, getattr(manager, method), attr, property()
581            )
582
583        fails("install_member", sa)
584        fails("install_member", ma)
585        fails("install_descriptor", sa)
586        fails("install_descriptor", ma)
587
588    def test_mapped_stateattr(self):
589        t = Table(
590            "t",
591            MetaData(),
592            Column("id", Integer, primary_key=True),
593            Column(instrumentation.ClassManager.STATE_ATTR, Integer),
594        )
595
596        class T(object):
597            pass
598
599        assert_raises(KeyError, mapper, T, t)
600
601    def test_mapped_managerattr(self):
602        t = Table(
603            "t",
604            MetaData(),
605            Column("id", Integer, primary_key=True),
606            Column(instrumentation.ClassManager.MANAGER_ATTR, Integer),
607        )
608
609        class T(object):
610            pass
611
612        assert_raises(KeyError, mapper, T, t)
613
614
615class Py3KFunctionInstTest(fixtures.ORMTest):
616    __requires__ = ("python3",)
617
618    def _instrument(self, cls):
619        manager = instrumentation.register_class(cls)
620        canary = []
621
622        def check(target, args, kwargs):
623            canary.append((args, kwargs))
624
625        event.listen(manager, "init", check)
626        return cls, canary
627
628    def test_kw_only_args(self):
629        cls, canary = self._kw_only_fixture()
630
631        cls("a", b="b", c="c")
632        eq_(canary, [(("a",), {"b": "b", "c": "c"})])
633
634    def test_kw_plus_posn_args(self):
635        cls, canary = self._kw_plus_posn_fixture()
636
637        cls("a", 1, 2, 3, b="b", c="c")
638        eq_(canary, [(("a", 1, 2, 3), {"b": "b", "c": "c"})])
639
640    def test_kw_only_args_plus_opt(self):
641        cls, canary = self._kw_opt_fixture()
642
643        cls("a", b="b")
644        eq_(canary, [(("a",), {"b": "b", "c": "c"})])
645
646        canary[:] = []
647        cls("a", b="b", c="d")
648        eq_(canary, [(("a",), {"b": "b", "c": "d"})])
649
650    def test_kw_only_sig(self):
651        cls, canary = self._kw_only_fixture()
652        assert_raises(TypeError, cls, "a", "b", "c")
653
654    def test_kw_plus_opt_sig(self):
655        cls, canary = self._kw_only_fixture()
656        assert_raises(TypeError, cls, "a", "b", "c")
657
658        assert_raises(TypeError, cls, "a", "b", c="c")
659
660
661if util.py3k:
662    _locals = {}
663    exec(
664        """
665def _kw_only_fixture(self):
666    class A(object):
667        def __init__(self, a, *, b, c):
668            self.a = a
669            self.b = b
670            self.c = c
671    return self._instrument(A)
672
673def _kw_plus_posn_fixture(self):
674    class A(object):
675        def __init__(self, a, *args, b, c):
676            self.a = a
677            self.b = b
678            self.c = c
679    return self._instrument(A)
680
681def _kw_opt_fixture(self):
682    class A(object):
683        def __init__(self, a, *, b, c="c"):
684            self.a = a
685            self.b = b
686            self.c = c
687    return self._instrument(A)
688""",
689        _locals,
690    )
691    for k in _locals:
692        setattr(Py3KFunctionInstTest, k, _locals[k])
693
694
695class MiscTest(fixtures.ORMTest):
696    """Seems basic, but not directly covered elsewhere!"""
697
698    def test_compileonattr(self):
699        t = Table(
700            "t",
701            MetaData(),
702            Column("id", Integer, primary_key=True),
703            Column("x", Integer),
704        )
705
706        class A(object):
707            pass
708
709        mapper(A, t)
710
711        a = A()
712        assert a.id is None
713
714    def test_compileonattr_rel(self):
715        m = MetaData()
716        t1 = Table(
717            "t1",
718            m,
719            Column("id", Integer, primary_key=True),
720            Column("x", Integer),
721        )
722        t2 = Table(
723            "t2",
724            m,
725            Column("id", Integer, primary_key=True),
726            Column("t1_id", Integer, ForeignKey("t1.id")),
727        )
728
729        class A(object):
730            pass
731
732        class B(object):
733            pass
734
735        mapper(A, t1, properties=dict(bs=relationship(B)))
736        mapper(B, t2)
737
738        a = A()
739        assert not a.bs
740
741    def test_uninstrument(self):
742        class A(object):
743            pass
744
745        manager = instrumentation.register_class(A)
746        attributes.register_attribute(A, "x", uselist=False, useobject=False)
747
748        assert instrumentation.manager_of_class(A) is manager
749        instrumentation.unregister_class(A)
750        assert instrumentation.manager_of_class(A) is None
751        assert not hasattr(A, "x")
752
753        # I prefer 'is' here but on pypy
754        # it seems only == works
755        assert A.__init__ == object.__init__
756
757    def test_compileonattr_rel_backref_a(self):
758        m = MetaData()
759        t1 = Table(
760            "t1",
761            m,
762            Column("id", Integer, primary_key=True),
763            Column("x", Integer),
764        )
765        t2 = Table(
766            "t2",
767            m,
768            Column("id", Integer, primary_key=True),
769            Column("t1_id", Integer, ForeignKey("t1.id")),
770        )
771
772        class Base(object):
773            def __init__(self, *args, **kwargs):
774                pass
775
776        for base in object, Base:
777
778            class A(base):
779                pass
780
781            class B(base):
782                pass
783
784            mapper(A, t1, properties=dict(bs=relationship(B, backref="a")))
785            mapper(B, t2)
786
787            b = B()
788            assert b.a is None
789            a = A()
790            b.a = a
791
792            session = create_session()
793            session.add(b)
794            assert a in session, "base is %s" % base
795
796    def test_compileonattr_rel_backref_b(self):
797        m = MetaData()
798        t1 = Table(
799            "t1",
800            m,
801            Column("id", Integer, primary_key=True),
802            Column("x", Integer),
803        )
804        t2 = Table(
805            "t2",
806            m,
807            Column("id", Integer, primary_key=True),
808            Column("t1_id", Integer, ForeignKey("t1.id")),
809        )
810
811        class Base(object):
812            def __init__(self):
813                pass
814
815        class Base_AKW(object):
816            def __init__(self, *args, **kwargs):
817                pass
818
819        for base in object, Base, Base_AKW:
820
821            class A(base):
822                pass
823
824            class B(base):
825                pass
826
827            mapper(A, t1)
828            mapper(B, t2, properties=dict(a=relationship(A, backref="bs")))
829
830            a = A()
831            b = B()
832            b.a = a
833
834            session = create_session()
835            session.add(a)
836            assert b in session, "base: %s" % base
837