1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2019 Radim Rehurek <me@radimrehurek.com>
4#
5# This code is distributed under the terms and conditions
6# from the MIT License (MIT).
7#
8
9"""Implements file-like objects for reading and writing to/from GCS."""
10
11import io
12import logging
13
14try:
15    import google.cloud.exceptions
16    import google.cloud.storage
17    import google.auth.transport.requests
18except ImportError:
19    MISSING_DEPS = True
20
21import smart_open.bytebuffer
22import smart_open.utils
23
24from smart_open import constants
25
26logger = logging.getLogger(__name__)
27
28_BINARY_TYPES = (bytes, bytearray, memoryview)
29"""Allowed binary buffer types for writing to the underlying GCS stream"""
30
31_UNKNOWN = '*'
32
33SCHEME = "gs"
34"""Supported scheme for GCS"""
35
36_MIN_MIN_PART_SIZE = _REQUIRED_CHUNK_MULTIPLE = 256 * 1024
37"""Google requires you to upload in multiples of 256 KB, except for the last part."""
38
39_DEFAULT_MIN_PART_SIZE = 50 * 1024**2
40"""Default minimum part size for GCS multipart uploads"""
41
42DEFAULT_BUFFER_SIZE = 256 * 1024
43"""Default buffer size for working with GCS"""
44
45_UPLOAD_INCOMPLETE_STATUS_CODES = (308, )
46_UPLOAD_COMPLETE_STATUS_CODES = (200, 201)
47
48
49def _make_range_string(start, stop=None, end=None):
50    #
51    # GCS seems to violate RFC-2616 (see utils.make_range_string), so we
52    # need a separate implementation.
53    #
54    # https://cloud.google.com/storage/docs/xml-api/resumable-upload#step_3upload_the_file_blocks
55    #
56    if end is None:
57        end = _UNKNOWN
58    if stop is None:
59        return 'bytes %d-/%s' % (start, end)
60    return 'bytes %d-%d/%s' % (start, stop, end)
61
62
63class UploadFailedError(Exception):
64    def __init__(self, message, status_code, text):
65        """Raise when a multi-part upload to GCS returns a failed response status code.
66
67        Parameters
68        ----------
69        message: str
70            The error message to display.
71        status_code: int
72            The status code returned from the upload response.
73        text: str
74            The text returned from the upload response.
75
76        """
77        super(UploadFailedError, self).__init__(message)
78        self.status_code = status_code
79        self.text = text
80
81
82def _fail(response, part_num, content_length, total_size, headers):
83    status_code = response.status_code
84    response_text = response.text
85    total_size_gb = total_size / 1024.0 ** 3
86
87    msg = (
88        "upload failed (status code: %(status_code)d, response text: %(response_text)s), "
89        "part #%(part_num)d, %(total_size)d bytes (total %(total_size_gb).3fGB), headers: %(headers)r"
90    ) % locals()
91    raise UploadFailedError(msg, response.status_code, response.text)
92
93
94def parse_uri(uri_as_string):
95    sr = smart_open.utils.safe_urlsplit(uri_as_string)
96    assert sr.scheme == SCHEME
97    bucket_id = sr.netloc
98    blob_id = sr.path.lstrip('/')
99    return dict(scheme=SCHEME, bucket_id=bucket_id, blob_id=blob_id)
100
101
102def open_uri(uri, mode, transport_params):
103    parsed_uri = parse_uri(uri)
104    kwargs = smart_open.utils.check_kwargs(open, transport_params)
105    return open(parsed_uri['bucket_id'], parsed_uri['blob_id'], mode, **kwargs)
106
107
108def open(
109        bucket_id,
110        blob_id,
111        mode,
112        buffer_size=DEFAULT_BUFFER_SIZE,
113        min_part_size=_MIN_MIN_PART_SIZE,
114        client=None,  # type: google.cloud.storage.Client
115        blob_properties=None
116        ):
117    """Open an GCS blob for reading or writing.
118
119    Parameters
120    ----------
121    bucket_id: str
122        The name of the bucket this object resides in.
123    blob_id: str
124        The name of the blob within the bucket.
125    mode: str
126        The mode for opening the object.  Must be either "rb" or "wb".
127    buffer_size: int, optional
128        The buffer size to use when performing I/O. For reading only.
129    min_part_size: int, optional
130        The minimum part size for multipart uploads.  For writing only.
131    client: google.cloud.storage.Client, optional
132        The GCS client to use when working with google-cloud-storage.
133    blob_properties: dict, optional
134        Set properties on blob before writing.  For writing only.
135
136    """
137    if mode == constants.READ_BINARY:
138        fileobj = Reader(
139            bucket_id,
140            blob_id,
141            buffer_size=buffer_size,
142            line_terminator=constants.BINARY_NEWLINE,
143            client=client,
144        )
145    elif mode == constants.WRITE_BINARY:
146        fileobj = Writer(
147            bucket_id,
148            blob_id,
149            min_part_size=min_part_size,
150            client=client,
151            blob_properties=blob_properties,
152        )
153    else:
154        raise NotImplementedError('GCS support for mode %r not implemented' % mode)
155
156    fileobj.name = blob_id
157    return fileobj
158
159
160class _RawReader(object):
161    """Read an GCS object."""
162
163    def __init__(self, gcs_blob, size):
164        # type: (google.cloud.storage.Blob, int) -> None
165        self._blob = gcs_blob
166        self._size = size
167        self._position = 0
168
169    def seek(self, position):
170        """Seek to the specified position (byte offset) in the GCS key.
171
172        :param int position: The byte offset from the beginning of the key.
173
174        Returns the position after seeking.
175        """
176        self._position = position
177        return self._position
178
179    def read(self, size=-1):
180        if self._position >= self._size:
181            return b''
182        binary = self._download_blob_chunk(size)
183        self._position += len(binary)
184        return binary
185
186    def _download_blob_chunk(self, size):
187        start = position = self._position
188        if position == self._size:
189            #
190            # When reading, we can't seek to the first byte of an empty file.
191            # Similarly, we can't seek past the last byte.  Do nothing here.
192            #
193            binary = b''
194        elif size == -1:
195            binary = self._blob.download_as_bytes(start=start)
196        else:
197            end = position + size
198            binary = self._blob.download_as_bytes(start=start, end=end)
199        return binary
200
201
202class Reader(io.BufferedIOBase):
203    """Reads bytes from GCS.
204
205    Implements the io.BufferedIOBase interface of the standard library.
206
207    :raises google.cloud.exceptions.NotFound: Raised when the blob to read from does not exist.
208
209    """
210    def __init__(
211            self,
212            bucket,
213            key,
214            buffer_size=DEFAULT_BUFFER_SIZE,
215            line_terminator=constants.BINARY_NEWLINE,
216            client=None,  # type: google.cloud.storage.Client
217    ):
218        if client is None:
219            client = google.cloud.storage.Client()
220
221        self._blob = client.bucket(bucket).get_blob(key)  # type: google.cloud.storage.Blob
222
223        if self._blob is None:
224            raise google.cloud.exceptions.NotFound('blob %s not found in %s' % (key, bucket))
225
226        self._size = self._blob.size if self._blob.size is not None else 0
227
228        self._raw_reader = _RawReader(self._blob, self._size)
229        self._current_pos = 0
230        self._current_part_size = buffer_size
231        self._current_part = smart_open.bytebuffer.ByteBuffer(buffer_size)
232        self._eof = False
233        self._line_terminator = line_terminator
234
235        #
236        # This member is part of the io.BufferedIOBase interface.
237        #
238        self.raw = None
239
240    #
241    # Override some methods from io.IOBase.
242    #
243    def close(self):
244        """Flush and close this stream."""
245        logger.debug("close: called")
246        self._blob = None
247        self._current_part = None
248        self._raw_reader = None
249
250    def readable(self):
251        """Return True if the stream can be read from."""
252        return True
253
254    def seekable(self):
255        """If False, seek(), tell() and truncate() will raise IOError.
256
257        We offer only seek support, and no truncate support."""
258        return True
259
260    #
261    # io.BufferedIOBase methods.
262    #
263    def detach(self):
264        """Unsupported."""
265        raise io.UnsupportedOperation
266
267    def seek(self, offset, whence=constants.WHENCE_START):
268        """Seek to the specified position.
269
270        :param int offset: The offset in bytes.
271        :param int whence: Where the offset is from.
272
273        Returns the position after seeking."""
274        logger.debug('seeking to offset: %r whence: %r', offset, whence)
275        if whence not in constants.WHENCE_CHOICES:
276            raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES)
277
278        if whence == constants.WHENCE_START:
279            new_position = offset
280        elif whence == constants.WHENCE_CURRENT:
281            new_position = self._current_pos + offset
282        else:
283            new_position = self._size + offset
284        new_position = smart_open.utils.clamp(new_position, 0, self._size)
285        self._current_pos = new_position
286        self._raw_reader.seek(new_position)
287        logger.debug('current_pos: %r', self._current_pos)
288
289        self._current_part.empty()
290        self._eof = self._current_pos == self._size
291        return self._current_pos
292
293    def tell(self):
294        """Return the current position within the file."""
295        return self._current_pos
296
297    def truncate(self, size=None):
298        """Unsupported."""
299        raise io.UnsupportedOperation
300
301    def read(self, size=-1):
302        """Read up to size bytes from the object and return them."""
303        if size == 0:
304            return b''
305        elif size < 0:
306            self._current_pos = self._size
307            return self._read_from_buffer() + self._raw_reader.read()
308
309        #
310        # Return unused data first
311        #
312        if len(self._current_part) >= size:
313            return self._read_from_buffer(size)
314
315        #
316        # If the stream is finished, return what we have.
317        #
318        if self._eof:
319            return self._read_from_buffer()
320
321        #
322        # Fill our buffer to the required size.
323        #
324        self._fill_buffer(size)
325        return self._read_from_buffer(size)
326
327    def read1(self, size=-1):
328        """This is the same as read()."""
329        return self.read(size=size)
330
331    def readinto(self, b):
332        """Read up to len(b) bytes into b, and return the number of bytes
333        read."""
334        data = self.read(len(b))
335        if not data:
336            return 0
337        b[:len(data)] = data
338        return len(data)
339
340    def readline(self, limit=-1):
341        """Read up to and including the next newline.  Returns the bytes read."""
342        if limit != -1:
343            raise NotImplementedError('limits other than -1 not implemented yet')
344        the_line = io.BytesIO()
345        while not (self._eof and len(self._current_part) == 0):
346            #
347            # In the worst case, we're reading the unread part of self._current_part
348            # twice here, once in the if condition and once when calling index.
349            #
350            # This is sub-optimal, but better than the alternative: wrapping
351            # .index in a try..except, because that is slower.
352            #
353            remaining_buffer = self._current_part.peek()
354            if self._line_terminator in remaining_buffer:
355                next_newline = remaining_buffer.index(self._line_terminator)
356                the_line.write(self._read_from_buffer(next_newline + 1))
357                break
358            else:
359                the_line.write(self._read_from_buffer())
360                self._fill_buffer()
361        return the_line.getvalue()
362
363    #
364    # Internal methods.
365    #
366    def _read_from_buffer(self, size=-1):
367        """Remove at most size bytes from our buffer and return them."""
368        # logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._current_part))
369        size = size if size >= 0 else len(self._current_part)
370        part = self._current_part.read(size)
371        self._current_pos += len(part)
372        # logger.debug('part: %r', part)
373        return part
374
375    def _fill_buffer(self, size=-1):
376        size = size if size >= 0 else self._current_part._chunk_size
377        while len(self._current_part) < size and not self._eof:
378            bytes_read = self._current_part.fill(self._raw_reader)
379            if bytes_read == 0:
380                logger.debug('reached EOF while filling buffer')
381                self._eof = True
382
383    def __str__(self):
384        return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name)
385
386    def __repr__(self):
387        return "%s(bucket=%r, blob=%r, buffer_size=%r)" % (
388            self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._current_part_size,
389        )
390
391
392class Writer(io.BufferedIOBase):
393    """Writes bytes to GCS.
394
395    Implements the io.BufferedIOBase interface of the standard library."""
396
397    def __init__(
398            self,
399            bucket,
400            blob,
401            min_part_size=_DEFAULT_MIN_PART_SIZE,
402            client=None,  # type: google.cloud.storage.Client
403            blob_properties=None,
404    ):
405        if client is None:
406            client = google.cloud.storage.Client()
407        self._client = client
408        self._blob = self._client.bucket(bucket).blob(blob)  # type: google.cloud.storage.Blob
409        assert min_part_size % _REQUIRED_CHUNK_MULTIPLE == 0, 'min part size must be a multiple of 256KB'
410        assert min_part_size >= _MIN_MIN_PART_SIZE, 'min part size must be greater than 256KB'
411        self._min_part_size = min_part_size
412
413        self._total_size = 0
414        self._total_parts = 0
415        self._bytes_uploaded = 0
416        self._current_part = io.BytesIO()
417
418        self._session = google.auth.transport.requests.AuthorizedSession(client._credentials)
419
420        if blob_properties:
421            for k, v in blob_properties.items():
422                setattr(self._blob, k, v)
423
424        #
425        # https://cloud.google.com/storage/docs/json_api/v1/how-tos/resumable-upload#start-resumable
426        #
427        self._resumable_upload_url = self._blob.create_resumable_upload_session()
428
429        #
430        # This member is part of the io.BufferedIOBase interface.
431        #
432        self.raw = None
433
434    def flush(self):
435        pass
436
437    #
438    # Override some methods from io.IOBase.
439    #
440    def close(self):
441        logger.debug("closing")
442        if not self.closed:
443            if self._total_size == 0:  # empty files
444                self._upload_empty_part()
445            else:
446                self._upload_part(is_last=True)
447            self._client = None
448        logger.debug("successfully closed")
449
450    @property
451    def closed(self):
452        return self._client is None
453
454    def writable(self):
455        """Return True if the stream supports writing."""
456        return True
457
458    def seekable(self):
459        """If False, seek(), tell() and truncate() will raise IOError.
460
461        We offer only tell support, and no seek or truncate support."""
462        return True
463
464    def seek(self, offset, whence=constants.WHENCE_START):
465        """Unsupported."""
466        raise io.UnsupportedOperation
467
468    def truncate(self, size=None):
469        """Unsupported."""
470        raise io.UnsupportedOperation
471
472    def tell(self):
473        """Return the current stream position."""
474        return self._total_size
475
476    #
477    # io.BufferedIOBase methods.
478    #
479    def detach(self):
480        raise io.UnsupportedOperation("detach() not supported")
481
482    def write(self, b):
483        """Write the given bytes (binary string) to the GCS file.
484
485        There's buffering happening under the covers, so this may not actually
486        do any HTTP transfer right away."""
487
488        if not isinstance(b, _BINARY_TYPES):
489            raise TypeError("input must be one of %r, got: %r" % (_BINARY_TYPES, type(b)))
490
491        self._current_part.write(b)
492        self._total_size += len(b)
493
494        #
495        # If the size of this part is precisely equal to the minimum part size,
496        # we don't perform the actual write now, and wait until we see more data.
497        # We do this because the very last part of the upload must be handled slightly
498        # differently (see comments in the _upload_part method).
499        #
500        if self._current_part.tell() > self._min_part_size:
501            self._upload_part()
502
503        return len(b)
504
505    def terminate(self):
506        """Cancel the underlying resumable upload."""
507        #
508        # https://cloud.google.com/storage/docs/xml-api/resumable-upload#example_cancelling_an_upload
509        #
510        self._session.delete(self._resumable_upload_url)
511
512    #
513    # Internal methods.
514    #
515    def _upload_part(self, is_last=False):
516        part_num = self._total_parts + 1
517
518        #
519        # Here we upload the largest amount possible given GCS's restriction
520        # of parts being multiples of 256kB, except for the last one.
521        #
522        # A final upload of 0 bytes does not work, so we need to guard against
523        # this edge case. This results in occasionally keeping an additional
524        # 256kB in the buffer after uploading a part, but until this is fixed
525        # on Google's end there is no other option.
526        #
527        # https://stackoverflow.com/questions/60230631/upload-zero-size-final-part-to-google-cloud-storage-resumable-upload
528        #
529        content_length = self._current_part.tell()
530        remainder = content_length % self._min_part_size
531        if is_last:
532            end = self._bytes_uploaded + content_length
533        elif remainder == 0:
534            content_length -= _REQUIRED_CHUNK_MULTIPLE
535            end = None
536        else:
537            content_length -= remainder
538            end = None
539
540        range_stop = self._bytes_uploaded + content_length - 1
541        content_range = _make_range_string(self._bytes_uploaded, range_stop, end=end)
542        headers = {
543            'Content-Length': str(content_length),
544            'Content-Range': content_range,
545        }
546        logger.info(
547            "uploading part #%i, %i bytes (total %.3fGB) headers %r",
548            part_num, content_length, range_stop / 1024.0 ** 3, headers,
549        )
550        self._current_part.seek(0)
551        response = self._session.put(
552            self._resumable_upload_url,
553            data=self._current_part.read(content_length),
554            headers=headers,
555        )
556
557        if is_last:
558            expected = _UPLOAD_COMPLETE_STATUS_CODES
559        else:
560            expected = _UPLOAD_INCOMPLETE_STATUS_CODES
561        if response.status_code not in expected:
562            _fail(response, part_num, content_length, self._total_size, headers)
563        logger.debug("upload of part #%i finished" % part_num)
564
565        self._total_parts += 1
566        self._bytes_uploaded += content_length
567
568        #
569        # For the last part, the below _current_part handling is a NOOP.
570        #
571        self._current_part = io.BytesIO(self._current_part.read())
572        self._current_part.seek(0, io.SEEK_END)
573
574    def _upload_empty_part(self):
575        logger.debug("creating empty file")
576        headers = {'Content-Length': '0'}
577        response = self._session.put(self._resumable_upload_url, headers=headers)
578        if response.status_code not in _UPLOAD_COMPLETE_STATUS_CODES:
579            _fail(response, self._total_parts + 1, 0, self._total_size, headers)
580
581        self._total_parts += 1
582
583    def __enter__(self):
584        return self
585
586    def __exit__(self, exc_type, exc_val, exc_tb):
587        if exc_type is not None:
588            self.terminate()
589        else:
590            self.close()
591
592    def __str__(self):
593        return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name)
594
595    def __repr__(self):
596        return "%s(bucket=%r, blob=%r, min_part_size=%r)" % (
597            self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._min_part_size,
598        )
599