1######################################################################
2#
3# File: b2/download_dest.py
4#
5# Copyright 2018 Backblaze Inc. All Rights Reserved.
6#
7# License https://www.backblaze.com/using_b2_code.html
8#
9######################################################################
10
11import os
12from abc import abstractmethod
13from contextlib import contextmanager
14
15import six
16
17from .utils import B2TraceMetaAbstract, limit_trace_arguments
18from .progress import WritingStreamWithProgress
19
20
21@six.add_metaclass(B2TraceMetaAbstract)
22class AbstractDownloadDestination(object):
23    """
24    Interface to a destination for a downloaded file.
25    """
26
27    @abstractmethod
28    @limit_trace_arguments(skip=[
29        'content_sha1',
30    ])
31    def make_file_context(
32        self,
33        file_id,
34        file_name,
35        content_length,
36        content_type,
37        content_sha1,
38        file_info,
39        mod_time_millis,
40        range_=None
41    ):
42        """
43        Returns a context manager that yields a binary file-like object to use for
44        writing the contents of the file.
45
46        :param file_id: the B2 file ID from the headers
47        :param file_name: the B2 file name from the headers
48        :param content_type: the content type from the headers
49        :param content_sha1: the content sha1 from the headers (or "none" for large files)
50        :param file_info: the user file info from the headers
51        :param mod_time_millis: the desired file modification date in ms since 1970-01-01
52        :param range_: starting and ending offsets of the received file contents. Usually None,
53                       which means that the whole file is downloaded.
54        :return: None
55        """
56
57
58class DownloadDestLocalFile(AbstractDownloadDestination):
59    """
60    Stores a downloaded file into a local file and sets its modification time.
61    """
62    MODE = 'wb+'
63
64    def __init__(self, local_file_path):
65        self.local_file_path = local_file_path
66
67    def make_file_context(
68        self,
69        file_id,
70        file_name,
71        content_length,
72        content_type,
73        content_sha1,
74        file_info,
75        mod_time_millis,
76        range_=None
77    ):
78        self.file_id = file_id
79        self.file_name = file_name
80        self.content_length = content_length
81        self.content_type = content_type
82        self.content_sha1 = content_sha1
83        self.file_info = file_info
84        self.range_ = range_
85        return self.write_to_local_file_context(mod_time_millis)
86
87    @contextmanager
88    def write_to_local_file_context(self, mod_time_millis):
89        completed = False
90        try:
91            # Open the file and let the caller write it.
92            with open(self.local_file_path, self.MODE) as f:
93                yield f
94
95            # After it's closed, set the mod time.
96            # This is an ugly hack to make the tests work.  I can't think
97            # of any other cases where os.utime might fail.
98            if self.local_file_path != os.devnull:
99                mod_time = mod_time_millis / 1000.0
100                os.utime(self.local_file_path, (mod_time, mod_time))
101
102            # Set the flag that means to leave the downloaded file on disk.
103            completed = True
104
105        finally:
106            # This is a best-effort attempt to clean up files that
107            # failed to download, so we don't leave partial files
108            # sitting on disk.
109            if not completed:
110                os.unlink(self.local_file_path)
111
112
113class PreSeekedDownloadDest(DownloadDestLocalFile):
114    """
115    Stores a downloaded file into a local file and sets its modification time.
116    Does not truncate the target file, seeks to a given offset just after opening
117    a descriptor.
118    """
119    MODE = 'rb+'
120
121    def __init__(self, local_file_path, seek_target):
122        self._seek_target = seek_target
123        super(PreSeekedDownloadDest, self).__init__(local_file_path)
124
125    @contextmanager
126    def write_to_local_file_context(self, *args, **kwargs):
127        with super(PreSeekedDownloadDest, self).write_to_local_file_context(*args, **kwargs) as f:
128            f.seek(self._seek_target)
129            yield f
130
131
132class DownloadDestBytes(AbstractDownloadDestination):
133    """
134    Stores a downloaded file into bytes in memory.
135    """
136
137    def __init__(self):
138        self.bytes_written = None
139
140    def make_file_context(
141        self,
142        file_id,
143        file_name,
144        content_length,
145        content_type,
146        content_sha1,
147        file_info,
148        mod_time_millis,
149        range_=None
150    ):
151        self.file_id = file_id
152        self.file_name = file_name
153        self.content_length = content_length
154        self.content_type = content_type
155        self.content_sha1 = content_sha1
156        self.file_info = file_info
157        self.mod_time_millis = mod_time_millis
158        self.range_ = range_
159        return self.capture_bytes_context()
160
161    @contextmanager
162    def capture_bytes_context(self):
163        """
164        Remembers the bytes written in self.bytes_written
165        """
166        # Make a place to store the data written
167        bytes_io = six.BytesIO()
168
169        # Let the caller write it
170        yield bytes_io
171
172        # Capture the result.  The BytesIO object won't let you grab
173        # the data after it's closed
174        self.bytes_written = bytes_io.getvalue()
175        bytes_io.close()
176
177    def get_bytes_written(self):
178        if self.bytes_written is None:
179            raise Exception('data not written yet')
180        return self.bytes_written
181
182
183class DownloadDestProgressWrapper(AbstractDownloadDestination):
184    """
185    Wraps a DownloadDestination, and reports progress to a ProgressListener.
186    """
187
188    def __init__(self, download_dest, progress_listener):
189        self.download_dest = download_dest
190        self.progress_listener = progress_listener
191
192    def make_file_context(
193        self,
194        file_id,
195        file_name,
196        content_length,
197        content_type,
198        content_sha1,
199        file_info,
200        mod_time_millis,
201        range_=None
202    ):
203        return self.write_file_and_report_progress_context(
204            file_id, file_name, content_length, content_type, content_sha1, file_info,
205            mod_time_millis, range_
206        )
207
208    @contextmanager
209    def write_file_and_report_progress_context(
210        self, file_id, file_name, content_length, content_type, content_sha1, file_info,
211        mod_time_millis, range_
212    ):
213        with self.download_dest.make_file_context(
214            file_id, file_name, content_length, content_type, content_sha1, file_info,
215            mod_time_millis, range_
216        ) as file_:
217            total_bytes = content_length
218            if range_ is not None:
219                total_bytes = range_[1] - range_[0] + 1
220            self.progress_listener.set_total_bytes(total_bytes)
221            with self.progress_listener:
222                yield WritingStreamWithProgress(file_, self.progress_listener)
223