1# -*- encoding: utf-8 -*- 2# Copyright 2018 Microsoft Corporation 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16# Requires Python 2.6+ and Openssl 1.0+ 17# 18 19import contextlib 20import json 21import os 22import re 23import socket 24import time 25import unittest 26import uuid 27from datetime import datetime, timedelta 28 29from azurelinuxagent.common import conf 30from azurelinuxagent.common.agent_supported_feature import SupportedFeatureNames, get_supported_feature_by_name, \ 31 get_agent_supported_features_list_for_crp 32from azurelinuxagent.common.exception import ResourceGoneError, ProtocolError, \ 33 ExtensionDownloadError, HttpError 34from azurelinuxagent.common.protocol.goal_state import ExtensionsConfig, GoalState 35from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol 36from azurelinuxagent.common.protocol.restapi import VMAgentManifestUri 37from azurelinuxagent.common.protocol.wire import WireProtocol, WireClient, \ 38 InVMArtifactsProfile, StatusBlob, VMStatus, EXT_CONF_FILE_NAME 39from azurelinuxagent.common.telemetryevent import GuestAgentExtensionEventsSchema, \ 40 TelemetryEventParam, TelemetryEvent 41from azurelinuxagent.common.utils import restutil 42from azurelinuxagent.common.exception import IncompleteGoalStateError 43from azurelinuxagent.common.version import CURRENT_VERSION, DISTRO_NAME, DISTRO_VERSION 44from azurelinuxagent.ga.exthandlers import get_exthandlers_handler 45from tests.ga.test_monitor import random_generator 46from tests.protocol import mockwiredata 47from tests.protocol.mocks import mock_wire_protocol, HttpRequestPredicates 48from tests.protocol.mockwiredata import DATA_FILE_NO_EXT, DATA_FILE 49from tests.protocol.mockwiredata import WireProtocolData 50from tests.tools import Mock, patch, AgentTestCase 51 52data_with_bom = b'\xef\xbb\xbfhehe' 53testurl = 'http://foo' 54testtype = 'BlockBlob' 55WIRESERVER_URL = '168.63.129.16' 56 57 58def get_event(message, duration=30000, evt_type="", is_internal=False, is_success=True, 59 name="", op="Unknown", version=CURRENT_VERSION, eventId=1): 60 event = TelemetryEvent(eventId, "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX") 61 event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Name, name)) 62 event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Version, str(version))) 63 event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.IsInternal, is_internal)) 64 event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Operation, op)) 65 event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.OperationSuccess, is_success)) 66 event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Message, message)) 67 event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Duration, duration)) 68 event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.ExtensionType, evt_type)) 69 return event 70 71 72@contextlib.contextmanager 73def create_mock_protocol(artifacts_profile_blob=None, status_upload_blob=None, status_upload_blob_type=None): 74 with mock_wire_protocol(DATA_FILE_NO_EXT) as protocol: 75 # These tests use mock wire data that dont have any extensions (extension config will be empty). 76 # Populate the upload blob and artifacts profile blob. 77 ext_conf = ExtensionsConfig(None) 78 ext_conf.artifacts_profile_blob = artifacts_profile_blob 79 ext_conf.status_upload_blob = status_upload_blob 80 ext_conf.status_upload_blob_type = status_upload_blob_type 81 protocol.client._goal_state.ext_conf = ext_conf # pylint: disable=protected-access 82 83 yield protocol 84 85 86@patch("time.sleep") 87@patch("azurelinuxagent.common.protocol.wire.CryptUtil") 88@patch("azurelinuxagent.common.protocol.healthservice.HealthService._report") 89class TestWireProtocol(AgentTestCase): 90 91 def setUp(self): 92 super(TestWireProtocol, self).setUp() 93 HostPluginProtocol.is_default_channel = False 94 95 def _test_getters(self, test_data, certsMustBePresent, __, MockCryptUtil, _): 96 MockCryptUtil.side_effect = test_data.mock_crypt_util 97 98 with patch.object(restutil, 'http_get', test_data.mock_http_get): 99 protocol = WireProtocol(WIRESERVER_URL) 100 protocol.detect() 101 protocol.get_vminfo() 102 protocol.get_certs() 103 ext_handlers, etag = protocol.get_ext_handlers() # pylint: disable=unused-variable 104 for ext_handler in ext_handlers.extHandlers: 105 protocol.get_ext_handler_pkgs(ext_handler) 106 107 crt1 = os.path.join(self.tmp_dir, 108 '33B0ABCE4673538650971C10F7D7397E71561F35.crt') 109 crt2 = os.path.join(self.tmp_dir, 110 '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt') 111 prv2 = os.path.join(self.tmp_dir, 112 '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv') 113 if certsMustBePresent: 114 self.assertTrue(os.path.isfile(crt1)) 115 self.assertTrue(os.path.isfile(crt2)) 116 self.assertTrue(os.path.isfile(prv2)) 117 else: 118 self.assertFalse(os.path.isfile(crt1)) 119 self.assertFalse(os.path.isfile(crt2)) 120 self.assertFalse(os.path.isfile(prv2)) 121 self.assertEqual("1", protocol.get_incarnation()) 122 123 @staticmethod 124 def _get_telemetry_events_generator(event_list): 125 def _yield_events(): 126 for telemetry_event in event_list: 127 yield telemetry_event 128 129 return _yield_events() 130 131 def test_getters(self, *args): 132 """Normal case""" 133 test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE) 134 self._test_getters(test_data, True, *args) 135 136 def test_getters_no_ext(self, *args): 137 """Provision with agent is not checked""" 138 test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_NO_EXT) 139 self._test_getters(test_data, True, *args) 140 141 def test_getters_ext_no_settings(self, *args): 142 """Extensions without any settings""" 143 test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_EXT_NO_SETTINGS) 144 self._test_getters(test_data, True, *args) 145 146 def test_getters_ext_no_public(self, *args): 147 """Extensions without any public settings""" 148 test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_EXT_NO_PUBLIC) 149 self._test_getters(test_data, True, *args) 150 151 def test_getters_ext_no_cert_format(self, *args): 152 """Certificate format not specified""" 153 test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_NO_CERT_FORMAT) 154 self._test_getters(test_data, True, *args) 155 156 def test_getters_ext_cert_format_not_pfx(self, *args): 157 """Certificate format is not Pkcs7BlobWithPfxContents specified""" 158 test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_CERT_FORMAT_NOT_PFX) 159 self._test_getters(test_data, False, *args) 160 161 @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_extension_artifact") 162 def test_getters_with_stale_goal_state(self, patch_report, *args): 163 test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE) 164 test_data.emulate_stale_goal_state = True 165 166 self._test_getters(test_data, True, *args) 167 # Ensure HostPlugin was invoked 168 self.assertEqual(1, test_data.call_counts["/versions"]) 169 self.assertEqual(2, test_data.call_counts["extensionArtifact"]) 170 # Ensure the expected number of HTTP calls were made 171 # -- Tracking calls to retrieve GoalState is problematic since it is 172 # fetched often; however, the dependent documents, such as the 173 # HostingEnvironmentConfig, will be retrieved the expected number 174 self.assertEqual(1, test_data.call_counts["hostingenvuri"]) 175 self.assertEqual(1, patch_report.call_count) 176 177 def test_call_storage_kwargs(self, *args): # pylint: disable=unused-argument 178 with patch.object(restutil, 'http_get') as http_patch: 179 http_req = restutil.http_get 180 url = testurl 181 headers = {} 182 183 # no kwargs -- Default to True 184 WireClient.call_storage_service(http_req) 185 186 # kwargs, no use_proxy -- Default to True 187 WireClient.call_storage_service(http_req, 188 url, 189 headers) 190 191 # kwargs, use_proxy None -- Default to True 192 WireClient.call_storage_service(http_req, 193 url, 194 headers, 195 use_proxy=None) 196 197 # kwargs, use_proxy False -- Keep False 198 WireClient.call_storage_service(http_req, 199 url, 200 headers, 201 use_proxy=False) 202 203 # kwargs, use_proxy True -- Keep True 204 WireClient.call_storage_service(http_req, 205 url, 206 headers, 207 use_proxy=True) 208 # assert 209 self.assertTrue(http_patch.call_count == 5) 210 for i in range(0, 5): 211 c = http_patch.call_args_list[i][-1]['use_proxy'] 212 self.assertTrue(c == (True if i != 3 else False)) 213 214 def test_status_blob_parsing(self, *args): # pylint: disable=unused-argument 215 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 216 self.assertEqual(protocol.client.get_ext_conf().status_upload_blob, 217 'https://test.blob.core.windows.net/vhds/test-cs12.test-cs12.test-cs12.status?' 218 'sr=b&sp=rw&se=9999-01-01&sk=key1&sv=2014-02-14&' 219 'sig=hfRh7gzUE7sUtYwke78IOlZOrTRCYvkec4hGZ9zZzXo') 220 self.assertEqual(protocol.client.get_ext_conf().status_upload_blob_type, u'BlockBlob') 221 222 def test_get_host_ga_plugin(self, *args): # pylint: disable=unused-argument 223 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 224 host_plugin = protocol.client.get_host_plugin() 225 goal_state = protocol.client.get_goal_state() 226 self.assertEqual(goal_state.container_id, host_plugin.container_id) 227 self.assertEqual(goal_state.role_config_name, host_plugin.role_config_name) 228 229 def test_upload_status_blob_should_use_the_host_channel_by_default(self, *_): 230 def http_put_handler(url, *_, **__): # pylint: disable=inconsistent-return-statements 231 if protocol.get_endpoint() in url and url.endswith('/status'): 232 return MockResponse(body=b'', status_code=200) 233 234 with mock_wire_protocol(mockwiredata.DATA_FILE, http_put_handler=http_put_handler) as protocol: 235 HostPluginProtocol.is_default_channel = False 236 protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready") 237 238 protocol.client.upload_status_blob() 239 240 urls = protocol.get_tracked_urls() 241 self.assertEqual(len(urls), 1, 'Expected one post request to the host: [{0}]'.format(urls)) 242 243 def test_upload_status_blob_host_ga_plugin(self, *_): 244 with create_mock_protocol(status_upload_blob=testurl, status_upload_blob_type=testtype) as protocol: 245 protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready") 246 247 with patch.object(HostPluginProtocol, "ensure_initialized", return_value=True): 248 with patch.object(StatusBlob, "upload", return_value=False) as patch_default_upload: 249 with patch.object(HostPluginProtocol, "_put_block_blob_status") as patch_http: 250 HostPluginProtocol.is_default_channel = False 251 protocol.client.upload_status_blob() 252 patch_default_upload.assert_not_called() 253 patch_http.assert_called_once_with(testurl, protocol.client.status_blob) 254 self.assertFalse(HostPluginProtocol.is_default_channel) 255 256 @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized") 257 def test_upload_status_blob_unknown_type_assumes_block(self, *_): 258 with create_mock_protocol(status_upload_blob=testurl, status_upload_blob_type="NotALegalType") as protocol: 259 protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready") 260 261 with patch.object(StatusBlob, "prepare") as patch_prepare: 262 with patch.object(StatusBlob, "upload") as patch_default_upload: 263 HostPluginProtocol.is_default_channel = False 264 protocol.client.upload_status_blob() 265 266 patch_prepare.assert_called_once_with("BlockBlob") 267 patch_default_upload.assert_called_once_with(testurl) 268 269 def test_upload_status_blob_reports_prepare_error(self, *_): 270 with create_mock_protocol(status_upload_blob=testurl, status_upload_blob_type=testtype) as protocol: 271 protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready") 272 273 with patch.object(StatusBlob, "prepare", side_effect=Exception) as mock_prepare: 274 self.assertRaises(ProtocolError, protocol.client.upload_status_blob) 275 self.assertEqual(1, mock_prepare.call_count) 276 277 def test_get_in_vm_artifacts_profile_blob_not_available(self, *_): 278 # Test when artifacts_profile_blob is null/None 279 with mock_wire_protocol(DATA_FILE_NO_EXT) as protocol: 280 protocol.client._goal_state.ext_conf = ExtensionsConfig(None) # pylint: disable=protected-access 281 282 self.assertEqual(None, protocol.client.get_artifacts_profile()) 283 284 # Test when artifacts_profile_blob is whitespace 285 with create_mock_protocol(artifacts_profile_blob=" ") as protocol: 286 self.assertEqual(None, protocol.client.get_artifacts_profile()) 287 288 def test_get_in_vm_artifacts_profile_response_body_not_valid(self, *_): 289 with create_mock_protocol(artifacts_profile_blob=testurl) as protocol: 290 with patch.object(HostPluginProtocol, "get_artifact_request", return_value=['dummy_url', {}]) as host_plugin_get_artifact_url_and_headers: 291 # Test when response body is None 292 protocol.client.call_storage_service = Mock(return_value=MockResponse(None, 200)) 293 in_vm_artifacts_profile = protocol.client.get_artifacts_profile() 294 self.assertTrue(in_vm_artifacts_profile is None) 295 296 # Test when response body is None 297 protocol.client.call_storage_service = Mock(return_value=MockResponse(' '.encode('utf-8'), 200)) 298 in_vm_artifacts_profile = protocol.client.get_artifacts_profile() 299 self.assertTrue(in_vm_artifacts_profile is None) 300 301 # Test when response body is None 302 protocol.client.call_storage_service = Mock(return_value=MockResponse('{ }'.encode('utf-8'), 200)) 303 in_vm_artifacts_profile = protocol.client.get_artifacts_profile() 304 self.assertEqual(dict(), in_vm_artifacts_profile.__dict__, 305 'If artifacts_profile_blob has empty json dictionary, in_vm_artifacts_profile ' 306 'should contain nothing') 307 308 host_plugin_get_artifact_url_and_headers.assert_called_with(testurl) 309 310 @patch("azurelinuxagent.common.event.add_event") 311 def test_artifacts_profile_json_parsing(self, patch_event, *args): # pylint: disable=unused-argument 312 with create_mock_protocol(artifacts_profile_blob=testurl) as protocol: 313 # response is invalid json 314 protocol.client.call_storage_service = Mock(return_value=MockResponse("invalid json".encode('utf-8'), 200)) 315 in_vm_artifacts_profile = protocol.client.get_artifacts_profile() 316 317 # ensure response is empty 318 self.assertEqual(None, in_vm_artifacts_profile) 319 320 # ensure event is logged 321 self.assertEqual(1, patch_event.call_count) 322 self.assertFalse(patch_event.call_args[1]['is_success']) 323 self.assertTrue('invalid json' in patch_event.call_args[1]['message']) 324 self.assertEqual('ArtifactsProfileBlob', patch_event.call_args[1]['op']) 325 326 def test_get_in_vm_artifacts_profile_default(self, *args): # pylint: disable=unused-argument 327 with create_mock_protocol(artifacts_profile_blob=testurl) as protocol: 328 protocol.client.call_storage_service = Mock(return_value=MockResponse('{"onHold": "true"}'.encode('utf-8'), 200)) 329 in_vm_artifacts_profile = protocol.client.get_artifacts_profile() 330 self.assertEqual(dict(onHold='true'), in_vm_artifacts_profile.__dict__) 331 self.assertTrue(in_vm_artifacts_profile.is_on_hold()) 332 333 @patch("socket.gethostname", return_value="hostname") 334 @patch("time.gmtime", return_value=time.localtime(1485543256)) 335 def test_report_vm_status(self, *args): # pylint: disable=unused-argument 336 status = 'status' 337 message = 'message' 338 339 client = WireProtocol(WIRESERVER_URL).client 340 actual = StatusBlob(client=client) 341 actual.set_vm_status(VMStatus(status=status, message=message)) 342 timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) 343 344 formatted_msg = { 345 'lang': 'en-US', 346 'message': message 347 } 348 v1_ga_status = { 349 'version': str(CURRENT_VERSION), 350 'status': status, 351 'formattedMessage': formatted_msg 352 } 353 v1_ga_guest_info = { 354 'computerName': socket.gethostname(), 355 'osName': DISTRO_NAME, 356 'osVersion': DISTRO_VERSION, 357 'version': str(CURRENT_VERSION), 358 } 359 v1_agg_status = { 360 'guestAgentStatus': v1_ga_status, 361 'handlerAggregateStatus': [] 362 } 363 v1_vm_status = { 364 'version': '1.1', 365 'timestampUTC': timestamp, 366 'aggregateStatus': v1_agg_status, 367 'guestOSInfo': v1_ga_guest_info 368 } 369 self.assertEqual(json.dumps(v1_vm_status), actual.to_json()) 370 371 def test_it_should_report_supported_features_in_status_blob_if_supported(self, *_): 372 with mock_wire_protocol(DATA_FILE) as protocol: 373 374 def mock_http_put(url, *args, **__): 375 if not HttpRequestPredicates.is_host_plugin_status_request(url): 376 # Skip reading the HostGA request data as its encoded 377 protocol.aggregate_status = json.loads(args[0]) 378 379 protocol.aggregate_status = {} 380 protocol.set_http_handlers(http_put_handler=mock_http_put) 381 exthandlers_handler = get_exthandlers_handler(protocol) 382 383 with patch("azurelinuxagent.common.agent_supported_feature._MultiConfigFeature.is_supported", True): 384 exthandlers_handler.run() 385 self.assertIsNotNone(protocol.aggregate_status, "Aggregate status should not be None") 386 self.assertIn("supportedFeatures", protocol.aggregate_status, "supported features not reported") 387 multi_config_feature = get_supported_feature_by_name(SupportedFeatureNames.MultiConfig) 388 found = False 389 for feature in protocol.aggregate_status['supportedFeatures']: 390 if feature['Key'] == multi_config_feature.name and feature['Value'] == multi_config_feature.version: 391 found = True 392 break 393 self.assertTrue(found, "Multi-config name should be present in supportedFeatures") 394 395 # Feature should not be reported if not present 396 with patch("azurelinuxagent.common.agent_supported_feature._MultiConfigFeature.is_supported", False): 397 exthandlers_handler.run() 398 self.assertIsNotNone(protocol.aggregate_status, "Aggregate status should not be None") 399 if "supportedFeatures" not in protocol.aggregate_status: 400 # In the case Multi-config was the only feature available, 'supportedFeatures' should not be 401 # reported in the status blob as its not supported as of now. 402 # Asserting no other feature was available to report back to crp 403 self.assertEqual(1, len(get_agent_supported_features_list_for_crp()), 404 "supportedFeatures should be available if there are more features") 405 return 406 407 # If there are other features available, confirm MultiConfig was not reported 408 multi_config_feature = get_supported_feature_by_name(SupportedFeatureNames.MultiConfig) 409 found = False 410 for feature in protocol.aggregate_status['supportedFeatures']: 411 if feature['Key'] == multi_config_feature.name and feature['Value'] == multi_config_feature.version: 412 found = True 413 break 414 self.assertFalse(found, "Multi-config name should be present in supportedFeatures") 415 416 @patch("azurelinuxagent.common.utils.restutil.http_request") 417 def test_send_event(self, mock_http_request, *args): 418 mock_http_request.return_value = MockResponse("", 200) 419 420 event_str = u'a test string' 421 client = WireProtocol(WIRESERVER_URL).client 422 client.send_event("foo", event_str.encode('utf-8')) 423 424 first_call = mock_http_request.call_args_list[0] 425 args, kwargs = first_call 426 method, url, body_received = args # pylint: disable=unused-variable 427 headers = kwargs['headers'] 428 429 # the headers should include utf-8 encoding... 430 self.assertTrue("utf-8" in headers['Content-Type']) 431 # the body is not encoded, just check for equality 432 self.assertIn(event_str, body_received) 433 434 @patch("azurelinuxagent.common.protocol.wire.WireClient.send_event") 435 def test_report_event_small_event(self, patch_send_event, *args): # pylint: disable=unused-argument 436 event_list = [] 437 client = WireProtocol(WIRESERVER_URL).client 438 439 event_str = random_generator(10) 440 event_list.append(get_event(message=event_str)) 441 442 event_str = random_generator(100) 443 event_list.append(get_event(message=event_str)) 444 445 event_str = random_generator(1000) 446 event_list.append(get_event(message=event_str)) 447 448 event_str = random_generator(10000) 449 event_list.append(get_event(message=event_str)) 450 451 client.report_event(self._get_telemetry_events_generator(event_list)) 452 453 # It merges the messages into one message 454 self.assertEqual(patch_send_event.call_count, 1) 455 456 @patch("azurelinuxagent.common.protocol.wire.WireClient.send_event") 457 def test_report_event_multiple_events_to_fill_buffer(self, patch_send_event, *args): # pylint: disable=unused-argument 458 event_list = [] 459 client = WireProtocol(WIRESERVER_URL).client 460 461 event_str = random_generator(2 ** 15) 462 event_list.append(get_event(message=event_str)) 463 event_list.append(get_event(message=event_str)) 464 465 client.report_event(self._get_telemetry_events_generator(event_list)) 466 467 # It merges the messages into one message 468 self.assertEqual(patch_send_event.call_count, 2) 469 470 @patch("azurelinuxagent.common.protocol.wire.WireClient.send_event") 471 def test_report_event_large_event(self, patch_send_event, *args): # pylint: disable=unused-argument 472 event_list = [] 473 event_str = random_generator(2 ** 18) 474 event_list.append(get_event(message=event_str)) 475 client = WireProtocol(WIRESERVER_URL).client 476 client.report_event(self._get_telemetry_events_generator(event_list)) 477 478 self.assertEqual(patch_send_event.call_count, 0) 479 480 481class TestWireClient(HttpRequestPredicates, AgentTestCase): 482 def test_get_ext_conf_without_extensions_should_retrieve_vmagent_manifests_info(self, *args): # pylint: disable=unused-argument 483 # Basic test for get_ext_conf() when extensions are not present in the config. The test verifies that 484 # get_ext_conf() fetches the correct data by comparing the returned data with the test data provided the 485 # mock_wire_protocol. 486 with mock_wire_protocol(mockwiredata.DATA_FILE_NO_EXT) as protocol: 487 ext_conf = protocol.client.get_ext_conf() 488 489 ext_handlers_names = [ext_handler.name for ext_handler in ext_conf.ext_handlers.extHandlers] 490 self.assertEqual(0, len(ext_conf.ext_handlers.extHandlers), 491 "Unexpected number of extension handlers in the extension config: [{0}]".format(ext_handlers_names)) 492 vmagent_manifests = [manifest.family for manifest in ext_conf.vmagent_manifests.vmAgentManifests] 493 self.assertEqual(0, len(ext_conf.vmagent_manifests.vmAgentManifests), 494 "Unexpected number of vmagent manifests in the extension config: [{0}]".format(vmagent_manifests)) 495 self.assertIsNone(ext_conf.status_upload_blob, 496 "Status upload blob in the extension config is expected to be None") 497 self.assertIsNone(ext_conf.status_upload_blob_type, 498 "Type of status upload blob in the extension config is expected to be None") 499 self.assertIsNone(ext_conf.artifacts_profile_blob, 500 "Artifacts profile blob in the extensions config is expected to be None") 501 502 def test_get_ext_conf_with_extensions_should_retrieve_ext_handlers_and_vmagent_manifests_info(self): 503 # Basic test for get_ext_conf() when extensions are present in the config. The test verifies that get_ext_conf() 504 # fetches the correct data by comparing the returned data with the test data provided the mock_wire_protocol. 505 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 506 wire_protocol_client = protocol.client 507 ext_conf = wire_protocol_client.get_ext_conf() 508 509 ext_handlers_names = [ext_handler.name for ext_handler in ext_conf.ext_handlers.extHandlers] 510 self.assertEqual(1, len(ext_conf.ext_handlers.extHandlers), 511 "Unexpected number of extension handlers in the extension config: [{0}]".format(ext_handlers_names)) 512 vmagent_manifests = [manifest.family for manifest in ext_conf.vmagent_manifests.vmAgentManifests] 513 self.assertEqual(2, len(ext_conf.vmagent_manifests.vmAgentManifests), 514 "Unexpected number of vmagent manifests in the extension config: [{0}]".format(vmagent_manifests)) 515 self.assertEqual("https://test.blob.core.windows.net/vhds/test-cs12.test-cs12.test-cs12.status?sr=b&sp=rw" 516 "&se=9999-01-01&sk=key1&sv=2014-02-14&sig=hfRh7gzUE7sUtYwke78IOlZOrTRCYvkec4hGZ9zZzXo", 517 ext_conf.status_upload_blob, "Unexpected value for status upload blob URI") 518 self.assertEqual("BlockBlob", ext_conf.status_upload_blob_type, 519 "Unexpected status upload blob type in the extension config") 520 self.assertEqual(None, ext_conf.artifacts_profile_blob, 521 "Artifacts profile blob in the extension config should have been None") 522 523 def test_download_ext_handler_pkg_should_not_invoke_host_channel_when_direct_channel_succeeds(self): 524 extension_url = 'https://fake_host/fake_extension.zip' 525 target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') 526 527 def http_get_handler(url, *_, **__): 528 if url == extension_url: 529 return MockResponse(body=b'', status_code=200) 530 if self.is_host_plugin_extension_artifact_request(url): 531 self.fail('The host channel should not have been used') 532 return None 533 534 with mock_wire_protocol(mockwiredata.DATA_FILE, http_get_handler=http_get_handler) as protocol: 535 HostPluginProtocol.is_default_channel = False 536 537 success = protocol.download_ext_handler_pkg(extension_url, target_file) 538 539 urls = protocol.get_tracked_urls() 540 self.assertTrue(success, "The download should have succeeded") 541 self.assertEqual(len(urls), 1, "Unexpected number of HTTP requests: [{0}]".format(urls)) 542 self.assertEqual(urls[0], extension_url, "The extension should have been downloaded over the direct channel") 543 self.assertTrue(os.path.exists(target_file), "The extension package was not downloaded") 544 self.assertFalse(HostPluginProtocol.is_default_channel, "The host channel should not have been set as the default") 545 546 def test_download_ext_handler_pkg_should_use_host_channel_when_direct_channel_fails_and_set_host_as_default(self): 547 extension_url = 'https://fake_host/fake_extension.zip' 548 target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') 549 550 def http_get_handler(url, *_, **kwargs): 551 if url == extension_url: 552 return HttpError("Exception to fake an error on the direct channel") 553 if self.is_host_plugin_extension_request(url, kwargs, extension_url): 554 return MockResponse(body=b'', status_code=200) 555 return None 556 557 with mock_wire_protocol(mockwiredata.DATA_FILE, http_get_handler=http_get_handler) as protocol: 558 HostPluginProtocol.is_default_channel = False 559 560 success = protocol.download_ext_handler_pkg(extension_url, target_file) 561 562 urls = protocol.get_tracked_urls() 563 self.assertTrue(success, "The download should have succeeded") 564 self.assertEqual(len(urls), 2, "Unexpected number of HTTP requests: [{0}]".format(urls)) 565 self.assertEqual(urls[0], extension_url, "The first attempt should have been over the direct channel") 566 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The retry attempt should have been over the host channel") 567 self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') 568 self.assertTrue(HostPluginProtocol.is_default_channel, "The host channel should have been set as the default") 569 570 def test_download_ext_handler_pkg_should_retry_the_host_channel_after_refreshing_host_plugin(self): 571 extension_url = 'https://fake_host/fake_extension.zip' 572 target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') 573 574 def http_get_handler(url, *_, **kwargs): 575 if url == extension_url: 576 return HttpError("Exception to fake an error on the direct channel") 577 if self.is_host_plugin_extension_request(url, kwargs, extension_url): 578 # fake a stale goal state then succeed once the goal state has been refreshed 579 if http_get_handler.goal_state_requests == 0: 580 http_get_handler.goal_state_requests += 1 581 return ResourceGoneError("Exception to fake a stale goal") 582 return MockResponse(body=b'', status_code=200) 583 if self.is_goal_state_request(url): 584 protocol.track_url(url) # track requests for the goal state 585 return None 586 http_get_handler.goal_state_requests = 0 587 588 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 589 HostPluginProtocol.is_default_channel = False 590 591 try: 592 # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests. 593 protocol.client.get_host_plugin() 594 595 protocol.set_http_handlers(http_get_handler=http_get_handler) 596 597 success = protocol.download_ext_handler_pkg(extension_url, target_file) 598 599 urls = protocol.get_tracked_urls() 600 self.assertTrue(success, "The download should have succeeded") 601 self.assertEqual(len(urls), 4, "Unexpected number of HTTP requests: [{0}]".format(urls)) 602 self.assertEqual(urls[0], extension_url, "The first attempt should have been over the direct channel") 603 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second attempt should have been over the host channel") 604 self.assertTrue(self.is_goal_state_request(urls[2]), "The host channel should have been refreshed the goal state") 605 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The third attempt should have been over the host channel") 606 self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') 607 self.assertTrue(HostPluginProtocol.is_default_channel, "The host channel should have been set as the default") 608 finally: 609 HostPluginProtocol.is_default_channel = False 610 611 def test_download_ext_handler_pkg_should_not_change_default_channel_when_all_channels_fail(self): 612 extension_url = 'https://fake_host/fake_extension.zip' 613 target_file = os.path.join(self.tmp_dir, "fake_extension.zip") 614 615 def http_get_handler(url, *_, **kwargs): 616 if url == extension_url or self.is_host_plugin_extension_request(url, kwargs, extension_url): 617 return MockResponse(body=b"content not found", status_code=404, reason="Not Found") 618 if self.is_goal_state_request(url): 619 protocol.track_url(url) # keep track of goal state requests 620 return None 621 622 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 623 HostPluginProtocol.is_default_channel = False 624 625 # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests. 626 protocol.client.get_host_plugin() 627 628 protocol.set_http_handlers(http_get_handler=http_get_handler) 629 630 success = protocol.download_ext_handler_pkg(extension_url, target_file) 631 632 urls = protocol.get_tracked_urls() 633 self.assertFalse(success, "The download should have failed") 634 self.assertEqual(len(urls), 2, "Unexpected number of HTTP requests: [{0}]".format(urls)) 635 self.assertEqual(urls[0], extension_url, "The first attempt should have been over the direct channel") 636 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second attempt should have been over the host channel") 637 self.assertFalse(os.path.exists(target_file), "The extension package was downloaded and it shouldn't have") 638 self.assertFalse(HostPluginProtocol.is_default_channel, "The host channel should not have been set as the default") 639 640 def test_fetch_manifest_should_not_invoke_host_channel_when_direct_channel_succeeds(self): 641 manifest_url = 'https://fake_host/fake_manifest.xml' 642 manifest_xml = '<?xml version="1.0" encoding="utf-8"?><PluginVersionManifest/>' 643 644 def http_get_handler(url, *_, **__): 645 if url == manifest_url: 646 return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) 647 if url.endswith('/extensionArtifact'): 648 self.fail('The Host GA Plugin should not have been invoked') 649 return None 650 651 with mock_wire_protocol(mockwiredata.DATA_FILE, http_get_handler=http_get_handler) as protocol: 652 HostPluginProtocol.is_default_channel = False 653 654 manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) 655 656 urls = protocol.get_tracked_urls() 657 self.assertEqual(manifest, manifest_xml, 'The expected manifest was not downloaded') 658 self.assertEqual(len(urls), 1, "Unexpected number of HTTP requests: [{0}]".format(urls)) 659 self.assertEqual(urls[0], manifest_url, "The manifest should have been downloaded over the direct channel") 660 self.assertFalse(HostPluginProtocol.is_default_channel, "The default channel should not have changed") 661 662 def test_fetch_manifest_should_use_host_channel_when_direct_channel_fails_and_set_it_to_default(self): 663 manifest_url = 'https://fake_host/fake_manifest.xml' 664 manifest_xml = '<?xml version="1.0" encoding="utf-8"?><PluginVersionManifest/>' 665 666 def http_get_handler(url, *_, **kwargs): 667 if url == manifest_url: 668 return ResourceGoneError("Exception to fake an error on the direct channel") 669 if self.is_host_plugin_extension_request(url, kwargs, manifest_url): 670 return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) 671 return None 672 673 with mock_wire_protocol(mockwiredata.DATA_FILE, http_get_handler=http_get_handler) as protocol: 674 HostPluginProtocol.is_default_channel = False 675 676 try: 677 manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) 678 679 urls = protocol.get_tracked_urls() 680 self.assertEqual(manifest, manifest_xml, 'The expected manifest was not downloaded') 681 self.assertEqual(len(urls), 2, "Unexpected number of HTTP requests: [{0}]".format(urls)) 682 self.assertEqual(urls[0], manifest_url, "The first attempt should have been over the direct channel") 683 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The retry should have been over the host channel") 684 self.assertTrue(HostPluginProtocol.is_default_channel, "The host should have been set as the default channel") 685 finally: 686 HostPluginProtocol.is_default_channel = False # Reset default channel 687 688 def test_fetch_manifest_should_retry_the_host_channel_after_refreshing_the_host_plugin_and_set_the_host_as_default(self): 689 manifest_url = 'https://fake_host/fake_manifest.xml' 690 manifest_xml = '<?xml version="1.0" encoding="utf-8"?><PluginVersionManifest/>' 691 692 def http_get_handler(url, *_, **kwargs): 693 if url == manifest_url: 694 return HttpError("Exception to fake an error on the direct channel") 695 if self.is_host_plugin_extension_request(url, kwargs, manifest_url): 696 # fake a stale goal state then succeed once the goal state has been refreshed 697 if http_get_handler.goal_state_requests == 0: 698 http_get_handler.goal_state_requests += 1 699 return ResourceGoneError("Exception to fake a stale goal state") 700 return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) 701 elif self.is_goal_state_request(url): 702 protocol.track_url(url) # keep track of goal state requests 703 return None 704 http_get_handler.goal_state_requests = 0 705 706 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 707 HostPluginProtocol.is_default_channel = False 708 709 try: 710 # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests. 711 protocol.client.get_host_plugin() 712 713 protocol.set_http_handlers(http_get_handler=http_get_handler) 714 manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) 715 716 urls = protocol.get_tracked_urls() 717 self.assertEqual(manifest, manifest_xml) 718 self.assertEqual(len(urls), 4, "Unexpected number of HTTP requests: [{0}]".format(urls)) 719 self.assertEqual(urls[0], manifest_url, "The first attempt should have been over the direct channel") 720 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second attempt should have been over the host channel") 721 self.assertTrue(self.is_goal_state_request(urls[2]), "The host channel should have been refreshed the goal state") 722 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The third attempt should have been over the host channel") 723 self.assertTrue(HostPluginProtocol.is_default_channel, "The host should have been set as the default channel") 724 finally: 725 HostPluginProtocol.is_default_channel = False # Reset default channel 726 727 def test_fetch_manifest_should_update_goal_state_and_not_change_default_channel_if_host_fails(self): 728 manifest_url = 'https://fake_host/fake_manifest.xml' 729 730 def http_get_handler(url, *_, **kwargs): 731 if url == manifest_url or self.is_host_plugin_extension_request(url, kwargs, manifest_url): 732 return ResourceGoneError("Exception to fake an error on either channel") 733 elif self.is_goal_state_request(url): 734 protocol.track_url(url) # keep track of goal state requests 735 return None 736 737 # Everything fails. Goal state should have been updated and host channel should not have been set as default. 738 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 739 HostPluginProtocol.is_default_channel = False 740 741 # initialization of the host plugin triggers a request for the goal state; do it here before we start 742 # tracking those requests. 743 protocol.client.get_host_plugin() 744 745 protocol.set_http_handlers(http_get_handler=http_get_handler) 746 747 with self.assertRaises(ExtensionDownloadError): 748 protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) 749 750 urls = protocol.get_tracked_urls() 751 self.assertEqual(len(urls), 4, "Unexpected number of HTTP requests: [{0}]".format(urls)) 752 self.assertEqual(urls[0], manifest_url, "The first attempt should have been over the direct channel") 753 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second attempt should have been over the host channel") 754 self.assertTrue(self.is_goal_state_request(urls[2]), "The host channel should have been refreshed the goal state") 755 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The third attempt should have been over the host channel") 756 self.assertFalse(HostPluginProtocol.is_default_channel, "The host should not have been set as the default channel") 757 758 self.assertEqual(HostPluginProtocol.is_default_channel, False) 759 760 def test_get_artifacts_profile_should_not_invoke_host_channel_when_direct_channel_succeeds(self): 761 def http_get_handler(url, *_, **__): 762 if self.is_in_vm_artifacts_profile_request(url): 763 protocol.track_url(url) 764 return None 765 766 with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE, http_get_handler=http_get_handler) as protocol: 767 HostPluginProtocol.is_default_channel = False 768 769 return_value = protocol.client.get_artifacts_profile() 770 771 self.assertIsInstance(return_value, InVMArtifactsProfile, 'The request did not return a valid artifacts profile: {0}'.format(return_value)) 772 urls = protocol.get_tracked_urls() 773 self.assertEqual(len(urls), 1, "Unexpected HTTP requests: [{0}]".format(urls)) 774 self.assertFalse(HostPluginProtocol.is_default_channel, "The host should not have been set as the default channel") 775 776 def test_get_artifacts_profile_should_use_host_channel_when_direct_channel_fails(self): 777 def http_get_handler(url, *_, **kwargs): 778 if self.is_in_vm_artifacts_profile_request(url): 779 return HttpError("Exception to fake an error on the direct channel") 780 if self.is_host_plugin_in_vm_artifacts_profile_request(url, kwargs): 781 protocol.track_url(url) 782 return None 783 784 with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE) as protocol: 785 HostPluginProtocol.is_default_channel = False 786 787 try: 788 protocol.set_http_handlers(http_get_handler=http_get_handler) 789 790 return_value = protocol.client.get_artifacts_profile() 791 792 self.assertIsNotNone(return_value, "The artifacts profile request should have succeeded") 793 self.assertIsInstance(return_value, InVMArtifactsProfile, 'The request did not return a valid artifacts profile: {0}'.format(return_value)) 794 self.assertTrue(return_value.onHold, 'The OnHold property should be True') # pylint: disable=no-member 795 urls = protocol.get_tracked_urls() 796 self.assertEqual(len(urls), 2, "Invalid number of requests: [{0}]".format(urls)) 797 self.assertTrue(self.is_in_vm_artifacts_profile_request(urls[0]), "The first request should have been over the direct channel") 798 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second request should have been over the host channel") 799 self.assertTrue(HostPluginProtocol.is_default_channel, "The default channel should have changed to the host") 800 finally: 801 HostPluginProtocol.is_default_channel = False 802 803 def test_get_artifacts_profile_should_retry_the_host_channel_after_refreshing_the_host_plugin(self): 804 def http_get_handler(url, *_, **kwargs): 805 if self.is_in_vm_artifacts_profile_request(url): 806 return HttpError("Exception to fake an error on the direct channel") 807 if self.is_host_plugin_in_vm_artifacts_profile_request(url, kwargs): 808 if http_get_handler.host_plugin_calls == 0: 809 http_get_handler.host_plugin_calls += 1 810 return ResourceGoneError("Exception to fake a stale goal state") 811 protocol.track_url(url) 812 if self.is_goal_state_request(url): 813 protocol.track_url(url) 814 return None 815 http_get_handler.host_plugin_calls = 0 816 817 with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE) as protocol: 818 HostPluginProtocol.is_default_channel = False 819 820 try: 821 # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests. 822 protocol.client.get_host_plugin() 823 824 protocol.set_http_handlers(http_get_handler=http_get_handler) 825 826 return_value = protocol.client.get_artifacts_profile() 827 828 self.assertIsNotNone(return_value, "The artifacts profile request should have succeeded") 829 self.assertIsInstance(return_value, InVMArtifactsProfile, 'The request did not return a valid artifacts profile: {0}'.format(return_value)) 830 self.assertTrue(return_value.onHold, 'The OnHold property should be True') # pylint: disable=no-member 831 urls = protocol.get_tracked_urls() 832 self.assertEqual(len(urls), 4, "Invalid number of requests: [{0}]".format(urls)) 833 self.assertTrue(self.is_in_vm_artifacts_profile_request(urls[0]), "The first request should have been over the direct channel") 834 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second request should have been over the host channel") 835 self.assertTrue(self.is_goal_state_request(urls[2]), "The goal state should have been refreshed before retrying the host channel") 836 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The retry request should have been over the host channel") 837 self.assertTrue(HostPluginProtocol.is_default_channel, "The default channel should have changed to the host") 838 finally: 839 HostPluginProtocol.is_default_channel = False 840 841 def test_get_artifacts_profile_should_refresh_the_host_plugin_and_not_change_default_channel_if_host_plugin_fails(self): 842 def http_get_handler(url, *_, **kwargs): 843 if self.is_in_vm_artifacts_profile_request(url): 844 return HttpError("Exception to fake an error on the direct channel") 845 if self.is_host_plugin_in_vm_artifacts_profile_request(url, kwargs): 846 return ResourceGoneError("Exception to fake a stale goal state") 847 if self.is_goal_state_request(url): 848 protocol.track_url(url) 849 return None 850 851 with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE) as protocol: 852 HostPluginProtocol.is_default_channel = False 853 854 # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests. 855 protocol.client.get_host_plugin() 856 857 protocol.set_http_handlers(http_get_handler=http_get_handler) 858 859 return_value = protocol.client.get_artifacts_profile() 860 861 self.assertIsNone(return_value, "The artifacts profile request should have failed") 862 urls = protocol.get_tracked_urls() 863 self.assertEqual(len(urls), 4, "Invalid number of requests: [{0}]".format(urls)) 864 self.assertTrue(self.is_in_vm_artifacts_profile_request(urls[0]), "The first request should have been over the direct channel") 865 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second request should have been over the host channel") 866 self.assertTrue(self.is_goal_state_request(urls[2]), "The goal state should have been refreshed before retrying the host channel") 867 self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The retry request should have been over the host channel") 868 self.assertFalse(HostPluginProtocol.is_default_channel, "The default channel should not have changed") 869 870 def test_upload_logs_should_not_refresh_plugin_when_first_attempt_succeeds(self): 871 def http_put_handler(url, *_, **__): # pylint: disable=inconsistent-return-statements 872 if self.is_host_plugin_put_logs_request(url): 873 return MockResponse(body=b'', status_code=200) 874 875 with mock_wire_protocol(mockwiredata.DATA_FILE, http_put_handler=http_put_handler) as protocol: 876 content = b"test" 877 protocol.client.upload_logs(content) 878 879 urls = protocol.get_tracked_urls() 880 self.assertEqual(len(urls), 1, 'Expected one post request to the host: [{0}]'.format(urls)) 881 882 def test_upload_logs_should_retry_the_host_channel_after_refreshing_the_host_plugin(self): 883 def http_put_handler(url, *_, **__): 884 if self.is_host_plugin_put_logs_request(url): 885 if http_put_handler.host_plugin_calls == 0: 886 http_put_handler.host_plugin_calls += 1 887 return ResourceGoneError("Exception to fake a stale goal state") 888 protocol.track_url(url) 889 return None 890 http_put_handler.host_plugin_calls = 0 891 892 with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE, http_put_handler=http_put_handler) \ 893 as protocol: 894 content = b"test" 895 protocol.client.upload_logs(content) 896 897 urls = protocol.get_tracked_urls() 898 self.assertEqual(len(urls), 2, "Invalid number of requests: [{0}]".format(urls)) 899 self.assertTrue(self.is_host_plugin_put_logs_request(urls[0]), "The first request should have been over the host channel") 900 self.assertTrue(self.is_host_plugin_put_logs_request(urls[1]), "The second request should have been over the host channel") 901 902 @staticmethod 903 def _set_and_fail_helper_channel_functions(fail_direct=False, fail_host=False): 904 def direct_func(*_): 905 direct_func.counter += 1 906 if direct_func.fail: 907 return None 908 return "direct" 909 910 def host_func(*_): 911 host_func.counter += 1 912 if host_func.fail: 913 return None 914 return "host" 915 916 direct_func.counter = 0 917 direct_func.fail = fail_direct 918 919 host_func.counter = 0 920 host_func.fail = fail_host 921 922 return direct_func, host_func 923 924 def test_send_request_using_appropriate_channel_should_not_invoke_secondary_when_primary_channel_succeeds(self): 925 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 926 # Scenario #1: Direct channel default 927 HostPluginProtocol.is_default_channel = False 928 929 direct_func, host_func = self._set_and_fail_helper_channel_functions() 930 # Assert we're only calling the primary channel (direct) and that it succeeds. 931 for iteration in range(5): 932 ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func) 933 self.assertEqual("direct", ret) 934 self.assertEqual(iteration + 1, direct_func.counter) 935 self.assertEqual(0, host_func.counter) 936 self.assertFalse(HostPluginProtocol.is_default_channel) 937 938 # Scenario #2: Host channel default 939 HostPluginProtocol.is_default_channel = True 940 direct_func, host_func = self._set_and_fail_helper_channel_functions() 941 942 # Assert we're only calling the primary channel (host) and that it succeeds. 943 for iteration in range(5): 944 ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func) 945 self.assertEqual("host", ret) 946 self.assertEqual(0, direct_func.counter) 947 self.assertEqual(iteration + 1, host_func.counter) 948 self.assertTrue(HostPluginProtocol.is_default_channel) 949 950 def test_send_request_using_appropriate_channel_should_not_change_default_channel_if_none_succeeds(self): 951 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 952 # Scenario #1: Direct channel is default 953 HostPluginProtocol.is_default_channel = False 954 direct_func, host_func = self._set_and_fail_helper_channel_functions(fail_direct=True, fail_host=True) 955 956 # Assert we keep trying both channels, but the default channel doesn't change 957 for iteration in range(5): 958 ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func) 959 self.assertEqual(None, ret) 960 self.assertEqual(iteration + 1, direct_func.counter) 961 self.assertEqual(iteration + 1, host_func.counter) 962 self.assertFalse(HostPluginProtocol.is_default_channel) 963 964 # Scenario #2: Host channel is default 965 HostPluginProtocol.is_default_channel = True 966 direct_func, host_func = self._set_and_fail_helper_channel_functions(fail_direct=True, fail_host=True) 967 968 # Assert we keep trying both channels, but the default channel doesn't change 969 for iteration in range(5): 970 ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func) 971 self.assertEqual(None, ret) 972 self.assertEqual(iteration + 1, direct_func.counter) 973 self.assertEqual(iteration + 1, host_func.counter) 974 self.assertTrue(HostPluginProtocol.is_default_channel) 975 976 def test_send_request_using_appropriate_channel_should_change_default_channel_when_secondary_succeeds(self): 977 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 978 # Scenario #1: Direct channel is default 979 HostPluginProtocol.is_default_channel = False 980 direct_func, host_func = self._set_and_fail_helper_channel_functions(fail_direct=True, fail_host=False) 981 982 # Assert we've called both channels and the default channel changed 983 ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func) 984 self.assertEqual("host", ret) 985 self.assertEqual(1, direct_func.counter) 986 self.assertEqual(1, host_func.counter) 987 self.assertTrue(HostPluginProtocol.is_default_channel) 988 989 # If host keeps succeeding, assert we keep calling only that channel and not changing the default. 990 for iteration in range(5): 991 ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func) 992 self.assertEqual("host", ret) 993 self.assertEqual(1, direct_func.counter) 994 self.assertEqual(1 + iteration + 1, host_func.counter) 995 self.assertTrue(HostPluginProtocol.is_default_channel) 996 997 # Scenario #2: Host channel is default 998 HostPluginProtocol.is_default_channel = True 999 direct_func, host_func = self._set_and_fail_helper_channel_functions(fail_direct=False, fail_host=True) 1000 1001 # Assert we've called both channels and the default channel changed 1002 ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func) 1003 self.assertEqual("direct", ret) 1004 self.assertEqual(1, direct_func.counter) 1005 self.assertEqual(1, host_func.counter) 1006 self.assertFalse(HostPluginProtocol.is_default_channel) 1007 1008 # If direct keeps succeeding, assert we keep calling only that channel and not changing the default. 1009 for iteration in range(5): 1010 ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func) 1011 self.assertEqual("direct", ret) 1012 self.assertEqual(1 + iteration + 1, direct_func.counter) 1013 self.assertEqual(1, host_func.counter) 1014 self.assertFalse(HostPluginProtocol.is_default_channel) 1015 1016 1017class UpdateGoalStateTestCase(AgentTestCase): 1018 """ 1019 Tests for WireClient.update_goal_state() 1020 """ 1021 1022 def test_it_should_update_the_goal_state_and_the_host_plugin_when_the_incarnation_changes(self): 1023 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 1024 protocol.client.get_host_plugin() 1025 1026 # if the incarnation changes the behavior is the same for forced and non-forced updates 1027 for forced in [True, False]: 1028 protocol.mock_wire_data.reload() # start each iteration of the test with fresh mock data 1029 1030 # 1031 # Update the mock data with random values; include at least one field from each of the components 1032 # in the goal state to ensure the entire state was updated. Note that numeric entities, e.g. incarnation, are 1033 # actually represented as strings in the goal state. 1034 # 1035 # Note that the shared config is not parsed by the agent, so we modify the XML data directly. Also, the 1036 # certificates are encrypted and it is hard to update a single field; instead, we update the entire list with 1037 # empty. 1038 # 1039 new_incarnation = str(uuid.uuid4()) 1040 new_container_id = str(uuid.uuid4()) 1041 new_role_config_name = str(uuid.uuid4()) 1042 new_hosting_env_deployment_name = str(uuid.uuid4()) 1043 new_shared_conf = WireProtocolData.replace_xml_attribute_value(protocol.mock_wire_data.shared_config, "Deployment", "name", str(uuid.uuid4())) 1044 new_sequence_number = str(uuid.uuid4()) 1045 1046 if '<Format>Pkcs7BlobWithPfxContents</Format>' not in protocol.mock_wire_data.certs: 1047 raise Exception('This test requires a non-empty certificate list') 1048 1049 protocol.mock_wire_data.set_incarnation(new_incarnation) 1050 protocol.mock_wire_data.set_container_id(new_container_id) 1051 protocol.mock_wire_data.set_role_config_name(new_role_config_name) 1052 protocol.mock_wire_data.set_hosting_env_deployment_name(new_hosting_env_deployment_name) 1053 protocol.mock_wire_data.shared_config = new_shared_conf 1054 protocol.mock_wire_data.set_extensions_config_sequence_number(new_sequence_number) 1055 protocol.mock_wire_data.certs = r'''<?xml version="1.0" encoding="utf-8"?> 1056 <CertificateFile><Version>2012-11-30</Version> 1057 <Incarnation>12</Incarnation> 1058 <Format>CertificatesNonPfxPackage</Format> 1059 <Data>NotPFXData</Data> 1060 </CertificateFile> 1061 ''' 1062 1063 if forced: 1064 protocol.client.update_goal_state(forced=True) 1065 else: 1066 protocol.client.update_goal_state() 1067 1068 sequence_number = protocol.client.get_ext_conf().ext_handlers.extHandlers[0].properties.extensions[0].sequenceNumber 1069 1070 self.assertEqual(protocol.client.get_goal_state().incarnation, new_incarnation) 1071 self.assertEqual(protocol.client.get_hosting_env().deployment_name, new_hosting_env_deployment_name) 1072 self.assertEqual(protocol.client.get_shared_conf().xml_text, new_shared_conf) 1073 self.assertEqual(sequence_number, new_sequence_number) 1074 self.assertEqual(len(protocol.client.get_certs().cert_list.certificates), 0) 1075 1076 self.assertEqual(protocol.client.get_host_plugin().container_id, new_container_id) 1077 self.assertEqual(protocol.client.get_host_plugin().role_config_name, new_role_config_name) 1078 1079 def test_non_forced_update_should_not_update_the_goal_state_nor_the_host_plugin_when_the_incarnation_does_not_change(self): 1080 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 1081 protocol.client.get_host_plugin() 1082 1083 # The container id, role config name and shared config can change without the incarnation changing; capture the initial 1084 # goal state and then change those fields. 1085 goal_state = protocol.client.get_goal_state().xml_text 1086 shared_conf = protocol.client.get_shared_conf().xml_text 1087 container_id = protocol.client.get_host_plugin().container_id 1088 role_config_name = protocol.client.get_host_plugin().role_config_name 1089 1090 protocol.mock_wire_data.set_container_id(str(uuid.uuid4())) 1091 protocol.mock_wire_data.set_role_config_name(str(uuid.uuid4())) 1092 protocol.mock_wire_data.shared_config = WireProtocolData.replace_xml_attribute_value( 1093 protocol.mock_wire_data.shared_config, "Deployment", "name", str(uuid.uuid4())) 1094 1095 protocol.client.update_goal_state() 1096 1097 self.assertEqual(protocol.client.get_goal_state().xml_text, goal_state) 1098 self.assertEqual(protocol.client.get_shared_conf().xml_text, shared_conf) 1099 1100 self.assertEqual(protocol.client.get_host_plugin().container_id, container_id) 1101 self.assertEqual(protocol.client.get_host_plugin().role_config_name, role_config_name) 1102 1103 def test_forced_update_should_update_the_goal_state_and_the_host_plugin_when_the_incarnation_does_not_change(self): 1104 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 1105 protocol.client.get_host_plugin() 1106 1107 # The container id, role config name and shared config can change without the incarnation changing 1108 incarnation = protocol.client.get_goal_state().incarnation 1109 new_container_id = str(uuid.uuid4()) 1110 new_role_config_name = str(uuid.uuid4()) 1111 new_shared_conf = WireProtocolData.replace_xml_attribute_value( 1112 protocol.mock_wire_data.shared_config, "Deployment", "name", str(uuid.uuid4())) 1113 1114 protocol.mock_wire_data.set_container_id(new_container_id) 1115 protocol.mock_wire_data.set_role_config_name(new_role_config_name) 1116 protocol.mock_wire_data.shared_config = new_shared_conf 1117 1118 protocol.client.update_goal_state(forced=True) 1119 1120 self.assertEqual(protocol.client.get_goal_state().incarnation, incarnation) 1121 self.assertEqual(protocol.client.get_shared_conf().xml_text, new_shared_conf) 1122 1123 self.assertEqual(protocol.client.get_host_plugin().container_id, new_container_id) 1124 self.assertEqual(protocol.client.get_host_plugin().role_config_name, new_role_config_name) 1125 1126 def test_update_goal_state_should_archive_last_goal_state(self): 1127 # We use the last modified timestamp of the goal state to be archived to determine the archive's name. 1128 mock_mtime = os.path.getmtime(self.tmp_dir) 1129 with patch("azurelinuxagent.common.utils.archive.os.path.getmtime") as patch_mtime: 1130 first_gs_ms = mock_mtime + timedelta(minutes=5).seconds 1131 second_gs_ms = mock_mtime + timedelta(minutes=10).seconds 1132 third_gs_ms = mock_mtime + timedelta(minutes=15).seconds 1133 1134 patch_mtime.side_effect = [first_gs_ms, second_gs_ms, third_gs_ms] 1135 1136 # The first goal state is created when we instantiate the protocol 1137 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 1138 history_dir = os.path.join(conf.get_lib_dir(), "history") 1139 archives = os.listdir(history_dir) 1140 self.assertEqual(len(archives), 0, "The goal state archive should have been empty since this is the first goal state") 1141 1142 # Create the second new goal state, so the initial one should be archived 1143 protocol.mock_wire_data.set_incarnation("2") 1144 protocol.client.update_goal_state() 1145 1146 # The initial goal state should be in the archive 1147 first_archive_name = datetime.utcfromtimestamp(first_gs_ms).isoformat() + "_incarnation_1" 1148 archives = os.listdir(history_dir) 1149 self.assertEqual(len(archives), 1, "Only one goal state should have been archived") 1150 self.assertEqual(archives[0], first_archive_name, "The name of goal state archive should match the first goal state timestamp and incarnation") 1151 1152 # Create the third goal state, so the second one should be archived too 1153 protocol.mock_wire_data.set_incarnation("3") 1154 protocol.client.update_goal_state() 1155 1156 # The second goal state should be in the archive 1157 second_archive_name = datetime.utcfromtimestamp(second_gs_ms).isoformat() + "_incarnation_2" 1158 archives = os.listdir(history_dir) 1159 archives.sort() 1160 self.assertEqual(len(archives), 2, "Two goal states should have been archived") 1161 self.assertEqual(archives[1], second_archive_name, "The name of goal state archive should match the second goal state timestamp and incarnation") 1162 1163 def test_update_goal_state_should_not_persist_the_protected_settings(self): 1164 with mock_wire_protocol(mockwiredata.DATA_FILE_MULTIPLE_EXT) as protocol: 1165 # instantiating the protocol fetches the goal state, so there is no need to do another call to update_goal_state() 1166 goal_state = protocol.client.get_goal_state() 1167 1168 protected_settings = [] 1169 for ext_handler in goal_state.ext_conf.ext_handlers.extHandlers: 1170 for extension in ext_handler.properties.extensions: 1171 if extension.protectedSettings is not None: 1172 protected_settings.append(extension.protectedSettings) 1173 if len(protected_settings) == 0: 1174 raise Exception("The test goal state does not include any protected settings") 1175 1176 extensions_config_file = os.path.join(conf.get_lib_dir(), EXT_CONF_FILE_NAME.format(goal_state.incarnation)) 1177 if not os.path.exists(extensions_config_file): 1178 raise Exception("Cannot find {0}".format(extensions_config_file)) 1179 1180 with open(extensions_config_file, "r") as stream: 1181 extensions_config = stream.read() 1182 1183 for settings in protected_settings: 1184 self.assertNotIn(settings, extensions_config, "The protectedSettings should not have been saved to {0}".format(extensions_config_file)) 1185 1186 matches = re.findall(r'"protectedSettings"\s*:\s*"\*\*\* REDACTED \*\*\*"', extensions_config) 1187 self.assertEqual( 1188 len(matches), 1189 len(protected_settings), 1190 "Could not find the expected number of redacted settings. Expected {0}.\n{1}".format(len(protected_settings), extensions_config)) 1191 1192 1193class TryUpdateGoalStateTestCase(HttpRequestPredicates, AgentTestCase): 1194 """ 1195 Tests for WireClient.try_update_goal_state() 1196 """ 1197 def test_it_should_return_true_on_success(self): 1198 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 1199 self.assertTrue(protocol.client.try_update_goal_state(), "try_update_goal_state should have succeeded") 1200 1201 def test_incomplete_gs_should_fail(self): 1202 1203 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 1204 GoalState.fetch_full_goal_state(protocol.client) 1205 1206 protocol.mock_wire_data.data_files = mockwiredata.DATA_FILE_NOOP_GS 1207 protocol.mock_wire_data.reload() 1208 protocol.mock_wire_data.set_incarnation(2) 1209 1210 with self.assertRaises(IncompleteGoalStateError): 1211 GoalState.fetch_full_goal_state_if_incarnation_different_than(protocol.client, 1) 1212 1213 def test_it_should_return_false_on_failure(self): 1214 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 1215 def http_get_handler(url, *_, **__): 1216 if self.is_goal_state_request(url): 1217 return HttpError('Exception to fake an error retrieving the goal state') 1218 return None 1219 1220 protocol.set_http_handlers(http_get_handler=http_get_handler) 1221 1222 self.assertFalse(protocol.client.try_update_goal_state(), "try_update_goal_state should have failed") 1223 1224 def test_it_should_log_errors_only_when_the_error_state_changes(self): 1225 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 1226 def http_get_handler(url, *_, **__): 1227 if self.is_goal_state_request(url): 1228 if fail_goal_state_request: 1229 return HttpError('Exception to fake an error retrieving the goal state') 1230 return None 1231 1232 protocol.set_http_handlers(http_get_handler=http_get_handler) 1233 1234 @contextlib.contextmanager 1235 def create_log_and_telemetry_mocks(): 1236 with patch("azurelinuxagent.common.protocol.wire.logger", autospec=True) as logger_patcher: 1237 with patch("azurelinuxagent.common.protocol.wire.add_event") as add_event_patcher: 1238 yield logger_patcher, add_event_patcher 1239 1240 calls_to_strings = lambda calls: (str(c) for c in calls) 1241 filter_calls = lambda calls, regex=None: (c for c in calls_to_strings(calls) if regex is None or re.match(regex, c)) 1242 logger_calls = lambda regex=None: [m for m in filter_calls(logger.method_calls, regex)] # pylint: disable=used-before-assignment,unnecessary-comprehension 1243 warnings = lambda: logger_calls(r'call.warn\(.*An error occurred while retrieving the goal state.*') 1244 periodic_warnings = lambda: logger_calls(r'call.periodic_warn\(.*Attempts to retrieve the goal state are failing.*') 1245 success_messages = lambda: logger_calls(r'call.info\(.*Retrieving the goal state recovered from previous errors.*') 1246 telemetry_calls = lambda regex=None: [m for m in filter_calls(add_event.mock_calls, regex)] # pylint: disable=used-before-assignment,unnecessary-comprehension 1247 goal_state_events = lambda: telemetry_calls(r".*op='FetchGoalState'.*") 1248 1249 # 1250 # Initially calls to retrieve the goal state are successful... 1251 # 1252 fail_goal_state_request = False 1253 with create_log_and_telemetry_mocks() as (logger, add_event): 1254 protocol.client.try_update_goal_state() 1255 1256 lc = logger_calls() 1257 self.assertTrue(len(lc) == 0, "A successful call should not produce any log messages: [{0}]".format(lc)) 1258 1259 tc = telemetry_calls() 1260 self.assertTrue(len(tc) == 0, "A successful call should not produce any telemetry events: [{0}]".format(tc)) 1261 1262 # 1263 # ... then an error happens... 1264 # 1265 fail_goal_state_request = True 1266 with create_log_and_telemetry_mocks() as (logger, add_event): 1267 protocol.client.try_update_goal_state() 1268 1269 w = warnings() 1270 pw = periodic_warnings() 1271 self.assertEqual(len(w), 1, "A failure should have produced a warning: [{0}]".format(w)) 1272 self.assertEqual(len(pw), 1, "A failure should have produced a periodic warning: [{0}]".format(pw)) 1273 1274 gs = goal_state_events() 1275 self.assertTrue(len(gs) == 1 and 'is_success=False' in gs[0], "A failure should produce a telemetry event (success=false): [{0}]".format(gs)) 1276 1277 # 1278 # ... and errors continue happening... 1279 # 1280 with create_log_and_telemetry_mocks() as (logger, add_event): 1281 protocol.client.try_update_goal_state() 1282 protocol.client.try_update_goal_state() 1283 protocol.client.try_update_goal_state() 1284 1285 w = warnings() 1286 pw = periodic_warnings() 1287 self.assertTrue(len(w) == 0, "Subsequent failures should not produce warnings: [{0}]".format(w)) 1288 self.assertEqual(len(pw), 3, "Subsequent failures should produce periodic warnings: [{0}]".format(pw)) 1289 1290 tc = telemetry_calls() 1291 self.assertTrue(len(tc) == 0, "Subsequent failures should not produce any telemetry events: [{0}]".format(tc)) 1292 1293 # 1294 # ... until we finally succeed 1295 # 1296 fail_goal_state_request = False 1297 with create_log_and_telemetry_mocks() as (logger, add_event): 1298 protocol.client.try_update_goal_state() 1299 1300 s = success_messages() 1301 w = warnings() 1302 pw = periodic_warnings() 1303 self.assertEqual(len(s), 1, "Recovering after failures should have produced an info message: [{0}]".format(s)) 1304 self.assertTrue(len(w) == 0 and len(pw) == 0, "Recovering after failures should have not produced any warnings: [{0}] [{1}]".format(w, pw)) 1305 1306 gs = goal_state_events() 1307 self.assertTrue(len(gs) == 1 and 'is_success=True' in gs[0], "Recovering after failures should produce a telemetry event (success=true): [{0}]".format(gs)) 1308 1309 1310class UpdateHostPluginFromGoalStateTestCase(AgentTestCase): 1311 """ 1312 Tests for WireClient.update_host_plugin_from_goal_state() 1313 """ 1314 1315 def test_it_should_update_the_host_plugin_with_or_without_incarnation_changes(self): 1316 with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: 1317 protocol.client.get_host_plugin() 1318 1319 # the behavior should be the same whether the incarnation changes or not 1320 for incarnation_change in [True, False]: 1321 protocol.mock_wire_data.reload() # start each iteration of the test with fresh mock data 1322 1323 new_container_id = str(uuid.uuid4()) 1324 new_role_config_name = str(uuid.uuid4()) 1325 1326 goal_state_xml_text = protocol.mock_wire_data.goal_state 1327 shared_conf_xml_text = protocol.mock_wire_data.shared_config 1328 1329 if incarnation_change: 1330 protocol.mock_wire_data.set_incarnation(str(uuid.uuid4())) 1331 1332 protocol.mock_wire_data.set_container_id(new_container_id) 1333 protocol.mock_wire_data.set_role_config_name(new_role_config_name) 1334 protocol.mock_wire_data.shared_config = WireProtocolData.replace_xml_attribute_value( 1335 protocol.mock_wire_data.shared_config, "Deployment", "name", str(uuid.uuid4())) 1336 1337 protocol.client.update_host_plugin_from_goal_state() 1338 1339 self.assertEqual(protocol.client.get_host_plugin().container_id, new_container_id) 1340 self.assertEqual(protocol.client.get_host_plugin().role_config_name, new_role_config_name) 1341 1342 # it should not update the goal state 1343 self.assertEqual(protocol.client.get_goal_state().xml_text, goal_state_xml_text) 1344 self.assertEqual(protocol.client.get_shared_conf().xml_text, shared_conf_xml_text) 1345 1346 1347class MockResponse: 1348 def __init__(self, body, status_code, reason=None): 1349 self.body = body 1350 self.status = status_code 1351 self.reason = reason 1352 1353 def read(self, *_): 1354 return self.body 1355 1356 1357if __name__ == '__main__': 1358 unittest.main() 1359