1# ------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for 4# license information. 5# -------------------------------------------------------------------------- 6 7import base64 8import hashlib 9import re 10import random 11from time import time 12from io import SEEK_SET, UnsupportedOperation 13import logging 14import uuid 15import types 16from typing import Any, TYPE_CHECKING 17from wsgiref.handlers import format_date_time 18try: 19 from urllib.parse import ( 20 urlparse, 21 parse_qsl, 22 urlunparse, 23 urlencode, 24 ) 25except ImportError: 26 from urllib import urlencode # type: ignore 27 from urlparse import ( # type: ignore 28 urlparse, 29 parse_qsl, 30 urlunparse, 31 ) 32 33from azure.core.pipeline.policies import ( 34 HeadersPolicy, 35 SansIOHTTPPolicy, 36 NetworkTraceLoggingPolicy, 37 HTTPPolicy, 38 RequestHistory 39) 40from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError 41 42from .models import LocationMode 43 44try: 45 _unicode_type = unicode # type: ignore 46except NameError: 47 _unicode_type = str 48 49if TYPE_CHECKING: 50 from azure.core.pipeline import PipelineRequest, PipelineResponse 51 52 53_LOGGER = logging.getLogger(__name__) 54 55 56def encode_base64(data): 57 if isinstance(data, _unicode_type): 58 data = data.encode('utf-8') 59 encoded = base64.b64encode(data) 60 return encoded.decode('utf-8') 61 62 63def is_exhausted(settings): 64 """Are we out of retries?""" 65 retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) 66 retry_counts = list(filter(None, retry_counts)) 67 if not retry_counts: 68 return False 69 return min(retry_counts) < 0 70 71 72def retry_hook(settings, **kwargs): 73 if settings['hook']: 74 settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs) 75 76 77def is_retry(response, mode): 78 """Is this method/status code retryable? (Based on whitelists and control 79 variables such as the number of total retries to allow, whether to 80 respect the Retry-After header, whether this header is present, and 81 whether the returned status code is on the list of status codes to 82 be retried upon on the presence of the aforementioned header) 83 """ 84 status = response.http_response.status_code 85 if 300 <= status < 500: 86 # An exception occured, but in most cases it was expected. Examples could 87 # include a 309 Conflict or 412 Precondition Failed. 88 if status == 404 and mode == LocationMode.SECONDARY: 89 # Response code 404 should be retried if secondary was used. 90 return True 91 if status == 408: 92 # Response code 408 is a timeout and should be retried. 93 return True 94 return False 95 if status >= 500: 96 # Response codes above 500 with the exception of 501 Not Implemented and 97 # 505 Version Not Supported indicate a server issue and should be retried. 98 if status in [501, 505]: 99 return False 100 return True 101 return False 102 103 104def urljoin(base_url, stub_url): 105 parsed = urlparse(base_url) 106 parsed = parsed._replace(path=parsed.path + '/' + stub_url) 107 return parsed.geturl() 108 109 110class QueueMessagePolicy(SansIOHTTPPolicy): 111 112 def on_request(self, request): 113 message_id = request.context.options.pop('queue_message_id', None) 114 if message_id: 115 request.http_request.url = urljoin( 116 request.http_request.url, 117 message_id) 118 119 120class StorageHeadersPolicy(HeadersPolicy): 121 request_id_header_name = 'x-ms-client-request-id' 122 123 def on_request(self, request): 124 # type: (PipelineRequest, Any) -> None 125 super(StorageHeadersPolicy, self).on_request(request) 126 current_time = format_date_time(time()) 127 request.http_request.headers['x-ms-date'] = current_time 128 129 custom_id = request.context.options.pop('client_request_id', None) 130 request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1()) 131 132 # def on_response(self, request, response): 133 # # raise exception if the echoed client request id from the service is not identical to the one we sent 134 # if self.request_id_header_name in response.http_response.headers: 135 136 # client_request_id = request.http_request.headers.get(self.request_id_header_name) 137 138 # if response.http_response.headers[self.request_id_header_name] != client_request_id: 139 # raise AzureError( 140 # "Echoed client request ID: {} does not match sent client request ID: {}. " 141 # "Service request ID: {}".format( 142 # response.http_response.headers[self.request_id_header_name], client_request_id, 143 # response.http_response.headers['x-ms-request-id']), 144 # response=response.http_response 145 # ) 146 147 148class StorageHosts(SansIOHTTPPolicy): 149 150 def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument 151 self.hosts = hosts 152 super(StorageHosts, self).__init__() 153 154 def on_request(self, request): 155 # type: (PipelineRequest, Any) -> None 156 request.context.options['hosts'] = self.hosts 157 parsed_url = urlparse(request.http_request.url) 158 159 # Detect what location mode we're currently requesting with 160 location_mode = LocationMode.PRIMARY 161 for key, value in self.hosts.items(): 162 if parsed_url.netloc == value: 163 location_mode = key 164 165 # See if a specific location mode has been specified, and if so, redirect 166 use_location = request.context.options.pop('use_location', None) 167 if use_location: 168 # Lock retries to the specific location 169 request.context.options['retry_to_secondary'] = False 170 if use_location not in self.hosts: 171 raise ValueError("Attempting to use undefined host location {}".format(use_location)) 172 if use_location != location_mode: 173 # Update request URL to use the specified location 174 updated = parsed_url._replace(netloc=self.hosts[use_location]) 175 request.http_request.url = updated.geturl() 176 location_mode = use_location 177 178 request.context.options['location_mode'] = location_mode 179 180 181class StorageLoggingPolicy(NetworkTraceLoggingPolicy): 182 """A policy that logs HTTP request and response to the DEBUG logger. 183 184 This accepts both global configuration, and per-request level with "enable_http_logger" 185 """ 186 187 def on_request(self, request): 188 # type: (PipelineRequest, Any) -> None 189 http_request = request.http_request 190 options = request.context.options 191 if options.pop("logging_enable", self.enable_http_logger): 192 request.context["logging_enable"] = True 193 if not _LOGGER.isEnabledFor(logging.DEBUG): 194 return 195 196 try: 197 log_url = http_request.url 198 query_params = http_request.query 199 if 'sig' in query_params: 200 log_url = log_url.replace(query_params['sig'], "sig=*****") 201 _LOGGER.debug("Request URL: %r", log_url) 202 _LOGGER.debug("Request method: %r", http_request.method) 203 _LOGGER.debug("Request headers:") 204 for header, value in http_request.headers.items(): 205 if header.lower() == 'authorization': 206 value = '*****' 207 elif header.lower() == 'x-ms-copy-source' and 'sig' in value: 208 # take the url apart and scrub away the signed signature 209 scheme, netloc, path, params, query, fragment = urlparse(value) 210 parsed_qs = dict(parse_qsl(query)) 211 parsed_qs['sig'] = '*****' 212 213 # the SAS needs to be put back together 214 value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) 215 216 _LOGGER.debug(" %r: %r", header, value) 217 _LOGGER.debug("Request body:") 218 219 # We don't want to log the binary data of a file upload. 220 if isinstance(http_request.body, types.GeneratorType): 221 _LOGGER.debug("File upload") 222 else: 223 _LOGGER.debug(str(http_request.body)) 224 except Exception as err: # pylint: disable=broad-except 225 _LOGGER.debug("Failed to log request: %r", err) 226 227 def on_response(self, request, response): 228 # type: (PipelineRequest, PipelineResponse, Any) -> None 229 if response.context.pop("logging_enable", self.enable_http_logger): 230 if not _LOGGER.isEnabledFor(logging.DEBUG): 231 return 232 233 try: 234 _LOGGER.debug("Response status: %r", response.http_response.status_code) 235 _LOGGER.debug("Response headers:") 236 for res_header, value in response.http_response.headers.items(): 237 _LOGGER.debug(" %r: %r", res_header, value) 238 239 # We don't want to log binary data if the response is a file. 240 _LOGGER.debug("Response content:") 241 pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) 242 header = response.http_response.headers.get('content-disposition') 243 244 if header and pattern.match(header): 245 filename = header.partition('=')[2] 246 _LOGGER.debug("File attachments: %s", filename) 247 elif response.http_response.headers.get("content-type", "").endswith("octet-stream"): 248 _LOGGER.debug("Body contains binary data.") 249 elif response.http_response.headers.get("content-type", "").startswith("image"): 250 _LOGGER.debug("Body contains image data.") 251 else: 252 if response.context.options.get('stream', False): 253 _LOGGER.debug("Body is streamable") 254 else: 255 _LOGGER.debug(response.http_response.text()) 256 except Exception as err: # pylint: disable=broad-except 257 _LOGGER.debug("Failed to log response: %s", repr(err)) 258 259 260class StorageRequestHook(SansIOHTTPPolicy): 261 262 def __init__(self, **kwargs): # pylint: disable=unused-argument 263 self._request_callback = kwargs.get('raw_request_hook') 264 super(StorageRequestHook, self).__init__() 265 266 def on_request(self, request): 267 # type: (PipelineRequest, **Any) -> PipelineResponse 268 request_callback = request.context.options.pop('raw_request_hook', self._request_callback) 269 if request_callback: 270 request_callback(request) 271 272 273class StorageResponseHook(HTTPPolicy): 274 275 def __init__(self, **kwargs): # pylint: disable=unused-argument 276 self._response_callback = kwargs.get('raw_response_hook') 277 super(StorageResponseHook, self).__init__() 278 279 def send(self, request): 280 # type: (PipelineRequest) -> PipelineResponse 281 data_stream_total = request.context.get('data_stream_total') or \ 282 request.context.options.pop('data_stream_total', None) 283 download_stream_current = request.context.get('download_stream_current') or \ 284 request.context.options.pop('download_stream_current', None) 285 upload_stream_current = request.context.get('upload_stream_current') or \ 286 request.context.options.pop('upload_stream_current', None) 287 response_callback = request.context.get('response_callback') or \ 288 request.context.options.pop('raw_response_hook', self._response_callback) 289 290 response = self.next.send(request) 291 will_retry = is_retry(response, request.context.options.get('mode')) 292 if not will_retry and download_stream_current is not None: 293 download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) 294 if data_stream_total is None: 295 content_range = response.http_response.headers.get('Content-Range') 296 if content_range: 297 data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) 298 else: 299 data_stream_total = download_stream_current 300 elif not will_retry and upload_stream_current is not None: 301 upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) 302 for pipeline_obj in [request, response]: 303 pipeline_obj.context['data_stream_total'] = data_stream_total 304 pipeline_obj.context['download_stream_current'] = download_stream_current 305 pipeline_obj.context['upload_stream_current'] = upload_stream_current 306 if response_callback: 307 response_callback(response) 308 request.context['response_callback'] = response_callback 309 return response 310 311 312class StorageContentValidation(SansIOHTTPPolicy): 313 """A simple policy that sends the given headers 314 with the request. 315 316 This will overwrite any headers already defined in the request. 317 """ 318 header_name = 'Content-MD5' 319 320 def __init__(self, **kwargs): # pylint: disable=unused-argument 321 super(StorageContentValidation, self).__init__() 322 323 @staticmethod 324 def get_content_md5(data): 325 md5 = hashlib.md5() # nosec 326 if isinstance(data, bytes): 327 md5.update(data) 328 elif hasattr(data, 'read'): 329 pos = 0 330 try: 331 pos = data.tell() 332 except: # pylint: disable=bare-except 333 pass 334 for chunk in iter(lambda: data.read(4096), b""): 335 md5.update(chunk) 336 try: 337 data.seek(pos, SEEK_SET) 338 except (AttributeError, IOError): 339 raise ValueError("Data should be bytes or a seekable file-like object.") 340 else: 341 raise ValueError("Data should be bytes or a seekable file-like object.") 342 343 return md5.digest() 344 345 def on_request(self, request): 346 # type: (PipelineRequest, Any) -> None 347 validate_content = request.context.options.pop('validate_content', False) 348 if validate_content and request.http_request.method != 'GET': 349 computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) 350 request.http_request.headers[self.header_name] = computed_md5 351 request.context['validate_content_md5'] = computed_md5 352 request.context['validate_content'] = validate_content 353 354 def on_response(self, request, response): 355 if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): 356 computed_md5 = request.context.get('validate_content_md5') or \ 357 encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) 358 if response.http_response.headers['content-md5'] != computed_md5: 359 raise AzureError( 360 'MD5 mismatch. Expected value is \'{0}\', computed value is \'{1}\'.'.format( 361 response.http_response.headers['content-md5'], computed_md5), 362 response=response.http_response 363 ) 364 365 366class StorageRetryPolicy(HTTPPolicy): 367 """ 368 The base class for Exponential and Linear retries containing shared code. 369 """ 370 371 def __init__(self, **kwargs): 372 self.total_retries = kwargs.pop('retry_total', 10) 373 self.connect_retries = kwargs.pop('retry_connect', 3) 374 self.read_retries = kwargs.pop('retry_read', 3) 375 self.status_retries = kwargs.pop('retry_status', 3) 376 self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) 377 super(StorageRetryPolicy, self).__init__() 378 379 def _set_next_host_location(self, settings, request): # pylint: disable=no-self-use 380 """ 381 A function which sets the next host location on the request, if applicable. 382 383 :param ~azure.storage.models.RetryContext context: 384 The retry context containing the previous host location and the request 385 to evaluate and possibly modify. 386 """ 387 if settings['hosts'] and all(settings['hosts'].values()): 388 url = urlparse(request.url) 389 # If there's more than one possible location, retry to the alternative 390 if settings['mode'] == LocationMode.PRIMARY: 391 settings['mode'] = LocationMode.SECONDARY 392 else: 393 settings['mode'] = LocationMode.PRIMARY 394 updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) 395 request.url = updated.geturl() 396 397 def configure_retries(self, request): # pylint: disable=no-self-use 398 body_position = None 399 if hasattr(request.http_request.body, 'read'): 400 try: 401 body_position = request.http_request.body.tell() 402 except (AttributeError, UnsupportedOperation): 403 # if body position cannot be obtained, then retries will not work 404 pass 405 options = request.context.options 406 return { 407 'total': options.pop("retry_total", self.total_retries), 408 'connect': options.pop("retry_connect", self.connect_retries), 409 'read': options.pop("retry_read", self.read_retries), 410 'status': options.pop("retry_status", self.status_retries), 411 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary), 412 'mode': options.pop("location_mode", LocationMode.PRIMARY), 413 'hosts': options.pop("hosts", None), 414 'hook': options.pop("retry_hook", None), 415 'body_position': body_position, 416 'count': 0, 417 'history': [] 418 } 419 420 def get_backoff_time(self, settings): # pylint: disable=unused-argument,no-self-use 421 """ Formula for computing the current backoff. 422 Should be calculated by child class. 423 424 :rtype: float 425 """ 426 return 0 427 428 def sleep(self, settings, transport): 429 backoff = self.get_backoff_time(settings) 430 if not backoff or backoff < 0: 431 return 432 transport.sleep(backoff) 433 434 def increment(self, settings, request, response=None, error=None): 435 """Increment the retry counters. 436 437 :param response: A pipeline response object. 438 :param error: An error encountered during the request, or 439 None if the response was received successfully. 440 441 :return: Whether the retry attempts are exhausted. 442 """ 443 settings['total'] -= 1 444 445 if error and isinstance(error, ServiceRequestError): 446 # Errors when we're fairly sure that the server did not receive the 447 # request, so it should be safe to retry. 448 settings['connect'] -= 1 449 settings['history'].append(RequestHistory(request, error=error)) 450 451 elif error and isinstance(error, ServiceResponseError): 452 # Errors that occur after the request has been started, so we should 453 # assume that the server began processing it. 454 settings['read'] -= 1 455 settings['history'].append(RequestHistory(request, error=error)) 456 457 else: 458 # Incrementing because of a server error like a 500 in 459 # status_forcelist and a the given method is in the whitelist 460 if response: 461 settings['status'] -= 1 462 settings['history'].append(RequestHistory(request, http_response=response)) 463 464 if not is_exhausted(settings): 465 if request.method not in ['PUT'] and settings['retry_secondary']: 466 self._set_next_host_location(settings, request) 467 468 # rewind the request body if it is a stream 469 if request.body and hasattr(request.body, 'read'): 470 # no position was saved, then retry would not work 471 if settings['body_position'] is None: 472 return False 473 try: 474 # attempt to rewind the body to the initial position 475 request.body.seek(settings['body_position'], SEEK_SET) 476 except (UnsupportedOperation, ValueError): 477 # if body is not seekable, then retry would not work 478 return False 479 settings['count'] += 1 480 return True 481 return False 482 483 def send(self, request): 484 retries_remaining = True 485 response = None 486 retry_settings = self.configure_retries(request) 487 while retries_remaining: 488 try: 489 response = self.next.send(request) 490 if is_retry(response, retry_settings['mode']): 491 retries_remaining = self.increment( 492 retry_settings, 493 request=request.http_request, 494 response=response.http_response) 495 if retries_remaining: 496 retry_hook( 497 retry_settings, 498 request=request.http_request, 499 response=response.http_response, 500 error=None) 501 self.sleep(retry_settings, request.context.transport) 502 continue 503 break 504 except AzureError as err: 505 retries_remaining = self.increment( 506 retry_settings, request=request.http_request, error=err) 507 if retries_remaining: 508 retry_hook( 509 retry_settings, 510 request=request.http_request, 511 response=None, 512 error=err) 513 self.sleep(retry_settings, request.context.transport) 514 continue 515 raise err 516 if retry_settings['history']: 517 response.context['history'] = retry_settings['history'] 518 response.http_response.location_mode = retry_settings['mode'] 519 return response 520 521 522class ExponentialRetry(StorageRetryPolicy): 523 """Exponential retry.""" 524 525 def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, 526 retry_to_secondary=False, random_jitter_range=3, **kwargs): 527 ''' 528 Constructs an Exponential retry object. The initial_backoff is used for 529 the first retry. Subsequent retries are retried after initial_backoff + 530 increment_power^retry_count seconds. For example, by default the first retry 531 occurs after 15 seconds, the second after (15+3^1) = 18 seconds, and the 532 third after (15+3^2) = 24 seconds. 533 534 :param int initial_backoff: 535 The initial backoff interval, in seconds, for the first retry. 536 :param int increment_base: 537 The base, in seconds, to increment the initial_backoff by after the 538 first retry. 539 :param int max_attempts: 540 The maximum number of retry attempts. 541 :param bool retry_to_secondary: 542 Whether the request should be retried to secondary, if able. This should 543 only be enabled of RA-GRS accounts are used and potentially stale data 544 can be handled. 545 :param int random_jitter_range: 546 A number in seconds which indicates a range to jitter/randomize for the back-off interval. 547 For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. 548 ''' 549 self.initial_backoff = initial_backoff 550 self.increment_base = increment_base 551 self.random_jitter_range = random_jitter_range 552 super(ExponentialRetry, self).__init__( 553 retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) 554 555 def get_backoff_time(self, settings): 556 """ 557 Calculates how long to sleep before retrying. 558 559 :return: 560 An integer indicating how long to wait before retrying the request, 561 or None to indicate no retry should be performed. 562 :rtype: int or None 563 """ 564 random_generator = random.Random() 565 backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) 566 random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 567 random_range_end = backoff + self.random_jitter_range 568 return random_generator.uniform(random_range_start, random_range_end) 569 570 571class LinearRetry(StorageRetryPolicy): 572 """Linear retry.""" 573 574 def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_jitter_range=3, **kwargs): 575 """ 576 Constructs a Linear retry object. 577 578 :param int backoff: 579 The backoff interval, in seconds, between retries. 580 :param int max_attempts: 581 The maximum number of retry attempts. 582 :param bool retry_to_secondary: 583 Whether the request should be retried to secondary, if able. This should 584 only be enabled of RA-GRS accounts are used and potentially stale data 585 can be handled. 586 :param int random_jitter_range: 587 A number in seconds which indicates a range to jitter/randomize for the back-off interval. 588 For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. 589 """ 590 self.backoff = backoff 591 self.random_jitter_range = random_jitter_range 592 super(LinearRetry, self).__init__( 593 retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) 594 595 def get_backoff_time(self, settings): 596 """ 597 Calculates how long to sleep before retrying. 598 599 :return: 600 An integer indicating how long to wait before retrying the request, 601 or None to indicate no retry should be performed. 602 :rtype: int or None 603 """ 604 random_generator = random.Random() 605 # the backoff interval normally does not change, however there is the possibility 606 # that it was modified by accessing the property directly after initializing the object 607 random_range_start = self.backoff - self.random_jitter_range \ 608 if self.backoff > self.random_jitter_range else 0 609 random_range_end = self.backoff + self.random_jitter_range 610 return random_generator.uniform(random_range_start, random_range_end) 611