1# Copyright 2020 Microsoft Corporation 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# 15# Requires Python 2.6+ and Openssl 1.0+ 16# 17import contextlib 18import re 19from azurelinuxagent.common.protocol.wire import WireProtocol 20from azurelinuxagent.common.utils import restutil 21from tests.tools import patch 22from tests.protocol import mockwiredata 23 24# regex used to determine whether to use the mock wireserver data 25_USE_MOCK_WIRE_DATA_RE = re.compile( 26 r'https?://(mock-goal-state|{0}).*'.format(restutil.KNOWN_WIRESERVER_IP.replace(r'.', r'\.')), re.IGNORECASE) 27 28 29@contextlib.contextmanager 30def mock_wire_protocol(mock_wire_data_file, http_get_handler=None, http_post_handler=None, http_put_handler=None, fail_on_unknown_request=True): 31 """ 32 Creates a WireProtocol object that handles requests to the WireServer and the Host GA Plugin (i.e requests on the WireServer endpoint), plus 33 some requests to storage (requests on the fake server 'mock-goal-state'). 34 35 The data returned by those requests is read from the files specified by 'mock_wire_data_file' (which must follow the structure of the data 36 files defined in tests/protocol/mockwiredata.py). 37 38 The caller can also provide handler functions for specific HTTP methods using the http_*_handler arguments. The return value of the handler 39 function is interpreted similarly to the "return_value" argument of patch(): if it is an exception the exception is raised or, if it is 40 any object other than None, the value is returned by the mock. If the handler function returns None the call is handled using the mock 41 wireserver data or passed to the original to restutil.http_request. 42 43 The returned protocol object maintains a list of "tracked" urls. When a handler function returns a value than is not None the url for the 44 request is automatically added to the tracked list. The handler function can add other items to this list using the track_url() method on 45 the mock. 46 47 The return value of this function is an instance of WireProtocol augmented with these properties/methods: 48 49 * mock_wire_data - the WireProtocolData constructed from the mock_wire_data_file parameter. 50 * start() - starts the patchers for http_request and CryptUtil 51 * stop() - stops the patchers 52 * track_url(url) - adds the given item to the list of tracked urls. 53 * get_tracked_urls() - returns the list of tracked urls. 54 55 NOTE: This function patches common.utils.restutil.http_request and common.protocol.wire.CryptUtil; you need to be aware of this if your 56 tests patch those methods or others in the call stack (e.g. restutil.get, resutil._http_request, etc) 57 58 """ 59 tracked_urls = [] 60 61 # use a helper function to keep the HTTP handlers (they need to be modified by set_http_handlers() and 62 # Python 2.* does not support nonlocal declarations) 63 def http_handlers(get, post, put): 64 http_handlers.get = get 65 http_handlers.post = post 66 http_handlers.put = put 67 del tracked_urls[:] 68 http_handlers(get=http_get_handler, post=http_post_handler, put=http_put_handler) 69 70 # 71 # function used to patch restutil.http_request 72 # 73 original_http_request = restutil.http_request 74 75 def http_request(method, url, data, **kwargs): 76 # if there is a handler for the request, use it 77 handler = None 78 if method == 'GET': 79 handler = http_handlers.get 80 elif method == 'POST': 81 handler = http_handlers.post 82 elif method == 'PUT': 83 handler = http_handlers.put 84 85 if handler is not None: 86 if method == 'GET': 87 return_value = handler(url, **kwargs) 88 else: 89 return_value = handler(url, data, **kwargs) 90 if return_value is not None: 91 tracked_urls.append(url) 92 if isinstance(return_value, Exception): 93 raise return_value 94 return return_value 95 96 # if the request was not handled try to use the mock wireserver data 97 if _USE_MOCK_WIRE_DATA_RE.match(url) is not None: 98 if method == 'GET': 99 return protocol.mock_wire_data.mock_http_get(url, **kwargs) 100 if method == 'POST': 101 return protocol.mock_wire_data.mock_http_post(url, data, **kwargs) 102 if method == 'PUT': 103 return protocol.mock_wire_data.mock_http_put(url, data, **kwargs) 104 105 # the request was not handled; fail or call the original resutil.http_request 106 if fail_on_unknown_request: 107 raise ValueError('Unknown HTTP request: {0} [{1}]'.format(url, method)) 108 return original_http_request(method, url, data, **kwargs) 109 110 # 111 # functions to start/stop the mocks 112 # 113 def start(): 114 patched = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request) 115 patched.start() 116 start.http_request_patch = patched 117 118 patched = patch("azurelinuxagent.common.protocol.wire.CryptUtil", side_effect=protocol.mock_wire_data.mock_crypt_util) 119 patched.start() 120 start.crypt_util_patch = patched 121 start.http_request_patch = None 122 start.crypt_util_patch = None 123 124 def stop(): 125 if start.crypt_util_patch is not None: 126 start.crypt_util_patch.stop() 127 if start.http_request_patch is not None: 128 start.http_request_patch.stop() 129 130 # 131 # create the protocol object 132 # 133 protocol = WireProtocol(restutil.KNOWN_WIRESERVER_IP) 134 protocol.mock_wire_data = mockwiredata.WireProtocolData(mock_wire_data_file) 135 protocol.start = start 136 protocol.stop = stop 137 protocol.track_url = lambda url: tracked_urls.append(url) # pylint: disable=unnecessary-lambda 138 protocol.get_tracked_urls = lambda: tracked_urls 139 protocol.set_http_handlers = lambda http_get_handler=None, http_post_handler=None, http_put_handler=None:\ 140 http_handlers(get=http_get_handler, post=http_post_handler, put=http_put_handler) 141 142 # go do it 143 try: 144 protocol.start() 145 protocol.detect() 146 yield protocol 147 finally: 148 protocol.stop() 149 150 151class HttpRequestPredicates(object): 152 """ 153 Utility functions to check the urls used by tests 154 """ 155 @staticmethod 156 def is_goal_state_request(url): 157 return url.lower() == 'http://{0}/machine/?comp=goalstate'.format(restutil.KNOWN_WIRESERVER_IP) 158 159 @staticmethod 160 def is_telemetry_request(url): 161 return url.lower() == 'http://{0}/machine?comp=telemetrydata'.format(restutil.KNOWN_WIRESERVER_IP) 162 163 @staticmethod 164 def is_health_service_request(url): 165 return url.lower() == 'http://{0}:80/healthservice'.format(restutil.KNOWN_WIRESERVER_IP) 166 167 @staticmethod 168 def is_in_vm_artifacts_profile_request(url): 169 return re.match(r'https://.+\.blob\.core\.windows\.net/\$system/.+\.(vmSettings|settings)\?.+', url) is not None 170 171 @staticmethod 172 def _get_host_plugin_request_artifact_location(url, request_kwargs): 173 if 'headers' not in request_kwargs: 174 raise ValueError('Host plugin request is missing HTTP headers ({0})'.format(url)) 175 headers = request_kwargs['headers'] 176 if 'x-ms-artifact-location' not in headers: 177 raise ValueError('Host plugin request is missing the x-ms-artifact-location header ({0})'.format(url)) 178 return headers['x-ms-artifact-location'] 179 180 @staticmethod 181 def is_host_plugin_health_request(url): 182 return url.lower() == 'http://{0}:{1}/health'.format(restutil.KNOWN_WIRESERVER_IP, restutil.HOST_PLUGIN_PORT) 183 184 @staticmethod 185 def is_host_plugin_extension_artifact_request(url): 186 return url.lower() == 'http://{0}:{1}/extensionartifact'.format(restutil.KNOWN_WIRESERVER_IP, restutil.HOST_PLUGIN_PORT) 187 188 @staticmethod 189 def is_host_plugin_status_request(url): 190 return url.lower() == 'http://{0}:{1}/status'.format(restutil.KNOWN_WIRESERVER_IP, restutil.HOST_PLUGIN_PORT) 191 192 @staticmethod 193 def is_host_plugin_extension_request(request_url, request_kwargs, extension_url): 194 if not HttpRequestPredicates.is_host_plugin_extension_artifact_request(request_url): 195 return False 196 artifact_location = HttpRequestPredicates._get_host_plugin_request_artifact_location(request_url, request_kwargs) 197 return artifact_location == extension_url 198 199 @staticmethod 200 def is_host_plugin_in_vm_artifacts_profile_request(url, request_kwargs): 201 if not HttpRequestPredicates.is_host_plugin_extension_artifact_request(url): 202 return False 203 artifact_location = HttpRequestPredicates._get_host_plugin_request_artifact_location(url, request_kwargs) 204 return HttpRequestPredicates.is_in_vm_artifacts_profile_request(artifact_location) 205 206 @staticmethod 207 def is_host_plugin_put_logs_request(url): 208 return url.lower() == 'http://{0}:{1}/vmagentlog'.format(restutil.KNOWN_WIRESERVER_IP, 209 restutil.HOST_PLUGIN_PORT) 210 211 212class MockHttpResponse: 213 def __init__(self, status, body=''): 214 self.body = body 215 self.status = status 216 217 def read(self, *_): 218 return self.body 219