1# Copyright (C) 2015-2021 Regents of the University of California
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import base64
15import bz2
16import errno
17import logging
18import os
19import socket
20import types
21from ssl import SSLError
22from typing import Optional
23
24from boto.exception import (
25    BotoServerError,
26    SDBResponseError,
27    S3ResponseError
28)
29from boto3.s3.transfer import TransferConfig
30from botocore.exceptions import ClientError
31
32from toil.lib.compatibility import compat_bytes
33from toil.lib.retry import (
34    old_retry,
35    retry,
36    ErrorCondition
37)
38
39logger = logging.getLogger(__name__)
40
41
42class SDBHelper(object):
43    """
44    A mixin with methods for storing limited amounts of binary data in an SDB item
45
46    >>> import os
47    >>> H=SDBHelper
48    >>> H.presenceIndicator() # doctest: +ALLOW_UNICODE
49    u'numChunks'
50    >>> H.binaryToAttributes(None)['numChunks']
51    0
52    >>> H.attributesToBinary({u'numChunks': 0})
53    (None, 0)
54    >>> H.binaryToAttributes(b'') # doctest: +ALLOW_UNICODE +ALLOW_BYTES
55    {u'000': b'VQ==', u'numChunks': 1}
56    >>> H.attributesToBinary({u'numChunks': 1, u'000': b'VQ=='}) # doctest: +ALLOW_BYTES
57    (b'', 1)
58
59    Good pseudo-random data is very likely smaller than its bzip2ed form. Subtract 1 for the type
60    character, i.e  'C' or 'U', with which the string is prefixed. We should get one full chunk:
61
62    >>> s = os.urandom(H.maxRawValueSize-1)
63    >>> d = H.binaryToAttributes(s)
64    >>> len(d), len(d['000'])
65    (2, 1024)
66    >>> H.attributesToBinary(d) == (s, 1)
67    True
68
69    One byte more and we should overflow four bytes into the second chunk, two bytes for
70    base64-encoding the additional character and two bytes for base64-padding to the next quartet.
71
72    >>> s += s[0:1]
73    >>> d = H.binaryToAttributes(s)
74    >>> len(d), len(d['000']), len(d['001'])
75    (3, 1024, 4)
76    >>> H.attributesToBinary(d) == (s, 2)
77    True
78
79    """
80    # The SDB documentation is not clear as to whether the attribute value size limit of 1024
81    # applies to the base64-encoded value or the raw value. It suggests that responses are
82    # automatically encoded from which I conclude that the limit should apply to the raw,
83    # unencoded value. However, there seems to be a discrepancy between how Boto computes the
84    # request signature if a value contains a binary data, and how SDB does it. This causes
85    # requests to fail signature verification, resulting in a 403. We therefore have to
86    # base64-encode values ourselves even if that means we loose a quarter of capacity.
87
88    maxAttributesPerItem = 256
89    maxValueSize = 1024
90    maxRawValueSize = maxValueSize * 3 // 4
91    # Just make sure we don't have a problem with padding or integer truncation:
92    assert len(base64.b64encode(b' ' * maxRawValueSize)) == 1024
93    assert len(base64.b64encode(b' ' * (1 + maxRawValueSize))) > 1024
94
95    @classmethod
96    def _reservedAttributes(cls):
97        """
98        Override in subclass to reserve a certain number of attributes that can't be used for
99        chunks.
100        """
101        return 1
102
103    @classmethod
104    def _maxChunks(cls):
105        return cls.maxAttributesPerItem - cls._reservedAttributes()
106
107    @classmethod
108    def maxBinarySize(cls, extraReservedChunks=0):
109        return (cls._maxChunks() - extraReservedChunks) * cls.maxRawValueSize - 1  # for the 'C' or 'U' prefix
110
111    @classmethod
112    def _maxEncodedSize(cls):
113        return cls._maxChunks() * cls.maxValueSize
114
115    @classmethod
116    def binaryToAttributes(cls, binary):
117        """
118        Turn a bytestring, or None, into SimpleDB attributes.
119        """
120        if binary is None: return {u'numChunks': 0}
121        assert isinstance(binary, bytes)
122        assert len(binary) <= cls.maxBinarySize()
123        # The use of compression is just an optimization. We can't include it in the maxValueSize
124        # computation because the compression ratio depends on the input.
125        compressed = bz2.compress(binary)
126        if len(compressed) > len(binary):
127            compressed = b'U' + binary
128        else:
129            compressed = b'C' + compressed
130        encoded = base64.b64encode(compressed)
131        assert len(encoded) <= cls._maxEncodedSize()
132        n = cls.maxValueSize
133        chunks = (encoded[i:i + n] for i in range(0, len(encoded), n))
134        attributes = {cls._chunkName(i): chunk for i, chunk in enumerate(chunks)}
135        attributes.update({u'numChunks': len(attributes)})
136        return attributes
137
138    @classmethod
139    def _chunkName(cls, i):
140        return str(i).zfill(3)
141
142    @classmethod
143    def _isValidChunkName(cls, s):
144        return len(s) == 3 and s.isdigit()
145
146    @classmethod
147    def presenceIndicator(cls):
148        """
149        The key that is guaranteed to be present in the return value of binaryToAttributes().
150        Assuming that binaryToAttributes() is used with SDB's PutAttributes, the return value of
151        this method could be used to detect the presence/absence of an item in SDB.
152        """
153        return u'numChunks'
154
155    @classmethod
156    def attributesToBinary(cls, attributes):
157        """
158        :rtype: (str|None,int)
159        :return: the binary data and the number of chunks it was composed from
160        """
161        chunks = [(int(k), v) for k, v in attributes.items() if cls._isValidChunkName(k)]
162        chunks.sort()
163        numChunks = int(attributes[u'numChunks'])
164        if numChunks:
165            serializedJob = b''.join(v.encode() for k, v in chunks)
166            compressed = base64.b64decode(serializedJob)
167            if compressed[0] == b'C'[0]:
168                binary = bz2.decompress(compressed[1:])
169            elif compressed[0] == b'U'[0]:
170                binary = compressed[1:]
171            else:
172                raise RuntimeError('Unexpected prefix {}'.format(compressed[0]))
173        else:
174            binary = None
175        return binary, numChunks
176
177
178def fileSizeAndTime(localFilePath):
179    file_stat = os.stat(localFilePath)
180    return file_stat.st_size, file_stat.st_mtime
181
182
183@retry(errors=[ErrorCondition(
184    error=ClientError,
185    error_codes=[404, 500, 502, 503, 504]
186)])
187def uploadFromPath(localFilePath: str,
188                   resource,
189                   bucketName: str,
190                   fileID: str,
191                   headerArgs: Optional[dict] = None,
192                   partSize: int = 50 << 20):
193    """
194    Uploads a file to s3, using multipart uploading if applicable
195
196    :param str localFilePath: Path of the file to upload to s3
197    :param S3.Resource resource: boto3 resource
198    :param str bucketName: name of the bucket to upload to
199    :param str fileID: the name of the file to upload to
200    :param dict headerArgs: http headers to use when uploading - generally used for encryption purposes
201    :param int partSize: max size of each part in the multipart upload, in bytes
202
203    :return: version of the newly uploaded file
204    """
205    if headerArgs is None:
206        headerArgs = {}
207
208    client = resource.meta.client
209    file_size, file_time = fileSizeAndTime(localFilePath)
210
211    version = uploadFile(localFilePath, resource, bucketName, fileID, headerArgs, partSize)
212    info = client.head_object(Bucket=bucketName, Key=compat_bytes(fileID), VersionId=version, **headerArgs)
213    size = info.get('ContentLength')
214
215    assert size == file_size
216
217    # Make reasonably sure that the file wasn't touched during the upload
218    assert fileSizeAndTime(localFilePath) == (file_size, file_time)
219    return version
220
221
222@retry(errors=[ErrorCondition(
223    error=ClientError,
224    error_codes=[404, 500, 502, 503, 504]
225)])
226def uploadFile(readable,
227               resource,
228               bucketName: str,
229               fileID: str,
230               headerArgs: Optional[dict] = None,
231               partSize: int = 50 << 20):
232    """
233    Upload a readable object to s3, using multipart uploading if applicable.
234    :param readable: a readable stream or a file path to upload to s3
235    :param S3.Resource resource: boto3 resource
236    :param str bucketName: name of the bucket to upload to
237    :param str fileID: the name of the file to upload to
238    :param dict headerArgs: http headers to use when uploading - generally used for encryption purposes
239    :param int partSize: max size of each part in the multipart upload, in bytes
240    :return: version of the newly uploaded file
241    """
242    if headerArgs is None:
243        headerArgs = {}
244
245    client = resource.meta.client
246    config = TransferConfig(
247        multipart_threshold=partSize,
248        multipart_chunksize=partSize,
249        use_threads=True
250    )
251    if isinstance(readable, str):
252        client.upload_file(Filename=readable,
253                           Bucket=bucketName,
254                           Key=compat_bytes(fileID),
255                           ExtraArgs=headerArgs,
256                           Config=config)
257    else:
258        client.upload_fileobj(Fileobj=readable,
259                              Bucket=bucketName,
260                              Key=compat_bytes(fileID),
261                              ExtraArgs=headerArgs,
262                              Config=config)
263
264        # Wait until the object exists before calling head_object
265        object_summary = resource.ObjectSummary(bucketName, compat_bytes(fileID))
266        object_summary.wait_until_exists(**headerArgs)
267
268    info = client.head_object(Bucket=bucketName, Key=compat_bytes(fileID), **headerArgs)
269    return info.get('VersionId', None)
270
271
272@retry(errors=[ErrorCondition(
273    error=ClientError,
274    error_codes=[404, 500, 502, 503, 504]
275)])
276def copyKeyMultipart(resource,
277                     srcBucketName: str,
278                     srcKeyName: str,
279                     srcKeyVersion: str,
280                     dstBucketName: str,
281                     dstKeyName: str,
282                     sseAlgorithm: Optional[str] = None,
283                     sseKey: Optional[str] = None,
284                     copySourceSseAlgorithm: Optional[str] = None,
285                     copySourceSseKey: Optional[str] = None):
286    """
287    Copies a key from a source key to a destination key in multiple parts. Note that if the
288    destination key exists it will be overwritten implicitly, and if it does not exist a new
289    key will be created. If the destination bucket does not exist an error will be raised.
290
291    :param S3.Resource resource: boto3 resource
292    :param str srcBucketName: The name of the bucket to be copied from.
293    :param str srcKeyName: The name of the key to be copied from.
294    :param str srcKeyVersion: The version of the key to be copied from.
295    :param str dstBucketName: The name of the destination bucket for the copy.
296    :param str dstKeyName: The name of the destination key that will be created or overwritten.
297    :param str sseAlgorithm: Server-side encryption algorithm for the destination.
298    :param str sseKey: Server-side encryption key for the destination.
299    :param str copySourceSseAlgorithm: Server-side encryption algorithm for the source.
300    :param str copySourceSseKey: Server-side encryption key for the source.
301
302    :rtype: str
303    :return: The version of the copied file (or None if versioning is not enabled for dstBucket).
304    """
305    dstBucket = resource.Bucket(compat_bytes(dstBucketName))
306    dstObject = dstBucket.Object(compat_bytes(dstKeyName))
307    copySource = {'Bucket': compat_bytes(srcBucketName), 'Key': compat_bytes(srcKeyName)}
308    if srcKeyVersion is not None:
309        copySource['VersionId'] = compat_bytes(srcKeyVersion)
310
311    # The boto3 functions don't allow passing parameters as None to
312    # indicate they weren't provided. So we have to do a bit of work
313    # to ensure we only provide the parameters when they are actually
314    # required.
315    destEncryptionArgs = {}
316    if sseKey is not None:
317        destEncryptionArgs.update({'SSECustomerAlgorithm': sseAlgorithm,
318                                   'SSECustomerKey': sseKey})
319    copyEncryptionArgs = {}
320    if copySourceSseKey is not None:
321        copyEncryptionArgs.update({'CopySourceSSECustomerAlgorithm': copySourceSseAlgorithm,
322                                   'CopySourceSSECustomerKey': copySourceSseKey})
323    copyEncryptionArgs.update(destEncryptionArgs)
324
325    dstObject.copy(copySource, ExtraArgs=copyEncryptionArgs)
326
327    # Wait until the object exists before calling head_object
328    object_summary = resource.ObjectSummary(dstObject.bucket_name, dstObject.key)
329    object_summary.wait_until_exists(**destEncryptionArgs)
330
331    # Unfortunately, boto3's managed copy doesn't return the version
332    # that it actually copied to. So we have to check immediately
333    # after, leaving open the possibility that it may have been
334    # modified again in the few seconds since the copy finished. There
335    # isn't much we can do about it.
336    info = resource.meta.client.head_object(Bucket=dstObject.bucket_name,
337                                            Key=dstObject.key,
338                                            **destEncryptionArgs)
339    return info.get('VersionId', None)
340
341
342def _put_attributes_using_post(self, domain_or_name, item_name, attributes,
343                               replace=True, expected_value=None):
344    """
345    Monkey-patched version of SDBConnection.put_attributes that uses POST instead of GET
346
347    The GET version is subject to the URL length limit which kicks in before the 256 x 1024 limit
348    for attribute values. Using POST prevents that.
349
350    https://github.com/BD2KGenomics/toil/issues/502
351    """
352    domain, domain_name = self.get_domain_and_name(domain_or_name)
353    params = {'DomainName': domain_name,
354              'ItemName': item_name}
355    self._build_name_value_list(params, attributes, replace)
356    if expected_value:
357        self._build_expected_value(params, expected_value)
358    # The addition of the verb keyword argument is the only difference to put_attributes (Hannes)
359    return self.get_status('PutAttributes', params, verb='POST')
360
361
362def monkeyPatchSdbConnection(sdb):
363    """
364    :type sdb: SDBConnection
365    """
366    sdb.put_attributes = types.MethodType(_put_attributes_using_post, sdb)
367
368
369default_delays = (0, 1, 1, 4, 16, 64)
370default_timeout = 300
371
372
373def connection_reset(e):
374    # For some reason we get 'error: [Errno 104] Connection reset by peer' where the
375    # English description suggests that errno is 54 (ECONNRESET) while the actual
376    # errno is listed as 104. To be safe, we check for both:
377    return isinstance(e, socket.error) and e.errno in (errno.ECONNRESET, 104)
378
379
380def sdb_unavailable(e):
381    return isinstance(e, BotoServerError) and e.status in (500, 503)
382
383
384def no_such_sdb_domain(e):
385    return (isinstance(e, SDBResponseError)
386            and e.error_code
387            and e.error_code.endswith('NoSuchDomain'))
388
389
390def retryable_ssl_error(e):
391    # https://github.com/BD2KGenomics/toil/issues/978
392    return isinstance(e, SSLError) and e.reason == 'DECRYPTION_FAILED_OR_BAD_RECORD_MAC'
393
394
395def retryable_sdb_errors(e):
396    return (sdb_unavailable(e)
397            or no_such_sdb_domain(e)
398            or connection_reset(e)
399            or retryable_ssl_error(e))
400
401
402def retry_sdb(delays=default_delays, timeout=default_timeout, predicate=retryable_sdb_errors):
403    return old_retry(delays=delays, timeout=timeout, predicate=predicate)
404# https://github.com/boto/botocore/blob/49f87350d54f55b687969ec8bf204df785975077/botocore/retries/standard.py#L316
405THROTTLED_ERROR_CODES = [
406        'Throttling',
407        'ThrottlingException',
408        'ThrottledException',
409        'RequestThrottledException',
410        'TooManyRequestsException',
411        'ProvisionedThroughputExceededException',
412        'TransactionInProgressException',
413        'RequestLimitExceeded',
414        'BandwidthLimitExceeded',
415        'LimitExceededException',
416        'RequestThrottled',
417        'SlowDown',
418        'PriorRequestNotComplete',
419        'EC2ThrottledException',
420]
421
422
423# TODO: Replace with: @retry and ErrorCondition
424def retryable_s3_errors(e):
425    return    (connection_reset(e)
426            or (isinstance(e, BotoServerError) and e.status in (429, 500))
427            or (isinstance(e, BotoServerError) and e.code in THROTTLED_ERROR_CODES)
428            # boto3 errors
429            or (isinstance(e, S3ResponseError) and e.error_code in THROTTLED_ERROR_CODES)
430            or (isinstance(e, ClientError) and 'BucketNotEmpty' in str(e))
431            or (isinstance(e, ClientError) and e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 409 and 'try again' in str(e))
432            or (isinstance(e, ClientError) and e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') in (404, 429, 500, 502, 503, 504)))
433
434
435def retry_s3(delays=default_delays, timeout=default_timeout, predicate=retryable_s3_errors):
436    return old_retry(delays=delays, timeout=timeout, predicate=predicate)
437
438
439def region_to_bucket_location(region):
440    return '' if region == 'us-east-1' else region
441
442
443def bucket_location_to_region(location):
444    return 'us-east-1' if location == '' else location
445