1import asyncio
2import pytest
3
4import aioredis
5
6
7@pytest.fixture
8def pool_or_redis(_closable, server):
9    version = tuple(map(int, aioredis.__version__.split('.')[:2]))
10    if version >= (1, 0):
11        factory = aioredis.create_redis_pool
12    else:
13        factory = aioredis.create_pool
14
15    async def redis_factory(maxsize):
16        redis = await factory(server.tcp_address,
17                              minsize=1, maxsize=maxsize)
18        _closable(redis)
19        return redis
20    return redis_factory
21
22
23async def simple_get_set(pool, idx):
24    """A simple test to make sure Redis(pool) can be used as old Pool(Redis).
25    """
26    val = 'val:{}'.format(idx)
27    with await pool as redis:
28        assert await redis.set('key', val)
29        await redis.get('key', encoding='utf-8')
30
31
32async def pipeline(pool, val):
33    val = 'val:{}'.format(val)
34    with await pool as redis:
35        f1 = redis.set('key', val)
36        f2 = redis.get('key', encoding='utf-8')
37        ok, res = await asyncio.gather(f1, f2)
38
39
40async def transaction(pool, val):
41    val = 'val:{}'.format(val)
42    with await pool as redis:
43        tr = redis.multi_exec()
44        tr.set('key', val)
45        tr.get('key', encoding='utf-8')
46        ok, res = await tr.execute()
47        assert ok, ok
48        assert res == val
49
50
51async def blocking_pop(pool, val):
52
53    async def lpush():
54        with await pool as redis:
55            # here v0.3 has bound connection, v1.0 does not;
56            await asyncio.sleep(.1)
57            await redis.lpush('list-key', 'val')
58
59    async def blpop():
60        with await pool as redis:
61            # here v0.3 has bound connection, v1.0 does not;
62            res = await redis.blpop(
63                'list-key', timeout=2, encoding='utf-8')
64            assert res == ['list-key', 'val'], res
65    await asyncio.gather(blpop(), lpush())
66
67
68@pytest.mark.parametrize('test_case,pool_size', [
69    (simple_get_set, 1),
70    (pipeline, 1),
71    (transaction, 1),
72    pytest.param(
73        blocking_pop, 1,
74        marks=pytest.mark.xfail(
75            reason="blpop gets connection first and blocks")
76        ),
77    (simple_get_set, 10),
78    (pipeline, 10),
79    (transaction, 10),
80    (blocking_pop, 10),
81], ids=lambda o: getattr(o, '__name__', repr(o)))
82async def test_operations(pool_or_redis, test_case, pool_size):
83    repeat = 100
84    redis = await pool_or_redis(pool_size)
85    done, pending = await asyncio.wait(
86        [asyncio.ensure_future(test_case(redis, i))
87         for i in range(repeat)])
88
89    assert not pending
90    success = 0
91    failures = []
92    for fut in done:
93        exc = fut.exception()
94        if exc is None:
95            success += 1
96        else:
97            failures.append(exc)
98    assert repeat == success, failures
99    assert not failures
100