1import threading
2import time as mod_time
3import uuid
4from redis.exceptions import LockError, LockNotOwnedError
5from redis.utils import dummy
6
7
8class Lock(object):
9    """
10    A shared, distributed Lock. Using Redis for locking allows the Lock
11    to be shared across processes and/or machines.
12
13    It's left to the user to resolve deadlock issues and make sure
14    multiple clients play nicely together.
15    """
16
17    lua_release = None
18    lua_extend = None
19    lua_reacquire = None
20
21    # KEYS[1] - lock name
22    # ARGV[1] - token
23    # return 1 if the lock was released, otherwise 0
24    LUA_RELEASE_SCRIPT = """
25        local token = redis.call('get', KEYS[1])
26        if not token or token ~= ARGV[1] then
27            return 0
28        end
29        redis.call('del', KEYS[1])
30        return 1
31    """
32
33    # KEYS[1] - lock name
34    # ARGV[1] - token
35    # ARGV[2] - additional milliseconds
36    # ARGV[3] - "0" if the additional time should be added to the lock's
37    #           existing ttl or "1" if the existing ttl should be replaced
38    # return 1 if the locks time was extended, otherwise 0
39    LUA_EXTEND_SCRIPT = """
40        local token = redis.call('get', KEYS[1])
41        if not token or token ~= ARGV[1] then
42            return 0
43        end
44        local expiration = redis.call('pttl', KEYS[1])
45        if not expiration then
46            expiration = 0
47        end
48        if expiration < 0 then
49            return 0
50        end
51
52        local newttl = ARGV[2]
53        if ARGV[3] == "0" then
54            newttl = ARGV[2] + expiration
55        end
56        redis.call('pexpire', KEYS[1], newttl)
57        return 1
58    """
59
60    # KEYS[1] - lock name
61    # ARGV[1] - token
62    # ARGV[2] - milliseconds
63    # return 1 if the locks time was reacquired, otherwise 0
64    LUA_REACQUIRE_SCRIPT = """
65        local token = redis.call('get', KEYS[1])
66        if not token or token ~= ARGV[1] then
67            return 0
68        end
69        redis.call('pexpire', KEYS[1], ARGV[2])
70        return 1
71    """
72
73    def __init__(self, redis, name, timeout=None, sleep=0.1,
74                 blocking=True, blocking_timeout=None, thread_local=True):
75        """
76        Create a new Lock instance named ``name`` using the Redis client
77        supplied by ``redis``.
78
79        ``timeout`` indicates a maximum life for the lock.
80        By default, it will remain locked until release() is called.
81        ``timeout`` can be specified as a float or integer, both representing
82        the number of seconds to wait.
83
84        ``sleep`` indicates the amount of time to sleep per loop iteration
85        when the lock is in blocking mode and another client is currently
86        holding the lock.
87
88        ``blocking`` indicates whether calling ``acquire`` should block until
89        the lock has been acquired or to fail immediately, causing ``acquire``
90        to return False and the lock not being acquired. Defaults to True.
91        Note this value can be overridden by passing a ``blocking``
92        argument to ``acquire``.
93
94        ``blocking_timeout`` indicates the maximum amount of time in seconds to
95        spend trying to acquire the lock. A value of ``None`` indicates
96        continue trying forever. ``blocking_timeout`` can be specified as a
97        float or integer, both representing the number of seconds to wait.
98
99        ``thread_local`` indicates whether the lock token is placed in
100        thread-local storage. By default, the token is placed in thread local
101        storage so that a thread only sees its token, not a token set by
102        another thread. Consider the following timeline:
103
104            time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
105                     thread-1 sets the token to "abc"
106            time: 1, thread-2 blocks trying to acquire `my-lock` using the
107                     Lock instance.
108            time: 5, thread-1 has not yet completed. redis expires the lock
109                     key.
110            time: 5, thread-2 acquired `my-lock` now that it's available.
111                     thread-2 sets the token to "xyz"
112            time: 6, thread-1 finishes its work and calls release(). if the
113                     token is *not* stored in thread local storage, then
114                     thread-1 would see the token value as "xyz" and would be
115                     able to successfully release the thread-2's lock.
116
117        In some use cases it's necessary to disable thread local storage. For
118        example, if you have code where one thread acquires a lock and passes
119        that lock instance to a worker thread to release later. If thread
120        local storage isn't disabled in this case, the worker thread won't see
121        the token set by the thread that acquired the lock. Our assumption
122        is that these cases aren't common and as such default to using
123        thread local storage.
124        """
125        self.redis = redis
126        self.name = name
127        self.timeout = timeout
128        self.sleep = sleep
129        self.blocking = blocking
130        self.blocking_timeout = blocking_timeout
131        self.thread_local = bool(thread_local)
132        self.local = threading.local() if self.thread_local else dummy()
133        self.local.token = None
134        self.register_scripts()
135
136    def register_scripts(self):
137        cls = self.__class__
138        client = self.redis
139        if cls.lua_release is None:
140            cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
141        if cls.lua_extend is None:
142            cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
143        if cls.lua_reacquire is None:
144            cls.lua_reacquire = \
145                client.register_script(cls.LUA_REACQUIRE_SCRIPT)
146
147    def __enter__(self):
148        # force blocking, as otherwise the user would have to check whether
149        # the lock was actually acquired or not.
150        if self.acquire(blocking=True):
151            return self
152        raise LockError("Unable to acquire lock within the time specified")
153
154    def __exit__(self, exc_type, exc_value, traceback):
155        self.release()
156
157    def acquire(self, blocking=None, blocking_timeout=None, token=None):
158        """
159        Use Redis to hold a shared, distributed lock named ``name``.
160        Returns True once the lock is acquired.
161
162        If ``blocking`` is False, always return immediately. If the lock
163        was acquired, return True, otherwise return False.
164
165        ``blocking_timeout`` specifies the maximum number of seconds to
166        wait trying to acquire the lock.
167
168        ``token`` specifies the token value to be used. If provided, token
169        must be a bytes object or a string that can be encoded to a bytes
170        object with the default encoding. If a token isn't specified, a UUID
171        will be generated.
172        """
173        sleep = self.sleep
174        if token is None:
175            token = uuid.uuid1().hex.encode()
176        else:
177            encoder = self.redis.connection_pool.get_encoder()
178            token = encoder.encode(token)
179        if blocking is None:
180            blocking = self.blocking
181        if blocking_timeout is None:
182            blocking_timeout = self.blocking_timeout
183        stop_trying_at = None
184        if blocking_timeout is not None:
185            stop_trying_at = mod_time.time() + blocking_timeout
186        while True:
187            if self.do_acquire(token):
188                self.local.token = token
189                return True
190            if not blocking:
191                return False
192            next_try_at = mod_time.time() + sleep
193            if stop_trying_at is not None and next_try_at > stop_trying_at:
194                return False
195            mod_time.sleep(sleep)
196
197    def do_acquire(self, token):
198        if self.timeout:
199            # convert to milliseconds
200            timeout = int(self.timeout * 1000)
201        else:
202            timeout = None
203        if self.redis.set(self.name, token, nx=True, px=timeout):
204            return True
205        return False
206
207    def locked(self):
208        """
209        Returns True if this key is locked by any process, otherwise False.
210        """
211        return self.redis.get(self.name) is not None
212
213    def owned(self):
214        """
215        Returns True if this key is locked by this lock, otherwise False.
216        """
217        stored_token = self.redis.get(self.name)
218        # need to always compare bytes to bytes
219        # TODO: this can be simplified when the context manager is finished
220        if stored_token and not isinstance(stored_token, bytes):
221            encoder = self.redis.connection_pool.get_encoder()
222            stored_token = encoder.encode(stored_token)
223        return self.local.token is not None and \
224            stored_token == self.local.token
225
226    def release(self):
227        "Releases the already acquired lock"
228        expected_token = self.local.token
229        if expected_token is None:
230            raise LockError("Cannot release an unlocked lock")
231        self.local.token = None
232        self.do_release(expected_token)
233
234    def do_release(self, expected_token):
235        if not bool(self.lua_release(keys=[self.name],
236                                     args=[expected_token],
237                                     client=self.redis)):
238            raise LockNotOwnedError("Cannot release a lock"
239                                    " that's no longer owned")
240
241    def extend(self, additional_time, replace_ttl=False):
242        """
243        Adds more time to an already acquired lock.
244
245        ``additional_time`` can be specified as an integer or a float, both
246        representing the number of seconds to add.
247
248        ``replace_ttl`` if False (the default), add `additional_time` to
249        the lock's existing ttl. If True, replace the lock's ttl with
250        `additional_time`.
251        """
252        if self.local.token is None:
253            raise LockError("Cannot extend an unlocked lock")
254        if self.timeout is None:
255            raise LockError("Cannot extend a lock with no timeout")
256        return self.do_extend(additional_time, replace_ttl)
257
258    def do_extend(self, additional_time, replace_ttl):
259        additional_time = int(additional_time * 1000)
260        if not bool(
261            self.lua_extend(
262                keys=[self.name],
263                args=[
264                    self.local.token,
265                    additional_time,
266                    replace_ttl and "1" or "0"
267                ],
268                client=self.redis,
269            )
270        ):
271            raise LockNotOwnedError(
272                "Cannot extend a lock that's" " no longer owned"
273            )
274        return True
275
276    def reacquire(self):
277        """
278        Resets a TTL of an already acquired lock back to a timeout value.
279        """
280        if self.local.token is None:
281            raise LockError("Cannot reacquire an unlocked lock")
282        if self.timeout is None:
283            raise LockError("Cannot reacquire a lock with no timeout")
284        return self.do_reacquire()
285
286    def do_reacquire(self):
287        timeout = int(self.timeout * 1000)
288        if not bool(self.lua_reacquire(keys=[self.name],
289                                       args=[self.local.token, timeout],
290                                       client=self.redis)):
291            raise LockNotOwnedError("Cannot reacquire a lock that's"
292                                    " no longer owned")
293        return True
294