1# Copyright 2020 Microsoft Corporation 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# 15# Requires Python 2.6+ and Openssl 1.0+ 16# 17 18import json 19import os 20import subprocess 21import time 22import uuid 23 24from azurelinuxagent.common.agent_supported_feature import AgentSupportedFeature 25from azurelinuxagent.common.cgroupconfigurator import CGroupConfigurator 26from azurelinuxagent.common.event import AGENT_EVENT_FILE_EXTENSION, WALAEventOperation 27from azurelinuxagent.common.exception import ExtensionError, ExtensionErrorCodes 28from azurelinuxagent.common.protocol.restapi import ExtensionStatus, Extension, ExtHandler, ExtHandlerProperties 29from azurelinuxagent.common.protocol.util import ProtocolUtil 30from azurelinuxagent.common.protocol.wire import WireProtocol 31from azurelinuxagent.common.utils import fileutil 32from azurelinuxagent.common.utils.extensionprocessutil import TELEMETRY_MESSAGE_MAX_LEN, format_stdout_stderr, \ 33 read_output 34from azurelinuxagent.ga.exthandlers import parse_ext_status, ExtHandlerInstance, ExtCommandEnvVariable, \ 35 ExtensionStatusError 36from tests.protocol import mockwiredata 37from tests.protocol.mocks import mock_wire_protocol 38from tests.tools import AgentTestCase, patch, mock_sleep, clear_singleton_instances 39 40 41class TestExtHandlers(AgentTestCase): 42 43 def setUp(self): 44 super(TestExtHandlers, self).setUp() 45 # Since ProtocolUtil is a singleton per thread, we need to clear it to ensure that the test cases do not 46 # reuse a previous state 47 clear_singleton_instances(ProtocolUtil) 48 49 def test_parse_ext_status_should_raise_on_non_array(self): 50 status = json.loads(''' 51 {{ 52 "status": {{ 53 "status": "transitioning", 54 "operation": "Enabling Handler", 55 "code": 0, 56 "name": "Microsoft.Azure.RecoveryServices.SiteRecovery.Linux" 57 }}, 58 "version": 1.0, 59 "timestampUTC": "2020-01-14T15:04:43Z", 60 "longText": "{0}" 61 }}'''.format("*" * 5 * 1024)) 62 63 with self.assertRaises(ExtensionStatusError) as context_manager: 64 parse_ext_status(ExtensionStatus(seq_no=0), status) 65 error_message = str(context_manager.exception) 66 self.assertIn("The extension status must be an array", error_message) 67 self.assertTrue(0 < len(error_message) - 64 < 4096, "The error message should not be much longer than 4K characters: [{0}]".format(error_message)) 68 69 def test_parse_extension_status00(self): 70 """ 71 Parse a status report for a successful execution of an extension. 72 """ 73 s = '''[{ 74 "status": { 75 "status": "success", 76 "formattedMessage": { 77 "lang": "en-US", 78 "message": "Command is finished." 79 }, 80 "operation": "Daemon", 81 "code": "0", 82 "name": "Microsoft.OSTCExtensions.CustomScriptForLinux" 83 }, 84 "version": "1.0", 85 "timestampUTC": "2018-04-20T21:20:24Z" 86 } 87]''' 88 ext_status = ExtensionStatus(seq_no=0) 89 parse_ext_status(ext_status, json.loads(s)) 90 91 self.assertEqual('0', ext_status.code) 92 self.assertEqual(None, ext_status.configurationAppliedTime) 93 self.assertEqual('Command is finished.', ext_status.message) 94 self.assertEqual('Daemon', ext_status.operation) 95 self.assertEqual('success', ext_status.status) 96 self.assertEqual(0, ext_status.sequenceNumber) 97 self.assertEqual(0, len(ext_status.substatusList)) 98 99 def test_parse_extension_status01(self): 100 """ 101 Parse a status report for a failed execution of an extension. 102 103 The extension returned a bad status/status of failed. 104 The agent should handle this gracefully, and convert all unknown 105 status/status values into an error. 106 """ 107 s = '''[{ 108 "status": { 109 "status": "failed", 110 "formattedMessage": { 111 "lang": "en-US", 112 "message": "Enable failed: Failed with error: commandToExecute is empty or invalid ..." 113 }, 114 "operation": "Enable", 115 "code": "0", 116 "name": "Microsoft.OSTCExtensions.CustomScriptForLinux" 117 }, 118 "version": "1.0", 119 "timestampUTC": "2018-04-20T20:50:22Z" 120}]''' 121 ext_status = ExtensionStatus(seq_no=0) 122 parse_ext_status(ext_status, json.loads(s)) 123 124 self.assertEqual('0', ext_status.code) 125 self.assertEqual(None, ext_status.configurationAppliedTime) 126 self.assertEqual('Enable failed: Failed with error: commandToExecute is empty or invalid ...', ext_status.message) 127 self.assertEqual('Enable', ext_status.operation) 128 self.assertEqual('error', ext_status.status) 129 self.assertEqual(0, ext_status.sequenceNumber) 130 self.assertEqual(0, len(ext_status.substatusList)) 131 132 def test_parse_ext_status_should_parse_missing_substatus_as_empty(self): 133 status = '''[{ 134 "status": { 135 "status": "success", 136 "formattedMessage": { 137 "lang": "en-US", 138 "message": "Command is finished." 139 }, 140 "operation": "Enable", 141 "code": "0", 142 "name": "Microsoft.OSTCExtensions.CustomScriptForLinux" 143 }, 144 145 "version": "1.0", 146 "timestampUTC": "2018-04-20T21:20:24Z" 147 } 148 ]''' 149 150 extension_status = ExtensionStatus(seq_no=0) 151 152 parse_ext_status(extension_status, json.loads(status)) 153 154 self.assertTrue(isinstance(extension_status.substatusList, list), 'substatus was not parsed correctly') 155 self.assertEqual(0, len(extension_status.substatusList)) 156 157 def test_parse_ext_status_should_parse_null_substatus_as_empty(self): 158 status = '''[{ 159 "status": { 160 "status": "success", 161 "formattedMessage": { 162 "lang": "en-US", 163 "message": "Command is finished." 164 }, 165 "operation": "Enable", 166 "code": "0", 167 "name": "Microsoft.OSTCExtensions.CustomScriptForLinux", 168 "substatus": null 169 }, 170 171 "version": "1.0", 172 "timestampUTC": "2018-04-20T21:20:24Z" 173 } 174 ]''' 175 176 extension_status = ExtensionStatus(seq_no=0) 177 178 parse_ext_status(extension_status, json.loads(status)) 179 180 self.assertTrue(isinstance(extension_status.substatusList, list), 'substatus was not parsed correctly') 181 self.assertEqual(0, len(extension_status.substatusList)) 182 183 def test_parse_extension_status_with_empty_status(self): 184 """ 185 Parse a status report for a successful execution of an extension. 186 """ 187 188 # Validating empty status case 189 s = '''[]''' 190 ext_status = ExtensionStatus(seq_no=0) 191 parse_ext_status(ext_status, json.loads(s)) 192 193 self.assertEqual(None, ext_status.code) 194 self.assertEqual(None, ext_status.configurationAppliedTime) 195 self.assertEqual(None, ext_status.message) 196 self.assertEqual(None, ext_status.operation) 197 self.assertEqual(None, ext_status.status) 198 self.assertEqual(0, ext_status.sequenceNumber) 199 self.assertEqual(0, len(ext_status.substatusList)) 200 201 # Validating None case 202 ext_status = ExtensionStatus(seq_no=0) 203 parse_ext_status(ext_status, None) 204 205 self.assertEqual(None, ext_status.code) 206 self.assertEqual(None, ext_status.configurationAppliedTime) 207 self.assertEqual(None, ext_status.message) 208 self.assertEqual(None, ext_status.operation) 209 self.assertEqual(None, ext_status.status) 210 self.assertEqual(0, ext_status.sequenceNumber) 211 self.assertEqual(0, len(ext_status.substatusList)) 212 213 @patch('azurelinuxagent.common.event.EventLogger.add_event') 214 @patch('azurelinuxagent.ga.exthandlers.ExtHandlerInstance._get_last_modified_seq_no_from_config_files') 215 def assert_extension_sequence_number(self, 216 patch_get_largest_seq, 217 patch_add_event, 218 goal_state_sequence_number, 219 disk_sequence_number, 220 expected_sequence_number): 221 ext = Extension() 222 ext.sequenceNumber = goal_state_sequence_number 223 patch_get_largest_seq.return_value = disk_sequence_number 224 225 ext_handler_props = ExtHandlerProperties() 226 ext_handler_props.version = "1.2.3" 227 ext_handler = ExtHandler(name='foo') 228 ext_handler.properties = ext_handler_props 229 230 instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None) 231 seq, path = instance.get_status_file_path(ext) 232 233 try: 234 gs_seq_int = int(goal_state_sequence_number) 235 gs_int = True 236 except ValueError: 237 gs_int = False 238 239 if gs_int and gs_seq_int != disk_sequence_number: 240 self.assertEqual(1, patch_add_event.call_count) 241 args, kw_args = patch_add_event.call_args # pylint: disable=unused-variable 242 self.assertEqual('SequenceNumberMismatch', kw_args['op']) 243 self.assertEqual(False, kw_args['is_success']) 244 self.assertEqual('Goal state: {0}, disk: {1}' 245 .format(gs_seq_int, disk_sequence_number), 246 kw_args['message']) 247 else: 248 self.assertEqual(0, patch_add_event.call_count) 249 250 self.assertEqual(expected_sequence_number, seq) 251 if seq > -1: 252 self.assertTrue(path.endswith('/foo-1.2.3/status/{0}.status'.format(expected_sequence_number))) 253 else: 254 self.assertIsNone(path) 255 256 def test_extension_sequence_number(self): 257 self.assert_extension_sequence_number(goal_state_sequence_number="12", # pylint: disable=no-value-for-parameter 258 disk_sequence_number=366, 259 expected_sequence_number=12) 260 261 self.assert_extension_sequence_number(goal_state_sequence_number=" 12 ", # pylint: disable=no-value-for-parameter 262 disk_sequence_number=366, 263 expected_sequence_number=12) 264 265 self.assert_extension_sequence_number(goal_state_sequence_number=" foo", # pylint: disable=no-value-for-parameter 266 disk_sequence_number=3, 267 expected_sequence_number=3) 268 269 self.assert_extension_sequence_number(goal_state_sequence_number="-1", # pylint: disable=no-value-for-parameter 270 disk_sequence_number=3, 271 expected_sequence_number=-1) 272 273 def test_it_should_report_error_if_plugin_settings_version_mismatch(self): 274 with mock_wire_protocol(mockwiredata.DATA_FILE_PLUGIN_SETTINGS_MISMATCH) as protocol: 275 with patch("azurelinuxagent.common.protocol.goal_state.add_event") as mock_add_event: 276 # Forcing update of GoalState to allow the ExtConfig to report an event 277 protocol.mock_wire_data.set_incarnation(2) 278 protocol.client.update_goal_state() 279 plugin_setting_mismatch_calls = [kw for _, kw in mock_add_event.call_args_list if 280 kw['op'] == WALAEventOperation.PluginSettingsVersionMismatch] 281 self.assertEqual(1, len(plugin_setting_mismatch_calls), 282 "PluginSettingsMismatch event should be reported once") 283 self.assertIn('ExtHandler PluginSettings Version Mismatch! Expected PluginSettings version: 1.0.0 for Handler: OSTCExtensions.ExampleHandlerLinux' 284 , plugin_setting_mismatch_calls[0]['message'], 285 "Invalid error message with incomplete data detected for PluginSettingsVersionMismatch") 286 self.assertTrue("1.0.2" in plugin_setting_mismatch_calls[0]['message'] and "1.0.1" in plugin_setting_mismatch_calls[0]['message'], 287 "Error message should contain the incorrect versions") 288 self.assertFalse(plugin_setting_mismatch_calls[0]['is_success'], "The event should be false") 289 290 @patch("azurelinuxagent.common.conf.get_ext_log_dir") 291 def test_command_extension_log_truncates_correctly(self, mock_log_dir): 292 log_dir_path = os.path.join(self.tmp_dir, "log_directory") 293 mock_log_dir.return_value = log_dir_path 294 295 ext_handler_props = ExtHandlerProperties() 296 ext_handler_props.version = "1.2.3" 297 ext_handler = ExtHandler(name='foo') 298 ext_handler.properties = ext_handler_props 299 300 first_line = "This is the first line!" 301 second_line = "This is the second line." 302 old_logfile_contents = "{first_line}\n{second_line}\n".format(first_line=first_line, second_line=second_line) 303 304 log_file_path = os.path.join(log_dir_path, "foo", "CommandExecution.log") 305 306 fileutil.mkdir(os.path.join(log_dir_path, "foo"), mode=0o755) 307 with open(log_file_path, "a") as log_file: 308 log_file.write(old_logfile_contents) 309 310 _ = ExtHandlerInstance(ext_handler=ext_handler, protocol=None, 311 execution_log_max_size=(len(first_line)+len(second_line)//2)) 312 313 with open(log_file_path) as truncated_log_file: 314 self.assertEqual(truncated_log_file.read(), "{second_line}\n".format(second_line=second_line)) 315 316class LaunchCommandTestCase(AgentTestCase): 317 """ 318 Test cases for launch_command 319 """ 320 321 def setUp(self): 322 AgentTestCase.setUp(self) 323 324 ext_handler_properties = ExtHandlerProperties() 325 ext_handler_properties.version = "1.2.3" 326 self.ext_handler = ExtHandler(name='foo') 327 self.ext_handler.properties = ext_handler_properties 328 self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol("1.2.3.4")) 329 330 self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir) 331 self.mock_get_base_dir.start() 332 333 self.log_dir = os.path.join(self.tmp_dir, "log") 334 self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir) 335 self.mock_get_log_dir.start() 336 337 self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01)) 338 self.mock_sleep.start() 339 340 self.cgroups_enabled = CGroupConfigurator.get_instance().enabled() 341 CGroupConfigurator.get_instance().disable() 342 343 def tearDown(self): 344 if self.cgroups_enabled: 345 CGroupConfigurator.get_instance().enable() 346 else: 347 CGroupConfigurator.get_instance().disable() 348 349 self.mock_get_log_dir.stop() 350 self.mock_get_base_dir.stop() 351 self.mock_sleep.stop() 352 353 AgentTestCase.tearDown(self) 354 355 @staticmethod 356 def _output_regex(stdout, stderr): 357 return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr) 358 359 @staticmethod 360 def _find_process(command): 361 for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]: 362 try: 363 with open(os.path.join('/proc', pid, 'cmdline'), 'r') as cmdline: 364 for line in cmdline.readlines(): 365 if command in line: 366 return True 367 except IOError: # proc has already terminated 368 continue 369 return False 370 371 def test_it_should_capture_the_output_of_the_command(self): 372 stdout = "stdout" * 5 373 stderr = "stderr" * 5 374 375 command = "produce_output.py" 376 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 377import sys 378 379sys.stdout.write("{0}") 380sys.stderr.write("{1}") 381 382'''.format(stdout, stderr)) 383 384 def list_directory(): 385 base_dir = self.ext_handler_instance.get_base_dir() 386 return [i for i in os.listdir(base_dir) if not i.endswith(AGENT_EVENT_FILE_EXTENSION)] # ignore telemetry files 387 388 files_before = list_directory() 389 390 output = self.ext_handler_instance.launch_command(command) 391 392 files_after = list_directory() 393 394 self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr)) 395 396 self.assertListEqual(files_before, files_after, "Not all temporary files were deleted. File list: {0}".format(files_after)) 397 398 def test_it_should_raise_an_exception_when_the_command_times_out(self): 399 extension_error_code = ExtensionErrorCodes.PluginHandlerScriptTimedout 400 stdout = "stdout" * 7 401 stderr = "stderr" * 7 402 403 # the signal file is used by the test command to indicate it has produced output 404 signal_file = os.path.join(self.tmp_dir, "signal_file.txt") 405 406 # the test command produces some output then goes into an infinite loop 407 command = "produce_output_then_hang.py" 408 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 409import sys 410import time 411 412sys.stdout.write("{0}") 413sys.stdout.flush() 414 415sys.stderr.write("{1}") 416sys.stderr.flush() 417 418with open("{2}", "w") as file: 419 while True: 420 file.write(".") 421 time.sleep(1) 422 423'''.format(stdout, stderr, signal_file)) 424 425 # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep) 426 def sleep(seconds): 427 if not os.path.exists(signal_file): 428 sleep.original_sleep(seconds) 429 sleep.original_sleep = time.sleep 430 431 timeout = 60 432 433 start_time = time.time() 434 435 with patch("time.sleep", side_effect=sleep, autospec=True) as mock_sleep: # pylint: disable=redefined-outer-name 436 437 with self.assertRaises(ExtensionError) as context_manager: 438 self.ext_handler_instance.launch_command(command, timeout=timeout, extension_error_code=extension_error_code) 439 440 # the command name and its output should be part of the message 441 message = str(context_manager.exception) 442 command_full_path = os.path.join(self.tmp_dir, command.lstrip(os.path.sep)) 443 self.assertRegex(message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(command_full_path, LaunchCommandTestCase._output_regex(stdout, stderr))) 444 445 # the exception code should be as specified in the call to launch_command 446 self.assertEqual(context_manager.exception.code, extension_error_code) 447 448 # the timeout period should have elapsed 449 self.assertGreaterEqual(mock_sleep.call_count, timeout) 450 451 # The command should have been terminated. 452 # The /proc file system may still include the process when we do this check so we try a few times after a short delay; note that we 453 # are mocking sleep, so we need to use the original implementation. 454 terminated = False 455 i = 0 456 while not terminated and i < 4: 457 if not LaunchCommandTestCase._find_process(command): 458 terminated = True 459 else: 460 sleep.original_sleep(0.25) 461 i += 1 462 463 self.assertTrue(terminated, "The command was not terminated") 464 465 # as a check for the test itself, verify it completed in just a few seconds 466 self.assertLessEqual(time.time() - start_time, 5) 467 468 def test_it_should_raise_an_exception_when_the_command_fails(self): 469 extension_error_code = 2345 470 stdout = "stdout" * 3 471 stderr = "stderr" * 3 472 exit_code = 99 473 474 command = "fail.py" 475 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 476import sys 477 478sys.stdout.write("{0}") 479sys.stderr.write("{1}") 480exit({2}) 481 482'''.format(stdout, stderr, exit_code)) 483 484 # the output is captured as part of the exception message 485 with self.assertRaises(ExtensionError) as context_manager: 486 self.ext_handler_instance.launch_command(command, extension_error_code=extension_error_code) 487 488 message = str(context_manager.exception) 489 self.assertRegex(message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(exit_code, command, LaunchCommandTestCase._output_regex(stdout, stderr))) 490 491 self.assertEqual(context_manager.exception.code, extension_error_code) 492 493 def test_it_should_not_wait_for_child_process(self): 494 stdout = "stdout" 495 stderr = "stderr" 496 497 command = "start_child_process.py" 498 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 499import os 500import sys 501import time 502 503pid = os.fork() 504 505if pid == 0: 506 time.sleep(60) 507else: 508 sys.stdout.write("{0}") 509 sys.stderr.write("{1}") 510 511'''.format(stdout, stderr)) 512 513 start_time = time.time() 514 515 output = self.ext_handler_instance.launch_command(command) 516 517 self.assertLessEqual(time.time() - start_time, 5) 518 519 # Also check that we capture the parent's output 520 self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr)) 521 522 def test_it_should_capture_the_output_of_child_process(self): 523 parent_stdout = "PARENT STDOUT" 524 parent_stderr = "PARENT STDERR" 525 child_stdout = "CHILD STDOUT" 526 child_stderr = "CHILD STDERR" 527 more_parent_stdout = "MORE PARENT STDOUT" 528 more_parent_stderr = "MORE PARENT STDERR" 529 530 # the child process uses the signal file to indicate it has produced output 531 signal_file = os.path.join(self.tmp_dir, "signal_file.txt") 532 533 command = "start_child_with_output.py" 534 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 535import os 536import sys 537import time 538 539sys.stdout.write("{0}") 540sys.stderr.write("{1}") 541 542pid = os.fork() 543 544if pid == 0: 545 sys.stdout.write("{2}") 546 sys.stderr.write("{3}") 547 548 open("{6}", "w").close() 549else: 550 sys.stdout.write("{4}") 551 sys.stderr.write("{5}") 552 553 while not os.path.exists("{6}"): 554 time.sleep(0.5) 555 556'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr, more_parent_stdout, more_parent_stderr, signal_file)) 557 558 output = self.ext_handler_instance.launch_command(command) 559 560 self.assertIn(parent_stdout, output) 561 self.assertIn(parent_stderr, output) 562 563 self.assertIn(child_stdout, output) 564 self.assertIn(child_stderr, output) 565 566 self.assertIn(more_parent_stdout, output) 567 self.assertIn(more_parent_stderr, output) 568 569 def test_it_should_capture_the_output_of_child_process_that_fails_to_start(self): 570 parent_stdout = "PARENT STDOUT" 571 parent_stderr = "PARENT STDERR" 572 child_stdout = "CHILD STDOUT" 573 child_stderr = "CHILD STDERR" 574 575 command = "start_child_that_fails.py" 576 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 577import os 578import sys 579import time 580 581pid = os.fork() 582 583if pid == 0: 584 sys.stdout.write("{0}") 585 sys.stderr.write("{1}") 586 exit(1) 587else: 588 sys.stdout.write("{2}") 589 sys.stderr.write("{3}") 590 591'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr)) 592 593 output = self.ext_handler_instance.launch_command(command) 594 595 self.assertIn(parent_stdout, output) 596 self.assertIn(parent_stderr, output) 597 598 self.assertIn(child_stdout, output) 599 self.assertIn(child_stderr, output) 600 601 def test_it_should_execute_commands_with_no_output(self): 602 # file used to verify the command completed successfully 603 signal_file = os.path.join(self.tmp_dir, "signal_file.txt") 604 605 command = "create_file.py" 606 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 607open("{0}", "w").close() 608 609'''.format(signal_file)) 610 611 output = self.ext_handler_instance.launch_command(command) 612 613 self.assertTrue(os.path.exists(signal_file)) 614 self.assertRegex(output, LaunchCommandTestCase._output_regex('', '')) 615 616 def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(self): 617 # the test script redirects its output to this file 618 command_output_file = os.path.join(self.tmp_dir, "command_output.txt") 619 stdout = "STDOUT" 620 stderr = "STDERR" 621 622 # the test script mimics the redirection done by the Custom Script extension 623 command = "produce_output" 624 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 625exec &> {0} 626echo {1} 627>&2 echo {2} 628 629'''.format(command_output_file, stdout, stderr)) 630 631 output = self.ext_handler_instance.launch_command(command) 632 633 self.assertRegex(output, LaunchCommandTestCase._output_regex('', '')) 634 635 with open(command_output_file, "r") as command_output: 636 output = command_output.read() 637 self.assertEqual(output, "{0}\n{1}\n".format(stdout, stderr)) 638 639 def test_it_should_truncate_the_command_output(self): 640 stdout = "STDOUT" 641 stderr = "STDERR" 642 643 command = "produce_long_output.py" 644 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 645import sys 646 647sys.stdout.write( "{0}" * {1}) 648sys.stderr.write( "{2}" * {3}) 649'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr, int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr)))) 650 651 output = self.ext_handler_instance.launch_command(command) 652 653 self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN) 654 self.assertIn(stdout, output) 655 self.assertIn(stderr, output) 656 657 def test_it_should_read_only_the_head_of_large_outputs(self): 658 command = "produce_long_output.py" 659 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 660import sys 661 662sys.stdout.write("O" * 5 * 1024 * 1024) 663sys.stderr.write("E" * 5 * 1024 * 1024) 664''') 665 666 # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the 667 # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory) 668 # more than a few KB of data from the files used to capture stdout/stderr 669 with patch('azurelinuxagent.common.utils.extensionprocessutil.format_stdout_stderr', side_effect=format_stdout_stderr) as mock_format: 670 output = self.ext_handler_instance.launch_command(command) 671 672 self.assertGreaterEqual(len(output), 1024) 673 self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN) 674 675 mock_format.assert_called_once() 676 677 args, kwargs = mock_format.call_args # pylint: disable=unused-variable 678 stdout, stderr = args 679 680 self.assertGreaterEqual(len(stdout), 1024) 681 self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN) 682 683 self.assertGreaterEqual(len(stderr), 1024) 684 self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN) 685 686 def test_it_should_handle_errors_while_reading_the_command_output(self): 687 command = "produce_output.py" 688 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 689import sys 690 691sys.stdout.write("STDOUT") 692sys.stderr.write("STDERR") 693''') 694 # Mocking the call to file.read() is difficult, so instead we mock the call to_capture_process_output, 695 # which will call file.read() and we force stdout/stderr to be None; this will produce an exception when 696 # trying to use these files. 697 original_capture_process_output = read_output 698 699 def capture_process_output(stdout_file, stderr_file): # pylint: disable=unused-argument 700 return original_capture_process_output(None, None) 701 702 with patch('azurelinuxagent.common.utils.extensionprocessutil.read_output', side_effect=capture_process_output): 703 output = self.ext_handler_instance.launch_command(command) 704 705 self.assertIn("[stderr]\nCannot read stdout/stderr:", output) 706 707 def test_it_should_contain_all_helper_environment_variables(self): 708 709 wire_ip = str(uuid.uuid4()) 710 ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol(wire_ip)) 711 712 helper_env_vars = {ExtCommandEnvVariable.ExtensionSeqNumber: ext_handler_instance.get_seq_no(), 713 ExtCommandEnvVariable.ExtensionPath: self.tmp_dir, 714 ExtCommandEnvVariable.ExtensionVersion: ext_handler_instance.ext_handler.properties.version, 715 ExtCommandEnvVariable.WireProtocolAddress: wire_ip} 716 717 command = """ 718 printenv | grep -E '(%s)' 719 """ % '|'.join(helper_env_vars.keys()) 720 721 test_file = 'printHelperEnvironments.sh' 722 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), test_file), command) 723 724 with patch("subprocess.Popen", wraps=subprocess.Popen) as patch_popen: 725 # Returning empty list for get_agent_supported_features_list_for_extensions as we have a separate test for it 726 with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions", 727 return_value={}): 728 output = ext_handler_instance.launch_command(test_file) 729 730 args, kwagrs = patch_popen.call_args # pylint: disable=unused-variable 731 without_os_env = dict((k, v) for (k, v) in kwagrs['env'].items() if k not in os.environ) 732 733 # This check will fail if any helper environment variables are added/removed later on 734 self.assertEqual(helper_env_vars, without_os_env) 735 736 # This check is checking if the expected values are set for the extension commands 737 for helper_var in helper_env_vars: 738 self.assertIn("%s=%s" % (helper_var, helper_env_vars[helper_var]), output) 739 740 def test_it_should_pass_supported_features_list_as_environment_variables(self): 741 742 class TestFeature(AgentSupportedFeature): 743 744 def __init__(self, name, version, supported): 745 super(TestFeature, self).__init__(name=name, 746 version=version, 747 supported=supported) 748 749 test_name = str(uuid.uuid4()) 750 test_version = str(uuid.uuid4()) 751 752 command = "check_env.py" 753 self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), ''' 754import os 755import json 756import sys 757 758features = os.getenv("{0}") 759if not features: 760 print("{0} not found in environment") 761 sys.exit(0) 762l = json.loads(features) 763found = False 764for feature in l: 765 if feature['Key'] == "{1}" and feature['Value'] == "{2}": 766 found = True 767 break 768 769print("Found Feature %s: %s" % ("{1}", found)) 770'''.format(ExtCommandEnvVariable.ExtensionSupportedFeatures, test_name, test_version)) 771 772 # It should include all supported features and pass it as Environment Variable to extensions 773 test_supported_features = {test_name: TestFeature(name=test_name, version=test_version, supported=True)} 774 with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions", 775 return_value=test_supported_features): 776 output = self.ext_handler_instance.launch_command(command) 777 778 self.assertIn("[stdout]\nFound Feature {0}: True".format(test_name), output, "Feature not found") 779 780 # It should not include the feature if feature not supported 781 test_supported_features = { 782 test_name: TestFeature(name=test_name, version=test_version, supported=False), 783 "testFeature": TestFeature(name="testFeature", version="1.2.1", supported=True) 784 } 785 with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions", 786 return_value=test_supported_features): 787 output = self.ext_handler_instance.launch_command(command) 788 789 self.assertIn("[stdout]\nFound Feature {0}: False".format(test_name), output, "Feature wrongfully found") 790 791 # It should not include the SupportedFeatures Key in Environment variables if no features supported 792 test_supported_features = {test_name: TestFeature(name=test_name, version=test_version, supported=False)} 793 with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions", 794 return_value=test_supported_features): 795 output = self.ext_handler_instance.launch_command(command) 796 797 self.assertIn( 798 "[stdout]\n{0} not found in environment".format(ExtCommandEnvVariable.ExtensionSupportedFeatures), 799 output, "Environment variable should not be found") 800