1# Copyright 2020 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# Requires Python 2.6+ and Openssl 1.0+
16#
17
18import json
19import os
20import subprocess
21import time
22import uuid
23
24from azurelinuxagent.common.agent_supported_feature import AgentSupportedFeature
25from azurelinuxagent.common.cgroupconfigurator import CGroupConfigurator
26from azurelinuxagent.common.event import AGENT_EVENT_FILE_EXTENSION, WALAEventOperation
27from azurelinuxagent.common.exception import ExtensionError, ExtensionErrorCodes
28from azurelinuxagent.common.protocol.restapi import ExtensionStatus, Extension, ExtHandler, ExtHandlerProperties
29from azurelinuxagent.common.protocol.util import ProtocolUtil
30from azurelinuxagent.common.protocol.wire import WireProtocol
31from azurelinuxagent.common.utils import fileutil
32from azurelinuxagent.common.utils.extensionprocessutil import TELEMETRY_MESSAGE_MAX_LEN, format_stdout_stderr, \
33    read_output
34from azurelinuxagent.ga.exthandlers import parse_ext_status, ExtHandlerInstance, ExtCommandEnvVariable, \
35    ExtensionStatusError
36from tests.protocol import mockwiredata
37from tests.protocol.mocks import mock_wire_protocol
38from tests.tools import AgentTestCase, patch, mock_sleep, clear_singleton_instances
39
40
41class TestExtHandlers(AgentTestCase):
42
43    def setUp(self):
44        super(TestExtHandlers, self).setUp()
45        # Since ProtocolUtil is a singleton per thread, we need to clear it to ensure that the test cases do not
46        # reuse a previous state
47        clear_singleton_instances(ProtocolUtil)
48
49    def test_parse_ext_status_should_raise_on_non_array(self):
50        status = json.loads('''
51            {{
52                "status": {{
53                    "status": "transitioning",
54                    "operation": "Enabling Handler",
55                    "code": 0,
56                    "name": "Microsoft.Azure.RecoveryServices.SiteRecovery.Linux"
57                }},
58                "version": 1.0,
59                "timestampUTC": "2020-01-14T15:04:43Z",
60                "longText": "{0}"
61            }}'''.format("*" * 5 * 1024))
62
63        with self.assertRaises(ExtensionStatusError) as context_manager:
64            parse_ext_status(ExtensionStatus(seq_no=0), status)
65        error_message = str(context_manager.exception)
66        self.assertIn("The extension status must be an array", error_message)
67        self.assertTrue(0 < len(error_message) - 64 < 4096, "The error message should not be much longer than 4K characters: [{0}]".format(error_message))
68
69    def test_parse_extension_status00(self):
70        """
71        Parse a status report for a successful execution of an extension.
72        """
73        s = '''[{
74    "status": {
75      "status": "success",
76      "formattedMessage": {
77        "lang": "en-US",
78        "message": "Command is finished."
79      },
80      "operation": "Daemon",
81      "code": "0",
82      "name": "Microsoft.OSTCExtensions.CustomScriptForLinux"
83    },
84    "version": "1.0",
85    "timestampUTC": "2018-04-20T21:20:24Z"
86  }
87]'''
88        ext_status = ExtensionStatus(seq_no=0)
89        parse_ext_status(ext_status, json.loads(s))
90
91        self.assertEqual('0', ext_status.code)
92        self.assertEqual(None, ext_status.configurationAppliedTime)
93        self.assertEqual('Command is finished.', ext_status.message)
94        self.assertEqual('Daemon', ext_status.operation)
95        self.assertEqual('success', ext_status.status)
96        self.assertEqual(0, ext_status.sequenceNumber)
97        self.assertEqual(0, len(ext_status.substatusList))
98
99    def test_parse_extension_status01(self):
100        """
101        Parse a status report for a failed execution of an extension.
102
103        The extension returned a bad status/status of failed.
104        The agent should handle this gracefully, and convert all unknown
105        status/status values into an error.
106        """
107        s = '''[{
108    "status": {
109      "status": "failed",
110      "formattedMessage": {
111        "lang": "en-US",
112        "message": "Enable failed: Failed with error: commandToExecute is empty or invalid ..."
113      },
114      "operation": "Enable",
115      "code": "0",
116      "name": "Microsoft.OSTCExtensions.CustomScriptForLinux"
117    },
118    "version": "1.0",
119    "timestampUTC": "2018-04-20T20:50:22Z"
120}]'''
121        ext_status = ExtensionStatus(seq_no=0)
122        parse_ext_status(ext_status, json.loads(s))
123
124        self.assertEqual('0', ext_status.code)
125        self.assertEqual(None, ext_status.configurationAppliedTime)
126        self.assertEqual('Enable failed: Failed with error: commandToExecute is empty or invalid ...', ext_status.message)
127        self.assertEqual('Enable', ext_status.operation)
128        self.assertEqual('error', ext_status.status)
129        self.assertEqual(0, ext_status.sequenceNumber)
130        self.assertEqual(0, len(ext_status.substatusList))
131
132    def test_parse_ext_status_should_parse_missing_substatus_as_empty(self):
133        status = '''[{
134            "status": {
135              "status": "success",
136              "formattedMessage": {
137                "lang": "en-US",
138                "message": "Command is finished."
139              },
140              "operation": "Enable",
141              "code": "0",
142              "name": "Microsoft.OSTCExtensions.CustomScriptForLinux"
143            },
144
145            "version": "1.0",
146            "timestampUTC": "2018-04-20T21:20:24Z"
147          }
148        ]'''
149
150        extension_status = ExtensionStatus(seq_no=0)
151
152        parse_ext_status(extension_status, json.loads(status))
153
154        self.assertTrue(isinstance(extension_status.substatusList, list), 'substatus was not parsed correctly')
155        self.assertEqual(0, len(extension_status.substatusList))
156
157    def test_parse_ext_status_should_parse_null_substatus_as_empty(self):
158        status = '''[{
159            "status": {
160              "status": "success",
161              "formattedMessage": {
162                "lang": "en-US",
163                "message": "Command is finished."
164              },
165              "operation": "Enable",
166              "code": "0",
167              "name": "Microsoft.OSTCExtensions.CustomScriptForLinux",
168              "substatus": null
169            },
170
171            "version": "1.0",
172            "timestampUTC": "2018-04-20T21:20:24Z"
173          }
174        ]'''
175
176        extension_status = ExtensionStatus(seq_no=0)
177
178        parse_ext_status(extension_status, json.loads(status))
179
180        self.assertTrue(isinstance(extension_status.substatusList, list), 'substatus was not parsed correctly')
181        self.assertEqual(0, len(extension_status.substatusList))
182
183    def test_parse_extension_status_with_empty_status(self):
184        """
185        Parse a status report for a successful execution of an extension.
186        """
187
188        # Validating empty status case
189        s = '''[]'''
190        ext_status = ExtensionStatus(seq_no=0)
191        parse_ext_status(ext_status, json.loads(s))
192
193        self.assertEqual(None, ext_status.code)
194        self.assertEqual(None, ext_status.configurationAppliedTime)
195        self.assertEqual(None, ext_status.message)
196        self.assertEqual(None, ext_status.operation)
197        self.assertEqual(None, ext_status.status)
198        self.assertEqual(0, ext_status.sequenceNumber)
199        self.assertEqual(0, len(ext_status.substatusList))
200
201        # Validating None case
202        ext_status = ExtensionStatus(seq_no=0)
203        parse_ext_status(ext_status, None)
204
205        self.assertEqual(None, ext_status.code)
206        self.assertEqual(None, ext_status.configurationAppliedTime)
207        self.assertEqual(None, ext_status.message)
208        self.assertEqual(None, ext_status.operation)
209        self.assertEqual(None, ext_status.status)
210        self.assertEqual(0, ext_status.sequenceNumber)
211        self.assertEqual(0, len(ext_status.substatusList))
212
213    @patch('azurelinuxagent.common.event.EventLogger.add_event')
214    @patch('azurelinuxagent.ga.exthandlers.ExtHandlerInstance._get_last_modified_seq_no_from_config_files')
215    def assert_extension_sequence_number(self,
216                                         patch_get_largest_seq,
217                                         patch_add_event,
218                                         goal_state_sequence_number,
219                                         disk_sequence_number,
220                                         expected_sequence_number):
221        ext = Extension()
222        ext.sequenceNumber = goal_state_sequence_number
223        patch_get_largest_seq.return_value = disk_sequence_number
224
225        ext_handler_props = ExtHandlerProperties()
226        ext_handler_props.version = "1.2.3"
227        ext_handler = ExtHandler(name='foo')
228        ext_handler.properties = ext_handler_props
229
230        instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None)
231        seq, path = instance.get_status_file_path(ext)
232
233        try:
234            gs_seq_int = int(goal_state_sequence_number)
235            gs_int = True
236        except ValueError:
237            gs_int = False
238
239        if gs_int and gs_seq_int != disk_sequence_number:
240            self.assertEqual(1, patch_add_event.call_count)
241            args, kw_args = patch_add_event.call_args  # pylint: disable=unused-variable
242            self.assertEqual('SequenceNumberMismatch', kw_args['op'])
243            self.assertEqual(False, kw_args['is_success'])
244            self.assertEqual('Goal state: {0}, disk: {1}'
245                             .format(gs_seq_int, disk_sequence_number),
246                             kw_args['message'])
247        else:
248            self.assertEqual(0, patch_add_event.call_count)
249
250        self.assertEqual(expected_sequence_number, seq)
251        if seq > -1:
252            self.assertTrue(path.endswith('/foo-1.2.3/status/{0}.status'.format(expected_sequence_number)))
253        else:
254            self.assertIsNone(path)
255
256    def test_extension_sequence_number(self):
257        self.assert_extension_sequence_number(goal_state_sequence_number="12",  # pylint: disable=no-value-for-parameter
258                                              disk_sequence_number=366,
259                                              expected_sequence_number=12)
260
261        self.assert_extension_sequence_number(goal_state_sequence_number=" 12 ",  # pylint: disable=no-value-for-parameter
262                                              disk_sequence_number=366,
263                                              expected_sequence_number=12)
264
265        self.assert_extension_sequence_number(goal_state_sequence_number=" foo",  # pylint: disable=no-value-for-parameter
266                                              disk_sequence_number=3,
267                                              expected_sequence_number=3)
268
269        self.assert_extension_sequence_number(goal_state_sequence_number="-1",  # pylint: disable=no-value-for-parameter
270                                              disk_sequence_number=3,
271                                              expected_sequence_number=-1)
272
273    def test_it_should_report_error_if_plugin_settings_version_mismatch(self):
274        with mock_wire_protocol(mockwiredata.DATA_FILE_PLUGIN_SETTINGS_MISMATCH) as protocol:
275            with patch("azurelinuxagent.common.protocol.goal_state.add_event") as mock_add_event:
276                # Forcing update of GoalState to allow the ExtConfig to report an event
277                protocol.mock_wire_data.set_incarnation(2)
278                protocol.client.update_goal_state()
279                plugin_setting_mismatch_calls = [kw for _, kw in mock_add_event.call_args_list if
280                                                 kw['op'] == WALAEventOperation.PluginSettingsVersionMismatch]
281                self.assertEqual(1, len(plugin_setting_mismatch_calls),
282                                 "PluginSettingsMismatch event should be reported once")
283                self.assertIn('ExtHandler PluginSettings Version Mismatch! Expected PluginSettings version: 1.0.0 for Handler: OSTCExtensions.ExampleHandlerLinux'
284                              , plugin_setting_mismatch_calls[0]['message'],
285                    "Invalid error message with incomplete data detected for PluginSettingsVersionMismatch")
286                self.assertTrue("1.0.2" in plugin_setting_mismatch_calls[0]['message'] and "1.0.1" in plugin_setting_mismatch_calls[0]['message'],
287                              "Error message should contain the incorrect versions")
288                self.assertFalse(plugin_setting_mismatch_calls[0]['is_success'], "The event should be false")
289
290    @patch("azurelinuxagent.common.conf.get_ext_log_dir")
291    def test_command_extension_log_truncates_correctly(self, mock_log_dir):
292        log_dir_path = os.path.join(self.tmp_dir, "log_directory")
293        mock_log_dir.return_value = log_dir_path
294
295        ext_handler_props = ExtHandlerProperties()
296        ext_handler_props.version = "1.2.3"
297        ext_handler = ExtHandler(name='foo')
298        ext_handler.properties = ext_handler_props
299
300        first_line = "This is the first line!"
301        second_line = "This is the second line."
302        old_logfile_contents = "{first_line}\n{second_line}\n".format(first_line=first_line, second_line=second_line)
303
304        log_file_path = os.path.join(log_dir_path, "foo", "CommandExecution.log")
305
306        fileutil.mkdir(os.path.join(log_dir_path, "foo"), mode=0o755)
307        with open(log_file_path, "a") as log_file:
308            log_file.write(old_logfile_contents)
309
310        _ = ExtHandlerInstance(ext_handler=ext_handler, protocol=None,
311            execution_log_max_size=(len(first_line)+len(second_line)//2))
312
313        with open(log_file_path) as truncated_log_file:
314            self.assertEqual(truncated_log_file.read(), "{second_line}\n".format(second_line=second_line))
315
316class LaunchCommandTestCase(AgentTestCase):
317    """
318    Test cases for launch_command
319    """
320
321    def setUp(self):
322        AgentTestCase.setUp(self)
323
324        ext_handler_properties = ExtHandlerProperties()
325        ext_handler_properties.version = "1.2.3"
326        self.ext_handler = ExtHandler(name='foo')
327        self.ext_handler.properties = ext_handler_properties
328        self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol("1.2.3.4"))
329
330        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
331        self.mock_get_base_dir.start()
332
333        self.log_dir = os.path.join(self.tmp_dir, "log")
334        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
335        self.mock_get_log_dir.start()
336
337        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
338        self.mock_sleep.start()
339
340        self.cgroups_enabled = CGroupConfigurator.get_instance().enabled()
341        CGroupConfigurator.get_instance().disable()
342
343    def tearDown(self):
344        if self.cgroups_enabled:
345            CGroupConfigurator.get_instance().enable()
346        else:
347            CGroupConfigurator.get_instance().disable()
348
349        self.mock_get_log_dir.stop()
350        self.mock_get_base_dir.stop()
351        self.mock_sleep.stop()
352
353        AgentTestCase.tearDown(self)
354
355    @staticmethod
356    def _output_regex(stdout, stderr):
357        return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr)
358
359    @staticmethod
360    def _find_process(command):
361        for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]:
362            try:
363                with open(os.path.join('/proc', pid, 'cmdline'), 'r') as cmdline:
364                    for line in cmdline.readlines():
365                        if command in line:
366                            return True
367            except IOError:  # proc has already terminated
368                continue
369        return False
370
371    def test_it_should_capture_the_output_of_the_command(self):
372        stdout = "stdout" * 5
373        stderr = "stderr" * 5
374
375        command = "produce_output.py"
376        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
377import sys
378
379sys.stdout.write("{0}")
380sys.stderr.write("{1}")
381
382'''.format(stdout, stderr))
383
384        def list_directory():
385            base_dir = self.ext_handler_instance.get_base_dir()
386            return [i for i in os.listdir(base_dir) if not i.endswith(AGENT_EVENT_FILE_EXTENSION)] # ignore telemetry files
387
388        files_before = list_directory()
389
390        output = self.ext_handler_instance.launch_command(command)
391
392        files_after = list_directory()
393
394        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))
395
396        self.assertListEqual(files_before, files_after, "Not all temporary files were deleted. File list: {0}".format(files_after))
397
398    def test_it_should_raise_an_exception_when_the_command_times_out(self):
399        extension_error_code = ExtensionErrorCodes.PluginHandlerScriptTimedout
400        stdout = "stdout" * 7
401        stderr = "stderr" * 7
402
403        # the signal file is used by the test command to indicate it has produced output
404        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")
405
406        # the test command produces some output then goes into an infinite loop
407        command = "produce_output_then_hang.py"
408        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
409import sys
410import time
411
412sys.stdout.write("{0}")
413sys.stdout.flush()
414
415sys.stderr.write("{1}")
416sys.stderr.flush()
417
418with open("{2}", "w") as file:
419    while True:
420        file.write(".")
421        time.sleep(1)
422
423'''.format(stdout, stderr, signal_file))
424
425        # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep)
426        def sleep(seconds):
427            if not os.path.exists(signal_file):
428                sleep.original_sleep(seconds)
429        sleep.original_sleep = time.sleep
430
431        timeout = 60
432
433        start_time = time.time()
434
435        with patch("time.sleep", side_effect=sleep, autospec=True) as mock_sleep:  # pylint: disable=redefined-outer-name
436
437            with self.assertRaises(ExtensionError) as context_manager:
438                self.ext_handler_instance.launch_command(command, timeout=timeout, extension_error_code=extension_error_code)
439
440            # the command name and its output should be part of the message
441            message = str(context_manager.exception)
442            command_full_path = os.path.join(self.tmp_dir, command.lstrip(os.path.sep))
443            self.assertRegex(message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(command_full_path, LaunchCommandTestCase._output_regex(stdout, stderr)))
444
445            # the exception code should be as specified in the call to launch_command
446            self.assertEqual(context_manager.exception.code, extension_error_code)
447
448            # the timeout period should have elapsed
449            self.assertGreaterEqual(mock_sleep.call_count, timeout)
450
451            # The command should have been terminated.
452            # The /proc file system may still include the process when we do this check so we try a few times after a short delay; note that we
453            # are mocking sleep, so we need to use the original implementation.
454            terminated = False
455            i = 0
456            while not terminated and i < 4:
457                if not LaunchCommandTestCase._find_process(command):
458                    terminated = True
459                else:
460                    sleep.original_sleep(0.25)
461                i += 1
462
463            self.assertTrue(terminated, "The command was not terminated")
464
465        # as a check for the test itself, verify it completed in just a few seconds
466        self.assertLessEqual(time.time() - start_time, 5)
467
468    def test_it_should_raise_an_exception_when_the_command_fails(self):
469        extension_error_code = 2345
470        stdout = "stdout" * 3
471        stderr = "stderr" * 3
472        exit_code = 99
473
474        command = "fail.py"
475        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
476import sys
477
478sys.stdout.write("{0}")
479sys.stderr.write("{1}")
480exit({2})
481
482'''.format(stdout, stderr, exit_code))
483
484        # the output is captured as part of the exception message
485        with self.assertRaises(ExtensionError) as context_manager:
486            self.ext_handler_instance.launch_command(command, extension_error_code=extension_error_code)
487
488        message = str(context_manager.exception)
489        self.assertRegex(message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(exit_code, command, LaunchCommandTestCase._output_regex(stdout, stderr)))
490
491        self.assertEqual(context_manager.exception.code, extension_error_code)
492
493    def test_it_should_not_wait_for_child_process(self):
494        stdout = "stdout"
495        stderr = "stderr"
496
497        command = "start_child_process.py"
498        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
499import os
500import sys
501import time
502
503pid = os.fork()
504
505if pid == 0:
506    time.sleep(60)
507else:
508    sys.stdout.write("{0}")
509    sys.stderr.write("{1}")
510
511'''.format(stdout, stderr))
512
513        start_time = time.time()
514
515        output = self.ext_handler_instance.launch_command(command)
516
517        self.assertLessEqual(time.time() - start_time, 5)
518
519        # Also check that we capture the parent's output
520        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))
521
522    def test_it_should_capture_the_output_of_child_process(self):
523        parent_stdout = "PARENT STDOUT"
524        parent_stderr = "PARENT STDERR"
525        child_stdout = "CHILD STDOUT"
526        child_stderr = "CHILD STDERR"
527        more_parent_stdout = "MORE PARENT STDOUT"
528        more_parent_stderr = "MORE PARENT STDERR"
529
530        # the child process uses the signal file to indicate it has produced output
531        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")
532
533        command = "start_child_with_output.py"
534        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
535import os
536import sys
537import time
538
539sys.stdout.write("{0}")
540sys.stderr.write("{1}")
541
542pid = os.fork()
543
544if pid == 0:
545    sys.stdout.write("{2}")
546    sys.stderr.write("{3}")
547
548    open("{6}", "w").close()
549else:
550    sys.stdout.write("{4}")
551    sys.stderr.write("{5}")
552
553    while not os.path.exists("{6}"):
554        time.sleep(0.5)
555
556'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr, more_parent_stdout, more_parent_stderr, signal_file))
557
558        output = self.ext_handler_instance.launch_command(command)
559
560        self.assertIn(parent_stdout, output)
561        self.assertIn(parent_stderr, output)
562
563        self.assertIn(child_stdout, output)
564        self.assertIn(child_stderr, output)
565
566        self.assertIn(more_parent_stdout, output)
567        self.assertIn(more_parent_stderr, output)
568
569    def test_it_should_capture_the_output_of_child_process_that_fails_to_start(self):
570        parent_stdout = "PARENT STDOUT"
571        parent_stderr = "PARENT STDERR"
572        child_stdout = "CHILD STDOUT"
573        child_stderr = "CHILD STDERR"
574
575        command = "start_child_that_fails.py"
576        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
577import os
578import sys
579import time
580
581pid = os.fork()
582
583if pid == 0:
584    sys.stdout.write("{0}")
585    sys.stderr.write("{1}")
586    exit(1)
587else:
588    sys.stdout.write("{2}")
589    sys.stderr.write("{3}")
590
591'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr))
592
593        output = self.ext_handler_instance.launch_command(command)
594
595        self.assertIn(parent_stdout, output)
596        self.assertIn(parent_stderr, output)
597
598        self.assertIn(child_stdout, output)
599        self.assertIn(child_stderr, output)
600
601    def test_it_should_execute_commands_with_no_output(self):
602        # file used to verify the command completed successfully
603        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")
604
605        command = "create_file.py"
606        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
607open("{0}", "w").close()
608
609'''.format(signal_file))
610
611        output = self.ext_handler_instance.launch_command(command)
612
613        self.assertTrue(os.path.exists(signal_file))
614        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))
615
616    def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(self):
617        # the test script redirects its output to this file
618        command_output_file = os.path.join(self.tmp_dir, "command_output.txt")
619        stdout = "STDOUT"
620        stderr = "STDERR"
621
622        # the test script mimics the redirection done by the Custom Script extension
623        command = "produce_output"
624        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
625exec &> {0}
626echo {1}
627>&2 echo {2}
628
629'''.format(command_output_file, stdout, stderr))
630
631        output = self.ext_handler_instance.launch_command(command)
632
633        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))
634
635        with open(command_output_file, "r") as command_output:
636            output = command_output.read()
637            self.assertEqual(output, "{0}\n{1}\n".format(stdout, stderr))
638
639    def test_it_should_truncate_the_command_output(self):
640        stdout = "STDOUT"
641        stderr = "STDERR"
642
643        command = "produce_long_output.py"
644        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
645import sys
646
647sys.stdout.write( "{0}" * {1})
648sys.stderr.write( "{2}" * {3})
649'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr, int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr))))
650
651        output = self.ext_handler_instance.launch_command(command)
652
653        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
654        self.assertIn(stdout, output)
655        self.assertIn(stderr, output)
656
657    def test_it_should_read_only_the_head_of_large_outputs(self):
658        command = "produce_long_output.py"
659        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
660import sys
661
662sys.stdout.write("O" * 5 * 1024 * 1024)
663sys.stderr.write("E" * 5 * 1024 * 1024)
664''')
665
666        # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the
667        # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory)
668        # more than a few KB of data from the files used to capture stdout/stderr
669        with patch('azurelinuxagent.common.utils.extensionprocessutil.format_stdout_stderr', side_effect=format_stdout_stderr) as mock_format:
670            output = self.ext_handler_instance.launch_command(command)
671
672        self.assertGreaterEqual(len(output), 1024)
673        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
674
675        mock_format.assert_called_once()
676
677        args, kwargs = mock_format.call_args  # pylint: disable=unused-variable
678        stdout, stderr = args
679
680        self.assertGreaterEqual(len(stdout), 1024)
681        self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN)
682
683        self.assertGreaterEqual(len(stderr), 1024)
684        self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN)
685
686    def test_it_should_handle_errors_while_reading_the_command_output(self):
687        command = "produce_output.py"
688        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
689import sys
690
691sys.stdout.write("STDOUT")
692sys.stderr.write("STDERR")
693''')
694        # Mocking the call to file.read() is difficult, so instead we mock the call to_capture_process_output,
695        # which will call file.read() and we force stdout/stderr to be None; this will produce an exception when
696        # trying to use these files.
697        original_capture_process_output = read_output
698
699        def capture_process_output(stdout_file, stderr_file):  # pylint: disable=unused-argument
700            return original_capture_process_output(None, None)
701
702        with patch('azurelinuxagent.common.utils.extensionprocessutil.read_output', side_effect=capture_process_output):
703            output = self.ext_handler_instance.launch_command(command)
704
705        self.assertIn("[stderr]\nCannot read stdout/stderr:", output)
706
707    def test_it_should_contain_all_helper_environment_variables(self):
708
709        wire_ip = str(uuid.uuid4())
710        ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol(wire_ip))
711
712        helper_env_vars = {ExtCommandEnvVariable.ExtensionSeqNumber: ext_handler_instance.get_seq_no(),
713                           ExtCommandEnvVariable.ExtensionPath: self.tmp_dir,
714                           ExtCommandEnvVariable.ExtensionVersion: ext_handler_instance.ext_handler.properties.version,
715                           ExtCommandEnvVariable.WireProtocolAddress: wire_ip}
716
717        command = """
718            printenv | grep -E '(%s)'
719        """ % '|'.join(helper_env_vars.keys())
720
721        test_file = 'printHelperEnvironments.sh'
722        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), test_file), command)
723
724        with patch("subprocess.Popen", wraps=subprocess.Popen) as patch_popen:
725            # Returning empty list for get_agent_supported_features_list_for_extensions as we have a separate test for it
726            with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions",
727                       return_value={}):
728                output = ext_handler_instance.launch_command(test_file)
729
730            args, kwagrs = patch_popen.call_args  # pylint: disable=unused-variable
731            without_os_env = dict((k, v) for (k, v) in kwagrs['env'].items() if k not in os.environ)
732
733            # This check will fail if any helper environment variables are added/removed later on
734            self.assertEqual(helper_env_vars, without_os_env)
735
736            # This check is checking if the expected values are set for the extension commands
737            for helper_var in helper_env_vars:
738                self.assertIn("%s=%s" % (helper_var, helper_env_vars[helper_var]), output)
739
740    def test_it_should_pass_supported_features_list_as_environment_variables(self):
741
742        class TestFeature(AgentSupportedFeature):
743
744            def __init__(self, name, version, supported):
745                super(TestFeature, self).__init__(name=name,
746                                                  version=version,
747                                                  supported=supported)
748
749        test_name = str(uuid.uuid4())
750        test_version = str(uuid.uuid4())
751
752        command = "check_env.py"
753        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
754import os
755import json
756import sys
757
758features = os.getenv("{0}")
759if not features:
760    print("{0} not found in environment")
761    sys.exit(0)
762l = json.loads(features)
763found = False
764for feature in l:
765    if feature['Key'] == "{1}" and feature['Value'] == "{2}":
766        found = True
767        break
768
769print("Found Feature %s: %s" % ("{1}", found))
770'''.format(ExtCommandEnvVariable.ExtensionSupportedFeatures, test_name, test_version))
771
772        # It should include all supported features and pass it as Environment Variable to extensions
773        test_supported_features = {test_name: TestFeature(name=test_name, version=test_version, supported=True)}
774        with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions",
775                   return_value=test_supported_features):
776            output = self.ext_handler_instance.launch_command(command)
777
778            self.assertIn("[stdout]\nFound Feature {0}: True".format(test_name), output, "Feature not found")
779
780        # It should not include the feature if feature not supported
781        test_supported_features = {
782            test_name: TestFeature(name=test_name, version=test_version, supported=False),
783            "testFeature": TestFeature(name="testFeature", version="1.2.1", supported=True)
784        }
785        with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions",
786                   return_value=test_supported_features):
787            output = self.ext_handler_instance.launch_command(command)
788
789            self.assertIn("[stdout]\nFound Feature {0}: False".format(test_name), output, "Feature wrongfully found")
790
791        # It should not include the SupportedFeatures Key in Environment variables if no features supported
792        test_supported_features = {test_name: TestFeature(name=test_name, version=test_version, supported=False)}
793        with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions",
794                   return_value=test_supported_features):
795            output = self.ext_handler_instance.launch_command(command)
796
797            self.assertIn(
798                "[stdout]\n{0} not found in environment".format(ExtCommandEnvVariable.ExtensionSupportedFeatures),
799                output, "Environment variable should not be found")
800