1import asyncio
2import collections
3import types
4import warnings
5import sys
6
7from .connection import create_connection, _PUBSUB_COMMANDS
8from .log import logger
9from .util import parse_url, CloseEvent
10from .errors import PoolClosedError
11from .abc import AbcPool
12from .locks import Lock
13
14
15async def create_pool(address, *, db=None, password=None, ssl=None,
16                      encoding=None, minsize=1, maxsize=10,
17                      parser=None, loop=None, create_connection_timeout=None,
18                      pool_cls=None, connection_cls=None):
19    # FIXME: rewrite docstring
20    """Creates Redis Pool.
21
22    By default it creates pool of Redis instances, but it is
23    also possible to create pool of plain connections by passing
24    ``lambda conn: conn`` as commands_factory.
25
26    *commands_factory* parameter is deprecated since v0.2.9
27
28    All arguments are the same as for create_connection.
29
30    Returns RedisPool instance or a pool_cls if it is given.
31    """
32    if pool_cls:
33        assert issubclass(pool_cls, AbcPool),\
34                "pool_class does not meet the AbcPool contract"
35        cls = pool_cls
36    else:
37        cls = ConnectionsPool
38    if isinstance(address, str):
39        address, options = parse_url(address)
40        db = options.setdefault('db', db)
41        password = options.setdefault('password', password)
42        encoding = options.setdefault('encoding', encoding)
43        create_connection_timeout = options.setdefault(
44            'timeout', create_connection_timeout)
45        if 'ssl' in options:
46            assert options['ssl'] or (not options['ssl'] and not ssl), (
47                "Conflicting ssl options are set", options['ssl'], ssl)
48            ssl = ssl or options['ssl']
49        # TODO: minsize/maxsize
50
51    pool = cls(address, db, password, encoding,
52               minsize=minsize, maxsize=maxsize,
53               ssl=ssl, parser=parser,
54               create_connection_timeout=create_connection_timeout,
55               connection_cls=connection_cls,
56               loop=loop)
57    try:
58        await pool._fill_free(override_min=False)
59    except Exception:
60        pool.close()
61        await pool.wait_closed()
62        raise
63    return pool
64
65
66class ConnectionsPool(AbcPool):
67    """Redis connections pool."""
68
69    def __init__(self, address, db=None, password=None, encoding=None,
70                 *, minsize, maxsize, ssl=None, parser=None,
71                 create_connection_timeout=None,
72                 connection_cls=None,
73                 loop=None):
74        assert isinstance(minsize, int) and minsize >= 0, (
75            "minsize must be int >= 0", minsize, type(minsize))
76        assert maxsize is not None, "Arbitrary pool size is disallowed."
77        assert isinstance(maxsize, int) and maxsize > 0, (
78            "maxsize must be int > 0", maxsize, type(maxsize))
79        assert minsize <= maxsize, (
80            "Invalid pool min/max sizes", minsize, maxsize)
81        if loop is not None and sys.version_info >= (3, 8):
82            warnings.warn("The loop argument is deprecated",
83                          DeprecationWarning)
84        self._address = address
85        self._db = db
86        self._password = password
87        self._ssl = ssl
88        self._encoding = encoding
89        self._parser_class = parser
90        self._minsize = minsize
91        self._create_connection_timeout = create_connection_timeout
92        self._pool = collections.deque(maxlen=maxsize)
93        self._used = set()
94        self._acquiring = 0
95        self._cond = asyncio.Condition(lock=Lock())
96        self._close_state = CloseEvent(self._do_close)
97        self._pubsub_conn = None
98        self._connection_cls = connection_cls
99
100    def __repr__(self):
101        return '<{} [db:{}, size:[{}:{}], free:{}]>'.format(
102            self.__class__.__name__, self.db,
103            self.minsize, self.maxsize, self.freesize)
104
105    @property
106    def minsize(self):
107        """Minimum pool size."""
108        return self._minsize
109
110    @property
111    def maxsize(self):
112        """Maximum pool size."""
113        return self._pool.maxlen
114
115    @property
116    def size(self):
117        """Current pool size."""
118        return self.freesize + len(self._used) + self._acquiring
119
120    @property
121    def freesize(self):
122        """Current number of free connections."""
123        return len(self._pool)
124
125    @property
126    def address(self):
127        return self._address
128
129    async def clear(self):
130        """Clear pool connections.
131
132        Close and remove all free connections.
133        """
134        async with self._cond:
135            await self._do_clear()
136
137    async def _do_clear(self):
138        waiters = []
139        while self._pool:
140            conn = self._pool.popleft()
141            conn.close()
142            waiters.append(conn.wait_closed())
143        await asyncio.gather(*waiters)
144
145    async def _do_close(self):
146        async with self._cond:
147            assert not self._acquiring, self._acquiring
148            waiters = []
149            while self._pool:
150                conn = self._pool.popleft()
151                conn.close()
152                waiters.append(conn.wait_closed())
153            for conn in self._used:
154                conn.close()
155                waiters.append(conn.wait_closed())
156            await asyncio.gather(*waiters)
157            # TODO: close _pubsub_conn connection
158            logger.debug("Closed %d connection(s)", len(waiters))
159
160    def close(self):
161        """Close all free and in-progress connections and mark pool as closed.
162        """
163        if not self._close_state.is_set():
164            self._close_state.set()
165
166    @property
167    def closed(self):
168        """True if pool is closed."""
169        return self._close_state.is_set()
170
171    async def wait_closed(self):
172        """Wait until pool gets closed."""
173        await self._close_state.wait()
174
175    @property
176    def db(self):
177        """Currently selected db index."""
178        return self._db or 0
179
180    @property
181    def encoding(self):
182        """Current set codec or None."""
183        return self._encoding
184
185    def execute(self, command, *args, **kw):
186        """Executes redis command in a free connection and returns
187        future waiting for result.
188
189        Picks connection from free pool and send command through
190        that connection.
191        If no connection is found, returns coroutine waiting for
192        free connection to execute command.
193        """
194        conn, address = self.get_connection(command, args)
195        if conn is not None:
196            fut = conn.execute(command, *args, **kw)
197            return self._check_result(fut, command, args, kw)
198        else:
199            coro = self._wait_execute(address, command, args, kw)
200            return self._check_result(coro, command, args, kw)
201
202    def execute_pubsub(self, command, *channels):
203        """Executes Redis (p)subscribe/(p)unsubscribe commands.
204
205        ConnectionsPool picks separate connection for pub/sub
206        and uses it until explicitly closed or disconnected
207        (unsubscribing from all channels/patterns will leave connection
208         locked for pub/sub use).
209
210        There is no auto-reconnect for this PUB/SUB connection.
211
212        Returns asyncio.gather coroutine waiting for all channels/patterns
213        to receive answers.
214        """
215        conn, address = self.get_connection(command)
216        if conn is not None:
217            return conn.execute_pubsub(command, *channels)
218        else:
219            return self._wait_execute_pubsub(address, command, channels, {})
220
221    def get_connection(self, command, args=()):
222        """Get free connection from pool.
223
224        Returns connection.
225        """
226        # TODO: find a better way to determine if connection is free
227        #       and not havily used.
228        command = command.upper().strip()
229        is_pubsub = command in _PUBSUB_COMMANDS
230        if is_pubsub and self._pubsub_conn:
231            if not self._pubsub_conn.closed:
232                return self._pubsub_conn, self._pubsub_conn.address
233            self._pubsub_conn = None
234        for i in range(self.freesize):
235            conn = self._pool[0]
236            self._pool.rotate(1)
237            if conn.closed:  # or conn._waiters: (eg: busy connection)
238                continue
239            if conn.in_pubsub:
240                continue
241            if is_pubsub:
242                self._pubsub_conn = conn
243                self._pool.remove(conn)
244                self._used.add(conn)
245            return conn, conn.address
246        return None, self._address  # figure out
247
248    def _check_result(self, fut, *data):
249        """Hook to check result or catch exception (like MovedError).
250
251        This method can be coroutine.
252        """
253        return fut
254
255    async def _wait_execute(self, address, command, args, kw):
256        """Acquire connection and execute command."""
257        conn = await self.acquire(command, args)
258        try:
259            return (await conn.execute(command, *args, **kw))
260        finally:
261            self.release(conn)
262
263    async def _wait_execute_pubsub(self, address, command, args, kw):
264        if self.closed:
265            raise PoolClosedError("Pool is closed")
266        assert self._pubsub_conn is None or self._pubsub_conn.closed, (
267            "Expected no or closed connection", self._pubsub_conn)
268        async with self._cond:
269            if self.closed:
270                raise PoolClosedError("Pool is closed")
271            if self._pubsub_conn is None or self._pubsub_conn.closed:
272                conn = await self._create_new_connection(address)
273                self._pubsub_conn = conn
274            conn = self._pubsub_conn
275            return (await conn.execute_pubsub(command, *args, **kw))
276
277    async def select(self, db):
278        """Changes db index for all free connections.
279
280        All previously acquired connections will be closed when released.
281        """
282        res = True
283        async with self._cond:
284            for i in range(self.freesize):
285                res = res and (await self._pool[i].select(db))
286            self._db = db
287        return res
288
289    async def auth(self, password):
290        self._password = password
291        async with self._cond:
292            for i in range(self.freesize):
293                await self._pool[i].auth(password)
294
295    @property
296    def in_pubsub(self):
297        if self._pubsub_conn and not self._pubsub_conn.closed:
298            return self._pubsub_conn.in_pubsub
299        return 0
300
301    @property
302    def pubsub_channels(self):
303        if self._pubsub_conn and not self._pubsub_conn.closed:
304            return self._pubsub_conn.pubsub_channels
305        return types.MappingProxyType({})
306
307    @property
308    def pubsub_patterns(self):
309        if self._pubsub_conn and not self._pubsub_conn.closed:
310            return self._pubsub_conn.pubsub_patterns
311        return types.MappingProxyType({})
312
313    async def acquire(self, command=None, args=()):
314        """Acquires a connection from free pool.
315
316        Creates new connection if needed.
317        """
318        if self.closed:
319            raise PoolClosedError("Pool is closed")
320        async with self._cond:
321            if self.closed:
322                raise PoolClosedError("Pool is closed")
323            while True:
324                await self._fill_free(override_min=True)
325                if self.freesize:
326                    conn = self._pool.popleft()
327                    assert not conn.closed, conn
328                    assert conn not in self._used, (conn, self._used)
329                    self._used.add(conn)
330                    return conn
331                else:
332                    await self._cond.wait()
333
334    def release(self, conn):
335        """Returns used connection back into pool.
336
337        When returned connection has db index that differs from one in pool
338        the connection will be closed and dropped.
339        When queue of free connections is full the connection will be dropped.
340        """
341        assert conn in self._used, (
342            "Invalid connection, maybe from other pool", conn)
343        self._used.remove(conn)
344        if not conn.closed:
345            if conn.in_transaction:
346                logger.warning(
347                    "Connection %r is in transaction, closing it.", conn)
348                conn.close()
349            elif conn.in_pubsub:
350                logger.warning(
351                    "Connection %r is in subscribe mode, closing it.", conn)
352                conn.close()
353            elif conn._waiters:
354                logger.warning(
355                    "Connection %r has pending commands, closing it.", conn)
356                conn.close()
357            elif conn.db == self.db:
358                if self.maxsize and self.freesize < self.maxsize:
359                    self._pool.append(conn)
360                else:
361                    # consider this connection as old and close it.
362                    conn.close()
363            else:
364                conn.close()
365        # FIXME: check event loop is not closed
366        asyncio.ensure_future(self._wakeup())
367
368    def _drop_closed(self):
369        for i in range(self.freesize):
370            conn = self._pool[0]
371            if conn.closed:
372                self._pool.popleft()
373            else:
374                self._pool.rotate(-1)
375
376    async def _fill_free(self, *, override_min):
377        # drop closed connections first
378        self._drop_closed()
379        # address = self._address
380        while self.size < self.minsize:
381            self._acquiring += 1
382            try:
383                conn = await self._create_new_connection(self._address)
384                # check the healthy of that connection, if
385                # something went wrong just trigger the Exception
386                await conn.execute('ping')
387                self._pool.append(conn)
388            finally:
389                self._acquiring -= 1
390                # connection may be closed at yield point
391                self._drop_closed()
392        if self.freesize:
393            return
394        if override_min:
395            while not self._pool and self.size < self.maxsize:
396                self._acquiring += 1
397                try:
398                    conn = await self._create_new_connection(self._address)
399                    self._pool.append(conn)
400                finally:
401                    self._acquiring -= 1
402                    # connection may be closed at yield point
403                    self._drop_closed()
404
405    def _create_new_connection(self, address):
406        return create_connection(address,
407                                 db=self._db,
408                                 password=self._password,
409                                 ssl=self._ssl,
410                                 encoding=self._encoding,
411                                 parser=self._parser_class,
412                                 timeout=self._create_connection_timeout,
413                                 connection_cls=self._connection_cls,
414                                 )
415
416    async def _wakeup(self, closing_conn=None):
417        async with self._cond:
418            self._cond.notify()
419        if closing_conn is not None:
420            await closing_conn.wait_closed()
421
422    def __enter__(self):
423        raise RuntimeError(
424            "'await' should be used as a context manager expression")
425
426    def __exit__(self, *args):
427        pass    # pragma: nocover
428
429    def __await__(self):
430        # To make `with await pool` work
431        conn = yield from self.acquire().__await__()
432        return _ConnectionContextManager(self, conn)
433
434    def get(self):
435        '''Return async context manager for working with connection.
436
437        async with pool.get() as conn:
438            await conn.execute('get', 'my-key')
439        '''
440        return _AsyncConnectionContextManager(self)
441
442
443class _ConnectionContextManager:
444
445    __slots__ = ('_pool', '_conn')
446
447    def __init__(self, pool, conn):
448        self._pool = pool
449        self._conn = conn
450
451    def __enter__(self):
452        return self._conn
453
454    def __exit__(self, exc_type, exc_value, tb):
455        try:
456            self._pool.release(self._conn)
457        finally:
458            self._pool = None
459            self._conn = None
460
461
462class _AsyncConnectionContextManager:
463
464    __slots__ = ('_pool', '_conn')
465
466    def __init__(self, pool):
467        self._pool = pool
468        self._conn = None
469
470    async def __aenter__(self):
471        conn = await self._pool.acquire()
472        self._conn = conn
473        return self._conn
474
475    async def __aexit__(self, exc_type, exc_value, tb):
476        try:
477            self._pool.release(self._conn)
478        finally:
479            self._pool = None
480            self._conn = None
481