1from __future__ import absolute_import, print_function, division, unicode_literals 2 3import _io 4import inspect 5import json as json_module 6import logging 7import re 8import six 9 10from collections import namedtuple 11from functools import update_wrapper 12from requests.adapters import HTTPAdapter 13from requests.exceptions import ConnectionError 14from requests.sessions import REDIRECT_STATI 15from requests.utils import cookiejar_from_dict 16 17try: 18 from collections.abc import Sequence, Sized 19except ImportError: 20 from collections import Sequence, Sized 21 22try: 23 from requests.packages.urllib3.response import HTTPResponse 24except ImportError: 25 from urllib3.response import HTTPResponse 26 27if six.PY2: 28 from urlparse import urlparse, parse_qsl, urlsplit, urlunsplit 29 from urllib import quote 30else: 31 from urllib.parse import urlparse, parse_qsl, urlsplit, urlunsplit, quote 32 33if six.PY2: 34 try: 35 from six import cStringIO as BufferIO 36 except ImportError: 37 from six import StringIO as BufferIO 38else: 39 from io import BytesIO as BufferIO 40 41try: 42 from unittest import mock as std_mock 43except ImportError: 44 import mock as std_mock 45 46try: 47 Pattern = re._pattern_type 48except AttributeError: 49 # Python 3.7 50 Pattern = re.Pattern 51 52UNSET = object() 53 54Call = namedtuple("Call", ["request", "response"]) 55 56_real_send = HTTPAdapter.send 57 58logger = logging.getLogger("responses") 59 60 61def _is_string(s): 62 return isinstance(s, six.string_types) 63 64 65def _has_unicode(s): 66 return any(ord(char) > 128 for char in s) 67 68 69def _clean_unicode(url): 70 # Clean up domain names, which use punycode to handle unicode chars 71 urllist = list(urlsplit(url)) 72 netloc = urllist[1] 73 if _has_unicode(netloc): 74 domains = netloc.split(".") 75 for i, d in enumerate(domains): 76 if _has_unicode(d): 77 d = "xn--" + d.encode("punycode").decode("ascii") 78 domains[i] = d 79 urllist[1] = ".".join(domains) 80 url = urlunsplit(urllist) 81 82 # Clean up path/query/params, which use url-encoding to handle unicode chars 83 if isinstance(url.encode("utf8"), six.string_types): 84 url = url.encode("utf8") 85 chars = list(url) 86 for i, x in enumerate(chars): 87 if ord(x) > 128: 88 chars[i] = quote(x) 89 90 return "".join(chars) 91 92 93def _is_redirect(response): 94 try: 95 # 2.0.0 <= requests <= 2.2 96 return response.is_redirect 97 98 except AttributeError: 99 # requests > 2.2 100 return ( 101 # use request.sessions conditional 102 response.status_code in REDIRECT_STATI 103 and "location" in response.headers 104 ) 105 106 107def _cookies_from_headers(headers): 108 try: 109 import http.cookies as cookies 110 111 resp_cookie = cookies.SimpleCookie() 112 resp_cookie.load(headers["set-cookie"]) 113 114 cookies_dict = {name: v.value for name, v in resp_cookie.items()} 115 except ImportError: 116 from cookies import Cookies 117 118 resp_cookies = Cookies.from_request(headers["set-cookie"]) 119 cookies_dict = {v.name: v.value for _, v in resp_cookies.items()} 120 return cookiejar_from_dict(cookies_dict) 121 122 123_wrapper_template = """\ 124def wrapper%(wrapper_args)s: 125 with responses: 126 return func%(func_args)s 127""" 128 129 130def get_wrapped(func, responses): 131 if six.PY2: 132 args, a, kw, defaults = inspect.getargspec(func) 133 wrapper_args = inspect.formatargspec(args, a, kw, defaults) 134 135 # Preserve the argspec for the wrapped function so that testing 136 # tools such as pytest can continue to use their fixture injection. 137 if hasattr(func, "__self__"): 138 args = args[1:] # Omit 'self' 139 func_args = inspect.formatargspec(args, a, kw, None) 140 else: 141 signature = inspect.signature(func) 142 signature = signature.replace(return_annotation=inspect.Signature.empty) 143 # If the function is wrapped, switch to *args, **kwargs for the parameters 144 # as we can't rely on the signature to give us the arguments the function will 145 # be called with. For example unittest.mock.patch uses required args that are 146 # not actually passed to the function when invoked. 147 if hasattr(func, "__wrapped__"): 148 wrapper_params = [ 149 inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL), 150 inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD), 151 ] 152 else: 153 wrapper_params = [ 154 param.replace(annotation=inspect.Parameter.empty) 155 for param in signature.parameters.values() 156 ] 157 signature = signature.replace(parameters=wrapper_params) 158 159 wrapper_args = str(signature) 160 params_without_defaults = [ 161 param.replace( 162 annotation=inspect.Parameter.empty, default=inspect.Parameter.empty 163 ) 164 for param in signature.parameters.values() 165 ] 166 signature = signature.replace(parameters=params_without_defaults) 167 func_args = str(signature) 168 169 evaldict = {"func": func, "responses": responses} 170 six.exec_( 171 _wrapper_template % {"wrapper_args": wrapper_args, "func_args": func_args}, 172 evaldict, 173 ) 174 wrapper = evaldict["wrapper"] 175 update_wrapper(wrapper, func) 176 return wrapper 177 178 179class CallList(Sequence, Sized): 180 def __init__(self): 181 self._calls = [] 182 183 def __iter__(self): 184 return iter(self._calls) 185 186 def __len__(self): 187 return len(self._calls) 188 189 def __getitem__(self, idx): 190 return self._calls[idx] 191 192 def add(self, request, response): 193 self._calls.append(Call(request, response)) 194 195 def reset(self): 196 self._calls = [] 197 198 199def _ensure_url_default_path(url): 200 if _is_string(url): 201 url_parts = list(urlsplit(url)) 202 if url_parts[2] == "": 203 url_parts[2] = "/" 204 url = urlunsplit(url_parts) 205 return url 206 207 208def _handle_body(body): 209 if isinstance(body, six.text_type): 210 body = body.encode("utf-8") 211 if isinstance(body, _io.BufferedReader): 212 return body 213 214 return BufferIO(body) 215 216 217_unspecified = object() 218 219 220class BaseResponse(object): 221 content_type = None 222 headers = None 223 224 stream = False 225 226 def __init__(self, method, url, match_querystring=_unspecified): 227 self.method = method 228 # ensure the url has a default path set if the url is a string 229 self.url = _ensure_url_default_path(url) 230 self.match_querystring = self._should_match_querystring(match_querystring) 231 self.call_count = 0 232 233 def __eq__(self, other): 234 if not isinstance(other, BaseResponse): 235 return False 236 237 if self.method != other.method: 238 return False 239 240 # Can't simply do a equality check on the objects directly here since __eq__ isn't 241 # implemented for regex. It might seem to work as regex is using a cache to return 242 # the same regex instances, but it doesn't in all cases. 243 self_url = self.url.pattern if isinstance(self.url, Pattern) else self.url 244 other_url = other.url.pattern if isinstance(other.url, Pattern) else other.url 245 246 return self_url == other_url 247 248 def __ne__(self, other): 249 return not self.__eq__(other) 250 251 def _url_matches_strict(self, url, other): 252 url_parsed = urlparse(url) 253 other_parsed = urlparse(other) 254 255 if url_parsed[:3] != other_parsed[:3]: 256 return False 257 258 url_qsl = sorted(parse_qsl(url_parsed.query)) 259 other_qsl = sorted(parse_qsl(other_parsed.query)) 260 261 if len(url_qsl) != len(other_qsl): 262 return False 263 264 for (a_k, a_v), (b_k, b_v) in zip(url_qsl, other_qsl): 265 if a_k != b_k: 266 return False 267 268 if a_v != b_v: 269 return False 270 271 return True 272 273 def _should_match_querystring(self, match_querystring_argument): 274 if match_querystring_argument is not _unspecified: 275 return match_querystring_argument 276 277 if isinstance(self.url, Pattern): 278 # the old default from <= 0.9.0 279 return False 280 281 return bool(urlparse(self.url).query) 282 283 def _url_matches(self, url, other, match_querystring=False): 284 if _is_string(url): 285 if _has_unicode(url): 286 url = _clean_unicode(url) 287 if not isinstance(other, six.text_type): 288 other = other.encode("ascii").decode("utf8") 289 if match_querystring: 290 return self._url_matches_strict(url, other) 291 292 else: 293 url_without_qs = url.split("?", 1)[0] 294 other_without_qs = other.split("?", 1)[0] 295 return url_without_qs == other_without_qs 296 297 elif isinstance(url, Pattern) and url.match(other): 298 return True 299 300 else: 301 return False 302 303 def get_headers(self): 304 headers = {} 305 if self.content_type is not None: 306 headers["Content-Type"] = self.content_type 307 if self.headers: 308 headers.update(self.headers) 309 return headers 310 311 def get_response(self, request): 312 raise NotImplementedError 313 314 def matches(self, request): 315 if request.method != self.method: 316 return False 317 318 if not self._url_matches(self.url, request.url, self.match_querystring): 319 return False 320 321 return True 322 323 324class Response(BaseResponse): 325 def __init__( 326 self, 327 method, 328 url, 329 body="", 330 json=None, 331 status=200, 332 headers=None, 333 stream=False, 334 content_type=UNSET, 335 **kwargs 336 ): 337 # if we were passed a `json` argument, 338 # override the body and content_type 339 if json is not None: 340 assert not body 341 body = json_module.dumps(json) 342 if content_type is UNSET: 343 content_type = "application/json" 344 345 if content_type is UNSET: 346 content_type = "text/plain" 347 348 # body must be bytes 349 if isinstance(body, six.text_type): 350 body = body.encode("utf-8") 351 352 self.body = body 353 self.status = status 354 self.headers = headers 355 self.stream = stream 356 self.content_type = content_type 357 super(Response, self).__init__(method, url, **kwargs) 358 359 def get_response(self, request): 360 if self.body and isinstance(self.body, Exception): 361 raise self.body 362 363 headers = self.get_headers() 364 status = self.status 365 body = _handle_body(self.body) 366 367 return HTTPResponse( 368 status=status, 369 reason=six.moves.http_client.responses.get(status), 370 body=body, 371 headers=headers, 372 preload_content=False, 373 ) 374 375 376class CallbackResponse(BaseResponse): 377 def __init__( 378 self, method, url, callback, stream=False, content_type="text/plain", **kwargs 379 ): 380 self.callback = callback 381 self.stream = stream 382 self.content_type = content_type 383 super(CallbackResponse, self).__init__(method, url, **kwargs) 384 385 def get_response(self, request): 386 headers = self.get_headers() 387 388 result = self.callback(request) 389 if isinstance(result, Exception): 390 raise result 391 392 status, r_headers, body = result 393 if isinstance(body, Exception): 394 raise body 395 396 body = _handle_body(body) 397 headers.update(r_headers) 398 399 return HTTPResponse( 400 status=status, 401 reason=six.moves.http_client.responses.get(status), 402 body=body, 403 headers=headers, 404 preload_content=False, 405 ) 406 407 408class RequestsMock(object): 409 DELETE = "DELETE" 410 GET = "GET" 411 HEAD = "HEAD" 412 OPTIONS = "OPTIONS" 413 PATCH = "PATCH" 414 POST = "POST" 415 PUT = "PUT" 416 response_callback = None 417 418 def __init__( 419 self, 420 assert_all_requests_are_fired=True, 421 response_callback=None, 422 passthru_prefixes=(), 423 target="requests.adapters.HTTPAdapter.send", 424 ): 425 self._calls = CallList() 426 self.reset() 427 self.assert_all_requests_are_fired = assert_all_requests_are_fired 428 self.response_callback = response_callback 429 self.passthru_prefixes = tuple(passthru_prefixes) 430 self.target = target 431 432 def reset(self): 433 self._matches = [] 434 self._calls.reset() 435 436 def add( 437 self, 438 method=None, # method or ``Response`` 439 url=None, 440 body="", 441 adding_headers=None, 442 *args, 443 **kwargs 444 ): 445 """ 446 A basic request: 447 448 >>> responses.add(responses.GET, 'http://example.com') 449 450 You can also directly pass an object which implements the 451 ``BaseResponse`` interface: 452 453 >>> responses.add(Response(...)) 454 455 A JSON payload: 456 457 >>> responses.add( 458 >>> method='GET', 459 >>> url='http://example.com', 460 >>> json={'foo': 'bar'}, 461 >>> ) 462 463 Custom headers: 464 465 >>> responses.add( 466 >>> method='GET', 467 >>> url='http://example.com', 468 >>> headers={'X-Header': 'foo'}, 469 >>> ) 470 471 472 Strict query string matching: 473 474 >>> responses.add( 475 >>> method='GET', 476 >>> url='http://example.com?foo=bar', 477 >>> match_querystring=True 478 >>> ) 479 """ 480 if isinstance(method, BaseResponse): 481 self._matches.append(method) 482 return 483 484 if adding_headers is not None: 485 kwargs.setdefault("headers", adding_headers) 486 487 self._matches.append(Response(method=method, url=url, body=body, **kwargs)) 488 489 def add_passthru(self, prefix): 490 """ 491 Register a URL prefix to passthru any non-matching mock requests to. 492 493 For example, to allow any request to 'https://example.com', but require 494 mocks for the remainder, you would add the prefix as so: 495 496 >>> responses.add_passthru('https://example.com') 497 """ 498 if _has_unicode(prefix): 499 prefix = _clean_unicode(prefix) 500 self.passthru_prefixes += (prefix,) 501 502 def remove(self, method_or_response=None, url=None): 503 """ 504 Removes a response previously added using ``add()``, identified 505 either by a response object inheriting ``BaseResponse`` or 506 ``method`` and ``url``. Removes all matching responses. 507 508 >>> response.add(responses.GET, 'http://example.org') 509 >>> response.remove(responses.GET, 'http://example.org') 510 """ 511 if isinstance(method_or_response, BaseResponse): 512 response = method_or_response 513 else: 514 response = BaseResponse(method=method_or_response, url=url) 515 516 while response in self._matches: 517 self._matches.remove(response) 518 519 def replace(self, method_or_response=None, url=None, body="", *args, **kwargs): 520 """ 521 Replaces a response previously added using ``add()``. The signature 522 is identical to ``add()``. The response is identified using ``method`` 523 and ``url``, and the first matching response is replaced. 524 525 >>> responses.add(responses.GET, 'http://example.org', json={'data': 1}) 526 >>> responses.replace(responses.GET, 'http://example.org', json={'data': 2}) 527 """ 528 if isinstance(method_or_response, BaseResponse): 529 response = method_or_response 530 else: 531 response = Response(method=method_or_response, url=url, body=body, **kwargs) 532 533 index = self._matches.index(response) 534 self._matches[index] = response 535 536 def add_callback( 537 self, method, url, callback, match_querystring=False, content_type="text/plain" 538 ): 539 # ensure the url has a default path set if the url is a string 540 # url = _ensure_url_default_path(url, match_querystring) 541 542 self._matches.append( 543 CallbackResponse( 544 url=url, 545 method=method, 546 callback=callback, 547 content_type=content_type, 548 match_querystring=match_querystring, 549 ) 550 ) 551 552 @property 553 def calls(self): 554 return self._calls 555 556 def __enter__(self): 557 self.start() 558 return self 559 560 def __exit__(self, type, value, traceback): 561 success = type is None 562 self.stop(allow_assert=success) 563 self.reset() 564 return success 565 566 def activate(self, func): 567 return get_wrapped(func, self) 568 569 def _find_match(self, request): 570 found = None 571 found_match = None 572 for i, match in enumerate(self._matches): 573 if match.matches(request): 574 if found is None: 575 found = i 576 found_match = match 577 else: 578 # Multiple matches found. Remove & return the first match. 579 return self._matches.pop(found) 580 581 return found_match 582 583 def _on_request(self, adapter, request, **kwargs): 584 match = self._find_match(request) 585 resp_callback = self.response_callback 586 587 if match is None: 588 if request.url.startswith(self.passthru_prefixes): 589 logger.info("request.allowed-passthru", extra={"url": request.url}) 590 return _real_send(adapter, request, **kwargs) 591 592 error_msg = ( 593 "Connection refused by Responses: {0} {1} doesn't " 594 "match Responses Mock".format(request.method, request.url) 595 ) 596 response = ConnectionError(error_msg) 597 response.request = request 598 599 self._calls.add(request, response) 600 response = resp_callback(response) if resp_callback else response 601 raise response 602 603 try: 604 response = adapter.build_response(request, match.get_response(request)) 605 except Exception as response: 606 match.call_count += 1 607 self._calls.add(request, response) 608 response = resp_callback(response) if resp_callback else response 609 raise 610 611 if not match.stream: 612 response.content # NOQA 613 614 try: 615 response.cookies = _cookies_from_headers(response.headers) 616 except (KeyError, TypeError): 617 pass 618 619 response = resp_callback(response) if resp_callback else response 620 match.call_count += 1 621 self._calls.add(request, response) 622 return response 623 624 def start(self): 625 def unbound_on_send(adapter, request, *a, **kwargs): 626 return self._on_request(adapter, request, *a, **kwargs) 627 628 self._patcher = std_mock.patch(target=self.target, new=unbound_on_send) 629 self._patcher.start() 630 631 def stop(self, allow_assert=True): 632 self._patcher.stop() 633 if not self.assert_all_requests_are_fired: 634 return 635 636 if not allow_assert: 637 return 638 639 not_called = [m for m in self._matches if m.call_count == 0] 640 if not_called: 641 raise AssertionError( 642 "Not all requests have been executed {0!r}".format( 643 [(match.method, match.url) for match in not_called] 644 ) 645 ) 646 647 648# expose default mock namespace 649mock = _default_mock = RequestsMock(assert_all_requests_are_fired=False) 650__all__ = ["CallbackResponse", "Response", "RequestsMock"] 651for __attr in (a for a in dir(_default_mock) if not a.startswith("_")): 652 __all__.append(__attr) 653 globals()[__attr] = getattr(_default_mock, __attr) 654