1from django.db import transaction
2from django.test import TestCase, override_settings
3
4from cacheops import cached_as, no_invalidation, invalidate_obj, invalidate_model, invalidate_all
5from cacheops.conf import settings
6from cacheops.signals import cache_read, cache_invalidated
7
8from .utils import BaseTestCase, make_inc
9from .models import Post, Category, Local, DbAgnostic, DbBinded
10
11
12class SettingsTests(TestCase):
13    def test_context_manager(self):
14        self.assertTrue(settings.CACHEOPS_ENABLED)
15
16        with self.settings(CACHEOPS_ENABLED=False):
17            self.assertFalse(settings.CACHEOPS_ENABLED)
18
19    @override_settings(CACHEOPS_ENABLED=False)
20    def test_decorator(self):
21        self.assertFalse(settings.CACHEOPS_ENABLED)
22
23
24@override_settings(CACHEOPS_ENABLED=False)
25class ClassOverrideSettingsTests(TestCase):
26    def test_class(self):
27        self.assertFalse(settings.CACHEOPS_ENABLED)
28
29
30class SignalsTests(BaseTestCase):
31    def setUp(self):
32        super(SignalsTests, self).setUp()
33
34        def set_signal(signal=None, **kwargs):
35            self.signal_calls.append(kwargs)
36
37        self.signal_calls = []
38        cache_read.connect(set_signal, dispatch_uid=1, weak=False)
39
40    def tearDown(self):
41        super(SignalsTests, self).tearDown()
42        cache_read.disconnect(dispatch_uid=1)
43
44    def test_queryset(self):
45        # Miss
46        test_model = Category.objects.create(title="foo")
47        Category.objects.cache().get(id=test_model.id)
48        self.assertEqual(self.signal_calls, [{'sender': Category, 'func': None, 'hit': False}])
49
50        # Hit
51        self.signal_calls = []
52        Category.objects.cache().get(id=test_model.id) # hit
53        self.assertEqual(self.signal_calls, [{'sender': Category, 'func': None, 'hit': True}])
54
55    def test_queryset_empty(self):
56        list(Category.objects.cache().filter(pk__in=[]))
57        self.assertEqual(self.signal_calls, [{'sender': Category, 'func': None, 'hit': False}])
58
59    def test_cached_as(self):
60        get_calls = make_inc(cached_as(Category.objects.filter(title='test')))
61        func = get_calls.__wrapped__
62
63        # Miss
64        self.assertEqual(get_calls(), 1)
65        self.assertEqual(self.signal_calls, [{'sender': None, 'func': func, 'hit': False}])
66
67        # Hit
68        self.signal_calls = []
69        self.assertEqual(get_calls(), 1)
70        self.assertEqual(self.signal_calls, [{'sender': None, 'func': func, 'hit': True}])
71
72    def test_invalidation_signal(self):
73        def set_signal(signal=None, **kwargs):
74            signal_calls.append(kwargs)
75
76        signal_calls = []
77        cache_invalidated.connect(set_signal, dispatch_uid=1, weak=False)
78
79        invalidate_all()
80        invalidate_model(Post)
81        c = Category.objects.create(title='Hey')
82        self.assertEqual(signal_calls, [
83            {'sender': None, 'obj_dict': None},
84            {'sender': Post, 'obj_dict': None},
85            {'sender': Category, 'obj_dict': {'id': c.pk, 'title': 'Hey'}},
86        ])
87
88
89class LockingTests(BaseTestCase):
90    def test_lock(self):
91        import random
92        import threading
93        from .utils import ThreadWithReturnValue
94        from before_after import before
95
96        @cached_as(Post, lock=True, timeout=60)
97        def func():
98            return random.random()
99
100        results = []
101        locked = threading.Event()
102        thread = [None]
103
104        def second_thread():
105            def _target():
106                try:
107                    with before('redis.StrictRedis.brpoplpush', lambda *a, **kw: locked.set()):
108                        results.append(func())
109                except Exception:
110                    locked.set()
111                    raise
112
113            thread[0] = ThreadWithReturnValue(target=_target)
114            thread[0].start()
115            assert locked.wait(1)  # Wait until right before the block
116
117        with before('random.random', second_thread):
118            results.append(func())
119
120        thread[0].join()
121
122        self.assertEqual(results[0], results[1])
123
124
125class NoInvalidationTests(BaseTestCase):
126    fixtures = ['basic']
127
128    def _template(self, invalidate):
129        post = Post.objects.cache().get(pk=1)
130        invalidate(post)
131
132        with self.assertNumQueries(0):
133            Post.objects.cache().get(pk=1)
134
135    def test_context_manager(self):
136        def invalidate(post):
137            with no_invalidation:
138                invalidate_obj(post)
139        self._template(invalidate)
140
141    def test_decorator(self):
142        self._template(no_invalidation(invalidate_obj))
143
144    def test_nested(self):
145        def invalidate(post):
146            with no_invalidation:
147                with no_invalidation:
148                    pass
149                invalidate_obj(post)
150        self._template(invalidate)
151
152    def test_in_transaction(self):
153        with transaction.atomic():
154            post = Post.objects.cache().get(pk=1)
155
156            with no_invalidation:
157                post.save()
158
159        with self.assertNumQueries(0):
160            Post.objects.cache().get(pk=1)
161
162
163class LocalGetTests(BaseTestCase):
164    def setUp(self):
165        Local.objects.create(pk=1)
166        super(LocalGetTests, self).setUp()
167
168    def test_unhashable_args(self):
169        Local.objects.cache().get(pk__in=[1, 2])
170
171
172class DbAgnosticTests(BaseTestCase):
173    databases = ('default', 'slave')
174
175    def test_db_agnostic_by_default(self):
176        list(DbAgnostic.objects.cache())
177
178        with self.assertNumQueries(0, using='slave'):
179            list(DbAgnostic.objects.cache().using('slave'))
180
181    def test_db_agnostic_disabled(self):
182        list(DbBinded.objects.cache())
183
184        with self.assertNumQueries(1, using='slave'):
185            list(DbBinded.objects.cache().using('slave'))
186