1# ------------------------------------
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4# ------------------------------------
5"""Policy implementing Key Vault's challenge authentication protocol.
6
7Normally the protocol is only used for the client's first service request, upon which:
81. The challenge authentication policy sends a copy of the request, without authorization or content.
92. Key Vault responds 401 with a header (the 'challenge') detailing how the client should authenticate such a request.
103. The policy authenticates according to the challenge and sends the original request with authorization.
11
12The policy caches the challenge and thus knows how to authenticate future requests. However, authentication
13requirements can change. For example, a vault may move to a new tenant. In such a case the policy will attempt the
14protocol again.
15"""
16
17import copy
18import time
19
20from azure.core.exceptions import ServiceRequestError
21from azure.core.pipeline import PipelineContext, PipelineRequest
22from azure.core.pipeline.policies import HTTPPolicy
23from azure.core.pipeline.transport import HttpRequest
24
25from .http_challenge import HttpChallenge
26from . import http_challenge_cache as ChallengeCache
27
28try:
29    from typing import TYPE_CHECKING
30except ImportError:
31    TYPE_CHECKING = False
32
33if TYPE_CHECKING:
34    from typing import Any, Optional
35    from azure.core.credentials import AccessToken, TokenCredential
36    from azure.core.pipeline import PipelineResponse
37
38
39def _enforce_tls(request):
40    # type: (PipelineRequest) -> None
41    if not request.http_request.url.lower().startswith("https"):
42        raise ServiceRequestError(
43            "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
44        )
45
46
47def _get_challenge_request(request):
48    # type: (PipelineRequest) -> PipelineRequest
49
50    # The challenge request is intended to provoke an authentication challenge from Key Vault, to learn how the
51    # service request should be authenticated. It should be identical to the service request but with no body.
52    challenge_request = HttpRequest(
53        request.http_request.method, request.http_request.url, headers=request.http_request.headers
54    )
55    challenge_request.headers["Content-Length"] = "0"
56
57    options = copy.deepcopy(request.context.options)
58    context = PipelineContext(request.context.transport, **options)
59
60    return PipelineRequest(http_request=challenge_request, context=context)
61
62
63def _update_challenge(request, challenger):
64    # type: (PipelineRequest, PipelineResponse) -> HttpChallenge
65    """parse challenge from challenger, cache it, return it"""
66
67    challenge = HttpChallenge(
68        request.http_request.url,
69        challenger.http_response.headers.get("WWW-Authenticate"),
70        response_headers=challenger.http_response.headers,
71    )
72    ChallengeCache.set_challenge_for_url(request.http_request.url, challenge)
73    return challenge
74
75
76class ChallengeAuthPolicyBase(object):
77    """Sans I/O base for challenge authentication policies"""
78
79    def __init__(self, **kwargs):
80        self._token = None  # type: Optional[AccessToken]
81        super(ChallengeAuthPolicyBase, self).__init__(**kwargs)
82
83    @property
84    def _need_new_token(self):
85        # type: () -> bool
86        return not self._token or self._token.expires_on - time.time() < 300
87
88
89class ChallengeAuthPolicy(ChallengeAuthPolicyBase, HTTPPolicy):
90    """policy for handling HTTP authentication challenges"""
91
92    def __init__(self, credential, **kwargs):
93        # type: (TokenCredential, **Any) -> None
94        self._credential = credential
95        super(ChallengeAuthPolicy, self).__init__(**kwargs)
96
97    def send(self, request):
98        # type: (PipelineRequest) -> PipelineResponse
99        _enforce_tls(request)
100
101        challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
102        if not challenge:
103            challenge_request = _get_challenge_request(request)
104            challenger = self.next.send(challenge_request)
105            try:
106                challenge = _update_challenge(request, challenger)
107            except ValueError:
108                # didn't receive the expected challenge -> nothing more this policy can do
109                return challenger
110
111        self._handle_challenge(request, challenge)
112        response = self.next.send(request)
113
114        if response.http_response.status_code == 401:
115            # any cached token must be invalid
116            self._token = None
117
118            # cached challenge could be outdated; maybe this response has a new one?
119            try:
120                challenge = _update_challenge(request, response)
121            except ValueError:
122                # 401 with no legible challenge -> nothing more this policy can do
123                return response
124
125            self._handle_challenge(request, challenge)
126            response = self.next.send(request)
127
128        return response
129
130    def _handle_challenge(self, request, challenge):
131        # type: (PipelineRequest, HttpChallenge) -> None
132        """authenticate according to challenge, add Authorization header to request"""
133
134        if self._need_new_token:
135            # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
136            scope = challenge.get_scope() or challenge.get_resource() + "/.default"
137            self._token = self._credential.get_token(scope)
138
139        # ignore mypy's warning because although self._token is Optional, get_token raises when it fails to get a token
140        request.http_request.headers["Authorization"] = "Bearer {}".format(self._token.token)  # type: ignore
141