1# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13import glob
14import os
15import time
16import threading
17
18from concurrent.futures import CancelledError
19
20from tests import assert_files_equal
21from tests import skip_if_windows
22from tests import skip_if_using_serial_implementation
23from tests import RecordingSubscriber
24from tests import NonSeekableWriter
25from tests.integration import BaseTransferManagerIntegTest
26from tests.integration import WaitForTransferStart
27from s3transfer.manager import TransferConfig
28from s3transfer.subscribers import BaseSubscriber
29
30
31class TestDownload(BaseTransferManagerIntegTest):
32    def setUp(self):
33        super(TestDownload, self).setUp()
34        self.multipart_threshold = 5 * 1024 * 1024
35        self.config = TransferConfig(
36            multipart_threshold=self.multipart_threshold
37        )
38
39    def test_below_threshold(self):
40        transfer_manager = self.create_transfer_manager(self.config)
41
42        filename = self.files.create_file_with_size(
43            'foo.txt', filesize=1024 * 1024)
44        self.upload_file(filename, '1mb.txt')
45
46        download_path = os.path.join(self.files.rootdir, '1mb.txt')
47        future = transfer_manager.download(
48            self.bucket_name, '1mb.txt', download_path)
49        future.result()
50        assert_files_equal(filename, download_path)
51
52    def test_above_threshold(self):
53        transfer_manager = self.create_transfer_manager(self.config)
54
55        filename = self.files.create_file_with_size(
56            'foo.txt', filesize=20 * 1024 * 1024)
57        self.upload_file(filename, '20mb.txt')
58
59        download_path = os.path.join(self.files.rootdir, '20mb.txt')
60        future = transfer_manager.download(
61            self.bucket_name, '20mb.txt', download_path)
62        future.result()
63        assert_files_equal(filename, download_path)
64
65    @skip_if_using_serial_implementation(
66        'Exception is thrown once the transfer is submitted. '
67        'However for the serial implementation, transfers are performed '
68        'in main thread meaning the transfer will complete before the '
69        'KeyboardInterrupt being thrown.'
70    )
71    def test_large_download_exits_quicky_on_exception(self):
72        transfer_manager = self.create_transfer_manager(self.config)
73
74        filename = self.files.create_file_with_size(
75            'foo.txt', filesize=60 * 1024 * 1024)
76        self.upload_file(filename, '60mb.txt')
77
78        download_path = os.path.join(self.files.rootdir, '60mb.txt')
79        timeout = 10
80        bytes_transferring = threading.Event()
81        subscriber = WaitForTransferStart(bytes_transferring)
82        try:
83            with transfer_manager:
84                future = transfer_manager.download(
85                    self.bucket_name, '60mb.txt', download_path,
86                    subscribers=[subscriber]
87                )
88                if not bytes_transferring.wait(timeout):
89                    future.cancel()
90                    raise RuntimeError(
91                        "Download transfer did not start after waiting for "
92                        "%s seconds." % timeout)
93                # Raise an exception which should cause the preceeding
94                # download to cancel and exit quickly
95                start_time = time.time()
96                raise KeyboardInterrupt()
97        except KeyboardInterrupt:
98            pass
99        end_time = time.time()
100        # The maximum time allowed for the transfer manager to exit.
101        # This means that it should take less than a couple second after
102        # sleeping to exit.
103        max_allowed_exit_time = 5
104        actual_time_to_exit = end_time - start_time
105        self.assertLess(
106            actual_time_to_exit, max_allowed_exit_time,
107            "Failed to exit under %s. Instead exited in %s." % (
108                max_allowed_exit_time, actual_time_to_exit)
109        )
110
111        # Make sure the future was cancelled because of the KeyboardInterrupt
112        with self.assertRaisesRegexp(CancelledError, 'KeyboardInterrupt()'):
113            future.result()
114
115        # Make sure the actual file and the temporary do not exist
116        # by globbing for the file and any of its extensions
117        possible_matches = glob.glob('%s*' % download_path)
118        self.assertEqual(possible_matches, [])
119
120    @skip_if_using_serial_implementation(
121        'Exception is thrown once the transfer is submitted. '
122        'However for the serial implementation, transfers are performed '
123        'in main thread meaning the transfer will complete before the '
124        'KeyboardInterrupt being thrown.'
125    )
126    def test_many_files_exits_quicky_on_exception(self):
127        # Set the max request queue size and number of submission threads
128        # to something small to simulate having a large queue
129        # of transfer requests to complete and it is backed up.
130        self.config.max_request_queue_size = 1
131        self.config.max_submission_concurrency = 1
132        transfer_manager = self.create_transfer_manager(self.config)
133
134        filename = self.files.create_file_with_size(
135            'foo.txt', filesize=1024 * 1024)
136        self.upload_file(filename, '1mb.txt')
137
138        filenames = []
139        futures = []
140        for i in range(10):
141            filenames.append(
142                os.path.join(self.files.rootdir, 'file'+str(i)))
143
144        try:
145            with transfer_manager:
146                start_time = time.time()
147                for filename in filenames:
148                    futures.append(transfer_manager.download(
149                        self.bucket_name, '1mb.txt', filename))
150                # Raise an exception which should cause the preceeding
151                # transfer to cancel and exit quickly
152                raise KeyboardInterrupt()
153        except KeyboardInterrupt:
154            pass
155        end_time = time.time()
156        # The maximum time allowed for the transfer manager to exit.
157        # This means that it should take less than a couple seconds to exit.
158        max_allowed_exit_time = 5
159        self.assertLess(
160            end_time - start_time, max_allowed_exit_time,
161            "Failed to exit under %s. Instead exited in %s." % (
162                max_allowed_exit_time, end_time - start_time)
163        )
164
165        # Make sure at least one of the futures got cancelled
166        with self.assertRaisesRegexp(CancelledError, 'KeyboardInterrupt()'):
167            for future in futures:
168                future.result()
169
170        # For the transfer that did get cancelled, make sure the temporary
171        # file got removed.
172        possible_matches = glob.glob('%s*' % future.meta.call_args.fileobj)
173        self.assertEqual(possible_matches, [])
174
175    def test_progress_subscribers_on_download(self):
176        subscriber = RecordingSubscriber()
177        transfer_manager = self.create_transfer_manager(self.config)
178
179        filename = self.files.create_file_with_size(
180            'foo.txt', filesize=20 * 1024 * 1024)
181        self.upload_file(filename, '20mb.txt')
182
183        download_path = os.path.join(self.files.rootdir, '20mb.txt')
184
185        future = transfer_manager.download(
186            self.bucket_name, '20mb.txt', download_path,
187            subscribers=[subscriber])
188        future.result()
189        self.assertEqual(subscriber.calculate_bytes_seen(), 20 * 1024 * 1024)
190
191    def test_below_threshold_for_fileobj(self):
192        transfer_manager = self.create_transfer_manager(self.config)
193
194        filename = self.files.create_file_with_size(
195            'foo.txt', filesize=1024 * 1024)
196        self.upload_file(filename, '1mb.txt')
197
198        download_path = os.path.join(self.files.rootdir, '1mb.txt')
199        with open(download_path, 'wb') as f:
200            future = transfer_manager.download(
201                self.bucket_name, '1mb.txt', f)
202            future.result()
203        assert_files_equal(filename, download_path)
204
205    def test_above_threshold_for_fileobj(self):
206        transfer_manager = self.create_transfer_manager(self.config)
207
208        filename = self.files.create_file_with_size(
209            'foo.txt', filesize=20 * 1024 * 1024)
210        self.upload_file(filename, '20mb.txt')
211
212        download_path = os.path.join(self.files.rootdir, '20mb.txt')
213        with open(download_path, 'wb') as f:
214            future = transfer_manager.download(
215                self.bucket_name, '20mb.txt', f)
216            future.result()
217        assert_files_equal(filename, download_path)
218
219    def test_below_threshold_for_nonseekable_fileobj(self):
220        transfer_manager = self.create_transfer_manager(self.config)
221
222        filename = self.files.create_file_with_size(
223            'foo.txt', filesize=1024 * 1024)
224        self.upload_file(filename, '1mb.txt')
225
226        download_path = os.path.join(self.files.rootdir, '1mb.txt')
227        with open(download_path, 'wb') as f:
228            future = transfer_manager.download(
229                self.bucket_name, '1mb.txt', NonSeekableWriter(f))
230            future.result()
231        assert_files_equal(filename, download_path)
232
233    def test_above_threshold_for_nonseekable_fileobj(self):
234        transfer_manager = self.create_transfer_manager(self.config)
235
236        filename = self.files.create_file_with_size(
237            'foo.txt', filesize=20 * 1024 * 1024)
238        self.upload_file(filename, '20mb.txt')
239
240        download_path = os.path.join(self.files.rootdir, '20mb.txt')
241        with open(download_path, 'wb') as f:
242            future = transfer_manager.download(
243                self.bucket_name, '20mb.txt', NonSeekableWriter(f))
244            future.result()
245        assert_files_equal(filename, download_path)
246
247    @skip_if_windows('Windows does not support UNIX special files')
248    def test_download_to_special_file(self):
249        transfer_manager = self.create_transfer_manager(self.config)
250        filename = self.files.create_file_with_size(
251            'foo.txt', filesize=1024 * 1024)
252        self.upload_file(filename, '1mb.txt')
253        future = transfer_manager.download(
254            self.bucket_name, '1mb.txt', '/dev/null')
255        try:
256            future.result()
257        except Exception as e:
258            self.fail(
259                'Should have been able to download to /dev/null but received '
260                'following exception %s' % e)
261