1#!/usr/bin/python 2# Copyright 2016 Google Inc. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16"""Unittest for script_retriever.py module.""" 17 18import subprocess 19 20from google_compute_engine.compat import urlerror 21from google_compute_engine.metadata_scripts import script_retriever 22from google_compute_engine.test_compat import builtin 23from google_compute_engine.test_compat import mock 24from google_compute_engine.test_compat import unittest 25 26 27class ScriptRetrieverTest(unittest.TestCase): 28 29 def setUp(self): 30 self.script_type = 'test' 31 self.dest_dir = '/tmp' 32 self.dest = '/tmp/file' 33 self.mock_logger = mock.Mock() 34 self.mock_watcher = mock.Mock() 35 self.retriever = script_retriever.ScriptRetriever( 36 self.mock_logger, self.script_type) 37 38 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 39 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlrequest.Request') 40 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlrequest.urlopen') 41 def testDownloadAuthUrl(self, mock_urlopen, mock_request, mock_tempfile): 42 auth_url = 'https://storage.googleapis.com/fake/url' 43 mock_tempfile.return_value = mock_tempfile 44 mock_tempfile.name = self.dest 45 self.retriever.token = 'bar' 46 47 mock_open = mock.mock_open() 48 with mock.patch('%s.open' % builtin, mock_open): 49 self.assertEqual( 50 self.retriever._DownloadAuthUrl(auth_url, self.dest_dir), self.dest) 51 52 mock_tempfile.assert_called_once_with(dir=self.dest_dir, delete=False) 53 mock_tempfile.close.assert_called_once_with() 54 55 self.mock_logger.info.assert_called_once_with( 56 mock.ANY, auth_url, self.dest) 57 mock_request.assert_called_with(auth_url) 58 mocked_request = mock_request() 59 mocked_request.add_unredirected_header.assert_called_with( 60 'Authorization', 'bar') 61 mock_urlopen.assert_called_with(mocked_request) 62 urlopen_read = mock_urlopen().read(return_value=b'foo').decode() 63 self.mock_logger.warning.assert_not_called() 64 65 mock_open.assert_called_once_with(self.dest, 'w') 66 handle = mock_open() 67 handle.write.assert_called_once_with(urlopen_read) 68 69 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 70 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlrequest.Request') 71 @mock.patch('google_compute_engine.metadata_watcher.MetadataWatcher.GetMetadata') 72 def testDownloadAuthUrlExceptionAndToken( 73 self, mock_get_metadata, mock_request, mock_tempfile): 74 auth_url = 'https://storage.googleapis.com/fake/url' 75 metadata_prefix = 'http://metadata.google.internal/computeMetadata/v1/' 76 token_url = metadata_prefix + 'instance/service-accounts/default/token' 77 mock_tempfile.return_value = mock_tempfile 78 mock_tempfile.name = self.dest 79 self.retriever.token = None 80 81 mock_get_metadata.return_value = { 82 'token_type': 'foo', 'access_token': 'bar'} 83 mock_request.return_value = mock_request 84 mock_request.side_effect = urlerror.URLError('Error.') 85 86 self.assertIsNone(self.retriever._DownloadAuthUrl(auth_url, self.dest_dir)) 87 88 mock_get_metadata.return_value = mock_get_metadata 89 # GetMetadata includes a prefix, so remove it. 90 stripped_url = token_url.replace(metadata_prefix, '') 91 mock_get_metadata.assert_called_once_with( 92 stripped_url, recursive=False, retry=False) 93 94 self.assertEqual(self.retriever.token, 'foo bar') 95 96 self.mock_logger.info.assert_called_once_with( 97 mock.ANY, auth_url, self.dest) 98 self.assertEqual(self.mock_logger.warning.call_count, 1) 99 100 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 101 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.ScriptRetriever._DownloadUrl') 102 @mock.patch('google_compute_engine.metadata_watcher.MetadataWatcher.GetMetadata') 103 def testDownloadAuthUrlFallback( 104 self, mock_get_metadata, mock_download_url, mock_tempfile): 105 auth_url = 'https://storage.googleapis.com/fake/url' 106 metadata_prefix = 'http://metadata.google.internal/computeMetadata/v1/' 107 token_url = metadata_prefix + 'instance/service-accounts/default/token' 108 mock_tempfile.return_value = mock_tempfile 109 mock_tempfile.name = self.dest 110 self.retriever.token = None 111 112 mock_get_metadata.return_value = None 113 mock_download_url.return_value = None 114 115 self.assertIsNone(self.retriever._DownloadAuthUrl(auth_url, self.dest_dir)) 116 117 mock_get_metadata.return_value = mock_get_metadata 118 # GetMetadata includes a prefix, so remove it. 119 prefix = 'http://metadata.google.internal/computeMetadata/v1/' 120 stripped_url = token_url.replace(prefix, '') 121 mock_get_metadata.assert_called_once_with( 122 stripped_url, recursive=False, retry=False) 123 mock_download_url.assert_called_once_with(auth_url, self.dest_dir) 124 125 self.assertIsNone(self.retriever.token) 126 127 expected_calls = [ 128 mock.call(mock.ANY, auth_url, self.dest), 129 mock.call(mock.ANY), 130 ] 131 self.assertEqual(self.mock_logger.info.mock_calls, expected_calls) 132 133 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 134 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve') 135 def testDownloadUrl(self, mock_retrieve, mock_tempfile): 136 url = 'http://www.google.com/fake/url' 137 mock_tempfile.return_value = mock_tempfile 138 mock_tempfile.name = self.dest 139 self.assertEqual( 140 self.retriever._DownloadUrl(url, self.dest_dir), self.dest) 141 mock_tempfile.assert_called_once_with(dir=self.dest_dir, delete=False) 142 mock_tempfile.close.assert_called_once_with() 143 self.mock_logger.info.assert_called_once_with(mock.ANY, url, self.dest) 144 mock_retrieve.assert_called_once_with(url, self.dest) 145 self.mock_logger.warning.assert_not_called() 146 147 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.time') 148 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 149 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve') 150 def testDownloadUrlProcessError(self, mock_retrieve, mock_tempfile, mock_time): 151 url = 'http://www.google.com/fake/url' 152 mock_tempfile.return_value = mock_tempfile 153 mock_tempfile.name = self.dest 154 mock_success = mock.Mock() 155 mock_success.getcode.return_value = script_retriever.httpclient.OK 156 # Success after 3 timeout. Since max_retry = 3, the final result is fail. 157 mock_retrieve.side_effect = [ 158 script_retriever.socket.timeout(), 159 script_retriever.socket.timeout(), 160 script_retriever.socket.timeout(), 161 mock_success, 162 ] 163 self.assertIsNone(self.retriever._DownloadUrl(url, self.dest_dir)) 164 self.assertEqual(self.mock_logger.warning.call_count, 1) 165 166 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.time') 167 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 168 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve') 169 def testDownloadUrlWithRetry(self, mock_retrieve, mock_tempfile, mock_time): 170 url = 'http://www.google.com/fake/url' 171 mock_tempfile.return_value = mock_tempfile 172 mock_tempfile.name = self.dest 173 mock_success = mock.Mock() 174 mock_success.getcode.return_value = script_retriever.httpclient.OK 175 # Success after 2 timeout. Since max_retry = 3, the final result is success. 176 mock_retrieve.side_effect = [ 177 script_retriever.socket.timeout(), 178 script_retriever.socket.timeout(), 179 mock_success, 180 ] 181 self.assertIsNotNone(self.retriever._DownloadUrl(url, self.dest_dir)) 182 183 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 184 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve') 185 def testDownloadUrlException(self, mock_retrieve, mock_tempfile): 186 url = 'http://www.google.com/fake/url' 187 mock_tempfile.return_value = mock_tempfile 188 mock_tempfile.name = self.dest 189 mock_retrieve.side_effect = Exception('Error.') 190 self.assertIsNone(self.retriever._DownloadUrl(url, self.dest_dir)) 191 self.assertEqual(self.mock_logger.warning.call_count, 1) 192 193 def _CreateUrls(self, bucket, obj, gs_match=True): 194 """Creates a URL for each of the supported Google Storage URL formats. 195 196 Args: 197 bucket: string, the Google Storage bucket name. 198 obj: string, the object name in the bucket. 199 gs_match: bool, True if the bucket and object names are valid. 200 201 Returns: 202 (list, dict): 203 list, the URLs to download. 204 dict, a Google Storage URL mapped to the expected 'gs://' format. 205 """ 206 gs_url = 'gs://%s/%s' % (bucket, obj) 207 gs_urls = {gs_url: gs_url} 208 url_formats = [ 209 'http://%s.storage.googleapis.com/%s', 210 'https://%s.storage.googleapis.com/%s', 211 'http://storage.googleapis.com/%s/%s', 212 'https://storage.googleapis.com/%s/%s', 213 'http://commondatastorage.googleapis.com/%s/%s', 214 'https://commondatastorage.googleapis.com/%s/%s', 215 ] 216 url_formats = [url % (bucket, obj) for url in url_formats] 217 if gs_match: 218 gs_urls.update(dict((url, gs_url) for url in url_formats)) 219 return ([], gs_urls) 220 else: 221 return (url_formats, gs_urls) 222 223 def testDownloadScript(self): 224 mock_auth_download = mock.Mock() 225 self.retriever._DownloadAuthUrl = mock_auth_download 226 mock_download = mock.Mock() 227 self.retriever._DownloadUrl = mock_download 228 download_urls = [] 229 download_gs_urls = {} 230 231 component_urls = [ 232 ('@#$%^', '\n\n\n\n', False), 233 ('///////', '///////', False), 234 ('Abc', 'xyz', False), 235 (' abc', 'xyz', False), 236 ('abc', 'xyz?', False), 237 ('abc', 'xyz*', False), 238 ('', 'xyz', False), 239 ('a', 'xyz', False), 240 ('abc', '', False), 241 ('hello', 'world', True), 242 ('hello', 'world!', True), 243 ('hello', 'world !', True), 244 ('hello', 'w o r l d ', True), 245 ('hello', 'w\no\nr\nl\nd ', True), 246 ('123_hello', '1!@#$%^', True), 247 ('123456', 'hello.world', True), 248 ] 249 250 for bucket, obj, gs_match in component_urls: 251 urls, gs_urls = self._CreateUrls(bucket, obj, gs_match=gs_match) 252 download_urls.extend(urls) 253 download_gs_urls.update(gs_urls) 254 255 # All Google Storage URLs are downloaded with an authentication token. 256 for url, gs_url in download_gs_urls.items(): 257 mock_download.reset_mock() 258 mock_auth_download.reset_mock() 259 self.retriever._DownloadScript(gs_url, self.dest_dir) 260 new_gs_url = gs_url.replace('gs://', 'https://storage.googleapis.com/') 261 mock_auth_download.assert_called_once_with(new_gs_url, self.dest_dir) 262 mock_download.assert_not_called() 263 264 for url in download_urls: 265 mock_download.reset_mock() 266 self.retriever._DownloadScript(url, self.dest_dir) 267 mock_download.assert_called_once_with(url, self.dest_dir) 268 269 for url, gs_url in download_gs_urls.items(): 270 if url.startswith('gs://'): 271 continue 272 mock_auth_download.reset_mock() 273 mock_auth_download.return_value = None 274 mock_download.reset_mock() 275 self.retriever._DownloadScript(url, self.dest_dir) 276 mock_auth_download.assert_called_once_with(url, self.dest_dir) 277 278 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 279 def testGetAttributeScripts(self, mock_tempfile): 280 script = 'echo Hello World.\n' 281 script_dest = '/tmp/script' 282 script_url = 'gs://fake/url' 283 script_url_dest = '/tmp/script_url' 284 attribute_data = { 285 '%s-script' % self.script_type: '\n%s' % script, 286 '%s-script-url' % self.script_type: script_url, 287 } 288 expected_data = { 289 '%s-script' % self.script_type: script_dest, 290 '%s-script-url' % self.script_type: script_url_dest, 291 } 292 # Mock saving a script to a file. 293 mock_dest = mock.Mock() 294 mock_dest.name = script_dest 295 mock_tempfile.__enter__.return_value = mock_dest 296 mock_tempfile.return_value = mock_tempfile 297 # Mock downloading a script from a URL. 298 mock_download = mock.Mock() 299 mock_download.return_value = script_url_dest 300 self.retriever._DownloadScript = mock_download 301 302 self.assertEqual( 303 self.retriever._GetAttributeScripts(attribute_data, self.dest_dir), 304 expected_data) 305 self.assertEqual(self.mock_logger.info.call_count, 2) 306 mock_dest.write.assert_called_once_with(script) 307 mock_download.assert_called_once_with(script_url, self.dest_dir) 308 309 def testGetAttributeScriptsNone(self): 310 attribute_data = {} 311 expected_data = {} 312 self.assertEqual( 313 self.retriever._GetAttributeScripts(attribute_data, self.dest_dir), 314 expected_data) 315 self.mock_logger.info.assert_not_called() 316 317 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 318 def testGetScripts(self, mock_tempfile): 319 script_dest = '/tmp/script' 320 script_url_dest = '/tmp/script_url' 321 metadata = { 322 'instance': { 323 'attributes': { 324 '%s-script' % self.script_type: 'a', 325 '%s-script-url' % self.script_type: 'b', 326 }, 327 }, 328 'project': { 329 'attributes': { 330 '%s-script' % self.script_type: 'c', 331 '%s-script-url' % self.script_type: 'd', 332 }, 333 }, 334 } 335 expected_data = { 336 '%s-script' % self.script_type: script_dest, 337 '%s-script-url' % self.script_type: script_url_dest, 338 } 339 self.mock_watcher.GetMetadata.return_value = metadata 340 self.retriever.watcher = self.mock_watcher 341 # Mock saving a script to a file. 342 mock_dest = mock.Mock() 343 mock_dest.name = script_dest 344 mock_tempfile.__enter__.return_value = mock_dest 345 mock_tempfile.return_value = mock_tempfile 346 # Mock downloading a script from a URL. 347 mock_download = mock.Mock() 348 mock_download.return_value = script_url_dest 349 self.retriever._DownloadScript = mock_download 350 351 self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data) 352 self.assertEqual(self.mock_logger.info.call_count, 2) 353 self.assertEqual(self.mock_logger.warning.call_count, 0) 354 mock_dest.write.assert_called_once_with('a') 355 mock_download.assert_called_once_with('b', self.dest_dir) 356 357 def testGetScriptsNone(self): 358 metadata = { 359 'instance': { 360 'attributes': None, 361 }, 362 'project': { 363 'attributes': None, 364 }, 365 } 366 expected_data = {} 367 self.mock_watcher.GetMetadata.return_value = metadata 368 self.retriever.watcher = self.mock_watcher 369 self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data) 370 self.mock_logger.info.assert_not_called() 371 372 def testGetScriptsNoMetadata(self): 373 metadata = None 374 expected_data = {} 375 self.mock_watcher.GetMetadata.return_value = metadata 376 self.retriever.watcher = self.mock_watcher 377 self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data) 378 self.mock_logger.info.assert_not_called() 379 self.assertEqual(self.mock_logger.warning.call_count, 2) 380 381 @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') 382 def testGetScriptsFailed(self, mock_tempfile): 383 script_dest = '/tmp/script' 384 script_url_dest = None 385 metadata = { 386 'instance': { 387 'attributes': { 388 '%s-script' % self.script_type: 'a', 389 '%s-script-url' % self.script_type: 'b', 390 }, 391 }, 392 'project': { 393 'attributes': { 394 '%s-script' % self.script_type: 'c', 395 '%s-script-url' % self.script_type: 'd', 396 }, 397 }, 398 } 399 expected_data = { 400 '%s-script' % self.script_type: script_dest, 401 '%s-script-url' % self.script_type: script_url_dest, 402 } 403 self.mock_watcher.GetMetadata.return_value = metadata 404 self.retriever.watcher = self.mock_watcher 405 # Mock saving a script to a file. 406 mock_dest = mock.Mock() 407 mock_dest.name = script_dest 408 mock_tempfile.__enter__.return_value = mock_dest 409 mock_tempfile.return_value = mock_tempfile 410 # Mock downloading a script from a URL. 411 mock_download = mock.Mock() 412 mock_download.return_value = None 413 self.retriever._DownloadScript = mock_download 414 415 self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data) 416 self.assertEqual(self.mock_logger.info.call_count, 2) 417 self.assertEqual(self.mock_logger.warning.call_count, 1) 418 419 420if __name__ == '__main__': 421 unittest.main() 422