1# -*- coding: utf-8 -*- #
2# Copyright 2020 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Task for daisy-chain copies.
17
18Typically executed in a task iterator:
19googlecloudsdk.command_lib.storage.tasks.task_executor.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import unicode_literals
25
26import collections
27import io
28import os
29import threading
30
31from googlecloudsdk.api_lib.storage import api_factory
32from googlecloudsdk.api_lib.storage import cloud_api
33from googlecloudsdk.command_lib.storage import errors
34from googlecloudsdk.command_lib.storage import progress_callbacks
35from googlecloudsdk.command_lib.storage import storage_url
36from googlecloudsdk.command_lib.storage.tasks import task
37from googlecloudsdk.command_lib.storage.tasks import task_status
38
39
40MAX_ALLOWED_READ_SIZE = 100 * 1024 * 1024  # 100 MiB
41MAX_BUFFER_QUEUE_SIZE = 100
42# TODO(b/174075495) Determine the max size based on the destination scheme.
43QUEUE_ITEM_MAX_SIZE = 8 * 1024  # 8 KiB
44
45
46class _AbruptShutdownError(errors.Error):
47  """Raised if a thread is terminated because of an error in another thread."""
48
49
50class _WritableStream:
51  """A write-only stream class that writes to the buffer queue."""
52
53  def __init__(self, buffer_queue, buffer_condition, shutdown_event):
54    """Initializes WritableStream.
55
56    Args:
57      buffer_queue (collections.deque): A queue where the data gets written.
58      buffer_condition (threading.Condition): The condition object to wait on if
59        the buffer is full.
60      shutdown_event (threading.Event): Used for signaling the thread to
61        terminate.
62    """
63    self._buffer_queue = buffer_queue
64    self._buffer_condition = buffer_condition
65    self._shutdown_event = shutdown_event
66
67  def write(self, data):
68    """Writes data to the buffer queue.
69
70    This method writes the data in chunks of QUEUE_ITEM_MAX_SIZE. In most cases,
71    the read operation is performed with size=QUEUE_ITEM_MAX_SIZE.
72    Splitting the data in QUEUE_ITEM_MAX_SIZE chunks improves the performance.
73
74    This method will be blocked if MAX_BUFFER_QUEUE_SIZE is reached to avoid
75    writing all the data in-memory.
76
77    Args:
78      data (bytes): The bytes that should be added to the queue.
79
80    Raises:
81      _AbruptShutdownError: If self._shudown_event was set.
82    """
83    start = 0
84    end = min(start + QUEUE_ITEM_MAX_SIZE, len(data))
85    while start < len(data):
86      with self._buffer_condition:
87        while (len(self._buffer_queue) >= MAX_BUFFER_QUEUE_SIZE and
88               not self._shutdown_event.is_set()):
89          self._buffer_condition.wait()
90
91        if self._shutdown_event.is_set():
92          raise _AbruptShutdownError()
93
94        self._buffer_queue.append(data[start:end])
95        start = end
96        end = min(start + QUEUE_ITEM_MAX_SIZE, len(data))
97        self._buffer_condition.notify_all()
98
99
100class _ReadableStream:
101  """A read-only stream that reads from the buffer queue."""
102
103  def __init__(self, buffer_queue, buffer_condition, shutdown_event,
104               end_position):
105    """Initializes ReadableStream.
106
107    Args:
108      buffer_queue (collections.deque): The underlying queue from which the data
109        gets read.
110      buffer_condition (threading.Condition): The condition object to wait on if
111        the buffer is empty.
112      shutdown_event (threading.Event): Used for signaling the thread to
113        terminate.
114      end_position (int): Position at which the stream reading stops. This is
115        usually the total size of the data that gets read.
116    """
117    self._buffer_queue = buffer_queue
118    self._buffer_condition = buffer_condition
119    self._end_position = end_position
120    self._shutdown_event = shutdown_event
121    self._position = 0
122    self._unused_data_from_previous_read = b''
123
124  def read(self, size=-1):
125    """Reads size bytes from the buffer queue and returns it.
126
127    This method will be blocked if the buffer_queue is empty.
128    If size > length of data available, the entire data is sent over.
129
130    Args:
131      size (int): The number of bytes to be read.
132
133    Returns:
134      Bytes of length 'size'. May return bytes of length less than the size
135        if there are no more bytes left to be read.
136
137    Raises:
138      _AbruptShutdownError: If self._shudown_event was set.
139      storage.errors.Error: If size is not within the allowed range of
140        [-1, MAX_ALLOWED_READ_SIZE] OR
141        If size is -1 but the object size is greater than MAX_ALLOWED_READ_SIZE.
142    """
143    if size == 0:
144      return b''
145
146    if size > MAX_ALLOWED_READ_SIZE:
147      raise errors.Error(
148          'Invalid HTTP read size {} during daisy chain operation, expected'
149          ' -1 <= size <= {} bytes.'.format(size, MAX_ALLOWED_READ_SIZE))
150
151    if size == -1:
152      # This indicates that we have to read the entire object at once.
153      if self._end_position <= MAX_ALLOWED_READ_SIZE:
154        chunk_size = self._end_position
155      else:
156        raise errors.Error('Read with size=-1 is not allowed for object'
157                           ' size > {} bytes to prevent reading large objects'
158                           ' in-memory.'.format(MAX_ALLOWED_READ_SIZE))
159    else:
160      chunk_size = size
161
162    result = io.BytesIO()
163    bytes_read = 0
164
165    while bytes_read < chunk_size and self._position < self._end_position:
166      if not self._unused_data_from_previous_read:
167        with self._buffer_condition:
168          while not self._buffer_queue and not self._shutdown_event.is_set():
169            self._buffer_condition.wait()
170
171          # The shutdown_event needs to be checked before the data is fetched
172          # from the buffer.
173          if self._shutdown_event.is_set():
174            raise _AbruptShutdownError()
175
176          data = self._buffer_queue.popleft()
177          self._buffer_condition.notify_all()
178      else:
179        # Data is already present from previous read.
180        if self._shutdown_event.is_set():
181          raise _AbruptShutdownError()
182        data = self._unused_data_from_previous_read
183
184      if bytes_read + len(data) > chunk_size:
185        self._unused_data_from_previous_read = data[chunk_size - bytes_read:]
186        data_to_return = data[:chunk_size - bytes_read]
187      else:
188        self._unused_data_from_previous_read = b''
189        data_to_return = data
190      result.write(data_to_return)
191      bytes_read += len(data_to_return)
192      self._position += len(data_to_return)
193
194    return result.getvalue()
195
196  def seek(self, offset, whence=os.SEEK_SET):
197    """Checks if seek was requested for the last position.
198
199    Ideally, seek changes the stream position to the given byte offset. Since
200    this stream is a non-seekable stream, seek is actually not required.
201    We implement this method as a hacky way to provide additional
202    integrity checking.
203
204    Apitools by default calls seek at the end of the transfer
205    https://github.com/google/apitools/blob/ca2094556531d61e741dc2954fdfccbc650cdc32/apitools/base/py/transfer.py#L986
206    to determine if it has read everything from the stream.
207
208    This method will raise an error if seek was called for any position except
209    for the last position and hence will fail if all the bytes were not
210    copied over as expected.
211
212    Args:
213      offset (int): Defines the position realative to the `whence` where the
214        current position of the stream should be moved.
215      whence (int): The reference relative to which offset is interpreted.
216          Values for whence are: os.SEEK_SET or 0 - start of the stream
217            (thedefault). os.SEEK_END or 2 - end of the stream. We do not
218            support other os.SEEK_* constants.
219
220    Returns:
221      (int) The current position.
222
223    Raises:
224      ValueError: If seek is not called for the last position.
225
226    """
227    if self._position != self._end_position:
228      raise ValueError(
229          'Seek called before all the bytes were read. Current positon: {},'
230          ' Last position {}.'.format(self._position, self._end_position))
231    if whence == os.SEEK_END:
232      if offset:
233        raise ValueError('Non-zero offset from os.SEEK_END is not allowed.'
234                         'Offset: {}.'.format(offset))
235    elif whence == os.SEEK_SET:
236      # Relative to the start of the stream, the offset should be the size
237      # of the stream
238      if offset != self._end_position:
239        raise ValueError(
240            'Seek relative to the beginning is only allowed for the last'
241            ' position {}. Offset: {}. Current position: {}.'.format(
242                self._end_position, offset, self._position))
243    else:
244      raise ValueError('Seek is only supported for os.SEEK_END and'
245                       ' os.SEEK_SET.')
246    return self._position
247
248  def seekable(self):
249    """Returns False, since this stream is not meant to be seekable."""
250    del self  # Unused.
251    return False
252
253  def tell(self):
254    """Returns the current position."""
255    return self._position
256
257
258class QueuingStream:
259  """Interface to a bidirectional buffer to read and write simultaneously.
260
261  Attributes:
262    buffer_queue (collections.deque): The underlying queue that acts like a
263      buffer for the streams
264    buffer_condition (threading.Condition): The condition object used for
265      waiting based on the underlying buffer_queue state.
266      All threads waiting on this condition are notified when data is added or
267      removed from buffer_queue. Streams that write to the buffer wait on this
268      condition until the buffer has space, and streams that read from the
269      buffer wait on this condition until the buffer has data.
270    shutdown_event (threading.Event): Used for signaling the operations to
271      terminate.
272    writable_stream (_WritableStream): Stream that writes to the buffer.
273    readable_stream (_ReadableStream): Stream that reads from the buffer.
274    exception_raised (Exception): Stores the Exception instance responsible for
275      termination of the operation.
276  """
277
278  def __init__(self, object_size=None):
279    """Intializes QueuingStream.
280
281    Args:
282      object_size (int): The size of the source object.
283    """
284    self.buffer_queue = collections.deque()
285    self.buffer_condition = threading.Condition()
286    self.shutdown_event = threading.Event()
287    self.writable_stream = _WritableStream(self.buffer_queue,
288                                           self.buffer_condition,
289                                           self.shutdown_event)
290    self.readable_stream = _ReadableStream(
291        self.buffer_queue,
292        self.buffer_condition,
293        self.shutdown_event,
294        object_size,
295    )
296    self.exception_raised = None
297
298  def shutdown(self, error):
299    """Sets the shutdown event and stores the error to re-raise later.
300
301    Args:
302      error (Exception): The error responsible for triggering shutdown.
303    """
304    self.shutdown_event.set()
305    with self.buffer_condition:
306      self.buffer_condition.notify_all()
307      self.exception_raised = error
308
309
310class DaisyChainCopyTask(task.Task):
311  """Represents an operation to copy by downloading and uploading.
312
313  This task downloads from one cloud location and uplaods to another cloud
314  location by keeping an in-memory buffer.
315  """
316
317  def __init__(self, source_resource, destination_resource):
318    """Initializes task.
319
320    Args:
321      source_resource (resource_reference.ObjectResource): Must
322          contain the full object path of existing object.
323          Directories will not be accepted.
324      destination_resource (resource_reference.UnknownResource): Must
325          contain the full object path. Object may not exist yet.
326          Existing objects at the this location will be overwritten.
327          Directories will not be accepted.
328    """
329    super().__init__()
330    if (not isinstance(source_resource.storage_url, storage_url.CloudUrl)
331        or not isinstance(destination_resource.storage_url,
332                          storage_url.CloudUrl)):
333      raise ValueError('DaisyChainCopyTask is for copies between cloud'
334                       ' providers.')
335
336    self._source_resource = source_resource
337    self._destination_resource = destination_resource
338    self.parallel_processing_key = (
339        self._destination_resource.storage_url.url_string)
340
341  def _run_download(self, daisy_chain_stream):
342    """Performs the download operation."""
343    client = api_factory.get_api(self._source_resource.storage_url.scheme)
344    try:
345      client.download_object(self._source_resource,
346                             daisy_chain_stream.writable_stream)
347    except _AbruptShutdownError:
348      # Shutdown caused by interuption from another thread.
349      pass
350    except Exception as e:  # pylint: disable=broad-except
351      # The stack trace of the exception raised in the thread is not visible
352      # in the caller thread. Hence we catch any exception so that we can
353      # re-raise them from the parent thread.
354      daisy_chain_stream.shutdown(e)
355
356  def execute(self, task_status_queue=None):
357    """Copies file by downloading and uploading in parallel."""
358    # TODO (b/168712813): Add option to use the Data Transfer component.
359
360    daisy_chain_stream = QueuingStream(self._source_resource.size)
361
362    # Perform download in a separate thread so that upload can be performed
363    # simultaneously.
364    download_thread = threading.Thread(
365        target=self._run_download, args=(daisy_chain_stream,))
366    download_thread.start()
367
368    destination_client = api_factory.get_api(
369        self._destination_resource.storage_url.scheme)
370    request_config = cloud_api.RequestConfig(size=self._source_resource.size)
371    progress_callback = progress_callbacks.FilesAndBytesProgressCallback(
372        status_queue=task_status_queue,
373        size=self._source_resource.size,
374        source_url=self._source_resource.storage_url,
375        destination_url=self._destination_resource.storage_url,
376        operation_name=task_status.OperationName.DAISY_CHAIN_COPYING,
377        process_id=os.getpid(),
378        thread_id=threading.get_ident(),
379    )
380
381    try:
382      destination_client.upload_object(
383          daisy_chain_stream.readable_stream,
384          self._destination_resource,
385          request_config=request_config,
386          progress_callback=progress_callback)
387    except _AbruptShutdownError:
388      # Not raising daisy_chain_stream.exception_raised here because we want
389      # to wait for the download thread to finish.
390      pass
391    except Exception as e:  # pylint: disable=broad-except
392      # For all the other errors raised during upload, we want to to make
393      # sure that the download thread is terminated before we re-reaise.
394      # Hence we catch any exception and store it to be re-raised later.
395      daisy_chain_stream.shutdown(e)
396
397    download_thread.join()
398    if daisy_chain_stream.exception_raised:
399      raise daisy_chain_stream.exception_raised
400