1# Copyright 2010 Google Inc.
2#
3# Permission is hereby granted, free of charge, to any person obtaining a
4# copy of this software and associated documentation files (the
5# "Software"), to deal in the Software without restriction, including
6# without limitation the rights to use, copy, modify, merge, publish, dis-
7# tribute, sublicense, and/or sell copies of the Software, and to permit
8# persons to whom the Software is furnished to do so, subject to the fol-
9# lowing conditions:
10#
11# The above copyright notice and this permission notice shall be included
12# in all copies or substantial portions of the Software.
13#
14# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
16# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
17# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
18# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20# IN THE SOFTWARE.
21
22"""
23Provides basic mocks of core storage service classes, for unit testing:
24ACL, Key, Bucket, Connection, and StorageUri. We implement a subset of
25the interfaces defined in the real boto classes, but don't handle most
26of the optional params (which we indicate with the constant "NOT_IMPL").
27"""
28
29import copy
30import boto
31import base64
32import re
33from hashlib import md5
34
35from boto.utils import compute_md5
36from boto.utils import find_matching_headers
37from boto.utils import merge_headers_by_name
38from boto.s3.prefix import Prefix
39from boto.compat import six
40
41NOT_IMPL = None
42
43
44class MockAcl(object):
45
46    def __init__(self, parent=NOT_IMPL):
47        pass
48
49    def startElement(self, name, attrs, connection):
50        pass
51
52    def endElement(self, name, value, connection):
53        pass
54
55    def to_xml(self):
56        return '<mock_ACL_XML/>'
57
58
59class MockKey(object):
60
61    def __init__(self, bucket=None, name=None):
62        self.bucket = bucket
63        self.name = name
64        self.data = None
65        self.etag = None
66        self.size = None
67        self.closed = True
68        self.content_encoding = None
69        self.content_language = None
70        self.content_type = None
71        self.last_modified = 'Wed, 06 Oct 2010 05:11:54 GMT'
72        self.BufferSize = 8192
73
74    def __repr__(self):
75        if self.bucket:
76            return '<MockKey: %s,%s>' % (self.bucket.name, self.name)
77        else:
78            return '<MockKey: %s>' % self.name
79
80    def get_contents_as_string(self, headers=NOT_IMPL,
81                               cb=NOT_IMPL, num_cb=NOT_IMPL,
82                               torrent=NOT_IMPL,
83                               version_id=NOT_IMPL):
84        return self.data
85
86    def get_contents_to_file(self, fp, headers=NOT_IMPL,
87                             cb=NOT_IMPL, num_cb=NOT_IMPL,
88                             torrent=NOT_IMPL,
89                             version_id=NOT_IMPL,
90                             res_download_handler=NOT_IMPL):
91        fp.write(self.data)
92
93    def get_file(self, fp, headers=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL,
94                 torrent=NOT_IMPL, version_id=NOT_IMPL,
95                 override_num_retries=NOT_IMPL):
96        fp.write(self.data)
97
98    def _handle_headers(self, headers):
99        if not headers:
100            return
101        if find_matching_headers('Content-Encoding', headers):
102            self.content_encoding = merge_headers_by_name('Content-Encoding',
103                                                          headers)
104        if find_matching_headers('Content-Type', headers):
105            self.content_type = merge_headers_by_name('Content-Type', headers)
106        if find_matching_headers('Content-Language', headers):
107            self.content_language = merge_headers_by_name('Content-Language',
108                                                          headers)
109
110    # Simplistic partial implementation for headers: Just supports range GETs
111    # of flavor 'Range: bytes=xyz-'.
112    def open_read(self, headers=None, query_args=NOT_IMPL,
113                  override_num_retries=NOT_IMPL):
114        if self.closed:
115            self.read_pos = 0
116        self.closed = False
117        if headers and 'Range' in headers:
118            match = re.match('bytes=([0-9]+)-$', headers['Range'])
119            if match:
120                self.read_pos = int(match.group(1))
121
122    def close(self, fast=NOT_IMPL):
123      self.closed = True
124
125    def read(self, size=0):
126        self.open_read()
127        if size == 0:
128            data = self.data[self.read_pos:]
129            self.read_pos = self.size
130        else:
131            data = self.data[self.read_pos:self.read_pos+size]
132            self.read_pos += size
133        if not data:
134            self.close()
135        return data
136
137    def set_contents_from_file(self, fp, headers=None, replace=NOT_IMPL,
138                               cb=NOT_IMPL, num_cb=NOT_IMPL,
139                               policy=NOT_IMPL, md5=NOT_IMPL,
140                               res_upload_handler=NOT_IMPL):
141        self.data = fp.read()
142        self.set_etag()
143        self.size = len(self.data)
144        self._handle_headers(headers)
145
146    def set_contents_from_stream(self, fp, headers=None, replace=NOT_IMPL,
147                               cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
148                               reduced_redundancy=NOT_IMPL, query_args=NOT_IMPL,
149                               size=NOT_IMPL):
150        self.data = ''
151        chunk = fp.read(self.BufferSize)
152        while chunk:
153          self.data += chunk
154          chunk = fp.read(self.BufferSize)
155        self.set_etag()
156        self.size = len(self.data)
157        self._handle_headers(headers)
158
159    def set_contents_from_string(self, s, headers=NOT_IMPL, replace=NOT_IMPL,
160                                 cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
161                                 md5=NOT_IMPL, reduced_redundancy=NOT_IMPL):
162        self.data = copy.copy(s)
163        self.set_etag()
164        self.size = len(s)
165        self._handle_headers(headers)
166
167    def set_contents_from_filename(self, filename, headers=None,
168                                   replace=NOT_IMPL, cb=NOT_IMPL,
169                                   num_cb=NOT_IMPL, policy=NOT_IMPL,
170                                   md5=NOT_IMPL, res_upload_handler=NOT_IMPL):
171        fp = open(filename, 'rb')
172        self.set_contents_from_file(fp, headers, replace, cb, num_cb,
173                                    policy, md5, res_upload_handler)
174        fp.close()
175
176    def copy(self, dst_bucket_name, dst_key, metadata=NOT_IMPL,
177             reduced_redundancy=NOT_IMPL, preserve_acl=NOT_IMPL):
178        dst_bucket = self.bucket.connection.get_bucket(dst_bucket_name)
179        return dst_bucket.copy_key(dst_key, self.bucket.name,
180                                   self.name, metadata)
181
182    @property
183    def provider(self):
184        provider = None
185        if self.bucket and self.bucket.connection:
186            provider = self.bucket.connection.provider
187        return provider
188
189    def set_etag(self):
190        """
191        Set etag attribute by generating hex MD5 checksum on current
192        contents of mock key.
193        """
194        m = md5()
195        if not isinstance(self.data, bytes):
196            m.update(self.data.encode('utf-8'))
197        else:
198            m.update(self.data)
199        hex_md5 = m.hexdigest()
200        self.etag = hex_md5
201
202    def compute_md5(self, fp):
203        """
204        :type fp: file
205        :param fp: File pointer to the file to MD5 hash.  The file pointer
206                   will be reset to the beginning of the file before the
207                   method returns.
208
209        :rtype: tuple
210        :return: A tuple containing the hex digest version of the MD5 hash
211                 as the first element and the base64 encoded version of the
212                 plain digest as the second element.
213        """
214        tup = compute_md5(fp)
215        # Returned values are MD5 hash, base64 encoded MD5 hash, and file size.
216        # The internal implementation of compute_md5() needs to return the
217        # file size but we don't want to return that value to the external
218        # caller because it changes the class interface (i.e. it might
219        # break some code) so we consume the third tuple value here and
220        # return the remainder of the tuple to the caller, thereby preserving
221        # the existing interface.
222        self.size = tup[2]
223        return tup[0:2]
224
225class MockBucket(object):
226
227    def __init__(self, connection=None, name=None, key_class=NOT_IMPL):
228        self.name = name
229        self.keys = {}
230        self.acls = {name: MockAcl()}
231        # default object ACLs are one per bucket and not supported for keys
232        self.def_acl = MockAcl()
233        self.subresources = {}
234        self.connection = connection
235        self.logging = False
236
237    def __repr__(self):
238        return 'MockBucket: %s' % self.name
239
240    def copy_key(self, new_key_name, src_bucket_name,
241                 src_key_name, metadata=NOT_IMPL, src_version_id=NOT_IMPL,
242                 storage_class=NOT_IMPL, preserve_acl=NOT_IMPL,
243                 encrypt_key=NOT_IMPL, headers=NOT_IMPL, query_args=NOT_IMPL):
244        new_key = self.new_key(key_name=new_key_name)
245        src_key = self.connection.get_bucket(
246            src_bucket_name).get_key(src_key_name)
247        new_key.data = copy.copy(src_key.data)
248        new_key.size = len(new_key.data)
249        return new_key
250
251    def disable_logging(self):
252        self.logging = False
253
254    def enable_logging(self, target_bucket_prefix):
255        self.logging = True
256
257    def get_logging_config(self):
258        return {"Logging": {}}
259
260    def get_versioning_status(self, headers=NOT_IMPL):
261        return False
262
263    def get_acl(self, key_name='', headers=NOT_IMPL, version_id=NOT_IMPL):
264        if key_name:
265            # Return ACL for the key.
266            return self.acls[key_name]
267        else:
268            # Return ACL for the bucket.
269            return self.acls[self.name]
270
271    def get_def_acl(self, key_name=NOT_IMPL, headers=NOT_IMPL,
272                    version_id=NOT_IMPL):
273        # Return default ACL for the bucket.
274        return self.def_acl
275
276    def get_subresource(self, subresource, key_name=NOT_IMPL, headers=NOT_IMPL,
277                        version_id=NOT_IMPL):
278        if subresource in self.subresources:
279            return self.subresources[subresource]
280        else:
281            return '<Subresource/>'
282
283    def new_key(self, key_name=None):
284        mock_key = MockKey(self, key_name)
285        self.keys[key_name] = mock_key
286        self.acls[key_name] = MockAcl()
287        return mock_key
288
289    def delete_key(self, key_name, headers=NOT_IMPL,
290                   version_id=NOT_IMPL, mfa_token=NOT_IMPL):
291        if key_name not in self.keys:
292            raise boto.exception.StorageResponseError(404, 'Not Found')
293        del self.keys[key_name]
294
295    def get_all_keys(self, headers=NOT_IMPL):
296        return six.itervalues(self.keys)
297
298    def get_key(self, key_name, headers=NOT_IMPL, version_id=NOT_IMPL):
299        # Emulate behavior of boto when get_key called with non-existent key.
300        if key_name not in self.keys:
301            return None
302        return self.keys[key_name]
303
304    def list(self, prefix='', delimiter='', marker=NOT_IMPL,
305             headers=NOT_IMPL):
306        prefix = prefix or '' # Turn None into '' for prefix match.
307        # Return list instead of using a generator so we don't get
308        # 'dictionary changed size during iteration' error when performing
309        # deletions while iterating (e.g., during test cleanup).
310        result = []
311        key_name_set = set()
312        for k in six.itervalues(self.keys):
313            if k.name.startswith(prefix):
314                k_name_past_prefix = k.name[len(prefix):]
315                if delimiter:
316                  pos = k_name_past_prefix.find(delimiter)
317                else:
318                  pos = -1
319                if (pos != -1):
320                    key_or_prefix = Prefix(
321                        bucket=self, name=k.name[:len(prefix)+pos+1])
322                else:
323                    key_or_prefix = MockKey(bucket=self, name=k.name)
324                if key_or_prefix.name not in key_name_set:
325                    key_name_set.add(key_or_prefix.name)
326                    result.append(key_or_prefix)
327        return result
328
329    def set_acl(self, acl_or_str, key_name='', headers=NOT_IMPL,
330                version_id=NOT_IMPL):
331        # We only handle setting ACL XML here; if you pass a canned ACL
332        # the get_acl call will just return that string name.
333        if key_name:
334            # Set ACL for the key.
335            self.acls[key_name] = MockAcl(acl_or_str)
336        else:
337            # Set ACL for the bucket.
338            self.acls[self.name] = MockAcl(acl_or_str)
339
340    def set_def_acl(self, acl_or_str, key_name=NOT_IMPL, headers=NOT_IMPL,
341                    version_id=NOT_IMPL):
342        # We only handle setting ACL XML here; if you pass a canned ACL
343        # the get_acl call will just return that string name.
344        # Set default ACL for the bucket.
345        self.def_acl = acl_or_str
346
347    def set_subresource(self, subresource, value, key_name=NOT_IMPL,
348                        headers=NOT_IMPL, version_id=NOT_IMPL):
349        self.subresources[subresource] = value
350
351
352class MockProvider(object):
353
354    def __init__(self, provider):
355        self.provider = provider
356
357    def get_provider_name(self):
358        return self.provider
359
360
361class MockConnection(object):
362
363    def __init__(self, aws_access_key_id=NOT_IMPL,
364                 aws_secret_access_key=NOT_IMPL, is_secure=NOT_IMPL,
365                 port=NOT_IMPL, proxy=NOT_IMPL, proxy_port=NOT_IMPL,
366                 proxy_user=NOT_IMPL, proxy_pass=NOT_IMPL,
367                 host=NOT_IMPL, debug=NOT_IMPL,
368                 https_connection_factory=NOT_IMPL,
369                 calling_format=NOT_IMPL,
370                 path=NOT_IMPL, provider='s3',
371                 bucket_class=NOT_IMPL):
372        self.buckets = {}
373        self.provider = MockProvider(provider)
374
375    def create_bucket(self, bucket_name, headers=NOT_IMPL, location=NOT_IMPL,
376                      policy=NOT_IMPL, storage_class=NOT_IMPL):
377        if bucket_name in self.buckets:
378            raise boto.exception.StorageCreateError(
379                409, 'BucketAlreadyOwnedByYou',
380                "<Message>Your previous request to create the named bucket "
381                "succeeded and you already own it.</Message>")
382        mock_bucket = MockBucket(name=bucket_name, connection=self)
383        self.buckets[bucket_name] = mock_bucket
384        return mock_bucket
385
386    def delete_bucket(self, bucket, headers=NOT_IMPL):
387        if bucket not in self.buckets:
388            raise boto.exception.StorageResponseError(
389                404, 'NoSuchBucket', '<Message>no such bucket</Message>')
390        del self.buckets[bucket]
391
392    def get_bucket(self, bucket_name, validate=NOT_IMPL, headers=NOT_IMPL):
393        if bucket_name not in self.buckets:
394            raise boto.exception.StorageResponseError(404, 'NoSuchBucket',
395                                                 'Not Found')
396        return self.buckets[bucket_name]
397
398    def get_all_buckets(self, headers=NOT_IMPL):
399        return six.itervalues(self.buckets)
400
401
402# We only mock a single provider/connection.
403mock_connection = MockConnection()
404
405
406class MockBucketStorageUri(object):
407
408    delim = '/'
409
410    def __init__(self, scheme, bucket_name=None, object_name=None,
411                 debug=NOT_IMPL, suppress_consec_slashes=NOT_IMPL,
412                 version_id=None, generation=None, is_latest=False):
413        self.scheme = scheme
414        self.bucket_name = bucket_name
415        self.object_name = object_name
416        self.suppress_consec_slashes = suppress_consec_slashes
417        if self.bucket_name and self.object_name:
418            self.uri = ('%s://%s/%s' % (self.scheme, self.bucket_name,
419                                        self.object_name))
420        elif self.bucket_name:
421            self.uri = ('%s://%s/' % (self.scheme, self.bucket_name))
422        else:
423            self.uri = ('%s://' % self.scheme)
424
425        self.version_id = version_id
426        self.generation = generation and int(generation)
427        self.is_version_specific = (bool(self.generation)
428                                    or bool(self.version_id))
429        self.is_latest = is_latest
430        if bucket_name and object_name:
431            self.versionless_uri = '%s://%s/%s' % (scheme, bucket_name,
432                                                   object_name)
433
434    def __repr__(self):
435        """Returns string representation of URI."""
436        return self.uri
437
438    def acl_class(self):
439        return MockAcl
440
441    def canned_acls(self):
442        return boto.provider.Provider('aws').canned_acls
443
444    def clone_replace_name(self, new_name):
445        return self.__class__(self.scheme, self.bucket_name, new_name)
446
447    def clone_replace_key(self, key):
448        return self.__class__(
449                key.provider.get_provider_name(),
450                bucket_name=key.bucket.name,
451                object_name=key.name,
452                suppress_consec_slashes=self.suppress_consec_slashes,
453                version_id=getattr(key, 'version_id', None),
454                generation=getattr(key, 'generation', None),
455                is_latest=getattr(key, 'is_latest', None))
456
457    def connect(self, access_key_id=NOT_IMPL, secret_access_key=NOT_IMPL):
458        return mock_connection
459
460    def create_bucket(self, headers=NOT_IMPL, location=NOT_IMPL,
461                      policy=NOT_IMPL, storage_class=NOT_IMPL):
462        return self.connect().create_bucket(self.bucket_name)
463
464    def delete_bucket(self, headers=NOT_IMPL):
465        return self.connect().delete_bucket(self.bucket_name)
466
467    def get_versioning_config(self, headers=NOT_IMPL):
468        self.get_bucket().get_versioning_status(headers)
469
470    def has_version(self):
471        return (issubclass(type(self), MockBucketStorageUri)
472                and ((self.version_id is not None)
473                     or (self.generation is not None)))
474
475    def delete_key(self, validate=NOT_IMPL, headers=NOT_IMPL,
476                   version_id=NOT_IMPL, mfa_token=NOT_IMPL):
477        self.get_bucket().delete_key(self.object_name)
478
479    def disable_logging(self, validate=NOT_IMPL, headers=NOT_IMPL,
480                        version_id=NOT_IMPL):
481        self.get_bucket().disable_logging()
482
483    def enable_logging(self, target_bucket, target_prefix, validate=NOT_IMPL,
484                       headers=NOT_IMPL, version_id=NOT_IMPL):
485        self.get_bucket().enable_logging(target_bucket)
486
487    def get_logging_config(self, validate=NOT_IMPL, headers=NOT_IMPL,
488                           version_id=NOT_IMPL):
489        return self.get_bucket().get_logging_config()
490
491    def equals(self, uri):
492        return self.uri == uri.uri
493
494    def get_acl(self, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL):
495        return self.get_bucket().get_acl(self.object_name)
496
497    def get_def_acl(self, validate=NOT_IMPL, headers=NOT_IMPL,
498                    version_id=NOT_IMPL):
499        return self.get_bucket().get_def_acl(self.object_name)
500
501    def get_subresource(self, subresource, validate=NOT_IMPL, headers=NOT_IMPL,
502                        version_id=NOT_IMPL):
503        return self.get_bucket().get_subresource(subresource, self.object_name)
504
505    def get_all_buckets(self, headers=NOT_IMPL):
506        return self.connect().get_all_buckets()
507
508    def get_all_keys(self, validate=NOT_IMPL, headers=NOT_IMPL):
509        return self.get_bucket().get_all_keys(self)
510
511    def list_bucket(self, prefix='', delimiter='', headers=NOT_IMPL,
512                    all_versions=NOT_IMPL):
513        return self.get_bucket().list(prefix=prefix, delimiter=delimiter)
514
515    def get_bucket(self, validate=NOT_IMPL, headers=NOT_IMPL):
516        return self.connect().get_bucket(self.bucket_name)
517
518    def get_key(self, validate=NOT_IMPL, headers=NOT_IMPL,
519                version_id=NOT_IMPL):
520        return self.get_bucket().get_key(self.object_name)
521
522    def is_file_uri(self):
523        return False
524
525    def is_cloud_uri(self):
526        return True
527
528    def names_container(self):
529        return bool(not self.object_name)
530
531    def names_singleton(self):
532        return bool(self.object_name)
533
534    def names_directory(self):
535        return False
536
537    def names_provider(self):
538        return bool(not self.bucket_name)
539
540    def names_bucket(self):
541        return self.names_container()
542
543    def names_file(self):
544        return False
545
546    def names_object(self):
547        return not self.names_container()
548
549    def is_stream(self):
550        return False
551
552    def new_key(self, validate=NOT_IMPL, headers=NOT_IMPL):
553        bucket = self.get_bucket()
554        return bucket.new_key(self.object_name)
555
556    def set_acl(self, acl_or_str, key_name='', validate=NOT_IMPL,
557                headers=NOT_IMPL, version_id=NOT_IMPL):
558        self.get_bucket().set_acl(acl_or_str, key_name)
559
560    def set_def_acl(self, acl_or_str, key_name=NOT_IMPL, validate=NOT_IMPL,
561                    headers=NOT_IMPL, version_id=NOT_IMPL):
562        self.get_bucket().set_def_acl(acl_or_str)
563
564    def set_subresource(self, subresource, value, validate=NOT_IMPL,
565                        headers=NOT_IMPL, version_id=NOT_IMPL):
566        self.get_bucket().set_subresource(subresource, value, self.object_name)
567
568    def copy_key(self, src_bucket_name, src_key_name, metadata=NOT_IMPL,
569                 src_version_id=NOT_IMPL, storage_class=NOT_IMPL,
570                 preserve_acl=NOT_IMPL, encrypt_key=NOT_IMPL, headers=NOT_IMPL,
571                 query_args=NOT_IMPL, src_generation=NOT_IMPL):
572        dst_bucket = self.get_bucket()
573        return dst_bucket.copy_key(new_key_name=self.object_name,
574                                   src_bucket_name=src_bucket_name,
575                                   src_key_name=src_key_name)
576
577    def set_contents_from_string(self, s, headers=NOT_IMPL, replace=NOT_IMPL,
578                                 cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
579                                 md5=NOT_IMPL, reduced_redundancy=NOT_IMPL):
580        key = self.new_key()
581        key.set_contents_from_string(s)
582
583    def set_contents_from_file(self, fp, headers=None, replace=NOT_IMPL,
584                               cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
585                               md5=NOT_IMPL, size=NOT_IMPL, rewind=NOT_IMPL,
586                               res_upload_handler=NOT_IMPL):
587        key = self.new_key()
588        return key.set_contents_from_file(fp, headers=headers)
589
590    def set_contents_from_stream(self, fp, headers=NOT_IMPL, replace=NOT_IMPL,
591                                 cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
592                                 reduced_redundancy=NOT_IMPL,
593                                 query_args=NOT_IMPL, size=NOT_IMPL):
594        dst_key.set_contents_from_stream(fp)
595
596    def get_contents_to_file(self, fp, headers=NOT_IMPL, cb=NOT_IMPL,
597                             num_cb=NOT_IMPL, torrent=NOT_IMPL,
598                             version_id=NOT_IMPL, res_download_handler=NOT_IMPL,
599                             response_headers=NOT_IMPL):
600        key = self.get_key()
601        key.get_contents_to_file(fp)
602
603    def get_contents_to_stream(self, fp, headers=NOT_IMPL, cb=NOT_IMPL,
604                               num_cb=NOT_IMPL, version_id=NOT_IMPL):
605        key = self.get_key()
606        return key.get_contents_to_file(fp)
607