1from __future__ import with_statement
2import pytest
3import time
4
5from redis.exceptions import LockError, ResponseError
6from redis.lock import Lock, LuaLock
7
8
9class TestLock(object):
10    lock_class = Lock
11
12    def get_lock(self, redis, *args, **kwargs):
13        kwargs['lock_class'] = self.lock_class
14        return redis.lock(*args, **kwargs)
15
16    def test_lock(self, sr):
17        lock = self.get_lock(sr, 'foo')
18        assert lock.acquire(blocking=False)
19        assert sr.get('foo') == lock.local.token
20        assert sr.ttl('foo') == -1
21        lock.release()
22        assert sr.get('foo') is None
23
24    def test_competing_locks(self, sr):
25        lock1 = self.get_lock(sr, 'foo')
26        lock2 = self.get_lock(sr, 'foo')
27        assert lock1.acquire(blocking=False)
28        assert not lock2.acquire(blocking=False)
29        lock1.release()
30        assert lock2.acquire(blocking=False)
31        assert not lock1.acquire(blocking=False)
32        lock2.release()
33
34    def test_timeout(self, sr):
35        lock = self.get_lock(sr, 'foo', timeout=10)
36        assert lock.acquire(blocking=False)
37        assert 8 < sr.ttl('foo') <= 10
38        lock.release()
39
40    def test_float_timeout(self, sr):
41        lock = self.get_lock(sr, 'foo', timeout=9.5)
42        assert lock.acquire(blocking=False)
43        assert 8 < sr.pttl('foo') <= 9500
44        lock.release()
45
46    def test_blocking_timeout(self, sr):
47        lock1 = self.get_lock(sr, 'foo')
48        assert lock1.acquire(blocking=False)
49        lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2)
50        start = time.time()
51        assert not lock2.acquire()
52        assert (time.time() - start) > 0.2
53        lock1.release()
54
55    def test_context_manager(self, sr):
56        # blocking_timeout prevents a deadlock if the lock can't be acquired
57        # for some reason
58        with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock:
59            assert sr.get('foo') == lock.local.token
60        assert sr.get('foo') is None
61
62    def test_high_sleep_raises_error(self, sr):
63        "If sleep is higher than timeout, it should raise an error"
64        with pytest.raises(LockError):
65            self.get_lock(sr, 'foo', timeout=1, sleep=2)
66
67    def test_releasing_unlocked_lock_raises_error(self, sr):
68        lock = self.get_lock(sr, 'foo')
69        with pytest.raises(LockError):
70            lock.release()
71
72    def test_releasing_lock_no_longer_owned_raises_error(self, sr):
73        lock = self.get_lock(sr, 'foo')
74        lock.acquire(blocking=False)
75        # manually change the token
76        sr.set('foo', 'a')
77        with pytest.raises(LockError):
78            lock.release()
79        # even though we errored, the token is still cleared
80        assert lock.local.token is None
81
82    def test_extend_lock(self, sr):
83        lock = self.get_lock(sr, 'foo', timeout=10)
84        assert lock.acquire(blocking=False)
85        assert 8000 < sr.pttl('foo') <= 10000
86        assert lock.extend(10)
87        assert 16000 < sr.pttl('foo') <= 20000
88        lock.release()
89
90    def test_extend_lock_float(self, sr):
91        lock = self.get_lock(sr, 'foo', timeout=10.0)
92        assert lock.acquire(blocking=False)
93        assert 8000 < sr.pttl('foo') <= 10000
94        assert lock.extend(10.0)
95        assert 16000 < sr.pttl('foo') <= 20000
96        lock.release()
97
98    def test_extending_unlocked_lock_raises_error(self, sr):
99        lock = self.get_lock(sr, 'foo', timeout=10)
100        with pytest.raises(LockError):
101            lock.extend(10)
102
103    def test_extending_lock_with_no_timeout_raises_error(self, sr):
104        lock = self.get_lock(sr, 'foo')
105        assert lock.acquire(blocking=False)
106        with pytest.raises(LockError):
107            lock.extend(10)
108        lock.release()
109
110    def test_extending_lock_no_longer_owned_raises_error(self, sr):
111        lock = self.get_lock(sr, 'foo')
112        assert lock.acquire(blocking=False)
113        sr.set('foo', 'a')
114        with pytest.raises(LockError):
115            lock.extend(10)
116
117
118class TestLuaLock(TestLock):
119    lock_class = LuaLock
120
121
122class TestLockClassSelection(object):
123    def test_lock_class_argument(self, sr):
124        lock = sr.lock('foo', lock_class=Lock)
125        assert type(lock) == Lock
126        lock = sr.lock('foo', lock_class=LuaLock)
127        assert type(lock) == LuaLock
128
129    def test_cached_lualock_flag(self, sr):
130        try:
131            sr._use_lua_lock = True
132            lock = sr.lock('foo')
133            assert type(lock) == LuaLock
134        finally:
135            sr._use_lua_lock = None
136
137    def test_cached_lock_flag(self, sr):
138        try:
139            sr._use_lua_lock = False
140            lock = sr.lock('foo')
141            assert type(lock) == Lock
142        finally:
143            sr._use_lua_lock = None
144
145    def test_lua_compatible_server(self, sr, monkeypatch):
146        @classmethod
147        def mock_register(cls, redis):
148            return
149        monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
150        try:
151            lock = sr.lock('foo')
152            assert type(lock) == LuaLock
153            assert sr._use_lua_lock is True
154        finally:
155            sr._use_lua_lock = None
156
157    def test_lua_unavailable(self, sr, monkeypatch):
158        @classmethod
159        def mock_register(cls, redis):
160            raise ResponseError()
161        monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
162        try:
163            lock = sr.lock('foo')
164            assert type(lock) == Lock
165            assert sr._use_lua_lock is False
166        finally:
167            sr._use_lua_lock = None
168