1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5import os
6import logging
7import sys
8import unittest
9try:
10    import mock
11except ImportError:
12    from unittest import mock
13
14from knack.arguments import ArgumentsContext
15from knack.commands import CLICommandsLoader, CommandGroup
16from tests.util import DummyCLI, redirect_io
17
18
19# a dummy callback for arg-parse
20def load_params(_):
21    pass
22
23
24def list_foo(my_param):
25    print(str(my_param), end='')
26
27
28class TestCommandWithConfiguredDefaults(unittest.TestCase):
29
30    @classmethod
31    def setUpClass(cls):
32        # Ensure initialization has occurred correctly
33        logging.basicConfig(level=logging.DEBUG)
34
35    @classmethod
36    def tearDownClass(cls):
37        logging.shutdown()
38
39    def _set_up_command_table(self, required):
40
41        class TestCommandsLoader(CLICommandsLoader):
42
43            def load_command_table(self, args):
44                super().load_command_table(args)
45                with CommandGroup(self, 'foo', '{}#{{}}'.format(__name__)) as g:
46                    g.command('list', 'list_foo')
47                return self.command_table
48
49            def load_arguments(self, command):
50                with ArgumentsContext(self, 'foo') as c:
51                    c.argument('my_param', options_list='--my-param',
52                               configured_default='param', required=required)
53                super().load_arguments(command)
54        self.cli_ctx = DummyCLI(commands_loader_cls=TestCommandsLoader)
55
56    @mock.patch.dict(os.environ, {'CLI_DEFAULTS_PARAM': 'myVal'})
57    @redirect_io
58    def test_apply_configured_defaults_on_required_arg(self):
59        self._set_up_command_table(required=True)
60        self.cli_ctx.invoke('foo list'.split())
61        actual = self.io.getvalue()
62        expected = 'myVal'
63        self.assertEqual(expected, actual)
64
65    @redirect_io
66    def test_no_configured_default_on_required_arg(self):
67        self._set_up_command_table(required=True)
68        with self.assertRaises(SystemExit):
69            self.cli_ctx.invoke('foo list'.split())
70        actual = self.io.getvalue()
71        expected = 'required: --my-param'
72        if sys.version_info[0] == 2:
73            expected = 'argument --my-param is required'
74        self.assertEqual(expected in actual, True)
75
76    @mock.patch.dict(os.environ, {'CLI_DEFAULTS_PARAM': 'myVal'})
77    @redirect_io
78    def test_apply_configured_defaults_on_optional_arg(self):
79        self._set_up_command_table(required=False)
80        self.cli_ctx.invoke('foo list'.split())
81        actual = self.io.getvalue()
82        expected = 'myVal'
83        self.assertEqual(expected, actual)
84
85
86if __name__ == '__main__':
87    unittest.main()
88