1import collections
2import socket
3import time
4import logging
5import six
6
7from pymemcache.client.base import (
8    Client,
9    PooledClient,
10    check_key_helper,
11    normalize_server_spec,
12)
13from pymemcache.client.rendezvous import RendezvousHash
14from pymemcache.exceptions import MemcacheError
15
16logger = logging.getLogger(__name__)
17
18
19class HashClient(object):
20    """
21    A client for communicating with a cluster of memcached servers
22    """
23    #: :class:`Client` class used to create new clients
24    client_class = Client
25
26    def __init__(
27        self,
28        servers,
29        hasher=RendezvousHash,
30        serde=None,
31        serializer=None,
32        deserializer=None,
33        connect_timeout=None,
34        timeout=None,
35        no_delay=False,
36        socket_module=socket,
37        socket_keepalive=None,
38        key_prefix=b'',
39        max_pool_size=None,
40        pool_idle_timeout=0,
41        lock_generator=None,
42        retry_attempts=2,
43        retry_timeout=1,
44        dead_timeout=60,
45        use_pooling=False,
46        ignore_exc=False,
47        allow_unicode_keys=False,
48        default_noreply=True,
49        encoding='ascii',
50        tls_context=None
51    ):
52        """
53        Constructor.
54
55        Args:
56          servers: list() of tuple(hostname, port) or string containing a UNIX
57                   socket path.
58          hasher: optional class three functions ``get_node``, ``add_node``,
59                  and ``remove_node``
60                  defaults to Rendezvous (HRW) hash.
61
62          use_pooling: use py:class:`.PooledClient` as the default underlying
63                       class. ``max_pool_size`` and ``lock_generator`` can
64                       be used with this. default: False
65
66          retry_attempts: Amount of times a client should be tried before it
67                          is marked dead and removed from the pool.
68          retry_timeout (float): Time in seconds that should pass between retry
69                                 attempts.
70          dead_timeout (float): Time in seconds before attempting to add a node
71                                back in the pool.
72          encoding: optional str, controls data encoding (defaults to 'ascii').
73
74        Further arguments are interpreted as for :py:class:`.Client`
75        constructor.
76        """
77        self.clients = {}
78        self.retry_attempts = retry_attempts
79        self.retry_timeout = retry_timeout
80        self.dead_timeout = dead_timeout
81        self.use_pooling = use_pooling
82        self.key_prefix = key_prefix
83        self.ignore_exc = ignore_exc
84        self.allow_unicode_keys = allow_unicode_keys
85        self._failed_clients = {}
86        self._dead_clients = {}
87        self._last_dead_check_time = time.time()
88
89        self.hasher = hasher()
90
91        self.default_kwargs = {
92            'connect_timeout': connect_timeout,
93            'timeout': timeout,
94            'no_delay': no_delay,
95            'socket_module': socket_module,
96            'socket_keepalive': socket_keepalive,
97            'key_prefix': key_prefix,
98            'serde': serde,
99            'serializer': serializer,
100            'deserializer': deserializer,
101            'allow_unicode_keys': allow_unicode_keys,
102            'default_noreply': default_noreply,
103            'encoding': encoding,
104            'tls_context': tls_context,
105        }
106
107        if use_pooling is True:
108            self.default_kwargs.update({
109                'max_pool_size': max_pool_size,
110                'pool_idle_timeout': pool_idle_timeout,
111                'lock_generator': lock_generator
112            })
113
114        for server in servers:
115            self.add_server(normalize_server_spec(server))
116        self.encoding = encoding
117        self.tls_context = tls_context
118
119    def _make_client_key(self, server):
120        if isinstance(server, (list, tuple)) and len(server) == 2:
121            return '%s:%s' % server
122        return server
123
124    def add_server(self, server, port=None):
125        # To maintain backward compatibility, if a port is provided, assume
126        # that server wasn't provided as a (host, port) tuple.
127        if port is not None:
128            if not isinstance(server, six.string_types):
129                raise TypeError('Server must be a string when passing port.')
130            server = (server, port)
131
132        _class = PooledClient if self.use_pooling else self.client_class
133        client = _class(server, **self.default_kwargs)
134        if self.use_pooling:
135            client.client_class = self.client_class
136
137        key = self._make_client_key(server)
138        self.clients[key] = client
139        self.hasher.add_node(key)
140
141    def remove_server(self, server, port=None):
142        # To maintain backward compatibility, if a port is provided, assume
143        # that server wasn't provided as a (host, port) tuple.
144        if port is not None:
145            if not isinstance(server, six.string_types):
146                raise TypeError('Server must be a string when passing port.')
147            server = (server, port)
148
149        key = self._make_client_key(server)
150        dead_time = time.time()
151        self._failed_clients.pop(server)
152        self._dead_clients[server] = dead_time
153        self.hasher.remove_node(key)
154
155    def _retry_dead(self):
156        current_time = time.time()
157        ldc = self._last_dead_check_time
158        # We have reached the retry timeout
159        if current_time - ldc > self.dead_timeout:
160            candidates = []
161            for server, dead_time in self._dead_clients.items():
162                if current_time - dead_time > self.dead_timeout:
163                    candidates.append(server)
164            for server in candidates:
165                logger.debug(
166                    'bringing server back into rotation %s',
167                    server
168                )
169                self.add_server(server)
170                del self._dead_clients[server]
171            self._last_dead_check_time = current_time
172
173    def _get_client(self, key):
174        check_key_helper(key, self.allow_unicode_keys, self.key_prefix)
175        if self._dead_clients:
176            self._retry_dead()
177
178        server = self.hasher.get_node(key)
179        # We've ran out of servers to try
180        if server is None:
181            if self.ignore_exc is True:
182                return
183            raise MemcacheError('All servers seem to be down right now')
184
185        return self.clients[server]
186
187    def _safely_run_func(self, client, func, default_val, *args, **kwargs):
188        try:
189            if client.server in self._failed_clients:
190                # This server is currently failing, lets check if it is in
191                # retry or marked as dead
192                failed_metadata = self._failed_clients[client.server]
193
194                # we haven't tried our max amount yet, if it has been enough
195                # time lets just retry using it
196                if failed_metadata['attempts'] < self.retry_attempts:
197                    failed_time = failed_metadata['failed_time']
198                    if time.time() - failed_time > self.retry_timeout:
199                        logger.debug(
200                            'retrying failed server: %s', client.server
201                        )
202                        result = func(*args, **kwargs)
203                        # we were successful, lets remove it from the failed
204                        # clients
205                        self._failed_clients.pop(client.server)
206                        return result
207                    return default_val
208                else:
209                    # We've reached our max retry attempts, we need to mark
210                    # the sever as dead
211                    logger.debug('marking server as dead: %s', client.server)
212                    self.remove_server(client.server)
213
214            result = func(*args, **kwargs)
215            return result
216
217        # Connecting to the server fail, we should enter
218        # retry mode
219        except socket.error:
220            self._mark_failed_server(client.server)
221
222            # if we haven't enabled ignore_exc, don't move on gracefully, just
223            # raise the exception
224            if not self.ignore_exc:
225                raise
226
227            return default_val
228        except Exception:
229            # any exceptions that aren't socket.error we need to handle
230            # gracefully as well
231            if not self.ignore_exc:
232                raise
233
234            return default_val
235
236    def _safely_run_set_many(self, client, values, *args, **kwargs):
237        failed = []
238        succeeded = []
239        try:
240            if client.server in self._failed_clients:
241                # This server is currently failing, lets check if it is in
242                # retry or marked as dead
243                failed_metadata = self._failed_clients[client.server]
244
245                # we haven't tried our max amount yet, if it has been enough
246                # time lets just retry using it
247                if failed_metadata['attempts'] < self.retry_attempts:
248                    failed_time = failed_metadata['failed_time']
249                    if time.time() - failed_time > self.retry_timeout:
250                        logger.debug(
251                            'retrying failed server: %s', client.server
252                        )
253                        succeeded, failed, err = self._set_many(
254                            client, values, *args, **kwargs)
255                        if err is not None:
256                            raise err
257                        # we were successful, lets remove it from the failed
258                        # clients
259                        self._failed_clients.pop(client.server)
260                        return failed
261                    return values.keys()
262                else:
263                    # We've reached our max retry attempts, we need to mark
264                    # the sever as dead
265                    logger.debug('marking server as dead: %s', client.server)
266                    self.remove_server(client.server)
267
268            succeeded, failed, err = self._set_many(
269                client, values, *args, **kwargs
270            )
271            if err is not None:
272                raise err
273
274            return failed
275
276        # Connecting to the server fail, we should enter
277        # retry mode
278        except socket.error:
279            self._mark_failed_server(client.server)
280
281            # if we haven't enabled ignore_exc, don't move on gracefully, just
282            # raise the exception
283            if not self.ignore_exc:
284                raise
285
286            return list(set(values.keys()) - set(succeeded))
287        except Exception:
288            # any exceptions that aren't socket.error we need to handle
289            # gracefully as well
290            if not self.ignore_exc:
291                raise
292
293            return list(set(values.keys()) - set(succeeded))
294
295    def _mark_failed_server(self, server):
296        # This client has never failed, lets mark it for failure
297        if (
298                server not in self._failed_clients and
299                self.retry_attempts > 0
300        ):
301            self._failed_clients[server] = {
302                'failed_time': time.time(),
303                'attempts': 0,
304            }
305        # We aren't allowing any retries, we should mark the server as
306        # dead immediately
307        elif (
308            server not in self._failed_clients and
309            self.retry_attempts <= 0
310        ):
311            self._failed_clients[server] = {
312                'failed_time': time.time(),
313                'attempts': 0,
314            }
315            logger.debug("marking server as dead %s", server)
316            self.remove_server(server)
317        # This client has failed previously, we need to update the metadata
318        # to reflect that we have attempted it again
319        else:
320            failed_metadata = self._failed_clients[server]
321            failed_metadata['attempts'] += 1
322            failed_metadata['failed_time'] = time.time()
323            self._failed_clients[server] = failed_metadata
324
325    def _run_cmd(self, cmd, key, default_val, *args, **kwargs):
326        client = self._get_client(key)
327
328        if client is None:
329            return default_val
330
331        func = getattr(client, cmd)
332        args = list(args)
333        args.insert(0, key)
334        return self._safely_run_func(
335            client, func, default_val, *args, **kwargs
336        )
337
338    def _set_many(self, client, values, *args, **kwargs):
339        failed = []
340        succeeded = []
341
342        try:
343            failed = client.set_many(values, *args, **kwargs)
344        except Exception as e:
345            if not self.ignore_exc:
346                return succeeded, failed, e
347
348        succeeded = [key for key in six.iterkeys(values) if key not in failed]
349        return succeeded, failed, None
350
351    def close(self):
352        for client in self.clients.values():
353            self._safely_run_func(client, client.close, False)
354
355    disconnect_all = close
356
357    def set(self, key, *args, **kwargs):
358        return self._run_cmd('set', key, False, *args, **kwargs)
359
360    def get(self, key, *args, **kwargs):
361        return self._run_cmd('get', key, None, *args, **kwargs)
362
363    def incr(self, key, *args, **kwargs):
364        return self._run_cmd('incr', key, False, *args, **kwargs)
365
366    def decr(self, key, *args, **kwargs):
367        return self._run_cmd('decr', key, False, *args, **kwargs)
368
369    def set_many(self, values, *args, **kwargs):
370        client_batches = collections.defaultdict(dict)
371        failed = []
372
373        for key, value in six.iteritems(values):
374            client = self._get_client(key)
375
376            if client is None:
377                failed.append(key)
378                continue
379
380            client_batches[client.server][key] = value
381
382        for server, values in client_batches.items():
383            client = self.clients[self._make_client_key(server)]
384            failed += self._safely_run_set_many(
385                client, values, *args, **kwargs
386            )
387
388        return failed
389
390    set_multi = set_many
391
392    def get_many(self, keys, gets=False, *args, **kwargs):
393        client_batches = collections.defaultdict(list)
394        end = {}
395
396        for key in keys:
397            client = self._get_client(key)
398
399            if client is None:
400                continue
401
402            client_batches[client.server].append(key)
403
404        for server, keys in client_batches.items():
405            client = self.clients[self._make_client_key(server)]
406            new_args = list(args)
407            new_args.insert(0, keys)
408
409            if gets:
410                get_func = client.gets_many
411            else:
412                get_func = client.get_many
413
414            result = self._safely_run_func(
415                client,
416                get_func, {}, *new_args, **kwargs
417            )
418            end.update(result)
419
420        return end
421
422    get_multi = get_many
423
424    def gets(self, key, *args, **kwargs):
425        return self._run_cmd('gets', key, None, *args, **kwargs)
426
427    def gets_many(self, keys, *args, **kwargs):
428        return self.get_many(keys, gets=True, *args, **kwargs)
429
430    gets_multi = gets_many
431
432    def add(self, key, *args, **kwargs):
433        return self._run_cmd('add', key, False, *args, **kwargs)
434
435    def prepend(self, key, *args, **kwargs):
436        return self._run_cmd('prepend', key, False, *args, **kwargs)
437
438    def append(self, key, *args, **kwargs):
439        return self._run_cmd('append', key, False, *args, **kwargs)
440
441    def delete(self, key, *args, **kwargs):
442        return self._run_cmd('delete', key, False, *args, **kwargs)
443
444    def delete_many(self, keys, *args, **kwargs):
445        for key in keys:
446            self._run_cmd('delete', key, False, *args, **kwargs)
447        return True
448
449    delete_multi = delete_many
450
451    def cas(self, key, *args, **kwargs):
452        return self._run_cmd('cas', key, False, *args, **kwargs)
453
454    def replace(self, key, *args, **kwargs):
455        return self._run_cmd('replace', key, False, *args, **kwargs)
456
457    def touch(self, key, *args, **kwargs):
458        return self._run_cmd('touch', key, False, *args, **kwargs)
459
460    def flush_all(self):
461        for client in self.clients.values():
462            self._safely_run_func(client, client.flush_all, False)
463
464    def quit(self):
465        for client in self.clients.values():
466            self._safely_run_func(client, client.quit, False)
467