1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26
27import ast
28import os
29import logging
30import re
31import time
32import warnings
33try:
34    from urlparse import urlparse, parse_qs
35except ImportError:
36    from urllib.parse import urlparse, parse_qs
37
38import adal
39from requests import RequestException, ConnectionError, HTTPError
40import requests
41
42from msrest.authentication import OAuthTokenAuthentication, Authentication, BasicTokenAuthentication
43from msrest.exceptions import TokenExpiredError as Expired
44from msrest.exceptions import AuthenticationError, raise_with_traceback
45
46from msrestazure.azure_cloud import AZURE_CHINA_CLOUD, AZURE_PUBLIC_CLOUD
47from msrestazure.azure_configuration import AzureConfiguration
48from msrestazure.azure_exceptions import MSIAuthenticationTimeoutError
49
50_LOGGER = logging.getLogger(__name__)
51
52class AADMixin(OAuthTokenAuthentication):
53    """Mixin for Authentication object.
54    Provides some AAD functionality:
55
56    - Token caching and retrieval
57    - Default AAD configuration
58    """
59    _case = re.compile('([a-z0-9])([A-Z])')
60
61    def _configure(self, **kwargs):
62        """Configure authentication endpoint.
63
64        Optional kwargs may include:
65
66            - cloud_environment (msrestazure.azure_cloud.Cloud): A targeted cloud environment
67            - china (bool): Configure auth for China-based service,
68              default is 'False'.
69            - tenant (str): Alternative tenant, default is 'common'.
70            - resource (str): Alternative authentication resource, default
71              is 'https://management.core.windows.net/'.
72            - verify (bool): Verify secure connection, default is 'True'.
73            - timeout (int): Timeout of the request in seconds.
74            - proxies (dict): Dictionary mapping protocol or protocol and
75              hostname to the URL of the proxy.
76            - cache (adal.TokenCache): A adal.TokenCache, see ADAL configuration
77              for details. This parameter is not used here and directly passed to ADAL.
78        """
79        if kwargs.get('china'):
80            err_msg = ("china parameter is deprecated, "
81                       "please use "
82                       "cloud_environment=msrestazure.azure_cloud.AZURE_CHINA_CLOUD")
83            warnings.warn(err_msg, DeprecationWarning)
84            self._cloud_environment = AZURE_CHINA_CLOUD
85        else:
86            self._cloud_environment = AZURE_PUBLIC_CLOUD
87        self._cloud_environment = kwargs.get('cloud_environment', self._cloud_environment)
88
89        auth_endpoint = self._cloud_environment.endpoints.active_directory
90        resource = self._cloud_environment.endpoints.active_directory_resource_id
91
92        self._tenant = kwargs.get('tenant', "common")
93        self._verify = kwargs.get('verify')  # 'None' will honor ADAL_PYTHON_SSL_NO_VERIFY
94        self.resource = kwargs.get('resource', resource)
95        self._proxies = kwargs.get('proxies')
96        self._timeout = kwargs.get('timeout')
97        self._cache = kwargs.get('cache')
98        self.store_key = "{}_{}".format(
99            auth_endpoint.strip('/'), self.store_key)
100        self.secret = None
101        self._context = None  # Future ADAL context
102
103    def _create_adal_context(self):
104        authority_url = self.cloud_environment.endpoints.active_directory
105        is_adfs = bool(re.match('.+(/adfs|/adfs/)$', authority_url, re.I))
106        if is_adfs:
107            authority_url = authority_url.rstrip('/')  # workaround: ADAL is known to reject auth urls with trailing /
108        else:
109            authority_url = authority_url + '/' + self._tenant
110
111        self._context = adal.AuthenticationContext(
112            authority_url,
113            timeout=self._timeout,
114            verify_ssl=self._verify,
115            proxies=self._proxies,
116            validate_authority=not is_adfs,
117            cache=self._cache,
118            api_version=None
119        )
120
121    def _destroy_adal_context(self):
122        self._context = None
123
124    @property
125    def verify(self):
126        return self._verify
127
128    @verify.setter
129    def verify(self, value):
130        self._verify = value
131        self._destroy_adal_context()
132
133    @property
134    def proxies(self):
135        return self._proxies
136
137    @proxies.setter
138    def proxies(self, value):
139        self._proxies = value
140        self._destroy_adal_context()
141
142    @property
143    def timeout(self):
144        return self._timeout
145
146    @timeout.setter
147    def timeout(self, value):
148        self._timeout = value
149        self._destroy_adal_context()
150
151    @property
152    def cloud_environment(self):
153        return self._cloud_environment
154
155    @cloud_environment.setter
156    def cloud_environment(self, value):
157        self._cloud_environment = value
158        self._destroy_adal_context()
159
160    def _convert_token(self, token):
161        """Convert token fields from camel case.
162
163        :param dict token: An authentication token.
164        :rtype: dict
165        """
166        # Beware that ADAL returns a pointer to its own dict, do
167        # NOT change it in place
168        token = token.copy()
169
170        # If it's from ADAL, expiresOn will be in ISO form.
171        # Bring it back to float, using expiresIn
172        if "expiresOn" in token and "expiresIn" in token:
173            token["expiresOn"] = token['expiresIn'] + time.time()
174        return {self._case.sub(r'\1_\2', k).lower(): v
175                for k, v in token.items()}
176
177    def _parse_token(self):
178        # AD answers 'expires_on', and Python oauthlib expects 'expires_at'
179        if 'expires_on' in self.token and 'expires_at' not in self.token:
180            self.token['expires_at'] = self.token['expires_on']
181
182        if self.token.get('expires_at'):
183            countdown = float(self.token['expires_at']) - time.time()
184            self.token['expires_in'] = countdown
185
186    def set_token(self):
187        if not self._context:
188            self._create_adal_context()
189
190    def signed_session(self, session=None):
191        """Create token-friendly Requests session, using auto-refresh.
192        Used internally when a request is made.
193
194        If a session object is provided, configure it directly. Otherwise,
195        create a new session and return it.
196
197        :param session: The session to configure for authentication
198        :type session: requests.Session
199        """
200        self.set_token() # Adal does the caching.
201        self._parse_token()
202        return super(AADMixin, self).signed_session(session)
203
204    def refresh_session(self, session=None):
205        """Return updated session if token has expired, attempts to
206        refresh using newly acquired token.
207
208        If a session object is provided, configure it directly. Otherwise,
209        create a new session and return it.
210
211        :param session: The session to configure for authentication
212        :type session: requests.Session
213        :rtype: requests.Session.
214        """
215        if 'refresh_token' in self.token:
216            try:
217                token = self._context.acquire_token_with_refresh_token(
218                    self.token['refresh_token'],
219                    self.id,
220                    self.resource,
221                    self.secret # This is needed when using Confidential Client
222                )
223                self.token = self._convert_token(token)
224            except adal.AdalError as err:
225                raise_with_traceback(AuthenticationError, "", err)
226        return self.signed_session(session)
227
228
229class AADTokenCredentials(AADMixin):
230    """
231    Credentials objects for AAD token retrieved through external process
232    e.g. Python ADAL lib.
233
234    If you just provide "token", refresh will be done on Public Azure with
235    default public Azure "resource". You can set "cloud_environment",
236    "tenant", "resource" and "client_id" to change that behavior.
237
238    Optional kwargs may include:
239
240    - cloud_environment (msrestazure.azure_cloud.Cloud): A targeted cloud environment
241    - china (bool): Configure auth for China-based service,
242      default is 'False'.
243    - tenant (str): Alternative tenant, default is 'common'.
244    - resource (str): Alternative authentication resource, default
245      is 'https://management.core.windows.net/'.
246    - verify (bool): Verify secure connection, default is 'True'.
247    - cache (adal.TokenCache): A adal.TokenCache, see ADAL configuration
248    for details. This parameter is not used here and directly passed to ADAL.
249
250
251    :param dict token: Authentication token.
252    :param str client_id: Client ID, if not set, Xplat Client ID
253     will be used.
254    """
255
256    def __init__(self, token, client_id=None, **kwargs):
257        if not client_id:
258            # Default to Xplat Client ID.
259            client_id = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'
260        super(AADTokenCredentials, self).__init__(client_id, None)
261        self._configure(**kwargs)
262        self.client = None
263        self.token = self._convert_token(token)
264
265
266class UserPassCredentials(AADMixin):
267    """Credentials object for Headless Authentication,
268    i.e. AAD authentication via username and password.
269
270    Headless Auth requires an AAD login (no a Live ID) that already has
271    permission to access the resource e.g. an organization account, and
272    that 2-factor auth be disabled.
273
274    Optional kwargs may include:
275
276    - cloud_environment (msrestazure.azure_cloud.Cloud): A targeted cloud environment
277    - china (bool): Configure auth for China-based service,
278      default is 'False'.
279    - tenant (str): Alternative tenant, default is 'common'.
280    - resource (str): Alternative authentication resource, default
281      is 'https://management.core.windows.net/'.
282    - verify (bool): Verify secure connection, default is 'True'.
283    - timeout (int): Timeout of the request in seconds.
284    - proxies (dict): Dictionary mapping protocol or protocol and
285      hostname to the URL of the proxy.
286    - cache (adal.TokenCache): A adal.TokenCache, see ADAL configuration
287    for details. This parameter is not used here and directly passed to ADAL.
288
289    :param str username: Account username.
290    :param str password: Account password.
291    :param str client_id: Client ID, if not set, Xplat Client ID
292     will be used.
293    :param str secret: Client secret, only if required by server.
294    """
295
296    def __init__(self, username, password,
297                 client_id=None, secret=None, **kwargs):
298        if not client_id:
299            # Default to Xplat Client ID.
300            client_id = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'
301        super(UserPassCredentials, self).__init__(client_id, None)
302        self._configure(**kwargs)
303
304        self.store_key += "_{}".format(username)
305        self.username = username
306        self.password = password
307        self.secret = secret
308        self.set_token()
309
310
311    def set_token(self):
312        """Get token using Username/Password credentials.
313
314        :raises: AuthenticationError if credentials invalid, or call fails.
315        """
316        super(UserPassCredentials, self).set_token()
317        try:
318            token = self._context.acquire_token_with_username_password(
319                self.resource,
320                self.username,
321                self.password,
322                self.id
323            )
324            self.token = self._convert_token(token)
325        except adal.AdalError as err:
326            raise_with_traceback(AuthenticationError, "", err)
327
328class ServicePrincipalCredentials(AADMixin):
329    """Credentials object for Service Principle Authentication.
330    Authenticates via a Client ID and Secret.
331
332    Optional kwargs may include:
333
334    - cloud_environment (msrestazure.azure_cloud.Cloud): A targeted cloud environment
335    - china (bool): Configure auth for China-based service,
336      default is 'False'.
337    - tenant (str): Alternative tenant, default is 'common'.
338    - resource (str): Alternative authentication resource, default
339      is 'https://management.core.windows.net/'.
340    - verify (bool): Verify secure connection, default is 'True'.
341    - timeout (int): Timeout of the request in seconds.
342    - proxies (dict): Dictionary mapping protocol or protocol and
343      hostname to the URL of the proxy.
344    - cache (adal.TokenCache): A adal.TokenCache, see ADAL configuration
345    for details. This parameter is not used here and directly passed to ADAL.
346
347    :param str client_id: Client ID.
348    :param str secret: Client secret.
349    """
350    def __init__(self, client_id, secret, **kwargs):
351        super(ServicePrincipalCredentials, self).__init__(client_id, None)
352        self._configure(**kwargs)
353
354        self.secret = secret
355        self.set_token()
356
357    def set_token(self):
358        """Get token using Client ID/Secret credentials.
359
360        :raises: AuthenticationError if credentials invalid, or call fails.
361        """
362        super(ServicePrincipalCredentials, self).set_token()
363        try:
364            token = self._context.acquire_token_with_client_credentials(
365                self.resource,
366                self.id,
367                self.secret
368            )
369            self.token = self._convert_token(token)
370        except adal.AdalError as err:
371            raise_with_traceback(AuthenticationError, "", err)
372
373# For backward compatibility of import, but I doubt someone uses that...
374class InteractiveCredentials(object):
375    """This class has been removed and using it will raise a NotImplementedError error.
376    """
377    def __init__(self, *args, **kwargs):
378        raise NotImplementedError("InteractiveCredentials was not functionning and was removed. Please use ADAL and device code instead.")
379
380class AdalAuthentication(Authentication):  # pylint: disable=too-few-public-methods
381    """A wrapper to use ADAL for Python easily to authenticate on Azure.
382
383    .. versionadded:: 0.4.5
384
385    Take an ADAL `acquire_token` method and its parameters.
386
387    :Example:
388
389    .. code:: python
390
391        context = adal.AuthenticationContext('https://login.microsoftonline.com/ABCDEFGH-1234-1234-1234-ABCDEFGHIJKL')
392        RESOURCE = '00000002-0000-0000-c000-000000000000' #AAD graph resource
393        token = context.acquire_token_with_client_credentials(
394            RESOURCE,
395            "http://PythonSDK",
396            "Key-Configured-In-Portal")
397
398    can be written here:
399
400    .. code:: python
401
402        context = adal.AuthenticationContext('https://login.microsoftonline.com/ABCDEFGH-1234-1234-1234-ABCDEFGHIJKL')
403        RESOURCE = '00000002-0000-0000-c000-000000000000' #AAD graph resource
404        credentials = AdalAuthentication(
405            context.acquire_token_with_client_credentials,
406            RESOURCE,
407            "http://PythonSDK",
408            "Key-Configured-In-Portal")
409
410    or using a lambda if you prefer:
411
412    .. code:: python
413
414        context = adal.AuthenticationContext('https://login.microsoftonline.com/ABCDEFGH-1234-1234-1234-ABCDEFGHIJKL')
415        RESOURCE = '00000002-0000-0000-c000-000000000000' #AAD graph resource
416        credentials = AdalAuthentication(
417            lambda: context.acquire_token_with_client_credentials(
418                RESOURCE,
419                "http://PythonSDK",
420                "Key-Configured-In-Portal"
421            )
422        )
423
424    :param callable adal_method: A lambda with no args, or `acquire_token` method with args using args/kwargs
425    :param args: Optional positional args for the method
426    :param kwargs: Optional kwargs for the method
427    """
428
429    def __init__(self, adal_method, *args, **kwargs):
430        super(AdalAuthentication, self).__init__()
431        self._adal_method = adal_method
432        self._args = args
433        self._kwargs = kwargs
434
435    def signed_session(self, session=None):
436        """Create requests session with any required auth headers applied.
437
438        If a session object is provided, configure it directly. Otherwise,
439        create a new session and return it.
440
441        :param session: The session to configure for authentication
442        :type session: requests.Session
443        :rtype: requests.Session
444        """
445        session = super(AdalAuthentication, self).signed_session(session)
446
447        try:
448            raw_token = self._adal_method(*self._args, **self._kwargs)
449        except adal.AdalError as err:
450            # pylint: disable=no-member
451            if 'AADSTS70008:' in ((getattr(err, 'error_response', None) or {}).get('error_description') or ''):
452                raise Expired("Credentials have expired due to inactivity.")
453            else:
454                raise AuthenticationError(err)
455        except ConnectionError as err:
456            raise AuthenticationError('Please ensure you have network connection. Error detail: ' + str(err))
457
458        scheme, token = raw_token['tokenType'], raw_token['accessToken']
459        header = "{} {}".format(scheme, token)
460        session.headers['Authorization'] = header
461        return session
462
463def get_msi_token(resource, port=50342, msi_conf=None):
464    """Get MSI token if MSI_ENDPOINT is set.
465
466    IF MSI_ENDPOINT is not set, will try legacy access through 'http://localhost:{}/oauth2/token'.format(port).
467
468    If msi_conf is used, must be a dict of one key in ["client_id", "object_id", "msi_res_id"]
469
470    :param str resource: The resource where the token would be use.
471    :param int port: The port if not the default 50342 is used. Ignored if MSI_ENDPOINT is set.
472    :param dict[str,str] msi_conf: msi_conf if to request a token through a User Assigned Identity (if not specified, assume System Assigned)
473    """
474    request_uri = os.environ.get("MSI_ENDPOINT", 'http://localhost:{}/oauth2/token'.format(port))
475    payload = {
476        'resource': resource
477    }
478    if msi_conf:
479        if len(msi_conf) > 1:
480            raise ValueError("{} are mutually exclusive".format(list(msi_conf.keys())))
481        payload.update(msi_conf)
482
483    try:
484        result = requests.post(request_uri, data=payload, headers={'Metadata': 'true'})
485        _LOGGER.debug("MSI: Retrieving a token from %s, with payload %s", request_uri, payload)
486        result.raise_for_status()
487    except Exception as ex:  # pylint: disable=broad-except
488        _LOGGER.warning("MSI: Failed to retrieve a token from '%s' with an error of '%s'. This could be caused "
489                        "by the MSI extension not yet fully provisioned.",
490                        request_uri, ex)
491        raise
492    token_entry = result.json()
493    return token_entry['token_type'], token_entry['access_token'], token_entry
494
495def get_msi_token_webapp(resource, msi_conf=None):
496    """Get a MSI token from inside a webapp or functions.
497
498    Env variable will look like:
499
500    - MSI_ENDPOINT = http://127.0.0.1:41741/MSI/token/
501    - MSI_SECRET = 69418689F1E342DD946CB82994CDA3CB
502
503    :param str resource: The resource where the token would be use.
504    :param dict[str,str] msi_conf: msi_conf if to request a token through a User Assigned Identity (if not specified, assume System Assigned)
505    """
506    try:
507        msi_endpoint = os.environ['MSI_ENDPOINT']
508        msi_secret = os.environ['MSI_SECRET']
509    except KeyError as err:
510        err_msg = "{} required env variable was not found. You might need to restart your app/function.".format(err)
511        _LOGGER.critical(err_msg)
512        raise RuntimeError(err_msg)
513
514    clientid_param = ''
515    if msi_conf:
516        if len(msi_conf) > 1:
517            raise ValueError("{} are mutually exclusive".format(list(msi_conf.keys())))
518        elif 'client_id' not in msi_conf.keys():
519            raise ValueError('"client_id" is the only supported explicit identity option on WebApp')
520        else:
521            clientid_param = '&clientid={}'.format(msi_conf['client_id'])
522
523    request_uri = '{}/?resource={}&api-version=2017-09-01{}'.format(msi_endpoint, resource, clientid_param)
524
525    headers = {
526        'secret': msi_secret
527    }
528
529    err = None
530    try:
531        result = requests.get(request_uri, headers=headers)
532        _LOGGER.debug("MSI: Retrieving a token from %s", request_uri)
533        if result.status_code != 200:
534            err = result.text
535        # Workaround since not all failures are != 200
536        if 'ExceptionMessage' in result.text:
537            err = result.text
538    except Exception as ex:  # pylint: disable=broad-except
539        err = str(ex)
540
541    if err:
542        err_msg = "MSI: Failed to retrieve a token from '{}' with an error of '{}'.".format(
543            request_uri, err
544        )
545        _LOGGER.critical(err_msg)
546        raise RuntimeError(err_msg)
547    _LOGGER.debug('MSI: token retrieved')
548    token_entry = result.json()
549    return token_entry['token_type'], token_entry['access_token'], token_entry
550
551
552def _is_app_service():
553    # Might be discussed if we think it's not robust enough
554    return 'APPSETTING_WEBSITE_SITE_NAME' in os.environ
555
556
557class MSIAuthentication(BasicTokenAuthentication):
558    """Credentials object for MSI authentication,.
559
560    Optional kwargs may include:
561
562    - timeout: If provided, must be in seconds and indicates the maximum time we'll try to get a token before raising MSIAuthenticationTimeout
563    - client_id: Identifies, by Azure AD client id, a specific explicit identity to use when authenticating to Azure AD. Mutually exclusive with object_id and msi_res_id.
564    - object_id: Identifies, by Azure AD object id, a specific explicit identity to use when authenticating to Azure AD. Mutually exclusive with client_id and msi_res_id.
565    - msi_res_id: Identifies, by ARM resource id, a specific explicit identity to use when authenticating to Azure AD. Mutually exclusive with client_id and object_id.
566    - cloud_environment (msrestazure.azure_cloud.Cloud): A targeted cloud environment
567    - resource (str): Alternative authentication resource, default
568      is 'https://management.core.windows.net/'.
569
570    .. versionadded:: 0.4.14
571    """
572
573    def __init__(self, port=50342, **kwargs):
574        super(MSIAuthentication, self).__init__(None)
575
576        if port != 50342:
577            warnings.warn("The 'port' argument is no longer used, and will be removed in a future release", DeprecationWarning)
578        self.port = port
579
580        self.msi_conf = {k:v for k,v in kwargs.items() if k in ["client_id", "object_id", "msi_res_id"]}
581
582        self.cloud_environment = kwargs.get('cloud_environment', AZURE_PUBLIC_CLOUD)
583        self.resource = kwargs.get('resource', self.cloud_environment.endpoints.active_directory_resource_id)
584
585        if not _is_app_service() and "MSI_ENDPOINT" not in os.environ:
586            # Use IMDS if no MSI_ENDPOINT
587            self._vm_msi = _ImdsTokenProvider(
588                self.msi_conf,
589                timeout=kwargs.get("timeout")
590            )
591        # Follow the same convention as all Credentials class to check for the token at creation time #106
592        self.set_token()
593
594    def set_token(self):
595        if _is_app_service():
596            self.scheme, _, self.token = get_msi_token_webapp(self.resource, self.msi_conf)
597        elif "MSI_ENDPOINT" in os.environ:
598            self.scheme, _, self.token = get_msi_token(self.resource, self.port, self.msi_conf)
599        else:
600            token_entry = self._vm_msi.get_token(self.resource)
601            self.scheme, self.token = token_entry['token_type'], token_entry
602
603    def signed_session(self, session=None):
604        """Create requests session with any required auth headers applied.
605
606        If a session object is provided, configure it directly. Otherwise,
607        create a new session and return it.
608
609        :param session: The session to configure for authentication
610        :type session: requests.Session
611        :rtype: requests.Session
612        """
613        # Token cache is handled by the VM extension, call each time to avoid expiration
614        self.set_token()
615        return super(MSIAuthentication, self).signed_session(session)
616
617
618class _ImdsTokenProvider(object):
619    """A help class handling token acquisitions through Azure IMDS plugin.
620    """
621
622    def __init__(self, msi_conf=None, timeout=None):
623        self._user_agent = AzureConfiguration(None).user_agent
624        self.identity_type, self.identity_id = None, None
625        if msi_conf:
626            if len(msi_conf.keys()) > 1:
627                raise ValueError('"client_id", "object_id", "msi_res_id" are mutually exclusive')
628            elif len(msi_conf.keys()) == 1:
629                self.identity_type, self.identity_id = next(iter(msi_conf.items()))
630        # default to system assigned identity on an empty configuration object
631
632        self.cache = {}
633        self.timeout = timeout
634
635    def get_token(self, resource):
636        import datetime
637        # let us hit the cache first
638        token_entry = self.cache.get(resource, None)
639        if token_entry:
640            expires_on = int(token_entry['expires_on'])
641            expires_on_datetime = datetime.datetime.fromtimestamp(expires_on)
642            expiration_margin = 5  # in minutes
643            if datetime.datetime.now() + datetime.timedelta(minutes=expiration_margin) <= expires_on_datetime:
644                _LOGGER.info("MSI: token is found in cache.")
645                return token_entry
646            _LOGGER.info("MSI: cache is found but expired within %s minutes, so getting a new one.", expiration_margin)
647            self.cache.pop(resource)
648
649        token_entry = self._retrieve_token_from_imds_with_retry(resource)
650        self.cache[resource] = token_entry
651        return token_entry
652
653    def _sleep(self, time_to_wait, start_time):
654        """Sleep for time_to_wait or time remaining until timeout reached.
655
656        :param float time: Time to sleep in seconds
657        :param float start_time: Absolute time where polling started
658        :rtype: bool
659        :returns: True if timeout was used
660        """
661        if self.timeout is not None:  # 0 is acceptable value, so we really want to test None
662            time_to_sleep = max(0, min(time_to_wait, start_time + self.timeout - time.time()))
663        else:
664            time_to_sleep = time_to_wait
665        time.sleep(time_to_sleep)
666        return time_to_sleep != time_to_wait
667
668    def _retrieve_token_from_imds_with_retry(self, resource):
669        import random
670        import json
671        # 169.254.169.254 is a well known ip address hosting the web service that provides the Azure IMDS metadata
672        request_uri = 'http://169.254.169.254/metadata/identity/oauth2/token'
673        payload = {
674            'resource': resource,
675            'api-version': '2018-02-01'
676        }
677        if self.identity_id:
678            payload[self.identity_type] = self.identity_id
679
680        retry, max_retry, start_time = 1, 12, time.time()
681        # simplified version of https://en.wikipedia.org/wiki/Exponential_backoff
682        slots = [100 * ((2 << x) - 1) / 1000 for x in range(max_retry)]
683        has_timed_out = self.timeout == 0 # Assume a 0 timeout means "no more than one try"
684        while True:
685            result = requests.get(request_uri, params=payload, headers={'Metadata': 'true', 'User-Agent':self._user_agent})
686            _LOGGER.debug("MSI: Retrieving a token from %s, with payload %s", request_uri, payload)
687            if result.status_code in [404, 410, 429] or (499 < result.status_code < 600):
688                if has_timed_out:  # It was the last try, and we still don't get a good status code, die
689                    raise MSIAuthenticationTimeoutError('MSI: Failed to acquired tokens before timeout {}'.format(self.timeout))
690                elif retry <= max_retry:
691                    wait = random.choice(slots[:retry])
692                    _LOGGER.warning("MSI: wait: %ss and retry: %s", wait, retry)
693                    has_timed_out = self._sleep(wait, start_time)
694                    retry += 1
695                else:
696                    if result.status_code == 410:  # For IMDS upgrading, we wait up to 70s
697                        gap = 70 - (time.time() - start_time)
698                        if gap > 0:
699                            _LOGGER.warning("MSI: wait till 70 seconds when IMDS is upgrading")
700                            has_timed_out = self._sleep(gap, start_time)
701                            continue
702                    break
703            elif result.status_code != 200:
704                raise HTTPError(request=result.request, response=result.raw)
705            else:
706                break
707
708        if result.status_code != 200:
709            raise MSIAuthenticationTimeoutError('MSI: Failed to acquire tokens after {} times'.format(max_retry))
710
711        _LOGGER.debug('MSI: Token retrieved')
712        token_entry = json.loads(result.content.decode())
713        return token_entry
714