1from contextlib import contextmanager
2import re
3import platform
4import unittest
5import mock
6
7import django
8from django.db import connection
9from django.db import DEFAULT_DB_ALIAS
10from django.test import override_settings
11from django.test.client import RequestFactory
12from django.template import Context, Template
13from django.db.models import F, Count, OuterRef, Sum, Subquery, Exists
14from django.db.models.expressions import RawSQL
15
16from cacheops import invalidate_model, invalidate_obj, \
17                     cached, cached_view, cached_as, cached_view_as
18from cacheops import invalidate_fragment
19from cacheops.templatetags.cacheops import register
20
21decorator_tag = register.decorator_tag
22from .models import *  # noqa
23from .utils import BaseTestCase, make_inc
24
25
26class BasicTests(BaseTestCase):
27    fixtures = ['basic']
28
29    def test_it_works(self):
30        with self.assertNumQueries(1):
31            cnt1 = Category.objects.cache().count()
32            cnt2 = Category.objects.cache().count()
33            self.assertEqual(cnt1, cnt2)
34
35    def test_empty(self):
36        with self.assertNumQueries(0):
37            list(Category.objects.cache().filter(id__in=[]))
38
39    def test_exact(self):
40        list(Category.objects.filter(pk=1).cache())
41        with self.assertNumQueries(0):
42            list(Category.objects.filter(pk__exact=1).cache())
43
44    def test_exists(self):
45        with self.assertNumQueries(1):
46            Category.objects.cache(ops='exists').exists()
47            Category.objects.cache(ops='exists').exists()
48
49    def test_some(self):
50        # Ignoring SOME condition lead to wrong DNF for this queryset,
51        # which leads to no invalidation
52        list(Category.objects.exclude(pk__in=range(10), pk__isnull=False).cache())
53        c = Category.objects.get(pk=1)
54        c.save()
55        with self.assertNumQueries(1):
56            list(Category.objects.exclude(pk__in=range(10), pk__isnull=False).cache())
57
58    def test_invalidation(self):
59        post = Post.objects.cache().get(pk=1)
60        post.title += ' changed'
61        post.save()
62
63        with self.assertNumQueries(1):
64            changed_post = Post.objects.cache().get(pk=1)
65            self.assertEqual(post.title, changed_post.title)
66
67    def test_granular(self):
68        Post.objects.cache().get(pk=1)
69        Post.objects.get(pk=2).save()
70
71        with self.assertNumQueries(0):
72            Post.objects.cache().get(pk=1)
73
74    def test_invalidate_by_foreign_key(self):
75        posts = list(Post.objects.cache().filter(category=1))
76        Post.objects.create(title='New Post', category_id=1)
77
78        with self.assertNumQueries(1):
79            changed_posts = list(Post.objects.cache().filter(category=1))
80            self.assertEqual(len(changed_posts), len(posts) + 1)
81
82    def test_invalidate_by_one_to_one(self):
83        extras = list(Extra.objects.cache().filter(post=3))
84        Extra.objects.create(post_id=3, tag=0)
85
86        with self.assertNumQueries(1):
87            changed_extras = list(Extra.objects.cache().filter(post=3))
88            self.assertEqual(len(changed_extras), len(extras) + 1)
89
90    def test_invalidate_by_boolean(self):
91        count = Post.objects.cache().filter(visible=True).count()
92
93        post = Post.objects.get(pk=1, visible=True)
94        post.visible = False
95        post.save()
96
97        with self.assertNumQueries(1):
98            new_count = Post.objects.cache().filter(visible=True).count()
99            self.assertEqual(new_count, count - 1)
100
101    def test_bulk_create(self):
102        cnt = Category.objects.cache().count()
103        Category.objects.bulk_create([Category(title='hi'), Category(title='there')])
104
105        with self.assertNumQueries(1):
106            cnt2 = Category.objects.cache().count()
107            self.assertEqual(cnt2, cnt + 2)
108
109    def test_db_column(self):
110        e = Extra.objects.cache().get(tag=5)
111        e.save()
112
113    def test_fk_to_db_column(self):
114        e = Extra.objects.cache().get(to_tag__tag=5)
115        e.save()
116
117        with self.assertNumQueries(1):
118            Extra.objects.cache().get(to_tag=5)
119
120    def test_expressions(self):
121        qs = Extra.objects.cache().filter(tag=F('to_tag') + 1, to_tag=F('tag').bitor(5))
122        qs.count()
123        with self.assertNumQueries(0):
124            qs.count()
125
126    def test_expressions_save(self):
127        # Check saving F
128        extra = Extra.objects.get(pk=1)
129        extra.tag = F('tag')
130        extra.save()
131
132        # Check saving ExressionNode
133        Extra.objects.create(post_id=3, tag=7)
134        extra = Extra.objects.get(pk=3)
135        extra.tag = F('tag') + 1
136        extra.save()
137
138    def test_combine(self):
139        qs = Post.objects.filter(pk__in=[1, 2]) & Post.objects.all()
140        self.assertEqual(list(qs.cache()), list(qs))
141
142        qs = Post.objects.filter(pk__in=[1, 2]) | Post.objects.none()
143        self.assertEqual(list(qs.cache()), list(qs))
144
145    def test_first_and_last(self):
146        qs = Category.objects.cache(ops='get')
147        qs.first()
148        qs.last()
149        with self.assertNumQueries(0):
150            qs.first()
151            qs.last()
152
153    def test_union(self):
154        qs = Post.objects.filter(category=1).values('id', 'title').union(
155                Category.objects.filter(title='Perl').values('id', 'title')).cache()
156        list(qs.clone())
157        # Invalidated
158        Category.objects.create(title='Perl')
159        with self.assertNumQueries(1):
160            list(qs.clone())
161        # Not invalidated
162        Category.objects.create(title='Ruby')
163        with self.assertNumQueries(0):
164            list(qs.clone())
165
166    def test_invalidated_update(self):
167        list(Post.objects.filter(category=1).cache())
168        list(Post.objects.filter(category=2).cache())
169
170        # Should invalidate both queries
171        Post.objects.filter(category=1).invalidated_update(category=2)
172
173        with self.assertNumQueries(2):
174            list(Post.objects.filter(category=1).cache())
175            list(Post.objects.filter(category=2).cache())
176
177    def test_subquery(self):
178        categories = Category.objects.cache().filter(title='Django').only('id')
179        Post.objects.cache().filter(category__in=Subquery(categories)).count()
180
181    def test_rawsql(self):
182        Post.objects.cache().filter(category__in=RawSQL("select 1", ())).count()
183
184
185class ValuesTests(BaseTestCase):
186    fixtures = ['basic']
187
188    def test_it_works(self):
189        with self.assertNumQueries(1):
190            len(Category.objects.cache().values())
191            len(Category.objects.cache().values())
192
193    def test_it_varies_on_class(self):
194        with self.assertNumQueries(2):
195            len(Category.objects.cache())
196            len(Category.objects.cache().values())
197
198    def test_it_varies_on_flat(self):
199        with self.assertNumQueries(2):
200            len(Category.objects.cache().values_list())
201            len(Category.objects.cache().values_list(flat=True))
202
203
204class DecoratorTests(BaseTestCase):
205    def test_cached_as_model(self):
206        get_calls = make_inc(cached_as(Category))
207
208        self.assertEqual(get_calls(), 1)      # miss
209        self.assertEqual(get_calls(), 1)      # hit
210        Category.objects.create(title='test') # invalidate
211        self.assertEqual(get_calls(), 2)      # miss
212
213    def test_cached_as_cond(self):
214        get_calls = make_inc(cached_as(Category.objects.filter(title='test')))
215
216        self.assertEqual(get_calls(), 1)      # cache
217        Category.objects.create(title='miss') # don't invalidate
218        self.assertEqual(get_calls(), 1)      # hit
219        Category.objects.create(title='test') # invalidate
220        self.assertEqual(get_calls(), 2)      # miss
221
222    def test_cached_as_obj(self):
223        c = Category.objects.create(title='test')
224        get_calls = make_inc(cached_as(c))
225
226        self.assertEqual(get_calls(), 1)      # cache
227        Category.objects.create(title='miss') # don't invalidate
228        self.assertEqual(get_calls(), 1)      # hit
229        c.title = 'new'
230        c.save()                              # invalidate
231        self.assertEqual(get_calls(), 2)      # miss
232
233    def test_cached_as_depends_on_args(self):
234        get_calls = make_inc(cached_as(Category))
235
236        self.assertEqual(get_calls(1), 1)      # cache
237        self.assertEqual(get_calls(1), 1)      # hit
238        self.assertEqual(get_calls(2), 2)      # miss
239
240    def test_cached_as_depends_on_two_models(self):
241        get_calls = make_inc(cached_as(Category, Post))
242        c = Category.objects.create(title='miss')
243        p = Post.objects.create(title='New Post', category=c)
244
245        self.assertEqual(get_calls(1), 1)      # cache
246        c.title = 'new title'
247        c.save()                               # invalidate by Category
248        self.assertEqual(get_calls(1), 2)      # miss and cache
249        p.title = 'new title'
250        p.save()                               # invalidate by Post
251        self.assertEqual(get_calls(1), 3)      # miss and cache
252
253    def test_cached_as_keep_fresh(self):
254        c = Category.objects.create(title='test')
255        calls = [0]
256
257        @cached_as(c, keep_fresh=True)
258        def get_calls(_=None, **kw):
259            # Invalidate during first run
260            if calls[0] < 1:
261                invalidate_obj(c)
262            calls[0] += 1
263            return calls[0]
264
265        self.assertEqual(get_calls(), 1)      # miss, stale result not cached.
266        self.assertEqual(get_calls(), 2)      # miss and cache
267        self.assertEqual(get_calls(), 2)      # hit
268
269    def test_cached_view_as(self):
270        get_calls = make_inc(cached_view_as(Category))
271
272        factory = RequestFactory()
273        r1 = factory.get('/hi')
274        r2 = factory.get('/hi')
275        r2.META['REMOTE_ADDR'] = '10.10.10.10'
276        r3 = factory.get('/bye')
277
278        self.assertEqual(get_calls(r1), 1) # cache
279        self.assertEqual(get_calls(r1), 1) # hit
280        self.assertEqual(get_calls(r2), 1) # hit, since only url is considered
281        self.assertEqual(get_calls(r3), 2) # miss
282
283    def test_cached_view_on_template_response(self):
284        from django.template.response import TemplateResponse
285        from django.template import engines
286        from_string = engines['django'].from_string
287
288        @cached_view_as(Category)
289        def view(request):
290            return TemplateResponse(request, from_string('hi'))
291
292        factory = RequestFactory()
293        view(factory.get('/hi'))
294
295
296from datetime import date, datetime, time
297
298class WeirdTests(BaseTestCase):
299    def _template(self, field, value):
300        qs = Weird.objects.cache().filter(**{field: value})
301        count = qs.count()
302
303        obj = Weird.objects.create(**{field: value})
304
305        with self.assertNumQueries(2):
306            self.assertEqual(qs.count(), count + 1)
307            new_obj = qs.get(pk=obj.pk)
308            self.assertEqual(getattr(new_obj, field), value)
309
310    def test_date(self):
311        self._template('date_field', date.today())
312
313    def test_datetime(self):
314        # NOTE: some databases (mysql) don't store microseconds
315        self._template('datetime_field', datetime.now().replace(microsecond=0))
316
317    def test_time(self):
318        self._template('time_field', time(10, 30))
319
320    def test_list(self):
321        self._template('list_field', [1, 2])
322
323    def test_binary(self):
324        obj = Weird.objects.create(binary_field=b'12345')
325        Weird.objects.cache().get(pk=obj.pk)
326        Weird.objects.cache().get(pk=obj.pk)
327
328    def test_custom(self):
329        self._template('custom_field', CustomValue('some'))
330
331    def test_weird_custom(self):
332        class WeirdCustom(CustomValue):
333            def __str__(self):
334                return 'other'
335        self._template('custom_field', WeirdCustom('some'))
336
337    def test_custom_query(self):
338        list(Weird.customs.cache())
339
340
341@unittest.skipIf(connection.vendor != 'postgresql', "Only for PostgreSQL")
342class PostgresTests(BaseTestCase):
343    def test_array_contains(self):
344        list(TaggedPost.objects.filter(tags__contains=[42]).cache())
345
346    def test_array_len(self):
347        list(TaggedPost.objects.filter(tags__len=42).cache())
348
349    def test_json(self):
350        list(TaggedPost.objects.filter(meta__author='Suor'))
351
352
353class TemplateTests(BaseTestCase):
354    def assertRendersTo(self, template, context, result):
355        s = template.render(Context(context))
356        self.assertEqual(re.sub(r'\s+', '', s), result)
357
358    def test_cached(self):
359        inc_a = make_inc()
360        inc_b = make_inc()
361        t = Template("""
362            {% load cacheops %}
363            {% cached 60 'a' %}.a{{ a }}{% endcached %}
364            {% cached 60 'a' %}.a{{ a }}{% endcached %}
365            {% cached 60 'a' 'variant' %}.a{{ a }}{% endcached %}
366            {% cached timeout=60 fragment_name='b' %}.b{{ b }}{% endcached %}
367        """)
368
369        self.assertRendersTo(t, {'a': inc_a, 'b': inc_b}, '.a1.a1.a2.b1')
370
371    def test_invalidate_fragment(self):
372        inc = make_inc()
373        t = Template("""
374            {% load cacheops %}
375            {% cached 60 'a' %}.{{ inc }}{% endcached %}
376        """)
377
378        self.assertRendersTo(t, {'inc': inc}, '.1')
379
380        invalidate_fragment('a')
381        self.assertRendersTo(t, {'inc': inc}, '.2')
382
383    def test_cached_as(self):
384        inc = make_inc()
385        qs = Post.objects.all()
386        t = Template("""
387            {% load cacheops %}
388            {% cached_as qs None 'a' %}.{{ inc }}{% endcached_as %}
389            {% cached_as qs timeout=60 fragment_name='a' %}.{{ inc }}{% endcached_as %}
390            {% cached_as qs fragment_name='a' timeout=60 %}.{{ inc }}{% endcached_as %}
391        """)
392
393        # All the forms are equivalent
394        self.assertRendersTo(t, {'inc': inc, 'qs': qs}, '.1.1.1')
395
396        # Cache works across calls
397        self.assertRendersTo(t, {'inc': inc, 'qs': qs}, '.1.1.1')
398
399        # Post invalidation clears cache
400        invalidate_model(Post)
401        self.assertRendersTo(t, {'inc': inc, 'qs': qs}, '.2.2.2')
402
403    def test_decorator_tag(self):
404        @decorator_tag
405        def my_cached(flag):
406            return cached(timeout=60) if flag else lambda x: x
407
408        inc = make_inc()
409        t = Template("""
410            {% load cacheops %}
411            {% my_cached 1 %}.{{ inc }}{% endmy_cached %}
412            {% my_cached 0 %}.{{ inc }}{% endmy_cached %}
413            {% my_cached 0 %}.{{ inc }}{% endmy_cached %}
414            {% my_cached 1 %}.{{ inc }}{% endmy_cached %}
415        """)
416
417        self.assertRendersTo(t, {'inc': inc}, '.1.2.3.1')
418
419    def test_decorator_tag_context(self):
420        @decorator_tag(takes_context=True)
421        def my_cached(context):
422            return cached(timeout=60) if context['flag'] else lambda x: x
423
424        inc = make_inc()
425        t = Template("""
426            {% load cacheops %}
427            {% my_cached %}.{{ inc }}{% endmy_cached %}
428            {% my_cached %}.{{ inc }}{% endmy_cached %}
429        """)
430
431        self.assertRendersTo(t, {'inc': inc, 'flag': True}, '.1.1')
432        self.assertRendersTo(t, {'inc': inc, 'flag': False}, '.2.3')
433
434    def test_jinja2(self):
435        from jinja2 import Environment
436        env = Environment(extensions=['cacheops.jinja2.cache'])
437        t = env.from_string('Hello, {% cached %}{{ name }}{% endcached %}')
438        t.render(name='Alex')
439
440
441class IssueTests(BaseTestCase):
442    databases = ('default', 'slave')
443    fixtures = ['basic']
444
445    def setUp(self):
446        self.user = User.objects.create(pk=1, username='Suor')
447        Profile.objects.create(pk=2, user=self.user, tag=10)
448        super(IssueTests, self).setUp()
449
450    def test_16(self):
451        p = Profile.objects.cache().get(user__id__exact=1)
452        p.save()
453
454        with self.assertNumQueries(1):
455            Profile.objects.cache().get(user=1)
456
457    def test_29(self):
458        Brand.objects.exclude(labels__in=[1, 2, 3]).cache().count()
459
460    def test_45(self):
461        m = CacheOnSaveModel(title="test")
462        m.save()
463
464        with self.assertNumQueries(0):
465            CacheOnSaveModel.objects.cache().get(pk=m.pk)
466
467    def test_57(self):
468        list(Post.objects.filter(category__in=Category.objects.nocache()).cache())
469
470    def test_114(self):
471        list(Category.objects.cache().filter(title=u'ó'))
472
473    def test_145(self):
474        # Create One with boolean False
475        one = One.objects.create(boolean=False)
476
477        # Update boolean to True
478        one = One.objects.cache().get(id=one.id)
479        one.boolean = True
480        one.save()  # An error was in post_save signal handler
481
482    def test_159(self):
483        brand = Brand.objects.create(pk=1)
484        label = Label.objects.create(pk=2)
485        brand.labels.add(label)
486
487        # Create another brand with the same pk as label.
488        # This will trigger a bug invalidating brands quering them by label id.
489        another_brand = Brand.objects.create(pk=2)
490
491        list(brand.labels.cache())
492        list(another_brand.labels.cache())
493
494        # Clear brands for label linked to brand, but not another_brand.
495        label.brands.clear()
496
497        # Cache must stay for another_brand
498        with self.assertNumQueries(0):
499            list(another_brand.labels.cache())
500
501    def test_161(self):
502        categories = Category.objects.using('slave').filter(title='Python')
503        list(Post.objects.using('slave').filter(category__in=categories).cache())
504
505    @unittest.skipIf(connection.vendor == 'mysql', 'MySQL fails with encodings')
506    def test_161_non_ascii(self):
507        # Non ascii text in non-unicode str literal
508        list(Category.objects.filter(title='фыва').cache())
509        list(Category.objects.filter(title='фыва', title__startswith='фыва').cache())
510
511    def test_169(self):
512        c = Category.objects.prefetch_related('posts').get(pk=3)
513        c.posts.get(visible=1)  # this used to fail
514
515    def test_173(self):
516        extra = Extra.objects.get(pk=1)
517        title = extra.post.category.title
518
519        # Cache
520        list(Extra.objects.filter(post__category__title=title).cache())
521
522        # Break the link
523        extra.post.category_id = 2
524        extra.post.save()
525
526        # Fail because neither Extra nor Catehory changed, but something in between
527        self.assertEqual([], list(Extra.objects.filter(post__category__title=title).cache()))
528
529    def test_177(self):
530        c = Category.objects.get(pk=1)
531        c.posts_copy = c.posts.cache()
532        bool(c.posts_copy)
533
534    def test_217(self):
535        # Destroy and recreate model manager
536        Post.objects.__class__().contribute_to_class(Post, 'objects')
537
538        # Test invalidation
539        post = Post.objects.cache().get(pk=1)
540        post.title += ' changed'
541        post.save()
542
543        with self.assertNumQueries(1):
544            changed_post = Post.objects.cache().get(pk=1)
545            self.assertEqual(post.title, changed_post.title)
546
547    def test_232(self):
548        list(Post.objects.cache().filter(category__in=[None, 1]).filter(category=1))
549
550    @unittest.skipIf(connection.vendor == 'mysql', 'In MySQL DDL is not transaction safe')
551    def test_265(self):
552        # Databases must have different structure,
553        # so exception other then DoesNotExist would be raised.
554        # Let's delete tests_video from default database
555        # and try working with it in slave database with using.
556        # Table is not restored automatically in MySQL, so I disabled this test in MySQL.
557        connection.cursor().execute("DROP TABLE tests_video;")
558
559        # Works fine
560        c = Video.objects.db_manager('slave').create(title='test_265')
561        self.assertTrue(Video.objects.using('slave').filter(title='test_265').exists())
562
563        # Fails with "no such table: tests_video"
564        # Fixed by adding .using(instance._state.db) in query.ManagerMixin._pre_save() method
565        c.title = 'test_265_1'
566        c.save()
567        self.assertTrue(Video.objects.using('slave').filter(title='test_265_1').exists())
568
569        # This also didn't work before fix above. Test that it works.
570        c.title = 'test_265_2'
571        c.save(using='slave')
572        self.assertTrue(Video.objects.using('slave').filter(title='test_265_2').exists())
573
574        # Same bug in other method
575        # Fixed by adding .using(self._db) in query.QuerySetMixin.invalidated_update() method
576        Video.objects.using('slave').invalidated_update(title='test_265_3')
577        self.assertTrue(Video.objects.using('slave').filter(title='test_265_3').exists())
578
579    @unittest.skipIf(django.VERSION < (3, 0), "Fixed in Django 3.0")
580    def test_312(self):
581        device = Device.objects.create()
582
583        # query by 32bytes uuid
584        d = Device.objects.cache().get(uid=device.uid.hex)
585
586        # test invalidation
587        d.model = 'new model'
588        d.save()
589
590        with self.assertNumQueries(1):
591            changed_device = Device.objects.cache().get(uid=device.uid.hex)
592            self.assertEqual(d.model, changed_device.model)
593
594    def test_316(self):
595        Category.objects.cache().annotate(num=Count('posts')).aggregate(total=Sum('num'))
596
597    @unittest.expectedFailure
598    def test_348(self):
599        foo = Foo.objects.create()
600        bar = Bar.objects.create(foo=foo)
601
602        bar = Bar.objects.cache().get(pk=bar.pk)
603        bar.foo.delete()
604
605        bar = Bar.objects.cache().get(pk=bar.pk)
606        bar.foo  # fails here since we try to fetch Foo instance by cached id
607
608    def test_352(self):
609        CombinedFieldModel.objects.create()
610        list(CombinedFieldModel.objects.cache().all())
611
612    def test_353(self):
613        foo = Foo.objects.create()
614        bar = Bar.objects.create()
615
616        self.assertEqual(Foo.objects.cache().filter(bar__isnull=True).count(), 1)
617        bar.foo = foo
618        bar.save()
619        self.assertEqual(Foo.objects.cache().filter(bar__isnull=True).count(), 0)
620
621    @unittest.skipIf(django.VERSION < (3, 0), "Supported from Django 3.0")
622    def test_359(self):
623        post_filter = Exists(Post.objects.all())
624        len(Category.objects.filter(post_filter).cache())
625
626    def test_365(self):
627        """
628        Check that an annotated Subquery is automatically invalidated.
629        """
630        # Retrieve all Categories and annotate the ID of the most recent Post for each
631        newest_post = Post.objects.filter(category=OuterRef('pk')).order_by('-pk').values('pk')
632        categories = Category.objects.cache().annotate(newest_post=Subquery(newest_post[:1]))
633
634        # Create a new Post in the first Category
635        post = Post(category=categories[0], title='Foo')
636        post.save()
637
638        # Retrieve Categories again, and check that the newest post ID is correct
639        categories = Category.objects.cache().annotate(newest_post=Subquery(newest_post[:1]))
640        self.assertEqual(categories[0].newest_post, post.pk)
641
642    @unittest.skipIf(platform.python_implementation() == "PyPy", "dill doesn't do that in PyPy")
643    def test_385(self):
644        Client.objects.create(name='Client Name')
645
646        with self.assertRaises(AttributeError) as e:
647            Client.objects.filter(name='Client Name').cache().first()
648        self.assertEqual(
649            str(e.exception),
650            "Can't pickle local object 'Client.__init__.<locals>.curry.<locals>._curried'")
651
652        invalidate_model(Client)
653
654        with override_settings(CACHEOPS_SERIALIZER='dill'):
655            with self.assertNumQueries(1):
656                Client.objects.filter(name='Client Name').cache().first()
657                Client.objects.filter(name='Client Name').cache().first()
658
659    def test_387(self):
660        post = Post.objects.defer("visible").last()
661        post.delete()
662
663
664class RelatedTests(BaseTestCase):
665    fixtures = ['basic']
666
667    def _template(self, qs, change, should_invalidate=True):
668        list(qs._clone().cache())
669        change()
670        with self.assertNumQueries(1 if should_invalidate else 0):
671            list(qs.cache())
672
673    def test_related_invalidation(self):
674        self._template(
675            Post.objects.filter(category__title='Django'),
676            lambda: Category.objects.get(title='Django').save()
677        )
678
679    def test_reverse_fk(self):
680        self._template(
681            Category.objects.filter(posts__title='Cacheops'),
682            lambda: Post.objects.get(title='Cacheops').save()
683        )
684
685    def test_reverse_fk_same(self):
686        title = "Implicit variable as pronoun"
687        self._template(
688            Category.objects.filter(posts__title=title, posts__visible=True),
689            lambda: Post.objects.get(title=title, visible=True).save()
690        )
691        self._template(
692            Category.objects.filter(posts__title=title, posts__visible=False),
693            lambda: Post.objects.get(title=title, visible=True).save(),
694            should_invalidate=False,
695        )
696
697    def test_reverse_fk_separate(self):
698        title = "Implicit variable as pronoun"
699        self._template(
700            Category.objects.filter(posts__title=title).filter(posts__visible=True),
701            lambda: Post.objects.get(title=title, visible=True).save()
702        )
703        self._template(
704            Category.objects.filter(posts__title=title).filter(posts__visible=False),
705            lambda: Post.objects.get(title=title, visible=True).save(),
706        )
707
708
709class AggregationTests(BaseTestCase):
710    fixtures = ['basic']
711
712    def test_annotate(self):
713        qs = Category.objects.annotate(posts_count=Count('posts')).cache()
714        list(qs._clone())
715        Post.objects.create(title='New One', category=Category.objects.all()[0])
716        with self.assertNumQueries(1):
717            list(qs._clone())
718
719    def test_aggregate(self):
720        qs = Category.objects.cache()
721        qs.aggregate(posts_count=Count('posts'))
722        # Test caching
723        with self.assertNumQueries(0):
724            qs.aggregate(posts_count=Count('posts'))
725        # Test invalidation
726        Post.objects.create(title='New One', category=Category.objects.all()[0])
727        with self.assertNumQueries(1):
728            qs.aggregate(posts_count=Count('posts'))
729
730
731class M2MTests(BaseTestCase):
732    brand_cls = Brand
733    label_cls = Label
734
735    def setUp(self):
736        self.bf = self.brand_cls.objects.create()
737        self.bs = self.brand_cls.objects.create()
738
739        self.fast = self.label_cls.objects.create(text='fast')
740        self.slow = self.label_cls.objects.create(text='slow')
741        self.furious = self.label_cls.objects.create(text='furios')
742
743        self.setup_m2m()
744        super(M2MTests, self).setUp()
745
746    def setup_m2m(self):
747        self.bf.labels.add(self.fast, self.furious)
748        self.bs.labels.add(self.slow, self.furious)
749
750    def _template(self, qs_or_action, change, should_invalidate=True):
751        if hasattr(qs_or_action, 'all'):
752            action = lambda: list(qs_or_action.all().cache())
753        else:
754            action = qs_or_action
755
756        action()
757        change()
758        with self.assertNumQueries(1 if should_invalidate else 0):
759            action()
760
761    def test_target_invalidates_on_clear(self):
762        self._template(
763            self.bf.labels,
764            lambda: self.bf.labels.clear()
765        )
766
767    def test_base_invalidates_on_clear(self):
768        self._template(
769            self.furious.brands,
770            lambda: self.bf.labels.clear()
771        )
772
773    def test_granular_through_on_clear(self):
774        through_qs = self.brand_cls.labels.through.objects.cache() \
775                                                  .filter(brand=self.bs, label=self.slow)
776        self._template(
777            lambda: through_qs.get(),
778            lambda: self.bf.labels.clear(),
779            should_invalidate=False
780        )
781
782    def test_granular_target_on_clear(self):
783        self._template(
784            lambda: self.label_cls.objects.cache().get(pk=self.slow.pk),
785            lambda: self.bf.labels.clear(),
786            should_invalidate=False
787        )
788
789    def test_target_invalidates_on_add(self):
790        self._template(
791            self.bf.labels,
792            lambda: self.bf.labels.add(self.slow)
793        )
794
795    def test_base_invalidates_on_add(self):
796        self._template(
797            self.slow.brands,
798            lambda: self.bf.labels.add(self.slow)
799        )
800
801    def test_target_invalidates_on_remove(self):
802        self._template(
803            self.bf.labels,
804            lambda: self.bf.labels.remove(self.furious)
805        )
806
807    def test_base_invalidates_on_remove(self):
808        self._template(
809            self.furious.brands,
810            lambda: self.bf.labels.remove(self.furious)
811        )
812
813
814class MultiTableInheritanceWithM2MTest(M2MTests):
815    brand_cls = PremiumBrand
816
817
818class M2MThroughTests(M2MTests):
819    brand_cls = BrandT
820    label_cls = LabelT
821
822    def setup_m2m(self):
823        Labeling.objects.create(brand=self.bf, label=self.fast, tag=10)
824        Labeling.objects.create(brand=self.bf, label=self.furious, tag=11)
825        Labeling.objects.create(brand=self.bs, label=self.slow, tag=20)
826        Labeling.objects.create(brand=self.bs, label=self.furious, tag=21)
827
828    # No add and remove methods for explicit through models
829    test_target_invalidates_on_add = None
830    test_base_invalidates_on_add = None
831    test_target_invalidates_on_remove = None
832    test_base_invalidates_on_remove = None
833
834    def test_target_invalidates_on_create(self):
835        self._template(
836            self.bf.labels,
837            lambda: Labeling.objects.create(brand=self.bf, label=self.slow, tag=1)
838        )
839
840    def test_base_invalidates_on_create(self):
841        self._template(
842            self.slow.brands,
843            lambda: Labeling.objects.create(brand=self.bf, label=self.slow, tag=1)
844        )
845
846    def test_target_invalidates_on_delete(self):
847        self._template(
848            self.bf.labels,
849            lambda: Labeling.objects.get(brand=self.bf, label=self.furious).delete()
850        )
851
852    def test_base_invalidates_on_delete(self):
853        self._template(
854            self.furious.brands,
855            # lambda: Labeling.objects.filter(brand=self.bf, label=self.furious).delete()
856            lambda: Labeling.objects.get(brand=self.bf, label=self.furious).delete()
857        )
858
859
860class ProxyTests(BaseTestCase):
861    def test_30(self):
862        list(VideoProxy.objects.cache())
863        Video.objects.create(title='Pulp Fiction')
864
865        with self.assertNumQueries(1):
866            list(VideoProxy.objects.cache())
867
868    def test_30_reversed(self):
869        list(Video.objects.cache())
870        VideoProxy.objects.create(title='Pulp Fiction')
871
872        with self.assertNumQueries(1):
873            list(Video.objects.cache())
874
875    @unittest.expectedFailure
876    def test_interchange(self):
877        list(Video.objects.cache())
878
879        with self.assertNumQueries(0):
880            list(VideoProxy.objects.cache())
881
882    def test_148_invalidate_from_non_cached_proxy(self):
883        video = Video.objects.create(title='Pulp Fiction')
884        Video.objects.cache().get(title=video.title)
885        NonCachedVideoProxy.objects.get(id=video.id).delete()
886
887        with self.assertRaises(Video.DoesNotExist):
888            Video.objects.cache().get(title=video.title)
889
890    def test_148_reverse(self):
891        media = NonCachedMedia.objects.create(title='Pulp Fiction')
892        MediaProxy.objects.cache().get(title=media.title)
893        NonCachedMedia.objects.get(id=media.id).delete()
894
895        with self.assertRaises(NonCachedMedia.DoesNotExist):
896            MediaProxy.objects.cache().get(title=media.title)
897
898    def test_proxy_caching(self):
899        video = Video.objects.create(title='Pulp Fiction')
900        self.assertEqual(type(Video.objects.cache().get(pk=video.pk)),
901                         Video)
902        self.assertEqual(type(VideoProxy.objects.cache().get(pk=video.pk)),
903                         VideoProxy)
904
905    def test_proxy_caching_reversed(self):
906        video = Video.objects.create(title='Pulp Fiction')
907        self.assertEqual(type(VideoProxy.objects.cache().get(pk=video.pk)),
908                         VideoProxy)
909        self.assertEqual(type(Video.objects.cache().get(pk=video.pk)),
910                         Video)
911
912
913class MultitableInheritanceTests(BaseTestCase):
914    @unittest.expectedFailure
915    def test_sub_added(self):
916        media_count = Media.objects.cache().count()
917        Movie.objects.create(name="Matrix", year=1999)
918
919        with self.assertNumQueries(1):
920            self.assertEqual(Media.objects.cache().count(), media_count + 1)
921
922    @unittest.expectedFailure
923    def test_base_changed(self):
924        matrix = Movie.objects.create(name="Matrix", year=1999)
925        list(Movie.objects.cache())
926
927        media = Media.objects.get(pk=matrix.pk)
928        media.name = "Matrix (original)"
929        media.save()
930
931        with self.assertNumQueries(1):
932            list(Movie.objects.cache())
933
934
935class SimpleCacheTests(BaseTestCase):
936    def test_cached(self):
937        get_calls = make_inc(cached(timeout=100))
938
939        self.assertEqual(get_calls(1), 1)
940        self.assertEqual(get_calls(1), 1)
941        self.assertEqual(get_calls(2), 2)
942        get_calls.invalidate(2)
943        self.assertEqual(get_calls(2), 3)
944
945        get_calls.key(2).delete()
946        self.assertEqual(get_calls(2), 4)
947
948        get_calls.key(2).set(42)
949        self.assertEqual(get_calls(2), 42)
950
951    def test_cached_view(self):
952        get_calls = make_inc(cached_view(timeout=100))
953
954        factory = RequestFactory()
955        r1 = factory.get('/hi')
956        r2 = factory.get('/hi')
957        r2.META['REMOTE_ADDR'] = '10.10.10.10'
958        r3 = factory.get('/bye')
959
960        self.assertEqual(get_calls(r1), 1) # cache
961        self.assertEqual(get_calls(r1), 1) # hit
962        self.assertEqual(get_calls(r2), 1) # hit, since only url is considered
963        self.assertEqual(get_calls(r3), 2) # miss
964
965        get_calls.invalidate(r1)
966        self.assertEqual(get_calls(r1), 3) # miss
967
968        # Can pass uri to invalidate
969        get_calls.invalidate(r1.build_absolute_uri())
970        self.assertEqual(get_calls(r1), 4) # miss
971
972
973@unittest.skipIf(connection.settings_dict['ENGINE'] != 'django.contrib.gis.db.backends.postgis',
974                 "Only for PostGIS")
975class GISTests(BaseTestCase):
976    def test_invalidate_model_with_geometry(self):
977        geom = Geometry()
978        geom.save()
979        # Raises ValueError if this doesn't work
980        invalidate_obj(geom)
981
982
983# NOTE: overriding cache prefix to separate invalidation sets by db.
984@override_settings(CACHEOPS_PREFIX=lambda q: q.db)
985class MultiDBInvalidationTests(BaseTestCase):
986    databases = ('default', 'slave')
987    fixtures = ['basic']
988
989    @contextmanager
990    def _control_counts(self):
991        Category.objects.cache().count()
992        Category.objects.using('slave').cache().count()
993
994        yield
995        with self.assertNumQueries(0):
996            Category.objects.cache().count()
997        with self.assertNumQueries(1, using='slave'):
998            Category.objects.cache().using('slave').count()
999
1000    def test_save(self):
1001        # NOTE: not testing when old db != new db,
1002        #       how cacheops works in that situation is undefined at the moment
1003        with self._control_counts():
1004            obj = Category()
1005            obj.save(using='slave')
1006
1007    def test_delete(self):
1008        obj = Category.objects.using('slave').create()
1009        with self._control_counts():
1010            obj.delete(using='slave')
1011
1012    def test_bulk_create(self):
1013        with self._control_counts():
1014            Category.objects.using('slave').bulk_create([Category(title='New')])
1015
1016    def test_invalidated_update(self):
1017        # NOTE: not testing router-based routing
1018        with self._control_counts():
1019            Category.objects.using('slave').invalidated_update(title='update')
1020
1021    @mock.patch('cacheops.invalidation.invalidate_dict')
1022    def test_m2m_changed_call_invalidate(self, mock_invalidate_dict):
1023        label = Label.objects.create()
1024        brand = Brand.objects.create()
1025        brand.labels.add(label)
1026        mock_invalidate_dict.assert_called_with(mock.ANY, mock.ANY, using=DEFAULT_DB_ALIAS)
1027
1028        label = Label.objects.using('slave').create()
1029        brand = Brand.objects.using('slave').create()
1030        brand.labels.add(label)
1031        mock_invalidate_dict.assert_called_with(mock.ANY, mock.ANY, using='slave')
1032