1import random 2import re 3import socket 4from collections import OrderedDict 5from datetime import datetime 6from typing import Any, Dict, Iterator, List, Optional, Union 7 8from django.conf import settings 9from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func 10from django.core.exceptions import ImproperlyConfigured 11from django.utils.module_loading import import_string 12from redis import Redis 13from redis.exceptions import ConnectionError, ResponseError, TimeoutError 14 15from .. import pool 16from ..exceptions import CompressorError, ConnectionInterrupted 17from ..util import CacheKey 18 19_main_exceptions = (TimeoutError, ResponseError, ConnectionError, socket.timeout) 20 21special_re = re.compile("([*?[])") 22 23 24def glob_escape(s: str) -> str: 25 return special_re.sub(r"[\1]", s) 26 27 28class DefaultClient: 29 def __init__(self, server, params: Dict[str, Any], backend: BaseCache) -> None: 30 self._backend = backend 31 self._server = server 32 self._params = params 33 34 self.reverse_key = get_key_func( 35 params.get("REVERSE_KEY_FUNCTION") 36 or "django_redis.util.default_reverse_key" 37 ) 38 39 if not self._server: 40 raise ImproperlyConfigured("Missing connections string") 41 42 if not isinstance(self._server, (list, tuple, set)): 43 self._server = self._server.split(",") 44 45 self._clients: List[Optional[Redis]] = [None] * len(self._server) 46 self._options = params.get("OPTIONS", {}) 47 self._replica_read_only = self._options.get("REPLICA_READ_ONLY", True) 48 49 serializer_path = self._options.get( 50 "SERIALIZER", "django_redis.serializers.pickle.PickleSerializer" 51 ) 52 serializer_cls = import_string(serializer_path) 53 54 compressor_path = self._options.get( 55 "COMPRESSOR", "django_redis.compressors.identity.IdentityCompressor" 56 ) 57 compressor_cls = import_string(compressor_path) 58 59 self._serializer = serializer_cls(options=self._options) 60 self._compressor = compressor_cls(options=self._options) 61 62 self.connection_factory = pool.get_connection_factory(options=self._options) 63 64 def __contains__(self, key: Any) -> bool: 65 return self.has_key(key) 66 67 def get_next_client_index( 68 self, write: bool = True, tried: Optional[List[int]] = None 69 ) -> int: 70 """ 71 Return a next index for read client. This function implements a default 72 behavior for get a next read client for a replication setup. 73 74 Overwrite this function if you want a specific 75 behavior. 76 """ 77 if tried is None: 78 tried = list() 79 80 if tried and len(tried) < len(self._server): 81 not_tried = [i for i in range(0, len(self._server)) if i not in tried] 82 return random.choice(not_tried) 83 84 if write or len(self._server) == 1: 85 return 0 86 87 return random.randint(1, len(self._server) - 1) 88 89 def get_client( 90 self, 91 write: bool = True, 92 tried: Optional[List[int]] = None, 93 show_index: bool = False, 94 ): 95 """ 96 Method used for obtain a raw redis client. 97 98 This function is used by almost all cache backend 99 operations for obtain a native redis client/connection 100 instance. 101 """ 102 index = self.get_next_client_index(write=write, tried=tried) 103 104 if self._clients[index] is None: 105 self._clients[index] = self.connect(index) 106 107 if show_index: 108 return self._clients[index], index 109 else: 110 return self._clients[index] 111 112 def connect(self, index: int = 0) -> Redis: 113 """ 114 Given a connection index, returns a new raw redis client/connection 115 instance. Index is used for replication setups and indicates that 116 connection string should be used. In normal setups, index is 0. 117 """ 118 return self.connection_factory.connect(self._server[index]) 119 120 def disconnect(self, index=0, client=None): 121 """delegates the connection factory to disconnect the client""" 122 if not client: 123 client = self._clients[index] 124 return self.connection_factory.disconnect(client) if client else None 125 126 def set( 127 self, 128 key: Any, 129 value: Any, 130 timeout: Optional[float] = DEFAULT_TIMEOUT, 131 version: Optional[int] = None, 132 client: Optional[Redis] = None, 133 nx: bool = False, 134 xx: bool = False, 135 ) -> bool: 136 """ 137 Persist a value to the cache, and set an optional expiration time. 138 139 Also supports optional nx parameter. If set to True - will use redis 140 setnx instead of set. 141 """ 142 nkey = self.make_key(key, version=version) 143 nvalue = self.encode(value) 144 145 if timeout is DEFAULT_TIMEOUT: 146 timeout = self._backend.default_timeout 147 148 original_client = client 149 tried: List[int] = [] 150 while True: 151 try: 152 if client is None: 153 client, index = self.get_client( 154 write=True, tried=tried, show_index=True 155 ) 156 157 if timeout is not None: 158 # Convert to milliseconds 159 timeout = int(timeout * 1000) 160 161 if timeout <= 0: 162 if nx: 163 # Using negative timeouts when nx is True should 164 # not expire (in our case delete) the value if it exists. 165 # Obviously expire not existent value is noop. 166 return not self.has_key(key, version=version, client=client) 167 else: 168 # redis doesn't support negative timeouts in ex flags 169 # so it seems that it's better to just delete the key 170 # than to set it and than expire in a pipeline 171 return bool( 172 self.delete(key, client=client, version=version) 173 ) 174 175 return bool(client.set(nkey, nvalue, nx=nx, px=timeout, xx=xx)) 176 except _main_exceptions as e: 177 if ( 178 not original_client 179 and not self._replica_read_only 180 and len(tried) < len(self._server) 181 ): 182 tried.append(index) 183 client = None 184 continue 185 raise ConnectionInterrupted(connection=client) from e 186 187 def incr_version( 188 self, 189 key: Any, 190 delta: int = 1, 191 version: Optional[int] = None, 192 client: Optional[Redis] = None, 193 ) -> int: 194 """ 195 Adds delta to the cache version for the supplied key. Returns the 196 new version. 197 """ 198 199 if client is None: 200 client = self.get_client(write=True) 201 202 if version is None: 203 version = self._backend.version 204 205 old_key = self.make_key(key, version) 206 value = self.get(old_key, version=version, client=client) 207 208 try: 209 ttl = self.ttl(old_key, version=version, client=client) 210 except _main_exceptions as e: 211 raise ConnectionInterrupted(connection=client) from e 212 213 if value is None: 214 raise ValueError("Key '%s' not found" % key) 215 216 if isinstance(key, CacheKey): 217 new_key = self.make_key(key.original_key(), version=version + delta) 218 else: 219 new_key = self.make_key(key, version=version + delta) 220 221 self.set(new_key, value, timeout=ttl, client=client) 222 self.delete(old_key, client=client) 223 return version + delta 224 225 def add( 226 self, 227 key: Any, 228 value: Any, 229 timeout: Any = DEFAULT_TIMEOUT, 230 version: Optional[Any] = None, 231 client: Optional[Redis] = None, 232 ) -> bool: 233 """ 234 Add a value to the cache, failing if the key already exists. 235 236 Returns ``True`` if the object was added, ``False`` if not. 237 """ 238 return self.set(key, value, timeout, version=version, client=client, nx=True) 239 240 def get( 241 self, 242 key: Any, 243 default=None, 244 version: Optional[int] = None, 245 client: Optional[Redis] = None, 246 ) -> Any: 247 """ 248 Retrieve a value from the cache. 249 250 Returns decoded value if key is found, the default if not. 251 """ 252 if client is None: 253 client = self.get_client(write=False) 254 255 key = self.make_key(key, version=version) 256 257 try: 258 value = client.get(key) 259 except _main_exceptions as e: 260 raise ConnectionInterrupted(connection=client) from e 261 262 if value is None: 263 return default 264 265 return self.decode(value) 266 267 def persist( 268 self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None 269 ) -> bool: 270 if client is None: 271 client = self.get_client(write=True) 272 273 key = self.make_key(key, version=version) 274 275 return client.persist(key) 276 277 def expire( 278 self, 279 key: Any, 280 timeout, 281 version: Optional[int] = None, 282 client: Optional[Redis] = None, 283 ) -> bool: 284 if client is None: 285 client = self.get_client(write=True) 286 287 key = self.make_key(key, version=version) 288 289 return client.expire(key, timeout) 290 291 def pexpire(self, key, timeout, version=None, client=None) -> bool: 292 if client is None: 293 client = self.get_client(write=True) 294 295 key = self.make_key(key, version=version) 296 297 # Temporary casting until https://github.com/redis/redis-py/issues/1664 298 # is fixed. 299 return bool(client.pexpire(key, timeout)) 300 301 def pexpire_at( 302 self, 303 key: Any, 304 when: Union[datetime, int], 305 version: Optional[int] = None, 306 client: Optional[Redis] = None, 307 ) -> bool: 308 """ 309 Set an expire flag on a ``key`` to ``when``, which can be represented 310 as an integer indicating unix time or a Python datetime object. 311 """ 312 if client is None: 313 client = self.get_client(write=True) 314 315 key = self.make_key(key, version=version) 316 317 return bool(client.pexpireat(key, when)) 318 319 def expire_at( 320 self, 321 key: Any, 322 when: Union[datetime, int], 323 version: Optional[int] = None, 324 client: Optional[Redis] = None, 325 ) -> bool: 326 """ 327 Set an expire flag on a ``key`` to ``when``, which can be represented 328 as an integer indicating unix time or a Python datetime object. 329 """ 330 if client is None: 331 client = self.get_client(write=True) 332 333 key = self.make_key(key, version=version) 334 335 return client.expireat(key, when) 336 337 def lock( 338 self, 339 key, 340 version: Optional[int] = None, 341 timeout=None, 342 sleep=0.1, 343 blocking_timeout=None, 344 client: Optional[Redis] = None, 345 thread_local=True, 346 ): 347 if client is None: 348 client = self.get_client(write=True) 349 350 key = self.make_key(key, version=version) 351 return client.lock( 352 key, 353 timeout=timeout, 354 sleep=sleep, 355 blocking_timeout=blocking_timeout, 356 thread_local=thread_local, 357 ) 358 359 def delete( 360 self, 361 key: Any, 362 version: Optional[int] = None, 363 prefix: Optional[str] = None, 364 client: Optional[Redis] = None, 365 ) -> int: 366 """ 367 Remove a key from the cache. 368 """ 369 if client is None: 370 client = self.get_client(write=True) 371 372 try: 373 return client.delete(self.make_key(key, version=version, prefix=prefix)) 374 except _main_exceptions as e: 375 raise ConnectionInterrupted(connection=client) from e 376 377 def delete_pattern( 378 self, 379 pattern: str, 380 version: Optional[int] = None, 381 prefix: Optional[str] = None, 382 client: Optional[Redis] = None, 383 itersize: Optional[int] = None, 384 ) -> int: 385 """ 386 Remove all keys matching pattern. 387 """ 388 389 if client is None: 390 client = self.get_client(write=True) 391 392 pattern = self.make_pattern(pattern, version=version, prefix=prefix) 393 394 try: 395 count = 0 396 for key in client.scan_iter(match=pattern, count=itersize): 397 client.delete(key) 398 count += 1 399 return count 400 except _main_exceptions as e: 401 raise ConnectionInterrupted(connection=client) from e 402 403 def delete_many( 404 self, keys, version: Optional[int] = None, client: Optional[Redis] = None 405 ): 406 """ 407 Remove multiple keys at once. 408 """ 409 410 if client is None: 411 client = self.get_client(write=True) 412 413 keys = [self.make_key(k, version=version) for k in keys] 414 415 if not keys: 416 return 417 418 try: 419 return client.delete(*keys) 420 except _main_exceptions as e: 421 raise ConnectionInterrupted(connection=client) from e 422 423 def clear(self, client: Optional[Redis] = None) -> None: 424 """ 425 Flush all cache keys. 426 """ 427 428 if client is None: 429 client = self.get_client(write=True) 430 431 try: 432 client.flushdb() 433 except _main_exceptions as e: 434 raise ConnectionInterrupted(connection=client) from e 435 436 def decode(self, value: Union[bytes, int]) -> Any: 437 """ 438 Decode the given value. 439 """ 440 try: 441 value = int(value) 442 except (ValueError, TypeError): 443 try: 444 value = self._compressor.decompress(value) 445 except CompressorError: 446 # Handle little values, chosen to be not compressed 447 pass 448 value = self._serializer.loads(value) 449 return value 450 451 def encode(self, value: Any) -> Union[bytes, Any]: 452 """ 453 Encode the given value. 454 """ 455 456 if isinstance(value, bool) or not isinstance(value, int): 457 value = self._serializer.dumps(value) 458 value = self._compressor.compress(value) 459 return value 460 461 return value 462 463 def get_many( 464 self, keys, version: Optional[int] = None, client: Optional[Redis] = None 465 ) -> OrderedDict: 466 """ 467 Retrieve many keys. 468 """ 469 470 if client is None: 471 client = self.get_client(write=False) 472 473 if not keys: 474 return OrderedDict() 475 476 recovered_data = OrderedDict() 477 478 map_keys = OrderedDict((self.make_key(k, version=version), k) for k in keys) 479 480 try: 481 results = client.mget(*map_keys) 482 except _main_exceptions as e: 483 raise ConnectionInterrupted(connection=client) from e 484 485 for key, value in zip(map_keys, results): 486 if value is None: 487 continue 488 recovered_data[map_keys[key]] = self.decode(value) 489 return recovered_data 490 491 def set_many( 492 self, 493 data: Dict[Any, Any], 494 timeout: Optional[float] = DEFAULT_TIMEOUT, 495 version: Optional[int] = None, 496 client: Optional[Redis] = None, 497 ) -> None: 498 """ 499 Set a bunch of values in the cache at once from a dict of key/value 500 pairs. This is much more efficient than calling set() multiple times. 501 502 If timeout is given, that timeout will be used for the key; otherwise 503 the default cache timeout will be used. 504 """ 505 if client is None: 506 client = self.get_client(write=True) 507 508 try: 509 pipeline = client.pipeline() 510 for key, value in data.items(): 511 self.set(key, value, timeout, version=version, client=pipeline) 512 pipeline.execute() 513 except _main_exceptions as e: 514 raise ConnectionInterrupted(connection=client) from e 515 516 def _incr( 517 self, 518 key: Any, 519 delta: int = 1, 520 version: Optional[int] = None, 521 client: Optional[Redis] = None, 522 ignore_key_check: bool = False, 523 ) -> int: 524 if client is None: 525 client = self.get_client(write=True) 526 527 key = self.make_key(key, version=version) 528 529 try: 530 try: 531 # if key expired after exists check, then we get 532 # key with wrong value and ttl -1. 533 # use lua script for atomicity 534 if not ignore_key_check: 535 lua = """ 536 local exists = redis.call('EXISTS', KEYS[1]) 537 if (exists == 1) then 538 return redis.call('INCRBY', KEYS[1], ARGV[1]) 539 else return false end 540 """ 541 else: 542 lua = """ 543 return redis.call('INCRBY', KEYS[1], ARGV[1]) 544 """ 545 value = client.eval(lua, 1, key, delta) 546 if value is None: 547 raise ValueError("Key '%s' not found" % key) 548 except ResponseError: 549 # if cached value or total value is greater than 64 bit signed 550 # integer. 551 # elif int is encoded. so redis sees the data as string. 552 # In this situations redis will throw ResponseError 553 554 # try to keep TTL of key 555 timeout = self.ttl(key, version=version, client=client) 556 557 # returns -2 if the key does not exist 558 # means, that key have expired 559 if timeout == -2: 560 raise ValueError("Key '%s' not found" % key) 561 value = self.get(key, version=version, client=client) + delta 562 self.set(key, value, version=version, timeout=timeout, client=client) 563 except _main_exceptions as e: 564 raise ConnectionInterrupted(connection=client) from e 565 566 return value 567 568 def incr( 569 self, 570 key: Any, 571 delta: int = 1, 572 version: Optional[int] = None, 573 client: Optional[Redis] = None, 574 ignore_key_check: bool = False, 575 ) -> int: 576 """ 577 Add delta to value in the cache. If the key does not exist, raise a 578 ValueError exception. if ignore_key_check=True then the key will be 579 created and set to the delta value by default. 580 """ 581 return self._incr( 582 key=key, 583 delta=delta, 584 version=version, 585 client=client, 586 ignore_key_check=ignore_key_check, 587 ) 588 589 def decr( 590 self, 591 key: Any, 592 delta: int = 1, 593 version: Optional[int] = None, 594 client: Optional[Redis] = None, 595 ) -> int: 596 """ 597 Decreace delta to value in the cache. If the key does not exist, raise a 598 ValueError exception. 599 """ 600 return self._incr(key=key, delta=-delta, version=version, client=client) 601 602 def ttl( 603 self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None 604 ) -> Optional[int]: 605 """ 606 Executes TTL redis command and return the "time-to-live" of specified key. 607 If key is a non volatile key, it returns None. 608 """ 609 if client is None: 610 client = self.get_client(write=False) 611 612 key = self.make_key(key, version=version) 613 if not client.exists(key): 614 return 0 615 616 t = client.ttl(key) 617 618 if t >= 0: 619 return t 620 elif t == -1: 621 return None 622 elif t == -2: 623 return 0 624 else: 625 # Should never reach here 626 return None 627 628 def pttl(self, key, version=None, client=None): 629 """ 630 Executes PTTL redis command and return the "time-to-live" of specified key. 631 If key is a non volatile key, it returns None. 632 """ 633 if client is None: 634 client = self.get_client(write=False) 635 636 key = self.make_key(key, version=version) 637 if not client.exists(key): 638 return 0 639 640 t = client.pttl(key) 641 642 if t >= 0: 643 return t 644 elif t == -1: 645 return None 646 elif t == -2: 647 return 0 648 else: 649 # Should never reach here 650 return None 651 652 def has_key( 653 self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None 654 ) -> bool: 655 """ 656 Test if key exists. 657 """ 658 659 if client is None: 660 client = self.get_client(write=False) 661 662 key = self.make_key(key, version=version) 663 try: 664 return client.exists(key) == 1 665 except _main_exceptions as e: 666 raise ConnectionInterrupted(connection=client) from e 667 668 def iter_keys( 669 self, 670 search: str, 671 itersize: Optional[int] = None, 672 client: Optional[Redis] = None, 673 version: Optional[int] = None, 674 ) -> Iterator[str]: 675 """ 676 Same as keys, but uses redis >= 2.8 cursors 677 for make memory efficient keys iteration. 678 """ 679 680 if client is None: 681 client = self.get_client(write=False) 682 683 pattern = self.make_pattern(search, version=version) 684 for item in client.scan_iter(match=pattern, count=itersize): 685 yield self.reverse_key(item.decode()) 686 687 def keys( 688 self, search: str, version: Optional[int] = None, client: Optional[Redis] = None 689 ) -> List[Any]: 690 """ 691 Execute KEYS command and return matched results. 692 Warning: this can return huge number of results, in 693 this case, it strongly recommended use iter_keys 694 for it. 695 """ 696 697 if client is None: 698 client = self.get_client(write=False) 699 700 pattern = self.make_pattern(search, version=version) 701 try: 702 return [self.reverse_key(k.decode()) for k in client.keys(pattern)] 703 except _main_exceptions as e: 704 raise ConnectionInterrupted(connection=client) from e 705 706 def make_key( 707 self, key: Any, version: Optional[Any] = None, prefix: Optional[str] = None 708 ) -> CacheKey: 709 if isinstance(key, CacheKey): 710 return key 711 712 if prefix is None: 713 prefix = self._backend.key_prefix 714 715 if version is None: 716 version = self._backend.version 717 718 return CacheKey(self._backend.key_func(key, prefix, version)) 719 720 def make_pattern( 721 self, pattern: str, version: Optional[int] = None, prefix: Optional[str] = None 722 ) -> CacheKey: 723 if isinstance(pattern, CacheKey): 724 return pattern 725 726 if prefix is None: 727 prefix = self._backend.key_prefix 728 prefix = glob_escape(prefix) 729 730 if version is None: 731 version = self._backend.version 732 version_str = glob_escape(str(version)) 733 734 return CacheKey(self._backend.key_func(pattern, prefix, version_str)) 735 736 def close(self, **kwargs): 737 close_flag = self._options.get( 738 "CLOSE_CONNECTION", 739 getattr(settings, "DJANGO_REDIS_CLOSE_CONNECTION", False), 740 ) 741 if close_flag: 742 self.do_close_clients() 743 744 def do_close_clients(self): 745 """default implementation: Override in custom client""" 746 num_clients = len(self._clients) 747 for idx in range(num_clients): 748 self.disconnect(index=idx) 749 self._clients = [None] * num_clients 750 751 def touch( 752 self, 753 key: Any, 754 timeout: Optional[float] = DEFAULT_TIMEOUT, 755 version: Optional[int] = None, 756 client: Optional[Redis] = None, 757 ) -> bool: 758 """ 759 Sets a new expiration for a key. 760 """ 761 762 if timeout is DEFAULT_TIMEOUT: 763 timeout = self._backend.default_timeout 764 765 if client is None: 766 client = self.get_client(write=True) 767 768 key = self.make_key(key, version=version) 769 if timeout is None: 770 return bool(client.persist(key)) 771 else: 772 # Convert to milliseconds 773 timeout = int(timeout * 1000) 774 return bool(client.pexpire(key, timeout)) 775