1# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7#     http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13from __future__ import division
14import logging
15import sys
16import threading
17import time
18from collections import namedtuple
19from collections import defaultdict
20
21from s3transfer.exceptions import CancelledError
22from s3transfer.exceptions import FatalError
23from s3transfer.subscribers import BaseSubscriber
24
25from awscli.compat import queue, ensure_text_type
26from awscli.customizations.s3.utils import relative_path
27from awscli.customizations.s3.utils import human_readable_size
28from awscli.customizations.utils import uni_print
29from awscli.customizations.s3.utils import WarningResult
30from awscli.customizations.s3.utils import OnDoneFilteredSubscriber
31
32
33LOGGER = logging.getLogger(__name__)
34
35
36BaseResult = namedtuple('BaseResult', ['transfer_type', 'src', 'dest'])
37
38
39def _create_new_result_cls(name, extra_fields=None, base_cls=BaseResult):
40    # Creates a new namedtuple class that subclasses from BaseResult for the
41    # benefit of filtering by type and ensuring particular base attrs.
42
43    # NOTE: _fields is a public attribute that has an underscore to avoid
44    # naming collisions for namedtuples:
45    # https://docs.python.org/2/library/collections.html#collections.somenamedtuple._fields
46    fields = list(base_cls._fields)
47    if extra_fields:
48        fields += extra_fields
49    return type(name, (namedtuple(name, fields), base_cls), {})
50
51
52QueuedResult = _create_new_result_cls('QueuedResult', ['total_transfer_size'])
53
54ProgressResult = _create_new_result_cls(
55    'ProgressResult', ['bytes_transferred', 'total_transfer_size',
56                       'timestamp'])
57
58SuccessResult = _create_new_result_cls('SuccessResult')
59
60FailureResult = _create_new_result_cls('FailureResult', ['exception'])
61
62DryRunResult = _create_new_result_cls('DryRunResult')
63
64ErrorResult = namedtuple('ErrorResult', ['exception'])
65
66CtrlCResult = _create_new_result_cls('CtrlCResult', base_cls=ErrorResult)
67
68CommandResult = namedtuple(
69    'CommandResult', ['num_tasks_failed', 'num_tasks_warned'])
70
71FinalTotalSubmissionsResult = namedtuple(
72    'FinalTotalSubmissionsResult', ['total_submissions'])
73
74
75class ShutdownThreadRequest(object):
76    pass
77
78
79class BaseResultSubscriber(OnDoneFilteredSubscriber):
80    TRANSFER_TYPE = None
81
82    def __init__(self, result_queue, transfer_type=None):
83        """Subscriber to send result notifications during transfer process
84
85        :param result_queue: The queue to place results to be processed later
86            on.
87        """
88        self._result_queue = result_queue
89        self._result_kwargs_cache = {}
90        self._transfer_type = transfer_type
91        if transfer_type is None:
92            self._transfer_type = self.TRANSFER_TYPE
93
94    def on_queued(self, future, **kwargs):
95        self._add_to_result_kwargs_cache(future)
96        result_kwargs = self._result_kwargs_cache[future.meta.transfer_id]
97        queued_result = QueuedResult(**result_kwargs)
98        self._result_queue.put(queued_result)
99
100    def on_progress(self, future, bytes_transferred, **kwargs):
101        result_kwargs = self._result_kwargs_cache[future.meta.transfer_id]
102        progress_result = ProgressResult(
103            bytes_transferred=bytes_transferred, timestamp=time.time(),
104            **result_kwargs)
105        self._result_queue.put(progress_result)
106
107    def _on_success(self, future):
108        result_kwargs = self._on_done_pop_from_result_kwargs_cache(future)
109        self._result_queue.put(SuccessResult(**result_kwargs))
110
111    def _on_failure(self, future, e):
112        result_kwargs = self._on_done_pop_from_result_kwargs_cache(future)
113        if isinstance(e, CancelledError):
114            error_result_cls = CtrlCResult
115            if isinstance(e, FatalError):
116                error_result_cls = ErrorResult
117            self._result_queue.put(error_result_cls(exception=e))
118        else:
119            self._result_queue.put(FailureResult(exception=e, **result_kwargs))
120
121    def _add_to_result_kwargs_cache(self, future):
122        src, dest = self._get_src_dest(future)
123        result_kwargs = {
124            'transfer_type': self._transfer_type,
125            'src': src,
126            'dest': dest,
127            'total_transfer_size': future.meta.size
128        }
129        self._result_kwargs_cache[future.meta.transfer_id] = result_kwargs
130
131    def _on_done_pop_from_result_kwargs_cache(self, future):
132        result_kwargs = self._result_kwargs_cache.pop(future.meta.transfer_id)
133        result_kwargs.pop('total_transfer_size')
134        return result_kwargs
135
136    def _get_src_dest(self, future):
137        raise NotImplementedError('_get_src_dest()')
138
139
140class UploadResultSubscriber(BaseResultSubscriber):
141    TRANSFER_TYPE = 'upload'
142
143    def _get_src_dest(self, future):
144        call_args = future.meta.call_args
145        src = self._get_src(call_args.fileobj)
146        dest = 's3://' + call_args.bucket + '/' + call_args.key
147        return src, dest
148
149    def _get_src(self, fileobj):
150        return relative_path(fileobj)
151
152
153class UploadStreamResultSubscriber(UploadResultSubscriber):
154    def _get_src(self, fileobj):
155        return '-'
156
157
158class DownloadResultSubscriber(BaseResultSubscriber):
159    TRANSFER_TYPE = 'download'
160
161    def _get_src_dest(self, future):
162        call_args = future.meta.call_args
163        src = 's3://' + call_args.bucket + '/' + call_args.key
164        dest = self._get_dest(call_args.fileobj)
165        return src, dest
166
167    def _get_dest(self, fileobj):
168        return relative_path(fileobj)
169
170
171class DownloadStreamResultSubscriber(DownloadResultSubscriber):
172    def _get_dest(self, fileobj):
173        return '-'
174
175
176class CopyResultSubscriber(BaseResultSubscriber):
177    TRANSFER_TYPE = 'copy'
178
179    def _get_src_dest(self, future):
180        call_args = future.meta.call_args
181        copy_source = call_args.copy_source
182        src = 's3://' + copy_source['Bucket'] + '/' + copy_source['Key']
183        dest = 's3://' + call_args.bucket + '/' + call_args.key
184        return src, dest
185
186
187class DeleteResultSubscriber(BaseResultSubscriber):
188    TRANSFER_TYPE = 'delete'
189
190    def _get_src_dest(self, future):
191        call_args = future.meta.call_args
192        src = 's3://' + call_args.bucket + '/' + call_args.key
193        return src, None
194
195
196class BaseResultHandler(object):
197    """Base handler class to be called in the ResultProcessor"""
198    def __call__(self, result):
199        raise NotImplementedError('__call__()')
200
201
202class ResultRecorder(BaseResultHandler):
203    """Records and track transfer statistics based on results received"""
204    def __init__(self):
205        self.bytes_transferred = 0
206        self.bytes_failed_to_transfer = 0
207        self.files_transferred = 0
208        self.files_failed = 0
209        self.files_warned = 0
210        self.errors = 0
211        self.expected_bytes_transferred = 0
212        self.expected_files_transferred = 0
213        self.final_expected_files_transferred = None
214
215        self.start_time = None
216        self.bytes_transfer_speed = 0
217
218        self._ongoing_progress = defaultdict(int)
219        self._ongoing_total_sizes = {}
220
221        self._result_handler_map = {
222            QueuedResult: self._record_queued_result,
223            ProgressResult: self._record_progress_result,
224            SuccessResult: self._record_success_result,
225            FailureResult: self._record_failure_result,
226            WarningResult: self._record_warning_result,
227            ErrorResult: self._record_error_result,
228            CtrlCResult: self._record_error_result,
229            FinalTotalSubmissionsResult: self._record_final_expected_files,
230        }
231
232    def expected_totals_are_final(self):
233        return (
234            self.final_expected_files_transferred ==
235            self.expected_files_transferred
236        )
237
238    def __call__(self, result):
239        """Record the result of an individual Result object"""
240        self._result_handler_map.get(type(result), self._record_noop)(
241            result=result)
242
243    def _get_ongoing_dict_key(self, result):
244        if not isinstance(result, BaseResult):
245            raise ValueError(
246                'Any result using _get_ongoing_dict_key must subclass from '
247                'BaseResult. Provided result is of type: %s' % type(result)
248            )
249        key_parts = []
250        for result_property in [result.transfer_type, result.src, result.dest]:
251            if result_property is not None:
252                key_parts.append(ensure_text_type(result_property))
253        return u':'.join(key_parts)
254
255    def _pop_result_from_ongoing_dicts(self, result):
256        ongoing_key = self._get_ongoing_dict_key(result)
257        total_progress = self._ongoing_progress.pop(ongoing_key, 0)
258        total_file_size = self._ongoing_total_sizes.pop(ongoing_key, None)
259        return total_progress, total_file_size
260
261    def _record_noop(self, **kwargs):
262        # If the result does not have a handler, then do nothing with it.
263        pass
264
265    def _record_queued_result(self, result, **kwargs):
266        if self.start_time is None:
267            self.start_time = time.time()
268        total_transfer_size = result.total_transfer_size
269        self._ongoing_total_sizes[
270            self._get_ongoing_dict_key(result)] = total_transfer_size
271        # The total transfer size can be None if we do not know the size
272        # immediately so do not add to the total right away.
273        if total_transfer_size:
274            self.expected_bytes_transferred += total_transfer_size
275        self.expected_files_transferred += 1
276
277    def _record_progress_result(self, result, **kwargs):
278        bytes_transferred = result.bytes_transferred
279        self._update_ongoing_transfer_size_if_unknown(result)
280        self._ongoing_progress[
281            self._get_ongoing_dict_key(result)] += bytes_transferred
282        self.bytes_transferred += bytes_transferred
283        # Since the start time is captured in the result recorder and
284        # capture timestamps in the subscriber, there is a chance that if
285        # a progress result gets created right after the queued result
286        # gets created that the timestamp on the progress result is less
287        # than the timestamp of when the result processor actually
288        # processes that initial queued result. So this will avoid
289        # negative progress being displayed or zero division occurring.
290        if result.timestamp > self.start_time:
291            self.bytes_transfer_speed = self.bytes_transferred / (
292                result.timestamp - self.start_time)
293
294    def _update_ongoing_transfer_size_if_unknown(self, result):
295        # This is a special case when the transfer size was previous not
296        # known but was provided in a progress result.
297        ongoing_key = self._get_ongoing_dict_key(result)
298
299        # First, check if the total size is None, meaning its size is
300        # currently unknown.
301        if self._ongoing_total_sizes[ongoing_key] is None:
302            total_transfer_size = result.total_transfer_size
303            # If the total size is no longer None that means we just learned
304            # of the size so let's update the appropriate places with this
305            # knowledge
306            if result.total_transfer_size is not None:
307                self._ongoing_total_sizes[ongoing_key] = total_transfer_size
308                # Figure out how many bytes have been unaccounted for as
309                # the recorder has been keeping track of how many bytes
310                # it has seen so far and add it to the total expected amount.
311                ongoing_progress = self._ongoing_progress[ongoing_key]
312                unaccounted_bytes = total_transfer_size - ongoing_progress
313                self.expected_bytes_transferred += unaccounted_bytes
314            # If we still do not know what the total transfer size is
315            # just update the expected bytes with the know bytes transferred
316            # as we know at the very least, those bytes are expected.
317            else:
318                self.expected_bytes_transferred += result.bytes_transferred
319
320    def _record_success_result(self, result, **kwargs):
321        self._pop_result_from_ongoing_dicts(result)
322        self.files_transferred += 1
323
324    def _record_failure_result(self, result, **kwargs):
325        # If there was a failure, we want to account for the failure in
326        # the count for bytes transferred by just adding on the remaining bytes
327        # that did not get transferred.
328        total_progress, total_file_size = self._pop_result_from_ongoing_dicts(
329            result)
330        if total_file_size is not None:
331            progress_left = total_file_size - total_progress
332            self.bytes_failed_to_transfer += progress_left
333
334        self.files_failed += 1
335        self.files_transferred += 1
336
337    def _record_warning_result(self, **kwargs):
338        self.files_warned += 1
339
340    def _record_error_result(self, **kwargs):
341        self.errors += 1
342
343    def _record_final_expected_files(self, result, **kwargs):
344        self.final_expected_files_transferred = result.total_submissions
345
346
347class ResultPrinter(BaseResultHandler):
348    _FILES_REMAINING = "{remaining_files} file(s) remaining"
349    _ESTIMATED_EXPECTED_TOTAL = "~{expected_total}"
350    _STILL_CALCULATING_TOTALS = " (calculating...)"
351    BYTE_PROGRESS_FORMAT = (
352        'Completed {bytes_completed}/{expected_bytes_completed} '
353        '({transfer_speed}) with ' + _FILES_REMAINING
354    )
355    FILE_PROGRESS_FORMAT = (
356        'Completed {files_completed} file(s) with ' + _FILES_REMAINING
357    )
358    SUCCESS_FORMAT = (
359        u'{transfer_type}: {transfer_location}'
360    )
361    DRY_RUN_FORMAT = u'(dryrun) ' + SUCCESS_FORMAT
362    FAILURE_FORMAT = (
363        u'{transfer_type} failed: {transfer_location} {exception}'
364    )
365    # TODO: Add "warning: " prefix once all commands are converted to using
366    # result printer and remove "warning: " prefix from ``create_warning``.
367    WARNING_FORMAT = (
368        u'{message}'
369    )
370    ERROR_FORMAT = (
371        u'fatal error: {exception}'
372    )
373    CTRL_C_MSG = 'cancelled: ctrl-c received'
374
375    SRC_DEST_TRANSFER_LOCATION_FORMAT = u'{src} to {dest}'
376    SRC_TRANSFER_LOCATION_FORMAT = u'{src}'
377
378    def __init__(self, result_recorder, out_file=None, error_file=None):
379        """Prints status of ongoing transfer
380
381        :type result_recorder: ResultRecorder
382        :param result_recorder: The associated result recorder
383
384        :type out_file: file-like obj
385        :param out_file: Location to write progress and success statements.
386            By default, the location is sys.stdout.
387
388        :type error_file: file-like obj
389        :param error_file: Location to write warnings and errors.
390            By default, the location is sys.stderr.
391        """
392        self._result_recorder = result_recorder
393        self._out_file = out_file
394        if self._out_file is None:
395            self._out_file = sys.stdout
396        self._error_file = error_file
397        if self._error_file is None:
398            self._error_file = sys.stderr
399        self._progress_length = 0
400        self._result_handler_map = {
401            ProgressResult: self._print_progress,
402            SuccessResult: self._print_success,
403            FailureResult: self._print_failure,
404            WarningResult: self._print_warning,
405            ErrorResult: self._print_error,
406            CtrlCResult: self._print_ctrl_c,
407            DryRunResult: self._print_dry_run,
408            FinalTotalSubmissionsResult:
409                self._clear_progress_if_no_more_expected_transfers,
410        }
411
412    def __call__(self, result):
413        """Print the progress of the ongoing transfer based on a result"""
414        self._result_handler_map.get(type(result), self._print_noop)(
415            result=result)
416
417    def _print_noop(self, **kwargs):
418        # If the result does not have a handler, then do nothing with it.
419        pass
420
421    def _print_dry_run(self, result, **kwargs):
422        statement = self.DRY_RUN_FORMAT.format(
423            transfer_type=result.transfer_type,
424            transfer_location=self._get_transfer_location(result)
425        )
426        statement = self._adjust_statement_padding(statement)
427        self._print_to_out_file(statement)
428
429    def _print_success(self, result, **kwargs):
430        success_statement = self.SUCCESS_FORMAT.format(
431            transfer_type=result.transfer_type,
432            transfer_location=self._get_transfer_location(result)
433        )
434        success_statement = self._adjust_statement_padding(success_statement)
435        self._print_to_out_file(success_statement)
436        self._redisplay_progress()
437
438    def _print_failure(self, result, **kwargs):
439        failure_statement = self.FAILURE_FORMAT.format(
440            transfer_type=result.transfer_type,
441            transfer_location=self._get_transfer_location(result),
442            exception=result.exception
443        )
444        failure_statement = self._adjust_statement_padding(failure_statement)
445        self._print_to_error_file(failure_statement)
446        self._redisplay_progress()
447
448    def _print_warning(self, result, **kwargs):
449        warning_statement = self.WARNING_FORMAT.format(message=result.message)
450        warning_statement = self._adjust_statement_padding(warning_statement)
451        self._print_to_error_file(warning_statement)
452        self._redisplay_progress()
453
454    def _print_error(self, result, **kwargs):
455        self._flush_error_statement(
456            self.ERROR_FORMAT.format(exception=result.exception))
457
458    def _print_ctrl_c(self, result, **kwargs):
459        self._flush_error_statement(self.CTRL_C_MSG)
460
461    def _flush_error_statement(self, error_statement):
462        error_statement = self._adjust_statement_padding(error_statement)
463        self._print_to_error_file(error_statement)
464
465    def _get_transfer_location(self, result):
466        if result.dest is None:
467            return self.SRC_TRANSFER_LOCATION_FORMAT.format(src=result.src)
468        return self.SRC_DEST_TRANSFER_LOCATION_FORMAT.format(
469            src=result.src, dest=result.dest)
470
471    def _redisplay_progress(self):
472        # Reset to zero because done statements are printed with new lines
473        # meaning there are no carriage returns to take into account when
474        # printing the next line.
475        self._progress_length = 0
476        self._add_progress_if_needed()
477
478    def _add_progress_if_needed(self):
479        if self._has_remaining_progress():
480            self._print_progress()
481
482    def _print_progress(self, **kwargs):
483        # Get all of the statistics in the correct form.
484        remaining_files = self._get_expected_total(
485            str(self._result_recorder.expected_files_transferred -
486                self._result_recorder.files_transferred)
487        )
488
489        # Create the display statement.
490        if self._result_recorder.expected_bytes_transferred > 0:
491            bytes_completed = human_readable_size(
492                self._result_recorder.bytes_transferred +
493                self._result_recorder.bytes_failed_to_transfer
494            )
495            expected_bytes_completed = self._get_expected_total(
496                human_readable_size(
497                    self._result_recorder.expected_bytes_transferred))
498
499            transfer_speed = human_readable_size(
500                self._result_recorder.bytes_transfer_speed) + '/s'
501            progress_statement = self.BYTE_PROGRESS_FORMAT.format(
502                bytes_completed=bytes_completed,
503                expected_bytes_completed=expected_bytes_completed,
504                transfer_speed=transfer_speed,
505                remaining_files=remaining_files
506            )
507        else:
508            # We're not expecting any bytes to be transferred, so we should
509            # only print of information about number of files transferred.
510            progress_statement = self.FILE_PROGRESS_FORMAT.format(
511                files_completed=self._result_recorder.files_transferred,
512                remaining_files=remaining_files
513            )
514
515        if not self._result_recorder.expected_totals_are_final():
516            progress_statement += self._STILL_CALCULATING_TOTALS
517
518        # Make sure that it overrides any previous progress bar.
519        progress_statement = self._adjust_statement_padding(
520                progress_statement, ending_char='\r')
521        # We do not want to include the carriage return in this calculation
522        # as progress length is used for determining whitespace padding.
523        # So we subtract one off of the length.
524        self._progress_length = len(progress_statement) - 1
525
526        # Print the progress out.
527        self._print_to_out_file(progress_statement)
528
529    def _get_expected_total(self, expected_total):
530        if not self._result_recorder.expected_totals_are_final():
531            return self._ESTIMATED_EXPECTED_TOTAL.format(
532                expected_total=expected_total)
533        return expected_total
534
535    def _adjust_statement_padding(self, print_statement, ending_char='\n'):
536        print_statement = print_statement.ljust(self._progress_length, ' ')
537        return print_statement + ending_char
538
539    def _has_remaining_progress(self):
540        if not self._result_recorder.expected_totals_are_final():
541            return True
542        actual = self._result_recorder.files_transferred
543        expected = self._result_recorder.expected_files_transferred
544        return actual != expected
545
546    def _print_to_out_file(self, statement):
547        uni_print(statement, self._out_file)
548
549    def _print_to_error_file(self, statement):
550        uni_print(statement, self._error_file)
551
552    def _clear_progress_if_no_more_expected_transfers(self, **kwargs):
553        if self._progress_length and not self._has_remaining_progress():
554            uni_print(self._adjust_statement_padding(''), self._out_file)
555
556
557class NoProgressResultPrinter(ResultPrinter):
558    """A result printer that doesn't print progress"""
559    def _print_progress(self, **kwargs):
560        pass
561
562
563class OnlyShowErrorsResultPrinter(ResultPrinter):
564    """A result printer that only prints out errors"""
565    def _print_progress(self, **kwargs):
566        pass
567
568    def _print_success(self, result, **kwargs):
569        pass
570
571
572class ResultProcessor(threading.Thread):
573    def __init__(self, result_queue, result_handlers=None):
574        """Thread to process results from result queue
575
576        This includes recording statistics and printing transfer status
577
578        :param result_queue: The result queue to process results from
579        :param result_handlers: A list of callables that take a result in as
580            a parameter to process the result for that handler.
581        """
582        threading.Thread.__init__(self)
583        self._result_queue = result_queue
584        self._result_handlers = result_handlers
585        if self._result_handlers is None:
586            self._result_handlers = []
587        self._result_handlers_enabled = True
588
589    def run(self):
590        while True:
591            try:
592                result = self._result_queue.get(True)
593                if isinstance(result, ShutdownThreadRequest):
594                    LOGGER.debug(
595                        'Shutdown request received in result processing '
596                        'thread, shutting down result thread.')
597                    break
598                if self._result_handlers_enabled:
599                    self._process_result(result)
600                # ErrorResults are fatal to the command. If a fatal error
601                # is seen, we know that the command is trying to shutdown
602                # so disable all of the handlers and quickly consume all
603                # of the results in the result queue in order to get to
604                # the shutdown request to clean up the process.
605                if isinstance(result, ErrorResult):
606                    self._result_handlers_enabled = False
607            except queue.Empty:
608                pass
609
610    def _process_result(self, result):
611        for result_handler in self._result_handlers:
612            try:
613                result_handler(result)
614            except Exception as e:
615                LOGGER.debug(
616                    'Error processing result %s with handler %s: %s',
617                    result, result_handler, e, exc_info=True)
618
619
620class CommandResultRecorder(object):
621    def __init__(self, result_queue, result_recorder, result_processor):
622        """Records the result for an entire command
623
624        It will fully process all results in a result queue and determine
625        a CommandResult representing the entire command.
626
627        :type result_queue: queue.Queue
628        :param result_queue: The result queue in which results are placed on
629            and processed from
630
631        :type result_recorder: ResultRecorder
632        :param result_recorder: The result recorder to track the various
633            results sent through the result queue
634
635        :type result_processor: ResultProcessor
636        :param result_processor: The result processor to process results
637            placed on the queue
638        """
639        self.result_queue = result_queue
640        self._result_recorder = result_recorder
641        self._result_processor = result_processor
642
643    def start(self):
644        self._result_processor.start()
645
646    def shutdown(self):
647        self.result_queue.put(ShutdownThreadRequest())
648        self._result_processor.join()
649
650    def get_command_result(self):
651        """Get the CommandResult representing the result of a command
652
653        :rtype: CommandResult
654        :returns: The CommandResult representing the total result from running
655            a particular command
656        """
657        return CommandResult(
658            self._result_recorder.files_failed + self._result_recorder.errors,
659            self._result_recorder.files_warned
660        )
661
662    def notify_total_submissions(self, total):
663        self.result_queue.put(FinalTotalSubmissionsResult(total))
664
665    def __enter__(self):
666        self.start()
667        return self
668
669    def __exit__(self, exc_type, exc_value, *args):
670        if exc_type:
671            LOGGER.debug('Exception caught during command execution: %s',
672                         exc_value, exc_info=True)
673            self.result_queue.put(ErrorResult(exception=exc_value))
674            self.shutdown()
675            return True
676        self.shutdown()
677