1# ------------------------------------
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4# ------------------------------------
5import functools
6from unittest.mock import Mock, patch
7from urllib.parse import urlparse
8import time
9from azure.core.exceptions import ClientAuthenticationError
10from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
11from azure.identity.aio._internal.aad_client import AadClient
12from azure.core.credentials import AccessToken
13from msal import TokenCache
14import pytest
15
16from helpers import build_aad_response, mock_response
17
18pytestmark = pytest.mark.asyncio
19
20
21async def test_error_reporting():
22    error_name = "everything's sideways"
23    error_description = "something went wrong"
24    error_response = {"error": error_name, "error_description": error_description}
25
26    response = mock_response(status_code=403, json_payload=error_response)
27
28    async def send(*_, **__):
29        return response
30
31    transport = Mock(send=Mock(wraps=send))
32    client = AadClient("tenant id", "client id", transport=transport)
33
34    fns = [
35        functools.partial(client.obtain_token_by_authorization_code, ("scope",), "code", "uri"),
36        functools.partial(client.obtain_token_by_refresh_token, ("scope",), "refresh token"),
37    ]
38
39    # exceptions raised for AAD errors should contain AAD's error description
40    for fn in fns:
41        with pytest.raises(ClientAuthenticationError) as ex:
42            await fn()
43        message = str(ex.value)
44        assert error_name in message and error_description in message
45        assert transport.send.call_count == 1
46        transport.send.reset_mock()
47
48
49async def test_exceptions_do_not_expose_secrets():
50    secret = "secret"
51    body = {"error": "bad thing", "access_token": secret, "refresh_token": secret}
52    response = mock_response(status_code=403, json_payload=body)
53
54    async def send(*_, **__):
55        return response
56
57    transport = Mock(send=Mock(wraps=send))
58
59    client = AadClient("tenant id", "client id", transport=transport)
60
61    fns = [
62        functools.partial(client.obtain_token_by_authorization_code, "code", "uri", ("scope",)),
63        functools.partial(client.obtain_token_by_refresh_token, "refresh token", ("scope",)),
64    ]
65
66    async def assert_secrets_not_exposed():
67        for fn in fns:
68            with pytest.raises(ClientAuthenticationError) as ex:
69                await fn()
70            assert secret not in str(ex.value)
71            assert secret not in repr(ex.value)
72            assert transport.send.call_count == 1
73            transport.send.reset_mock()
74
75    # AAD errors shouldn't provoke exceptions exposing secrets
76    await assert_secrets_not_exposed()
77
78    # neither should unexpected AAD responses
79    del body["error"]
80    await assert_secrets_not_exposed()
81
82
83@pytest.mark.parametrize("secret", (None, "client secret"))
84async def test_authorization_code(secret):
85    tenant_id = "tenant-id"
86    client_id = "client-id"
87    auth_code = "code"
88    scope = "scope"
89    redirect_uri = "https://localhost"
90    access_token = "***"
91
92    async def send(request, **_):
93        assert request.data["client_id"] == client_id
94        assert request.data["code"] == auth_code
95        assert request.data["grant_type"] == "authorization_code"
96        assert request.data["redirect_uri"] == redirect_uri
97        assert request.data["scope"] == scope
98        assert request.data.get("client_secret") == secret
99
100        return mock_response(json_payload={"access_token": access_token, "expires_in": 42})
101
102    transport = Mock(send=Mock(wraps=send))
103
104    client = AadClient(tenant_id, client_id, transport=transport)
105    token = await client.obtain_token_by_authorization_code(
106        scopes=(scope,), code=auth_code, redirect_uri=redirect_uri, client_secret=secret
107    )
108
109    assert token.token == access_token
110    assert transport.send.call_count == 1
111
112
113async def test_client_secret():
114    tenant_id = "tenant-id"
115    client_id = "client-id"
116    scope = "scope"
117    secret = "refresh-token"
118    access_token = "***"
119
120    async def send(request, **_):
121        assert request.data["client_id"] == client_id
122        assert request.data["client_secret"] == secret
123        assert request.data["grant_type"] == "client_credentials"
124        assert request.data["scope"] == scope
125
126        return mock_response(json_payload={"access_token": access_token, "expires_in": 42})
127
128    transport = Mock(send=Mock(wraps=send))
129
130    client = AadClient(tenant_id, client_id, transport=transport)
131    token = await client.obtain_token_by_client_secret(scopes=(scope,), secret=secret)
132
133    assert token.token == access_token
134    assert transport.send.call_count == 1
135
136
137async def test_refresh_token():
138    tenant_id = "tenant-id"
139    client_id = "client-id"
140    scope = "scope"
141    refresh_token = "refresh-token"
142    access_token = "***"
143
144    async def send(request, **_):
145        assert request.data["client_id"] == client_id
146        assert request.data["grant_type"] == "refresh_token"
147        assert request.data["refresh_token"] == refresh_token
148        assert request.data["scope"] == scope
149
150        return mock_response(json_payload={"access_token": access_token, "expires_in": 42})
151
152    transport = Mock(send=Mock(wraps=send))
153
154    client = AadClient(tenant_id, client_id, transport=transport)
155    token = await client.obtain_token_by_refresh_token(scopes=(scope,), refresh_token=refresh_token)
156
157    assert token.token == access_token
158    assert transport.send.call_count == 1
159
160
161@pytest.mark.parametrize("authority", ("localhost", "https://localhost"))
162async def test_request_url(authority):
163    tenant_id = "expected-tenant"
164    parsed_authority = urlparse(authority)
165    expected_netloc = parsed_authority.netloc or authority  # "localhost" parses to netloc "", path "localhost"
166
167    async def send(request, **_):
168        actual = urlparse(request.url)
169        assert actual.scheme == "https"
170        assert actual.netloc == expected_netloc
171        assert actual.path.startswith("/" + tenant_id)
172        return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"})
173
174    client = AadClient(tenant_id, "client id", transport=Mock(send=send), authority=authority)
175
176    await client.obtain_token_by_authorization_code("scope", "code", "uri")
177    await client.obtain_token_by_refresh_token("scope", "refresh token")
178
179    # authority can be configured via environment variable
180    with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
181        client = AadClient(tenant_id=tenant_id, client_id="client id", transport=Mock(send=send))
182    await client.obtain_token_by_authorization_code("scope", "code", "uri")
183    await client.obtain_token_by_refresh_token("scope", "refresh token")
184
185
186async def test_evicts_invalid_refresh_token():
187    """when AAD rejects a refresh token, the client should evict that token from its cache"""
188
189    tenant_id = "tenant-id"
190    client_id = "client-id"
191    invalid_token = "invalid-refresh-token"
192
193    cache = TokenCache()
194    cache.add({"response": build_aad_response(uid="id1", utid="tid1", access_token="*", refresh_token=invalid_token)})
195    cache.add({"response": build_aad_response(uid="id2", utid="tid2", access_token="*", refresh_token="...")})
196    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 2
197    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 1
198
199    async def send(request, **_):
200        assert request.data["refresh_token"] == invalid_token
201        return mock_response(json_payload={"error": "invalid_grant"}, status_code=400)
202
203    transport = Mock(send=Mock(wraps=send))
204
205    client = AadClient(tenant_id, client_id, transport=transport, cache=cache)
206    with pytest.raises(ClientAuthenticationError):
207        await client.obtain_token_by_refresh_token(scopes=("scope",), refresh_token=invalid_token)
208
209    assert transport.send.call_count == 1
210    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1
211    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 0
212
213
214async def test_should_refresh():
215    client = AadClient("test", "test")
216    now = int(time.time())
217
218    # do not need refresh
219    token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET + 1)
220    should_refresh = client.should_refresh(token)
221    assert not should_refresh
222
223    # need refresh
224    token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1)
225    should_refresh = client.should_refresh(token)
226    assert should_refresh
227
228    # not exceed cool down time, do not refresh
229    token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1)
230    client._last_refresh_time = now - DEFAULT_TOKEN_REFRESH_RETRY_DELAY + 1
231    should_refresh = client.should_refresh(token)
232    assert not should_refresh
233