1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6# pylint: disable=invalid-overridden-method
7
8import asyncio
9import sys
10from io import BytesIO
11from itertools import islice
12import warnings
13
14from azure.core.exceptions import HttpResponseError
15from .._shared.encryption import decrypt_blob
16from .._shared.request_handlers import validate_and_format_range_headers
17from .._shared.response_handlers import process_storage_error, parse_length_from_content_range
18from .._deserialize import get_page_ranges_result
19from .._download import process_range_and_offset, _ChunkDownloader
20
21
22async def process_content(data, start_offset, end_offset, encryption):
23    if data is None:
24        raise ValueError("Response cannot be None.")
25    try:
26        content = data.response.body()
27    except Exception as error:
28        raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error)
29    if encryption.get('key') is not None or encryption.get('resolver') is not None:
30        try:
31            return decrypt_blob(
32                encryption.get('required'),
33                encryption.get('key'),
34                encryption.get('resolver'),
35                content,
36                start_offset,
37                end_offset,
38                data.response.headers)
39        except Exception as error:
40            raise HttpResponseError(
41                message="Decryption failed.",
42                response=data.response,
43                error=error)
44    return content
45
46
47class _AsyncChunkDownloader(_ChunkDownloader):
48    def __init__(self, **kwargs):
49        super(_AsyncChunkDownloader, self).__init__(**kwargs)
50        self.stream_lock = asyncio.Lock() if kwargs.get('parallel') else None
51        self.progress_lock = asyncio.Lock() if kwargs.get('parallel') else None
52
53    async def process_chunk(self, chunk_start):
54        chunk_start, chunk_end = self._calculate_range(chunk_start)
55        chunk_data = await self._download_chunk(chunk_start, chunk_end - 1)
56        length = chunk_end - chunk_start
57        if length > 0:
58            await self._write_to_stream(chunk_data, chunk_start)
59            await self._update_progress(length)
60
61    async def yield_chunk(self, chunk_start):
62        chunk_start, chunk_end = self._calculate_range(chunk_start)
63        return await self._download_chunk(chunk_start, chunk_end - 1)
64
65    async def _update_progress(self, length):
66        if self.progress_lock:
67            async with self.progress_lock:  # pylint: disable=not-async-context-manager
68                self.progress_total += length
69        else:
70            self.progress_total += length
71
72    async def _write_to_stream(self, chunk_data, chunk_start):
73        if self.stream_lock:
74            async with self.stream_lock:  # pylint: disable=not-async-context-manager
75                self.stream.seek(self.stream_start + (chunk_start - self.start_index))
76                self.stream.write(chunk_data)
77        else:
78            self.stream.write(chunk_data)
79
80    async def _download_chunk(self, chunk_start, chunk_end):
81        download_range, offset = process_range_and_offset(
82            chunk_start, chunk_end, chunk_end, self.encryption_options)
83
84        # No need to download the empty chunk from server if there's no data in the chunk to be downloaded.
85        # Do optimize and create empty chunk locally if condition is met.
86        if self._do_optimize(download_range[0], download_range[1]):
87            chunk_data = b"\x00" * self.chunk_size
88        else:
89            range_header, range_validation = validate_and_format_range_headers(
90                download_range[0],
91                download_range[1],
92                check_content_md5=self.validate_content
93            )
94            try:
95                _, response = await self.client.download(
96                    range=range_header,
97                    range_get_content_md5=range_validation,
98                    validate_content=self.validate_content,
99                    data_stream_total=self.total_size,
100                    download_stream_current=self.progress_total,
101                    **self.request_options
102                )
103            except HttpResponseError as error:
104                process_storage_error(error)
105
106            chunk_data = await process_content(response, offset[0], offset[1], self.encryption_options)
107
108            # This makes sure that if_match is set so that we can validate
109            # that subsequent downloads are to an unmodified blob
110            if self.request_options.get('modified_access_conditions'):
111                self.request_options['modified_access_conditions'].if_match = response.properties.etag
112
113        return chunk_data
114
115
116class _AsyncChunkIterator(object):
117    """Async iterator for chunks in blob download stream."""
118
119    def __init__(self, size, content, downloader):
120        self.size = size
121        self._current_content = content
122        self._iter_downloader = downloader
123        self._iter_chunks = None
124        self._complete = (size == 0)
125
126    def __len__(self):
127        return self.size
128
129    def __iter__(self):
130        raise TypeError("Async stream must be iterated asynchronously.")
131
132    def __aiter__(self):
133        return self
134
135    async def __anext__(self):
136        """Iterate through responses."""
137        if self._complete:
138            raise StopAsyncIteration("Download complete")
139        if not self._iter_downloader:
140            # If no iterator was supplied, the download completed with
141            # the initial GET, so we just return that data
142            self._complete = True
143            return self._current_content
144
145        if not self._iter_chunks:
146            self._iter_chunks = self._iter_downloader.get_chunk_offsets()
147        else:
148            try:
149                chunk = next(self._iter_chunks)
150            except StopIteration:
151                raise StopAsyncIteration("Download complete")
152            self._current_content = await self._iter_downloader.yield_chunk(chunk)
153
154        return self._current_content
155
156
157class StorageStreamDownloader(object):  # pylint: disable=too-many-instance-attributes
158    """A streaming object to download from Azure Storage.
159
160    :ivar str name:
161        The name of the blob being downloaded.
162    :ivar str container:
163        The name of the container where the blob is.
164    :ivar ~azure.storage.blob.BlobProperties properties:
165        The properties of the blob being downloaded. If only a range of the data is being
166        downloaded, this will be reflected in the properties.
167    :ivar int size:
168        The size of the total data in the stream. This will be the byte range if speficied,
169        otherwise the total size of the blob.
170    """
171
172    def __init__(
173            self,
174            clients=None,
175            config=None,
176            start_range=None,
177            end_range=None,
178            validate_content=None,
179            encryption_options=None,
180            max_concurrency=1,
181            name=None,
182            container=None,
183            encoding=None,
184            **kwargs
185    ):
186        self.name = name
187        self.container = container
188        self.properties = None
189        self.size = None
190
191        self._clients = clients
192        self._config = config
193        self._start_range = start_range
194        self._end_range = end_range
195        self._max_concurrency = max_concurrency
196        self._encoding = encoding
197        self._validate_content = validate_content
198        self._encryption_options = encryption_options or {}
199        self._request_options = kwargs
200        self._location_mode = None
201        self._download_complete = False
202        self._current_content = None
203        self._file_size = None
204        self._non_empty_ranges = None
205        self._response = None
206
207        # The service only provides transactional MD5s for chunks under 4MB.
208        # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first
209        # chunk so a transactional MD5 can be retrieved.
210        self._first_get_size = self._config.max_single_get_size if not self._validate_content \
211            else self._config.max_chunk_get_size
212        initial_request_start = self._start_range if self._start_range is not None else 0
213        if self._end_range is not None and self._end_range - self._start_range < self._first_get_size:
214            initial_request_end = self._end_range
215        else:
216            initial_request_end = initial_request_start + self._first_get_size - 1
217
218        self._initial_range, self._initial_offset = process_range_and_offset(
219            initial_request_start, initial_request_end, self._end_range, self._encryption_options
220        )
221
222    def __len__(self):
223        return self.size
224
225    async def _setup(self):
226        self._response = await self._initial_request()
227        self.properties = self._response.properties
228        self.properties.name = self.name
229        self.properties.container = self.container
230
231        # Set the content length to the download size instead of the size of
232        # the last range
233        self.properties.size = self.size
234
235        # Overwrite the content range to the user requested range
236        self.properties.content_range = 'bytes {0}-{1}/{2}'.format(
237            self._start_range,
238            self._end_range,
239            self._file_size
240        )
241
242        # Overwrite the content MD5 as it is the MD5 for the last range instead
243        # of the stored MD5
244        # TODO: Set to the stored MD5 when the service returns this
245        self.properties.content_md5 = None
246
247        if self.size == 0:
248            self._current_content = b""
249        else:
250            self._current_content = await process_content(
251                self._response,
252                self._initial_offset[0],
253                self._initial_offset[1],
254                self._encryption_options
255            )
256
257    async def _initial_request(self):
258        range_header, range_validation = validate_and_format_range_headers(
259            self._initial_range[0],
260            self._initial_range[1],
261            start_range_required=False,
262            end_range_required=False,
263            check_content_md5=self._validate_content)
264
265        try:
266            location_mode, response = await self._clients.blob.download(
267                range=range_header,
268                range_get_content_md5=range_validation,
269                validate_content=self._validate_content,
270                data_stream_total=None,
271                download_stream_current=0,
272                **self._request_options)
273
274            # Check the location we read from to ensure we use the same one
275            # for subsequent requests.
276            self._location_mode = location_mode
277
278            # Parse the total file size and adjust the download size if ranges
279            # were specified
280            self._file_size = parse_length_from_content_range(response.properties.content_range)
281            if self._end_range is not None:
282                # Use the length unless it is over the end of the file
283                self.size = min(self._file_size, self._end_range - self._start_range + 1)
284            elif self._start_range is not None:
285                self.size = self._file_size - self._start_range
286            else:
287                self.size = self._file_size
288
289        except HttpResponseError as error:
290            if self._start_range is None and error.response.status_code == 416:
291                # Get range will fail on an empty file. If the user did not
292                # request a range, do a regular get request in order to get
293                # any properties.
294                try:
295                    _, response = await self._clients.blob.download(
296                        validate_content=self._validate_content,
297                        data_stream_total=0,
298                        download_stream_current=0,
299                        **self._request_options)
300                except HttpResponseError as error:
301                    process_storage_error(error)
302
303                # Set the download size to empty
304                self.size = 0
305                self._file_size = 0
306            else:
307                process_storage_error(error)
308
309        # get page ranges to optimize downloading sparse page blob
310        if response.properties.blob_type == 'PageBlob':
311            try:
312                page_ranges = await self._clients.page_blob.get_page_ranges()
313                self._non_empty_ranges = get_page_ranges_result(page_ranges)[0]
314            except HttpResponseError:
315                pass
316
317        # If the file is small, the download is complete at this point.
318        # If file size is large, download the rest of the file in chunks.
319        if response.properties.size != self.size:
320            # Lock on the etag. This can be overriden by the user by specifying '*'
321            if self._request_options.get('modified_access_conditions'):
322                if not self._request_options['modified_access_conditions'].if_match:
323                    self._request_options['modified_access_conditions'].if_match = response.properties.etag
324        else:
325            self._download_complete = True
326        return response
327
328    def chunks(self):
329        """Iterate over chunks in the download stream.
330
331        :rtype: Iterable[bytes]
332        """
333        if self.size == 0 or self._download_complete:
334            iter_downloader = None
335        else:
336            data_end = self._file_size
337            if self._end_range is not None:
338                # Use the length unless it is over the end of the file
339                data_end = min(self._file_size, self._end_range + 1)
340            iter_downloader = _AsyncChunkDownloader(
341                client=self._clients.blob,
342                non_empty_ranges=self._non_empty_ranges,
343                total_size=self.size,
344                chunk_size=self._config.max_chunk_get_size,
345                current_progress=self._first_get_size,
346                start_range=self._initial_range[1] + 1,  # Start where the first download ended
347                end_range=data_end,
348                stream=None,
349                parallel=False,
350                validate_content=self._validate_content,
351                encryption_options=self._encryption_options,
352                use_location=self._location_mode,
353                **self._request_options)
354        return _AsyncChunkIterator(
355            size=self.size,
356            content=self._current_content,
357            downloader=iter_downloader)
358
359    async def readall(self):
360        """Download the contents of this blob.
361
362        This operation is blocking until all data is downloaded.
363        :rtype: bytes or str
364        """
365        stream = BytesIO()
366        await self.readinto(stream)
367        data = stream.getvalue()
368        if self._encoding:
369            return data.decode(self._encoding)
370        return data
371
372    async def content_as_bytes(self, max_concurrency=1):
373        """Download the contents of this file.
374
375        This operation is blocking until all data is downloaded.
376
377        :keyword int max_concurrency:
378            The number of parallel connections with which to download.
379        :rtype: bytes
380        """
381        warnings.warn(
382            "content_as_bytes is deprecated, use readall instead",
383            DeprecationWarning
384        )
385        self._max_concurrency = max_concurrency
386        return await self.readall()
387
388    async def content_as_text(self, max_concurrency=1, encoding="UTF-8"):
389        """Download the contents of this blob, and decode as text.
390
391        This operation is blocking until all data is downloaded.
392
393        :keyword int max_concurrency:
394            The number of parallel connections with which to download.
395        :param str encoding:
396            Test encoding to decode the downloaded bytes. Default is UTF-8.
397        :rtype: str
398        """
399        warnings.warn(
400            "content_as_text is deprecated, use readall instead",
401            DeprecationWarning
402        )
403        self._max_concurrency = max_concurrency
404        self._encoding = encoding
405        return await self.readall()
406
407    async def readinto(self, stream):
408        """Download the contents of this blob to a stream.
409
410        :param stream:
411            The stream to download to. This can be an open file-handle,
412            or any writable stream. The stream must be seekable if the download
413            uses more than one parallel connection.
414        :returns: The number of bytes read.
415        :rtype: int
416        """
417        # the stream must be seekable if parallel download is required
418        parallel = self._max_concurrency > 1
419        if parallel:
420            error_message = "Target stream handle must be seekable."
421            if sys.version_info >= (3,) and not stream.seekable():
422                raise ValueError(error_message)
423
424            try:
425                stream.seek(stream.tell())
426            except (NotImplementedError, AttributeError):
427                raise ValueError(error_message)
428
429        # Write the content to the user stream
430        stream.write(self._current_content)
431        if self._download_complete:
432            return self.size
433
434        data_end = self._file_size
435        if self._end_range is not None:
436            # Use the length unless it is over the end of the file
437            data_end = min(self._file_size, self._end_range + 1)
438
439        downloader = _AsyncChunkDownloader(
440            client=self._clients.blob,
441            non_empty_ranges=self._non_empty_ranges,
442            total_size=self.size,
443            chunk_size=self._config.max_chunk_get_size,
444            current_progress=self._first_get_size,
445            start_range=self._initial_range[1] + 1,  # start where the first download ended
446            end_range=data_end,
447            stream=stream,
448            parallel=parallel,
449            validate_content=self._validate_content,
450            encryption_options=self._encryption_options,
451            use_location=self._location_mode,
452            **self._request_options)
453
454        dl_tasks = downloader.get_chunk_offsets()
455        running_futures = [
456            asyncio.ensure_future(downloader.process_chunk(d))
457            for d in islice(dl_tasks, 0, self._max_concurrency)
458        ]
459        while running_futures:
460            # Wait for some download to finish before adding a new one
461            _done, running_futures = await asyncio.wait(
462                running_futures, return_when=asyncio.FIRST_COMPLETED)
463            try:
464                next_chunk = next(dl_tasks)
465            except StopIteration:
466                break
467            else:
468                running_futures.add(asyncio.ensure_future(downloader.process_chunk(next_chunk)))
469
470        if running_futures:
471            # Wait for the remaining downloads to finish
472            await asyncio.wait(running_futures)
473        return self.size
474
475    async def download_to_stream(self, stream, max_concurrency=1):
476        """Download the contents of this blob to a stream.
477
478        :param stream:
479            The stream to download to. This can be an open file-handle,
480            or any writable stream. The stream must be seekable if the download
481            uses more than one parallel connection.
482        :returns: The properties of the downloaded blob.
483        :rtype: Any
484        """
485        warnings.warn(
486            "download_to_stream is deprecated, use readinto instead",
487            DeprecationWarning
488        )
489        self._max_concurrency = max_concurrency
490        await self.readinto(stream)
491        return self.properties
492