1######################################################################
2#
3# File: b2sdk/transfer/inbound/downloaded_file.py
4#
5# Copyright 2021 Backblaze Inc. All Rights Reserved.
6#
7# License https://www.backblaze.com/using_b2_code.html
8#
9######################################################################
10
11import io
12import logging
13from typing import Optional, Tuple, TYPE_CHECKING
14
15from requests.models import Response
16
17from ...encryption.setting import EncryptionSetting
18from ...file_version import DownloadVersion
19from ...progress import AbstractProgressListener
20from ...stream.progress import WritingStreamWithProgress
21
22from b2sdk.exception import (
23    ChecksumMismatch,
24    TruncatedOutput,
25)
26from b2sdk.utils import set_file_mtime
27
28if TYPE_CHECKING:
29    from .download_manager import DownloadManager
30
31logger = logging.getLogger(__name__)
32
33
34class MtimeUpdatedFile(io.IOBase):
35    """
36    Helper class that facilitates updating a files mod_time after closing.
37    Usage:
38
39    .. code-block: python
40
41       downloaded_file = bucket.download_file_by_id('b2_file_id')
42       with MtimeUpdatedFile('some_local_path', mod_time_millis=downloaded_file.download_version.mod_time_millis) as file:
43           downloaded_file.save(file)
44       #  'some_local_path' has the mod_time set according to metadata in B2
45    """
46
47    def __init__(self, path_, mod_time_millis: int, mode='wb+'):
48        self.path_ = path_
49        self.mode = mode
50        self.mod_time_to_set = mod_time_millis
51        self.file = None
52
53    def write(self, value):
54        """
55        This method is overwritten (monkey-patched) in __enter__ for performance reasons
56        """
57        raise NotImplementedError
58
59    def read(self, *a):
60        """
61        This method is overwritten (monkey-patched) in __enter__ for performance reasons
62        """
63        raise NotImplementedError
64
65    def seek(self, offset, whence=0):
66        return self.file.seek(offset, whence)
67
68    def tell(self):
69        return self.file.tell()
70
71    def __enter__(self):
72        self.file = open(self.path_, self.mode)
73        self.write = self.file.write
74        self.read = self.file.read
75        return self
76
77    def __exit__(self, exc_type, exc_val, exc_tb):
78        self.file.close()
79        set_file_mtime(self.path_, self.mod_time_to_set)
80
81
82class DownloadedFile:
83    """
84    Result of a successful download initialization. Holds information about file's metadata
85    and allows to perform the download.
86    """
87
88    def __init__(
89        self,
90        download_version: DownloadVersion,
91        download_manager: 'DownloadManager',
92        range_: Optional[Tuple[int, int]],
93        response: Response,
94        encryption: Optional[EncryptionSetting],
95        progress_listener: AbstractProgressListener,
96    ):
97        self.download_version = download_version
98        self.download_manager = download_manager
99        self.range_ = range_
100        self.response = response
101        self.encryption = encryption
102        self.progress_listener = progress_listener
103        self.download_strategy = None
104
105    def _validate_download(self, bytes_read, actual_sha1):
106        if self.range_ is None:
107            if bytes_read != self.download_version.content_length:
108                raise TruncatedOutput(bytes_read, self.download_version.content_length)
109
110            if self.download_version.content_sha1 != 'none' and actual_sha1 != self.download_version.content_sha1:
111                raise ChecksumMismatch(
112                    checksum_type='sha1',
113                    expected=self.download_version.content_sha1,
114                    actual=actual_sha1,
115                )
116        else:
117            desired_length = self.range_[1] - self.range_[0] + 1
118            if bytes_read != desired_length:
119                raise TruncatedOutput(bytes_read, desired_length)
120
121    def save(self, file, allow_seeking=True):
122        """
123        Read data from B2 cloud and write it to a file-like object
124
125        :param file: a file-like object
126        :param allow_seeking: if False, download strategies that rely on seeking to write data
127                              (parallel strategies) will be discarded.
128        """
129        if self.progress_listener:
130            file = WritingStreamWithProgress(file, self.progress_listener)
131            if self.range_ is not None:
132                total_bytes = self.range_[1] - self.range_[0] + 1
133            else:
134                total_bytes = self.download_version.content_length
135            self.progress_listener.set_total_bytes(total_bytes)
136        for strategy in self.download_manager.strategies:
137            if strategy.is_suitable(self.download_version, allow_seeking):
138                break
139        else:
140            raise ValueError('no strategy suitable for download was found!')
141        self.download_strategy = strategy
142        bytes_read, actual_sha1 = strategy.download(
143            file,
144            response=self.response,
145            download_version=self.download_version,
146            session=self.download_manager.services.session,
147            encryption=self.encryption,
148        )
149        self._validate_download(bytes_read, actual_sha1)
150
151    def save_to(self, path_, mode='wb+', allow_seeking=True):
152        """
153        Open a local file and write data from B2 cloud to it, also update the mod_time.
154
155        :param path_: path to file to be opened
156        :param mode: mode in which the file should be opened
157        :param allow_seeking: if False, download strategies that rely on seeking to write data
158                              (parallel strategies) will be discarded.
159        """
160        with MtimeUpdatedFile(
161            path_, mod_time_millis=self.download_version.mod_time_millis, mode=mode
162        ) as file:
163            self.save(file, allow_seeking=allow_seeking)
164