1# Copyright 2016 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import absolute_import
16
17from apitools.base.py import encoding
18import mock
19import os
20import tempfile
21import unittest2
22from expects import be_false, be_none, be_true, expect, equal, raise_error
23
24from google.api.auth import suppliers
25from google.api.auth import tokens
26from google.api.control import client, messages, report_request, service, wsgi
27
28
29def _dummy_start_response(content, dummy_response_headers):
30    pass
31
32
33_DUMMY_RESPONSE = ('All must answer "here!"',)
34
35
36class _DummyWsgiApp(object):
37
38    def __call__(self, environ, dummy_start_response):
39        return _DUMMY_RESPONSE
40
41
42class TestEnvironmentMiddleware(unittest2.TestCase):
43
44    def test_should_add_service_et_al_to_environment(self):
45        cls = wsgi.EnvironmentMiddleware
46        wrappee = _DummyWsgiApp()
47        wanted_service = service.Loaders.SIMPLE.load()
48        wrapped = cls(wrappee, wanted_service)
49
50        given = {
51            'wsgi.url_scheme': 'http',
52            'HTTP_HOST': 'localhost',
53            'REQUEST_METHOD': 'GET'
54        }
55        wrapped(given, _dummy_start_response)
56        expect(given.get(cls.SERVICE)).to(equal(wanted_service))
57        expect(given.get(cls.SERVICE_NAME)).to(equal(wanted_service.name))
58        expect(given.get(cls.METHOD_REGISTRY)).not_to(be_none)
59        expect(given.get(cls.REPORTING_RULES)).not_to(be_none)
60        expect(given.get(cls.METHOD_INFO)).not_to(be_none)
61
62
63class TestMiddleware(unittest2.TestCase):
64    PROJECT_ID = 'middleware'
65
66    def test_should_not_send_requests_if_there_is_no_service(self):
67        wrappee = _DummyWsgiApp()
68        control_client = mock.MagicMock(spec=client.Client)
69
70        given = {
71            'wsgi.url_scheme': 'http',
72            'PATH_INFO': '/any/method',
73            'REMOTE_ADDR': '192.168.0.3',
74            'HTTP_HOST': 'localhost',
75            'HTTP_REFERER': 'example.myreferer.com',
76            'REQUEST_METHOD': 'GET'
77        }
78        dummy_response = messages.CheckResponse(operationId='fake_operation_id')
79        wrapped = wsgi.Middleware(wrappee, self.PROJECT_ID, control_client)
80        wrapped(given, _dummy_start_response)
81        expect(control_client.check.called).to(be_false)
82        expect(control_client.report.called).to(be_false)
83
84    def test_should_send_requests_using_the_client(self):
85        wrappee = _DummyWsgiApp()
86        control_client = mock.MagicMock(spec=client.Client)
87
88        given = {
89            'wsgi.url_scheme': 'http',
90            'PATH_INFO': '/any/method',
91            'REMOTE_ADDR': '192.168.0.3',
92            'HTTP_HOST': 'localhost',
93            'HTTP_REFERER': 'example.myreferer.com',
94            'REQUEST_METHOD': 'GET'
95        }
96        dummy_response = messages.CheckResponse(operationId='fake_operation_id')
97        with_control = wsgi.Middleware(wrappee, self.PROJECT_ID, control_client)
98        wrapped = wsgi.EnvironmentMiddleware(with_control,
99                                             service.Loaders.SIMPLE.load())
100        control_client.check.return_value = dummy_response
101        wrapped(given, _dummy_start_response)
102        expect(control_client.check.called).to(be_true)
103        expect(control_client.report.called).to(be_true)
104
105    def test_should_send_report_request_if_check_fails(self):
106        wrappee = _DummyWsgiApp()
107        control_client = mock.MagicMock(spec=client.Client)
108        given = {
109            'wsgi.url_scheme': 'http',
110            'PATH_INFO': '/any/method',
111            'REMOTE_ADDR': '192.168.0.3',
112            'HTTP_HOST': 'localhost',
113            'HTTP_REFERER': 'example.myreferer.com',
114            'REQUEST_METHOD': 'GET'
115        }
116        dummy_response = messages.CheckResponse(
117            operationId = 'fake_operation_id',
118            checkErrors = [
119                messages.CheckError(
120                    code=messages.CheckError.CodeValueValuesEnum.PROJECT_DELETED)
121            ]
122        )
123        wrapped = wsgi.add_all(wrappee,
124                               self.PROJECT_ID,
125                               control_client,
126                               loader=service.Loaders.SIMPLE)
127        control_client.check.return_value = dummy_response
128        wrapped(given, _dummy_start_response)
129        expect(control_client.check.called).to(be_true)
130        expect(control_client.report.called).to(be_true)
131
132    def test_load_service_failed(self):
133        loader = mock.MagicMock(load=lambda: None)
134        with self.assertRaisesRegex(ValueError, "Failed to load service config"):
135            wsgi.add_all(_DummyWsgiApp(),
136                         self.PROJECT_ID,
137                         mock.MagicMock(spec=client.Client),
138                         loader=loader)
139
140
141_SYSTEM_PARAMETER_CONFIG_TEST = """
142{
143    "name": "system-parameter-config",
144    "systemParameters": {
145       "rules": [{
146         "selector": "Uvw.Method1",
147         "parameters": [{
148            "name": "name1",
149            "httpHeader": "Header-Key1",
150            "urlQueryParameter": "param_key1"
151         }, {
152            "name": "name2",
153            "httpHeader": "Header-Key2",
154            "urlQueryParameter": "param_key2"
155         }, {
156            "name": "api_key",
157            "httpHeader": "ApiKeyHeader",
158            "urlQueryParameter": "ApiKeyParam"
159         }, {
160            "httpHeader": "Ignored-NoName-Key3",
161            "urlQueryParameter": "Ignored-NoName-key3"
162         }]
163       }, {
164         "selector": "Bad.NotConfigured",
165         "parameters": [{
166            "name": "neverUsed",
167            "httpHeader": "NeverUsed-Key1",
168            "urlQueryParameter": "NeverUsed_key1"
169         }]
170       }]
171    },
172    "http": {
173        "rules": [{
174            "selector": "Uvw.Method1",
175            "get": "/uvw/method1/*"
176        }, {
177            "selector": "Uvw.MethodNeedsApiKey",
178            "get": "/uvw/method_needs_api_key/*"
179        }, {
180            "selector": "Uvw.DefaultParameters",
181            "get": "/uvw/default_parameters"
182        }]
183    },
184    "usage": {
185        "rules": [{
186            "selector" : "Uvw.Method1",
187            "allowUnregisteredCalls" : true
188        },  {
189            "selector": "Uvw.MethodNeedsApiKey",
190            "allowUnregisteredCalls" : false
191        }, {
192            "selector" : "Uvw.DefaultParameters",
193            "allowUnregisteredCalls" : true
194        }]
195    }
196}
197"""
198
199class TestMiddlewareWithParams(unittest2.TestCase):
200    PROJECT_ID = 'middleware-with-params'
201
202    def setUp(self):
203        _config_fd = tempfile.NamedTemporaryFile(delete=False)
204        with _config_fd as f:
205            f.write(_SYSTEM_PARAMETER_CONFIG_TEST)
206        self._config_file = _config_fd.name
207        os.environ[service.CONFIG_VAR] = self._config_file
208
209    def tearDown(self):
210        if os.path.exists(self._config_file):
211            os.remove(self._config_file)
212
213    def test_should_send_requests_with_no_param(self):
214        wrappee = _DummyWsgiApp()
215        control_client = mock.MagicMock(spec=client.Client)
216        given = {
217            'wsgi.url_scheme': 'http',
218            'PATH_INFO': '/uvw/method1/with_no_param',
219            'REMOTE_ADDR': '192.168.0.3',
220            'HTTP_HOST': 'localhost',
221            'HTTP_REFERER': 'example.myreferer.com',
222            'REQUEST_METHOD': 'GET'
223        }
224        dummy_response = messages.CheckResponse(operationId='fake_operation_id')
225        wrapped = wsgi.add_all(wrappee,
226                               self.PROJECT_ID,
227                               control_client,
228                               loader=service.Loaders.ENVIRONMENT)
229        control_client.check.return_value = dummy_response
230        wrapped(given, _dummy_start_response)
231        expect(control_client.check.called).to(be_true)
232        req = control_client.check.call_args[0][0]
233        expect(req.checkRequest.operation.consumerId).to(
234            equal('project:middleware-with-params'))
235        expect(control_client.report.called).to(be_true)
236
237    def test_should_send_requests_with_configured_query_param_api_key(self):
238        wrappee = _DummyWsgiApp()
239        control_client = mock.MagicMock(spec=client.Client)
240        given = {
241            'wsgi.url_scheme': 'http',
242            'QUERY_STRING': 'ApiKeyParam=my-query-value',
243            'PATH_INFO': '/uvw/method1/with_query_param',
244            'REMOTE_ADDR': '192.168.0.3',
245            'HTTP_HOST': 'localhost',
246            'HTTP_REFERER': 'example.myreferer.com',
247            'REQUEST_METHOD': 'GET'
248        }
249        dummy_response = messages.CheckResponse(operationId='fake_operation_id')
250        wrapped = wsgi.add_all(wrappee,
251                               self.PROJECT_ID,
252                               control_client,
253                               loader=service.Loaders.ENVIRONMENT)
254        control_client.check.return_value = dummy_response
255        wrapped(given, _dummy_start_response)
256        expect(control_client.check.called).to(be_true)
257        check_req = control_client.check.call_args[0][0]
258        expect(check_req.checkRequest.operation.consumerId).to(
259            equal('api_key:my-query-value'))
260        expect(control_client.report.called).to(be_true)
261        report_req = control_client.report.call_args[0][0]
262        expect(report_req.reportRequest.operations[0].consumerId).to(
263            equal('api_key:my-query-value'))
264
265    def test_should_send_requests_with_configured_header_api_key(self):
266        wrappee = _DummyWsgiApp()
267        control_client = mock.MagicMock(spec=client.Client)
268        given = {
269            'wsgi.url_scheme': 'http',
270            'PATH_INFO': '/uvw/method1/with_query_param',
271            'REMOTE_ADDR': '192.168.0.3',
272            'HTTP_HOST': 'localhost',
273            'HTTP_APIKEYHEADER': 'my-header-value',
274            'HTTP_REFERER': 'example.myreferer.com',
275            'REQUEST_METHOD': 'GET'
276        }
277        dummy_response = messages.CheckResponse(operationId='fake_operation_id')
278        wrapped = wsgi.add_all(wrappee,
279                               self.PROJECT_ID,
280                               control_client,
281                               loader=service.Loaders.ENVIRONMENT)
282        control_client.check.return_value = dummy_response
283        wrapped(given, _dummy_start_response)
284        expect(control_client.check.called).to(be_true)
285        check_request = control_client.check.call_args_list[0].checkRequest
286        check_req = control_client.check.call_args[0][0]
287        expect(check_req.checkRequest.operation.consumerId).to(
288            equal('api_key:my-header-value'))
289        expect(control_client.report.called).to(be_true)
290        report_req = control_client.report.call_args[0][0]
291        expect(report_req.reportRequest.operations[0].consumerId).to(
292            equal('api_key:my-header-value'))
293
294    def test_should_send_requests_with_default_query_param_api_key(self):
295        for default_key in ('key', 'api_key'):
296            wrappee = _DummyWsgiApp()
297            control_client = mock.MagicMock(spec=client.Client)
298            given = {
299                'wsgi.url_scheme': 'http',
300                'QUERY_STRING': '%s=my-default-api-key-value' % (default_key,),
301                'PATH_INFO': '/uvw/method_needs_api_key/with_query_param',
302                'REMOTE_ADDR': '192.168.0.3',
303                'HTTP_HOST': 'localhost',
304                'HTTP_REFERER': 'example.myreferer.com',
305                'REQUEST_METHOD': 'GET'
306            }
307            dummy_response = messages.CheckResponse(operationId='fake_operation_id')
308            wrapped = wsgi.add_all(wrappee,
309                                   self.PROJECT_ID,
310                                   control_client,
311                                   loader=service.Loaders.ENVIRONMENT)
312            control_client.check.return_value = dummy_response
313            wrapped(given, _dummy_start_response)
314            expect(control_client.check.called).to(be_true)
315            check_request = control_client.check.call_args_list[0].checkRequest
316            check_req = control_client.check.call_args[0][0]
317            expect(check_req.checkRequest.operation.consumerId).to(
318                equal('api_key:my-default-api-key-value'))
319            expect(control_client.report.called).to(be_true)
320            report_req = control_client.report.call_args[0][0]
321            expect(report_req.reportRequest.operations[0].consumerId).to(
322                equal('api_key:my-default-api-key-value'))
323
324    def test_should_not_perform_check_if_needed_api_key_is_missing(self):
325        wrappee = _DummyWsgiApp()
326        control_client = mock.MagicMock(spec=client.Client)
327        given = {
328            'wsgi.url_scheme': 'http',
329            'PATH_INFO': '/uvw/method_needs_api_key/more_stuff',
330            'REMOTE_ADDR': '192.168.0.3',
331            'HTTP_HOST': 'localhost',
332            'HTTP_REFERER': 'example.myreferer.com',
333            'REQUEST_METHOD': 'GET'
334        }
335        dummy_response = messages.CheckResponse(operationId='fake_operation_id')
336        wrapped = wsgi.add_all(wrappee,
337                               self.PROJECT_ID,
338                               control_client,
339                               loader=service.Loaders.ENVIRONMENT)
340        control_client.check.return_value = dummy_response
341        wrapped(given, _dummy_start_response)
342        expect(control_client.check.called).to(be_false)
343        expect(control_client.report.called).to(be_true)
344        report_req = control_client.report.call_args[0][0]
345        expect(report_req.reportRequest.operations[0].consumerId).to(
346            equal('project:middleware-with-params'))
347
348AuthMiddleware = wsgi.AuthenticationMiddleware
349
350
351class TestAuthenticationMiddleware(unittest2.TestCase):
352
353  def setUp(self):
354      self._mock_application = _DummyWsgiApp()
355      self._mock_authenticator = mock.MagicMock(spec=tokens.Authenticator)
356      self._middleware = AuthMiddleware(self._mock_application,
357                                        self._mock_authenticator)
358
359  def test_no_authentication(self):
360      with self.assertRaisesRegex(ValueError, "Invalid authenticator"):
361          AuthMiddleware(self._mock_application, None)
362
363  def test_no_method_info(self):
364      environ = {}
365      self.assertEqual(_DUMMY_RESPONSE,
366                       self._middleware(environ, _dummy_start_response))
367
368  def test_no_auth_token(self):
369      auth_app = AuthMiddleware(self.UserInfoWsgiApp(), self._mock_authenticator)
370      method_info = mock.MagicMock()
371      method_info.auth_info = mock.MagicMock()
372      environ = {
373          wsgi.EnvironmentMiddleware.METHOD_INFO: method_info
374      }
375      self.assertIsNone(auth_app(environ, _dummy_start_response))
376
377  def test_malformed_authorization_header(self):
378      auth_app = AuthMiddleware(self.UserInfoWsgiApp(), self._mock_authenticator)
379      environ = {
380          "HTTP_AUTHORIZATION": "malformed-auth-token",
381          wsgi.EnvironmentMiddleware.METHOD_INFO: mock.MagicMock(),
382          wsgi.EnvironmentMiddleware.SERVICE_NAME: "service-name"
383      }
384      self._mock_authenticator.authenticate.side_effect = suppliers.UnauthenticatedException()
385      self.assertIsNone(auth_app(environ, _dummy_start_response))
386
387  def test_successful_authentication(self):
388      auth_token = "Bearer test-bearer-token"
389      auth_info = mock.MagicMock()
390      service_name = "test-service-name"
391      method_info = mock.MagicMock()
392      method_info.auth_info = auth_info
393      environ = {
394          "HTTP_AUTHORIZATION": auth_token,
395          wsgi.EnvironmentMiddleware.METHOD_INFO: method_info,
396          wsgi.EnvironmentMiddleware.SERVICE_NAME: service_name
397      }
398
399      user_info = mock.MagicMock()
400      self._mock_authenticator.authenticate.return_value = user_info
401      self._middleware(environ, _dummy_start_response)
402      self.assertEqual(user_info, environ.get(AuthMiddleware.USER_INFO))
403      authenticate_mock = self._mock_authenticator.authenticate
404      authenticate_mock.assert_called_once_with("test-bearer-token", auth_info,
405                                                service_name)
406
407  def test_auth_token_in_query(self):
408      auth_token = "test-bearer-token"
409      auth_info = mock.MagicMock()
410      service_name = "test-service-name"
411      method_info = mock.MagicMock()
412      method_info.auth_info = auth_info
413      environ = {
414          "QUERY_STRING": "access_token=" + auth_token,
415          wsgi.EnvironmentMiddleware.METHOD_INFO: method_info,
416          wsgi.EnvironmentMiddleware.SERVICE_NAME: service_name
417      }
418
419      user_info = mock.MagicMock()
420      self._mock_authenticator.authenticate.return_value = user_info
421      self._middleware(environ, _dummy_start_response)
422      self.assertEqual(user_info, environ.get(AuthMiddleware.USER_INFO))
423      authenticate_mock = self._mock_authenticator.authenticate
424      authenticate_mock.assert_called_once_with("test-bearer-token", auth_info,
425                                                service_name)
426
427  patched_environ = {}
428  @mock.patch("os.environ", patched_environ)
429  def test_set_user_info(self):
430    environ = {
431        "QUERY_STRING": "access_token=test-token",
432        wsgi.EnvironmentMiddleware.METHOD_INFO: mock.MagicMock(),
433        wsgi.EnvironmentMiddleware.SERVICE_NAME: "test-service-name"
434    }
435    application = self.UserInfoWsgiApp()
436    auth_middleware = AuthMiddleware(application, self._mock_authenticator)
437    user_info = mock.MagicMock()
438    self._mock_authenticator.authenticate.return_value = user_info
439    self.assertEqual(user_info, auth_middleware(environ,
440                                                _dummy_start_response))
441    self.assertFalse(self.patched_environ)
442
443  class UserInfoWsgiApp(object):
444    def __call__(self, environ, start_response):
445      return os.environ.get(wsgi.AuthenticationMiddleware.USER_INFO)
446
447
448class TestCreateAuthenticator(unittest2.TestCase):
449    def test_create_without_service(self):
450        with self.assertRaises(ValueError):
451            wsgi._create_authenticator(None)
452
453    def test_load_service_without_auth(self):
454        service = _read_service_from_json("{}")
455        self.assertIsNone(wsgi._create_authenticator(service))
456
457    def test_load_service(self):
458        json = """{
459            "authentication": {
460                "providers": [{
461                    "issuer": "auth-issuer"
462                }]
463            }
464        }"""
465        service = _read_service_from_json(json)
466        self.assertIsNotNone(wsgi._create_authenticator(service))
467
468
469patched_platform_environ = {}
470@mock.patch.dict('os.environ', patched_platform_environ, clear=True)
471class TestPlatformDetection(unittest2.TestCase):
472
473  def test_development(self):
474    os.environ['SERVER_SOFTWARE'] = 'Development/2.0.0'
475    self.assertEqual(report_request.ReportedPlatforms.DEVELOPMENT,
476                     wsgi._get_platform())
477
478  def test_gke(self):
479    os.environ['KUBERNETES_SERVICE_HOST'] = 'hostname'
480    self.assertEqual(report_request.ReportedPlatforms.GKE,
481                     wsgi._get_platform())
482
483  @mock.patch.object(wsgi, '_running_on_gce', return_value=True)
484  def test_gae_flex(self, _running_on_gce):
485    os.environ['GAE_MODULE_NAME'] = 'gae_module'
486    self.assertEqual(report_request.ReportedPlatforms.GAE_FLEX,
487                     wsgi._get_platform())
488
489  @mock.patch.object(wsgi, '_running_on_gce', return_value=True)
490  def test_gce(self, _running_on_gce):
491    self.assertEqual(report_request.ReportedPlatforms.GCE,
492                     wsgi._get_platform())
493
494  @mock.patch.object(wsgi, '_running_on_gce', return_value=False)
495  def test_gae_standard(self, _running_on_gce):
496    os.environ['GAE_MODULE_NAME'] = 'gae_module'
497    self.assertEqual(report_request.ReportedPlatforms.GAE_STANDARD,
498                     wsgi._get_platform())
499
500  @mock.patch.object(wsgi, '_running_on_gce', return_value=False)
501  def test_unknown(self, _running_on_gce):
502    self.assertEqual(report_request.ReportedPlatforms.UNKNOWN,
503                     wsgi._get_platform())
504
505
506def _read_service_from_json(json):
507    return encoding.JsonToMessage(messages.Service, json)
508