1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6
7import logging
8import sys
9
10try:
11    from urllib.parse import urlparse, unquote
12except ImportError:
13    from urlparse import urlparse # type: ignore
14    from urllib2 import unquote # type: ignore
15
16try:
17    from yarl import URL
18except ImportError:
19    pass
20
21try:
22    from azure.core.pipeline.transport import AioHttpTransport
23except ImportError:
24    AioHttpTransport = None
25
26from azure.core.exceptions import ClientAuthenticationError
27from azure.core.pipeline.policies import SansIOHTTPPolicy
28
29from . import sign_string
30
31
32logger = logging.getLogger(__name__)
33
34
35
36# wraps a given exception with the desired exception type
37def _wrap_exception(ex, desired_type):
38    msg = ""
39    if ex.args:
40        msg = ex.args[0]
41    if sys.version_info >= (3,):
42        # Automatic chaining in Python 3 means we keep the trace
43        return desired_type(msg)
44    # There isn't a good solution in 2 for keeping the stack trace
45    # in general, or that will not result in an error in 3
46    # However, we can keep the previous error type and message
47    # TODO: In the future we will log the trace
48    return desired_type('{}: {}'.format(ex.__class__.__name__, msg))
49
50
51class AzureSigningError(ClientAuthenticationError):
52    """
53    Represents a fatal error when attempting to sign a request.
54    In general, the cause of this exception is user error. For example, the given account key is not valid.
55    Please visit https://docs.microsoft.com/en-us/azure/storage/common/storage-create-storage-account for more info.
56    """
57
58
59# pylint: disable=no-self-use
60class SharedKeyCredentialPolicy(SansIOHTTPPolicy):
61
62    def __init__(self, account_name, account_key):
63        self.account_name = account_name
64        self.account_key = account_key
65        super(SharedKeyCredentialPolicy, self).__init__()
66
67    def _get_headers(self, request, headers_to_sign):
68        headers = dict((name.lower(), value) for name, value in request.http_request.headers.items() if value)
69        if 'content-length' in headers and headers['content-length'] == '0':
70            del headers['content-length']
71        return '\n'.join(headers.get(x, '') for x in headers_to_sign) + '\n'
72
73    def _get_verb(self, request):
74        return request.http_request.method + '\n'
75
76    def _get_canonicalized_resource(self, request):
77        uri_path = urlparse(request.http_request.url).path
78        try:
79            if isinstance(request.context.transport, AioHttpTransport) or \
80                isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport):
81                uri_path = URL(uri_path)
82                return '/' + self.account_name + str(uri_path)
83        except TypeError:
84            pass
85        return '/' + self.account_name + uri_path
86
87    def _get_canonicalized_headers(self, request):
88        string_to_sign = ''
89        x_ms_headers = []
90        for name, value in request.http_request.headers.items():
91            if name.startswith('x-ms-'):
92                x_ms_headers.append((name.lower(), value))
93        x_ms_headers.sort()
94        for name, value in x_ms_headers:
95            if value is not None:
96                string_to_sign += ''.join([name, ':', value, '\n'])
97        return string_to_sign
98
99    def _get_canonicalized_resource_query(self, request):
100        sorted_queries = [(name, value) for name, value in request.http_request.query.items()]
101        sorted_queries.sort()
102
103        string_to_sign = ''
104        for name, value in sorted_queries:
105            if value is not None:
106                string_to_sign += '\n' + name.lower() + ':' + unquote(value)
107
108        return string_to_sign
109
110    def _add_authorization_header(self, request, string_to_sign):
111        try:
112            signature = sign_string(self.account_key, string_to_sign)
113            auth_string = 'SharedKey ' + self.account_name + ':' + signature
114            request.http_request.headers['Authorization'] = auth_string
115        except Exception as ex:
116            # Wrap any error that occurred as signing error
117            # Doing so will clarify/locate the source of problem
118            raise _wrap_exception(ex, AzureSigningError)
119
120    def on_request(self, request):
121        string_to_sign = \
122            self._get_verb(request) + \
123            self._get_headers(
124                request,
125                [
126                    'content-encoding', 'content-language', 'content-length',
127                    'content-md5', 'content-type', 'date', 'if-modified-since',
128                    'if-match', 'if-none-match', 'if-unmodified-since', 'byte_range'
129                ]
130            ) + \
131            self._get_canonicalized_headers(request) + \
132            self._get_canonicalized_resource(request) + \
133            self._get_canonicalized_resource_query(request)
134
135        self._add_authorization_header(request, string_to_sign)
136        #logger.debug("String_to_sign=%s", string_to_sign)
137