1import json 2import uuid 3 4import aioredis 5from . import defaults 6from .base_cache import BaseCache, CacheEntry 7 8# Remove once fixed: https://github.com/aio-libs/aioredis-py/issues/1115 9aioredis.Redis.__del__ = lambda *args: None # type: ignore 10 11def pack_entry(entry): 12 ts, pol_id, pol_body = entry # pylint: disable=invalid-name,unused-variable 13 obj = (pol_id, pol_body) 14 # add unique seed to entry in order to avoid set collisions 15 # and use ZSET two-index table 16 packed = uuid.uuid4().bytes + json.dumps(obj).encode('utf-8') 17 return packed 18 19 20def unpack_entry(packed): 21 bin_obj = packed[16:] 22 obj = json.loads(bin_obj.decode('utf-8')) 23 pol_id, pol_body = obj 24 return CacheEntry(ts=0, pol_id=pol_id, pol_body=pol_body) 25 26 27class RedisCache(BaseCache): 28 def __init__(self, **opts): 29 self._opts = dict(opts) 30 self._opts['socket_timeout'] = self._opts.get('socket_timeout', 31 defaults.REDIS_TIMEOUT) 32 self._opts['socket_connect_timeout'] = self._opts.get( 33 'socket_connect_timeout', defaults.REDIS_CONNECT_TIMEOUT) 34 self._opts['encoding'] = 'utf-8' 35 self._pool = None 36 37 async def setup(self): 38 url = self._opts['url'] 39 opts = dict((k,v) for k, v in self._opts.items() if k != 'url') 40 self._pool = aioredis.from_url(url, **opts) 41 42 async def get(self, key): 43 assert self._pool is not None 44 key = key.encode('utf-8') 45 res = await self._pool.zrevrange(key, 0, 0, withscores=True) 46 if not res: 47 return None 48 packed, ts = res[0] # pylint: disable=invalid-name 49 entry = unpack_entry(packed) 50 return CacheEntry(ts=ts, pol_id=entry.pol_id, pol_body=entry.pol_body) 51 52 async def set(self, key, value): 53 assert self._pool is not None 54 packed = pack_entry(value) 55 ts = value.ts # pylint: disable=invalid-name 56 key = key.encode('utf-8') 57 58 # Write 59 async with self._pool.pipeline(transaction=True) as pipe: 60 pipe.zadd(key, {packed: ts}) 61 pipe.zremrangebyrank(key, 0, -2) 62 await pipe.execute() 63 64 async def scan(self, token, amount_hint): 65 assert self._pool is not None 66 if token is None: 67 token = b'0' 68 69 new_token, keys = await self._pool.scan(cursor=token, count=amount_hint) 70 if not new_token: 71 new_token = None 72 73 result = [] 74 for key in keys: 75 key = key.decode('utf-8') 76 if key != '_metadata': 77 result.append((key, await self.get(key))) 78 return new_token, result 79 80 async def get_proactive_fetch_ts(self): 81 assert self._pool is not None 82 val = await self._pool.hget('_metadata', 'proactive_fetch_ts') 83 return 0 if not val else float(val.decode('utf-8')) 84 85 async def set_proactive_fetch_ts(self, timestamp): 86 assert self._pool is not None 87 val = str(timestamp).encode('utf-8') 88 await self._pool.hset('_metadata', 'proactive_fetch_ts', val) 89 90 async def teardown(self): 91 assert self._pool is not None 92 await self._pool.close() 93