1# -------------------------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for license information. 4# -------------------------------------------------------------------------------------------- 5 6import unittest 7import os 8import inspect 9import json 10import shlex 11import tempfile 12import shutil 13import logging 14import io 15import vcr 16 17from .patches import patch_time_sleep_api 18from .exceptions import CliExecutionError 19from .const import (ENV_LIVE_TEST, ENV_SKIP_ASSERT, ENV_TEST_DIAGNOSE) 20from .decorators import live_only 21from .recording_processors import (GeneralNameReplacer, LargeRequestBodyProcessor, 22 LargeResponseBodyProcessor, LargeResponseBodyReplacer) 23from .util import find_recording_dir, create_random_name 24logger = logging.getLogger('clicore.testsdk') 25 26 27class IntegrationTestBase(unittest.TestCase): 28 def __init__(self, cli, method_name): 29 super().__init__(method_name) 30 self.cli = cli 31 self.diagnose = os.environ.get(ENV_TEST_DIAGNOSE, None) == 'True' 32 33 def cmd(self, command, checks=None, expect_failure=False): 34 return ExecutionResult(self.cli, command, expect_failure=expect_failure).assert_with_checks(checks) 35 36 def create_random_name(self, prefix, length): # pylint: disable=no-self-use 37 return create_random_name(prefix=prefix, length=length) 38 39 def create_temp_file(self, size_kb, full_random=False): 40 """ 41 Create a temporary file for testing. The test harness will delete the file during tearing 42 down. 43 """ 44 fd, path = tempfile.mkstemp() 45 os.close(fd) 46 self.addCleanup(lambda: os.remove(path)) 47 48 with open(path, mode='r+b') as f: 49 if full_random: 50 chunk = os.urandom(1024) 51 else: 52 chunk = bytearray([0] * 1024) 53 for _ in range(size_kb): 54 f.write(chunk) 55 56 return path 57 58 def create_temp_dir(self): 59 """ 60 Create a temporary directory for testing. The test harness will delete the directory during 61 tearing down. 62 """ 63 temp_dir = tempfile.mkdtemp() 64 self.addCleanup(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) 65 66 return temp_dir 67 68 @classmethod 69 def set_env(cls, key, val): 70 os.environ[key] = val 71 72 @classmethod 73 def pop_env(cls, key): 74 return os.environ.pop(key, None) 75 76 77@live_only() 78class LiveTest(IntegrationTestBase): 79 pass 80 81 82class ScenarioTest(IntegrationTestBase): # pylint: disable=too-many-instance-attributes 83 84 def __init__(self, cli, method_name, filter_headers=None): 85 super().__init__(cli, method_name) 86 self.name_replacer = GeneralNameReplacer() 87 self.recording_processors = [LargeRequestBodyProcessor(), 88 LargeResponseBodyProcessor(), 89 self.name_replacer] 90 self.replay_processors = [LargeResponseBodyReplacer()] 91 self.filter_headers = filter_headers or [] 92 93 test_file_path = inspect.getfile(self.__class__) 94 recordings_dir = find_recording_dir(test_file_path) 95 live_test = os.environ.get(ENV_LIVE_TEST, None) == 'True' 96 97 self.vcr = vcr.VCR( 98 cassette_library_dir=recordings_dir, 99 before_record_request=self._process_request_recording, 100 before_record_response=self._process_response_recording, 101 decode_compressed_response=True, 102 record_mode='once' if not live_test else 'all', 103 filter_headers=self.filter_headers 104 ) 105 self.vcr.register_matcher('query', self._custom_request_query_matcher) 106 107 self.recording_file = os.path.join(recordings_dir, '{}.yaml'.format(method_name)) 108 if live_test and os.path.exists(self.recording_file): 109 os.remove(self.recording_file) 110 111 self.in_recording = live_test or not os.path.exists(self.recording_file) 112 self.test_resources_count = 0 113 self.original_env = os.environ.copy() 114 115 def setUp(self): 116 super().setUp() 117 118 # set up cassette 119 cm = self.vcr.use_cassette(self.recording_file) 120 self.cassette = cm.__enter__() 121 self.addCleanup(cm.__exit__) 122 123 if not self.in_recording: 124 patch_time_sleep_api(self) 125 126 def tearDown(self): 127 os.environ = self.original_env 128 129 def create_random_name(self, prefix, length): 130 self.test_resources_count += 1 131 moniker = '{}{:06}'.format(prefix, self.test_resources_count) 132 133 if self.in_recording: 134 name = create_random_name(prefix, length) 135 self.name_replacer.register_name_pair(name, moniker) 136 return name 137 138 return moniker 139 140 def _process_request_recording(self, request): 141 if self.in_recording: 142 for processor in self.recording_processors: 143 request = processor.process_request(request) 144 if not request: 145 break 146 else: 147 for processor in self.replay_processors: 148 request = processor.process_request(request) 149 if not request: 150 break 151 152 return request 153 154 def _process_response_recording(self, response): 155 if self.in_recording: 156 # make header name lower case and filter unwanted headers 157 headers = {} 158 for key in response['headers']: 159 if key.lower() not in self.filter_headers: 160 headers[key.lower()] = response['headers'][key] 161 response['headers'] = headers 162 163 body = response['body']['string'] 164 if body and not isinstance(body, str): 165 response['body']['string'] = body.decode('utf-8') 166 167 for processor in self.recording_processors: 168 response = processor.process_response(response) 169 if not response: 170 break 171 else: 172 for processor in self.replay_processors: 173 response = processor.process_response(response) 174 if not response: 175 break 176 177 return response 178 179 @classmethod 180 def _custom_request_query_matcher(cls, r1, r2): 181 """ Ensure method, path, and query parameters match. """ 182 from urllib.parse import urlparse, parse_qs # pylint: disable=useless-suppression 183 184 url1 = urlparse(r1.uri) 185 url2 = urlparse(r2.uri) 186 187 q1 = parse_qs(url1.query) 188 q2 = parse_qs(url2.query) 189 shared_keys = set(q1.keys()).intersection(set(q2.keys())) 190 191 if len(shared_keys) != len(q1) or len(shared_keys) != len(q2): 192 return False 193 194 for key in shared_keys: 195 if q1[key][0].lower() != q2[key][0].lower(): 196 return False 197 198 return True 199 200 201class ExecutionResult(object): 202 def __init__(self, cli, command, expect_failure=False): 203 self.cli = cli 204 self._in_process_execute(command) 205 206 if expect_failure and self.exit_code == 0: 207 logger.error('Command "%s" => %d. (It did not fail as expected) Output: %s', command, 208 self.exit_code, self.output) 209 raise AssertionError('The command did not fail as it was expected.') 210 if not expect_failure and self.exit_code != 0: 211 logger.error('Command "%s" => %d. Output: %s', command, self.exit_code, self.output) 212 raise AssertionError('The command failed. Exit code: {}'.format(self.exit_code)) 213 214 logger.info('Command "%s" => %d. Output: %s', command, self.exit_code, self.output) 215 216 self.json_value = None 217 self.skip_assert = os.environ.get(ENV_SKIP_ASSERT, None) == 'True' 218 219 def assert_with_checks(self, *args): 220 checks = [] 221 for each in args: 222 if isinstance(each, list): 223 checks.extend(each) 224 elif callable(each): 225 checks.append(each) 226 227 if not self.skip_assert: 228 for c in checks: 229 c(self) 230 231 return self 232 233 def get_output_in_json(self): 234 if not self.json_value: 235 self.json_value = json.loads(self.output) 236 237 if self.json_value is None: 238 raise AssertionError('The command output cannot be parsed in json.') 239 240 return self.json_value 241 242 def _in_process_execute(self, command): 243 cli_name_prefixed = '{} '.format(self.cli.name) 244 if command.startswith(cli_name_prefixed): 245 command = command[len(cli_name_prefixed):] 246 247 out_buffer = io.StringIO() 248 try: 249 # issue: stderr cannot be redirect in this form, as a result some failure information 250 # is lost when command fails. 251 self.exit_code = self.cli.invoke(shlex.split(command), out_file=out_buffer) or 0 252 self.output = out_buffer.getvalue() 253 except vcr.errors.CannotOverwriteExistingCassetteException as ex: 254 raise AssertionError(ex) from ex 255 except CliExecutionError as ex: 256 if ex.exception: 257 raise ex.exception 258 raise ex 259 except Exception as ex: # pylint: disable=broad-except 260 self.exit_code = 1 261 self.output = out_buffer.getvalue() 262 self.process_error = ex 263 finally: 264 out_buffer.close() 265