1# -*- encoding: utf-8 -*-
2# Copyright 2018 Microsoft Corporation
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Requires Python 2.6+ and Openssl 1.0+
17#
18
19import contextlib
20import json
21import os
22import re
23import socket
24import time
25import unittest
26import uuid
27from datetime import datetime, timedelta
28
29from azurelinuxagent.common import conf
30from azurelinuxagent.common.agent_supported_feature import SupportedFeatureNames, get_supported_feature_by_name, \
31    get_agent_supported_features_list_for_crp
32from azurelinuxagent.common.exception import ResourceGoneError, ProtocolError, \
33    ExtensionDownloadError, HttpError
34from azurelinuxagent.common.protocol.goal_state import ExtensionsConfig, GoalState
35from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol
36from azurelinuxagent.common.protocol.restapi import VMAgentManifestUri
37from azurelinuxagent.common.protocol.wire import WireProtocol, WireClient, \
38    InVMArtifactsProfile, StatusBlob, VMStatus, EXT_CONF_FILE_NAME
39from azurelinuxagent.common.telemetryevent import GuestAgentExtensionEventsSchema, \
40    TelemetryEventParam, TelemetryEvent
41from azurelinuxagent.common.utils import restutil
42from azurelinuxagent.common.exception import IncompleteGoalStateError
43from azurelinuxagent.common.version import CURRENT_VERSION, DISTRO_NAME, DISTRO_VERSION
44from azurelinuxagent.ga.exthandlers import get_exthandlers_handler
45from tests.ga.test_monitor import random_generator
46from tests.protocol import mockwiredata
47from tests.protocol.mocks import mock_wire_protocol, HttpRequestPredicates
48from tests.protocol.mockwiredata import DATA_FILE_NO_EXT, DATA_FILE
49from tests.protocol.mockwiredata import WireProtocolData
50from tests.tools import Mock, patch, AgentTestCase
51
52data_with_bom = b'\xef\xbb\xbfhehe'
53testurl = 'http://foo'
54testtype = 'BlockBlob'
55WIRESERVER_URL = '168.63.129.16'
56
57
58def get_event(message, duration=30000, evt_type="", is_internal=False, is_success=True,
59              name="", op="Unknown", version=CURRENT_VERSION, eventId=1):
60    event = TelemetryEvent(eventId, "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX")
61    event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Name, name))
62    event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Version, str(version)))
63    event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.IsInternal, is_internal))
64    event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Operation, op))
65    event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.OperationSuccess, is_success))
66    event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Message, message))
67    event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.Duration, duration))
68    event.parameters.append(TelemetryEventParam(GuestAgentExtensionEventsSchema.ExtensionType, evt_type))
69    return event
70
71
72@contextlib.contextmanager
73def create_mock_protocol(artifacts_profile_blob=None, status_upload_blob=None, status_upload_blob_type=None):
74    with mock_wire_protocol(DATA_FILE_NO_EXT) as protocol:
75        # These tests use mock wire data that dont have any extensions (extension config will be empty).
76        # Populate the upload blob and artifacts profile blob.
77        ext_conf = ExtensionsConfig(None)
78        ext_conf.artifacts_profile_blob = artifacts_profile_blob
79        ext_conf.status_upload_blob = status_upload_blob
80        ext_conf.status_upload_blob_type = status_upload_blob_type
81        protocol.client._goal_state.ext_conf = ext_conf  # pylint: disable=protected-access
82
83        yield protocol
84
85
86@patch("time.sleep")
87@patch("azurelinuxagent.common.protocol.wire.CryptUtil")
88@patch("azurelinuxagent.common.protocol.healthservice.HealthService._report")
89class TestWireProtocol(AgentTestCase):
90
91    def setUp(self):
92        super(TestWireProtocol, self).setUp()
93        HostPluginProtocol.is_default_channel = False
94
95    def _test_getters(self, test_data, certsMustBePresent, __, MockCryptUtil, _):
96        MockCryptUtil.side_effect = test_data.mock_crypt_util
97
98        with patch.object(restutil, 'http_get', test_data.mock_http_get):
99            protocol = WireProtocol(WIRESERVER_URL)
100            protocol.detect()
101            protocol.get_vminfo()
102            protocol.get_certs()
103            ext_handlers, etag = protocol.get_ext_handlers()  # pylint: disable=unused-variable
104            for ext_handler in ext_handlers.extHandlers:
105                protocol.get_ext_handler_pkgs(ext_handler)
106
107            crt1 = os.path.join(self.tmp_dir,
108                                '33B0ABCE4673538650971C10F7D7397E71561F35.crt')
109            crt2 = os.path.join(self.tmp_dir,
110                                '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt')
111            prv2 = os.path.join(self.tmp_dir,
112                                '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv')
113            if certsMustBePresent:
114                self.assertTrue(os.path.isfile(crt1))
115                self.assertTrue(os.path.isfile(crt2))
116                self.assertTrue(os.path.isfile(prv2))
117            else:
118                self.assertFalse(os.path.isfile(crt1))
119                self.assertFalse(os.path.isfile(crt2))
120                self.assertFalse(os.path.isfile(prv2))
121            self.assertEqual("1", protocol.get_incarnation())
122
123    @staticmethod
124    def _get_telemetry_events_generator(event_list):
125        def _yield_events():
126            for telemetry_event in event_list:
127                yield telemetry_event
128
129        return _yield_events()
130
131    def test_getters(self, *args):
132        """Normal case"""
133        test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE)
134        self._test_getters(test_data, True, *args)
135
136    def test_getters_no_ext(self, *args):
137        """Provision with agent is not checked"""
138        test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_NO_EXT)
139        self._test_getters(test_data, True, *args)
140
141    def test_getters_ext_no_settings(self, *args):
142        """Extensions without any settings"""
143        test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_EXT_NO_SETTINGS)
144        self._test_getters(test_data, True, *args)
145
146    def test_getters_ext_no_public(self, *args):
147        """Extensions without any public settings"""
148        test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_EXT_NO_PUBLIC)
149        self._test_getters(test_data, True, *args)
150
151    def test_getters_ext_no_cert_format(self, *args):
152        """Certificate format not specified"""
153        test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_NO_CERT_FORMAT)
154        self._test_getters(test_data, True, *args)
155
156    def test_getters_ext_cert_format_not_pfx(self, *args):
157        """Certificate format is not Pkcs7BlobWithPfxContents specified"""
158        test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE_CERT_FORMAT_NOT_PFX)
159        self._test_getters(test_data, False, *args)
160
161    @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_extension_artifact")
162    def test_getters_with_stale_goal_state(self, patch_report, *args):
163        test_data = mockwiredata.WireProtocolData(mockwiredata.DATA_FILE)
164        test_data.emulate_stale_goal_state = True
165
166        self._test_getters(test_data, True, *args)
167        # Ensure HostPlugin was invoked
168        self.assertEqual(1, test_data.call_counts["/versions"])
169        self.assertEqual(2, test_data.call_counts["extensionArtifact"])
170        # Ensure the expected number of HTTP calls were made
171        # -- Tracking calls to retrieve GoalState is problematic since it is
172        #    fetched often; however, the dependent documents, such as the
173        #    HostingEnvironmentConfig, will be retrieved the expected number
174        self.assertEqual(1, test_data.call_counts["hostingenvuri"])
175        self.assertEqual(1, patch_report.call_count)
176
177    def test_call_storage_kwargs(self, *args):  # pylint: disable=unused-argument
178        with patch.object(restutil, 'http_get') as http_patch:
179            http_req = restutil.http_get
180            url = testurl
181            headers = {}
182
183            # no kwargs -- Default to True
184            WireClient.call_storage_service(http_req)
185
186            # kwargs, no use_proxy -- Default to True
187            WireClient.call_storage_service(http_req,
188                                            url,
189                                            headers)
190
191            # kwargs, use_proxy None -- Default to True
192            WireClient.call_storage_service(http_req,
193                                            url,
194                                            headers,
195                                            use_proxy=None)
196
197            # kwargs, use_proxy False -- Keep False
198            WireClient.call_storage_service(http_req,
199                                            url,
200                                            headers,
201                                            use_proxy=False)
202
203            # kwargs, use_proxy True -- Keep True
204            WireClient.call_storage_service(http_req,
205                                            url,
206                                            headers,
207                                            use_proxy=True)
208            # assert
209            self.assertTrue(http_patch.call_count == 5)
210            for i in range(0, 5):
211                c = http_patch.call_args_list[i][-1]['use_proxy']
212                self.assertTrue(c == (True if i != 3 else False))
213
214    def test_status_blob_parsing(self, *args):  # pylint: disable=unused-argument
215        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
216            self.assertEqual(protocol.client.get_ext_conf().status_upload_blob,
217                             'https://test.blob.core.windows.net/vhds/test-cs12.test-cs12.test-cs12.status?'
218                             'sr=b&sp=rw&se=9999-01-01&sk=key1&sv=2014-02-14&'
219                             'sig=hfRh7gzUE7sUtYwke78IOlZOrTRCYvkec4hGZ9zZzXo')
220            self.assertEqual(protocol.client.get_ext_conf().status_upload_blob_type, u'BlockBlob')
221
222    def test_get_host_ga_plugin(self, *args):  # pylint: disable=unused-argument
223        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
224            host_plugin = protocol.client.get_host_plugin()
225            goal_state = protocol.client.get_goal_state()
226            self.assertEqual(goal_state.container_id, host_plugin.container_id)
227            self.assertEqual(goal_state.role_config_name, host_plugin.role_config_name)
228
229    def test_upload_status_blob_should_use_the_host_channel_by_default(self, *_):
230        def http_put_handler(url, *_, **__):  # pylint: disable=inconsistent-return-statements
231            if protocol.get_endpoint() in url and url.endswith('/status'):
232                return MockResponse(body=b'', status_code=200)
233
234        with mock_wire_protocol(mockwiredata.DATA_FILE, http_put_handler=http_put_handler) as protocol:
235            HostPluginProtocol.is_default_channel = False
236            protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready")
237
238            protocol.client.upload_status_blob()
239
240            urls = protocol.get_tracked_urls()
241            self.assertEqual(len(urls), 1, 'Expected one post request to the host: [{0}]'.format(urls))
242
243    def test_upload_status_blob_host_ga_plugin(self, *_):
244        with create_mock_protocol(status_upload_blob=testurl, status_upload_blob_type=testtype) as protocol:
245            protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready")
246
247            with patch.object(HostPluginProtocol, "ensure_initialized", return_value=True):
248                with patch.object(StatusBlob, "upload", return_value=False) as patch_default_upload:
249                    with patch.object(HostPluginProtocol, "_put_block_blob_status") as patch_http:
250                        HostPluginProtocol.is_default_channel = False
251                        protocol.client.upload_status_blob()
252                        patch_default_upload.assert_not_called()
253                        patch_http.assert_called_once_with(testurl, protocol.client.status_blob)
254                        self.assertFalse(HostPluginProtocol.is_default_channel)
255
256    @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized")
257    def test_upload_status_blob_unknown_type_assumes_block(self, *_):
258        with create_mock_protocol(status_upload_blob=testurl, status_upload_blob_type="NotALegalType") as protocol:
259            protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready")
260
261            with patch.object(StatusBlob, "prepare") as patch_prepare:
262                with patch.object(StatusBlob, "upload") as patch_default_upload:
263                    HostPluginProtocol.is_default_channel = False
264                    protocol.client.upload_status_blob()
265
266                    patch_prepare.assert_called_once_with("BlockBlob")
267                    patch_default_upload.assert_called_once_with(testurl)
268
269    def test_upload_status_blob_reports_prepare_error(self, *_):
270        with create_mock_protocol(status_upload_blob=testurl, status_upload_blob_type=testtype) as protocol:
271            protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready")
272
273            with patch.object(StatusBlob, "prepare", side_effect=Exception) as mock_prepare:
274                self.assertRaises(ProtocolError, protocol.client.upload_status_blob)
275                self.assertEqual(1, mock_prepare.call_count)
276
277    def test_get_in_vm_artifacts_profile_blob_not_available(self, *_):
278        # Test when artifacts_profile_blob is null/None
279        with mock_wire_protocol(DATA_FILE_NO_EXT) as protocol:
280            protocol.client._goal_state.ext_conf = ExtensionsConfig(None)  # pylint: disable=protected-access
281
282            self.assertEqual(None, protocol.client.get_artifacts_profile())
283
284        # Test when artifacts_profile_blob is whitespace
285        with create_mock_protocol(artifacts_profile_blob="  ") as protocol:
286            self.assertEqual(None, protocol.client.get_artifacts_profile())
287
288    def test_get_in_vm_artifacts_profile_response_body_not_valid(self, *_):
289        with create_mock_protocol(artifacts_profile_blob=testurl) as protocol:
290            with patch.object(HostPluginProtocol, "get_artifact_request", return_value=['dummy_url', {}]) as host_plugin_get_artifact_url_and_headers:
291                # Test when response body is None
292                protocol.client.call_storage_service = Mock(return_value=MockResponse(None, 200))
293                in_vm_artifacts_profile = protocol.client.get_artifacts_profile()
294                self.assertTrue(in_vm_artifacts_profile is None)
295
296                # Test when response body is None
297                protocol.client.call_storage_service = Mock(return_value=MockResponse('   '.encode('utf-8'), 200))
298                in_vm_artifacts_profile = protocol.client.get_artifacts_profile()
299                self.assertTrue(in_vm_artifacts_profile is None)
300
301                # Test when response body is None
302                protocol.client.call_storage_service = Mock(return_value=MockResponse('{ }'.encode('utf-8'), 200))
303                in_vm_artifacts_profile = protocol.client.get_artifacts_profile()
304                self.assertEqual(dict(), in_vm_artifacts_profile.__dict__,
305                                 'If artifacts_profile_blob has empty json dictionary, in_vm_artifacts_profile '
306                                 'should contain nothing')
307
308                host_plugin_get_artifact_url_and_headers.assert_called_with(testurl)
309
310    @patch("azurelinuxagent.common.event.add_event")
311    def test_artifacts_profile_json_parsing(self, patch_event, *args):  # pylint: disable=unused-argument
312        with create_mock_protocol(artifacts_profile_blob=testurl) as protocol:
313            # response is invalid json
314            protocol.client.call_storage_service = Mock(return_value=MockResponse("invalid json".encode('utf-8'), 200))
315            in_vm_artifacts_profile = protocol.client.get_artifacts_profile()
316
317            # ensure response is empty
318            self.assertEqual(None, in_vm_artifacts_profile)
319
320            # ensure event is logged
321            self.assertEqual(1, patch_event.call_count)
322            self.assertFalse(patch_event.call_args[1]['is_success'])
323            self.assertTrue('invalid json' in patch_event.call_args[1]['message'])
324            self.assertEqual('ArtifactsProfileBlob', patch_event.call_args[1]['op'])
325
326    def test_get_in_vm_artifacts_profile_default(self, *args):  # pylint: disable=unused-argument
327        with create_mock_protocol(artifacts_profile_blob=testurl) as protocol:
328            protocol.client.call_storage_service = Mock(return_value=MockResponse('{"onHold": "true"}'.encode('utf-8'), 200))
329            in_vm_artifacts_profile = protocol.client.get_artifacts_profile()
330            self.assertEqual(dict(onHold='true'), in_vm_artifacts_profile.__dict__)
331            self.assertTrue(in_vm_artifacts_profile.is_on_hold())
332
333    @patch("socket.gethostname", return_value="hostname")
334    @patch("time.gmtime", return_value=time.localtime(1485543256))
335    def test_report_vm_status(self, *args):  # pylint: disable=unused-argument
336        status = 'status'
337        message = 'message'
338
339        client = WireProtocol(WIRESERVER_URL).client
340        actual = StatusBlob(client=client)
341        actual.set_vm_status(VMStatus(status=status, message=message))
342        timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
343
344        formatted_msg = {
345            'lang': 'en-US',
346            'message': message
347        }
348        v1_ga_status = {
349            'version': str(CURRENT_VERSION),
350            'status': status,
351            'formattedMessage': formatted_msg
352        }
353        v1_ga_guest_info = {
354            'computerName': socket.gethostname(),
355            'osName': DISTRO_NAME,
356            'osVersion': DISTRO_VERSION,
357            'version': str(CURRENT_VERSION),
358        }
359        v1_agg_status = {
360            'guestAgentStatus': v1_ga_status,
361            'handlerAggregateStatus': []
362        }
363        v1_vm_status = {
364            'version': '1.1',
365            'timestampUTC': timestamp,
366            'aggregateStatus': v1_agg_status,
367            'guestOSInfo': v1_ga_guest_info
368        }
369        self.assertEqual(json.dumps(v1_vm_status), actual.to_json())
370
371    def test_it_should_report_supported_features_in_status_blob_if_supported(self, *_):
372        with mock_wire_protocol(DATA_FILE) as protocol:
373
374            def mock_http_put(url, *args, **__):
375                if not HttpRequestPredicates.is_host_plugin_status_request(url):
376                    # Skip reading the HostGA request data as its encoded
377                    protocol.aggregate_status = json.loads(args[0])
378
379            protocol.aggregate_status = {}
380            protocol.set_http_handlers(http_put_handler=mock_http_put)
381            exthandlers_handler = get_exthandlers_handler(protocol)
382
383            with patch("azurelinuxagent.common.agent_supported_feature._MultiConfigFeature.is_supported", True):
384                exthandlers_handler.run()
385                self.assertIsNotNone(protocol.aggregate_status, "Aggregate status should not be None")
386                self.assertIn("supportedFeatures", protocol.aggregate_status, "supported features not reported")
387                multi_config_feature = get_supported_feature_by_name(SupportedFeatureNames.MultiConfig)
388                found = False
389                for feature in protocol.aggregate_status['supportedFeatures']:
390                    if feature['Key'] == multi_config_feature.name and feature['Value'] == multi_config_feature.version:
391                        found = True
392                        break
393                self.assertTrue(found, "Multi-config name should be present in supportedFeatures")
394
395            # Feature should not be reported if not present
396            with patch("azurelinuxagent.common.agent_supported_feature._MultiConfigFeature.is_supported", False):
397                exthandlers_handler.run()
398                self.assertIsNotNone(protocol.aggregate_status, "Aggregate status should not be None")
399                if "supportedFeatures" not in protocol.aggregate_status:
400                    # In the case Multi-config was the only feature available, 'supportedFeatures' should not be
401                    # reported in the status blob as its not supported as of now.
402                    # Asserting no other feature was available to report back to crp
403                    self.assertEqual(1, len(get_agent_supported_features_list_for_crp()),
404                                     "supportedFeatures should be available if there are more features")
405                    return
406
407                # If there are other features available, confirm MultiConfig was not reported
408                multi_config_feature = get_supported_feature_by_name(SupportedFeatureNames.MultiConfig)
409                found = False
410                for feature in protocol.aggregate_status['supportedFeatures']:
411                    if feature['Key'] == multi_config_feature.name and feature['Value'] == multi_config_feature.version:
412                        found = True
413                        break
414                self.assertFalse(found, "Multi-config name should be present in supportedFeatures")
415
416    @patch("azurelinuxagent.common.utils.restutil.http_request")
417    def test_send_event(self, mock_http_request, *args):
418        mock_http_request.return_value = MockResponse("", 200)
419
420        event_str = u'a test string'
421        client = WireProtocol(WIRESERVER_URL).client
422        client.send_event("foo", event_str.encode('utf-8'))
423
424        first_call = mock_http_request.call_args_list[0]
425        args, kwargs = first_call
426        method, url, body_received = args  # pylint: disable=unused-variable
427        headers = kwargs['headers']
428
429        # the headers should include utf-8 encoding...
430        self.assertTrue("utf-8" in headers['Content-Type'])
431        # the body is not encoded, just check for equality
432        self.assertIn(event_str, body_received)
433
434    @patch("azurelinuxagent.common.protocol.wire.WireClient.send_event")
435    def test_report_event_small_event(self, patch_send_event, *args):  # pylint: disable=unused-argument
436        event_list = []
437        client = WireProtocol(WIRESERVER_URL).client
438
439        event_str = random_generator(10)
440        event_list.append(get_event(message=event_str))
441
442        event_str = random_generator(100)
443        event_list.append(get_event(message=event_str))
444
445        event_str = random_generator(1000)
446        event_list.append(get_event(message=event_str))
447
448        event_str = random_generator(10000)
449        event_list.append(get_event(message=event_str))
450
451        client.report_event(self._get_telemetry_events_generator(event_list))
452
453        # It merges the messages into one message
454        self.assertEqual(patch_send_event.call_count, 1)
455
456    @patch("azurelinuxagent.common.protocol.wire.WireClient.send_event")
457    def test_report_event_multiple_events_to_fill_buffer(self, patch_send_event, *args):  # pylint: disable=unused-argument
458        event_list = []
459        client = WireProtocol(WIRESERVER_URL).client
460
461        event_str = random_generator(2 ** 15)
462        event_list.append(get_event(message=event_str))
463        event_list.append(get_event(message=event_str))
464
465        client.report_event(self._get_telemetry_events_generator(event_list))
466
467        # It merges the messages into one message
468        self.assertEqual(patch_send_event.call_count, 2)
469
470    @patch("azurelinuxagent.common.protocol.wire.WireClient.send_event")
471    def test_report_event_large_event(self, patch_send_event, *args):  # pylint: disable=unused-argument
472        event_list = []
473        event_str = random_generator(2 ** 18)
474        event_list.append(get_event(message=event_str))
475        client = WireProtocol(WIRESERVER_URL).client
476        client.report_event(self._get_telemetry_events_generator(event_list))
477
478        self.assertEqual(patch_send_event.call_count, 0)
479
480
481class TestWireClient(HttpRequestPredicates, AgentTestCase):
482    def test_get_ext_conf_without_extensions_should_retrieve_vmagent_manifests_info(self, *args):  # pylint: disable=unused-argument
483        # Basic test for get_ext_conf() when extensions are not present in the config. The test verifies that
484        # get_ext_conf() fetches the correct data by comparing the returned data with the test data provided the
485        # mock_wire_protocol.
486        with mock_wire_protocol(mockwiredata.DATA_FILE_NO_EXT) as protocol:
487            ext_conf = protocol.client.get_ext_conf()
488
489            ext_handlers_names = [ext_handler.name for ext_handler in ext_conf.ext_handlers.extHandlers]
490            self.assertEqual(0, len(ext_conf.ext_handlers.extHandlers),
491                             "Unexpected number of extension handlers in the extension config: [{0}]".format(ext_handlers_names))
492            vmagent_manifests = [manifest.family for manifest in ext_conf.vmagent_manifests.vmAgentManifests]
493            self.assertEqual(0, len(ext_conf.vmagent_manifests.vmAgentManifests),
494                             "Unexpected number of vmagent manifests in the extension config: [{0}]".format(vmagent_manifests))
495            self.assertIsNone(ext_conf.status_upload_blob,
496                              "Status upload blob in the extension config is expected to be None")
497            self.assertIsNone(ext_conf.status_upload_blob_type,
498                              "Type of status upload blob in the extension config is expected to be None")
499            self.assertIsNone(ext_conf.artifacts_profile_blob,
500                              "Artifacts profile blob in the extensions config is expected to be None")
501
502    def test_get_ext_conf_with_extensions_should_retrieve_ext_handlers_and_vmagent_manifests_info(self):
503        # Basic test for get_ext_conf() when extensions are present in the config. The test verifies that get_ext_conf()
504        # fetches the correct data by comparing the returned data with the test data provided the mock_wire_protocol.
505        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
506            wire_protocol_client = protocol.client
507            ext_conf = wire_protocol_client.get_ext_conf()
508
509            ext_handlers_names = [ext_handler.name for ext_handler in ext_conf.ext_handlers.extHandlers]
510            self.assertEqual(1, len(ext_conf.ext_handlers.extHandlers),
511                             "Unexpected number of extension handlers in the extension config: [{0}]".format(ext_handlers_names))
512            vmagent_manifests = [manifest.family for manifest in ext_conf.vmagent_manifests.vmAgentManifests]
513            self.assertEqual(2, len(ext_conf.vmagent_manifests.vmAgentManifests),
514                             "Unexpected number of vmagent manifests in the extension config: [{0}]".format(vmagent_manifests))
515            self.assertEqual("https://test.blob.core.windows.net/vhds/test-cs12.test-cs12.test-cs12.status?sr=b&sp=rw"
516                             "&se=9999-01-01&sk=key1&sv=2014-02-14&sig=hfRh7gzUE7sUtYwke78IOlZOrTRCYvkec4hGZ9zZzXo",
517                             ext_conf.status_upload_blob, "Unexpected value for status upload blob URI")
518            self.assertEqual("BlockBlob", ext_conf.status_upload_blob_type,
519                             "Unexpected status upload blob type in the extension config")
520            self.assertEqual(None, ext_conf.artifacts_profile_blob,
521                             "Artifacts profile blob in the extension config should have been None")
522
523    def test_download_ext_handler_pkg_should_not_invoke_host_channel_when_direct_channel_succeeds(self):
524        extension_url = 'https://fake_host/fake_extension.zip'
525        target_file = os.path.join(self.tmp_dir, 'fake_extension.zip')
526
527        def http_get_handler(url, *_, **__):
528            if url == extension_url:
529                return MockResponse(body=b'', status_code=200)
530            if self.is_host_plugin_extension_artifact_request(url):
531                self.fail('The host channel should not have been used')
532            return None
533
534        with mock_wire_protocol(mockwiredata.DATA_FILE, http_get_handler=http_get_handler) as protocol:
535            HostPluginProtocol.is_default_channel = False
536
537            success = protocol.download_ext_handler_pkg(extension_url, target_file)
538
539            urls = protocol.get_tracked_urls()
540            self.assertTrue(success, "The download should have succeeded")
541            self.assertEqual(len(urls), 1, "Unexpected number of HTTP requests: [{0}]".format(urls))
542            self.assertEqual(urls[0], extension_url, "The extension should have been downloaded over the direct channel")
543            self.assertTrue(os.path.exists(target_file), "The extension package was not downloaded")
544            self.assertFalse(HostPluginProtocol.is_default_channel, "The host channel should not have been set as the default")
545
546    def test_download_ext_handler_pkg_should_use_host_channel_when_direct_channel_fails_and_set_host_as_default(self):
547        extension_url = 'https://fake_host/fake_extension.zip'
548        target_file = os.path.join(self.tmp_dir, 'fake_extension.zip')
549
550        def http_get_handler(url, *_, **kwargs):
551            if url == extension_url:
552                return HttpError("Exception to fake an error on the direct channel")
553            if self.is_host_plugin_extension_request(url, kwargs, extension_url):
554                return MockResponse(body=b'', status_code=200)
555            return None
556
557        with mock_wire_protocol(mockwiredata.DATA_FILE, http_get_handler=http_get_handler) as protocol:
558            HostPluginProtocol.is_default_channel = False
559
560            success = protocol.download_ext_handler_pkg(extension_url, target_file)
561
562            urls = protocol.get_tracked_urls()
563            self.assertTrue(success, "The download should have succeeded")
564            self.assertEqual(len(urls), 2, "Unexpected number of HTTP requests: [{0}]".format(urls))
565            self.assertEqual(urls[0], extension_url, "The first attempt should have been over the direct channel")
566            self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The retry attempt should have been over the host channel")
567            self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded')
568            self.assertTrue(HostPluginProtocol.is_default_channel, "The host channel should have been set as the default")
569
570    def test_download_ext_handler_pkg_should_retry_the_host_channel_after_refreshing_host_plugin(self):
571        extension_url = 'https://fake_host/fake_extension.zip'
572        target_file = os.path.join(self.tmp_dir, 'fake_extension.zip')
573
574        def http_get_handler(url, *_, **kwargs):
575            if url == extension_url:
576                return HttpError("Exception to fake an error on the direct channel")
577            if self.is_host_plugin_extension_request(url, kwargs, extension_url):
578                # fake a stale goal state then succeed once the goal state has been refreshed
579                if http_get_handler.goal_state_requests == 0:
580                    http_get_handler.goal_state_requests += 1
581                    return ResourceGoneError("Exception to fake a stale goal")
582                return MockResponse(body=b'', status_code=200)
583            if self.is_goal_state_request(url):
584                protocol.track_url(url)  # track requests for the goal state
585            return None
586        http_get_handler.goal_state_requests = 0
587
588        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
589            HostPluginProtocol.is_default_channel = False
590
591            try:
592                # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests.
593                protocol.client.get_host_plugin()
594
595                protocol.set_http_handlers(http_get_handler=http_get_handler)
596
597                success = protocol.download_ext_handler_pkg(extension_url, target_file)
598
599                urls = protocol.get_tracked_urls()
600                self.assertTrue(success, "The download should have succeeded")
601                self.assertEqual(len(urls), 4, "Unexpected number of HTTP requests: [{0}]".format(urls))
602                self.assertEqual(urls[0], extension_url, "The first attempt should have been over the direct channel")
603                self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second attempt should have been over the host channel")
604                self.assertTrue(self.is_goal_state_request(urls[2]), "The host channel should have been refreshed the goal state")
605                self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The third attempt should have been over the host channel")
606                self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded')
607                self.assertTrue(HostPluginProtocol.is_default_channel, "The host channel should have been set as the default")
608            finally:
609                HostPluginProtocol.is_default_channel = False
610
611    def test_download_ext_handler_pkg_should_not_change_default_channel_when_all_channels_fail(self):
612        extension_url = 'https://fake_host/fake_extension.zip'
613        target_file = os.path.join(self.tmp_dir, "fake_extension.zip")
614
615        def http_get_handler(url, *_, **kwargs):
616            if url == extension_url or self.is_host_plugin_extension_request(url, kwargs, extension_url):
617                return MockResponse(body=b"content not found", status_code=404, reason="Not Found")
618            if self.is_goal_state_request(url):
619                protocol.track_url(url)  # keep track of goal state requests
620            return None
621
622        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
623            HostPluginProtocol.is_default_channel = False
624
625            # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests.
626            protocol.client.get_host_plugin()
627
628            protocol.set_http_handlers(http_get_handler=http_get_handler)
629
630            success = protocol.download_ext_handler_pkg(extension_url, target_file)
631
632            urls = protocol.get_tracked_urls()
633            self.assertFalse(success, "The download should have failed")
634            self.assertEqual(len(urls), 2, "Unexpected number of HTTP requests: [{0}]".format(urls))
635            self.assertEqual(urls[0], extension_url, "The first attempt should have been over the direct channel")
636            self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second attempt should have been over the host channel")
637            self.assertFalse(os.path.exists(target_file), "The extension package was downloaded and it shouldn't have")
638            self.assertFalse(HostPluginProtocol.is_default_channel, "The host channel should not have been set as the default")
639
640    def test_fetch_manifest_should_not_invoke_host_channel_when_direct_channel_succeeds(self):
641        manifest_url = 'https://fake_host/fake_manifest.xml'
642        manifest_xml = '<?xml version="1.0" encoding="utf-8"?><PluginVersionManifest/>'
643
644        def http_get_handler(url, *_, **__):
645            if url == manifest_url:
646                return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200)
647            if url.endswith('/extensionArtifact'):
648                self.fail('The Host GA Plugin should not have been invoked')
649            return None
650
651        with mock_wire_protocol(mockwiredata.DATA_FILE, http_get_handler=http_get_handler) as protocol:
652            HostPluginProtocol.is_default_channel = False
653
654            manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)])
655
656            urls = protocol.get_tracked_urls()
657            self.assertEqual(manifest, manifest_xml, 'The expected manifest was not downloaded')
658            self.assertEqual(len(urls), 1, "Unexpected number of HTTP requests: [{0}]".format(urls))
659            self.assertEqual(urls[0], manifest_url, "The manifest should have been downloaded over the direct channel")
660            self.assertFalse(HostPluginProtocol.is_default_channel, "The default channel should not have changed")
661
662    def test_fetch_manifest_should_use_host_channel_when_direct_channel_fails_and_set_it_to_default(self):
663        manifest_url = 'https://fake_host/fake_manifest.xml'
664        manifest_xml = '<?xml version="1.0" encoding="utf-8"?><PluginVersionManifest/>'
665
666        def http_get_handler(url, *_, **kwargs):
667            if url == manifest_url:
668                return ResourceGoneError("Exception to fake an error on the direct channel")
669            if self.is_host_plugin_extension_request(url, kwargs, manifest_url):
670                return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200)
671            return None
672
673        with mock_wire_protocol(mockwiredata.DATA_FILE, http_get_handler=http_get_handler) as protocol:
674            HostPluginProtocol.is_default_channel = False
675
676            try:
677                manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)])
678
679                urls = protocol.get_tracked_urls()
680                self.assertEqual(manifest, manifest_xml, 'The expected manifest was not downloaded')
681                self.assertEqual(len(urls), 2, "Unexpected number of HTTP requests: [{0}]".format(urls))
682                self.assertEqual(urls[0], manifest_url, "The first attempt should have been over the direct channel")
683                self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The retry should have been over the host channel")
684                self.assertTrue(HostPluginProtocol.is_default_channel, "The host should have been set as the default channel")
685            finally:
686                HostPluginProtocol.is_default_channel = False  # Reset default channel
687
688    def test_fetch_manifest_should_retry_the_host_channel_after_refreshing_the_host_plugin_and_set_the_host_as_default(self):
689        manifest_url = 'https://fake_host/fake_manifest.xml'
690        manifest_xml = '<?xml version="1.0" encoding="utf-8"?><PluginVersionManifest/>'
691
692        def http_get_handler(url, *_, **kwargs):
693            if url == manifest_url:
694                return HttpError("Exception to fake an error on the direct channel")
695            if self.is_host_plugin_extension_request(url, kwargs, manifest_url):
696                # fake a stale goal state then succeed once the goal state has been refreshed
697                if http_get_handler.goal_state_requests == 0:
698                    http_get_handler.goal_state_requests += 1
699                    return ResourceGoneError("Exception to fake a stale goal state")
700                return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200)
701            elif self.is_goal_state_request(url):
702                protocol.track_url(url)  # keep track of goal state requests
703            return None
704        http_get_handler.goal_state_requests = 0
705
706        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
707            HostPluginProtocol.is_default_channel = False
708
709            try:
710                # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests.
711                protocol.client.get_host_plugin()
712
713                protocol.set_http_handlers(http_get_handler=http_get_handler)
714                manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)])
715
716                urls = protocol.get_tracked_urls()
717                self.assertEqual(manifest, manifest_xml)
718                self.assertEqual(len(urls), 4, "Unexpected number of HTTP requests: [{0}]".format(urls))
719                self.assertEqual(urls[0], manifest_url, "The first attempt should have been over the direct channel")
720                self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second attempt should have been over the host channel")
721                self.assertTrue(self.is_goal_state_request(urls[2]), "The host channel should have been refreshed the goal state")
722                self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The third attempt should have been over the host channel")
723                self.assertTrue(HostPluginProtocol.is_default_channel, "The host should have been set as the default channel")
724            finally:
725                HostPluginProtocol.is_default_channel = False  # Reset default channel
726
727    def test_fetch_manifest_should_update_goal_state_and_not_change_default_channel_if_host_fails(self):
728        manifest_url = 'https://fake_host/fake_manifest.xml'
729
730        def http_get_handler(url, *_, **kwargs):
731            if url == manifest_url or self.is_host_plugin_extension_request(url, kwargs, manifest_url):
732                return ResourceGoneError("Exception to fake an error on either channel")
733            elif self.is_goal_state_request(url):
734                protocol.track_url(url)  # keep track of goal state requests
735            return None
736
737        # Everything fails. Goal state should have been updated and host channel should not have been set as default.
738        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
739            HostPluginProtocol.is_default_channel = False
740
741            # initialization of the host plugin triggers a request for the goal state; do it here before we start
742            # tracking those requests.
743            protocol.client.get_host_plugin()
744
745            protocol.set_http_handlers(http_get_handler=http_get_handler)
746
747            with self.assertRaises(ExtensionDownloadError):
748                protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)])
749
750            urls = protocol.get_tracked_urls()
751            self.assertEqual(len(urls), 4, "Unexpected number of HTTP requests: [{0}]".format(urls))
752            self.assertEqual(urls[0], manifest_url, "The first attempt should have been over the direct channel")
753            self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second attempt should have been over the host channel")
754            self.assertTrue(self.is_goal_state_request(urls[2]), "The host channel should have been refreshed the goal state")
755            self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The third attempt should have been over the host channel")
756            self.assertFalse(HostPluginProtocol.is_default_channel, "The host should not have been set as the default channel")
757
758            self.assertEqual(HostPluginProtocol.is_default_channel, False)
759
760    def test_get_artifacts_profile_should_not_invoke_host_channel_when_direct_channel_succeeds(self):
761        def http_get_handler(url, *_, **__):
762            if self.is_in_vm_artifacts_profile_request(url):
763                protocol.track_url(url)
764            return None
765
766        with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE, http_get_handler=http_get_handler) as protocol:
767            HostPluginProtocol.is_default_channel = False
768
769            return_value = protocol.client.get_artifacts_profile()
770
771            self.assertIsInstance(return_value, InVMArtifactsProfile, 'The request did not return a valid artifacts profile: {0}'.format(return_value))
772            urls = protocol.get_tracked_urls()
773            self.assertEqual(len(urls), 1, "Unexpected HTTP requests: [{0}]".format(urls))
774            self.assertFalse(HostPluginProtocol.is_default_channel, "The host should not have been set as the default channel")
775
776    def test_get_artifacts_profile_should_use_host_channel_when_direct_channel_fails(self):
777        def http_get_handler(url, *_, **kwargs):
778            if self.is_in_vm_artifacts_profile_request(url):
779                return HttpError("Exception to fake an error on the direct channel")
780            if self.is_host_plugin_in_vm_artifacts_profile_request(url, kwargs):
781                protocol.track_url(url)
782            return None
783
784        with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE) as protocol:
785            HostPluginProtocol.is_default_channel = False
786
787            try:
788                protocol.set_http_handlers(http_get_handler=http_get_handler)
789
790                return_value = protocol.client.get_artifacts_profile()
791
792                self.assertIsNotNone(return_value, "The artifacts profile request should have succeeded")
793                self.assertIsInstance(return_value, InVMArtifactsProfile, 'The request did not return a valid artifacts profile: {0}'.format(return_value))
794                self.assertTrue(return_value.onHold, 'The OnHold property should be True')  # pylint: disable=no-member
795                urls = protocol.get_tracked_urls()
796                self.assertEqual(len(urls), 2, "Invalid number of requests: [{0}]".format(urls))
797                self.assertTrue(self.is_in_vm_artifacts_profile_request(urls[0]), "The first request should have been over the direct channel")
798                self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second request should have been over the host channel")
799                self.assertTrue(HostPluginProtocol.is_default_channel, "The default channel should have changed to the host")
800            finally:
801                HostPluginProtocol.is_default_channel = False
802
803    def test_get_artifacts_profile_should_retry_the_host_channel_after_refreshing_the_host_plugin(self):
804        def http_get_handler(url, *_, **kwargs):
805            if self.is_in_vm_artifacts_profile_request(url):
806                return HttpError("Exception to fake an error on the direct channel")
807            if self.is_host_plugin_in_vm_artifacts_profile_request(url, kwargs):
808                if http_get_handler.host_plugin_calls == 0:
809                    http_get_handler.host_plugin_calls += 1
810                    return ResourceGoneError("Exception to fake a stale goal state")
811                protocol.track_url(url)
812            if self.is_goal_state_request(url):
813                protocol.track_url(url)
814            return None
815        http_get_handler.host_plugin_calls = 0
816
817        with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE) as protocol:
818            HostPluginProtocol.is_default_channel = False
819
820            try:
821                # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests.
822                protocol.client.get_host_plugin()
823
824                protocol.set_http_handlers(http_get_handler=http_get_handler)
825
826                return_value = protocol.client.get_artifacts_profile()
827
828                self.assertIsNotNone(return_value, "The artifacts profile request should have succeeded")
829                self.assertIsInstance(return_value, InVMArtifactsProfile, 'The request did not return a valid artifacts profile: {0}'.format(return_value))
830                self.assertTrue(return_value.onHold, 'The OnHold property should be True')  # pylint: disable=no-member
831                urls = protocol.get_tracked_urls()
832                self.assertEqual(len(urls), 4, "Invalid number of requests: [{0}]".format(urls))
833                self.assertTrue(self.is_in_vm_artifacts_profile_request(urls[0]), "The first request should have been over the direct channel")
834                self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second request should have been over the host channel")
835                self.assertTrue(self.is_goal_state_request(urls[2]), "The goal state should have been refreshed before retrying the host channel")
836                self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The retry request should have been over the host channel")
837                self.assertTrue(HostPluginProtocol.is_default_channel, "The default channel should have changed to the host")
838            finally:
839                HostPluginProtocol.is_default_channel = False
840
841    def test_get_artifacts_profile_should_refresh_the_host_plugin_and_not_change_default_channel_if_host_plugin_fails(self):
842        def http_get_handler(url, *_, **kwargs):
843            if self.is_in_vm_artifacts_profile_request(url):
844                return HttpError("Exception to fake an error on the direct channel")
845            if self.is_host_plugin_in_vm_artifacts_profile_request(url, kwargs):
846                return ResourceGoneError("Exception to fake a stale goal state")
847            if self.is_goal_state_request(url):
848                protocol.track_url(url)
849            return None
850
851        with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE) as protocol:
852            HostPluginProtocol.is_default_channel = False
853
854            # initialization of the host plugin triggers a request for the goal state; do it here before we start tracking those requests.
855            protocol.client.get_host_plugin()
856
857            protocol.set_http_handlers(http_get_handler=http_get_handler)
858
859            return_value = protocol.client.get_artifacts_profile()
860
861            self.assertIsNone(return_value, "The artifacts profile request should have failed")
862            urls = protocol.get_tracked_urls()
863            self.assertEqual(len(urls), 4, "Invalid number of requests: [{0}]".format(urls))
864            self.assertTrue(self.is_in_vm_artifacts_profile_request(urls[0]), "The first request should have been over the direct channel")
865            self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[1]), "The second request should have been over the host channel")
866            self.assertTrue(self.is_goal_state_request(urls[2]), "The goal state should have been refreshed before retrying the host channel")
867            self.assertTrue(self.is_host_plugin_extension_artifact_request(urls[3]), "The retry request should have been over the host channel")
868            self.assertFalse(HostPluginProtocol.is_default_channel, "The default channel should not have changed")
869
870    def test_upload_logs_should_not_refresh_plugin_when_first_attempt_succeeds(self):
871        def http_put_handler(url, *_, **__):  # pylint: disable=inconsistent-return-statements
872            if self.is_host_plugin_put_logs_request(url):
873                return MockResponse(body=b'', status_code=200)
874
875        with mock_wire_protocol(mockwiredata.DATA_FILE, http_put_handler=http_put_handler) as protocol:
876            content = b"test"
877            protocol.client.upload_logs(content)
878
879            urls = protocol.get_tracked_urls()
880            self.assertEqual(len(urls), 1, 'Expected one post request to the host: [{0}]'.format(urls))
881
882    def test_upload_logs_should_retry_the_host_channel_after_refreshing_the_host_plugin(self):
883        def http_put_handler(url, *_, **__):
884            if self.is_host_plugin_put_logs_request(url):
885                if http_put_handler.host_plugin_calls == 0:
886                    http_put_handler.host_plugin_calls += 1
887                    return ResourceGoneError("Exception to fake a stale goal state")
888                protocol.track_url(url)
889            return None
890        http_put_handler.host_plugin_calls = 0
891
892        with mock_wire_protocol(mockwiredata.DATA_FILE_IN_VM_ARTIFACTS_PROFILE, http_put_handler=http_put_handler) \
893                as protocol:
894            content = b"test"
895            protocol.client.upload_logs(content)
896
897            urls = protocol.get_tracked_urls()
898            self.assertEqual(len(urls), 2, "Invalid number of requests: [{0}]".format(urls))
899            self.assertTrue(self.is_host_plugin_put_logs_request(urls[0]), "The first request should have been over the host channel")
900            self.assertTrue(self.is_host_plugin_put_logs_request(urls[1]), "The second request should have been over the host channel")
901
902    @staticmethod
903    def _set_and_fail_helper_channel_functions(fail_direct=False, fail_host=False):
904        def direct_func(*_):
905            direct_func.counter += 1
906            if direct_func.fail:
907                return None
908            return "direct"
909
910        def host_func(*_):
911            host_func.counter += 1
912            if host_func.fail:
913                return None
914            return "host"
915
916        direct_func.counter = 0
917        direct_func.fail = fail_direct
918
919        host_func.counter = 0
920        host_func.fail = fail_host
921
922        return direct_func, host_func
923
924    def test_send_request_using_appropriate_channel_should_not_invoke_secondary_when_primary_channel_succeeds(self):
925        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
926            # Scenario #1: Direct channel default
927            HostPluginProtocol.is_default_channel = False
928
929            direct_func, host_func = self._set_and_fail_helper_channel_functions()
930            # Assert we're only calling the primary channel (direct) and that it succeeds.
931            for iteration in range(5):
932                ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func)
933                self.assertEqual("direct", ret)
934                self.assertEqual(iteration + 1, direct_func.counter)
935                self.assertEqual(0, host_func.counter)
936                self.assertFalse(HostPluginProtocol.is_default_channel)
937
938            # Scenario #2: Host channel default
939            HostPluginProtocol.is_default_channel = True
940            direct_func, host_func = self._set_and_fail_helper_channel_functions()
941
942            # Assert we're only calling the primary channel (host) and that it succeeds.
943            for iteration in range(5):
944                ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func)
945                self.assertEqual("host", ret)
946                self.assertEqual(0, direct_func.counter)
947                self.assertEqual(iteration + 1, host_func.counter)
948                self.assertTrue(HostPluginProtocol.is_default_channel)
949
950    def test_send_request_using_appropriate_channel_should_not_change_default_channel_if_none_succeeds(self):
951        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
952            # Scenario #1: Direct channel is default
953            HostPluginProtocol.is_default_channel = False
954            direct_func, host_func = self._set_and_fail_helper_channel_functions(fail_direct=True, fail_host=True)
955
956            # Assert we keep trying both channels, but the default channel doesn't change
957            for iteration in range(5):
958                ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func)
959                self.assertEqual(None, ret)
960                self.assertEqual(iteration + 1, direct_func.counter)
961                self.assertEqual(iteration + 1, host_func.counter)
962                self.assertFalse(HostPluginProtocol.is_default_channel)
963
964            # Scenario #2: Host channel is default
965            HostPluginProtocol.is_default_channel = True
966            direct_func, host_func = self._set_and_fail_helper_channel_functions(fail_direct=True, fail_host=True)
967
968            # Assert we keep trying both channels, but the default channel doesn't change
969            for iteration in range(5):
970                ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func)
971                self.assertEqual(None, ret)
972                self.assertEqual(iteration + 1, direct_func.counter)
973                self.assertEqual(iteration + 1, host_func.counter)
974                self.assertTrue(HostPluginProtocol.is_default_channel)
975
976    def test_send_request_using_appropriate_channel_should_change_default_channel_when_secondary_succeeds(self):
977        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
978            # Scenario #1: Direct channel is default
979            HostPluginProtocol.is_default_channel = False
980            direct_func, host_func = self._set_and_fail_helper_channel_functions(fail_direct=True, fail_host=False)
981
982            # Assert we've called both channels and the default channel changed
983            ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func)
984            self.assertEqual("host", ret)
985            self.assertEqual(1, direct_func.counter)
986            self.assertEqual(1, host_func.counter)
987            self.assertTrue(HostPluginProtocol.is_default_channel)
988
989            # If host keeps succeeding, assert we keep calling only that channel and not changing the default.
990            for iteration in range(5):
991                ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func)
992                self.assertEqual("host", ret)
993                self.assertEqual(1, direct_func.counter)
994                self.assertEqual(1 + iteration + 1, host_func.counter)
995                self.assertTrue(HostPluginProtocol.is_default_channel)
996
997            # Scenario #2: Host channel is default
998            HostPluginProtocol.is_default_channel = True
999            direct_func, host_func = self._set_and_fail_helper_channel_functions(fail_direct=False, fail_host=True)
1000
1001            # Assert we've called both channels and the default channel changed
1002            ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func)
1003            self.assertEqual("direct", ret)
1004            self.assertEqual(1, direct_func.counter)
1005            self.assertEqual(1, host_func.counter)
1006            self.assertFalse(HostPluginProtocol.is_default_channel)
1007
1008            # If direct keeps succeeding, assert we keep calling only that channel and not changing the default.
1009            for iteration in range(5):
1010                ret = protocol.client.send_request_using_appropriate_channel(direct_func, host_func)
1011                self.assertEqual("direct", ret)
1012                self.assertEqual(1 + iteration + 1, direct_func.counter)
1013                self.assertEqual(1, host_func.counter)
1014                self.assertFalse(HostPluginProtocol.is_default_channel)
1015
1016
1017class UpdateGoalStateTestCase(AgentTestCase):
1018    """
1019    Tests for WireClient.update_goal_state()
1020    """
1021
1022    def test_it_should_update_the_goal_state_and_the_host_plugin_when_the_incarnation_changes(self):
1023        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
1024            protocol.client.get_host_plugin()
1025
1026            # if the incarnation changes the behavior is the same for forced and non-forced updates
1027            for forced in [True, False]:
1028                protocol.mock_wire_data.reload()  # start each iteration of the test with fresh mock data
1029
1030                #
1031                # Update the mock data with random values; include at least one field from each of the components
1032                # in the goal state to ensure the entire state was updated. Note that numeric entities, e.g. incarnation, are
1033                # actually represented as strings in the goal state.
1034                #
1035                # Note that the shared config is not parsed by the agent, so we modify the XML data directly. Also, the
1036                # certificates are encrypted and it is hard to update a single field; instead, we update the entire list with
1037                # empty.
1038                #
1039                new_incarnation = str(uuid.uuid4())
1040                new_container_id = str(uuid.uuid4())
1041                new_role_config_name = str(uuid.uuid4())
1042                new_hosting_env_deployment_name = str(uuid.uuid4())
1043                new_shared_conf = WireProtocolData.replace_xml_attribute_value(protocol.mock_wire_data.shared_config, "Deployment", "name", str(uuid.uuid4()))
1044                new_sequence_number = str(uuid.uuid4())
1045
1046                if '<Format>Pkcs7BlobWithPfxContents</Format>' not in protocol.mock_wire_data.certs:
1047                    raise Exception('This test requires a non-empty certificate list')
1048
1049                protocol.mock_wire_data.set_incarnation(new_incarnation)
1050                protocol.mock_wire_data.set_container_id(new_container_id)
1051                protocol.mock_wire_data.set_role_config_name(new_role_config_name)
1052                protocol.mock_wire_data.set_hosting_env_deployment_name(new_hosting_env_deployment_name)
1053                protocol.mock_wire_data.shared_config = new_shared_conf
1054                protocol.mock_wire_data.set_extensions_config_sequence_number(new_sequence_number)
1055                protocol.mock_wire_data.certs = r'''<?xml version="1.0" encoding="utf-8"?>
1056                    <CertificateFile><Version>2012-11-30</Version>
1057                      <Incarnation>12</Incarnation>
1058                      <Format>CertificatesNonPfxPackage</Format>
1059                      <Data>NotPFXData</Data>
1060                    </CertificateFile>
1061                '''
1062
1063                if forced:
1064                    protocol.client.update_goal_state(forced=True)
1065                else:
1066                    protocol.client.update_goal_state()
1067
1068                sequence_number = protocol.client.get_ext_conf().ext_handlers.extHandlers[0].properties.extensions[0].sequenceNumber
1069
1070                self.assertEqual(protocol.client.get_goal_state().incarnation, new_incarnation)
1071                self.assertEqual(protocol.client.get_hosting_env().deployment_name, new_hosting_env_deployment_name)
1072                self.assertEqual(protocol.client.get_shared_conf().xml_text, new_shared_conf)
1073                self.assertEqual(sequence_number, new_sequence_number)
1074                self.assertEqual(len(protocol.client.get_certs().cert_list.certificates), 0)
1075
1076                self.assertEqual(protocol.client.get_host_plugin().container_id, new_container_id)
1077                self.assertEqual(protocol.client.get_host_plugin().role_config_name, new_role_config_name)
1078
1079    def test_non_forced_update_should_not_update_the_goal_state_nor_the_host_plugin_when_the_incarnation_does_not_change(self):
1080        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
1081            protocol.client.get_host_plugin()
1082
1083            # The container id, role config name and shared config can change without the incarnation changing; capture the initial
1084            # goal state and then change those fields.
1085            goal_state = protocol.client.get_goal_state().xml_text
1086            shared_conf = protocol.client.get_shared_conf().xml_text
1087            container_id = protocol.client.get_host_plugin().container_id
1088            role_config_name = protocol.client.get_host_plugin().role_config_name
1089
1090            protocol.mock_wire_data.set_container_id(str(uuid.uuid4()))
1091            protocol.mock_wire_data.set_role_config_name(str(uuid.uuid4()))
1092            protocol.mock_wire_data.shared_config = WireProtocolData.replace_xml_attribute_value(
1093                protocol.mock_wire_data.shared_config, "Deployment", "name", str(uuid.uuid4()))
1094
1095            protocol.client.update_goal_state()
1096
1097            self.assertEqual(protocol.client.get_goal_state().xml_text, goal_state)
1098            self.assertEqual(protocol.client.get_shared_conf().xml_text, shared_conf)
1099
1100            self.assertEqual(protocol.client.get_host_plugin().container_id, container_id)
1101            self.assertEqual(protocol.client.get_host_plugin().role_config_name, role_config_name)
1102
1103    def test_forced_update_should_update_the_goal_state_and_the_host_plugin_when_the_incarnation_does_not_change(self):
1104        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
1105            protocol.client.get_host_plugin()
1106
1107            # The container id, role config name and shared config can change without the incarnation changing
1108            incarnation = protocol.client.get_goal_state().incarnation
1109            new_container_id = str(uuid.uuid4())
1110            new_role_config_name = str(uuid.uuid4())
1111            new_shared_conf = WireProtocolData.replace_xml_attribute_value(
1112                protocol.mock_wire_data.shared_config, "Deployment", "name", str(uuid.uuid4()))
1113
1114            protocol.mock_wire_data.set_container_id(new_container_id)
1115            protocol.mock_wire_data.set_role_config_name(new_role_config_name)
1116            protocol.mock_wire_data.shared_config = new_shared_conf
1117
1118            protocol.client.update_goal_state(forced=True)
1119
1120            self.assertEqual(protocol.client.get_goal_state().incarnation, incarnation)
1121            self.assertEqual(protocol.client.get_shared_conf().xml_text, new_shared_conf)
1122
1123            self.assertEqual(protocol.client.get_host_plugin().container_id, new_container_id)
1124            self.assertEqual(protocol.client.get_host_plugin().role_config_name, new_role_config_name)
1125
1126    def test_update_goal_state_should_archive_last_goal_state(self):
1127        # We use the last modified timestamp of the goal state to be archived to determine the archive's name.
1128        mock_mtime = os.path.getmtime(self.tmp_dir)
1129        with patch("azurelinuxagent.common.utils.archive.os.path.getmtime") as patch_mtime:
1130            first_gs_ms = mock_mtime + timedelta(minutes=5).seconds
1131            second_gs_ms = mock_mtime + timedelta(minutes=10).seconds
1132            third_gs_ms = mock_mtime + timedelta(minutes=15).seconds
1133
1134            patch_mtime.side_effect = [first_gs_ms, second_gs_ms, third_gs_ms]
1135
1136            # The first goal state is created when we instantiate the protocol
1137            with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
1138                history_dir = os.path.join(conf.get_lib_dir(), "history")
1139                archives = os.listdir(history_dir)
1140                self.assertEqual(len(archives), 0, "The goal state archive should have been empty since this is the first goal state")
1141
1142                # Create the second new goal state, so the initial one should be archived
1143                protocol.mock_wire_data.set_incarnation("2")
1144                protocol.client.update_goal_state()
1145
1146                # The initial goal state should be in the archive
1147                first_archive_name = datetime.utcfromtimestamp(first_gs_ms).isoformat() + "_incarnation_1"
1148                archives = os.listdir(history_dir)
1149                self.assertEqual(len(archives), 1, "Only one goal state should have been archived")
1150                self.assertEqual(archives[0], first_archive_name, "The name of goal state archive should match the first goal state timestamp and incarnation")
1151
1152                # Create the third goal state, so the second one should be archived too
1153                protocol.mock_wire_data.set_incarnation("3")
1154                protocol.client.update_goal_state()
1155
1156                # The second goal state should be in the archive
1157                second_archive_name = datetime.utcfromtimestamp(second_gs_ms).isoformat() + "_incarnation_2"
1158                archives = os.listdir(history_dir)
1159                archives.sort()
1160                self.assertEqual(len(archives), 2, "Two goal states should have been archived")
1161                self.assertEqual(archives[1], second_archive_name, "The name of goal state archive should match the second goal state timestamp and incarnation")
1162
1163    def test_update_goal_state_should_not_persist_the_protected_settings(self):
1164        with mock_wire_protocol(mockwiredata.DATA_FILE_MULTIPLE_EXT) as protocol:
1165            # instantiating the protocol fetches the goal state, so there is no need to do another call to update_goal_state()
1166            goal_state = protocol.client.get_goal_state()
1167
1168            protected_settings = []
1169            for ext_handler in goal_state.ext_conf.ext_handlers.extHandlers:
1170                for extension in ext_handler.properties.extensions:
1171                    if extension.protectedSettings is not None:
1172                        protected_settings.append(extension.protectedSettings)
1173            if len(protected_settings) == 0:
1174                raise Exception("The test goal state does not include any protected settings")
1175
1176            extensions_config_file = os.path.join(conf.get_lib_dir(), EXT_CONF_FILE_NAME.format(goal_state.incarnation))
1177            if not os.path.exists(extensions_config_file):
1178                raise Exception("Cannot find {0}".format(extensions_config_file))
1179
1180            with open(extensions_config_file, "r") as stream:
1181                extensions_config = stream.read()
1182
1183                for settings in protected_settings:
1184                    self.assertNotIn(settings, extensions_config, "The protectedSettings should not have been saved to {0}".format(extensions_config_file))
1185
1186                matches = re.findall(r'"protectedSettings"\s*:\s*"\*\*\* REDACTED \*\*\*"', extensions_config)
1187                self.assertEqual(
1188                    len(matches),
1189                    len(protected_settings),
1190                    "Could not find the expected number of redacted settings. Expected {0}.\n{1}".format(len(protected_settings), extensions_config))
1191
1192
1193class TryUpdateGoalStateTestCase(HttpRequestPredicates, AgentTestCase):
1194    """
1195    Tests for WireClient.try_update_goal_state()
1196    """
1197    def test_it_should_return_true_on_success(self):
1198        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
1199            self.assertTrue(protocol.client.try_update_goal_state(), "try_update_goal_state should have succeeded")
1200
1201    def test_incomplete_gs_should_fail(self):
1202
1203        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
1204            GoalState.fetch_full_goal_state(protocol.client)
1205
1206            protocol.mock_wire_data.data_files = mockwiredata.DATA_FILE_NOOP_GS
1207            protocol.mock_wire_data.reload()
1208            protocol.mock_wire_data.set_incarnation(2)
1209
1210            with self.assertRaises(IncompleteGoalStateError):
1211                GoalState.fetch_full_goal_state_if_incarnation_different_than(protocol.client, 1)
1212
1213    def test_it_should_return_false_on_failure(self):
1214        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
1215            def http_get_handler(url, *_, **__):
1216                if self.is_goal_state_request(url):
1217                    return HttpError('Exception to fake an error retrieving the goal state')
1218                return None
1219
1220            protocol.set_http_handlers(http_get_handler=http_get_handler)
1221
1222            self.assertFalse(protocol.client.try_update_goal_state(), "try_update_goal_state should have failed")
1223
1224    def test_it_should_log_errors_only_when_the_error_state_changes(self):
1225        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
1226            def http_get_handler(url, *_, **__):
1227                if self.is_goal_state_request(url):
1228                    if fail_goal_state_request:
1229                        return HttpError('Exception to fake an error retrieving the goal state')
1230                return None
1231
1232            protocol.set_http_handlers(http_get_handler=http_get_handler)
1233
1234            @contextlib.contextmanager
1235            def create_log_and_telemetry_mocks():
1236                with patch("azurelinuxagent.common.protocol.wire.logger", autospec=True) as logger_patcher:
1237                    with patch("azurelinuxagent.common.protocol.wire.add_event") as add_event_patcher:
1238                        yield logger_patcher, add_event_patcher
1239
1240            calls_to_strings = lambda calls: (str(c) for c in calls)
1241            filter_calls = lambda calls, regex=None: (c for c in calls_to_strings(calls) if regex is None or re.match(regex, c))
1242            logger_calls = lambda regex=None: [m for m in filter_calls(logger.method_calls, regex)]  # pylint: disable=used-before-assignment,unnecessary-comprehension
1243            warnings = lambda: logger_calls(r'call.warn\(.*An error occurred while retrieving the goal state.*')
1244            periodic_warnings = lambda: logger_calls(r'call.periodic_warn\(.*Attempts to retrieve the goal state are failing.*')
1245            success_messages = lambda: logger_calls(r'call.info\(.*Retrieving the goal state recovered from previous errors.*')
1246            telemetry_calls = lambda regex=None: [m for m in filter_calls(add_event.mock_calls, regex)]  # pylint: disable=used-before-assignment,unnecessary-comprehension
1247            goal_state_events = lambda: telemetry_calls(r".*op='FetchGoalState'.*")
1248
1249            #
1250            # Initially calls to retrieve the goal state are successful...
1251            #
1252            fail_goal_state_request = False
1253            with create_log_and_telemetry_mocks() as (logger, add_event):
1254                protocol.client.try_update_goal_state()
1255
1256                lc = logger_calls()
1257                self.assertTrue(len(lc) == 0, "A successful call should not produce any log messages: [{0}]".format(lc))
1258
1259                tc = telemetry_calls()
1260                self.assertTrue(len(tc) == 0, "A successful call should not produce any telemetry events: [{0}]".format(tc))
1261
1262            #
1263            # ... then an error happens...
1264            #
1265            fail_goal_state_request = True
1266            with create_log_and_telemetry_mocks() as (logger, add_event):
1267                protocol.client.try_update_goal_state()
1268
1269                w = warnings()
1270                pw = periodic_warnings()
1271                self.assertEqual(len(w), 1, "A failure should have produced a warning: [{0}]".format(w))
1272                self.assertEqual(len(pw), 1, "A failure should have produced a periodic warning: [{0}]".format(pw))
1273
1274                gs = goal_state_events()
1275                self.assertTrue(len(gs) == 1 and 'is_success=False' in gs[0], "A failure should produce a telemetry event (success=false): [{0}]".format(gs))
1276
1277            #
1278            # ... and errors continue happening...
1279            #
1280            with create_log_and_telemetry_mocks() as (logger, add_event):
1281                protocol.client.try_update_goal_state()
1282                protocol.client.try_update_goal_state()
1283                protocol.client.try_update_goal_state()
1284
1285                w = warnings()
1286                pw = periodic_warnings()
1287                self.assertTrue(len(w) == 0, "Subsequent failures should not produce warnings: [{0}]".format(w))
1288                self.assertEqual(len(pw), 3, "Subsequent failures should produce periodic warnings: [{0}]".format(pw))
1289
1290                tc = telemetry_calls()
1291                self.assertTrue(len(tc) == 0, "Subsequent failures should not produce any telemetry events: [{0}]".format(tc))
1292
1293            #
1294            # ... until we finally succeed
1295            #
1296            fail_goal_state_request = False
1297            with create_log_and_telemetry_mocks() as (logger, add_event):
1298                protocol.client.try_update_goal_state()
1299
1300                s = success_messages()
1301                w = warnings()
1302                pw = periodic_warnings()
1303                self.assertEqual(len(s), 1, "Recovering after failures should have produced an info message: [{0}]".format(s))
1304                self.assertTrue(len(w) == 0 and len(pw) == 0, "Recovering after failures should have not produced any warnings: [{0}] [{1}]".format(w, pw))
1305
1306                gs = goal_state_events()
1307                self.assertTrue(len(gs) == 1 and 'is_success=True' in gs[0], "Recovering after failures should produce a telemetry event (success=true): [{0}]".format(gs))
1308
1309
1310class UpdateHostPluginFromGoalStateTestCase(AgentTestCase):
1311    """
1312    Tests for WireClient.update_host_plugin_from_goal_state()
1313    """
1314
1315    def test_it_should_update_the_host_plugin_with_or_without_incarnation_changes(self):
1316        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
1317            protocol.client.get_host_plugin()
1318
1319            # the behavior should be the same whether the incarnation changes or not
1320            for incarnation_change in [True, False]:
1321                protocol.mock_wire_data.reload()  # start each iteration of the test with fresh mock data
1322
1323                new_container_id = str(uuid.uuid4())
1324                new_role_config_name = str(uuid.uuid4())
1325
1326                goal_state_xml_text = protocol.mock_wire_data.goal_state
1327                shared_conf_xml_text = protocol.mock_wire_data.shared_config
1328
1329                if incarnation_change:
1330                    protocol.mock_wire_data.set_incarnation(str(uuid.uuid4()))
1331
1332                protocol.mock_wire_data.set_container_id(new_container_id)
1333                protocol.mock_wire_data.set_role_config_name(new_role_config_name)
1334                protocol.mock_wire_data.shared_config = WireProtocolData.replace_xml_attribute_value(
1335                    protocol.mock_wire_data.shared_config, "Deployment", "name", str(uuid.uuid4()))
1336
1337                protocol.client.update_host_plugin_from_goal_state()
1338
1339                self.assertEqual(protocol.client.get_host_plugin().container_id, new_container_id)
1340                self.assertEqual(protocol.client.get_host_plugin().role_config_name, new_role_config_name)
1341
1342                # it should not update the goal state
1343                self.assertEqual(protocol.client.get_goal_state().xml_text, goal_state_xml_text)
1344                self.assertEqual(protocol.client.get_shared_conf().xml_text, shared_conf_xml_text)
1345
1346
1347class MockResponse:
1348    def __init__(self, body, status_code, reason=None):
1349        self.body = body
1350        self.status = status_code
1351        self.reason = reason
1352
1353    def read(self, *_):
1354        return self.body
1355
1356
1357if __name__ == '__main__':
1358    unittest.main()
1359