1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6# pylint: disable=invalid-overridden-method
8import asyncio
9import sys
10from io import BytesIO
11from itertools import islice
12import warnings
14from azure.core.exceptions import HttpResponseError
15from .._shared.encryption import decrypt_blob
16from .._shared.request_handlers import validate_and_format_range_headers
17from .._shared.response_handlers import process_storage_error, parse_length_from_content_range
18from .._deserialize import get_page_ranges_result
19from .._download import process_range_and_offset, _ChunkDownloader
22async def process_content(data, start_offset, end_offset, encryption):
23    if data is None:
24        raise ValueError("Response cannot be None.")
25    try:
26        content = data.response.body()
27    except Exception as error:
28        raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error)
29    if encryption.get('key') is not None or encryption.get('resolver') is not None:
30        try:
31            return decrypt_blob(
32                encryption.get('required'),
33                encryption.get('key'),
34                encryption.get('resolver'),
35                content,
36                start_offset,
37                end_offset,
38                data.response.headers)
39        except Exception as error:
40            raise HttpResponseError(
41                message="Decryption failed.",
42                response=data.response,
43                error=error)
44    return content
47class _AsyncChunkDownloader(_ChunkDownloader):
48    def __init__(self, **kwargs):
49        super(_AsyncChunkDownloader, self).__init__(**kwargs)
50        self.stream_lock = asyncio.Lock() if kwargs.get('parallel') else None
51        self.progress_lock = asyncio.Lock() if kwargs.get('parallel') else None
53    async def process_chunk(self, chunk_start):
54        chunk_start, chunk_end = self._calculate_range(chunk_start)
55        chunk_data = await self._download_chunk(chunk_start, chunk_end - 1)
56        length = chunk_end - chunk_start
57        if length > 0:
58            await self._write_to_stream(chunk_data, chunk_start)
59            await self._update_progress(length)
61    async def yield_chunk(self, chunk_start):
62        chunk_start, chunk_end = self._calculate_range(chunk_start)
63        return await self._download_chunk(chunk_start, chunk_end - 1)
65    async def _update_progress(self, length):
66        if self.progress_lock:
67            async with self.progress_lock:  # pylint: disable=not-async-context-manager
68                self.progress_total += length
69        else:
70            self.progress_total += length
72    async def _write_to_stream(self, chunk_data, chunk_start):
73        if self.stream_lock:
74            async with self.stream_lock:  # pylint: disable=not-async-context-manager
75                self.stream.seek(self.stream_start + (chunk_start - self.start_index))
76                self.stream.write(chunk_data)
77        else:
78            self.stream.write(chunk_data)
80    async def _download_chunk(self, chunk_start, chunk_end):
81        download_range, offset = process_range_and_offset(
82            chunk_start, chunk_end, chunk_end, self.encryption_options)
84        # No need to download the empty chunk from server if there's no data in the chunk to be downloaded.
85        # Do optimize and create empty chunk locally if condition is met.
86        if self._do_optimize(download_range[0], download_range[1]):
87            chunk_data = b"\x00" * self.chunk_size
88        else:
89            range_header, range_validation = validate_and_format_range_headers(
90                download_range[0],
91                download_range[1],
92                check_content_md5=self.validate_content
93            )
94            try:
95                _, response = await self.client.download(
96                    range=range_header,
97                    range_get_content_md5=range_validation,
98                    validate_content=self.validate_content,
99                    data_stream_total=self.total_size,
100                    download_stream_current=self.progress_total,
101                    **self.request_options
102                )
103            except HttpResponseError as error:
104                process_storage_error(error)
106            chunk_data = await process_content(response, offset[0], offset[1], self.encryption_options)
108            # This makes sure that if_match is set so that we can validate
109            # that subsequent downloads are to an unmodified blob
110            if self.request_options.get('modified_access_conditions'):
111                self.request_options['modified_access_conditions'].if_match = response.properties.etag
113        return chunk_data
116class _AsyncChunkIterator(object):
117    """Async iterator for chunks in blob download stream."""
119    def __init__(self, size, content, downloader):
120        self.size = size
121        self._current_content = content
122        self._iter_downloader = downloader
123        self._iter_chunks = None
124        self._complete = (size == 0)
126    def __len__(self):
127        return self.size
129    def __iter__(self):
130        raise TypeError("Async stream must be iterated asynchronously.")
132    def __aiter__(self):
133        return self
135    async def __anext__(self):
136        """Iterate through responses."""
137        if self._complete:
138            raise StopAsyncIteration("Download complete")
139        if not self._iter_downloader:
140            # If no iterator was supplied, the download completed with
141            # the initial GET, so we just return that data
142            self._complete = True
143            return self._current_content
145        if not self._iter_chunks:
146            self._iter_chunks = self._iter_downloader.get_chunk_offsets()
147        else:
148            try:
149                chunk = next(self._iter_chunks)
150            except StopIteration:
151                raise StopAsyncIteration("Download complete")
152            self._current_content = await self._iter_downloader.yield_chunk(chunk)
154        return self._current_content
157class StorageStreamDownloader(object):  # pylint: disable=too-many-instance-attributes
158    """A streaming object to download from Azure Storage.
160    :ivar str name:
161        The name of the blob being downloaded.
162    :ivar str container:
163        The name of the container where the blob is.
164    :ivar ~azure.storage.blob.BlobProperties properties:
165        The properties of the blob being downloaded. If only a range of the data is being
166        downloaded, this will be reflected in the properties.
167    :ivar int size:
168        The size of the total data in the stream. This will be the byte range if speficied,
169        otherwise the total size of the blob.
170    """
172    def __init__(
173            self,
174            clients=None,
175            config=None,
176            start_range=None,
177            end_range=None,
178            validate_content=None,
179            encryption_options=None,
180            max_concurrency=1,
181            name=None,
182            container=None,
183            encoding=None,
184            **kwargs
185    ):
186        self.name = name
187        self.container = container
188        self.properties = None
189        self.size = None
191        self._clients = clients
192        self._config = config
193        self._start_range = start_range
194        self._end_range = end_range
195        self._max_concurrency = max_concurrency
196        self._encoding = encoding
197        self._validate_content = validate_content
198        self._encryption_options = encryption_options or {}
199        self._request_options = kwargs
200        self._location_mode = None
201        self._download_complete = False
202        self._current_content = None
203        self._file_size = None
204        self._non_empty_ranges = None
205        self._response = None
207        # The service only provides transactional MD5s for chunks under 4MB.
208        # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first
209        # chunk so a transactional MD5 can be retrieved.
210        self._first_get_size = self._config.max_single_get_size if not self._validate_content \
211            else self._config.max_chunk_get_size
212        initial_request_start = self._start_range if self._start_range is not None else 0
213        if self._end_range is not None and self._end_range - self._start_range < self._first_get_size:
214            initial_request_end = self._end_range
215        else:
216            initial_request_end = initial_request_start + self._first_get_size - 1
218        self._initial_range, self._initial_offset = process_range_and_offset(
219            initial_request_start, initial_request_end, self._end_range, self._encryption_options
220        )
222    def __len__(self):
223        return self.size
225    async def _setup(self):
226        self._response = await self._initial_request()
227        self.properties = self._response.properties
228        self.properties.name = self.name
229        self.properties.container = self.container
231        # Set the content length to the download size instead of the size of
232        # the last range
233        self.properties.size = self.size
235        # Overwrite the content range to the user requested range
236        self.properties.content_range = 'bytes {0}-{1}/{2}'.format(
237            self._start_range,
238            self._end_range,
239            self._file_size
240        )
242        # Overwrite the content MD5 as it is the MD5 for the last range instead
243        # of the stored MD5
244        # TODO: Set to the stored MD5 when the service returns this
245        self.properties.content_md5 = None
247        if self.size == 0:
248            self._current_content = b""
249        else:
250            self._current_content = await process_content(
251                self._response,
252                self._initial_offset[0],
253                self._initial_offset[1],
254                self._encryption_options
255            )
257    async def _initial_request(self):
258        range_header, range_validation = validate_and_format_range_headers(
259            self._initial_range[0],
260            self._initial_range[1],
261            start_range_required=False,
262            end_range_required=False,
263            check_content_md5=self._validate_content)
265        try:
266            location_mode, response = await self._clients.blob.download(
267                range=range_header,
268                range_get_content_md5=range_validation,
269                validate_content=self._validate_content,
270                data_stream_total=None,
271                download_stream_current=0,
272                **self._request_options)
274            # Check the location we read from to ensure we use the same one
275            # for subsequent requests.
276            self._location_mode = location_mode
278            # Parse the total file size and adjust the download size if ranges
279            # were specified
280            self._file_size = parse_length_from_content_range(response.properties.content_range)
281            if self._end_range is not None:
282                # Use the length unless it is over the end of the file
283                self.size = min(self._file_size, self._end_range - self._start_range + 1)
284            elif self._start_range is not None:
285                self.size = self._file_size - self._start_range
286            else:
287                self.size = self._file_size
289        except HttpResponseError as error:
290            if self._start_range is None and error.response.status_code == 416:
291                # Get range will fail on an empty file. If the user did not
292                # request a range, do a regular get request in order to get
293                # any properties.
294                try:
295                    _, response = await self._clients.blob.download(
296                        validate_content=self._validate_content,
297                        data_stream_total=0,
298                        download_stream_current=0,
299                        **self._request_options)
300                except HttpResponseError as error:
301                    process_storage_error(error)
303                # Set the download size to empty
304                self.size = 0
305                self._file_size = 0
306            else:
307                process_storage_error(error)
309        # get page ranges to optimize downloading sparse page blob
310        if response.properties.blob_type == 'PageBlob':
311            try:
312                page_ranges = await self._clients.page_blob.get_page_ranges()
313                self._non_empty_ranges = get_page_ranges_result(page_ranges)[0]
314            except HttpResponseError:
315                pass
317        # If the file is small, the download is complete at this point.
318        # If file size is large, download the rest of the file in chunks.
319        if response.properties.size != self.size:
320            # Lock on the etag. This can be overriden by the user by specifying '*'
321            if self._request_options.get('modified_access_conditions'):
322                if not self._request_options['modified_access_conditions'].if_match:
323                    self._request_options['modified_access_conditions'].if_match = response.properties.etag
324        else:
325            self._download_complete = True
326        return response
328    def chunks(self):
329        """Iterate over chunks in the download stream.
331        :rtype: Iterable[bytes]
332        """
333        if self.size == 0 or self._download_complete:
334            iter_downloader = None
335        else:
336            data_end = self._file_size
337            if self._end_range is not None:
338                # Use the length unless it is over the end of the file
339                data_end = min(self._file_size, self._end_range + 1)
340            iter_downloader = _AsyncChunkDownloader(
341                client=self._clients.blob,
342                non_empty_ranges=self._non_empty_ranges,
343                total_size=self.size,
344                chunk_size=self._config.max_chunk_get_size,
345                current_progress=self._first_get_size,
346                start_range=self._initial_range[1] + 1,  # Start where the first download ended
347                end_range=data_end,
348                stream=None,
349                parallel=False,
350                validate_content=self._validate_content,
351                encryption_options=self._encryption_options,
352                use_location=self._location_mode,
353                **self._request_options)
354        return _AsyncChunkIterator(
355            size=self.size,
356            content=self._current_content,
357            downloader=iter_downloader)
359    async def readall(self):
360        """Download the contents of this blob.
362        This operation is blocking until all data is downloaded.
363        :rtype: bytes or str
364        """
365        stream = BytesIO()
366        await self.readinto(stream)
367        data = stream.getvalue()
368        if self._encoding:
369            return data.decode(self._encoding)
370        return data
372    async def content_as_bytes(self, max_concurrency=1):
373        """Download the contents of this file.
375        This operation is blocking until all data is downloaded.
377        :keyword int max_concurrency:
378            The number of parallel connections with which to download.
379        :rtype: bytes
380        """
381        warnings.warn(
382            "content_as_bytes is deprecated, use readall instead",
383            DeprecationWarning
384        )
385        self._max_concurrency = max_concurrency
386        return await self.readall()
388    async def content_as_text(self, max_concurrency=1, encoding="UTF-8"):
389        """Download the contents of this blob, and decode as text.
391        This operation is blocking until all data is downloaded.
393        :keyword int max_concurrency:
394            The number of parallel connections with which to download.
395        :param str encoding:
396            Test encoding to decode the downloaded bytes. Default is UTF-8.
397        :rtype: str
398        """
399        warnings.warn(
400            "content_as_text is deprecated, use readall instead",
401            DeprecationWarning
402        )
403        self._max_concurrency = max_concurrency
404        self._encoding = encoding
405        return await self.readall()
407    async def readinto(self, stream):
408        """Download the contents of this blob to a stream.
410        :param stream:
411            The stream to download to. This can be an open file-handle,
412            or any writable stream. The stream must be seekable if the download
413            uses more than one parallel connection.
414        :returns: The number of bytes read.
415        :rtype: int
416        """
417        # the stream must be seekable if parallel download is required
418        parallel = self._max_concurrency > 1
419        if parallel:
420            error_message = "Target stream handle must be seekable."
421            if sys.version_info >= (3,) and not stream.seekable():
422                raise ValueError(error_message)
424            try:
425                stream.seek(stream.tell())
426            except (NotImplementedError, AttributeError):
427                raise ValueError(error_message)
429        # Write the content to the user stream
430        stream.write(self._current_content)
431        if self._download_complete:
432            return self.size
434        data_end = self._file_size
435        if self._end_range is not None:
436            # Use the length unless it is over the end of the file
437            data_end = min(self._file_size, self._end_range + 1)
439        downloader = _AsyncChunkDownloader(
440            client=self._clients.blob,
441            non_empty_ranges=self._non_empty_ranges,
442            total_size=self.size,
443            chunk_size=self._config.max_chunk_get_size,
444            current_progress=self._first_get_size,
445            start_range=self._initial_range[1] + 1,  # start where the first download ended
446            end_range=data_end,
447            stream=stream,
448            parallel=parallel,
449            validate_content=self._validate_content,
450            encryption_options=self._encryption_options,
451            use_location=self._location_mode,
452            **self._request_options)
454        dl_tasks = downloader.get_chunk_offsets()
455        running_futures = [
456            asyncio.ensure_future(downloader.process_chunk(d))
457            for d in islice(dl_tasks, 0, self._max_concurrency)
458        ]
459        while running_futures:
460            # Wait for some download to finish before adding a new one
461            _done, running_futures = await asyncio.wait(
462                running_futures, return_when=asyncio.FIRST_COMPLETED)
463            try:
464                next_chunk = next(dl_tasks)
465            except StopIteration:
466                break
467            else:
468                running_futures.add(asyncio.ensure_future(downloader.process_chunk(next_chunk)))
470        if running_futures:
471            # Wait for the remaining downloads to finish
472            await asyncio.wait(running_futures)
473        return self.size
475    async def download_to_stream(self, stream, max_concurrency=1):
476        """Download the contents of this blob to a stream.
478        :param stream:
479            The stream to download to. This can be an open file-handle,
480            or any writable stream. The stream must be seekable if the download
481            uses more than one parallel connection.
482        :returns: The properties of the downloaded blob.
483        :rtype: Any
484        """
485        warnings.warn(
486            "download_to_stream is deprecated, use readinto instead",
487            DeprecationWarning
488        )
489        self._max_concurrency = max_concurrency
490        await self.readinto(stream)
491        return self.properties