1import asyncio
2import pickle
3from urllib.parse import urlparse
4
5try:
6    import aioredis
7except ImportError:
8    aioredis = None
9
10from .asyncio_pubsub_manager import AsyncPubSubManager
11
12
13def _parse_redis_url(url):
14    p = urlparse(url)
15    if p.scheme not in {'redis', 'rediss'}:
16        raise ValueError('Invalid redis url')
17    ssl = p.scheme == 'rediss'
18    host = p.hostname or 'localhost'
19    port = p.port or 6379
20    password = p.password
21    if p.path:
22        db = int(p.path[1:])
23    else:
24        db = 0
25    return host, port, password, db, ssl
26
27
28class AsyncRedisManager(AsyncPubSubManager):  # pragma: no cover
29    """Redis based client manager for asyncio servers.
30
31    This class implements a Redis backend for event sharing across multiple
32    processes. Only kept here as one more example of how to build a custom
33    backend, since the kombu backend is perfectly adequate to support a Redis
34    message queue.
35
36    To use a Redis backend, initialize the :class:`Server` instance as
37    follows::
38
39        server = socketio.Server(client_manager=socketio.AsyncRedisManager(
40            'redis://hostname:port/0'))
41
42    :param url: The connection URL for the Redis server. For a default Redis
43                store running on the same host, use ``redis://``.  To use an
44                SSL connection, use ``rediss://``.
45    :param channel: The channel name on which the server sends and receives
46                    notifications. Must be the same in all the servers.
47    :param write_only: If set to ``True``, only initialize to emit events. The
48                       default of ``False`` initializes the class for emitting
49                       and receiving.
50    """
51    name = 'aioredis'
52
53    def __init__(self, url='redis://localhost:6379/0', channel='socketio',
54                 write_only=False, logger=None):
55        if aioredis is None:
56            raise RuntimeError('Redis package is not installed '
57                               '(Run "pip install aioredis" in your '
58                               'virtualenv).')
59        (
60            self.host, self.port, self.password, self.db, self.ssl
61        ) = _parse_redis_url(url)
62        self.pub = None
63        self.sub = None
64        super().__init__(channel=channel, write_only=write_only, logger=logger)
65
66    async def _publish(self, data):
67        retry = True
68        while True:
69            try:
70                if self.pub is None:
71                    self.pub = await aioredis.create_redis(
72                        (self.host, self.port), db=self.db,
73                        password=self.password, ssl=self.ssl
74                    )
75                return await self.pub.publish(self.channel,
76                                              pickle.dumps(data))
77            except (aioredis.RedisError, OSError):
78                if retry:
79                    self._get_logger().error('Cannot publish to redis... '
80                                             'retrying')
81                    self.pub = None
82                    retry = False
83                else:
84                    self._get_logger().error('Cannot publish to redis... '
85                                             'giving up')
86                    break
87
88    async def _listen(self):
89        retry_sleep = 1
90        while True:
91            try:
92                if self.sub is None:
93                    self.sub = await aioredis.create_redis(
94                        (self.host, self.port), db=self.db,
95                        password=self.password, ssl=self.ssl
96                    )
97                self.ch = (await self.sub.subscribe(self.channel))[0]
98                retry_sleep = 1
99                return await self.ch.get()
100            except (aioredis.RedisError, OSError):
101                self._get_logger().error('Cannot receive from redis... '
102                                         'retrying in '
103                                         '{} secs'.format(retry_sleep))
104                self.sub = None
105                await asyncio.sleep(retry_sleep)
106                retry_sleep *= 2
107                if retry_sleep > 60:
108                    retry_sleep = 60
109