1# -*- coding: utf-8 -*-
2import os
3import mock
4import logging
5import importlib
6
7from django.conf import settings
8from django.core.management import call_command, find_commands, load_command_class
9from django.test import TestCase
10from io import StringIO
11
12from django_extensions.management.modelviz import use_model, generate_graph_data
13from django_extensions.management.commands.merge_model_instances import get_model_to_deduplicate, get_field_names, keep_first_or_last_instance
14from . import force_color_support
15from .testapp.models import Person, Name, Note, Personality, Club, Membership, Permission
16from .testapp.jobs.hourly.test_hourly_job import HOURLY_JOB_MOCK
17from .testapp.jobs.daily.test_daily_job import DAILY_JOB_MOCK
18from .testapp.jobs.weekly.test_weekly_job import WEEKLY_JOB_MOCK
19from .testapp.jobs.monthly.test_monthly_job import MONTHLY_JOB_MOCK
20from .testapp.jobs.yearly.test_yearly_job import YEARLY_JOB_MOCK
21
22
23class MockLoggingHandler(logging.Handler):
24    """ Mock logging handler to check for expected logs. """
25
26    def __init__(self, *args, **kwargs):
27        self.reset()
28        logging.Handler.__init__(self, *args, **kwargs)
29
30    def emit(self, record):
31        self.messages[record.levelname.lower()].append(record.getMessage())
32
33    def reset(self):
34        self.messages = {
35            'debug': [],
36            'info': [],
37            'warning': [],
38            'error': [],
39            'critical': [],
40        }
41
42
43class CommandTest(TestCase):
44    def test_error_logging(self):
45        # Ensure command errors are properly logged and reraised
46        from django_extensions.management.base import logger
47        logger.addHandler(MockLoggingHandler())
48        module_path = "tests.management.commands.error_raising_command"
49        module = importlib.import_module(module_path)
50        error_raising_command = module.Command()
51        self.assertRaises(Exception, error_raising_command.execute)
52        handler = logger.handlers[0]
53        self.assertEqual(len(handler.messages['error']), 1)
54
55
56class ShowTemplateTagsTests(TestCase):
57    def test_some_output(self):
58        out = StringIO()
59        call_command('show_template_tags', stdout=out)
60        output = out.getvalue()
61        # Once django_extension is installed during tests it should appear with
62        # its templatetags
63        self.assertIn('django_extensions', output)
64        # let's check at least one
65        self.assertIn('syntax_color', output)
66
67
68class DescribeFormTests(TestCase):
69    def test_command(self):
70        out = StringIO()
71        call_command('describe_form', 'django_extensions.Secret', stdout=out)
72        output = out.getvalue()
73        self.assertIn("class SecretForm(forms.Form):", output)
74        self.assertRegex(output, r"name = forms.CharField\(.*max_length=255")
75        self.assertRegex(output, r"name = forms.CharField\(.*required=False")
76        self.assertRegex(output, r"name = forms.CharField\(.*label=u?'Name'")
77        self.assertRegex(output, r"text = forms.CharField\(.*required=False")
78        self.assertRegex(output, r"text = forms.CharField\(.*label=u?'Text'")
79
80
81class CommandSignalTests(TestCase):
82    pre = None
83    post = None
84
85    def test_works(self):
86        from django_extensions.management.signals import post_command, pre_command
87        from django_extensions.management.commands.show_template_tags import Command
88
89        def pre(sender, **kwargs):
90            CommandSignalTests.pre = dict(**kwargs)
91
92        def post(sender, **kwargs):
93            CommandSignalTests.post = dict(**kwargs)
94
95        pre_command.connect(pre, Command)
96        post_command.connect(post, Command)
97
98        out = StringIO()
99        call_command('show_template_tags', stdout=out)
100
101        self.assertIn('args', CommandSignalTests.pre)
102        self.assertIn('kwargs', CommandSignalTests.pre)
103
104        self.assertIn('args', CommandSignalTests.post)
105        self.assertIn('kwargs', CommandSignalTests.post)
106        self.assertIn('outcome', CommandSignalTests.post)
107
108
109class CommandClassTests(TestCase):
110    def setUp(self):
111        management_dir = os.path.join('django_extensions', 'management')
112        self.commands = find_commands(management_dir)
113
114    def test_load_commands(self):
115        """Try to load every management command to catch exceptions."""
116        try:
117            for command in self.commands:
118                load_command_class('django_extensions', command)
119        except Exception as e:
120            self.fail("Can't load command class of {0}\n{1}".format(command, e))
121
122
123class GraphModelsTests(TestCase):
124    """
125    Tests for the `graph_models` management command.
126    """
127    def test_use_model(self):
128        include_models = [
129            'NoWildcardInclude',
130            'Wildcard*InsideInclude',
131            '*WildcardPrefixInclude',
132            'WildcardSuffixInclude*',
133            '*WildcardBothInclude*',
134        ]
135        exclude_models = [
136            'NoWildcardExclude',
137            'Wildcard*InsideExclude',
138            '*WildcardPrefixExclude',
139            'WildcardSuffixExclude*',
140            '*WildcardBothExclude*',
141            '*Include',
142        ]
143        # Any model name should be used if neither include or exclude
144        # are defined.
145        self.assertTrue(use_model(
146            'SomeModel',
147            None,
148            None,
149        ))
150        # Any model name should be allowed if `*` is in `include_models`.
151        self.assertTrue(use_model(
152            'SomeModel',
153            ['OtherModel', '*', 'Wildcard*Model'],
154            None,
155        ))
156        # No model name should be allowed if `*` is in `exclude_models`.
157        self.assertFalse(use_model(
158            'SomeModel',
159            None,
160            ['OtherModel', '*', 'Wildcard*Model'],
161        ))
162        # Some tests with the `include_models` defined above.
163        self.assertFalse(use_model(
164            'SomeModel',
165            include_models,
166            None,
167        ))
168        self.assertTrue(use_model(
169            'NoWildcardInclude',
170            include_models,
171            None,
172        ))
173        self.assertTrue(use_model(
174            'WildcardSomewhereInsideInclude',
175            include_models,
176            None,
177        ))
178        self.assertTrue(use_model(
179            'MyWildcardPrefixInclude',
180            include_models,
181            None,
182        ))
183        self.assertTrue(use_model(
184            'WildcardSuffixIncludeModel',
185            include_models,
186            None,
187        ))
188        self.assertTrue(use_model(
189            'MyWildcardBothIncludeModel',
190            include_models,
191            None,
192        ))
193        # Some tests with the `exclude_models` defined above.
194        self.assertTrue(use_model(
195            'SomeModel',
196            None,
197            exclude_models,
198        ))
199        self.assertFalse(use_model(
200            'NoWildcardExclude',
201            None,
202            exclude_models,
203        ))
204        self.assertFalse(use_model(
205            'WildcardSomewhereInsideExclude',
206            None,
207            exclude_models,
208        ))
209        self.assertFalse(use_model(
210            'MyWildcardPrefixExclude',
211            None,
212            exclude_models,
213        ))
214        self.assertFalse(use_model(
215            'WildcardSuffixExcludeModel',
216            None,
217            exclude_models,
218        ))
219        self.assertFalse(use_model(
220            'MyWildcardBothExcludeModel',
221            None,
222            exclude_models,
223        ))
224        # Test with `exclude_models` and `include_models` combined
225        # where the user wants to exclude some models through a wildcard
226        # while still being able to include given models
227        self.assertTrue(use_model(
228            'MyWildcardPrefixInclude',
229            include_models,
230            exclude_models
231        ))
232        self.assertFalse(use_model(
233            'MyInclude',
234            include_models,
235            exclude_models
236        ))
237
238    def test_no_models_dot_py(self):
239        data = generate_graph_data(['testapp_with_no_models_file'])
240        self.assertEqual(len(data['graphs']), 1)
241
242        model_name = data['graphs'][0]['models'][0]['name']
243        self.assertEqual(model_name, 'TeslaCar')
244
245
246class ShowUrlsTests(TestCase):
247    """
248    Tests for the `show_urls` management command.
249    """
250    def test_color(self):
251        with force_color_support:
252            out = StringIO()
253            call_command('show_urls', stdout=out)
254            self.output = out.getvalue()
255            self.assertIn('\x1b', self.output)
256
257    def test_no_color(self):
258        with force_color_support:
259            out = StringIO()
260            call_command('show_urls', '--no-color', stdout=out)
261            self.output = out.getvalue()
262            self.assertNotIn('\x1b', self.output)
263
264
265class ListModelInfoTests(TestCase):
266    """
267    Tests for the `list_model_info` management command.
268    """
269    def test_plain(self):
270        out = StringIO()
271        call_command('list_model_info', '--model', 'django_extensions.MultipleFieldsAndMethods', stdout=out)
272        self.output = out.getvalue()
273        self.assertIn('id', self.output)
274        self.assertIn('char_field', self.output)
275        self.assertIn('integer_field', self.output)
276        self.assertIn('foreign_key_field', self.output)
277        self.assertIn('has_self_only()', self.output)
278        self.assertIn('has_one_extra_argument()', self.output)
279        self.assertIn('has_two_extra_arguments()', self.output)
280        self.assertIn('has_args_kwargs()', self.output)
281        self.assertIn('has_defaults()', self.output)
282        self.assertNotIn('__class__()', self.output)
283        self.assertNotIn('validate_unique()', self.output)
284
285    def test_all(self):
286        out = StringIO()
287        call_command('list_model_info', '--model', 'django_extensions.MultipleFieldsAndMethods', '--all', stdout=out)
288        self.output = out.getvalue()
289        self.assertIn('id', self.output)
290        self.assertIn('__class__()', self.output)
291        self.assertIn('validate_unique()', self.output)
292
293    def test_signature(self):
294        out = StringIO()
295        call_command('list_model_info', '--model', 'django_extensions.MultipleFieldsAndMethods', '--signature', stdout=out)
296        self.output = out.getvalue()
297        self.assertIn('has_self_only(self)', self.output)
298        self.assertIn('has_one_extra_argument(self, arg_one)', self.output)
299        self.assertIn('has_two_extra_arguments(self, arg_one, arg_two)', self.output)
300        self.assertIn('has_args_kwargs(self, *args, **kwargs)', self.output)
301        self.assertIn("has_defaults(self, one=1, two='Two', true=True, false=False, none=None)", self.output)
302
303    def test_db_type(self):
304        if settings.DATABASES['default']['ENGINE'] == 'django.db.backends.postgresql':
305            id_type = 'serial'
306        else:
307            id_type = 'integer'
308
309        out = StringIO()
310        call_command('list_model_info', '--model', 'django_extensions.MultipleFieldsAndMethods', '--db-type', stdout=out)
311        self.output = out.getvalue()
312        self.assertIn('id - %s' % id_type, self.output)
313        self.assertIn('char_field - varchar(10)', self.output)
314        self.assertIn('integer_field - integer', self.output)
315        self.assertIn('foreign_key_field - integer', self.output)
316
317    def test_field_class(self):
318        out = StringIO()
319        call_command('list_model_info', '--model', 'django_extensions.MultipleFieldsAndMethods', '--field-class', stdout=out)
320        self.output = out.getvalue()
321        self.assertIn('id - AutoField', self.output)
322        self.assertIn('char_field - CharField', self.output)
323        self.assertIn('integer_field - IntegerField', self.output)
324        self.assertIn('foreign_key_field - ForeignKey', self.output)
325
326
327class MergeModelInstancesTests(TestCase):
328    """
329    Tests for the `merge_model_instances` management command.
330    """
331
332    @mock.patch('django_extensions.management.commands.merge_model_instances.apps.get_models')
333    @mock.patch('django_extensions.management.commands.merge_model_instances.input')
334    def test_get_model_to_merge(self, test_input, get_models):
335        class Model:
336            __name__ = ""
337
338        return_value = []
339        for v in ["one", "two", "three"]:
340            instance = Model()
341            instance.__name__ = v
342            return_value.append(instance)
343        get_models.return_value = return_value
344        test_input.return_value = 2
345        model_to_deduplicate = get_model_to_deduplicate()
346        self.assertEqual(model_to_deduplicate.__name__, "two")
347
348    @mock.patch('django_extensions.management.commands.merge_model_instances.input')
349    def test_get_field_names(self, test_input):
350
351        class Field:
352            name = ""
353
354            def __init__(self, name):
355                self.name = name
356
357        class Model:
358            __name__ = ""
359            one = Field(name="one")
360            two = Field(name="two")
361            three = Field(name="three")
362
363        return_value = [Model().__getattribute__(field) for field in dir(Model()) if not field.startswith("__")]
364        Model._meta = mock.MagicMock()
365        Model._meta.get_fields = mock.MagicMock(return_value=return_value)
366
367        # Choose the second return_value
368        test_input.side_effect = [2, "C"]
369        field_names = get_field_names(Model())
370        # Test that the second return_value returned
371        self.assertEqual(field_names, [return_value[1].name])
372
373    @mock.patch('django_extensions.management.commands.merge_model_instances.input')
374    def test_keep_first_or_last_instance(self, test_input):
375        test_input.side_effect = ["xxxx", "first", "last"]
376        first_or_last = keep_first_or_last_instance()
377        self.assertEqual(first_or_last, "first")
378        first_or_last = keep_first_or_last_instance()
379        self.assertEqual(first_or_last, "last")
380
381    @mock.patch('django_extensions.management.commands.merge_model_instances.get_model_to_deduplicate')
382    @mock.patch('django_extensions.management.commands.merge_model_instances.get_field_names')
383    @mock.patch('django_extensions.management.commands.merge_model_instances.keep_first_or_last_instance')
384    def test_merge_model_instances(self, keep_first_or_last_instance, get_field_names, get_model_to_deduplicate):
385        get_model_to_deduplicate.return_value = Person
386        get_field_names.return_value = ["name"]
387        keep_first_or_last_instance.return_value = "first"
388
389        name = Name.objects.create(name="Name")
390        note = Note.objects.create(note="This is a note.")
391        personality_1 = Personality.objects.create(description="Child 1's personality.")
392        personality_2 = Personality.objects.create(description="Child 2's personality.")
393        child_1 = Person.objects.create(
394            name=Name.objects.create(name="Child1"),
395            age=10,
396            personality=personality_1,
397        )
398        child_1.notes.add(note)
399        child_2 = Person.objects.create(
400            name=Name.objects.create(name="Child2"),
401            age=10,
402            personality=personality_2,
403        )
404        child_2.notes.add(note)
405
406        club1 = Club.objects.create(name="Club one")
407        club2 = Club.objects.create(name="Club two")
408        person_1 = Person.objects.create(
409            name=name,
410            age=50,
411            personality=Personality.objects.create(description="First personality"),
412        )
413        person_1.children.add(child_1)
414        person_1.notes.add(note)
415        Permission.objects.create(text="Permission", person=person_1)
416
417        person_2 = Person.objects.create(
418            name=name,
419            age=50,
420            personality=Personality.objects.create(description="Second personality"),
421        )
422        person_2.children.add(child_2)
423        new_note = Note.objects.create(note="This is a new note")
424        person_2.notes.add(new_note)
425        Membership.objects.create(club=club1, person=person_2)
426        Membership.objects.create(club=club1, person=person_2)
427        Permission.objects.create(text="Permission", person=person_2)
428
429        person_3 = Person.objects.create(
430            name=name,
431            age=50,
432            personality=Personality.objects.create(description="Third personality"),
433        )
434        person_3.children.add(child_2)
435        person_3.notes.add(new_note)
436        Membership.objects.create(club=club2, person=person_3)
437        Membership.objects.create(club=club2, person=person_3)
438        Permission.objects.create(text="Permission", person=person_3)
439
440        self.assertEqual(Person.objects.count(), 5)
441        self.assertEqual(Membership.objects.count(), 4)
442        out = StringIO()
443        call_command('merge_model_instances', stdout=out)
444        self.ouptput = out.getvalue()
445        self.assertEqual(Person.objects.count(), 3)
446        person = Person.objects.get(name__name="Name")
447        self.assertRaises(
448            Person.DoesNotExist,
449            lambda: Person.objects.get(personality__description="Second personality"),
450        )
451        self.assertEqual(person.notes.count(), 2)
452        self.assertEqual(person.clubs.distinct().count(), 2)
453        self.assertEqual(person.permission_set.count(), 3)
454        self.assertRaises(
455            Personality.DoesNotExist,
456            lambda: Personality.objects.get(description="Second personality"),
457        )
458
459
460class RunJobsTests(TestCase):
461    """
462    Tests for the `runjobs` management command.
463    """
464
465    @mock.patch('django_extensions.management.commands.runjobs.Command.runjobs_by_signals')
466    @mock.patch('django_extensions.management.commands.runjobs.Command.runjobs')
467    @mock.patch('django_extensions.management.commands.runjobs.Command.usage_msg')
468    def test_runjobs_management_command(
469            self, usage_msg, runjobs, runjobs_by_signals):
470        when = 'daily'
471        call_command('runjobs', when)
472        usage_msg.assert_not_called()
473        runjobs.assert_called_once()
474        runjobs_by_signals.assert_called_once()
475        self.assertEqual(runjobs.call_args[0][0], when)
476
477    @mock.patch('django_extensions.management.commands.runjobs.Command.runjobs_by_signals')
478    @mock.patch('django_extensions.management.commands.runjobs.Command.runjobs')
479    @mock.patch('django_extensions.management.commands.runjobs.Command.usage_msg')
480    def test_runjobs_management_command_invalid_when(
481            self, usage_msg, runjobs, runjobs_by_signals):
482        when = 'invalid'
483        call_command('runjobs', when)
484        usage_msg.assert_called_once_with()
485        runjobs.assert_not_called()
486        runjobs_by_signals.assert_not_called()
487
488    def test_runjobs_integration_test(self):
489        jobs = [
490            ("hourly", HOURLY_JOB_MOCK),
491            ("daily", DAILY_JOB_MOCK),
492            ("monthly", MONTHLY_JOB_MOCK),
493            ("weekly", WEEKLY_JOB_MOCK),
494            ("yearly", YEARLY_JOB_MOCK),
495        ]
496
497        # Reset all mocks in case they have been called elsewhere.
498        for job in jobs:
499            job[1].reset_mock()
500
501        counter = 1
502        for job in jobs:
503            call_command('runjobs', job[0], verbosity=2)
504            for already_called in jobs[:counter]:
505                already_called[1].assert_called_once_with()
506            for not_yet_called in jobs[counter:]:
507                not_yet_called[1].assert_not_called()
508            counter += 1
509
510    def test_runjob_integration_test(self):
511        jobs = [
512            ("test_hourly_job", HOURLY_JOB_MOCK),
513            ("test_daily_job", DAILY_JOB_MOCK),
514            ("test_monthly_job", MONTHLY_JOB_MOCK),
515            ("test_weekly_job", WEEKLY_JOB_MOCK),
516            ("test_yearly_job", YEARLY_JOB_MOCK),
517        ]
518
519        # Reset all mocks in case they have been called elsewhere.
520        for job in jobs:
521            job[1].reset_mock()
522
523        counter = 1
524        for job in jobs:
525            call_command('runjob', job[0], verbosity=2)
526            for already_called in jobs[:counter]:
527                already_called[1].assert_called_once_with()
528            for not_yet_called in jobs[counter:]:
529                not_yet_called[1].assert_not_called()
530            counter += 1
531