1"""Utilities shared by tests.""" 2 3import asyncio 4import contextlib 5import functools 6import gc 7import inspect 8import os 9import socket 10import sys 11import unittest 12from abc import ABC, abstractmethod 13from types import TracebackType 14from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Type, Union 15from unittest import mock 16 17from multidict import CIMultiDict, CIMultiDictProxy 18from yarl import URL 19 20import aiohttp 21from aiohttp.client import ( 22 ClientResponse, 23 _RequestContextManager, 24 _WSRequestContextManager, 25) 26 27from . import ClientSession, hdrs 28from .abc import AbstractCookieJar 29from .client_reqrep import ClientResponse 30from .client_ws import ClientWebSocketResponse 31from .helpers import sentinel 32from .http import HttpVersion, RawRequestMessage 33from .signals import Signal 34from .web import ( 35 Application, 36 AppRunner, 37 BaseRunner, 38 Request, 39 Server, 40 ServerRunner, 41 SockSite, 42 UrlMappingMatchInfo, 43) 44from .web_protocol import _RequestHandler 45 46if TYPE_CHECKING: # pragma: no cover 47 from ssl import SSLContext 48else: 49 SSLContext = None 50 51 52REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" 53 54 55def get_unused_port_socket(host: str) -> socket.socket: 56 return get_port_socket(host, 0) 57 58 59def get_port_socket(host: str, port: int) -> socket.socket: 60 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 61 if REUSE_ADDRESS: 62 # Windows has different semantics for SO_REUSEADDR, 63 # so don't set it. Ref: 64 # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse 65 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 66 s.bind((host, port)) 67 return s 68 69 70def unused_port() -> int: 71 """Return a port that is unused on the current host.""" 72 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 73 s.bind(("127.0.0.1", 0)) 74 return s.getsockname()[1] 75 76 77class BaseTestServer(ABC): 78 __test__ = False 79 80 def __init__( 81 self, 82 *, 83 scheme: Union[str, object] = sentinel, 84 loop: Optional[asyncio.AbstractEventLoop] = None, 85 host: str = "127.0.0.1", 86 port: Optional[int] = None, 87 skip_url_asserts: bool = False, 88 **kwargs: Any, 89 ) -> None: 90 self._loop = loop 91 self.runner = None # type: Optional[BaseRunner] 92 self._root = None # type: Optional[URL] 93 self.host = host 94 self.port = port 95 self._closed = False 96 self.scheme = scheme 97 self.skip_url_asserts = skip_url_asserts 98 99 async def start_server( 100 self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any 101 ) -> None: 102 if self.runner: 103 return 104 self._loop = loop 105 self._ssl = kwargs.pop("ssl", None) 106 self.runner = await self._make_runner(**kwargs) 107 await self.runner.setup() 108 if not self.port: 109 self.port = 0 110 _sock = get_port_socket(self.host, self.port) 111 self.host, self.port = _sock.getsockname()[:2] 112 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl) 113 await site.start() 114 server = site._server 115 assert server is not None 116 sockets = server.sockets 117 assert sockets is not None 118 self.port = sockets[0].getsockname()[1] 119 if self.scheme is sentinel: 120 if self._ssl: 121 scheme = "https" 122 else: 123 scheme = "http" 124 self.scheme = scheme 125 self._root = URL(f"{self.scheme}://{self.host}:{self.port}") 126 127 @abstractmethod # pragma: no cover 128 async def _make_runner(self, **kwargs: Any) -> BaseRunner: 129 pass 130 131 def make_url(self, path: str) -> URL: 132 assert self._root is not None 133 url = URL(path) 134 if not self.skip_url_asserts: 135 assert not url.is_absolute() 136 return self._root.join(url) 137 else: 138 return URL(str(self._root) + path) 139 140 @property 141 def started(self) -> bool: 142 return self.runner is not None 143 144 @property 145 def closed(self) -> bool: 146 return self._closed 147 148 @property 149 def handler(self) -> Server: 150 # for backward compatibility 151 # web.Server instance 152 runner = self.runner 153 assert runner is not None 154 assert runner.server is not None 155 return runner.server 156 157 async def close(self) -> None: 158 """Close all fixtures created by the test client. 159 160 After that point, the TestClient is no longer usable. 161 162 This is an idempotent function: running close multiple times 163 will not have any additional effects. 164 165 close is also run when the object is garbage collected, and on 166 exit when used as a context manager. 167 168 """ 169 if self.started and not self.closed: 170 assert self.runner is not None 171 await self.runner.cleanup() 172 self._root = None 173 self.port = None 174 self._closed = True 175 176 def __enter__(self) -> None: 177 raise TypeError("Use async with instead") 178 179 def __exit__( 180 self, 181 exc_type: Optional[Type[BaseException]], 182 exc_value: Optional[BaseException], 183 traceback: Optional[TracebackType], 184 ) -> None: 185 # __exit__ should exist in pair with __enter__ but never executed 186 pass # pragma: no cover 187 188 async def __aenter__(self) -> "BaseTestServer": 189 await self.start_server(loop=self._loop) 190 return self 191 192 async def __aexit__( 193 self, 194 exc_type: Optional[Type[BaseException]], 195 exc_value: Optional[BaseException], 196 traceback: Optional[TracebackType], 197 ) -> None: 198 await self.close() 199 200 201class TestServer(BaseTestServer): 202 def __init__( 203 self, 204 app: Application, 205 *, 206 scheme: Union[str, object] = sentinel, 207 host: str = "127.0.0.1", 208 port: Optional[int] = None, 209 **kwargs: Any, 210 ): 211 self.app = app 212 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 213 214 async def _make_runner(self, **kwargs: Any) -> BaseRunner: 215 return AppRunner(self.app, **kwargs) 216 217 218class RawTestServer(BaseTestServer): 219 def __init__( 220 self, 221 handler: _RequestHandler, 222 *, 223 scheme: Union[str, object] = sentinel, 224 host: str = "127.0.0.1", 225 port: Optional[int] = None, 226 **kwargs: Any, 227 ) -> None: 228 self._handler = handler 229 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 230 231 async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner: 232 srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs) 233 return ServerRunner(srv, debug=debug, **kwargs) 234 235 236class TestClient: 237 """ 238 A test client implementation. 239 240 To write functional tests for aiohttp based servers. 241 242 """ 243 244 __test__ = False 245 246 def __init__( 247 self, 248 server: BaseTestServer, 249 *, 250 cookie_jar: Optional[AbstractCookieJar] = None, 251 loop: Optional[asyncio.AbstractEventLoop] = None, 252 **kwargs: Any, 253 ) -> None: 254 if not isinstance(server, BaseTestServer): 255 raise TypeError( 256 "server must be TestServer " "instance, found type: %r" % type(server) 257 ) 258 self._server = server 259 self._loop = loop 260 if cookie_jar is None: 261 cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop) 262 self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs) 263 self._closed = False 264 self._responses = [] # type: List[ClientResponse] 265 self._websockets = [] # type: List[ClientWebSocketResponse] 266 267 async def start_server(self) -> None: 268 await self._server.start_server(loop=self._loop) 269 270 @property 271 def host(self) -> str: 272 return self._server.host 273 274 @property 275 def port(self) -> Optional[int]: 276 return self._server.port 277 278 @property 279 def server(self) -> BaseTestServer: 280 return self._server 281 282 @property 283 def app(self) -> Application: 284 return getattr(self._server, "app", None) 285 286 @property 287 def session(self) -> ClientSession: 288 """An internal aiohttp.ClientSession. 289 290 Unlike the methods on the TestClient, client session requests 291 do not automatically include the host in the url queried, and 292 will require an absolute path to the resource. 293 294 """ 295 return self._session 296 297 def make_url(self, path: str) -> URL: 298 return self._server.make_url(path) 299 300 async def _request(self, method: str, path: str, **kwargs: Any) -> ClientResponse: 301 resp = await self._session.request(method, self.make_url(path), **kwargs) 302 # save it to close later 303 self._responses.append(resp) 304 return resp 305 306 def request(self, method: str, path: str, **kwargs: Any) -> _RequestContextManager: 307 """Routes a request to tested http server. 308 309 The interface is identical to aiohttp.ClientSession.request, 310 except the loop kwarg is overridden by the instance used by the 311 test server. 312 313 """ 314 return _RequestContextManager(self._request(method, path, **kwargs)) 315 316 def get(self, path: str, **kwargs: Any) -> _RequestContextManager: 317 """Perform an HTTP GET request.""" 318 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) 319 320 def post(self, path: str, **kwargs: Any) -> _RequestContextManager: 321 """Perform an HTTP POST request.""" 322 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) 323 324 def options(self, path: str, **kwargs: Any) -> _RequestContextManager: 325 """Perform an HTTP OPTIONS request.""" 326 return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs)) 327 328 def head(self, path: str, **kwargs: Any) -> _RequestContextManager: 329 """Perform an HTTP HEAD request.""" 330 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) 331 332 def put(self, path: str, **kwargs: Any) -> _RequestContextManager: 333 """Perform an HTTP PUT request.""" 334 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) 335 336 def patch(self, path: str, **kwargs: Any) -> _RequestContextManager: 337 """Perform an HTTP PATCH request.""" 338 return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs)) 339 340 def delete(self, path: str, **kwargs: Any) -> _RequestContextManager: 341 """Perform an HTTP PATCH request.""" 342 return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs)) 343 344 def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager: 345 """Initiate websocket connection. 346 347 The api corresponds to aiohttp.ClientSession.ws_connect. 348 349 """ 350 return _WSRequestContextManager(self._ws_connect(path, **kwargs)) 351 352 async def _ws_connect(self, path: str, **kwargs: Any) -> ClientWebSocketResponse: 353 ws = await self._session.ws_connect(self.make_url(path), **kwargs) 354 self._websockets.append(ws) 355 return ws 356 357 async def close(self) -> None: 358 """Close all fixtures created by the test client. 359 360 After that point, the TestClient is no longer usable. 361 362 This is an idempotent function: running close multiple times 363 will not have any additional effects. 364 365 close is also run on exit when used as a(n) (asynchronous) 366 context manager. 367 368 """ 369 if not self._closed: 370 for resp in self._responses: 371 resp.close() 372 for ws in self._websockets: 373 await ws.close() 374 await self._session.close() 375 await self._server.close() 376 self._closed = True 377 378 def __enter__(self) -> None: 379 raise TypeError("Use async with instead") 380 381 def __exit__( 382 self, 383 exc_type: Optional[Type[BaseException]], 384 exc: Optional[BaseException], 385 tb: Optional[TracebackType], 386 ) -> None: 387 # __exit__ should exist in pair with __enter__ but never executed 388 pass # pragma: no cover 389 390 async def __aenter__(self) -> "TestClient": 391 await self.start_server() 392 return self 393 394 async def __aexit__( 395 self, 396 exc_type: Optional[Type[BaseException]], 397 exc: Optional[BaseException], 398 tb: Optional[TracebackType], 399 ) -> None: 400 await self.close() 401 402 403class AioHTTPTestCase(unittest.TestCase): 404 """A base class to allow for unittest web applications using 405 aiohttp. 406 407 Provides the following: 408 409 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client. 410 * self.loop (asyncio.BaseEventLoop): the event loop in which the 411 application and server are running. 412 * self.app (aiohttp.web.Application): the application returned by 413 self.get_application() 414 415 Note that the TestClient's methods are asynchronous: you have to 416 execute function on the test client using asynchronous methods. 417 """ 418 419 async def get_application(self) -> Application: 420 """ 421 This method should be overridden 422 to return the aiohttp.web.Application 423 object to test. 424 425 """ 426 return self.get_app() 427 428 def get_app(self) -> Application: 429 """Obsolete method used to constructing web application. 430 431 Use .get_application() coroutine instead 432 433 """ 434 raise RuntimeError("Did you forget to define get_application()?") 435 436 def setUp(self) -> None: 437 self.loop = setup_test_loop() 438 439 self.app = self.loop.run_until_complete(self.get_application()) 440 self.server = self.loop.run_until_complete(self.get_server(self.app)) 441 self.client = self.loop.run_until_complete(self.get_client(self.server)) 442 443 self.loop.run_until_complete(self.client.start_server()) 444 445 self.loop.run_until_complete(self.setUpAsync()) 446 447 async def setUpAsync(self) -> None: 448 pass 449 450 def tearDown(self) -> None: 451 self.loop.run_until_complete(self.tearDownAsync()) 452 self.loop.run_until_complete(self.client.close()) 453 teardown_test_loop(self.loop) 454 455 async def tearDownAsync(self) -> None: 456 pass 457 458 async def get_server(self, app: Application) -> TestServer: 459 """Return a TestServer instance.""" 460 return TestServer(app, loop=self.loop) 461 462 async def get_client(self, server: TestServer) -> TestClient: 463 """Return a TestClient instance.""" 464 return TestClient(server, loop=self.loop) 465 466 467def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any: 468 """A decorator dedicated to use with asynchronous methods of an 469 AioHTTPTestCase. 470 471 Handles executing an asynchronous function, using 472 the self.loop of the AioHTTPTestCase. 473 """ 474 475 @functools.wraps(func, *args, **kwargs) 476 def new_func(self: Any, *inner_args: Any, **inner_kwargs: Any) -> Any: 477 return self.loop.run_until_complete(func(self, *inner_args, **inner_kwargs)) 478 479 return new_func 480 481 482_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop] 483 484 485@contextlib.contextmanager 486def loop_context( 487 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False 488) -> Iterator[asyncio.AbstractEventLoop]: 489 """A contextmanager that creates an event_loop, for test purposes. 490 491 Handles the creation and cleanup of a test loop. 492 """ 493 loop = setup_test_loop(loop_factory) 494 yield loop 495 teardown_test_loop(loop, fast=fast) 496 497 498def setup_test_loop( 499 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, 500) -> asyncio.AbstractEventLoop: 501 """Create and return an asyncio.BaseEventLoop 502 instance. 503 504 The caller should also call teardown_test_loop, 505 once they are done with the loop. 506 """ 507 loop = loop_factory() 508 try: 509 module = loop.__class__.__module__ 510 skip_watcher = "uvloop" in module 511 except AttributeError: # pragma: no cover 512 # Just in case 513 skip_watcher = True 514 asyncio.set_event_loop(loop) 515 if sys.platform != "win32" and not skip_watcher: 516 policy = asyncio.get_event_loop_policy() 517 watcher = asyncio.SafeChildWatcher() 518 watcher.attach_loop(loop) 519 with contextlib.suppress(NotImplementedError): 520 policy.set_child_watcher(watcher) 521 return loop 522 523 524def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: 525 """Teardown and cleanup an event_loop created 526 by setup_test_loop. 527 528 """ 529 closed = loop.is_closed() 530 if not closed: 531 loop.call_soon(loop.stop) 532 loop.run_forever() 533 loop.close() 534 535 if not fast: 536 gc.collect() 537 538 asyncio.set_event_loop(None) 539 540 541def _create_app_mock() -> mock.MagicMock: 542 def get_dict(app: Any, key: str) -> Any: 543 return app.__app_dict[key] 544 545 def set_dict(app: Any, key: str, value: Any) -> None: 546 app.__app_dict[key] = value 547 548 app = mock.MagicMock() 549 app.__app_dict = {} 550 app.__getitem__ = get_dict 551 app.__setitem__ = set_dict 552 553 app._debug = False 554 app.on_response_prepare = Signal(app) 555 app.on_response_prepare.freeze() 556 return app 557 558 559def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock: 560 transport = mock.Mock() 561 562 def get_extra_info(key: str) -> Optional[SSLContext]: 563 if key == "sslcontext": 564 return sslcontext 565 else: 566 return None 567 568 transport.get_extra_info.side_effect = get_extra_info 569 return transport 570 571 572def make_mocked_request( 573 method: str, 574 path: str, 575 headers: Any = None, 576 *, 577 match_info: Any = sentinel, 578 version: HttpVersion = HttpVersion(1, 1), 579 closing: bool = False, 580 app: Any = None, 581 writer: Any = sentinel, 582 protocol: Any = sentinel, 583 transport: Any = sentinel, 584 payload: Any = sentinel, 585 sslcontext: Optional[SSLContext] = None, 586 client_max_size: int = 1024 ** 2, 587 loop: Any = ..., 588) -> Request: 589 """Creates mocked web.Request testing purposes. 590 591 Useful in unit tests, when spinning full web server is overkill or 592 specific conditions and errors are hard to trigger. 593 594 """ 595 596 task = mock.Mock() 597 if loop is ...: 598 loop = mock.Mock() 599 loop.create_future.return_value = () 600 601 if version < HttpVersion(1, 1): 602 closing = True 603 604 if headers: 605 headers = CIMultiDictProxy(CIMultiDict(headers)) 606 raw_hdrs = tuple( 607 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() 608 ) 609 else: 610 headers = CIMultiDictProxy(CIMultiDict()) 611 raw_hdrs = () 612 613 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower() 614 615 message = RawRequestMessage( 616 method, 617 path, 618 version, 619 headers, 620 raw_hdrs, 621 closing, 622 False, 623 False, 624 chunked, 625 URL(path), 626 ) 627 if app is None: 628 app = _create_app_mock() 629 630 if transport is sentinel: 631 transport = _create_transport(sslcontext) 632 633 if protocol is sentinel: 634 protocol = mock.Mock() 635 protocol.transport = transport 636 637 if writer is sentinel: 638 writer = mock.Mock() 639 writer.write_headers = make_mocked_coro(None) 640 writer.write = make_mocked_coro(None) 641 writer.write_eof = make_mocked_coro(None) 642 writer.drain = make_mocked_coro(None) 643 writer.transport = transport 644 645 protocol.transport = transport 646 protocol.writer = writer 647 648 if payload is sentinel: 649 payload = mock.Mock() 650 651 req = Request( 652 message, payload, protocol, writer, task, loop, client_max_size=client_max_size 653 ) 654 655 match_info = UrlMappingMatchInfo( 656 {} if match_info is sentinel else match_info, mock.Mock() 657 ) 658 match_info.add_app(app) 659 req._match_info = match_info 660 661 return req 662 663 664def make_mocked_coro( 665 return_value: Any = sentinel, raise_exception: Any = sentinel 666) -> Any: 667 """Creates a coroutine mock.""" 668 669 async def mock_coro(*args: Any, **kwargs: Any) -> Any: 670 if raise_exception is not sentinel: 671 raise raise_exception 672 if not inspect.isawaitable(return_value): 673 return return_value 674 await return_value 675 676 return mock.Mock(wraps=mock_coro) 677