1from __future__ import unicode_literals
2from distutils.version import StrictVersion
3from itertools import chain
4from time import time
5import errno
6import io
7import os
8import socket
9import threading
10import warnings
12from redis._compat import (xrange, imap, unicode, long,
13                           nativestr, basestring, iteritems,
14                           LifoQueue, Empty, Full, urlparse, parse_qs,
15                           recv, recv_into, unquote, BlockingIOError,
16                           sendall, shutdown, ssl_wrap_socket)
17from redis.exceptions import (
18    AuthenticationError,
19    AuthenticationWrongNumberOfArgsError,
20    BusyLoadingError,
21    ChildDeadlockedError,
22    ConnectionError,
23    DataError,
24    ExecAbortError,
25    InvalidResponse,
26    NoPermissionError,
27    NoScriptError,
28    ReadOnlyError,
29    RedisError,
30    ResponseError,
31    TimeoutError,
33from redis.utils import HIREDIS_AVAILABLE
36    import ssl
37    ssl_available = True
38except ImportError:
39    ssl_available = False
42    BlockingIOError: errno.EWOULDBLOCK,
45if ssl_available:
46    if hasattr(ssl, 'SSLWantReadError'):
49    else:
52# In Python 2.7 a socket.error is raised for a nonblocking read.
53# The _compat module aliases BlockingIOError to socket.error to be
54# Python 2/3 compatible.
55# However this means that all socket.error exceptions need to be handled
56# properly within these exception handlers.
57# We need to make sure socket.error is included in these handlers and
58# provide a dummy error number that will never match a real exception.
60    NONBLOCKING_EXCEPTION_ERROR_NUMBERS[socket.error] = -999999
65    import hiredis
67    hiredis_version = StrictVersion(hiredis.__version__)
69        hiredis_version >= StrictVersion('0.1.3')
71        hiredis_version >= StrictVersion('0.1.4')
73        hiredis_version >= StrictVersion('1.0.0')
76        msg = ("redis-py works best with hiredis >= 0.1.4. You're running "
77               "hiredis %s. Please consider upgrading." % hiredis.__version__)
78        warnings.warn(msg)
81    # only use byte buffer if hiredis supports it
85SYM_STAR = b'*'
86SYM_DOLLAR = b'$'
87SYM_CRLF = b'\r\n'
88SYM_EMPTY = b''
90SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
92SENTINEL = object()
95class Encoder(object):
96    "Encode strings to bytes-like and decode bytes-like to strings"
98    def __init__(self, encoding, encoding_errors, decode_responses):
99        self.encoding = encoding
100        self.encoding_errors = encoding_errors
101        self.decode_responses = decode_responses
103    def encode(self, value):
104        "Return a bytestring or bytes-like representation of the value"
105        if isinstance(value, (bytes, memoryview)):
106            return value
107        elif isinstance(value, bool):
108            # special case bool since it is a subclass of int
109            raise DataError("Invalid input of type: 'bool'. Convert to a "
110                            "bytes, string, int or float first.")
111        elif isinstance(value, float):
112            value = repr(value).encode()
113        elif isinstance(value, (int, long)):
114            # python 2 repr() on longs is '123L', so use str() instead
115            value = str(value).encode()
116        elif not isinstance(value, basestring):
117            # a value we don't know how to deal with. throw an error
118            typename = type(value).__name__
119            raise DataError("Invalid input of type: '%s'. Convert to a "
120                            "bytes, string, int or float first." % typename)
121        if isinstance(value, unicode):
122            value = value.encode(self.encoding, self.encoding_errors)
123        return value
125    def decode(self, value, force=False):
126        "Return a unicode string from the bytes-like representation"
127        if self.decode_responses or force:
128            if isinstance(value, memoryview):
129                value = value.tobytes()
130            if isinstance(value, bytes):
131                value = value.decode(self.encoding, self.encoding_errors)
132        return value
135class BaseParser(object):
137        'ERR': {
138            'max number of clients reached': ConnectionError,
139            'Client sent AUTH, but no password is set': AuthenticationError,
140            'invalid password': AuthenticationError,
141            # some Redis server versions report invalid command syntax
142            # in lowercase
143            'wrong number of arguments for \'auth\' command':
144                AuthenticationWrongNumberOfArgsError,
145            # some Redis server versions report invalid command syntax
146            # in uppercase
147            'wrong number of arguments for \'AUTH\' command':
148                AuthenticationWrongNumberOfArgsError,
149        },
150        'EXECABORT': ExecAbortError,
151        'LOADING': BusyLoadingError,
152        'NOSCRIPT': NoScriptError,
153        'READONLY': ReadOnlyError,
154        'NOAUTH': AuthenticationError,
155        'NOPERM': NoPermissionError,
156    }
158    def parse_error(self, response):
159        "Parse an error response"
160        error_code = response.split(' ')[0]
161        if error_code in self.EXCEPTION_CLASSES:
162            response = response[len(error_code) + 1:]
163            exception_class = self.EXCEPTION_CLASSES[error_code]
164            if isinstance(exception_class, dict):
165                exception_class = exception_class.get(response, ResponseError)
166            return exception_class(response)
167        return ResponseError(response)
170class SocketBuffer(object):
171    def __init__(self, socket, socket_read_size, socket_timeout):
172        self._sock = socket
173        self.socket_read_size = socket_read_size
174        self.socket_timeout = socket_timeout
175        self._buffer = io.BytesIO()
176        # number of bytes written to the buffer from the socket
177        self.bytes_written = 0
178        # number of bytes read from the buffer
179        self.bytes_read = 0
181    @property
182    def length(self):
183        return self.bytes_written - self.bytes_read
185    def _read_from_socket(self, length=None, timeout=SENTINEL,
186                          raise_on_timeout=True):
187        sock = self._sock
188        socket_read_size = self.socket_read_size
189        buf = self._buffer
190        buf.seek(self.bytes_written)
191        marker = 0
192        custom_timeout = timeout is not SENTINEL
194        try:
195            if custom_timeout:
196                sock.settimeout(timeout)
197            while True:
198                data = recv(self._sock, socket_read_size)
199                # an empty string indicates the server shutdown the socket
200                if isinstance(data, bytes) and len(data) == 0:
201                    raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
202                buf.write(data)
203                data_length = len(data)
204                self.bytes_written += data_length
205                marker += data_length
207                if length is not None and length > marker:
208                    continue
209                return True
210        except socket.timeout:
211            if raise_on_timeout:
212                raise TimeoutError("Timeout reading from socket")
213            return False
214        except NONBLOCKING_EXCEPTIONS as ex:
215            # if we're in nonblocking mode and the recv raises a
216            # blocking error, simply return False indicating that
217            # there's no data to be read. otherwise raise the
218            # original exception.
219            allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
220            if not raise_on_timeout and ex.errno == allowed:
221                return False
222            raise ConnectionError("Error while reading from socket: %s" %
223                                  (ex.args,))
224        finally:
225            if custom_timeout:
226                sock.settimeout(self.socket_timeout)
228    def can_read(self, timeout):
229        return bool(self.length) or \
230            self._read_from_socket(timeout=timeout,
231                                   raise_on_timeout=False)
233    def read(self, length):
234        length = length + 2  # make sure to read the \r\n terminator
235        # make sure we've read enough data from the socket
236        if length > self.length:
237            self._read_from_socket(length - self.length)
239        self._buffer.seek(self.bytes_read)
240        data = self._buffer.read(length)
241        self.bytes_read += len(data)
243        # purge the buffer when we've consumed it all so it doesn't
244        # grow forever
245        if self.bytes_read == self.bytes_written:
246            self.purge()
248        return data[:-2]
250    def readline(self):
251        buf = self._buffer
252        buf.seek(self.bytes_read)
253        data = buf.readline()
254        while not data.endswith(SYM_CRLF):
255            # there's more data in the socket that we need
256            self._read_from_socket()
257            buf.seek(self.bytes_read)
258            data = buf.readline()
260        self.bytes_read += len(data)
262        # purge the buffer when we've consumed it all so it doesn't
263        # grow forever
264        if self.bytes_read == self.bytes_written:
265            self.purge()
267        return data[:-2]
269    def purge(self):
270        self._buffer.seek(0)
271        self._buffer.truncate()
272        self.bytes_written = 0
273        self.bytes_read = 0
275    def close(self):
276        try:
277            self.purge()
278            self._buffer.close()
279        except Exception:
280            # issue #633 suggests the purge/close somehow raised a
281            # BadFileDescriptor error. Perhaps the client ran out of
282            # memory or something else? It's probably OK to ignore
283            # any error being raised from purge/close since we're
284            # removing the reference to the instance below.
285            pass
286        self._buffer = None
287        self._sock = None
290class PythonParser(BaseParser):
291    "Plain Python parsing class"
292    def __init__(self, socket_read_size):
293        self.socket_read_size = socket_read_size
294        self.encoder = None
295        self._sock = None
296        self._buffer = None
298    def __del__(self):
299        try:
300            self.on_disconnect()
301        except Exception:
302            pass
304    def on_connect(self, connection):
305        "Called when the socket connects"
306        self._sock = connection._sock
307        self._buffer = SocketBuffer(self._sock,
308                                    self.socket_read_size,
309                                    connection.socket_timeout)
310        self.encoder = connection.encoder
312    def on_disconnect(self):
313        "Called when the socket disconnects"
314        self._sock = None
315        if self._buffer is not None:
316            self._buffer.close()
317            self._buffer = None
318        self.encoder = None
320    def can_read(self, timeout):
321        return self._buffer and self._buffer.can_read(timeout)
323    def read_response(self):
324        raw = self._buffer.readline()
325        if not raw:
326            raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
328        byte, response = raw[:1], raw[1:]
330        if byte not in (b'-', b'+', b':', b'$', b'*'):
331            raise InvalidResponse("Protocol Error: %r" % raw)
333        # server returned an error
334        if byte == b'-':
335            response = nativestr(response)
336            error = self.parse_error(response)
337            # if the error is a ConnectionError, raise immediately so the user
338            # is notified
339            if isinstance(error, ConnectionError):
340                raise error
341            # otherwise, we're dealing with a ResponseError that might belong
342            # inside a pipeline response. the connection's read_response()
343            # and/or the pipeline's execute() will raise this error if
344            # necessary, so just return the exception instance here.
345            return error
346        # single value
347        elif byte == b'+':
348            pass
349        # int value
350        elif byte == b':':
351            response = long(response)
352        # bulk response
353        elif byte == b'$':
354            length = int(response)
355            if length == -1:
356                return None
357            response = self._buffer.read(length)
358        # multi-bulk response
359        elif byte == b'*':
360            length = int(response)
361            if length == -1:
362                return None
363            response = [self.read_response() for i in xrange(length)]
364        if isinstance(response, bytes):
365            response = self.encoder.decode(response)
366        return response
369class HiredisParser(BaseParser):
370    "Parser class for connections using Hiredis"
371    def __init__(self, socket_read_size):
372        if not HIREDIS_AVAILABLE:
373            raise RedisError("Hiredis is not installed")
374        self.socket_read_size = socket_read_size
377            self._buffer = bytearray(socket_read_size)
379    def __del__(self):
380        try:
381            self.on_disconnect()
382        except Exception:
383            pass
385    def on_connect(self, connection):
386        self._sock = connection._sock
387        self._socket_timeout = connection.socket_timeout
388        kwargs = {
389            'protocolError': InvalidResponse,
390            'replyError': self.parse_error,
391        }
393        # hiredis < 0.1.3 doesn't support functions that create exceptions
395            kwargs['replyError'] = ResponseError
397        if connection.encoder.decode_responses:
398            kwargs['encoding'] = connection.encoder.encoding
400            kwargs['errors'] = connection.encoder.encoding_errors
401        self._reader = hiredis.Reader(**kwargs)
402        self._next_response = False
404    def on_disconnect(self):
405        self._sock = None
406        self._reader = None
407        self._next_response = False
409    def can_read(self, timeout):
410        if not self._reader:
411            raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
413        if self._next_response is False:
414            self._next_response = self._reader.gets()
415            if self._next_response is False:
416                return self.read_from_socket(timeout=timeout,
417                                             raise_on_timeout=False)
418        return True
420    def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
421        sock = self._sock
422        custom_timeout = timeout is not SENTINEL
423        try:
424            if custom_timeout:
425                sock.settimeout(timeout)
426            if HIREDIS_USE_BYTE_BUFFER:
427                bufflen = recv_into(self._sock, self._buffer)
428                if bufflen == 0:
429                    raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
430                self._reader.feed(self._buffer, 0, bufflen)
431            else:
432                buffer = recv(self._sock, self.socket_read_size)
433                # an empty string indicates the server shutdown the socket
434                if not isinstance(buffer, bytes) or len(buffer) == 0:
435                    raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
436                self._reader.feed(buffer)
437            # data was read from the socket and added to the buffer.
438            # return True to indicate that data was read.
439            return True
440        except socket.timeout:
441            if raise_on_timeout:
442                raise TimeoutError("Timeout reading from socket")
443            return False
444        except NONBLOCKING_EXCEPTIONS as ex:
445            # if we're in nonblocking mode and the recv raises a
446            # blocking error, simply return False indicating that
447            # there's no data to be read. otherwise raise the
448            # original exception.
449            allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
450            if not raise_on_timeout and ex.errno == allowed:
451                return False
452            raise ConnectionError("Error while reading from socket: %s" %
453                                  (ex.args,))
454        finally:
455            if custom_timeout:
456                sock.settimeout(self._socket_timeout)
458    def read_response(self):
459        if not self._reader:
460            raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
462        # _next_response might be cached from a can_read() call
463        if self._next_response is not False:
464            response = self._next_response
465            self._next_response = False
466            return response
468        response = self._reader.gets()
469        while response is False:
470            self.read_from_socket()
471            response = self._reader.gets()
472        # if an older version of hiredis is installed, we need to attempt
473        # to convert ResponseErrors to their appropriate types.
475            if isinstance(response, ResponseError):
476                response = self.parse_error(response.args[0])
477            elif isinstance(response, list) and response and \
478                    isinstance(response[0], ResponseError):
479                response[0] = self.parse_error(response[0].args[0])
480        # if the response is a ConnectionError or the response is a list and
481        # the first item is a ConnectionError, raise it as something bad
482        # happened
483        if isinstance(response, ConnectionError):
484            raise response
485        elif isinstance(response, list) and response and \
486                isinstance(response[0], ConnectionError):
487            raise response[0]
488        return response
492    DefaultParser = HiredisParser
494    DefaultParser = PythonParser
497class Connection(object):
498    "Manages TCP communication to and from a Redis server"
500    def __init__(self, host='localhost', port=6379, db=0, password=None,
501                 socket_timeout=None, socket_connect_timeout=None,
502                 socket_keepalive=False, socket_keepalive_options=None,
503                 socket_type=0, retry_on_timeout=False, encoding='utf-8',
504                 encoding_errors='strict', decode_responses=False,
505                 parser_class=DefaultParser, socket_read_size=65536,
506                 health_check_interval=0, client_name=None, username=None):
507        self.pid = os.getpid()
508        self.host = host
509        self.port = int(port)
510        self.db = db
511        self.username = username
512        self.client_name = client_name
513        self.password = password
514        self.socket_timeout = socket_timeout
515        self.socket_connect_timeout = socket_connect_timeout or socket_timeout
516        self.socket_keepalive = socket_keepalive
517        self.socket_keepalive_options = socket_keepalive_options or {}
518        self.socket_type = socket_type
519        self.retry_on_timeout = retry_on_timeout
520        self.health_check_interval = health_check_interval
521        self.next_health_check = 0
522        self.encoder = Encoder(encoding, encoding_errors, decode_responses)
523        self._sock = None
524        self._parser = parser_class(socket_read_size=socket_read_size)
525        self._connect_callbacks = []
526        self._buffer_cutoff = 6000
528    def __repr__(self):
529        repr_args = ','.join(['%s=%s' % (k, v) for k, v in self.repr_pieces()])
530        return '%s<%s>' % (self.__class__.__name__, repr_args)
532    def repr_pieces(self):
533        pieces = [
534            ('host', self.host),
535            ('port', self.port),
536            ('db', self.db)
537        ]
538        if self.client_name:
539            pieces.append(('client_name', self.client_name))
540        return pieces
542    def __del__(self):
543        try:
544            self.disconnect()
545        except Exception:
546            pass
548    def register_connect_callback(self, callback):
549        self._connect_callbacks.append(callback)
551    def clear_connect_callbacks(self):
552        self._connect_callbacks = []
554    def connect(self):
555        "Connects to the Redis server if not already connected"
556        if self._sock:
557            return
558        try:
559            sock = self._connect()
560        except socket.timeout:
561            raise TimeoutError("Timeout connecting to server")
562        except socket.error as e:
563            raise ConnectionError(self._error_message(e))
565        self._sock = sock
566        try:
567            self.on_connect()
568        except RedisError:
569            # clean up after any error in on_connect
570            self.disconnect()
571            raise
573        # run any user callbacks. right now the only internal callback
574        # is for pubsub channel/pattern resubscription
575        for callback in self._connect_callbacks:
576            callback(self)
578    def _connect(self):
579        "Create a TCP socket connection"
580        # we want to mimic what socket.create_connection does to support
581        # ipv4/ipv6, but we want to set options prior to calling
582        # socket.connect()
583        err = None
584        for res in socket.getaddrinfo(self.host, self.port, self.socket_type,
585                                      socket.SOCK_STREAM):
586            family, socktype, proto, canonname, socket_address = res
587            sock = None
588            try:
589                sock = socket.socket(family, socktype, proto)
590                # TCP_NODELAY
591                sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
593                # TCP_KEEPALIVE
594                if self.socket_keepalive:
595                    sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
596                    for k, v in iteritems(self.socket_keepalive_options):
597                        sock.setsockopt(socket.IPPROTO_TCP, k, v)
599                # set the socket_connect_timeout before we connect
600                sock.settimeout(self.socket_connect_timeout)
602                # connect
603                sock.connect(socket_address)
605                # set the socket_timeout now that we're connected
606                sock.settimeout(self.socket_timeout)
607                return sock
609            except socket.error as _:
610                err = _
611                if sock is not None:
612                    sock.close()
614        if err is not None:
615            raise err
616        raise socket.error("socket.getaddrinfo returned an empty list")
618    def _error_message(self, exception):
619        # args for socket.error can either be (errno, "message")
620        # or just "message"
621        if len(exception.args) == 1:
622            return "Error connecting to %s:%s. %s." % \
623                (self.host, self.port, exception.args[0])
624        else:
625            return "Error %s connecting to %s:%s. %s." % \
626                (exception.args[0], self.host, self.port, exception.args[1])
628    def on_connect(self):
629        "Initialize the connection, authenticate and select a database"
630        self._parser.on_connect(self)
632        # if username and/or password are set, authenticate
633        if self.username or self.password:
634            if self.username:
635                auth_args = (self.username, self.password or '')
636            else:
637                auth_args = (self.password,)
638            # avoid checking health here -- PING will fail if we try
639            # to check the health prior to the AUTH
640            self.send_command('AUTH', *auth_args, check_health=False)
642            try:
643                auth_response = self.read_response()
644            except AuthenticationWrongNumberOfArgsError:
645                # a username and password were specified but the Redis
646                # server seems to be < 6.0.0 which expects a single password
647                # arg. retry auth with just the password.
648                # https://github.com/andymccurdy/redis-py/issues/1274
649                self.send_command('AUTH', self.password, check_health=False)
650                auth_response = self.read_response()
652            if nativestr(auth_response) != 'OK':
653                raise AuthenticationError('Invalid Username or Password')
655        # if a client_name is given, set it
656        if self.client_name:
657            self.send_command('CLIENT', 'SETNAME', self.client_name)
658            if nativestr(self.read_response()) != 'OK':
659                raise ConnectionError('Error setting client name')
661        # if a database is specified, switch to it
662        if self.db:
663            self.send_command('SELECT', self.db)
664            if nativestr(self.read_response()) != 'OK':
665                raise ConnectionError('Invalid Database')
667    def disconnect(self):
668        "Disconnects from the Redis server"
669        self._parser.on_disconnect()
670        if self._sock is None:
671            return
672        try:
673            if os.getpid() == self.pid:
674                shutdown(self._sock, socket.SHUT_RDWR)
675            self._sock.close()
676        except socket.error:
677            pass
678        self._sock = None
680    def check_health(self):
681        "Check the health of the connection with a PING/PONG"
682        if self.health_check_interval and time() > self.next_health_check:
683            try:
684                self.send_command('PING', check_health=False)
685                if nativestr(self.read_response()) != 'PONG':
686                    raise ConnectionError(
687                        'Bad response from PING health check')
688            except (ConnectionError, TimeoutError):
689                self.disconnect()
690                self.send_command('PING', check_health=False)
691                if nativestr(self.read_response()) != 'PONG':
692                    raise ConnectionError(
693                        'Bad response from PING health check')
695    def send_packed_command(self, command, check_health=True):
696        "Send an already packed command to the Redis server"
697        if not self._sock:
698            self.connect()
699        # guard against health check recursion
700        if check_health:
701            self.check_health()
702        try:
703            if isinstance(command, str):
704                command = [command]
705            for item in command:
706                sendall(self._sock, item)
707        except socket.timeout:
708            self.disconnect()
709            raise TimeoutError("Timeout writing to socket")
710        except socket.error as e:
711            self.disconnect()
712            if len(e.args) == 1:
713                errno, errmsg = 'UNKNOWN', e.args[0]
714            else:
715                errno = e.args[0]
716                errmsg = e.args[1]
717            raise ConnectionError("Error %s while writing to socket. %s." %
718                                  (errno, errmsg))
719        except BaseException:
720            self.disconnect()
721            raise
723    def send_command(self, *args, **kwargs):
724        "Pack and send a command to the Redis server"
725        self.send_packed_command(self.pack_command(*args),
726                                 check_health=kwargs.get('check_health', True))
728    def can_read(self, timeout=0):
729        "Poll the socket to see if there's data that can be read."
730        sock = self._sock
731        if not sock:
732            self.connect()
733            sock = self._sock
734        return self._parser.can_read(timeout)
736    def read_response(self):
737        "Read the response from a previously sent command"
738        try:
739            response = self._parser.read_response()
740        except socket.timeout:
741            self.disconnect()
742            raise TimeoutError("Timeout reading from %s:%s" %
743                               (self.host, self.port))
744        except socket.error as e:
745            self.disconnect()
746            raise ConnectionError("Error while reading from %s:%s : %s" %
747                                  (self.host, self.port, e.args))
748        except BaseException:
749            self.disconnect()
750            raise
752        if self.health_check_interval:
753            self.next_health_check = time() + self.health_check_interval
755        if isinstance(response, ResponseError):
756            raise response
757        return response
759    def pack_command(self, *args):
760        "Pack a series of arguments into the Redis protocol"
761        output = []
762        # the client might have included 1 or more literal arguments in
763        # the command name, e.g., 'CONFIG GET'. The Redis server expects these
764        # arguments to be sent separately, so split the first argument
765        # manually. These arguments should be bytestrings so that they are
766        # not encoded.
767        if isinstance(args[0], unicode):
768            args = tuple(args[0].encode().split()) + args[1:]
769        elif b' ' in args[0]:
770            args = tuple(args[0].split()) + args[1:]
772        buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
774        buffer_cutoff = self._buffer_cutoff
775        for arg in imap(self.encoder.encode, args):
776            # to avoid large string mallocs, chunk the command into the
777            # output list if we're sending large values or memoryviews
778            arg_length = len(arg)
779            if (len(buff) > buffer_cutoff or arg_length > buffer_cutoff
780                    or isinstance(arg, memoryview)):
781                buff = SYM_EMPTY.join(
782                    (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF))
783                output.append(buff)
784                output.append(arg)
785                buff = SYM_CRLF
786            else:
787                buff = SYM_EMPTY.join(
788                    (buff, SYM_DOLLAR, str(arg_length).encode(),
789                     SYM_CRLF, arg, SYM_CRLF))
790        output.append(buff)
791        return output
793    def pack_commands(self, commands):
794        "Pack multiple commands into the Redis protocol"
795        output = []
796        pieces = []
797        buffer_length = 0
798        buffer_cutoff = self._buffer_cutoff
800        for cmd in commands:
801            for chunk in self.pack_command(*cmd):
802                chunklen = len(chunk)
803                if (buffer_length > buffer_cutoff or chunklen > buffer_cutoff
804                        or isinstance(chunk, memoryview)):
805                    output.append(SYM_EMPTY.join(pieces))
806                    buffer_length = 0
807                    pieces = []
809                if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
810                    output.append(chunk)
811                else:
812                    pieces.append(chunk)
813                    buffer_length += chunklen
815        if pieces:
816            output.append(SYM_EMPTY.join(pieces))
817        return output
820class SSLConnection(Connection):
822    def __init__(self, ssl_keyfile=None, ssl_certfile=None,
823                 ssl_cert_reqs='required', ssl_ca_certs=None,
824                 ssl_check_hostname=False, **kwargs):
825        if not ssl_available:
826            raise RedisError("Python wasn't built with SSL support")
828        super(SSLConnection, self).__init__(**kwargs)
830        self.keyfile = ssl_keyfile
831        self.certfile = ssl_certfile
832        if ssl_cert_reqs is None:
833            ssl_cert_reqs = ssl.CERT_NONE
834        elif isinstance(ssl_cert_reqs, basestring):
835            CERT_REQS = {
836                'none': ssl.CERT_NONE,
837                'optional': ssl.CERT_OPTIONAL,
838                'required': ssl.CERT_REQUIRED
839            }
840            if ssl_cert_reqs not in CERT_REQS:
841                raise RedisError(
842                    "Invalid SSL Certificate Requirements Flag: %s" %
843                    ssl_cert_reqs)
844            ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
845        self.cert_reqs = ssl_cert_reqs
846        self.ca_certs = ssl_ca_certs
847        self.check_hostname = ssl_check_hostname
849    def _connect(self):
850        "Wrap the socket with SSL support"
851        sock = super(SSLConnection, self)._connect()
852        if hasattr(ssl, "create_default_context"):
853            context = ssl.create_default_context()
854            context.check_hostname = self.check_hostname
855            context.verify_mode = self.cert_reqs
856            if self.certfile and self.keyfile:
857                context.load_cert_chain(certfile=self.certfile,
858                                        keyfile=self.keyfile)
859            if self.ca_certs:
860                context.load_verify_locations(self.ca_certs)
861            sock = ssl_wrap_socket(context, sock, server_hostname=self.host)
862        else:
863            # In case this code runs in a version which is older than 2.7.9,
864            # we want to fall back to old code
865            sock = ssl_wrap_socket(ssl,
866                                   sock,
867                                   cert_reqs=self.cert_reqs,
868                                   keyfile=self.keyfile,
869                                   certfile=self.certfile,
870                                   ca_certs=self.ca_certs)
871        return sock
874class UnixDomainSocketConnection(Connection):
876    def __init__(self, path='', db=0, username=None, password=None,
877                 socket_timeout=None, encoding='utf-8',
878                 encoding_errors='strict', decode_responses=False,
879                 retry_on_timeout=False,
880                 parser_class=DefaultParser, socket_read_size=65536,
881                 health_check_interval=0, client_name=None):
882        self.pid = os.getpid()
883        self.path = path
884        self.db = db
885        self.username = username
886        self.client_name = client_name
887        self.password = password
888        self.socket_timeout = socket_timeout
889        self.retry_on_timeout = retry_on_timeout
890        self.health_check_interval = health_check_interval
891        self.next_health_check = 0
892        self.encoder = Encoder(encoding, encoding_errors, decode_responses)
893        self._sock = None
894        self._parser = parser_class(socket_read_size=socket_read_size)
895        self._connect_callbacks = []
896        self._buffer_cutoff = 6000
898    def repr_pieces(self):
899        pieces = [
900            ('path', self.path),
901            ('db', self.db),
902        ]
903        if self.client_name:
904            pieces.append(('client_name', self.client_name))
905        return pieces
907    def _connect(self):
908        "Create a Unix domain socket connection"
909        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
910        sock.settimeout(self.socket_timeout)
911        sock.connect(self.path)
912        return sock
914    def _error_message(self, exception):
915        # args for socket.error can either be (errno, "message")
916        # or just "message"
917        if len(exception.args) == 1:
918            return "Error connecting to unix socket: %s. %s." % \
919                (self.path, exception.args[0])
920        else:
921            return "Error %s connecting to unix socket: %s. %s." % \
922                (exception.args[0], self.path, exception.args[1])
925FALSE_STRINGS = ('0', 'F', 'FALSE', 'N', 'NO')
928def to_bool(value):
929    if value is None or value == '':
930        return None
931    if isinstance(value, basestring) and value.upper() in FALSE_STRINGS:
932        return False
933    return bool(value)
937    'socket_timeout': float,
938    'socket_connect_timeout': float,
939    'socket_keepalive': to_bool,
940    'retry_on_timeout': to_bool,
941    'max_connections': int,
942    'health_check_interval': int,
943    'ssl_check_hostname': to_bool,
947class ConnectionPool(object):
948    "Generic connection pool"
949    @classmethod
950    def from_url(cls, url, db=None, decode_components=False, **kwargs):
951        """
952        Return a connection pool configured from the given URL.
954        For example::
956            redis://[[username]:[password]]@localhost:6379/0
957            rediss://[[username]:[password]]@localhost:6379/0
958            unix://[[username]:[password]]@/path/to/socket.sock?db=0
960        Three URL schemes are supported:
962        - ```redis://``
963          <https://www.iana.org/assignments/uri-schemes/prov/redis>`_ creates a
964          normal TCP socket connection
965        - ```rediss://``
966          <https://www.iana.org/assignments/uri-schemes/prov/rediss>`_ creates
967          a SSL wrapped TCP socket connection
968        - ``unix://`` creates a Unix Domain Socket connection
970        There are several ways to specify a database number. The parse function
971        will return the first specified option:
972            1. A ``db`` querystring option, e.g. redis://localhost?db=0
973            2. If using the redis:// scheme, the path argument of the url, e.g.
974               redis://localhost/0
975            3. The ``db`` argument to this function.
977        If none of these options are specified, db=0 is used.
979        The ``decode_components`` argument allows this function to work with
980        percent-encoded URLs. If this argument is set to ``True`` all ``%xx``
981        escapes will be replaced by their single-character equivalents after
982        the URL has been parsed. This only applies to the ``hostname``,
983        ``path``, ``username`` and ``password`` components.
985        Any additional querystring arguments and keyword arguments will be
986        passed along to the ConnectionPool class's initializer. The querystring
987        arguments ``socket_connect_timeout`` and ``socket_timeout`` if supplied
988        are parsed as float values. The arguments ``socket_keepalive`` and
989        ``retry_on_timeout`` are parsed to boolean values that accept
990        True/False, Yes/No values to indicate state. Invalid types cause a
991        ``UserWarning`` to be raised. In the case of conflicting arguments,
992        querystring arguments always win.
994        """
995        url = urlparse(url)
996        url_options = {}
998        for name, value in iteritems(parse_qs(url.query)):
999            if value and len(value) > 0:
1000                parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
1001                if parser:
1002                    try:
1003                        url_options[name] = parser(value[0])
1004                    except (TypeError, ValueError):
1005                        warnings.warn(UserWarning(
1006                            "Invalid value for `%s` in connection URL." % name
1007                        ))
1008                else:
1009                    url_options[name] = value[0]
1011        if decode_components:
1012            username = unquote(url.username) if url.username else None
1013            password = unquote(url.password) if url.password else None
1014            path = unquote(url.path) if url.path else None
1015            hostname = unquote(url.hostname) if url.hostname else None
1016        else:
1017            username = url.username or None
1018            password = url.password or None
1019            path = url.path
1020            hostname = url.hostname
1022        # We only support redis://, rediss:// and unix:// schemes.
1023        if url.scheme == 'unix':
1024            url_options.update({
1025                'username': username,
1026                'password': password,
1027                'path': path,
1028                'connection_class': UnixDomainSocketConnection,
1029            })
1031        elif url.scheme in ('redis', 'rediss'):
1032            url_options.update({
1033                'host': hostname,
1034                'port': int(url.port or 6379),
1035                'username': username,
1036                'password': password,
1037            })
1039            # If there's a path argument, use it as the db argument if a
1040            # querystring value wasn't specified
1041            if 'db' not in url_options and path:
1042                try:
1043                    url_options['db'] = int(path.replace('/', ''))
1044                except (AttributeError, ValueError):
1045                    pass
1047            if url.scheme == 'rediss':
1048                url_options['connection_class'] = SSLConnection
1049        else:
1050            valid_schemes = ', '.join(('redis://', 'rediss://', 'unix://'))
1051            raise ValueError('Redis URL must specify one of the following '
1052                             'schemes (%s)' % valid_schemes)
1054        # last shot at the db value
1055        url_options['db'] = int(url_options.get('db', db or 0))
1057        # update the arguments from the URL values
1058        kwargs.update(url_options)
1060        # backwards compatability
1061        if 'charset' in kwargs:
1062            warnings.warn(DeprecationWarning(
1063                '"charset" is deprecated. Use "encoding" instead'))
1064            kwargs['encoding'] = kwargs.pop('charset')
1065        if 'errors' in kwargs:
1066            warnings.warn(DeprecationWarning(
1067                '"errors" is deprecated. Use "encoding_errors" instead'))
1068            kwargs['encoding_errors'] = kwargs.pop('errors')
1070        return cls(**kwargs)
1072    def __init__(self, connection_class=Connection, max_connections=None,
1073                 **connection_kwargs):
1074        """
1075        Create a connection pool. If max_connections is set, then this
1076        object raises redis.ConnectionError when the pool's limit is reached.
1078        By default, TCP connections are created unless connection_class is
1079        specified. Use redis.UnixDomainSocketConnection for unix sockets.
1081        Any additional keyword arguments are passed to the constructor of
1082        connection_class.
1083        """
1084        max_connections = max_connections or 2 ** 31
1085        if not isinstance(max_connections, (int, long)) or max_connections < 0:
1086            raise ValueError('"max_connections" must be a positive integer')
1088        self.connection_class = connection_class
1089        self.connection_kwargs = connection_kwargs
1090        self.max_connections = max_connections
1092        # a lock to protect the critical section in _checkpid().
1093        # this lock is acquired when the process id changes, such as
1094        # after a fork. during this time, multiple threads in the child
1095        # process could attempt to acquire this lock. the first thread
1096        # to acquire the lock will reset the data structures and lock
1097        # object of this pool. subsequent threads acquiring this lock
1098        # will notice the first thread already did the work and simply
1099        # release the lock.
1100        self._fork_lock = threading.Lock()
1101        self.reset()
1103    def __repr__(self):
1104        return "%s<%s>" % (
1105            type(self).__name__,
1106            repr(self.connection_class(**self.connection_kwargs)),
1107        )
1109    def reset(self):
1110        self._lock = threading.Lock()
1111        self._created_connections = 0
1112        self._available_connections = []
1113        self._in_use_connections = set()
1115        # this must be the last operation in this method. while reset() is
1116        # called when holding _fork_lock, other threads in this process
1117        # can call _checkpid() which compares self.pid and os.getpid() without
1118        # holding any lock (for performance reasons). keeping this assignment
1119        # as the last operation ensures that those other threads will also
1120        # notice a pid difference and block waiting for the first thread to
1121        # release _fork_lock. when each of these threads eventually acquire
1122        # _fork_lock, they will notice that another thread already called
1123        # reset() and they will immediately release _fork_lock and continue on.
1124        self.pid = os.getpid()
1126    def _checkpid(self):
1127        # _checkpid() attempts to keep ConnectionPool fork-safe on modern
1128        # systems. this is called by all ConnectionPool methods that
1129        # manipulate the pool's state such as get_connection() and release().
1130        #
1131        # _checkpid() determines whether the process has forked by comparing
1132        # the current process id to the process id saved on the ConnectionPool
1133        # instance. if these values are the same, _checkpid() simply returns.
1134        #
1135        # when the process ids differ, _checkpid() assumes that the process
1136        # has forked and that we're now running in the child process. the child
1137        # process cannot use the parent's file descriptors (e.g., sockets).
1138        # therefore, when _checkpid() sees the process id change, it calls
1139        # reset() in order to reinitialize the child's ConnectionPool. this
1140        # will cause the child to make all new connection objects.
1141        #
1142        # _checkpid() is protected by self._fork_lock to ensure that multiple
1143        # threads in the child process do not call reset() multiple times.
1144        #
1145        # there is an extremely small chance this could fail in the following
1146        # scenario:
1147        #   1. process A calls _checkpid() for the first time and acquires
1148        #      self._fork_lock.
1149        #   2. while holding self._fork_lock, process A forks (the fork()
1150        #      could happen in a different thread owned by process A)
1151        #   3. process B (the forked child process) inherits the
1152        #      ConnectionPool's state from the parent. that state includes
1153        #      a locked _fork_lock. process B will not be notified when
1154        #      process A releases the _fork_lock and will thus never be
1155        #      able to acquire the _fork_lock.
1156        #
1157        # to mitigate this possible deadlock, _checkpid() will only wait 5
1158        # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
1159        # that time it is assumed that the child is deadlocked and a
1160        # redis.ChildDeadlockedError error is raised.
1161        if self.pid != os.getpid():
1162            # python 2.7 doesn't support a timeout option to lock.acquire()
1163            # we have to mimic lock timeouts ourselves.
1164            timeout_at = time() + 5
1165            acquired = False
1166            while time() < timeout_at:
1167                acquired = self._fork_lock.acquire(False)
1168                if acquired:
1169                    break
1170            if not acquired:
1171                raise ChildDeadlockedError
1172            # reset() the instance for the new process if another thread
1173            # hasn't already done so
1174            try:
1175                if self.pid != os.getpid():
1176                    self.reset()
1177            finally:
1178                self._fork_lock.release()
1180    def get_connection(self, command_name, *keys, **options):
1181        "Get a connection from the pool"
1182        self._checkpid()
1183        with self._lock:
1184            try:
1185                connection = self._available_connections.pop()
1186            except IndexError:
1187                connection = self.make_connection()
1188            self._in_use_connections.add(connection)
1190        try:
1191            # ensure this connection is connected to Redis
1192            connection.connect()
1193            # connections that the pool provides should be ready to send
1194            # a command. if not, the connection was either returned to the
1195            # pool before all data has been read or the socket has been
1196            # closed. either way, reconnect and verify everything is good.
1197            try:
1198                if connection.can_read():
1199                    raise ConnectionError('Connection has data')
1200            except ConnectionError:
1201                connection.disconnect()
1202                connection.connect()
1203                if connection.can_read():
1204                    raise ConnectionError('Connection not ready')
1205        except BaseException:
1206            # release the connection back to the pool so that we don't
1207            # leak it
1208            self.release(connection)
1209            raise
1211        return connection
1213    def get_encoder(self):
1214        "Return an encoder based on encoding settings"
1215        kwargs = self.connection_kwargs
1216        return Encoder(
1217            encoding=kwargs.get('encoding', 'utf-8'),
1218            encoding_errors=kwargs.get('encoding_errors', 'strict'),
1219            decode_responses=kwargs.get('decode_responses', False)
1220        )
1222    def make_connection(self):
1223        "Create a new connection"
1224        if self._created_connections >= self.max_connections:
1225            raise ConnectionError("Too many connections")
1226        self._created_connections += 1
1227        return self.connection_class(**self.connection_kwargs)
1229    def release(self, connection):
1230        "Releases the connection back to the pool"
1231        self._checkpid()
1232        with self._lock:
1233            try:
1234                self._in_use_connections.remove(connection)
1235            except KeyError:
1236                # Gracefully fail when a connection is returned to this pool
1237                # that the pool doesn't actually own
1238                pass
1240            if self.owns_connection(connection):
1241                self._available_connections.append(connection)
1242            else:
1243                # pool doesn't own this connection. do not add it back
1244                # to the pool and decrement the count so that another
1245                # connection can take its place if needed
1246                self._created_connections -= 1
1247                connection.disconnect()
1248                return
1250    def owns_connection(self, connection):
1251        return connection.pid == self.pid
1253    def disconnect(self, inuse_connections=True):
1254        """
1255        Disconnects connections in the pool
1257        If ``inuse_connections`` is True, disconnect connections that are
1258        current in use, potentially by other threads. Otherwise only disconnect
1259        connections that are idle in the pool.
1260        """
1261        self._checkpid()
1262        with self._lock:
1263            if inuse_connections:
1264                connections = chain(self._available_connections,
1265                                    self._in_use_connections)
1266            else:
1267                connections = self._available_connections
1269            for connection in connections:
1270                connection.disconnect()
1273class BlockingConnectionPool(ConnectionPool):
1274    """
1275    Thread-safe blocking connection pool::
1277        >>> from redis.client import Redis
1278        >>> client = Redis(connection_pool=BlockingConnectionPool())
1280    It performs the same function as the default
1281    ``:py:class: ~redis.connection.ConnectionPool`` implementation, in that,
1282    it maintains a pool of reusable connections that can be shared by
1283    multiple redis clients (safely across threads if required).
1285    The difference is that, in the event that a client tries to get a
1286    connection from the pool when all of connections are in use, rather than
1287    raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default
1288    ``:py:class: ~redis.connection.ConnectionPool`` implementation does), it
1289    makes the client wait ("blocks") for a specified number of seconds until
1290    a connection becomes available.
1292    Use ``max_connections`` to increase / decrease the pool size::
1294        >>> pool = BlockingConnectionPool(max_connections=10)
1296    Use ``timeout`` to tell it either how many seconds to wait for a connection
1297    to become available, or to block forever:
1299        # Block forever.
1300        >>> pool = BlockingConnectionPool(timeout=None)
1302        # Raise a ``ConnectionError`` after five seconds if a connection is
1303        # not available.
1304        >>> pool = BlockingConnectionPool(timeout=5)
1305    """
1306    def __init__(self, max_connections=50, timeout=20,
1307                 connection_class=Connection, queue_class=LifoQueue,
1308                 **connection_kwargs):
1310        self.queue_class = queue_class
1311        self.timeout = timeout
1312        super(BlockingConnectionPool, self).__init__(
1313            connection_class=connection_class,
1314            max_connections=max_connections,
1315            **connection_kwargs)
1317    def reset(self):
1318        # Create and fill up a thread safe queue with ``None`` values.
1319        self.pool = self.queue_class(self.max_connections)
1320        while True:
1321            try:
1322                self.pool.put_nowait(None)
1323            except Full:
1324                break
1326        # Keep a list of actual connection instances so that we can
1327        # disconnect them later.
1328        self._connections = []
1330        # this must be the last operation in this method. while reset() is
1331        # called when holding _fork_lock, other threads in this process
1332        # can call _checkpid() which compares self.pid and os.getpid() without
1333        # holding any lock (for performance reasons). keeping this assignment
1334        # as the last operation ensures that those other threads will also
1335        # notice a pid difference and block waiting for the first thread to
1336        # release _fork_lock. when each of these threads eventually acquire
1337        # _fork_lock, they will notice that another thread already called
1338        # reset() and they will immediately release _fork_lock and continue on.
1339        self.pid = os.getpid()
1341    def make_connection(self):
1342        "Make a fresh connection."
1343        connection = self.connection_class(**self.connection_kwargs)
1344        self._connections.append(connection)
1345        return connection
1347    def get_connection(self, command_name, *keys, **options):
1348        """
1349        Get a connection, blocking for ``self.timeout`` until a connection
1350        is available from the pool.
1352        If the connection returned is ``None`` then creates a new connection.
1353        Because we use a last-in first-out queue, the existing connections
1354        (having been returned to the pool after the initial ``None`` values
1355        were added) will be returned before ``None`` values. This means we only
1356        create new connections when we need to, i.e.: the actual number of
1357        connections will only increase in response to demand.
1358        """
1359        # Make sure we haven't changed process.
1360        self._checkpid()
1362        # Try and get a connection from the pool. If one isn't available within
1363        # self.timeout then raise a ``ConnectionError``.
1364        connection = None
1365        try:
1366            connection = self.pool.get(block=True, timeout=self.timeout)
1367        except Empty:
1368            # Note that this is not caught by the redis client and will be
1369            # raised unless handled by application code. If you want never to
1370            raise ConnectionError("No connection available.")
1372        # If the ``connection`` is actually ``None`` then that's a cue to make
1373        # a new connection to add to the pool.
1374        if connection is None:
1375            connection = self.make_connection()
1377        try:
1378            # ensure this connection is connected to Redis
1379            connection.connect()
1380            # connections that the pool provides should be ready to send
1381            # a command. if not, the connection was either returned to the
1382            # pool before all data has been read or the socket has been
1383            # closed. either way, reconnect and verify everything is good.
1384            try:
1385                if connection.can_read():
1386                    raise ConnectionError('Connection has data')
1387            except ConnectionError:
1388                connection.disconnect()
1389                connection.connect()
1390                if connection.can_read():
1391                    raise ConnectionError('Connection not ready')
1392        except BaseException:
1393            # release the connection back to the pool so that we don't leak it
1394            self.release(connection)
1395            raise
1397        return connection
1399    def release(self, connection):
1400        "Releases the connection back to the pool."
1401        # Make sure we haven't changed process.
1402        self._checkpid()
1403        if not self.owns_connection(connection):
1404            # pool doesn't own this connection. do not add it back
1405            # to the pool. instead add a None value which is a placeholder
1406            # that will cause the pool to recreate the connection if
1407            # its needed.
1408            connection.disconnect()
1409            self.pool.put_nowait(None)
1410            return
1412        # Put the connection back into the pool.
1413        try:
1414            self.pool.put_nowait(connection)
1415        except Full:
1416            # perhaps the pool has been reset() after a fork? regardless,
1417            # we don't want this connection
1418            pass
1420    def disconnect(self):
1421        "Disconnects all connections in the pool."
1422        self._checkpid()
1423        for connection in self._connections:
1424            connection.disconnect()