1from __future__ import absolute_import, print_function, division, unicode_literals
2
3import _io
4import inspect
5import json as json_module
6import logging
7import re
8import six
9
10from collections import namedtuple
11from functools import update_wrapper
12from requests.adapters import HTTPAdapter
13from requests.exceptions import ConnectionError
14from requests.sessions import REDIRECT_STATI
15from requests.utils import cookiejar_from_dict
16
17try:
18    from collections.abc import Sequence, Sized
19except ImportError:
20    from collections import Sequence, Sized
21
22try:
23    from requests.packages.urllib3.response import HTTPResponse
24except ImportError:
25    from urllib3.response import HTTPResponse
26
27if six.PY2:
28    from urlparse import urlparse, parse_qsl, urlsplit, urlunsplit
29    from urllib import quote
30else:
31    from urllib.parse import urlparse, parse_qsl, urlsplit, urlunsplit, quote
32
33if six.PY2:
34    try:
35        from six import cStringIO as BufferIO
36    except ImportError:
37        from six import StringIO as BufferIO
38else:
39    from io import BytesIO as BufferIO
40
41try:
42    from unittest import mock as std_mock
43except ImportError:
44    import mock as std_mock
45
46try:
47    Pattern = re._pattern_type
48except AttributeError:
49    # Python 3.7
50    Pattern = re.Pattern
51
52UNSET = object()
53
54Call = namedtuple("Call", ["request", "response"])
55
56_real_send = HTTPAdapter.send
57
58logger = logging.getLogger("responses")
59
60
61def _is_string(s):
62    return isinstance(s, six.string_types)
63
64
65def _has_unicode(s):
66    return any(ord(char) > 128 for char in s)
67
68
69def _clean_unicode(url):
70    # Clean up domain names, which use punycode to handle unicode chars
71    urllist = list(urlsplit(url))
72    netloc = urllist[1]
73    if _has_unicode(netloc):
74        domains = netloc.split(".")
75        for i, d in enumerate(domains):
76            if _has_unicode(d):
77                d = "xn--" + d.encode("punycode").decode("ascii")
78                domains[i] = d
79        urllist[1] = ".".join(domains)
80        url = urlunsplit(urllist)
81
82    # Clean up path/query/params, which use url-encoding to handle unicode chars
83    if isinstance(url.encode("utf8"), six.string_types):
84        url = url.encode("utf8")
85    chars = list(url)
86    for i, x in enumerate(chars):
87        if ord(x) > 128:
88            chars[i] = quote(x)
89
90    return "".join(chars)
91
92
93def _is_redirect(response):
94    try:
95        # 2.0.0 <= requests <= 2.2
96        return response.is_redirect
97
98    except AttributeError:
99        # requests > 2.2
100        return (
101            # use request.sessions conditional
102            response.status_code in REDIRECT_STATI
103            and "location" in response.headers
104        )
105
106
107def _cookies_from_headers(headers):
108    try:
109        import http.cookies as cookies
110
111        resp_cookie = cookies.SimpleCookie()
112        resp_cookie.load(headers["set-cookie"])
113
114        cookies_dict = {name: v.value for name, v in resp_cookie.items()}
115    except ImportError:
116        from cookies import Cookies
117
118        resp_cookies = Cookies.from_request(headers["set-cookie"])
119        cookies_dict = {v.name: v.value for _, v in resp_cookies.items()}
120    return cookiejar_from_dict(cookies_dict)
121
122
123_wrapper_template = """\
124def wrapper%(wrapper_args)s:
125    with responses:
126        return func%(func_args)s
127"""
128
129
130def get_wrapped(func, responses):
131    if six.PY2:
132        args, a, kw, defaults = inspect.getargspec(func)
133        wrapper_args = inspect.formatargspec(args, a, kw, defaults)
134
135        # Preserve the argspec for the wrapped function so that testing
136        # tools such as pytest can continue to use their fixture injection.
137        if hasattr(func, "__self__"):
138            args = args[1:]  # Omit 'self'
139        func_args = inspect.formatargspec(args, a, kw, None)
140    else:
141        signature = inspect.signature(func)
142        signature = signature.replace(return_annotation=inspect.Signature.empty)
143        # If the function is wrapped, switch to *args, **kwargs for the parameters
144        # as we can't rely on the signature to give us the arguments the function will
145        # be called with. For example unittest.mock.patch uses required args that are
146        # not actually passed to the function when invoked.
147        if hasattr(func, "__wrapped__"):
148            wrapper_params = [
149                inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL),
150                inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD),
151            ]
152        else:
153            wrapper_params = [
154                param.replace(annotation=inspect.Parameter.empty)
155                for param in signature.parameters.values()
156            ]
157        signature = signature.replace(parameters=wrapper_params)
158
159        wrapper_args = str(signature)
160        params_without_defaults = [
161            param.replace(
162                annotation=inspect.Parameter.empty, default=inspect.Parameter.empty
163            )
164            for param in signature.parameters.values()
165        ]
166        signature = signature.replace(parameters=params_without_defaults)
167        func_args = str(signature)
168
169    evaldict = {"func": func, "responses": responses}
170    six.exec_(
171        _wrapper_template % {"wrapper_args": wrapper_args, "func_args": func_args},
172        evaldict,
173    )
174    wrapper = evaldict["wrapper"]
175    update_wrapper(wrapper, func)
176    return wrapper
177
178
179class CallList(Sequence, Sized):
180    def __init__(self):
181        self._calls = []
182
183    def __iter__(self):
184        return iter(self._calls)
185
186    def __len__(self):
187        return len(self._calls)
188
189    def __getitem__(self, idx):
190        return self._calls[idx]
191
192    def add(self, request, response):
193        self._calls.append(Call(request, response))
194
195    def reset(self):
196        self._calls = []
197
198
199def _ensure_url_default_path(url):
200    if _is_string(url):
201        url_parts = list(urlsplit(url))
202        if url_parts[2] == "":
203            url_parts[2] = "/"
204        url = urlunsplit(url_parts)
205    return url
206
207
208def _handle_body(body):
209    if isinstance(body, six.text_type):
210        body = body.encode("utf-8")
211    if isinstance(body, _io.BufferedReader):
212        return body
213
214    return BufferIO(body)
215
216
217_unspecified = object()
218
219
220class BaseResponse(object):
221    content_type = None
222    headers = None
223
224    stream = False
225
226    def __init__(self, method, url, match_querystring=_unspecified):
227        self.method = method
228        # ensure the url has a default path set if the url is a string
229        self.url = _ensure_url_default_path(url)
230        self.match_querystring = self._should_match_querystring(match_querystring)
231        self.call_count = 0
232
233    def __eq__(self, other):
234        if not isinstance(other, BaseResponse):
235            return False
236
237        if self.method != other.method:
238            return False
239
240        # Can't simply do a equality check on the objects directly here since __eq__ isn't
241        # implemented for regex. It might seem to work as regex is using a cache to return
242        # the same regex instances, but it doesn't in all cases.
243        self_url = self.url.pattern if isinstance(self.url, Pattern) else self.url
244        other_url = other.url.pattern if isinstance(other.url, Pattern) else other.url
245
246        return self_url == other_url
247
248    def __ne__(self, other):
249        return not self.__eq__(other)
250
251    def _url_matches_strict(self, url, other):
252        url_parsed = urlparse(url)
253        other_parsed = urlparse(other)
254
255        if url_parsed[:3] != other_parsed[:3]:
256            return False
257
258        url_qsl = sorted(parse_qsl(url_parsed.query))
259        other_qsl = sorted(parse_qsl(other_parsed.query))
260
261        if len(url_qsl) != len(other_qsl):
262            return False
263
264        for (a_k, a_v), (b_k, b_v) in zip(url_qsl, other_qsl):
265            if a_k != b_k:
266                return False
267
268            if a_v != b_v:
269                return False
270
271        return True
272
273    def _should_match_querystring(self, match_querystring_argument):
274        if match_querystring_argument is not _unspecified:
275            return match_querystring_argument
276
277        if isinstance(self.url, Pattern):
278            # the old default from <= 0.9.0
279            return False
280
281        return bool(urlparse(self.url).query)
282
283    def _url_matches(self, url, other, match_querystring=False):
284        if _is_string(url):
285            if _has_unicode(url):
286                url = _clean_unicode(url)
287                if not isinstance(other, six.text_type):
288                    other = other.encode("ascii").decode("utf8")
289            if match_querystring:
290                return self._url_matches_strict(url, other)
291
292            else:
293                url_without_qs = url.split("?", 1)[0]
294                other_without_qs = other.split("?", 1)[0]
295                return url_without_qs == other_without_qs
296
297        elif isinstance(url, Pattern) and url.match(other):
298            return True
299
300        else:
301            return False
302
303    def get_headers(self):
304        headers = {}
305        if self.content_type is not None:
306            headers["Content-Type"] = self.content_type
307        if self.headers:
308            headers.update(self.headers)
309        return headers
310
311    def get_response(self, request):
312        raise NotImplementedError
313
314    def matches(self, request):
315        if request.method != self.method:
316            return False
317
318        if not self._url_matches(self.url, request.url, self.match_querystring):
319            return False
320
321        return True
322
323
324class Response(BaseResponse):
325    def __init__(
326        self,
327        method,
328        url,
329        body="",
330        json=None,
331        status=200,
332        headers=None,
333        stream=False,
334        content_type=UNSET,
335        **kwargs
336    ):
337        # if we were passed a `json` argument,
338        # override the body and content_type
339        if json is not None:
340            assert not body
341            body = json_module.dumps(json)
342            if content_type is UNSET:
343                content_type = "application/json"
344
345        if content_type is UNSET:
346            content_type = "text/plain"
347
348        # body must be bytes
349        if isinstance(body, six.text_type):
350            body = body.encode("utf-8")
351
352        self.body = body
353        self.status = status
354        self.headers = headers
355        self.stream = stream
356        self.content_type = content_type
357        super(Response, self).__init__(method, url, **kwargs)
358
359    def get_response(self, request):
360        if self.body and isinstance(self.body, Exception):
361            raise self.body
362
363        headers = self.get_headers()
364        status = self.status
365        body = _handle_body(self.body)
366
367        return HTTPResponse(
368            status=status,
369            reason=six.moves.http_client.responses.get(status),
370            body=body,
371            headers=headers,
372            preload_content=False,
373        )
374
375
376class CallbackResponse(BaseResponse):
377    def __init__(
378        self, method, url, callback, stream=False, content_type="text/plain", **kwargs
379    ):
380        self.callback = callback
381        self.stream = stream
382        self.content_type = content_type
383        super(CallbackResponse, self).__init__(method, url, **kwargs)
384
385    def get_response(self, request):
386        headers = self.get_headers()
387
388        result = self.callback(request)
389        if isinstance(result, Exception):
390            raise result
391
392        status, r_headers, body = result
393        if isinstance(body, Exception):
394            raise body
395
396        body = _handle_body(body)
397        headers.update(r_headers)
398
399        return HTTPResponse(
400            status=status,
401            reason=six.moves.http_client.responses.get(status),
402            body=body,
403            headers=headers,
404            preload_content=False,
405        )
406
407
408class RequestsMock(object):
409    DELETE = "DELETE"
410    GET = "GET"
411    HEAD = "HEAD"
412    OPTIONS = "OPTIONS"
413    PATCH = "PATCH"
414    POST = "POST"
415    PUT = "PUT"
416    response_callback = None
417
418    def __init__(
419        self,
420        assert_all_requests_are_fired=True,
421        response_callback=None,
422        passthru_prefixes=(),
423        target="requests.adapters.HTTPAdapter.send",
424    ):
425        self._calls = CallList()
426        self.reset()
427        self.assert_all_requests_are_fired = assert_all_requests_are_fired
428        self.response_callback = response_callback
429        self.passthru_prefixes = tuple(passthru_prefixes)
430        self.target = target
431
432    def reset(self):
433        self._matches = []
434        self._calls.reset()
435
436    def add(
437        self,
438        method=None,  # method or ``Response``
439        url=None,
440        body="",
441        adding_headers=None,
442        *args,
443        **kwargs
444    ):
445        """
446        A basic request:
447
448        >>> responses.add(responses.GET, 'http://example.com')
449
450        You can also directly pass an object which implements the
451        ``BaseResponse`` interface:
452
453        >>> responses.add(Response(...))
454
455        A JSON payload:
456
457        >>> responses.add(
458        >>>     method='GET',
459        >>>     url='http://example.com',
460        >>>     json={'foo': 'bar'},
461        >>> )
462
463        Custom headers:
464
465        >>> responses.add(
466        >>>     method='GET',
467        >>>     url='http://example.com',
468        >>>     headers={'X-Header': 'foo'},
469        >>> )
470
471
472        Strict query string matching:
473
474        >>> responses.add(
475        >>>     method='GET',
476        >>>     url='http://example.com?foo=bar',
477        >>>     match_querystring=True
478        >>> )
479        """
480        if isinstance(method, BaseResponse):
481            self._matches.append(method)
482            return
483
484        if adding_headers is not None:
485            kwargs.setdefault("headers", adding_headers)
486
487        self._matches.append(Response(method=method, url=url, body=body, **kwargs))
488
489    def add_passthru(self, prefix):
490        """
491        Register a URL prefix to passthru any non-matching mock requests to.
492
493        For example, to allow any request to 'https://example.com', but require
494        mocks for the remainder, you would add the prefix as so:
495
496        >>> responses.add_passthru('https://example.com')
497        """
498        if _has_unicode(prefix):
499            prefix = _clean_unicode(prefix)
500        self.passthru_prefixes += (prefix,)
501
502    def remove(self, method_or_response=None, url=None):
503        """
504        Removes a response previously added using ``add()``, identified
505        either by a response object inheriting ``BaseResponse`` or
506        ``method`` and ``url``. Removes all matching responses.
507
508        >>> response.add(responses.GET, 'http://example.org')
509        >>> response.remove(responses.GET, 'http://example.org')
510        """
511        if isinstance(method_or_response, BaseResponse):
512            response = method_or_response
513        else:
514            response = BaseResponse(method=method_or_response, url=url)
515
516        while response in self._matches:
517            self._matches.remove(response)
518
519    def replace(self, method_or_response=None, url=None, body="", *args, **kwargs):
520        """
521        Replaces a response previously added using ``add()``. The signature
522        is identical to ``add()``. The response is identified using ``method``
523        and ``url``, and the first matching response is replaced.
524
525        >>> responses.add(responses.GET, 'http://example.org', json={'data': 1})
526        >>> responses.replace(responses.GET, 'http://example.org', json={'data': 2})
527        """
528        if isinstance(method_or_response, BaseResponse):
529            response = method_or_response
530        else:
531            response = Response(method=method_or_response, url=url, body=body, **kwargs)
532
533        index = self._matches.index(response)
534        self._matches[index] = response
535
536    def add_callback(
537        self, method, url, callback, match_querystring=False, content_type="text/plain"
538    ):
539        # ensure the url has a default path set if the url is a string
540        # url = _ensure_url_default_path(url, match_querystring)
541
542        self._matches.append(
543            CallbackResponse(
544                url=url,
545                method=method,
546                callback=callback,
547                content_type=content_type,
548                match_querystring=match_querystring,
549            )
550        )
551
552    @property
553    def calls(self):
554        return self._calls
555
556    def __enter__(self):
557        self.start()
558        return self
559
560    def __exit__(self, type, value, traceback):
561        success = type is None
562        self.stop(allow_assert=success)
563        self.reset()
564        return success
565
566    def activate(self, func):
567        return get_wrapped(func, self)
568
569    def _find_match(self, request):
570        found = None
571        found_match = None
572        for i, match in enumerate(self._matches):
573            if match.matches(request):
574                if found is None:
575                    found = i
576                    found_match = match
577                else:
578                    # Multiple matches found.  Remove & return the first match.
579                    return self._matches.pop(found)
580
581        return found_match
582
583    def _on_request(self, adapter, request, **kwargs):
584        match = self._find_match(request)
585        resp_callback = self.response_callback
586
587        if match is None:
588            if request.url.startswith(self.passthru_prefixes):
589                logger.info("request.allowed-passthru", extra={"url": request.url})
590                return _real_send(adapter, request, **kwargs)
591
592            error_msg = (
593                "Connection refused by Responses: {0} {1} doesn't "
594                "match Responses Mock".format(request.method, request.url)
595            )
596            response = ConnectionError(error_msg)
597            response.request = request
598
599            self._calls.add(request, response)
600            response = resp_callback(response) if resp_callback else response
601            raise response
602
603        try:
604            response = adapter.build_response(request, match.get_response(request))
605        except Exception as response:
606            match.call_count += 1
607            self._calls.add(request, response)
608            response = resp_callback(response) if resp_callback else response
609            raise
610
611        if not match.stream:
612            response.content  # NOQA
613
614        try:
615            response.cookies = _cookies_from_headers(response.headers)
616        except (KeyError, TypeError):
617            pass
618
619        response = resp_callback(response) if resp_callback else response
620        match.call_count += 1
621        self._calls.add(request, response)
622        return response
623
624    def start(self):
625        def unbound_on_send(adapter, request, *a, **kwargs):
626            return self._on_request(adapter, request, *a, **kwargs)
627
628        self._patcher = std_mock.patch(target=self.target, new=unbound_on_send)
629        self._patcher.start()
630
631    def stop(self, allow_assert=True):
632        self._patcher.stop()
633        if not self.assert_all_requests_are_fired:
634            return
635
636        if not allow_assert:
637            return
638
639        not_called = [m for m in self._matches if m.call_count == 0]
640        if not_called:
641            raise AssertionError(
642                "Not all requests have been executed {0!r}".format(
643                    [(match.method, match.url) for match in not_called]
644                )
645            )
646
647
648# expose default mock namespace
649mock = _default_mock = RequestsMock(assert_all_requests_are_fired=False)
650__all__ = ["CallbackResponse", "Response", "RequestsMock"]
651for __attr in (a for a in dir(_default_mock) if not a.startswith("_")):
652    __all__.append(__attr)
653    globals()[__attr] = getattr(_default_mock, __attr)
654