1import warnings
2from sqlalchemy.testing import eq_, assert_raises, assert_raises_message
3from sqlalchemy import *
4from sqlalchemy import exc as sa_exc, util, event
5from sqlalchemy.orm import *
6from sqlalchemy.orm.util import instance_str
7from sqlalchemy.orm import exc as orm_exc, attributes
8from sqlalchemy.testing.assertsql import AllOf, CompiledSQL, Or
9from sqlalchemy.sql import table, column
10from sqlalchemy import testing
11from sqlalchemy.testing import engines
12from sqlalchemy.testing import fixtures
13from test.orm import _fixtures
14from sqlalchemy.testing.schema import Table, Column
15from sqlalchemy import inspect
16from sqlalchemy.ext.declarative import declarative_base
17from sqlalchemy.testing.util import gc_collect
18
19class O2MTest(fixtures.MappedTest):
20    """deals with inheritance and one-to-many relationships"""
21    @classmethod
22    def define_tables(cls, metadata):
23        global foo, bar, blub
24        foo = Table('foo', metadata,
25            Column('id', Integer, primary_key=True,
26                            test_needs_autoincrement=True),
27            Column('data', String(20)))
28
29        bar = Table('bar', metadata,
30            Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
31            Column('bar_data', String(20)))
32
33        blub = Table('blub', metadata,
34            Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
35            Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
36            Column('blub_data', String(20)))
37
38    def test_basic(self):
39        class Foo(object):
40            def __init__(self, data=None):
41                self.data = data
42            def __repr__(self):
43                return "Foo id %d, data %s" % (self.id, self.data)
44        mapper(Foo, foo)
45
46        class Bar(Foo):
47            def __repr__(self):
48                return "Bar id %d, data %s" % (self.id, self.data)
49
50        mapper(Bar, bar, inherits=Foo)
51
52        class Blub(Bar):
53            def __repr__(self):
54                return "Blub id %d, data %s" % (self.id, self.data)
55
56        mapper(Blub, blub, inherits=Bar, properties={
57            'parent_foo':relationship(Foo)
58        })
59
60        sess = create_session()
61        b1 = Blub("blub #1")
62        b2 = Blub("blub #2")
63        f = Foo("foo #1")
64        sess.add(b1)
65        sess.add(b2)
66        sess.add(f)
67        b1.parent_foo = f
68        b2.parent_foo = f
69        sess.flush()
70        compare = ','.join([repr(b1), repr(b2), repr(b1.parent_foo),
71                           repr(b2.parent_foo)])
72        sess.expunge_all()
73        l = sess.query(Blub).all()
74        result = ','.join([repr(l[0]), repr(l[1]),
75                          repr(l[0].parent_foo), repr(l[1].parent_foo)])
76        eq_(compare, result)
77        eq_(l[0].parent_foo.data, 'foo #1')
78        eq_(l[1].parent_foo.data, 'foo #1')
79
80class PolymorphicResolutionMultiLevel(fixtures.DeclarativeMappedTest,
81                                        testing.AssertsCompiledSQL):
82    run_setup_mappers = 'once'
83    __dialect__ = 'default'
84
85    @classmethod
86    def setup_classes(cls):
87        Base = cls.DeclarativeBasic
88        class A(Base):
89            __tablename__ = 'a'
90            id = Column(Integer, primary_key=True)
91        class B(A):
92            __tablename__ = 'b'
93            id = Column(Integer, ForeignKey('a.id'), primary_key=True)
94        class C(A):
95            __tablename__ = 'c'
96            id = Column(Integer, ForeignKey('a.id'), primary_key=True)
97        class D(B):
98            __tablename__ = 'd'
99            id = Column(Integer, ForeignKey('b.id'), primary_key=True)
100
101    def test_ordered_b_d(self):
102        a_mapper = inspect(self.classes.A)
103        eq_(
104            a_mapper._mappers_from_spec(
105                    [self.classes.B, self.classes.D], None),
106            [a_mapper, inspect(self.classes.B), inspect(self.classes.D)]
107        )
108
109    def test_a(self):
110        a_mapper = inspect(self.classes.A)
111        eq_(
112            a_mapper._mappers_from_spec(
113                    [self.classes.A], None),
114            [a_mapper]
115        )
116
117    def test_b_d_selectable(self):
118        a_mapper = inspect(self.classes.A)
119        spec = [self.classes.D, self.classes.B]
120        eq_(
121            a_mapper._mappers_from_spec(
122                    spec,
123                    self.classes.B.__table__.join(self.classes.D.__table__)
124            ),
125            [inspect(self.classes.B), inspect(self.classes.D)]
126        )
127
128    def test_d_selectable(self):
129        a_mapper = inspect(self.classes.A)
130        spec = [self.classes.D]
131        eq_(
132            a_mapper._mappers_from_spec(
133                    spec,
134                    self.classes.B.__table__.join(self.classes.D.__table__)
135            ),
136            [inspect(self.classes.D)]
137        )
138
139    def test_reverse_d_b(self):
140        a_mapper = inspect(self.classes.A)
141        spec = [self.classes.D, self.classes.B]
142        eq_(
143            a_mapper._mappers_from_spec(
144                    spec, None),
145            [a_mapper, inspect(self.classes.B), inspect(self.classes.D)]
146        )
147        mappers, selectable = a_mapper._with_polymorphic_args(spec=spec)
148        self.assert_compile(selectable,
149            "a LEFT OUTER JOIN b ON a.id = b.id "
150            "LEFT OUTER JOIN d ON b.id = d.id")
151
152    def test_d_b_missing(self):
153        a_mapper = inspect(self.classes.A)
154        spec = [self.classes.D]
155        eq_(
156            a_mapper._mappers_from_spec(
157                    spec, None),
158            [a_mapper, inspect(self.classes.B), inspect(self.classes.D)]
159        )
160        mappers, selectable = a_mapper._with_polymorphic_args(spec=spec)
161        self.assert_compile(selectable,
162            "a LEFT OUTER JOIN b ON a.id = b.id "
163            "LEFT OUTER JOIN d ON b.id = d.id")
164
165    def test_d_c_b(self):
166        a_mapper = inspect(self.classes.A)
167        spec = [self.classes.D, self.classes.C, self.classes.B]
168        ms = a_mapper._mappers_from_spec(spec, None)
169
170        eq_(
171            ms[-1], inspect(self.classes.D)
172        )
173        eq_(ms[0], a_mapper)
174        eq_(
175            set(ms[1:3]), set(a_mapper._inheriting_mappers)
176        )
177
178class PolymorphicOnNotLocalTest(fixtures.MappedTest):
179    @classmethod
180    def define_tables(cls, metadata):
181        t1 = Table('t1', metadata,
182                Column('id', Integer, primary_key=True,
183                            test_needs_autoincrement=True),
184                Column('x', String(10)),
185                Column('q', String(10)))
186        t2 = Table('t2', metadata,
187                Column('t2id', Integer, primary_key=True,
188                            test_needs_autoincrement=True),
189                Column('y', String(10)),
190                Column('xid', ForeignKey('t1.id')))
191
192    @classmethod
193    def setup_classes(cls):
194        class Parent(cls.Comparable):
195            pass
196        class Child(Parent):
197            pass
198
199    def test_non_col_polymorphic_on(self):
200        Parent = self.classes.Parent
201        t2 = self.tables.t2
202        assert_raises_message(
203            sa_exc.ArgumentError,
204            "Can't determine polymorphic_on "
205            "value 'im not a column' - no "
206            "attribute is mapped to this name.",
207            mapper,
208            Parent, t2, polymorphic_on="im not a column"
209        )
210
211    def test_polymorphic_on_non_expr_prop(self):
212        t2, t1 = self.tables.t2, self.tables.t1
213        Parent = self.classes.Parent
214
215        t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias()
216        def go():
217            interface_m = mapper(Parent, t2,
218                                polymorphic_on=lambda:"hi",
219                                polymorphic_identity=0)
220
221        assert_raises_message(
222            sa_exc.ArgumentError,
223            "Only direct column-mapped property or "
224            "SQL expression can be passed for polymorphic_on",
225            go
226        )
227
228    def test_polymorphic_on_not_present_col(self):
229        t2, t1 = self.tables.t2, self.tables.t1
230        Parent = self.classes.Parent
231        t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias()
232        def go():
233            t1t2_join_2 = select([t1.c.q], from_obj=[t1.join(t2)]).alias()
234            interface_m = mapper(Parent, t2,
235                                polymorphic_on=t1t2_join.c.x,
236                                with_polymorphic=('*', t1t2_join_2),
237                                polymorphic_identity=0)
238        assert_raises_message(
239            sa_exc.InvalidRequestError,
240            "Could not map polymorphic_on column 'x' to the mapped table - "
241            "polymorphic loads will not function properly",
242            go
243        )
244
245    def test_polymorphic_on_only_in_with_poly(self):
246        t2, t1 = self.tables.t2, self.tables.t1
247        Parent = self.classes.Parent
248        t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias()
249        # if its in the with_polymorphic, then its OK
250        mapper(Parent, t2,
251                                polymorphic_on=t1t2_join.c.x,
252                                with_polymorphic=('*', t1t2_join),
253                                polymorphic_identity=0)
254
255    def test_polymorpic_on_not_in_with_poly(self):
256        t2, t1 = self.tables.t2, self.tables.t1
257        Parent = self.classes.Parent
258
259        t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias()
260
261        # if with_polymorphic, but its not present, not OK
262        def go():
263            t1t2_join_2 = select([t1.c.q], from_obj=[t1.join(t2)]).alias()
264            interface_m = mapper(Parent, t2,
265                                polymorphic_on=t1t2_join.c.x,
266                                with_polymorphic=('*', t1t2_join_2),
267                                polymorphic_identity=0)
268        assert_raises_message(
269            sa_exc.InvalidRequestError,
270            "Could not map polymorphic_on column 'x' "
271            "to the mapped table - "
272            "polymorphic loads will not function properly",
273            go
274        )
275
276    def test_polymorphic_on_expr_explicit_map(self):
277        t2, t1 = self.tables.t2, self.tables.t1
278        Parent, Child = self.classes.Parent, self.classes.Child
279        expr = case([
280            (t1.c.x=="p", "parent"),
281            (t1.c.x=="c", "child"),
282        ])
283        mapper(Parent, t1, properties={
284            "discriminator":column_property(expr)
285        }, polymorphic_identity="parent",
286            polymorphic_on=expr)
287        mapper(Child, t2, inherits=Parent,
288                polymorphic_identity="child")
289
290        self._roundtrip(parent_ident='p', child_ident='c')
291
292    def test_polymorphic_on_expr_implicit_map_no_label_joined(self):
293        t2, t1 = self.tables.t2, self.tables.t1
294        Parent, Child = self.classes.Parent, self.classes.Child
295        expr = case([
296            (t1.c.x=="p", "parent"),
297            (t1.c.x=="c", "child"),
298        ])
299        mapper(Parent, t1, polymorphic_identity="parent",
300            polymorphic_on=expr)
301        mapper(Child, t2, inherits=Parent, polymorphic_identity="child")
302
303        self._roundtrip(parent_ident='p', child_ident='c')
304
305    def test_polymorphic_on_expr_implicit_map_w_label_joined(self):
306        t2, t1 = self.tables.t2, self.tables.t1
307        Parent, Child = self.classes.Parent, self.classes.Child
308        expr = case([
309            (t1.c.x=="p", "parent"),
310            (t1.c.x=="c", "child"),
311        ]).label(None)
312        mapper(Parent, t1, polymorphic_identity="parent",
313            polymorphic_on=expr)
314        mapper(Child, t2, inherits=Parent, polymorphic_identity="child")
315
316        self._roundtrip(parent_ident='p', child_ident='c')
317
318    def test_polymorphic_on_expr_implicit_map_no_label_single(self):
319        """test that single_table_criterion is propagated
320        with a standalone expr"""
321        t2, t1 = self.tables.t2, self.tables.t1
322        Parent, Child = self.classes.Parent, self.classes.Child
323        expr = case([
324            (t1.c.x=="p", "parent"),
325            (t1.c.x=="c", "child"),
326        ])
327        mapper(Parent, t1, polymorphic_identity="parent",
328            polymorphic_on=expr)
329        mapper(Child, inherits=Parent, polymorphic_identity="child")
330
331        self._roundtrip(parent_ident='p', child_ident='c')
332
333    def test_polymorphic_on_expr_implicit_map_w_label_single(self):
334        """test that single_table_criterion is propagated
335        with a standalone expr"""
336        t2, t1 = self.tables.t2, self.tables.t1
337        Parent, Child = self.classes.Parent, self.classes.Child
338        expr = case([
339            (t1.c.x=="p", "parent"),
340            (t1.c.x=="c", "child"),
341        ]).label(None)
342        mapper(Parent, t1, polymorphic_identity="parent",
343            polymorphic_on=expr)
344        mapper(Child, inherits=Parent, polymorphic_identity="child")
345
346        self._roundtrip(parent_ident='p', child_ident='c')
347
348    def test_polymorphic_on_column_prop(self):
349        t2, t1 = self.tables.t2, self.tables.t1
350        Parent, Child = self.classes.Parent, self.classes.Child
351        expr = case([
352            (t1.c.x=="p", "parent"),
353            (t1.c.x=="c", "child"),
354        ])
355        cprop = column_property(expr)
356        mapper(Parent, t1, properties={
357            "discriminator":cprop
358        }, polymorphic_identity="parent",
359            polymorphic_on=cprop)
360        mapper(Child, t2, inherits=Parent,
361                polymorphic_identity="child")
362
363        self._roundtrip(parent_ident='p', child_ident='c')
364
365    def test_polymorphic_on_column_str_prop(self):
366        t2, t1 = self.tables.t2, self.tables.t1
367        Parent, Child = self.classes.Parent, self.classes.Child
368        expr = case([
369            (t1.c.x=="p", "parent"),
370            (t1.c.x=="c", "child"),
371        ])
372        cprop = column_property(expr)
373        mapper(Parent, t1, properties={
374            "discriminator":cprop
375        }, polymorphic_identity="parent",
376            polymorphic_on="discriminator")
377        mapper(Child, t2, inherits=Parent,
378                polymorphic_identity="child")
379
380        self._roundtrip(parent_ident='p', child_ident='c')
381
382    def test_polymorphic_on_synonym(self):
383        t2, t1 = self.tables.t2, self.tables.t1
384        Parent, Child = self.classes.Parent, self.classes.Child
385        cprop = column_property(t1.c.x)
386        assert_raises_message(
387            sa_exc.ArgumentError,
388            "Only direct column-mapped property or "
389            "SQL expression can be passed for polymorphic_on",
390            mapper, Parent, t1, properties={
391            "discriminator":cprop,
392            "discrim_syn":synonym(cprop)
393        }, polymorphic_identity="parent",
394            polymorphic_on="discrim_syn")
395
396    def _roundtrip(self, set_event=True, parent_ident='parent', child_ident='child'):
397        Parent, Child = self.classes.Parent, self.classes.Child
398
399        if set_event:
400            @event.listens_for(Parent, "init", propagate=True)
401            def set_identity(instance, *arg, **kw):
402                ident = object_mapper(instance).polymorphic_identity
403                if ident == 'parent':
404                    instance.x = parent_ident
405                elif ident == 'child':
406                    instance.x = child_ident
407                else:
408                    assert False, "Got unexpected identity %r" % ident
409
410        s = Session(testing.db)
411        s.add_all([
412            Parent(q="p1"),
413            Child(q="c1", y="c1"),
414            Parent(q="p2"),
415        ])
416        s.commit()
417        s.close()
418
419        eq_(
420            [type(t) for t in s.query(Parent).order_by(Parent.id)],
421            [Parent, Child, Parent]
422        )
423
424        eq_(
425            [type(t) for t in s.query(Child).all()],
426            [Child]
427        )
428
429class SortOnlyOnImportantFKsTest(fixtures.MappedTest):
430    @classmethod
431    def define_tables(cls, metadata):
432        Table('a', metadata,
433                Column('id', Integer, primary_key=True,
434                                    test_needs_autoincrement=True),
435                Column('b_id', Integer,
436                        ForeignKey('b.id', use_alter=True, name='b'))
437            )
438        Table('b', metadata,
439            Column('id', Integer, ForeignKey('a.id'), primary_key=True)
440            )
441
442    @classmethod
443    def setup_classes(cls):
444        Base = declarative_base()
445
446        class A(Base):
447            __tablename__ = "a"
448
449            id = Column(Integer, primary_key=True,
450                                    test_needs_autoincrement=True)
451            b_id = Column(Integer, ForeignKey('b.id'))
452
453        class B(A):
454            __tablename__ = "b"
455
456            id = Column(Integer, ForeignKey('a.id'), primary_key=True)
457
458            __mapper_args__ = {'inherit_condition': id == A.id}
459
460        cls.classes.A = A
461        cls.classes.B = B
462
463    def test_flush(self):
464        s = Session(testing.db)
465        s.add(self.classes.B())
466        s.flush()
467
468class FalseDiscriminatorTest(fixtures.MappedTest):
469    @classmethod
470    def define_tables(cls, metadata):
471        global t1
472        t1 = Table('t1', metadata,
473            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
474            Column('type', Boolean, nullable=False))
475
476    def test_false_on_sub(self):
477        class Foo(object):
478            pass
479        class Bar(Foo):
480            pass
481        mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=True)
482        mapper(Bar, inherits=Foo, polymorphic_identity=False)
483        sess = create_session()
484        b1 = Bar()
485        sess.add(b1)
486        sess.flush()
487        assert b1.type is False
488        sess.expunge_all()
489        assert isinstance(sess.query(Foo).one(), Bar)
490
491    def test_false_on_base(self):
492        class Ding(object):pass
493        class Bat(Ding):pass
494        mapper(Ding, t1, polymorphic_on=t1.c.type, polymorphic_identity=False)
495        mapper(Bat, inherits=Ding, polymorphic_identity=True)
496        sess = create_session()
497        d1 = Ding()
498        sess.add(d1)
499        sess.flush()
500        assert d1.type is False
501        sess.expunge_all()
502        assert sess.query(Ding).one() is not None
503
504class PolymorphicSynonymTest(fixtures.MappedTest):
505    @classmethod
506    def define_tables(cls, metadata):
507        global t1, t2
508        t1 = Table('t1', metadata,
509                   Column('id', Integer, primary_key=True,
510                                    test_needs_autoincrement=True),
511                   Column('type', String(10), nullable=False),
512                   Column('info', String(255)))
513        t2 = Table('t2', metadata,
514                   Column('id', Integer, ForeignKey('t1.id'),
515                                            primary_key=True),
516                   Column('data', String(10), nullable=False))
517
518    def test_polymorphic_synonym(self):
519        class T1(fixtures.ComparableEntity):
520            def info(self):
521                return "THE INFO IS:" + self._info
522            def _set_info(self, x):
523                self._info = x
524            info = property(info, _set_info)
525
526        class T2(T1):pass
527
528        mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1',
529            properties={
530                'info':synonym('_info', map_column=True)
531            })
532        mapper(T2, t2, inherits=T1, polymorphic_identity='t2')
533        sess = create_session()
534        at1 = T1(info='at1')
535        at2 = T2(info='at2', data='t2 data')
536        sess.add(at1)
537        sess.add(at2)
538        sess.flush()
539        sess.expunge_all()
540        eq_(sess.query(T2).filter(T2.info=='at2').one(), at2)
541        eq_(at2.info, "THE INFO IS:at2")
542
543class PolymorphicAttributeManagementTest(fixtures.MappedTest):
544    """Test polymorphic_on can be assigned, can be mirrored, etc."""
545
546    run_setup_mappers = 'once'
547
548    @classmethod
549    def define_tables(cls, metadata):
550        Table('table_a', metadata,
551            Column('id', Integer, primary_key=True,
552                                test_needs_autoincrement=True),
553            Column('class_name', String(50))
554        )
555        Table('table_b', metadata,
556           Column('id', Integer, ForeignKey('table_a.id'),
557                                primary_key=True),
558           Column('class_name', String(50)),
559        )
560        Table('table_c', metadata,
561           Column('id', Integer, ForeignKey('table_b.id'),
562                                primary_key=True),
563           Column('data', String(10))
564        )
565
566    @classmethod
567    def setup_classes(cls):
568        table_b, table_c, table_a = (cls.tables.table_b,
569                                cls.tables.table_c,
570                                cls.tables.table_a)
571
572        class A(cls.Basic):
573            pass
574        class B(A):
575            pass
576        class C(B):
577            pass
578        class D(B):
579            pass
580
581        mapper(A, table_a,
582                        polymorphic_on=table_a.c.class_name,
583                        polymorphic_identity='a')
584        mapper(B, table_b, inherits=A,
585                        polymorphic_on=table_b.c.class_name,
586                        polymorphic_identity='b',
587                        properties=dict(class_name=[table_a.c.class_name, table_b.c.class_name]))
588        mapper(C, table_c, inherits=B,
589                        polymorphic_identity='c')
590        mapper(D, inherits=B,
591                        polymorphic_identity='d')
592
593    def test_poly_configured_immediate(self):
594        A, C, B = (self.classes.A,
595                                self.classes.C,
596                                self.classes.B)
597
598        a = A()
599        b = B()
600        c = C()
601        eq_(a.class_name, 'a')
602        eq_(b.class_name, 'b')
603        eq_(c.class_name, 'c')
604
605    def test_base_class(self):
606        A, C, B = (self.classes.A,
607                                self.classes.C,
608                                self.classes.B)
609
610        sess = Session()
611        c1 = C()
612        sess.add(c1)
613        sess.commit()
614
615        assert isinstance(sess.query(B).first(), C)
616
617        sess.close()
618
619        assert isinstance(sess.query(A).first(), C)
620
621    def test_valid_assignment_upwards(self):
622        """test that we can assign 'd' to a B, since B/D
623        both involve the same set of tables.
624        """
625        D, B = self.classes.D, self.classes.B
626
627        sess = Session()
628        b1 = B()
629        b1.class_name = 'd'
630        sess.add(b1)
631        sess.commit()
632        sess.close()
633        assert isinstance(sess.query(B).first(), D)
634
635    def test_invalid_assignment_downwards(self):
636        """test that we warn on assign of 'b' to a C, since this adds
637        a row to the C table we'd never load.
638        """
639        C = self.classes.C
640
641        sess = Session()
642        c1 = C()
643        c1.class_name = 'b'
644        sess.add(c1)
645        assert_raises_message(
646            sa_exc.SAWarning,
647            "Flushing object %s with incompatible "
648            "polymorphic identity 'b'; the object may not "
649            "refresh and/or load correctly" % instance_str(c1),
650            sess.flush
651        )
652
653    def test_invalid_assignment_upwards(self):
654        """test that we warn on assign of 'c' to a B, since we will have a
655        "C" row that has no joined row, which will cause object
656        deleted errors.
657        """
658        B = self.classes.B
659
660        sess = Session()
661        b1 = B()
662        b1.class_name = 'c'
663        sess.add(b1)
664        assert_raises_message(
665            sa_exc.SAWarning,
666            "Flushing object %s with incompatible "
667            "polymorphic identity 'c'; the object may not "
668            "refresh and/or load correctly" % instance_str(b1),
669            sess.flush
670        )
671
672    def test_entirely_oob_assignment(self):
673        """test warn on an unknown polymorphic identity.
674        """
675        B = self.classes.B
676
677        sess = Session()
678        b1 = B()
679        b1.class_name = 'xyz'
680        sess.add(b1)
681        assert_raises_message(
682            sa_exc.SAWarning,
683            "Flushing object %s with incompatible "
684            "polymorphic identity 'xyz'; the object may not "
685            "refresh and/or load correctly" % instance_str(b1),
686            sess.flush
687        )
688
689    def test_not_set_on_upate(self):
690        C = self.classes.C
691
692        sess = Session()
693        c1 = C()
694        sess.add(c1)
695        sess.commit()
696        sess.expire(c1)
697
698        c1.data = 'foo'
699        sess.flush()
700
701    def test_validate_on_upate(self):
702        C = self.classes.C
703
704        sess = Session()
705        c1 = C()
706        sess.add(c1)
707        sess.commit()
708        sess.expire(c1)
709
710        c1.class_name = 'b'
711        assert_raises_message(
712            sa_exc.SAWarning,
713            "Flushing object %s with incompatible "
714            "polymorphic identity 'b'; the object may not "
715            "refresh and/or load correctly" % instance_str(c1),
716            sess.flush
717        )
718
719class CascadeTest(fixtures.MappedTest):
720    """that cascades on polymorphic relationships continue
721    cascading along the path of the instance's mapper, not
722    the base mapper."""
723
724    @classmethod
725    def define_tables(cls, metadata):
726        global t1, t2, t3, t4
727        t1= Table('t1', metadata,
728            Column('id', Integer, primary_key=True,
729                                    test_needs_autoincrement=True),
730            Column('data', String(30))
731            )
732
733        t2 = Table('t2', metadata,
734            Column('id', Integer, primary_key=True,
735                                    test_needs_autoincrement=True),
736            Column('t1id', Integer, ForeignKey('t1.id')),
737            Column('type', String(30)),
738            Column('data', String(30))
739        )
740        t3 = Table('t3', metadata,
741            Column('id', Integer, ForeignKey('t2.id'),
742                                    primary_key=True),
743            Column('moredata', String(30)))
744
745        t4 = Table('t4', metadata,
746            Column('id', Integer, primary_key=True,
747                                    test_needs_autoincrement=True),
748            Column('t3id', Integer, ForeignKey('t3.id')),
749            Column('data', String(30)))
750
751    def test_cascade(self):
752        class T1(fixtures.BasicEntity):
753            pass
754        class T2(fixtures.BasicEntity):
755            pass
756        class T3(T2):
757            pass
758        class T4(fixtures.BasicEntity):
759            pass
760
761        mapper(T1, t1, properties={
762            't2s':relationship(T2, cascade="all")
763        })
764        mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2')
765        mapper(T3, t3, inherits=T2, polymorphic_identity='t3', properties={
766            't4s':relationship(T4, cascade="all")
767        })
768        mapper(T4, t4)
769
770        sess = create_session()
771        t1_1 = T1(data='t1')
772
773        t3_1 = T3(data ='t3', moredata='t3')
774        t2_1 = T2(data='t2')
775
776        t1_1.t2s.append(t2_1)
777        t1_1.t2s.append(t3_1)
778
779        t4_1 = T4(data='t4')
780        t3_1.t4s.append(t4_1)
781
782        sess.add(t1_1)
783
784
785        assert t4_1 in sess.new
786        sess.flush()
787
788        sess.delete(t1_1)
789        assert t4_1 in sess.deleted
790        sess.flush()
791
792class M2OUseGetTest(fixtures.MappedTest):
793    @classmethod
794    def define_tables(cls, metadata):
795        Table('base', metadata,
796            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
797            Column('type', String(30))
798        )
799        Table('sub', metadata,
800            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
801        )
802        Table('related', metadata,
803            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
804            Column('sub_id', Integer, ForeignKey('sub.id')),
805        )
806
807    def test_use_get(self):
808        base, sub, related = (self.tables.base,
809                                self.tables.sub,
810                                self.tables.related)
811
812        # test [ticket:1186]
813        class Base(fixtures.BasicEntity):
814            pass
815        class Sub(Base):
816            pass
817        class Related(Base):
818            pass
819        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b')
820        mapper(Sub, sub, inherits=Base, polymorphic_identity='s')
821        mapper(Related, related, properties={
822            # previously, this was needed for the comparison to occur:
823            # the 'primaryjoin' looks just like "Sub"'s "get" clause (based on the Base id),
824            # and foreign_keys since that join condition doesn't actually have any fks in it
825            #'sub':relationship(Sub, primaryjoin=base.c.id==related.c.sub_id, foreign_keys=related.c.sub_id)
826
827            # now we can use this:
828            'sub':relationship(Sub)
829        })
830
831        assert class_mapper(Related).get_property('sub').strategy.use_get
832
833        sess = create_session()
834        s1 = Sub()
835        r1 = Related(sub=s1)
836        sess.add(r1)
837        sess.flush()
838        sess.expunge_all()
839
840        r1 = sess.query(Related).first()
841        s1 = sess.query(Sub).first()
842        def go():
843            assert r1.sub
844        self.assert_sql_count(testing.db, go, 0)
845
846
847class GetTest(fixtures.MappedTest):
848    @classmethod
849    def define_tables(cls, metadata):
850        global foo, bar, blub
851        foo = Table('foo', metadata,
852            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
853            Column('type', String(30)),
854            Column('data', String(20)))
855
856        bar = Table('bar', metadata,
857            Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
858            Column('bar_data', String(20)))
859
860        blub = Table('blub', metadata,
861            Column('blub_id', Integer, primary_key=True, test_needs_autoincrement=True),
862            Column('foo_id', Integer, ForeignKey('foo.id')),
863            Column('bar_id', Integer, ForeignKey('bar.id')),
864            Column('blub_data', String(20)))
865
866    @classmethod
867    def setup_classes(cls):
868        class Foo(cls.Basic):
869            pass
870
871        class Bar(Foo):
872            pass
873
874        class Blub(Bar):
875            pass
876
877    def test_get_polymorphic(self):
878        self._do_get_test(True)
879
880    def test_get_nonpolymorphic(self):
881        self._do_get_test(False)
882
883    def _do_get_test(self, polymorphic):
884        foo, Bar, Blub, blub, bar, Foo = (self.tables.foo,
885                                self.classes.Bar,
886                                self.classes.Blub,
887                                self.tables.blub,
888                                self.tables.bar,
889                                self.classes.Foo)
890
891        if polymorphic:
892            mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo')
893            mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar')
894            mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub')
895        else:
896            mapper(Foo, foo)
897            mapper(Bar, bar, inherits=Foo)
898            mapper(Blub, blub, inherits=Bar)
899
900        sess = create_session()
901        f = Foo()
902        b = Bar()
903        bl = Blub()
904        sess.add(f)
905        sess.add(b)
906        sess.add(bl)
907        sess.flush()
908
909        if polymorphic:
910            def go():
911                assert sess.query(Foo).get(f.id) is f
912                assert sess.query(Foo).get(b.id) is b
913                assert sess.query(Foo).get(bl.id) is bl
914                assert sess.query(Bar).get(b.id) is b
915                assert sess.query(Bar).get(bl.id) is bl
916                assert sess.query(Blub).get(bl.id) is bl
917
918                # test class mismatches - item is present
919                # in the identity map but we requested a subclass
920                assert sess.query(Blub).get(f.id) is None
921                assert sess.query(Blub).get(b.id) is None
922                assert sess.query(Bar).get(f.id) is None
923
924            self.assert_sql_count(testing.db, go, 0)
925        else:
926            # this is testing the 'wrong' behavior of using get()
927            # polymorphically with mappers that are not configured to be
928            # polymorphic.  the important part being that get() always
929            # returns an instance of the query's type.
930            def go():
931                assert sess.query(Foo).get(f.id) is f
932
933                bb = sess.query(Foo).get(b.id)
934                assert isinstance(b, Foo) and bb.id==b.id
935
936                bll = sess.query(Foo).get(bl.id)
937                assert isinstance(bll, Foo) and bll.id==bl.id
938
939                assert sess.query(Bar).get(b.id) is b
940
941                bll = sess.query(Bar).get(bl.id)
942                assert isinstance(bll, Bar) and bll.id == bl.id
943
944                assert sess.query(Blub).get(bl.id) is bl
945
946            self.assert_sql_count(testing.db, go, 3)
947
948
949class EagerLazyTest(fixtures.MappedTest):
950    """tests eager load/lazy load of child items off inheritance mappers, tests that
951    LazyLoader constructs the right query condition."""
952
953    @classmethod
954    def define_tables(cls, metadata):
955        global foo, bar, bar_foo
956        foo = Table('foo', metadata,
957                    Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
958                    Column('data', String(30)))
959        bar = Table('bar', metadata,
960                    Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
961                    Column('bar_data', String(30)))
962
963        bar_foo = Table('bar_foo', metadata,
964                        Column('bar_id', Integer, ForeignKey('bar.id')),
965                        Column('foo_id', Integer, ForeignKey('foo.id'))
966        )
967
968    def test_basic(self):
969        class Foo(object): pass
970        class Bar(Foo): pass
971
972        foos = mapper(Foo, foo)
973        bars = mapper(Bar, bar, inherits=foos)
974        bars.add_property('lazy', relationship(foos, bar_foo, lazy='select'))
975        bars.add_property('eager', relationship(foos, bar_foo, lazy='joined'))
976
977        foo.insert().execute(data='foo1')
978        bar.insert().execute(id=1, data='bar1')
979
980        foo.insert().execute(data='foo2')
981        bar.insert().execute(id=2, data='bar2')
982
983        foo.insert().execute(data='foo3') #3
984        foo.insert().execute(data='foo4') #4
985
986        bar_foo.insert().execute(bar_id=1, foo_id=3)
987        bar_foo.insert().execute(bar_id=2, foo_id=4)
988
989        sess = create_session()
990        q = sess.query(Bar)
991        self.assert_(len(q.first().lazy) == 1)
992        self.assert_(len(q.first().eager) == 1)
993
994class EagerTargetingTest(fixtures.MappedTest):
995    """test a scenario where joined table inheritance might be
996    confused as an eagerly loaded joined table."""
997
998    @classmethod
999    def define_tables(cls, metadata):
1000        Table('a_table', metadata,
1001           Column('id', Integer, primary_key=True),
1002           Column('name', String(50)),
1003           Column('type', String(30), nullable=False),
1004           Column('parent_id', Integer, ForeignKey('a_table.id'))
1005        )
1006
1007        Table('b_table', metadata,
1008           Column('id', Integer, ForeignKey('a_table.id'), primary_key=True),
1009           Column('b_data', String(50)),
1010        )
1011
1012    def test_adapt_stringency(self):
1013        b_table, a_table = self.tables.b_table, self.tables.a_table
1014
1015        class A(fixtures.ComparableEntity):
1016            pass
1017        class B(A):
1018            pass
1019
1020        mapper(A, a_table, polymorphic_on=a_table.c.type, polymorphic_identity='A',
1021                properties={
1022                    'children': relationship(A, order_by=a_table.c.name)
1023            })
1024
1025        mapper(B, b_table, inherits=A, polymorphic_identity='B', properties={
1026                'b_derived':column_property(b_table.c.b_data + "DATA")
1027                })
1028
1029        sess=create_session()
1030
1031        b1=B(id=1, name='b1',b_data='i')
1032        sess.add(b1)
1033        sess.flush()
1034
1035        b2=B(id=2, name='b2', b_data='l', parent_id=1)
1036        sess.add(b2)
1037        sess.flush()
1038
1039        bid=b1.id
1040
1041        sess.expunge_all()
1042        node = sess.query(B).filter(B.id==bid).all()[0]
1043        eq_(node, B(id=1, name='b1',b_data='i'))
1044        eq_(node.children[0], B(id=2, name='b2',b_data='l'))
1045
1046        sess.expunge_all()
1047        node = sess.query(B).options(joinedload(B.children)).filter(B.id==bid).all()[0]
1048        eq_(node, B(id=1, name='b1',b_data='i'))
1049        eq_(node.children[0], B(id=2, name='b2',b_data='l'))
1050
1051class FlushTest(fixtures.MappedTest):
1052    """test dependency sorting among inheriting mappers"""
1053
1054    @classmethod
1055    def define_tables(cls, metadata):
1056        Table('users', metadata,
1057            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
1058            Column('email', String(128)),
1059            Column('password', String(16)),
1060        )
1061
1062        Table('roles', metadata,
1063            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
1064            Column('description', String(32))
1065        )
1066
1067        Table('user_roles', metadata,
1068            Column('user_id', Integer, ForeignKey('users.id'), primary_key=True),
1069            Column('role_id', Integer, ForeignKey('roles.id'), primary_key=True)
1070        )
1071
1072        Table('admins', metadata,
1073            Column('admin_id', Integer, primary_key=True, test_needs_autoincrement=True),
1074            Column('user_id', Integer, ForeignKey('users.id'))
1075        )
1076
1077    def test_one(self):
1078        admins, users, roles, user_roles = (self.tables.admins,
1079                                self.tables.users,
1080                                self.tables.roles,
1081                                self.tables.user_roles)
1082
1083        class User(object):pass
1084        class Role(object):pass
1085        class Admin(User):pass
1086        role_mapper = mapper(Role, roles)
1087        user_mapper = mapper(User, users, properties = {
1088                'roles' : relationship(Role, secondary=user_roles, lazy='joined')
1089            }
1090        )
1091        admin_mapper = mapper(Admin, admins, inherits=user_mapper)
1092        sess = create_session()
1093        adminrole = Role()
1094        sess.add(adminrole)
1095        sess.flush()
1096
1097        # create an Admin, and append a Role.  the dependency processors
1098        # corresponding to the "roles" attribute for the Admin mapper and the User mapper
1099        # have to ensure that two dependency processors don't fire off and insert the
1100        # many to many row twice.
1101        a = Admin()
1102        a.roles.append(adminrole)
1103        a.password = 'admin'
1104        sess.add(a)
1105        sess.flush()
1106
1107        assert user_roles.count().scalar() == 1
1108
1109    def test_two(self):
1110        admins, users, roles, user_roles = (self.tables.admins,
1111                                self.tables.users,
1112                                self.tables.roles,
1113                                self.tables.user_roles)
1114
1115        class User(object):
1116            def __init__(self, email=None, password=None):
1117                self.email = email
1118                self.password = password
1119
1120        class Role(object):
1121            def __init__(self, description=None):
1122                self.description = description
1123
1124        class Admin(User):pass
1125
1126        role_mapper = mapper(Role, roles)
1127        user_mapper = mapper(User, users, properties = {
1128                'roles' : relationship(Role, secondary=user_roles, lazy='joined')
1129            }
1130        )
1131
1132        admin_mapper = mapper(Admin, admins, inherits=user_mapper)
1133
1134        # create roles
1135        adminrole = Role('admin')
1136
1137        sess = create_session()
1138        sess.add(adminrole)
1139        sess.flush()
1140
1141        # create admin user
1142        a = Admin(email='tim', password='admin')
1143        a.roles.append(adminrole)
1144        sess.add(a)
1145        sess.flush()
1146
1147        a.password = 'sadmin'
1148        sess.flush()
1149        assert user_roles.count().scalar() == 1
1150
1151
1152class OptimizedGetOnDeferredTest(fixtures.MappedTest):
1153    """test that the 'optimized get' path accommodates deferred columns."""
1154
1155    @classmethod
1156    def define_tables(cls, metadata):
1157        Table(
1158            "a", metadata,
1159            Column('id', Integer, primary_key=True,
1160                   test_needs_autoincrement=True)
1161        )
1162        Table(
1163            "b", metadata,
1164            Column('id', Integer, ForeignKey('a.id'), primary_key=True),
1165            Column('data', String(10))
1166        )
1167
1168    @classmethod
1169    def setup_classes(cls):
1170        class A(cls.Basic):
1171            pass
1172
1173        class B(A):
1174            pass
1175
1176    @classmethod
1177    def setup_mappers(cls):
1178        A, B = cls.classes("A", "B")
1179        a, b = cls.tables("a", "b")
1180
1181        mapper(A, a)
1182        mapper(B, b, inherits=A, properties={
1183            'data': deferred(b.c.data),
1184            'expr': column_property(b.c.data + 'q', deferred=True)
1185        })
1186
1187    def test_column_property(self):
1188        A, B = self.classes("A", "B")
1189        sess = Session()
1190        b1 = B(data='x')
1191        sess.add(b1)
1192        sess.flush()
1193
1194        eq_(b1.expr, 'xq')
1195
1196    def test_expired_column(self):
1197        A, B = self.classes("A", "B")
1198        sess = Session()
1199        b1 = B(data='x')
1200        sess.add(b1)
1201        sess.flush()
1202        sess.expire(b1, ['data'])
1203
1204        eq_(b1.data, 'x')
1205
1206
1207class JoinedNoFKSortingTest(fixtures.MappedTest):
1208    @classmethod
1209    def define_tables(cls, metadata):
1210        Table("a", metadata,
1211                Column('id', Integer, primary_key=True,
1212                    test_needs_autoincrement=True)
1213            )
1214        Table("b", metadata,
1215                Column('id', Integer, primary_key=True)
1216            )
1217        Table("c", metadata,
1218                Column('id', Integer, primary_key=True)
1219            )
1220
1221    @classmethod
1222    def setup_classes(cls):
1223        class A(cls.Basic):
1224            pass
1225        class B(A):
1226            pass
1227        class C(A):
1228            pass
1229
1230    @classmethod
1231    def setup_mappers(cls):
1232        A, B, C = cls.classes.A, cls.classes.B, cls.classes.C
1233        mapper(A, cls.tables.a)
1234        mapper(B, cls.tables.b, inherits=A,
1235                    inherit_condition=cls.tables.a.c.id == cls.tables.b.c.id,
1236                    inherit_foreign_keys=cls.tables.b.c.id)
1237        mapper(C, cls.tables.c, inherits=A,
1238                    inherit_condition=cls.tables.a.c.id == cls.tables.c.c.id,
1239                    inherit_foreign_keys=cls.tables.c.c.id)
1240
1241    def test_ordering(self):
1242        B, C = self.classes.B, self.classes.C
1243        sess = Session()
1244        sess.add_all([B(), C(), B(), C()])
1245        self.assert_sql_execution(
1246                testing.db,
1247                sess.flush,
1248                CompiledSQL(
1249                    "INSERT INTO a () VALUES ()",
1250                    {}
1251                ),
1252                CompiledSQL(
1253                    "INSERT INTO a () VALUES ()",
1254                    {}
1255                ),
1256                CompiledSQL(
1257                    "INSERT INTO a () VALUES ()",
1258                    {}
1259                ),
1260                CompiledSQL(
1261                    "INSERT INTO a () VALUES ()",
1262                    {}
1263                ),
1264                AllOf(
1265                    CompiledSQL(
1266                        "INSERT INTO b (id) VALUES (:id)",
1267                        [{"id": 1}, {"id": 3}]
1268                    ),
1269                    CompiledSQL(
1270                        "INSERT INTO c (id) VALUES (:id)",
1271                        [{"id": 2}, {"id": 4}]
1272                    )
1273                )
1274        )
1275
1276class VersioningTest(fixtures.MappedTest):
1277    @classmethod
1278    def define_tables(cls, metadata):
1279        Table('base', metadata,
1280            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
1281            Column('version_id', Integer, nullable=False),
1282            Column('value', String(40)),
1283            Column('discriminator', Integer, nullable=False)
1284        )
1285        Table('subtable', metadata,
1286            Column('id', None, ForeignKey('base.id'), primary_key=True),
1287            Column('subdata', String(50))
1288            )
1289        Table('stuff', metadata,
1290            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
1291            Column('parent', Integer, ForeignKey('base.id'))
1292            )
1293
1294    @testing.emits_warning(r".*updated rowcount")
1295    @engines.close_open_connections
1296    def test_save_update(self):
1297        subtable, base, stuff = (self.tables.subtable,
1298                                self.tables.base,
1299                                self.tables.stuff)
1300
1301        class Base(fixtures.BasicEntity):
1302            pass
1303        class Sub(Base):
1304            pass
1305        class Stuff(Base):
1306            pass
1307        mapper(Stuff, stuff)
1308        mapper(Base, base,
1309                    polymorphic_on=base.c.discriminator,
1310                    version_id_col=base.c.version_id,
1311                    polymorphic_identity=1, properties={
1312            'stuff':relationship(Stuff)
1313        })
1314        mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
1315
1316        sess = create_session()
1317
1318        b1 = Base(value='b1')
1319        s1 = Sub(value='sub1', subdata='some subdata')
1320        sess.add(b1)
1321        sess.add(s1)
1322
1323        sess.flush()
1324
1325        sess2 = create_session()
1326        s2 = sess2.query(Base).get(s1.id)
1327        s2.subdata = 'sess2 subdata'
1328
1329        s1.subdata = 'sess1 subdata'
1330
1331        sess.flush()
1332
1333        assert_raises(orm_exc.StaleDataError,
1334                        sess2.query(Base).with_lockmode('read').get,
1335                        s1.id)
1336
1337        if not testing.db.dialect.supports_sane_rowcount:
1338            sess2.flush()
1339        else:
1340            assert_raises(orm_exc.StaleDataError, sess2.flush)
1341
1342        sess2.refresh(s2)
1343        if testing.db.dialect.supports_sane_rowcount:
1344            assert s2.subdata == 'sess1 subdata'
1345        s2.subdata = 'sess2 subdata'
1346        sess2.flush()
1347
1348    @testing.emits_warning(r".*(update|delete)d rowcount")
1349    def test_delete(self):
1350        subtable, base = self.tables.subtable, self.tables.base
1351
1352        class Base(fixtures.BasicEntity):
1353            pass
1354        class Sub(Base):
1355            pass
1356
1357        mapper(Base, base,
1358                    polymorphic_on=base.c.discriminator,
1359                    version_id_col=base.c.version_id, polymorphic_identity=1)
1360        mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
1361
1362        sess = create_session()
1363
1364        b1 = Base(value='b1')
1365        s1 = Sub(value='sub1', subdata='some subdata')
1366        s2 = Sub(value='sub2', subdata='some other subdata')
1367        sess.add(b1)
1368        sess.add(s1)
1369        sess.add(s2)
1370
1371        sess.flush()
1372
1373        sess2 = create_session()
1374        s3 = sess2.query(Base).get(s1.id)
1375        sess2.delete(s3)
1376        sess2.flush()
1377
1378        s2.subdata = 'some new subdata'
1379        sess.flush()
1380
1381        s1.subdata = 'some new subdata'
1382        if testing.db.dialect.supports_sane_rowcount:
1383            assert_raises(
1384                orm_exc.StaleDataError,
1385                sess.flush
1386            )
1387        else:
1388            sess.flush()
1389
1390class DistinctPKTest(fixtures.MappedTest):
1391    """test the construction of mapper.primary_key when an inheriting relationship
1392    joins on a column other than primary key column."""
1393
1394    run_inserts = 'once'
1395    run_deletes = None
1396
1397    @classmethod
1398    def define_tables(cls, metadata):
1399        global person_table, employee_table, Person, Employee
1400
1401        person_table = Table("persons", metadata,
1402                Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
1403                Column("name", String(80)),
1404                )
1405
1406        employee_table = Table("employees", metadata,
1407                Column("eid", Integer, primary_key=True, test_needs_autoincrement=True),
1408                Column("salary", Integer),
1409                Column("person_id", Integer, ForeignKey("persons.id")),
1410                )
1411
1412        class Person(object):
1413            def __init__(self, name):
1414                self.name = name
1415
1416        class Employee(Person): pass
1417
1418    @classmethod
1419    def insert_data(cls):
1420        person_insert = person_table.insert()
1421        person_insert.execute(id=1, name='alice')
1422        person_insert.execute(id=2, name='bob')
1423
1424        employee_insert = employee_table.insert()
1425        employee_insert.execute(id=2, salary=250, person_id=1) # alice
1426        employee_insert.execute(id=3, salary=200, person_id=2) # bob
1427
1428    def test_implicit(self):
1429        person_mapper = mapper(Person, person_table)
1430        mapper(Employee, employee_table, inherits=person_mapper)
1431        assert list(class_mapper(Employee).primary_key) == [person_table.c.id]
1432
1433    def test_explicit_props(self):
1434        person_mapper = mapper(Person, person_table)
1435        mapper(Employee, employee_table, inherits=person_mapper,
1436                        properties={'pid':person_table.c.id,
1437                                    'eid':employee_table.c.eid})
1438        self._do_test(False)
1439
1440    def test_explicit_composite_pk(self):
1441        person_mapper = mapper(Person, person_table)
1442        mapper(Employee, employee_table,
1443                    inherits=person_mapper,
1444                    properties=dict(id=[employee_table.c.eid, person_table.c.id]),
1445                    primary_key=[person_table.c.id, employee_table.c.eid])
1446        assert_raises_message(sa_exc.SAWarning,
1447                                    r"On mapper Mapper\|Employee\|employees, "
1448                                    "primary key column 'persons.id' is being "
1449                                    "combined with distinct primary key column 'employees.eid' "
1450                                    "in attribute 'id'.  Use explicit properties to give "
1451                                    "each column its own mapped attribute name.",
1452            self._do_test, True
1453        )
1454
1455    def test_explicit_pk(self):
1456        person_mapper = mapper(Person, person_table)
1457        mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id])
1458        self._do_test(False)
1459
1460    def _do_test(self, composite):
1461        session = create_session()
1462        query = session.query(Employee)
1463
1464        if composite:
1465            alice1 = query.get([1,2])
1466            bob = query.get([2,3])
1467            alice2 = query.get([1,2])
1468        else:
1469            alice1 = query.get(1)
1470            bob = query.get(2)
1471            alice2 = query.get(1)
1472
1473            assert alice1.name == alice2.name == 'alice'
1474            assert bob.name == 'bob'
1475
1476class SyncCompileTest(fixtures.MappedTest):
1477    """test that syncrules compile properly on custom inherit conds"""
1478
1479    @classmethod
1480    def define_tables(cls, metadata):
1481        global _a_table, _b_table, _c_table
1482
1483        _a_table = Table('a', metadata,
1484           Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
1485           Column('data1', String(128))
1486        )
1487
1488        _b_table = Table('b', metadata,
1489           Column('a_id', Integer, ForeignKey('a.id'), primary_key=True),
1490           Column('data2', String(128))
1491        )
1492
1493        _c_table = Table('c', metadata,
1494        #   Column('a_id', Integer, ForeignKey('b.a_id'), primary_key=True), #works
1495           Column('b_a_id', Integer, ForeignKey('b.a_id'), primary_key=True),
1496           Column('data3', String(128))
1497        )
1498
1499    def test_joins(self):
1500        for j1 in (None, _b_table.c.a_id==_a_table.c.id, _a_table.c.id==_b_table.c.a_id):
1501            for j2 in (None, _b_table.c.a_id==_c_table.c.b_a_id,
1502                                    _c_table.c.b_a_id==_b_table.c.a_id):
1503                self._do_test(j1, j2)
1504                for t in reversed(_a_table.metadata.sorted_tables):
1505                    t.delete().execute().close()
1506
1507    def _do_test(self, j1, j2):
1508        class A(object):
1509           def __init__(self, **kwargs):
1510               for key, value in list(kwargs.items()):
1511                    setattr(self, key, value)
1512
1513        class B(A):
1514            pass
1515
1516        class C(B):
1517            pass
1518
1519        mapper(A, _a_table)
1520        mapper(B, _b_table, inherits=A,
1521               inherit_condition=j1
1522               )
1523        mapper(C, _c_table, inherits=B,
1524               inherit_condition=j2
1525               )
1526
1527        session = create_session()
1528
1529        a = A(data1='a1')
1530        session.add(a)
1531
1532        b = B(data1='b1', data2='b2')
1533        session.add(b)
1534
1535        c = C(data1='c1', data2='c2', data3='c3')
1536        session.add(c)
1537
1538        session.flush()
1539        session.expunge_all()
1540
1541        assert len(session.query(A).all()) == 3
1542        assert len(session.query(B).all()) == 2
1543        assert len(session.query(C).all()) == 1
1544
1545class OverrideColKeyTest(fixtures.MappedTest):
1546    """test overriding of column attributes."""
1547
1548    @classmethod
1549    def define_tables(cls, metadata):
1550        global base, subtable, subtable_two
1551
1552        base = Table('base', metadata,
1553            Column('base_id', Integer, primary_key=True, test_needs_autoincrement=True),
1554            Column('data', String(255)),
1555            Column('sqlite_fixer', String(10))
1556            )
1557
1558        subtable = Table('subtable', metadata,
1559            Column('base_id', Integer, ForeignKey('base.base_id'), primary_key=True),
1560            Column('subdata', String(255))
1561        )
1562        subtable_two = Table('subtable_two', metadata,
1563            Column('base_id', Integer, primary_key=True),
1564            Column('fk_base_id', Integer, ForeignKey('base.base_id')),
1565            Column('subdata', String(255))
1566        )
1567
1568
1569    def test_plain(self):
1570        # control case
1571        class Base(object):
1572            pass
1573        class Sub(Base):
1574            pass
1575
1576        mapper(Base, base)
1577        mapper(Sub, subtable, inherits=Base)
1578
1579        # Sub gets a "base_id" property using the "base_id"
1580        # column of both tables.
1581        eq_(
1582            class_mapper(Sub).get_property('base_id').columns,
1583            [subtable.c.base_id, base.c.base_id]
1584        )
1585
1586    def test_override_explicit(self):
1587        # this pattern is what you see when using declarative
1588        # in particular, here we do a "manual" version of
1589        # what we'd like the mapper to do.
1590
1591        class Base(object):
1592            pass
1593        class Sub(Base):
1594            pass
1595
1596        mapper(Base, base, properties={
1597            'id':base.c.base_id
1598        })
1599        mapper(Sub, subtable, inherits=Base, properties={
1600            # this is the manual way to do it, is not really
1601            # possible in declarative
1602            'id':[base.c.base_id, subtable.c.base_id]
1603        })
1604
1605        eq_(
1606            class_mapper(Sub).get_property('id').columns,
1607            [base.c.base_id, subtable.c.base_id]
1608        )
1609
1610        s1 = Sub()
1611        s1.id = 10
1612        sess = create_session()
1613        sess.add(s1)
1614        sess.flush()
1615        assert sess.query(Sub).get(10) is s1
1616
1617    def test_override_onlyinparent(self):
1618        class Base(object):
1619            pass
1620        class Sub(Base):
1621            pass
1622
1623        mapper(Base, base, properties={
1624            'id':base.c.base_id
1625        })
1626        mapper(Sub, subtable, inherits=Base)
1627
1628        eq_(
1629            class_mapper(Sub).get_property('id').columns,
1630            [base.c.base_id]
1631        )
1632
1633        eq_(
1634            class_mapper(Sub).get_property('base_id').columns,
1635            [subtable.c.base_id]
1636        )
1637
1638        s1 = Sub()
1639        s1.id = 10
1640
1641        s2 = Sub()
1642        s2.base_id = 15
1643
1644        sess = create_session()
1645        sess.add_all([s1, s2])
1646        sess.flush()
1647
1648        # s1 gets '10'
1649        assert sess.query(Sub).get(10) is s1
1650
1651        # s2 gets a new id, base_id is overwritten by the ultimate
1652        # PK col
1653        assert s2.id == s2.base_id != 15
1654
1655    def test_override_implicit(self):
1656        # this is originally [ticket:1111].
1657        # the pattern here is now disallowed by [ticket:1892]
1658
1659        class Base(object):
1660            pass
1661        class Sub(Base):
1662            pass
1663
1664        mapper(Base, base, properties={
1665            'id':base.c.base_id
1666        })
1667
1668        def go():
1669            mapper(Sub, subtable, inherits=Base, properties={
1670                'id':subtable.c.base_id
1671            })
1672        # Sub mapper compilation needs to detect that "base.c.base_id"
1673        # is renamed in the inherited mapper as "id", even though
1674        # it has its own "id" property.  It then generates
1675        # an exception in 0.7 due to the implicit conflict.
1676        assert_raises(sa_exc.InvalidRequestError, go)
1677
1678    def test_pk_fk_different(self):
1679        class Base(object):
1680            pass
1681        class Sub(Base):
1682            pass
1683
1684        mapper(Base, base)
1685
1686        def go():
1687            mapper(Sub, subtable_two, inherits=Base)
1688        assert_raises_message(
1689            sa_exc.SAWarning,
1690            "Implicitly combining column base.base_id with "
1691            "column subtable_two.base_id under attribute 'base_id'",
1692            go
1693        )
1694
1695    def test_plain_descriptor(self):
1696        """test that descriptors prevent inheritance from propigating properties to subclasses."""
1697
1698        class Base(object):
1699            pass
1700        class Sub(Base):
1701            @property
1702            def data(self):
1703                return "im the data"
1704
1705        mapper(Base, base)
1706        mapper(Sub, subtable, inherits=Base)
1707
1708        s1 = Sub()
1709        sess = create_session()
1710        sess.add(s1)
1711        sess.flush()
1712        assert sess.query(Sub).one().data == "im the data"
1713
1714    def test_custom_descriptor(self):
1715        """test that descriptors prevent inheritance from propigating properties to subclasses."""
1716
1717        class MyDesc(object):
1718            def __get__(self, instance, owner):
1719                if instance is None:
1720                    return self
1721                return "im the data"
1722
1723        class Base(object):
1724            pass
1725        class Sub(Base):
1726            data = MyDesc()
1727
1728        mapper(Base, base)
1729        mapper(Sub, subtable, inherits=Base)
1730
1731        s1 = Sub()
1732        sess = create_session()
1733        sess.add(s1)
1734        sess.flush()
1735        assert sess.query(Sub).one().data == "im the data"
1736
1737    def test_sub_columns_over_base_descriptors(self):
1738        class Base(object):
1739            @property
1740            def subdata(self):
1741                return "this is base"
1742
1743        class Sub(Base):
1744            pass
1745
1746        mapper(Base, base)
1747        mapper(Sub, subtable, inherits=Base)
1748
1749        sess = create_session()
1750        b1 = Base()
1751        assert b1.subdata == "this is base"
1752        s1 = Sub()
1753        s1.subdata = "this is sub"
1754        assert s1.subdata == "this is sub"
1755
1756        sess.add_all([s1, b1])
1757        sess.flush()
1758        sess.expunge_all()
1759
1760        assert sess.query(Base).get(b1.base_id).subdata == "this is base"
1761        assert sess.query(Sub).get(s1.base_id).subdata == "this is sub"
1762
1763    def test_base_descriptors_over_base_cols(self):
1764        class Base(object):
1765            @property
1766            def data(self):
1767                return "this is base"
1768
1769        class Sub(Base):
1770            pass
1771
1772        mapper(Base, base)
1773        mapper(Sub, subtable, inherits=Base)
1774
1775        sess = create_session()
1776        b1 = Base()
1777        assert b1.data == "this is base"
1778        s1 = Sub()
1779        assert s1.data == "this is base"
1780
1781        sess.add_all([s1, b1])
1782        sess.flush()
1783        sess.expunge_all()
1784
1785        assert sess.query(Base).get(b1.base_id).data == "this is base"
1786        assert sess.query(Sub).get(s1.base_id).data == "this is base"
1787
1788class OptimizedLoadTest(fixtures.MappedTest):
1789    """tests for the "optimized load" routine."""
1790
1791    @classmethod
1792    def define_tables(cls, metadata):
1793        Table('base', metadata,
1794            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
1795            Column('data', String(50)),
1796            Column('type', String(50)),
1797            Column('counter', Integer, server_default="1")
1798        )
1799        Table('sub', metadata,
1800            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
1801            Column('sub', String(50)),
1802            Column('subcounter', Integer, server_default="1"),
1803            Column('subcounter2', Integer, server_default="1")
1804        )
1805        Table('subsub', metadata,
1806            Column('id', Integer, ForeignKey('sub.id'), primary_key=True),
1807            Column('subsubcounter2', Integer, server_default="1")
1808        )
1809        Table('with_comp', metadata,
1810            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
1811            Column('a', String(10)),
1812            Column('b', String(10))
1813        )
1814
1815    def test_no_optimize_on_map_to_join(self):
1816        base, sub = self.tables.base, self.tables.sub
1817
1818        class Base(fixtures.ComparableEntity):
1819            pass
1820
1821        class JoinBase(fixtures.ComparableEntity):
1822            pass
1823        class SubJoinBase(JoinBase):
1824            pass
1825
1826        mapper(Base, base)
1827        mapper(JoinBase, base.outerjoin(sub), properties=util.OrderedDict(
1828                [('id', [base.c.id, sub.c.id]),
1829                ('counter', [base.c.counter, sub.c.subcounter])])
1830            )
1831        mapper(SubJoinBase, inherits=JoinBase)
1832
1833        sess = Session()
1834        sess.add(Base(data='data'))
1835        sess.commit()
1836
1837        sjb = sess.query(SubJoinBase).one()
1838        sjb_id = sjb.id
1839        sess.expire(sjb)
1840
1841        # this should not use the optimized load,
1842        # which assumes discrete tables
1843        def go():
1844            eq_(sjb.data, 'data')
1845
1846        self.assert_sql_execution(
1847            testing.db,
1848            go,
1849            CompiledSQL(
1850                "SELECT base.id AS base_id, sub.id AS sub_id, "
1851                "base.counter AS base_counter, sub.subcounter AS sub_subcounter, "
1852                "base.data AS base_data, base.type AS base_type, "
1853                "sub.sub AS sub_sub, sub.subcounter2 AS sub_subcounter2 "
1854                "FROM base LEFT OUTER JOIN sub ON base.id = sub.id "
1855                "WHERE base.id = :param_1",
1856                {'param_1': sjb_id}
1857            ),
1858        )
1859
1860
1861    def test_optimized_passes(self):
1862        """"test that the 'optimized load' routine doesn't crash when
1863        a column in the join condition is not available."""
1864
1865        base, sub = self.tables.base, self.tables.sub
1866
1867
1868        class Base(fixtures.ComparableEntity):
1869            pass
1870        class Sub(Base):
1871            pass
1872
1873        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
1874
1875        # redefine Sub's "id" to favor the "id" col in the subtable.
1876        # "id" is also part of the primary join condition
1877        mapper(Sub, sub, inherits=Base,
1878                        polymorphic_identity='sub',
1879                        properties={'id':[sub.c.id, base.c.id]})
1880        sess = sessionmaker()()
1881        s1 = Sub(data='s1data', sub='s1sub')
1882        sess.add(s1)
1883        sess.commit()
1884        sess.expunge_all()
1885
1886        # load s1 via Base.  s1.id won't populate since it's relative to
1887        # the "sub" table.  The optimized load kicks in and tries to
1888        # generate on the primary join, but cannot since "id" is itself unloaded.
1889        # the optimized load needs to return "None" so regular full-row loading proceeds
1890        s1 = sess.query(Base).first()
1891        assert s1.sub == 's1sub'
1892
1893    def test_column_expression(self):
1894        base, sub = self.tables.base, self.tables.sub
1895
1896        class Base(fixtures.ComparableEntity):
1897            pass
1898        class Sub(Base):
1899            pass
1900        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
1901        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={
1902            'concat': column_property(sub.c.sub + "|" + sub.c.sub)
1903        })
1904        sess = sessionmaker()()
1905        s1 = Sub(data='s1data', sub='s1sub')
1906        sess.add(s1)
1907        sess.commit()
1908        sess.expunge_all()
1909        s1 = sess.query(Base).first()
1910        assert s1.concat == 's1sub|s1sub'
1911
1912    def test_column_expression_joined(self):
1913        base, sub = self.tables.base, self.tables.sub
1914
1915        class Base(fixtures.ComparableEntity):
1916            pass
1917        class Sub(Base):
1918            pass
1919        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
1920        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={
1921            'concat': column_property(base.c.data + "|" + sub.c.sub)
1922        })
1923        sess = sessionmaker()()
1924        s1 = Sub(data='s1data', sub='s1sub')
1925        s2 = Sub(data='s2data', sub='s2sub')
1926        s3 = Sub(data='s3data', sub='s3sub')
1927        sess.add_all([s1, s2, s3])
1928        sess.commit()
1929        sess.expunge_all()
1930        # query a bunch of rows to ensure there's no cartesian
1931        # product against "base" occurring, it is in fact
1932        # detecting that "base" needs to be in the join
1933        # criterion
1934        eq_(
1935            sess.query(Base).order_by(Base.id).all(),
1936            [
1937                Sub(data='s1data', sub='s1sub', concat='s1data|s1sub'),
1938                Sub(data='s2data', sub='s2sub', concat='s2data|s2sub'),
1939                Sub(data='s3data', sub='s3sub', concat='s3data|s3sub')
1940            ]
1941        )
1942
1943    def test_composite_column_joined(self):
1944        base, with_comp = self.tables.base, self.tables.with_comp
1945
1946        class Base(fixtures.BasicEntity):
1947            pass
1948        class WithComp(Base):
1949            pass
1950        class Comp(object):
1951            def __init__(self, a, b):
1952                self.a = a
1953                self.b = b
1954            def __composite_values__(self):
1955                return self.a, self.b
1956            def __eq__(self, other):
1957                return (self.a == other.a) and (self.b == other.b)
1958        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
1959        mapper(WithComp, with_comp, inherits=Base, polymorphic_identity='wc', properties={
1960            'comp': composite(Comp, with_comp.c.a, with_comp.c.b)
1961        })
1962        sess = sessionmaker()()
1963        s1 = WithComp(data='s1data', comp=Comp('ham', 'cheese'))
1964        s2 = WithComp(data='s2data', comp=Comp('bacon', 'eggs'))
1965        sess.add_all([s1, s2])
1966        sess.commit()
1967        sess.expunge_all()
1968        s1test, s2test = sess.query(Base).order_by(Base.id).all()
1969        assert s1test.comp
1970        assert s2test.comp
1971        eq_(s1test.comp, Comp('ham', 'cheese'))
1972        eq_(s2test.comp, Comp('bacon', 'eggs'))
1973
1974    def test_load_expired_on_pending(self):
1975        base, sub = self.tables.base, self.tables.sub
1976
1977        class Base(fixtures.BasicEntity):
1978            pass
1979        class Sub(Base):
1980            pass
1981        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
1982        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub')
1983        sess = Session()
1984        s1 = Sub(data='s1')
1985        sess.add(s1)
1986        self.assert_sql_execution(
1987                testing.db,
1988                sess.flush,
1989                CompiledSQL(
1990                    "INSERT INTO base (data, type) VALUES (:data, :type)",
1991                    [{'data':'s1','type':'sub'}]
1992                ),
1993                CompiledSQL(
1994                    "INSERT INTO sub (id, sub) VALUES (:id, :sub)",
1995                    lambda ctx:{'id':s1.id, 'sub':None}
1996                ),
1997        )
1998        def go():
1999            eq_( s1.subcounter2, 1 )
2000        self.assert_sql_execution(
2001            testing.db,
2002            go,
2003            CompiledSQL(
2004                "SELECT base.counter AS base_counter, sub.subcounter AS sub_subcounter, "
2005                "sub.subcounter2 AS sub_subcounter2 FROM base JOIN sub "
2006                "ON base.id = sub.id WHERE base.id = :param_1",
2007                lambda ctx:{'param_1': s1.id}
2008            ),
2009        )
2010
2011    def test_dont_generate_on_none(self):
2012        base, sub = self.tables.base, self.tables.sub
2013
2014        class Base(fixtures.BasicEntity):
2015            pass
2016        class Sub(Base):
2017            pass
2018        mapper(Base, base, polymorphic_on=base.c.type,
2019                            polymorphic_identity='base')
2020        m = mapper(Sub, sub, inherits=Base, polymorphic_identity='sub')
2021
2022        s1 = Sub()
2023        assert m._optimized_get_statement(attributes.instance_state(s1),
2024                                ['subcounter2']) is None
2025
2026        # loads s1.id as None
2027        eq_(s1.id, None)
2028
2029        # this now will come up with a value of None for id - should reject
2030        assert m._optimized_get_statement(attributes.instance_state(s1),
2031                                ['subcounter2']) is None
2032
2033        s1.id = 1
2034        attributes.instance_state(s1)._commit_all(s1.__dict__, None)
2035        assert m._optimized_get_statement(attributes.instance_state(s1),
2036                                ['subcounter2']) is not None
2037
2038    def test_load_expired_on_pending_twolevel(self):
2039        base, sub, subsub = (self.tables.base,
2040                                self.tables.sub,
2041                                self.tables.subsub)
2042
2043        class Base(fixtures.BasicEntity):
2044            pass
2045        class Sub(Base):
2046            pass
2047        class SubSub(Sub):
2048            pass
2049
2050        mapper(Base, base, polymorphic_on=base.c.type,
2051                    polymorphic_identity='base')
2052        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub')
2053        mapper(SubSub, subsub, inherits=Sub, polymorphic_identity='subsub')
2054        sess = Session()
2055        s1 = SubSub(data='s1', counter=1, subcounter=2)
2056        sess.add(s1)
2057        self.assert_sql_execution(
2058                testing.db,
2059                sess.flush,
2060                CompiledSQL(
2061                    "INSERT INTO base (data, type, counter) VALUES "
2062                    "(:data, :type, :counter)",
2063                    [{'data':'s1','type':'subsub','counter':1}]
2064                ),
2065                CompiledSQL(
2066                    "INSERT INTO sub (id, sub, subcounter) VALUES "
2067                    "(:id, :sub, :subcounter)",
2068                    lambda ctx:[{'subcounter': 2, 'sub': None, 'id': s1.id}]
2069                ),
2070                CompiledSQL(
2071                    "INSERT INTO subsub (id) VALUES (:id)",
2072                    lambda ctx:{'id':s1.id}
2073                ),
2074        )
2075
2076        def go():
2077            eq_(
2078                s1.subcounter2, 1
2079            )
2080        self.assert_sql_execution(
2081            testing.db,
2082            go,
2083            Or(
2084                CompiledSQL(
2085                    "SELECT subsub.subsubcounter2 AS subsub_subsubcounter2, "
2086                    "sub.subcounter2 AS sub_subcounter2 FROM subsub, sub "
2087                    "WHERE :param_1 = sub.id AND sub.id = subsub.id",
2088                    lambda ctx: {'param_1': s1.id}
2089                ),
2090                CompiledSQL(
2091                    "SELECT sub.subcounter2 AS sub_subcounter2, "
2092                    "subsub.subsubcounter2 AS subsub_subsubcounter2 "
2093                    "FROM sub, subsub "
2094                    "WHERE :param_1 = sub.id AND sub.id = subsub.id",
2095                    lambda ctx: {'param_1': s1.id}
2096                ),
2097            )
2098        )
2099
2100class TransientInheritingGCTest(fixtures.TestBase):
2101    __requires__ = ('cpython', 'no_coverage')
2102
2103    def _fixture(self):
2104        Base = declarative_base()
2105
2106        class A(Base):
2107            __tablename__ = 'a'
2108            id = Column(Integer, primary_key=True,
2109                                    test_needs_autoincrement=True)
2110            data = Column(String(10))
2111        self.A = A
2112        return Base
2113
2114    def setUp(self):
2115        self.Base = self._fixture()
2116
2117    def tearDown(self):
2118        self.Base.metadata.drop_all(testing.db)
2119        #clear_mappers()
2120        self.Base = None
2121
2122    def _do_test(self, go):
2123        B = go()
2124        self.Base.metadata.create_all(testing.db)
2125        sess = Session(testing.db)
2126        sess.add(B(data='some b'))
2127        sess.commit()
2128
2129        b1 = sess.query(B).one()
2130        assert isinstance(b1, B)
2131        sess.close()
2132        del sess
2133        del b1
2134        del B
2135
2136        gc_collect()
2137
2138        eq_(
2139            len(self.A.__subclasses__()),
2140            0)
2141
2142    def test_single(self):
2143        def go():
2144            class B(self.A):
2145                pass
2146            return B
2147        self._do_test(go)
2148
2149    @testing.fails_if(lambda: True,
2150                "not supported for joined inh right now.")
2151    def test_joined(self):
2152        def go():
2153            class B(self.A):
2154                __tablename__ = 'b'
2155                id = Column(Integer, ForeignKey('a.id'),
2156                        primary_key=True)
2157            return B
2158        self._do_test(go)
2159
2160class NoPKOnSubTableWarningTest(fixtures.TestBase):
2161
2162    def _fixture(self):
2163        metadata = MetaData()
2164        parent = Table('parent', metadata,
2165            Column('id', Integer, primary_key=True)
2166        )
2167        child = Table('child', metadata,
2168            Column('id', Integer, ForeignKey('parent.id'))
2169        )
2170        return parent, child
2171
2172    def tearDown(self):
2173        clear_mappers()
2174
2175    def test_warning_on_sub(self):
2176        parent, child = self._fixture()
2177
2178        class P(object):
2179            pass
2180        class C(P):
2181            pass
2182
2183        mapper(P, parent)
2184        assert_raises_message(
2185            sa_exc.SAWarning,
2186            "Could not assemble any primary keys for locally mapped "
2187            "table 'child' - no rows will be persisted in this Table.",
2188            mapper, C, child, inherits=P
2189        )
2190
2191    def test_no_warning_with_explicit(self):
2192        parent, child = self._fixture()
2193
2194        class P(object):
2195            pass
2196        class C(P):
2197            pass
2198
2199        mapper(P, parent)
2200        mc = mapper(C, child, inherits=P, primary_key=[parent.c.id])
2201        eq_(mc.primary_key, (parent.c.id,))
2202
2203class InhCondTest(fixtures.TestBase):
2204    def test_inh_cond_nonexistent_table_unrelated(self):
2205        metadata = MetaData()
2206        base_table = Table("base", metadata,
2207            Column("id", Integer, primary_key=True)
2208        )
2209        derived_table = Table("derived", metadata,
2210            Column("id", Integer, ForeignKey("base.id"), primary_key=True),
2211            Column("owner_id", Integer, ForeignKey("owner.owner_id"))
2212        )
2213
2214        class Base(object):
2215            pass
2216
2217        class Derived(Base):
2218            pass
2219
2220        mapper(Base, base_table)
2221        # succeeds, despite "owner" table not configured yet
2222        m2 = mapper(Derived, derived_table,
2223                    inherits=Base)
2224        assert m2.inherit_condition.compare(
2225                    base_table.c.id==derived_table.c.id
2226                )
2227
2228    def test_inh_cond_nonexistent_col_unrelated(self):
2229        m = MetaData()
2230        base_table = Table("base", m,
2231            Column("id", Integer, primary_key=True)
2232        )
2233        derived_table = Table("derived", m,
2234            Column("id", Integer, ForeignKey('base.id'),
2235                primary_key=True),
2236            Column('order_id', Integer, ForeignKey('order.foo'))
2237        )
2238        order_table = Table('order', m, Column('id', Integer, primary_key=True))
2239        class Base(object):
2240            pass
2241
2242        class Derived(Base):
2243            pass
2244
2245        mapper(Base, base_table)
2246
2247        # succeeds, despite "order.foo" doesn't exist
2248        m2 = mapper(Derived, derived_table, inherits=Base)
2249        assert m2.inherit_condition.compare(
2250                    base_table.c.id==derived_table.c.id
2251                )
2252
2253    def test_inh_cond_no_fk(self):
2254        metadata = MetaData()
2255        base_table = Table("base", metadata,
2256            Column("id", Integer, primary_key=True)
2257        )
2258        derived_table = Table("derived", metadata,
2259            Column("id", Integer, primary_key=True),
2260        )
2261
2262        class Base(object):
2263            pass
2264
2265        class Derived(Base):
2266            pass
2267
2268        mapper(Base, base_table)
2269        assert_raises_message(
2270            sa_exc.ArgumentError,
2271            "Can't find any foreign key relationships between "
2272            "'base' and 'derived'.",
2273            mapper,
2274            Derived, derived_table,  inherits=Base
2275        )
2276
2277    def test_inh_cond_nonexistent_table_related(self):
2278        m1 = MetaData()
2279        m2 = MetaData()
2280        base_table = Table("base", m1,
2281            Column("id", Integer, primary_key=True)
2282        )
2283        derived_table = Table("derived", m2,
2284            Column("id", Integer, ForeignKey('base.id'),
2285                primary_key=True),
2286        )
2287
2288        class Base(object):
2289            pass
2290
2291        class Derived(Base):
2292            pass
2293
2294        mapper(Base, base_table)
2295
2296        # the ForeignKey def is correct but there are two
2297        # different metadatas.  Would like the traditional
2298        # "noreferencedtable" error to raise so that the
2299        # user is directed towards the FK definition in question.
2300        assert_raises_message(
2301            sa_exc.NoReferencedTableError,
2302            "Foreign key associated with column 'derived.id' "
2303            "could not find table 'base' with which to generate "
2304            "a foreign key to target column 'id'",
2305            mapper,
2306            Derived, derived_table,  inherits=Base
2307        )
2308
2309    def test_inh_cond_nonexistent_col_related(self):
2310        m = MetaData()
2311        base_table = Table("base", m,
2312            Column("id", Integer, primary_key=True)
2313        )
2314        derived_table = Table("derived", m,
2315            Column("id", Integer, ForeignKey('base.q'),
2316                primary_key=True),
2317        )
2318
2319        class Base(object):
2320            pass
2321
2322        class Derived(Base):
2323            pass
2324
2325        mapper(Base, base_table)
2326
2327        assert_raises_message(
2328            sa_exc.NoReferencedColumnError,
2329            "Could not initialize target column for ForeignKey "
2330            "'base.q' on table "
2331            "'derived': table 'base' has no column named 'q'",
2332            mapper,
2333            Derived, derived_table,  inherits=Base
2334        )
2335
2336
2337class PKDiscriminatorTest(fixtures.MappedTest):
2338    @classmethod
2339    def define_tables(cls, metadata):
2340        parents = Table('parents', metadata,
2341                           Column('id', Integer, primary_key=True,
2342                                    test_needs_autoincrement=True),
2343                           Column('name', String(60)))
2344
2345        children = Table('children', metadata,
2346                        Column('id', Integer, ForeignKey('parents.id'),
2347                                    primary_key=True),
2348                        Column('type', Integer,primary_key=True),
2349                        Column('name', String(60)))
2350
2351    def test_pk_as_discriminator(self):
2352        parents, children = self.tables.parents, self.tables.children
2353
2354        class Parent(object):
2355                def __init__(self, name=None):
2356                    self.name = name
2357
2358        class Child(object):
2359            def __init__(self, name=None):
2360                self.name = name
2361
2362        class A(Child):
2363            pass
2364
2365        mapper(Parent, parents, properties={
2366            'children': relationship(Child, backref='parent'),
2367        })
2368        mapper(Child, children, polymorphic_on=children.c.type,
2369            polymorphic_identity=1)
2370
2371        mapper(A, inherits=Child, polymorphic_identity=2)
2372
2373        s = create_session()
2374        p = Parent('p1')
2375        a = A('a1')
2376        p.children.append(a)
2377        s.add(p)
2378        s.flush()
2379
2380        assert a.id
2381        assert a.type == 2
2382
2383        p.name='p1new'
2384        a.name='a1new'
2385        s.flush()
2386
2387        s.expire_all()
2388        assert a.name=='a1new'
2389        assert p.name=='p1new'
2390
2391class NoPolyIdentInMiddleTest(fixtures.MappedTest):
2392    @classmethod
2393    def define_tables(cls, metadata):
2394        Table('base', metadata,
2395            Column('id', Integer, primary_key=True,
2396                            test_needs_autoincrement=True),
2397            Column('type', String(50), nullable=False),
2398        )
2399
2400    @classmethod
2401    def setup_classes(cls):
2402        class A(cls.Comparable):
2403            pass
2404        class B(A):
2405            pass
2406        class C(B):
2407            pass
2408        class D(B):
2409            pass
2410        class E(A):
2411            pass
2412
2413    @classmethod
2414    def setup_mappers(cls):
2415        A, C, B, E, D, base = (cls.classes.A,
2416                                cls.classes.C,
2417                                cls.classes.B,
2418                                cls.classes.E,
2419                                cls.classes.D,
2420                                cls.tables.base)
2421
2422        mapper(A, base, polymorphic_on=base.c.type)
2423        mapper(B, inherits=A, )
2424        mapper(C, inherits=B, polymorphic_identity='c')
2425        mapper(D, inherits=B, polymorphic_identity='d')
2426        mapper(E, inherits=A, polymorphic_identity='e')
2427
2428    def test_load_from_middle(self):
2429        C, B = self.classes.C, self.classes.B
2430
2431        s = Session()
2432        s.add(C())
2433        o = s.query(B).first()
2434        eq_(o.type, 'c')
2435        assert isinstance(o, C)
2436
2437    def test_load_from_base(self):
2438        A, C = self.classes.A, self.classes.C
2439
2440        s = Session()
2441        s.add(C())
2442        o = s.query(A).first()
2443        eq_(o.type, 'c')
2444        assert isinstance(o, C)
2445
2446    def test_discriminator(self):
2447        C, B, base = (self.classes.C,
2448                                self.classes.B,
2449                                self.tables.base)
2450
2451        assert class_mapper(B).polymorphic_on is base.c.type
2452        assert class_mapper(C).polymorphic_on is base.c.type
2453
2454    def test_load_multiple_from_middle(self):
2455        C, B, E, D, base = (self.classes.C,
2456                                self.classes.B,
2457                                self.classes.E,
2458                                self.classes.D,
2459                                self.tables.base)
2460
2461        s = Session()
2462        s.add_all([C(), D(), E()])
2463        eq_(
2464            s.query(B).order_by(base.c.type).all(),
2465            [C(), D()]
2466        )
2467
2468class DeleteOrphanTest(fixtures.MappedTest):
2469    """Test the fairly obvious, that an error is raised
2470    when attempting to insert an orphan.
2471
2472    Previous SQLA versions would check this constraint
2473    in memory which is the original rationale for this test.
2474
2475    """
2476
2477    @classmethod
2478    def define_tables(cls, metadata):
2479        global single, parent
2480        single = Table('single', metadata,
2481            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
2482            Column('type', String(50), nullable=False),
2483            Column('data', String(50)),
2484            Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False),
2485            )
2486
2487        parent = Table('parent', metadata,
2488                Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
2489                Column('data', String(50))
2490            )
2491
2492    def test_orphan_message(self):
2493        class Base(fixtures.BasicEntity):
2494            pass
2495
2496        class SubClass(Base):
2497            pass
2498
2499        class Parent(fixtures.BasicEntity):
2500            pass
2501
2502        mapper(Base, single, polymorphic_on=single.c.type, polymorphic_identity='base')
2503        mapper(SubClass, inherits=Base, polymorphic_identity='sub')
2504        mapper(Parent, parent, properties={
2505            'related':relationship(Base, cascade="all, delete-orphan")
2506        })
2507
2508        sess = create_session()
2509        s1 = SubClass(data='s1')
2510        sess.add(s1)
2511        assert_raises(sa_exc.DBAPIError, sess.flush)
2512
2513class PolymorphicUnionTest(fixtures.TestBase, testing.AssertsCompiledSQL):
2514    __dialect__ = 'default'
2515
2516    def _fixture(self):
2517        t1 = table('t1', column('c1', Integer),
2518                        column('c2', Integer),
2519                        column('c3', Integer))
2520        t2 = table('t2', column('c1', Integer), column('c2', Integer),
2521                                column('c3', Integer),
2522                                column('c4', Integer))
2523        t3 = table('t3', column('c1', Integer),
2524                                column('c3', Integer),
2525                                column('c5', Integer))
2526        return t1, t2, t3
2527
2528    def test_type_col_present(self):
2529        t1, t2, t3 = self._fixture()
2530        self.assert_compile(
2531            polymorphic_union(
2532                util.OrderedDict([("a", t1), ("b", t2), ("c", t3)]),
2533                'q1'
2534            ),
2535            "SELECT t1.c1, t1.c2, t1.c3, CAST(NULL AS INTEGER) AS c4, "
2536            "CAST(NULL AS INTEGER) AS c5, 'a' AS q1 FROM t1 UNION ALL "
2537            "SELECT t2.c1, t2.c2, t2.c3, t2.c4, CAST(NULL AS INTEGER) AS c5, "
2538            "'b' AS q1 FROM t2 UNION ALL SELECT t3.c1, "
2539            "CAST(NULL AS INTEGER) AS c2, t3.c3, CAST(NULL AS INTEGER) AS c4, "
2540            "t3.c5, 'c' AS q1 FROM t3"
2541        )
2542
2543    def test_type_col_non_present(self):
2544        t1, t2, t3 = self._fixture()
2545        self.assert_compile(
2546            polymorphic_union(
2547                util.OrderedDict([("a", t1), ("b", t2), ("c", t3)]),
2548                None
2549            ),
2550            "SELECT t1.c1, t1.c2, t1.c3, CAST(NULL AS INTEGER) AS c4, "
2551            "CAST(NULL AS INTEGER) AS c5 FROM t1 UNION ALL SELECT t2.c1, "
2552            "t2.c2, t2.c3, t2.c4, CAST(NULL AS INTEGER) AS c5 FROM t2 "
2553            "UNION ALL SELECT t3.c1, CAST(NULL AS INTEGER) AS c2, t3.c3, "
2554            "CAST(NULL AS INTEGER) AS c4, t3.c5 FROM t3"
2555        )
2556
2557    def test_no_cast_null(self):
2558        t1, t2, t3 = self._fixture()
2559        self.assert_compile(
2560            polymorphic_union(
2561                util.OrderedDict([("a", t1), ("b", t2), ("c", t3)]),
2562                'q1', cast_nulls=False
2563            ),
2564            "SELECT t1.c1, t1.c2, t1.c3, NULL AS c4, NULL AS c5, 'a' AS q1 "
2565            "FROM t1 UNION ALL SELECT t2.c1, t2.c2, t2.c3, t2.c4, NULL AS c5, "
2566            "'b' AS q1 FROM t2 UNION ALL SELECT t3.c1, NULL AS c2, t3.c3, "
2567            "NULL AS c4, t3.c5, 'c' AS q1 FROM t3"
2568        )
2569
2570
2571class NameConflictTest(fixtures.MappedTest):
2572    @classmethod
2573    def define_tables(cls, metadata):
2574        content = Table('content', metadata,
2575            Column('id', Integer, primary_key=True,
2576                    test_needs_autoincrement=True),
2577            Column('type', String(30))
2578        )
2579        foo = Table('foo', metadata,
2580            Column('id', Integer, ForeignKey('content.id'),
2581                        primary_key=True),
2582            Column('content_type', String(30))
2583        )
2584
2585    def test_name_conflict(self):
2586        class Content(object):
2587            pass
2588        class Foo(Content):
2589            pass
2590        mapper(Content, self.tables.content,
2591                    polymorphic_on=self.tables.content.c.type)
2592        mapper(Foo, self.tables.foo, inherits=Content,
2593                    polymorphic_identity='foo')
2594        sess = create_session()
2595        f = Foo()
2596        f.content_type = 'bar'
2597        sess.add(f)
2598        sess.flush()
2599        f_id = f.id
2600        sess.expunge_all()
2601        assert sess.query(Content).get(f_id).content_type == 'bar'
2602