1__author__ = "Johannes Köster" 2__copyright__ = "Copyright 2021, Johannes Köster" 3__email__ = "johannes.koester@tu-dortmund.de" 4__license__ = "MIT" 5 6import base64 7import os 8import re 9import struct 10import time 11 12from snakemake.remote import AbstractRemoteObject, AbstractRemoteProvider 13from snakemake.exceptions import WorkflowError, CheckSumMismatchException 14from snakemake.common import lazy_property 15import snakemake.io 16from snakemake.utils import os_sync 17 18try: 19 import google.cloud 20 from google.cloud import storage 21 from google.api_core import retry 22 from google_crc32c import Checksum 23except ImportError as e: 24 raise WorkflowError( 25 "The Python 3 packages 'google-cloud-sdk' and `google-crc32c` " 26 "need to be installed to use GS remote() file functionality. %s" % e.msg 27 ) 28 29 30def google_cloud_retry_predicate(ex): 31 """Given an exception from Google Cloud, determine if it's one in the 32 listing of transient errors (determined by function 33 google.api_core.retry.if_transient_error(exception)) or determine if 34 triggered by a hash mismatch due to a bad download. This function will 35 return a boolean to indicate if retry should be done, and is typically 36 used with the google.api_core.retry.Retry as a decorator (predicate). 37 38 Arguments: 39 ex (Exception) : the exception passed from the decorated function 40 Returns: boolean to indicate doing retry (True) or not (False) 41 """ 42 from requests.exceptions import ReadTimeout 43 44 # Most likely case is Google API transient error. 45 if retry.if_transient_error(ex): 46 return True 47 # Timeouts should be considered for retry as well. 48 if isinstance(ex, ReadTimeout): 49 return True 50 # Could also be checksum mismatch of download. 51 if isinstance(ex, CheckSumMismatchException): 52 return True 53 return False 54 55 56@retry.Retry(predicate=google_cloud_retry_predicate) 57def download_blob(blob, filename): 58 """A helper function to download a storage Blob to a blob_file (the filename) 59 and validate it using the Crc32cCalculator. 60 61 Arguments: 62 blob (storage.Blob) : the Google storage blob object 63 blob_file (str) : the file path to download to 64 Returns: boolean to indicate doing retry (True) or not (False) 65 """ 66 67 # create parent directories if necessary 68 os.makedirs(os.path.dirname(filename), exist_ok=True) 69 70 # ideally we could calculate hash while streaming to file with provided function 71 # https://github.com/googleapis/python-storage/issues/29 72 with open(filename, "wb") as blob_file: 73 parser = Crc32cCalculator(blob_file) 74 blob.download_to_file(parser) 75 os.sync() 76 77 # **Important** hash can be incorrect or missing if not refreshed 78 blob.reload() 79 80 # Compute local hash and verify correct 81 if parser.hexdigest() != blob.crc32c: 82 os.remove(filename) 83 raise CheckSumMismatchException("The checksum of %s does not match." % filename) 84 return filename 85 86 87class Crc32cCalculator: 88 """The Google Python client doesn't provide a way to stream a file being 89 written, so we can wrap the file object in an additional class to 90 do custom handling. This is so we don't need to download the file 91 and then stream read it again to calculate the hash. 92 """ 93 94 def __init__(self, fileobj): 95 self._fileobj = fileobj 96 self.checksum = Checksum() 97 98 def write(self, chunk): 99 self._fileobj.write(chunk) 100 self._update(chunk) 101 102 def _update(self, chunk): 103 """Given a chunk from the read in file, update the hexdigest""" 104 self.checksum.update(chunk) 105 106 def hexdigest(self): 107 """Return the hexdigest of the hasher. 108 The Base64 encoded CRC32c is in big-endian byte order. 109 See https://cloud.google.com/storage/docs/hashes-etags 110 """ 111 return base64.b64encode(self.checksum.digest()).decode("utf-8") 112 113 114class RemoteProvider(AbstractRemoteProvider): 115 116 supports_default = True 117 118 def __init__( 119 self, *args, keep_local=False, stay_on_remote=False, is_default=False, **kwargs 120 ): 121 super(RemoteProvider, self).__init__( 122 *args, 123 keep_local=keep_local, 124 stay_on_remote=stay_on_remote, 125 is_default=is_default, 126 **kwargs 127 ) 128 129 self.client = storage.Client(*args, **kwargs) 130 131 def remote_interface(self): 132 return self.client 133 134 @property 135 def default_protocol(self): 136 """The protocol that is prepended to the path when no protocol is specified.""" 137 return "gs://" 138 139 @property 140 def available_protocols(self): 141 """List of valid protocols for this remote provider.""" 142 return ["gs://"] 143 144 145class RemoteObject(AbstractRemoteObject): 146 def __init__( 147 self, *args, keep_local=False, provider=None, user_project=None, **kwargs 148 ): 149 super(RemoteObject, self).__init__( 150 *args, keep_local=keep_local, provider=provider, **kwargs 151 ) 152 153 if provider: 154 self.client = provider.remote_interface() 155 else: 156 self.client = storage.Client(*args, **kwargs) 157 158 # keep user_project available for when bucket is initialized 159 self._user_project = user_project 160 161 self._key = None 162 self._bucket_name = None 163 self._bucket = None 164 self._blob = None 165 166 async def inventory(self, cache: snakemake.io.IOCache): 167 """Using client.list_blobs(), we want to iterate over the objects in 168 the "folder" of a bucket and store information about the IOFiles in the 169 provided cache (snakemake.io.IOCache) indexed by bucket/blob name. 170 This will be called by the first mention of a remote object, and 171 iterate over the entire bucket once (and then not need to again). 172 This includes: 173 - cache.exist_remote 174 - cache_mtime 175 - cache.size 176 """ 177 if cache.remaining_wait_time <= 0: 178 # No more time to create inventory. 179 return 180 181 start_time = time.time() 182 subfolder = os.path.dirname(self.blob.name) 183 for blob in self.client.list_blobs(self.bucket_name, prefix=subfolder): 184 # By way of being listed, it exists. mtime is a datetime object 185 name = "{}/{}".format(blob.bucket.name, blob.name) 186 cache.exists_remote[name] = True 187 cache.mtime[name] = blob.updated.timestamp() 188 cache.size[name] = blob.size 189 190 cache.remaining_wait_time -= time.time() - start_time 191 192 # Mark bucket and prefix as having an inventory, such that this method is 193 # only called once for the subfolder in the bucket. 194 cache.exists_remote.has_inventory.add("%s/%s" % (self.bucket_name, subfolder)) 195 196 # === Implementations of abstract class members === 197 198 def get_inventory_parent(self): 199 return self.bucket_name 200 201 @retry.Retry(predicate=google_cloud_retry_predicate) 202 def exists(self): 203 return self.blob.exists() 204 205 def mtime(self): 206 if self.exists(): 207 self.update_blob() 208 t = self.blob.updated 209 return t.timestamp() 210 else: 211 raise WorkflowError( 212 "The file does not seem to exist remotely: %s" % self.local_file() 213 ) 214 215 def size(self): 216 if self.exists(): 217 self.update_blob() 218 return self.blob.size // 1024 219 else: 220 return self._iofile.size_local 221 222 @retry.Retry(predicate=google_cloud_retry_predicate, deadline=600) 223 def download(self): 224 """Download with maximum retry duration of 600 seconds (10 minutes)""" 225 if not self.exists(): 226 return None 227 228 # Create just a directory, or a file itself 229 if snakemake.io.is_flagged(self.local_file(), "directory"): 230 return self._download_directory() 231 return download_blob(self.blob, self.local_file()) 232 233 @retry.Retry(predicate=google_cloud_retry_predicate) 234 def _download_directory(self): 235 """A 'private' function to handle download of a storage folder, which 236 includes the content found inside. 237 """ 238 # Create the directory locally 239 os.makedirs(self.local_file(), exist_ok=True) 240 241 for blob in self.client.list_blobs(self.bucket_name, prefix=self.key): 242 local_name = "{}/{}".format(blob.bucket.name, blob.name) 243 244 # Don't try to create "directory blob" 245 if os.path.exists(local_name) and os.path.isdir(local_name): 246 continue 247 248 download_blob(blob, local_name) 249 250 # Return the root directory 251 return self.local_file() 252 253 @retry.Retry(predicate=google_cloud_retry_predicate) 254 def upload(self): 255 try: 256 if not self.bucket.exists(): 257 self.bucket.create() 258 self.update_blob() 259 260 # Distinguish between single file, and folder 261 f = self.local_file() 262 if os.path.isdir(f): 263 264 # Ensure the "directory" exists 265 self.blob.upload_from_string( 266 "", content_type="application/x-www-form-urlencoded;charset=UTF-8" 267 ) 268 for root, _, files in os.walk(f): 269 for filename in files: 270 filename = os.path.join(root, filename) 271 bucket_path = filename.lstrip(self.bucket.name).lstrip("/") 272 blob = self.bucket.blob(bucket_path) 273 blob.upload_from_filename(filename) 274 else: 275 self.blob.upload_from_filename(f) 276 except google.cloud.exceptions.Forbidden as e: 277 raise WorkflowError( 278 e, 279 "When running locally, make sure that you are authenticated " 280 "via gcloud (see Snakemake documentation). When running in a " 281 "kubernetes cluster, make sure that storage-rw is added to " 282 "--scopes (see Snakemake documentation).", 283 ) 284 285 @property 286 def name(self): 287 return self.key 288 289 @property 290 def list(self): 291 return [k.name for k in self.bucket.list_blobs()] 292 293 # ========= Helpers =============== 294 295 @retry.Retry(predicate=google_cloud_retry_predicate) 296 def update_blob(self): 297 self._blob = self.bucket.get_blob(self.key) 298 299 @lazy_property 300 def bucket(self): 301 return self.client.bucket(self.bucket_name, user_project=self._user_project) 302 303 @lazy_property 304 def blob(self): 305 return self.bucket.blob(self.key) 306 307 @lazy_property 308 def bucket_name(self): 309 return self.parse().group("bucket") 310 311 @property 312 def key(self): 313 key = self.parse().group("key") 314 f = self.local_file() 315 if snakemake.io.is_flagged(f, "directory"): 316 key = key if f.endswith("/") else key + "/" 317 return key 318 319 def parse(self): 320 m = re.search("(?P<bucket>[^/]*)/(?P<key>.*)", self.local_file()) 321 if len(m.groups()) != 2: 322 raise WorkflowError( 323 "GS remote file {} does not have the form " 324 "<bucket>/<key>.".format(self.local_file()) 325 ) 326 return m 327