1# ------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for 4# license information. 5# -------------------------------------------------------------------------- 6 7from typing import ( # pylint: disable=unused-import 8 Union, 9 Optional, 10 Any, 11 Iterable, 12 Dict, 13 List, 14 Type, 15 Tuple, 16 TYPE_CHECKING, 17) 18import logging 19 20try: 21 from urllib.parse import parse_qs, quote 22except ImportError: 23 from urlparse import parse_qs # type: ignore 24 from urllib2 import quote # type: ignore 25 26import six 27 28from azure.core.configuration import Configuration 29from azure.core.exceptions import HttpResponseError 30from azure.core.pipeline import Pipeline 31from azure.core.pipeline.transport import RequestsTransport, HttpTransport 32from azure.core.pipeline.policies import ( 33 RedirectPolicy, 34 ContentDecodePolicy, 35 BearerTokenCredentialPolicy, 36 ProxyPolicy, 37 DistributedTracingPolicy, 38 HttpLoggingPolicy, 39 UserAgentPolicy 40) 41 42from .constants import STORAGE_OAUTH_SCOPE, SERVICE_HOST_BASE, CONNECTION_TIMEOUT, READ_TIMEOUT 43from .models import LocationMode 44from .authentication import SharedKeyCredentialPolicy 45from .shared_access_signature import QueryStringConstants 46from .policies import ( 47 StorageHeadersPolicy, 48 StorageContentValidation, 49 StorageRequestHook, 50 StorageResponseHook, 51 StorageLoggingPolicy, 52 StorageHosts, 53 QueueMessagePolicy, 54 ExponentialRetry, 55) 56from .._version import VERSION 57from .._generated.models import StorageErrorException 58from .response_handlers import process_storage_error, PartialBatchErrorException 59 60 61_LOGGER = logging.getLogger(__name__) 62_SERVICE_PARAMS = { 63 "blob": {"primary": "BlobEndpoint", "secondary": "BlobSecondaryEndpoint"}, 64 "queue": {"primary": "QueueEndpoint", "secondary": "QueueSecondaryEndpoint"}, 65 "file": {"primary": "FileEndpoint", "secondary": "FileSecondaryEndpoint"}, 66 "dfs": {"primary": "BlobEndpoint", "secondary": "BlobEndpoint"}, 67} 68 69 70class StorageAccountHostsMixin(object): # pylint: disable=too-many-instance-attributes 71 def __init__( 72 self, 73 parsed_url, # type: Any 74 service, # type: str 75 credential=None, # type: Optional[Any] 76 **kwargs # type: Any 77 ): 78 # type: (...) -> None 79 self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) 80 self._hosts = kwargs.get("_hosts") 81 self.scheme = parsed_url.scheme 82 83 if service not in ["blob", "queue", "file-share", "dfs"]: 84 raise ValueError("Invalid service: {}".format(service)) 85 service_name = service.split('-')[0] 86 account = parsed_url.netloc.split(".{}.core.".format(service_name)) 87 88 self.account_name = account[0] if len(account) > 1 else None 89 if not self.account_name and parsed_url.netloc.startswith("localhost") \ 90 or parsed_url.netloc.startswith("127.0.0.1"): 91 self.account_name = parsed_url.path.strip("/") 92 93 self.credential = _format_shared_key_credential(self.account_name, credential) 94 if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"): 95 raise ValueError("Token credential is only supported with HTTPS.") 96 97 secondary_hostname = None 98 if hasattr(self.credential, "account_name"): 99 self.account_name = self.credential.account_name 100 secondary_hostname = "{}-secondary.{}.{}".format( 101 self.credential.account_name, service_name, SERVICE_HOST_BASE) 102 103 if not self._hosts: 104 if len(account) > 1: 105 secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary") 106 if kwargs.get("secondary_hostname"): 107 secondary_hostname = kwargs["secondary_hostname"] 108 primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/') 109 self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname} 110 111 self.require_encryption = kwargs.get("require_encryption", False) 112 self.key_encryption_key = kwargs.get("key_encryption_key") 113 self.key_resolver_function = kwargs.get("key_resolver_function") 114 self._config, self._pipeline = self._create_pipeline(self.credential, storage_sdk=service, **kwargs) 115 116 def __enter__(self): 117 self._client.__enter__() 118 return self 119 120 def __exit__(self, *args): 121 self._client.__exit__(*args) 122 123 def close(self): 124 """ This method is to close the sockets opened by the client. 125 It need not be used when using with a context manager. 126 """ 127 self._client.close() 128 129 @property 130 def url(self): 131 """The full endpoint URL to this entity, including SAS token if used. 132 133 This could be either the primary endpoint, 134 or the secondary endpoint depending on the current :func:`location_mode`. 135 """ 136 return self._format_url(self._hosts[self._location_mode]) 137 138 @property 139 def primary_endpoint(self): 140 """The full primary endpoint URL. 141 142 :type: str 143 """ 144 return self._format_url(self._hosts[LocationMode.PRIMARY]) 145 146 @property 147 def primary_hostname(self): 148 """The hostname of the primary endpoint. 149 150 :type: str 151 """ 152 return self._hosts[LocationMode.PRIMARY] 153 154 @property 155 def secondary_endpoint(self): 156 """The full secondary endpoint URL if configured. 157 158 If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional 159 `secondary_hostname` keyword argument on instantiation. 160 161 :type: str 162 :raise ValueError: 163 """ 164 if not self._hosts[LocationMode.SECONDARY]: 165 raise ValueError("No secondary host configured.") 166 return self._format_url(self._hosts[LocationMode.SECONDARY]) 167 168 @property 169 def secondary_hostname(self): 170 """The hostname of the secondary endpoint. 171 172 If not available this will be None. To explicitly specify a secondary hostname, use the optional 173 `secondary_hostname` keyword argument on instantiation. 174 175 :type: str or None 176 """ 177 return self._hosts[LocationMode.SECONDARY] 178 179 @property 180 def location_mode(self): 181 """The location mode that the client is currently using. 182 183 By default this will be "primary". Options include "primary" and "secondary". 184 185 :type: str 186 """ 187 188 return self._location_mode 189 190 @location_mode.setter 191 def location_mode(self, value): 192 if self._hosts.get(value): 193 self._location_mode = value 194 self._client._config.url = self.url # pylint: disable=protected-access 195 else: 196 raise ValueError("No host URL for location mode: {}".format(value)) 197 198 @property 199 def api_version(self): 200 """The version of the Storage API used for requests. 201 202 :type: str 203 """ 204 return self._client._config.version # pylint: disable=protected-access 205 206 def _format_query_string(self, sas_token, credential, snapshot=None, share_snapshot=None): 207 query_str = "?" 208 if snapshot: 209 query_str += "snapshot={}&".format(self.snapshot) 210 if share_snapshot: 211 query_str += "sharesnapshot={}&".format(self.snapshot) 212 if sas_token and not credential: 213 query_str += sas_token 214 elif is_credential_sastoken(credential): 215 query_str += credential.lstrip("?") 216 credential = None 217 return query_str.rstrip("?&"), credential 218 219 def _create_pipeline(self, credential, **kwargs): 220 # type: (Any, **Any) -> Tuple[Configuration, Pipeline] 221 self._credential_policy = None 222 if hasattr(credential, "get_token"): 223 self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) 224 elif isinstance(credential, SharedKeyCredentialPolicy): 225 self._credential_policy = credential 226 elif credential is not None: 227 raise TypeError("Unsupported credential: {}".format(credential)) 228 229 config = kwargs.get("_configuration") or create_configuration(**kwargs) 230 if kwargs.get("_pipeline"): 231 return config, kwargs["_pipeline"] 232 config.transport = kwargs.get("transport") # type: ignore 233 kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) 234 kwargs.setdefault("read_timeout", READ_TIMEOUT) 235 if not config.transport: 236 config.transport = RequestsTransport(**kwargs) 237 policies = [ 238 QueueMessagePolicy(), 239 config.headers_policy, 240 config.proxy_policy, 241 config.user_agent_policy, 242 StorageContentValidation(), 243 StorageRequestHook(**kwargs), 244 self._credential_policy, 245 ContentDecodePolicy(response_encoding="utf-8"), 246 RedirectPolicy(**kwargs), 247 StorageHosts(hosts=self._hosts, **kwargs), 248 config.retry_policy, 249 config.logging_policy, 250 StorageResponseHook(**kwargs), 251 DistributedTracingPolicy(**kwargs), 252 HttpLoggingPolicy(**kwargs) 253 ] 254 if kwargs.get("_additional_pipeline_policies"): 255 policies = policies + kwargs.get("_additional_pipeline_policies") 256 return config, Pipeline(config.transport, policies=policies) 257 258 def _batch_send( 259 self, *reqs, # type: HttpRequest 260 **kwargs 261 ): 262 """Given a series of request, do a Storage batch call. 263 """ 264 # Pop it here, so requests doesn't feel bad about additional kwarg 265 raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) 266 request = self._client._client.post( # pylint: disable=protected-access 267 url='{}://{}/?comp=batch{}{}'.format( 268 self.scheme, 269 self.primary_hostname, 270 kwargs.pop('sas', ""), 271 kwargs.pop('timeout', "") 272 ), 273 headers={ 274 'x-ms-version': self.api_version 275 } 276 ) 277 278 policies = [StorageHeadersPolicy()] 279 if self._credential_policy: 280 policies.append(self._credential_policy) 281 282 request.set_multipart_mixed( 283 *reqs, 284 policies=policies, 285 enforce_https=False 286 ) 287 288 pipeline_response = self._pipeline.run( 289 request, **kwargs 290 ) 291 response = pipeline_response.http_response 292 293 try: 294 if response.status_code not in [202]: 295 raise HttpResponseError(response=response) 296 parts = response.parts() 297 if raise_on_any_failure: 298 parts = list(response.parts()) 299 if any(p for p in parts if not 200 <= p.status_code < 300): 300 error = PartialBatchErrorException( 301 message="There is a partial failure in the batch operation.", 302 response=response, parts=parts 303 ) 304 raise error 305 return iter(parts) 306 return parts 307 except StorageErrorException as error: 308 process_storage_error(error) 309 310class TransportWrapper(HttpTransport): 311 """Wrapper class that ensures that an inner client created 312 by a `get_client` method does not close the outer transport for the parent 313 when used in a context manager. 314 """ 315 def __init__(self, transport): 316 self._transport = transport 317 318 def send(self, request, **kwargs): 319 return self._transport.send(request, **kwargs) 320 321 def open(self): 322 pass 323 324 def close(self): 325 pass 326 327 def __enter__(self): 328 pass 329 330 def __exit__(self, *args): # pylint: disable=arguments-differ 331 pass 332 333 334def _format_shared_key_credential(account_name, credential): 335 if isinstance(credential, six.string_types): 336 if not account_name: 337 raise ValueError("Unable to determine account name for shared key credential.") 338 credential = {"account_name": account_name, "account_key": credential} 339 if isinstance(credential, dict): 340 if "account_name" not in credential: 341 raise ValueError("Shared key credential missing 'account_name") 342 if "account_key" not in credential: 343 raise ValueError("Shared key credential missing 'account_key") 344 return SharedKeyCredentialPolicy(**credential) 345 return credential 346 347 348def parse_connection_str(conn_str, credential, service): 349 conn_str = conn_str.rstrip(";") 350 conn_settings = [s.split("=", 1) for s in conn_str.split(";")] 351 if any(len(tup) != 2 for tup in conn_settings): 352 raise ValueError("Connection string is either blank or malformed.") 353 conn_settings = dict(conn_settings) 354 endpoints = _SERVICE_PARAMS[service] 355 primary = None 356 secondary = None 357 if not credential: 358 try: 359 credential = {"account_name": conn_settings["AccountName"], "account_key": conn_settings["AccountKey"]} 360 except KeyError: 361 credential = conn_settings.get("SharedAccessSignature") 362 if endpoints["primary"] in conn_settings: 363 primary = conn_settings[endpoints["primary"]] 364 if endpoints["secondary"] in conn_settings: 365 secondary = conn_settings[endpoints["secondary"]] 366 else: 367 if endpoints["secondary"] in conn_settings: 368 raise ValueError("Connection string specifies only secondary endpoint.") 369 try: 370 primary = "{}://{}.{}.{}".format( 371 conn_settings["DefaultEndpointsProtocol"], 372 conn_settings["AccountName"], 373 service, 374 conn_settings["EndpointSuffix"], 375 ) 376 secondary = "{}-secondary.{}.{}".format( 377 conn_settings["AccountName"], service, conn_settings["EndpointSuffix"] 378 ) 379 except KeyError: 380 pass 381 382 if not primary: 383 try: 384 primary = "https://{}.{}.{}".format( 385 conn_settings["AccountName"], service, conn_settings.get("EndpointSuffix", SERVICE_HOST_BASE) 386 ) 387 except KeyError: 388 raise ValueError("Connection string missing required connection details.") 389 return primary, secondary, credential 390 391 392def create_configuration(**kwargs): 393 # type: (**Any) -> Configuration 394 config = Configuration(**kwargs) 395 config.headers_policy = StorageHeadersPolicy(**kwargs) 396 config.user_agent_policy = UserAgentPolicy( 397 sdk_moniker="storage-{}/{}".format(kwargs.pop('storage_sdk'), VERSION), **kwargs) 398 config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) 399 config.logging_policy = StorageLoggingPolicy(**kwargs) 400 config.proxy_policy = ProxyPolicy(**kwargs) 401 402 # Storage settings 403 config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024) 404 config.copy_polling_interval = 15 405 406 # Block blob uploads 407 config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024) 408 config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) 409 config.use_byte_buffer = kwargs.get("use_byte_buffer", False) 410 411 # Page blob uploads 412 config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024) 413 414 # Blob downloads 415 config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024) 416 config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024) 417 418 # File uploads 419 config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024) 420 return config 421 422 423def parse_query(query_str): 424 sas_values = QueryStringConstants.to_list() 425 parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()} 426 sas_params = ["{}={}".format(k, quote(v, safe='')) for k, v in parsed_query.items() if k in sas_values] 427 sas_token = None 428 if sas_params: 429 sas_token = "&".join(sas_params) 430 431 snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot") 432 return snapshot, sas_token 433 434 435def is_credential_sastoken(credential): 436 if not credential or not isinstance(credential, six.string_types): 437 return False 438 439 sas_values = QueryStringConstants.to_list() 440 parsed_query = parse_qs(credential.lstrip("?")) 441 if parsed_query and all([k in sas_values for k in parsed_query.keys()]): 442 return True 443 return False 444