1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6# pylint: disable=no-self-use
7
8from concurrent import futures
9from io import (BytesIO, IOBase, SEEK_CUR, SEEK_END, SEEK_SET, UnsupportedOperation)
10from threading import Lock
11from itertools import islice
12from math import ceil
13
14import six
15
16from azure.core.tracing.common import with_current_context
17
18from . import encode_base64, url_quote
19from .request_handlers import get_length
20from .response_handlers import return_response_headers
21from .encryption import get_blob_encryptor_and_padder
22
23
24_LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE = 4 * 1024 * 1024
25_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM = "{0} should be a seekable file-like/io.IOBase type stream object."
26
27
28def _parallel_uploads(executor, uploader, pending, running):
29    range_ids = []
30    while True:
31        # Wait for some download to finish before adding a new one
32        done, running = futures.wait(running, return_when=futures.FIRST_COMPLETED)
33        range_ids.extend([chunk.result() for chunk in done])
34        try:
35            for _ in range(0, len(done)):
36                next_chunk = next(pending)
37                running.add(executor.submit(with_current_context(uploader), next_chunk))
38        except StopIteration:
39            break
40
41    # Wait for the remaining uploads to finish
42    done, _running = futures.wait(running)
43    range_ids.extend([chunk.result() for chunk in done])
44    return range_ids
45
46
47def upload_data_chunks(
48        service=None,
49        uploader_class=None,
50        total_size=None,
51        chunk_size=None,
52        max_concurrency=None,
53        stream=None,
54        validate_content=None,
55        encryption_options=None,
56        **kwargs):
57
58    if encryption_options:
59        encryptor, padder = get_blob_encryptor_and_padder(
60            encryption_options.get('cek'),
61            encryption_options.get('vector'),
62            uploader_class is not PageBlobChunkUploader)
63        kwargs['encryptor'] = encryptor
64        kwargs['padder'] = padder
65
66    parallel = max_concurrency > 1
67    if parallel and 'modified_access_conditions' in kwargs:
68        # Access conditions do not work with parallelism
69        kwargs['modified_access_conditions'] = None
70
71    uploader = uploader_class(
72        service=service,
73        total_size=total_size,
74        chunk_size=chunk_size,
75        stream=stream,
76        parallel=parallel,
77        validate_content=validate_content,
78        **kwargs)
79    if parallel:
80        executor = futures.ThreadPoolExecutor(max_concurrency)
81        upload_tasks = uploader.get_chunk_streams()
82        running_futures = [
83            executor.submit(with_current_context(uploader.process_chunk), u)
84            for u in islice(upload_tasks, 0, max_concurrency)
85        ]
86        range_ids = _parallel_uploads(executor, uploader.process_chunk, upload_tasks, running_futures)
87    else:
88        range_ids = [uploader.process_chunk(result) for result in uploader.get_chunk_streams()]
89    if any(range_ids):
90        return [r[1] for r in sorted(range_ids, key=lambda r: r[0])]
91    return uploader.response_headers
92
93
94def upload_substream_blocks(
95        service=None,
96        uploader_class=None,
97        total_size=None,
98        chunk_size=None,
99        max_concurrency=None,
100        stream=None,
101        **kwargs):
102    parallel = max_concurrency > 1
103    if parallel and 'modified_access_conditions' in kwargs:
104        # Access conditions do not work with parallelism
105        kwargs['modified_access_conditions'] = None
106    uploader = uploader_class(
107        service=service,
108        total_size=total_size,
109        chunk_size=chunk_size,
110        stream=stream,
111        parallel=parallel,
112        **kwargs)
113
114    if parallel:
115        executor = futures.ThreadPoolExecutor(max_concurrency)
116        upload_tasks = uploader.get_substream_blocks()
117        running_futures = [
118            executor.submit(with_current_context(uploader.process_substream_block), u)
119            for u in islice(upload_tasks, 0, max_concurrency)
120        ]
121        range_ids = _parallel_uploads(executor, uploader.process_substream_block, upload_tasks, running_futures)
122    else:
123        range_ids = [uploader.process_substream_block(b) for b in uploader.get_substream_blocks()]
124    return sorted(range_ids)
125
126
127class _ChunkUploader(object):  # pylint: disable=too-many-instance-attributes
128
129    def __init__(self, service, total_size, chunk_size, stream, parallel, encryptor=None, padder=None, **kwargs):
130        self.service = service
131        self.total_size = total_size
132        self.chunk_size = chunk_size
133        self.stream = stream
134        self.parallel = parallel
135
136        # Stream management
137        self.stream_start = stream.tell() if parallel else None
138        self.stream_lock = Lock() if parallel else None
139
140        # Progress feedback
141        self.progress_total = 0
142        self.progress_lock = Lock() if parallel else None
143
144        # Encryption
145        self.encryptor = encryptor
146        self.padder = padder
147        self.response_headers = None
148        self.etag = None
149        self.last_modified = None
150        self.request_options = kwargs
151
152    def get_chunk_streams(self):
153        index = 0
154        while True:
155            data = b""
156            read_size = self.chunk_size
157
158            # Buffer until we either reach the end of the stream or get a whole chunk.
159            while True:
160                if self.total_size:
161                    read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data)))
162                temp = self.stream.read(read_size)
163                if not isinstance(temp, six.binary_type):
164                    raise TypeError("Blob data should be of type bytes.")
165                data += temp or b""
166
167                # We have read an empty string and so are at the end
168                # of the buffer or we have read a full chunk.
169                if temp == b"" or len(data) == self.chunk_size:
170                    break
171
172            if len(data) == self.chunk_size:
173                if self.padder:
174                    data = self.padder.update(data)
175                if self.encryptor:
176                    data = self.encryptor.update(data)
177                yield index, data
178            else:
179                if self.padder:
180                    data = self.padder.update(data) + self.padder.finalize()
181                if self.encryptor:
182                    data = self.encryptor.update(data) + self.encryptor.finalize()
183                if data:
184                    yield index, data
185                break
186            index += len(data)
187
188    def process_chunk(self, chunk_data):
189        chunk_bytes = chunk_data[1]
190        chunk_offset = chunk_data[0]
191        return self._upload_chunk_with_progress(chunk_offset, chunk_bytes)
192
193    def _update_progress(self, length):
194        if self.progress_lock is not None:
195            with self.progress_lock:
196                self.progress_total += length
197        else:
198            self.progress_total += length
199
200    def _upload_chunk(self, chunk_offset, chunk_data):
201        raise NotImplementedError("Must be implemented by child class.")
202
203    def _upload_chunk_with_progress(self, chunk_offset, chunk_data):
204        range_id = self._upload_chunk(chunk_offset, chunk_data)
205        self._update_progress(len(chunk_data))
206        return range_id
207
208    def get_substream_blocks(self):
209        assert self.chunk_size is not None
210        lock = self.stream_lock
211        blob_length = self.total_size
212
213        if blob_length is None:
214            blob_length = get_length(self.stream)
215            if blob_length is None:
216                raise ValueError("Unable to determine content length of upload data.")
217
218        blocks = int(ceil(blob_length / (self.chunk_size * 1.0)))
219        last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size
220
221        for i in range(blocks):
222            index = i * self.chunk_size
223            length = last_block_size if i == blocks - 1 else self.chunk_size
224            yield ('BlockId{}'.format("%05d" % i), SubStream(self.stream, index, length, lock))
225
226    def process_substream_block(self, block_data):
227        return self._upload_substream_block_with_progress(block_data[0], block_data[1])
228
229    def _upload_substream_block(self, block_id, block_stream):
230        raise NotImplementedError("Must be implemented by child class.")
231
232    def _upload_substream_block_with_progress(self, block_id, block_stream):
233        range_id = self._upload_substream_block(block_id, block_stream)
234        self._update_progress(len(block_stream))
235        return range_id
236
237    def set_response_properties(self, resp):
238        self.etag = resp.etag
239        self.last_modified = resp.last_modified
240
241
242class BlockBlobChunkUploader(_ChunkUploader):
243
244    def __init__(self, *args, **kwargs):
245        kwargs.pop("modified_access_conditions", None)
246        super(BlockBlobChunkUploader, self).__init__(*args, **kwargs)
247        self.current_length = None
248
249    def _upload_chunk(self, chunk_offset, chunk_data):
250        # TODO: This is incorrect, but works with recording.
251        index = '{0:032d}'.format(chunk_offset)
252        block_id = encode_base64(url_quote(encode_base64(index)))
253        self.service.stage_block(
254            block_id,
255            len(chunk_data),
256            chunk_data,
257            data_stream_total=self.total_size,
258            upload_stream_current=self.progress_total,
259            **self.request_options
260        )
261        return index, block_id
262
263    def _upload_substream_block(self, block_id, block_stream):
264        try:
265            self.service.stage_block(
266                block_id,
267                len(block_stream),
268                block_stream,
269                data_stream_total=self.total_size,
270                upload_stream_current=self.progress_total,
271                **self.request_options
272            )
273        finally:
274            block_stream.close()
275        return block_id
276
277
278class PageBlobChunkUploader(_ChunkUploader):  # pylint: disable=abstract-method
279
280    def _is_chunk_empty(self, chunk_data):
281        # read until non-zero byte is encountered
282        # if reached the end without returning, then chunk_data is all 0's
283        return not any(bytearray(chunk_data))
284
285    def _upload_chunk(self, chunk_offset, chunk_data):
286        # avoid uploading the empty pages
287        if not self._is_chunk_empty(chunk_data):
288            chunk_end = chunk_offset + len(chunk_data) - 1
289            content_range = "bytes={0}-{1}".format(chunk_offset, chunk_end)
290            computed_md5 = None
291            self.response_headers = self.service.upload_pages(
292                chunk_data,
293                content_length=len(chunk_data),
294                transactional_content_md5=computed_md5,
295                range=content_range,
296                cls=return_response_headers,
297                data_stream_total=self.total_size,
298                upload_stream_current=self.progress_total,
299                **self.request_options
300            )
301
302            if not self.parallel and self.request_options.get('modified_access_conditions'):
303                self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
304
305
306class AppendBlobChunkUploader(_ChunkUploader):  # pylint: disable=abstract-method
307
308    def __init__(self, *args, **kwargs):
309        super(AppendBlobChunkUploader, self).__init__(*args, **kwargs)
310        self.current_length = None
311
312    def _upload_chunk(self, chunk_offset, chunk_data):
313        if self.current_length is None:
314            self.response_headers = self.service.append_block(
315                chunk_data,
316                content_length=len(chunk_data),
317                cls=return_response_headers,
318                data_stream_total=self.total_size,
319                upload_stream_current=self.progress_total,
320                **self.request_options
321            )
322            self.current_length = int(self.response_headers["blob_append_offset"])
323        else:
324            self.request_options['append_position_access_conditions'].append_position = \
325                self.current_length + chunk_offset
326            self.response_headers = self.service.append_block(
327                chunk_data,
328                content_length=len(chunk_data),
329                cls=return_response_headers,
330                data_stream_total=self.total_size,
331                upload_stream_current=self.progress_total,
332                **self.request_options
333            )
334
335
336class DataLakeFileChunkUploader(_ChunkUploader):  # pylint: disable=abstract-method
337
338    def _upload_chunk(self, chunk_offset, chunk_data):
339        # avoid uploading the empty pages
340        self.response_headers = self.service.append_data(
341            body=chunk_data,
342            position=chunk_offset,
343            content_length=len(chunk_data),
344            cls=return_response_headers,
345            data_stream_total=self.total_size,
346            upload_stream_current=self.progress_total,
347            **self.request_options
348        )
349
350        if not self.parallel and self.request_options.get('modified_access_conditions'):
351            self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
352
353
354class FileChunkUploader(_ChunkUploader):  # pylint: disable=abstract-method
355
356    def _upload_chunk(self, chunk_offset, chunk_data):
357        length = len(chunk_data)
358        chunk_end = chunk_offset + length - 1
359        response = self.service.upload_range(
360            chunk_data,
361            chunk_offset,
362            length,
363            data_stream_total=self.total_size,
364            upload_stream_current=self.progress_total,
365            **self.request_options
366        )
367        return 'bytes={0}-{1}'.format(chunk_offset, chunk_end), response
368
369
370class SubStream(IOBase):
371
372    def __init__(self, wrapped_stream, stream_begin_index, length, lockObj):
373        # Python 2.7: file-like objects created with open() typically support seek(), but are not
374        # derivations of io.IOBase and thus do not implement seekable().
375        # Python > 3.0: file-like objects created with open() are derived from io.IOBase.
376        try:
377            # only the main thread runs this, so there's no need grabbing the lock
378            wrapped_stream.seek(0, SEEK_CUR)
379        except:
380            raise ValueError("Wrapped stream must support seek().")
381
382        self._lock = lockObj
383        self._wrapped_stream = wrapped_stream
384        self._position = 0
385        self._stream_begin_index = stream_begin_index
386        self._length = length
387        self._buffer = BytesIO()
388
389        # we must avoid buffering more than necessary, and also not use up too much memory
390        # so the max buffer size is capped at 4MB
391        self._max_buffer_size = (
392            length if length < _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE else _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE
393        )
394        self._current_buffer_start = 0
395        self._current_buffer_size = 0
396        super(SubStream, self).__init__()
397
398    def __len__(self):
399        return self._length
400
401    def close(self):
402        if self._buffer:
403            self._buffer.close()
404        self._wrapped_stream = None
405        IOBase.close(self)
406
407    def fileno(self):
408        return self._wrapped_stream.fileno()
409
410    def flush(self):
411        pass
412
413    def read(self, size=None):
414        if self.closed:  # pylint: disable=using-constant-test
415            raise ValueError("Stream is closed.")
416
417        if size is None:
418            size = self._length - self._position
419
420        # adjust if out of bounds
421        if size + self._position >= self._length:
422            size = self._length - self._position
423
424        # return fast
425        if size == 0 or self._buffer.closed:
426            return b""
427
428        # attempt first read from the read buffer and update position
429        read_buffer = self._buffer.read(size)
430        bytes_read = len(read_buffer)
431        bytes_remaining = size - bytes_read
432        self._position += bytes_read
433
434        # repopulate the read buffer from the underlying stream to fulfill the request
435        # ensure the seek and read operations are done atomically (only if a lock is provided)
436        if bytes_remaining > 0:
437            with self._buffer:
438                # either read in the max buffer size specified on the class
439                # or read in just enough data for the current block/sub stream
440                current_max_buffer_size = min(self._max_buffer_size, self._length - self._position)
441
442                # lock is only defined if max_concurrency > 1 (parallel uploads)
443                if self._lock:
444                    with self._lock:
445                        # reposition the underlying stream to match the start of the data to read
446                        absolute_position = self._stream_begin_index + self._position
447                        self._wrapped_stream.seek(absolute_position, SEEK_SET)
448                        # If we can't seek to the right location, our read will be corrupted so fail fast.
449                        if self._wrapped_stream.tell() != absolute_position:
450                            raise IOError("Stream failed to seek to the desired location.")
451                        buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size)
452                else:
453                    buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size)
454
455            if buffer_from_stream:
456                # update the buffer with new data from the wrapped stream
457                # we need to note down the start position and size of the buffer, in case seek is performed later
458                self._buffer = BytesIO(buffer_from_stream)
459                self._current_buffer_start = self._position
460                self._current_buffer_size = len(buffer_from_stream)
461
462                # read the remaining bytes from the new buffer and update position
463                second_read_buffer = self._buffer.read(bytes_remaining)
464                read_buffer += second_read_buffer
465                self._position += len(second_read_buffer)
466
467        return read_buffer
468
469    def readable(self):
470        return True
471
472    def readinto(self, b):
473        raise UnsupportedOperation
474
475    def seek(self, offset, whence=0):
476        if whence is SEEK_SET:
477            start_index = 0
478        elif whence is SEEK_CUR:
479            start_index = self._position
480        elif whence is SEEK_END:
481            start_index = self._length
482            offset = -offset
483        else:
484            raise ValueError("Invalid argument for the 'whence' parameter.")
485
486        pos = start_index + offset
487
488        if pos > self._length:
489            pos = self._length
490        elif pos < 0:
491            pos = 0
492
493        # check if buffer is still valid
494        # if not, drop buffer
495        if pos < self._current_buffer_start or pos >= self._current_buffer_start + self._current_buffer_size:
496            self._buffer.close()
497            self._buffer = BytesIO()
498        else:  # if yes seek to correct position
499            delta = pos - self._current_buffer_start
500            self._buffer.seek(delta, SEEK_SET)
501
502        self._position = pos
503        return pos
504
505    def seekable(self):
506        return True
507
508    def tell(self):
509        return self._position
510
511    def write(self):
512        raise UnsupportedOperation
513
514    def writelines(self):
515        raise UnsupportedOperation
516
517    def writeable(self):
518        return False
519
520
521class IterStreamer(object):
522    """
523    File-like streaming iterator.
524    """
525
526    def __init__(self, generator, encoding="UTF-8"):
527        self.generator = generator
528        self.iterator = iter(generator)
529        self.leftover = b""
530        self.encoding = encoding
531
532    def __len__(self):
533        return self.generator.__len__()
534
535    def __iter__(self):
536        return self.iterator
537
538    def seekable(self):
539        return False
540
541    def __next__(self):
542        return next(self.iterator)
543
544    next = __next__  # Python 2 compatibility.
545
546    def tell(self, *args, **kwargs):
547        raise UnsupportedOperation("Data generator does not support tell.")
548
549    def seek(self, *args, **kwargs):
550        raise UnsupportedOperation("Data generator is unseekable.")
551
552    def read(self, size):
553        data = self.leftover
554        count = len(self.leftover)
555        try:
556            while count < size:
557                chunk = self.__next__()
558                if isinstance(chunk, six.text_type):
559                    chunk = chunk.encode(self.encoding)
560                data += chunk
561                count += len(chunk)
562        except StopIteration:
563            pass
564
565        if count > size:
566            self.leftover = data[size:]
567
568        return data[:size]
569