1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6try:
7    import mock
8except ImportError:
9    from unittest import mock
10import logging
11import os
12import re
13import shutil
14import sys
15import tempfile
16from io import StringIO
17
18from knack.cli import CLI, CLICommandsLoader, CommandInvoker
19from knack.log import CLI_LOGGER_NAME
20
21TEMP_FOLDER_NAME = "knack_temp"
22
23
24def redirect_io(func):
25
26    original_stderr = sys.stderr
27    original_stdout = sys.stdout
28
29    def wrapper(self):
30        # Ensure a clean startup - no log handlers
31        root_logger = logging.getLogger()
32        cli_logger = logging.getLogger(CLI_LOGGER_NAME)
33        root_logger.handlers.clear()
34        cli_logger.handlers.clear()
35
36        sys.stdout = sys.stderr = self.io = StringIO()
37        func(self)
38        self.io.close()
39        sys.stdout = original_stdout
40        sys.stderr = original_stderr
41
42        # Remove the handlers added by CLI, so that the next invoke call init them again with the new stderr
43        # Otherwise, the handlers will write to a closed StringIO from a preview test
44        root_logger.handlers.clear()
45        cli_logger.handlers.clear()
46    return wrapper
47
48
49def disable_color(func):
50    def wrapper(self):
51        self.cli_ctx.enable_color = False
52        func(self)
53        self.cli_ctx.enable_color = True
54    return wrapper
55
56
57def _remove_control_sequence(string):
58    return re.sub(r'\x1b[^m]+m', '', string)
59
60
61def _remove_whitespace(string):
62    return re.sub(r'\s', '', string)
63
64
65def assert_in_multi_line(sub_string, string):
66    # assert sub_string is in string, with all whitespaces, line breaks and control sequences ignored
67    assert _remove_whitespace(sub_string) in _remove_control_sequence(_remove_whitespace(string))
68
69
70class MockContext(CLI):
71
72    def __init__(self):
73        super().__init__(config_dir=new_temp_folder())
74        loader = CLICommandsLoader(cli_ctx=self)
75        invocation = mock.MagicMock(spec=CommandInvoker)
76        invocation.data = {}
77        setattr(self, 'commands_loader', loader)
78        setattr(self, 'invocation', invocation)
79
80
81class DummyCLI(CLI):
82
83    def get_cli_version(self):
84        return '0.1.0'
85
86    def __init__(self, **kwargs):
87        kwargs['config_dir'] = new_temp_folder()
88        super().__init__(**kwargs)
89        # Force to enable color
90        self.enable_color = True
91
92
93def new_temp_folder():
94    temp_dir = os.path.join(tempfile.gettempdir(), TEMP_FOLDER_NAME)
95    if os.path.exists(temp_dir):
96        shutil.rmtree(temp_dir)
97    os.mkdir(temp_dir)
98    return temp_dir
99