1import asyncio
2import logging
3import warnings
4from functools import partial, update_wrapper
5from typing import (
6    TYPE_CHECKING,
7    Any,
8    AsyncIterator,
9    Awaitable,
10    Callable,
11    Dict,
12    Iterable,
13    Iterator,
14    List,
15    Mapping,
16    MutableMapping,
17    Optional,
18    Sequence,
19    Tuple,
20    Type,
21    Union,
22    cast,
23)
24
25from . import hdrs
26from .abc import (
27    AbstractAccessLogger,
28    AbstractMatchInfo,
29    AbstractRouter,
30    AbstractStreamWriter,
31)
32from .frozenlist import FrozenList
33from .helpers import DEBUG
34from .http_parser import RawRequestMessage
35from .log import web_logger
36from .signals import Signal
37from .streams import StreamReader
38from .web_log import AccessLogger
39from .web_middlewares import _fix_request_current_app
40from .web_protocol import RequestHandler
41from .web_request import Request
42from .web_response import StreamResponse
43from .web_routedef import AbstractRouteDef
44from .web_server import Server
45from .web_urldispatcher import (
46    AbstractResource,
47    AbstractRoute,
48    Domain,
49    MaskDomain,
50    MatchedSubAppResource,
51    PrefixedSubAppResource,
52    UrlDispatcher,
53)
54
55__all__ = ("Application", "CleanupError")
56
57
58if TYPE_CHECKING:  # pragma: no cover
59    _AppSignal = Signal[Callable[["Application"], Awaitable[None]]]
60    _RespPrepareSignal = Signal[Callable[[Request, StreamResponse], Awaitable[None]]]
61    _Handler = Callable[[Request], Awaitable[StreamResponse]]
62    _Middleware = Union[
63        Callable[[Request, _Handler], Awaitable[StreamResponse]],
64        Callable[["Application", _Handler], Awaitable[_Handler]],  # old-style
65    ]
66    _Middlewares = FrozenList[_Middleware]
67    _MiddlewaresHandlers = Optional[Sequence[Tuple[_Middleware, bool]]]
68    _Subapps = List["Application"]
69else:
70    # No type checker mode, skip types
71    _AppSignal = Signal
72    _RespPrepareSignal = Signal
73    _Handler = Callable
74    _Middleware = Callable
75    _Middlewares = FrozenList
76    _MiddlewaresHandlers = Optional[Sequence]
77    _Subapps = List
78
79
80class Application(MutableMapping[str, Any]):
81    ATTRS = frozenset(
82        [
83            "logger",
84            "_debug",
85            "_router",
86            "_loop",
87            "_handler_args",
88            "_middlewares",
89            "_middlewares_handlers",
90            "_run_middlewares",
91            "_state",
92            "_frozen",
93            "_pre_frozen",
94            "_subapps",
95            "_on_response_prepare",
96            "_on_startup",
97            "_on_shutdown",
98            "_on_cleanup",
99            "_client_max_size",
100            "_cleanup_ctx",
101        ]
102    )
103
104    def __init__(
105        self,
106        *,
107        logger: logging.Logger = web_logger,
108        router: Optional[UrlDispatcher] = None,
109        middlewares: Iterable[_Middleware] = (),
110        handler_args: Optional[Mapping[str, Any]] = None,
111        client_max_size: int = 1024 ** 2,
112        loop: Optional[asyncio.AbstractEventLoop] = None,
113        debug: Any = ...,  # mypy doesn't support ellipsis
114    ) -> None:
115        if router is None:
116            router = UrlDispatcher()
117        else:
118            warnings.warn(
119                "router argument is deprecated", DeprecationWarning, stacklevel=2
120            )
121        assert isinstance(router, AbstractRouter), router
122
123        if loop is not None:
124            warnings.warn(
125                "loop argument is deprecated", DeprecationWarning, stacklevel=2
126            )
127
128        if debug is not ...:
129            warnings.warn(
130                "debug argument is deprecated", DeprecationWarning, stacklevel=2
131            )
132        self._debug = debug
133        self._router = router  # type: UrlDispatcher
134        self._loop = loop
135        self._handler_args = handler_args
136        self.logger = logger
137
138        self._middlewares = FrozenList(middlewares)  # type: _Middlewares
139
140        # initialized on freezing
141        self._middlewares_handlers = None  # type: _MiddlewaresHandlers
142        # initialized on freezing
143        self._run_middlewares = None  # type: Optional[bool]
144
145        self._state = {}  # type: Dict[str, Any]
146        self._frozen = False
147        self._pre_frozen = False
148        self._subapps = []  # type: _Subapps
149
150        self._on_response_prepare = Signal(self)  # type: _RespPrepareSignal
151        self._on_startup = Signal(self)  # type: _AppSignal
152        self._on_shutdown = Signal(self)  # type: _AppSignal
153        self._on_cleanup = Signal(self)  # type: _AppSignal
154        self._cleanup_ctx = CleanupContext()
155        self._on_startup.append(self._cleanup_ctx._on_startup)
156        self._on_cleanup.append(self._cleanup_ctx._on_cleanup)
157        self._client_max_size = client_max_size
158
159    def __init_subclass__(cls: Type["Application"]) -> None:
160        warnings.warn(
161            "Inheritance class {} from web.Application "
162            "is discouraged".format(cls.__name__),
163            DeprecationWarning,
164            stacklevel=2,
165        )
166
167    if DEBUG:  # pragma: no cover
168
169        def __setattr__(self, name: str, val: Any) -> None:
170            if name not in self.ATTRS:
171                warnings.warn(
172                    "Setting custom web.Application.{} attribute "
173                    "is discouraged".format(name),
174                    DeprecationWarning,
175                    stacklevel=2,
176                )
177            super().__setattr__(name, val)
178
179    # MutableMapping API
180
181    def __eq__(self, other: object) -> bool:
182        return self is other
183
184    def __getitem__(self, key: str) -> Any:
185        return self._state[key]
186
187    def _check_frozen(self) -> None:
188        if self._frozen:
189            warnings.warn(
190                "Changing state of started or joined " "application is deprecated",
191                DeprecationWarning,
192                stacklevel=3,
193            )
194
195    def __setitem__(self, key: str, value: Any) -> None:
196        self._check_frozen()
197        self._state[key] = value
198
199    def __delitem__(self, key: str) -> None:
200        self._check_frozen()
201        del self._state[key]
202
203    def __len__(self) -> int:
204        return len(self._state)
205
206    def __iter__(self) -> Iterator[str]:
207        return iter(self._state)
208
209    ########
210    @property
211    def loop(self) -> asyncio.AbstractEventLoop:
212        # Technically the loop can be None
213        # but we mask it by explicit type cast
214        # to provide more convinient type annotation
215        warnings.warn("loop property is deprecated", DeprecationWarning, stacklevel=2)
216        return cast(asyncio.AbstractEventLoop, self._loop)
217
218    def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None:
219        if loop is None:
220            loop = asyncio.get_event_loop()
221        if self._loop is not None and self._loop is not loop:
222            raise RuntimeError(
223                "web.Application instance initialized with different loop"
224            )
225
226        self._loop = loop
227
228        # set loop debug
229        if self._debug is ...:
230            self._debug = loop.get_debug()
231
232        # set loop to sub applications
233        for subapp in self._subapps:
234            subapp._set_loop(loop)
235
236    @property
237    def pre_frozen(self) -> bool:
238        return self._pre_frozen
239
240    def pre_freeze(self) -> None:
241        if self._pre_frozen:
242            return
243
244        self._pre_frozen = True
245        self._middlewares.freeze()
246        self._router.freeze()
247        self._on_response_prepare.freeze()
248        self._cleanup_ctx.freeze()
249        self._on_startup.freeze()
250        self._on_shutdown.freeze()
251        self._on_cleanup.freeze()
252        self._middlewares_handlers = tuple(self._prepare_middleware())
253
254        # If current app and any subapp do not have middlewares avoid run all
255        # of the code footprint that it implies, which have a middleware
256        # hardcoded per app that sets up the current_app attribute. If no
257        # middlewares are configured the handler will receive the proper
258        # current_app without needing all of this code.
259        self._run_middlewares = True if self.middlewares else False
260
261        for subapp in self._subapps:
262            subapp.pre_freeze()
263            self._run_middlewares = self._run_middlewares or subapp._run_middlewares
264
265    @property
266    def frozen(self) -> bool:
267        return self._frozen
268
269    def freeze(self) -> None:
270        if self._frozen:
271            return
272
273        self.pre_freeze()
274        self._frozen = True
275        for subapp in self._subapps:
276            subapp.freeze()
277
278    @property
279    def debug(self) -> bool:
280        warnings.warn("debug property is deprecated", DeprecationWarning, stacklevel=2)
281        return self._debug
282
283    def _reg_subapp_signals(self, subapp: "Application") -> None:
284        def reg_handler(signame: str) -> None:
285            subsig = getattr(subapp, signame)
286
287            async def handler(app: "Application") -> None:
288                await subsig.send(subapp)
289
290            appsig = getattr(self, signame)
291            appsig.append(handler)
292
293        reg_handler("on_startup")
294        reg_handler("on_shutdown")
295        reg_handler("on_cleanup")
296
297    def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource:
298        if not isinstance(prefix, str):
299            raise TypeError("Prefix must be str")
300        prefix = prefix.rstrip("/")
301        if not prefix:
302            raise ValueError("Prefix cannot be empty")
303        factory = partial(PrefixedSubAppResource, prefix, subapp)
304        return self._add_subapp(factory, subapp)
305
306    def _add_subapp(
307        self, resource_factory: Callable[[], AbstractResource], subapp: "Application"
308    ) -> AbstractResource:
309        if self.frozen:
310            raise RuntimeError("Cannot add sub application to frozen application")
311        if subapp.frozen:
312            raise RuntimeError("Cannot add frozen application")
313        resource = resource_factory()
314        self.router.register_resource(resource)
315        self._reg_subapp_signals(subapp)
316        self._subapps.append(subapp)
317        subapp.pre_freeze()
318        if self._loop is not None:
319            subapp._set_loop(self._loop)
320        return resource
321
322    def add_domain(self, domain: str, subapp: "Application") -> AbstractResource:
323        if not isinstance(domain, str):
324            raise TypeError("Domain must be str")
325        elif "*" in domain:
326            rule = MaskDomain(domain)  # type: Domain
327        else:
328            rule = Domain(domain)
329        factory = partial(MatchedSubAppResource, rule, subapp)
330        return self._add_subapp(factory, subapp)
331
332    def add_routes(self, routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]:
333        return self.router.add_routes(routes)
334
335    @property
336    def on_response_prepare(self) -> _RespPrepareSignal:
337        return self._on_response_prepare
338
339    @property
340    def on_startup(self) -> _AppSignal:
341        return self._on_startup
342
343    @property
344    def on_shutdown(self) -> _AppSignal:
345        return self._on_shutdown
346
347    @property
348    def on_cleanup(self) -> _AppSignal:
349        return self._on_cleanup
350
351    @property
352    def cleanup_ctx(self) -> "CleanupContext":
353        return self._cleanup_ctx
354
355    @property
356    def router(self) -> UrlDispatcher:
357        return self._router
358
359    @property
360    def middlewares(self) -> _Middlewares:
361        return self._middlewares
362
363    def _make_handler(
364        self,
365        *,
366        loop: Optional[asyncio.AbstractEventLoop] = None,
367        access_log_class: Type[AbstractAccessLogger] = AccessLogger,
368        **kwargs: Any,
369    ) -> Server:
370
371        if not issubclass(access_log_class, AbstractAccessLogger):
372            raise TypeError(
373                "access_log_class must be subclass of "
374                "aiohttp.abc.AbstractAccessLogger, got {}".format(access_log_class)
375            )
376
377        self._set_loop(loop)
378        self.freeze()
379
380        kwargs["debug"] = self._debug
381        kwargs["access_log_class"] = access_log_class
382        if self._handler_args:
383            for k, v in self._handler_args.items():
384                kwargs[k] = v
385
386        return Server(
387            self._handle,  # type: ignore
388            request_factory=self._make_request,
389            loop=self._loop,
390            **kwargs,
391        )
392
393    def make_handler(
394        self,
395        *,
396        loop: Optional[asyncio.AbstractEventLoop] = None,
397        access_log_class: Type[AbstractAccessLogger] = AccessLogger,
398        **kwargs: Any,
399    ) -> Server:
400
401        warnings.warn(
402            "Application.make_handler(...) is deprecated, " "use AppRunner API instead",
403            DeprecationWarning,
404            stacklevel=2,
405        )
406
407        return self._make_handler(
408            loop=loop, access_log_class=access_log_class, **kwargs
409        )
410
411    async def startup(self) -> None:
412        """Causes on_startup signal
413
414        Should be called in the event loop along with the request handler.
415        """
416        await self.on_startup.send(self)
417
418    async def shutdown(self) -> None:
419        """Causes on_shutdown signal
420
421        Should be called before cleanup()
422        """
423        await self.on_shutdown.send(self)
424
425    async def cleanup(self) -> None:
426        """Causes on_cleanup signal
427
428        Should be called after shutdown()
429        """
430        await self.on_cleanup.send(self)
431
432    def _make_request(
433        self,
434        message: RawRequestMessage,
435        payload: StreamReader,
436        protocol: RequestHandler,
437        writer: AbstractStreamWriter,
438        task: "asyncio.Task[None]",
439        _cls: Type[Request] = Request,
440    ) -> Request:
441        return _cls(
442            message,
443            payload,
444            protocol,
445            writer,
446            task,
447            self._loop,
448            client_max_size=self._client_max_size,
449        )
450
451    def _prepare_middleware(self) -> Iterator[Tuple[_Middleware, bool]]:
452        for m in reversed(self._middlewares):
453            if getattr(m, "__middleware_version__", None) == 1:
454                yield m, True
455            else:
456                warnings.warn(
457                    'old-style middleware "{!r}" deprecated, ' "see #2252".format(m),
458                    DeprecationWarning,
459                    stacklevel=2,
460                )
461                yield m, False
462
463        yield _fix_request_current_app(self), True
464
465    async def _handle(self, request: Request) -> StreamResponse:
466        loop = asyncio.get_event_loop()
467        debug = loop.get_debug()
468        match_info = await self._router.resolve(request)
469        if debug:  # pragma: no cover
470            if not isinstance(match_info, AbstractMatchInfo):
471                raise TypeError(
472                    "match_info should be AbstractMatchInfo "
473                    "instance, not {!r}".format(match_info)
474                )
475        match_info.add_app(self)
476
477        match_info.freeze()
478
479        resp = None
480        request._match_info = match_info  # type: ignore
481        expect = request.headers.get(hdrs.EXPECT)
482        if expect:
483            resp = await match_info.expect_handler(request)
484            await request.writer.drain()
485
486        if resp is None:
487            handler = match_info.handler
488
489            if self._run_middlewares:
490                for app in match_info.apps[::-1]:
491                    for m, new_style in app._middlewares_handlers:  # type: ignore
492                        if new_style:
493                            handler = update_wrapper(
494                                partial(m, handler=handler), handler
495                            )
496                        else:
497                            handler = await m(app, handler)  # type: ignore
498
499            resp = await handler(request)
500
501        return resp
502
503    def __call__(self) -> "Application":
504        """gunicorn compatibility"""
505        return self
506
507    def __repr__(self) -> str:
508        return "<Application 0x{:x}>".format(id(self))
509
510    def __bool__(self) -> bool:
511        return True
512
513
514class CleanupError(RuntimeError):
515    @property
516    def exceptions(self) -> List[BaseException]:
517        return self.args[1]
518
519
520if TYPE_CHECKING:  # pragma: no cover
521    _CleanupContextBase = FrozenList[Callable[[Application], AsyncIterator[None]]]
522else:
523    _CleanupContextBase = FrozenList
524
525
526class CleanupContext(_CleanupContextBase):
527    def __init__(self) -> None:
528        super().__init__()
529        self._exits = []  # type: List[AsyncIterator[None]]
530
531    async def _on_startup(self, app: Application) -> None:
532        for cb in self:
533            it = cb(app).__aiter__()
534            await it.__anext__()
535            self._exits.append(it)
536
537    async def _on_cleanup(self, app: Application) -> None:
538        errors = []
539        for it in reversed(self._exits):
540            try:
541                await it.__anext__()
542            except StopAsyncIteration:
543                pass
544            except Exception as exc:
545                errors.append(exc)
546            else:
547                errors.append(RuntimeError(f"{it!r} has more than one 'yield'"))
548        if errors:
549            if len(errors) == 1:
550                raise errors[0]
551            else:
552                raise CleanupError("Multiple errors on cleanup stage", errors)
553