1"""Redis database engine.""" 2 3import time 4import logging 5import datetime 6import functools 7 8try: 9 import redis 10 _has_redis = True 11except ImportError: 12 redis = None 13 _has_redis = False 14 15from pyzor.engines.common import * 16 17VERSION = "1" 18NAMESPACE = "pyzord.digest_v%s" % VERSION 19 20 21def encode_date(date): 22 """Convert the date to Unix Timestamp""" 23 if date is None: 24 return 0 25 return int(time.mktime(date.timetuple())) 26 27 28def decode_date(stamp): 29 """Return a datetime object from a Unix Timestamp.""" 30 stamp = int(stamp) 31 if stamp == 0: 32 return None 33 return datetime.datetime.fromtimestamp(stamp) 34 35 36def safe_call(f): 37 """Decorator that wraps a method for handling database operations.""" 38 39 def wrapped_f(self, *args, **kwargs): 40 # This only logs the error and raise the usual Error for consistency, 41 # the redis library takes care of reconnecting and everything else. 42 try: 43 return f(self, *args, **kwargs) 44 except redis.exceptions.RedisError as e: 45 self.log.error("Redis error while calling %s: %s", 46 f.__name__, e) 47 raise DatabaseError("Database temporarily unavailable.") 48 49 return wrapped_f 50 51 52class RedisDBHandle(BaseEngine): 53 absolute_source = False 54 handles_one_step = True 55 56 log = logging.getLogger("pyzord") 57 58 def __init__(self, fn, mode, max_age=None): 59 self.max_age = max_age 60 # The 'fn' is host,port,password,db. We ignore mode. 61 # We store the authentication details so that we can reconnect if 62 # necessary. 63 self._dsn = fn 64 fn = fn.split(",") 65 self.host = fn[0] or "localhost" 66 self.port = fn[1] or "6379" 67 self.passwd = fn[2] or None 68 self.db_name = fn[3] or "0" 69 self.db = self._get_new_connection() 70 self._check_version() 71 72 @staticmethod 73 def _encode_record(r): 74 return {"r_count": r.r_count, 75 "r_entered": encode_date(r.r_entered), 76 "r_updated": encode_date(r.r_updated), 77 "wl_count": r.wl_count, 78 "wl_entered": encode_date(r.wl_entered), 79 "wl_updated": encode_date(r.wl_updated) 80 } 81 82 @staticmethod 83 def _decode_record(r): 84 if not r: 85 return Record() 86 return Record(r_count=int(r.get(b"r_count", 0)), 87 r_entered=decode_date(r.get(b"r_entered", 0)), 88 r_updated=decode_date(r.get(b"r_updated", 0)), 89 wl_count=int(r.get(b"wl_count", 0)), 90 wl_entered=decode_date(r.get(b"wl_entered", 0)), 91 wl_updated=decode_date(r.get(b"wl_updated", 0))) 92 93 def __iter__(self): 94 for key in self.db.keys(self._real_key("*")): 95 yield key.rsplit(".", 1)[-1] 96 97 def _iteritems(self): 98 for key in self: 99 try: 100 yield key, self[key] 101 except Exception as ex: 102 self.log.warning("Invalid record %s: %s", key, ex) 103 104 def iteritems(self): 105 return self._iteritems() 106 107 def items(self): 108 return list(self._iteritems()) 109 110 @staticmethod 111 def _real_key(key): 112 return "%s.%s" % (NAMESPACE, key) 113 114 @safe_call 115 def _get_new_connection(self): 116 if "/" in self.host: 117 return redis.StrictRedis(unix_socket_path=self.host, 118 db=int(self.db_name), password=self.passwd) 119 return redis.StrictRedis(host=self.host, port=int(self.port), 120 db=int(self.db_name), password=self.passwd) 121 122 @safe_call 123 def __getitem__(self, key): 124 return self._decode_record(self.db.hgetall(self._real_key(key))) 125 126 @safe_call 127 def __setitem__(self, key, value): 128 real_key = self._real_key(key) 129 self.db.hmset(real_key, self._encode_record(value)) 130 if self.max_age is not None: 131 self.db.expire(real_key, self.max_age) 132 133 @safe_call 134 def __delitem__(self, key): 135 self.db.delete(self._real_key(key)) 136 137 @safe_call 138 def report(self, keys): 139 now = int(time.time()) 140 for key in keys: 141 real_key = self._real_key(key) 142 self.db.hincrby(real_key, "r_count") 143 self.db.hsetnx(real_key, "r_entered", now) 144 self.db.hset(real_key, "r_updated", now) 145 if self.max_age: 146 self.db.expire(real_key, self.max_age) 147 148 @safe_call 149 def whitelist(self, keys): 150 now = int(time.time()) 151 for key in keys: 152 real_key = self._real_key(key) 153 self.db.hincrby(real_key, "wl_count") 154 self.db.hsetnx(real_key, "wl_entered", now) 155 self.db.hset(real_key, "wl_updated", now) 156 if self.max_age: 157 self.db.expire(real_key, self.max_age) 158 159 @classmethod 160 def get_prefork_connections(cls, fn, mode, max_age=None): 161 """Yields a number of database connections suitable for a Pyzor 162 pre-fork server. 163 """ 164 while True: 165 yield functools.partial(cls, fn, mode, max_age=max_age) 166 167 def _check_version(self): 168 """Check if there are deprecated records and warn the user.""" 169 old_keys = len(self.db.keys("pyzord.digest.*")) 170 if old_keys: 171 cmd = ("pyzor-migrate --delete --se=redis_v0 --sd=%s " 172 "--de=redis --dd=%s" % (self._dsn, self._dsn)) 173 self.log.critical("You have %s records in the deprecated version " 174 "of the redis engine.", old_keys) 175 self.log.critical("Please migrate the records with: %r", cmd) 176 177 178class ThreadedRedisDBHandle(RedisDBHandle): 179 def __init__(self, fn, mode, max_age=None, bound=None): 180 RedisDBHandle.__init__(self, fn, mode, max_age=max_age) 181 182 183if not _has_redis: 184 handle = DBHandle(single_threaded=None, 185 multi_threaded=None, 186 multi_processing=None, 187 prefork=None) 188else: 189 handle = DBHandle(single_threaded=RedisDBHandle, 190 multi_threaded=ThreadedRedisDBHandle, 191 multi_processing=None, 192 prefork=RedisDBHandle) 193