1# Copyright 2016 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16import functools
17import os
18import sys
19
20import freezegun
21import mock
22import OpenSSL
23import pytest
24import requests
25import requests.adapters
26from six.moves import http_client
27
28from google.auth import environment_vars
29from google.auth import exceptions
30import google.auth.credentials
31import google.auth.transport._mtls_helper
32import google.auth.transport.requests
33from google.oauth2 import service_account
34from tests.transport import compliance
35
36
37@pytest.fixture
38def frozen_time():
39    with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen:
40        yield frozen
41
42
43class TestRequestResponse(compliance.RequestResponseTests):
44    def make_request(self):
45        return google.auth.transport.requests.Request()
46
47    def test_timeout(self):
48        http = mock.create_autospec(requests.Session, instance=True)
49        request = google.auth.transport.requests.Request(http)
50        request(url="http://example.com", method="GET", timeout=5)
51
52        assert http.request.call_args[1]["timeout"] == 5
53
54
55class TestTimeoutGuard(object):
56    def make_guard(self, *args, **kwargs):
57        return google.auth.transport.requests.TimeoutGuard(*args, **kwargs)
58
59    def test_tracks_elapsed_time_w_numeric_timeout(self, frozen_time):
60        with self.make_guard(timeout=10) as guard:
61            frozen_time.tick(delta=datetime.timedelta(seconds=3.8))
62        assert guard.remaining_timeout == 6.2
63
64    def test_tracks_elapsed_time_w_tuple_timeout(self, frozen_time):
65        with self.make_guard(timeout=(16, 19)) as guard:
66            frozen_time.tick(delta=datetime.timedelta(seconds=3.8))
67        assert guard.remaining_timeout == (12.2, 15.2)
68
69    def test_noop_if_no_timeout(self, frozen_time):
70        with self.make_guard(timeout=None) as guard:
71            frozen_time.tick(delta=datetime.timedelta(days=3650))
72        # NOTE: no timeout error raised, despite years have passed
73        assert guard.remaining_timeout is None
74
75    def test_timeout_error_w_numeric_timeout(self, frozen_time):
76        with pytest.raises(requests.exceptions.Timeout):
77            with self.make_guard(timeout=10) as guard:
78                frozen_time.tick(delta=datetime.timedelta(seconds=10.001))
79        assert guard.remaining_timeout == pytest.approx(-0.001)
80
81    def test_timeout_error_w_tuple_timeout(self, frozen_time):
82        with pytest.raises(requests.exceptions.Timeout):
83            with self.make_guard(timeout=(11, 10)) as guard:
84                frozen_time.tick(delta=datetime.timedelta(seconds=10.001))
85        assert guard.remaining_timeout == pytest.approx((0.999, -0.001))
86
87    def test_custom_timeout_error_type(self, frozen_time):
88        class FooError(Exception):
89            pass
90
91        with pytest.raises(FooError):
92            with self.make_guard(timeout=1, timeout_error_type=FooError):
93                frozen_time.tick(delta=datetime.timedelta(seconds=2))
94
95    def test_lets_suite_errors_bubble_up(self, frozen_time):
96        with pytest.raises(IndexError):
97            with self.make_guard(timeout=1):
98                [1, 2, 3][3]
99
100
101class CredentialsStub(google.auth.credentials.Credentials):
102    def __init__(self, token="token"):
103        super(CredentialsStub, self).__init__()
104        self.token = token
105
106    def apply(self, headers, token=None):
107        headers["authorization"] = self.token
108
109    def before_request(self, request, method, url, headers):
110        self.apply(headers)
111
112    def refresh(self, request):
113        self.token += "1"
114
115    def with_quota_project(self, quota_project_id):
116        raise NotImplementedError()
117
118
119class TimeTickCredentialsStub(CredentialsStub):
120    """Credentials that spend some (mocked) time when refreshing a token."""
121
122    def __init__(self, time_tick, token="token"):
123        self._time_tick = time_tick
124        super(TimeTickCredentialsStub, self).__init__(token=token)
125
126    def refresh(self, request):
127        self._time_tick()
128        super(TimeTickCredentialsStub, self).refresh(requests)
129
130
131class AdapterStub(requests.adapters.BaseAdapter):
132    def __init__(self, responses, headers=None):
133        super(AdapterStub, self).__init__()
134        self.responses = responses
135        self.requests = []
136        self.headers = headers or {}
137
138    def send(self, request, **kwargs):
139        # pylint: disable=arguments-differ
140        # request is the only required argument here and the only argument
141        # we care about.
142        self.requests.append(request)
143        return self.responses.pop(0)
144
145    def close(self):  # pragma: NO COVER
146        # pylint wants this to be here because it's abstract in the base
147        # class, but requests never actually calls it.
148        return
149
150
151class TimeTickAdapterStub(AdapterStub):
152    """Adapter that spends some (mocked) time when making a request."""
153
154    def __init__(self, time_tick, responses, headers=None):
155        self._time_tick = time_tick
156        super(TimeTickAdapterStub, self).__init__(responses, headers=headers)
157
158    def send(self, request, **kwargs):
159        self._time_tick()
160        return super(TimeTickAdapterStub, self).send(request, **kwargs)
161
162
163class TestMutualTlsAdapter(object):
164    @mock.patch.object(requests.adapters.HTTPAdapter, "init_poolmanager")
165    @mock.patch.object(requests.adapters.HTTPAdapter, "proxy_manager_for")
166    def test_success(self, mock_proxy_manager_for, mock_init_poolmanager):
167        adapter = google.auth.transport.requests._MutualTlsAdapter(
168            pytest.public_cert_bytes, pytest.private_key_bytes
169        )
170
171        adapter.init_poolmanager()
172        mock_init_poolmanager.assert_called_with(ssl_context=adapter._ctx_poolmanager)
173
174        adapter.proxy_manager_for()
175        mock_proxy_manager_for.assert_called_with(ssl_context=adapter._ctx_proxymanager)
176
177    def test_invalid_cert_or_key(self):
178        with pytest.raises(OpenSSL.crypto.Error):
179            google.auth.transport.requests._MutualTlsAdapter(
180                b"invalid cert", b"invalid key"
181            )
182
183    @mock.patch.dict("sys.modules", {"OpenSSL.crypto": None})
184    def test_import_error(self):
185        with pytest.raises(ImportError):
186            google.auth.transport.requests._MutualTlsAdapter(
187                pytest.public_cert_bytes, pytest.private_key_bytes
188            )
189
190
191def make_response(status=http_client.OK, data=None):
192    response = requests.Response()
193    response.status_code = status
194    response._content = data
195    return response
196
197
198class TestAuthorizedSession(object):
199    TEST_URL = "http://example.com/"
200
201    def test_constructor(self):
202        authed_session = google.auth.transport.requests.AuthorizedSession(
203            mock.sentinel.credentials
204        )
205
206        assert authed_session.credentials == mock.sentinel.credentials
207
208    def test_constructor_with_auth_request(self):
209        http = mock.create_autospec(requests.Session)
210        auth_request = google.auth.transport.requests.Request(http)
211
212        authed_session = google.auth.transport.requests.AuthorizedSession(
213            mock.sentinel.credentials, auth_request=auth_request
214        )
215
216        assert authed_session._auth_request is auth_request
217
218    def test_request_default_timeout(self):
219        credentials = mock.Mock(wraps=CredentialsStub())
220        response = make_response()
221        adapter = AdapterStub([response])
222
223        authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
224        authed_session.mount(self.TEST_URL, adapter)
225
226        patcher = mock.patch("google.auth.transport.requests.requests.Session.request")
227        with patcher as patched_request:
228            authed_session.request("GET", self.TEST_URL)
229
230        expected_timeout = google.auth.transport.requests._DEFAULT_TIMEOUT
231        assert patched_request.call_args[1]["timeout"] == expected_timeout
232
233    def test_request_no_refresh(self):
234        credentials = mock.Mock(wraps=CredentialsStub())
235        response = make_response()
236        adapter = AdapterStub([response])
237
238        authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
239        authed_session.mount(self.TEST_URL, adapter)
240
241        result = authed_session.request("GET", self.TEST_URL)
242
243        assert response == result
244        assert credentials.before_request.called
245        assert not credentials.refresh.called
246        assert len(adapter.requests) == 1
247        assert adapter.requests[0].url == self.TEST_URL
248        assert adapter.requests[0].headers["authorization"] == "token"
249
250    def test_request_refresh(self):
251        credentials = mock.Mock(wraps=CredentialsStub())
252        final_response = make_response(status=http_client.OK)
253        # First request will 401, second request will succeed.
254        adapter = AdapterStub(
255            [make_response(status=http_client.UNAUTHORIZED), final_response]
256        )
257
258        authed_session = google.auth.transport.requests.AuthorizedSession(
259            credentials, refresh_timeout=60
260        )
261        authed_session.mount(self.TEST_URL, adapter)
262
263        result = authed_session.request("GET", self.TEST_URL)
264
265        assert result == final_response
266        assert credentials.before_request.call_count == 2
267        assert credentials.refresh.called
268        assert len(adapter.requests) == 2
269
270        assert adapter.requests[0].url == self.TEST_URL
271        assert adapter.requests[0].headers["authorization"] == "token"
272
273        assert adapter.requests[1].url == self.TEST_URL
274        assert adapter.requests[1].headers["authorization"] == "token1"
275
276    def test_request_max_allowed_time_timeout_error(self, frozen_time):
277        tick_one_second = functools.partial(
278            frozen_time.tick, delta=datetime.timedelta(seconds=1.0)
279        )
280
281        credentials = mock.Mock(
282            wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
283        )
284        adapter = TimeTickAdapterStub(
285            time_tick=tick_one_second, responses=[make_response(status=http_client.OK)]
286        )
287
288        authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
289        authed_session.mount(self.TEST_URL, adapter)
290
291        # Because a request takes a full mocked second, max_allowed_time shorter
292        # than that will cause a timeout error.
293        with pytest.raises(requests.exceptions.Timeout):
294            authed_session.request("GET", self.TEST_URL, max_allowed_time=0.9)
295
296    def test_request_max_allowed_time_w_transport_timeout_no_error(self, frozen_time):
297        tick_one_second = functools.partial(
298            frozen_time.tick, delta=datetime.timedelta(seconds=1.0)
299        )
300
301        credentials = mock.Mock(
302            wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
303        )
304        adapter = TimeTickAdapterStub(
305            time_tick=tick_one_second,
306            responses=[
307                make_response(status=http_client.UNAUTHORIZED),
308                make_response(status=http_client.OK),
309            ],
310        )
311
312        authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
313        authed_session.mount(self.TEST_URL, adapter)
314
315        # A short configured transport timeout does not affect max_allowed_time.
316        # The latter is not adjusted to it and is only concerned with the actual
317        # execution time. The call below should thus not raise a timeout error.
318        authed_session.request("GET", self.TEST_URL, timeout=0.5, max_allowed_time=3.1)
319
320    def test_request_max_allowed_time_w_refresh_timeout_no_error(self, frozen_time):
321        tick_one_second = functools.partial(
322            frozen_time.tick, delta=datetime.timedelta(seconds=1.0)
323        )
324
325        credentials = mock.Mock(
326            wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
327        )
328        adapter = TimeTickAdapterStub(
329            time_tick=tick_one_second,
330            responses=[
331                make_response(status=http_client.UNAUTHORIZED),
332                make_response(status=http_client.OK),
333            ],
334        )
335
336        authed_session = google.auth.transport.requests.AuthorizedSession(
337            credentials, refresh_timeout=1.1
338        )
339        authed_session.mount(self.TEST_URL, adapter)
340
341        # A short configured refresh timeout does not affect max_allowed_time.
342        # The latter is not adjusted to it and is only concerned with the actual
343        # execution time. The call below should thus not raise a timeout error
344        # (and `timeout` does not come into play either, as it's very long).
345        authed_session.request("GET", self.TEST_URL, timeout=60, max_allowed_time=3.1)
346
347    def test_request_timeout_w_refresh_timeout_timeout_error(self, frozen_time):
348        tick_one_second = functools.partial(
349            frozen_time.tick, delta=datetime.timedelta(seconds=1.0)
350        )
351
352        credentials = mock.Mock(
353            wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
354        )
355        adapter = TimeTickAdapterStub(
356            time_tick=tick_one_second,
357            responses=[
358                make_response(status=http_client.UNAUTHORIZED),
359                make_response(status=http_client.OK),
360            ],
361        )
362
363        authed_session = google.auth.transport.requests.AuthorizedSession(
364            credentials, refresh_timeout=100
365        )
366        authed_session.mount(self.TEST_URL, adapter)
367
368        # An UNAUTHORIZED response triggers a refresh (an extra request), thus
369        # the final request that otherwise succeeds results in a timeout error
370        # (all three requests together last 3 mocked seconds).
371        with pytest.raises(requests.exceptions.Timeout):
372            authed_session.request(
373                "GET", self.TEST_URL, timeout=60, max_allowed_time=2.9
374            )
375
376    def test_authorized_session_without_default_host(self):
377        credentials = mock.create_autospec(service_account.Credentials)
378
379        authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
380
381        authed_session.credentials._create_self_signed_jwt.assert_called_once_with(None)
382
383    def test_authorized_session_with_default_host(self):
384        default_host = "pubsub.googleapis.com"
385        credentials = mock.create_autospec(service_account.Credentials)
386
387        authed_session = google.auth.transport.requests.AuthorizedSession(
388            credentials, default_host=default_host
389        )
390
391        authed_session.credentials._create_self_signed_jwt.assert_called_once_with(
392            "https://{}/".format(default_host)
393        )
394
395    def test_configure_mtls_channel_with_callback(self):
396        mock_callback = mock.Mock()
397        mock_callback.return_value = (
398            pytest.public_cert_bytes,
399            pytest.private_key_bytes,
400        )
401
402        auth_session = google.auth.transport.requests.AuthorizedSession(
403            credentials=mock.Mock()
404        )
405        with mock.patch.dict(
406            os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
407        ):
408            auth_session.configure_mtls_channel(mock_callback)
409
410        assert auth_session.is_mtls
411        assert isinstance(
412            auth_session.adapters["https://"],
413            google.auth.transport.requests._MutualTlsAdapter,
414        )
415
416    @mock.patch(
417        "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True
418    )
419    def test_configure_mtls_channel_with_metadata(self, mock_get_client_cert_and_key):
420        mock_get_client_cert_and_key.return_value = (
421            True,
422            pytest.public_cert_bytes,
423            pytest.private_key_bytes,
424        )
425
426        auth_session = google.auth.transport.requests.AuthorizedSession(
427            credentials=mock.Mock()
428        )
429        with mock.patch.dict(
430            os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
431        ):
432            auth_session.configure_mtls_channel()
433
434        assert auth_session.is_mtls
435        assert isinstance(
436            auth_session.adapters["https://"],
437            google.auth.transport.requests._MutualTlsAdapter,
438        )
439
440    @mock.patch.object(google.auth.transport.requests._MutualTlsAdapter, "__init__")
441    @mock.patch(
442        "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True
443    )
444    def test_configure_mtls_channel_non_mtls(
445        self, mock_get_client_cert_and_key, mock_adapter_ctor
446    ):
447        mock_get_client_cert_and_key.return_value = (False, None, None)
448
449        auth_session = google.auth.transport.requests.AuthorizedSession(
450            credentials=mock.Mock()
451        )
452        with mock.patch.dict(
453            os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
454        ):
455            auth_session.configure_mtls_channel()
456
457        assert not auth_session.is_mtls
458
459        # Assert _MutualTlsAdapter constructor is not called.
460        mock_adapter_ctor.assert_not_called()
461
462    @mock.patch(
463        "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True
464    )
465    def test_configure_mtls_channel_exceptions(self, mock_get_client_cert_and_key):
466        mock_get_client_cert_and_key.side_effect = exceptions.ClientCertError()
467
468        auth_session = google.auth.transport.requests.AuthorizedSession(
469            credentials=mock.Mock()
470        )
471        with pytest.raises(exceptions.MutualTLSChannelError):
472            with mock.patch.dict(
473                os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
474            ):
475                auth_session.configure_mtls_channel()
476
477        mock_get_client_cert_and_key.return_value = (False, None, None)
478        with mock.patch.dict("sys.modules"):
479            sys.modules["OpenSSL"] = None
480            with pytest.raises(exceptions.MutualTLSChannelError):
481                with mock.patch.dict(
482                    os.environ,
483                    {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"},
484                ):
485                    auth_session.configure_mtls_channel()
486
487    @mock.patch(
488        "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True
489    )
490    def test_configure_mtls_channel_without_client_cert_env(
491        self, get_client_cert_and_key
492    ):
493        # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE
494        # environment variable is not set.
495        auth_session = google.auth.transport.requests.AuthorizedSession(
496            credentials=mock.Mock()
497        )
498
499        auth_session.configure_mtls_channel()
500        assert not auth_session.is_mtls
501        get_client_cert_and_key.assert_not_called()
502
503        mock_callback = mock.Mock()
504        auth_session.configure_mtls_channel(mock_callback)
505        assert not auth_session.is_mtls
506        mock_callback.assert_not_called()
507
508    def test_close_wo_passed_in_auth_request(self):
509        authed_session = google.auth.transport.requests.AuthorizedSession(
510            mock.sentinel.credentials
511        )
512        authed_session._auth_request_session = mock.Mock(spec=["close"])
513
514        authed_session.close()
515
516        authed_session._auth_request_session.close.assert_called_once_with()
517
518    def test_close_w_passed_in_auth_request(self):
519        http = mock.create_autospec(requests.Session)
520        auth_request = google.auth.transport.requests.Request(http)
521        authed_session = google.auth.transport.requests.AuthorizedSession(
522            mock.sentinel.credentials, auth_request=auth_request
523        )
524
525        authed_session.close()  # no raise
526