1# ------------------------------------ 2# Copyright (c) Microsoft Corporation. 3# Licensed under the MIT License. 4# ------------------------------------ 5import time 6from unittest.mock import Mock, patch 7from urllib.parse import urlparse 8 9import pytest 10from azure.core.credentials import AccessToken 11from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY 12from azure.identity.aio._authn_client import AsyncAuthnClient 13 14from helpers import mock_response 15from helpers_async import wrap_in_future 16 17 18@pytest.mark.asyncio 19@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) 20async def test_request_url(authority): 21 tenant_id = "expected-tenant" 22 parsed_authority = urlparse(authority) 23 expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" 24 25 def mock_send(request, **kwargs): 26 actual = urlparse(request.url) 27 assert actual.scheme == "https" 28 assert actual.netloc == expected_netloc 29 assert actual.path.startswith("/" + tenant_id) 30 return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "*"}) 31 32 client = AsyncAuthnClient(tenant=tenant_id, transport=Mock(send=wrap_in_future(mock_send)), authority=authority) 33 await client.request_token(("scope",)) 34 35 # authority can be configured via environment variable 36 with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): 37 client = AsyncAuthnClient(tenant=tenant_id, transport=Mock(send=wrap_in_future(mock_send))) 38 await client.request_token(("scope",)) 39 40 41def test_should_refresh(): 42 client = AsyncAuthnClient(endpoint="http://foo") 43 now = int(time.time()) 44 45 # do not need refresh 46 token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET + 1) 47 should_refresh = client.should_refresh(token) 48 assert not should_refresh 49 50 # need refresh 51 token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) 52 should_refresh = client.should_refresh(token) 53 assert should_refresh 54 55 # not exceed cool down time, do not refresh 56 token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) 57 client._last_refresh_time = now - DEFAULT_TOKEN_REFRESH_RETRY_DELAY + 1 58 should_refresh = client.should_refresh(token) 59 assert not should_refresh 60