1import random
2import re
3import socket
4from collections import OrderedDict
5from datetime import datetime
6from typing import Any, Dict, Iterator, List, Optional, Union
7
8from django.conf import settings
9from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func
10from django.core.exceptions import ImproperlyConfigured
11from django.utils.module_loading import import_string
12from redis import Redis
13from redis.exceptions import ConnectionError, ResponseError, TimeoutError
14
15from .. import pool
16from ..exceptions import CompressorError, ConnectionInterrupted
17from ..util import CacheKey
18
19_main_exceptions = (TimeoutError, ResponseError, ConnectionError, socket.timeout)
20
21special_re = re.compile("([*?[])")
22
23
24def glob_escape(s: str) -> str:
25    return special_re.sub(r"[\1]", s)
26
27
28class DefaultClient:
29    def __init__(self, server, params: Dict[str, Any], backend: BaseCache) -> None:
30        self._backend = backend
31        self._server = server
32        self._params = params
33
34        self.reverse_key = get_key_func(
35            params.get("REVERSE_KEY_FUNCTION")
36            or "django_redis.util.default_reverse_key"
37        )
38
39        if not self._server:
40            raise ImproperlyConfigured("Missing connections string")
41
42        if not isinstance(self._server, (list, tuple, set)):
43            self._server = self._server.split(",")
44
45        self._clients: List[Optional[Redis]] = [None] * len(self._server)
46        self._options = params.get("OPTIONS", {})
47        self._replica_read_only = self._options.get("REPLICA_READ_ONLY", True)
48
49        serializer_path = self._options.get(
50            "SERIALIZER", "django_redis.serializers.pickle.PickleSerializer"
51        )
52        serializer_cls = import_string(serializer_path)
53
54        compressor_path = self._options.get(
55            "COMPRESSOR", "django_redis.compressors.identity.IdentityCompressor"
56        )
57        compressor_cls = import_string(compressor_path)
58
59        self._serializer = serializer_cls(options=self._options)
60        self._compressor = compressor_cls(options=self._options)
61
62        self.connection_factory = pool.get_connection_factory(options=self._options)
63
64    def __contains__(self, key: Any) -> bool:
65        return self.has_key(key)
66
67    def get_next_client_index(
68        self, write: bool = True, tried: Optional[List[int]] = None
69    ) -> int:
70        """
71        Return a next index for read client. This function implements a default
72        behavior for get a next read client for a replication setup.
73
74        Overwrite this function if you want a specific
75        behavior.
76        """
77        if tried is None:
78            tried = list()
79
80        if tried and len(tried) < len(self._server):
81            not_tried = [i for i in range(0, len(self._server)) if i not in tried]
82            return random.choice(not_tried)
83
84        if write or len(self._server) == 1:
85            return 0
86
87        return random.randint(1, len(self._server) - 1)
88
89    def get_client(
90        self,
91        write: bool = True,
92        tried: Optional[List[int]] = None,
93        show_index: bool = False,
94    ):
95        """
96        Method used for obtain a raw redis client.
97
98        This function is used by almost all cache backend
99        operations for obtain a native redis client/connection
100        instance.
101        """
102        index = self.get_next_client_index(write=write, tried=tried)
103
104        if self._clients[index] is None:
105            self._clients[index] = self.connect(index)
106
107        if show_index:
108            return self._clients[index], index
109        else:
110            return self._clients[index]
111
112    def connect(self, index: int = 0) -> Redis:
113        """
114        Given a connection index, returns a new raw redis client/connection
115        instance. Index is used for replication setups and indicates that
116        connection string should be used. In normal setups, index is 0.
117        """
118        return self.connection_factory.connect(self._server[index])
119
120    def disconnect(self, index=0, client=None):
121        """delegates the connection factory to disconnect the client"""
122        if not client:
123            client = self._clients[index]
124        return self.connection_factory.disconnect(client) if client else None
125
126    def set(
127        self,
128        key: Any,
129        value: Any,
130        timeout: Optional[float] = DEFAULT_TIMEOUT,
131        version: Optional[int] = None,
132        client: Optional[Redis] = None,
133        nx: bool = False,
134        xx: bool = False,
135    ) -> bool:
136        """
137        Persist a value to the cache, and set an optional expiration time.
138
139        Also supports optional nx parameter. If set to True - will use redis
140        setnx instead of set.
141        """
142        nkey = self.make_key(key, version=version)
143        nvalue = self.encode(value)
144
145        if timeout is DEFAULT_TIMEOUT:
146            timeout = self._backend.default_timeout
147
148        original_client = client
149        tried: List[int] = []
150        while True:
151            try:
152                if client is None:
153                    client, index = self.get_client(
154                        write=True, tried=tried, show_index=True
155                    )
156
157                if timeout is not None:
158                    # Convert to milliseconds
159                    timeout = int(timeout * 1000)
160
161                    if timeout <= 0:
162                        if nx:
163                            # Using negative timeouts when nx is True should
164                            # not expire (in our case delete) the value if it exists.
165                            # Obviously expire not existent value is noop.
166                            return not self.has_key(key, version=version, client=client)
167                        else:
168                            # redis doesn't support negative timeouts in ex flags
169                            # so it seems that it's better to just delete the key
170                            # than to set it and than expire in a pipeline
171                            return bool(
172                                self.delete(key, client=client, version=version)
173                            )
174
175                return bool(client.set(nkey, nvalue, nx=nx, px=timeout, xx=xx))
176            except _main_exceptions as e:
177                if (
178                    not original_client
179                    and not self._replica_read_only
180                    and len(tried) < len(self._server)
181                ):
182                    tried.append(index)
183                    client = None
184                    continue
185                raise ConnectionInterrupted(connection=client) from e
186
187    def incr_version(
188        self,
189        key: Any,
190        delta: int = 1,
191        version: Optional[int] = None,
192        client: Optional[Redis] = None,
193    ) -> int:
194        """
195        Adds delta to the cache version for the supplied key. Returns the
196        new version.
197        """
198
199        if client is None:
200            client = self.get_client(write=True)
201
202        if version is None:
203            version = self._backend.version
204
205        old_key = self.make_key(key, version)
206        value = self.get(old_key, version=version, client=client)
207
208        try:
209            ttl = self.ttl(old_key, version=version, client=client)
210        except _main_exceptions as e:
211            raise ConnectionInterrupted(connection=client) from e
212
213        if value is None:
214            raise ValueError("Key '%s' not found" % key)
215
216        if isinstance(key, CacheKey):
217            new_key = self.make_key(key.original_key(), version=version + delta)
218        else:
219            new_key = self.make_key(key, version=version + delta)
220
221        self.set(new_key, value, timeout=ttl, client=client)
222        self.delete(old_key, client=client)
223        return version + delta
224
225    def add(
226        self,
227        key: Any,
228        value: Any,
229        timeout: Any = DEFAULT_TIMEOUT,
230        version: Optional[Any] = None,
231        client: Optional[Redis] = None,
232    ) -> bool:
233        """
234        Add a value to the cache, failing if the key already exists.
235
236        Returns ``True`` if the object was added, ``False`` if not.
237        """
238        return self.set(key, value, timeout, version=version, client=client, nx=True)
239
240    def get(
241        self,
242        key: Any,
243        default=None,
244        version: Optional[int] = None,
245        client: Optional[Redis] = None,
246    ) -> Any:
247        """
248        Retrieve a value from the cache.
249
250        Returns decoded value if key is found, the default if not.
251        """
252        if client is None:
253            client = self.get_client(write=False)
254
255        key = self.make_key(key, version=version)
256
257        try:
258            value = client.get(key)
259        except _main_exceptions as e:
260            raise ConnectionInterrupted(connection=client) from e
261
262        if value is None:
263            return default
264
265        return self.decode(value)
266
267    def persist(
268        self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None
269    ) -> bool:
270        if client is None:
271            client = self.get_client(write=True)
272
273        key = self.make_key(key, version=version)
274
275        return client.persist(key)
276
277    def expire(
278        self,
279        key: Any,
280        timeout,
281        version: Optional[int] = None,
282        client: Optional[Redis] = None,
283    ) -> bool:
284        if client is None:
285            client = self.get_client(write=True)
286
287        key = self.make_key(key, version=version)
288
289        return client.expire(key, timeout)
290
291    def pexpire(self, key, timeout, version=None, client=None) -> bool:
292        if client is None:
293            client = self.get_client(write=True)
294
295        key = self.make_key(key, version=version)
296
297        # Temporary casting until https://github.com/redis/redis-py/issues/1664
298        # is fixed.
299        return bool(client.pexpire(key, timeout))
300
301    def pexpire_at(
302        self,
303        key: Any,
304        when: Union[datetime, int],
305        version: Optional[int] = None,
306        client: Optional[Redis] = None,
307    ) -> bool:
308        """
309        Set an expire flag on a ``key`` to ``when``, which can be represented
310        as an integer indicating unix time or a Python datetime object.
311        """
312        if client is None:
313            client = self.get_client(write=True)
314
315        key = self.make_key(key, version=version)
316
317        return bool(client.pexpireat(key, when))
318
319    def expire_at(
320        self,
321        key: Any,
322        when: Union[datetime, int],
323        version: Optional[int] = None,
324        client: Optional[Redis] = None,
325    ) -> bool:
326        """
327        Set an expire flag on a ``key`` to ``when``, which can be represented
328        as an integer indicating unix time or a Python datetime object.
329        """
330        if client is None:
331            client = self.get_client(write=True)
332
333        key = self.make_key(key, version=version)
334
335        return client.expireat(key, when)
336
337    def lock(
338        self,
339        key,
340        version: Optional[int] = None,
341        timeout=None,
342        sleep=0.1,
343        blocking_timeout=None,
344        client: Optional[Redis] = None,
345        thread_local=True,
346    ):
347        if client is None:
348            client = self.get_client(write=True)
349
350        key = self.make_key(key, version=version)
351        return client.lock(
352            key,
353            timeout=timeout,
354            sleep=sleep,
355            blocking_timeout=blocking_timeout,
356            thread_local=thread_local,
357        )
358
359    def delete(
360        self,
361        key: Any,
362        version: Optional[int] = None,
363        prefix: Optional[str] = None,
364        client: Optional[Redis] = None,
365    ) -> int:
366        """
367        Remove a key from the cache.
368        """
369        if client is None:
370            client = self.get_client(write=True)
371
372        try:
373            return client.delete(self.make_key(key, version=version, prefix=prefix))
374        except _main_exceptions as e:
375            raise ConnectionInterrupted(connection=client) from e
376
377    def delete_pattern(
378        self,
379        pattern: str,
380        version: Optional[int] = None,
381        prefix: Optional[str] = None,
382        client: Optional[Redis] = None,
383        itersize: Optional[int] = None,
384    ) -> int:
385        """
386        Remove all keys matching pattern.
387        """
388
389        if client is None:
390            client = self.get_client(write=True)
391
392        pattern = self.make_pattern(pattern, version=version, prefix=prefix)
393
394        try:
395            count = 0
396            for key in client.scan_iter(match=pattern, count=itersize):
397                client.delete(key)
398                count += 1
399            return count
400        except _main_exceptions as e:
401            raise ConnectionInterrupted(connection=client) from e
402
403    def delete_many(
404        self, keys, version: Optional[int] = None, client: Optional[Redis] = None
405    ):
406        """
407        Remove multiple keys at once.
408        """
409
410        if client is None:
411            client = self.get_client(write=True)
412
413        keys = [self.make_key(k, version=version) for k in keys]
414
415        if not keys:
416            return
417
418        try:
419            return client.delete(*keys)
420        except _main_exceptions as e:
421            raise ConnectionInterrupted(connection=client) from e
422
423    def clear(self, client: Optional[Redis] = None) -> None:
424        """
425        Flush all cache keys.
426        """
427
428        if client is None:
429            client = self.get_client(write=True)
430
431        try:
432            client.flushdb()
433        except _main_exceptions as e:
434            raise ConnectionInterrupted(connection=client) from e
435
436    def decode(self, value: Union[bytes, int]) -> Any:
437        """
438        Decode the given value.
439        """
440        try:
441            value = int(value)
442        except (ValueError, TypeError):
443            try:
444                value = self._compressor.decompress(value)
445            except CompressorError:
446                # Handle little values, chosen to be not compressed
447                pass
448            value = self._serializer.loads(value)
449        return value
450
451    def encode(self, value: Any) -> Union[bytes, Any]:
452        """
453        Encode the given value.
454        """
455
456        if isinstance(value, bool) or not isinstance(value, int):
457            value = self._serializer.dumps(value)
458            value = self._compressor.compress(value)
459            return value
460
461        return value
462
463    def get_many(
464        self, keys, version: Optional[int] = None, client: Optional[Redis] = None
465    ) -> OrderedDict:
466        """
467        Retrieve many keys.
468        """
469
470        if client is None:
471            client = self.get_client(write=False)
472
473        if not keys:
474            return OrderedDict()
475
476        recovered_data = OrderedDict()
477
478        map_keys = OrderedDict((self.make_key(k, version=version), k) for k in keys)
479
480        try:
481            results = client.mget(*map_keys)
482        except _main_exceptions as e:
483            raise ConnectionInterrupted(connection=client) from e
484
485        for key, value in zip(map_keys, results):
486            if value is None:
487                continue
488            recovered_data[map_keys[key]] = self.decode(value)
489        return recovered_data
490
491    def set_many(
492        self,
493        data: Dict[Any, Any],
494        timeout: Optional[float] = DEFAULT_TIMEOUT,
495        version: Optional[int] = None,
496        client: Optional[Redis] = None,
497    ) -> None:
498        """
499        Set a bunch of values in the cache at once from a dict of key/value
500        pairs. This is much more efficient than calling set() multiple times.
501
502        If timeout is given, that timeout will be used for the key; otherwise
503        the default cache timeout will be used.
504        """
505        if client is None:
506            client = self.get_client(write=True)
507
508        try:
509            pipeline = client.pipeline()
510            for key, value in data.items():
511                self.set(key, value, timeout, version=version, client=pipeline)
512            pipeline.execute()
513        except _main_exceptions as e:
514            raise ConnectionInterrupted(connection=client) from e
515
516    def _incr(
517        self,
518        key: Any,
519        delta: int = 1,
520        version: Optional[int] = None,
521        client: Optional[Redis] = None,
522        ignore_key_check: bool = False,
523    ) -> int:
524        if client is None:
525            client = self.get_client(write=True)
526
527        key = self.make_key(key, version=version)
528
529        try:
530            try:
531                # if key expired after exists check, then we get
532                # key with wrong value and ttl -1.
533                # use lua script for atomicity
534                if not ignore_key_check:
535                    lua = """
536                    local exists = redis.call('EXISTS', KEYS[1])
537                    if (exists == 1) then
538                        return redis.call('INCRBY', KEYS[1], ARGV[1])
539                    else return false end
540                    """
541                else:
542                    lua = """
543                    return redis.call('INCRBY', KEYS[1], ARGV[1])
544                    """
545                value = client.eval(lua, 1, key, delta)
546                if value is None:
547                    raise ValueError("Key '%s' not found" % key)
548            except ResponseError:
549                # if cached value or total value is greater than 64 bit signed
550                # integer.
551                # elif int is encoded. so redis sees the data as string.
552                # In this situations redis will throw ResponseError
553
554                # try to keep TTL of key
555                timeout = self.ttl(key, version=version, client=client)
556
557                # returns -2 if the key does not exist
558                # means, that key have expired
559                if timeout == -2:
560                    raise ValueError("Key '%s' not found" % key)
561                value = self.get(key, version=version, client=client) + delta
562                self.set(key, value, version=version, timeout=timeout, client=client)
563        except _main_exceptions as e:
564            raise ConnectionInterrupted(connection=client) from e
565
566        return value
567
568    def incr(
569        self,
570        key: Any,
571        delta: int = 1,
572        version: Optional[int] = None,
573        client: Optional[Redis] = None,
574        ignore_key_check: bool = False,
575    ) -> int:
576        """
577        Add delta to value in the cache. If the key does not exist, raise a
578        ValueError exception. if ignore_key_check=True then the key will be
579        created and set to the delta value by default.
580        """
581        return self._incr(
582            key=key,
583            delta=delta,
584            version=version,
585            client=client,
586            ignore_key_check=ignore_key_check,
587        )
588
589    def decr(
590        self,
591        key: Any,
592        delta: int = 1,
593        version: Optional[int] = None,
594        client: Optional[Redis] = None,
595    ) -> int:
596        """
597        Decreace delta to value in the cache. If the key does not exist, raise a
598        ValueError exception.
599        """
600        return self._incr(key=key, delta=-delta, version=version, client=client)
601
602    def ttl(
603        self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None
604    ) -> Optional[int]:
605        """
606        Executes TTL redis command and return the "time-to-live" of specified key.
607        If key is a non volatile key, it returns None.
608        """
609        if client is None:
610            client = self.get_client(write=False)
611
612        key = self.make_key(key, version=version)
613        if not client.exists(key):
614            return 0
615
616        t = client.ttl(key)
617
618        if t >= 0:
619            return t
620        elif t == -1:
621            return None
622        elif t == -2:
623            return 0
624        else:
625            # Should never reach here
626            return None
627
628    def pttl(self, key, version=None, client=None):
629        """
630        Executes PTTL redis command and return the "time-to-live" of specified key.
631        If key is a non volatile key, it returns None.
632        """
633        if client is None:
634            client = self.get_client(write=False)
635
636        key = self.make_key(key, version=version)
637        if not client.exists(key):
638            return 0
639
640        t = client.pttl(key)
641
642        if t >= 0:
643            return t
644        elif t == -1:
645            return None
646        elif t == -2:
647            return 0
648        else:
649            # Should never reach here
650            return None
651
652    def has_key(
653        self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None
654    ) -> bool:
655        """
656        Test if key exists.
657        """
658
659        if client is None:
660            client = self.get_client(write=False)
661
662        key = self.make_key(key, version=version)
663        try:
664            return client.exists(key) == 1
665        except _main_exceptions as e:
666            raise ConnectionInterrupted(connection=client) from e
667
668    def iter_keys(
669        self,
670        search: str,
671        itersize: Optional[int] = None,
672        client: Optional[Redis] = None,
673        version: Optional[int] = None,
674    ) -> Iterator[str]:
675        """
676        Same as keys, but uses redis >= 2.8 cursors
677        for make memory efficient keys iteration.
678        """
679
680        if client is None:
681            client = self.get_client(write=False)
682
683        pattern = self.make_pattern(search, version=version)
684        for item in client.scan_iter(match=pattern, count=itersize):
685            yield self.reverse_key(item.decode())
686
687    def keys(
688        self, search: str, version: Optional[int] = None, client: Optional[Redis] = None
689    ) -> List[Any]:
690        """
691        Execute KEYS command and return matched results.
692        Warning: this can return huge number of results, in
693        this case, it strongly recommended use iter_keys
694        for it.
695        """
696
697        if client is None:
698            client = self.get_client(write=False)
699
700        pattern = self.make_pattern(search, version=version)
701        try:
702            return [self.reverse_key(k.decode()) for k in client.keys(pattern)]
703        except _main_exceptions as e:
704            raise ConnectionInterrupted(connection=client) from e
705
706    def make_key(
707        self, key: Any, version: Optional[Any] = None, prefix: Optional[str] = None
708    ) -> CacheKey:
709        if isinstance(key, CacheKey):
710            return key
711
712        if prefix is None:
713            prefix = self._backend.key_prefix
714
715        if version is None:
716            version = self._backend.version
717
718        return CacheKey(self._backend.key_func(key, prefix, version))
719
720    def make_pattern(
721        self, pattern: str, version: Optional[int] = None, prefix: Optional[str] = None
722    ) -> CacheKey:
723        if isinstance(pattern, CacheKey):
724            return pattern
725
726        if prefix is None:
727            prefix = self._backend.key_prefix
728        prefix = glob_escape(prefix)
729
730        if version is None:
731            version = self._backend.version
732        version_str = glob_escape(str(version))
733
734        return CacheKey(self._backend.key_func(pattern, prefix, version_str))
735
736    def close(self, **kwargs):
737        close_flag = self._options.get(
738            "CLOSE_CONNECTION",
739            getattr(settings, "DJANGO_REDIS_CLOSE_CONNECTION", False),
740        )
741        if close_flag:
742            self.do_close_clients()
743
744    def do_close_clients(self):
745        """default implementation: Override in custom client"""
746        num_clients = len(self._clients)
747        for idx in range(num_clients):
748            self.disconnect(index=idx)
749        self._clients = [None] * num_clients
750
751    def touch(
752        self,
753        key: Any,
754        timeout: Optional[float] = DEFAULT_TIMEOUT,
755        version: Optional[int] = None,
756        client: Optional[Redis] = None,
757    ) -> bool:
758        """
759        Sets a new expiration for a key.
760        """
761
762        if timeout is DEFAULT_TIMEOUT:
763            timeout = self._backend.default_timeout
764
765        if client is None:
766            client = self.get_client(write=True)
767
768        key = self.make_key(key, version=version)
769        if timeout is None:
770            return bool(client.persist(key))
771        else:
772            # Convert to milliseconds
773            timeout = int(timeout * 1000)
774            return bool(client.pexpire(key, timeout))
775