1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6
7import sys
8import threading
9import warnings
10from io import BytesIO
11
12from azure.core.exceptions import HttpResponseError
13from azure.core.tracing.common import with_current_context
14from ._shared.encryption import decrypt_blob
15from ._shared.request_handlers import validate_and_format_range_headers
16from ._shared.response_handlers import process_storage_error, parse_length_from_content_range
17from ._deserialize import get_page_ranges_result
18
19
20def process_range_and_offset(start_range, end_range, length, encryption):
21    start_offset, end_offset = 0, 0
22    if encryption.get("key") is not None or encryption.get("resolver") is not None:
23        if start_range is not None:
24            # Align the start of the range along a 16 byte block
25            start_offset = start_range % 16
26            start_range -= start_offset
27
28            # Include an extra 16 bytes for the IV if necessary
29            # Because of the previous offsetting, start_range will always
30            # be a multiple of 16.
31            if start_range > 0:
32                start_offset += 16
33                start_range -= 16
34
35        if length is not None:
36            # Align the end of the range along a 16 byte block
37            end_offset = 15 - (end_range % 16)
38            end_range += end_offset
39
40    return (start_range, end_range), (start_offset, end_offset)
41
42
43def process_content(data, start_offset, end_offset, encryption):
44    if data is None:
45        raise ValueError("Response cannot be None.")
46    try:
47        content = b"".join(list(data))
48    except Exception as error:
49        raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error)
50    if content and encryption.get("key") is not None or encryption.get("resolver") is not None:
51        try:
52            return decrypt_blob(
53                encryption.get("required"),
54                encryption.get("key"),
55                encryption.get("resolver"),
56                content,
57                start_offset,
58                end_offset,
59                data.response.headers,
60            )
61        except Exception as error:
62            raise HttpResponseError(message="Decryption failed.", response=data.response, error=error)
63    return content
64
65
66class _ChunkDownloader(object):  # pylint: disable=too-many-instance-attributes
67    def __init__(
68        self,
69        client=None,
70        non_empty_ranges=None,
71        total_size=None,
72        chunk_size=None,
73        current_progress=None,
74        start_range=None,
75        end_range=None,
76        stream=None,
77        parallel=None,
78        validate_content=None,
79        encryption_options=None,
80        **kwargs
81    ):
82        self.client = client
83        self.non_empty_ranges = non_empty_ranges
84
85        # Information on the download range/chunk size
86        self.chunk_size = chunk_size
87        self.total_size = total_size
88        self.start_index = start_range
89        self.end_index = end_range
90
91        # The destination that we will write to
92        self.stream = stream
93        self.stream_lock = threading.Lock() if parallel else None
94        self.progress_lock = threading.Lock() if parallel else None
95
96        # For a parallel download, the stream is always seekable, so we note down the current position
97        # in order to seek to the right place when out-of-order chunks come in
98        self.stream_start = stream.tell() if parallel else None
99
100        # Download progress so far
101        self.progress_total = current_progress
102
103        # Encryption
104        self.encryption_options = encryption_options
105
106        # Parameters for each get operation
107        self.validate_content = validate_content
108        self.request_options = kwargs
109
110    def _calculate_range(self, chunk_start):
111        if chunk_start + self.chunk_size > self.end_index:
112            chunk_end = self.end_index
113        else:
114            chunk_end = chunk_start + self.chunk_size
115        return chunk_start, chunk_end
116
117    def get_chunk_offsets(self):
118        index = self.start_index
119        while index < self.end_index:
120            yield index
121            index += self.chunk_size
122
123    def process_chunk(self, chunk_start):
124        chunk_start, chunk_end = self._calculate_range(chunk_start)
125        chunk_data = self._download_chunk(chunk_start, chunk_end - 1)
126        length = chunk_end - chunk_start
127        if length > 0:
128            self._write_to_stream(chunk_data, chunk_start)
129            self._update_progress(length)
130
131    def yield_chunk(self, chunk_start):
132        chunk_start, chunk_end = self._calculate_range(chunk_start)
133        return self._download_chunk(chunk_start, chunk_end - 1)
134
135    def _update_progress(self, length):
136        if self.progress_lock:
137            with self.progress_lock:  # pylint: disable=not-context-manager
138                self.progress_total += length
139        else:
140            self.progress_total += length
141
142    def _write_to_stream(self, chunk_data, chunk_start):
143        if self.stream_lock:
144            with self.stream_lock:  # pylint: disable=not-context-manager
145                self.stream.seek(self.stream_start + (chunk_start - self.start_index))
146                self.stream.write(chunk_data)
147        else:
148            self.stream.write(chunk_data)
149
150    def _do_optimize(self, given_range_start, given_range_end):
151        # If we have no page range list stored, then assume there's data everywhere for that page blob
152        # or it's a block blob or append blob
153        if self.non_empty_ranges is None:
154            return False
155
156        for source_range in self.non_empty_ranges:
157            # Case 1: As the range list is sorted, if we've reached such a source_range
158            # we've checked all the appropriate source_range already and haven't found any overlapping.
159            # so the given range doesn't have any data and download optimization could be applied.
160            # given range:		|   |
161            # source range:			       |   |
162            if given_range_end < source_range['start']:  # pylint:disable=no-else-return
163                return True
164            # Case 2: the given range comes after source_range, continue checking.
165            # given range:				|   |
166            # source range:	|   |
167            elif source_range['end'] < given_range_start:
168                pass
169            # Case 3: source_range and given range overlap somehow, no need to optimize.
170            else:
171                return False
172        # Went through all src_ranges, but nothing overlapped. Optimization will be applied.
173        return True
174
175    def _download_chunk(self, chunk_start, chunk_end):
176        download_range, offset = process_range_and_offset(
177            chunk_start, chunk_end, chunk_end, self.encryption_options
178        )
179
180        # No need to download the empty chunk from server if there's no data in the chunk to be downloaded.
181        # Do optimize and create empty chunk locally if condition is met.
182        if self._do_optimize(download_range[0], download_range[1]):
183            chunk_data = b"\x00" * self.chunk_size
184        else:
185            range_header, range_validation = validate_and_format_range_headers(
186                download_range[0],
187                download_range[1],
188                check_content_md5=self.validate_content
189            )
190
191            try:
192                _, response = self.client.download(
193                    range=range_header,
194                    range_get_content_md5=range_validation,
195                    validate_content=self.validate_content,
196                    data_stream_total=self.total_size,
197                    download_stream_current=self.progress_total,
198                    **self.request_options
199                )
200            except HttpResponseError as error:
201                process_storage_error(error)
202
203            chunk_data = process_content(response, offset[0], offset[1], self.encryption_options)
204
205            # This makes sure that if_match is set so that we can validate
206            # that subsequent downloads are to an unmodified blob
207            if self.request_options.get("modified_access_conditions"):
208                self.request_options["modified_access_conditions"].if_match = response.properties.etag
209
210        return chunk_data
211
212
213class _ChunkIterator(object):
214    """Async iterator for chunks in blob download stream."""
215
216    def __init__(self, size, content, downloader):
217        self.size = size
218        self._current_content = content
219        self._iter_downloader = downloader
220        self._iter_chunks = None
221        self._complete = (size == 0)
222
223    def __len__(self):
224        return self.size
225
226    def __iter__(self):
227        return self
228
229    def __next__(self):
230        """Iterate through responses."""
231        if self._complete:
232            raise StopIteration("Download complete")
233        if not self._iter_downloader:
234            # If no iterator was supplied, the download completed with
235            # the initial GET, so we just return that data
236            self._complete = True
237            return self._current_content
238
239        if not self._iter_chunks:
240            self._iter_chunks = self._iter_downloader.get_chunk_offsets()
241        else:
242            chunk = next(self._iter_chunks)
243            self._current_content = self._iter_downloader.yield_chunk(chunk)
244
245        return self._current_content
246
247    next = __next__  # Python 2 compatibility.
248
249
250class StorageStreamDownloader(object):  # pylint: disable=too-many-instance-attributes
251    """A streaming object to download from Azure Storage.
252
253    :ivar str name:
254        The name of the blob being downloaded.
255    :ivar str container:
256        The name of the container where the blob is.
257    :ivar ~azure.storage.blob.BlobProperties properties:
258        The properties of the blob being downloaded. If only a range of the data is being
259        downloaded, this will be reflected in the properties.
260    :ivar int size:
261        The size of the total data in the stream. This will be the byte range if speficied,
262        otherwise the total size of the blob.
263    """
264
265    def __init__(
266        self,
267        clients=None,
268        config=None,
269        start_range=None,
270        end_range=None,
271        validate_content=None,
272        encryption_options=None,
273        max_concurrency=1,
274        name=None,
275        container=None,
276        encoding=None,
277        **kwargs
278    ):
279        self.name = name
280        self.container = container
281        self.properties = None
282        self.size = None
283
284        self._clients = clients
285        self._config = config
286        self._start_range = start_range
287        self._end_range = end_range
288        self._max_concurrency = max_concurrency
289        self._encoding = encoding
290        self._validate_content = validate_content
291        self._encryption_options = encryption_options or {}
292        self._request_options = kwargs
293        self._location_mode = None
294        self._download_complete = False
295        self._current_content = None
296        self._file_size = None
297        self._non_empty_ranges = None
298        self._response = None
299
300        # The service only provides transactional MD5s for chunks under 4MB.
301        # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first
302        # chunk so a transactional MD5 can be retrieved.
303        self._first_get_size = (
304            self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size
305        )
306        initial_request_start = self._start_range if self._start_range is not None else 0
307        if self._end_range is not None and self._end_range - self._start_range < self._first_get_size:
308            initial_request_end = self._end_range
309        else:
310            initial_request_end = initial_request_start + self._first_get_size - 1
311
312        self._initial_range, self._initial_offset = process_range_and_offset(
313            initial_request_start, initial_request_end, self._end_range, self._encryption_options
314        )
315
316        self._response = self._initial_request()
317        self.properties = self._response.properties
318        self.properties.name = self.name
319        self.properties.container = self.container
320
321        # Set the content length to the download size instead of the size of
322        # the last range
323        self.properties.size = self.size
324
325        # Overwrite the content range to the user requested range
326        self.properties.content_range = "bytes {0}-{1}/{2}".format(
327            self._start_range,
328            self._end_range,
329            self._file_size
330        )
331
332        # Overwrite the content MD5 as it is the MD5 for the last range instead
333        # of the stored MD5
334        # TODO: Set to the stored MD5 when the service returns this
335        self.properties.content_md5 = None
336
337        if self.size == 0:
338            self._current_content = b""
339        else:
340            self._current_content = process_content(
341                self._response,
342                self._initial_offset[0],
343                self._initial_offset[1],
344                self._encryption_options
345            )
346
347    def __len__(self):
348        return self.size
349
350    def _initial_request(self):
351        range_header, range_validation = validate_and_format_range_headers(
352            self._initial_range[0],
353            self._initial_range[1],
354            start_range_required=False,
355            end_range_required=False,
356            check_content_md5=self._validate_content
357        )
358
359        try:
360            location_mode, response = self._clients.blob.download(
361                range=range_header,
362                range_get_content_md5=range_validation,
363                validate_content=self._validate_content,
364                data_stream_total=None,
365                download_stream_current=0,
366                **self._request_options
367            )
368
369            # Check the location we read from to ensure we use the same one
370            # for subsequent requests.
371            self._location_mode = location_mode
372
373            # Parse the total file size and adjust the download size if ranges
374            # were specified
375            self._file_size = parse_length_from_content_range(response.properties.content_range)
376            if self._end_range is not None:
377                # Use the end range index unless it is over the end of the file
378                self.size = min(self._file_size, self._end_range - self._start_range + 1)
379            elif self._start_range is not None:
380                self.size = self._file_size - self._start_range
381            else:
382                self.size = self._file_size
383
384        except HttpResponseError as error:
385            if self._start_range is None and error.response.status_code == 416:
386                # Get range will fail on an empty file. If the user did not
387                # request a range, do a regular get request in order to get
388                # any properties.
389                try:
390                    _, response = self._clients.blob.download(
391                        validate_content=self._validate_content,
392                        data_stream_total=0,
393                        download_stream_current=0,
394                        **self._request_options
395                    )
396                except HttpResponseError as error:
397                    process_storage_error(error)
398
399                # Set the download size to empty
400                self.size = 0
401                self._file_size = 0
402            else:
403                process_storage_error(error)
404
405        # get page ranges to optimize downloading sparse page blob
406        if response.properties.blob_type == 'PageBlob':
407            try:
408                page_ranges = self._clients.page_blob.get_page_ranges()
409                self._non_empty_ranges = get_page_ranges_result(page_ranges)[0]
410            # according to the REST API documentation:
411            # in a highly fragmented page blob with a large number of writes,
412            # a Get Page Ranges request can fail due to an internal server timeout.
413            # thus, if the page blob is not sparse, it's ok for it to fail
414            except HttpResponseError:
415                pass
416
417        # If the file is small, the download is complete at this point.
418        # If file size is large, download the rest of the file in chunks.
419        if response.properties.size != self.size:
420            # Lock on the etag. This can be overriden by the user by specifying '*'
421            if self._request_options.get("modified_access_conditions"):
422                if not self._request_options["modified_access_conditions"].if_match:
423                    self._request_options["modified_access_conditions"].if_match = response.properties.etag
424        else:
425            self._download_complete = True
426        return response
427
428    def chunks(self):
429        if self.size == 0 or self._download_complete:
430            iter_downloader = None
431        else:
432            data_end = self._file_size
433            if self._end_range is not None:
434                # Use the end range index unless it is over the end of the file
435                data_end = min(self._file_size, self._end_range + 1)
436            iter_downloader = _ChunkDownloader(
437                client=self._clients.blob,
438                non_empty_ranges=self._non_empty_ranges,
439                total_size=self.size,
440                chunk_size=self._config.max_chunk_get_size,
441                current_progress=self._first_get_size,
442                start_range=self._initial_range[1] + 1,  # start where the first download ended
443                end_range=data_end,
444                stream=None,
445                parallel=False,
446                validate_content=self._validate_content,
447                encryption_options=self._encryption_options,
448                use_location=self._location_mode,
449                **self._request_options
450            )
451        return _ChunkIterator(
452            size=self.size,
453            content=self._current_content,
454            downloader=iter_downloader)
455
456    def readall(self):
457        """Download the contents of this blob.
458
459        This operation is blocking until all data is downloaded.
460        :rtype: bytes or str
461        """
462        stream = BytesIO()
463        self.readinto(stream)
464        data = stream.getvalue()
465        if self._encoding:
466            return data.decode(self._encoding)
467        return data
468
469    def content_as_bytes(self, max_concurrency=1):
470        """Download the contents of this file.
471
472        This operation is blocking until all data is downloaded.
473
474        :keyword int max_concurrency:
475            The number of parallel connections with which to download.
476        :rtype: bytes
477        """
478        warnings.warn(
479            "content_as_bytes is deprecated, use readall instead",
480            DeprecationWarning
481        )
482        self._max_concurrency = max_concurrency
483        return self.readall()
484
485    def content_as_text(self, max_concurrency=1, encoding="UTF-8"):
486        """Download the contents of this blob, and decode as text.
487
488        This operation is blocking until all data is downloaded.
489
490        :keyword int max_concurrency:
491            The number of parallel connections with which to download.
492        :param str encoding:
493            Test encoding to decode the downloaded bytes. Default is UTF-8.
494        :rtype: str
495        """
496        warnings.warn(
497            "content_as_text is deprecated, use readall instead",
498            DeprecationWarning
499        )
500        self._max_concurrency = max_concurrency
501        self._encoding = encoding
502        return self.readall()
503
504    def readinto(self, stream):
505        """Download the contents of this file to a stream.
506
507        :param stream:
508            The stream to download to. This can be an open file-handle,
509            or any writable stream. The stream must be seekable if the download
510            uses more than one parallel connection.
511        :returns: The number of bytes read.
512        :rtype: int
513        """
514        # The stream must be seekable if parallel download is required
515        parallel = self._max_concurrency > 1
516        if parallel:
517            error_message = "Target stream handle must be seekable."
518            if sys.version_info >= (3,) and not stream.seekable():
519                raise ValueError(error_message)
520
521            try:
522                stream.seek(stream.tell())
523            except (NotImplementedError, AttributeError):
524                raise ValueError(error_message)
525
526        # Write the content to the user stream
527        stream.write(self._current_content)
528        if self._download_complete:
529            return self.size
530
531        data_end = self._file_size
532        if self._end_range is not None:
533            # Use the length unless it is over the end of the file
534            data_end = min(self._file_size, self._end_range + 1)
535
536        downloader = _ChunkDownloader(
537            client=self._clients.blob,
538            non_empty_ranges=self._non_empty_ranges,
539            total_size=self.size,
540            chunk_size=self._config.max_chunk_get_size,
541            current_progress=self._first_get_size,
542            start_range=self._initial_range[1] + 1,  # Start where the first download ended
543            end_range=data_end,
544            stream=stream,
545            parallel=parallel,
546            validate_content=self._validate_content,
547            encryption_options=self._encryption_options,
548            use_location=self._location_mode,
549            **self._request_options
550        )
551        if parallel:
552            import concurrent.futures
553            executor = concurrent.futures.ThreadPoolExecutor(self._max_concurrency)
554            list(executor.map(
555                    with_current_context(downloader.process_chunk),
556                    downloader.get_chunk_offsets()
557                ))
558        else:
559            for chunk in downloader.get_chunk_offsets():
560                downloader.process_chunk(chunk)
561        return self.size
562
563    def download_to_stream(self, stream, max_concurrency=1):
564        """Download the contents of this blob to a stream.
565
566        :param stream:
567            The stream to download to. This can be an open file-handle,
568            or any writable stream. The stream must be seekable if the download
569            uses more than one parallel connection.
570        :returns: The properties of the downloaded blob.
571        :rtype: Any
572        """
573        warnings.warn(
574            "download_to_stream is deprecated, use readinto instead",
575            DeprecationWarning
576        )
577        self._max_concurrency = max_concurrency
578        self.readinto(stream)
579        return self.properties
580