1"""Redis database engine."""
2
3import time
4import logging
5import datetime
6import functools
7
8try:
9    import redis
10    _has_redis = True
11except ImportError:
12    redis = None
13    _has_redis = False
14
15from pyzor.engines.common import *
16
17VERSION = "1"
18NAMESPACE = "pyzord.digest_v%s" % VERSION
19
20
21def encode_date(date):
22    """Convert the date to Unix Timestamp"""
23    if date is None:
24        return 0
25    return int(time.mktime(date.timetuple()))
26
27
28def decode_date(stamp):
29    """Return a datetime object from a Unix Timestamp."""
30    stamp = int(stamp)
31    if stamp == 0:
32        return None
33    return datetime.datetime.fromtimestamp(stamp)
34
35
36def safe_call(f):
37    """Decorator that wraps a method for handling database operations."""
38
39    def wrapped_f(self, *args, **kwargs):
40        # This only logs the error and raise the usual Error for consistency,
41        # the redis library takes care of reconnecting and everything else.
42        try:
43            return f(self, *args, **kwargs)
44        except redis.exceptions.RedisError as e:
45            self.log.error("Redis error while calling %s: %s",
46                           f.__name__, e)
47            raise DatabaseError("Database temporarily unavailable.")
48
49    return wrapped_f
50
51
52class RedisDBHandle(BaseEngine):
53    absolute_source = False
54    handles_one_step = True
55
56    log = logging.getLogger("pyzord")
57
58    def __init__(self, fn, mode, max_age=None):
59        self.max_age = max_age
60        # The 'fn' is host,port,password,db.  We ignore mode.
61        # We store the authentication details so that we can reconnect if
62        # necessary.
63        self._dsn = fn
64        fn = fn.split(",")
65        self.host = fn[0] or "localhost"
66        self.port = fn[1] or "6379"
67        self.passwd = fn[2] or None
68        self.db_name = fn[3] or "0"
69        self.db = self._get_new_connection()
70        self._check_version()
71
72    @staticmethod
73    def _encode_record(r):
74        return {"r_count": r.r_count,
75                "r_entered": encode_date(r.r_entered),
76                "r_updated": encode_date(r.r_updated),
77                "wl_count": r.wl_count,
78                "wl_entered": encode_date(r.wl_entered),
79                "wl_updated": encode_date(r.wl_updated)
80                }
81
82    @staticmethod
83    def _decode_record(r):
84        if not r:
85            return Record()
86        return Record(r_count=int(r.get(b"r_count", 0)),
87                      r_entered=decode_date(r.get(b"r_entered", 0)),
88                      r_updated=decode_date(r.get(b"r_updated", 0)),
89                      wl_count=int(r.get(b"wl_count", 0)),
90                      wl_entered=decode_date(r.get(b"wl_entered", 0)),
91                      wl_updated=decode_date(r.get(b"wl_updated", 0)))
92
93    def __iter__(self):
94        for key in self.db.keys(self._real_key("*")):
95            yield key.rsplit(".", 1)[-1]
96
97    def _iteritems(self):
98        for key in self:
99            try:
100                yield key, self[key]
101            except Exception as ex:
102                self.log.warning("Invalid record %s: %s", key, ex)
103
104    def iteritems(self):
105        return self._iteritems()
106
107    def items(self):
108        return list(self._iteritems())
109
110    @staticmethod
111    def _real_key(key):
112        return "%s.%s" % (NAMESPACE, key)
113
114    @safe_call
115    def _get_new_connection(self):
116        if "/" in self.host:
117            return redis.StrictRedis(unix_socket_path=self.host,
118                                     db=int(self.db_name), password=self.passwd)
119        return redis.StrictRedis(host=self.host, port=int(self.port),
120                                 db=int(self.db_name), password=self.passwd)
121
122    @safe_call
123    def __getitem__(self, key):
124        return self._decode_record(self.db.hgetall(self._real_key(key)))
125
126    @safe_call
127    def __setitem__(self, key, value):
128        real_key = self._real_key(key)
129        self.db.hmset(real_key, self._encode_record(value))
130        if self.max_age is not None:
131            self.db.expire(real_key, self.max_age)
132
133    @safe_call
134    def __delitem__(self, key):
135        self.db.delete(self._real_key(key))
136
137    @safe_call
138    def report(self, keys):
139        now = int(time.time())
140        for key in keys:
141            real_key = self._real_key(key)
142            self.db.hincrby(real_key, "r_count")
143            self.db.hsetnx(real_key, "r_entered", now)
144            self.db.hset(real_key, "r_updated", now)
145            if self.max_age:
146                self.db.expire(real_key, self.max_age)
147
148    @safe_call
149    def whitelist(self, keys):
150        now = int(time.time())
151        for key in keys:
152            real_key = self._real_key(key)
153            self.db.hincrby(real_key, "wl_count")
154            self.db.hsetnx(real_key, "wl_entered", now)
155            self.db.hset(real_key, "wl_updated", now)
156            if self.max_age:
157                self.db.expire(real_key, self.max_age)
158
159    @classmethod
160    def get_prefork_connections(cls, fn, mode, max_age=None):
161        """Yields a number of database connections suitable for a Pyzor
162        pre-fork server.
163        """
164        while True:
165            yield functools.partial(cls, fn, mode, max_age=max_age)
166
167    def _check_version(self):
168        """Check if there are deprecated records and warn the user."""
169        old_keys = len(self.db.keys("pyzord.digest.*"))
170        if old_keys:
171            cmd = ("pyzor-migrate --delete --se=redis_v0 --sd=%s "
172                   "--de=redis --dd=%s" % (self._dsn, self._dsn))
173            self.log.critical("You have %s records in the deprecated version "
174                              "of the redis engine.", old_keys)
175            self.log.critical("Please migrate the records with: %r", cmd)
176
177
178class ThreadedRedisDBHandle(RedisDBHandle):
179    def __init__(self, fn, mode, max_age=None, bound=None):
180        RedisDBHandle.__init__(self, fn, mode, max_age=max_age)
181
182
183if not _has_redis:
184    handle = DBHandle(single_threaded=None,
185                      multi_threaded=None,
186                      multi_processing=None,
187                      prefork=None)
188else:
189    handle = DBHandle(single_threaded=RedisDBHandle,
190                      multi_threaded=ThreadedRedisDBHandle,
191                      multi_processing=None,
192                      prefork=RedisDBHandle)
193