1# ------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for 4# license information. 5# -------------------------------------------------------------------------- 6# pylint: disable=invalid-overridden-method 7 8import asyncio 9import sys 10from io import BytesIO 11from itertools import islice 12import warnings 13 14from azure.core.exceptions import HttpResponseError 15from .._shared.encryption import decrypt_blob 16from .._shared.request_handlers import validate_and_format_range_headers 17from .._shared.response_handlers import process_storage_error, parse_length_from_content_range 18from .._deserialize import get_page_ranges_result 19from .._download import process_range_and_offset, _ChunkDownloader 20 21 22async def process_content(data, start_offset, end_offset, encryption): 23 if data is None: 24 raise ValueError("Response cannot be None.") 25 try: 26 content = data.response.body() 27 except Exception as error: 28 raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error) 29 if encryption.get('key') is not None or encryption.get('resolver') is not None: 30 try: 31 return decrypt_blob( 32 encryption.get('required'), 33 encryption.get('key'), 34 encryption.get('resolver'), 35 content, 36 start_offset, 37 end_offset, 38 data.response.headers) 39 except Exception as error: 40 raise HttpResponseError( 41 message="Decryption failed.", 42 response=data.response, 43 error=error) 44 return content 45 46 47class _AsyncChunkDownloader(_ChunkDownloader): 48 def __init__(self, **kwargs): 49 super(_AsyncChunkDownloader, self).__init__(**kwargs) 50 self.stream_lock = asyncio.Lock() if kwargs.get('parallel') else None 51 self.progress_lock = asyncio.Lock() if kwargs.get('parallel') else None 52 53 async def process_chunk(self, chunk_start): 54 chunk_start, chunk_end = self._calculate_range(chunk_start) 55 chunk_data = await self._download_chunk(chunk_start, chunk_end - 1) 56 length = chunk_end - chunk_start 57 if length > 0: 58 await self._write_to_stream(chunk_data, chunk_start) 59 await self._update_progress(length) 60 61 async def yield_chunk(self, chunk_start): 62 chunk_start, chunk_end = self._calculate_range(chunk_start) 63 return await self._download_chunk(chunk_start, chunk_end - 1) 64 65 async def _update_progress(self, length): 66 if self.progress_lock: 67 async with self.progress_lock: # pylint: disable=not-async-context-manager 68 self.progress_total += length 69 else: 70 self.progress_total += length 71 72 async def _write_to_stream(self, chunk_data, chunk_start): 73 if self.stream_lock: 74 async with self.stream_lock: # pylint: disable=not-async-context-manager 75 self.stream.seek(self.stream_start + (chunk_start - self.start_index)) 76 self.stream.write(chunk_data) 77 else: 78 self.stream.write(chunk_data) 79 80 async def _download_chunk(self, chunk_start, chunk_end): 81 download_range, offset = process_range_and_offset( 82 chunk_start, chunk_end, chunk_end, self.encryption_options) 83 84 # No need to download the empty chunk from server if there's no data in the chunk to be downloaded. 85 # Do optimize and create empty chunk locally if condition is met. 86 if self._do_optimize(download_range[0], download_range[1]): 87 chunk_data = b"\x00" * self.chunk_size 88 else: 89 range_header, range_validation = validate_and_format_range_headers( 90 download_range[0], 91 download_range[1], 92 check_content_md5=self.validate_content 93 ) 94 try: 95 _, response = await self.client.download( 96 range=range_header, 97 range_get_content_md5=range_validation, 98 validate_content=self.validate_content, 99 data_stream_total=self.total_size, 100 download_stream_current=self.progress_total, 101 **self.request_options 102 ) 103 except HttpResponseError as error: 104 process_storage_error(error) 105 106 chunk_data = await process_content(response, offset[0], offset[1], self.encryption_options) 107 108 # This makes sure that if_match is set so that we can validate 109 # that subsequent downloads are to an unmodified blob 110 if self.request_options.get('modified_access_conditions'): 111 self.request_options['modified_access_conditions'].if_match = response.properties.etag 112 113 return chunk_data 114 115 116class _AsyncChunkIterator(object): 117 """Async iterator for chunks in blob download stream.""" 118 119 def __init__(self, size, content, downloader): 120 self.size = size 121 self._current_content = content 122 self._iter_downloader = downloader 123 self._iter_chunks = None 124 self._complete = (size == 0) 125 126 def __len__(self): 127 return self.size 128 129 def __iter__(self): 130 raise TypeError("Async stream must be iterated asynchronously.") 131 132 def __aiter__(self): 133 return self 134 135 async def __anext__(self): 136 """Iterate through responses.""" 137 if self._complete: 138 raise StopAsyncIteration("Download complete") 139 if not self._iter_downloader: 140 # If no iterator was supplied, the download completed with 141 # the initial GET, so we just return that data 142 self._complete = True 143 return self._current_content 144 145 if not self._iter_chunks: 146 self._iter_chunks = self._iter_downloader.get_chunk_offsets() 147 else: 148 try: 149 chunk = next(self._iter_chunks) 150 except StopIteration: 151 raise StopAsyncIteration("Download complete") 152 self._current_content = await self._iter_downloader.yield_chunk(chunk) 153 154 return self._current_content 155 156 157class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes 158 """A streaming object to download from Azure Storage. 159 160 :ivar str name: 161 The name of the blob being downloaded. 162 :ivar str container: 163 The name of the container where the blob is. 164 :ivar ~azure.storage.blob.BlobProperties properties: 165 The properties of the blob being downloaded. If only a range of the data is being 166 downloaded, this will be reflected in the properties. 167 :ivar int size: 168 The size of the total data in the stream. This will be the byte range if speficied, 169 otherwise the total size of the blob. 170 """ 171 172 def __init__( 173 self, 174 clients=None, 175 config=None, 176 start_range=None, 177 end_range=None, 178 validate_content=None, 179 encryption_options=None, 180 max_concurrency=1, 181 name=None, 182 container=None, 183 encoding=None, 184 **kwargs 185 ): 186 self.name = name 187 self.container = container 188 self.properties = None 189 self.size = None 190 191 self._clients = clients 192 self._config = config 193 self._start_range = start_range 194 self._end_range = end_range 195 self._max_concurrency = max_concurrency 196 self._encoding = encoding 197 self._validate_content = validate_content 198 self._encryption_options = encryption_options or {} 199 self._request_options = kwargs 200 self._location_mode = None 201 self._download_complete = False 202 self._current_content = None 203 self._file_size = None 204 self._non_empty_ranges = None 205 self._response = None 206 207 # The service only provides transactional MD5s for chunks under 4MB. 208 # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first 209 # chunk so a transactional MD5 can be retrieved. 210 self._first_get_size = self._config.max_single_get_size if not self._validate_content \ 211 else self._config.max_chunk_get_size 212 initial_request_start = self._start_range if self._start_range is not None else 0 213 if self._end_range is not None and self._end_range - self._start_range < self._first_get_size: 214 initial_request_end = self._end_range 215 else: 216 initial_request_end = initial_request_start + self._first_get_size - 1 217 218 self._initial_range, self._initial_offset = process_range_and_offset( 219 initial_request_start, initial_request_end, self._end_range, self._encryption_options 220 ) 221 222 def __len__(self): 223 return self.size 224 225 async def _setup(self): 226 self._response = await self._initial_request() 227 self.properties = self._response.properties 228 self.properties.name = self.name 229 self.properties.container = self.container 230 231 # Set the content length to the download size instead of the size of 232 # the last range 233 self.properties.size = self.size 234 235 # Overwrite the content range to the user requested range 236 self.properties.content_range = 'bytes {0}-{1}/{2}'.format( 237 self._start_range, 238 self._end_range, 239 self._file_size 240 ) 241 242 # Overwrite the content MD5 as it is the MD5 for the last range instead 243 # of the stored MD5 244 # TODO: Set to the stored MD5 when the service returns this 245 self.properties.content_md5 = None 246 247 if self.size == 0: 248 self._current_content = b"" 249 else: 250 self._current_content = await process_content( 251 self._response, 252 self._initial_offset[0], 253 self._initial_offset[1], 254 self._encryption_options 255 ) 256 257 async def _initial_request(self): 258 range_header, range_validation = validate_and_format_range_headers( 259 self._initial_range[0], 260 self._initial_range[1], 261 start_range_required=False, 262 end_range_required=False, 263 check_content_md5=self._validate_content) 264 265 try: 266 location_mode, response = await self._clients.blob.download( 267 range=range_header, 268 range_get_content_md5=range_validation, 269 validate_content=self._validate_content, 270 data_stream_total=None, 271 download_stream_current=0, 272 **self._request_options) 273 274 # Check the location we read from to ensure we use the same one 275 # for subsequent requests. 276 self._location_mode = location_mode 277 278 # Parse the total file size and adjust the download size if ranges 279 # were specified 280 self._file_size = parse_length_from_content_range(response.properties.content_range) 281 if self._end_range is not None: 282 # Use the length unless it is over the end of the file 283 self.size = min(self._file_size, self._end_range - self._start_range + 1) 284 elif self._start_range is not None: 285 self.size = self._file_size - self._start_range 286 else: 287 self.size = self._file_size 288 289 except HttpResponseError as error: 290 if self._start_range is None and error.response.status_code == 416: 291 # Get range will fail on an empty file. If the user did not 292 # request a range, do a regular get request in order to get 293 # any properties. 294 try: 295 _, response = await self._clients.blob.download( 296 validate_content=self._validate_content, 297 data_stream_total=0, 298 download_stream_current=0, 299 **self._request_options) 300 except HttpResponseError as error: 301 process_storage_error(error) 302 303 # Set the download size to empty 304 self.size = 0 305 self._file_size = 0 306 else: 307 process_storage_error(error) 308 309 # get page ranges to optimize downloading sparse page blob 310 if response.properties.blob_type == 'PageBlob': 311 try: 312 page_ranges = await self._clients.page_blob.get_page_ranges() 313 self._non_empty_ranges = get_page_ranges_result(page_ranges)[0] 314 except HttpResponseError: 315 pass 316 317 # If the file is small, the download is complete at this point. 318 # If file size is large, download the rest of the file in chunks. 319 if response.properties.size != self.size: 320 # Lock on the etag. This can be overriden by the user by specifying '*' 321 if self._request_options.get('modified_access_conditions'): 322 if not self._request_options['modified_access_conditions'].if_match: 323 self._request_options['modified_access_conditions'].if_match = response.properties.etag 324 else: 325 self._download_complete = True 326 return response 327 328 def chunks(self): 329 """Iterate over chunks in the download stream. 330 331 :rtype: Iterable[bytes] 332 """ 333 if self.size == 0 or self._download_complete: 334 iter_downloader = None 335 else: 336 data_end = self._file_size 337 if self._end_range is not None: 338 # Use the length unless it is over the end of the file 339 data_end = min(self._file_size, self._end_range + 1) 340 iter_downloader = _AsyncChunkDownloader( 341 client=self._clients.blob, 342 non_empty_ranges=self._non_empty_ranges, 343 total_size=self.size, 344 chunk_size=self._config.max_chunk_get_size, 345 current_progress=self._first_get_size, 346 start_range=self._initial_range[1] + 1, # Start where the first download ended 347 end_range=data_end, 348 stream=None, 349 parallel=False, 350 validate_content=self._validate_content, 351 encryption_options=self._encryption_options, 352 use_location=self._location_mode, 353 **self._request_options) 354 return _AsyncChunkIterator( 355 size=self.size, 356 content=self._current_content, 357 downloader=iter_downloader) 358 359 async def readall(self): 360 """Download the contents of this blob. 361 362 This operation is blocking until all data is downloaded. 363 :rtype: bytes or str 364 """ 365 stream = BytesIO() 366 await self.readinto(stream) 367 data = stream.getvalue() 368 if self._encoding: 369 return data.decode(self._encoding) 370 return data 371 372 async def content_as_bytes(self, max_concurrency=1): 373 """Download the contents of this file. 374 375 This operation is blocking until all data is downloaded. 376 377 :keyword int max_concurrency: 378 The number of parallel connections with which to download. 379 :rtype: bytes 380 """ 381 warnings.warn( 382 "content_as_bytes is deprecated, use readall instead", 383 DeprecationWarning 384 ) 385 self._max_concurrency = max_concurrency 386 return await self.readall() 387 388 async def content_as_text(self, max_concurrency=1, encoding="UTF-8"): 389 """Download the contents of this blob, and decode as text. 390 391 This operation is blocking until all data is downloaded. 392 393 :keyword int max_concurrency: 394 The number of parallel connections with which to download. 395 :param str encoding: 396 Test encoding to decode the downloaded bytes. Default is UTF-8. 397 :rtype: str 398 """ 399 warnings.warn( 400 "content_as_text is deprecated, use readall instead", 401 DeprecationWarning 402 ) 403 self._max_concurrency = max_concurrency 404 self._encoding = encoding 405 return await self.readall() 406 407 async def readinto(self, stream): 408 """Download the contents of this blob to a stream. 409 410 :param stream: 411 The stream to download to. This can be an open file-handle, 412 or any writable stream. The stream must be seekable if the download 413 uses more than one parallel connection. 414 :returns: The number of bytes read. 415 :rtype: int 416 """ 417 # the stream must be seekable if parallel download is required 418 parallel = self._max_concurrency > 1 419 if parallel: 420 error_message = "Target stream handle must be seekable." 421 if sys.version_info >= (3,) and not stream.seekable(): 422 raise ValueError(error_message) 423 424 try: 425 stream.seek(stream.tell()) 426 except (NotImplementedError, AttributeError): 427 raise ValueError(error_message) 428 429 # Write the content to the user stream 430 stream.write(self._current_content) 431 if self._download_complete: 432 return self.size 433 434 data_end = self._file_size 435 if self._end_range is not None: 436 # Use the length unless it is over the end of the file 437 data_end = min(self._file_size, self._end_range + 1) 438 439 downloader = _AsyncChunkDownloader( 440 client=self._clients.blob, 441 non_empty_ranges=self._non_empty_ranges, 442 total_size=self.size, 443 chunk_size=self._config.max_chunk_get_size, 444 current_progress=self._first_get_size, 445 start_range=self._initial_range[1] + 1, # start where the first download ended 446 end_range=data_end, 447 stream=stream, 448 parallel=parallel, 449 validate_content=self._validate_content, 450 encryption_options=self._encryption_options, 451 use_location=self._location_mode, 452 **self._request_options) 453 454 dl_tasks = downloader.get_chunk_offsets() 455 running_futures = [ 456 asyncio.ensure_future(downloader.process_chunk(d)) 457 for d in islice(dl_tasks, 0, self._max_concurrency) 458 ] 459 while running_futures: 460 # Wait for some download to finish before adding a new one 461 _done, running_futures = await asyncio.wait( 462 running_futures, return_when=asyncio.FIRST_COMPLETED) 463 try: 464 next_chunk = next(dl_tasks) 465 except StopIteration: 466 break 467 else: 468 running_futures.add(asyncio.ensure_future(downloader.process_chunk(next_chunk))) 469 470 if running_futures: 471 # Wait for the remaining downloads to finish 472 await asyncio.wait(running_futures) 473 return self.size 474 475 async def download_to_stream(self, stream, max_concurrency=1): 476 """Download the contents of this blob to a stream. 477 478 :param stream: 479 The stream to download to. This can be an open file-handle, 480 or any writable stream. The stream must be seekable if the download 481 uses more than one parallel connection. 482 :returns: The properties of the downloaded blob. 483 :rtype: Any 484 """ 485 warnings.warn( 486 "download_to_stream is deprecated, use readinto instead", 487 DeprecationWarning 488 ) 489 self._max_concurrency = max_concurrency 490 await self.readinto(stream) 491 return self.properties 492