1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
7import sys
8import threading
9import warnings
10from io import BytesIO
12from azure.core.exceptions import HttpResponseError
13from azure.core.tracing.common import with_current_context
14from ._shared.encryption import decrypt_blob
15from ._shared.request_handlers import validate_and_format_range_headers
16from ._shared.response_handlers import process_storage_error, parse_length_from_content_range
17from ._deserialize import get_page_ranges_result
20def process_range_and_offset(start_range, end_range, length, encryption):
21    start_offset, end_offset = 0, 0
22    if encryption.get("key") is not None or encryption.get("resolver") is not None:
23        if start_range is not None:
24            # Align the start of the range along a 16 byte block
25            start_offset = start_range % 16
26            start_range -= start_offset
28            # Include an extra 16 bytes for the IV if necessary
29            # Because of the previous offsetting, start_range will always
30            # be a multiple of 16.
31            if start_range > 0:
32                start_offset += 16
33                start_range -= 16
35        if length is not None:
36            # Align the end of the range along a 16 byte block
37            end_offset = 15 - (end_range % 16)
38            end_range += end_offset
40    return (start_range, end_range), (start_offset, end_offset)
43def process_content(data, start_offset, end_offset, encryption):
44    if data is None:
45        raise ValueError("Response cannot be None.")
46    try:
47        content = b"".join(list(data))
48    except Exception as error:
49        raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error)
50    if content and encryption.get("key") is not None or encryption.get("resolver") is not None:
51        try:
52            return decrypt_blob(
53                encryption.get("required"),
54                encryption.get("key"),
55                encryption.get("resolver"),
56                content,
57                start_offset,
58                end_offset,
59                data.response.headers,
60            )
61        except Exception as error:
62            raise HttpResponseError(message="Decryption failed.", response=data.response, error=error)
63    return content
66class _ChunkDownloader(object):  # pylint: disable=too-many-instance-attributes
67    def __init__(
68        self,
69        client=None,
70        non_empty_ranges=None,
71        total_size=None,
72        chunk_size=None,
73        current_progress=None,
74        start_range=None,
75        end_range=None,
76        stream=None,
77        parallel=None,
78        validate_content=None,
79        encryption_options=None,
80        **kwargs
81    ):
82        self.client = client
83        self.non_empty_ranges = non_empty_ranges
85        # Information on the download range/chunk size
86        self.chunk_size = chunk_size
87        self.total_size = total_size
88        self.start_index = start_range
89        self.end_index = end_range
91        # The destination that we will write to
92        self.stream = stream
93        self.stream_lock = threading.Lock() if parallel else None
94        self.progress_lock = threading.Lock() if parallel else None
96        # For a parallel download, the stream is always seekable, so we note down the current position
97        # in order to seek to the right place when out-of-order chunks come in
98        self.stream_start = stream.tell() if parallel else None
100        # Download progress so far
101        self.progress_total = current_progress
103        # Encryption
104        self.encryption_options = encryption_options
106        # Parameters for each get operation
107        self.validate_content = validate_content
108        self.request_options = kwargs
110    def _calculate_range(self, chunk_start):
111        if chunk_start + self.chunk_size > self.end_index:
112            chunk_end = self.end_index
113        else:
114            chunk_end = chunk_start + self.chunk_size
115        return chunk_start, chunk_end
117    def get_chunk_offsets(self):
118        index = self.start_index
119        while index < self.end_index:
120            yield index
121            index += self.chunk_size
123    def process_chunk(self, chunk_start):
124        chunk_start, chunk_end = self._calculate_range(chunk_start)
125        chunk_data = self._download_chunk(chunk_start, chunk_end - 1)
126        length = chunk_end - chunk_start
127        if length > 0:
128            self._write_to_stream(chunk_data, chunk_start)
129            self._update_progress(length)
131    def yield_chunk(self, chunk_start):
132        chunk_start, chunk_end = self._calculate_range(chunk_start)
133        return self._download_chunk(chunk_start, chunk_end - 1)
135    def _update_progress(self, length):
136        if self.progress_lock:
137            with self.progress_lock:  # pylint: disable=not-context-manager
138                self.progress_total += length
139        else:
140            self.progress_total += length
142    def _write_to_stream(self, chunk_data, chunk_start):
143        if self.stream_lock:
144            with self.stream_lock:  # pylint: disable=not-context-manager
145                self.stream.seek(self.stream_start + (chunk_start - self.start_index))
146                self.stream.write(chunk_data)
147        else:
148            self.stream.write(chunk_data)
150    def _do_optimize(self, given_range_start, given_range_end):
151        # If we have no page range list stored, then assume there's data everywhere for that page blob
152        # or it's a block blob or append blob
153        if self.non_empty_ranges is None:
154            return False
156        for source_range in self.non_empty_ranges:
157            # Case 1: As the range list is sorted, if we've reached such a source_range
158            # we've checked all the appropriate source_range already and haven't found any overlapping.
159            # so the given range doesn't have any data and download optimization could be applied.
160            # given range:		|   |
161            # source range:			       |   |
162            if given_range_end < source_range['start']:  # pylint:disable=no-else-return
163                return True
164            # Case 2: the given range comes after source_range, continue checking.
165            # given range:				|   |
166            # source range:	|   |
167            elif source_range['end'] < given_range_start:
168                pass
169            # Case 3: source_range and given range overlap somehow, no need to optimize.
170            else:
171                return False
172        # Went through all src_ranges, but nothing overlapped. Optimization will be applied.
173        return True
175    def _download_chunk(self, chunk_start, chunk_end):
176        download_range, offset = process_range_and_offset(
177            chunk_start, chunk_end, chunk_end, self.encryption_options
178        )
180        # No need to download the empty chunk from server if there's no data in the chunk to be downloaded.
181        # Do optimize and create empty chunk locally if condition is met.
182        if self._do_optimize(download_range[0], download_range[1]):
183            chunk_data = b"\x00" * self.chunk_size
184        else:
185            range_header, range_validation = validate_and_format_range_headers(
186                download_range[0],
187                download_range[1],
188                check_content_md5=self.validate_content
189            )
191            try:
192                _, response = self.client.download(
193                    range=range_header,
194                    range_get_content_md5=range_validation,
195                    validate_content=self.validate_content,
196                    data_stream_total=self.total_size,
197                    download_stream_current=self.progress_total,
198                    **self.request_options
199                )
200            except HttpResponseError as error:
201                process_storage_error(error)
203            chunk_data = process_content(response, offset[0], offset[1], self.encryption_options)
205            # This makes sure that if_match is set so that we can validate
206            # that subsequent downloads are to an unmodified blob
207            if self.request_options.get("modified_access_conditions"):
208                self.request_options["modified_access_conditions"].if_match = response.properties.etag
210        return chunk_data
213class _ChunkIterator(object):
214    """Async iterator for chunks in blob download stream."""
216    def __init__(self, size, content, downloader):
217        self.size = size
218        self._current_content = content
219        self._iter_downloader = downloader
220        self._iter_chunks = None
221        self._complete = (size == 0)
223    def __len__(self):
224        return self.size
226    def __iter__(self):
227        return self
229    def __next__(self):
230        """Iterate through responses."""
231        if self._complete:
232            raise StopIteration("Download complete")
233        if not self._iter_downloader:
234            # If no iterator was supplied, the download completed with
235            # the initial GET, so we just return that data
236            self._complete = True
237            return self._current_content
239        if not self._iter_chunks:
240            self._iter_chunks = self._iter_downloader.get_chunk_offsets()
241        else:
242            chunk = next(self._iter_chunks)
243            self._current_content = self._iter_downloader.yield_chunk(chunk)
245        return self._current_content
247    next = __next__  # Python 2 compatibility.
250class StorageStreamDownloader(object):  # pylint: disable=too-many-instance-attributes
251    """A streaming object to download from Azure Storage.
253    :ivar str name:
254        The name of the blob being downloaded.
255    :ivar str container:
256        The name of the container where the blob is.
257    :ivar ~azure.storage.blob.BlobProperties properties:
258        The properties of the blob being downloaded. If only a range of the data is being
259        downloaded, this will be reflected in the properties.
260    :ivar int size:
261        The size of the total data in the stream. This will be the byte range if speficied,
262        otherwise the total size of the blob.
263    """
265    def __init__(
266        self,
267        clients=None,
268        config=None,
269        start_range=None,
270        end_range=None,
271        validate_content=None,
272        encryption_options=None,
273        max_concurrency=1,
274        name=None,
275        container=None,
276        encoding=None,
277        **kwargs
278    ):
279        self.name = name
280        self.container = container
281        self.properties = None
282        self.size = None
284        self._clients = clients
285        self._config = config
286        self._start_range = start_range
287        self._end_range = end_range
288        self._max_concurrency = max_concurrency
289        self._encoding = encoding
290        self._validate_content = validate_content
291        self._encryption_options = encryption_options or {}
292        self._request_options = kwargs
293        self._location_mode = None
294        self._download_complete = False
295        self._current_content = None
296        self._file_size = None
297        self._non_empty_ranges = None
298        self._response = None
300        # The service only provides transactional MD5s for chunks under 4MB.
301        # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first
302        # chunk so a transactional MD5 can be retrieved.
303        self._first_get_size = (
304            self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size
305        )
306        initial_request_start = self._start_range if self._start_range is not None else 0
307        if self._end_range is not None and self._end_range - self._start_range < self._first_get_size:
308            initial_request_end = self._end_range
309        else:
310            initial_request_end = initial_request_start + self._first_get_size - 1
312        self._initial_range, self._initial_offset = process_range_and_offset(
313            initial_request_start, initial_request_end, self._end_range, self._encryption_options
314        )
316        self._response = self._initial_request()
317        self.properties = self._response.properties
318        self.properties.name = self.name
319        self.properties.container = self.container
321        # Set the content length to the download size instead of the size of
322        # the last range
323        self.properties.size = self.size
325        # Overwrite the content range to the user requested range
326        self.properties.content_range = "bytes {0}-{1}/{2}".format(
327            self._start_range,
328            self._end_range,
329            self._file_size
330        )
332        # Overwrite the content MD5 as it is the MD5 for the last range instead
333        # of the stored MD5
334        # TODO: Set to the stored MD5 when the service returns this
335        self.properties.content_md5 = None
337        if self.size == 0:
338            self._current_content = b""
339        else:
340            self._current_content = process_content(
341                self._response,
342                self._initial_offset[0],
343                self._initial_offset[1],
344                self._encryption_options
345            )
347    def __len__(self):
348        return self.size
350    def _initial_request(self):
351        range_header, range_validation = validate_and_format_range_headers(
352            self._initial_range[0],
353            self._initial_range[1],
354            start_range_required=False,
355            end_range_required=False,
356            check_content_md5=self._validate_content
357        )
359        try:
360            location_mode, response = self._clients.blob.download(
361                range=range_header,
362                range_get_content_md5=range_validation,
363                validate_content=self._validate_content,
364                data_stream_total=None,
365                download_stream_current=0,
366                **self._request_options
367            )
369            # Check the location we read from to ensure we use the same one
370            # for subsequent requests.
371            self._location_mode = location_mode
373            # Parse the total file size and adjust the download size if ranges
374            # were specified
375            self._file_size = parse_length_from_content_range(response.properties.content_range)
376            if self._end_range is not None:
377                # Use the end range index unless it is over the end of the file
378                self.size = min(self._file_size, self._end_range - self._start_range + 1)
379            elif self._start_range is not None:
380                self.size = self._file_size - self._start_range
381            else:
382                self.size = self._file_size
384        except HttpResponseError as error:
385            if self._start_range is None and error.response.status_code == 416:
386                # Get range will fail on an empty file. If the user did not
387                # request a range, do a regular get request in order to get
388                # any properties.
389                try:
390                    _, response = self._clients.blob.download(
391                        validate_content=self._validate_content,
392                        data_stream_total=0,
393                        download_stream_current=0,
394                        **self._request_options
395                    )
396                except HttpResponseError as error:
397                    process_storage_error(error)
399                # Set the download size to empty
400                self.size = 0
401                self._file_size = 0
402            else:
403                process_storage_error(error)
405        # get page ranges to optimize downloading sparse page blob
406        if response.properties.blob_type == 'PageBlob':
407            try:
408                page_ranges = self._clients.page_blob.get_page_ranges()
409                self._non_empty_ranges = get_page_ranges_result(page_ranges)[0]
410            # according to the REST API documentation:
411            # in a highly fragmented page blob with a large number of writes,
412            # a Get Page Ranges request can fail due to an internal server timeout.
413            # thus, if the page blob is not sparse, it's ok for it to fail
414            except HttpResponseError:
415                pass
417        # If the file is small, the download is complete at this point.
418        # If file size is large, download the rest of the file in chunks.
419        if response.properties.size != self.size:
420            # Lock on the etag. This can be overriden by the user by specifying '*'
421            if self._request_options.get("modified_access_conditions"):
422                if not self._request_options["modified_access_conditions"].if_match:
423                    self._request_options["modified_access_conditions"].if_match = response.properties.etag
424        else:
425            self._download_complete = True
426        return response
428    def chunks(self):
429        if self.size == 0 or self._download_complete:
430            iter_downloader = None
431        else:
432            data_end = self._file_size
433            if self._end_range is not None:
434                # Use the end range index unless it is over the end of the file
435                data_end = min(self._file_size, self._end_range + 1)
436            iter_downloader = _ChunkDownloader(
437                client=self._clients.blob,
438                non_empty_ranges=self._non_empty_ranges,
439                total_size=self.size,
440                chunk_size=self._config.max_chunk_get_size,
441                current_progress=self._first_get_size,
442                start_range=self._initial_range[1] + 1,  # start where the first download ended
443                end_range=data_end,
444                stream=None,
445                parallel=False,
446                validate_content=self._validate_content,
447                encryption_options=self._encryption_options,
448                use_location=self._location_mode,
449                **self._request_options
450            )
451        return _ChunkIterator(
452            size=self.size,
453            content=self._current_content,
454            downloader=iter_downloader)
456    def readall(self):
457        """Download the contents of this blob.
459        This operation is blocking until all data is downloaded.
460        :rtype: bytes or str
461        """
462        stream = BytesIO()
463        self.readinto(stream)
464        data = stream.getvalue()
465        if self._encoding:
466            return data.decode(self._encoding)
467        return data
469    def content_as_bytes(self, max_concurrency=1):
470        """Download the contents of this file.
472        This operation is blocking until all data is downloaded.
474        :keyword int max_concurrency:
475            The number of parallel connections with which to download.
476        :rtype: bytes
477        """
478        warnings.warn(
479            "content_as_bytes is deprecated, use readall instead",
480            DeprecationWarning
481        )
482        self._max_concurrency = max_concurrency
483        return self.readall()
485    def content_as_text(self, max_concurrency=1, encoding="UTF-8"):
486        """Download the contents of this blob, and decode as text.
488        This operation is blocking until all data is downloaded.
490        :keyword int max_concurrency:
491            The number of parallel connections with which to download.
492        :param str encoding:
493            Test encoding to decode the downloaded bytes. Default is UTF-8.
494        :rtype: str
495        """
496        warnings.warn(
497            "content_as_text is deprecated, use readall instead",
498            DeprecationWarning
499        )
500        self._max_concurrency = max_concurrency
501        self._encoding = encoding
502        return self.readall()
504    def readinto(self, stream):
505        """Download the contents of this file to a stream.
507        :param stream:
508            The stream to download to. This can be an open file-handle,
509            or any writable stream. The stream must be seekable if the download
510            uses more than one parallel connection.
511        :returns: The number of bytes read.
512        :rtype: int
513        """
514        # The stream must be seekable if parallel download is required
515        parallel = self._max_concurrency > 1
516        if parallel:
517            error_message = "Target stream handle must be seekable."
518            if sys.version_info >= (3,) and not stream.seekable():
519                raise ValueError(error_message)
521            try:
522                stream.seek(stream.tell())
523            except (NotImplementedError, AttributeError):
524                raise ValueError(error_message)
526        # Write the content to the user stream
527        stream.write(self._current_content)
528        if self._download_complete:
529            return self.size
531        data_end = self._file_size
532        if self._end_range is not None:
533            # Use the length unless it is over the end of the file
534            data_end = min(self._file_size, self._end_range + 1)
536        downloader = _ChunkDownloader(
537            client=self._clients.blob,
538            non_empty_ranges=self._non_empty_ranges,
539            total_size=self.size,
540            chunk_size=self._config.max_chunk_get_size,
541            current_progress=self._first_get_size,
542            start_range=self._initial_range[1] + 1,  # Start where the first download ended
543            end_range=data_end,
544            stream=stream,
545            parallel=parallel,
546            validate_content=self._validate_content,
547            encryption_options=self._encryption_options,
548            use_location=self._location_mode,
549            **self._request_options
550        )
551        if parallel:
552            import concurrent.futures
553            executor = concurrent.futures.ThreadPoolExecutor(self._max_concurrency)
554            list(executor.map(
555                    with_current_context(downloader.process_chunk),
556                    downloader.get_chunk_offsets()
557                ))
558        else:
559            for chunk in downloader.get_chunk_offsets():
560                downloader.process_chunk(chunk)
561        return self.size
563    def download_to_stream(self, stream, max_concurrency=1):
564        """Download the contents of this blob to a stream.
566        :param stream:
567            The stream to download to. This can be an open file-handle,
568            or any writable stream. The stream must be seekable if the download
569            uses more than one parallel connection.
570        :returns: The properties of the downloaded blob.
571        :rtype: Any
572        """
573        warnings.warn(
574            "download_to_stream is deprecated, use readinto instead",
575            DeprecationWarning
576        )
577        self._max_concurrency = max_concurrency
578        self.readinto(stream)
579        return self.properties