1#!/usr/bin/python
2# Copyright 2016 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Unittest for script_retriever.py module."""
17
18import subprocess
19
20from google_compute_engine.compat import urlerror
21from google_compute_engine.metadata_scripts import script_retriever
22from google_compute_engine.test_compat import builtin
23from google_compute_engine.test_compat import mock
24from google_compute_engine.test_compat import unittest
25
26
27class ScriptRetrieverTest(unittest.TestCase):
28
29  def setUp(self):
30    self.script_type = 'test'
31    self.dest_dir = '/tmp'
32    self.dest = '/tmp/file'
33    self.mock_logger = mock.Mock()
34    self.mock_watcher = mock.Mock()
35    self.retriever = script_retriever.ScriptRetriever(
36        self.mock_logger, self.script_type)
37
38  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
39  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlrequest.Request')
40  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlrequest.urlopen')
41  def testDownloadAuthUrl(self, mock_urlopen, mock_request, mock_tempfile):
42    auth_url = 'https://storage.googleapis.com/fake/url'
43    mock_tempfile.return_value = mock_tempfile
44    mock_tempfile.name = self.dest
45    self.retriever.token = 'bar'
46
47    mock_open = mock.mock_open()
48    with mock.patch('%s.open' % builtin, mock_open):
49      self.assertEqual(
50          self.retriever._DownloadAuthUrl(auth_url, self.dest_dir), self.dest)
51
52    mock_tempfile.assert_called_once_with(dir=self.dest_dir, delete=False)
53    mock_tempfile.close.assert_called_once_with()
54
55    self.mock_logger.info.assert_called_once_with(
56        mock.ANY, auth_url, self.dest)
57    mock_request.assert_called_with(auth_url)
58    mocked_request = mock_request()
59    mocked_request.add_unredirected_header.assert_called_with(
60        'Authorization', 'bar')
61    mock_urlopen.assert_called_with(mocked_request)
62    urlopen_read = mock_urlopen().read(return_value=b'foo').decode()
63    self.mock_logger.warning.assert_not_called()
64
65    mock_open.assert_called_once_with(self.dest, 'w')
66    handle = mock_open()
67    handle.write.assert_called_once_with(urlopen_read)
68
69  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
70  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlrequest.Request')
71  @mock.patch('google_compute_engine.metadata_watcher.MetadataWatcher.GetMetadata')
72  def testDownloadAuthUrlExceptionAndToken(
73      self, mock_get_metadata, mock_request, mock_tempfile):
74    auth_url = 'https://storage.googleapis.com/fake/url'
75    metadata_prefix = 'http://metadata.google.internal/computeMetadata/v1/'
76    token_url = metadata_prefix + 'instance/service-accounts/default/token'
77    mock_tempfile.return_value = mock_tempfile
78    mock_tempfile.name = self.dest
79    self.retriever.token = None
80
81    mock_get_metadata.return_value = {
82        'token_type': 'foo', 'access_token': 'bar'}
83    mock_request.return_value = mock_request
84    mock_request.side_effect = urlerror.URLError('Error.')
85
86    self.assertIsNone(self.retriever._DownloadAuthUrl(auth_url, self.dest_dir))
87
88    mock_get_metadata.return_value = mock_get_metadata
89    # GetMetadata includes a prefix, so remove it.
90    stripped_url = token_url.replace(metadata_prefix, '')
91    mock_get_metadata.assert_called_once_with(
92        stripped_url, recursive=False, retry=False)
93
94    self.assertEqual(self.retriever.token, 'foo bar')
95
96    self.mock_logger.info.assert_called_once_with(
97        mock.ANY, auth_url, self.dest)
98    self.assertEqual(self.mock_logger.warning.call_count, 1)
99
100  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
101  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.ScriptRetriever._DownloadUrl')
102  @mock.patch('google_compute_engine.metadata_watcher.MetadataWatcher.GetMetadata')
103  def testDownloadAuthUrlFallback(
104      self, mock_get_metadata, mock_download_url, mock_tempfile):
105    auth_url = 'https://storage.googleapis.com/fake/url'
106    metadata_prefix = 'http://metadata.google.internal/computeMetadata/v1/'
107    token_url = metadata_prefix + 'instance/service-accounts/default/token'
108    mock_tempfile.return_value = mock_tempfile
109    mock_tempfile.name = self.dest
110    self.retriever.token = None
111
112    mock_get_metadata.return_value = None
113    mock_download_url.return_value = None
114
115    self.assertIsNone(self.retriever._DownloadAuthUrl(auth_url, self.dest_dir))
116
117    mock_get_metadata.return_value = mock_get_metadata
118    # GetMetadata includes a prefix, so remove it.
119    prefix = 'http://metadata.google.internal/computeMetadata/v1/'
120    stripped_url = token_url.replace(prefix, '')
121    mock_get_metadata.assert_called_once_with(
122        stripped_url, recursive=False, retry=False)
123    mock_download_url.assert_called_once_with(auth_url, self.dest_dir)
124
125    self.assertIsNone(self.retriever.token)
126
127    expected_calls = [
128        mock.call(mock.ANY, auth_url, self.dest),
129        mock.call(mock.ANY),
130    ]
131    self.assertEqual(self.mock_logger.info.mock_calls, expected_calls)
132
133  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
134  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve')
135  def testDownloadUrl(self, mock_retrieve, mock_tempfile):
136    url = 'http://www.google.com/fake/url'
137    mock_tempfile.return_value = mock_tempfile
138    mock_tempfile.name = self.dest
139    self.assertEqual(
140        self.retriever._DownloadUrl(url, self.dest_dir), self.dest)
141    mock_tempfile.assert_called_once_with(dir=self.dest_dir, delete=False)
142    mock_tempfile.close.assert_called_once_with()
143    self.mock_logger.info.assert_called_once_with(mock.ANY, url, self.dest)
144    mock_retrieve.assert_called_once_with(url, self.dest)
145    self.mock_logger.warning.assert_not_called()
146
147  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.time')
148  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
149  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve')
150  def testDownloadUrlProcessError(self, mock_retrieve, mock_tempfile, mock_time):
151    url = 'http://www.google.com/fake/url'
152    mock_tempfile.return_value = mock_tempfile
153    mock_tempfile.name = self.dest
154    mock_success = mock.Mock()
155    mock_success.getcode.return_value = script_retriever.httpclient.OK
156    # Success after 3 timeout. Since max_retry = 3, the final result is fail.
157    mock_retrieve.side_effect = [
158        script_retriever.socket.timeout(),
159        script_retriever.socket.timeout(),
160        script_retriever.socket.timeout(),
161        mock_success,
162    ]
163    self.assertIsNone(self.retriever._DownloadUrl(url, self.dest_dir))
164    self.assertEqual(self.mock_logger.warning.call_count, 1)
165
166  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.time')
167  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
168  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve')
169  def testDownloadUrlWithRetry(self, mock_retrieve, mock_tempfile, mock_time):
170    url = 'http://www.google.com/fake/url'
171    mock_tempfile.return_value = mock_tempfile
172    mock_tempfile.name = self.dest
173    mock_success = mock.Mock()
174    mock_success.getcode.return_value = script_retriever.httpclient.OK
175    # Success after 2 timeout. Since max_retry = 3, the final result is success.
176    mock_retrieve.side_effect = [
177        script_retriever.socket.timeout(),
178        script_retriever.socket.timeout(),
179        mock_success,
180    ]
181    self.assertIsNotNone(self.retriever._DownloadUrl(url, self.dest_dir))
182
183  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
184  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve')
185  def testDownloadUrlException(self, mock_retrieve, mock_tempfile):
186    url = 'http://www.google.com/fake/url'
187    mock_tempfile.return_value = mock_tempfile
188    mock_tempfile.name = self.dest
189    mock_retrieve.side_effect = Exception('Error.')
190    self.assertIsNone(self.retriever._DownloadUrl(url, self.dest_dir))
191    self.assertEqual(self.mock_logger.warning.call_count, 1)
192
193  def _CreateUrls(self, bucket, obj, gs_match=True):
194    """Creates a URL for each of the supported Google Storage URL formats.
195
196    Args:
197      bucket: string, the Google Storage bucket name.
198      obj: string, the object name in the bucket.
199      gs_match: bool, True if the bucket and object names are valid.
200
201    Returns:
202      (list, dict):
203      list, the URLs to download.
204      dict, a Google Storage URL mapped to the expected 'gs://' format.
205    """
206    gs_url = 'gs://%s/%s' % (bucket, obj)
207    gs_urls = {gs_url: gs_url}
208    url_formats = [
209        'http://%s.storage.googleapis.com/%s',
210        'https://%s.storage.googleapis.com/%s',
211        'http://storage.googleapis.com/%s/%s',
212        'https://storage.googleapis.com/%s/%s',
213        'http://commondatastorage.googleapis.com/%s/%s',
214        'https://commondatastorage.googleapis.com/%s/%s',
215    ]
216    url_formats = [url % (bucket, obj) for url in url_formats]
217    if gs_match:
218      gs_urls.update(dict((url, gs_url) for url in url_formats))
219      return ([], gs_urls)
220    else:
221      return (url_formats, gs_urls)
222
223  def testDownloadScript(self):
224    mock_auth_download = mock.Mock()
225    self.retriever._DownloadAuthUrl = mock_auth_download
226    mock_download = mock.Mock()
227    self.retriever._DownloadUrl = mock_download
228    download_urls = []
229    download_gs_urls = {}
230
231    component_urls = [
232        ('@#$%^', '\n\n\n\n', False),
233        ('///////', '///////', False),
234        ('Abc', 'xyz', False),
235        (' abc', 'xyz', False),
236        ('abc', 'xyz?', False),
237        ('abc', 'xyz*', False),
238        ('', 'xyz', False),
239        ('a', 'xyz', False),
240        ('abc', '', False),
241        ('hello', 'world', True),
242        ('hello', 'world!', True),
243        ('hello', 'world !', True),
244        ('hello', 'w o r l d ', True),
245        ('hello', 'w\no\nr\nl\nd ', True),
246        ('123_hello', '1!@#$%^', True),
247        ('123456', 'hello.world', True),
248    ]
249
250    for bucket, obj, gs_match in component_urls:
251      urls, gs_urls = self._CreateUrls(bucket, obj, gs_match=gs_match)
252      download_urls.extend(urls)
253      download_gs_urls.update(gs_urls)
254
255    # All Google Storage URLs are downloaded with an authentication token.
256    for url, gs_url in download_gs_urls.items():
257      mock_download.reset_mock()
258      mock_auth_download.reset_mock()
259      self.retriever._DownloadScript(gs_url, self.dest_dir)
260      new_gs_url = gs_url.replace('gs://', 'https://storage.googleapis.com/')
261      mock_auth_download.assert_called_once_with(new_gs_url, self.dest_dir)
262      mock_download.assert_not_called()
263
264    for url in download_urls:
265      mock_download.reset_mock()
266      self.retriever._DownloadScript(url, self.dest_dir)
267      mock_download.assert_called_once_with(url, self.dest_dir)
268
269    for url, gs_url in download_gs_urls.items():
270      if url.startswith('gs://'):
271        continue
272      mock_auth_download.reset_mock()
273      mock_auth_download.return_value = None
274      mock_download.reset_mock()
275      self.retriever._DownloadScript(url, self.dest_dir)
276      mock_auth_download.assert_called_once_with(url, self.dest_dir)
277
278  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
279  def testGetAttributeScripts(self, mock_tempfile):
280    script = 'echo Hello World.\n'
281    script_dest = '/tmp/script'
282    script_url = 'gs://fake/url'
283    script_url_dest = '/tmp/script_url'
284    attribute_data = {
285        '%s-script' % self.script_type: '\n%s' % script,
286        '%s-script-url' % self.script_type: script_url,
287    }
288    expected_data = {
289        '%s-script' % self.script_type: script_dest,
290        '%s-script-url' % self.script_type: script_url_dest,
291    }
292    # Mock saving a script to a file.
293    mock_dest = mock.Mock()
294    mock_dest.name = script_dest
295    mock_tempfile.__enter__.return_value = mock_dest
296    mock_tempfile.return_value = mock_tempfile
297    # Mock downloading a script from a URL.
298    mock_download = mock.Mock()
299    mock_download.return_value = script_url_dest
300    self.retriever._DownloadScript = mock_download
301
302    self.assertEqual(
303        self.retriever._GetAttributeScripts(attribute_data, self.dest_dir),
304        expected_data)
305    self.assertEqual(self.mock_logger.info.call_count, 2)
306    mock_dest.write.assert_called_once_with(script)
307    mock_download.assert_called_once_with(script_url, self.dest_dir)
308
309  def testGetAttributeScriptsNone(self):
310    attribute_data = {}
311    expected_data = {}
312    self.assertEqual(
313        self.retriever._GetAttributeScripts(attribute_data, self.dest_dir),
314        expected_data)
315    self.mock_logger.info.assert_not_called()
316
317  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
318  def testGetScripts(self, mock_tempfile):
319    script_dest = '/tmp/script'
320    script_url_dest = '/tmp/script_url'
321    metadata = {
322        'instance': {
323            'attributes': {
324                '%s-script' % self.script_type: 'a',
325                '%s-script-url' % self.script_type: 'b',
326            },
327        },
328        'project': {
329            'attributes': {
330                '%s-script' % self.script_type: 'c',
331                '%s-script-url' % self.script_type: 'd',
332            },
333        },
334    }
335    expected_data = {
336        '%s-script' % self.script_type: script_dest,
337        '%s-script-url' % self.script_type: script_url_dest,
338    }
339    self.mock_watcher.GetMetadata.return_value = metadata
340    self.retriever.watcher = self.mock_watcher
341    # Mock saving a script to a file.
342    mock_dest = mock.Mock()
343    mock_dest.name = script_dest
344    mock_tempfile.__enter__.return_value = mock_dest
345    mock_tempfile.return_value = mock_tempfile
346    # Mock downloading a script from a URL.
347    mock_download = mock.Mock()
348    mock_download.return_value = script_url_dest
349    self.retriever._DownloadScript = mock_download
350
351    self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data)
352    self.assertEqual(self.mock_logger.info.call_count, 2)
353    self.assertEqual(self.mock_logger.warning.call_count, 0)
354    mock_dest.write.assert_called_once_with('a')
355    mock_download.assert_called_once_with('b', self.dest_dir)
356
357  def testGetScriptsNone(self):
358    metadata = {
359        'instance': {
360            'attributes': None,
361        },
362        'project': {
363            'attributes': None,
364        },
365    }
366    expected_data = {}
367    self.mock_watcher.GetMetadata.return_value = metadata
368    self.retriever.watcher = self.mock_watcher
369    self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data)
370    self.mock_logger.info.assert_not_called()
371
372  def testGetScriptsNoMetadata(self):
373    metadata = None
374    expected_data = {}
375    self.mock_watcher.GetMetadata.return_value = metadata
376    self.retriever.watcher = self.mock_watcher
377    self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data)
378    self.mock_logger.info.assert_not_called()
379    self.assertEqual(self.mock_logger.warning.call_count, 2)
380
381  @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile')
382  def testGetScriptsFailed(self, mock_tempfile):
383    script_dest = '/tmp/script'
384    script_url_dest = None
385    metadata = {
386        'instance': {
387            'attributes': {
388                '%s-script' % self.script_type: 'a',
389                '%s-script-url' % self.script_type: 'b',
390            },
391        },
392        'project': {
393            'attributes': {
394                '%s-script' % self.script_type: 'c',
395                '%s-script-url' % self.script_type: 'd',
396            },
397        },
398    }
399    expected_data = {
400        '%s-script' % self.script_type: script_dest,
401        '%s-script-url' % self.script_type: script_url_dest,
402    }
403    self.mock_watcher.GetMetadata.return_value = metadata
404    self.retriever.watcher = self.mock_watcher
405    # Mock saving a script to a file.
406    mock_dest = mock.Mock()
407    mock_dest.name = script_dest
408    mock_tempfile.__enter__.return_value = mock_dest
409    mock_tempfile.return_value = mock_tempfile
410    # Mock downloading a script from a URL.
411    mock_download = mock.Mock()
412    mock_download.return_value = None
413    self.retriever._DownloadScript = mock_download
414
415    self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data)
416    self.assertEqual(self.mock_logger.info.call_count, 2)
417    self.assertEqual(self.mock_logger.warning.call_count, 1)
418
419
420if __name__ == '__main__':
421  unittest.main()
422