1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6
7import base64
8import hashlib
9import re
10import random
11from time import time
12from io import SEEK_SET, UnsupportedOperation
13import logging
14import uuid
15import types
16from typing import Any, TYPE_CHECKING
17from wsgiref.handlers import format_date_time
18try:
19    from urllib.parse import (
20        urlparse,
21        parse_qsl,
22        urlunparse,
23        urlencode,
24    )
25except ImportError:
26    from urllib import urlencode # type: ignore
27    from urlparse import ( # type: ignore
28        urlparse,
29        parse_qsl,
30        urlunparse,
31    )
32
33from azure.core.pipeline.policies import (
34    HeadersPolicy,
35    SansIOHTTPPolicy,
36    NetworkTraceLoggingPolicy,
37    HTTPPolicy,
38    RequestHistory
39)
40from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError
41
42from .models import LocationMode
43
44try:
45    _unicode_type = unicode # type: ignore
46except NameError:
47    _unicode_type = str
48
49if TYPE_CHECKING:
50    from azure.core.pipeline import PipelineRequest, PipelineResponse
51
52
53_LOGGER = logging.getLogger(__name__)
54
55
56def encode_base64(data):
57    if isinstance(data, _unicode_type):
58        data = data.encode('utf-8')
59    encoded = base64.b64encode(data)
60    return encoded.decode('utf-8')
61
62
63def is_exhausted(settings):
64    """Are we out of retries?"""
65    retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status'])
66    retry_counts = list(filter(None, retry_counts))
67    if not retry_counts:
68        return False
69    return min(retry_counts) < 0
70
71
72def retry_hook(settings, **kwargs):
73    if settings['hook']:
74        settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs)
75
76
77def is_retry(response, mode):
78    """Is this method/status code retryable? (Based on whitelists and control
79    variables such as the number of total retries to allow, whether to
80    respect the Retry-After header, whether this header is present, and
81    whether the returned status code is on the list of status codes to
82    be retried upon on the presence of the aforementioned header)
83    """
84    status = response.http_response.status_code
85    if 300 <= status < 500:
86        # An exception occured, but in most cases it was expected. Examples could
87        # include a 309 Conflict or 412 Precondition Failed.
88        if status == 404 and mode == LocationMode.SECONDARY:
89            # Response code 404 should be retried if secondary was used.
90            return True
91        if status == 408:
92            # Response code 408 is a timeout and should be retried.
93            return True
94        return False
95    if status >= 500:
96        # Response codes above 500 with the exception of 501 Not Implemented and
97        # 505 Version Not Supported indicate a server issue and should be retried.
98        if status in [501, 505]:
99            return False
100        return True
101    return False
102
103
104def urljoin(base_url, stub_url):
105    parsed = urlparse(base_url)
106    parsed = parsed._replace(path=parsed.path + '/' + stub_url)
107    return parsed.geturl()
108
109
110class QueueMessagePolicy(SansIOHTTPPolicy):
111
112    def on_request(self, request):
113        message_id = request.context.options.pop('queue_message_id', None)
114        if message_id:
115            request.http_request.url = urljoin(
116                request.http_request.url,
117                message_id)
118
119
120class StorageHeadersPolicy(HeadersPolicy):
121    request_id_header_name = 'x-ms-client-request-id'
122
123    def on_request(self, request):
124        # type: (PipelineRequest, Any) -> None
125        super(StorageHeadersPolicy, self).on_request(request)
126        current_time = format_date_time(time())
127        request.http_request.headers['x-ms-date'] = current_time
128
129        custom_id = request.context.options.pop('client_request_id', None)
130        request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1())
131
132    # def on_response(self, request, response):
133    #     # raise exception if the echoed client request id from the service is not identical to the one we sent
134    #     if self.request_id_header_name in response.http_response.headers:
135
136    #         client_request_id = request.http_request.headers.get(self.request_id_header_name)
137
138    #         if response.http_response.headers[self.request_id_header_name] != client_request_id:
139    #             raise AzureError(
140    #                 "Echoed client request ID: {} does not match sent client request ID: {}.  "
141    #                 "Service request ID: {}".format(
142    #                     response.http_response.headers[self.request_id_header_name], client_request_id,
143    #                     response.http_response.headers['x-ms-request-id']),
144    #                 response=response.http_response
145    #             )
146
147
148class StorageHosts(SansIOHTTPPolicy):
149
150    def __init__(self, hosts=None, **kwargs):  # pylint: disable=unused-argument
151        self.hosts = hosts
152        super(StorageHosts, self).__init__()
153
154    def on_request(self, request):
155        # type: (PipelineRequest, Any) -> None
156        request.context.options['hosts'] = self.hosts
157        parsed_url = urlparse(request.http_request.url)
158
159        # Detect what location mode we're currently requesting with
160        location_mode = LocationMode.PRIMARY
161        for key, value in self.hosts.items():
162            if parsed_url.netloc == value:
163                location_mode = key
164
165        # See if a specific location mode has been specified, and if so, redirect
166        use_location = request.context.options.pop('use_location', None)
167        if use_location:
168            # Lock retries to the specific location
169            request.context.options['retry_to_secondary'] = False
170            if use_location not in self.hosts:
171                raise ValueError("Attempting to use undefined host location {}".format(use_location))
172            if use_location != location_mode:
173                # Update request URL to use the specified location
174                updated = parsed_url._replace(netloc=self.hosts[use_location])
175                request.http_request.url = updated.geturl()
176                location_mode = use_location
177
178        request.context.options['location_mode'] = location_mode
179
180
181class StorageLoggingPolicy(NetworkTraceLoggingPolicy):
182    """A policy that logs HTTP request and response to the DEBUG logger.
183
184    This accepts both global configuration, and per-request level with "enable_http_logger"
185    """
186
187    def on_request(self, request):
188        # type: (PipelineRequest, Any) -> None
189        http_request = request.http_request
190        options = request.context.options
191        if options.pop("logging_enable", self.enable_http_logger):
192            request.context["logging_enable"] = True
193            if not _LOGGER.isEnabledFor(logging.DEBUG):
194                return
195
196            try:
197                log_url = http_request.url
198                query_params = http_request.query
199                if 'sig' in query_params:
200                    log_url = log_url.replace(query_params['sig'], "sig=*****")
201                _LOGGER.debug("Request URL: %r", log_url)
202                _LOGGER.debug("Request method: %r", http_request.method)
203                _LOGGER.debug("Request headers:")
204                for header, value in http_request.headers.items():
205                    if header.lower() == 'authorization':
206                        value = '*****'
207                    elif header.lower() == 'x-ms-copy-source' and 'sig' in value:
208                        # take the url apart and scrub away the signed signature
209                        scheme, netloc, path, params, query, fragment = urlparse(value)
210                        parsed_qs = dict(parse_qsl(query))
211                        parsed_qs['sig'] = '*****'
212
213                        # the SAS needs to be put back together
214                        value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment))
215
216                    _LOGGER.debug("    %r: %r", header, value)
217                _LOGGER.debug("Request body:")
218
219                # We don't want to log the binary data of a file upload.
220                if isinstance(http_request.body, types.GeneratorType):
221                    _LOGGER.debug("File upload")
222                else:
223                    _LOGGER.debug(str(http_request.body))
224            except Exception as err:  # pylint: disable=broad-except
225                _LOGGER.debug("Failed to log request: %r", err)
226
227    def on_response(self, request, response):
228        # type: (PipelineRequest, PipelineResponse, Any) -> None
229        if response.context.pop("logging_enable", self.enable_http_logger):
230            if not _LOGGER.isEnabledFor(logging.DEBUG):
231                return
232
233            try:
234                _LOGGER.debug("Response status: %r", response.http_response.status_code)
235                _LOGGER.debug("Response headers:")
236                for res_header, value in response.http_response.headers.items():
237                    _LOGGER.debug("    %r: %r", res_header, value)
238
239                # We don't want to log binary data if the response is a file.
240                _LOGGER.debug("Response content:")
241                pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE)
242                header = response.http_response.headers.get('content-disposition')
243
244                if header and pattern.match(header):
245                    filename = header.partition('=')[2]
246                    _LOGGER.debug("File attachments: %s", filename)
247                elif response.http_response.headers.get("content-type", "").endswith("octet-stream"):
248                    _LOGGER.debug("Body contains binary data.")
249                elif response.http_response.headers.get("content-type", "").startswith("image"):
250                    _LOGGER.debug("Body contains image data.")
251                else:
252                    if response.context.options.get('stream', False):
253                        _LOGGER.debug("Body is streamable")
254                    else:
255                        _LOGGER.debug(response.http_response.text())
256            except Exception as err:  # pylint: disable=broad-except
257                _LOGGER.debug("Failed to log response: %s", repr(err))
258
259
260class StorageRequestHook(SansIOHTTPPolicy):
261
262    def __init__(self, **kwargs):  # pylint: disable=unused-argument
263        self._request_callback = kwargs.get('raw_request_hook')
264        super(StorageRequestHook, self).__init__()
265
266    def on_request(self, request):
267        # type: (PipelineRequest, **Any) -> PipelineResponse
268        request_callback = request.context.options.pop('raw_request_hook', self._request_callback)
269        if request_callback:
270            request_callback(request)
271
272
273class StorageResponseHook(HTTPPolicy):
274
275    def __init__(self, **kwargs):  # pylint: disable=unused-argument
276        self._response_callback = kwargs.get('raw_response_hook')
277        super(StorageResponseHook, self).__init__()
278
279    def send(self, request):
280        # type: (PipelineRequest) -> PipelineResponse
281        data_stream_total = request.context.get('data_stream_total') or \
282            request.context.options.pop('data_stream_total', None)
283        download_stream_current = request.context.get('download_stream_current') or \
284            request.context.options.pop('download_stream_current', None)
285        upload_stream_current = request.context.get('upload_stream_current') or \
286            request.context.options.pop('upload_stream_current', None)
287        response_callback = request.context.get('response_callback') or \
288            request.context.options.pop('raw_response_hook', self._response_callback)
289
290        response = self.next.send(request)
291        will_retry = is_retry(response, request.context.options.get('mode'))
292        if not will_retry and download_stream_current is not None:
293            download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
294            if data_stream_total is None:
295                content_range = response.http_response.headers.get('Content-Range')
296                if content_range:
297                    data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
298                else:
299                    data_stream_total = download_stream_current
300        elif not will_retry and upload_stream_current is not None:
301            upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
302        for pipeline_obj in [request, response]:
303            pipeline_obj.context['data_stream_total'] = data_stream_total
304            pipeline_obj.context['download_stream_current'] = download_stream_current
305            pipeline_obj.context['upload_stream_current'] = upload_stream_current
306        if response_callback:
307            response_callback(response)
308            request.context['response_callback'] = response_callback
309        return response
310
311
312class StorageContentValidation(SansIOHTTPPolicy):
313    """A simple policy that sends the given headers
314    with the request.
315
316    This will overwrite any headers already defined in the request.
317    """
318    header_name = 'Content-MD5'
319
320    def __init__(self, **kwargs):  # pylint: disable=unused-argument
321        super(StorageContentValidation, self).__init__()
322
323    @staticmethod
324    def get_content_md5(data):
325        md5 = hashlib.md5() # nosec
326        if isinstance(data, bytes):
327            md5.update(data)
328        elif hasattr(data, 'read'):
329            pos = 0
330            try:
331                pos = data.tell()
332            except:  # pylint: disable=bare-except
333                pass
334            for chunk in iter(lambda: data.read(4096), b""):
335                md5.update(chunk)
336            try:
337                data.seek(pos, SEEK_SET)
338            except (AttributeError, IOError):
339                raise ValueError("Data should be bytes or a seekable file-like object.")
340        else:
341            raise ValueError("Data should be bytes or a seekable file-like object.")
342
343        return md5.digest()
344
345    def on_request(self, request):
346        # type: (PipelineRequest, Any) -> None
347        validate_content = request.context.options.pop('validate_content', False)
348        if validate_content and request.http_request.method != 'GET':
349            computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data))
350            request.http_request.headers[self.header_name] = computed_md5
351            request.context['validate_content_md5'] = computed_md5
352        request.context['validate_content'] = validate_content
353
354    def on_response(self, request, response):
355        if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
356            computed_md5 = request.context.get('validate_content_md5') or \
357                encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
358            if response.http_response.headers['content-md5'] != computed_md5:
359                raise AzureError(
360                    'MD5 mismatch. Expected value is \'{0}\', computed value is \'{1}\'.'.format(
361                        response.http_response.headers['content-md5'], computed_md5),
362                    response=response.http_response
363                )
364
365
366class StorageRetryPolicy(HTTPPolicy):
367    """
368    The base class for Exponential and Linear retries containing shared code.
369    """
370
371    def __init__(self, **kwargs):
372        self.total_retries = kwargs.pop('retry_total', 10)
373        self.connect_retries = kwargs.pop('retry_connect', 3)
374        self.read_retries = kwargs.pop('retry_read', 3)
375        self.status_retries = kwargs.pop('retry_status', 3)
376        self.retry_to_secondary = kwargs.pop('retry_to_secondary', False)
377        super(StorageRetryPolicy, self).__init__()
378
379    def _set_next_host_location(self, settings, request):  # pylint: disable=no-self-use
380        """
381        A function which sets the next host location on the request, if applicable.
382
383        :param ~azure.storage.models.RetryContext context:
384            The retry context containing the previous host location and the request
385            to evaluate and possibly modify.
386        """
387        if settings['hosts'] and all(settings['hosts'].values()):
388            url = urlparse(request.url)
389            # If there's more than one possible location, retry to the alternative
390            if settings['mode'] == LocationMode.PRIMARY:
391                settings['mode'] = LocationMode.SECONDARY
392            else:
393                settings['mode'] = LocationMode.PRIMARY
394            updated = url._replace(netloc=settings['hosts'].get(settings['mode']))
395            request.url = updated.geturl()
396
397    def configure_retries(self, request):  # pylint: disable=no-self-use
398        body_position = None
399        if hasattr(request.http_request.body, 'read'):
400            try:
401                body_position = request.http_request.body.tell()
402            except (AttributeError, UnsupportedOperation):
403                # if body position cannot be obtained, then retries will not work
404                pass
405        options = request.context.options
406        return {
407            'total': options.pop("retry_total", self.total_retries),
408            'connect': options.pop("retry_connect", self.connect_retries),
409            'read': options.pop("retry_read", self.read_retries),
410            'status': options.pop("retry_status", self.status_retries),
411            'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary),
412            'mode': options.pop("location_mode", LocationMode.PRIMARY),
413            'hosts': options.pop("hosts", None),
414            'hook': options.pop("retry_hook", None),
415            'body_position': body_position,
416            'count': 0,
417            'history': []
418        }
419
420    def get_backoff_time(self, settings):  # pylint: disable=unused-argument,no-self-use
421        """ Formula for computing the current backoff.
422        Should be calculated by child class.
423
424        :rtype: float
425        """
426        return 0
427
428    def sleep(self, settings, transport):
429        backoff = self.get_backoff_time(settings)
430        if not backoff or backoff < 0:
431            return
432        transport.sleep(backoff)
433
434    def increment(self, settings, request, response=None, error=None):
435        """Increment the retry counters.
436
437        :param response: A pipeline response object.
438        :param error: An error encountered during the request, or
439            None if the response was received successfully.
440
441        :return: Whether the retry attempts are exhausted.
442        """
443        settings['total'] -= 1
444
445        if error and isinstance(error, ServiceRequestError):
446            # Errors when we're fairly sure that the server did not receive the
447            # request, so it should be safe to retry.
448            settings['connect'] -= 1
449            settings['history'].append(RequestHistory(request, error=error))
450
451        elif error and isinstance(error, ServiceResponseError):
452            # Errors that occur after the request has been started, so we should
453            # assume that the server began processing it.
454            settings['read'] -= 1
455            settings['history'].append(RequestHistory(request, error=error))
456
457        else:
458            # Incrementing because of a server error like a 500 in
459            # status_forcelist and a the given method is in the whitelist
460            if response:
461                settings['status'] -= 1
462                settings['history'].append(RequestHistory(request, http_response=response))
463
464        if not is_exhausted(settings):
465            if request.method not in ['PUT'] and settings['retry_secondary']:
466                self._set_next_host_location(settings, request)
467
468            # rewind the request body if it is a stream
469            if request.body and hasattr(request.body, 'read'):
470                # no position was saved, then retry would not work
471                if settings['body_position'] is None:
472                    return False
473                try:
474                    # attempt to rewind the body to the initial position
475                    request.body.seek(settings['body_position'], SEEK_SET)
476                except (UnsupportedOperation, ValueError):
477                    # if body is not seekable, then retry would not work
478                    return False
479            settings['count'] += 1
480            return True
481        return False
482
483    def send(self, request):
484        retries_remaining = True
485        response = None
486        retry_settings = self.configure_retries(request)
487        while retries_remaining:
488            try:
489                response = self.next.send(request)
490                if is_retry(response, retry_settings['mode']):
491                    retries_remaining = self.increment(
492                        retry_settings,
493                        request=request.http_request,
494                        response=response.http_response)
495                    if retries_remaining:
496                        retry_hook(
497                            retry_settings,
498                            request=request.http_request,
499                            response=response.http_response,
500                            error=None)
501                        self.sleep(retry_settings, request.context.transport)
502                        continue
503                break
504            except AzureError as err:
505                retries_remaining = self.increment(
506                    retry_settings, request=request.http_request, error=err)
507                if retries_remaining:
508                    retry_hook(
509                        retry_settings,
510                        request=request.http_request,
511                        response=None,
512                        error=err)
513                    self.sleep(retry_settings, request.context.transport)
514                    continue
515                raise err
516        if retry_settings['history']:
517            response.context['history'] = retry_settings['history']
518        response.http_response.location_mode = retry_settings['mode']
519        return response
520
521
522class ExponentialRetry(StorageRetryPolicy):
523    """Exponential retry."""
524
525    def __init__(self, initial_backoff=15, increment_base=3, retry_total=3,
526                 retry_to_secondary=False, random_jitter_range=3, **kwargs):
527        '''
528        Constructs an Exponential retry object. The initial_backoff is used for
529        the first retry. Subsequent retries are retried after initial_backoff +
530        increment_power^retry_count seconds. For example, by default the first retry
531        occurs after 15 seconds, the second after (15+3^1) = 18 seconds, and the
532        third after (15+3^2) = 24 seconds.
533
534        :param int initial_backoff:
535            The initial backoff interval, in seconds, for the first retry.
536        :param int increment_base:
537            The base, in seconds, to increment the initial_backoff by after the
538            first retry.
539        :param int max_attempts:
540            The maximum number of retry attempts.
541        :param bool retry_to_secondary:
542            Whether the request should be retried to secondary, if able. This should
543            only be enabled of RA-GRS accounts are used and potentially stale data
544            can be handled.
545        :param int random_jitter_range:
546            A number in seconds which indicates a range to jitter/randomize for the back-off interval.
547            For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
548        '''
549        self.initial_backoff = initial_backoff
550        self.increment_base = increment_base
551        self.random_jitter_range = random_jitter_range
552        super(ExponentialRetry, self).__init__(
553            retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
554
555    def get_backoff_time(self, settings):
556        """
557        Calculates how long to sleep before retrying.
558
559        :return:
560            An integer indicating how long to wait before retrying the request,
561            or None to indicate no retry should be performed.
562        :rtype: int or None
563        """
564        random_generator = random.Random()
565        backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count']))
566        random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0
567        random_range_end = backoff + self.random_jitter_range
568        return random_generator.uniform(random_range_start, random_range_end)
569
570
571class LinearRetry(StorageRetryPolicy):
572    """Linear retry."""
573
574    def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_jitter_range=3, **kwargs):
575        """
576        Constructs a Linear retry object.
577
578        :param int backoff:
579            The backoff interval, in seconds, between retries.
580        :param int max_attempts:
581            The maximum number of retry attempts.
582        :param bool retry_to_secondary:
583            Whether the request should be retried to secondary, if able. This should
584            only be enabled of RA-GRS accounts are used and potentially stale data
585            can be handled.
586        :param int random_jitter_range:
587            A number in seconds which indicates a range to jitter/randomize for the back-off interval.
588            For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
589        """
590        self.backoff = backoff
591        self.random_jitter_range = random_jitter_range
592        super(LinearRetry, self).__init__(
593            retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
594
595    def get_backoff_time(self, settings):
596        """
597        Calculates how long to sleep before retrying.
598
599        :return:
600            An integer indicating how long to wait before retrying the request,
601            or None to indicate no retry should be performed.
602        :rtype: int or None
603        """
604        random_generator = random.Random()
605        # the backoff interval normally does not change, however there is the possibility
606        # that it was modified by accessing the property directly after initializing the object
607        random_range_start = self.backoff - self.random_jitter_range \
608            if self.backoff > self.random_jitter_range else 0
609        random_range_end = self.backoff + self.random_jitter_range
610        return random_generator.uniform(random_range_start, random_range_end)
611