1from __future__ import print_function
2
3from django.contrib.contenttypes.models import ContentType
4from django.db.models import Q
5from django.test import TestCase
6
7from polymorphic.tests.models import (
8    Base,
9    BlogA,
10    BlogEntry,
11    Model2A,
12    Model2B,
13    Model2C,
14    Model2D,
15    ModelX,
16    ModelY,
17    One2OneRelatingModel,
18    RelatingModel,
19)
20
21
22class MultipleDatabasesTests(TestCase):
23    multi_db = True
24
25    def test_save_to_non_default_database(self):
26        Model2A.objects.db_manager("secondary").create(field1="A1")
27        Model2C(field1="C1", field2="C2", field3="C3").save(using="secondary")
28        Model2B.objects.create(field1="B1", field2="B2")
29        Model2D(field1="D1", field2="D2", field3="D3", field4="D4").save()
30
31        self.assertQuerysetEqual(
32            Model2A.objects.order_by("id"),
33            [Model2B, Model2D],
34            transform=lambda o: o.__class__,
35        )
36
37        self.assertQuerysetEqual(
38            Model2A.objects.db_manager("secondary").order_by("id"),
39            [Model2A, Model2C],
40            transform=lambda o: o.__class__,
41        )
42
43    def test_instance_of_filter_on_non_default_database(self):
44        Base.objects.db_manager("secondary").create(field_b="B1")
45        ModelX.objects.db_manager("secondary").create(field_b="B", field_x="X")
46        ModelY.objects.db_manager("secondary").create(field_b="Y", field_y="Y")
47
48        objects = Base.objects.db_manager("secondary").filter(instance_of=Base)
49        self.assertQuerysetEqual(
50            objects,
51            [Base, ModelX, ModelY],
52            transform=lambda o: o.__class__,
53            ordered=False,
54        )
55
56        self.assertQuerysetEqual(
57            Base.objects.db_manager("secondary").filter(instance_of=ModelX),
58            [ModelX],
59            transform=lambda o: o.__class__,
60        )
61
62        self.assertQuerysetEqual(
63            Base.objects.db_manager("secondary").filter(instance_of=ModelY),
64            [ModelY],
65            transform=lambda o: o.__class__,
66        )
67
68        self.assertQuerysetEqual(
69            Base.objects.db_manager("secondary").filter(
70                Q(instance_of=ModelX) | Q(instance_of=ModelY)
71            ),
72            [ModelX, ModelY],
73            transform=lambda o: o.__class__,
74            ordered=False,
75        )
76
77    def test_forward_many_to_one_descriptor_on_non_default_database(self):
78        def func():
79            blog = BlogA.objects.db_manager("secondary").create(
80                name="Blog", info="Info"
81            )
82            entry = BlogEntry.objects.db_manager("secondary").create(
83                blog=blog, text="Text"
84            )
85            ContentType.objects.clear_cache()
86            entry = BlogEntry.objects.db_manager("secondary").get(pk=entry.id)
87            self.assertEqual(blog, entry.blog)
88
89        # Ensure no queries are made using the default database.
90        self.assertNumQueries(0, func)
91
92    def test_reverse_many_to_one_descriptor_on_non_default_database(self):
93        def func():
94            blog = BlogA.objects.db_manager("secondary").create(
95                name="Blog", info="Info"
96            )
97            entry = BlogEntry.objects.db_manager("secondary").create(
98                blog=blog, text="Text"
99            )
100            ContentType.objects.clear_cache()
101            blog = BlogA.objects.db_manager("secondary").get(pk=blog.id)
102            self.assertEqual(entry, blog.blogentry_set.using("secondary").get())
103
104        # Ensure no queries are made using the default database.
105        self.assertNumQueries(0, func)
106
107    def test_reverse_one_to_one_descriptor_on_non_default_database(self):
108        def func():
109            m2a = Model2A.objects.db_manager("secondary").create(field1="A1")
110            one2one = One2OneRelatingModel.objects.db_manager("secondary").create(
111                one2one=m2a, field1="121"
112            )
113            ContentType.objects.clear_cache()
114            m2a = Model2A.objects.db_manager("secondary").get(pk=m2a.id)
115            self.assertEqual(one2one, m2a.one2onerelatingmodel)
116
117        # Ensure no queries are made using the default database.
118        self.assertNumQueries(0, func)
119
120    def test_many_to_many_descriptor_on_non_default_database(self):
121        def func():
122            m2a = Model2A.objects.db_manager("secondary").create(field1="A1")
123            rm = RelatingModel.objects.db_manager("secondary").create()
124            rm.many2many.add(m2a)
125            ContentType.objects.clear_cache()
126            m2a = Model2A.objects.db_manager("secondary").get(pk=m2a.id)
127            self.assertEqual(rm, m2a.relatingmodel_set.using("secondary").get())
128
129        # Ensure no queries are made using the default database.
130        self.assertNumQueries(0, func)
131