1# Copyright: See the LICENSE file.
2
3import unittest
4
5from factory import base, declarations, enums, errors
6
7
8class TestObject:
9    def __init__(self, one=None, two=None, three=None, four=None):
10        self.one = one
11        self.two = two
12        self.three = three
13        self.four = four
14
15
16class FakeDjangoModel:
17    @classmethod
18    def create(cls, **kwargs):
19        instance = cls(**kwargs)
20        instance.id = 1
21        return instance
22
23    def __init__(self, **kwargs):
24        for name, value in kwargs.items():
25            setattr(self, name, value)
26            self.id = None
27
28
29class FakeModelFactory(base.Factory):
30    class Meta:
31        abstract = True
32
33    @classmethod
34    def _create(cls, model_class, *args, **kwargs):
35        return model_class.create(**kwargs)
36
37
38class TestModel(FakeDjangoModel):
39    pass
40
41
42class SafetyTestCase(unittest.TestCase):
43    def test_base_factory(self):
44        with self.assertRaises(errors.FactoryError):
45            base.BaseFactory()
46
47
48class AbstractFactoryTestCase(unittest.TestCase):
49    def test_factory_for_optional(self):
50        """Ensure that model= is optional for abstract=True."""
51        class TestObjectFactory(base.Factory):
52            class Meta:
53                abstract = True
54
55        self.assertTrue(TestObjectFactory._meta.abstract)
56        self.assertIsNone(TestObjectFactory._meta.model)
57
58    def test_factory_for_and_abstract_factory_optional(self):
59        """Ensure that Meta.abstract is optional."""
60        class TestObjectFactory(base.Factory):
61            pass
62
63        self.assertTrue(TestObjectFactory._meta.abstract)
64        self.assertIsNone(TestObjectFactory._meta.model)
65
66    def test_abstract_factory_cannot_be_called(self):
67        class TestObjectFactory(base.Factory):
68            pass
69
70        with self.assertRaises(errors.FactoryError):
71            TestObjectFactory.build()
72        with self.assertRaises(errors.FactoryError):
73            TestObjectFactory.create()
74
75    def test_abstract_factory_not_inherited(self):
76        """abstract=True isn't propagated to child classes."""
77
78        class TestObjectFactory(base.Factory):
79            class Meta:
80                abstract = True
81                model = TestObject
82
83        class TestObjectChildFactory(TestObjectFactory):
84            pass
85
86        self.assertFalse(TestObjectChildFactory._meta.abstract)
87
88    def test_abstract_or_model_is_required(self):
89        class TestObjectFactory(base.Factory):
90            class Meta:
91                abstract = False
92                model = None
93
94        with self.assertRaises(errors.FactoryError):
95            TestObjectFactory.build()
96        with self.assertRaises(errors.FactoryError):
97            TestObjectFactory.create()
98
99
100class OptionsTests(unittest.TestCase):
101    def test_base_attrs(self):
102        class AbstractFactory(base.Factory):
103            pass
104
105        # Declarative attributes
106        self.assertTrue(AbstractFactory._meta.abstract)
107        self.assertIsNone(AbstractFactory._meta.model)
108        self.assertEqual((), AbstractFactory._meta.inline_args)
109        self.assertEqual((), AbstractFactory._meta.exclude)
110        self.assertEqual(enums.CREATE_STRATEGY, AbstractFactory._meta.strategy)
111
112        # Non-declarative attributes
113        self.assertEqual({}, AbstractFactory._meta.pre_declarations.as_dict())
114        self.assertEqual({}, AbstractFactory._meta.post_declarations.as_dict())
115        self.assertEqual(AbstractFactory, AbstractFactory._meta.factory)
116        self.assertEqual(base.Factory, AbstractFactory._meta.base_factory)
117        self.assertEqual(AbstractFactory._meta, AbstractFactory._meta.counter_reference)
118
119    def test_declaration_collecting(self):
120        lazy = declarations.LazyFunction(int)
121        lazy2 = declarations.LazyAttribute(lambda _o: 1)
122        postgen = declarations.PostGenerationDeclaration()
123
124        class AbstractFactory(base.Factory):
125            x = 1
126            y = lazy
127            y2 = lazy2
128            z = postgen
129
130        # Declarations aren't removed
131        self.assertEqual(1, AbstractFactory.x)
132        self.assertEqual(lazy, AbstractFactory.y)
133        self.assertEqual(lazy2, AbstractFactory.y2)
134        self.assertEqual(postgen, AbstractFactory.z)
135
136        # And are available in class Meta
137        self.assertEqual(
138            {'x': 1, 'y': lazy, 'y2': lazy2},
139            AbstractFactory._meta.pre_declarations.as_dict(),
140        )
141        self.assertEqual(
142            {'z': postgen},
143            AbstractFactory._meta.post_declarations.as_dict(),
144        )
145
146    def test_inherited_declaration_collecting(self):
147        lazy = declarations.LazyFunction(int)
148        lazy2 = declarations.LazyAttribute(lambda _o: 2)
149        postgen = declarations.PostGenerationDeclaration()
150        postgen2 = declarations.PostGenerationDeclaration()
151
152        class AbstractFactory(base.Factory):
153            x = 1
154            y = lazy
155            z = postgen
156
157        class OtherFactory(AbstractFactory):
158            a = lazy2
159            b = postgen2
160
161        # Declarations aren't removed
162        self.assertEqual(lazy2, OtherFactory.a)
163        self.assertEqual(postgen2, OtherFactory.b)
164        self.assertEqual(1, OtherFactory.x)
165        self.assertEqual(lazy, OtherFactory.y)
166        self.assertEqual(postgen, OtherFactory.z)
167
168        # And are available in class Meta
169        self.assertEqual(
170            {'x': 1, 'y': lazy, 'a': lazy2},
171            OtherFactory._meta.pre_declarations.as_dict(),
172        )
173        self.assertEqual(
174            {'z': postgen, 'b': postgen2},
175            OtherFactory._meta.post_declarations.as_dict(),
176        )
177
178    def test_inherited_declaration_shadowing(self):
179        lazy = declarations.LazyFunction(int)
180        lazy2 = declarations.LazyAttribute(lambda _o: 2)
181        postgen = declarations.PostGenerationDeclaration()
182        postgen2 = declarations.PostGenerationDeclaration()
183
184        class AbstractFactory(base.Factory):
185            x = 1
186            y = lazy
187            z = postgen
188
189        class OtherFactory(AbstractFactory):
190            y = lazy2
191            z = postgen2
192
193        # Declarations aren't removed
194        self.assertEqual(1, OtherFactory.x)
195        self.assertEqual(lazy2, OtherFactory.y)
196        self.assertEqual(postgen2, OtherFactory.z)
197
198        # And are available in class Meta
199        self.assertEqual(
200            {'x': 1, 'y': lazy2},
201            OtherFactory._meta.pre_declarations.as_dict(),
202        )
203        self.assertEqual(
204            {'z': postgen2},
205            OtherFactory._meta.post_declarations.as_dict(),
206        )
207
208    def test_factory_as_meta_model_raises_exception(self):
209        class FirstFactory(base.Factory):
210            pass
211
212        class Meta:
213            model = FirstFactory
214
215        with self.assertRaises(TypeError):
216            type("SecondFactory", (base.Factory,), {"Meta": Meta})
217
218
219class DeclarationParsingTests(unittest.TestCase):
220    def test_classmethod(self):
221        class TestObjectFactory(base.Factory):
222            class Meta:
223                model = TestObject
224
225            @classmethod
226            def some_classmethod(cls):
227                return cls.create()
228
229        self.assertTrue(hasattr(TestObjectFactory, 'some_classmethod'))
230        obj = TestObjectFactory.some_classmethod()
231        self.assertEqual(TestObject, obj.__class__)
232
233
234class FactoryTestCase(unittest.TestCase):
235    def test_magic_happens(self):
236        """Calling a FooFactory doesn't yield a FooFactory instance."""
237        class TestObjectFactory(base.Factory):
238            class Meta:
239                model = TestObject
240
241        self.assertEqual(TestObject, TestObjectFactory._meta.model)
242        obj = TestObjectFactory.build()
243        self.assertFalse(hasattr(obj, '_meta'))
244
245    def test_display(self):
246        class TestObjectFactory(base.Factory):
247            class Meta:
248                model = FakeDjangoModel
249
250        self.assertIn('TestObjectFactory', str(TestObjectFactory))
251        self.assertIn('FakeDjangoModel', str(TestObjectFactory))
252
253    def test_lazy_attribute_non_existent_param(self):
254        class TestObjectFactory(base.Factory):
255            class Meta:
256                model = TestObject
257
258            one = declarations.LazyAttribute(lambda a: a.does_not_exist)
259
260        with self.assertRaises(AttributeError):
261            TestObjectFactory()
262
263    def test_inheritance_with_sequence(self):
264        """Tests that sequence IDs are shared between parent and son."""
265        class TestObjectFactory(base.Factory):
266            class Meta:
267                model = TestObject
268
269            one = declarations.Sequence(lambda a: a)
270
271        class TestSubFactory(TestObjectFactory):
272            class Meta:
273                model = TestObject
274
275            pass
276
277        parent = TestObjectFactory.build()
278        sub = TestSubFactory.build()
279        alt_parent = TestObjectFactory.build()
280        alt_sub = TestSubFactory.build()
281        ones = {x.one for x in (parent, alt_parent, sub, alt_sub)}
282        self.assertEqual(4, len(ones))
283
284
285class FactorySequenceTestCase(unittest.TestCase):
286    def setUp(self):
287        super().setUp()
288
289        class TestObjectFactory(base.Factory):
290            class Meta:
291                model = TestObject
292            one = declarations.Sequence(lambda n: n)
293
294        self.TestObjectFactory = TestObjectFactory
295
296    def test_reset_sequence(self):
297        o1 = self.TestObjectFactory()
298        self.assertEqual(0, o1.one)
299
300        o2 = self.TestObjectFactory()
301        self.assertEqual(1, o2.one)
302
303        self.TestObjectFactory.reset_sequence()
304        o3 = self.TestObjectFactory()
305        self.assertEqual(0, o3.one)
306
307    def test_reset_sequence_with_value(self):
308        o1 = self.TestObjectFactory()
309        self.assertEqual(0, o1.one)
310
311        o2 = self.TestObjectFactory()
312        self.assertEqual(1, o2.one)
313
314        self.TestObjectFactory.reset_sequence(42)
315        o3 = self.TestObjectFactory()
316        self.assertEqual(42, o3.one)
317
318    def test_reset_sequence_subclass_fails(self):
319        """Tests that the sequence of a 'slave' factory cannot be reseted."""
320        class SubTestObjectFactory(self.TestObjectFactory):
321            pass
322
323        with self.assertRaises(ValueError):
324            SubTestObjectFactory.reset_sequence()
325
326    def test_reset_sequence_subclass_force(self):
327        """Tests that reset_sequence(force=True) works."""
328        class SubTestObjectFactory(self.TestObjectFactory):
329            pass
330
331        o1 = SubTestObjectFactory()
332        self.assertEqual(0, o1.one)
333
334        o2 = SubTestObjectFactory()
335        self.assertEqual(1, o2.one)
336
337        SubTestObjectFactory.reset_sequence(force=True)
338        o3 = SubTestObjectFactory()
339        self.assertEqual(0, o3.one)
340
341        # The master sequence counter has been reset
342        o4 = self.TestObjectFactory()
343        self.assertEqual(1, o4.one)
344
345    def test_reset_sequence_subclass_parent(self):
346        """Tests that the sequence of a 'slave' factory cannot be reseted."""
347        class SubTestObjectFactory(self.TestObjectFactory):
348            pass
349
350        o1 = SubTestObjectFactory()
351        self.assertEqual(0, o1.one)
352
353        o2 = SubTestObjectFactory()
354        self.assertEqual(1, o2.one)
355
356        self.TestObjectFactory.reset_sequence()
357        o3 = SubTestObjectFactory()
358        self.assertEqual(0, o3.one)
359
360        o4 = self.TestObjectFactory()
361        self.assertEqual(1, o4.one)
362
363
364class FactoryDefaultStrategyTestCase(unittest.TestCase):
365    def test_build_strategy(self):
366        class TestModelFactory(base.Factory):
367            class Meta:
368                model = TestModel
369                strategy = enums.BUILD_STRATEGY
370
371            one = 'one'
372
373        test_model = TestModelFactory()
374        self.assertEqual(test_model.one, 'one')
375        self.assertFalse(test_model.id)
376
377    def test_create_strategy(self):
378        # Default Meta.strategy
379
380        class TestModelFactory(FakeModelFactory):
381            class Meta:
382                model = TestModel
383
384            one = 'one'
385
386        test_model = TestModelFactory()
387        self.assertEqual(test_model.one, 'one')
388        self.assertTrue(test_model.id)
389
390    def test_stub_strategy(self):
391        class TestModelFactory(base.Factory):
392            class Meta:
393                model = TestModel
394                strategy = enums.STUB_STRATEGY
395
396            one = 'one'
397
398        test_model = TestModelFactory()
399        self.assertEqual(test_model.one, 'one')
400        self.assertFalse(hasattr(test_model, 'id'))  # We should have a plain old object
401
402    def test_unknown_strategy(self):
403        class TestModelFactory(base.Factory):
404            class Meta:
405                model = TestModel
406                strategy = 'unknown'
407
408            one = 'one'
409
410        with self.assertRaises(base.Factory.UnknownStrategy):
411            TestModelFactory()
412
413    def test_stub_with_create_strategy(self):
414        class TestModelFactory(base.StubFactory):
415            class Meta:
416                model = TestModel
417                strategy = enums.CREATE_STRATEGY
418
419            one = 'one'
420
421        with self.assertRaises(base.StubFactory.UnsupportedStrategy):
422            TestModelFactory()
423
424    def test_stub_with_build_strategy(self):
425        class TestModelFactory(base.StubFactory):
426            class Meta:
427                model = TestModel
428                strategy = enums.BUILD_STRATEGY
429
430            one = 'one'
431
432        obj = TestModelFactory()
433
434        # For stubs, build() is an alias of stub().
435        self.assertFalse(isinstance(obj, TestModel))
436
437    def test_change_strategy(self):
438        class TestModelFactory(base.StubFactory):
439            class Meta:
440                model = TestModel
441                strategy = enums.CREATE_STRATEGY
442
443            one = 'one'
444
445        self.assertEqual(enums.CREATE_STRATEGY, TestModelFactory._meta.strategy)
446
447
448class FactoryCreationTestCase(unittest.TestCase):
449    def test_factory_for(self):
450        class TestFactory(base.Factory):
451            class Meta:
452                model = TestObject
453
454        self.assertTrue(isinstance(TestFactory.build(), TestObject))
455
456    def test_stub(self):
457        class TestFactory(base.StubFactory):
458            pass
459
460        self.assertEqual(TestFactory._meta.strategy, enums.STUB_STRATEGY)
461
462    def test_inheritance_with_stub(self):
463        class TestObjectFactory(base.StubFactory):
464            class Meta:
465                model = TestObject
466
467            pass
468
469        class TestFactory(TestObjectFactory):
470            pass
471
472        self.assertEqual(TestFactory._meta.strategy, enums.STUB_STRATEGY)
473
474    def test_stub_and_subfactory(self):
475        class StubA(base.StubFactory):
476            class Meta:
477                model = TestObject
478
479            one = 'blah'
480
481        class StubB(base.StubFactory):
482            class Meta:
483                model = TestObject
484
485            stubbed = declarations.SubFactory(StubA, two='two')
486
487        b = StubB()
488        self.assertEqual('blah', b.stubbed.one)
489        self.assertEqual('two', b.stubbed.two)
490
491    def test_custom_creation(self):
492        class TestModelFactory(FakeModelFactory):
493            class Meta:
494                model = TestModel
495
496            @classmethod
497            def _generate(cls, create, attrs):
498                attrs['four'] = 4
499                return super()._generate(create, attrs)
500
501        b = TestModelFactory.build(one=1)
502        self.assertEqual(1, b.one)
503        self.assertEqual(4, b.four)
504        self.assertEqual(None, b.id)
505
506        c = TestModelFactory(one=1)
507        self.assertEqual(1, c.one)
508        self.assertEqual(4, c.four)
509        self.assertEqual(1, c.id)
510
511    # Errors
512
513    def test_no_associated_class(self):
514        class Test(base.Factory):
515            pass
516
517        self.assertTrue(Test._meta.abstract)
518
519
520class PostGenerationParsingTestCase(unittest.TestCase):
521
522    def test_extraction(self):
523        class TestObjectFactory(base.Factory):
524            class Meta:
525                model = TestObject
526
527            foo = declarations.PostGenerationDeclaration()
528
529        self.assertIn('foo', TestObjectFactory._meta.post_declarations.as_dict())
530
531    def test_classlevel_extraction(self):
532        class TestObjectFactory(base.Factory):
533            class Meta:
534                model = TestObject
535
536            foo = declarations.PostGenerationDeclaration()
537            foo__bar = 42
538
539        self.assertIn('foo', TestObjectFactory._meta.post_declarations.as_dict())
540        self.assertIn('foo__bar', TestObjectFactory._meta.post_declarations.as_dict())
541