1# -*- coding: utf-8 -*- 2# 3# Copyright (C) 2019 Radim Rehurek <me@radimrehurek.com> 4# 5# This code is distributed under the terms and conditions 6# from the MIT License (MIT). 7# 8 9"""Implements file-like objects for reading and writing to/from GCS.""" 10 11import io 12import logging 13 14try: 15 import google.cloud.exceptions 16 import google.cloud.storage 17 import google.auth.transport.requests 18except ImportError: 19 MISSING_DEPS = True 20 21import smart_open.bytebuffer 22import smart_open.utils 23 24from smart_open import constants 25 26logger = logging.getLogger(__name__) 27 28_BINARY_TYPES = (bytes, bytearray, memoryview) 29"""Allowed binary buffer types for writing to the underlying GCS stream""" 30 31_UNKNOWN = '*' 32 33SCHEME = "gs" 34"""Supported scheme for GCS""" 35 36_MIN_MIN_PART_SIZE = _REQUIRED_CHUNK_MULTIPLE = 256 * 1024 37"""Google requires you to upload in multiples of 256 KB, except for the last part.""" 38 39_DEFAULT_MIN_PART_SIZE = 50 * 1024**2 40"""Default minimum part size for GCS multipart uploads""" 41 42DEFAULT_BUFFER_SIZE = 256 * 1024 43"""Default buffer size for working with GCS""" 44 45_UPLOAD_INCOMPLETE_STATUS_CODES = (308, ) 46_UPLOAD_COMPLETE_STATUS_CODES = (200, 201) 47 48 49def _make_range_string(start, stop=None, end=None): 50 # 51 # GCS seems to violate RFC-2616 (see utils.make_range_string), so we 52 # need a separate implementation. 53 # 54 # https://cloud.google.com/storage/docs/xml-api/resumable-upload#step_3upload_the_file_blocks 55 # 56 if end is None: 57 end = _UNKNOWN 58 if stop is None: 59 return 'bytes %d-/%s' % (start, end) 60 return 'bytes %d-%d/%s' % (start, stop, end) 61 62 63class UploadFailedError(Exception): 64 def __init__(self, message, status_code, text): 65 """Raise when a multi-part upload to GCS returns a failed response status code. 66 67 Parameters 68 ---------- 69 message: str 70 The error message to display. 71 status_code: int 72 The status code returned from the upload response. 73 text: str 74 The text returned from the upload response. 75 76 """ 77 super(UploadFailedError, self).__init__(message) 78 self.status_code = status_code 79 self.text = text 80 81 82def _fail(response, part_num, content_length, total_size, headers): 83 status_code = response.status_code 84 response_text = response.text 85 total_size_gb = total_size / 1024.0 ** 3 86 87 msg = ( 88 "upload failed (status code: %(status_code)d, response text: %(response_text)s), " 89 "part #%(part_num)d, %(total_size)d bytes (total %(total_size_gb).3fGB), headers: %(headers)r" 90 ) % locals() 91 raise UploadFailedError(msg, response.status_code, response.text) 92 93 94def parse_uri(uri_as_string): 95 sr = smart_open.utils.safe_urlsplit(uri_as_string) 96 assert sr.scheme == SCHEME 97 bucket_id = sr.netloc 98 blob_id = sr.path.lstrip('/') 99 return dict(scheme=SCHEME, bucket_id=bucket_id, blob_id=blob_id) 100 101 102def open_uri(uri, mode, transport_params): 103 parsed_uri = parse_uri(uri) 104 kwargs = smart_open.utils.check_kwargs(open, transport_params) 105 return open(parsed_uri['bucket_id'], parsed_uri['blob_id'], mode, **kwargs) 106 107 108def open( 109 bucket_id, 110 blob_id, 111 mode, 112 buffer_size=DEFAULT_BUFFER_SIZE, 113 min_part_size=_MIN_MIN_PART_SIZE, 114 client=None, # type: google.cloud.storage.Client 115 blob_properties=None 116 ): 117 """Open an GCS blob for reading or writing. 118 119 Parameters 120 ---------- 121 bucket_id: str 122 The name of the bucket this object resides in. 123 blob_id: str 124 The name of the blob within the bucket. 125 mode: str 126 The mode for opening the object. Must be either "rb" or "wb". 127 buffer_size: int, optional 128 The buffer size to use when performing I/O. For reading only. 129 min_part_size: int, optional 130 The minimum part size for multipart uploads. For writing only. 131 client: google.cloud.storage.Client, optional 132 The GCS client to use when working with google-cloud-storage. 133 blob_properties: dict, optional 134 Set properties on blob before writing. For writing only. 135 136 """ 137 if mode == constants.READ_BINARY: 138 fileobj = Reader( 139 bucket_id, 140 blob_id, 141 buffer_size=buffer_size, 142 line_terminator=constants.BINARY_NEWLINE, 143 client=client, 144 ) 145 elif mode == constants.WRITE_BINARY: 146 fileobj = Writer( 147 bucket_id, 148 blob_id, 149 min_part_size=min_part_size, 150 client=client, 151 blob_properties=blob_properties, 152 ) 153 else: 154 raise NotImplementedError('GCS support for mode %r not implemented' % mode) 155 156 fileobj.name = blob_id 157 return fileobj 158 159 160class _RawReader(object): 161 """Read an GCS object.""" 162 163 def __init__(self, gcs_blob, size): 164 # type: (google.cloud.storage.Blob, int) -> None 165 self._blob = gcs_blob 166 self._size = size 167 self._position = 0 168 169 def seek(self, position): 170 """Seek to the specified position (byte offset) in the GCS key. 171 172 :param int position: The byte offset from the beginning of the key. 173 174 Returns the position after seeking. 175 """ 176 self._position = position 177 return self._position 178 179 def read(self, size=-1): 180 if self._position >= self._size: 181 return b'' 182 binary = self._download_blob_chunk(size) 183 self._position += len(binary) 184 return binary 185 186 def _download_blob_chunk(self, size): 187 start = position = self._position 188 if position == self._size: 189 # 190 # When reading, we can't seek to the first byte of an empty file. 191 # Similarly, we can't seek past the last byte. Do nothing here. 192 # 193 binary = b'' 194 elif size == -1: 195 binary = self._blob.download_as_bytes(start=start) 196 else: 197 end = position + size 198 binary = self._blob.download_as_bytes(start=start, end=end) 199 return binary 200 201 202class Reader(io.BufferedIOBase): 203 """Reads bytes from GCS. 204 205 Implements the io.BufferedIOBase interface of the standard library. 206 207 :raises google.cloud.exceptions.NotFound: Raised when the blob to read from does not exist. 208 209 """ 210 def __init__( 211 self, 212 bucket, 213 key, 214 buffer_size=DEFAULT_BUFFER_SIZE, 215 line_terminator=constants.BINARY_NEWLINE, 216 client=None, # type: google.cloud.storage.Client 217 ): 218 if client is None: 219 client = google.cloud.storage.Client() 220 221 self._blob = client.bucket(bucket).get_blob(key) # type: google.cloud.storage.Blob 222 223 if self._blob is None: 224 raise google.cloud.exceptions.NotFound('blob %s not found in %s' % (key, bucket)) 225 226 self._size = self._blob.size if self._blob.size is not None else 0 227 228 self._raw_reader = _RawReader(self._blob, self._size) 229 self._current_pos = 0 230 self._current_part_size = buffer_size 231 self._current_part = smart_open.bytebuffer.ByteBuffer(buffer_size) 232 self._eof = False 233 self._line_terminator = line_terminator 234 235 # 236 # This member is part of the io.BufferedIOBase interface. 237 # 238 self.raw = None 239 240 # 241 # Override some methods from io.IOBase. 242 # 243 def close(self): 244 """Flush and close this stream.""" 245 logger.debug("close: called") 246 self._blob = None 247 self._current_part = None 248 self._raw_reader = None 249 250 def readable(self): 251 """Return True if the stream can be read from.""" 252 return True 253 254 def seekable(self): 255 """If False, seek(), tell() and truncate() will raise IOError. 256 257 We offer only seek support, and no truncate support.""" 258 return True 259 260 # 261 # io.BufferedIOBase methods. 262 # 263 def detach(self): 264 """Unsupported.""" 265 raise io.UnsupportedOperation 266 267 def seek(self, offset, whence=constants.WHENCE_START): 268 """Seek to the specified position. 269 270 :param int offset: The offset in bytes. 271 :param int whence: Where the offset is from. 272 273 Returns the position after seeking.""" 274 logger.debug('seeking to offset: %r whence: %r', offset, whence) 275 if whence not in constants.WHENCE_CHOICES: 276 raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) 277 278 if whence == constants.WHENCE_START: 279 new_position = offset 280 elif whence == constants.WHENCE_CURRENT: 281 new_position = self._current_pos + offset 282 else: 283 new_position = self._size + offset 284 new_position = smart_open.utils.clamp(new_position, 0, self._size) 285 self._current_pos = new_position 286 self._raw_reader.seek(new_position) 287 logger.debug('current_pos: %r', self._current_pos) 288 289 self._current_part.empty() 290 self._eof = self._current_pos == self._size 291 return self._current_pos 292 293 def tell(self): 294 """Return the current position within the file.""" 295 return self._current_pos 296 297 def truncate(self, size=None): 298 """Unsupported.""" 299 raise io.UnsupportedOperation 300 301 def read(self, size=-1): 302 """Read up to size bytes from the object and return them.""" 303 if size == 0: 304 return b'' 305 elif size < 0: 306 self._current_pos = self._size 307 return self._read_from_buffer() + self._raw_reader.read() 308 309 # 310 # Return unused data first 311 # 312 if len(self._current_part) >= size: 313 return self._read_from_buffer(size) 314 315 # 316 # If the stream is finished, return what we have. 317 # 318 if self._eof: 319 return self._read_from_buffer() 320 321 # 322 # Fill our buffer to the required size. 323 # 324 self._fill_buffer(size) 325 return self._read_from_buffer(size) 326 327 def read1(self, size=-1): 328 """This is the same as read().""" 329 return self.read(size=size) 330 331 def readinto(self, b): 332 """Read up to len(b) bytes into b, and return the number of bytes 333 read.""" 334 data = self.read(len(b)) 335 if not data: 336 return 0 337 b[:len(data)] = data 338 return len(data) 339 340 def readline(self, limit=-1): 341 """Read up to and including the next newline. Returns the bytes read.""" 342 if limit != -1: 343 raise NotImplementedError('limits other than -1 not implemented yet') 344 the_line = io.BytesIO() 345 while not (self._eof and len(self._current_part) == 0): 346 # 347 # In the worst case, we're reading the unread part of self._current_part 348 # twice here, once in the if condition and once when calling index. 349 # 350 # This is sub-optimal, but better than the alternative: wrapping 351 # .index in a try..except, because that is slower. 352 # 353 remaining_buffer = self._current_part.peek() 354 if self._line_terminator in remaining_buffer: 355 next_newline = remaining_buffer.index(self._line_terminator) 356 the_line.write(self._read_from_buffer(next_newline + 1)) 357 break 358 else: 359 the_line.write(self._read_from_buffer()) 360 self._fill_buffer() 361 return the_line.getvalue() 362 363 # 364 # Internal methods. 365 # 366 def _read_from_buffer(self, size=-1): 367 """Remove at most size bytes from our buffer and return them.""" 368 # logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._current_part)) 369 size = size if size >= 0 else len(self._current_part) 370 part = self._current_part.read(size) 371 self._current_pos += len(part) 372 # logger.debug('part: %r', part) 373 return part 374 375 def _fill_buffer(self, size=-1): 376 size = size if size >= 0 else self._current_part._chunk_size 377 while len(self._current_part) < size and not self._eof: 378 bytes_read = self._current_part.fill(self._raw_reader) 379 if bytes_read == 0: 380 logger.debug('reached EOF while filling buffer') 381 self._eof = True 382 383 def __str__(self): 384 return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name) 385 386 def __repr__(self): 387 return "%s(bucket=%r, blob=%r, buffer_size=%r)" % ( 388 self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._current_part_size, 389 ) 390 391 392class Writer(io.BufferedIOBase): 393 """Writes bytes to GCS. 394 395 Implements the io.BufferedIOBase interface of the standard library.""" 396 397 def __init__( 398 self, 399 bucket, 400 blob, 401 min_part_size=_DEFAULT_MIN_PART_SIZE, 402 client=None, # type: google.cloud.storage.Client 403 blob_properties=None, 404 ): 405 if client is None: 406 client = google.cloud.storage.Client() 407 self._client = client 408 self._blob = self._client.bucket(bucket).blob(blob) # type: google.cloud.storage.Blob 409 assert min_part_size % _REQUIRED_CHUNK_MULTIPLE == 0, 'min part size must be a multiple of 256KB' 410 assert min_part_size >= _MIN_MIN_PART_SIZE, 'min part size must be greater than 256KB' 411 self._min_part_size = min_part_size 412 413 self._total_size = 0 414 self._total_parts = 0 415 self._bytes_uploaded = 0 416 self._current_part = io.BytesIO() 417 418 self._session = google.auth.transport.requests.AuthorizedSession(client._credentials) 419 420 if blob_properties: 421 for k, v in blob_properties.items(): 422 setattr(self._blob, k, v) 423 424 # 425 # https://cloud.google.com/storage/docs/json_api/v1/how-tos/resumable-upload#start-resumable 426 # 427 self._resumable_upload_url = self._blob.create_resumable_upload_session() 428 429 # 430 # This member is part of the io.BufferedIOBase interface. 431 # 432 self.raw = None 433 434 def flush(self): 435 pass 436 437 # 438 # Override some methods from io.IOBase. 439 # 440 def close(self): 441 logger.debug("closing") 442 if not self.closed: 443 if self._total_size == 0: # empty files 444 self._upload_empty_part() 445 else: 446 self._upload_part(is_last=True) 447 self._client = None 448 logger.debug("successfully closed") 449 450 @property 451 def closed(self): 452 return self._client is None 453 454 def writable(self): 455 """Return True if the stream supports writing.""" 456 return True 457 458 def seekable(self): 459 """If False, seek(), tell() and truncate() will raise IOError. 460 461 We offer only tell support, and no seek or truncate support.""" 462 return True 463 464 def seek(self, offset, whence=constants.WHENCE_START): 465 """Unsupported.""" 466 raise io.UnsupportedOperation 467 468 def truncate(self, size=None): 469 """Unsupported.""" 470 raise io.UnsupportedOperation 471 472 def tell(self): 473 """Return the current stream position.""" 474 return self._total_size 475 476 # 477 # io.BufferedIOBase methods. 478 # 479 def detach(self): 480 raise io.UnsupportedOperation("detach() not supported") 481 482 def write(self, b): 483 """Write the given bytes (binary string) to the GCS file. 484 485 There's buffering happening under the covers, so this may not actually 486 do any HTTP transfer right away.""" 487 488 if not isinstance(b, _BINARY_TYPES): 489 raise TypeError("input must be one of %r, got: %r" % (_BINARY_TYPES, type(b))) 490 491 self._current_part.write(b) 492 self._total_size += len(b) 493 494 # 495 # If the size of this part is precisely equal to the minimum part size, 496 # we don't perform the actual write now, and wait until we see more data. 497 # We do this because the very last part of the upload must be handled slightly 498 # differently (see comments in the _upload_part method). 499 # 500 if self._current_part.tell() > self._min_part_size: 501 self._upload_part() 502 503 return len(b) 504 505 def terminate(self): 506 """Cancel the underlying resumable upload.""" 507 # 508 # https://cloud.google.com/storage/docs/xml-api/resumable-upload#example_cancelling_an_upload 509 # 510 self._session.delete(self._resumable_upload_url) 511 512 # 513 # Internal methods. 514 # 515 def _upload_part(self, is_last=False): 516 part_num = self._total_parts + 1 517 518 # 519 # Here we upload the largest amount possible given GCS's restriction 520 # of parts being multiples of 256kB, except for the last one. 521 # 522 # A final upload of 0 bytes does not work, so we need to guard against 523 # this edge case. This results in occasionally keeping an additional 524 # 256kB in the buffer after uploading a part, but until this is fixed 525 # on Google's end there is no other option. 526 # 527 # https://stackoverflow.com/questions/60230631/upload-zero-size-final-part-to-google-cloud-storage-resumable-upload 528 # 529 content_length = self._current_part.tell() 530 remainder = content_length % self._min_part_size 531 if is_last: 532 end = self._bytes_uploaded + content_length 533 elif remainder == 0: 534 content_length -= _REQUIRED_CHUNK_MULTIPLE 535 end = None 536 else: 537 content_length -= remainder 538 end = None 539 540 range_stop = self._bytes_uploaded + content_length - 1 541 content_range = _make_range_string(self._bytes_uploaded, range_stop, end=end) 542 headers = { 543 'Content-Length': str(content_length), 544 'Content-Range': content_range, 545 } 546 logger.info( 547 "uploading part #%i, %i bytes (total %.3fGB) headers %r", 548 part_num, content_length, range_stop / 1024.0 ** 3, headers, 549 ) 550 self._current_part.seek(0) 551 response = self._session.put( 552 self._resumable_upload_url, 553 data=self._current_part.read(content_length), 554 headers=headers, 555 ) 556 557 if is_last: 558 expected = _UPLOAD_COMPLETE_STATUS_CODES 559 else: 560 expected = _UPLOAD_INCOMPLETE_STATUS_CODES 561 if response.status_code not in expected: 562 _fail(response, part_num, content_length, self._total_size, headers) 563 logger.debug("upload of part #%i finished" % part_num) 564 565 self._total_parts += 1 566 self._bytes_uploaded += content_length 567 568 # 569 # For the last part, the below _current_part handling is a NOOP. 570 # 571 self._current_part = io.BytesIO(self._current_part.read()) 572 self._current_part.seek(0, io.SEEK_END) 573 574 def _upload_empty_part(self): 575 logger.debug("creating empty file") 576 headers = {'Content-Length': '0'} 577 response = self._session.put(self._resumable_upload_url, headers=headers) 578 if response.status_code not in _UPLOAD_COMPLETE_STATUS_CODES: 579 _fail(response, self._total_parts + 1, 0, self._total_size, headers) 580 581 self._total_parts += 1 582 583 def __enter__(self): 584 return self 585 586 def __exit__(self, exc_type, exc_val, exc_tb): 587 if exc_type is not None: 588 self.terminate() 589 else: 590 self.close() 591 592 def __str__(self): 593 return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name) 594 595 def __repr__(self): 596 return "%s(bucket=%r, blob=%r, min_part_size=%r)" % ( 597 self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._min_part_size, 598 ) 599