1# ------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for 4# license information. 5# -------------------------------------------------------------------------- 6 7import sys 8import threading 9import warnings 10from io import BytesIO 11 12from azure.core.exceptions import HttpResponseError 13from azure.core.tracing.common import with_current_context 14from ._shared.encryption import decrypt_blob 15from ._shared.request_handlers import validate_and_format_range_headers 16from ._shared.response_handlers import process_storage_error, parse_length_from_content_range 17from ._deserialize import get_page_ranges_result 18 19 20def process_range_and_offset(start_range, end_range, length, encryption): 21 start_offset, end_offset = 0, 0 22 if encryption.get("key") is not None or encryption.get("resolver") is not None: 23 if start_range is not None: 24 # Align the start of the range along a 16 byte block 25 start_offset = start_range % 16 26 start_range -= start_offset 27 28 # Include an extra 16 bytes for the IV if necessary 29 # Because of the previous offsetting, start_range will always 30 # be a multiple of 16. 31 if start_range > 0: 32 start_offset += 16 33 start_range -= 16 34 35 if length is not None: 36 # Align the end of the range along a 16 byte block 37 end_offset = 15 - (end_range % 16) 38 end_range += end_offset 39 40 return (start_range, end_range), (start_offset, end_offset) 41 42 43def process_content(data, start_offset, end_offset, encryption): 44 if data is None: 45 raise ValueError("Response cannot be None.") 46 try: 47 content = b"".join(list(data)) 48 except Exception as error: 49 raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error) 50 if content and encryption.get("key") is not None or encryption.get("resolver") is not None: 51 try: 52 return decrypt_blob( 53 encryption.get("required"), 54 encryption.get("key"), 55 encryption.get("resolver"), 56 content, 57 start_offset, 58 end_offset, 59 data.response.headers, 60 ) 61 except Exception as error: 62 raise HttpResponseError(message="Decryption failed.", response=data.response, error=error) 63 return content 64 65 66class _ChunkDownloader(object): # pylint: disable=too-many-instance-attributes 67 def __init__( 68 self, 69 client=None, 70 non_empty_ranges=None, 71 total_size=None, 72 chunk_size=None, 73 current_progress=None, 74 start_range=None, 75 end_range=None, 76 stream=None, 77 parallel=None, 78 validate_content=None, 79 encryption_options=None, 80 **kwargs 81 ): 82 self.client = client 83 self.non_empty_ranges = non_empty_ranges 84 85 # Information on the download range/chunk size 86 self.chunk_size = chunk_size 87 self.total_size = total_size 88 self.start_index = start_range 89 self.end_index = end_range 90 91 # The destination that we will write to 92 self.stream = stream 93 self.stream_lock = threading.Lock() if parallel else None 94 self.progress_lock = threading.Lock() if parallel else None 95 96 # For a parallel download, the stream is always seekable, so we note down the current position 97 # in order to seek to the right place when out-of-order chunks come in 98 self.stream_start = stream.tell() if parallel else None 99 100 # Download progress so far 101 self.progress_total = current_progress 102 103 # Encryption 104 self.encryption_options = encryption_options 105 106 # Parameters for each get operation 107 self.validate_content = validate_content 108 self.request_options = kwargs 109 110 def _calculate_range(self, chunk_start): 111 if chunk_start + self.chunk_size > self.end_index: 112 chunk_end = self.end_index 113 else: 114 chunk_end = chunk_start + self.chunk_size 115 return chunk_start, chunk_end 116 117 def get_chunk_offsets(self): 118 index = self.start_index 119 while index < self.end_index: 120 yield index 121 index += self.chunk_size 122 123 def process_chunk(self, chunk_start): 124 chunk_start, chunk_end = self._calculate_range(chunk_start) 125 chunk_data = self._download_chunk(chunk_start, chunk_end - 1) 126 length = chunk_end - chunk_start 127 if length > 0: 128 self._write_to_stream(chunk_data, chunk_start) 129 self._update_progress(length) 130 131 def yield_chunk(self, chunk_start): 132 chunk_start, chunk_end = self._calculate_range(chunk_start) 133 return self._download_chunk(chunk_start, chunk_end - 1) 134 135 def _update_progress(self, length): 136 if self.progress_lock: 137 with self.progress_lock: # pylint: disable=not-context-manager 138 self.progress_total += length 139 else: 140 self.progress_total += length 141 142 def _write_to_stream(self, chunk_data, chunk_start): 143 if self.stream_lock: 144 with self.stream_lock: # pylint: disable=not-context-manager 145 self.stream.seek(self.stream_start + (chunk_start - self.start_index)) 146 self.stream.write(chunk_data) 147 else: 148 self.stream.write(chunk_data) 149 150 def _do_optimize(self, given_range_start, given_range_end): 151 # If we have no page range list stored, then assume there's data everywhere for that page blob 152 # or it's a block blob or append blob 153 if self.non_empty_ranges is None: 154 return False 155 156 for source_range in self.non_empty_ranges: 157 # Case 1: As the range list is sorted, if we've reached such a source_range 158 # we've checked all the appropriate source_range already and haven't found any overlapping. 159 # so the given range doesn't have any data and download optimization could be applied. 160 # given range: | | 161 # source range: | | 162 if given_range_end < source_range['start']: # pylint:disable=no-else-return 163 return True 164 # Case 2: the given range comes after source_range, continue checking. 165 # given range: | | 166 # source range: | | 167 elif source_range['end'] < given_range_start: 168 pass 169 # Case 3: source_range and given range overlap somehow, no need to optimize. 170 else: 171 return False 172 # Went through all src_ranges, but nothing overlapped. Optimization will be applied. 173 return True 174 175 def _download_chunk(self, chunk_start, chunk_end): 176 download_range, offset = process_range_and_offset( 177 chunk_start, chunk_end, chunk_end, self.encryption_options 178 ) 179 180 # No need to download the empty chunk from server if there's no data in the chunk to be downloaded. 181 # Do optimize and create empty chunk locally if condition is met. 182 if self._do_optimize(download_range[0], download_range[1]): 183 chunk_data = b"\x00" * self.chunk_size 184 else: 185 range_header, range_validation = validate_and_format_range_headers( 186 download_range[0], 187 download_range[1], 188 check_content_md5=self.validate_content 189 ) 190 191 try: 192 _, response = self.client.download( 193 range=range_header, 194 range_get_content_md5=range_validation, 195 validate_content=self.validate_content, 196 data_stream_total=self.total_size, 197 download_stream_current=self.progress_total, 198 **self.request_options 199 ) 200 except HttpResponseError as error: 201 process_storage_error(error) 202 203 chunk_data = process_content(response, offset[0], offset[1], self.encryption_options) 204 205 # This makes sure that if_match is set so that we can validate 206 # that subsequent downloads are to an unmodified blob 207 if self.request_options.get("modified_access_conditions"): 208 self.request_options["modified_access_conditions"].if_match = response.properties.etag 209 210 return chunk_data 211 212 213class _ChunkIterator(object): 214 """Async iterator for chunks in blob download stream.""" 215 216 def __init__(self, size, content, downloader): 217 self.size = size 218 self._current_content = content 219 self._iter_downloader = downloader 220 self._iter_chunks = None 221 self._complete = (size == 0) 222 223 def __len__(self): 224 return self.size 225 226 def __iter__(self): 227 return self 228 229 def __next__(self): 230 """Iterate through responses.""" 231 if self._complete: 232 raise StopIteration("Download complete") 233 if not self._iter_downloader: 234 # If no iterator was supplied, the download completed with 235 # the initial GET, so we just return that data 236 self._complete = True 237 return self._current_content 238 239 if not self._iter_chunks: 240 self._iter_chunks = self._iter_downloader.get_chunk_offsets() 241 else: 242 chunk = next(self._iter_chunks) 243 self._current_content = self._iter_downloader.yield_chunk(chunk) 244 245 return self._current_content 246 247 next = __next__ # Python 2 compatibility. 248 249 250class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes 251 """A streaming object to download from Azure Storage. 252 253 :ivar str name: 254 The name of the blob being downloaded. 255 :ivar str container: 256 The name of the container where the blob is. 257 :ivar ~azure.storage.blob.BlobProperties properties: 258 The properties of the blob being downloaded. If only a range of the data is being 259 downloaded, this will be reflected in the properties. 260 :ivar int size: 261 The size of the total data in the stream. This will be the byte range if speficied, 262 otherwise the total size of the blob. 263 """ 264 265 def __init__( 266 self, 267 clients=None, 268 config=None, 269 start_range=None, 270 end_range=None, 271 validate_content=None, 272 encryption_options=None, 273 max_concurrency=1, 274 name=None, 275 container=None, 276 encoding=None, 277 **kwargs 278 ): 279 self.name = name 280 self.container = container 281 self.properties = None 282 self.size = None 283 284 self._clients = clients 285 self._config = config 286 self._start_range = start_range 287 self._end_range = end_range 288 self._max_concurrency = max_concurrency 289 self._encoding = encoding 290 self._validate_content = validate_content 291 self._encryption_options = encryption_options or {} 292 self._request_options = kwargs 293 self._location_mode = None 294 self._download_complete = False 295 self._current_content = None 296 self._file_size = None 297 self._non_empty_ranges = None 298 self._response = None 299 300 # The service only provides transactional MD5s for chunks under 4MB. 301 # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first 302 # chunk so a transactional MD5 can be retrieved. 303 self._first_get_size = ( 304 self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size 305 ) 306 initial_request_start = self._start_range if self._start_range is not None else 0 307 if self._end_range is not None and self._end_range - self._start_range < self._first_get_size: 308 initial_request_end = self._end_range 309 else: 310 initial_request_end = initial_request_start + self._first_get_size - 1 311 312 self._initial_range, self._initial_offset = process_range_and_offset( 313 initial_request_start, initial_request_end, self._end_range, self._encryption_options 314 ) 315 316 self._response = self._initial_request() 317 self.properties = self._response.properties 318 self.properties.name = self.name 319 self.properties.container = self.container 320 321 # Set the content length to the download size instead of the size of 322 # the last range 323 self.properties.size = self.size 324 325 # Overwrite the content range to the user requested range 326 self.properties.content_range = "bytes {0}-{1}/{2}".format( 327 self._start_range, 328 self._end_range, 329 self._file_size 330 ) 331 332 # Overwrite the content MD5 as it is the MD5 for the last range instead 333 # of the stored MD5 334 # TODO: Set to the stored MD5 when the service returns this 335 self.properties.content_md5 = None 336 337 if self.size == 0: 338 self._current_content = b"" 339 else: 340 self._current_content = process_content( 341 self._response, 342 self._initial_offset[0], 343 self._initial_offset[1], 344 self._encryption_options 345 ) 346 347 def __len__(self): 348 return self.size 349 350 def _initial_request(self): 351 range_header, range_validation = validate_and_format_range_headers( 352 self._initial_range[0], 353 self._initial_range[1], 354 start_range_required=False, 355 end_range_required=False, 356 check_content_md5=self._validate_content 357 ) 358 359 try: 360 location_mode, response = self._clients.blob.download( 361 range=range_header, 362 range_get_content_md5=range_validation, 363 validate_content=self._validate_content, 364 data_stream_total=None, 365 download_stream_current=0, 366 **self._request_options 367 ) 368 369 # Check the location we read from to ensure we use the same one 370 # for subsequent requests. 371 self._location_mode = location_mode 372 373 # Parse the total file size and adjust the download size if ranges 374 # were specified 375 self._file_size = parse_length_from_content_range(response.properties.content_range) 376 if self._end_range is not None: 377 # Use the end range index unless it is over the end of the file 378 self.size = min(self._file_size, self._end_range - self._start_range + 1) 379 elif self._start_range is not None: 380 self.size = self._file_size - self._start_range 381 else: 382 self.size = self._file_size 383 384 except HttpResponseError as error: 385 if self._start_range is None and error.response.status_code == 416: 386 # Get range will fail on an empty file. If the user did not 387 # request a range, do a regular get request in order to get 388 # any properties. 389 try: 390 _, response = self._clients.blob.download( 391 validate_content=self._validate_content, 392 data_stream_total=0, 393 download_stream_current=0, 394 **self._request_options 395 ) 396 except HttpResponseError as error: 397 process_storage_error(error) 398 399 # Set the download size to empty 400 self.size = 0 401 self._file_size = 0 402 else: 403 process_storage_error(error) 404 405 # get page ranges to optimize downloading sparse page blob 406 if response.properties.blob_type == 'PageBlob': 407 try: 408 page_ranges = self._clients.page_blob.get_page_ranges() 409 self._non_empty_ranges = get_page_ranges_result(page_ranges)[0] 410 # according to the REST API documentation: 411 # in a highly fragmented page blob with a large number of writes, 412 # a Get Page Ranges request can fail due to an internal server timeout. 413 # thus, if the page blob is not sparse, it's ok for it to fail 414 except HttpResponseError: 415 pass 416 417 # If the file is small, the download is complete at this point. 418 # If file size is large, download the rest of the file in chunks. 419 if response.properties.size != self.size: 420 # Lock on the etag. This can be overriden by the user by specifying '*' 421 if self._request_options.get("modified_access_conditions"): 422 if not self._request_options["modified_access_conditions"].if_match: 423 self._request_options["modified_access_conditions"].if_match = response.properties.etag 424 else: 425 self._download_complete = True 426 return response 427 428 def chunks(self): 429 if self.size == 0 or self._download_complete: 430 iter_downloader = None 431 else: 432 data_end = self._file_size 433 if self._end_range is not None: 434 # Use the end range index unless it is over the end of the file 435 data_end = min(self._file_size, self._end_range + 1) 436 iter_downloader = _ChunkDownloader( 437 client=self._clients.blob, 438 non_empty_ranges=self._non_empty_ranges, 439 total_size=self.size, 440 chunk_size=self._config.max_chunk_get_size, 441 current_progress=self._first_get_size, 442 start_range=self._initial_range[1] + 1, # start where the first download ended 443 end_range=data_end, 444 stream=None, 445 parallel=False, 446 validate_content=self._validate_content, 447 encryption_options=self._encryption_options, 448 use_location=self._location_mode, 449 **self._request_options 450 ) 451 return _ChunkIterator( 452 size=self.size, 453 content=self._current_content, 454 downloader=iter_downloader) 455 456 def readall(self): 457 """Download the contents of this blob. 458 459 This operation is blocking until all data is downloaded. 460 :rtype: bytes or str 461 """ 462 stream = BytesIO() 463 self.readinto(stream) 464 data = stream.getvalue() 465 if self._encoding: 466 return data.decode(self._encoding) 467 return data 468 469 def content_as_bytes(self, max_concurrency=1): 470 """Download the contents of this file. 471 472 This operation is blocking until all data is downloaded. 473 474 :keyword int max_concurrency: 475 The number of parallel connections with which to download. 476 :rtype: bytes 477 """ 478 warnings.warn( 479 "content_as_bytes is deprecated, use readall instead", 480 DeprecationWarning 481 ) 482 self._max_concurrency = max_concurrency 483 return self.readall() 484 485 def content_as_text(self, max_concurrency=1, encoding="UTF-8"): 486 """Download the contents of this blob, and decode as text. 487 488 This operation is blocking until all data is downloaded. 489 490 :keyword int max_concurrency: 491 The number of parallel connections with which to download. 492 :param str encoding: 493 Test encoding to decode the downloaded bytes. Default is UTF-8. 494 :rtype: str 495 """ 496 warnings.warn( 497 "content_as_text is deprecated, use readall instead", 498 DeprecationWarning 499 ) 500 self._max_concurrency = max_concurrency 501 self._encoding = encoding 502 return self.readall() 503 504 def readinto(self, stream): 505 """Download the contents of this file to a stream. 506 507 :param stream: 508 The stream to download to. This can be an open file-handle, 509 or any writable stream. The stream must be seekable if the download 510 uses more than one parallel connection. 511 :returns: The number of bytes read. 512 :rtype: int 513 """ 514 # The stream must be seekable if parallel download is required 515 parallel = self._max_concurrency > 1 516 if parallel: 517 error_message = "Target stream handle must be seekable." 518 if sys.version_info >= (3,) and not stream.seekable(): 519 raise ValueError(error_message) 520 521 try: 522 stream.seek(stream.tell()) 523 except (NotImplementedError, AttributeError): 524 raise ValueError(error_message) 525 526 # Write the content to the user stream 527 stream.write(self._current_content) 528 if self._download_complete: 529 return self.size 530 531 data_end = self._file_size 532 if self._end_range is not None: 533 # Use the length unless it is over the end of the file 534 data_end = min(self._file_size, self._end_range + 1) 535 536 downloader = _ChunkDownloader( 537 client=self._clients.blob, 538 non_empty_ranges=self._non_empty_ranges, 539 total_size=self.size, 540 chunk_size=self._config.max_chunk_get_size, 541 current_progress=self._first_get_size, 542 start_range=self._initial_range[1] + 1, # Start where the first download ended 543 end_range=data_end, 544 stream=stream, 545 parallel=parallel, 546 validate_content=self._validate_content, 547 encryption_options=self._encryption_options, 548 use_location=self._location_mode, 549 **self._request_options 550 ) 551 if parallel: 552 import concurrent.futures 553 executor = concurrent.futures.ThreadPoolExecutor(self._max_concurrency) 554 list(executor.map( 555 with_current_context(downloader.process_chunk), 556 downloader.get_chunk_offsets() 557 )) 558 else: 559 for chunk in downloader.get_chunk_offsets(): 560 downloader.process_chunk(chunk) 561 return self.size 562 563 def download_to_stream(self, stream, max_concurrency=1): 564 """Download the contents of this blob to a stream. 565 566 :param stream: 567 The stream to download to. This can be an open file-handle, 568 or any writable stream. The stream must be seekable if the download 569 uses more than one parallel connection. 570 :returns: The properties of the downloaded blob. 571 :rtype: Any 572 """ 573 warnings.warn( 574 "download_to_stream is deprecated, use readinto instead", 575 DeprecationWarning 576 ) 577 self._max_concurrency = max_concurrency 578 self.readinto(stream) 579 return self.properties 580