1import json
2import uuid
3
4import aioredis
5from . import defaults
6from .base_cache import BaseCache, CacheEntry
7
8# Remove once fixed: https://github.com/aio-libs/aioredis-py/issues/1115
9aioredis.Redis.__del__ = lambda *args: None  # type: ignore
10
11def pack_entry(entry):
12    ts, pol_id, pol_body = entry  # pylint: disable=invalid-name,unused-variable
13    obj = (pol_id, pol_body)
14    # add unique seed to entry in order to avoid set collisions
15    # and use ZSET two-index table
16    packed = uuid.uuid4().bytes + json.dumps(obj).encode('utf-8')
17    return packed
18
19
20def unpack_entry(packed):
21    bin_obj = packed[16:]
22    obj = json.loads(bin_obj.decode('utf-8'))
23    pol_id, pol_body = obj
24    return CacheEntry(ts=0, pol_id=pol_id, pol_body=pol_body)
25
26
27class RedisCache(BaseCache):
28    def __init__(self, **opts):
29        self._opts = dict(opts)
30        self._opts['socket_timeout'] = self._opts.get('socket_timeout',
31            defaults.REDIS_TIMEOUT)
32        self._opts['socket_connect_timeout'] = self._opts.get(
33            'socket_connect_timeout', defaults.REDIS_CONNECT_TIMEOUT)
34        self._opts['encoding'] = 'utf-8'
35        self._pool = None
36
37    async def setup(self):
38        url = self._opts['url']
39        opts = dict((k,v) for k, v in self._opts.items() if k != 'url')
40        self._pool = aioredis.from_url(url, **opts)
41
42    async def get(self, key):
43        assert self._pool is not None
44        key = key.encode('utf-8')
45        res = await self._pool.zrevrange(key, 0, 0, withscores=True)
46        if not res:
47            return None
48        packed, ts = res[0]  # pylint: disable=invalid-name
49        entry = unpack_entry(packed)
50        return CacheEntry(ts=ts, pol_id=entry.pol_id, pol_body=entry.pol_body)
51
52    async def set(self, key, value):
53        assert self._pool is not None
54        packed = pack_entry(value)
55        ts = value.ts  # pylint: disable=invalid-name
56        key = key.encode('utf-8')
57
58        # Write
59        async with self._pool.pipeline(transaction=True) as pipe:
60            pipe.zadd(key, {packed: ts})
61            pipe.zremrangebyrank(key, 0, -2)
62            await pipe.execute()
63
64    async def scan(self, token, amount_hint):
65        assert self._pool is not None
66        if token is None:
67            token = b'0'
68
69        new_token, keys = await self._pool.scan(cursor=token, count=amount_hint)
70        if not new_token:
71            new_token = None
72
73        result = []
74        for key in keys:
75            key = key.decode('utf-8')
76            if key != '_metadata':
77                result.append((key, await self.get(key)))
78        return new_token, result
79
80    async def get_proactive_fetch_ts(self):
81        assert self._pool is not None
82        val = await self._pool.hget('_metadata', 'proactive_fetch_ts')
83        return 0 if not val else float(val.decode('utf-8'))
84
85    async def set_proactive_fetch_ts(self, timestamp):
86        assert self._pool is not None
87        val = str(timestamp).encode('utf-8')
88        await self._pool.hset('_metadata', 'proactive_fetch_ts', val)
89
90    async def teardown(self):
91        assert self._pool is not None
92        await self._pool.close()
93