1from datetime import timedelta
2from unittest.mock import MagicMock
3
4from django.contrib.auth.models import User
5from django.db import models
6from django.db.transaction import get_connection
7from django.utils import timezone
8import reversion
9from test_app.models import TestModel, TestModelRelated, TestModelThrough, TestModelParent, TestMeta
10from test_app.tests.base import TestBase, TestBaseTransaction, TestModelMixin, UserMixin
11
12
13class SaveTest(TestModelMixin, TestBase):
14
15    def testModelSave(self):
16        TestModel.objects.create()
17        self.assertNoRevision()
18
19
20class IsRegisteredTest(TestModelMixin, TestBase):
21
22    def testIsRegistered(self):
23        self.assertTrue(reversion.is_registered(TestModel))
24
25
26class IsRegisterUnregisteredTest(TestBase):
27
28    def testIsRegisteredFalse(self):
29        self.assertFalse(reversion.is_registered(TestModel))
30
31
32class GetRegisteredModelsTest(TestModelMixin, TestBase):
33
34    def testGetRegisteredModels(self):
35        self.assertEqual(set(reversion.get_registered_models()), set((TestModel,)))
36
37
38class RegisterTest(TestBase):
39
40    def testRegister(self):
41        reversion.register(TestModel)
42        self.assertTrue(reversion.is_registered(TestModel))
43
44    def testRegisterDecorator(self):
45        @reversion.register()
46        class TestModelDecorater(models.Model):
47            pass
48        self.assertTrue(reversion.is_registered(TestModelDecorater))
49
50    def testRegisterAlreadyRegistered(self):
51        reversion.register(TestModel)
52        with self.assertRaises(reversion.RegistrationError):
53            reversion.register(TestModel)
54
55    def testRegisterM2MSThroughLazy(self):
56        # When register is used as a decorator in models.py, lazy relations haven't had a chance to be resolved, so
57        # will still be a string.
58        @reversion.register()
59        class TestModelLazy(models.Model):
60            related = models.ManyToManyField(
61                TestModelRelated,
62                through="TestModelThroughLazy",
63            )
64
65        class TestModelThroughLazy(models.Model):
66            pass
67
68
69class UnregisterTest(TestModelMixin, TestBase):
70
71    def testUnregister(self):
72        reversion.unregister(TestModel)
73        self.assertFalse(reversion.is_registered(TestModel))
74
75
76class UnregisterUnregisteredTest(TestBase):
77
78    def testUnregisterNotRegistered(self):
79        with self.assertRaises(reversion.RegistrationError):
80            reversion.unregister(User)
81
82
83class CreateRevisionTest(TestModelMixin, TestBase):
84
85    def testCreateRevision(self):
86        with reversion.create_revision():
87            obj = TestModel.objects.create()
88        self.assertSingleRevision((obj,))
89
90    def testCreateRevisionNested(self):
91        with reversion.create_revision():
92            with reversion.create_revision():
93                obj = TestModel.objects.create()
94        self.assertSingleRevision((obj,))
95
96    def testCreateRevisionEmpty(self):
97        with reversion.create_revision():
98            pass
99        self.assertNoRevision()
100
101    def testCreateRevisionException(self):
102        try:
103            with reversion.create_revision():
104                TestModel.objects.create()
105                raise Exception("Boom!")
106        except Exception:
107            pass
108        self.assertNoRevision()
109
110    def testCreateRevisionDecorator(self):
111        obj = reversion.create_revision()(TestModel.objects.create)()
112        self.assertSingleRevision((obj,))
113
114    def testPreRevisionCommitSignal(self):
115        _callback = MagicMock()
116        reversion.signals.pre_revision_commit.connect(_callback)
117
118        with reversion.create_revision():
119            TestModel.objects.create()
120        self.assertEqual(_callback.call_count, 1)
121
122    def testPostRevisionCommitSignal(self):
123        _callback = MagicMock()
124        reversion.signals.post_revision_commit.connect(_callback)
125
126        with reversion.create_revision():
127            TestModel.objects.create()
128        self.assertEqual(_callback.call_count, 1)
129
130
131class CreateRevisionAtomicTest(TestModelMixin, TestBaseTransaction):
132    def testCreateRevisionAtomic(self):
133        self.assertFalse(get_connection().in_atomic_block)
134        with reversion.create_revision():
135            self.assertTrue(get_connection().in_atomic_block)
136
137    def testCreateRevisionNonAtomic(self):
138        self.assertFalse(get_connection().in_atomic_block)
139        with reversion.create_revision(atomic=False):
140            self.assertFalse(get_connection().in_atomic_block)
141
142    def testCreateRevisionInOnCommitHandler(self):
143        from django.db import transaction
144        from reversion.models import Revision
145
146        self.assertEqual(Revision.objects.all().count(), 0)
147
148        with reversion.create_revision(atomic=True):
149            model = TestModel.objects.create()
150
151            def on_commit():
152                with reversion.create_revision(atomic=True):
153                    model.name = 'oncommit'
154                    model.save()
155
156            transaction.on_commit(on_commit)
157
158        self.assertEqual(Revision.objects.all().count(), 2)
159
160
161class CreateRevisionManageManuallyTest(TestModelMixin, TestBase):
162
163    def testCreateRevisionManageManually(self):
164        with reversion.create_revision(manage_manually=True):
165            TestModel.objects.create()
166        self.assertNoRevision()
167
168    def testCreateRevisionManageManuallyNested(self):
169        with reversion.create_revision():
170            with reversion.create_revision(manage_manually=True):
171                TestModel.objects.create()
172        self.assertNoRevision()
173
174
175class CreateRevisionDbTest(TestModelMixin, TestBase):
176
177    def testCreateRevisionMultiDb(self):
178        with reversion.create_revision(using="mysql"), reversion.create_revision(using="postgres"):
179            obj = TestModel.objects.create()
180        self.assertNoRevision()
181        self.assertSingleRevision((obj,), using="mysql")
182        self.assertSingleRevision((obj,), using="postgres")
183
184
185class CreateRevisionFollowTest(TestBase):
186
187    def testCreateRevisionFollow(self):
188        reversion.register(TestModel, follow=("related",))
189        reversion.register(TestModelRelated)
190        obj_related = TestModelRelated.objects.create()
191        with reversion.create_revision():
192            obj = TestModel.objects.create()
193            obj.related.add(obj_related)
194        self.assertSingleRevision((obj, obj_related))
195
196    def testCreateRevisionFollowThrough(self):
197        reversion.register(TestModel, follow=("related_through",))
198        reversion.register(TestModelThrough, follow=("test_model", "test_model_related",))
199        reversion.register(TestModelRelated)
200        obj_related = TestModelRelated.objects.create()
201        with reversion.create_revision():
202            obj = TestModel.objects.create()
203            obj_through = TestModelThrough.objects.create(
204                test_model=obj,
205                test_model_related=obj_related,
206            )
207        self.assertSingleRevision((obj, obj_through, obj_related))
208
209    def testCreateRevisionFollowInvalid(self):
210        reversion.register(TestModel, follow=("name",))
211        with reversion.create_revision():
212            with self.assertRaises(reversion.RegistrationError):
213                TestModel.objects.create()
214
215
216class CreateRevisionIgnoreDuplicatesTest(TestBase):
217
218    def testCreateRevisionIgnoreDuplicates(self):
219        reversion.register(TestModel, ignore_duplicates=True)
220        with reversion.create_revision():
221            obj = TestModel.objects.create()
222        with reversion.create_revision():
223            obj.save()
224        self.assertSingleRevision((obj,))
225
226
227class CreateRevisionInheritanceTest(TestModelMixin, TestBase):
228
229    def testCreateRevisionInheritance(self):
230        reversion.register(TestModelParent, follow=("testmodel_ptr",))
231        with reversion.create_revision():
232            obj = TestModelParent.objects.create()
233        self.assertSingleRevision((obj, obj.testmodel_ptr))
234
235
236class SetCommentTest(TestModelMixin, TestBase):
237
238    def testSetComment(self):
239        with reversion.create_revision():
240            reversion.set_comment("comment v1")
241            obj = TestModel.objects.create()
242        self.assertSingleRevision((obj,), comment="comment v1")
243
244    def testSetCommentNoBlock(self):
245        with self.assertRaises(reversion.RevisionManagementError):
246            reversion.set_comment("comment v1")
247
248
249class GetCommentTest(TestBase):
250
251    def testGetComment(self):
252        with reversion.create_revision():
253            reversion.set_comment("comment v1")
254            self.assertEqual(reversion.get_comment(), "comment v1")
255
256    def testGetCommentDefault(self):
257        with reversion.create_revision():
258            self.assertEqual(reversion.get_comment(), "")
259
260    def testGetCommentNoBlock(self):
261        with self.assertRaises(reversion.RevisionManagementError):
262            reversion.get_comment()
263
264
265class SetUserTest(UserMixin, TestModelMixin, TestBase):
266
267    def testSetUser(self):
268        with reversion.create_revision():
269            reversion.set_user(self.user)
270            obj = TestModel.objects.create()
271        self.assertSingleRevision((obj,), user=self.user)
272
273    def testSetUserNoBlock(self):
274        with self.assertRaises(reversion.RevisionManagementError):
275            reversion.set_user(self.user)
276
277
278class GetUserTest(UserMixin, TestBase):
279
280    def testGetUser(self):
281        with reversion.create_revision():
282            reversion.set_user(self.user)
283            self.assertEqual(reversion.get_user(), self.user)
284
285    def testGetUserDefault(self):
286        with reversion.create_revision():
287            self.assertEqual(reversion.get_user(), None)
288
289    def testGetUserNoBlock(self):
290        with self.assertRaises(reversion.RevisionManagementError):
291            reversion.get_user()
292
293
294class SetDateCreatedTest(TestModelMixin, TestBase):
295
296    def testSetDateCreated(self):
297        date_created = timezone.now() - timedelta(days=20)
298        with reversion.create_revision():
299            reversion.set_date_created(date_created)
300            obj = TestModel.objects.create()
301        self.assertSingleRevision((obj,), date_created=date_created)
302
303    def testDateCreatedNoBlock(self):
304        with self.assertRaises(reversion.RevisionManagementError):
305            reversion.set_date_created(timezone.now())
306
307
308class GetDateCreatedTest(TestBase):
309
310    def testGetDateCreated(self):
311        date_created = timezone.now() - timedelta(days=20)
312        with reversion.create_revision():
313            reversion.set_date_created(date_created)
314            self.assertEqual(reversion.get_date_created(), date_created)
315
316    def testGetDateCreatedDefault(self):
317        with reversion.create_revision():
318            self.assertAlmostEqual(reversion.get_date_created(), timezone.now(), delta=timedelta(seconds=1))
319
320    def testGetDateCreatedNoBlock(self):
321        with self.assertRaises(reversion.RevisionManagementError):
322            reversion.get_date_created()
323
324
325class AddMetaTest(TestModelMixin, TestBase):
326
327    def testAddMeta(self):
328        with reversion.create_revision():
329            reversion.add_meta(TestMeta, name="meta v1")
330            obj = TestModel.objects.create()
331        self.assertSingleRevision((obj,), meta_names=("meta v1",))
332
333    def testAddMetaNoBlock(self):
334        with self.assertRaises(reversion.RevisionManagementError):
335            reversion.add_meta(TestMeta, name="meta v1")
336
337    def testAddMetaMultDb(self):
338        with reversion.create_revision(using="mysql"), reversion.create_revision(using="postgres"):
339            obj = TestModel.objects.create()
340            reversion.add_meta(TestMeta, name="meta v1")
341        self.assertNoRevision()
342        self.assertSingleRevision((obj,), meta_names=("meta v1",), using="mysql")
343        self.assertSingleRevision((obj,), meta_names=("meta v1",), using="postgres")
344