1from sqlalchemy import exc
2from sqlalchemy import testing
3from sqlalchemy.engine import default
4from sqlalchemy.orm import joinedload
5from sqlalchemy.orm import relationship
6from sqlalchemy.testing import assert_raises_message
7from sqlalchemy.testing import AssertsCompiledSQL
8from sqlalchemy.testing import eq_
9from sqlalchemy.testing.fixtures import fixture_session
10from test.orm import _fixtures
11
12
13class ForUpdateTest(_fixtures.FixtureTest):
14    @classmethod
15    def setup_mappers(cls):
16        User, users = cls.classes.User, cls.tables.users
17        cls.mapper_registry.map_imperatively(User, users)
18
19    def _assert(
20        self,
21        read=False,
22        nowait=False,
23        of=None,
24        key_share=None,
25        assert_q_of=None,
26        assert_sel_of=None,
27    ):
28        User = self.classes.User
29        s = fixture_session()
30        q = s.query(User).with_for_update(
31            read=read, nowait=nowait, of=of, key_share=key_share
32        )
33        sel = q._compile_state().statement
34
35        assert q._for_update_arg.read is read
36        assert sel._for_update_arg.read is read
37
38        assert q._for_update_arg.nowait is nowait
39        assert sel._for_update_arg.nowait is nowait
40
41        assert q._for_update_arg.key_share is key_share
42        assert sel._for_update_arg.key_share is key_share
43
44        eq_(q._for_update_arg.of, assert_q_of)
45        eq_(sel._for_update_arg.of, assert_sel_of)
46
47    def test_key_share(self):
48        self._assert(key_share=True)
49
50    def test_read(self):
51        self._assert(read=True)
52
53    def test_plain(self):
54        self._assert()
55
56    def test_nowait(self):
57        self._assert(nowait=True)
58
59    def test_of_single_col(self):
60        User, users = self.classes.User, self.tables.users
61        self._assert(
62            of=User.id, assert_q_of=[users.c.id], assert_sel_of=[users.c.id]
63        )
64
65
66class BackendTest(_fixtures.FixtureTest):
67    __backend__ = True
68
69    # test against the major backends.   We are naming specific databases
70    # here rather than using requirements rules since the behavior of
71    # "FOR UPDATE" as well as "OF" is very specific to each DB, and we need
72    # to run the query differently based on backend.
73
74    @classmethod
75    def setup_mappers(cls):
76        User, users = cls.classes.User, cls.tables.users
77        Address, addresses = cls.classes.Address, cls.tables.addresses
78        cls.mapper_registry.map_imperatively(
79            User, users, properties={"addresses": relationship(Address)}
80        )
81        cls.mapper_registry.map_imperatively(Address, addresses)
82
83    def test_inner_joinedload_w_limit(self):
84        User = self.classes.User
85        sess = fixture_session()
86        q = (
87            sess.query(User)
88            .options(joinedload(User.addresses, innerjoin=True))
89            .with_for_update()
90            .limit(1)
91        )
92
93        if testing.against("oracle"):
94            assert_raises_message(exc.DatabaseError, "ORA-02014", q.all)
95        else:
96            q.all()
97        sess.close()
98
99    def test_inner_joinedload_wo_limit(self):
100        User = self.classes.User
101        sess = fixture_session()
102        sess.query(User).options(
103            joinedload(User.addresses, innerjoin=True)
104        ).with_for_update().all()
105        sess.close()
106
107    def test_outer_joinedload_w_limit(self):
108        User = self.classes.User
109        sess = fixture_session()
110        q = sess.query(User).options(
111            joinedload(User.addresses, innerjoin=False)
112        )
113
114        if testing.against("postgresql"):
115            q = q.with_for_update(of=User)
116        else:
117            q = q.with_for_update()
118
119        q = q.limit(1)
120
121        if testing.against("oracle"):
122            assert_raises_message(exc.DatabaseError, "ORA-02014", q.all)
123        else:
124            q.all()
125        sess.close()
126
127    def test_outer_joinedload_wo_limit(self):
128        User = self.classes.User
129        sess = fixture_session()
130        q = sess.query(User).options(
131            joinedload(User.addresses, innerjoin=False)
132        )
133
134        if testing.against("postgresql"):
135            q = q.with_for_update(of=User)
136        else:
137            q = q.with_for_update()
138
139        q.all()
140        sess.close()
141
142    def test_join_w_subquery(self):
143        User = self.classes.User
144        Address = self.classes.Address
145        sess = fixture_session()
146        q1 = sess.query(User).with_for_update().subquery()
147        sess.query(q1).join(Address).all()
148        sess.close()
149
150    def test_plain(self):
151        User = self.classes.User
152        sess = fixture_session()
153        sess.query(User).with_for_update().all()
154        sess.close()
155
156
157class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL):
158    """run some compile tests, even though these are redundant."""
159
160    run_inserts = None
161
162    @classmethod
163    def setup_mappers(cls):
164        User, users = cls.classes.User, cls.tables.users
165        Address, addresses = cls.classes.Address, cls.tables.addresses
166        cls.mapper_registry.map_imperatively(
167            User, users, properties={"addresses": relationship(Address)}
168        )
169        cls.mapper_registry.map_imperatively(Address, addresses)
170
171    def test_default_update(self):
172        User = self.classes.User
173        sess = fixture_session()
174        self.assert_compile(
175            sess.query(User.id).with_for_update(),
176            "SELECT users.id AS users_id FROM users FOR UPDATE",
177            dialect=default.DefaultDialect(),
178        )
179
180    def test_not_supported_by_dialect_should_just_use_update(self):
181        User = self.classes.User
182        sess = fixture_session()
183        self.assert_compile(
184            sess.query(User.id).with_for_update(read=True),
185            "SELECT users.id AS users_id FROM users FOR UPDATE",
186            dialect=default.DefaultDialect(),
187        )
188
189    def test_postgres_read(self):
190        User = self.classes.User
191        sess = fixture_session()
192        self.assert_compile(
193            sess.query(User.id).with_for_update(read=True),
194            "SELECT users.id AS users_id FROM users FOR SHARE",
195            dialect="postgresql",
196        )
197
198    def test_postgres_read_nowait(self):
199        User = self.classes.User
200        sess = fixture_session()
201        self.assert_compile(
202            sess.query(User.id).with_for_update(read=True, nowait=True),
203            "SELECT users.id AS users_id FROM users FOR SHARE NOWAIT",
204            dialect="postgresql",
205        )
206
207    def test_postgres_update(self):
208        User = self.classes.User
209        sess = fixture_session()
210        self.assert_compile(
211            sess.query(User.id).with_for_update(),
212            "SELECT users.id AS users_id FROM users FOR UPDATE",
213            dialect="postgresql",
214        )
215
216    def test_postgres_update_of(self):
217        User = self.classes.User
218        sess = fixture_session()
219        self.assert_compile(
220            sess.query(User.id).with_for_update(of=User.id),
221            "SELECT users.id AS users_id FROM users FOR UPDATE OF users",
222            dialect="postgresql",
223        )
224
225    def test_postgres_update_of_entity(self):
226        User = self.classes.User
227        sess = fixture_session()
228        self.assert_compile(
229            sess.query(User.id).with_for_update(of=User),
230            "SELECT users.id AS users_id FROM users FOR UPDATE OF users",
231            dialect="postgresql",
232        )
233
234    def test_postgres_update_of_entity_list(self):
235        User = self.classes.User
236        Address = self.classes.Address
237
238        sess = fixture_session()
239        self.assert_compile(
240            sess.query(User.id, Address.id).with_for_update(
241                of=[User, Address]
242            ),
243            "SELECT users.id AS users_id, addresses.id AS addresses_id "
244            "FROM users, addresses FOR UPDATE OF users, addresses",
245            dialect="postgresql",
246        )
247
248    def test_postgres_for_no_key_update(self):
249        User = self.classes.User
250        sess = fixture_session()
251        self.assert_compile(
252            sess.query(User.id).with_for_update(key_share=True),
253            "SELECT users.id AS users_id FROM users FOR NO KEY UPDATE",
254            dialect="postgresql",
255        )
256
257    def test_postgres_for_no_key_nowait_update(self):
258        User = self.classes.User
259        sess = fixture_session()
260        self.assert_compile(
261            sess.query(User.id).with_for_update(key_share=True, nowait=True),
262            "SELECT users.id AS users_id FROM users FOR NO KEY UPDATE NOWAIT",
263            dialect="postgresql",
264        )
265
266    def test_postgres_update_of_list(self):
267        User = self.classes.User
268        sess = fixture_session()
269        self.assert_compile(
270            sess.query(User.id).with_for_update(
271                of=[User.id, User.id, User.id]
272            ),
273            "SELECT users.id AS users_id FROM users FOR UPDATE OF users",
274            dialect="postgresql",
275        )
276
277    def test_postgres_update_skip_locked(self):
278        User = self.classes.User
279        sess = fixture_session()
280        self.assert_compile(
281            sess.query(User.id).with_for_update(skip_locked=True),
282            "SELECT users.id AS users_id FROM users FOR UPDATE SKIP LOCKED",
283            dialect="postgresql",
284        )
285
286    def test_oracle_update(self):
287        User = self.classes.User
288        sess = fixture_session()
289        self.assert_compile(
290            sess.query(User.id).with_for_update(),
291            "SELECT users.id AS users_id FROM users FOR UPDATE",
292            dialect="oracle",
293        )
294
295    def test_oracle_update_skip_locked(self):
296        User = self.classes.User
297        sess = fixture_session()
298        self.assert_compile(
299            sess.query(User.id).with_for_update(skip_locked=True),
300            "SELECT users.id AS users_id FROM users FOR UPDATE SKIP LOCKED",
301            dialect="oracle",
302        )
303
304    def test_mysql_read(self):
305        User = self.classes.User
306        sess = fixture_session()
307        self.assert_compile(
308            sess.query(User.id).with_for_update(read=True),
309            "SELECT users.id AS users_id FROM users LOCK IN SHARE MODE",
310            dialect="mysql",
311        )
312
313    def test_for_update_on_inner_w_joinedload(self):
314        User = self.classes.User
315        sess = fixture_session()
316        self.assert_compile(
317            sess.query(User)
318            .options(joinedload(User.addresses))
319            .with_for_update()
320            .limit(1),
321            "SELECT anon_1.users_id AS anon_1_users_id, anon_1.users_name "
322            "AS anon_1_users_name, addresses_1.id AS addresses_1_id, "
323            "addresses_1.user_id AS addresses_1_user_id, "
324            "addresses_1.email_address AS addresses_1_email_address "
325            "FROM (SELECT users.id AS users_id, users.name AS users_name "
326            "FROM users  LIMIT %s FOR UPDATE) AS anon_1 "
327            "LEFT OUTER JOIN addresses AS addresses_1 "
328            "ON anon_1.users_id = addresses_1.user_id FOR UPDATE",
329            dialect="mysql",
330        )
331
332    def test_for_update_on_inner_w_joinedload_no_render_oracle(self):
333        User = self.classes.User
334        sess = fixture_session()
335        self.assert_compile(
336            sess.query(User)
337            .options(joinedload(User.addresses))
338            .with_for_update()
339            .limit(1),
340            "SELECT anon_1.users_id AS anon_1_users_id, "
341            "anon_1.users_name AS anon_1_users_name, "
342            "addresses_1.id AS addresses_1_id, "
343            "addresses_1.user_id AS addresses_1_user_id, "
344            "addresses_1.email_address AS addresses_1_email_address "
345            "FROM (SELECT anon_2.users_id AS users_id, "
346            "anon_2.users_name AS users_name FROM "
347            "(SELECT users.id AS users_id, users.name AS users_name "
348            "FROM users) anon_2 WHERE ROWNUM <= "
349            "__[POSTCOMPILE_param_1]) anon_1 "
350            "LEFT OUTER JOIN addresses addresses_1 "
351            "ON anon_1.users_id = addresses_1.user_id FOR UPDATE",
352            dialect="oracle",
353        )
354