1import asyncio 2import asyncio.streams 3import traceback 4import warnings 5from collections import deque 6from contextlib import suppress 7from html import escape as html_escape 8from http import HTTPStatus 9from logging import Logger 10from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, cast 11 12import yarl 13 14from .abc import AbstractAccessLogger, AbstractStreamWriter 15from .base_protocol import BaseProtocol 16from .helpers import CeilTimeout, current_task 17from .http import ( 18 HttpProcessingError, 19 HttpRequestParser, 20 HttpVersion10, 21 RawRequestMessage, 22 StreamWriter, 23) 24from .log import access_logger, server_logger 25from .streams import EMPTY_PAYLOAD, StreamReader 26from .tcp_helpers import tcp_keepalive 27from .web_exceptions import HTTPException 28from .web_log import AccessLogger 29from .web_request import BaseRequest 30from .web_response import Response, StreamResponse 31 32__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError") 33 34if TYPE_CHECKING: # pragma: no cover 35 from .web_server import Server 36 37 38_RequestFactory = Callable[ 39 [ 40 RawRequestMessage, 41 StreamReader, 42 "RequestHandler", 43 AbstractStreamWriter, 44 "asyncio.Task[None]", 45 ], 46 BaseRequest, 47] 48 49_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]] 50 51 52ERROR = RawRequestMessage( 53 "UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/") 54) 55 56 57class RequestPayloadError(Exception): 58 """Payload parsing error.""" 59 60 61class PayloadAccessError(Exception): 62 """Payload was accessed after response was sent.""" 63 64 65class RequestHandler(BaseProtocol): 66 """HTTP protocol implementation. 67 68 RequestHandler handles incoming HTTP request. It reads request line, 69 request headers and request payload and calls handle_request() method. 70 By default it always returns with 404 response. 71 72 RequestHandler handles errors in incoming request, like bad 73 status line, bad headers or incomplete payload. If any error occurs, 74 connection gets closed. 75 76 :param keepalive_timeout: number of seconds before closing 77 keep-alive connection 78 :type keepalive_timeout: int or None 79 80 :param bool tcp_keepalive: TCP keep-alive is on, default is on 81 82 :param bool debug: enable debug mode 83 84 :param logger: custom logger object 85 :type logger: aiohttp.log.server_logger 86 87 :param access_log_class: custom class for access_logger 88 :type access_log_class: aiohttp.abc.AbstractAccessLogger 89 90 :param access_log: custom logging object 91 :type access_log: aiohttp.log.server_logger 92 93 :param str access_log_format: access log format string 94 95 :param loop: Optional event loop 96 97 :param int max_line_size: Optional maximum header line size 98 99 :param int max_field_size: Optional maximum header field size 100 101 :param int max_headers: Optional maximum header size 102 103 """ 104 105 KEEPALIVE_RESCHEDULE_DELAY = 1 106 107 __slots__ = ( 108 "_request_count", 109 "_keepalive", 110 "_manager", 111 "_request_handler", 112 "_request_factory", 113 "_tcp_keepalive", 114 "_keepalive_time", 115 "_keepalive_handle", 116 "_keepalive_timeout", 117 "_lingering_time", 118 "_messages", 119 "_message_tail", 120 "_waiter", 121 "_error_handler", 122 "_task_handler", 123 "_upgrade", 124 "_payload_parser", 125 "_request_parser", 126 "_reading_paused", 127 "logger", 128 "debug", 129 "access_log", 130 "access_logger", 131 "_close", 132 "_force_close", 133 "_current_request", 134 ) 135 136 def __init__( 137 self, 138 manager: "Server", 139 *, 140 loop: asyncio.AbstractEventLoop, 141 keepalive_timeout: float = 75.0, # NGINX default is 75 secs 142 tcp_keepalive: bool = True, 143 logger: Logger = server_logger, 144 access_log_class: Type[AbstractAccessLogger] = AccessLogger, 145 access_log: Logger = access_logger, 146 access_log_format: str = AccessLogger.LOG_FORMAT, 147 debug: bool = False, 148 max_line_size: int = 8190, 149 max_headers: int = 32768, 150 max_field_size: int = 8190, 151 lingering_time: float = 10.0, 152 read_bufsize: int = 2 ** 16, 153 ): 154 155 super().__init__(loop) 156 157 self._request_count = 0 158 self._keepalive = False 159 self._current_request = None # type: Optional[BaseRequest] 160 self._manager = manager # type: Optional[Server] 161 self._request_handler = ( 162 manager.request_handler 163 ) # type: Optional[_RequestHandler] 164 self._request_factory = ( 165 manager.request_factory 166 ) # type: Optional[_RequestFactory] 167 168 self._tcp_keepalive = tcp_keepalive 169 # placeholder to be replaced on keepalive timeout setup 170 self._keepalive_time = 0.0 171 self._keepalive_handle = None # type: Optional[asyncio.Handle] 172 self._keepalive_timeout = keepalive_timeout 173 self._lingering_time = float(lingering_time) 174 175 self._messages = deque() # type: Any # Python 3.5 has no typing.Deque 176 self._message_tail = b"" 177 178 self._waiter = None # type: Optional[asyncio.Future[None]] 179 self._error_handler = None # type: Optional[asyncio.Task[None]] 180 self._task_handler = None # type: Optional[asyncio.Task[None]] 181 182 self._upgrade = False 183 self._payload_parser = None # type: Any 184 self._request_parser = HttpRequestParser( 185 self, 186 loop, 187 read_bufsize, 188 max_line_size=max_line_size, 189 max_field_size=max_field_size, 190 max_headers=max_headers, 191 payload_exception=RequestPayloadError, 192 ) # type: Optional[HttpRequestParser] 193 194 self.logger = logger 195 self.debug = debug 196 self.access_log = access_log 197 if access_log: 198 self.access_logger = access_log_class( 199 access_log, access_log_format 200 ) # type: Optional[AbstractAccessLogger] 201 else: 202 self.access_logger = None 203 204 self._close = False 205 self._force_close = False 206 207 def __repr__(self) -> str: 208 return "<{} {}>".format( 209 self.__class__.__name__, 210 "connected" if self.transport is not None else "disconnected", 211 ) 212 213 @property 214 def keepalive_timeout(self) -> float: 215 return self._keepalive_timeout 216 217 async def shutdown(self, timeout: Optional[float] = 15.0) -> None: 218 """Worker process is about to exit, we need cleanup everything and 219 stop accepting requests. It is especially important for keep-alive 220 connections.""" 221 self._force_close = True 222 223 if self._keepalive_handle is not None: 224 self._keepalive_handle.cancel() 225 226 if self._waiter: 227 self._waiter.cancel() 228 229 # wait for handlers 230 with suppress(asyncio.CancelledError, asyncio.TimeoutError): 231 with CeilTimeout(timeout, loop=self._loop): 232 if self._error_handler is not None and not self._error_handler.done(): 233 await self._error_handler 234 235 if self._current_request is not None: 236 self._current_request._cancel(asyncio.CancelledError()) 237 238 if self._task_handler is not None and not self._task_handler.done(): 239 await self._task_handler 240 241 # force-close non-idle handler 242 if self._task_handler is not None: 243 self._task_handler.cancel() 244 245 if self.transport is not None: 246 self.transport.close() 247 self.transport = None 248 249 def connection_made(self, transport: asyncio.BaseTransport) -> None: 250 super().connection_made(transport) 251 252 real_transport = cast(asyncio.Transport, transport) 253 if self._tcp_keepalive: 254 tcp_keepalive(real_transport) 255 256 self._task_handler = self._loop.create_task(self.start()) 257 assert self._manager is not None 258 self._manager.connection_made(self, real_transport) 259 260 def connection_lost(self, exc: Optional[BaseException]) -> None: 261 if self._manager is None: 262 return 263 self._manager.connection_lost(self, exc) 264 265 super().connection_lost(exc) 266 267 self._manager = None 268 self._force_close = True 269 self._request_factory = None 270 self._request_handler = None 271 self._request_parser = None 272 273 if self._keepalive_handle is not None: 274 self._keepalive_handle.cancel() 275 276 if self._current_request is not None: 277 if exc is None: 278 exc = ConnectionResetError("Connection lost") 279 self._current_request._cancel(exc) 280 281 if self._error_handler is not None: 282 self._error_handler.cancel() 283 if self._task_handler is not None: 284 self._task_handler.cancel() 285 if self._waiter is not None: 286 self._waiter.cancel() 287 288 self._task_handler = None 289 290 if self._payload_parser is not None: 291 self._payload_parser.feed_eof() 292 self._payload_parser = None 293 294 def set_parser(self, parser: Any) -> None: 295 # Actual type is WebReader 296 assert self._payload_parser is None 297 298 self._payload_parser = parser 299 300 if self._message_tail: 301 self._payload_parser.feed_data(self._message_tail) 302 self._message_tail = b"" 303 304 def eof_received(self) -> None: 305 pass 306 307 def data_received(self, data: bytes) -> None: 308 if self._force_close or self._close: 309 return 310 # parse http messages 311 if self._payload_parser is None and not self._upgrade: 312 assert self._request_parser is not None 313 try: 314 messages, upgraded, tail = self._request_parser.feed_data(data) 315 except HttpProcessingError as exc: 316 # something happened during parsing 317 self._error_handler = self._loop.create_task( 318 self.handle_parse_error( 319 StreamWriter(self, self._loop), 400, exc, exc.message 320 ) 321 ) 322 self.close() 323 except Exception as exc: 324 # 500: internal error 325 self._error_handler = self._loop.create_task( 326 self.handle_parse_error(StreamWriter(self, self._loop), 500, exc) 327 ) 328 self.close() 329 else: 330 if messages: 331 # sometimes the parser returns no messages 332 for (msg, payload) in messages: 333 self._request_count += 1 334 self._messages.append((msg, payload)) 335 336 waiter = self._waiter 337 if waiter is not None: 338 if not waiter.done(): 339 # don't set result twice 340 waiter.set_result(None) 341 342 self._upgrade = upgraded 343 if upgraded and tail: 344 self._message_tail = tail 345 346 # no parser, just store 347 elif self._payload_parser is None and self._upgrade and data: 348 self._message_tail += data 349 350 # feed payload 351 elif data: 352 eof, tail = self._payload_parser.feed_data(data) 353 if eof: 354 self.close() 355 356 def keep_alive(self, val: bool) -> None: 357 """Set keep-alive connection mode. 358 359 :param bool val: new state. 360 """ 361 self._keepalive = val 362 if self._keepalive_handle: 363 self._keepalive_handle.cancel() 364 self._keepalive_handle = None 365 366 def close(self) -> None: 367 """Stop accepting new pipelinig messages and close 368 connection when handlers done processing messages""" 369 self._close = True 370 if self._waiter: 371 self._waiter.cancel() 372 373 def force_close(self) -> None: 374 """Force close connection""" 375 self._force_close = True 376 if self._waiter: 377 self._waiter.cancel() 378 if self.transport is not None: 379 self.transport.close() 380 self.transport = None 381 382 def log_access( 383 self, request: BaseRequest, response: StreamResponse, time: float 384 ) -> None: 385 if self.access_logger is not None: 386 self.access_logger.log(request, response, self._loop.time() - time) 387 388 def log_debug(self, *args: Any, **kw: Any) -> None: 389 if self.debug: 390 self.logger.debug(*args, **kw) 391 392 def log_exception(self, *args: Any, **kw: Any) -> None: 393 self.logger.exception(*args, **kw) 394 395 def _process_keepalive(self) -> None: 396 if self._force_close or not self._keepalive: 397 return 398 399 next = self._keepalive_time + self._keepalive_timeout 400 401 # handler in idle state 402 if self._waiter: 403 if self._loop.time() > next: 404 self.force_close() 405 return 406 407 # not all request handlers are done, 408 # reschedule itself to next second 409 self._keepalive_handle = self._loop.call_later( 410 self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive 411 ) 412 413 async def _handle_request( 414 self, 415 request: BaseRequest, 416 start_time: float, 417 ) -> Tuple[StreamResponse, bool]: 418 assert self._request_handler is not None 419 try: 420 try: 421 self._current_request = request 422 resp = await self._request_handler(request) 423 finally: 424 self._current_request = None 425 except HTTPException as exc: 426 resp = Response( 427 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers 428 ) 429 reset = await self.finish_response(request, resp, start_time) 430 except asyncio.CancelledError: 431 raise 432 except asyncio.TimeoutError as exc: 433 self.log_debug("Request handler timed out.", exc_info=exc) 434 resp = self.handle_error(request, 504) 435 reset = await self.finish_response(request, resp, start_time) 436 except Exception as exc: 437 resp = self.handle_error(request, 500, exc) 438 reset = await self.finish_response(request, resp, start_time) 439 else: 440 reset = await self.finish_response(request, resp, start_time) 441 442 return resp, reset 443 444 async def start(self) -> None: 445 """Process incoming request. 446 447 It reads request line, request headers and request payload, then 448 calls handle_request() method. Subclass has to override 449 handle_request(). start() handles various exceptions in request 450 or response handling. Connection is being closed always unless 451 keep_alive(True) specified. 452 """ 453 loop = self._loop 454 handler = self._task_handler 455 assert handler is not None 456 manager = self._manager 457 assert manager is not None 458 keepalive_timeout = self._keepalive_timeout 459 resp = None 460 assert self._request_factory is not None 461 assert self._request_handler is not None 462 463 while not self._force_close: 464 if not self._messages: 465 try: 466 # wait for next request 467 self._waiter = loop.create_future() 468 await self._waiter 469 except asyncio.CancelledError: 470 break 471 finally: 472 self._waiter = None 473 474 message, payload = self._messages.popleft() 475 476 start = loop.time() 477 478 manager.requests_count += 1 479 writer = StreamWriter(self, loop) 480 request = self._request_factory(message, payload, self, writer, handler) 481 try: 482 # a new task is used for copy context vars (#3406) 483 task = self._loop.create_task(self._handle_request(request, start)) 484 try: 485 resp, reset = await task 486 except (asyncio.CancelledError, ConnectionError): 487 self.log_debug("Ignored premature client disconnection") 488 break 489 # Deprecation warning (See #2415) 490 if getattr(resp, "__http_exception__", False): 491 warnings.warn( 492 "returning HTTPException object is deprecated " 493 "(#2415) and will be removed, " 494 "please raise the exception instead", 495 DeprecationWarning, 496 ) 497 498 # Drop the processed task from asyncio.Task.all_tasks() early 499 del task 500 if reset: 501 self.log_debug("Ignored premature client disconnection 2") 502 break 503 504 # notify server about keep-alive 505 self._keepalive = bool(resp.keep_alive) 506 507 # check payload 508 if not payload.is_eof(): 509 lingering_time = self._lingering_time 510 if not self._force_close and lingering_time: 511 self.log_debug( 512 "Start lingering close timer for %s sec.", lingering_time 513 ) 514 515 now = loop.time() 516 end_t = now + lingering_time 517 518 with suppress(asyncio.TimeoutError, asyncio.CancelledError): 519 while not payload.is_eof() and now < end_t: 520 with CeilTimeout(end_t - now, loop=loop): 521 # read and ignore 522 await payload.readany() 523 now = loop.time() 524 525 # if payload still uncompleted 526 if not payload.is_eof() and not self._force_close: 527 self.log_debug("Uncompleted request.") 528 self.close() 529 530 payload.set_exception(PayloadAccessError()) 531 532 except asyncio.CancelledError: 533 self.log_debug("Ignored premature client disconnection ") 534 break 535 except RuntimeError as exc: 536 if self.debug: 537 self.log_exception("Unhandled runtime exception", exc_info=exc) 538 self.force_close() 539 except Exception as exc: 540 self.log_exception("Unhandled exception", exc_info=exc) 541 self.force_close() 542 finally: 543 if self.transport is None and resp is not None: 544 self.log_debug("Ignored premature client disconnection.") 545 elif not self._force_close: 546 if self._keepalive and not self._close: 547 # start keep-alive timer 548 if keepalive_timeout is not None: 549 now = self._loop.time() 550 self._keepalive_time = now 551 if self._keepalive_handle is None: 552 self._keepalive_handle = loop.call_at( 553 now + keepalive_timeout, self._process_keepalive 554 ) 555 else: 556 break 557 558 # remove handler, close transport if no handlers left 559 if not self._force_close: 560 self._task_handler = None 561 if self.transport is not None and self._error_handler is None: 562 self.transport.close() 563 564 async def finish_response( 565 self, request: BaseRequest, resp: StreamResponse, start_time: float 566 ) -> bool: 567 """ 568 Prepare the response and write_eof, then log access. This has to 569 be called within the context of any exception so the access logger 570 can get exception information. Returns True if the client disconnects 571 prematurely. 572 """ 573 if self._request_parser is not None: 574 self._request_parser.set_upgraded(False) 575 self._upgrade = False 576 if self._message_tail: 577 self._request_parser.feed_data(self._message_tail) 578 self._message_tail = b"" 579 try: 580 prepare_meth = resp.prepare 581 except AttributeError: 582 if resp is None: 583 raise RuntimeError("Missing return " "statement on request handler") 584 else: 585 raise RuntimeError( 586 "Web-handler should return " 587 "a response instance, " 588 "got {!r}".format(resp) 589 ) 590 try: 591 await prepare_meth(request) 592 await resp.write_eof() 593 except ConnectionError: 594 self.log_access(request, resp, start_time) 595 return True 596 else: 597 self.log_access(request, resp, start_time) 598 return False 599 600 def handle_error( 601 self, 602 request: BaseRequest, 603 status: int = 500, 604 exc: Optional[BaseException] = None, 605 message: Optional[str] = None, 606 ) -> StreamResponse: 607 """Handle errors. 608 609 Returns HTTP response with specific status code. Logs additional 610 information. It always closes current connection.""" 611 self.log_exception("Error handling request", exc_info=exc) 612 613 ct = "text/plain" 614 if status == HTTPStatus.INTERNAL_SERVER_ERROR: 615 title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR) 616 msg = HTTPStatus.INTERNAL_SERVER_ERROR.description 617 tb = None 618 if self.debug: 619 with suppress(Exception): 620 tb = traceback.format_exc() 621 622 if "text/html" in request.headers.get("Accept", ""): 623 if tb: 624 tb = html_escape(tb) 625 msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>" 626 message = ( 627 "<html><head>" 628 "<title>{title}</title>" 629 "</head><body>\n<h1>{title}</h1>" 630 "\n{msg}\n</body></html>\n" 631 ).format(title=title, msg=msg) 632 ct = "text/html" 633 else: 634 if tb: 635 msg = tb 636 message = title + "\n\n" + msg 637 638 resp = Response(status=status, text=message, content_type=ct) 639 resp.force_close() 640 641 # some data already got sent, connection is broken 642 if request.writer.output_size > 0 or self.transport is None: 643 self.force_close() 644 645 return resp 646 647 async def handle_parse_error( 648 self, 649 writer: AbstractStreamWriter, 650 status: int, 651 exc: Optional[BaseException] = None, 652 message: Optional[str] = None, 653 ) -> None: 654 task = current_task() 655 assert task is not None 656 request = BaseRequest( 657 ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop # type: ignore 658 ) 659 660 resp = self.handle_error(request, status, exc, message) 661 await resp.prepare(request) 662 await resp.write_eof() 663 664 if self.transport is not None: 665 self.transport.close() 666 667 self._error_handler = None 668