1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6
7from typing import (  # pylint: disable=unused-import
8    Union,
9    Optional,
10    Any,
11    Iterable,
12    Dict,
13    List,
14    Type,
15    Tuple,
16    TYPE_CHECKING,
17)
18import logging
19
20try:
21    from urllib.parse import parse_qs, quote
22except ImportError:
23    from urlparse import parse_qs  # type: ignore
24    from urllib2 import quote  # type: ignore
25
26import six
27
28from azure.core.configuration import Configuration
29from azure.core.exceptions import HttpResponseError
30from azure.core.pipeline import Pipeline
31from azure.core.pipeline.transport import RequestsTransport, HttpTransport
32from azure.core.pipeline.policies import (
33    RedirectPolicy,
34    ContentDecodePolicy,
35    BearerTokenCredentialPolicy,
36    ProxyPolicy,
37    DistributedTracingPolicy,
38    HttpLoggingPolicy,
39    UserAgentPolicy
40)
41
42from .constants import STORAGE_OAUTH_SCOPE, SERVICE_HOST_BASE, CONNECTION_TIMEOUT, READ_TIMEOUT
43from .models import LocationMode
44from .authentication import SharedKeyCredentialPolicy
45from .shared_access_signature import QueryStringConstants
46from .policies import (
47    StorageHeadersPolicy,
48    StorageContentValidation,
49    StorageRequestHook,
50    StorageResponseHook,
51    StorageLoggingPolicy,
52    StorageHosts,
53    QueueMessagePolicy,
54    ExponentialRetry,
55)
56from .._version import VERSION
57from .._generated.models import StorageErrorException
58from .response_handlers import process_storage_error, PartialBatchErrorException
59
60
61_LOGGER = logging.getLogger(__name__)
62_SERVICE_PARAMS = {
63    "blob": {"primary": "BlobEndpoint", "secondary": "BlobSecondaryEndpoint"},
64    "queue": {"primary": "QueueEndpoint", "secondary": "QueueSecondaryEndpoint"},
65    "file": {"primary": "FileEndpoint", "secondary": "FileSecondaryEndpoint"},
66    "dfs": {"primary": "BlobEndpoint", "secondary": "BlobEndpoint"},
67}
68
69
70class StorageAccountHostsMixin(object):  # pylint: disable=too-many-instance-attributes
71    def __init__(
72        self,
73        parsed_url,  # type: Any
74        service,  # type: str
75        credential=None,  # type: Optional[Any]
76        **kwargs  # type: Any
77    ):
78        # type: (...) -> None
79        self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
80        self._hosts = kwargs.get("_hosts")
81        self.scheme = parsed_url.scheme
82
83        if service not in ["blob", "queue", "file-share", "dfs"]:
84            raise ValueError("Invalid service: {}".format(service))
85        service_name = service.split('-')[0]
86        account = parsed_url.netloc.split(".{}.core.".format(service_name))
87
88        self.account_name = account[0] if len(account) > 1 else None
89        if not self.account_name and parsed_url.netloc.startswith("localhost") \
90                or parsed_url.netloc.startswith("127.0.0.1"):
91            self.account_name = parsed_url.path.strip("/")
92
93        self.credential = _format_shared_key_credential(self.account_name, credential)
94        if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"):
95            raise ValueError("Token credential is only supported with HTTPS.")
96
97        secondary_hostname = None
98        if hasattr(self.credential, "account_name"):
99            self.account_name = self.credential.account_name
100            secondary_hostname = "{}-secondary.{}.{}".format(
101                self.credential.account_name, service_name, SERVICE_HOST_BASE)
102
103        if not self._hosts:
104            if len(account) > 1:
105                secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary")
106            if kwargs.get("secondary_hostname"):
107                secondary_hostname = kwargs["secondary_hostname"]
108            primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/')
109            self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname}
110
111        self.require_encryption = kwargs.get("require_encryption", False)
112        self.key_encryption_key = kwargs.get("key_encryption_key")
113        self.key_resolver_function = kwargs.get("key_resolver_function")
114        self._config, self._pipeline = self._create_pipeline(self.credential, storage_sdk=service, **kwargs)
115
116    def __enter__(self):
117        self._client.__enter__()
118        return self
119
120    def __exit__(self, *args):
121        self._client.__exit__(*args)
122
123    def close(self):
124        """ This method is to close the sockets opened by the client.
125        It need not be used when using with a context manager.
126        """
127        self._client.close()
128
129    @property
130    def url(self):
131        """The full endpoint URL to this entity, including SAS token if used.
132
133        This could be either the primary endpoint,
134        or the secondary endpoint depending on the current :func:`location_mode`.
135        """
136        return self._format_url(self._hosts[self._location_mode])
137
138    @property
139    def primary_endpoint(self):
140        """The full primary endpoint URL.
141
142        :type: str
143        """
144        return self._format_url(self._hosts[LocationMode.PRIMARY])
145
146    @property
147    def primary_hostname(self):
148        """The hostname of the primary endpoint.
149
150        :type: str
151        """
152        return self._hosts[LocationMode.PRIMARY]
153
154    @property
155    def secondary_endpoint(self):
156        """The full secondary endpoint URL if configured.
157
158        If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional
159        `secondary_hostname` keyword argument on instantiation.
160
161        :type: str
162        :raise ValueError:
163        """
164        if not self._hosts[LocationMode.SECONDARY]:
165            raise ValueError("No secondary host configured.")
166        return self._format_url(self._hosts[LocationMode.SECONDARY])
167
168    @property
169    def secondary_hostname(self):
170        """The hostname of the secondary endpoint.
171
172        If not available this will be None. To explicitly specify a secondary hostname, use the optional
173        `secondary_hostname` keyword argument on instantiation.
174
175        :type: str or None
176        """
177        return self._hosts[LocationMode.SECONDARY]
178
179    @property
180    def location_mode(self):
181        """The location mode that the client is currently using.
182
183        By default this will be "primary". Options include "primary" and "secondary".
184
185        :type: str
186        """
187
188        return self._location_mode
189
190    @location_mode.setter
191    def location_mode(self, value):
192        if self._hosts.get(value):
193            self._location_mode = value
194            self._client._config.url = self.url  # pylint: disable=protected-access
195        else:
196            raise ValueError("No host URL for location mode: {}".format(value))
197
198    @property
199    def api_version(self):
200        """The version of the Storage API used for requests.
201
202        :type: str
203        """
204        return self._client._config.version  # pylint: disable=protected-access
205
206    def _format_query_string(self, sas_token, credential, snapshot=None, share_snapshot=None):
207        query_str = "?"
208        if snapshot:
209            query_str += "snapshot={}&".format(self.snapshot)
210        if share_snapshot:
211            query_str += "sharesnapshot={}&".format(self.snapshot)
212        if sas_token and not credential:
213            query_str += sas_token
214        elif is_credential_sastoken(credential):
215            query_str += credential.lstrip("?")
216            credential = None
217        return query_str.rstrip("?&"), credential
218
219    def _create_pipeline(self, credential, **kwargs):
220        # type: (Any, **Any) -> Tuple[Configuration, Pipeline]
221        self._credential_policy = None
222        if hasattr(credential, "get_token"):
223            self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE)
224        elif isinstance(credential, SharedKeyCredentialPolicy):
225            self._credential_policy = credential
226        elif credential is not None:
227            raise TypeError("Unsupported credential: {}".format(credential))
228
229        config = kwargs.get("_configuration") or create_configuration(**kwargs)
230        if kwargs.get("_pipeline"):
231            return config, kwargs["_pipeline"]
232        config.transport = kwargs.get("transport")  # type: ignore
233        kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
234        kwargs.setdefault("read_timeout", READ_TIMEOUT)
235        if not config.transport:
236            config.transport = RequestsTransport(**kwargs)
237        policies = [
238            QueueMessagePolicy(),
239            config.headers_policy,
240            config.proxy_policy,
241            config.user_agent_policy,
242            StorageContentValidation(),
243            StorageRequestHook(**kwargs),
244            self._credential_policy,
245            ContentDecodePolicy(response_encoding="utf-8"),
246            RedirectPolicy(**kwargs),
247            StorageHosts(hosts=self._hosts, **kwargs),
248            config.retry_policy,
249            config.logging_policy,
250            StorageResponseHook(**kwargs),
251            DistributedTracingPolicy(**kwargs),
252            HttpLoggingPolicy(**kwargs)
253        ]
254        if kwargs.get("_additional_pipeline_policies"):
255            policies = policies + kwargs.get("_additional_pipeline_policies")
256        return config, Pipeline(config.transport, policies=policies)
257
258    def _batch_send(
259        self, *reqs,  # type: HttpRequest
260        **kwargs
261    ):
262        """Given a series of request, do a Storage batch call.
263        """
264        # Pop it here, so requests doesn't feel bad about additional kwarg
265        raise_on_any_failure = kwargs.pop("raise_on_any_failure", True)
266        request = self._client._client.post(  # pylint: disable=protected-access
267            url='{}://{}/?comp=batch{}{}'.format(
268                self.scheme,
269                self.primary_hostname,
270                kwargs.pop('sas', ""),
271                kwargs.pop('timeout', "")
272            ),
273            headers={
274                'x-ms-version': self.api_version
275            }
276        )
277
278        policies = [StorageHeadersPolicy()]
279        if self._credential_policy:
280            policies.append(self._credential_policy)
281
282        request.set_multipart_mixed(
283            *reqs,
284            policies=policies,
285            enforce_https=False
286        )
287
288        pipeline_response = self._pipeline.run(
289            request, **kwargs
290        )
291        response = pipeline_response.http_response
292
293        try:
294            if response.status_code not in [202]:
295                raise HttpResponseError(response=response)
296            parts = response.parts()
297            if raise_on_any_failure:
298                parts = list(response.parts())
299                if any(p for p in parts if not 200 <= p.status_code < 300):
300                    error = PartialBatchErrorException(
301                        message="There is a partial failure in the batch operation.",
302                        response=response, parts=parts
303                    )
304                    raise error
305                return iter(parts)
306            return parts
307        except StorageErrorException as error:
308            process_storage_error(error)
309
310class TransportWrapper(HttpTransport):
311    """Wrapper class that ensures that an inner client created
312    by a `get_client` method does not close the outer transport for the parent
313    when used in a context manager.
314    """
315    def __init__(self, transport):
316        self._transport = transport
317
318    def send(self, request, **kwargs):
319        return self._transport.send(request, **kwargs)
320
321    def open(self):
322        pass
323
324    def close(self):
325        pass
326
327    def __enter__(self):
328        pass
329
330    def __exit__(self, *args):  # pylint: disable=arguments-differ
331        pass
332
333
334def _format_shared_key_credential(account_name, credential):
335    if isinstance(credential, six.string_types):
336        if not account_name:
337            raise ValueError("Unable to determine account name for shared key credential.")
338        credential = {"account_name": account_name, "account_key": credential}
339    if isinstance(credential, dict):
340        if "account_name" not in credential:
341            raise ValueError("Shared key credential missing 'account_name")
342        if "account_key" not in credential:
343            raise ValueError("Shared key credential missing 'account_key")
344        return SharedKeyCredentialPolicy(**credential)
345    return credential
346
347
348def parse_connection_str(conn_str, credential, service):
349    conn_str = conn_str.rstrip(";")
350    conn_settings = [s.split("=", 1) for s in conn_str.split(";")]
351    if any(len(tup) != 2 for tup in conn_settings):
352        raise ValueError("Connection string is either blank or malformed.")
353    conn_settings = dict(conn_settings)
354    endpoints = _SERVICE_PARAMS[service]
355    primary = None
356    secondary = None
357    if not credential:
358        try:
359            credential = {"account_name": conn_settings["AccountName"], "account_key": conn_settings["AccountKey"]}
360        except KeyError:
361            credential = conn_settings.get("SharedAccessSignature")
362    if endpoints["primary"] in conn_settings:
363        primary = conn_settings[endpoints["primary"]]
364        if endpoints["secondary"] in conn_settings:
365            secondary = conn_settings[endpoints["secondary"]]
366    else:
367        if endpoints["secondary"] in conn_settings:
368            raise ValueError("Connection string specifies only secondary endpoint.")
369        try:
370            primary = "{}://{}.{}.{}".format(
371                conn_settings["DefaultEndpointsProtocol"],
372                conn_settings["AccountName"],
373                service,
374                conn_settings["EndpointSuffix"],
375            )
376            secondary = "{}-secondary.{}.{}".format(
377                conn_settings["AccountName"], service, conn_settings["EndpointSuffix"]
378            )
379        except KeyError:
380            pass
381
382    if not primary:
383        try:
384            primary = "https://{}.{}.{}".format(
385                conn_settings["AccountName"], service, conn_settings.get("EndpointSuffix", SERVICE_HOST_BASE)
386            )
387        except KeyError:
388            raise ValueError("Connection string missing required connection details.")
389    return primary, secondary, credential
390
391
392def create_configuration(**kwargs):
393    # type: (**Any) -> Configuration
394    config = Configuration(**kwargs)
395    config.headers_policy = StorageHeadersPolicy(**kwargs)
396    config.user_agent_policy = UserAgentPolicy(
397        sdk_moniker="storage-{}/{}".format(kwargs.pop('storage_sdk'), VERSION), **kwargs)
398    config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs)
399    config.logging_policy = StorageLoggingPolicy(**kwargs)
400    config.proxy_policy = ProxyPolicy(**kwargs)
401
402    # Storage settings
403    config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024)
404    config.copy_polling_interval = 15
405
406    # Block blob uploads
407    config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024)
408    config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1)
409    config.use_byte_buffer = kwargs.get("use_byte_buffer", False)
410
411    # Page blob uploads
412    config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024)
413
414    # Blob downloads
415    config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024)
416    config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024)
417
418    # File uploads
419    config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024)
420    return config
421
422
423def parse_query(query_str):
424    sas_values = QueryStringConstants.to_list()
425    parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()}
426    sas_params = ["{}={}".format(k, quote(v, safe='')) for k, v in parsed_query.items() if k in sas_values]
427    sas_token = None
428    if sas_params:
429        sas_token = "&".join(sas_params)
430
431    snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot")
432    return snapshot, sas_token
433
434
435def is_credential_sastoken(credential):
436    if not credential or not isinstance(credential, six.string_types):
437        return False
438
439    sas_values = QueryStringConstants.to_list()
440    parsed_query = parse_qs(credential.lstrip("?"))
441    if parsed_query and all([k in sas_values for k in parsed_query.keys()]):
442        return True
443    return False
444