1import asyncio
2import pytest
3
4from aioredis import ReplyError, MultiExecError, WatchVariableError
5from aioredis import ConnectionClosedError
6
7
8async def test_multi_exec(redis):
9    await redis.delete('foo', 'bar')
10
11    tr = redis.multi_exec()
12    f1 = tr.incr('foo')
13    f2 = tr.incr('bar')
14    res = await tr.execute()
15    assert res == [1, 1]
16    res2 = await asyncio.gather(f1, f2)
17    assert res == res2
18
19    tr = redis.multi_exec()
20    f1 = tr.incr('foo')
21    f2 = tr.incr('bar')
22    await tr.execute()
23    assert (await f1) == 2
24    assert (await f2) == 2
25
26    tr = redis.multi_exec()
27    f1 = tr.set('foo', 1.0)
28    f2 = tr.incrbyfloat('foo', 1.2)
29    res = await tr.execute()
30    assert res == [True, 2.2]
31    res2 = await asyncio.gather(f1, f2)
32    assert res == res2
33
34    tr = redis.multi_exec()
35    f1 = tr.incrby('foo', 1.0)
36    with pytest.raises(MultiExecError, match="increment must be .* int"):
37        await tr.execute()
38    with pytest.raises(TypeError):
39        await f1
40
41
42async def test_empty(redis):
43    tr = redis.multi_exec()
44    res = await tr.execute()
45    assert res == []
46
47
48async def test_double_execute(redis):
49    tr = redis.multi_exec()
50    await tr.execute()
51    with pytest.raises(AssertionError):
52        await tr.execute()
53    with pytest.raises(AssertionError):
54        await tr.incr('foo')
55
56
57async def test_connection_closed(redis):
58    tr = redis.multi_exec()
59    fut1 = tr.quit()
60    fut2 = tr.incrby('foo', 1.0)
61    fut3 = tr.incrby('foo', 1)
62    with pytest.raises(MultiExecError):
63        await tr.execute()
64
65    assert fut1.done() is True
66    assert fut2.done() is True
67    assert fut3.done() is True
68    assert fut1.exception() is not None
69    assert fut2.exception() is not None
70    assert fut3.exception() is not None
71    assert not fut1.cancelled()
72    assert not fut2.cancelled()
73    assert not fut3.cancelled()
74
75    try:
76        assert (await fut1) == b'OK'
77    except Exception as err:
78        assert isinstance(err, (ConnectionClosedError, ConnectionError))
79    assert fut2.cancelled() is False
80    assert isinstance(fut2.exception(), TypeError)
81
82    # assert fut3.cancelled() is True
83    assert fut3.done() and not fut3.cancelled()
84    assert isinstance(fut3.exception(),
85                      (ConnectionClosedError, ConnectionError))
86
87
88async def test_discard(redis):
89    await redis.delete('foo')
90    tr = redis.multi_exec()
91    fut1 = tr.incrby('foo', 1.0)
92    fut2 = tr.connection.execute('MULTI')
93    fut3 = tr.connection.execute('incr', 'foo')
94
95    with pytest.raises(MultiExecError):
96        await tr.execute()
97    with pytest.raises(TypeError):
98        await fut1
99    with pytest.raises(ReplyError):
100        await fut2
101    # with pytest.raises(ReplyError):
102    res = await fut3
103    assert res == 1
104
105
106async def test_exec_error(redis):
107    tr = redis.multi_exec()
108    fut = tr.connection.execute('INCRBY', 'key', '1.0')
109    with pytest.raises(MultiExecError):
110        await tr.execute()
111    with pytest.raises(ReplyError):
112        await fut
113
114    await redis.set('foo', 'bar')
115    tr = redis.multi_exec()
116    fut = tr.incrbyfloat('foo', 1.1)
117    res = await tr.execute(return_exceptions=True)
118    assert isinstance(res[0], ReplyError)
119    with pytest.raises(ReplyError):
120        await fut
121
122
123async def test_command_errors(redis):
124    tr = redis.multi_exec()
125    fut = tr.incrby('key', 1.0)
126    with pytest.raises(MultiExecError):
127        await tr.execute()
128    with pytest.raises(TypeError):
129        await fut
130
131
132async def test_several_command_errors(redis):
133    tr = redis.multi_exec()
134    fut1 = tr.incrby('key', 1.0)
135    fut2 = tr.rename('bar', 'bar')
136    with pytest.raises(MultiExecError):
137        await tr.execute()
138    with pytest.raises(TypeError):
139        await fut1
140    with pytest.raises(ValueError):
141        await fut2
142
143
144async def test_error_in_connection(redis):
145    await redis.set('foo', 1)
146    tr = redis.multi_exec()
147    fut1 = tr.mget('foo', None)
148    fut2 = tr.incr('foo')
149    with pytest.raises(MultiExecError):
150        await tr.execute()
151    with pytest.raises(TypeError):
152        await fut1
153    await fut2
154
155
156async def test_watch_unwatch(redis):
157    res = await redis.watch('key')
158    assert res is True
159    res = await redis.watch('key', 'key')
160    assert res is True
161
162    with pytest.raises(TypeError):
163        await redis.watch(None)
164    with pytest.raises(TypeError):
165        await redis.watch('key', None)
166    with pytest.raises(TypeError):
167        await redis.watch('key', 'key', None)
168
169    res = await redis.unwatch()
170    assert res is True
171
172
173async def test_encoding(redis):
174    res = await redis.set('key', 'value')
175    assert res is True
176    res = await redis.hmset(
177        'hash-key', 'foo', 'val1', 'bar', 'val2')
178    assert res is True
179
180    tr = redis.multi_exec()
181    fut1 = tr.get('key')
182    fut2 = tr.get('key', encoding='utf-8')
183    fut3 = tr.hgetall('hash-key', encoding='utf-8')
184    await tr.execute()
185    res = await fut1
186    assert res == b'value'
187    res = await fut2
188    assert res == 'value'
189    res = await fut3
190    assert res == {'foo': 'val1', 'bar': 'val2'}
191
192
193async def test_global_encoding(redis, create_redis, server):
194    redis = await create_redis(server.tcp_address, encoding='utf-8')
195    res = await redis.set('key', 'value')
196    assert res is True
197    res = await redis.hmset(
198        'hash-key', 'foo', 'val1', 'bar', 'val2')
199    assert res is True
200
201    tr = redis.multi_exec()
202    fut1 = tr.get('key')
203    fut2 = tr.get('key', encoding='utf-8')
204    fut3 = tr.get('key', encoding=None)
205    fut4 = tr.hgetall('hash-key', encoding='utf-8')
206    await tr.execute()
207    res = await fut1
208    assert res == 'value'
209    res = await fut2
210    assert res == 'value'
211    res = await fut3
212    assert res == b'value'
213    res = await fut4
214    assert res == {'foo': 'val1', 'bar': 'val2'}
215
216
217async def test_transaction__watch_error(redis, create_redis, server):
218    other = await create_redis(server.tcp_address)
219
220    ok = await redis.set('foo', 'bar')
221    assert ok is True
222
223    ok = await redis.watch('foo')
224    assert ok is True
225
226    ok = await other.set('foo', 'baz')
227    assert ok is True
228
229    tr = redis.multi_exec()
230    fut1 = tr.set('foo', 'foo')
231    fut2 = tr.get('bar')
232    with pytest.raises(MultiExecError):
233        await tr.execute()
234    with pytest.raises(WatchVariableError):
235        await fut1
236    with pytest.raises(WatchVariableError):
237        await fut2
238
239
240async def test_multi_exec_and_pool_release(redis):
241    # Test the case when pool connection is released before
242    # `exec` result is received.
243
244    slow_script = """
245    local a = tonumber(redis.call('time')[1])
246    local b = a + 1
247    while (a < b)
248    do
249        a = tonumber(redis.call('time')[1])
250    end
251    """
252
253    tr = redis.multi_exec()
254    fut1 = tr.eval(slow_script)
255    ret, = await tr.execute()
256    assert ret is None
257    assert (await fut1) is None
258
259
260async def test_multi_exec_db_select(redis):
261    await redis.set('foo', 'bar')
262
263    tr = redis.multi_exec()
264    f1 = tr.get('foo', encoding='utf-8')
265    f2 = tr.get('foo')
266    await tr.execute()
267    assert await f1 == 'bar'
268    assert await f2 == b'bar'
269