1import asyncio 2from typing import Optional, cast 3 4from .tcp_helpers import tcp_nodelay 5 6 7class BaseProtocol(asyncio.Protocol): 8 __slots__ = ( 9 "_loop", 10 "_paused", 11 "_drain_waiter", 12 "_connection_lost", 13 "_reading_paused", 14 "transport", 15 ) 16 17 def __init__(self, loop: asyncio.AbstractEventLoop) -> None: 18 self._loop = loop # type: asyncio.AbstractEventLoop 19 self._paused = False 20 self._drain_waiter = None # type: Optional[asyncio.Future[None]] 21 self._connection_lost = False 22 self._reading_paused = False 23 24 self.transport = None # type: Optional[asyncio.Transport] 25 26 def pause_writing(self) -> None: 27 assert not self._paused 28 self._paused = True 29 30 def resume_writing(self) -> None: 31 assert self._paused 32 self._paused = False 33 34 waiter = self._drain_waiter 35 if waiter is not None: 36 self._drain_waiter = None 37 if not waiter.done(): 38 waiter.set_result(None) 39 40 def pause_reading(self) -> None: 41 if not self._reading_paused and self.transport is not None: 42 try: 43 self.transport.pause_reading() 44 except (AttributeError, NotImplementedError, RuntimeError): 45 pass 46 self._reading_paused = True 47 48 def resume_reading(self) -> None: 49 if self._reading_paused and self.transport is not None: 50 try: 51 self.transport.resume_reading() 52 except (AttributeError, NotImplementedError, RuntimeError): 53 pass 54 self._reading_paused = False 55 56 def connection_made(self, transport: asyncio.BaseTransport) -> None: 57 tr = cast(asyncio.Transport, transport) 58 tcp_nodelay(tr, True) 59 self.transport = tr 60 61 def connection_lost(self, exc: Optional[BaseException]) -> None: 62 self._connection_lost = True 63 # Wake up the writer if currently paused. 64 self.transport = None 65 if not self._paused: 66 return 67 waiter = self._drain_waiter 68 if waiter is None: 69 return 70 self._drain_waiter = None 71 if waiter.done(): 72 return 73 if exc is None: 74 waiter.set_result(None) 75 else: 76 waiter.set_exception(exc) 77 78 async def _drain_helper(self) -> None: 79 if self._connection_lost: 80 raise ConnectionResetError("Connection lost") 81 if not self._paused: 82 return 83 waiter = self._drain_waiter 84 assert waiter is None or waiter.cancelled() 85 waiter = self._loop.create_future() 86 self._drain_waiter = waiter 87 await waiter 88