1from sqlalchemy import cast
2from sqlalchemy import event
3from sqlalchemy import exc
4from sqlalchemy import FetchedValue
5from sqlalchemy import ForeignKey
6from sqlalchemy import func
7from sqlalchemy import Integer
8from sqlalchemy import JSON
9from sqlalchemy import literal
10from sqlalchemy import select
11from sqlalchemy import String
12from sqlalchemy import testing
13from sqlalchemy import text
14from sqlalchemy import util
15from sqlalchemy.orm import attributes
16from sqlalchemy.orm import backref
17from sqlalchemy.orm import create_session
18from sqlalchemy.orm import exc as orm_exc
19from sqlalchemy.orm import mapper
20from sqlalchemy.orm import relationship
21from sqlalchemy.orm import Session
22from sqlalchemy.orm import unitofwork
23from sqlalchemy.testing import assert_raises_message
24from sqlalchemy.testing import config
25from sqlalchemy.testing import engines
26from sqlalchemy.testing import eq_
27from sqlalchemy.testing import fixtures
28from sqlalchemy.testing.assertsql import AllOf
29from sqlalchemy.testing.assertsql import CompiledSQL
30from sqlalchemy.testing.mock import Mock
31from sqlalchemy.testing.mock import patch
32from sqlalchemy.testing.schema import Column
33from sqlalchemy.testing.schema import Table
34from test.orm import _fixtures
35
36
37class AssertsUOW(object):
38    def _get_test_uow(self, session):
39        uow = unitofwork.UOWTransaction(session)
40        deleted = set(session._deleted)
41        new = set(session._new)
42        dirty = set(session._dirty_states).difference(deleted)
43        for s in new.union(dirty):
44            uow.register_object(s)
45        for d in deleted:
46            uow.register_object(d, isdelete=True)
47        return uow
48
49    def _assert_uow_size(self, session, expected):
50        uow = self._get_test_uow(session)
51        postsort_actions = uow._generate_actions()
52        print(postsort_actions)
53        eq_(len(postsort_actions), expected, postsort_actions)
54
55
56class UOWTest(
57    _fixtures.FixtureTest, testing.AssertsExecutionResults, AssertsUOW
58):
59    run_inserts = None
60
61
62class RudimentaryFlushTest(UOWTest):
63    def test_one_to_many_save(self):
64        users, Address, addresses, User = (
65            self.tables.users,
66            self.classes.Address,
67            self.tables.addresses,
68            self.classes.User,
69        )
70
71        mapper(User, users, properties={"addresses": relationship(Address)})
72        mapper(Address, addresses)
73        sess = create_session()
74
75        a1, a2 = Address(email_address="a1"), Address(email_address="a2")
76        u1 = User(name="u1", addresses=[a1, a2])
77        sess.add(u1)
78
79        self.assert_sql_execution(
80            testing.db,
81            sess.flush,
82            CompiledSQL(
83                "INSERT INTO users (name) VALUES (:name)", {"name": "u1"}
84            ),
85            CompiledSQL(
86                "INSERT INTO addresses (user_id, email_address) "
87                "VALUES (:user_id, :email_address)",
88                lambda ctx: {"email_address": "a1", "user_id": u1.id},
89            ),
90            CompiledSQL(
91                "INSERT INTO addresses (user_id, email_address) "
92                "VALUES (:user_id, :email_address)",
93                lambda ctx: {"email_address": "a2", "user_id": u1.id},
94            ),
95        )
96
97    def test_one_to_many_delete_all(self):
98        users, Address, addresses, User = (
99            self.tables.users,
100            self.classes.Address,
101            self.tables.addresses,
102            self.classes.User,
103        )
104
105        mapper(User, users, properties={"addresses": relationship(Address)})
106        mapper(Address, addresses)
107        sess = create_session()
108        a1, a2 = Address(email_address="a1"), Address(email_address="a2")
109        u1 = User(name="u1", addresses=[a1, a2])
110        sess.add(u1)
111        sess.flush()
112
113        sess.delete(u1)
114        sess.delete(a1)
115        sess.delete(a2)
116        self.assert_sql_execution(
117            testing.db,
118            sess.flush,
119            CompiledSQL(
120                "DELETE FROM addresses WHERE addresses.id = :id",
121                [{"id": a1.id}, {"id": a2.id}],
122            ),
123            CompiledSQL(
124                "DELETE FROM users WHERE users.id = :id", {"id": u1.id}
125            ),
126        )
127
128    def test_one_to_many_delete_parent(self):
129        users, Address, addresses, User = (
130            self.tables.users,
131            self.classes.Address,
132            self.tables.addresses,
133            self.classes.User,
134        )
135
136        mapper(User, users, properties={"addresses": relationship(Address)})
137        mapper(Address, addresses)
138        sess = create_session()
139        a1, a2 = Address(email_address="a1"), Address(email_address="a2")
140        u1 = User(name="u1", addresses=[a1, a2])
141        sess.add(u1)
142        sess.flush()
143
144        sess.delete(u1)
145        self.assert_sql_execution(
146            testing.db,
147            sess.flush,
148            CompiledSQL(
149                "UPDATE addresses SET user_id=:user_id WHERE "
150                "addresses.id = :addresses_id",
151                lambda ctx: [
152                    {"addresses_id": a1.id, "user_id": None},
153                    {"addresses_id": a2.id, "user_id": None},
154                ],
155            ),
156            CompiledSQL(
157                "DELETE FROM users WHERE users.id = :id", {"id": u1.id}
158            ),
159        )
160
161    def test_many_to_one_save(self):
162        users, Address, addresses, User = (
163            self.tables.users,
164            self.classes.Address,
165            self.tables.addresses,
166            self.classes.User,
167        )
168
169        mapper(User, users)
170        mapper(Address, addresses, properties={"user": relationship(User)})
171        sess = create_session()
172
173        u1 = User(name="u1")
174        a1, a2 = (
175            Address(email_address="a1", user=u1),
176            Address(email_address="a2", user=u1),
177        )
178        sess.add_all([a1, a2])
179
180        self.assert_sql_execution(
181            testing.db,
182            sess.flush,
183            CompiledSQL(
184                "INSERT INTO users (name) VALUES (:name)", {"name": "u1"}
185            ),
186            CompiledSQL(
187                "INSERT INTO addresses (user_id, email_address) "
188                "VALUES (:user_id, :email_address)",
189                lambda ctx: {"email_address": "a1", "user_id": u1.id},
190            ),
191            CompiledSQL(
192                "INSERT INTO addresses (user_id, email_address) "
193                "VALUES (:user_id, :email_address)",
194                lambda ctx: {"email_address": "a2", "user_id": u1.id},
195            ),
196        )
197
198    def test_many_to_one_delete_all(self):
199        users, Address, addresses, User = (
200            self.tables.users,
201            self.classes.Address,
202            self.tables.addresses,
203            self.classes.User,
204        )
205
206        mapper(User, users)
207        mapper(Address, addresses, properties={"user": relationship(User)})
208        sess = create_session()
209
210        u1 = User(name="u1")
211        a1, a2 = (
212            Address(email_address="a1", user=u1),
213            Address(email_address="a2", user=u1),
214        )
215        sess.add_all([a1, a2])
216        sess.flush()
217
218        sess.delete(u1)
219        sess.delete(a1)
220        sess.delete(a2)
221        self.assert_sql_execution(
222            testing.db,
223            sess.flush,
224            CompiledSQL(
225                "DELETE FROM addresses WHERE addresses.id = :id",
226                [{"id": a1.id}, {"id": a2.id}],
227            ),
228            CompiledSQL(
229                "DELETE FROM users WHERE users.id = :id", {"id": u1.id}
230            ),
231        )
232
233    def test_many_to_one_delete_target(self):
234        users, Address, addresses, User = (
235            self.tables.users,
236            self.classes.Address,
237            self.tables.addresses,
238            self.classes.User,
239        )
240
241        mapper(User, users)
242        mapper(Address, addresses, properties={"user": relationship(User)})
243        sess = create_session()
244
245        u1 = User(name="u1")
246        a1, a2 = (
247            Address(email_address="a1", user=u1),
248            Address(email_address="a2", user=u1),
249        )
250        sess.add_all([a1, a2])
251        sess.flush()
252
253        sess.delete(u1)
254        a1.user = a2.user = None
255        self.assert_sql_execution(
256            testing.db,
257            sess.flush,
258            CompiledSQL(
259                "UPDATE addresses SET user_id=:user_id WHERE "
260                "addresses.id = :addresses_id",
261                lambda ctx: [
262                    {"addresses_id": a1.id, "user_id": None},
263                    {"addresses_id": a2.id, "user_id": None},
264                ],
265            ),
266            CompiledSQL(
267                "DELETE FROM users WHERE users.id = :id", {"id": u1.id}
268            ),
269        )
270
271    def test_many_to_one_delete_unloaded(self):
272        users, Address, addresses, User = (
273            self.tables.users,
274            self.classes.Address,
275            self.tables.addresses,
276            self.classes.User,
277        )
278
279        mapper(User, users)
280        mapper(Address, addresses, properties={"parent": relationship(User)})
281
282        parent = User(name="p1")
283        c1, c2 = (
284            Address(email_address="c1", parent=parent),
285            Address(email_address="c2", parent=parent),
286        )
287
288        session = Session()
289        session.add_all([c1, c2])
290        session.add(parent)
291
292        session.flush()
293
294        pid = parent.id
295        c1id = c1.id
296        c2id = c2.id
297
298        session.expire(parent)
299        session.expire(c1)
300        session.expire(c2)
301
302        session.delete(c1)
303        session.delete(c2)
304        session.delete(parent)
305
306        # testing that relationships
307        # are loaded even if all ids/references are
308        # expired
309        self.assert_sql_execution(
310            testing.db,
311            session.flush,
312            AllOf(
313                # [ticket:2002] - ensure the m2os are loaded.
314                # the selects here are in fact unexpiring
315                # each row - the m2o comes from the identity map.
316                # the User row might be handled before or the addresses
317                # are loaded so need to use AllOf
318                CompiledSQL(
319                    "SELECT addresses.id AS addresses_id, "
320                    "addresses.user_id AS "
321                    "addresses_user_id, addresses.email_address AS "
322                    "addresses_email_address FROM addresses "
323                    "WHERE addresses.id = "
324                    ":param_1",
325                    lambda ctx: {"param_1": c1id},
326                ),
327                CompiledSQL(
328                    "SELECT addresses.id AS addresses_id, "
329                    "addresses.user_id AS "
330                    "addresses_user_id, addresses.email_address AS "
331                    "addresses_email_address FROM addresses "
332                    "WHERE addresses.id = "
333                    ":param_1",
334                    lambda ctx: {"param_1": c2id},
335                ),
336                CompiledSQL(
337                    "SELECT users.id AS users_id, users.name AS users_name "
338                    "FROM users WHERE users.id = :param_1",
339                    lambda ctx: {"param_1": pid},
340                ),
341                CompiledSQL(
342                    "DELETE FROM addresses WHERE addresses.id = :id",
343                    lambda ctx: [{"id": c1id}, {"id": c2id}],
344                ),
345                CompiledSQL(
346                    "DELETE FROM users WHERE users.id = :id",
347                    lambda ctx: {"id": pid},
348                ),
349            ),
350        )
351
352    def test_many_to_one_delete_childonly_unloaded(self):
353        users, Address, addresses, User = (
354            self.tables.users,
355            self.classes.Address,
356            self.tables.addresses,
357            self.classes.User,
358        )
359
360        mapper(User, users)
361        mapper(Address, addresses, properties={"parent": relationship(User)})
362
363        parent = User(name="p1")
364        c1, c2 = (
365            Address(email_address="c1", parent=parent),
366            Address(email_address="c2", parent=parent),
367        )
368
369        session = Session()
370        session.add_all([c1, c2])
371        session.add(parent)
372
373        session.flush()
374
375        # pid = parent.id
376        c1id = c1.id
377        c2id = c2.id
378
379        session.expire(c1)
380        session.expire(c2)
381
382        session.delete(c1)
383        session.delete(c2)
384
385        self.assert_sql_execution(
386            testing.db,
387            session.flush,
388            AllOf(
389                # [ticket:2049] - we aren't deleting User,
390                # relationship is simple m2o, no SELECT should be emitted for
391                # it.
392                CompiledSQL(
393                    "SELECT addresses.id AS addresses_id, "
394                    "addresses.user_id AS "
395                    "addresses_user_id, addresses.email_address AS "
396                    "addresses_email_address FROM addresses "
397                    "WHERE addresses.id = "
398                    ":param_1",
399                    lambda ctx: {"param_1": c1id},
400                ),
401                CompiledSQL(
402                    "SELECT addresses.id AS addresses_id, "
403                    "addresses.user_id AS "
404                    "addresses_user_id, addresses.email_address AS "
405                    "addresses_email_address FROM addresses "
406                    "WHERE addresses.id = "
407                    ":param_1",
408                    lambda ctx: {"param_1": c2id},
409                ),
410            ),
411            CompiledSQL(
412                "DELETE FROM addresses WHERE addresses.id = :id",
413                lambda ctx: [{"id": c1id}, {"id": c2id}],
414            ),
415        )
416
417    def test_many_to_one_delete_childonly_unloaded_expired(self):
418        users, Address, addresses, User = (
419            self.tables.users,
420            self.classes.Address,
421            self.tables.addresses,
422            self.classes.User,
423        )
424
425        mapper(User, users)
426        mapper(Address, addresses, properties={"parent": relationship(User)})
427
428        parent = User(name="p1")
429        c1, c2 = (
430            Address(email_address="c1", parent=parent),
431            Address(email_address="c2", parent=parent),
432        )
433
434        session = Session()
435        session.add_all([c1, c2])
436        session.add(parent)
437
438        session.flush()
439
440        # pid = parent.id
441        c1id = c1.id
442        c2id = c2.id
443
444        session.expire(parent)
445        session.expire(c1)
446        session.expire(c2)
447
448        session.delete(c1)
449        session.delete(c2)
450
451        self.assert_sql_execution(
452            testing.db,
453            session.flush,
454            AllOf(
455                # the parent User is expired, so it gets loaded here.
456                CompiledSQL(
457                    "SELECT addresses.id AS addresses_id, "
458                    "addresses.user_id AS "
459                    "addresses_user_id, addresses.email_address AS "
460                    "addresses_email_address FROM addresses "
461                    "WHERE addresses.id = "
462                    ":param_1",
463                    lambda ctx: {"param_1": c1id},
464                ),
465                CompiledSQL(
466                    "SELECT addresses.id AS addresses_id, "
467                    "addresses.user_id AS "
468                    "addresses_user_id, addresses.email_address AS "
469                    "addresses_email_address FROM addresses "
470                    "WHERE addresses.id = "
471                    ":param_1",
472                    lambda ctx: {"param_1": c2id},
473                ),
474            ),
475            CompiledSQL(
476                "DELETE FROM addresses WHERE addresses.id = :id",
477                lambda ctx: [{"id": c1id}, {"id": c2id}],
478            ),
479        )
480
481    def test_natural_ordering(self):
482        """test that unconnected items take relationship()
483        into account regardless."""
484
485        users, Address, addresses, User = (
486            self.tables.users,
487            self.classes.Address,
488            self.tables.addresses,
489            self.classes.User,
490        )
491
492        mapper(User, users)
493        mapper(Address, addresses, properties={"parent": relationship(User)})
494
495        sess = create_session()
496
497        u1 = User(id=1, name="u1")
498        a1 = Address(id=1, user_id=1, email_address="a2")
499
500        sess.add_all([u1, a1])
501        self.assert_sql_execution(
502            testing.db,
503            sess.flush,
504            CompiledSQL(
505                "INSERT INTO users (id, name) VALUES (:id, :name)",
506                {"id": 1, "name": "u1"},
507            ),
508            CompiledSQL(
509                "INSERT INTO addresses (id, user_id, email_address) "
510                "VALUES (:id, :user_id, :email_address)",
511                {"email_address": "a2", "user_id": 1, "id": 1},
512            ),
513        )
514
515        sess.delete(u1)
516        sess.delete(a1)
517        self.assert_sql_execution(
518            testing.db,
519            sess.flush,
520            CompiledSQL(
521                "DELETE FROM addresses WHERE addresses.id = :id", [{"id": 1}]
522            ),
523            CompiledSQL("DELETE FROM users WHERE users.id = :id", [{"id": 1}]),
524        )
525
526    def test_natural_selfref(self):
527        """test that unconnected items take relationship()
528        into account regardless."""
529
530        Node, nodes = self.classes.Node, self.tables.nodes
531
532        mapper(Node, nodes, properties={"children": relationship(Node)})
533
534        sess = create_session()
535
536        n1 = Node(id=1)
537        n2 = Node(id=2, parent_id=1)
538        n3 = Node(id=3, parent_id=2)
539
540        # insert order is determined from add order since they
541        # are the same class
542        sess.add_all([n1, n2, n3])
543
544        self.assert_sql_execution(
545            testing.db,
546            sess.flush,
547            CompiledSQL(
548                "INSERT INTO nodes (id, parent_id, data) VALUES "
549                "(:id, :parent_id, :data)",
550                [
551                    {"parent_id": None, "data": None, "id": 1},
552                    {"parent_id": 1, "data": None, "id": 2},
553                    {"parent_id": 2, "data": None, "id": 3},
554                ],
555            ),
556        )
557
558    def test_many_to_many(self):
559        keywords, items, item_keywords, Keyword, Item = (
560            self.tables.keywords,
561            self.tables.items,
562            self.tables.item_keywords,
563            self.classes.Keyword,
564            self.classes.Item,
565        )
566
567        mapper(
568            Item,
569            items,
570            properties={
571                "keywords": relationship(Keyword, secondary=item_keywords)
572            },
573        )
574        mapper(Keyword, keywords)
575
576        sess = create_session()
577        k1 = Keyword(name="k1")
578        i1 = Item(description="i1", keywords=[k1])
579        sess.add(i1)
580        self.assert_sql_execution(
581            testing.db,
582            sess.flush,
583            AllOf(
584                CompiledSQL(
585                    "INSERT INTO keywords (name) VALUES (:name)",
586                    {"name": "k1"},
587                ),
588                CompiledSQL(
589                    "INSERT INTO items (description) VALUES (:description)",
590                    {"description": "i1"},
591                ),
592            ),
593            CompiledSQL(
594                "INSERT INTO item_keywords (item_id, keyword_id) "
595                "VALUES (:item_id, :keyword_id)",
596                lambda ctx: {"item_id": i1.id, "keyword_id": k1.id},
597            ),
598        )
599
600        # test that keywords collection isn't loaded
601        sess.expire(i1, ["keywords"])
602        i1.description = "i2"
603        self.assert_sql_execution(
604            testing.db,
605            sess.flush,
606            CompiledSQL(
607                "UPDATE items SET description=:description "
608                "WHERE items.id = :items_id",
609                lambda ctx: {"description": "i2", "items_id": i1.id},
610            ),
611        )
612
613    def test_m2o_flush_size(self):
614        users, Address, addresses, User = (
615            self.tables.users,
616            self.classes.Address,
617            self.tables.addresses,
618            self.classes.User,
619        )
620
621        mapper(User, users)
622        mapper(
623            Address,
624            addresses,
625            properties={"user": relationship(User, passive_updates=True)},
626        )
627        sess = create_session()
628        u1 = User(name="ed")
629        sess.add(u1)
630        self._assert_uow_size(sess, 2)
631
632    def test_o2m_flush_size(self):
633        users, Address, addresses, User = (
634            self.tables.users,
635            self.classes.Address,
636            self.tables.addresses,
637            self.classes.User,
638        )
639
640        mapper(User, users, properties={"addresses": relationship(Address)})
641        mapper(Address, addresses)
642
643        sess = create_session()
644        u1 = User(name="ed")
645        sess.add(u1)
646        self._assert_uow_size(sess, 2)
647
648        sess.flush()
649
650        u1.name = "jack"
651
652        self._assert_uow_size(sess, 2)
653        sess.flush()
654
655        a1 = Address(email_address="foo")
656        sess.add(a1)
657        sess.flush()
658
659        u1.addresses.append(a1)
660
661        self._assert_uow_size(sess, 6)
662
663        sess.flush()
664
665        sess = create_session()
666        u1 = sess.query(User).first()
667        u1.name = "ed"
668        self._assert_uow_size(sess, 2)
669
670        u1.addresses
671        self._assert_uow_size(sess, 6)
672
673
674class SingleCycleTest(UOWTest):
675    def teardown(self):
676        engines.testing_reaper.rollback_all()
677        # mysql can't handle delete from nodes
678        # since it doesn't deal with the FKs correctly,
679        # so wipe out the parent_id first
680        testing.db.execute(self.tables.nodes.update().values(parent_id=None))
681        super(SingleCycleTest, self).teardown()
682
683    def test_one_to_many_save(self):
684        Node, nodes = self.classes.Node, self.tables.nodes
685
686        mapper(Node, nodes, properties={"children": relationship(Node)})
687        sess = create_session()
688
689        n2, n3 = Node(data="n2"), Node(data="n3")
690        n1 = Node(data="n1", children=[n2, n3])
691
692        sess.add(n1)
693
694        self.assert_sql_execution(
695            testing.db,
696            sess.flush,
697            CompiledSQL(
698                "INSERT INTO nodes (parent_id, data) VALUES "
699                "(:parent_id, :data)",
700                {"parent_id": None, "data": "n1"},
701            ),
702            AllOf(
703                CompiledSQL(
704                    "INSERT INTO nodes (parent_id, data) VALUES "
705                    "(:parent_id, :data)",
706                    lambda ctx: {"parent_id": n1.id, "data": "n2"},
707                ),
708                CompiledSQL(
709                    "INSERT INTO nodes (parent_id, data) VALUES "
710                    "(:parent_id, :data)",
711                    lambda ctx: {"parent_id": n1.id, "data": "n3"},
712                ),
713            ),
714        )
715
716    def test_one_to_many_delete_all(self):
717        Node, nodes = self.classes.Node, self.tables.nodes
718
719        mapper(Node, nodes, properties={"children": relationship(Node)})
720        sess = create_session()
721
722        n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[])
723        n1 = Node(data="n1", children=[n2, n3])
724
725        sess.add(n1)
726        sess.flush()
727
728        sess.delete(n1)
729        sess.delete(n2)
730        sess.delete(n3)
731        self.assert_sql_execution(
732            testing.db,
733            sess.flush,
734            CompiledSQL(
735                "DELETE FROM nodes WHERE nodes.id = :id",
736                lambda ctx: [{"id": n2.id}, {"id": n3.id}],
737            ),
738            CompiledSQL(
739                "DELETE FROM nodes WHERE nodes.id = :id",
740                lambda ctx: {"id": n1.id},
741            ),
742        )
743
744    def test_one_to_many_delete_parent(self):
745        Node, nodes = self.classes.Node, self.tables.nodes
746
747        mapper(Node, nodes, properties={"children": relationship(Node)})
748        sess = create_session()
749
750        n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[])
751        n1 = Node(data="n1", children=[n2, n3])
752
753        sess.add(n1)
754        sess.flush()
755
756        sess.delete(n1)
757        self.assert_sql_execution(
758            testing.db,
759            sess.flush,
760            AllOf(
761                CompiledSQL(
762                    "UPDATE nodes SET parent_id=:parent_id "
763                    "WHERE nodes.id = :nodes_id",
764                    lambda ctx: [
765                        {"nodes_id": n3.id, "parent_id": None},
766                        {"nodes_id": n2.id, "parent_id": None},
767                    ],
768                )
769            ),
770            CompiledSQL(
771                "DELETE FROM nodes WHERE nodes.id = :id",
772                lambda ctx: {"id": n1.id},
773            ),
774        )
775
776    def test_many_to_one_save(self):
777        Node, nodes = self.classes.Node, self.tables.nodes
778
779        mapper(
780            Node,
781            nodes,
782            properties={"parent": relationship(Node, remote_side=nodes.c.id)},
783        )
784        sess = create_session()
785
786        n1 = Node(data="n1")
787        n2, n3 = Node(data="n2", parent=n1), Node(data="n3", parent=n1)
788
789        sess.add_all([n2, n3])
790
791        self.assert_sql_execution(
792            testing.db,
793            sess.flush,
794            CompiledSQL(
795                "INSERT INTO nodes (parent_id, data) VALUES "
796                "(:parent_id, :data)",
797                {"parent_id": None, "data": "n1"},
798            ),
799            AllOf(
800                CompiledSQL(
801                    "INSERT INTO nodes (parent_id, data) VALUES "
802                    "(:parent_id, :data)",
803                    lambda ctx: {"parent_id": n1.id, "data": "n2"},
804                ),
805                CompiledSQL(
806                    "INSERT INTO nodes (parent_id, data) VALUES "
807                    "(:parent_id, :data)",
808                    lambda ctx: {"parent_id": n1.id, "data": "n3"},
809                ),
810            ),
811        )
812
813    def test_many_to_one_delete_all(self):
814        Node, nodes = self.classes.Node, self.tables.nodes
815
816        mapper(
817            Node,
818            nodes,
819            properties={"parent": relationship(Node, remote_side=nodes.c.id)},
820        )
821        sess = create_session()
822
823        n1 = Node(data="n1")
824        n2, n3 = Node(data="n2", parent=n1), Node(data="n3", parent=n1)
825
826        sess.add_all([n2, n3])
827        sess.flush()
828
829        sess.delete(n1)
830        sess.delete(n2)
831        sess.delete(n3)
832        self.assert_sql_execution(
833            testing.db,
834            sess.flush,
835            CompiledSQL(
836                "DELETE FROM nodes WHERE nodes.id = :id",
837                lambda ctx: [{"id": n2.id}, {"id": n3.id}],
838            ),
839            CompiledSQL(
840                "DELETE FROM nodes WHERE nodes.id = :id",
841                lambda ctx: {"id": n1.id},
842            ),
843        )
844
845    def test_many_to_one_set_null_unloaded(self):
846        Node, nodes = self.classes.Node, self.tables.nodes
847
848        mapper(
849            Node,
850            nodes,
851            properties={"parent": relationship(Node, remote_side=nodes.c.id)},
852        )
853        sess = create_session()
854        n1 = Node(data="n1")
855        n2 = Node(data="n2", parent=n1)
856        sess.add_all([n1, n2])
857        sess.flush()
858        sess.close()
859
860        n2 = sess.query(Node).filter_by(data="n2").one()
861        n2.parent = None
862        self.assert_sql_execution(
863            testing.db,
864            sess.flush,
865            CompiledSQL(
866                "UPDATE nodes SET parent_id=:parent_id WHERE "
867                "nodes.id = :nodes_id",
868                lambda ctx: {"parent_id": None, "nodes_id": n2.id},
869            ),
870        )
871
872    def test_cycle_rowswitch(self):
873        Node, nodes = self.classes.Node, self.tables.nodes
874
875        mapper(Node, nodes, properties={"children": relationship(Node)})
876        sess = create_session()
877
878        n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[])
879        n1 = Node(data="n1", children=[n2])
880
881        sess.add(n1)
882        sess.flush()
883        sess.delete(n2)
884        n3.id = n2.id
885        n1.children.append(n3)
886        sess.flush()
887
888    def test_bidirectional_mutations_one(self):
889        Node, nodes = self.classes.Node, self.tables.nodes
890
891        mapper(
892            Node,
893            nodes,
894            properties={
895                "children": relationship(
896                    Node, backref=backref("parent", remote_side=nodes.c.id)
897                )
898            },
899        )
900        sess = create_session()
901
902        n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[])
903        n1 = Node(data="n1", children=[n2])
904        sess.add(n1)
905        sess.flush()
906        sess.delete(n2)
907        n1.children.append(n3)
908        sess.flush()
909
910        sess.delete(n1)
911        sess.delete(n3)
912        sess.flush()
913
914    def test_bidirectional_multilevel_save(self):
915        Node, nodes = self.classes.Node, self.tables.nodes
916
917        mapper(
918            Node,
919            nodes,
920            properties={
921                "children": relationship(
922                    Node, backref=backref("parent", remote_side=nodes.c.id)
923                )
924            },
925        )
926        sess = create_session()
927        n1 = Node(data="n1")
928        n1.children.append(Node(data="n11"))
929        n12 = Node(data="n12")
930        n1.children.append(n12)
931        n1.children.append(Node(data="n13"))
932        n1.children[1].children.append(Node(data="n121"))
933        n1.children[1].children.append(Node(data="n122"))
934        n1.children[1].children.append(Node(data="n123"))
935        sess.add(n1)
936        self.assert_sql_execution(
937            testing.db,
938            sess.flush,
939            CompiledSQL(
940                "INSERT INTO nodes (parent_id, data) VALUES "
941                "(:parent_id, :data)",
942                lambda ctx: {"parent_id": None, "data": "n1"},
943            ),
944            CompiledSQL(
945                "INSERT INTO nodes (parent_id, data) VALUES "
946                "(:parent_id, :data)",
947                lambda ctx: {"parent_id": n1.id, "data": "n11"},
948            ),
949            CompiledSQL(
950                "INSERT INTO nodes (parent_id, data) VALUES "
951                "(:parent_id, :data)",
952                lambda ctx: {"parent_id": n1.id, "data": "n12"},
953            ),
954            CompiledSQL(
955                "INSERT INTO nodes (parent_id, data) VALUES "
956                "(:parent_id, :data)",
957                lambda ctx: {"parent_id": n1.id, "data": "n13"},
958            ),
959            CompiledSQL(
960                "INSERT INTO nodes (parent_id, data) VALUES "
961                "(:parent_id, :data)",
962                lambda ctx: {"parent_id": n12.id, "data": "n121"},
963            ),
964            CompiledSQL(
965                "INSERT INTO nodes (parent_id, data) VALUES "
966                "(:parent_id, :data)",
967                lambda ctx: {"parent_id": n12.id, "data": "n122"},
968            ),
969            CompiledSQL(
970                "INSERT INTO nodes (parent_id, data) VALUES "
971                "(:parent_id, :data)",
972                lambda ctx: {"parent_id": n12.id, "data": "n123"},
973            ),
974        )
975
976    def test_singlecycle_flush_size(self):
977        Node, nodes = self.classes.Node, self.tables.nodes
978
979        mapper(Node, nodes, properties={"children": relationship(Node)})
980        sess = create_session()
981        n1 = Node(data="ed")
982        sess.add(n1)
983        self._assert_uow_size(sess, 2)
984
985        sess.flush()
986
987        n1.data = "jack"
988
989        self._assert_uow_size(sess, 2)
990        sess.flush()
991
992        n2 = Node(data="foo")
993        sess.add(n2)
994        sess.flush()
995
996        n1.children.append(n2)
997
998        self._assert_uow_size(sess, 3)
999
1000        sess.flush()
1001
1002        sess = create_session()
1003        n1 = sess.query(Node).first()
1004        n1.data = "ed"
1005        self._assert_uow_size(sess, 2)
1006
1007        n1.children
1008        self._assert_uow_size(sess, 2)
1009
1010    def test_delete_unloaded_m2o(self):
1011        Node, nodes = self.classes.Node, self.tables.nodes
1012
1013        mapper(
1014            Node,
1015            nodes,
1016            properties={"parent": relationship(Node, remote_side=nodes.c.id)},
1017        )
1018
1019        parent = Node()
1020        c1, c2 = Node(parent=parent), Node(parent=parent)
1021
1022        session = Session()
1023        session.add_all([c1, c2])
1024        session.add(parent)
1025
1026        session.flush()
1027
1028        pid = parent.id
1029        c1id = c1.id
1030        c2id = c2.id
1031
1032        session.expire(parent)
1033        session.expire(c1)
1034        session.expire(c2)
1035
1036        session.delete(c1)
1037        session.delete(c2)
1038        session.delete(parent)
1039
1040        # testing that relationships
1041        # are loaded even if all ids/references are
1042        # expired
1043        self.assert_sql_execution(
1044            testing.db,
1045            session.flush,
1046            AllOf(
1047                # ensure all three m2os are loaded.
1048                # the selects here are in fact unexpiring
1049                # each row - the m2o comes from the identity map.
1050                CompiledSQL(
1051                    "SELECT nodes.id AS nodes_id, nodes.parent_id AS "
1052                    "nodes_parent_id, "
1053                    "nodes.data AS nodes_data FROM nodes "
1054                    "WHERE nodes.id = :param_1",
1055                    lambda ctx: {"param_1": pid},
1056                ),
1057                CompiledSQL(
1058                    "SELECT nodes.id AS nodes_id, nodes.parent_id AS "
1059                    "nodes_parent_id, "
1060                    "nodes.data AS nodes_data FROM nodes "
1061                    "WHERE nodes.id = :param_1",
1062                    lambda ctx: {"param_1": c1id},
1063                ),
1064                CompiledSQL(
1065                    "SELECT nodes.id AS nodes_id, nodes.parent_id AS "
1066                    "nodes_parent_id, "
1067                    "nodes.data AS nodes_data FROM nodes "
1068                    "WHERE nodes.id = :param_1",
1069                    lambda ctx: {"param_1": c2id},
1070                ),
1071                AllOf(
1072                    CompiledSQL(
1073                        "DELETE FROM nodes WHERE nodes.id = :id",
1074                        lambda ctx: [{"id": c1id}, {"id": c2id}],
1075                    ),
1076                    CompiledSQL(
1077                        "DELETE FROM nodes WHERE nodes.id = :id",
1078                        lambda ctx: {"id": pid},
1079                    ),
1080                ),
1081            ),
1082        )
1083
1084
1085class SingleCyclePlusAttributeTest(
1086    fixtures.MappedTest, testing.AssertsExecutionResults, AssertsUOW
1087):
1088    @classmethod
1089    def define_tables(cls, metadata):
1090        Table(
1091            "nodes",
1092            metadata,
1093            Column(
1094                "id", Integer, primary_key=True, test_needs_autoincrement=True
1095            ),
1096            Column("parent_id", Integer, ForeignKey("nodes.id")),
1097            Column("data", String(30)),
1098        )
1099
1100        Table(
1101            "foobars",
1102            metadata,
1103            Column(
1104                "id", Integer, primary_key=True, test_needs_autoincrement=True
1105            ),
1106            Column("parent_id", Integer, ForeignKey("nodes.id")),
1107        )
1108
1109    def test_flush_size(self):
1110        foobars, nodes = self.tables.foobars, self.tables.nodes
1111
1112        class Node(fixtures.ComparableEntity):
1113            pass
1114
1115        class FooBar(fixtures.ComparableEntity):
1116            pass
1117
1118        mapper(
1119            Node,
1120            nodes,
1121            properties={
1122                "children": relationship(Node),
1123                "foobars": relationship(FooBar),
1124            },
1125        )
1126        mapper(FooBar, foobars)
1127
1128        sess = create_session()
1129        n1 = Node(data="n1")
1130        n2 = Node(data="n2")
1131        n1.children.append(n2)
1132        sess.add(n1)
1133        # ensure "foobars" doesn't get yanked in here
1134        self._assert_uow_size(sess, 3)
1135
1136        n1.foobars.append(FooBar())
1137        # saveupdateall/deleteall for FooBar added here,
1138        # plus processstate node.foobars
1139        # currently the "all" procs stay in pairs
1140        self._assert_uow_size(sess, 6)
1141
1142        sess.flush()
1143
1144
1145class SingleCycleM2MTest(
1146    fixtures.MappedTest, testing.AssertsExecutionResults, AssertsUOW
1147):
1148    @classmethod
1149    def define_tables(cls, metadata):
1150        Table(
1151            "nodes",
1152            metadata,
1153            Column(
1154                "id", Integer, primary_key=True, test_needs_autoincrement=True
1155            ),
1156            Column("data", String(30)),
1157            Column("favorite_node_id", Integer, ForeignKey("nodes.id")),
1158        )
1159
1160        Table(
1161            "node_to_nodes",
1162            metadata,
1163            Column(
1164                "left_node_id",
1165                Integer,
1166                ForeignKey("nodes.id"),
1167                primary_key=True,
1168            ),
1169            Column(
1170                "right_node_id",
1171                Integer,
1172                ForeignKey("nodes.id"),
1173                primary_key=True,
1174            ),
1175        )
1176
1177    def test_many_to_many_one(self):
1178        nodes, node_to_nodes = self.tables.nodes, self.tables.node_to_nodes
1179
1180        class Node(fixtures.ComparableEntity):
1181            pass
1182
1183        mapper(
1184            Node,
1185            nodes,
1186            properties={
1187                "children": relationship(
1188                    Node,
1189                    secondary=node_to_nodes,
1190                    primaryjoin=nodes.c.id == node_to_nodes.c.left_node_id,
1191                    secondaryjoin=nodes.c.id == node_to_nodes.c.right_node_id,
1192                    backref="parents",
1193                ),
1194                "favorite": relationship(Node, remote_side=nodes.c.id),
1195            },
1196        )
1197
1198        sess = create_session()
1199        n1 = Node(data="n1")
1200        n2 = Node(data="n2")
1201        n3 = Node(data="n3")
1202        n4 = Node(data="n4")
1203        n5 = Node(data="n5")
1204
1205        n4.favorite = n3
1206        n1.favorite = n5
1207        n5.favorite = n2
1208
1209        n1.children = [n2, n3, n4]
1210        n2.children = [n3, n5]
1211        n3.children = [n5, n4]
1212
1213        sess.add_all([n1, n2, n3, n4, n5])
1214
1215        # can't really assert the SQL on this easily
1216        # since there's too many ways to insert the rows.
1217        # so check the end result
1218        sess.flush()
1219        eq_(
1220            sess.query(
1221                node_to_nodes.c.left_node_id, node_to_nodes.c.right_node_id
1222            )
1223            .order_by(
1224                node_to_nodes.c.left_node_id, node_to_nodes.c.right_node_id
1225            )
1226            .all(),
1227            sorted(
1228                [
1229                    (n1.id, n2.id),
1230                    (n1.id, n3.id),
1231                    (n1.id, n4.id),
1232                    (n2.id, n3.id),
1233                    (n2.id, n5.id),
1234                    (n3.id, n5.id),
1235                    (n3.id, n4.id),
1236                ]
1237            ),
1238        )
1239
1240        sess.delete(n1)
1241
1242        self.assert_sql_execution(
1243            testing.db,
1244            sess.flush,
1245            # this is n1.parents firing off, as it should, since
1246            # passive_deletes is False for n1.parents
1247            CompiledSQL(
1248                "SELECT nodes.id AS nodes_id, nodes.data AS nodes_data, "
1249                "nodes.favorite_node_id AS nodes_favorite_node_id FROM "
1250                "nodes, node_to_nodes WHERE :param_1 = "
1251                "node_to_nodes.right_node_id AND nodes.id = "
1252                "node_to_nodes.left_node_id",
1253                lambda ctx: {"param_1": n1.id},
1254            ),
1255            CompiledSQL(
1256                "DELETE FROM node_to_nodes WHERE "
1257                "node_to_nodes.left_node_id = :left_node_id AND "
1258                "node_to_nodes.right_node_id = :right_node_id",
1259                lambda ctx: [
1260                    {"right_node_id": n2.id, "left_node_id": n1.id},
1261                    {"right_node_id": n3.id, "left_node_id": n1.id},
1262                    {"right_node_id": n4.id, "left_node_id": n1.id},
1263                ],
1264            ),
1265            CompiledSQL(
1266                "DELETE FROM nodes WHERE nodes.id = :id",
1267                lambda ctx: {"id": n1.id},
1268            ),
1269        )
1270
1271        for n in [n2, n3, n4, n5]:
1272            sess.delete(n)
1273
1274        # load these collections
1275        # outside of the flush() below
1276        n4.children
1277        n5.children
1278
1279        self.assert_sql_execution(
1280            testing.db,
1281            sess.flush,
1282            CompiledSQL(
1283                "DELETE FROM node_to_nodes WHERE node_to_nodes.left_node_id "
1284                "= :left_node_id AND node_to_nodes.right_node_id = "
1285                ":right_node_id",
1286                lambda ctx: [
1287                    {"right_node_id": n5.id, "left_node_id": n3.id},
1288                    {"right_node_id": n4.id, "left_node_id": n3.id},
1289                    {"right_node_id": n3.id, "left_node_id": n2.id},
1290                    {"right_node_id": n5.id, "left_node_id": n2.id},
1291                ],
1292            ),
1293            CompiledSQL(
1294                "DELETE FROM nodes WHERE nodes.id = :id",
1295                lambda ctx: [{"id": n4.id}, {"id": n5.id}],
1296            ),
1297            CompiledSQL(
1298                "DELETE FROM nodes WHERE nodes.id = :id",
1299                lambda ctx: [{"id": n2.id}, {"id": n3.id}],
1300            ),
1301        )
1302
1303
1304class RowswitchAccountingTest(fixtures.MappedTest):
1305    @classmethod
1306    def define_tables(cls, metadata):
1307        Table(
1308            "parent",
1309            metadata,
1310            Column("id", Integer, primary_key=True),
1311            Column("data", Integer),
1312        )
1313        Table(
1314            "child",
1315            metadata,
1316            Column("id", Integer, ForeignKey("parent.id"), primary_key=True),
1317            Column("data", Integer),
1318        )
1319
1320    def _fixture(self):
1321        parent, child = self.tables.parent, self.tables.child
1322
1323        class Parent(fixtures.BasicEntity):
1324            pass
1325
1326        class Child(fixtures.BasicEntity):
1327            pass
1328
1329        mapper(
1330            Parent,
1331            parent,
1332            properties={
1333                "child": relationship(
1334                    Child,
1335                    uselist=False,
1336                    cascade="all, delete-orphan",
1337                    backref="parent",
1338                )
1339            },
1340        )
1341        mapper(Child, child)
1342        return Parent, Child
1343
1344    def test_switch_on_update(self):
1345        Parent, Child = self._fixture()
1346
1347        sess = create_session(autocommit=False)
1348
1349        p1 = Parent(id=1, child=Child())
1350        sess.add(p1)
1351        sess.commit()
1352
1353        sess.close()
1354        p2 = Parent(id=1, child=Child())
1355        p3 = sess.merge(p2)
1356
1357        old = attributes.get_history(p3, "child")[2][0]
1358        assert old in sess
1359
1360        # essentially no SQL should emit here,
1361        # because we've replaced the row with another identical one
1362        sess.flush()
1363
1364        assert p3.child._sa_instance_state.session_id == sess.hash_key
1365        assert p3.child in sess
1366
1367        p4 = Parent(id=1, child=Child())
1368        p5 = sess.merge(p4)
1369
1370        old = attributes.get_history(p5, "child")[2][0]
1371        assert old in sess
1372
1373        sess.flush()
1374
1375    def test_switch_on_delete(self):
1376        Parent, Child = self._fixture()
1377
1378        sess = Session()
1379        p1 = Parent(id=1, data=2, child=None)
1380        sess.add(p1)
1381        sess.flush()
1382
1383        p1.id = 5
1384        sess.delete(p1)
1385        eq_(p1.id, 5)
1386        sess.flush()
1387
1388        eq_(
1389            sess.scalar(
1390                select([func.count("*")]).select_from(self.tables.parent)
1391            ),
1392            0,
1393        )
1394
1395        sess.close()
1396
1397
1398class RowswitchM2OTest(fixtures.MappedTest):
1399    # tests for #3060 and related issues
1400
1401    @classmethod
1402    def define_tables(cls, metadata):
1403        Table("a", metadata, Column("id", Integer, primary_key=True))
1404        Table(
1405            "b",
1406            metadata,
1407            Column("id", Integer, primary_key=True),
1408            Column("aid", ForeignKey("a.id")),
1409            Column("cid", ForeignKey("c.id")),
1410            Column("data", String(50)),
1411        )
1412        Table("c", metadata, Column("id", Integer, primary_key=True))
1413
1414    def _fixture(self):
1415        a, b, c = self.tables.a, self.tables.b, self.tables.c
1416
1417        class A(fixtures.BasicEntity):
1418            pass
1419
1420        class B(fixtures.BasicEntity):
1421            pass
1422
1423        class C(fixtures.BasicEntity):
1424            pass
1425
1426        mapper(
1427            A,
1428            a,
1429            properties={"bs": relationship(B, cascade="all, delete-orphan")},
1430        )
1431        mapper(B, b, properties={"c": relationship(C)})
1432        mapper(C, c)
1433        return A, B, C
1434
1435    def test_set_none_replaces_m2o(self):
1436        # we have to deal here with the fact that a
1437        # get of an unset attribute implicitly sets it to None
1438        # with no history.  So while we'd like "b.x = None" to
1439        # record that "None" was added and we can then actively set it,
1440        # a simple read of "b.x" ruins that; we'd have to dramatically
1441        # alter the semantics of get() such that it creates history, which
1442        # would incur extra work within the flush process to deal with
1443        # change that previously showed up as nothing.
1444
1445        A, B, C = self._fixture()
1446        sess = Session()
1447
1448        sess.add(A(id=1, bs=[B(id=1, c=C(id=1))]))
1449        sess.commit()
1450
1451        a1 = sess.query(A).first()
1452        a1.bs = [B(id=1, c=None)]
1453        sess.commit()
1454        assert a1.bs[0].c is None
1455
1456    def test_set_none_w_get_replaces_m2o(self):
1457        A, B, C = self._fixture()
1458        sess = Session()
1459
1460        sess.add(A(id=1, bs=[B(id=1, c=C(id=1))]))
1461        sess.commit()
1462
1463        a1 = sess.query(A).first()
1464        b2 = B(id=1)
1465        assert b2.c is None
1466        b2.c = None
1467        a1.bs = [b2]
1468        sess.commit()
1469        assert a1.bs[0].c is None
1470
1471    def test_set_none_replaces_scalar(self):
1472        # this case worked before #3060, because a straight scalar
1473        # set of None shows up.  However, as test_set_none_w_get
1474        # shows, we can't rely on this - the get of None will blow
1475        # away the history.
1476        A, B, C = self._fixture()
1477        sess = Session()
1478
1479        sess.add(A(id=1, bs=[B(id=1, data="somedata")]))
1480        sess.commit()
1481
1482        a1 = sess.query(A).first()
1483        a1.bs = [B(id=1, data=None)]
1484        sess.commit()
1485        assert a1.bs[0].data is None
1486
1487    def test_set_none_w_get_replaces_scalar(self):
1488        A, B, C = self._fixture()
1489        sess = Session()
1490
1491        sess.add(A(id=1, bs=[B(id=1, data="somedata")]))
1492        sess.commit()
1493
1494        a1 = sess.query(A).first()
1495        b2 = B(id=1)
1496        assert b2.data is None
1497        b2.data = None
1498        a1.bs = [b2]
1499        sess.commit()
1500        assert a1.bs[0].data is None
1501
1502
1503class BasicStaleChecksTest(fixtures.MappedTest):
1504    @classmethod
1505    def define_tables(cls, metadata):
1506        Table(
1507            "parent",
1508            metadata,
1509            Column("id", Integer, primary_key=True),
1510            Column("data", Integer),
1511        )
1512        Table(
1513            "child",
1514            metadata,
1515            Column("id", Integer, ForeignKey("parent.id"), primary_key=True),
1516            Column("data", Integer),
1517        )
1518
1519    def _fixture(self, confirm_deleted_rows=True):
1520        parent, child = self.tables.parent, self.tables.child
1521
1522        class Parent(fixtures.BasicEntity):
1523            pass
1524
1525        class Child(fixtures.BasicEntity):
1526            pass
1527
1528        mapper(
1529            Parent,
1530            parent,
1531            properties={
1532                "child": relationship(
1533                    Child,
1534                    uselist=False,
1535                    cascade="all, delete-orphan",
1536                    backref="parent",
1537                )
1538            },
1539            confirm_deleted_rows=confirm_deleted_rows,
1540        )
1541        mapper(Child, child)
1542        return Parent, Child
1543
1544    @testing.requires.sane_rowcount
1545    def test_update_single_missing(self):
1546        Parent, Child = self._fixture()
1547        sess = Session()
1548        p1 = Parent(id=1, data=2)
1549        sess.add(p1)
1550        sess.flush()
1551
1552        sess.execute(self.tables.parent.delete())
1553
1554        p1.data = 3
1555        assert_raises_message(
1556            orm_exc.StaleDataError,
1557            r"UPDATE statement on table 'parent' expected to "
1558            r"update 1 row\(s\); 0 were matched.",
1559            sess.flush,
1560        )
1561
1562    @testing.requires.sane_rowcount
1563    def test_update_single_missing_broken_multi_rowcount(self):
1564        @util.memoized_property
1565        def rowcount(self):
1566            if len(self.context.compiled_parameters) > 1:
1567                return -1
1568            else:
1569                return self.context.rowcount
1570
1571        with patch.object(
1572            config.db.dialect, "supports_sane_multi_rowcount", False
1573        ):
1574            with patch(
1575                "sqlalchemy.engine.result.ResultProxy.rowcount", rowcount
1576            ):
1577                Parent, Child = self._fixture()
1578                sess = Session()
1579                p1 = Parent(id=1, data=2)
1580                sess.add(p1)
1581                sess.flush()
1582
1583                sess.execute(self.tables.parent.delete())
1584
1585                p1.data = 3
1586                assert_raises_message(
1587                    orm_exc.StaleDataError,
1588                    r"UPDATE statement on table 'parent' expected to "
1589                    r"update 1 row\(s\); 0 were matched.",
1590                    sess.flush,
1591                )
1592
1593    def test_update_multi_missing_broken_multi_rowcount(self):
1594        @util.memoized_property
1595        def rowcount(self):
1596            if len(self.context.compiled_parameters) > 1:
1597                return -1
1598            else:
1599                return self.context.rowcount
1600
1601        with patch.object(
1602            config.db.dialect, "supports_sane_multi_rowcount", False
1603        ):
1604            with patch(
1605                "sqlalchemy.engine.result.ResultProxy.rowcount", rowcount
1606            ):
1607                Parent, Child = self._fixture()
1608                sess = Session()
1609                p1 = Parent(id=1, data=2)
1610                p2 = Parent(id=2, data=3)
1611                sess.add_all([p1, p2])
1612                sess.flush()
1613
1614                sess.execute(self.tables.parent.delete().where(Parent.id == 1))
1615
1616                p1.data = 3
1617                p2.data = 4
1618                sess.flush()  # no exception
1619
1620                # update occurred for remaining row
1621                eq_(sess.query(Parent.id, Parent.data).all(), [(2, 4)])
1622
1623    def test_update_value_missing_broken_multi_rowcount(self):
1624        @util.memoized_property
1625        def rowcount(self):
1626            if len(self.context.compiled_parameters) > 1:
1627                return -1
1628            else:
1629                return self.context.rowcount
1630
1631        with patch.object(
1632            config.db.dialect, "supports_sane_multi_rowcount", False
1633        ):
1634            with patch(
1635                "sqlalchemy.engine.result.ResultProxy.rowcount", rowcount
1636            ):
1637                Parent, Child = self._fixture()
1638                sess = Session()
1639                p1 = Parent(id=1, data=1)
1640                sess.add(p1)
1641                sess.flush()
1642
1643                sess.execute(self.tables.parent.delete())
1644
1645                p1.data = literal(1)
1646                assert_raises_message(
1647                    orm_exc.StaleDataError,
1648                    r"UPDATE statement on table 'parent' expected to "
1649                    r"update 1 row\(s\); 0 were matched.",
1650                    sess.flush,
1651                )
1652
1653    @testing.requires.sane_rowcount
1654    def test_delete_twice(self):
1655        Parent, Child = self._fixture()
1656        sess = Session()
1657        p1 = Parent(id=1, data=2, child=None)
1658        sess.add(p1)
1659        sess.commit()
1660
1661        sess.delete(p1)
1662        sess.flush()
1663
1664        sess.delete(p1)
1665
1666        assert_raises_message(
1667            exc.SAWarning,
1668            r"DELETE statement on table 'parent' expected to "
1669            r"delete 1 row\(s\); 0 were matched.",
1670            sess.commit,
1671        )
1672
1673    @testing.requires.sane_multi_rowcount
1674    def test_delete_multi_missing_warning(self):
1675        Parent, Child = self._fixture()
1676        sess = Session()
1677        p1 = Parent(id=1, data=2, child=None)
1678        p2 = Parent(id=2, data=3, child=None)
1679        sess.add_all([p1, p2])
1680        sess.flush()
1681
1682        sess.execute(self.tables.parent.delete())
1683        sess.delete(p1)
1684        sess.delete(p2)
1685
1686        assert_raises_message(
1687            exc.SAWarning,
1688            r"DELETE statement on table 'parent' expected to "
1689            r"delete 2 row\(s\); 0 were matched.",
1690            sess.flush,
1691        )
1692
1693    def test_delete_multi_missing_allow(self):
1694        Parent, Child = self._fixture(confirm_deleted_rows=False)
1695        sess = Session()
1696        p1 = Parent(id=1, data=2, child=None)
1697        p2 = Parent(id=2, data=3, child=None)
1698        sess.add_all([p1, p2])
1699        sess.flush()
1700
1701        sess.execute(self.tables.parent.delete())
1702        sess.delete(p1)
1703        sess.delete(p2)
1704
1705        sess.flush()
1706
1707
1708class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
1709    @classmethod
1710    def define_tables(cls, metadata):
1711        Table(
1712            "t",
1713            metadata,
1714            Column(
1715                "id", Integer, primary_key=True, test_needs_autoincrement=True
1716            ),
1717            Column("data", String(50)),
1718            Column("def_", String(50), server_default="def1"),
1719        )
1720
1721    def test_batch_interaction(self):
1722        """test batching groups same-structured, primary
1723        key present statements together.
1724
1725        """
1726
1727        t = self.tables.t
1728
1729        class T(fixtures.ComparableEntity):
1730            pass
1731
1732        mapper(T, t)
1733        sess = Session()
1734        sess.add_all(
1735            [
1736                T(data="t1"),
1737                T(data="t2"),
1738                T(id=3, data="t3"),
1739                T(id=4, data="t4"),
1740                T(id=5, data="t5"),
1741                T(id=6, data=func.lower("t6")),
1742                T(id=7, data="t7"),
1743                T(id=8, data="t8"),
1744                T(id=9, data="t9", def_="def2"),
1745                T(id=10, data="t10", def_="def3"),
1746                T(id=11, data="t11"),
1747            ]
1748        )
1749
1750        self.assert_sql_execution(
1751            testing.db,
1752            sess.flush,
1753            CompiledSQL("INSERT INTO t (data) VALUES (:data)", {"data": "t1"}),
1754            CompiledSQL("INSERT INTO t (data) VALUES (:data)", {"data": "t2"}),
1755            CompiledSQL(
1756                "INSERT INTO t (id, data) VALUES (:id, :data)",
1757                [
1758                    {"data": "t3", "id": 3},
1759                    {"data": "t4", "id": 4},
1760                    {"data": "t5", "id": 5},
1761                ],
1762            ),
1763            CompiledSQL(
1764                "INSERT INTO t (id, data) VALUES (:id, lower(:lower_1))",
1765                {"lower_1": "t6", "id": 6},
1766            ),
1767            CompiledSQL(
1768                "INSERT INTO t (id, data) VALUES (:id, :data)",
1769                [{"data": "t7", "id": 7}, {"data": "t8", "id": 8}],
1770            ),
1771            CompiledSQL(
1772                "INSERT INTO t (id, data, def_) VALUES (:id, :data, :def_)",
1773                [
1774                    {"data": "t9", "id": 9, "def_": "def2"},
1775                    {"data": "t10", "id": 10, "def_": "def3"},
1776                ],
1777            ),
1778            CompiledSQL(
1779                "INSERT INTO t (id, data) VALUES (:id, :data)",
1780                {"data": "t11", "id": 11},
1781            ),
1782        )
1783
1784
1785class LoadersUsingCommittedTest(UOWTest):
1786
1787    """Test that events which occur within a flush()
1788    get the same attribute loading behavior as on the outside
1789    of the flush, and that the unit of work itself uses the
1790    "committed" version of primary/foreign key attributes
1791    when loading a collection for historical purposes (this typically
1792    has importance for when primary key values change).
1793
1794    """
1795
1796    def _mapper_setup(self, passive_updates=True):
1797        users, Address, addresses, User = (
1798            self.tables.users,
1799            self.classes.Address,
1800            self.tables.addresses,
1801            self.classes.User,
1802        )
1803
1804        mapper(
1805            User,
1806            users,
1807            properties={
1808                "addresses": relationship(
1809                    Address,
1810                    order_by=addresses.c.email_address,
1811                    passive_updates=passive_updates,
1812                    backref="user",
1813                )
1814            },
1815        )
1816        mapper(Address, addresses)
1817        return create_session(autocommit=False)
1818
1819    def test_before_update_m2o(self):
1820        """Expect normal many to one attribute load behavior
1821        (should not get committed value)
1822        from within public 'before_update' event"""
1823        sess = self._mapper_setup()
1824
1825        Address, User = self.classes.Address, self.classes.User
1826
1827        def before_update(mapper, connection, target):
1828            # if get committed is used to find target.user, then
1829            # it will be still be u1 instead of u2
1830            assert target.user.id == target.user_id == u2.id
1831
1832        from sqlalchemy import event
1833
1834        event.listen(Address, "before_update", before_update)
1835
1836        a1 = Address(email_address="a1")
1837        u1 = User(name="u1", addresses=[a1])
1838        sess.add(u1)
1839
1840        u2 = User(name="u2")
1841        sess.add(u2)
1842        sess.commit()
1843
1844        sess.expunge_all()
1845        # lookup an address and move it to the other user
1846        a1 = sess.query(Address).get(a1.id)
1847
1848        # move address to another user's fk
1849        assert a1.user_id == u1.id
1850        a1.user_id = u2.id
1851
1852        sess.flush()
1853
1854    def test_before_update_o2m_passive(self):
1855        """Expect normal one to many attribute load behavior
1856        (should not get committed value)
1857        from within public 'before_update' event"""
1858        self._test_before_update_o2m(True)
1859
1860    def test_before_update_o2m_notpassive(self):
1861        """Expect normal one to many attribute load behavior
1862        (should not get committed value)
1863        from within public 'before_update' event with
1864        passive_updates=False
1865
1866        """
1867        self._test_before_update_o2m(False)
1868
1869    def _test_before_update_o2m(self, passive_updates):
1870        sess = self._mapper_setup(passive_updates=passive_updates)
1871
1872        Address, User = self.classes.Address, self.classes.User
1873
1874        class AvoidReferencialError(Exception):
1875
1876            """the test here would require ON UPDATE CASCADE on FKs
1877            for the flush to fully succeed; this exception is used
1878            to cancel the flush before we get that far.
1879
1880            """
1881
1882        def before_update(mapper, connection, target):
1883            if passive_updates:
1884                # we shouldn't be using committed value.
1885                # so, having switched target's primary key,
1886                # we expect no related items in the collection
1887                # since we are using passive_updates
1888                # this is a behavior change since #2350
1889                assert "addresses" not in target.__dict__
1890                eq_(target.addresses, [])
1891            else:
1892                # in contrast with passive_updates=True,
1893                # here we expect the orm to have looked up the addresses
1894                # with the committed value (it needs to in order to
1895                # update the foreign keys).  So we expect addresses
1896                # collection to move with the user,
1897                # (just like they will be after the update)
1898
1899                # collection is already loaded
1900                assert "addresses" in target.__dict__
1901                eq_([a.id for a in target.addresses], [a.id for a in [a1, a2]])
1902            raise AvoidReferencialError()
1903
1904        from sqlalchemy import event
1905
1906        event.listen(User, "before_update", before_update)
1907
1908        a1 = Address(email_address="jack1")
1909        a2 = Address(email_address="jack2")
1910        u1 = User(id=1, name="jack", addresses=[a1, a2])
1911        sess.add(u1)
1912        sess.commit()
1913
1914        sess.expunge_all()
1915        u1 = sess.query(User).get(u1.id)
1916        u1.id = 2
1917        try:
1918            sess.flush()
1919        except AvoidReferencialError:
1920            pass
1921
1922
1923class NoAttrEventInFlushTest(fixtures.MappedTest):
1924    """test [ticket:3167].
1925
1926    See also RefreshFlushInReturningTest in test/orm/test_events.py which
1927    tests the positive case for the refresh_flush event, added in
1928    [ticket:3427].
1929
1930    """
1931
1932    __backend__ = True
1933
1934    @classmethod
1935    def define_tables(cls, metadata):
1936        Table(
1937            "test",
1938            metadata,
1939            Column(
1940                "id", Integer, primary_key=True, test_needs_autoincrement=True
1941            ),
1942            Column("prefetch_val", Integer, default=5),
1943            Column("returning_val", Integer, server_default="5"),
1944        )
1945
1946    @classmethod
1947    def setup_classes(cls):
1948        class Thing(cls.Basic):
1949            pass
1950
1951    @classmethod
1952    def setup_mappers(cls):
1953        Thing = cls.classes.Thing
1954
1955        mapper(Thing, cls.tables.test, eager_defaults=True)
1956
1957    def test_no_attr_events_flush(self):
1958        Thing = self.classes.Thing
1959        mock = Mock()
1960        event.listen(Thing.id, "set", mock.id)
1961        event.listen(Thing.prefetch_val, "set", mock.prefetch_val)
1962        event.listen(Thing.returning_val, "set", mock.prefetch_val)
1963        t1 = Thing()
1964        s = Session()
1965        s.add(t1)
1966        s.flush()
1967
1968        eq_(len(mock.mock_calls), 0)
1969        eq_(t1.id, 1)
1970        eq_(t1.prefetch_val, 5)
1971        eq_(t1.returning_val, 5)
1972
1973
1974class EagerDefaultsTest(fixtures.MappedTest):
1975    __backend__ = True
1976
1977    @classmethod
1978    def define_tables(cls, metadata):
1979        Table(
1980            "test",
1981            metadata,
1982            Column("id", Integer, primary_key=True),
1983            Column("foo", Integer, server_default="3"),
1984        )
1985
1986        Table(
1987            "test2",
1988            metadata,
1989            Column("id", Integer, primary_key=True),
1990            Column("foo", Integer),
1991            Column("bar", Integer, server_onupdate=FetchedValue()),
1992        )
1993
1994    @classmethod
1995    def setup_classes(cls):
1996        class Thing(cls.Basic):
1997            pass
1998
1999        class Thing2(cls.Basic):
2000            pass
2001
2002    @classmethod
2003    def setup_mappers(cls):
2004        Thing = cls.classes.Thing
2005
2006        mapper(Thing, cls.tables.test, eager_defaults=True)
2007
2008        Thing2 = cls.classes.Thing2
2009
2010        mapper(Thing2, cls.tables.test2, eager_defaults=True)
2011
2012    def test_insert_defaults_present(self):
2013        Thing = self.classes.Thing
2014        s = Session()
2015
2016        t1, t2 = (Thing(id=1, foo=5), Thing(id=2, foo=10))
2017
2018        s.add_all([t1, t2])
2019
2020        self.assert_sql_execution(
2021            testing.db,
2022            s.flush,
2023            CompiledSQL(
2024                "INSERT INTO test (id, foo) VALUES (:id, :foo)",
2025                [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}],
2026            ),
2027        )
2028
2029        def go():
2030            eq_(t1.foo, 5)
2031            eq_(t2.foo, 10)
2032
2033        self.assert_sql_count(testing.db, go, 0)
2034
2035    def test_insert_defaults_present_as_expr(self):
2036        Thing = self.classes.Thing
2037        s = Session()
2038
2039        t1, t2 = (
2040            Thing(id=1, foo=text("2 + 5")),
2041            Thing(id=2, foo=text("5 + 5")),
2042        )
2043
2044        s.add_all([t1, t2])
2045
2046        if testing.db.dialect.implicit_returning:
2047            self.assert_sql_execution(
2048                testing.db,
2049                s.flush,
2050                CompiledSQL(
2051                    "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) "
2052                    "RETURNING test.foo",
2053                    [{"id": 1}],
2054                    dialect="postgresql",
2055                ),
2056                CompiledSQL(
2057                    "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) "
2058                    "RETURNING test.foo",
2059                    [{"id": 2}],
2060                    dialect="postgresql",
2061                ),
2062            )
2063
2064        else:
2065            self.assert_sql_execution(
2066                testing.db,
2067                s.flush,
2068                CompiledSQL(
2069                    "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)",
2070                    [{"id": 1}],
2071                ),
2072                CompiledSQL(
2073                    "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)",
2074                    [{"id": 2}],
2075                ),
2076                CompiledSQL(
2077                    "SELECT test.foo AS test_foo FROM test "
2078                    "WHERE test.id = :param_1",
2079                    [{"param_1": 1}],
2080                ),
2081                CompiledSQL(
2082                    "SELECT test.foo AS test_foo FROM test "
2083                    "WHERE test.id = :param_1",
2084                    [{"param_1": 2}],
2085                ),
2086            )
2087
2088        def go():
2089            eq_(t1.foo, 7)
2090            eq_(t2.foo, 10)
2091
2092        self.assert_sql_count(testing.db, go, 0)
2093
2094    def test_insert_defaults_nonpresent(self):
2095        Thing = self.classes.Thing
2096        s = Session()
2097
2098        t1, t2 = (Thing(id=1), Thing(id=2))
2099
2100        s.add_all([t1, t2])
2101
2102        if testing.db.dialect.implicit_returning:
2103            self.assert_sql_execution(
2104                testing.db,
2105                s.commit,
2106                CompiledSQL(
2107                    "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
2108                    [{"id": 1}],
2109                    dialect="postgresql",
2110                ),
2111                CompiledSQL(
2112                    "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
2113                    [{"id": 2}],
2114                    dialect="postgresql",
2115                ),
2116            )
2117        else:
2118            self.assert_sql_execution(
2119                testing.db,
2120                s.commit,
2121                CompiledSQL(
2122                    "INSERT INTO test (id) VALUES (:id)",
2123                    [{"id": 1}, {"id": 2}],
2124                ),
2125                CompiledSQL(
2126                    "SELECT test.foo AS test_foo FROM test "
2127                    "WHERE test.id = :param_1",
2128                    [{"param_1": 1}],
2129                ),
2130                CompiledSQL(
2131                    "SELECT test.foo AS test_foo FROM test "
2132                    "WHERE test.id = :param_1",
2133                    [{"param_1": 2}],
2134                ),
2135            )
2136
2137    def test_update_defaults_nonpresent(self):
2138        Thing2 = self.classes.Thing2
2139        s = Session()
2140
2141        t1, t2, t3, t4 = (
2142            Thing2(id=1, foo=1, bar=2),
2143            Thing2(id=2, foo=2, bar=3),
2144            Thing2(id=3, foo=3, bar=4),
2145            Thing2(id=4, foo=4, bar=5),
2146        )
2147
2148        s.add_all([t1, t2, t3, t4])
2149        s.flush()
2150
2151        t1.foo = 5
2152        t2.foo = 6
2153        t2.bar = 10
2154        t3.foo = 7
2155        t4.foo = 8
2156        t4.bar = 12
2157
2158        if testing.db.dialect.implicit_returning:
2159            self.assert_sql_execution(
2160                testing.db,
2161                s.flush,
2162                CompiledSQL(
2163                    "UPDATE test2 SET foo=%(foo)s "
2164                    "WHERE test2.id = %(test2_id)s "
2165                    "RETURNING test2.bar",
2166                    [{"foo": 5, "test2_id": 1}],
2167                    dialect="postgresql",
2168                ),
2169                CompiledSQL(
2170                    "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
2171                    "WHERE test2.id = %(test2_id)s",
2172                    [{"foo": 6, "bar": 10, "test2_id": 2}],
2173                    dialect="postgresql",
2174                ),
2175                CompiledSQL(
2176                    "UPDATE test2 SET foo=%(foo)s "
2177                    "WHERE test2.id = %(test2_id)s "
2178                    "RETURNING test2.bar",
2179                    [{"foo": 7, "test2_id": 3}],
2180                    dialect="postgresql",
2181                ),
2182                CompiledSQL(
2183                    "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
2184                    "WHERE test2.id = %(test2_id)s",
2185                    [{"foo": 8, "bar": 12, "test2_id": 4}],
2186                    dialect="postgresql",
2187                ),
2188            )
2189        else:
2190            self.assert_sql_execution(
2191                testing.db,
2192                s.flush,
2193                CompiledSQL(
2194                    "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
2195                    [{"foo": 5, "test2_id": 1}],
2196                ),
2197                CompiledSQL(
2198                    "UPDATE test2 SET foo=:foo, bar=:bar "
2199                    "WHERE test2.id = :test2_id",
2200                    [{"foo": 6, "bar": 10, "test2_id": 2}],
2201                ),
2202                CompiledSQL(
2203                    "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
2204                    [{"foo": 7, "test2_id": 3}],
2205                ),
2206                CompiledSQL(
2207                    "UPDATE test2 SET foo=:foo, bar=:bar "
2208                    "WHERE test2.id = :test2_id",
2209                    [{"foo": 8, "bar": 12, "test2_id": 4}],
2210                ),
2211                CompiledSQL(
2212                    "SELECT test2.bar AS test2_bar FROM test2 "
2213                    "WHERE test2.id = :param_1",
2214                    [{"param_1": 1}],
2215                ),
2216                CompiledSQL(
2217                    "SELECT test2.bar AS test2_bar FROM test2 "
2218                    "WHERE test2.id = :param_1",
2219                    [{"param_1": 3}],
2220                ),
2221            )
2222
2223        def go():
2224            eq_(t1.bar, 2)
2225            eq_(t2.bar, 10)
2226            eq_(t3.bar, 4)
2227            eq_(t4.bar, 12)
2228
2229        self.assert_sql_count(testing.db, go, 0)
2230
2231    def test_update_defaults_present_as_expr(self):
2232        Thing2 = self.classes.Thing2
2233        s = Session()
2234
2235        t1, t2, t3, t4 = (
2236            Thing2(id=1, foo=1, bar=2),
2237            Thing2(id=2, foo=2, bar=3),
2238            Thing2(id=3, foo=3, bar=4),
2239            Thing2(id=4, foo=4, bar=5),
2240        )
2241
2242        s.add_all([t1, t2, t3, t4])
2243        s.flush()
2244
2245        t1.foo = 5
2246        t1.bar = text("1 + 1")
2247        t2.foo = 6
2248        t2.bar = 10
2249        t3.foo = 7
2250        t4.foo = 8
2251        t4.bar = text("5 + 7")
2252
2253        if testing.db.dialect.implicit_returning:
2254            self.assert_sql_execution(
2255                testing.db,
2256                s.flush,
2257                CompiledSQL(
2258                    "UPDATE test2 SET foo=%(foo)s, bar=1 + 1 "
2259                    "WHERE test2.id = %(test2_id)s "
2260                    "RETURNING test2.bar",
2261                    [{"foo": 5, "test2_id": 1}],
2262                    dialect="postgresql",
2263                ),
2264                CompiledSQL(
2265                    "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
2266                    "WHERE test2.id = %(test2_id)s",
2267                    [{"foo": 6, "bar": 10, "test2_id": 2}],
2268                    dialect="postgresql",
2269                ),
2270                CompiledSQL(
2271                    "UPDATE test2 SET foo=%(foo)s "
2272                    "WHERE test2.id = %(test2_id)s "
2273                    "RETURNING test2.bar",
2274                    [{"foo": 7, "test2_id": 3}],
2275                    dialect="postgresql",
2276                ),
2277                CompiledSQL(
2278                    "UPDATE test2 SET foo=%(foo)s, bar=5 + 7 "
2279                    "WHERE test2.id = %(test2_id)s RETURNING test2.bar",
2280                    [{"foo": 8, "test2_id": 4}],
2281                    dialect="postgresql",
2282                ),
2283            )
2284        else:
2285            self.assert_sql_execution(
2286                testing.db,
2287                s.flush,
2288                CompiledSQL(
2289                    "UPDATE test2 SET foo=:foo, bar=1 + 1 "
2290                    "WHERE test2.id = :test2_id",
2291                    [{"foo": 5, "test2_id": 1}],
2292                ),
2293                CompiledSQL(
2294                    "UPDATE test2 SET foo=:foo, bar=:bar "
2295                    "WHERE test2.id = :test2_id",
2296                    [{"foo": 6, "bar": 10, "test2_id": 2}],
2297                ),
2298                CompiledSQL(
2299                    "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
2300                    [{"foo": 7, "test2_id": 3}],
2301                ),
2302                CompiledSQL(
2303                    "UPDATE test2 SET foo=:foo, bar=5 + 7 "
2304                    "WHERE test2.id = :test2_id",
2305                    [{"foo": 8, "test2_id": 4}],
2306                ),
2307                CompiledSQL(
2308                    "SELECT test2.bar AS test2_bar FROM test2 "
2309                    "WHERE test2.id = :param_1",
2310                    [{"param_1": 1}],
2311                ),
2312                CompiledSQL(
2313                    "SELECT test2.bar AS test2_bar FROM test2 "
2314                    "WHERE test2.id = :param_1",
2315                    [{"param_1": 3}],
2316                ),
2317                CompiledSQL(
2318                    "SELECT test2.bar AS test2_bar FROM test2 "
2319                    "WHERE test2.id = :param_1",
2320                    [{"param_1": 4}],
2321                ),
2322            )
2323
2324        def go():
2325            eq_(t1.bar, 2)
2326            eq_(t2.bar, 10)
2327            eq_(t3.bar, 4)
2328            eq_(t4.bar, 12)
2329
2330        self.assert_sql_count(testing.db, go, 0)
2331
2332    def test_insert_defaults_bulk_insert(self):
2333        Thing = self.classes.Thing
2334        s = Session()
2335
2336        mappings = [{"id": 1}, {"id": 2}]
2337
2338        self.assert_sql_execution(
2339            testing.db,
2340            lambda: s.bulk_insert_mappings(Thing, mappings),
2341            CompiledSQL(
2342                "INSERT INTO test (id) VALUES (:id)", [{"id": 1}, {"id": 2}]
2343            ),
2344        )
2345
2346    def test_update_defaults_bulk_update(self):
2347        Thing2 = self.classes.Thing2
2348        s = Session()
2349
2350        t1, t2, t3, t4 = (
2351            Thing2(id=1, foo=1, bar=2),
2352            Thing2(id=2, foo=2, bar=3),
2353            Thing2(id=3, foo=3, bar=4),
2354            Thing2(id=4, foo=4, bar=5),
2355        )
2356
2357        s.add_all([t1, t2, t3, t4])
2358        s.flush()
2359
2360        mappings = [
2361            {"id": 1, "foo": 5},
2362            {"id": 2, "foo": 6, "bar": 10},
2363            {"id": 3, "foo": 7},
2364            {"id": 4, "foo": 8},
2365        ]
2366
2367        self.assert_sql_execution(
2368            testing.db,
2369            lambda: s.bulk_update_mappings(Thing2, mappings),
2370            CompiledSQL(
2371                "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
2372                [{"foo": 5, "test2_id": 1}],
2373            ),
2374            CompiledSQL(
2375                "UPDATE test2 SET foo=:foo, bar=:bar "
2376                "WHERE test2.id = :test2_id",
2377                [{"foo": 6, "bar": 10, "test2_id": 2}],
2378            ),
2379            CompiledSQL(
2380                "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
2381                [{"foo": 7, "test2_id": 3}, {"foo": 8, "test2_id": 4}],
2382            ),
2383        )
2384
2385    def test_update_defaults_present(self):
2386        Thing2 = self.classes.Thing2
2387        s = Session()
2388
2389        t1, t2 = (Thing2(id=1, foo=1, bar=2), Thing2(id=2, foo=2, bar=3))
2390
2391        s.add_all([t1, t2])
2392        s.flush()
2393
2394        t1.bar = 5
2395        t2.bar = 10
2396
2397        self.assert_sql_execution(
2398            testing.db,
2399            s.commit,
2400            CompiledSQL(
2401                "UPDATE test2 SET bar=%(bar)s WHERE test2.id = %(test2_id)s",
2402                [{"bar": 5, "test2_id": 1}, {"bar": 10, "test2_id": 2}],
2403                dialect="postgresql",
2404            ),
2405        )
2406
2407    def test_insert_dont_fetch_nondefaults(self):
2408        Thing2 = self.classes.Thing2
2409        s = Session()
2410
2411        t1 = Thing2(id=1, bar=2)
2412
2413        s.add(t1)
2414
2415        self.assert_sql_execution(
2416            testing.db,
2417            s.flush,
2418            CompiledSQL(
2419                "INSERT INTO test2 (id, foo, bar) " "VALUES (:id, :foo, :bar)",
2420                [{"id": 1, "foo": None, "bar": 2}],
2421            ),
2422        )
2423
2424    def test_update_dont_fetch_nondefaults(self):
2425        Thing2 = self.classes.Thing2
2426        s = Session()
2427
2428        t1 = Thing2(id=1, bar=2)
2429
2430        s.add(t1)
2431        s.flush()
2432
2433        s.expire(t1, ["foo"])
2434
2435        t1.bar = 3
2436
2437        self.assert_sql_execution(
2438            testing.db,
2439            s.flush,
2440            CompiledSQL(
2441                "UPDATE test2 SET bar=:bar WHERE test2.id = :test2_id",
2442                [{"bar": 3, "test2_id": 1}],
2443            ),
2444        )
2445
2446
2447class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults):
2448    """test support for custom datatypes that return a non-__bool__ value
2449    when compared via __eq__(), eg. ticket 3469"""
2450
2451    @classmethod
2452    def define_tables(cls, metadata):
2453        from sqlalchemy import TypeDecorator
2454
2455        class NoBool(object):
2456            def __nonzero__(self):
2457                raise NotImplementedError("not supported")
2458
2459        class MyWidget(object):
2460            def __init__(self, text):
2461                self.text = text
2462
2463            def __eq__(self, other):
2464                return NoBool()
2465
2466        cls.MyWidget = MyWidget
2467
2468        class MyType(TypeDecorator):
2469            impl = String(50)
2470
2471            def process_bind_param(self, value, dialect):
2472                if value is not None:
2473                    value = value.text
2474                return value
2475
2476            def process_result_value(self, value, dialect):
2477                if value is not None:
2478                    value = MyWidget(value)
2479                return value
2480
2481        Table(
2482            "test",
2483            metadata,
2484            Column(
2485                "id", Integer, primary_key=True, test_needs_autoincrement=True
2486            ),
2487            Column("value", MyType),
2488            Column("unrelated", String(50)),
2489        )
2490
2491    @classmethod
2492    def setup_classes(cls):
2493        class Thing(cls.Basic):
2494            pass
2495
2496    @classmethod
2497    def setup_mappers(cls):
2498        Thing = cls.classes.Thing
2499
2500        mapper(Thing, cls.tables.test)
2501
2502    def test_update_against_none(self):
2503        Thing = self.classes.Thing
2504
2505        s = Session()
2506        s.add(Thing(value=self.MyWidget("foo")))
2507        s.commit()
2508
2509        t1 = s.query(Thing).first()
2510        t1.value = None
2511        s.commit()
2512
2513        eq_(s.query(Thing.value).scalar(), None)
2514
2515    def test_update_against_something_else(self):
2516        Thing = self.classes.Thing
2517
2518        s = Session()
2519        s.add(Thing(value=self.MyWidget("foo")))
2520        s.commit()
2521
2522        t1 = s.query(Thing).first()
2523        t1.value = self.MyWidget("bar")
2524        s.commit()
2525
2526        eq_(s.query(Thing.value).scalar().text, "bar")
2527
2528    def test_no_update_no_change(self):
2529        Thing = self.classes.Thing
2530
2531        s = Session()
2532        s.add(Thing(value=self.MyWidget("foo"), unrelated="unrelated"))
2533        s.commit()
2534
2535        t1 = s.query(Thing).first()
2536        t1.unrelated = "something else"
2537
2538        self.assert_sql_execution(
2539            testing.db,
2540            s.commit,
2541            CompiledSQL(
2542                "UPDATE test SET unrelated=:unrelated "
2543                "WHERE test.id = :test_id",
2544                [{"test_id": 1, "unrelated": "something else"}],
2545            ),
2546        )
2547
2548        eq_(s.query(Thing.value).scalar().text, "foo")
2549
2550
2551class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults):
2552    @classmethod
2553    def define_tables(cls, metadata):
2554        from sqlalchemy import TypeDecorator
2555
2556        class EvalsNull(TypeDecorator):
2557            impl = String(50)
2558
2559            should_evaluate_none = True
2560
2561            def process_bind_param(self, value, dialect):
2562                if value is None:
2563                    value = "nothing"
2564                return value
2565
2566        Table(
2567            "test",
2568            metadata,
2569            Column(
2570                "id", Integer, primary_key=True, test_needs_autoincrement=True
2571            ),
2572            Column("evals_null_no_default", EvalsNull()),
2573            Column("evals_null_default", EvalsNull(), default="default_val"),
2574            Column("no_eval_null_no_default", String(50)),
2575            Column("no_eval_null_default", String(50), default="default_val"),
2576            Column(
2577                "builtin_evals_null_no_default", String(50).evaluates_none()
2578            ),
2579            Column(
2580                "builtin_evals_null_default",
2581                String(50).evaluates_none(),
2582                default="default_val",
2583            ),
2584        )
2585
2586        Table(
2587            "test_w_renames",
2588            metadata,
2589            Column(
2590                "id", Integer, primary_key=True, test_needs_autoincrement=True
2591            ),
2592            Column("evals_null_no_default", EvalsNull()),
2593            Column("evals_null_default", EvalsNull(), default="default_val"),
2594            Column("no_eval_null_no_default", String(50)),
2595            Column("no_eval_null_default", String(50), default="default_val"),
2596            Column(
2597                "builtin_evals_null_no_default", String(50).evaluates_none()
2598            ),
2599            Column(
2600                "builtin_evals_null_default",
2601                String(50).evaluates_none(),
2602                default="default_val",
2603            ),
2604        )
2605
2606        if testing.requires.json_type.enabled:
2607            Table(
2608                "test_has_json",
2609                metadata,
2610                Column(
2611                    "id",
2612                    Integer,
2613                    primary_key=True,
2614                    test_needs_autoincrement=True,
2615                ),
2616                Column("data", JSON(none_as_null=True).evaluates_none()),
2617                Column("data_null", JSON(none_as_null=True)),
2618            )
2619
2620    @classmethod
2621    def setup_classes(cls):
2622        class Thing(cls.Basic):
2623            pass
2624
2625        class AltNameThing(cls.Basic):
2626            pass
2627
2628        class JSONThing(cls.Basic):
2629            pass
2630
2631    @classmethod
2632    def setup_mappers(cls):
2633        Thing = cls.classes.Thing
2634        AltNameThing = cls.classes.AltNameThing
2635
2636        mapper(Thing, cls.tables.test)
2637
2638        mapper(AltNameThing, cls.tables.test_w_renames, column_prefix="_foo_")
2639
2640        if testing.requires.json_type.enabled:
2641            mapper(cls.classes.JSONThing, cls.tables.test_has_json)
2642
2643    def _assert_col(self, name, value):
2644        Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing
2645        s = Session()
2646
2647        col = getattr(Thing, name)
2648        obj = s.query(col).filter(col == value).one()
2649        eq_(obj[0], value)
2650
2651        col = getattr(AltNameThing, "_foo_" + name)
2652        obj = s.query(col).filter(col == value).one()
2653        eq_(obj[0], value)
2654
2655    def _test_insert(self, attr, expected):
2656        Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing
2657
2658        s = Session()
2659        t1 = Thing(**{attr: None})
2660        s.add(t1)
2661
2662        t2 = AltNameThing(**{"_foo_" + attr: None})
2663        s.add(t2)
2664
2665        s.commit()
2666
2667        self._assert_col(attr, expected)
2668
2669    def _test_bulk_insert(self, attr, expected):
2670        Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing
2671
2672        s = Session()
2673        s.bulk_insert_mappings(Thing, [{attr: None}])
2674        s.bulk_insert_mappings(AltNameThing, [{"_foo_" + attr: None}])
2675        s.commit()
2676
2677        self._assert_col(attr, expected)
2678
2679    def _test_insert_novalue(self, attr, expected):
2680        Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing
2681
2682        s = Session()
2683        t1 = Thing()
2684        s.add(t1)
2685
2686        t2 = AltNameThing()
2687        s.add(t2)
2688
2689        s.commit()
2690
2691        self._assert_col(attr, expected)
2692
2693    def _test_bulk_insert_novalue(self, attr, expected):
2694        Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing
2695
2696        s = Session()
2697        s.bulk_insert_mappings(Thing, [{}])
2698        s.bulk_insert_mappings(AltNameThing, [{}])
2699        s.commit()
2700
2701        self._assert_col(attr, expected)
2702
2703    def test_evalnull_nodefault_insert(self):
2704        self._test_insert("evals_null_no_default", "nothing")
2705
2706    def test_evalnull_nodefault_bulk_insert(self):
2707        self._test_bulk_insert("evals_null_no_default", "nothing")
2708
2709    def test_evalnull_nodefault_insert_novalue(self):
2710        self._test_insert_novalue("evals_null_no_default", None)
2711
2712    def test_evalnull_nodefault_bulk_insert_novalue(self):
2713        self._test_bulk_insert_novalue("evals_null_no_default", None)
2714
2715    def test_evalnull_default_insert(self):
2716        self._test_insert("evals_null_default", "nothing")
2717
2718    def test_evalnull_default_bulk_insert(self):
2719        self._test_bulk_insert("evals_null_default", "nothing")
2720
2721    def test_evalnull_default_insert_novalue(self):
2722        self._test_insert_novalue("evals_null_default", "default_val")
2723
2724    def test_evalnull_default_bulk_insert_novalue(self):
2725        self._test_bulk_insert_novalue("evals_null_default", "default_val")
2726
2727    def test_no_evalnull_nodefault_insert(self):
2728        self._test_insert("no_eval_null_no_default", None)
2729
2730    def test_no_evalnull_nodefault_bulk_insert(self):
2731        self._test_bulk_insert("no_eval_null_no_default", None)
2732
2733    def test_no_evalnull_nodefault_insert_novalue(self):
2734        self._test_insert_novalue("no_eval_null_no_default", None)
2735
2736    def test_no_evalnull_nodefault_bulk_insert_novalue(self):
2737        self._test_bulk_insert_novalue("no_eval_null_no_default", None)
2738
2739    def test_no_evalnull_default_insert(self):
2740        self._test_insert("no_eval_null_default", "default_val")
2741
2742    def test_no_evalnull_default_bulk_insert(self):
2743        self._test_bulk_insert("no_eval_null_default", "default_val")
2744
2745    def test_no_evalnull_default_insert_novalue(self):
2746        self._test_insert_novalue("no_eval_null_default", "default_val")
2747
2748    def test_no_evalnull_default_bulk_insert_novalue(self):
2749        self._test_bulk_insert_novalue("no_eval_null_default", "default_val")
2750
2751    def test_builtin_evalnull_nodefault_insert(self):
2752        self._test_insert("builtin_evals_null_no_default", None)
2753
2754    def test_builtin_evalnull_nodefault_bulk_insert(self):
2755        self._test_bulk_insert("builtin_evals_null_no_default", None)
2756
2757    def test_builtin_evalnull_nodefault_insert_novalue(self):
2758        self._test_insert_novalue("builtin_evals_null_no_default", None)
2759
2760    def test_builtin_evalnull_nodefault_bulk_insert_novalue(self):
2761        self._test_bulk_insert_novalue("builtin_evals_null_no_default", None)
2762
2763    def test_builtin_evalnull_default_insert(self):
2764        self._test_insert("builtin_evals_null_default", None)
2765
2766    def test_builtin_evalnull_default_bulk_insert(self):
2767        self._test_bulk_insert("builtin_evals_null_default", None)
2768
2769    def test_builtin_evalnull_default_insert_novalue(self):
2770        self._test_insert_novalue("builtin_evals_null_default", "default_val")
2771
2772    def test_builtin_evalnull_default_bulk_insert_novalue(self):
2773        self._test_bulk_insert_novalue(
2774            "builtin_evals_null_default", "default_val"
2775        )
2776
2777    @testing.requires.json_type
2778    def test_json_none_as_null(self):
2779        JSONThing = self.classes.JSONThing
2780
2781        s = Session()
2782        f1 = JSONThing(data=None, data_null=None)
2783        s.add(f1)
2784        s.commit()
2785        eq_(s.query(cast(JSONThing.data, String)).scalar(), "null")
2786        eq_(s.query(cast(JSONThing.data_null, String)).scalar(), None)
2787