1# ------------------------------------
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4# ------------------------------------
5import time
6from unittest.mock import Mock, patch
7from urllib.parse import urlparse
8
9import pytest
10from azure.core.credentials import AccessToken
11from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
12from azure.identity.aio._authn_client import AsyncAuthnClient
13
14from helpers import mock_response
15from helpers_async import wrap_in_future
16
17
18@pytest.mark.asyncio
19@pytest.mark.parametrize("authority", ("localhost", "https://localhost"))
20async def test_request_url(authority):
21    tenant_id = "expected-tenant"
22    parsed_authority = urlparse(authority)
23    expected_netloc = parsed_authority.netloc or authority  # "localhost" parses to netloc "", path "localhost"
24
25    def mock_send(request, **kwargs):
26        actual = urlparse(request.url)
27        assert actual.scheme == "https"
28        assert actual.netloc == expected_netloc
29        assert actual.path.startswith("/" + tenant_id)
30        return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "*"})
31
32    client = AsyncAuthnClient(tenant=tenant_id, transport=Mock(send=wrap_in_future(mock_send)), authority=authority)
33    await client.request_token(("scope",))
34
35    # authority can be configured via environment variable
36    with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
37        client = AsyncAuthnClient(tenant=tenant_id, transport=Mock(send=wrap_in_future(mock_send)))
38        await client.request_token(("scope",))
39
40
41def test_should_refresh():
42    client = AsyncAuthnClient(endpoint="http://foo")
43    now = int(time.time())
44
45    # do not need refresh
46    token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET + 1)
47    should_refresh = client.should_refresh(token)
48    assert not should_refresh
49
50    # need refresh
51    token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1)
52    should_refresh = client.should_refresh(token)
53    assert should_refresh
54
55    # not exceed cool down time, do not refresh
56    token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1)
57    client._last_refresh_time = now - DEFAULT_TOKEN_REFRESH_RETRY_DELAY + 1
58    should_refresh = client.should_refresh(token)
59    assert not should_refresh
60