1import asyncio
2from typing import Optional, cast
3
4from .tcp_helpers import tcp_nodelay
5
6
7class BaseProtocol(asyncio.Protocol):
8    __slots__ = (
9        "_loop",
10        "_paused",
11        "_drain_waiter",
12        "_connection_lost",
13        "_reading_paused",
14        "transport",
15    )
16
17    def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
18        self._loop = loop  # type: asyncio.AbstractEventLoop
19        self._paused = False
20        self._drain_waiter = None  # type: Optional[asyncio.Future[None]]
21        self._connection_lost = False
22        self._reading_paused = False
23
24        self.transport = None  # type: Optional[asyncio.Transport]
25
26    def pause_writing(self) -> None:
27        assert not self._paused
28        self._paused = True
29
30    def resume_writing(self) -> None:
31        assert self._paused
32        self._paused = False
33
34        waiter = self._drain_waiter
35        if waiter is not None:
36            self._drain_waiter = None
37            if not waiter.done():
38                waiter.set_result(None)
39
40    def pause_reading(self) -> None:
41        if not self._reading_paused and self.transport is not None:
42            try:
43                self.transport.pause_reading()
44            except (AttributeError, NotImplementedError, RuntimeError):
45                pass
46            self._reading_paused = True
47
48    def resume_reading(self) -> None:
49        if self._reading_paused and self.transport is not None:
50            try:
51                self.transport.resume_reading()
52            except (AttributeError, NotImplementedError, RuntimeError):
53                pass
54            self._reading_paused = False
55
56    def connection_made(self, transport: asyncio.BaseTransport) -> None:
57        tr = cast(asyncio.Transport, transport)
58        tcp_nodelay(tr, True)
59        self.transport = tr
60
61    def connection_lost(self, exc: Optional[BaseException]) -> None:
62        self._connection_lost = True
63        # Wake up the writer if currently paused.
64        self.transport = None
65        if not self._paused:
66            return
67        waiter = self._drain_waiter
68        if waiter is None:
69            return
70        self._drain_waiter = None
71        if waiter.done():
72            return
73        if exc is None:
74            waiter.set_result(None)
75        else:
76            waiter.set_exception(exc)
77
78    async def _drain_helper(self) -> None:
79        if self._connection_lost:
80            raise ConnectionResetError("Connection lost")
81        if not self._paused:
82            return
83        waiter = self._drain_waiter
84        assert waiter is None or waiter.cancelled()
85        waiter = self._loop.create_future()
86        self._drain_waiter = waiter
87        await waiter
88