1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import unittest
7import os
8import inspect
9import json
10import shlex
11import tempfile
12import shutil
13import logging
14import io
15import vcr
16
17from .patches import patch_time_sleep_api
18from .exceptions import CliExecutionError
19from .const import (ENV_LIVE_TEST, ENV_SKIP_ASSERT, ENV_TEST_DIAGNOSE)
20from .decorators import live_only
21from .recording_processors import (GeneralNameReplacer, LargeRequestBodyProcessor,
22                                   LargeResponseBodyProcessor, LargeResponseBodyReplacer)
23from .util import find_recording_dir, create_random_name
24logger = logging.getLogger('clicore.testsdk')
25
26
27class IntegrationTestBase(unittest.TestCase):
28    def __init__(self, cli, method_name):
29        super().__init__(method_name)
30        self.cli = cli
31        self.diagnose = os.environ.get(ENV_TEST_DIAGNOSE, None) == 'True'
32
33    def cmd(self, command, checks=None, expect_failure=False):
34        return ExecutionResult(self.cli, command, expect_failure=expect_failure).assert_with_checks(checks)
35
36    def create_random_name(self, prefix, length):  # pylint: disable=no-self-use
37        return create_random_name(prefix=prefix, length=length)
38
39    def create_temp_file(self, size_kb, full_random=False):
40        """
41        Create a temporary file for testing. The test harness will delete the file during tearing
42        down.
43        """
44        fd, path = tempfile.mkstemp()
45        os.close(fd)
46        self.addCleanup(lambda: os.remove(path))
47
48        with open(path, mode='r+b') as f:
49            if full_random:
50                chunk = os.urandom(1024)
51            else:
52                chunk = bytearray([0] * 1024)
53            for _ in range(size_kb):
54                f.write(chunk)
55
56        return path
57
58    def create_temp_dir(self):
59        """
60        Create a temporary directory for testing. The test harness will delete the directory during
61        tearing down.
62        """
63        temp_dir = tempfile.mkdtemp()
64        self.addCleanup(lambda: shutil.rmtree(temp_dir, ignore_errors=True))
65
66        return temp_dir
67
68    @classmethod
69    def set_env(cls, key, val):
70        os.environ[key] = val
71
72    @classmethod
73    def pop_env(cls, key):
74        return os.environ.pop(key, None)
75
76
77@live_only()
78class LiveTest(IntegrationTestBase):
79    pass
80
81
82class ScenarioTest(IntegrationTestBase):  # pylint: disable=too-many-instance-attributes
83
84    def __init__(self, cli, method_name, filter_headers=None):
85        super().__init__(cli, method_name)
86        self.name_replacer = GeneralNameReplacer()
87        self.recording_processors = [LargeRequestBodyProcessor(),
88                                     LargeResponseBodyProcessor(),
89                                     self.name_replacer]
90        self.replay_processors = [LargeResponseBodyReplacer()]
91        self.filter_headers = filter_headers or []
92
93        test_file_path = inspect.getfile(self.__class__)
94        recordings_dir = find_recording_dir(test_file_path)
95        live_test = os.environ.get(ENV_LIVE_TEST, None) == 'True'
96
97        self.vcr = vcr.VCR(
98            cassette_library_dir=recordings_dir,
99            before_record_request=self._process_request_recording,
100            before_record_response=self._process_response_recording,
101            decode_compressed_response=True,
102            record_mode='once' if not live_test else 'all',
103            filter_headers=self.filter_headers
104        )
105        self.vcr.register_matcher('query', self._custom_request_query_matcher)
106
107        self.recording_file = os.path.join(recordings_dir, '{}.yaml'.format(method_name))
108        if live_test and os.path.exists(self.recording_file):
109            os.remove(self.recording_file)
110
111        self.in_recording = live_test or not os.path.exists(self.recording_file)
112        self.test_resources_count = 0
113        self.original_env = os.environ.copy()
114
115    def setUp(self):
116        super().setUp()
117
118        # set up cassette
119        cm = self.vcr.use_cassette(self.recording_file)
120        self.cassette = cm.__enter__()
121        self.addCleanup(cm.__exit__)
122
123        if not self.in_recording:
124            patch_time_sleep_api(self)
125
126    def tearDown(self):
127        os.environ = self.original_env
128
129    def create_random_name(self, prefix, length):
130        self.test_resources_count += 1
131        moniker = '{}{:06}'.format(prefix, self.test_resources_count)
132
133        if self.in_recording:
134            name = create_random_name(prefix, length)
135            self.name_replacer.register_name_pair(name, moniker)
136            return name
137
138        return moniker
139
140    def _process_request_recording(self, request):
141        if self.in_recording:
142            for processor in self.recording_processors:
143                request = processor.process_request(request)
144                if not request:
145                    break
146        else:
147            for processor in self.replay_processors:
148                request = processor.process_request(request)
149                if not request:
150                    break
151
152        return request
153
154    def _process_response_recording(self, response):
155        if self.in_recording:
156            # make header name lower case and filter unwanted headers
157            headers = {}
158            for key in response['headers']:
159                if key.lower() not in self.filter_headers:
160                    headers[key.lower()] = response['headers'][key]
161            response['headers'] = headers
162
163            body = response['body']['string']
164            if body and not isinstance(body, str):
165                response['body']['string'] = body.decode('utf-8')
166
167            for processor in self.recording_processors:
168                response = processor.process_response(response)
169                if not response:
170                    break
171        else:
172            for processor in self.replay_processors:
173                response = processor.process_response(response)
174                if not response:
175                    break
176
177        return response
178
179    @classmethod
180    def _custom_request_query_matcher(cls, r1, r2):
181        """ Ensure method, path, and query parameters match. """
182        from urllib.parse import urlparse, parse_qs  # pylint: disable=useless-suppression
183
184        url1 = urlparse(r1.uri)
185        url2 = urlparse(r2.uri)
186
187        q1 = parse_qs(url1.query)
188        q2 = parse_qs(url2.query)
189        shared_keys = set(q1.keys()).intersection(set(q2.keys()))
190
191        if len(shared_keys) != len(q1) or len(shared_keys) != len(q2):
192            return False
193
194        for key in shared_keys:
195            if q1[key][0].lower() != q2[key][0].lower():
196                return False
197
198        return True
199
200
201class ExecutionResult(object):
202    def __init__(self, cli, command, expect_failure=False):
203        self.cli = cli
204        self._in_process_execute(command)
205
206        if expect_failure and self.exit_code == 0:
207            logger.error('Command "%s" => %d. (It did not fail as expected) Output: %s', command,
208                         self.exit_code, self.output)
209            raise AssertionError('The command did not fail as it was expected.')
210        if not expect_failure and self.exit_code != 0:
211            logger.error('Command "%s" => %d. Output: %s', command, self.exit_code, self.output)
212            raise AssertionError('The command failed. Exit code: {}'.format(self.exit_code))
213
214        logger.info('Command "%s" => %d. Output: %s', command, self.exit_code, self.output)
215
216        self.json_value = None
217        self.skip_assert = os.environ.get(ENV_SKIP_ASSERT, None) == 'True'
218
219    def assert_with_checks(self, *args):
220        checks = []
221        for each in args:
222            if isinstance(each, list):
223                checks.extend(each)
224            elif callable(each):
225                checks.append(each)
226
227        if not self.skip_assert:
228            for c in checks:
229                c(self)
230
231        return self
232
233    def get_output_in_json(self):
234        if not self.json_value:
235            self.json_value = json.loads(self.output)
236
237        if self.json_value is None:
238            raise AssertionError('The command output cannot be parsed in json.')
239
240        return self.json_value
241
242    def _in_process_execute(self, command):
243        cli_name_prefixed = '{} '.format(self.cli.name)
244        if command.startswith(cli_name_prefixed):
245            command = command[len(cli_name_prefixed):]
246
247        out_buffer = io.StringIO()
248        try:
249            # issue: stderr cannot be redirect in this form, as a result some failure information
250            # is lost when command fails.
251            self.exit_code = self.cli.invoke(shlex.split(command), out_file=out_buffer) or 0
252            self.output = out_buffer.getvalue()
253        except vcr.errors.CannotOverwriteExistingCassetteException as ex:
254            raise AssertionError(ex) from ex
255        except CliExecutionError as ex:
256            if ex.exception:
257                raise ex.exception
258            raise ex
259        except Exception as ex:  # pylint: disable=broad-except
260            self.exit_code = 1
261            self.output = out_buffer.getvalue()
262            self.process_error = ex
263        finally:
264            out_buffer.close()
265