1import asyncio
2import asyncio.streams
3import traceback
4import warnings
5from collections import deque
6from contextlib import suppress
7from html import escape as html_escape
8from http import HTTPStatus
9from logging import Logger
10from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, cast
11
12import yarl
13
14from .abc import AbstractAccessLogger, AbstractStreamWriter
15from .base_protocol import BaseProtocol
16from .helpers import CeilTimeout, current_task
17from .http import (
18    HttpProcessingError,
19    HttpRequestParser,
20    HttpVersion10,
21    RawRequestMessage,
22    StreamWriter,
23)
24from .log import access_logger, server_logger
25from .streams import EMPTY_PAYLOAD, StreamReader
26from .tcp_helpers import tcp_keepalive
27from .web_exceptions import HTTPException
28from .web_log import AccessLogger
29from .web_request import BaseRequest
30from .web_response import Response, StreamResponse
31
32__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
33
34if TYPE_CHECKING:  # pragma: no cover
35    from .web_server import Server
36
37
38_RequestFactory = Callable[
39    [
40        RawRequestMessage,
41        StreamReader,
42        "RequestHandler",
43        AbstractStreamWriter,
44        "asyncio.Task[None]",
45    ],
46    BaseRequest,
47]
48
49_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
50
51
52ERROR = RawRequestMessage(
53    "UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/")
54)
55
56
57class RequestPayloadError(Exception):
58    """Payload parsing error."""
59
60
61class PayloadAccessError(Exception):
62    """Payload was accessed after response was sent."""
63
64
65class RequestHandler(BaseProtocol):
66    """HTTP protocol implementation.
67
68    RequestHandler handles incoming HTTP request. It reads request line,
69    request headers and request payload and calls handle_request() method.
70    By default it always returns with 404 response.
71
72    RequestHandler handles errors in incoming request, like bad
73    status line, bad headers or incomplete payload. If any error occurs,
74    connection gets closed.
75
76    :param keepalive_timeout: number of seconds before closing
77                              keep-alive connection
78    :type keepalive_timeout: int or None
79
80    :param bool tcp_keepalive: TCP keep-alive is on, default is on
81
82    :param bool debug: enable debug mode
83
84    :param logger: custom logger object
85    :type logger: aiohttp.log.server_logger
86
87    :param access_log_class: custom class for access_logger
88    :type access_log_class: aiohttp.abc.AbstractAccessLogger
89
90    :param access_log: custom logging object
91    :type access_log: aiohttp.log.server_logger
92
93    :param str access_log_format: access log format string
94
95    :param loop: Optional event loop
96
97    :param int max_line_size: Optional maximum header line size
98
99    :param int max_field_size: Optional maximum header field size
100
101    :param int max_headers: Optional maximum header size
102
103    """
104
105    KEEPALIVE_RESCHEDULE_DELAY = 1
106
107    __slots__ = (
108        "_request_count",
109        "_keepalive",
110        "_manager",
111        "_request_handler",
112        "_request_factory",
113        "_tcp_keepalive",
114        "_keepalive_time",
115        "_keepalive_handle",
116        "_keepalive_timeout",
117        "_lingering_time",
118        "_messages",
119        "_message_tail",
120        "_waiter",
121        "_error_handler",
122        "_task_handler",
123        "_upgrade",
124        "_payload_parser",
125        "_request_parser",
126        "_reading_paused",
127        "logger",
128        "debug",
129        "access_log",
130        "access_logger",
131        "_close",
132        "_force_close",
133        "_current_request",
134    )
135
136    def __init__(
137        self,
138        manager: "Server",
139        *,
140        loop: asyncio.AbstractEventLoop,
141        keepalive_timeout: float = 75.0,  # NGINX default is 75 secs
142        tcp_keepalive: bool = True,
143        logger: Logger = server_logger,
144        access_log_class: Type[AbstractAccessLogger] = AccessLogger,
145        access_log: Logger = access_logger,
146        access_log_format: str = AccessLogger.LOG_FORMAT,
147        debug: bool = False,
148        max_line_size: int = 8190,
149        max_headers: int = 32768,
150        max_field_size: int = 8190,
151        lingering_time: float = 10.0,
152        read_bufsize: int = 2 ** 16,
153    ):
154
155        super().__init__(loop)
156
157        self._request_count = 0
158        self._keepalive = False
159        self._current_request = None  # type: Optional[BaseRequest]
160        self._manager = manager  # type: Optional[Server]
161        self._request_handler = (
162            manager.request_handler
163        )  # type: Optional[_RequestHandler]
164        self._request_factory = (
165            manager.request_factory
166        )  # type: Optional[_RequestFactory]
167
168        self._tcp_keepalive = tcp_keepalive
169        # placeholder to be replaced on keepalive timeout setup
170        self._keepalive_time = 0.0
171        self._keepalive_handle = None  # type: Optional[asyncio.Handle]
172        self._keepalive_timeout = keepalive_timeout
173        self._lingering_time = float(lingering_time)
174
175        self._messages = deque()  # type: Any  # Python 3.5 has no typing.Deque
176        self._message_tail = b""
177
178        self._waiter = None  # type: Optional[asyncio.Future[None]]
179        self._error_handler = None  # type: Optional[asyncio.Task[None]]
180        self._task_handler = None  # type: Optional[asyncio.Task[None]]
181
182        self._upgrade = False
183        self._payload_parser = None  # type: Any
184        self._request_parser = HttpRequestParser(
185            self,
186            loop,
187            read_bufsize,
188            max_line_size=max_line_size,
189            max_field_size=max_field_size,
190            max_headers=max_headers,
191            payload_exception=RequestPayloadError,
192        )  # type: Optional[HttpRequestParser]
193
194        self.logger = logger
195        self.debug = debug
196        self.access_log = access_log
197        if access_log:
198            self.access_logger = access_log_class(
199                access_log, access_log_format
200            )  # type: Optional[AbstractAccessLogger]
201        else:
202            self.access_logger = None
203
204        self._close = False
205        self._force_close = False
206
207    def __repr__(self) -> str:
208        return "<{} {}>".format(
209            self.__class__.__name__,
210            "connected" if self.transport is not None else "disconnected",
211        )
212
213    @property
214    def keepalive_timeout(self) -> float:
215        return self._keepalive_timeout
216
217    async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
218        """Worker process is about to exit, we need cleanup everything and
219        stop accepting requests. It is especially important for keep-alive
220        connections."""
221        self._force_close = True
222
223        if self._keepalive_handle is not None:
224            self._keepalive_handle.cancel()
225
226        if self._waiter:
227            self._waiter.cancel()
228
229        # wait for handlers
230        with suppress(asyncio.CancelledError, asyncio.TimeoutError):
231            with CeilTimeout(timeout, loop=self._loop):
232                if self._error_handler is not None and not self._error_handler.done():
233                    await self._error_handler
234
235                if self._current_request is not None:
236                    self._current_request._cancel(asyncio.CancelledError())
237
238                if self._task_handler is not None and not self._task_handler.done():
239                    await self._task_handler
240
241        # force-close non-idle handler
242        if self._task_handler is not None:
243            self._task_handler.cancel()
244
245        if self.transport is not None:
246            self.transport.close()
247            self.transport = None
248
249    def connection_made(self, transport: asyncio.BaseTransport) -> None:
250        super().connection_made(transport)
251
252        real_transport = cast(asyncio.Transport, transport)
253        if self._tcp_keepalive:
254            tcp_keepalive(real_transport)
255
256        self._task_handler = self._loop.create_task(self.start())
257        assert self._manager is not None
258        self._manager.connection_made(self, real_transport)
259
260    def connection_lost(self, exc: Optional[BaseException]) -> None:
261        if self._manager is None:
262            return
263        self._manager.connection_lost(self, exc)
264
265        super().connection_lost(exc)
266
267        self._manager = None
268        self._force_close = True
269        self._request_factory = None
270        self._request_handler = None
271        self._request_parser = None
272
273        if self._keepalive_handle is not None:
274            self._keepalive_handle.cancel()
275
276        if self._current_request is not None:
277            if exc is None:
278                exc = ConnectionResetError("Connection lost")
279            self._current_request._cancel(exc)
280
281        if self._error_handler is not None:
282            self._error_handler.cancel()
283        if self._task_handler is not None:
284            self._task_handler.cancel()
285        if self._waiter is not None:
286            self._waiter.cancel()
287
288        self._task_handler = None
289
290        if self._payload_parser is not None:
291            self._payload_parser.feed_eof()
292            self._payload_parser = None
293
294    def set_parser(self, parser: Any) -> None:
295        # Actual type is WebReader
296        assert self._payload_parser is None
297
298        self._payload_parser = parser
299
300        if self._message_tail:
301            self._payload_parser.feed_data(self._message_tail)
302            self._message_tail = b""
303
304    def eof_received(self) -> None:
305        pass
306
307    def data_received(self, data: bytes) -> None:
308        if self._force_close or self._close:
309            return
310        # parse http messages
311        if self._payload_parser is None and not self._upgrade:
312            assert self._request_parser is not None
313            try:
314                messages, upgraded, tail = self._request_parser.feed_data(data)
315            except HttpProcessingError as exc:
316                # something happened during parsing
317                self._error_handler = self._loop.create_task(
318                    self.handle_parse_error(
319                        StreamWriter(self, self._loop), 400, exc, exc.message
320                    )
321                )
322                self.close()
323            except Exception as exc:
324                # 500: internal error
325                self._error_handler = self._loop.create_task(
326                    self.handle_parse_error(StreamWriter(self, self._loop), 500, exc)
327                )
328                self.close()
329            else:
330                if messages:
331                    # sometimes the parser returns no messages
332                    for (msg, payload) in messages:
333                        self._request_count += 1
334                        self._messages.append((msg, payload))
335
336                    waiter = self._waiter
337                    if waiter is not None:
338                        if not waiter.done():
339                            # don't set result twice
340                            waiter.set_result(None)
341
342                self._upgrade = upgraded
343                if upgraded and tail:
344                    self._message_tail = tail
345
346        # no parser, just store
347        elif self._payload_parser is None and self._upgrade and data:
348            self._message_tail += data
349
350        # feed payload
351        elif data:
352            eof, tail = self._payload_parser.feed_data(data)
353            if eof:
354                self.close()
355
356    def keep_alive(self, val: bool) -> None:
357        """Set keep-alive connection mode.
358
359        :param bool val: new state.
360        """
361        self._keepalive = val
362        if self._keepalive_handle:
363            self._keepalive_handle.cancel()
364            self._keepalive_handle = None
365
366    def close(self) -> None:
367        """Stop accepting new pipelinig messages and close
368        connection when handlers done processing messages"""
369        self._close = True
370        if self._waiter:
371            self._waiter.cancel()
372
373    def force_close(self) -> None:
374        """Force close connection"""
375        self._force_close = True
376        if self._waiter:
377            self._waiter.cancel()
378        if self.transport is not None:
379            self.transport.close()
380            self.transport = None
381
382    def log_access(
383        self, request: BaseRequest, response: StreamResponse, time: float
384    ) -> None:
385        if self.access_logger is not None:
386            self.access_logger.log(request, response, self._loop.time() - time)
387
388    def log_debug(self, *args: Any, **kw: Any) -> None:
389        if self.debug:
390            self.logger.debug(*args, **kw)
391
392    def log_exception(self, *args: Any, **kw: Any) -> None:
393        self.logger.exception(*args, **kw)
394
395    def _process_keepalive(self) -> None:
396        if self._force_close or not self._keepalive:
397            return
398
399        next = self._keepalive_time + self._keepalive_timeout
400
401        # handler in idle state
402        if self._waiter:
403            if self._loop.time() > next:
404                self.force_close()
405                return
406
407        # not all request handlers are done,
408        # reschedule itself to next second
409        self._keepalive_handle = self._loop.call_later(
410            self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive
411        )
412
413    async def _handle_request(
414        self,
415        request: BaseRequest,
416        start_time: float,
417    ) -> Tuple[StreamResponse, bool]:
418        assert self._request_handler is not None
419        try:
420            try:
421                self._current_request = request
422                resp = await self._request_handler(request)
423            finally:
424                self._current_request = None
425        except HTTPException as exc:
426            resp = Response(
427                status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
428            )
429            reset = await self.finish_response(request, resp, start_time)
430        except asyncio.CancelledError:
431            raise
432        except asyncio.TimeoutError as exc:
433            self.log_debug("Request handler timed out.", exc_info=exc)
434            resp = self.handle_error(request, 504)
435            reset = await self.finish_response(request, resp, start_time)
436        except Exception as exc:
437            resp = self.handle_error(request, 500, exc)
438            reset = await self.finish_response(request, resp, start_time)
439        else:
440            reset = await self.finish_response(request, resp, start_time)
441
442        return resp, reset
443
444    async def start(self) -> None:
445        """Process incoming request.
446
447        It reads request line, request headers and request payload, then
448        calls handle_request() method. Subclass has to override
449        handle_request(). start() handles various exceptions in request
450        or response handling. Connection is being closed always unless
451        keep_alive(True) specified.
452        """
453        loop = self._loop
454        handler = self._task_handler
455        assert handler is not None
456        manager = self._manager
457        assert manager is not None
458        keepalive_timeout = self._keepalive_timeout
459        resp = None
460        assert self._request_factory is not None
461        assert self._request_handler is not None
462
463        while not self._force_close:
464            if not self._messages:
465                try:
466                    # wait for next request
467                    self._waiter = loop.create_future()
468                    await self._waiter
469                except asyncio.CancelledError:
470                    break
471                finally:
472                    self._waiter = None
473
474            message, payload = self._messages.popleft()
475
476            start = loop.time()
477
478            manager.requests_count += 1
479            writer = StreamWriter(self, loop)
480            request = self._request_factory(message, payload, self, writer, handler)
481            try:
482                # a new task is used for copy context vars (#3406)
483                task = self._loop.create_task(self._handle_request(request, start))
484                try:
485                    resp, reset = await task
486                except (asyncio.CancelledError, ConnectionError):
487                    self.log_debug("Ignored premature client disconnection")
488                    break
489                # Deprecation warning (See #2415)
490                if getattr(resp, "__http_exception__", False):
491                    warnings.warn(
492                        "returning HTTPException object is deprecated "
493                        "(#2415) and will be removed, "
494                        "please raise the exception instead",
495                        DeprecationWarning,
496                    )
497
498                # Drop the processed task from asyncio.Task.all_tasks() early
499                del task
500                if reset:
501                    self.log_debug("Ignored premature client disconnection 2")
502                    break
503
504                # notify server about keep-alive
505                self._keepalive = bool(resp.keep_alive)
506
507                # check payload
508                if not payload.is_eof():
509                    lingering_time = self._lingering_time
510                    if not self._force_close and lingering_time:
511                        self.log_debug(
512                            "Start lingering close timer for %s sec.", lingering_time
513                        )
514
515                        now = loop.time()
516                        end_t = now + lingering_time
517
518                        with suppress(asyncio.TimeoutError, asyncio.CancelledError):
519                            while not payload.is_eof() and now < end_t:
520                                with CeilTimeout(end_t - now, loop=loop):
521                                    # read and ignore
522                                    await payload.readany()
523                                now = loop.time()
524
525                    # if payload still uncompleted
526                    if not payload.is_eof() and not self._force_close:
527                        self.log_debug("Uncompleted request.")
528                        self.close()
529
530                payload.set_exception(PayloadAccessError())
531
532            except asyncio.CancelledError:
533                self.log_debug("Ignored premature client disconnection ")
534                break
535            except RuntimeError as exc:
536                if self.debug:
537                    self.log_exception("Unhandled runtime exception", exc_info=exc)
538                self.force_close()
539            except Exception as exc:
540                self.log_exception("Unhandled exception", exc_info=exc)
541                self.force_close()
542            finally:
543                if self.transport is None and resp is not None:
544                    self.log_debug("Ignored premature client disconnection.")
545                elif not self._force_close:
546                    if self._keepalive and not self._close:
547                        # start keep-alive timer
548                        if keepalive_timeout is not None:
549                            now = self._loop.time()
550                            self._keepalive_time = now
551                            if self._keepalive_handle is None:
552                                self._keepalive_handle = loop.call_at(
553                                    now + keepalive_timeout, self._process_keepalive
554                                )
555                    else:
556                        break
557
558        # remove handler, close transport if no handlers left
559        if not self._force_close:
560            self._task_handler = None
561            if self.transport is not None and self._error_handler is None:
562                self.transport.close()
563
564    async def finish_response(
565        self, request: BaseRequest, resp: StreamResponse, start_time: float
566    ) -> bool:
567        """
568        Prepare the response and write_eof, then log access. This has to
569        be called within the context of any exception so the access logger
570        can get exception information. Returns True if the client disconnects
571        prematurely.
572        """
573        if self._request_parser is not None:
574            self._request_parser.set_upgraded(False)
575            self._upgrade = False
576            if self._message_tail:
577                self._request_parser.feed_data(self._message_tail)
578                self._message_tail = b""
579        try:
580            prepare_meth = resp.prepare
581        except AttributeError:
582            if resp is None:
583                raise RuntimeError("Missing return " "statement on request handler")
584            else:
585                raise RuntimeError(
586                    "Web-handler should return "
587                    "a response instance, "
588                    "got {!r}".format(resp)
589                )
590        try:
591            await prepare_meth(request)
592            await resp.write_eof()
593        except ConnectionError:
594            self.log_access(request, resp, start_time)
595            return True
596        else:
597            self.log_access(request, resp, start_time)
598            return False
599
600    def handle_error(
601        self,
602        request: BaseRequest,
603        status: int = 500,
604        exc: Optional[BaseException] = None,
605        message: Optional[str] = None,
606    ) -> StreamResponse:
607        """Handle errors.
608
609        Returns HTTP response with specific status code. Logs additional
610        information. It always closes current connection."""
611        self.log_exception("Error handling request", exc_info=exc)
612
613        ct = "text/plain"
614        if status == HTTPStatus.INTERNAL_SERVER_ERROR:
615            title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
616            msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
617            tb = None
618            if self.debug:
619                with suppress(Exception):
620                    tb = traceback.format_exc()
621
622            if "text/html" in request.headers.get("Accept", ""):
623                if tb:
624                    tb = html_escape(tb)
625                    msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
626                message = (
627                    "<html><head>"
628                    "<title>{title}</title>"
629                    "</head><body>\n<h1>{title}</h1>"
630                    "\n{msg}\n</body></html>\n"
631                ).format(title=title, msg=msg)
632                ct = "text/html"
633            else:
634                if tb:
635                    msg = tb
636                message = title + "\n\n" + msg
637
638        resp = Response(status=status, text=message, content_type=ct)
639        resp.force_close()
640
641        # some data already got sent, connection is broken
642        if request.writer.output_size > 0 or self.transport is None:
643            self.force_close()
644
645        return resp
646
647    async def handle_parse_error(
648        self,
649        writer: AbstractStreamWriter,
650        status: int,
651        exc: Optional[BaseException] = None,
652        message: Optional[str] = None,
653    ) -> None:
654        task = current_task()
655        assert task is not None
656        request = BaseRequest(
657            ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop  # type: ignore
658        )
659
660        resp = self.handle_error(request, status, exc, message)
661        await resp.prepare(request)
662        await resp.write_eof()
663
664        if self.transport is not None:
665            self.transport.close()
666
667        self._error_handler = None
668