1from sqlalchemy import ForeignKey
2from sqlalchemy import Integer
3from sqlalchemy import testing
4from sqlalchemy.orm import attributes
5from sqlalchemy.orm import class_mapper
6from sqlalchemy.orm import create_session
7from sqlalchemy.orm import exc as orm_exc
8from sqlalchemy.orm import mapper
9from sqlalchemy.orm import sync
10from sqlalchemy.orm import unitofwork
11from sqlalchemy.testing import assert_raises_message
12from sqlalchemy.testing import eq_
13from sqlalchemy.testing import fixtures
14from sqlalchemy.testing.schema import Column
15from sqlalchemy.testing.schema import Table
16
17
18class AssertsUOW(object):
19    def _get_test_uow(self, session):
20        uow = unitofwork.UOWTransaction(session)
21        deleted = set(session._deleted)
22        new = set(session._new)
23        dirty = set(session._dirty_states).difference(deleted)
24        for s in new.union(dirty):
25            uow.register_object(s)
26        for d in deleted:
27            uow.register_object(d, isdelete=True)
28        return uow
29
30
31class SyncTest(
32    fixtures.MappedTest, testing.AssertsExecutionResults, AssertsUOW
33):
34    @classmethod
35    def define_tables(cls, metadata):
36        Table(
37            "t1",
38            metadata,
39            Column("id", Integer, primary_key=True),
40            Column("foo", Integer),
41        )
42        Table(
43            "t2",
44            metadata,
45            Column("id", Integer, ForeignKey("t1.id"), primary_key=True),
46            Column("t1id", Integer, ForeignKey("t1.id")),
47        )
48
49    @classmethod
50    def setup_classes(cls):
51        class A(cls.Basic):
52            pass
53
54        class B(cls.Basic):
55            pass
56
57    @classmethod
58    def setup_mappers(cls):
59        mapper(cls.classes.A, cls.tables.t1)
60        mapper(cls.classes.B, cls.tables.t2)
61
62    def _fixture(self):
63        A, B = self.classes.A, self.classes.B
64        session = create_session()
65        uowcommit = self._get_test_uow(session)
66        a_mapper = class_mapper(A)
67        b_mapper = class_mapper(B)
68        self.a1 = a1 = A()
69        self.b1 = b1 = B()
70        uowcommit = self._get_test_uow(session)
71        return (
72            uowcommit,
73            attributes.instance_state(a1),
74            attributes.instance_state(b1),
75            a_mapper,
76            b_mapper,
77        )
78
79    def test_populate(self):
80        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
81        pairs = [(a_mapper.c.id, b_mapper.c.id)]
82        a1.obj().id = 7
83        assert "id" not in b1.obj().__dict__
84        sync.populate(a1, a_mapper, b1, b_mapper, pairs, uowcommit, False)
85        eq_(b1.obj().id, 7)
86        eq_(b1.obj().__dict__["id"], 7)
87        assert ("pk_cascaded", b1, b_mapper.c.id) not in uowcommit.attributes
88
89    def test_populate_flag_cascaded(self):
90        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
91        pairs = [(a_mapper.c.id, b_mapper.c.id)]
92        a1.obj().id = 7
93        assert "id" not in b1.obj().__dict__
94        sync.populate(a1, a_mapper, b1, b_mapper, pairs, uowcommit, True)
95        eq_(b1.obj().id, 7)
96        eq_(b1.obj().__dict__["id"], 7)
97        eq_(uowcommit.attributes[("pk_cascaded", b1, b_mapper.c.id)], True)
98
99    def test_populate_unmapped_source(self):
100        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
101        pairs = [(b_mapper.c.id, b_mapper.c.id)]
102        assert_raises_message(
103            orm_exc.UnmappedColumnError,
104            "Can't execute sync rule for source column 't2.id'; "
105            r"mapper 'mapped class A->t1' does not map this column.",
106            sync.populate,
107            a1,
108            a_mapper,
109            b1,
110            b_mapper,
111            pairs,
112            uowcommit,
113            False,
114        )
115
116    def test_populate_unmapped_dest(self):
117        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
118        pairs = [(a_mapper.c.id, a_mapper.c.id)]
119        assert_raises_message(
120            orm_exc.UnmappedColumnError,
121            r"Can't execute sync rule for destination "
122            r"column 't1.id'; "
123            r"mapper 'mapped class B->t2' does not map this column.",
124            sync.populate,
125            a1,
126            a_mapper,
127            b1,
128            b_mapper,
129            pairs,
130            uowcommit,
131            False,
132        )
133
134    def test_clear(self):
135        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
136        pairs = [(a_mapper.c.id, b_mapper.c.t1id)]
137        b1.obj().t1id = 8
138        eq_(b1.obj().__dict__["t1id"], 8)
139        sync.clear(b1, b_mapper, pairs)
140        eq_(b1.obj().__dict__["t1id"], None)
141
142    def test_clear_pk(self):
143        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
144        pairs = [(a_mapper.c.id, b_mapper.c.id)]
145        b1.obj().id = 8
146        eq_(b1.obj().__dict__["id"], 8)
147        assert_raises_message(
148            AssertionError,
149            "Dependency rule tried to blank-out primary key "
150            "column 't2.id' on instance '<B",
151            sync.clear,
152            b1,
153            b_mapper,
154            pairs,
155        )
156
157    def test_clear_unmapped(self):
158        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
159        pairs = [(a_mapper.c.id, a_mapper.c.foo)]
160        assert_raises_message(
161            orm_exc.UnmappedColumnError,
162            "Can't execute sync rule for destination "
163            r"column 't1.foo'; mapper 'mapped class B->t2' does not "
164            "map this column.",
165            sync.clear,
166            b1,
167            b_mapper,
168            pairs,
169        )
170
171    def test_update(self):
172        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
173        a1.obj().id = 10
174        a1._commit_all(a1.dict)
175        a1.obj().id = 12
176        pairs = [(a_mapper.c.id, b_mapper.c.id)]
177        dest = {}
178        sync.update(a1, a_mapper, dest, "old_", pairs)
179        eq_(dest, {"id": 12, "old_id": 10})
180
181    def test_update_unmapped(self):
182        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
183        pairs = [(b_mapper.c.id, b_mapper.c.id)]
184        dest = {}
185        assert_raises_message(
186            orm_exc.UnmappedColumnError,
187            "Can't execute sync rule for source column 't2.id'; "
188            r"mapper 'mapped class A->t1' does not map this column.",
189            sync.update,
190            a1,
191            a_mapper,
192            dest,
193            "old_",
194            pairs,
195        )
196
197    def test_populate_dict(self):
198        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
199        a1.obj().id = 10
200        pairs = [(a_mapper.c.id, b_mapper.c.id)]
201        dest = {}
202        sync.populate_dict(a1, a_mapper, dest, pairs)
203        eq_(dest, {"id": 10})
204
205    def test_populate_dict_unmapped(self):
206        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
207        a1.obj().id = 10
208        pairs = [(b_mapper.c.id, b_mapper.c.id)]
209        dest = {}
210        assert_raises_message(
211            orm_exc.UnmappedColumnError,
212            "Can't execute sync rule for source column 't2.id'; "
213            r"mapper 'mapped class A->t1' does not map this column.",
214            sync.populate_dict,
215            a1,
216            a_mapper,
217            dest,
218            pairs,
219        )
220
221    def test_source_modified_unmodified(self):
222        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
223        a1.obj().id = 10
224        pairs = [(a_mapper.c.id, b_mapper.c.id)]
225        eq_(sync.source_modified(uowcommit, a1, a_mapper, pairs), False)
226
227    def test_source_modified_no_pairs(self):
228        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
229        eq_(sync.source_modified(uowcommit, a1, a_mapper, []), False)
230
231    def test_source_modified_modified(self):
232        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
233        a1.obj().id = 10
234        a1._commit_all(a1.dict)
235        a1.obj().id = 12
236        pairs = [(a_mapper.c.id, b_mapper.c.id)]
237        eq_(sync.source_modified(uowcommit, a1, a_mapper, pairs), True)
238
239    def test_source_modified_composite(self):
240        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
241        a1.obj().foo = 10
242        a1._commit_all(a1.dict)
243        a1.obj().foo = 12
244        pairs = [
245            (a_mapper.c.id, b_mapper.c.id),
246            (a_mapper.c.foo, b_mapper.c.id),
247        ]
248        eq_(sync.source_modified(uowcommit, a1, a_mapper, pairs), True)
249
250    def test_source_modified_composite_unmodified(self):
251        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
252        a1.obj().foo = 10
253        a1._commit_all(a1.dict)
254        pairs = [
255            (a_mapper.c.id, b_mapper.c.id),
256            (a_mapper.c.foo, b_mapper.c.id),
257        ]
258        eq_(sync.source_modified(uowcommit, a1, a_mapper, pairs), False)
259
260    def test_source_modified_no_unmapped(self):
261        uowcommit, a1, b1, a_mapper, b_mapper = self._fixture()
262        pairs = [(b_mapper.c.id, b_mapper.c.id)]
263        assert_raises_message(
264            orm_exc.UnmappedColumnError,
265            "Can't execute sync rule for source column 't2.id'; "
266            r"mapper 'mapped class A->t1' does not map this column.",
267            sync.source_modified,
268            uowcommit,
269            a1,
270            a_mapper,
271            pairs,
272        )
273