1"""Tests cyclical mapper relationships.
2
3We might want to try an automated generate of much of this, all combos of
4T1<->T2, with o2m or m2o between them, and a third T3 with o2m/m2o to one/both
5T1/T2.
6
7"""
8from sqlalchemy import testing
9from sqlalchemy import Integer, String, ForeignKey
10from sqlalchemy.testing.schema import Table, Column
11from sqlalchemy.orm import mapper, relationship, backref, \
12    create_session, sessionmaker
13from sqlalchemy.testing import eq_, is_
14from sqlalchemy.testing.assertsql import RegexSQL, CompiledSQL, AllOf
15from sqlalchemy.testing import fixtures
16
17
18class SelfReferentialTest(fixtures.MappedTest):
19    """A self-referential mapper with an additional list of child objects."""
20
21    @classmethod
22    def define_tables(cls, metadata):
23        Table('t1', metadata,
24              Column('c1', Integer, primary_key=True,
25                     test_needs_autoincrement=True),
26              Column('parent_c1', Integer, ForeignKey('t1.c1')),
27              Column('data', String(20)))
28        Table('t2', metadata,
29              Column('c1', Integer, primary_key=True,
30                     test_needs_autoincrement=True),
31              Column('c1id', Integer, ForeignKey('t1.c1')),
32              Column('data', String(20)))
33
34    @classmethod
35    def setup_classes(cls):
36        class C1(cls.Basic):
37            def __init__(self, data=None):
38                self.data = data
39
40        class C2(cls.Basic):
41            def __init__(self, data=None):
42                self.data = data
43
44    def test_single(self):
45        C1, t1 = self.classes.C1, self.tables.t1
46
47        mapper(C1, t1, properties={
48            'c1s': relationship(C1, cascade="all"),
49            'parent': relationship(C1,
50                                   primaryjoin=t1.c.parent_c1 == t1.c.c1,
51                                   remote_side=t1.c.c1,
52                                   lazy='select',
53                                   uselist=False)})
54        a = C1('head c1')
55        a.c1s.append(C1('another c1'))
56
57        sess = create_session()
58        sess.add(a)
59        sess.flush()
60        sess.delete(a)
61        sess.flush()
62
63    def test_many_to_one_only(self):
64        """
65
66        test that the circular dependency sort can assemble a many-to-one
67        dependency processor when only the object on the "many" side is
68        actually in the list of modified objects.
69
70        """
71
72        C1, t1 = self.classes.C1, self.tables.t1
73
74        mapper(C1, t1, properties={
75            'parent': relationship(C1,
76                                   primaryjoin=t1.c.parent_c1 == t1.c.c1,
77                                   remote_side=t1.c.c1)})
78
79        c1 = C1()
80
81        sess = create_session()
82        sess.add(c1)
83        sess.flush()
84        sess.expunge_all()
85        c1 = sess.query(C1).get(c1.c1)
86        c2 = C1()
87        c2.parent = c1
88        sess.add(c2)
89        sess.flush()
90        assert c2.parent_c1 == c1.c1
91
92    def test_cycle(self):
93        C2, C1, t2, t1 = (self.classes.C2,
94                          self.classes.C1,
95                          self.tables.t2,
96                          self.tables.t1)
97
98        mapper(C1, t1, properties={
99            'c1s': relationship(C1, cascade="all"),
100            'c2s': relationship(mapper(C2, t2), cascade="all, delete-orphan")})
101
102        a = C1('head c1')
103        a.c1s.append(C1('child1'))
104        a.c1s.append(C1('child2'))
105        a.c1s[0].c1s.append(C1('subchild1'))
106        a.c1s[0].c1s.append(C1('subchild2'))
107        a.c1s[1].c2s.append(C2('child2 data1'))
108        a.c1s[1].c2s.append(C2('child2 data2'))
109        sess = create_session()
110        sess.add(a)
111        sess.flush()
112
113        sess.delete(a)
114        sess.flush()
115
116    def test_setnull_ondelete(self):
117        C1, t1 = self.classes.C1, self.tables.t1
118
119        mapper(C1, t1, properties={
120            'children': relationship(C1)
121        })
122
123        sess = create_session()
124        c1 = C1()
125        c2 = C1()
126        c1.children.append(c2)
127        sess.add(c1)
128        sess.flush()
129        assert c2.parent_c1 == c1.c1
130
131        sess.delete(c1)
132        sess.flush()
133        assert c2.parent_c1 is None
134
135        sess.expire_all()
136        assert c2.parent_c1 is None
137
138
139class SelfReferentialNoPKTest(fixtures.MappedTest):
140    """A self-referential relationship that joins on a column other than the
141    primary key column"""
142
143    @classmethod
144    def define_tables(cls, metadata):
145        Table('item', metadata,
146              Column('id', Integer, primary_key=True,
147                     test_needs_autoincrement=True),
148              Column('uuid', String(32), unique=True, nullable=False),
149              Column('parent_uuid', String(32), ForeignKey('item.uuid'),
150                     nullable=True))
151
152    @classmethod
153    def setup_classes(cls):
154        class TT(cls.Basic):
155            def __init__(self):
156                self.uuid = hex(id(self))
157
158    @classmethod
159    def setup_mappers(cls):
160        item, TT = cls.tables.item, cls.classes.TT
161
162        mapper(TT, item, properties={
163            'children': relationship(
164                TT,
165                remote_side=[item.c.parent_uuid],
166                backref=backref('parent', remote_side=[item.c.uuid]))})
167
168    def test_basic(self):
169        TT = self.classes.TT
170
171        t1 = TT()
172        t1.children.append(TT())
173        t1.children.append(TT())
174
175        s = create_session()
176        s.add(t1)
177        s.flush()
178        s.expunge_all()
179        t = s.query(TT).filter_by(id=t1.id).one()
180        eq_(t.children[0].parent_uuid, t1.uuid)
181
182    def test_lazy_clause(self):
183        TT = self.classes.TT
184
185        s = create_session()
186        t1 = TT()
187        t2 = TT()
188        t1.children.append(t2)
189        s.add(t1)
190        s.flush()
191        s.expunge_all()
192
193        t = s.query(TT).filter_by(id=t2.id).one()
194        eq_(t.uuid, t2.uuid)
195        eq_(t.parent.uuid, t1.uuid)
196
197
198class InheritTestOne(fixtures.MappedTest):
199    @classmethod
200    def define_tables(cls, metadata):
201        Table("parent", metadata,
202              Column("id", Integer, primary_key=True,
203                     test_needs_autoincrement=True),
204              Column("parent_data", String(50)),
205              Column("type", String(10)))
206
207        Table("child1", metadata,
208              Column("id", Integer, ForeignKey("parent.id"), primary_key=True),
209              Column("child1_data", String(50)))
210
211        Table("child2", metadata,
212              Column("id", Integer, ForeignKey("parent.id"), primary_key=True),
213              Column("child1_id", Integer, ForeignKey("child1.id"),
214                     nullable=False),
215              Column("child2_data", String(50)))
216
217    @classmethod
218    def setup_classes(cls):
219        class Parent(cls.Basic):
220            pass
221
222        class Child1(Parent):
223            pass
224
225        class Child2(Parent):
226            pass
227
228    @classmethod
229    def setup_mappers(cls):
230        child1, child2, parent, Parent, Child1, Child2 = (cls.tables.child1,
231                                                          cls.tables.child2,
232                                                          cls.tables.parent,
233                                                          cls.classes.Parent,
234                                                          cls.classes.Child1,
235                                                          cls.classes.Child2)
236
237        mapper(Parent, parent)
238        mapper(Child1, child1, inherits=Parent)
239        mapper(Child2, child2, inherits=Parent, properties=dict(
240            child1=relationship(
241                Child1,
242                primaryjoin=child2.c.child1_id == child1.c.id)))
243
244    def test_many_to_one_only(self):
245        """test similar to SelfReferentialTest.testmanytooneonly"""
246
247        Child1, Child2 = self.classes.Child1, self.classes.Child2
248
249        session = create_session()
250
251        c1 = Child1()
252        c1.child1_data = "qwerty"
253        session.add(c1)
254        session.flush()
255        session.expunge_all()
256
257        c1 = session.query(Child1).filter_by(child1_data="qwerty").one()
258        c2 = Child2()
259        c2.child1 = c1
260        c2.child2_data = "asdfgh"
261        session.add(c2)
262
263        # the flush will fail if the UOW does not set up a many-to-one DP
264        # attached to a task corresponding to c1, since "child1_id" is not
265        # nullable
266        session.flush()
267
268
269class InheritTestTwo(fixtures.MappedTest):
270    """
271
272    The fix in BiDirectionalManyToOneTest raised this issue, regarding the
273    'circular sort' containing UOWTasks that were still polymorphic, which
274    could create duplicate entries in the final sort
275
276    """
277
278    @classmethod
279    def define_tables(cls, metadata):
280        Table('a', metadata,
281              Column('id', Integer, primary_key=True,
282                     test_needs_autoincrement=True),
283              Column('cid', Integer, ForeignKey('c.id')))
284
285        Table('b', metadata,
286              Column('id', Integer, ForeignKey("a.id"), primary_key=True))
287
288        Table('c', metadata,
289              Column('id', Integer, primary_key=True,
290                     test_needs_autoincrement=True),
291              Column('aid', Integer,
292                     ForeignKey('a.id', name="foo")))
293
294    @classmethod
295    def setup_classes(cls):
296        class A(cls.Basic):
297            pass
298
299        class B(A):
300            pass
301
302        class C(cls.Basic):
303            pass
304
305    def test_flush(self):
306        a, A, c, b, C, B = (self.tables.a,
307                            self.classes.A,
308                            self.tables.c,
309                            self.tables.b,
310                            self.classes.C,
311                            self.classes.B)
312
313        mapper(A, a, properties={
314            'cs': relationship(C, primaryjoin=a.c.cid == c.c.id)})
315
316        mapper(B, b, inherits=A, inherit_condition=b.c.id == a.c.id)
317
318        mapper(C, c, properties={
319            'arel': relationship(A, primaryjoin=a.c.id == c.c.aid)})
320
321        sess = create_session()
322        bobj = B()
323        sess.add(bobj)
324        cobj = C()
325        sess.add(cobj)
326        sess.flush()
327
328
329class BiDirectionalManyToOneTest(fixtures.MappedTest):
330    run_define_tables = 'each'
331
332    @classmethod
333    def define_tables(cls, metadata):
334        Table('t1', metadata,
335              Column('id', Integer, primary_key=True,
336                     test_needs_autoincrement=True),
337              Column('data', String(30)),
338              Column('t2id', Integer, ForeignKey('t2.id')))
339        Table('t2', metadata,
340              Column('id', Integer, primary_key=True,
341                     test_needs_autoincrement=True),
342              Column('data', String(30)),
343              Column('t1id', Integer,
344                     ForeignKey('t1.id', name="foo_fk")))
345        Table('t3', metadata,
346              Column('id', Integer, primary_key=True,
347                     test_needs_autoincrement=True),
348              Column('data', String(30)),
349              Column('t1id', Integer, ForeignKey('t1.id'), nullable=False),
350              Column('t2id', Integer, ForeignKey('t2.id'), nullable=False))
351
352    @classmethod
353    def setup_classes(cls):
354        class T1(cls.Basic):
355            pass
356
357        class T2(cls.Basic):
358            pass
359
360        class T3(cls.Basic):
361            pass
362
363    @classmethod
364    def setup_mappers(cls):
365        t2, T2, T3, t1, t3, T1 = (cls.tables.t2,
366                                  cls.classes.T2,
367                                  cls.classes.T3,
368                                  cls.tables.t1,
369                                  cls.tables.t3,
370                                  cls.classes.T1)
371
372        mapper(T1, t1, properties={
373            't2': relationship(T2, primaryjoin=t1.c.t2id == t2.c.id)})
374        mapper(T2, t2, properties={
375            't1': relationship(T1, primaryjoin=t2.c.t1id == t1.c.id)})
376        mapper(T3, t3, properties={
377            't1': relationship(T1),
378            't2': relationship(T2)})
379
380    def test_reflush(self):
381        T2, T3, T1 = (self.classes.T2,
382                      self.classes.T3,
383                      self.classes.T1)
384
385        o1 = T1()
386        o1.t2 = T2()
387        sess = create_session()
388        sess.add(o1)
389        sess.flush()
390
391        # the bug here is that the dependency sort comes up with T1/T2 in a
392        # cycle, but there are no T1/T2 objects to be saved.  therefore no
393        # "cyclical subtree" gets generated, and one or the other of T1/T2
394        # gets lost, and processors on T3 don't fire off.  the test will then
395        # fail because the FK's on T3 are not nullable.
396        o3 = T3()
397        o3.t1 = o1
398        o3.t2 = o1.t2
399        sess.add(o3)
400        sess.flush()
401
402    def test_reflush_2(self):
403        """A variant on test_reflush()"""
404
405        T2, T3, T1 = (self.classes.T2,
406                      self.classes.T3,
407                      self.classes.T1)
408
409        o1 = T1()
410        o1.t2 = T2()
411        sess = create_session()
412        sess.add(o1)
413        sess.flush()
414
415        # in this case, T1, T2, and T3 tasks will all be in the cyclical
416        # tree normally.  the dependency processors for T3 are part of the
417        # 'extradeps' collection so they all get assembled into the tree
418        # as well.
419        o1a = T1()
420        o2a = T2()
421        sess.add(o1a)
422        sess.add(o2a)
423        o3b = T3()
424        o3b.t1 = o1a
425        o3b.t2 = o2a
426        sess.add(o3b)
427
428        o3 = T3()
429        o3.t1 = o1
430        o3.t2 = o1.t2
431        sess.add(o3)
432        sess.flush()
433
434
435class BiDirectionalOneToManyTest(fixtures.MappedTest):
436    """tests two mappers with a one-to-many relationship to each other."""
437
438    run_define_tables = 'each'
439
440    @classmethod
441    def define_tables(cls, metadata):
442        Table('t1', metadata,
443              Column('c1', Integer, primary_key=True,
444                     test_needs_autoincrement=True),
445              Column('c2', Integer, ForeignKey('t2.c1')))
446
447        Table('t2', metadata,
448              Column('c1', Integer, primary_key=True,
449                     test_needs_autoincrement=True),
450              Column('c2', Integer,
451                     ForeignKey('t1.c1', name='t1c1_fk')))
452
453    @classmethod
454    def setup_classes(cls):
455        class C1(cls.Basic):
456            pass
457
458        class C2(cls.Basic):
459            pass
460
461    def test_cycle(self):
462        C2, C1, t2, t1 = (self.classes.C2,
463                          self.classes.C1,
464                          self.tables.t2,
465                          self.tables.t1)
466
467        mapper(C2, t2, properties={
468            'c1s': relationship(C1,
469                                primaryjoin=t2.c.c1 == t1.c.c2,
470                                uselist=True)})
471        mapper(C1, t1, properties={
472            'c2s': relationship(C2,
473                                primaryjoin=t1.c.c1 == t2.c.c2,
474                                uselist=True)})
475
476        a = C1()
477        b = C2()
478        c = C1()
479        d = C2()
480        e = C2()
481        f = C2()
482        a.c2s.append(b)
483        d.c1s.append(c)
484        b.c1s.append(c)
485        sess = create_session()
486        sess.add_all((a, b, c, d, e, f))
487        sess.flush()
488
489
490class BiDirectionalOneToManyTest2(fixtures.MappedTest):
491    """Two mappers with a one-to-many relationship to each other,
492    with a second one-to-many on one of the mappers"""
493
494    run_define_tables = 'each'
495
496    @classmethod
497    def define_tables(cls, metadata):
498        Table('t1', metadata,
499              Column('c1', Integer, primary_key=True,
500                     test_needs_autoincrement=True),
501              Column('c2', Integer, ForeignKey('t2.c1')),
502              test_needs_autoincrement=True)
503
504        Table('t2', metadata,
505              Column('c1', Integer, primary_key=True,
506                     test_needs_autoincrement=True),
507              Column('c2', Integer,
508                     ForeignKey('t1.c1', name='t1c1_fq')),
509              test_needs_autoincrement=True)
510
511        Table('t1_data', metadata,
512              Column('c1', Integer, primary_key=True,
513                     test_needs_autoincrement=True),
514              Column('t1id', Integer, ForeignKey('t1.c1')),
515              Column('data', String(20)),
516              test_needs_autoincrement=True)
517
518    @classmethod
519    def setup_classes(cls):
520        class C1(cls.Basic):
521            pass
522
523        class C2(cls.Basic):
524            pass
525
526        class C1Data(cls.Basic):
527            pass
528
529    @classmethod
530    def setup_mappers(cls):
531        t2, t1, C1Data, t1_data, C2, C1 = (cls.tables.t2,
532                                           cls.tables.t1,
533                                           cls.classes.C1Data,
534                                           cls.tables.t1_data,
535                                           cls.classes.C2,
536                                           cls.classes.C1)
537
538        mapper(C2, t2, properties={
539            'c1s': relationship(C1,
540                                primaryjoin=t2.c.c1 == t1.c.c2,
541                                uselist=True)})
542        mapper(C1, t1, properties={
543            'c2s': relationship(C2,
544                                primaryjoin=t1.c.c1 == t2.c.c2,
545                                uselist=True),
546            'data': relationship(mapper(C1Data, t1_data))})
547
548    def test_cycle(self):
549        C2, C1, C1Data = (self.classes.C2,
550                          self.classes.C1,
551                          self.classes.C1Data)
552
553        a = C1()
554        b = C2()
555        c = C1()
556        d = C2()
557        e = C2()
558        f = C2()
559        a.c2s.append(b)
560        d.c1s.append(c)
561        b.c1s.append(c)
562        a.data.append(C1Data(data='c1data1'))
563        a.data.append(C1Data(data='c1data2'))
564        c.data.append(C1Data(data='c1data3'))
565        sess = create_session()
566        sess.add_all((a, b, c, d, e, f))
567        sess.flush()
568
569        sess.delete(d)
570        sess.delete(c)
571        sess.flush()
572
573
574class OneToManyManyToOneTest(fixtures.MappedTest):
575    """
576
577    Tests two mappers, one has a one-to-many on the other mapper, the other
578    has a separate many-to-one relationship to the first.  two tests will have
579    a row for each item that is dependent on the other.  without the
580    "post_update" flag, such relationships raise an exception when
581    dependencies are sorted.
582
583    """
584    run_define_tables = 'each'
585
586    @classmethod
587    def define_tables(cls, metadata):
588        Table('ball', metadata,
589              Column('id', Integer, primary_key=True,
590                     test_needs_autoincrement=True),
591              Column('person_id', Integer,
592                     ForeignKey('person.id', name='fk_person_id')),
593              Column('data', String(30)))
594
595        Table('person', metadata,
596              Column('id', Integer, primary_key=True,
597                     test_needs_autoincrement=True),
598              Column('favorite_ball_id', Integer, ForeignKey('ball.id')),
599              Column('data', String(30)))
600
601    @classmethod
602    def setup_classes(cls):
603        class Person(cls.Basic):
604            pass
605
606        class Ball(cls.Basic):
607            pass
608
609    def test_cycle(self):
610        """
611        This test has a peculiar aspect in that it doesn't create as many
612        dependent relationships as the other tests, and revealed a small
613        glitch in the circular dependency sorting.
614
615        """
616
617        person, ball, Ball, Person = (self.tables.person,
618                                      self.tables.ball,
619                                      self.classes.Ball,
620                                      self.classes.Person)
621
622        mapper(Ball, ball)
623        mapper(Person, person, properties=dict(
624            balls=relationship(Ball,
625                               primaryjoin=ball.c.person_id == person.c.id,
626                               remote_side=ball.c.person_id),
627            favorite=relationship(
628                Ball,
629                primaryjoin=person.c.favorite_ball_id == ball.c.id,
630                remote_side=ball.c.id)))
631
632        b = Ball()
633        p = Person()
634        p.balls.append(b)
635        sess = create_session()
636        sess.add(p)
637        sess.flush()
638
639    def test_post_update_m2o_no_cascade(self):
640        person, ball, Ball, Person = (self.tables.person,
641                                      self.tables.ball,
642                                      self.classes.Ball,
643                                      self.classes.Person)
644
645        mapper(Ball, ball)
646        mapper(Person, person, properties=dict(
647            favorite=relationship(
648                Ball, primaryjoin=person.c.favorite_ball_id == ball.c.id,
649                post_update=True)))
650        b = Ball(data='some data')
651        p = Person(data='some data')
652        p.favorite = b
653        sess = create_session()
654        sess.add(b)
655        sess.add(p)
656        sess.flush()
657
658        sess.delete(p)
659        self.assert_sql_execution(
660            testing.db,
661            sess.flush,
662            CompiledSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id "
663                        "WHERE person.id = :person_id",
664                        lambda ctx: {
665                            'favorite_ball_id': None,
666                            'person_id': p.id}
667                        ),
668            CompiledSQL("DELETE FROM person WHERE person.id = :id",
669                        lambda ctx: {'id': p.id}
670                        ),
671        )
672
673    def test_post_update_m2o(self):
674        """A cycle between two rows, with a post_update on the many-to-one"""
675
676        person, ball, Ball, Person = (self.tables.person,
677                                      self.tables.ball,
678                                      self.classes.Ball,
679                                      self.classes.Person)
680
681        mapper(Ball, ball)
682        mapper(Person, person, properties=dict(
683            balls=relationship(Ball,
684                               primaryjoin=ball.c.person_id == person.c.id,
685                               remote_side=ball.c.person_id,
686                               post_update=False,
687                               cascade="all, delete-orphan"),
688            favorite=relationship(
689                Ball,
690                primaryjoin=person.c.favorite_ball_id == ball.c.id,
691                remote_side=person.c.favorite_ball_id,
692                post_update=True)))
693
694        b = Ball(data='some data')
695        p = Person(data='some data')
696        p.balls.append(b)
697        p.balls.append(Ball(data='some data'))
698        p.balls.append(Ball(data='some data'))
699        p.balls.append(Ball(data='some data'))
700        p.favorite = b
701        sess = create_session()
702        sess.add(b)
703        sess.add(p)
704
705        self.assert_sql_execution(
706            testing.db,
707            sess.flush,
708            RegexSQL("^INSERT INTO person", {'data': 'some data'}),
709            RegexSQL("^INSERT INTO ball", lambda c: {
710                     'person_id': p.id, 'data': 'some data'}),
711            RegexSQL("^INSERT INTO ball", lambda c: {
712                     'person_id': p.id, 'data': 'some data'}),
713            RegexSQL("^INSERT INTO ball", lambda c: {
714                     'person_id': p.id, 'data': 'some data'}),
715            RegexSQL("^INSERT INTO ball", lambda c: {
716                     'person_id': p.id, 'data': 'some data'}),
717            CompiledSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id "
718                        "WHERE person.id = :person_id",
719                        lambda ctx: {
720                            'favorite_ball_id': p.favorite.id,
721                            'person_id': p.id}
722                        ),
723        )
724
725        sess.delete(p)
726
727        self.assert_sql_execution(
728            testing.db,
729            sess.flush,
730            CompiledSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id "
731                        "WHERE person.id = :person_id",
732                        lambda ctx: {'person_id': p.id,
733                                     'favorite_ball_id': None}),
734            # lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}])
735            CompiledSQL("DELETE FROM ball WHERE ball.id = :id", None),
736            CompiledSQL("DELETE FROM person WHERE person.id = :id",
737                        lambda ctx: [{'id': p.id}])
738        )
739
740    def test_post_update_backref(self):
741        """test bidirectional post_update."""
742
743        person, ball, Ball, Person = (self.tables.person,
744                                      self.tables.ball,
745                                      self.classes.Ball,
746                                      self.classes.Person)
747
748        mapper(Ball, ball)
749        mapper(Person, person, properties=dict(
750            balls=relationship(Ball,
751                               primaryjoin=ball.c.person_id == person.c.id,
752                               remote_side=ball.c.person_id, post_update=True,
753                               backref=backref('person', post_update=True)
754                               ),
755            favorite=relationship(
756                Ball,
757                primaryjoin=person.c.favorite_ball_id == ball.c.id,
758                remote_side=person.c.favorite_ball_id)
759        ))
760
761        sess = sessionmaker()()
762        p1 = Person(data='p1')
763        p2 = Person(data='p2')
764        p3 = Person(data='p3')
765
766        b1 = Ball(data='b1')
767
768        b1.person = p1
769        sess.add_all([p1, p2, p3])
770        sess.commit()
771
772        # switch here.  the post_update
773        # on ball.person can't get tripped up
774        # by the fact that there's a "reverse" prop.
775        b1.person = p2
776        sess.commit()
777        eq_(
778            p2, b1.person
779        )
780
781        # do it the other way
782        p3.balls.append(b1)
783        sess.commit()
784        eq_(
785            p3, b1.person
786        )
787
788    def test_post_update_o2m(self):
789        """A cycle between two rows, with a post_update on the one-to-many"""
790
791        person, ball, Ball, Person = (self.tables.person,
792                                      self.tables.ball,
793                                      self.classes.Ball,
794                                      self.classes.Person)
795
796        mapper(Ball, ball)
797        mapper(Person, person, properties=dict(
798            balls=relationship(Ball,
799                               primaryjoin=ball.c.person_id == person.c.id,
800                               remote_side=ball.c.person_id,
801                               cascade="all, delete-orphan",
802                               post_update=True,
803                               backref='person'),
804            favorite=relationship(
805                Ball,
806                primaryjoin=person.c.favorite_ball_id == ball.c.id,
807                remote_side=person.c.favorite_ball_id)))
808
809        b = Ball(data='some data')
810        p = Person(data='some data')
811        p.balls.append(b)
812        b2 = Ball(data='some data')
813        p.balls.append(b2)
814        b3 = Ball(data='some data')
815        p.balls.append(b3)
816        b4 = Ball(data='some data')
817        p.balls.append(b4)
818        p.favorite = b
819        sess = create_session()
820        sess.add_all((b, p, b2, b3, b4))
821
822        self.assert_sql_execution(
823            testing.db,
824            sess.flush,
825            CompiledSQL("INSERT INTO ball (person_id, data) "
826                        "VALUES (:person_id, :data)",
827                        {'person_id': None, 'data': 'some data'}),
828
829            CompiledSQL("INSERT INTO ball (person_id, data) "
830                        "VALUES (:person_id, :data)",
831                        {'person_id': None, 'data': 'some data'}),
832
833            CompiledSQL("INSERT INTO ball (person_id, data) "
834                        "VALUES (:person_id, :data)",
835                        {'person_id': None, 'data': 'some data'}),
836
837            CompiledSQL("INSERT INTO ball (person_id, data) "
838                        "VALUES (:person_id, :data)",
839                        {'person_id': None, 'data': 'some data'}),
840
841            CompiledSQL("INSERT INTO person (favorite_ball_id, data) "
842                        "VALUES (:favorite_ball_id, :data)",
843                        lambda ctx: {'favorite_ball_id': b.id,
844                                     'data': 'some data'}),
845
846            CompiledSQL("UPDATE ball SET person_id=:person_id "
847                        "WHERE ball.id = :ball_id",
848                        lambda ctx: [
849                            {'person_id': p.id, 'ball_id': b.id},
850                            {'person_id': p.id, 'ball_id': b2.id},
851                            {'person_id': p.id, 'ball_id': b3.id},
852                            {'person_id': p.id, 'ball_id': b4.id}
853                        ]),
854        )
855
856        sess.delete(p)
857
858        self.assert_sql_execution(testing.db, sess.flush,
859                                  CompiledSQL(
860                                      "UPDATE ball SET person_id=:person_id "
861                                      "WHERE ball.id = :ball_id",
862                                      lambda ctx: [
863                                          {'person_id': None,
864                                           'ball_id': b.id},
865                                          {'person_id': None,
866                                           'ball_id': b2.id},
867                                          {'person_id': None,
868                                           'ball_id': b3.id},
869                                          {'person_id': None,
870                                           'ball_id': b4.id}
871                                      ]
872                                  ),
873                                  CompiledSQL(
874                                      "DELETE FROM person "
875                                      "WHERE person.id = :id",
876                                      lambda ctx: [{'id': p.id}]),
877
878                                  CompiledSQL(
879                                      "DELETE FROM ball WHERE ball.id = :id",
880                                      lambda ctx: [{'id': b.id},
881                                                   {'id': b2.id},
882                                                   {'id': b3.id},
883                                                   {'id': b4.id}])
884                                  )
885
886    def test_post_update_m2o_detect_none(self):
887        person, ball, Ball, Person = (
888            self.tables.person,
889            self.tables.ball,
890            self.classes.Ball,
891            self.classes.Person)
892
893        mapper(Ball, ball, properties={
894            'person': relationship(
895                Person, post_update=True,
896                primaryjoin=person.c.id == ball.c.person_id)
897        })
898        mapper(Person, person)
899
900        sess = create_session(autocommit=False, expire_on_commit=True)
901        sess.add(Ball(person=Person()))
902        sess.commit()
903        b1 = sess.query(Ball).first()
904
905        # needs to be unloaded
906        assert 'person' not in b1.__dict__
907        b1.person = None
908
909        self.assert_sql_execution(
910            testing.db,
911            sess.flush,
912            CompiledSQL(
913                "UPDATE ball SET person_id=:person_id "
914                "WHERE ball.id = :ball_id",
915                lambda ctx: {'person_id': None, 'ball_id': b1.id})
916        )
917
918        is_(b1.person, None)
919
920
921class SelfReferentialPostUpdateTest(fixtures.MappedTest):
922    """Post_update on a single self-referential mapper.
923
924
925    """
926
927    @classmethod
928    def define_tables(cls, metadata):
929        Table('node', metadata,
930              Column('id', Integer, primary_key=True,
931                     test_needs_autoincrement=True),
932              Column('path', String(50), nullable=False),
933              Column('parent_id', Integer,
934                     ForeignKey('node.id'), nullable=True),
935              Column('prev_sibling_id', Integer,
936                     ForeignKey('node.id'), nullable=True),
937              Column('next_sibling_id', Integer,
938                     ForeignKey('node.id'), nullable=True))
939
940    @classmethod
941    def setup_classes(cls):
942        class Node(cls.Basic):
943            def __init__(self, path=''):
944                self.path = path
945
946    def test_one(self):
947        """Post_update only fires off when needed.
948
949        This test case used to produce many superfluous update statements,
950        particularly upon delete
951
952        """
953
954        node, Node = self.tables.node, self.classes.Node
955
956        mapper(Node, node, properties={
957            'children': relationship(
958                Node,
959                primaryjoin=node.c.id == node.c.parent_id,
960                cascade="all",
961                backref=backref("parent", remote_side=node.c.id)
962            ),
963            'prev_sibling': relationship(
964                Node,
965                primaryjoin=node.c.prev_sibling_id == node.c.id,
966                remote_side=node.c.id,
967                uselist=False),
968            'next_sibling': relationship(
969                Node,
970                primaryjoin=node.c.next_sibling_id == node.c.id,
971                remote_side=node.c.id,
972                uselist=False,
973                post_update=True)})
974
975        session = create_session()
976
977        def append_child(parent, child):
978            if parent.children:
979                parent.children[-1].next_sibling = child
980                child.prev_sibling = parent.children[-1]
981            parent.children.append(child)
982
983        def remove_child(parent, child):
984            child.parent = None
985            node = child.next_sibling
986            node.prev_sibling = child.prev_sibling
987            child.prev_sibling.next_sibling = node
988            session.delete(child)
989        root = Node('root')
990
991        about = Node('about')
992        cats = Node('cats')
993        stories = Node('stories')
994        bruce = Node('bruce')
995
996        append_child(root, about)
997        assert(about.prev_sibling is None)
998        append_child(root, cats)
999        assert(cats.prev_sibling is about)
1000        assert(cats.next_sibling is None)
1001        assert(about.next_sibling is cats)
1002        assert(about.prev_sibling is None)
1003        append_child(root, stories)
1004        append_child(root, bruce)
1005        session.add(root)
1006        session.flush()
1007
1008        remove_child(root, cats)
1009
1010        # pre-trigger lazy loader on 'cats' to make the test easier
1011        cats.children
1012        self.assert_sql_execution(
1013            testing.db,
1014            session.flush,
1015            AllOf(
1016                CompiledSQL("UPDATE node SET prev_sibling_id=:prev_sibling_id "
1017                            "WHERE node.id = :node_id",
1018                            lambda ctx: {'prev_sibling_id': about.id,
1019                                         'node_id': stories.id}),
1020
1021                CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
1022                            "WHERE node.id = :node_id",
1023                            lambda ctx: {'next_sibling_id': stories.id,
1024                                         'node_id': about.id}),
1025
1026                CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
1027                            "WHERE node.id = :node_id",
1028                            lambda ctx: {'next_sibling_id': None,
1029                                         'node_id': cats.id}),
1030            ),
1031
1032            CompiledSQL("DELETE FROM node WHERE node.id = :id",
1033                        lambda ctx: [{'id': cats.id}])
1034        )
1035
1036        session.delete(root)
1037
1038        self.assert_sql_execution(
1039            testing.db,
1040            session.flush,
1041            CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
1042                        "WHERE node.id = :node_id",
1043                        lambda ctx: [
1044                            {'node_id': about.id, 'next_sibling_id': None},
1045                            {'node_id': stories.id, 'next_sibling_id': None}
1046                        ]
1047                        ),
1048            AllOf(
1049                CompiledSQL("DELETE FROM node WHERE node.id = :id",
1050                            lambda ctx: {'id': about.id}
1051                            ),
1052                CompiledSQL("DELETE FROM node WHERE node.id = :id",
1053                            lambda ctx: {'id': stories.id}
1054                            ),
1055                CompiledSQL("DELETE FROM node WHERE node.id = :id",
1056                            lambda ctx: {'id': bruce.id}
1057                            ),
1058            ),
1059            CompiledSQL("DELETE FROM node WHERE node.id = :id",
1060                        lambda ctx: {'id': root.id}
1061                        ),
1062        )
1063        about = Node('about')
1064        cats = Node('cats')
1065        about.next_sibling = cats
1066        cats.prev_sibling = about
1067        session.add(about)
1068        session.flush()
1069        session.delete(about)
1070        cats.prev_sibling = None
1071        session.flush()
1072
1073
1074class SelfReferentialPostUpdateTest2(fixtures.MappedTest):
1075
1076    @classmethod
1077    def define_tables(cls, metadata):
1078        Table("a_table", metadata,
1079              Column("id", Integer(), primary_key=True,
1080                     test_needs_autoincrement=True),
1081              Column("fui", String(128)),
1082              Column("b", Integer(), ForeignKey("a_table.id")))
1083
1084    @classmethod
1085    def setup_classes(cls):
1086        class A(cls.Basic):
1087            pass
1088
1089    def test_one(self):
1090        """
1091        Test that post_update remembers to be involved in update operations as
1092        well, since it replaces the normal dependency processing completely
1093        [ticket:413]
1094
1095        """
1096
1097        A, a_table = self.classes.A, self.tables.a_table
1098
1099        mapper(A, a_table, properties={
1100            'foo': relationship(A,
1101                                remote_side=[a_table.c.id],
1102                                post_update=True)})
1103
1104        session = create_session()
1105
1106        f1 = A(fui="f1")
1107        session.add(f1)
1108        session.flush()
1109
1110        f2 = A(fui="f2", foo=f1)
1111
1112        # at this point f1 is already inserted.  but we need post_update
1113        # to fire off anyway
1114        session.add(f2)
1115        session.flush()
1116        session.expunge_all()
1117
1118        f1 = session.query(A).get(f1.id)
1119        f2 = session.query(A).get(f2.id)
1120        assert f2.foo is f1
1121
1122
1123class SelfReferentialPostUpdateTest3(fixtures.MappedTest):
1124    @classmethod
1125    def define_tables(cls, metadata):
1126        Table('parent', metadata,
1127              Column('id', Integer, primary_key=True,
1128                     test_needs_autoincrement=True),
1129              Column('name', String(50), nullable=False),
1130              Column('child_id', Integer,
1131                     ForeignKey('child.id', name='c1'), nullable=True))
1132
1133        Table('child', metadata,
1134              Column('id', Integer, primary_key=True,
1135                     test_needs_autoincrement=True),
1136              Column('name', String(50), nullable=False),
1137              Column('child_id', Integer,
1138                     ForeignKey('child.id')),
1139              Column('parent_id', Integer,
1140                     ForeignKey('parent.id'), nullable=True))
1141
1142    @classmethod
1143    def setup_classes(cls):
1144        class Parent(cls.Basic):
1145            def __init__(self, name=''):
1146                self.name = name
1147
1148        class Child(cls.Basic):
1149            def __init__(self, name=''):
1150                self.name = name
1151
1152    def test_one(self):
1153        Child, Parent, parent, child = (self.classes.Child,
1154                                        self.classes.Parent,
1155                                        self.tables.parent,
1156                                        self.tables.child)
1157
1158        mapper(Parent, parent, properties={
1159            'children': relationship(
1160                Child,
1161                primaryjoin=parent.c.id == child.c.parent_id),
1162            'child': relationship(
1163                Child,
1164                primaryjoin=parent.c.child_id == child.c.id, post_update=True)
1165        })
1166        mapper(Child, child, properties={
1167            'parent': relationship(Child, remote_side=child.c.id)
1168        })
1169
1170        session = create_session()
1171        p1 = Parent('p1')
1172        c1 = Child('c1')
1173        c2 = Child('c2')
1174        p1.children = [c1, c2]
1175        c2.parent = c1
1176        p1.child = c2
1177
1178        session.add_all([p1, c1, c2])
1179        session.flush()
1180
1181        p2 = Parent('p2')
1182        c3 = Child('c3')
1183        p2.children = [c3]
1184        p2.child = c3
1185        session.add(p2)
1186
1187        session.delete(c2)
1188        p1.children.remove(c2)
1189        p1.child = None
1190        session.flush()
1191
1192        p2.child = None
1193        session.flush()
1194
1195
1196class PostUpdateBatchingTest(fixtures.MappedTest):
1197    """test that lots of post update cols batch together into a single UPDATE.
1198    """
1199
1200    @classmethod
1201    def define_tables(cls, metadata):
1202        Table('parent', metadata,
1203              Column('id', Integer, primary_key=True,
1204                     test_needs_autoincrement=True),
1205              Column('name', String(50), nullable=False),
1206              Column('c1_id', Integer,
1207                     ForeignKey('child1.id', name='c1'), nullable=True),
1208              Column('c2_id', Integer,
1209                     ForeignKey('child2.id', name='c2'), nullable=True),
1210              Column('c3_id', Integer,
1211                     ForeignKey('child3.id', name='c3'), nullable=True)
1212              )
1213
1214        Table('child1', metadata,
1215              Column('id', Integer, primary_key=True,
1216                     test_needs_autoincrement=True),
1217              Column('name', String(50), nullable=False),
1218              Column('parent_id', Integer,
1219                     ForeignKey('parent.id'), nullable=False))
1220
1221        Table('child2', metadata,
1222              Column('id', Integer, primary_key=True,
1223                     test_needs_autoincrement=True),
1224              Column('name', String(50), nullable=False),
1225              Column('parent_id', Integer,
1226                     ForeignKey('parent.id'), nullable=False))
1227
1228        Table('child3', metadata,
1229              Column('id', Integer, primary_key=True,
1230                     test_needs_autoincrement=True),
1231              Column('name', String(50), nullable=False),
1232              Column('parent_id', Integer,
1233                     ForeignKey('parent.id'), nullable=False))
1234
1235    @classmethod
1236    def setup_classes(cls):
1237        class Parent(cls.Basic):
1238            def __init__(self, name=''):
1239                self.name = name
1240
1241        class Child1(cls.Basic):
1242            def __init__(self, name=''):
1243                self.name = name
1244
1245        class Child2(cls.Basic):
1246            def __init__(self, name=''):
1247                self.name = name
1248
1249        class Child3(cls.Basic):
1250            def __init__(self, name=''):
1251                self.name = name
1252
1253    def test_one(self):
1254        child1, child2, child3, Parent, parent, Child1, Child2, Child3 = (
1255            self.tables.child1,
1256            self.tables.child2,
1257            self.tables.child3,
1258            self.classes.Parent,
1259            self.tables.parent,
1260            self.classes.Child1,
1261            self.classes.Child2,
1262            self.classes.Child3)
1263
1264        mapper(Parent, parent, properties={
1265            'c1s': relationship(
1266                Child1,
1267                primaryjoin=child1.c.parent_id == parent.c.id),
1268            'c2s': relationship(
1269                Child2,
1270                primaryjoin=child2.c.parent_id == parent.c.id),
1271            'c3s': relationship(
1272                Child3, primaryjoin=child3.c.parent_id == parent.c.id),
1273
1274            'c1': relationship(
1275                Child1,
1276                primaryjoin=child1.c.id == parent.c.c1_id, post_update=True),
1277            'c2': relationship(
1278                Child2,
1279                primaryjoin=child2.c.id == parent.c.c2_id, post_update=True),
1280            'c3': relationship(
1281                Child3,
1282                primaryjoin=child3.c.id == parent.c.c3_id, post_update=True),
1283        })
1284        mapper(Child1, child1)
1285        mapper(Child2, child2)
1286        mapper(Child3, child3)
1287
1288        sess = create_session()
1289
1290        p1 = Parent('p1')
1291        c11, c12, c13 = Child1('c1'), Child1('c2'), Child1('c3')
1292        c21, c22, c23 = Child2('c1'), Child2('c2'), Child2('c3')
1293        c31, c32, c33 = Child3('c1'), Child3('c2'), Child3('c3')
1294
1295        p1.c1s = [c11, c12, c13]
1296        p1.c2s = [c21, c22, c23]
1297        p1.c3s = [c31, c32, c33]
1298        sess.add(p1)
1299        sess.flush()
1300
1301        p1.c1 = c12
1302        p1.c2 = c23
1303        p1.c3 = c31
1304
1305        self.assert_sql_execution(
1306            testing.db,
1307            sess.flush,
1308            CompiledSQL(
1309                "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, c3_id=:c3_id "
1310                "WHERE parent.id = :parent_id",
1311                lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id,
1312                             'c1_id': c12.id, 'c3_id': c31.id}
1313            )
1314        )
1315
1316        p1.c1 = p1.c2 = p1.c3 = None
1317
1318        self.assert_sql_execution(
1319            testing.db,
1320            sess.flush,
1321            CompiledSQL(
1322                "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, c3_id=:c3_id "
1323                "WHERE parent.id = :parent_id",
1324                lambda ctx: {'c2_id': None, 'parent_id': p1.id,
1325                             'c1_id': None, 'c3_id': None}
1326            )
1327        )
1328