1# Copyright 2020 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# Requires Python 2.6+ and Openssl 1.0+
16#
17import contextlib
18import re
19from azurelinuxagent.common.protocol.wire import WireProtocol
20from azurelinuxagent.common.utils import restutil
21from tests.tools import patch
22from tests.protocol import mockwiredata
23
24# regex used to determine whether to use the mock wireserver data
25_USE_MOCK_WIRE_DATA_RE = re.compile(
26    r'https?://(mock-goal-state|{0}).*'.format(restutil.KNOWN_WIRESERVER_IP.replace(r'.', r'\.')), re.IGNORECASE)
27
28
29@contextlib.contextmanager
30def mock_wire_protocol(mock_wire_data_file, http_get_handler=None, http_post_handler=None, http_put_handler=None, fail_on_unknown_request=True):
31    """
32    Creates a WireProtocol object that handles requests to the WireServer and the Host GA Plugin (i.e requests on the WireServer endpoint), plus
33    some requests to storage (requests on the fake server 'mock-goal-state').
34
35    The data returned by those requests is read from the files specified by 'mock_wire_data_file' (which must follow the structure of the data
36    files defined in tests/protocol/mockwiredata.py).
37
38    The caller can also provide handler functions for specific HTTP methods using the http_*_handler arguments. The return value of the handler
39    function is interpreted similarly to the "return_value" argument of patch(): if it is an exception the exception is raised or, if it is
40    any object other than None, the value is returned by the mock. If the handler function returns None the call is handled using the mock
41    wireserver data or passed to the original to restutil.http_request.
42
43    The returned protocol object maintains a list of "tracked" urls. When a handler function returns a value than is not None the url for the
44    request is automatically added to the tracked list. The handler function can add other items to this list using the track_url() method on
45    the mock.
46
47    The return value of this function is an instance of WireProtocol augmented with these properties/methods:
48
49        * mock_wire_data - the WireProtocolData constructed from the mock_wire_data_file parameter.
50        * start() - starts the patchers for http_request and CryptUtil
51        * stop() - stops the patchers
52        * track_url(url) - adds the given item to the list of tracked urls.
53        * get_tracked_urls() - returns the list of tracked urls.
54
55    NOTE: This function patches common.utils.restutil.http_request and common.protocol.wire.CryptUtil; you need to be aware of this if your
56          tests patch those methods or others in the call stack (e.g. restutil.get, resutil._http_request, etc)
57
58    """
59    tracked_urls = []
60
61    # use a helper function to keep the HTTP handlers (they need to be modified by set_http_handlers() and
62    # Python 2.* does not support nonlocal declarations)
63    def http_handlers(get, post, put):
64        http_handlers.get = get
65        http_handlers.post = post
66        http_handlers.put = put
67        del tracked_urls[:]
68    http_handlers(get=http_get_handler, post=http_post_handler, put=http_put_handler)
69
70    #
71    # function used to patch restutil.http_request
72    #
73    original_http_request = restutil.http_request
74
75    def http_request(method, url, data, **kwargs):
76        # if there is a handler for the request, use it
77        handler = None
78        if method == 'GET':
79            handler = http_handlers.get
80        elif method == 'POST':
81            handler = http_handlers.post
82        elif method == 'PUT':
83            handler = http_handlers.put
84
85        if handler is not None:
86            if method == 'GET':
87                return_value = handler(url, **kwargs)
88            else:
89                return_value = handler(url, data, **kwargs)
90            if return_value is not None:
91                tracked_urls.append(url)
92                if isinstance(return_value, Exception):
93                    raise return_value
94                return return_value
95
96        # if the request was not handled try to use the mock wireserver data
97        if _USE_MOCK_WIRE_DATA_RE.match(url) is not None:
98            if method == 'GET':
99                return protocol.mock_wire_data.mock_http_get(url, **kwargs)
100            if method == 'POST':
101                return protocol.mock_wire_data.mock_http_post(url, data, **kwargs)
102            if method == 'PUT':
103                return protocol.mock_wire_data.mock_http_put(url, data, **kwargs)
104
105        # the request was not handled; fail or call the original resutil.http_request
106        if fail_on_unknown_request:
107            raise ValueError('Unknown HTTP request: {0} [{1}]'.format(url, method))
108        return original_http_request(method, url, data, **kwargs)
109
110    #
111    # functions to start/stop the mocks
112    #
113    def start():
114        patched = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request)
115        patched.start()
116        start.http_request_patch = patched
117
118        patched = patch("azurelinuxagent.common.protocol.wire.CryptUtil", side_effect=protocol.mock_wire_data.mock_crypt_util)
119        patched.start()
120        start.crypt_util_patch = patched
121    start.http_request_patch = None
122    start.crypt_util_patch = None
123
124    def stop():
125        if start.crypt_util_patch is not None:
126            start.crypt_util_patch.stop()
127        if start.http_request_patch is not None:
128            start.http_request_patch.stop()
129
130    #
131    # create the protocol object
132    #
133    protocol = WireProtocol(restutil.KNOWN_WIRESERVER_IP)
134    protocol.mock_wire_data = mockwiredata.WireProtocolData(mock_wire_data_file)
135    protocol.start = start
136    protocol.stop = stop
137    protocol.track_url = lambda url: tracked_urls.append(url)  # pylint: disable=unnecessary-lambda
138    protocol.get_tracked_urls = lambda: tracked_urls
139    protocol.set_http_handlers = lambda http_get_handler=None, http_post_handler=None, http_put_handler=None:\
140        http_handlers(get=http_get_handler, post=http_post_handler, put=http_put_handler)
141
142    # go do it
143    try:
144        protocol.start()
145        protocol.detect()
146        yield protocol
147    finally:
148        protocol.stop()
149
150
151class HttpRequestPredicates(object):
152    """
153    Utility functions to check the urls used by tests
154    """
155    @staticmethod
156    def is_goal_state_request(url):
157        return url.lower() == 'http://{0}/machine/?comp=goalstate'.format(restutil.KNOWN_WIRESERVER_IP)
158
159    @staticmethod
160    def is_telemetry_request(url):
161        return url.lower() == 'http://{0}/machine?comp=telemetrydata'.format(restutil.KNOWN_WIRESERVER_IP)
162
163    @staticmethod
164    def is_health_service_request(url):
165        return url.lower() == 'http://{0}:80/healthservice'.format(restutil.KNOWN_WIRESERVER_IP)
166
167    @staticmethod
168    def is_in_vm_artifacts_profile_request(url):
169        return re.match(r'https://.+\.blob\.core\.windows\.net/\$system/.+\.(vmSettings|settings)\?.+', url) is not None
170
171    @staticmethod
172    def _get_host_plugin_request_artifact_location(url, request_kwargs):
173        if 'headers' not in request_kwargs:
174            raise ValueError('Host plugin request is missing HTTP headers ({0})'.format(url))
175        headers = request_kwargs['headers']
176        if 'x-ms-artifact-location' not in headers:
177            raise ValueError('Host plugin request is missing the x-ms-artifact-location header ({0})'.format(url))
178        return headers['x-ms-artifact-location']
179
180    @staticmethod
181    def is_host_plugin_health_request(url):
182        return url.lower() == 'http://{0}:{1}/health'.format(restutil.KNOWN_WIRESERVER_IP, restutil.HOST_PLUGIN_PORT)
183
184    @staticmethod
185    def is_host_plugin_extension_artifact_request(url):
186        return url.lower() == 'http://{0}:{1}/extensionartifact'.format(restutil.KNOWN_WIRESERVER_IP, restutil.HOST_PLUGIN_PORT)
187
188    @staticmethod
189    def is_host_plugin_status_request(url):
190        return url.lower() == 'http://{0}:{1}/status'.format(restutil.KNOWN_WIRESERVER_IP, restutil.HOST_PLUGIN_PORT)
191
192    @staticmethod
193    def is_host_plugin_extension_request(request_url, request_kwargs, extension_url):
194        if not HttpRequestPredicates.is_host_plugin_extension_artifact_request(request_url):
195            return False
196        artifact_location = HttpRequestPredicates._get_host_plugin_request_artifact_location(request_url, request_kwargs)
197        return artifact_location == extension_url
198
199    @staticmethod
200    def is_host_plugin_in_vm_artifacts_profile_request(url, request_kwargs):
201        if not HttpRequestPredicates.is_host_plugin_extension_artifact_request(url):
202            return False
203        artifact_location = HttpRequestPredicates._get_host_plugin_request_artifact_location(url, request_kwargs)
204        return HttpRequestPredicates.is_in_vm_artifacts_profile_request(artifact_location)
205
206    @staticmethod
207    def is_host_plugin_put_logs_request(url):
208        return url.lower() == 'http://{0}:{1}/vmagentlog'.format(restutil.KNOWN_WIRESERVER_IP,
209                                                                 restutil.HOST_PLUGIN_PORT)
210
211
212class MockHttpResponse:
213    def __init__(self, status, body=''):
214        self.body = body
215        self.status = status
216
217    def read(self, *_):
218        return self.body
219