1import threading
2import time as mod_time
3import uuid
4from redis.exceptions import LockError, WatchError
5from redis.utils import dummy
6from redis._compat import b
7
8
9class Lock(object):
10    """
11    A shared, distributed Lock. Using Redis for locking allows the Lock
12    to be shared across processes and/or machines.
13
14    It's left to the user to resolve deadlock issues and make sure
15    multiple clients play nicely together.
16    """
17    def __init__(self, redis, name, timeout=None, sleep=0.1,
18                 blocking=True, blocking_timeout=None, thread_local=True):
19        """
20        Create a new Lock instance named ``name`` using the Redis client
21        supplied by ``redis``.
22
23        ``timeout`` indicates a maximum life for the lock.
24        By default, it will remain locked until release() is called.
25        ``timeout`` can be specified as a float or integer, both representing
26        the number of seconds to wait.
27
28        ``sleep`` indicates the amount of time to sleep per loop iteration
29        when the lock is in blocking mode and another client is currently
30        holding the lock.
31
32        ``blocking`` indicates whether calling ``acquire`` should block until
33        the lock has been acquired or to fail immediately, causing ``acquire``
34        to return False and the lock not being acquired. Defaults to True.
35        Note this value can be overridden by passing a ``blocking``
36        argument to ``acquire``.
37
38        ``blocking_timeout`` indicates the maximum amount of time in seconds to
39        spend trying to acquire the lock. A value of ``None`` indicates
40        continue trying forever. ``blocking_timeout`` can be specified as a
41        float or integer, both representing the number of seconds to wait.
42
43        ``thread_local`` indicates whether the lock token is placed in
44        thread-local storage. By default, the token is placed in thread local
45        storage so that a thread only sees its token, not a token set by
46        another thread. Consider the following timeline:
47
48            time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
49                     thread-1 sets the token to "abc"
50            time: 1, thread-2 blocks trying to acquire `my-lock` using the
51                     Lock instance.
52            time: 5, thread-1 has not yet completed. redis expires the lock
53                     key.
54            time: 5, thread-2 acquired `my-lock` now that it's available.
55                     thread-2 sets the token to "xyz"
56            time: 6, thread-1 finishes its work and calls release(). if the
57                     token is *not* stored in thread local storage, then
58                     thread-1 would see the token value as "xyz" and would be
59                     able to successfully release the thread-2's lock.
60
61        In some use cases it's necessary to disable thread local storage. For
62        example, if you have code where one thread acquires a lock and passes
63        that lock instance to a worker thread to release later. If thread
64        local storage isn't disabled in this case, the worker thread won't see
65        the token set by the thread that acquired the lock. Our assumption
66        is that these cases aren't common and as such default to using
67        thread local storage.
68        """
69        self.redis = redis
70        self.name = name
71        self.timeout = timeout
72        self.sleep = sleep
73        self.blocking = blocking
74        self.blocking_timeout = blocking_timeout
75        self.thread_local = bool(thread_local)
76        self.local = threading.local() if self.thread_local else dummy()
77        self.local.token = None
78        if self.timeout and self.sleep > self.timeout:
79            raise LockError("'sleep' must be less than 'timeout'")
80
81    def __enter__(self):
82        # force blocking, as otherwise the user would have to check whether
83        # the lock was actually acquired or not.
84        self.acquire(blocking=True)
85        return self
86
87    def __exit__(self, exc_type, exc_value, traceback):
88        self.release()
89
90    def acquire(self, blocking=None, blocking_timeout=None):
91        """
92        Use Redis to hold a shared, distributed lock named ``name``.
93        Returns True once the lock is acquired.
94
95        If ``blocking`` is False, always return immediately. If the lock
96        was acquired, return True, otherwise return False.
97
98        ``blocking_timeout`` specifies the maximum number of seconds to
99        wait trying to acquire the lock.
100        """
101        sleep = self.sleep
102        token = b(uuid.uuid1().hex)
103        if blocking is None:
104            blocking = self.blocking
105        if blocking_timeout is None:
106            blocking_timeout = self.blocking_timeout
107        stop_trying_at = None
108        if blocking_timeout is not None:
109            stop_trying_at = mod_time.time() + blocking_timeout
110        while 1:
111            if self.do_acquire(token):
112                self.local.token = token
113                return True
114            if not blocking:
115                return False
116            if stop_trying_at is not None and mod_time.time() > stop_trying_at:
117                return False
118            mod_time.sleep(sleep)
119
120    def do_acquire(self, token):
121        if self.redis.setnx(self.name, token):
122            if self.timeout:
123                # convert to milliseconds
124                timeout = int(self.timeout * 1000)
125                self.redis.pexpire(self.name, timeout)
126            return True
127        return False
128
129    def release(self):
130        "Releases the already acquired lock"
131        expected_token = self.local.token
132        if expected_token is None:
133            raise LockError("Cannot release an unlocked lock")
134        self.local.token = None
135        self.do_release(expected_token)
136
137    def do_release(self, expected_token):
138        name = self.name
139
140        def execute_release(pipe):
141            lock_value = pipe.get(name)
142            if lock_value != expected_token:
143                raise LockError("Cannot release a lock that's no longer owned")
144            pipe.delete(name)
145
146        self.redis.transaction(execute_release, name)
147
148    def extend(self, additional_time):
149        """
150        Adds more time to an already acquired lock.
151
152        ``additional_time`` can be specified as an integer or a float, both
153        representing the number of seconds to add.
154        """
155        if self.local.token is None:
156            raise LockError("Cannot extend an unlocked lock")
157        if self.timeout is None:
158            raise LockError("Cannot extend a lock with no timeout")
159        return self.do_extend(additional_time)
160
161    def do_extend(self, additional_time):
162        pipe = self.redis.pipeline()
163        pipe.watch(self.name)
164        lock_value = pipe.get(self.name)
165        if lock_value != self.local.token:
166            raise LockError("Cannot extend a lock that's no longer owned")
167        expiration = pipe.pttl(self.name)
168        if expiration is None or expiration < 0:
169            # Redis evicted the lock key between the previous get() and now
170            # we'll handle this when we call pexpire()
171            expiration = 0
172        pipe.multi()
173        pipe.pexpire(self.name, expiration + int(additional_time * 1000))
174
175        try:
176            response = pipe.execute()
177        except WatchError:
178            # someone else acquired the lock
179            raise LockError("Cannot extend a lock that's no longer owned")
180        if not response[0]:
181            # pexpire returns False if the key doesn't exist
182            raise LockError("Cannot extend a lock that's no longer owned")
183        return True
184
185
186class LuaLock(Lock):
187    """
188    A lock implementation that uses Lua scripts rather than pipelines
189    and watches.
190    """
191    lua_acquire = None
192    lua_release = None
193    lua_extend = None
194
195    # KEYS[1] - lock name
196    # ARGV[1] - token
197    # ARGV[2] - timeout in milliseconds
198    # return 1 if lock was acquired, otherwise 0
199    LUA_ACQUIRE_SCRIPT = """
200        if redis.call('setnx', KEYS[1], ARGV[1]) == 1 then
201            if ARGV[2] ~= '' then
202                redis.call('pexpire', KEYS[1], ARGV[2])
203            end
204            return 1
205        end
206        return 0
207    """
208
209    # KEYS[1] - lock name
210    # ARGS[1] - token
211    # return 1 if the lock was released, otherwise 0
212    LUA_RELEASE_SCRIPT = """
213        local token = redis.call('get', KEYS[1])
214        if not token or token ~= ARGV[1] then
215            return 0
216        end
217        redis.call('del', KEYS[1])
218        return 1
219    """
220
221    # KEYS[1] - lock name
222    # ARGS[1] - token
223    # ARGS[2] - additional milliseconds
224    # return 1 if the locks time was extended, otherwise 0
225    LUA_EXTEND_SCRIPT = """
226        local token = redis.call('get', KEYS[1])
227        if not token or token ~= ARGV[1] then
228            return 0
229        end
230        local expiration = redis.call('pttl', KEYS[1])
231        if not expiration then
232            expiration = 0
233        end
234        if expiration < 0 then
235            return 0
236        end
237        redis.call('pexpire', KEYS[1], expiration + ARGV[2])
238        return 1
239    """
240
241    def __init__(self, *args, **kwargs):
242        super(LuaLock, self).__init__(*args, **kwargs)
243        LuaLock.register_scripts(self.redis)
244
245    @classmethod
246    def register_scripts(cls, redis):
247        if cls.lua_acquire is None:
248            cls.lua_acquire = redis.register_script(cls.LUA_ACQUIRE_SCRIPT)
249        if cls.lua_release is None:
250            cls.lua_release = redis.register_script(cls.LUA_RELEASE_SCRIPT)
251        if cls.lua_extend is None:
252            cls.lua_extend = redis.register_script(cls.LUA_EXTEND_SCRIPT)
253
254    def do_acquire(self, token):
255        timeout = self.timeout and int(self.timeout * 1000) or ''
256        return bool(self.lua_acquire(keys=[self.name],
257                                     args=[token, timeout],
258                                     client=self.redis))
259
260    def do_release(self, expected_token):
261        if not bool(self.lua_release(keys=[self.name],
262                                     args=[expected_token],
263                                     client=self.redis)):
264            raise LockError("Cannot release a lock that's no longer owned")
265
266    def do_extend(self, additional_time):
267        additional_time = int(additional_time * 1000)
268        if not bool(self.lua_extend(keys=[self.name],
269                                    args=[self.local.token, additional_time],
270                                    client=self.redis)):
271            raise LockError("Cannot extend a lock that's no longer owned")
272        return True
273