1# ------------------------------------ 2# Copyright (c) Microsoft Corporation. 3# Licensed under the MIT License. 4# ------------------------------------ 5import functools 6from unittest.mock import Mock, patch 7from urllib.parse import urlparse 8import time 9from azure.core.exceptions import ClientAuthenticationError 10from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY 11from azure.identity.aio._internal.aad_client import AadClient 12from azure.core.credentials import AccessToken 13from msal import TokenCache 14import pytest 15 16from helpers import build_aad_response, mock_response 17 18pytestmark = pytest.mark.asyncio 19 20 21async def test_error_reporting(): 22 error_name = "everything's sideways" 23 error_description = "something went wrong" 24 error_response = {"error": error_name, "error_description": error_description} 25 26 response = mock_response(status_code=403, json_payload=error_response) 27 28 async def send(*_, **__): 29 return response 30 31 transport = Mock(send=Mock(wraps=send)) 32 client = AadClient("tenant id", "client id", transport=transport) 33 34 fns = [ 35 functools.partial(client.obtain_token_by_authorization_code, ("scope",), "code", "uri"), 36 functools.partial(client.obtain_token_by_refresh_token, ("scope",), "refresh token"), 37 ] 38 39 # exceptions raised for AAD errors should contain AAD's error description 40 for fn in fns: 41 with pytest.raises(ClientAuthenticationError) as ex: 42 await fn() 43 message = str(ex.value) 44 assert error_name in message and error_description in message 45 assert transport.send.call_count == 1 46 transport.send.reset_mock() 47 48 49async def test_exceptions_do_not_expose_secrets(): 50 secret = "secret" 51 body = {"error": "bad thing", "access_token": secret, "refresh_token": secret} 52 response = mock_response(status_code=403, json_payload=body) 53 54 async def send(*_, **__): 55 return response 56 57 transport = Mock(send=Mock(wraps=send)) 58 59 client = AadClient("tenant id", "client id", transport=transport) 60 61 fns = [ 62 functools.partial(client.obtain_token_by_authorization_code, "code", "uri", ("scope",)), 63 functools.partial(client.obtain_token_by_refresh_token, "refresh token", ("scope",)), 64 ] 65 66 async def assert_secrets_not_exposed(): 67 for fn in fns: 68 with pytest.raises(ClientAuthenticationError) as ex: 69 await fn() 70 assert secret not in str(ex.value) 71 assert secret not in repr(ex.value) 72 assert transport.send.call_count == 1 73 transport.send.reset_mock() 74 75 # AAD errors shouldn't provoke exceptions exposing secrets 76 await assert_secrets_not_exposed() 77 78 # neither should unexpected AAD responses 79 del body["error"] 80 await assert_secrets_not_exposed() 81 82 83@pytest.mark.parametrize("secret", (None, "client secret")) 84async def test_authorization_code(secret): 85 tenant_id = "tenant-id" 86 client_id = "client-id" 87 auth_code = "code" 88 scope = "scope" 89 redirect_uri = "https://localhost" 90 access_token = "***" 91 92 async def send(request, **_): 93 assert request.data["client_id"] == client_id 94 assert request.data["code"] == auth_code 95 assert request.data["grant_type"] == "authorization_code" 96 assert request.data["redirect_uri"] == redirect_uri 97 assert request.data["scope"] == scope 98 assert request.data.get("client_secret") == secret 99 100 return mock_response(json_payload={"access_token": access_token, "expires_in": 42}) 101 102 transport = Mock(send=Mock(wraps=send)) 103 104 client = AadClient(tenant_id, client_id, transport=transport) 105 token = await client.obtain_token_by_authorization_code( 106 scopes=(scope,), code=auth_code, redirect_uri=redirect_uri, client_secret=secret 107 ) 108 109 assert token.token == access_token 110 assert transport.send.call_count == 1 111 112 113async def test_client_secret(): 114 tenant_id = "tenant-id" 115 client_id = "client-id" 116 scope = "scope" 117 secret = "refresh-token" 118 access_token = "***" 119 120 async def send(request, **_): 121 assert request.data["client_id"] == client_id 122 assert request.data["client_secret"] == secret 123 assert request.data["grant_type"] == "client_credentials" 124 assert request.data["scope"] == scope 125 126 return mock_response(json_payload={"access_token": access_token, "expires_in": 42}) 127 128 transport = Mock(send=Mock(wraps=send)) 129 130 client = AadClient(tenant_id, client_id, transport=transport) 131 token = await client.obtain_token_by_client_secret(scopes=(scope,), secret=secret) 132 133 assert token.token == access_token 134 assert transport.send.call_count == 1 135 136 137async def test_refresh_token(): 138 tenant_id = "tenant-id" 139 client_id = "client-id" 140 scope = "scope" 141 refresh_token = "refresh-token" 142 access_token = "***" 143 144 async def send(request, **_): 145 assert request.data["client_id"] == client_id 146 assert request.data["grant_type"] == "refresh_token" 147 assert request.data["refresh_token"] == refresh_token 148 assert request.data["scope"] == scope 149 150 return mock_response(json_payload={"access_token": access_token, "expires_in": 42}) 151 152 transport = Mock(send=Mock(wraps=send)) 153 154 client = AadClient(tenant_id, client_id, transport=transport) 155 token = await client.obtain_token_by_refresh_token(scopes=(scope,), refresh_token=refresh_token) 156 157 assert token.token == access_token 158 assert transport.send.call_count == 1 159 160 161@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) 162async def test_request_url(authority): 163 tenant_id = "expected-tenant" 164 parsed_authority = urlparse(authority) 165 expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" 166 167 async def send(request, **_): 168 actual = urlparse(request.url) 169 assert actual.scheme == "https" 170 assert actual.netloc == expected_netloc 171 assert actual.path.startswith("/" + tenant_id) 172 return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) 173 174 client = AadClient(tenant_id, "client id", transport=Mock(send=send), authority=authority) 175 176 await client.obtain_token_by_authorization_code("scope", "code", "uri") 177 await client.obtain_token_by_refresh_token("scope", "refresh token") 178 179 # authority can be configured via environment variable 180 with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): 181 client = AadClient(tenant_id=tenant_id, client_id="client id", transport=Mock(send=send)) 182 await client.obtain_token_by_authorization_code("scope", "code", "uri") 183 await client.obtain_token_by_refresh_token("scope", "refresh token") 184 185 186async def test_evicts_invalid_refresh_token(): 187 """when AAD rejects a refresh token, the client should evict that token from its cache""" 188 189 tenant_id = "tenant-id" 190 client_id = "client-id" 191 invalid_token = "invalid-refresh-token" 192 193 cache = TokenCache() 194 cache.add({"response": build_aad_response(uid="id1", utid="tid1", access_token="*", refresh_token=invalid_token)}) 195 cache.add({"response": build_aad_response(uid="id2", utid="tid2", access_token="*", refresh_token="...")}) 196 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 2 197 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 1 198 199 async def send(request, **_): 200 assert request.data["refresh_token"] == invalid_token 201 return mock_response(json_payload={"error": "invalid_grant"}, status_code=400) 202 203 transport = Mock(send=Mock(wraps=send)) 204 205 client = AadClient(tenant_id, client_id, transport=transport, cache=cache) 206 with pytest.raises(ClientAuthenticationError): 207 await client.obtain_token_by_refresh_token(scopes=("scope",), refresh_token=invalid_token) 208 209 assert transport.send.call_count == 1 210 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1 211 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 0 212 213 214async def test_should_refresh(): 215 client = AadClient("test", "test") 216 now = int(time.time()) 217 218 # do not need refresh 219 token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET + 1) 220 should_refresh = client.should_refresh(token) 221 assert not should_refresh 222 223 # need refresh 224 token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) 225 should_refresh = client.should_refresh(token) 226 assert should_refresh 227 228 # not exceed cool down time, do not refresh 229 token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) 230 client._last_refresh_time = now - DEFAULT_TOKEN_REFRESH_RETRY_DELAY + 1 231 should_refresh = client.should_refresh(token) 232 assert not should_refresh 233