1""" 2Generic Asynchronous Message-based Protocol Support 3 4This module provides a generic framework for sending and receiving 5messages over an asyncio stream. `AsyncProtocol` is an abstract class 6that implements the core mechanisms of a simple send/receive protocol, 7and is designed to be extended. 8 9In this package, it is used as the implementation for the `QMPClient` 10class. 11""" 12 13# It's all the docstrings ... ! It's long for a good reason ^_^; 14# pylint: disable=too-many-lines 15 16import asyncio 17from asyncio import StreamReader, StreamWriter 18from enum import Enum 19from functools import wraps 20import logging 21import socket 22from ssl import SSLContext 23from typing import ( 24 Any, 25 Awaitable, 26 Callable, 27 Generic, 28 List, 29 Optional, 30 Tuple, 31 TypeVar, 32 Union, 33 cast, 34) 35 36from .error import QMPError 37from .util import ( 38 bottom_half, 39 create_task, 40 exception_summary, 41 flush, 42 is_closing, 43 pretty_traceback, 44 upper_half, 45 wait_closed, 46) 47 48 49T = TypeVar('T') 50_U = TypeVar('_U') 51_TaskFN = Callable[[], Awaitable[None]] # aka ``async def func() -> None`` 52 53InternetAddrT = Tuple[str, int] 54UnixAddrT = str 55SocketAddrT = Union[UnixAddrT, InternetAddrT] 56 57 58class Runstate(Enum): 59 """Protocol session runstate.""" 60 61 #: Fully quiesced and disconnected. 62 IDLE = 0 63 #: In the process of connecting or establishing a session. 64 CONNECTING = 1 65 #: Fully connected and active session. 66 RUNNING = 2 67 #: In the process of disconnecting. 68 #: Runstate may be returned to `IDLE` by calling `disconnect()`. 69 DISCONNECTING = 3 70 71 72class ConnectError(QMPError): 73 """ 74 Raised when the initial connection process has failed. 75 76 This Exception always wraps a "root cause" exception that can be 77 interrogated for additional information. 78 79 :param error_message: Human-readable string describing the error. 80 :param exc: The root-cause exception. 81 """ 82 def __init__(self, error_message: str, exc: Exception): 83 super().__init__(error_message) 84 #: Human-readable error string 85 self.error_message: str = error_message 86 #: Wrapped root cause exception 87 self.exc: Exception = exc 88 89 def __str__(self) -> str: 90 cause = str(self.exc) 91 if not cause: 92 # If there's no error string, use the exception name. 93 cause = exception_summary(self.exc) 94 return f"{self.error_message}: {cause}" 95 96 97class StateError(QMPError): 98 """ 99 An API command (connect, execute, etc) was issued at an inappropriate time. 100 101 This error is raised when a command like 102 :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate 103 time. 104 105 :param error_message: Human-readable string describing the state violation. 106 :param state: The actual `Runstate` seen at the time of the violation. 107 :param required: The `Runstate` required to process this command. 108 """ 109 def __init__(self, error_message: str, 110 state: Runstate, required: Runstate): 111 super().__init__(error_message) 112 self.error_message = error_message 113 self.state = state 114 self.required = required 115 116 117F = TypeVar('F', bound=Callable[..., Any]) # pylint: disable=invalid-name 118 119 120# Don't Panic. 121def require(required_state: Runstate) -> Callable[[F], F]: 122 """ 123 Decorator: protect a method so it can only be run in a certain `Runstate`. 124 125 :param required_state: The `Runstate` required to invoke this method. 126 :raise StateError: When the required `Runstate` is not met. 127 """ 128 def _decorator(func: F) -> F: 129 # _decorator is the decorator that is built by calling the 130 # require() decorator factory; e.g.: 131 # 132 # @require(Runstate.IDLE) def foo(): ... 133 # will replace 'foo' with the result of '_decorator(foo)'. 134 135 @wraps(func) 136 def _wrapper(proto: 'AsyncProtocol[Any]', 137 *args: Any, **kwargs: Any) -> Any: 138 # _wrapper is the function that gets executed prior to the 139 # decorated method. 140 141 name = type(proto).__name__ 142 143 if proto.runstate != required_state: 144 if proto.runstate == Runstate.CONNECTING: 145 emsg = f"{name} is currently connecting." 146 elif proto.runstate == Runstate.DISCONNECTING: 147 emsg = (f"{name} is disconnecting." 148 " Call disconnect() to return to IDLE state.") 149 elif proto.runstate == Runstate.RUNNING: 150 emsg = f"{name} is already connected and running." 151 elif proto.runstate == Runstate.IDLE: 152 emsg = f"{name} is disconnected and idle." 153 else: 154 assert False 155 raise StateError(emsg, proto.runstate, required_state) 156 # No StateError, so call the wrapped method. 157 return func(proto, *args, **kwargs) 158 159 # Return the decorated method; 160 # Transforming Func to Decorated[Func]. 161 return cast(F, _wrapper) 162 163 # Return the decorator instance from the decorator factory. Phew! 164 return _decorator 165 166 167class AsyncProtocol(Generic[T]): 168 """ 169 AsyncProtocol implements a generic async message-based protocol. 170 171 This protocol assumes the basic unit of information transfer between 172 client and server is a "message", the details of which are left up 173 to the implementation. It assumes the sending and receiving of these 174 messages is full-duplex and not necessarily correlated; i.e. it 175 supports asynchronous inbound messages. 176 177 It is designed to be extended by a specific protocol which provides 178 the implementations for how to read and send messages. These must be 179 defined in `_do_recv()` and `_do_send()`, respectively. 180 181 Other callbacks have a default implementation, but are intended to be 182 either extended or overridden: 183 184 - `_establish_session`: 185 The base implementation starts the reader/writer tasks. 186 A protocol implementation can override this call, inserting 187 actions to be taken prior to starting the reader/writer tasks 188 before the super() call; actions needing to occur afterwards 189 can be written after the super() call. 190 - `_on_message`: 191 Actions to be performed when a message is received. 192 - `_cb_outbound`: 193 Logging/Filtering hook for all outbound messages. 194 - `_cb_inbound`: 195 Logging/Filtering hook for all inbound messages. 196 This hook runs *before* `_on_message()`. 197 198 :param name: 199 Name used for logging messages, if any. By default, messages 200 will log to 'qemu.qmp.protocol', but each individual connection 201 can be given its own logger by giving it a name; messages will 202 then log to 'qemu.qmp.protocol.${name}'. 203 """ 204 # pylint: disable=too-many-instance-attributes 205 206 #: Logger object for debugging messages from this connection. 207 logger = logging.getLogger(__name__) 208 209 # Maximum allowable size of read buffer 210 _limit = (64 * 1024) 211 212 # ------------------------- 213 # Section: Public interface 214 # ------------------------- 215 216 def __init__(self, name: Optional[str] = None) -> None: 217 #: The nickname for this connection, if any. 218 self.name: Optional[str] = name 219 if self.name is not None: 220 self.logger = self.logger.getChild(self.name) 221 222 # stream I/O 223 self._reader: Optional[StreamReader] = None 224 self._writer: Optional[StreamWriter] = None 225 226 # Outbound Message queue 227 self._outgoing: asyncio.Queue[T] 228 229 # Special, long-running tasks: 230 self._reader_task: Optional[asyncio.Future[None]] = None 231 self._writer_task: Optional[asyncio.Future[None]] = None 232 233 # Aggregate of the above two tasks, used for Exception management. 234 self._bh_tasks: Optional[asyncio.Future[Tuple[None, None]]] = None 235 236 #: Disconnect task. The disconnect implementation runs in a task 237 #: so that asynchronous disconnects (initiated by the 238 #: reader/writer) are allowed to wait for the reader/writers to 239 #: exit. 240 self._dc_task: Optional[asyncio.Future[None]] = None 241 242 self._runstate = Runstate.IDLE 243 self._runstate_changed: Optional[asyncio.Event] = None 244 245 # Server state for start_server() and _incoming() 246 self._server: Optional[asyncio.AbstractServer] = None 247 self._accepted: Optional[asyncio.Event] = None 248 249 def __repr__(self) -> str: 250 cls_name = type(self).__name__ 251 tokens = [] 252 if self.name is not None: 253 tokens.append(f"name={self.name!r}") 254 tokens.append(f"runstate={self.runstate.name}") 255 return f"<{cls_name} {' '.join(tokens)}>" 256 257 @property # @upper_half 258 def runstate(self) -> Runstate: 259 """The current `Runstate` of the connection.""" 260 return self._runstate 261 262 @upper_half 263 async def runstate_changed(self) -> Runstate: 264 """ 265 Wait for the `runstate` to change, then return that runstate. 266 """ 267 await self._runstate_event.wait() 268 return self.runstate 269 270 @upper_half 271 @require(Runstate.IDLE) 272 async def start_server_and_accept( 273 self, address: SocketAddrT, 274 ssl: Optional[SSLContext] = None 275 ) -> None: 276 """ 277 Accept a connection and begin processing message queues. 278 279 If this call fails, `runstate` is guaranteed to be set back to `IDLE`. 280 This method is precisely equivalent to calling `start_server()` 281 followed by `accept()`. 282 283 :param address: 284 Address to listen on; UNIX socket path or TCP address/port. 285 :param ssl: SSL context to use, if any. 286 287 :raise StateError: When the `Runstate` is not `IDLE`. 288 :raise ConnectError: 289 When a connection or session cannot be established. 290 291 This exception will wrap a more concrete one. In most cases, 292 the wrapped exception will be `OSError` or `EOFError`. If a 293 protocol-level failure occurs while establishing a new 294 session, the wrapped error may also be an `QMPError`. 295 """ 296 await self.start_server(address, ssl) 297 await self.accept() 298 assert self.runstate == Runstate.RUNNING 299 300 @upper_half 301 @require(Runstate.IDLE) 302 async def open_with_socket(self, sock: socket.socket) -> None: 303 """ 304 Start connection with given socket. 305 306 :param sock: A socket. 307 308 :raise StateError: When the `Runstate` is not `IDLE`. 309 """ 310 self._reader, self._writer = await asyncio.open_connection(sock=sock) 311 self._set_state(Runstate.CONNECTING) 312 313 @upper_half 314 @require(Runstate.IDLE) 315 async def start_server(self, address: SocketAddrT, 316 ssl: Optional[SSLContext] = None) -> None: 317 """ 318 Start listening for an incoming connection, but do not wait for a peer. 319 320 This method starts listening for an incoming connection, but 321 does not block waiting for a peer. This call will return 322 immediately after binding and listening on a socket. A later 323 call to `accept()` must be made in order to finalize the 324 incoming connection. 325 326 :param address: 327 Address to listen on; UNIX socket path or TCP address/port. 328 :param ssl: SSL context to use, if any. 329 330 :raise StateError: When the `Runstate` is not `IDLE`. 331 :raise ConnectError: 332 When the server could not start listening on this address. 333 334 This exception will wrap a more concrete one. In most cases, 335 the wrapped exception will be `OSError`. 336 """ 337 await self._session_guard( 338 self._do_start_server(address, ssl), 339 'Failed to establish connection') 340 assert self.runstate == Runstate.CONNECTING 341 342 @upper_half 343 @require(Runstate.CONNECTING) 344 async def accept(self) -> None: 345 """ 346 Accept an incoming connection and begin processing message queues. 347 348 If this call fails, `runstate` is guaranteed to be set back to `IDLE`. 349 350 :raise StateError: When the `Runstate` is not `CONNECTING`. 351 :raise QMPError: When `start_server()` was not called yet. 352 :raise ConnectError: 353 When a connection or session cannot be established. 354 355 This exception will wrap a more concrete one. In most cases, 356 the wrapped exception will be `OSError` or `EOFError`. If a 357 protocol-level failure occurs while establishing a new 358 session, the wrapped error may also be an `QMPError`. 359 """ 360 if not self._reader: 361 if self._accepted is None: 362 raise QMPError("Cannot call accept() before start_server().") 363 await self._session_guard( 364 self._do_accept(), 365 'Failed to establish connection') 366 await self._session_guard( 367 self._establish_session(), 368 'Failed to establish session') 369 assert self.runstate == Runstate.RUNNING 370 371 @upper_half 372 @require(Runstate.IDLE) 373 async def connect(self, address: SocketAddrT, 374 ssl: Optional[SSLContext] = None) -> None: 375 """ 376 Connect to the server and begin processing message queues. 377 378 If this call fails, `runstate` is guaranteed to be set back to `IDLE`. 379 380 :param address: 381 Address to connect to; UNIX socket path or TCP address/port. 382 :param ssl: SSL context to use, if any. 383 384 :raise StateError: When the `Runstate` is not `IDLE`. 385 :raise ConnectError: 386 When a connection or session cannot be established. 387 388 This exception will wrap a more concrete one. In most cases, 389 the wrapped exception will be `OSError` or `EOFError`. If a 390 protocol-level failure occurs while establishing a new 391 session, the wrapped error may also be an `QMPError`. 392 """ 393 await self._session_guard( 394 self._do_connect(address, ssl), 395 'Failed to establish connection') 396 await self._session_guard( 397 self._establish_session(), 398 'Failed to establish session') 399 assert self.runstate == Runstate.RUNNING 400 401 @upper_half 402 async def disconnect(self) -> None: 403 """ 404 Disconnect and wait for all tasks to fully stop. 405 406 If there was an exception that caused the reader/writers to 407 terminate prematurely, it will be raised here. 408 409 :raise Exception: When the reader or writer terminate unexpectedly. 410 """ 411 self.logger.debug("disconnect() called.") 412 self._schedule_disconnect() 413 await self._wait_disconnect() 414 415 # -------------------------- 416 # Section: Session machinery 417 # -------------------------- 418 419 async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None: 420 """ 421 Async guard function used to roll back to `IDLE` on any error. 422 423 On any Exception, the state machine will be reset back to 424 `IDLE`. Most Exceptions will be wrapped with `ConnectError`, but 425 `BaseException` events will be left alone (This includes 426 asyncio.CancelledError, even prior to Python 3.8). 427 428 :param error_message: 429 Human-readable string describing what connection phase failed. 430 431 :raise BaseException: 432 When `BaseException` occurs in the guarded block. 433 :raise ConnectError: 434 When any other error is encountered in the guarded block. 435 """ 436 # Note: After Python 3.6 support is removed, this should be an 437 # @asynccontextmanager instead of accepting a callback. 438 try: 439 await coro 440 except BaseException as err: 441 self.logger.error("%s: %s", emsg, exception_summary(err)) 442 self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) 443 try: 444 # Reset the runstate back to IDLE. 445 await self.disconnect() 446 except: 447 # We don't expect any Exceptions from the disconnect function 448 # here, because we failed to connect in the first place. 449 # The disconnect() function is intended to perform 450 # only cannot-fail cleanup here, but you never know. 451 emsg = ( 452 "Unexpected bottom half exception. " 453 "This is a bug in the QMP library. " 454 "Please report it to <qemu-devel@nongnu.org> and " 455 "CC: John Snow <jsnow@redhat.com>." 456 ) 457 self.logger.critical("%s:\n%s\n", emsg, pretty_traceback()) 458 raise 459 460 # CancelledError is an Exception with special semantic meaning; 461 # We do NOT want to wrap it up under ConnectError. 462 # NB: CancelledError is not a BaseException before Python 3.8 463 if isinstance(err, asyncio.CancelledError): 464 raise 465 466 # Any other kind of error can be treated as some kind of connection 467 # failure broadly. Inspect the 'exc' field to explore the root 468 # cause in greater detail. 469 if isinstance(err, Exception): 470 raise ConnectError(emsg, err) from err 471 472 # Raise BaseExceptions un-wrapped, they're more important. 473 raise 474 475 @property 476 def _runstate_event(self) -> asyncio.Event: 477 # asyncio.Event() objects should not be created prior to entrance into 478 # an event loop, so we can ensure we create it in the correct context. 479 # Create it on-demand *only* at the behest of an 'async def' method. 480 if not self._runstate_changed: 481 self._runstate_changed = asyncio.Event() 482 return self._runstate_changed 483 484 @upper_half 485 @bottom_half 486 def _set_state(self, state: Runstate) -> None: 487 """ 488 Change the `Runstate` of the protocol connection. 489 490 Signals the `runstate_changed` event. 491 """ 492 if state == self._runstate: 493 return 494 495 self.logger.debug("Transitioning from '%s' to '%s'.", 496 str(self._runstate), str(state)) 497 self._runstate = state 498 self._runstate_event.set() 499 self._runstate_event.clear() 500 501 @bottom_half 502 async def _stop_server(self) -> None: 503 """ 504 Stop listening for / accepting new incoming connections. 505 """ 506 if self._server is None: 507 return 508 509 try: 510 self.logger.debug("Stopping server.") 511 self._server.close() 512 await self._server.wait_closed() 513 self.logger.debug("Server stopped.") 514 finally: 515 self._server = None 516 517 @bottom_half # However, it does not run from the R/W tasks. 518 async def _incoming(self, 519 reader: asyncio.StreamReader, 520 writer: asyncio.StreamWriter) -> None: 521 """ 522 Accept an incoming connection and signal the upper_half. 523 524 This method does the minimum necessary to accept a single 525 incoming connection. It signals back to the upper_half ASAP so 526 that any errors during session initialization can occur 527 naturally in the caller's stack. 528 529 :param reader: Incoming `asyncio.StreamReader` 530 :param writer: Incoming `asyncio.StreamWriter` 531 """ 532 peer = writer.get_extra_info('peername', 'Unknown peer') 533 self.logger.debug("Incoming connection from %s", peer) 534 535 if self._reader or self._writer: 536 # Sadly, we can have more than one pending connection 537 # because of https://bugs.python.org/issue46715 538 # Close any extra connections we don't actually want. 539 self.logger.warning("Extraneous connection inadvertently accepted") 540 writer.close() 541 return 542 543 # A connection has been accepted; stop listening for new ones. 544 assert self._accepted is not None 545 await self._stop_server() 546 self._reader, self._writer = (reader, writer) 547 self._accepted.set() 548 549 @upper_half 550 async def _do_start_server(self, address: SocketAddrT, 551 ssl: Optional[SSLContext] = None) -> None: 552 """ 553 Start listening for an incoming connection, but do not wait for a peer. 554 555 This method starts listening for an incoming connection, but does not 556 block waiting for a peer. This call will return immediately after 557 binding and listening to a socket. A later call to accept() must be 558 made in order to finalize the incoming connection. 559 560 :param address: 561 Address to listen on; UNIX socket path or TCP address/port. 562 :param ssl: SSL context to use, if any. 563 564 :raise OSError: For stream-related errors. 565 """ 566 assert self.runstate == Runstate.IDLE 567 self._set_state(Runstate.CONNECTING) 568 569 self.logger.debug("Awaiting connection on %s ...", address) 570 self._accepted = asyncio.Event() 571 572 if isinstance(address, tuple): 573 coro = asyncio.start_server( 574 self._incoming, 575 host=address[0], 576 port=address[1], 577 ssl=ssl, 578 backlog=1, 579 limit=self._limit, 580 ) 581 else: 582 coro = asyncio.start_unix_server( 583 self._incoming, 584 path=address, 585 ssl=ssl, 586 backlog=1, 587 limit=self._limit, 588 ) 589 590 # Allow runstate watchers to witness 'CONNECTING' state; some 591 # failures in the streaming layer are synchronous and will not 592 # otherwise yield. 593 await asyncio.sleep(0) 594 595 # This will start the server (bind(2), listen(2)). It will also 596 # call accept(2) if we yield, but we don't block on that here. 597 self._server = await coro 598 self.logger.debug("Server listening on %s", address) 599 600 @upper_half 601 async def _do_accept(self) -> None: 602 """ 603 Wait for and accept an incoming connection. 604 605 Requires that we have not yet accepted an incoming connection 606 from the upper_half, but it's OK if the server is no longer 607 running because the bottom_half has already accepted the 608 connection. 609 """ 610 assert self._accepted is not None 611 await self._accepted.wait() 612 assert self._server is None 613 self._accepted = None 614 615 self.logger.debug("Connection accepted.") 616 617 @upper_half 618 async def _do_connect(self, address: SocketAddrT, 619 ssl: Optional[SSLContext] = None) -> None: 620 """ 621 Acting as the transport client, initiate a connection to a server. 622 623 :param address: 624 Address to connect to; UNIX socket path or TCP address/port. 625 :param ssl: SSL context to use, if any. 626 627 :raise OSError: For stream-related errors. 628 """ 629 assert self.runstate == Runstate.IDLE 630 self._set_state(Runstate.CONNECTING) 631 632 # Allow runstate watchers to witness 'CONNECTING' state; some 633 # failures in the streaming layer are synchronous and will not 634 # otherwise yield. 635 await asyncio.sleep(0) 636 637 self.logger.debug("Connecting to %s ...", address) 638 639 if isinstance(address, tuple): 640 connect = asyncio.open_connection( 641 address[0], 642 address[1], 643 ssl=ssl, 644 limit=self._limit, 645 ) 646 else: 647 connect = asyncio.open_unix_connection( 648 path=address, 649 ssl=ssl, 650 limit=self._limit, 651 ) 652 self._reader, self._writer = await connect 653 654 self.logger.debug("Connected.") 655 656 @upper_half 657 async def _establish_session(self) -> None: 658 """ 659 Establish a new session. 660 661 Starts the readers/writer tasks; subclasses may perform their 662 own negotiations here. The Runstate will be RUNNING upon 663 successful conclusion. 664 """ 665 assert self.runstate == Runstate.CONNECTING 666 667 self._outgoing = asyncio.Queue() 668 669 reader_coro = self._bh_loop_forever(self._bh_recv_message, 'Reader') 670 writer_coro = self._bh_loop_forever(self._bh_send_message, 'Writer') 671 672 self._reader_task = create_task(reader_coro) 673 self._writer_task = create_task(writer_coro) 674 675 self._bh_tasks = asyncio.gather( 676 self._reader_task, 677 self._writer_task, 678 ) 679 680 self._set_state(Runstate.RUNNING) 681 await asyncio.sleep(0) # Allow runstate_event to process 682 683 @upper_half 684 @bottom_half 685 def _schedule_disconnect(self) -> None: 686 """ 687 Initiate a disconnect; idempotent. 688 689 This method is used both in the upper-half as a direct 690 consequence of `disconnect()`, and in the bottom-half in the 691 case of unhandled exceptions in the reader/writer tasks. 692 693 It can be invoked no matter what the `runstate` is. 694 """ 695 if not self._dc_task: 696 self._set_state(Runstate.DISCONNECTING) 697 self.logger.debug("Scheduling disconnect.") 698 self._dc_task = create_task(self._bh_disconnect()) 699 700 @upper_half 701 async def _wait_disconnect(self) -> None: 702 """ 703 Waits for a previously scheduled disconnect to finish. 704 705 This method will gather any bottom half exceptions and re-raise 706 the one that occurred first; presuming it to be the root cause 707 of any subsequent Exceptions. It is intended to be used in the 708 upper half of the call chain. 709 710 :raise Exception: 711 Arbitrary exception re-raised on behalf of the reader/writer. 712 """ 713 assert self.runstate == Runstate.DISCONNECTING 714 assert self._dc_task 715 716 aws: List[Awaitable[object]] = [self._dc_task] 717 if self._bh_tasks: 718 aws.insert(0, self._bh_tasks) 719 all_defined_tasks = asyncio.gather(*aws) 720 721 # Ensure disconnect is done; Exception (if any) is not raised here: 722 await asyncio.wait((self._dc_task,)) 723 724 try: 725 await all_defined_tasks # Raise Exceptions from the bottom half. 726 finally: 727 self._cleanup() 728 self._set_state(Runstate.IDLE) 729 730 @upper_half 731 def _cleanup(self) -> None: 732 """ 733 Fully reset this object to a clean state and return to `IDLE`. 734 """ 735 def _paranoid_task_erase(task: Optional['asyncio.Future[_U]'] 736 ) -> Optional['asyncio.Future[_U]']: 737 # Help to erase a task, ENSURING it is fully quiesced first. 738 assert (task is None) or task.done() 739 return None if (task and task.done()) else task 740 741 assert self.runstate == Runstate.DISCONNECTING 742 self._dc_task = _paranoid_task_erase(self._dc_task) 743 self._reader_task = _paranoid_task_erase(self._reader_task) 744 self._writer_task = _paranoid_task_erase(self._writer_task) 745 self._bh_tasks = _paranoid_task_erase(self._bh_tasks) 746 747 self._reader = None 748 self._writer = None 749 self._accepted = None 750 751 # NB: _runstate_changed cannot be cleared because we still need it to 752 # send the final runstate changed event ...! 753 754 # ---------------------------- 755 # Section: Bottom Half methods 756 # ---------------------------- 757 758 @bottom_half 759 async def _bh_disconnect(self) -> None: 760 """ 761 Disconnect and cancel all outstanding tasks. 762 763 It is designed to be called from its task context, 764 :py:obj:`~AsyncProtocol._dc_task`. By running in its own task, 765 it is free to wait on any pending actions that may still need to 766 occur in either the reader or writer tasks. 767 """ 768 assert self.runstate == Runstate.DISCONNECTING 769 770 def _done(task: Optional['asyncio.Future[Any]']) -> bool: 771 return task is not None and task.done() 772 773 # If the server is running, stop it. 774 await self._stop_server() 775 776 # Are we already in an error pathway? If either of the tasks are 777 # already done, or if we have no tasks but a reader/writer; we 778 # must be. 779 # 780 # NB: We can't use _bh_tasks to check for premature task 781 # completion, because it may not yet have had a chance to run 782 # and gather itself. 783 tasks = tuple(filter(None, (self._writer_task, self._reader_task))) 784 error_pathway = _done(self._reader_task) or _done(self._writer_task) 785 if not tasks: 786 error_pathway |= bool(self._reader) or bool(self._writer) 787 788 try: 789 # Try to flush the writer, if possible. 790 # This *may* cause an error and force us over into the error path. 791 if not error_pathway: 792 await self._bh_flush_writer() 793 except BaseException as err: 794 error_pathway = True 795 emsg = "Failed to flush the writer" 796 self.logger.error("%s: %s", emsg, exception_summary(err)) 797 self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) 798 raise 799 finally: 800 # Cancel any still-running tasks (Won't raise): 801 if self._writer_task is not None and not self._writer_task.done(): 802 self.logger.debug("Cancelling writer task.") 803 self._writer_task.cancel() 804 if self._reader_task is not None and not self._reader_task.done(): 805 self.logger.debug("Cancelling reader task.") 806 self._reader_task.cancel() 807 808 # Close out the tasks entirely (Won't raise): 809 if tasks: 810 self.logger.debug("Waiting for tasks to complete ...") 811 await asyncio.wait(tasks) 812 813 # Lastly, close the stream itself. (*May raise*!): 814 await self._bh_close_stream(error_pathway) 815 self.logger.debug("Disconnected.") 816 817 @bottom_half 818 async def _bh_flush_writer(self) -> None: 819 if not self._writer_task: 820 return 821 822 self.logger.debug("Draining the outbound queue ...") 823 await self._outgoing.join() 824 if self._writer is not None: 825 self.logger.debug("Flushing the StreamWriter ...") 826 await flush(self._writer) 827 828 @bottom_half 829 async def _bh_close_stream(self, error_pathway: bool = False) -> None: 830 # NB: Closing the writer also implicitly closes the reader. 831 if not self._writer: 832 return 833 834 if not is_closing(self._writer): 835 self.logger.debug("Closing StreamWriter.") 836 self._writer.close() 837 838 self.logger.debug("Waiting for StreamWriter to close ...") 839 try: 840 await wait_closed(self._writer) 841 except Exception: # pylint: disable=broad-except 842 # It's hard to tell if the Stream is already closed or 843 # not. Even if one of the tasks has failed, it may have 844 # failed for a higher-layered protocol reason. The 845 # stream could still be open and perfectly fine. 846 # I don't know how to discern its health here. 847 848 if error_pathway: 849 # We already know that *something* went wrong. Let's 850 # just trust that the Exception we already have is the 851 # better one to present to the user, even if we don't 852 # genuinely *know* the relationship between the two. 853 self.logger.debug( 854 "Discarding Exception from wait_closed:\n%s\n", 855 pretty_traceback(), 856 ) 857 else: 858 # Oops, this is a brand-new error! 859 raise 860 finally: 861 self.logger.debug("StreamWriter closed.") 862 863 @bottom_half 864 async def _bh_loop_forever(self, async_fn: _TaskFN, name: str) -> None: 865 """ 866 Run one of the bottom-half methods in a loop forever. 867 868 If the bottom half ever raises any exception, schedule a 869 disconnect that will terminate the entire loop. 870 871 :param async_fn: The bottom-half method to run in a loop. 872 :param name: The name of this task, used for logging. 873 """ 874 try: 875 while True: 876 await async_fn() 877 except asyncio.CancelledError: 878 # We have been cancelled by _bh_disconnect, exit gracefully. 879 self.logger.debug("Task.%s: cancelled.", name) 880 return 881 except BaseException as err: 882 self.logger.log( 883 logging.INFO if isinstance(err, EOFError) else logging.ERROR, 884 "Task.%s: %s", 885 name, exception_summary(err) 886 ) 887 self.logger.debug("Task.%s: failure:\n%s\n", 888 name, pretty_traceback()) 889 self._schedule_disconnect() 890 raise 891 finally: 892 self.logger.debug("Task.%s: exiting.", name) 893 894 @bottom_half 895 async def _bh_send_message(self) -> None: 896 """ 897 Wait for an outgoing message, then send it. 898 899 Designed to be run in `_bh_loop_forever()`. 900 """ 901 msg = await self._outgoing.get() 902 try: 903 await self._send(msg) 904 finally: 905 self._outgoing.task_done() 906 907 @bottom_half 908 async def _bh_recv_message(self) -> None: 909 """ 910 Wait for an incoming message and call `_on_message` to route it. 911 912 Designed to be run in `_bh_loop_forever()`. 913 """ 914 msg = await self._recv() 915 await self._on_message(msg) 916 917 # -------------------- 918 # Section: Message I/O 919 # -------------------- 920 921 @upper_half 922 @bottom_half 923 def _cb_outbound(self, msg: T) -> T: 924 """ 925 Callback: outbound message hook. 926 927 This is intended for subclasses to be able to add arbitrary 928 hooks to filter or manipulate outgoing messages. The base 929 implementation does nothing but log the message without any 930 manipulation of the message. 931 932 :param msg: raw outbound message 933 :return: final outbound message 934 """ 935 self.logger.debug("--> %s", str(msg)) 936 return msg 937 938 @upper_half 939 @bottom_half 940 def _cb_inbound(self, msg: T) -> T: 941 """ 942 Callback: inbound message hook. 943 944 This is intended for subclasses to be able to add arbitrary 945 hooks to filter or manipulate incoming messages. The base 946 implementation does nothing but log the message without any 947 manipulation of the message. 948 949 This method does not "handle" incoming messages; it is a filter. 950 The actual "endpoint" for incoming messages is `_on_message()`. 951 952 :param msg: raw inbound message 953 :return: processed inbound message 954 """ 955 self.logger.debug("<-- %s", str(msg)) 956 return msg 957 958 @upper_half 959 @bottom_half 960 async def _readline(self) -> bytes: 961 """ 962 Wait for a newline from the incoming reader. 963 964 This method is provided as a convenience for upper-layer 965 protocols, as many are line-based. 966 967 This method *may* return a sequence of bytes without a trailing 968 newline if EOF occurs, but *some* bytes were received. In this 969 case, the next call will raise `EOFError`. It is assumed that 970 the layer 5 protocol will decide if there is anything meaningful 971 to be done with a partial message. 972 973 :raise OSError: For stream-related errors. 974 :raise EOFError: 975 If the reader stream is at EOF and there are no bytes to return. 976 :return: bytes, including the newline. 977 """ 978 assert self._reader is not None 979 msg_bytes = await self._reader.readline() 980 981 if not msg_bytes: 982 if self._reader.at_eof(): 983 raise EOFError 984 985 return msg_bytes 986 987 @upper_half 988 @bottom_half 989 async def _do_recv(self) -> T: 990 """ 991 Abstract: Read from the stream and return a message. 992 993 Very low-level; intended to only be called by `_recv()`. 994 """ 995 raise NotImplementedError 996 997 @upper_half 998 @bottom_half 999 async def _recv(self) -> T: 1000 """ 1001 Read an arbitrary protocol message. 1002 1003 .. warning:: 1004 This method is intended primarily for `_bh_recv_message()` 1005 to use in an asynchronous task loop. Using it outside of 1006 this loop will "steal" messages from the normal routing 1007 mechanism. It is safe to use prior to `_establish_session()`, 1008 but should not be used otherwise. 1009 1010 This method uses `_do_recv()` to retrieve the raw message, and 1011 then transforms it using `_cb_inbound()`. 1012 1013 :return: A single (filtered, processed) protocol message. 1014 """ 1015 message = await self._do_recv() 1016 return self._cb_inbound(message) 1017 1018 @upper_half 1019 @bottom_half 1020 def _do_send(self, msg: T) -> None: 1021 """ 1022 Abstract: Write a message to the stream. 1023 1024 Very low-level; intended to only be called by `_send()`. 1025 """ 1026 raise NotImplementedError 1027 1028 @upper_half 1029 @bottom_half 1030 async def _send(self, msg: T) -> None: 1031 """ 1032 Send an arbitrary protocol message. 1033 1034 This method will transform any outgoing messages according to 1035 `_cb_outbound()`. 1036 1037 .. warning:: 1038 Like `_recv()`, this method is intended to be called by 1039 the writer task loop that processes outgoing 1040 messages. Calling it directly may circumvent logic 1041 implemented by the caller meant to correlate outgoing and 1042 incoming messages. 1043 1044 :raise OSError: For problems with the underlying stream. 1045 """ 1046 msg = self._cb_outbound(msg) 1047 self._do_send(msg) 1048 1049 @bottom_half 1050 async def _on_message(self, msg: T) -> None: 1051 """ 1052 Called to handle the receipt of a new message. 1053 1054 .. caution:: 1055 This is executed from within the reader loop, so be advised 1056 that waiting on either the reader or writer task will lead 1057 to deadlock. Additionally, any unhandled exceptions will 1058 directly cause the loop to halt, so logic may be best-kept 1059 to a minimum if at all possible. 1060 1061 :param msg: The incoming message, already logged/filtered. 1062 """ 1063 # Nothing to do in the abstract case. 1064