1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6
7import logging
8import sys
9
10try:
11    from urllib.parse import urlparse, unquote
12except ImportError:
13    from urlparse import urlparse # type: ignore
14    from urllib2 import unquote # type: ignore
15
16try:
17    from yarl import URL
18except ImportError:
19    pass
20
21try:
22    from azure.core.pipeline.transport import AioHttpTransport
23except ImportError:
24    AioHttpTransport = None
25
26from azure.core.exceptions import ClientAuthenticationError
27from azure.core.pipeline.policies import SansIOHTTPPolicy
28
29from . import sign_string
30
31
32logger = logging.getLogger(__name__)
33
34
35
36# wraps a given exception with the desired exception type
37def _wrap_exception(ex, desired_type):
38    msg = ""
39    if ex.args:
40        msg = ex.args[0]
41    if sys.version_info >= (3,):
42        # Automatic chaining in Python 3 means we keep the trace
43        return desired_type(msg)
44    # There isn't a good solution in 2 for keeping the stack trace
45    # in general, or that will not result in an error in 3
46    # However, we can keep the previous error type and message
47    # TODO: In the future we will log the trace
48    return desired_type('{}: {}'.format(ex.__class__.__name__, msg))
49
50
51class AzureSigningError(ClientAuthenticationError):
52    """
53    Represents a fatal error when attempting to sign a request.
54    In general, the cause of this exception is user error. For example, the given account key is not valid.
55    Please visit https://docs.microsoft.com/en-us/azure/storage/common/storage-create-storage-account for more info.
56    """
57
58
59# pylint: disable=no-self-use
60class SharedKeyCredentialPolicy(SansIOHTTPPolicy):
61
62    def __init__(self, account_name, account_key):
63        self.account_name = account_name
64        self.account_key = account_key
65        super(SharedKeyCredentialPolicy, self).__init__()
66
67    @staticmethod
68    def _get_headers(request, headers_to_sign):
69        headers = dict((name.lower(), value) for name, value in request.http_request.headers.items() if value)
70        if 'content-length' in headers and headers['content-length'] == '0':
71            del headers['content-length']
72        return '\n'.join(headers.get(x, '') for x in headers_to_sign) + '\n'
73
74    @staticmethod
75    def _get_verb(request):
76        return request.http_request.method + '\n'
77
78    def _get_canonicalized_resource(self, request):
79        uri_path = urlparse(request.http_request.url).path
80        try:
81            if isinstance(request.context.transport, AioHttpTransport) or \
82                    isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport):
83                uri_path = URL(uri_path)
84                return '/' + self.account_name + str(uri_path)
85        except TypeError:
86            pass
87        return '/' + self.account_name + uri_path
88
89    @staticmethod
90    def _get_canonicalized_headers(request):
91        string_to_sign = ''
92        x_ms_headers = []
93        for name, value in request.http_request.headers.items():
94            if name.startswith('x-ms-'):
95                x_ms_headers.append((name.lower(), value))
96        x_ms_headers.sort()
97        for name, value in x_ms_headers:
98            if value is not None:
99                string_to_sign += ''.join([name, ':', value, '\n'])
100        return string_to_sign
101
102    @staticmethod
103    def _get_canonicalized_resource_query(request):
104        sorted_queries = list(request.http_request.query.items())
105        sorted_queries.sort()
106
107        string_to_sign = ''
108        for name, value in sorted_queries:
109            if value is not None:
110                string_to_sign += '\n' + name.lower() + ':' + unquote(value)
111
112        return string_to_sign
113
114    def _add_authorization_header(self, request, string_to_sign):
115        try:
116            signature = sign_string(self.account_key, string_to_sign)
117            auth_string = 'SharedKey ' + self.account_name + ':' + signature
118            request.http_request.headers['Authorization'] = auth_string
119        except Exception as ex:
120            # Wrap any error that occurred as signing error
121            # Doing so will clarify/locate the source of problem
122            raise _wrap_exception(ex, AzureSigningError)
123
124    def on_request(self, request):
125        string_to_sign = \
126            self._get_verb(request) + \
127            self._get_headers(
128                request,
129                [
130                    'content-encoding', 'content-language', 'content-length',
131                    'content-md5', 'content-type', 'date', 'if-modified-since',
132                    'if-match', 'if-none-match', 'if-unmodified-since', 'byte_range'
133                ]
134            ) + \
135            self._get_canonicalized_headers(request) + \
136            self._get_canonicalized_resource(request) + \
137            self._get_canonicalized_resource_query(request)
138
139        self._add_authorization_header(request, string_to_sign)
140        #logger.debug("String_to_sign=%s", string_to_sign)
141