1import collections
2import itertools
3import random
4from threading import Lock
5from threading import Thread
6import time
7from unittest import TestCase
8
9import pytest
10
11from dogpile.cache import CacheRegion
12from dogpile.cache import register_backend
13from dogpile.cache.api import CacheBackend
14from dogpile.cache.api import NO_VALUE
15from dogpile.cache.region import _backend_loader
16from . import assert_raises_message
17from . import eq_
18
19
20class _GenericBackendFixture(object):
21    @classmethod
22    def setup_class(cls):
23        backend_cls = _backend_loader.load(cls.backend)
24        try:
25            arguments = cls.config_args.get("arguments", {})
26            backend = backend_cls(arguments)
27        except ImportError:
28            pytest.skip("Backend %s not installed" % cls.backend)
29        cls._check_backend_available(backend)
30
31    def tearDown(self):
32        if self._region_inst:
33            for key in self._keys:
34                self._region_inst.delete(key)
35            self._keys.clear()
36        elif self._backend_inst:
37            self._backend_inst.delete("some_key")
38
39    @classmethod
40    def _check_backend_available(cls, backend):
41        pass
42
43    region_args = {}
44    config_args = {}
45
46    _region_inst = None
47    _backend_inst = None
48
49    _keys = set()
50
51    def _region(self, backend=None, region_args={}, config_args={}):
52        _region_args = self.region_args.copy()
53        _region_args.update(**region_args)
54        _config_args = self.config_args.copy()
55        _config_args.update(config_args)
56
57        def _store_keys(key):
58            if existing_key_mangler:
59                key = existing_key_mangler(key)
60            self._keys.add(key)
61            return key
62
63        self._region_inst = reg = CacheRegion(**_region_args)
64
65        existing_key_mangler = self._region_inst.key_mangler
66        self._region_inst.key_mangler = _store_keys
67        self._region_inst._user_defined_key_mangler = _store_keys
68
69        reg.configure(backend or self.backend, **_config_args)
70        return reg
71
72    def _backend(self):
73        backend_cls = _backend_loader.load(self.backend)
74        _config_args = self.config_args.copy()
75        arguments = _config_args.get("arguments", {})
76        self._backend_inst = backend_cls(arguments)
77        return self._backend_inst
78
79
80class _GenericBackendTest(_GenericBackendFixture, TestCase):
81    def test_backend_get_nothing(self):
82        backend = self._backend()
83        eq_(backend.get("some_key"), NO_VALUE)
84
85    def test_backend_delete_nothing(self):
86        backend = self._backend()
87        backend.delete("some_key")
88
89    def test_backend_set_get_value(self):
90        backend = self._backend()
91        backend.set("some_key", "some value")
92        eq_(backend.get("some_key"), "some value")
93
94    def test_backend_delete(self):
95        backend = self._backend()
96        backend.set("some_key", "some value")
97        backend.delete("some_key")
98        eq_(backend.get("some_key"), NO_VALUE)
99
100    def test_region_set_get_value(self):
101        reg = self._region()
102        reg.set("some key", "some value")
103        eq_(reg.get("some key"), "some value")
104
105    def test_region_set_multiple_values(self):
106        reg = self._region()
107        values = {"key1": "value1", "key2": "value2", "key3": "value3"}
108        reg.set_multi(values)
109        eq_(values["key1"], reg.get("key1"))
110        eq_(values["key2"], reg.get("key2"))
111        eq_(values["key3"], reg.get("key3"))
112
113    def test_region_get_zero_multiple_values(self):
114        reg = self._region()
115        eq_(reg.get_multi([]), [])
116
117    def test_region_set_zero_multiple_values(self):
118        reg = self._region()
119        reg.set_multi({})
120
121    def test_region_set_zero_multiple_values_w_decorator(self):
122        reg = self._region()
123        values = reg.get_or_create_multi([], lambda: 0)
124        eq_(values, [])
125
126    def test_region_get_or_create_multi_w_should_cache_none(self):
127        reg = self._region()
128        values = reg.get_or_create_multi(
129            ["key1", "key2", "key3"],
130            lambda *k: [None, None, None],
131            should_cache_fn=lambda v: v is not None,
132        )
133        eq_(values, [None, None, None])
134
135    def test_region_get_multiple_values(self):
136        reg = self._region()
137        key1 = "value1"
138        key2 = "value2"
139        key3 = "value3"
140        reg.set("key1", key1)
141        reg.set("key2", key2)
142        reg.set("key3", key3)
143        values = reg.get_multi(["key1", "key2", "key3"])
144        eq_([key1, key2, key3], values)
145
146    def test_region_get_nothing_multiple(self):
147        reg = self._region()
148        reg.delete_multi(["key1", "key2", "key3", "key4", "key5"])
149        values = {"key1": "value1", "key3": "value3", "key5": "value5"}
150        reg.set_multi(values)
151        reg_values = reg.get_multi(
152            ["key1", "key2", "key3", "key4", "key5", "key6"]
153        )
154        eq_(
155            reg_values,
156            ["value1", NO_VALUE, "value3", NO_VALUE, "value5", NO_VALUE],
157        )
158
159    def test_region_get_empty_multiple(self):
160        reg = self._region()
161        reg_values = reg.get_multi([])
162        eq_(reg_values, [])
163
164    def test_region_delete_multiple(self):
165        reg = self._region()
166        values = {"key1": "value1", "key2": "value2", "key3": "value3"}
167        reg.set_multi(values)
168        reg.delete_multi(["key2", "key10"])
169        eq_(values["key1"], reg.get("key1"))
170        eq_(NO_VALUE, reg.get("key2"))
171        eq_(values["key3"], reg.get("key3"))
172        eq_(NO_VALUE, reg.get("key10"))
173
174    def test_region_set_get_nothing(self):
175        reg = self._region()
176        reg.delete_multi(["some key"])
177        eq_(reg.get("some key"), NO_VALUE)
178
179    def test_region_creator(self):
180        reg = self._region()
181
182        def creator():
183            return "some value"
184
185        eq_(reg.get_or_create("some key", creator), "some value")
186
187    def test_threaded_dogpile(self):
188        # run a basic dogpile concurrency test.
189        # note the concurrency of dogpile itself
190        # is intensively tested as part of dogpile.
191        reg = self._region(config_args={"expiration_time": 0.25})
192        lock = Lock()
193        canary = []
194
195        def creator():
196            ack = lock.acquire(False)
197            canary.append(ack)
198            time.sleep(0.25)
199            if ack:
200                lock.release()
201            return "some value"
202
203        def f():
204            for x in range(5):
205                reg.get_or_create("some key", creator)
206                time.sleep(0.5)
207
208        threads = [Thread(target=f) for i in range(10)]
209        for t in threads:
210            t.start()
211        for t in threads:
212            t.join()
213        assert len(canary) > 2
214        if not reg.backend.has_lock_timeout():
215            assert False not in canary
216        else:
217            assert False in canary
218
219    def test_threaded_get_multi(self):
220        reg = self._region(config_args={"expiration_time": 0.25})
221        locks = dict((str(i), Lock()) for i in range(11))
222
223        canary = collections.defaultdict(list)
224
225        def creator(*keys):
226            assert keys
227            ack = [locks[key].acquire(False) for key in keys]
228
229            # print(
230            #        ("%s " % thread.get_ident()) + \
231            #        ", ".join(sorted("%s=%s" % (key, acq)
232            #                    for acq, key in zip(ack, keys)))
233            #    )
234
235            for acq, key in zip(ack, keys):
236                canary[key].append(acq)
237
238            time.sleep(0.5)
239
240            for acq, key in zip(ack, keys):
241                if acq:
242                    locks[key].release()
243            return ["some value %s" % k for k in keys]
244
245        def f():
246            for x in range(5):
247                reg.get_or_create_multi(
248                    [
249                        str(random.randint(1, 10))
250                        for i in range(random.randint(1, 5))
251                    ],
252                    creator,
253                )
254                time.sleep(0.5)
255
256        f()
257        return
258        threads = [Thread(target=f) for i in range(5)]
259        for t in threads:
260            t.start()
261        for t in threads:
262            t.join()
263
264        assert sum([len(v) for v in canary.values()]) > 10
265        for l in canary.values():
266            assert False not in l
267
268    def test_region_delete(self):
269        reg = self._region()
270        reg.set("some key", "some value")
271        reg.delete("some key")
272        reg.delete("some key")
273        eq_(reg.get("some key"), NO_VALUE)
274
275    def test_region_expire(self):
276        reg = self._region(config_args={"expiration_time": 0.25})
277        counter = itertools.count(1)
278
279        def creator():
280            return "some value %d" % next(counter)
281
282        eq_(reg.get_or_create("some key", creator), "some value 1")
283        time.sleep(0.4)
284        eq_(reg.get("some key", ignore_expiration=True), "some value 1")
285        eq_(reg.get_or_create("some key", creator), "some value 2")
286        eq_(reg.get("some key"), "some value 2")
287
288    def test_decorated_fn_functionality(self):
289        # test for any quirks in the fn decoration that interact
290        # with the backend.
291
292        reg = self._region()
293
294        counter = itertools.count(1)
295
296        @reg.cache_on_arguments()
297        def my_function(x, y):
298            return next(counter) + x + y
299
300        # Start with a clean slate
301        my_function.invalidate(3, 4)
302        my_function.invalidate(5, 6)
303        my_function.invalidate(4, 3)
304
305        eq_(my_function(3, 4), 8)
306        eq_(my_function(5, 6), 13)
307        eq_(my_function(3, 4), 8)
308        eq_(my_function(4, 3), 10)
309
310        my_function.invalidate(4, 3)
311        eq_(my_function(4, 3), 11)
312
313    def test_exploding_value_fn(self):
314        reg = self._region()
315
316        def boom():
317            raise Exception("boom")
318
319        assert_raises_message(
320            Exception, "boom", reg.get_or_create, "some_key", boom
321        )
322
323
324class _GenericMutexTest(_GenericBackendFixture, TestCase):
325    def test_mutex(self):
326        backend = self._backend()
327        mutex = backend.get_mutex("foo")
328
329        ac = mutex.acquire()
330        assert ac
331        ac2 = mutex.acquire(False)
332        assert not ac2
333        mutex.release()
334        ac3 = mutex.acquire()
335        assert ac3
336        mutex.release()
337
338    def test_mutex_threaded(self):
339        backend = self._backend()
340        backend.get_mutex("foo")
341
342        lock = Lock()
343        canary = []
344
345        def f():
346            for x in range(5):
347                mutex = backend.get_mutex("foo")
348                mutex.acquire()
349                for y in range(5):
350                    ack = lock.acquire(False)
351                    canary.append(ack)
352                    time.sleep(0.002)
353                    if ack:
354                        lock.release()
355                mutex.release()
356                time.sleep(0.02)
357
358        threads = [Thread(target=f) for i in range(5)]
359        for t in threads:
360            t.start()
361        for t in threads:
362            t.join()
363        assert False not in canary
364
365    def test_mutex_reentrant_across_keys(self):
366        backend = self._backend()
367        for x in range(3):
368            m1 = backend.get_mutex("foo")
369            m2 = backend.get_mutex("bar")
370            try:
371                m1.acquire()
372                assert m2.acquire(False)
373                assert not m2.acquire(False)
374                m2.release()
375
376                assert m2.acquire(False)
377                assert not m2.acquire(False)
378                m2.release()
379            finally:
380                m1.release()
381
382    def test_reentrant_dogpile(self):
383        reg = self._region()
384
385        def create_foo():
386            return "foo" + reg.get_or_create("bar", create_bar)
387
388        def create_bar():
389            return "bar"
390
391        eq_(reg.get_or_create("foo", create_foo), "foobar")
392        eq_(reg.get_or_create("foo", create_foo), "foobar")
393
394
395class MockMutex(object):
396    def __init__(self, key):
397        self.key = key
398
399    def acquire(self, blocking=True):
400        return True
401
402    def release(self):
403        return
404
405
406class MockBackend(CacheBackend):
407    def __init__(self, arguments):
408        self.arguments = arguments
409        self._cache = {}
410
411    def get_mutex(self, key):
412        return MockMutex(key)
413
414    def get(self, key):
415        try:
416            return self._cache[key]
417        except KeyError:
418            return NO_VALUE
419
420    def get_multi(self, keys):
421        return [self.get(key) for key in keys]
422
423    def set(self, key, value):
424        self._cache[key] = value
425
426    def set_multi(self, mapping):
427        for key, value in mapping.items():
428            self.set(key, value)
429
430    def delete(self, key):
431        self._cache.pop(key, None)
432
433    def delete_multi(self, keys):
434        for key in keys:
435            self.delete(key)
436
437
438register_backend("mock", __name__, "MockBackend")
439