1# ------------------------------------
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4# ------------------------------------
5"""Policy implementing Key Vault's challenge authentication protocol.
6
7Normally the protocol is only used for the client's first service request, upon which:
81. The challenge authentication policy sends a copy of the request, without authorization or content.
92. Key Vault responds 401 with a header (the 'challenge') detailing how the client should authenticate such a request.
103. The policy authenticates according to the challenge and sends the original request with authorization.
11
12The policy caches the challenge and thus knows how to authenticate future requests. However, authentication
13requirements can change. For example, a vault may move to a new tenant. In such a case the policy will attempt the
14protocol again.
15"""
16from typing import TYPE_CHECKING
17
18from azure.core.pipeline.policies import AsyncHTTPPolicy
19
20from . import HttpChallengeCache
21from .challenge_auth_policy import _enforce_tls, _get_challenge_request, _update_challenge, ChallengeAuthPolicyBase
22
23if TYPE_CHECKING:
24    from typing import Any
25    from azure.core.credentials_async import AsyncTokenCredential
26    from azure.core.pipeline import PipelineRequest
27    from azure.core.pipeline.transport import AsyncHttpResponse
28    from . import HttpChallenge
29
30
31class AsyncChallengeAuthPolicy(ChallengeAuthPolicyBase, AsyncHTTPPolicy):
32    """policy for handling HTTP authentication challenges"""
33
34    def __init__(self, credential: "AsyncTokenCredential", **kwargs: "Any") -> None:
35        self._credential = credential
36        super(AsyncChallengeAuthPolicy, self).__init__(**kwargs)
37
38    async def send(self, request: "PipelineRequest") -> "AsyncHttpResponse":
39        _enforce_tls(request)
40
41        challenge = HttpChallengeCache.get_challenge_for_url(request.http_request.url)
42        if not challenge:
43            challenge_request = _get_challenge_request(request)
44            challenger = await self.next.send(challenge_request)
45            try:
46                challenge = _update_challenge(request, challenger)
47            except ValueError:
48                # didn't receive the expected challenge -> nothing more this policy can do
49                return challenger
50
51        await self._handle_challenge(request, challenge)
52        response = await self.next.send(request)
53
54        if response.http_response.status_code == 401:
55            # any cached token must be invalid
56            self._token = None
57
58            # cached challenge could be outdated; maybe this response has a new one?
59            try:
60                challenge = _update_challenge(request, response)
61            except ValueError:
62                # 401 with no legible challenge -> nothing more this policy can do
63                return response
64
65            await self._handle_challenge(request, challenge)
66            response = await self.next.send(request)
67
68        return response
69
70    async def _handle_challenge(self, request: "PipelineRequest", challenge: "HttpChallenge") -> None:
71        """authenticate according to challenge, add Authorization header to request"""
72
73        if self._need_new_token:
74            # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
75            scope = challenge.get_scope() or challenge.get_resource() + "/.default"
76            self._token = await self._credential.get_token(scope)
77
78        # ignore mypy's warning because although self._token is Optional, get_token raises when it fails to get a token
79        request.http_request.headers["Authorization"] = "Bearer {}".format(self._token.token)  # type: ignore
80