1import contextlib
2import datetime
3import uuid
4
5import sqlalchemy as sa
6from sqlalchemy import Date
7from sqlalchemy import exc
8from sqlalchemy import ForeignKey
9from sqlalchemy import inspect
10from sqlalchemy import Integer
11from sqlalchemy import orm
12from sqlalchemy import select
13from sqlalchemy import String
14from sqlalchemy import testing
15from sqlalchemy import TypeDecorator
16from sqlalchemy import util
17from sqlalchemy.orm import configure_mappers
18from sqlalchemy.orm import exc as orm_exc
19from sqlalchemy.orm import relationship
20from sqlalchemy.orm import Session
21from sqlalchemy.testing import assert_raises
22from sqlalchemy.testing import assert_raises_message
23from sqlalchemy.testing import config
24from sqlalchemy.testing import engines
25from sqlalchemy.testing import eq_
26from sqlalchemy.testing import expect_warnings
27from sqlalchemy.testing import fixtures
28from sqlalchemy.testing import is_false
29from sqlalchemy.testing import is_true
30from sqlalchemy.testing.assertsql import CompiledSQL
31from sqlalchemy.testing.fixtures import fixture_session
32from sqlalchemy.testing.mock import patch
33from sqlalchemy.testing.schema import Column
34from sqlalchemy.testing.schema import Table
35
36
37def make_uuid():
38    return uuid.uuid4().hex
39
40
41@contextlib.contextmanager
42def conditional_sane_rowcount_warnings(
43    update=False, delete=False, only_returning=False
44):
45    warnings = ()
46    if (
47        only_returning
48        and not testing.db.dialect.supports_sane_rowcount_returning
49    ) or (
50        not only_returning and not testing.db.dialect.supports_sane_rowcount
51    ):
52        if update:
53            warnings += (
54                "Dialect .* does not support updated rowcount - "
55                "versioning cannot be verified.",
56            )
57        if delete:
58            warnings += (
59                "Dialect .* does not support deleted rowcount - "
60                "versioning cannot be verified.",
61            )
62
63        with expect_warnings(*warnings):
64            yield
65    else:
66        yield
67
68
69class NullVersionIdTest(fixtures.MappedTest):
70    __backend__ = True
71
72    @classmethod
73    def define_tables(cls, metadata):
74        Table(
75            "version_table",
76            metadata,
77            Column(
78                "id", Integer, primary_key=True, test_needs_autoincrement=True
79            ),
80            Column("version_id", Integer),
81            Column("value", String(40), nullable=False),
82        )
83
84    @classmethod
85    def setup_classes(cls):
86        class Foo(cls.Basic):
87            pass
88
89    def _fixture(self):
90        Foo, version_table = self.classes.Foo, self.tables.version_table
91
92        self.mapper_registry.map_imperatively(
93            Foo,
94            version_table,
95            version_id_col=version_table.c.version_id,
96            version_id_generator=False,
97        )
98
99        s1 = fixture_session()
100        return s1
101
102    def test_null_version_id_insert(self):
103        Foo = self.classes.Foo
104
105        s1 = self._fixture()
106        f1 = Foo(value="f1")
107        s1.add(f1)
108
109        # Prior to the fix for #3673, you would have been allowed to insert
110        # the above record with a NULL version_id and you would have gotten
111        # the following error when you tried to update it. Now you should
112        # get a FlushError on the initial insert.
113        #
114        # A value is required for bind parameter 'version_table_version_id'
115        # UPDATE version_table SET value=?
116        #    WHERE version_table.id = ?
117        #    AND version_table.version_id = ?
118        # parameters: [{'version_table_id': 1, 'value': 'f1rev2'}]]
119
120        assert_raises_message(
121            sa.orm.exc.FlushError,
122            "Instance does not contain a non-NULL version value",
123            s1.commit,
124        )
125
126    def test_null_version_id_update(self):
127        Foo = self.classes.Foo
128
129        s1 = self._fixture()
130        f1 = Foo(value="f1", version_id=1)
131        s1.add(f1)
132        s1.commit()
133
134        # Prior to the fix for #3673, you would have been allowed to update
135        # the above record with a NULL version_id, and it would look like
136        # this, post commit: Foo(id=1, value='f1rev2', version_id=None). Now
137        # you should get a FlushError on update.
138
139        f1.value = "f1rev2"
140
141        with conditional_sane_rowcount_warnings(
142            update=True, only_returning=True
143        ):
144            f1.version_id = None
145            assert_raises_message(
146                sa.orm.exc.FlushError,
147                "Instance does not contain a non-NULL version value",
148                s1.commit,
149            )
150
151
152class VersioningTest(fixtures.MappedTest):
153    __backend__ = True
154
155    @classmethod
156    def define_tables(cls, metadata):
157        Table(
158            "version_table",
159            metadata,
160            Column(
161                "id", Integer, primary_key=True, test_needs_autoincrement=True
162            ),
163            Column("version_id", Integer, nullable=False),
164            Column("value", String(40), nullable=False),
165            test_needs_acid=True,
166        )
167
168    @classmethod
169    def setup_classes(cls):
170        class Foo(cls.Basic):
171            pass
172
173    def _fixture(self):
174        Foo, version_table = self.classes.Foo, self.tables.version_table
175
176        self.mapper_registry.map_imperatively(
177            Foo, version_table, version_id_col=version_table.c.version_id
178        )
179        s1 = fixture_session()
180        return s1
181
182    @engines.close_open_connections
183    def test_notsane_warning(self):
184        Foo = self.classes.Foo
185
186        save = testing.db.dialect.supports_sane_rowcount
187        testing.db.dialect.supports_sane_rowcount = False
188        try:
189            s1 = self._fixture()
190            f1 = Foo(value="f1")
191            f2 = Foo(value="f2")
192            s1.add_all((f1, f2))
193            s1.commit()
194
195            f1.value = "f1rev2"
196            assert_raises(sa.exc.SAWarning, s1.commit)
197        finally:
198            testing.db.dialect.supports_sane_rowcount = save
199
200    def test_basic(self):
201        Foo = self.classes.Foo
202
203        s1 = self._fixture()
204        f1 = Foo(value="f1")
205        f2 = Foo(value="f2")
206        s1.add_all((f1, f2))
207        s1.commit()
208
209        f1.value = "f1rev2"
210        with conditional_sane_rowcount_warnings(
211            update=True, only_returning=True
212        ):
213            s1.commit()
214
215        s2 = fixture_session(autocommit=False)
216        f1_s = s2.get(Foo, f1.id)
217        f1_s.value = "f1rev3"
218        with conditional_sane_rowcount_warnings(
219            update=True, only_returning=True
220        ):
221            s2.commit()
222
223        f1.value = "f1rev3mine"
224
225        # Only dialects with a sane rowcount can detect the
226        # StaleDataError
227        if testing.db.dialect.supports_sane_rowcount_returning:
228            assert_raises_message(
229                sa.orm.exc.StaleDataError,
230                r"UPDATE statement on table 'version_table' expected "
231                r"to update 1 row\(s\); 0 were matched.",
232                s1.commit,
233            ),
234            s1.rollback()
235        else:
236            with conditional_sane_rowcount_warnings(
237                update=True, only_returning=True
238            ):
239                s1.commit()
240
241        # new in 0.5 !  don't need to close the session
242        f1 = s1.get(Foo, f1.id)
243        f2 = s1.get(Foo, f2.id)
244
245        f1_s.value = "f1rev4"
246        with conditional_sane_rowcount_warnings(
247            update=True, only_returning=True
248        ):
249            s2.commit()
250
251        s1.delete(f1)
252        s1.delete(f2)
253
254        if testing.db.dialect.supports_sane_multi_rowcount:
255            assert_raises_message(
256                sa.orm.exc.StaleDataError,
257                r"DELETE statement on table 'version_table' expected "
258                r"to delete 2 row\(s\); 1 were matched.",
259                s1.commit,
260            )
261        else:
262            with conditional_sane_rowcount_warnings(delete=True):
263                s1.commit()
264
265    def test_multiple_updates(self):
266        Foo = self.classes.Foo
267
268        s1 = self._fixture()
269        f1 = Foo(value="f1")
270        f2 = Foo(value="f2")
271        s1.add_all((f1, f2))
272        s1.commit()
273
274        f1.value = "f1rev2"
275        f2.value = "f2rev2"
276        with conditional_sane_rowcount_warnings(
277            update=True, only_returning=True
278        ):
279            s1.commit()
280
281        eq_(
282            s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(),
283            [(f1.id, "f1rev2", 2), (f2.id, "f2rev2", 2)],
284        )
285
286    def test_bulk_insert(self):
287        Foo = self.classes.Foo
288
289        s1 = self._fixture()
290        s1.bulk_insert_mappings(
291            Foo, [{"id": 1, "value": "f1"}, {"id": 2, "value": "f2"}]
292        )
293        eq_(
294            s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(),
295            [(1, "f1", 1), (2, "f2", 1)],
296        )
297
298    def test_bulk_update(self):
299        Foo = self.classes.Foo
300
301        s1 = self._fixture()
302        f1 = Foo(value="f1")
303        f2 = Foo(value="f2")
304        s1.add_all((f1, f2))
305        s1.commit()
306
307        with conditional_sane_rowcount_warnings(
308            update=True, only_returning=True
309        ):
310            s1.bulk_update_mappings(
311                Foo,
312                [
313                    {"id": f1.id, "value": "f1rev2", "version_id": 1},
314                    {"id": f2.id, "value": "f2rev2", "version_id": 1},
315                ],
316            )
317        s1.commit()
318
319        eq_(
320            s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(),
321            [(f1.id, "f1rev2", 2), (f2.id, "f2rev2", 2)],
322        )
323
324    def test_bump_version(self):
325        """test that version number can be bumped.
326
327        Ensures that the UPDATE or DELETE is against the
328        last committed version of version_id_col, not the modified
329        state.
330
331        """
332
333        Foo = self.classes.Foo
334
335        s1 = self._fixture()
336        f1 = Foo(value="f1")
337        s1.add(f1)
338        s1.commit()
339        eq_(f1.version_id, 1)
340        f1.version_id = 2
341        with conditional_sane_rowcount_warnings(
342            update=True, only_returning=True
343        ):
344            s1.commit()
345        eq_(f1.version_id, 2)
346
347        # skip an id, test that history
348        # is honored
349        f1.version_id = 4
350        f1.value = "something new"
351        with conditional_sane_rowcount_warnings(
352            update=True, only_returning=True
353        ):
354            s1.commit()
355        eq_(f1.version_id, 4)
356
357        f1.version_id = 5
358        s1.delete(f1)
359        with conditional_sane_rowcount_warnings(delete=True):
360            s1.commit()
361        eq_(s1.query(Foo).count(), 0)
362
363    @engines.close_open_connections
364    def test_versioncheck(self):
365        """query.with_lockmode performs a 'version check' on an already loaded
366        instance"""
367
368        Foo = self.classes.Foo
369
370        s1 = self._fixture()
371        f1s1 = Foo(value="f1 value")
372        s1.add(f1s1)
373        s1.commit()
374
375        s2 = fixture_session(autocommit=False)
376        f1s2 = s2.get(Foo, f1s1.id)
377        f1s2.value = "f1 new value"
378        with conditional_sane_rowcount_warnings(
379            update=True, only_returning=True
380        ):
381            s2.commit()
382
383        # load, version is wrong
384        assert_raises_message(
385            sa.orm.exc.StaleDataError,
386            r"Instance .* has version id '\d+' which does not "
387            r"match database-loaded version id '\d+'",
388            s1.get,
389            Foo,
390            f1s1.id,
391            with_for_update=dict(read=True),
392        )
393
394        # reload it - this expires the old version first
395        s1.refresh(f1s1, with_for_update={"read": True})
396
397        # now assert version OK
398        s1.get(Foo, f1s1.id, with_for_update=dict(read=True))
399
400        # assert brand new load is OK too
401        s1.close()
402        s1.get(Foo, f1s1.id, with_for_update=dict(read=True))
403
404    def test_versioncheck_not_versioned(self):
405        """ensure the versioncheck logic skips if there isn't a
406        version_id_col actually configured"""
407
408        Foo = self.classes.Foo
409        version_table = self.tables.version_table
410
411        self.mapper_registry.map_imperatively(Foo, version_table)
412        s1 = fixture_session()
413        f1s1 = Foo(value="f1 value", version_id=1)
414        s1.add(f1s1)
415        s1.commit()
416        s1.query(Foo).with_for_update(read=True).where(Foo.id == f1s1.id).one()
417
418    @engines.close_open_connections
419    @testing.requires.update_nowait
420    def test_versioncheck_for_update(self):
421        """query.with_lockmode performs a 'version check' on an already loaded
422        instance"""
423
424        Foo = self.classes.Foo
425
426        s1 = self._fixture()
427        f1s1 = Foo(value="f1 value")
428        s1.add(f1s1)
429        s1.commit()
430
431        s2 = fixture_session(autocommit=False)
432        f1s2 = s2.get(Foo, f1s1.id)
433        # not sure if I like this API
434        s2.refresh(f1s2, with_for_update=True)
435        f1s2.value = "f1 new value"
436
437        assert_raises(
438            exc.DBAPIError, s1.refresh, f1s1, with_for_update={"nowait": True}
439        )
440        s1.rollback()
441
442        with conditional_sane_rowcount_warnings(update=True):
443            s2.commit()
444        s1.refresh(f1s1, with_for_update={"nowait": True})
445        assert f1s1.version_id == f1s2.version_id
446
447    def test_update_multi_missing_broken_multi_rowcount(self):
448        @util.memoized_property
449        def rowcount(self):
450            if len(self.context.compiled_parameters) > 1:
451                return -1
452            else:
453                return self.context.rowcount
454
455        with patch.object(
456            config.db.dialect, "supports_sane_multi_rowcount", False
457        ), patch("sqlalchemy.engine.cursor.CursorResult.rowcount", rowcount):
458
459            Foo = self.classes.Foo
460            s1 = self._fixture()
461            f1s1 = Foo(value="f1 value")
462            s1.add(f1s1)
463            s1.commit()
464
465            f1s1.value = "f2 value"
466            with conditional_sane_rowcount_warnings(
467                update=True, only_returning=True
468            ):
469                s1.flush()
470            eq_(f1s1.version_id, 2)
471
472    def test_update_delete_no_plain_rowcount(self):
473
474        with patch.object(
475            config.db.dialect, "supports_sane_rowcount", False
476        ), patch.object(
477            config.db.dialect, "supports_sane_multi_rowcount", False
478        ):
479            Foo = self.classes.Foo
480            s1 = self._fixture()
481            f1s1 = Foo(value="f1 value")
482            s1.add(f1s1)
483            s1.commit()
484
485            f1s1.value = "f2 value"
486
487            with expect_warnings(
488                "Dialect .* does not support updated rowcount - "
489                "versioning cannot be verified."
490            ):
491                s1.flush()
492            eq_(f1s1.version_id, 2)
493
494            s1.delete(f1s1)
495            with expect_warnings(
496                "Dialect .* does not support deleted rowcount - "
497                "versioning cannot be verified."
498            ):
499                s1.flush()
500
501    @engines.close_open_connections
502    def test_noversioncheck(self):
503        """test query.with_lockmode works when the mapper has no version id
504        col"""
505
506        Foo, version_table = self.classes.Foo, self.tables.version_table
507
508        s1 = fixture_session(autocommit=False)
509        self.mapper_registry.map_imperatively(Foo, version_table)
510        f1s1 = Foo(value="foo", version_id=0)
511        s1.add(f1s1)
512        s1.commit()
513
514        s2 = fixture_session(autocommit=False)
515        f1s2 = (
516            s2.query(Foo)
517            .with_for_update(read=True)
518            .where(Foo.id == f1s1.id)
519            .one()
520        )
521        assert f1s2.id == f1s1.id
522        assert f1s2.value == f1s1.value
523
524    def test_merge_no_version(self):
525        Foo = self.classes.Foo
526
527        s1 = self._fixture()
528        f1 = Foo(value="f1")
529        s1.add(f1)
530        s1.commit()
531
532        f1.value = "f2"
533        with conditional_sane_rowcount_warnings(
534            update=True, only_returning=True
535        ):
536            s1.commit()
537
538        f2 = Foo(id=f1.id, value="f3")
539        f3 = s1.merge(f2)
540        assert f3 is f1
541        with conditional_sane_rowcount_warnings(
542            update=True, only_returning=True
543        ):
544            s1.commit()
545        eq_(f3.version_id, 3)
546
547    def test_merge_correct_version(self):
548        Foo = self.classes.Foo
549
550        s1 = self._fixture()
551        f1 = Foo(value="f1")
552        s1.add(f1)
553        s1.commit()
554
555        f1.value = "f2"
556        with conditional_sane_rowcount_warnings(
557            update=True, only_returning=True
558        ):
559            s1.commit()
560
561        f2 = Foo(id=f1.id, value="f3", version_id=2)
562        f3 = s1.merge(f2)
563        assert f3 is f1
564        with conditional_sane_rowcount_warnings(
565            update=True, only_returning=True
566        ):
567            s1.commit()
568        eq_(f3.version_id, 3)
569
570    def test_merge_incorrect_version(self):
571        Foo = self.classes.Foo
572
573        s1 = self._fixture()
574        f1 = Foo(value="f1")
575        s1.add(f1)
576        s1.commit()
577
578        f1.value = "f2"
579        with conditional_sane_rowcount_warnings(
580            update=True, only_returning=True
581        ):
582            s1.commit()
583
584        f2 = Foo(id=f1.id, value="f3", version_id=1)
585        assert_raises_message(
586            orm_exc.StaleDataError,
587            "Version id '1' on merged state "
588            "<Foo at .*?> does not match existing version '2'. "
589            "Leave the version attribute unset when "
590            "merging to update the most recent version.",
591            s1.merge,
592            f2,
593        )
594
595    def test_merge_incorrect_version_not_in_session(self):
596        Foo = self.classes.Foo
597
598        s1 = self._fixture()
599        f1 = Foo(value="f1")
600        s1.add(f1)
601        s1.commit()
602
603        f1.value = "f2"
604        with conditional_sane_rowcount_warnings(
605            update=True, only_returning=True
606        ):
607            s1.commit()
608
609        f2 = Foo(id=f1.id, value="f3", version_id=1)
610        s1.close()
611
612        assert_raises_message(
613            orm_exc.StaleDataError,
614            "Version id '1' on merged state "
615            "<Foo at .*?> does not match existing version '2'. "
616            "Leave the version attribute unset when "
617            "merging to update the most recent version.",
618            s1.merge,
619            f2,
620        )
621
622
623class VersionOnPostUpdateTest(fixtures.MappedTest):
624    __backend__ = True
625
626    @classmethod
627    def define_tables(cls, metadata):
628        Table(
629            "node",
630            metadata,
631            Column("id", Integer, primary_key=True),
632            Column("version_id", Integer),
633            Column("parent_id", ForeignKey("node.id")),
634        )
635
636    @classmethod
637    def setup_classes(cls):
638        class Node(cls.Basic):
639            pass
640
641    def _fixture(self, o2m, post_update, insert=True):
642        Node = self.classes.Node
643        node = self.tables.node
644
645        self.mapper_registry.map_imperatively(
646            Node,
647            node,
648            properties={
649                "related": relationship(
650                    Node,
651                    remote_side=node.c.id if not o2m else node.c.parent_id,
652                    post_update=post_update,
653                )
654            },
655            version_id_col=node.c.version_id,
656        )
657
658        s = fixture_session()
659        n1 = Node(id=1)
660        n2 = Node(id=2)
661
662        if insert:
663            s.add_all([n1, n2])
664            s.flush()
665        return s, n1, n2
666
667    def test_o2m_plain(self):
668        s, n1, n2 = self._fixture(o2m=True, post_update=False)
669
670        n1.related.append(n2)
671        with conditional_sane_rowcount_warnings(
672            update=True, only_returning=True
673        ):
674            s.flush()
675
676        eq_(n1.version_id, 1)
677        eq_(n2.version_id, 2)
678
679    def test_m2o_plain(self):
680        s, n1, n2 = self._fixture(o2m=False, post_update=False)
681
682        n1.related = n2
683        with conditional_sane_rowcount_warnings(
684            update=True, only_returning=True
685        ):
686            s.flush()
687
688        eq_(n1.version_id, 2)
689        eq_(n2.version_id, 1)
690
691    def test_o2m_post_update(self):
692        s, n1, n2 = self._fixture(o2m=True, post_update=True)
693
694        n1.related.append(n2)
695        with conditional_sane_rowcount_warnings(
696            update=True, only_returning=True
697        ):
698            s.flush()
699
700        eq_(n1.version_id, 1)
701        eq_(n2.version_id, 2)
702
703    def test_m2o_post_update(self):
704        s, n1, n2 = self._fixture(o2m=False, post_update=True)
705
706        n1.related = n2
707        with conditional_sane_rowcount_warnings(
708            update=True, only_returning=True
709        ):
710            s.flush()
711
712        eq_(n1.version_id, 2)
713        eq_(n2.version_id, 1)
714
715    def test_o2m_post_update_not_assoc_w_insert(self):
716        s, n1, n2 = self._fixture(o2m=True, post_update=True, insert=False)
717
718        n1.related.append(n2)
719        s.add_all([n1, n2])
720        with conditional_sane_rowcount_warnings(
721            update=True, only_returning=True
722        ):
723            s.flush()
724
725        eq_(n1.version_id, 1)
726        eq_(n2.version_id, 1)
727
728    def test_m2o_post_update_not_assoc_w_insert(self):
729        s, n1, n2 = self._fixture(o2m=False, post_update=True, insert=False)
730
731        n1.related = n2
732        s.add_all([n1, n2])
733        with conditional_sane_rowcount_warnings(
734            update=True, only_returning=True
735        ):
736            s.flush()
737
738        eq_(n1.version_id, 1)
739        eq_(n2.version_id, 1)
740
741    @testing.requires.sane_rowcount_w_returning
742    def test_o2m_post_update_version_assert(self):
743        Node = self.classes.Node
744        s, n1, n2 = self._fixture(o2m=True, post_update=True)
745
746        n1.related.append(n2)
747
748        # outwit the database transaction isolation and SQLA's
749        # expiration at the same time by using different Session on
750        # same transaction
751        s2 = Session(bind=s.connection(mapper=Node))
752        s2.query(Node).filter(Node.id == n2.id).update({"version_id": 3})
753        s2.commit()
754
755        assert_raises_message(
756            orm_exc.StaleDataError,
757            "UPDATE statement on table 'node' expected to "
758            r"update 1 row\(s\); 0 were matched.",
759            s.flush,
760        )
761
762    def test_o2m_post_update_no_sane_rowcount(self):
763        Node = self.classes.Node
764        s, n1, n2 = self._fixture(o2m=True, post_update=True)
765
766        n1.related.append(n2)
767
768        with patch.object(
769            config.db.dialect, "supports_sane_rowcount", False
770        ), patch.object(
771            config.db.dialect, "supports_sane_multi_rowcount", False
772        ):
773            s2 = Session(bind=s.connection(mapper=Node))
774            s2.query(Node).filter(Node.id == n2.id).update({"version_id": 3})
775            s2.commit()
776
777            with expect_warnings(
778                "Dialect .* does not support updated rowcount - "
779                "versioning cannot be verified."
780            ):
781                s.flush()
782
783    @testing.requires.sane_rowcount_w_returning
784    def test_m2o_post_update_version_assert(self):
785        Node = self.classes.Node
786
787        s, n1, n2 = self._fixture(o2m=False, post_update=True)
788
789        n1.related = n2
790
791        # outwit the database transaction isolation and SQLA's
792        # expiration at the same time by using different Session on
793        # same transaction
794        s2 = Session(bind=s.connection(mapper=Node))
795        s2.query(Node).filter(Node.id == n1.id).update({"version_id": 3})
796        s2.commit()
797
798        assert_raises_message(
799            orm_exc.StaleDataError,
800            "UPDATE statement on table 'node' expected to "
801            r"update 1 row\(s\); 0 were matched.",
802            s.flush,
803        )
804
805
806class NoBumpOnRelationshipTest(fixtures.MappedTest):
807    __backend__ = True
808
809    @classmethod
810    def define_tables(cls, metadata):
811        Table(
812            "a",
813            metadata,
814            Column(
815                "id", Integer, primary_key=True, test_needs_autoincrement=True
816            ),
817            Column("version_id", Integer),
818        )
819        Table(
820            "b",
821            metadata,
822            Column(
823                "id", Integer, primary_key=True, test_needs_autoincrement=True
824            ),
825            Column("a_id", ForeignKey("a.id")),
826        )
827
828    @classmethod
829    def setup_classes(cls):
830        class A(cls.Basic):
831            pass
832
833        class B(cls.Basic):
834            pass
835
836    def _run_test(self, auto_version_counter=True):
837        A, B = self.classes("A", "B")
838        s = fixture_session(future=True)
839        if auto_version_counter:
840            a1 = A()
841        else:
842            a1 = A(version_id=1)
843        s.add(a1)
844        s.commit()
845        eq_(a1.version_id, 1)
846
847        b1 = B()
848        b1.a = a1
849        s.add(b1)
850        s.commit()
851
852        eq_(a1.version_id, 1)
853
854    def test_plain_counter(self):
855        A, B = self.classes("A", "B")
856        a, b = self.tables("a", "b")
857
858        self.mapper_registry.map_imperatively(
859            A,
860            a,
861            properties={"bs": relationship(B, backref="a")},
862            version_id_col=a.c.version_id,
863        )
864        self.mapper_registry.map_imperatively(B, b)
865
866        self._run_test()
867
868    def test_functional_counter(self):
869        A, B = self.classes("A", "B")
870        a, b = self.tables("a", "b")
871
872        self.mapper_registry.map_imperatively(
873            A,
874            a,
875            properties={"bs": relationship(B, backref="a")},
876            version_id_col=a.c.version_id,
877            version_id_generator=lambda num: (num or 0) + 1,
878        )
879        self.mapper_registry.map_imperatively(B, b)
880
881        self._run_test()
882
883    def test_no_counter(self):
884        A, B = self.classes("A", "B")
885        a, b = self.tables("a", "b")
886
887        self.mapper_registry.map_imperatively(
888            A,
889            a,
890            properties={"bs": relationship(B, backref="a")},
891            version_id_col=a.c.version_id,
892            version_id_generator=False,
893        )
894        self.mapper_registry.map_imperatively(B, b)
895
896        self._run_test(False)
897
898
899class ColumnTypeTest(fixtures.MappedTest):
900    __backend__ = True
901    __requires__ = ("sane_rowcount",)
902
903    @classmethod
904    def define_tables(cls, metadata):
905        class SpecialType(TypeDecorator):
906            impl = Date
907            cache_ok = True
908
909            def process_bind_param(self, value, dialect):
910                assert isinstance(value, datetime.date)
911                return value
912
913        Table(
914            "version_table",
915            metadata,
916            Column("id", SpecialType, primary_key=True),
917            Column("version_id", Integer, nullable=False),
918            Column("value", String(40), nullable=False),
919        )
920
921    @classmethod
922    def setup_classes(cls):
923        class Foo(cls.Basic):
924            pass
925
926    def _fixture(self):
927        Foo, version_table = self.classes.Foo, self.tables.version_table
928
929        self.mapper_registry.map_imperatively(
930            Foo, version_table, version_id_col=version_table.c.version_id
931        )
932        s1 = fixture_session()
933        return s1
934
935    @engines.close_open_connections
936    def test_update(self):
937        Foo = self.classes.Foo
938
939        s1 = self._fixture()
940        f1 = Foo(id=datetime.date.today(), value="f1")
941        s1.add(f1)
942        s1.commit()
943
944        f1.value = "f1rev2"
945        with conditional_sane_rowcount_warnings(
946            update=True, only_returning=True
947        ):
948            s1.commit()
949
950
951class RowSwitchTest(fixtures.MappedTest):
952    __backend__ = True
953
954    @classmethod
955    def define_tables(cls, metadata):
956        Table(
957            "p",
958            metadata,
959            Column("id", String(10), primary_key=True),
960            Column("version_id", Integer, default=1, nullable=False),
961            Column("data", String(50)),
962        )
963        Table(
964            "c",
965            metadata,
966            Column("id", String(10), ForeignKey("p.id"), primary_key=True),
967            Column("version_id", Integer, default=1, nullable=False),
968            Column("data", String(50)),
969        )
970
971    @classmethod
972    def setup_classes(cls):
973        class P(cls.Basic):
974            pass
975
976        class C(cls.Basic):
977            pass
978
979    @classmethod
980    def setup_mappers(cls):
981        p, c, C, P = cls.tables.p, cls.tables.c, cls.classes.C, cls.classes.P
982
983        cls.mapper_registry.map_imperatively(
984            P,
985            p,
986            version_id_col=p.c.version_id,
987            properties={
988                "c": relationship(
989                    C, uselist=False, cascade="all, delete-orphan"
990                )
991            },
992        )
993        cls.mapper_registry.map_imperatively(
994            C, c, version_id_col=c.c.version_id
995        )
996
997    def test_row_switch(self):
998        P = self.classes.P
999
1000        session = fixture_session()
1001        session.add(P(id="P1", data="P version 1"))
1002        session.commit()
1003        session.close()
1004
1005        p = session.query(P).first()
1006        session.delete(p)
1007        session.add(P(id="P1", data="really a row-switch"))
1008        with conditional_sane_rowcount_warnings(
1009            update=True, only_returning=True
1010        ):
1011            session.commit()
1012
1013    def test_child_row_switch(self):
1014        P, C = self.classes.P, self.classes.C
1015
1016        assert P.c.property.strategy.use_get
1017
1018        session = fixture_session()
1019        session.add(P(id="P1", data="P version 1"))
1020        session.commit()
1021        session.close()
1022
1023        p = session.query(P).first()
1024        p.c = C(data="child version 1")
1025        session.commit()
1026
1027        p = session.query(P).first()
1028        p.c = C(data="child row-switch")
1029        with conditional_sane_rowcount_warnings(
1030            update=True, only_returning=True
1031        ):
1032            session.commit()
1033
1034
1035class AlternateGeneratorTest(fixtures.MappedTest):
1036    __backend__ = True
1037    __requires__ = ("sane_rowcount",)
1038
1039    @classmethod
1040    def define_tables(cls, metadata):
1041        Table(
1042            "p",
1043            metadata,
1044            Column("id", String(10), primary_key=True),
1045            Column("version_id", String(32), nullable=False),
1046            Column("data", String(50)),
1047        )
1048        Table(
1049            "c",
1050            metadata,
1051            Column("id", String(10), ForeignKey("p.id"), primary_key=True),
1052            Column("version_id", String(32), nullable=False),
1053            Column("data", String(50)),
1054        )
1055
1056    @classmethod
1057    def setup_classes(cls):
1058        class P(cls.Basic):
1059            pass
1060
1061        class C(cls.Basic):
1062            pass
1063
1064    @classmethod
1065    def setup_mappers(cls):
1066        p, c, C, P = cls.tables.p, cls.tables.c, cls.classes.C, cls.classes.P
1067
1068        cls.mapper_registry.map_imperatively(
1069            P,
1070            p,
1071            version_id_col=p.c.version_id,
1072            version_id_generator=lambda x: make_uuid(),
1073            properties={
1074                "c": relationship(
1075                    C, uselist=False, cascade="all, delete-orphan"
1076                )
1077            },
1078        )
1079        cls.mapper_registry.map_imperatively(
1080            C,
1081            c,
1082            version_id_col=c.c.version_id,
1083            version_id_generator=lambda x: make_uuid(),
1084        )
1085
1086    def test_row_switch(self):
1087        P = self.classes.P
1088
1089        session = fixture_session()
1090        session.add(P(id="P1", data="P version 1"))
1091        session.commit()
1092        session.close()
1093
1094        p = session.query(P).first()
1095        session.delete(p)
1096        session.add(P(id="P1", data="really a row-switch"))
1097        with conditional_sane_rowcount_warnings(
1098            update=True, only_returning=True
1099        ):
1100            session.commit()
1101
1102    def test_child_row_switch_one(self):
1103        P, C = self.classes.P, self.classes.C
1104
1105        assert P.c.property.strategy.use_get
1106
1107        session = fixture_session()
1108        session.add(P(id="P1", data="P version 1"))
1109        session.commit()
1110        session.close()
1111
1112        p = session.query(P).first()
1113        p.c = C(data="child version 1")
1114        session.commit()
1115
1116        p = session.query(P).first()
1117        p.c = C(data="child row-switch")
1118        with conditional_sane_rowcount_warnings(
1119            update=True, only_returning=True
1120        ):
1121            session.commit()
1122
1123    @testing.requires.sane_rowcount_w_returning
1124    def test_child_row_switch_two(self):
1125        P = self.classes.P
1126
1127        # TODO: not sure this test is
1128        # testing exactly what its looking for
1129
1130        sess1 = fixture_session()
1131        sess1.add(P(id="P1", data="P version 1"))
1132        sess1.commit()
1133        sess1.close()
1134
1135        p1 = sess1.query(P).first()
1136
1137        sess2 = fixture_session()
1138        p2 = sess2.query(P).first()
1139
1140        sess1.delete(p1)
1141        sess1.commit()
1142
1143        # this can be removed and it still passes
1144        sess1.add(P(id="P1", data="P version 2"))
1145        sess1.commit()
1146
1147        p2.data = "P overwritten by concurrent tx"
1148        if testing.db.dialect.supports_sane_rowcount:
1149            assert_raises_message(
1150                orm.exc.StaleDataError,
1151                r"UPDATE statement on table 'p' expected to update "
1152                r"1 row\(s\); 0 were matched.",
1153                sess2.commit,
1154            )
1155        else:
1156            sess2.commit()
1157
1158
1159class PlainInheritanceTest(fixtures.MappedTest):
1160    __backend__ = True
1161
1162    @classmethod
1163    def define_tables(cls, metadata):
1164        Table(
1165            "base",
1166            metadata,
1167            Column(
1168                "id", Integer, primary_key=True, test_needs_autoincrement=True
1169            ),
1170            Column("version_id", Integer, nullable=True),
1171            Column("data", String(50)),
1172        )
1173        Table(
1174            "sub",
1175            metadata,
1176            Column("id", Integer, ForeignKey("base.id"), primary_key=True),
1177            Column("sub_data", String(50)),
1178        )
1179
1180    @classmethod
1181    def setup_classes(cls):
1182        class Base(cls.Basic):
1183            pass
1184
1185        class Sub(Base):
1186            pass
1187
1188    def test_update_child_table_only(self):
1189        Base, sub, base, Sub = (
1190            self.classes.Base,
1191            self.tables.sub,
1192            self.tables.base,
1193            self.classes.Sub,
1194        )
1195
1196        self.mapper_registry.map_imperatively(
1197            Base, base, version_id_col=base.c.version_id
1198        )
1199        self.mapper_registry.map_imperatively(Sub, sub, inherits=Base)
1200
1201        s = fixture_session()
1202        s1 = Sub(data="b", sub_data="s")
1203        s.add(s1)
1204        s.commit()
1205
1206        s1.sub_data = "s2"
1207        with conditional_sane_rowcount_warnings(
1208            update=True, only_returning=True
1209        ):
1210            s.commit()
1211
1212        eq_(s1.version_id, 2)
1213
1214
1215class InheritanceTwoVersionIdsTest(fixtures.MappedTest):
1216    """Test versioning where both parent/child table have a
1217    versioning column.
1218
1219    """
1220
1221    __backend__ = True
1222
1223    @classmethod
1224    def define_tables(cls, metadata):
1225        Table(
1226            "base",
1227            metadata,
1228            Column(
1229                "id", Integer, primary_key=True, test_needs_autoincrement=True
1230            ),
1231            Column("version_id", Integer, nullable=True),
1232            Column("data", String(50)),
1233        )
1234        Table(
1235            "sub",
1236            metadata,
1237            Column("id", Integer, ForeignKey("base.id"), primary_key=True),
1238            Column("version_id", Integer, nullable=False),
1239            Column("sub_data", String(50)),
1240        )
1241
1242    @classmethod
1243    def setup_classes(cls):
1244        class Base(cls.Basic):
1245            pass
1246
1247        class Sub(Base):
1248            pass
1249
1250    def test_base_both(self):
1251        Base, sub, base, Sub = (
1252            self.classes.Base,
1253            self.tables.sub,
1254            self.tables.base,
1255            self.classes.Sub,
1256        )
1257
1258        self.mapper_registry.map_imperatively(
1259            Base, base, version_id_col=base.c.version_id
1260        )
1261        self.mapper_registry.map_imperatively(Sub, sub, inherits=Base)
1262
1263        session = fixture_session()
1264        b1 = Base(data="b1")
1265        session.add(b1)
1266        session.commit()
1267        eq_(b1.version_id, 1)
1268        # base is populated
1269        eq_(session.connection().scalar(select(base.c.version_id)), 1)
1270
1271    def test_sub_both(self):
1272        Base, sub, base, Sub = (
1273            self.classes.Base,
1274            self.tables.sub,
1275            self.tables.base,
1276            self.classes.Sub,
1277        )
1278
1279        self.mapper_registry.map_imperatively(
1280            Base, base, version_id_col=base.c.version_id
1281        )
1282        self.mapper_registry.map_imperatively(Sub, sub, inherits=Base)
1283
1284        session = fixture_session()
1285        s1 = Sub(data="s1", sub_data="s1")
1286        session.add(s1)
1287        session.commit()
1288
1289        # table is populated
1290        eq_(session.connection().scalar(select(sub.c.version_id)), 1)
1291
1292        # base is populated
1293        eq_(session.connection().scalar(select(base.c.version_id)), 1)
1294
1295    def test_sub_only(self):
1296        Base, sub, base, Sub = (
1297            self.classes.Base,
1298            self.tables.sub,
1299            self.tables.base,
1300            self.classes.Sub,
1301        )
1302
1303        self.mapper_registry.map_imperatively(Base, base)
1304        self.mapper_registry.map_imperatively(
1305            Sub, sub, inherits=Base, version_id_col=sub.c.version_id
1306        )
1307
1308        session = fixture_session()
1309        s1 = Sub(data="s1", sub_data="s1")
1310        session.add(s1)
1311        session.commit()
1312
1313        # table is populated
1314        eq_(session.connection().scalar(select(sub.c.version_id)), 1)
1315
1316        # base is not
1317        eq_(session.connection().scalar(select(base.c.version_id)), None)
1318
1319    def test_mismatch_version_col_warning(self):
1320        Base, sub, base, Sub = (
1321            self.classes.Base,
1322            self.tables.sub,
1323            self.tables.base,
1324            self.classes.Sub,
1325        )
1326
1327        self.mapper_registry.map_imperatively(
1328            Base, base, version_id_col=base.c.version_id
1329        )
1330
1331        assert_raises_message(
1332            exc.SAWarning,
1333            "Inheriting version_id_col 'version_id' does not "
1334            "match inherited version_id_col 'version_id' and will not "
1335            "automatically populate the inherited versioning column. "
1336            "version_id_col should only be specified on "
1337            "the base-most mapper that includes versioning.",
1338            self.mapper_registry.map_imperatively,
1339            Sub,
1340            sub,
1341            inherits=Base,
1342            version_id_col=sub.c.version_id,
1343        )
1344
1345
1346class ServerVersioningTest(fixtures.MappedTest):
1347    run_define_tables = "each"
1348    __backend__ = True
1349
1350    @classmethod
1351    def define_tables(cls, metadata):
1352        from sqlalchemy.sql import ColumnElement
1353        from sqlalchemy.ext.compiler import compiles
1354        import itertools
1355
1356        counter = itertools.count(1)
1357
1358        class IncDefault(ColumnElement):
1359            pass
1360
1361        @compiles(IncDefault)
1362        def compile_(element, compiler, **kw):
1363            # cache the counter value on the statement
1364            # itself so the assertsql system gets the same
1365            # value when it compiles the statement a second time
1366            stmt = compiler.statement
1367            if hasattr(stmt, "_counter"):
1368                return stmt._counter
1369            else:
1370                stmt._counter = str(next(counter))
1371                return stmt._counter
1372
1373        Table(
1374            "version_table",
1375            metadata,
1376            Column(
1377                "id", Integer, primary_key=True, test_needs_autoincrement=True
1378            ),
1379            Column(
1380                "version_id",
1381                Integer,
1382                nullable=False,
1383                default=IncDefault(),
1384                onupdate=IncDefault(),
1385            ),
1386            Column("value", String(40), nullable=False),
1387        )
1388
1389    @classmethod
1390    def setup_classes(cls):
1391        class Foo(cls.Basic):
1392            pass
1393
1394        class Bar(cls.Basic):
1395            pass
1396
1397    def _fixture(self, expire_on_commit=True, eager_defaults=False):
1398        Foo, version_table = self.classes.Foo, self.tables.version_table
1399
1400        self.mapper_registry.map_imperatively(
1401            Foo,
1402            version_table,
1403            version_id_col=version_table.c.version_id,
1404            version_id_generator=False,
1405            eager_defaults=eager_defaults,
1406        )
1407
1408        s1 = fixture_session(expire_on_commit=expire_on_commit)
1409        return s1
1410
1411    def test_insert_col(self):
1412        self._test_insert_col()
1413
1414    def test_insert_col_eager_defaults(self):
1415        self._test_insert_col(eager_defaults=True)
1416
1417    def _test_insert_col(self, **kw):
1418        sess = self._fixture(**kw)
1419
1420        f1 = self.classes.Foo(value="f1")
1421        sess.add(f1)
1422
1423        statements = [
1424            # note that the assertsql tests the rule against
1425            # "default" - on a "returning" backend, the statement
1426            # includes "RETURNING"
1427            CompiledSQL(
1428                "INSERT INTO version_table (version_id, value) "
1429                "VALUES (1, :value)",
1430                lambda ctx: [{"value": "f1"}],
1431            )
1432        ]
1433        if not testing.db.dialect.implicit_returning:
1434            # DBs without implicit returning, we must immediately
1435            # SELECT for the new version id
1436            statements.append(
1437                CompiledSQL(
1438                    "SELECT version_table.version_id "
1439                    "AS version_table_version_id "
1440                    "FROM version_table WHERE version_table.id = :pk_1",
1441                    lambda ctx: [{"pk_1": 1}],
1442                )
1443            )
1444        self.assert_sql_execution(testing.db, sess.flush, *statements)
1445
1446    def test_update_col(self):
1447        self._test_update_col()
1448
1449    def test_update_col_eager_defaults(self):
1450        self._test_update_col(eager_defaults=True)
1451
1452    def _test_update_col(self, **kw):
1453        sess = self._fixture(**kw)
1454
1455        f1 = self.classes.Foo(value="f1")
1456        sess.add(f1)
1457        sess.flush()
1458
1459        f1.value = "f2"
1460
1461        statements = [
1462            # note that the assertsql tests the rule against
1463            # "default" - on a "returning" backend, the statement
1464            # includes "RETURNING"
1465            CompiledSQL(
1466                "UPDATE version_table SET version_id=2, value=:value "
1467                "WHERE version_table.id = :version_table_id AND "
1468                "version_table.version_id = :version_table_version_id",
1469                lambda ctx: [
1470                    {
1471                        "version_table_id": 1,
1472                        "version_table_version_id": 1,
1473                        "value": "f2",
1474                    }
1475                ],
1476            )
1477        ]
1478        if not testing.db.dialect.implicit_returning:
1479            # DBs without implicit returning, we must immediately
1480            # SELECT for the new version id
1481            statements.append(
1482                CompiledSQL(
1483                    "SELECT version_table.version_id "
1484                    "AS version_table_version_id "
1485                    "FROM version_table WHERE version_table.id = :pk_1",
1486                    lambda ctx: [{"pk_1": 1}],
1487                )
1488            )
1489        with conditional_sane_rowcount_warnings(
1490            update=True, only_returning=True
1491        ):
1492            self.assert_sql_execution(testing.db, sess.flush, *statements)
1493
1494    @testing.requires.updateable_autoincrement_pks
1495    def test_sql_expr_bump(self):
1496        sess = self._fixture()
1497
1498        f1 = self.classes.Foo(value="f1")
1499        sess.add(f1)
1500        sess.flush()
1501
1502        eq_(f1.version_id, 1)
1503
1504        f1.id = self.classes.Foo.id + 0
1505
1506        with conditional_sane_rowcount_warnings(
1507            update=True, only_returning=True
1508        ):
1509            sess.flush()
1510
1511        eq_(f1.version_id, 2)
1512
1513    @testing.requires.updateable_autoincrement_pks
1514    @testing.requires.returning
1515    def test_sql_expr_w_mods_bump(self):
1516        sess = self._fixture()
1517
1518        f1 = self.classes.Foo(id=2, value="f1")
1519        sess.add(f1)
1520        sess.flush()
1521
1522        eq_(f1.version_id, 1)
1523
1524        f1.id = self.classes.Foo.id + 3
1525
1526        with conditional_sane_rowcount_warnings(update=True):
1527            sess.flush()
1528
1529        eq_(f1.id, 5)
1530        eq_(f1.version_id, 2)
1531
1532    def test_multi_update(self):
1533        sess = self._fixture()
1534
1535        f1 = self.classes.Foo(value="f1")
1536        f2 = self.classes.Foo(value="f2")
1537        f3 = self.classes.Foo(value="f3")
1538        sess.add_all([f1, f2, f3])
1539        sess.flush()
1540
1541        f1.value = "f1a"
1542        f2.value = "f2a"
1543        f3.value = "f3a"
1544
1545        statements = [
1546            # note that the assertsql tests the rule against
1547            # "default" - on a "returning" backend, the statement
1548            # includes "RETURNING"
1549            CompiledSQL(
1550                "UPDATE version_table SET version_id=2, value=:value "
1551                "WHERE version_table.id = :version_table_id AND "
1552                "version_table.version_id = :version_table_version_id",
1553                lambda ctx: [
1554                    {
1555                        "version_table_id": 1,
1556                        "version_table_version_id": 1,
1557                        "value": "f1a",
1558                    }
1559                ],
1560            ),
1561            CompiledSQL(
1562                "UPDATE version_table SET version_id=2, value=:value "
1563                "WHERE version_table.id = :version_table_id AND "
1564                "version_table.version_id = :version_table_version_id",
1565                lambda ctx: [
1566                    {
1567                        "version_table_id": 2,
1568                        "version_table_version_id": 1,
1569                        "value": "f2a",
1570                    }
1571                ],
1572            ),
1573            CompiledSQL(
1574                "UPDATE version_table SET version_id=2, value=:value "
1575                "WHERE version_table.id = :version_table_id AND "
1576                "version_table.version_id = :version_table_version_id",
1577                lambda ctx: [
1578                    {
1579                        "version_table_id": 3,
1580                        "version_table_version_id": 1,
1581                        "value": "f3a",
1582                    }
1583                ],
1584            ),
1585        ]
1586        if not testing.db.dialect.implicit_returning:
1587            # DBs without implicit returning, we must immediately
1588            # SELECT for the new version id
1589            statements.extend(
1590                [
1591                    CompiledSQL(
1592                        "SELECT version_table.version_id "
1593                        "AS version_table_version_id "
1594                        "FROM version_table WHERE version_table.id = :pk_1",
1595                        lambda ctx: [{"pk_1": 1}],
1596                    ),
1597                    CompiledSQL(
1598                        "SELECT version_table.version_id "
1599                        "AS version_table_version_id "
1600                        "FROM version_table WHERE version_table.id = :pk_1",
1601                        lambda ctx: [{"pk_1": 2}],
1602                    ),
1603                    CompiledSQL(
1604                        "SELECT version_table.version_id "
1605                        "AS version_table_version_id "
1606                        "FROM version_table WHERE version_table.id = :pk_1",
1607                        lambda ctx: [{"pk_1": 3}],
1608                    ),
1609                ]
1610            )
1611        with conditional_sane_rowcount_warnings(
1612            update=True, only_returning=True
1613        ):
1614            self.assert_sql_execution(testing.db, sess.flush, *statements)
1615
1616    def test_delete_col(self):
1617        sess = self._fixture()
1618
1619        f1 = self.classes.Foo(value="f1")
1620        sess.add(f1)
1621        sess.flush()
1622
1623        sess.delete(f1)
1624
1625        statements = [
1626            # note that the assertsql tests the rule against
1627            # "default" - on a "returning" backend, the statement
1628            # includes "RETURNING"
1629            CompiledSQL(
1630                "DELETE FROM version_table "
1631                "WHERE version_table.id = :id AND "
1632                "version_table.version_id = :version_id",
1633                lambda ctx: [{"id": 1, "version_id": 1}],
1634            )
1635        ]
1636        with conditional_sane_rowcount_warnings(delete=True):
1637            self.assert_sql_execution(testing.db, sess.flush, *statements)
1638
1639    @testing.requires.sane_rowcount_w_returning
1640    def test_concurrent_mod_err_expire_on_commit(self):
1641        sess = self._fixture()
1642
1643        f1 = self.classes.Foo(value="f1")
1644        sess.add(f1)
1645        sess.commit()
1646
1647        f1.value
1648
1649        s2 = fixture_session()
1650        f2 = s2.query(self.classes.Foo).first()
1651        f2.value = "f2"
1652        s2.commit()
1653
1654        f1.value = "f3"
1655
1656        assert_raises_message(
1657            orm.exc.StaleDataError,
1658            r"UPDATE statement on table 'version_table' expected to "
1659            r"update 1 row\(s\); 0 were matched.",
1660            sess.commit,
1661        )
1662
1663    @testing.requires.sane_rowcount_w_returning
1664    def test_concurrent_mod_err_noexpire_on_commit(self):
1665        sess = self._fixture(expire_on_commit=False)
1666
1667        f1 = self.classes.Foo(value="f1")
1668        sess.add(f1)
1669        sess.commit()
1670
1671        # here, we're not expired overall, so no load occurs and we
1672        # stay without a version id, unless we've emitted
1673        # a SELECT for it within the flush.
1674        f1.value
1675
1676        s2 = fixture_session(expire_on_commit=False)
1677        f2 = s2.query(self.classes.Foo).first()
1678        f2.value = "f2"
1679        s2.commit()
1680
1681        f1.value = "f3"
1682
1683        assert_raises_message(
1684            orm.exc.StaleDataError,
1685            r"UPDATE statement on table 'version_table' expected to "
1686            r"update 1 row\(s\); 0 were matched.",
1687            sess.commit,
1688        )
1689
1690
1691class ManualVersionTest(fixtures.MappedTest):
1692    run_define_tables = "each"
1693    __backend__ = True
1694
1695    @classmethod
1696    def define_tables(cls, metadata):
1697        Table(
1698            "a",
1699            metadata,
1700            Column(
1701                "id", Integer, primary_key=True, test_needs_autoincrement=True
1702            ),
1703            Column("data", String(30)),
1704            Column("vid", Integer),
1705        )
1706
1707    @classmethod
1708    def setup_classes(cls):
1709        class A(cls.Basic):
1710            pass
1711
1712    @classmethod
1713    def setup_mappers(cls):
1714        cls.mapper_registry.map_imperatively(
1715            cls.classes.A,
1716            cls.tables.a,
1717            version_id_col=cls.tables.a.c.vid,
1718            version_id_generator=False,
1719        )
1720
1721    def test_insert(self):
1722        sess = fixture_session()
1723        a1 = self.classes.A()
1724
1725        a1.vid = 1
1726        sess.add(a1)
1727        sess.commit()
1728
1729        eq_(a1.vid, 1)
1730
1731    def test_update(self):
1732        sess = fixture_session()
1733        a1 = self.classes.A()
1734
1735        a1.vid = 1
1736        a1.data = "d1"
1737        sess.add(a1)
1738        sess.commit()
1739
1740        a1.vid = 2
1741        a1.data = "d2"
1742
1743        with conditional_sane_rowcount_warnings(
1744            update=True, only_returning=True
1745        ):
1746            sess.commit()
1747
1748        eq_(a1.vid, 2)
1749
1750    @testing.requires.sane_rowcount_w_returning
1751    def test_update_concurrent_check(self):
1752        sess = fixture_session()
1753        a1 = self.classes.A()
1754
1755        a1.vid = 1
1756        a1.data = "d1"
1757        sess.add(a1)
1758        sess.commit()
1759
1760        a1.vid = 2
1761        sess.execute(self.tables.a.update().values(vid=3))
1762        a1.data = "d2"
1763        assert_raises(orm_exc.StaleDataError, sess.commit)
1764
1765    def test_update_version_conditional(self):
1766        sess = fixture_session()
1767        a1 = self.classes.A()
1768
1769        a1.vid = 1
1770        a1.data = "d1"
1771        sess.add(a1)
1772        sess.commit()
1773
1774        # change the data and UPDATE without
1775        # incrementing version id
1776        a1.data = "d2"
1777        with conditional_sane_rowcount_warnings(
1778            update=True, only_returning=True
1779        ):
1780            sess.commit()
1781
1782        eq_(a1.vid, 1)
1783
1784        a1.data = "d3"
1785        a1.vid = 2
1786        with conditional_sane_rowcount_warnings(
1787            update=True, only_returning=True
1788        ):
1789            sess.commit()
1790
1791        eq_(a1.vid, 2)
1792
1793
1794class ManualInheritanceVersionTest(fixtures.MappedTest):
1795    run_define_tables = "each"
1796    __backend__ = True
1797    __requires__ = ("sane_rowcount",)
1798
1799    @classmethod
1800    def define_tables(cls, metadata):
1801        Table(
1802            "a",
1803            metadata,
1804            Column(
1805                "id", Integer, primary_key=True, test_needs_autoincrement=True
1806            ),
1807            Column("data", String(30)),
1808            Column("vid", Integer, nullable=False),
1809        )
1810
1811        Table(
1812            "b",
1813            metadata,
1814            Column("id", Integer, ForeignKey("a.id"), primary_key=True),
1815            Column("b_data", String(30)),
1816        )
1817
1818    @classmethod
1819    def setup_classes(cls):
1820        class A(cls.Basic):
1821            pass
1822
1823        class B(A):
1824            pass
1825
1826    @classmethod
1827    def setup_mappers(cls):
1828        cls.mapper_registry.map_imperatively(
1829            cls.classes.A,
1830            cls.tables.a,
1831            version_id_col=cls.tables.a.c.vid,
1832            version_id_generator=False,
1833        )
1834
1835        cls.mapper_registry.map_imperatively(
1836            cls.classes.B, cls.tables.b, inherits=cls.classes.A
1837        )
1838
1839    def test_no_increment(self):
1840        sess = fixture_session()
1841        b1 = self.classes.B()
1842
1843        b1.vid = 1
1844        b1.data = "d1"
1845        sess.add(b1)
1846        sess.commit()
1847
1848        # change col on subtable only without
1849        # incrementing version id
1850        b1.b_data = "bd2"
1851        with conditional_sane_rowcount_warnings(
1852            update=True, only_returning=True
1853        ):
1854            sess.commit()
1855
1856        eq_(b1.vid, 1)
1857
1858        b1.b_data = "d3"
1859        b1.vid = 2
1860        with conditional_sane_rowcount_warnings(
1861            update=True, only_returning=True
1862        ):
1863            sess.commit()
1864
1865        eq_(b1.vid, 2)
1866
1867
1868class VersioningMappedSelectTest(fixtures.MappedTest):
1869    # test for #4193, see also #4194 for related notes
1870
1871    __backend__ = True
1872
1873    @classmethod
1874    def define_tables(cls, metadata):
1875        Table(
1876            "version_table",
1877            metadata,
1878            Column(
1879                "id", Integer, primary_key=True, test_needs_autoincrement=True
1880            ),
1881            Column("version_id", Integer, nullable=False),
1882            Column("value", String(40), nullable=False),
1883        )
1884
1885    @classmethod
1886    def setup_classes(cls):
1887        class Foo(cls.Basic):
1888            pass
1889
1890    def _implicit_version_fixture(self):
1891        Foo, version_table = self.classes.Foo, self.tables.version_table
1892
1893        current = (
1894            version_table.select()
1895            .where(version_table.c.id > 0)
1896            .alias("current_table")
1897        )
1898
1899        self.mapper_registry.map_imperatively(
1900            Foo, current, version_id_col=version_table.c.version_id
1901        )
1902        s1 = fixture_session()
1903        return s1
1904
1905    def _explicit_version_fixture(self):
1906        Foo, version_table = self.classes.Foo, self.tables.version_table
1907
1908        current = (
1909            version_table.select()
1910            .where(version_table.c.id > 0)
1911            .alias("current_table")
1912        )
1913
1914        self.mapper_registry.map_imperatively(
1915            Foo,
1916            current,
1917            version_id_col=version_table.c.version_id,
1918            version_id_generator=False,
1919        )
1920        s1 = fixture_session()
1921        return s1
1922
1923    def test_implicit(self):
1924        Foo = self.classes.Foo
1925
1926        s1 = self._implicit_version_fixture()
1927        f1 = Foo(value="f1")
1928        f2 = Foo(value="f2")
1929        s1.add_all((f1, f2))
1930        s1.commit()
1931
1932        f1.value = "f1rev2"
1933        f2.value = "f2rev2"
1934        with conditional_sane_rowcount_warnings(
1935            update=True, only_returning=True
1936        ):
1937            s1.commit()
1938
1939        eq_(
1940            s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(),
1941            [(f1.id, "f1rev2", 2), (f2.id, "f2rev2", 2)],
1942        )
1943
1944    def test_explicit(self):
1945        Foo = self.classes.Foo
1946
1947        s1 = self._explicit_version_fixture()
1948        f1 = Foo(value="f1", version_id=1)
1949        f2 = Foo(value="f2", version_id=1)
1950        s1.add_all((f1, f2))
1951        s1.flush()
1952
1953        # note this requires that the Session was not expired until
1954        # we fix #4195
1955        f1.value = "f1rev2"
1956        f1.version_id = 2
1957        f2.value = "f2rev2"
1958        f2.version_id = 2
1959        with conditional_sane_rowcount_warnings(
1960            update=True, only_returning=True
1961        ):
1962            s1.flush()
1963
1964        eq_(
1965            s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(),
1966            [(f1.id, "f1rev2", 2), (f2.id, "f2rev2", 2)],
1967        )
1968
1969    def test_implicit_no_readonly(self):
1970        # test issue 4194
1971        Foo = self.classes.Foo
1972
1973        s1 = self._implicit_version_fixture()
1974        f1 = Foo(value="f1")
1975        s1.add(f1)
1976        s1.flush()
1977
1978        is_false(bool(inspect(Foo)._readonly_props))
1979
1980        def go():
1981            eq_(f1.version_id, 1)
1982
1983        self.assert_sql_count(testing.db, go, 0)
1984
1985    def test_explicit_assign_from_expired(self):
1986        # test issue 4195
1987        Foo = self.classes.Foo
1988
1989        s1 = self._explicit_version_fixture()
1990
1991        configure_mappers()
1992        is_true(Foo.version_id.impl.active_history)
1993
1994        f1 = Foo(value="f1", version_id=1)
1995        s1.add(f1)
1996
1997        s1.flush()
1998
1999        s1.expire_all()
2000
2001        with conditional_sane_rowcount_warnings(
2002            update=True, only_returning=True
2003        ):
2004            f1.value = "f2"
2005            f1.version_id = 2
2006            s1.flush()
2007