1__author__ = "Johannes Köster"
2__copyright__ = "Copyright 2021, Johannes Köster"
3__email__ = "johannes.koester@tu-dortmund.de"
4__license__ = "MIT"
5
6import base64
7import os
8import re
9import struct
10import time
11
12from snakemake.remote import AbstractRemoteObject, AbstractRemoteProvider
13from snakemake.exceptions import WorkflowError, CheckSumMismatchException
14from snakemake.common import lazy_property
15import snakemake.io
16from snakemake.utils import os_sync
17
18try:
19    import google.cloud
20    from google.cloud import storage
21    from google.api_core import retry
22    from google_crc32c import Checksum
23except ImportError as e:
24    raise WorkflowError(
25        "The Python 3 packages 'google-cloud-sdk' and `google-crc32c` "
26        "need to be installed to use GS remote() file functionality. %s" % e.msg
27    )
28
29
30def google_cloud_retry_predicate(ex):
31    """Given an exception from Google Cloud, determine if it's one in the
32    listing of transient errors (determined by function
33    google.api_core.retry.if_transient_error(exception)) or determine if
34    triggered by a hash mismatch due to a bad download. This function will
35    return a boolean to indicate if retry should be done, and is typically
36    used with the google.api_core.retry.Retry as a decorator (predicate).
37
38    Arguments:
39      ex (Exception) : the exception passed from the decorated function
40    Returns: boolean to indicate doing retry (True) or not (False)
41    """
42    from requests.exceptions import ReadTimeout
43
44    # Most likely case is Google API transient error.
45    if retry.if_transient_error(ex):
46        return True
47    # Timeouts should be considered for retry as well.
48    if isinstance(ex, ReadTimeout):
49        return True
50    # Could also be checksum mismatch of download.
51    if isinstance(ex, CheckSumMismatchException):
52        return True
53    return False
54
55
56@retry.Retry(predicate=google_cloud_retry_predicate)
57def download_blob(blob, filename):
58    """A helper function to download a storage Blob to a blob_file (the filename)
59    and validate it using the Crc32cCalculator.
60
61    Arguments:
62      blob (storage.Blob) : the Google storage blob object
63      blob_file (str)     : the file path to download to
64    Returns: boolean to indicate doing retry (True) or not (False)
65    """
66
67    # create parent directories if necessary
68    os.makedirs(os.path.dirname(filename), exist_ok=True)
69
70    # ideally we could calculate hash while streaming to file with provided function
71    # https://github.com/googleapis/python-storage/issues/29
72    with open(filename, "wb") as blob_file:
73        parser = Crc32cCalculator(blob_file)
74        blob.download_to_file(parser)
75    os.sync()
76
77    # **Important** hash can be incorrect or missing if not refreshed
78    blob.reload()
79
80    # Compute local hash and verify correct
81    if parser.hexdigest() != blob.crc32c:
82        os.remove(filename)
83        raise CheckSumMismatchException("The checksum of %s does not match." % filename)
84    return filename
85
86
87class Crc32cCalculator:
88    """The Google Python client doesn't provide a way to stream a file being
89    written, so we can wrap the file object in an additional class to
90    do custom handling. This is so we don't need to download the file
91    and then stream read it again to calculate the hash.
92    """
93
94    def __init__(self, fileobj):
95        self._fileobj = fileobj
96        self.checksum = Checksum()
97
98    def write(self, chunk):
99        self._fileobj.write(chunk)
100        self._update(chunk)
101
102    def _update(self, chunk):
103        """Given a chunk from the read in file, update the hexdigest"""
104        self.checksum.update(chunk)
105
106    def hexdigest(self):
107        """Return the hexdigest of the hasher.
108        The Base64 encoded CRC32c is in big-endian byte order.
109        See https://cloud.google.com/storage/docs/hashes-etags
110        """
111        return base64.b64encode(self.checksum.digest()).decode("utf-8")
112
113
114class RemoteProvider(AbstractRemoteProvider):
115
116    supports_default = True
117
118    def __init__(
119        self, *args, keep_local=False, stay_on_remote=False, is_default=False, **kwargs
120    ):
121        super(RemoteProvider, self).__init__(
122            *args,
123            keep_local=keep_local,
124            stay_on_remote=stay_on_remote,
125            is_default=is_default,
126            **kwargs
127        )
128
129        self.client = storage.Client(*args, **kwargs)
130
131    def remote_interface(self):
132        return self.client
133
134    @property
135    def default_protocol(self):
136        """The protocol that is prepended to the path when no protocol is specified."""
137        return "gs://"
138
139    @property
140    def available_protocols(self):
141        """List of valid protocols for this remote provider."""
142        return ["gs://"]
143
144
145class RemoteObject(AbstractRemoteObject):
146    def __init__(
147        self, *args, keep_local=False, provider=None, user_project=None, **kwargs
148    ):
149        super(RemoteObject, self).__init__(
150            *args, keep_local=keep_local, provider=provider, **kwargs
151        )
152
153        if provider:
154            self.client = provider.remote_interface()
155        else:
156            self.client = storage.Client(*args, **kwargs)
157
158        # keep user_project available for when bucket is initialized
159        self._user_project = user_project
160
161        self._key = None
162        self._bucket_name = None
163        self._bucket = None
164        self._blob = None
165
166    async def inventory(self, cache: snakemake.io.IOCache):
167        """Using client.list_blobs(), we want to iterate over the objects in
168        the "folder" of a bucket and store information about the IOFiles in the
169        provided cache (snakemake.io.IOCache) indexed by bucket/blob name.
170        This will be called by the first mention of a remote object, and
171        iterate over the entire bucket once (and then not need to again).
172        This includes:
173         - cache.exist_remote
174         - cache_mtime
175         - cache.size
176        """
177        if cache.remaining_wait_time <= 0:
178            # No more time to create inventory.
179            return
180
181        start_time = time.time()
182        subfolder = os.path.dirname(self.blob.name)
183        for blob in self.client.list_blobs(self.bucket_name, prefix=subfolder):
184            # By way of being listed, it exists. mtime is a datetime object
185            name = "{}/{}".format(blob.bucket.name, blob.name)
186            cache.exists_remote[name] = True
187            cache.mtime[name] = blob.updated.timestamp()
188            cache.size[name] = blob.size
189
190        cache.remaining_wait_time -= time.time() - start_time
191
192        # Mark bucket and prefix as having an inventory, such that this method is
193        # only called once for the subfolder in the bucket.
194        cache.exists_remote.has_inventory.add("%s/%s" % (self.bucket_name, subfolder))
195
196    # === Implementations of abstract class members ===
197
198    def get_inventory_parent(self):
199        return self.bucket_name
200
201    @retry.Retry(predicate=google_cloud_retry_predicate)
202    def exists(self):
203        return self.blob.exists()
204
205    def mtime(self):
206        if self.exists():
207            self.update_blob()
208            t = self.blob.updated
209            return t.timestamp()
210        else:
211            raise WorkflowError(
212                "The file does not seem to exist remotely: %s" % self.local_file()
213            )
214
215    def size(self):
216        if self.exists():
217            self.update_blob()
218            return self.blob.size // 1024
219        else:
220            return self._iofile.size_local
221
222    @retry.Retry(predicate=google_cloud_retry_predicate, deadline=600)
223    def download(self):
224        """Download with maximum retry duration of 600 seconds (10 minutes)"""
225        if not self.exists():
226            return None
227
228        # Create just a directory, or a file itself
229        if snakemake.io.is_flagged(self.local_file(), "directory"):
230            return self._download_directory()
231        return download_blob(self.blob, self.local_file())
232
233    @retry.Retry(predicate=google_cloud_retry_predicate)
234    def _download_directory(self):
235        """A 'private' function to handle download of a storage folder, which
236        includes the content found inside.
237        """
238        # Create the directory locally
239        os.makedirs(self.local_file(), exist_ok=True)
240
241        for blob in self.client.list_blobs(self.bucket_name, prefix=self.key):
242            local_name = "{}/{}".format(blob.bucket.name, blob.name)
243
244            # Don't try to create "directory blob"
245            if os.path.exists(local_name) and os.path.isdir(local_name):
246                continue
247
248            download_blob(blob, local_name)
249
250        # Return the root directory
251        return self.local_file()
252
253    @retry.Retry(predicate=google_cloud_retry_predicate)
254    def upload(self):
255        try:
256            if not self.bucket.exists():
257                self.bucket.create()
258                self.update_blob()
259
260            # Distinguish between single file, and folder
261            f = self.local_file()
262            if os.path.isdir(f):
263
264                # Ensure the "directory" exists
265                self.blob.upload_from_string(
266                    "", content_type="application/x-www-form-urlencoded;charset=UTF-8"
267                )
268                for root, _, files in os.walk(f):
269                    for filename in files:
270                        filename = os.path.join(root, filename)
271                        bucket_path = filename.lstrip(self.bucket.name).lstrip("/")
272                        blob = self.bucket.blob(bucket_path)
273                        blob.upload_from_filename(filename)
274            else:
275                self.blob.upload_from_filename(f)
276        except google.cloud.exceptions.Forbidden as e:
277            raise WorkflowError(
278                e,
279                "When running locally, make sure that you are authenticated "
280                "via gcloud (see Snakemake documentation). When running in a "
281                "kubernetes cluster, make sure that storage-rw is added to "
282                "--scopes (see Snakemake documentation).",
283            )
284
285    @property
286    def name(self):
287        return self.key
288
289    @property
290    def list(self):
291        return [k.name for k in self.bucket.list_blobs()]
292
293    # ========= Helpers ===============
294
295    @retry.Retry(predicate=google_cloud_retry_predicate)
296    def update_blob(self):
297        self._blob = self.bucket.get_blob(self.key)
298
299    @lazy_property
300    def bucket(self):
301        return self.client.bucket(self.bucket_name, user_project=self._user_project)
302
303    @lazy_property
304    def blob(self):
305        return self.bucket.blob(self.key)
306
307    @lazy_property
308    def bucket_name(self):
309        return self.parse().group("bucket")
310
311    @property
312    def key(self):
313        key = self.parse().group("key")
314        f = self.local_file()
315        if snakemake.io.is_flagged(f, "directory"):
316            key = key if f.endswith("/") else key + "/"
317        return key
318
319    def parse(self):
320        m = re.search("(?P<bucket>[^/]*)/(?P<key>.*)", self.local_file())
321        if len(m.groups()) != 2:
322            raise WorkflowError(
323                "GS remote file {} does not have the form "
324                "<bucket>/<key>.".format(self.local_file())
325            )
326        return m
327