1# ------------------------------------
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4# ------------------------------------
5import time
6
7try:
8    from unittest import mock
9except ImportError:
10    import mock
11
12from azure.core.credentials import AccessToken
13import pytest
14
15from azure.identity._constants import DEFAULT_REFRESH_OFFSET
16from azure.identity._internal.get_token_mixin import GetTokenMixin
17
18
19class MockCredential(GetTokenMixin):
20    NEW_TOKEN = AccessToken("new token", 42)
21
22    def __init__(self, cached_token=None):
23        super(MockCredential, self).__init__()
24        self.request_token = mock.Mock(return_value=MockCredential.NEW_TOKEN)
25        self.acquire_token_silently = mock.Mock(return_value=cached_token)
26
27    def _acquire_token_silently(self, *scopes):
28        return self.acquire_token_silently(*scopes)
29
30    def _request_token(self, *scopes, **kwargs):
31        return self.request_token(*scopes, **kwargs)
32
33    def get_token(self, *_, **__):
34        return super(MockCredential, self).get_token(*_, **__)
35
36
37CACHED_TOKEN = "cached token"
38SCOPE = "scope"
39
40
41def test_no_cached_token():
42    """When it has no token cached, a credential should request one every time get_token is called"""
43
44    credential = MockCredential()
45    token = credential.get_token(SCOPE)
46
47    credential.acquire_token_silently.assert_called_once_with(SCOPE)
48    credential.request_token.assert_called_once_with(SCOPE)
49    assert token.token == MockCredential.NEW_TOKEN.token
50
51
52def test_token_acquisition_failure():
53    """When the credential has no token cached, every get_token call should prompt a token request"""
54
55    credential = MockCredential()
56    credential.request_token = mock.Mock(side_effect=Exception("whoops"))
57    for i in range(4):
58        with pytest.raises(Exception):
59            credential.get_token(SCOPE)
60        assert credential.request_token.call_count == i + 1
61        credential.request_token.assert_called_with(SCOPE)
62
63
64def test_expired_token():
65    """A credential should request a token when it has an expired token cached"""
66
67    now = time.time()
68    credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now - 1))
69    token = credential.get_token(SCOPE)
70
71    credential.acquire_token_silently.assert_called_once_with(SCOPE)
72    credential.request_token.assert_called_once_with(SCOPE)
73    assert token.token == MockCredential.NEW_TOKEN.token
74
75
76def test_cached_token_outside_refresh_window():
77    """A credential shouldn't request a new token when it has a cached one with sufficient validity remaining"""
78
79    credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET + 1))
80    token = credential.get_token(SCOPE)
81
82    credential.acquire_token_silently.assert_called_once_with(SCOPE)
83    assert credential.request_token.call_count == 0
84    assert token.token == CACHED_TOKEN
85
86
87def test_cached_token_within_refresh_window():
88    """A credential should request a new token when its cached one is within the refresh window"""
89
90    credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET - 1))
91    token = credential.get_token(SCOPE)
92
93    credential.acquire_token_silently.assert_called_once_with(SCOPE)
94    credential.request_token.assert_called_once_with(SCOPE)
95    assert token.token == MockCredential.NEW_TOKEN.token
96
97
98def test_retry_delay():
99    """A credential should wait between requests when trying to refresh a token"""
100
101    now = time.time()
102    credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now + DEFAULT_REFRESH_OFFSET - 1))
103
104    # the credential should swallow exceptions during proactive refresh attempts
105    credential.request_token = mock.Mock(side_effect=Exception("whoops"))
106    for i in range(4):
107        token = credential.get_token(SCOPE)
108        assert token.token == CACHED_TOKEN
109        credential.acquire_token_silently.assert_called_with(SCOPE)
110        credential.request_token.assert_called_once_with(SCOPE)
111