1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26"""
27This module represents universal policy that works whatever the HTTPSender implementation
28"""
29import json
30import logging
31import os
32import xml.etree.ElementTree as ET
33import platform
34import codecs
35import re
36
37from typing import Mapping, Any, Optional, AnyStr, Union, IO, cast, TYPE_CHECKING  # pylint: disable=unused-import
38
39from ..version import msrest_version as _msrest_version
40from . import SansIOHTTPPolicy
41from ..exceptions import DeserializationError, raise_with_traceback
42from ..http_logger import log_request, log_response
43
44if TYPE_CHECKING:
45    from . import Request, Response  # pylint: disable=unused-import
46
47
48_LOGGER = logging.getLogger(__name__)
49
50_BOM = codecs.BOM_UTF8.decode(encoding='utf-8')
51
52
53class HeadersPolicy(SansIOHTTPPolicy):
54    """A simple policy that sends the given headers
55    with the request.
56
57    This overwrite any headers already defined in the request.
58    """
59    def __init__(self, headers):
60        # type: (Mapping[str, str]) -> None
61        self.headers = headers
62
63    def on_request(self, request, **kwargs):
64        # type: (Request, Any) -> None
65        http_request = request.http_request
66        http_request.headers.update(self.headers)
67
68class UserAgentPolicy(SansIOHTTPPolicy):
69    _USERAGENT = "User-Agent"
70    _ENV_ADDITIONAL_USER_AGENT = 'AZURE_HTTP_USER_AGENT'
71
72    def __init__(self, user_agent=None, overwrite=False):
73        # type: (Optional[str], bool) -> None
74        self._overwrite = overwrite
75        if user_agent is None:
76            self._user_agent = "python/{} ({}) msrest/{}".format(
77                platform.python_version(),
78                platform.platform(),
79                _msrest_version
80            )
81        else:
82            self._user_agent = user_agent
83
84        # Whatever you gave me a header explicitly or not,
85        # if the env variable is set, add to it.
86        add_user_agent_header = os.environ.get(self._ENV_ADDITIONAL_USER_AGENT, None)
87        if add_user_agent_header is not None:
88            self.add_user_agent(add_user_agent_header)
89
90    @property
91    def user_agent(self):
92        # type: () -> str
93        """The current user agent value."""
94        return self._user_agent
95
96    def add_user_agent(self, value):
97        # type: (str) -> None
98        """Add value to current user agent with a space.
99
100        :param str value: value to add to user agent.
101        """
102        self._user_agent = "{} {}".format(self._user_agent, value)
103
104    def on_request(self, request, **kwargs):
105        # type: (Request, Any) -> None
106        http_request = request.http_request
107        if self._overwrite or self._USERAGENT not in http_request.headers:
108            http_request.headers[self._USERAGENT] = self._user_agent
109
110class HTTPLogger(SansIOHTTPPolicy):
111    """A policy that logs HTTP request and response to the DEBUG logger.
112
113    This accepts both global configuration, and kwargs request level with "enable_http_logger"
114    """
115    def __init__(self, enable_http_logger = False):
116        self.enable_http_logger = enable_http_logger
117
118    def on_request(self, request, **kwargs):
119        # type: (Request, Any) -> None
120        http_request = request.http_request
121        if kwargs.get("enable_http_logger", self.enable_http_logger):
122            log_request(None, http_request)
123
124    def on_response(self, request, response, **kwargs):
125        # type: (Request, Response, Any) -> None
126        http_request = request.http_request
127        if kwargs.get("enable_http_logger", self.enable_http_logger):
128            log_response(None, http_request, response.http_response, result=response)
129
130
131class RawDeserializer(SansIOHTTPPolicy):
132
133    # Accept "text" because we're open minded people...
134    JSON_REGEXP = re.compile(r'^(application|text)/([a-z+.]+\+)?json$')
135
136    # Name used in context
137    CONTEXT_NAME = "deserialized_data"
138
139    @classmethod
140    def deserialize_from_text(cls, data, content_type=None):
141        # type: (Optional[Union[AnyStr, IO]], Optional[str]) -> Any
142        """Decode data according to content-type.
143
144        Accept a stream of data as well, but will be load at once in memory for now.
145
146        If no content-type, will return the string version (not bytes, not stream)
147
148        :param data: Input, could be bytes or stream (will be decoded with UTF8) or text
149        :type data: str or bytes or IO
150        :param str content_type: The content type.
151        """
152        if hasattr(data, 'read'):
153            # Assume a stream
154            data = cast(IO, data).read()
155
156        if isinstance(data, bytes):
157            data_as_str = data.decode(encoding='utf-8-sig')
158        else:
159            # Explain to mypy the correct type.
160            data_as_str = cast(str, data)
161
162            # Remove Byte Order Mark if present in string
163            data_as_str = data_as_str.lstrip(_BOM)
164
165        if content_type is None:
166            return data
167
168        if cls.JSON_REGEXP.match(content_type):
169            try:
170                return json.loads(data_as_str)
171            except ValueError as err:
172                raise DeserializationError("JSON is invalid: {}".format(err), err)
173        elif "xml" in (content_type or []):
174            try:
175
176                try:
177                    if isinstance(data, unicode):  # type: ignore
178                        # If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string
179                        data_as_str = data_as_str.encode(encoding="utf-8")  # type: ignore
180                except NameError:
181                    pass
182
183                return ET.fromstring(data_as_str)
184            except ET.ParseError:
185                # It might be because the server has an issue, and returned JSON with
186                # content-type XML....
187                # So let's try a JSON load, and if it's still broken
188                # let's flow the initial exception
189                def _json_attemp(data):
190                    try:
191                        return True, json.loads(data)
192                    except ValueError:
193                        return False, None # Don't care about this one
194                success, json_result = _json_attemp(data)
195                if success:
196                    return json_result
197                # If i'm here, it's not JSON, it's not XML, let's scream
198                # and raise the last context in this block (the XML exception)
199                # The function hack is because Py2.7 messes up with exception
200                # context otherwise.
201                _LOGGER.critical("Wasn't XML not JSON, failing")
202                raise_with_traceback(DeserializationError, "XML is invalid")
203        raise DeserializationError("Cannot deserialize content-type: {}".format(content_type))
204
205    @classmethod
206    def deserialize_from_http_generics(cls, body_bytes, headers):
207        # type: (Optional[Union[AnyStr, IO]], Mapping) -> Any
208        """Deserialize from HTTP response.
209
210        Use bytes and headers to NOT use any requests/aiohttp or whatever
211        specific implementation.
212        Headers will tested for "content-type"
213        """
214        # Try to use content-type from headers if available
215        content_type = None
216        if 'content-type' in headers:
217            content_type = headers['content-type'].split(";")[0].strip().lower()
218        # Ouch, this server did not declare what it sent...
219        # Let's guess it's JSON...
220        # Also, since Autorest was considering that an empty body was a valid JSON,
221        # need that test as well....
222        else:
223            content_type = "application/json"
224
225        if body_bytes:
226            return cls.deserialize_from_text(body_bytes, content_type)
227        return None
228
229    def on_response(self, request, response, **kwargs):
230        # type: (Request, Response, Any) -> None
231        """Extract data from the body of a REST response object.
232
233        This will load the entire payload in memory.
234
235        Will follow Content-Type to parse.
236        We assume everything is UTF8 (BOM acceptable).
237
238        :param raw_data: Data to be processed.
239        :param content_type: How to parse if raw_data is a string/bytes.
240        :raises JSONDecodeError: If JSON is requested and parsing is impossible.
241        :raises UnicodeDecodeError: If bytes is not UTF8
242        :raises xml.etree.ElementTree.ParseError: If bytes is not valid XML
243        """
244        # If response was asked as stream, do NOT read anything and quit now
245        if kwargs.get("stream", True):
246            return
247
248        http_response = response.http_response
249
250        response.context[self.CONTEXT_NAME] = self.deserialize_from_http_generics(
251            http_response.text(),
252            http_response.headers
253        )
254