1"""Utilities shared by tests."""
2
3import asyncio
4import contextlib
5import functools
6import gc
7import inspect
8import os
9import socket
10import sys
11import unittest
12from abc import ABC, abstractmethod
13from types import TracebackType
14from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Type, Union
15from unittest import mock
16
17from multidict import CIMultiDict, CIMultiDictProxy
18from yarl import URL
19
20import aiohttp
21from aiohttp.client import (
22    ClientResponse,
23    _RequestContextManager,
24    _WSRequestContextManager,
25)
26
27from . import ClientSession, hdrs
28from .abc import AbstractCookieJar
29from .client_reqrep import ClientResponse
30from .client_ws import ClientWebSocketResponse
31from .helpers import sentinel
32from .http import HttpVersion, RawRequestMessage
33from .signals import Signal
34from .web import (
35    Application,
36    AppRunner,
37    BaseRunner,
38    Request,
39    Server,
40    ServerRunner,
41    SockSite,
42    UrlMappingMatchInfo,
43)
44from .web_protocol import _RequestHandler
45
46if TYPE_CHECKING:  # pragma: no cover
47    from ssl import SSLContext
48else:
49    SSLContext = None
50
51
52REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
53
54
55def get_unused_port_socket(host: str) -> socket.socket:
56    return get_port_socket(host, 0)
57
58
59def get_port_socket(host: str, port: int) -> socket.socket:
60    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
61    if REUSE_ADDRESS:
62        # Windows has different semantics for SO_REUSEADDR,
63        # so don't set it. Ref:
64        # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
65        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
66    s.bind((host, port))
67    return s
68
69
70def unused_port() -> int:
71    """Return a port that is unused on the current host."""
72    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
73        s.bind(("127.0.0.1", 0))
74        return s.getsockname()[1]
75
76
77class BaseTestServer(ABC):
78    __test__ = False
79
80    def __init__(
81        self,
82        *,
83        scheme: Union[str, object] = sentinel,
84        loop: Optional[asyncio.AbstractEventLoop] = None,
85        host: str = "127.0.0.1",
86        port: Optional[int] = None,
87        skip_url_asserts: bool = False,
88        **kwargs: Any,
89    ) -> None:
90        self._loop = loop
91        self.runner = None  # type: Optional[BaseRunner]
92        self._root = None  # type: Optional[URL]
93        self.host = host
94        self.port = port
95        self._closed = False
96        self.scheme = scheme
97        self.skip_url_asserts = skip_url_asserts
98
99    async def start_server(
100        self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any
101    ) -> None:
102        if self.runner:
103            return
104        self._loop = loop
105        self._ssl = kwargs.pop("ssl", None)
106        self.runner = await self._make_runner(**kwargs)
107        await self.runner.setup()
108        if not self.port:
109            self.port = 0
110        _sock = get_port_socket(self.host, self.port)
111        self.host, self.port = _sock.getsockname()[:2]
112        site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
113        await site.start()
114        server = site._server
115        assert server is not None
116        sockets = server.sockets
117        assert sockets is not None
118        self.port = sockets[0].getsockname()[1]
119        if self.scheme is sentinel:
120            if self._ssl:
121                scheme = "https"
122            else:
123                scheme = "http"
124            self.scheme = scheme
125        self._root = URL(f"{self.scheme}://{self.host}:{self.port}")
126
127    @abstractmethod  # pragma: no cover
128    async def _make_runner(self, **kwargs: Any) -> BaseRunner:
129        pass
130
131    def make_url(self, path: str) -> URL:
132        assert self._root is not None
133        url = URL(path)
134        if not self.skip_url_asserts:
135            assert not url.is_absolute()
136            return self._root.join(url)
137        else:
138            return URL(str(self._root) + path)
139
140    @property
141    def started(self) -> bool:
142        return self.runner is not None
143
144    @property
145    def closed(self) -> bool:
146        return self._closed
147
148    @property
149    def handler(self) -> Server:
150        # for backward compatibility
151        # web.Server instance
152        runner = self.runner
153        assert runner is not None
154        assert runner.server is not None
155        return runner.server
156
157    async def close(self) -> None:
158        """Close all fixtures created by the test client.
159
160        After that point, the TestClient is no longer usable.
161
162        This is an idempotent function: running close multiple times
163        will not have any additional effects.
164
165        close is also run when the object is garbage collected, and on
166        exit when used as a context manager.
167
168        """
169        if self.started and not self.closed:
170            assert self.runner is not None
171            await self.runner.cleanup()
172            self._root = None
173            self.port = None
174            self._closed = True
175
176    def __enter__(self) -> None:
177        raise TypeError("Use async with instead")
178
179    def __exit__(
180        self,
181        exc_type: Optional[Type[BaseException]],
182        exc_value: Optional[BaseException],
183        traceback: Optional[TracebackType],
184    ) -> None:
185        # __exit__ should exist in pair with __enter__ but never executed
186        pass  # pragma: no cover
187
188    async def __aenter__(self) -> "BaseTestServer":
189        await self.start_server(loop=self._loop)
190        return self
191
192    async def __aexit__(
193        self,
194        exc_type: Optional[Type[BaseException]],
195        exc_value: Optional[BaseException],
196        traceback: Optional[TracebackType],
197    ) -> None:
198        await self.close()
199
200
201class TestServer(BaseTestServer):
202    def __init__(
203        self,
204        app: Application,
205        *,
206        scheme: Union[str, object] = sentinel,
207        host: str = "127.0.0.1",
208        port: Optional[int] = None,
209        **kwargs: Any,
210    ):
211        self.app = app
212        super().__init__(scheme=scheme, host=host, port=port, **kwargs)
213
214    async def _make_runner(self, **kwargs: Any) -> BaseRunner:
215        return AppRunner(self.app, **kwargs)
216
217
218class RawTestServer(BaseTestServer):
219    def __init__(
220        self,
221        handler: _RequestHandler,
222        *,
223        scheme: Union[str, object] = sentinel,
224        host: str = "127.0.0.1",
225        port: Optional[int] = None,
226        **kwargs: Any,
227    ) -> None:
228        self._handler = handler
229        super().__init__(scheme=scheme, host=host, port=port, **kwargs)
230
231    async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner:
232        srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs)
233        return ServerRunner(srv, debug=debug, **kwargs)
234
235
236class TestClient:
237    """
238    A test client implementation.
239
240    To write functional tests for aiohttp based servers.
241
242    """
243
244    __test__ = False
245
246    def __init__(
247        self,
248        server: BaseTestServer,
249        *,
250        cookie_jar: Optional[AbstractCookieJar] = None,
251        loop: Optional[asyncio.AbstractEventLoop] = None,
252        **kwargs: Any,
253    ) -> None:
254        if not isinstance(server, BaseTestServer):
255            raise TypeError(
256                "server must be TestServer " "instance, found type: %r" % type(server)
257            )
258        self._server = server
259        self._loop = loop
260        if cookie_jar is None:
261            cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop)
262        self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs)
263        self._closed = False
264        self._responses = []  # type: List[ClientResponse]
265        self._websockets = []  # type: List[ClientWebSocketResponse]
266
267    async def start_server(self) -> None:
268        await self._server.start_server(loop=self._loop)
269
270    @property
271    def host(self) -> str:
272        return self._server.host
273
274    @property
275    def port(self) -> Optional[int]:
276        return self._server.port
277
278    @property
279    def server(self) -> BaseTestServer:
280        return self._server
281
282    @property
283    def app(self) -> Application:
284        return getattr(self._server, "app", None)
285
286    @property
287    def session(self) -> ClientSession:
288        """An internal aiohttp.ClientSession.
289
290        Unlike the methods on the TestClient, client session requests
291        do not automatically include the host in the url queried, and
292        will require an absolute path to the resource.
293
294        """
295        return self._session
296
297    def make_url(self, path: str) -> URL:
298        return self._server.make_url(path)
299
300    async def _request(self, method: str, path: str, **kwargs: Any) -> ClientResponse:
301        resp = await self._session.request(method, self.make_url(path), **kwargs)
302        # save it to close later
303        self._responses.append(resp)
304        return resp
305
306    def request(self, method: str, path: str, **kwargs: Any) -> _RequestContextManager:
307        """Routes a request to tested http server.
308
309        The interface is identical to aiohttp.ClientSession.request,
310        except the loop kwarg is overridden by the instance used by the
311        test server.
312
313        """
314        return _RequestContextManager(self._request(method, path, **kwargs))
315
316    def get(self, path: str, **kwargs: Any) -> _RequestContextManager:
317        """Perform an HTTP GET request."""
318        return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
319
320    def post(self, path: str, **kwargs: Any) -> _RequestContextManager:
321        """Perform an HTTP POST request."""
322        return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
323
324    def options(self, path: str, **kwargs: Any) -> _RequestContextManager:
325        """Perform an HTTP OPTIONS request."""
326        return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs))
327
328    def head(self, path: str, **kwargs: Any) -> _RequestContextManager:
329        """Perform an HTTP HEAD request."""
330        return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
331
332    def put(self, path: str, **kwargs: Any) -> _RequestContextManager:
333        """Perform an HTTP PUT request."""
334        return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
335
336    def patch(self, path: str, **kwargs: Any) -> _RequestContextManager:
337        """Perform an HTTP PATCH request."""
338        return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs))
339
340    def delete(self, path: str, **kwargs: Any) -> _RequestContextManager:
341        """Perform an HTTP PATCH request."""
342        return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs))
343
344    def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager:
345        """Initiate websocket connection.
346
347        The api corresponds to aiohttp.ClientSession.ws_connect.
348
349        """
350        return _WSRequestContextManager(self._ws_connect(path, **kwargs))
351
352    async def _ws_connect(self, path: str, **kwargs: Any) -> ClientWebSocketResponse:
353        ws = await self._session.ws_connect(self.make_url(path), **kwargs)
354        self._websockets.append(ws)
355        return ws
356
357    async def close(self) -> None:
358        """Close all fixtures created by the test client.
359
360        After that point, the TestClient is no longer usable.
361
362        This is an idempotent function: running close multiple times
363        will not have any additional effects.
364
365        close is also run on exit when used as a(n) (asynchronous)
366        context manager.
367
368        """
369        if not self._closed:
370            for resp in self._responses:
371                resp.close()
372            for ws in self._websockets:
373                await ws.close()
374            await self._session.close()
375            await self._server.close()
376            self._closed = True
377
378    def __enter__(self) -> None:
379        raise TypeError("Use async with instead")
380
381    def __exit__(
382        self,
383        exc_type: Optional[Type[BaseException]],
384        exc: Optional[BaseException],
385        tb: Optional[TracebackType],
386    ) -> None:
387        # __exit__ should exist in pair with __enter__ but never executed
388        pass  # pragma: no cover
389
390    async def __aenter__(self) -> "TestClient":
391        await self.start_server()
392        return self
393
394    async def __aexit__(
395        self,
396        exc_type: Optional[Type[BaseException]],
397        exc: Optional[BaseException],
398        tb: Optional[TracebackType],
399    ) -> None:
400        await self.close()
401
402
403class AioHTTPTestCase(unittest.TestCase):
404    """A base class to allow for unittest web applications using
405    aiohttp.
406
407    Provides the following:
408
409    * self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
410    * self.loop (asyncio.BaseEventLoop): the event loop in which the
411        application and server are running.
412    * self.app (aiohttp.web.Application): the application returned by
413        self.get_application()
414
415    Note that the TestClient's methods are asynchronous: you have to
416    execute function on the test client using asynchronous methods.
417    """
418
419    async def get_application(self) -> Application:
420        """
421        This method should be overridden
422        to return the aiohttp.web.Application
423        object to test.
424
425        """
426        return self.get_app()
427
428    def get_app(self) -> Application:
429        """Obsolete method used to constructing web application.
430
431        Use .get_application() coroutine instead
432
433        """
434        raise RuntimeError("Did you forget to define get_application()?")
435
436    def setUp(self) -> None:
437        self.loop = setup_test_loop()
438
439        self.app = self.loop.run_until_complete(self.get_application())
440        self.server = self.loop.run_until_complete(self.get_server(self.app))
441        self.client = self.loop.run_until_complete(self.get_client(self.server))
442
443        self.loop.run_until_complete(self.client.start_server())
444
445        self.loop.run_until_complete(self.setUpAsync())
446
447    async def setUpAsync(self) -> None:
448        pass
449
450    def tearDown(self) -> None:
451        self.loop.run_until_complete(self.tearDownAsync())
452        self.loop.run_until_complete(self.client.close())
453        teardown_test_loop(self.loop)
454
455    async def tearDownAsync(self) -> None:
456        pass
457
458    async def get_server(self, app: Application) -> TestServer:
459        """Return a TestServer instance."""
460        return TestServer(app, loop=self.loop)
461
462    async def get_client(self, server: TestServer) -> TestClient:
463        """Return a TestClient instance."""
464        return TestClient(server, loop=self.loop)
465
466
467def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any:
468    """A decorator dedicated to use with asynchronous methods of an
469    AioHTTPTestCase.
470
471    Handles executing an asynchronous function, using
472    the self.loop of the AioHTTPTestCase.
473    """
474
475    @functools.wraps(func, *args, **kwargs)
476    def new_func(self: Any, *inner_args: Any, **inner_kwargs: Any) -> Any:
477        return self.loop.run_until_complete(func(self, *inner_args, **inner_kwargs))
478
479    return new_func
480
481
482_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
483
484
485@contextlib.contextmanager
486def loop_context(
487    loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
488) -> Iterator[asyncio.AbstractEventLoop]:
489    """A contextmanager that creates an event_loop, for test purposes.
490
491    Handles the creation and cleanup of a test loop.
492    """
493    loop = setup_test_loop(loop_factory)
494    yield loop
495    teardown_test_loop(loop, fast=fast)
496
497
498def setup_test_loop(
499    loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
500) -> asyncio.AbstractEventLoop:
501    """Create and return an asyncio.BaseEventLoop
502    instance.
503
504    The caller should also call teardown_test_loop,
505    once they are done with the loop.
506    """
507    loop = loop_factory()
508    try:
509        module = loop.__class__.__module__
510        skip_watcher = "uvloop" in module
511    except AttributeError:  # pragma: no cover
512        # Just in case
513        skip_watcher = True
514    asyncio.set_event_loop(loop)
515    if sys.platform != "win32" and not skip_watcher:
516        policy = asyncio.get_event_loop_policy()
517        watcher = asyncio.SafeChildWatcher()
518        watcher.attach_loop(loop)
519        with contextlib.suppress(NotImplementedError):
520            policy.set_child_watcher(watcher)
521    return loop
522
523
524def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
525    """Teardown and cleanup an event_loop created
526    by setup_test_loop.
527
528    """
529    closed = loop.is_closed()
530    if not closed:
531        loop.call_soon(loop.stop)
532        loop.run_forever()
533        loop.close()
534
535    if not fast:
536        gc.collect()
537
538    asyncio.set_event_loop(None)
539
540
541def _create_app_mock() -> mock.MagicMock:
542    def get_dict(app: Any, key: str) -> Any:
543        return app.__app_dict[key]
544
545    def set_dict(app: Any, key: str, value: Any) -> None:
546        app.__app_dict[key] = value
547
548    app = mock.MagicMock()
549    app.__app_dict = {}
550    app.__getitem__ = get_dict
551    app.__setitem__ = set_dict
552
553    app._debug = False
554    app.on_response_prepare = Signal(app)
555    app.on_response_prepare.freeze()
556    return app
557
558
559def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock:
560    transport = mock.Mock()
561
562    def get_extra_info(key: str) -> Optional[SSLContext]:
563        if key == "sslcontext":
564            return sslcontext
565        else:
566            return None
567
568    transport.get_extra_info.side_effect = get_extra_info
569    return transport
570
571
572def make_mocked_request(
573    method: str,
574    path: str,
575    headers: Any = None,
576    *,
577    match_info: Any = sentinel,
578    version: HttpVersion = HttpVersion(1, 1),
579    closing: bool = False,
580    app: Any = None,
581    writer: Any = sentinel,
582    protocol: Any = sentinel,
583    transport: Any = sentinel,
584    payload: Any = sentinel,
585    sslcontext: Optional[SSLContext] = None,
586    client_max_size: int = 1024 ** 2,
587    loop: Any = ...,
588) -> Request:
589    """Creates mocked web.Request testing purposes.
590
591    Useful in unit tests, when spinning full web server is overkill or
592    specific conditions and errors are hard to trigger.
593
594    """
595
596    task = mock.Mock()
597    if loop is ...:
598        loop = mock.Mock()
599        loop.create_future.return_value = ()
600
601    if version < HttpVersion(1, 1):
602        closing = True
603
604    if headers:
605        headers = CIMultiDictProxy(CIMultiDict(headers))
606        raw_hdrs = tuple(
607            (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
608        )
609    else:
610        headers = CIMultiDictProxy(CIMultiDict())
611        raw_hdrs = ()
612
613    chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
614
615    message = RawRequestMessage(
616        method,
617        path,
618        version,
619        headers,
620        raw_hdrs,
621        closing,
622        False,
623        False,
624        chunked,
625        URL(path),
626    )
627    if app is None:
628        app = _create_app_mock()
629
630    if transport is sentinel:
631        transport = _create_transport(sslcontext)
632
633    if protocol is sentinel:
634        protocol = mock.Mock()
635        protocol.transport = transport
636
637    if writer is sentinel:
638        writer = mock.Mock()
639        writer.write_headers = make_mocked_coro(None)
640        writer.write = make_mocked_coro(None)
641        writer.write_eof = make_mocked_coro(None)
642        writer.drain = make_mocked_coro(None)
643        writer.transport = transport
644
645    protocol.transport = transport
646    protocol.writer = writer
647
648    if payload is sentinel:
649        payload = mock.Mock()
650
651    req = Request(
652        message, payload, protocol, writer, task, loop, client_max_size=client_max_size
653    )
654
655    match_info = UrlMappingMatchInfo(
656        {} if match_info is sentinel else match_info, mock.Mock()
657    )
658    match_info.add_app(app)
659    req._match_info = match_info
660
661    return req
662
663
664def make_mocked_coro(
665    return_value: Any = sentinel, raise_exception: Any = sentinel
666) -> Any:
667    """Creates a coroutine mock."""
668
669    async def mock_coro(*args: Any, **kwargs: Any) -> Any:
670        if raise_exception is not sentinel:
671            raise raise_exception
672        if not inspect.isawaitable(return_value):
673            return return_value
674        await return_value
675
676    return mock.Mock(wraps=mock_coro)
677