1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6# pylint: disable=no-self-use
7
8from concurrent import futures
9from io import (BytesIO, IOBase, SEEK_CUR, SEEK_END, SEEK_SET, UnsupportedOperation)
10from threading import Lock
11from itertools import islice
12from math import ceil
13
14import six
15
16from azure.core.tracing.common import with_current_context
17
18from . import encode_base64, url_quote
19from .request_handlers import get_length
20from .response_handlers import return_response_headers
21from .encryption import get_blob_encryptor_and_padder
22
23
24_LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE = 4 * 1024 * 1024
25_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM = "{0} should be a seekable file-like/io.IOBase type stream object."
26
27
28def _parallel_uploads(executor, uploader, pending, running):
29    range_ids = []
30    while True:
31        # Wait for some download to finish before adding a new one
32        done, running = futures.wait(running, return_when=futures.FIRST_COMPLETED)
33        range_ids.extend([chunk.result() for chunk in done])
34        try:
35            for _ in range(0, len(done)):
36                next_chunk = next(pending)
37                running.add(executor.submit(with_current_context(uploader), next_chunk))
38        except StopIteration:
39            break
40
41    # Wait for the remaining uploads to finish
42    done, _running = futures.wait(running)
43    range_ids.extend([chunk.result() for chunk in done])
44    return range_ids
45
46
47def upload_data_chunks(
48        service=None,
49        uploader_class=None,
50        total_size=None,
51        chunk_size=None,
52        max_concurrency=None,
53        stream=None,
54        validate_content=None,
55        encryption_options=None,
56        **kwargs):
57
58    if encryption_options:
59        encryptor, padder = get_blob_encryptor_and_padder(
60            encryption_options.get('cek'),
61            encryption_options.get('vector'),
62            uploader_class is not PageBlobChunkUploader)
63        kwargs['encryptor'] = encryptor
64        kwargs['padder'] = padder
65
66    parallel = max_concurrency > 1
67    if parallel and 'modified_access_conditions' in kwargs:
68        # Access conditions do not work with parallelism
69        kwargs['modified_access_conditions'] = None
70
71    uploader = uploader_class(
72        service=service,
73        total_size=total_size,
74        chunk_size=chunk_size,
75        stream=stream,
76        parallel=parallel,
77        validate_content=validate_content,
78        **kwargs)
79    if parallel:
80        executor = futures.ThreadPoolExecutor(max_concurrency)
81        upload_tasks = uploader.get_chunk_streams()
82        running_futures = [
83            executor.submit(with_current_context(uploader.process_chunk), u)
84            for u in islice(upload_tasks, 0, max_concurrency)
85        ]
86        range_ids = _parallel_uploads(executor, uploader.process_chunk, upload_tasks, running_futures)
87    else:
88        range_ids = [uploader.process_chunk(result) for result in uploader.get_chunk_streams()]
89    if any(range_ids):
90        return [r[1] for r in sorted(range_ids, key=lambda r: r[0])]
91    return uploader.response_headers
92
93
94def upload_substream_blocks(
95        service=None,
96        uploader_class=None,
97        total_size=None,
98        chunk_size=None,
99        max_concurrency=None,
100        stream=None,
101        **kwargs):
102    parallel = max_concurrency > 1
103    if parallel and 'modified_access_conditions' in kwargs:
104        # Access conditions do not work with parallelism
105        kwargs['modified_access_conditions'] = None
106    uploader = uploader_class(
107        service=service,
108        total_size=total_size,
109        chunk_size=chunk_size,
110        stream=stream,
111        parallel=parallel,
112        **kwargs)
113
114    if parallel:
115        executor = futures.ThreadPoolExecutor(max_concurrency)
116        upload_tasks = uploader.get_substream_blocks()
117        running_futures = [
118            executor.submit(with_current_context(uploader.process_substream_block), u)
119            for u in islice(upload_tasks, 0, max_concurrency)
120        ]
121        range_ids = _parallel_uploads(executor, uploader.process_substream_block, upload_tasks, running_futures)
122    else:
123        range_ids = [uploader.process_substream_block(b) for b in uploader.get_substream_blocks()]
124    return sorted(range_ids)
125
126
127class _ChunkUploader(object):  # pylint: disable=too-many-instance-attributes
128
129    def __init__(self, service, total_size, chunk_size, stream, parallel, encryptor=None, padder=None, **kwargs):
130        self.service = service
131        self.total_size = total_size
132        self.chunk_size = chunk_size
133        self.stream = stream
134        self.parallel = parallel
135
136        # Stream management
137        self.stream_start = stream.tell() if parallel else None
138        self.stream_lock = Lock() if parallel else None
139
140        # Progress feedback
141        self.progress_total = 0
142        self.progress_lock = Lock() if parallel else None
143
144        # Encryption
145        self.encryptor = encryptor
146        self.padder = padder
147        self.response_headers = None
148        self.etag = None
149        self.last_modified = None
150        self.request_options = kwargs
151
152    def get_chunk_streams(self):
153        index = 0
154        while True:
155            data = b""
156            read_size = self.chunk_size
157
158            # Buffer until we either reach the end of the stream or get a whole chunk.
159            while True:
160                if self.total_size:
161                    read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data)))
162                temp = self.stream.read(read_size)
163                if not isinstance(temp, six.binary_type):
164                    raise TypeError("Blob data should be of type bytes.")
165                data += temp or b""
166
167                # We have read an empty string and so are at the end
168                # of the buffer or we have read a full chunk.
169                if temp == b"" or len(data) == self.chunk_size:
170                    break
171
172            if len(data) == self.chunk_size:
173                if self.padder:
174                    data = self.padder.update(data)
175                if self.encryptor:
176                    data = self.encryptor.update(data)
177                yield index, data
178            else:
179                if self.padder:
180                    data = self.padder.update(data) + self.padder.finalize()
181                if self.encryptor:
182                    data = self.encryptor.update(data) + self.encryptor.finalize()
183                if data:
184                    yield index, data
185                break
186            index += len(data)
187
188    def process_chunk(self, chunk_data):
189        chunk_bytes = chunk_data[1]
190        chunk_offset = chunk_data[0]
191        return self._upload_chunk_with_progress(chunk_offset, chunk_bytes)
192
193    def _update_progress(self, length):
194        if self.progress_lock is not None:
195            with self.progress_lock:
196                self.progress_total += length
197        else:
198            self.progress_total += length
199
200    def _upload_chunk(self, chunk_offset, chunk_data):
201        raise NotImplementedError("Must be implemented by child class.")
202
203    def _upload_chunk_with_progress(self, chunk_offset, chunk_data):
204        range_id = self._upload_chunk(chunk_offset, chunk_data)
205        self._update_progress(len(chunk_data))
206        return range_id
207
208    def get_substream_blocks(self):
209        assert self.chunk_size is not None
210        lock = self.stream_lock
211        blob_length = self.total_size
212
213        if blob_length is None:
214            blob_length = get_length(self.stream)
215            if blob_length is None:
216                raise ValueError("Unable to determine content length of upload data.")
217
218        blocks = int(ceil(blob_length / (self.chunk_size * 1.0)))
219        last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size
220
221        for i in range(blocks):
222            index = i * self.chunk_size
223            length = last_block_size if i == blocks - 1 else self.chunk_size
224            yield ('BlockId{}'.format("%05d" % i), SubStream(self.stream, index, length, lock))
225
226    def process_substream_block(self, block_data):
227        return self._upload_substream_block_with_progress(block_data[0], block_data[1])
228
229    def _upload_substream_block(self, block_id, block_stream):
230        raise NotImplementedError("Must be implemented by child class.")
231
232    def _upload_substream_block_with_progress(self, block_id, block_stream):
233        range_id = self._upload_substream_block(block_id, block_stream)
234        self._update_progress(len(block_stream))
235        return range_id
236
237    def set_response_properties(self, resp):
238        self.etag = resp.etag
239        self.last_modified = resp.last_modified
240
241
242class BlockBlobChunkUploader(_ChunkUploader):
243
244    def __init__(self, *args, **kwargs):
245        kwargs.pop("modified_access_conditions", None)
246        super(BlockBlobChunkUploader, self).__init__(*args, **kwargs)
247        self.current_length = None
248
249    def _upload_chunk(self, chunk_offset, chunk_data):
250        # TODO: This is incorrect, but works with recording.
251        index = '{0:032d}'.format(chunk_offset)
252        block_id = encode_base64(url_quote(encode_base64(index)))
253        self.service.stage_block(
254            block_id,
255            len(chunk_data),
256            chunk_data,
257            data_stream_total=self.total_size,
258            upload_stream_current=self.progress_total,
259            **self.request_options
260        )
261        return index, block_id
262
263    def _upload_substream_block(self, block_id, block_stream):
264        try:
265            self.service.stage_block(
266                block_id,
267                len(block_stream),
268                block_stream,
269                data_stream_total=self.total_size,
270                upload_stream_current=self.progress_total,
271                **self.request_options
272            )
273        finally:
274            block_stream.close()
275        return block_id
276
277
278class PageBlobChunkUploader(_ChunkUploader):  # pylint: disable=abstract-method
279
280    def _is_chunk_empty(self, chunk_data):
281        # read until non-zero byte is encountered
282        # if reached the end without returning, then chunk_data is all 0's
283        return not any(bytearray(chunk_data))
284
285    def _upload_chunk(self, chunk_offset, chunk_data):
286        # avoid uploading the empty pages
287        if not self._is_chunk_empty(chunk_data):
288            chunk_end = chunk_offset + len(chunk_data) - 1
289            content_range = "bytes={0}-{1}".format(chunk_offset, chunk_end)
290            computed_md5 = None
291            self.response_headers = self.service.upload_pages(
292                chunk_data,
293                content_length=len(chunk_data),
294                transactional_content_md5=computed_md5,
295                range=content_range,
296                cls=return_response_headers,
297                data_stream_total=self.total_size,
298                upload_stream_current=self.progress_total,
299                **self.request_options
300            )
301
302            if not self.parallel and self.request_options.get('modified_access_conditions'):
303                self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
304
305
306class AppendBlobChunkUploader(_ChunkUploader):  # pylint: disable=abstract-method
307
308    def __init__(self, *args, **kwargs):
309        super(AppendBlobChunkUploader, self).__init__(*args, **kwargs)
310        self.current_length = None
311
312    def _upload_chunk(self, chunk_offset, chunk_data):
313        if self.current_length is None:
314            self.response_headers = self.service.append_block(
315                chunk_data,
316                content_length=len(chunk_data),
317                cls=return_response_headers,
318                data_stream_total=self.total_size,
319                upload_stream_current=self.progress_total,
320                **self.request_options
321            )
322            self.current_length = int(self.response_headers["blob_append_offset"])
323        else:
324            self.request_options['append_position_access_conditions'].append_position = \
325                self.current_length + chunk_offset
326            self.response_headers = self.service.append_block(
327                chunk_data,
328                content_length=len(chunk_data),
329                cls=return_response_headers,
330                data_stream_total=self.total_size,
331                upload_stream_current=self.progress_total,
332                **self.request_options
333            )
334
335
336class FileChunkUploader(_ChunkUploader):  # pylint: disable=abstract-method
337
338    def _upload_chunk(self, chunk_offset, chunk_data):
339        length = len(chunk_data)
340        chunk_end = chunk_offset + length - 1
341        response = self.service.upload_range(
342            chunk_data,
343            chunk_offset,
344            length,
345            data_stream_total=self.total_size,
346            upload_stream_current=self.progress_total,
347            **self.request_options
348        )
349        return 'bytes={0}-{1}'.format(chunk_offset, chunk_end), response
350
351
352class SubStream(IOBase):
353
354    def __init__(self, wrapped_stream, stream_begin_index, length, lockObj):
355        # Python 2.7: file-like objects created with open() typically support seek(), but are not
356        # derivations of io.IOBase and thus do not implement seekable().
357        # Python > 3.0: file-like objects created with open() are derived from io.IOBase.
358        try:
359            # only the main thread runs this, so there's no need grabbing the lock
360            wrapped_stream.seek(0, SEEK_CUR)
361        except:
362            raise ValueError("Wrapped stream must support seek().")
363
364        self._lock = lockObj
365        self._wrapped_stream = wrapped_stream
366        self._position = 0
367        self._stream_begin_index = stream_begin_index
368        self._length = length
369        self._buffer = BytesIO()
370
371        # we must avoid buffering more than necessary, and also not use up too much memory
372        # so the max buffer size is capped at 4MB
373        self._max_buffer_size = (
374            length if length < _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE else _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE
375        )
376        self._current_buffer_start = 0
377        self._current_buffer_size = 0
378        super(SubStream, self).__init__()
379
380    def __len__(self):
381        return self._length
382
383    def close(self):
384        if self._buffer:
385            self._buffer.close()
386        self._wrapped_stream = None
387        IOBase.close(self)
388
389    def fileno(self):
390        return self._wrapped_stream.fileno()
391
392    def flush(self):
393        pass
394
395    def read(self, size=None):
396        if self.closed:  # pylint: disable=using-constant-test
397            raise ValueError("Stream is closed.")
398
399        if size is None:
400            size = self._length - self._position
401
402        # adjust if out of bounds
403        if size + self._position >= self._length:
404            size = self._length - self._position
405
406        # return fast
407        if size == 0 or self._buffer.closed:
408            return b""
409
410        # attempt first read from the read buffer and update position
411        read_buffer = self._buffer.read(size)
412        bytes_read = len(read_buffer)
413        bytes_remaining = size - bytes_read
414        self._position += bytes_read
415
416        # repopulate the read buffer from the underlying stream to fulfill the request
417        # ensure the seek and read operations are done atomically (only if a lock is provided)
418        if bytes_remaining > 0:
419            with self._buffer:
420                # either read in the max buffer size specified on the class
421                # or read in just enough data for the current block/sub stream
422                current_max_buffer_size = min(self._max_buffer_size, self._length - self._position)
423
424                # lock is only defined if max_concurrency > 1 (parallel uploads)
425                if self._lock:
426                    with self._lock:
427                        # reposition the underlying stream to match the start of the data to read
428                        absolute_position = self._stream_begin_index + self._position
429                        self._wrapped_stream.seek(absolute_position, SEEK_SET)
430                        # If we can't seek to the right location, our read will be corrupted so fail fast.
431                        if self._wrapped_stream.tell() != absolute_position:
432                            raise IOError("Stream failed to seek to the desired location.")
433                        buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size)
434                else:
435                    buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size)
436
437            if buffer_from_stream:
438                # update the buffer with new data from the wrapped stream
439                # we need to note down the start position and size of the buffer, in case seek is performed later
440                self._buffer = BytesIO(buffer_from_stream)
441                self._current_buffer_start = self._position
442                self._current_buffer_size = len(buffer_from_stream)
443
444                # read the remaining bytes from the new buffer and update position
445                second_read_buffer = self._buffer.read(bytes_remaining)
446                read_buffer += second_read_buffer
447                self._position += len(second_read_buffer)
448
449        return read_buffer
450
451    def readable(self):
452        return True
453
454    def readinto(self, b):
455        raise UnsupportedOperation
456
457    def seek(self, offset, whence=0):
458        if whence is SEEK_SET:
459            start_index = 0
460        elif whence is SEEK_CUR:
461            start_index = self._position
462        elif whence is SEEK_END:
463            start_index = self._length
464            offset = -offset
465        else:
466            raise ValueError("Invalid argument for the 'whence' parameter.")
467
468        pos = start_index + offset
469
470        if pos > self._length:
471            pos = self._length
472        elif pos < 0:
473            pos = 0
474
475        # check if buffer is still valid
476        # if not, drop buffer
477        if pos < self._current_buffer_start or pos >= self._current_buffer_start + self._current_buffer_size:
478            self._buffer.close()
479            self._buffer = BytesIO()
480        else:  # if yes seek to correct position
481            delta = pos - self._current_buffer_start
482            self._buffer.seek(delta, SEEK_SET)
483
484        self._position = pos
485        return pos
486
487    def seekable(self):
488        return True
489
490    def tell(self):
491        return self._position
492
493    def write(self):
494        raise UnsupportedOperation
495
496    def writelines(self):
497        raise UnsupportedOperation
498
499    def writeable(self):
500        return False
501
502
503class IterStreamer(object):
504    """
505    File-like streaming iterator.
506    """
507
508    def __init__(self, generator, encoding="UTF-8"):
509        self.generator = generator
510        self.iterator = iter(generator)
511        self.leftover = b""
512        self.encoding = encoding
513
514    def __len__(self):
515        return self.generator.__len__()
516
517    def __iter__(self):
518        return self.iterator
519
520    def seekable(self):
521        return False
522
523    def __next__(self):
524        return next(self.iterator)
525
526    next = __next__  # Python 2 compatibility.
527
528    def tell(self, *args, **kwargs):
529        raise UnsupportedOperation("Data generator does not support tell.")
530
531    def seek(self, *args, **kwargs):
532        raise UnsupportedOperation("Data generator is unseekable.")
533
534    def read(self, size):
535        data = self.leftover
536        count = len(self.leftover)
537        try:
538            while count < size:
539                chunk = self.__next__()
540                if isinstance(chunk, six.text_type):
541                    chunk = chunk.encode(self.encoding)
542                data += chunk
543                count += len(chunk)
544        except StopIteration:
545            pass
546
547        if count > size:
548            self.leftover = data[size:]
549
550        return data[:size]
551