1import asyncio
2from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast
3
4from ..quic import events
5from ..quic.connection import NetworkAddress, QuicConnection
6
7QuicConnectionIdHandler = Callable[[bytes], None]
8QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
9
10
11class QuicConnectionProtocol(asyncio.DatagramProtocol):
12    def __init__(
13        self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None
14    ):
15        loop = asyncio.get_event_loop()
16
17        self._closed = asyncio.Event()
18        self._connected = False
19        self._connected_waiter: Optional[asyncio.Future[None]] = None
20        self._loop = loop
21        self._ping_waiters: Dict[int, asyncio.Future[None]] = {}
22        self._quic = quic
23        self._stream_readers: Dict[int, asyncio.StreamReader] = {}
24        self._timer: Optional[asyncio.TimerHandle] = None
25        self._timer_at: Optional[float] = None
26        self._transmit_task: Optional[asyncio.Handle] = None
27        self._transport: Optional[asyncio.DatagramTransport] = None
28
29        # callbacks
30        self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
31        self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
32        self._connection_terminated_handler: Callable[[], None] = lambda: None
33        if stream_handler is not None:
34            self._stream_handler = stream_handler
35        else:
36            self._stream_handler = lambda r, w: None
37
38    def change_connection_id(self) -> None:
39        """
40        Change the connection ID used to communicate with the peer.
41
42        The previous connection ID will be retired.
43        """
44        self._quic.change_connection_id()
45        self.transmit()
46
47    def close(self) -> None:
48        """
49        Close the connection.
50        """
51        self._quic.close()
52        self.transmit()
53
54    def connect(self, addr: NetworkAddress) -> None:
55        """
56        Initiate the TLS handshake.
57
58        This method can only be called for clients and a single time.
59        """
60        self._quic.connect(addr, now=self._loop.time())
61        self.transmit()
62
63    async def create_stream(
64        self, is_unidirectional: bool = False
65    ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
66        """
67        Create a QUIC stream and return a pair of (reader, writer) objects.
68
69        The returned reader and writer objects are instances of :class:`asyncio.StreamReader`
70        and :class:`asyncio.StreamWriter` classes.
71        """
72        stream_id = self._quic.get_next_available_stream_id(
73            is_unidirectional=is_unidirectional
74        )
75        return self._create_stream(stream_id)
76
77    def request_key_update(self) -> None:
78        """
79        Request an update of the encryption keys.
80        """
81        self._quic.request_key_update()
82        self.transmit()
83
84    async def ping(self) -> None:
85        """
86        Ping the peer and wait for the response.
87        """
88        waiter = self._loop.create_future()
89        uid = id(waiter)
90        self._ping_waiters[uid] = waiter
91        self._quic.send_ping(uid)
92        self.transmit()
93        await asyncio.shield(waiter)
94
95    def transmit(self) -> None:
96        """
97        Send pending datagrams to the peer and arm the timer if needed.
98        """
99        self._transmit_task = None
100
101        # send datagrams
102        for data, addr in self._quic.datagrams_to_send(now=self._loop.time()):
103            self._transport.sendto(data, addr)
104
105        # re-arm timer
106        timer_at = self._quic.get_timer()
107        if self._timer is not None and self._timer_at != timer_at:
108            self._timer.cancel()
109            self._timer = None
110        if self._timer is None and timer_at is not None:
111            self._timer = self._loop.call_at(timer_at, self._handle_timer)
112        self._timer_at = timer_at
113
114    async def wait_closed(self) -> None:
115        """
116        Wait for the connection to be closed.
117        """
118        await self._closed.wait()
119
120    async def wait_connected(self) -> None:
121        """
122        Wait for the TLS handshake to complete.
123        """
124        assert self._connected_waiter is None, "already awaiting connected"
125        if not self._connected:
126            self._connected_waiter = self._loop.create_future()
127            await asyncio.shield(self._connected_waiter)
128
129    # asyncio.Transport
130
131    def connection_made(self, transport: asyncio.BaseTransport) -> None:
132        self._transport = cast(asyncio.DatagramTransport, transport)
133
134    def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
135        self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
136        self._process_events()
137        self.transmit()
138
139    # overridable
140
141    def quic_event_received(self, event: events.QuicEvent) -> None:
142        """
143        Called when a QUIC event is received.
144
145        Reimplement this in your subclass to handle the events.
146        """
147        # FIXME: move this to a subclass
148        if isinstance(event, events.ConnectionTerminated):
149            for reader in self._stream_readers.values():
150                reader.feed_eof()
151        elif isinstance(event, events.StreamDataReceived):
152            reader = self._stream_readers.get(event.stream_id, None)
153            if reader is None:
154                reader, writer = self._create_stream(event.stream_id)
155                self._stream_handler(reader, writer)
156            reader.feed_data(event.data)
157            if event.end_stream:
158                reader.feed_eof()
159
160    # private
161
162    def _create_stream(
163        self, stream_id: int
164    ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
165        adapter = QuicStreamAdapter(self, stream_id)
166        reader = asyncio.StreamReader()
167        writer = asyncio.StreamWriter(adapter, None, reader, self._loop)
168        self._stream_readers[stream_id] = reader
169        return reader, writer
170
171    def _handle_timer(self) -> None:
172        now = max(self._timer_at, self._loop.time())
173        self._timer = None
174        self._timer_at = None
175        self._quic.handle_timer(now=now)
176        self._process_events()
177        self.transmit()
178
179    def _process_events(self) -> None:
180        event = self._quic.next_event()
181        while event is not None:
182            if isinstance(event, events.ConnectionIdIssued):
183                self._connection_id_issued_handler(event.connection_id)
184            elif isinstance(event, events.ConnectionIdRetired):
185                self._connection_id_retired_handler(event.connection_id)
186            elif isinstance(event, events.ConnectionTerminated):
187                self._connection_terminated_handler()
188
189                # abort connection waiter
190                if self._connected_waiter is not None:
191                    waiter = self._connected_waiter
192                    self._connected_waiter = None
193                    waiter.set_exception(ConnectionError)
194
195                # abort ping waiters
196                for waiter in self._ping_waiters.values():
197                    waiter.set_exception(ConnectionError)
198                self._ping_waiters.clear()
199
200                self._closed.set()
201            elif isinstance(event, events.HandshakeCompleted):
202                if self._connected_waiter is not None:
203                    waiter = self._connected_waiter
204                    self._connected = True
205                    self._connected_waiter = None
206                    waiter.set_result(None)
207            elif isinstance(event, events.PingAcknowledged):
208                waiter = self._ping_waiters.pop(event.uid, None)
209                if waiter is not None:
210                    waiter.set_result(None)
211            self.quic_event_received(event)
212            event = self._quic.next_event()
213
214    def _transmit_soon(self) -> None:
215        if self._transmit_task is None:
216            self._transmit_task = self._loop.call_soon(self.transmit)
217
218
219class QuicStreamAdapter(asyncio.Transport):
220    def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
221        self.protocol = protocol
222        self.stream_id = stream_id
223
224    def can_write_eof(self) -> bool:
225        return True
226
227    def get_extra_info(self, name: str, default: Any = None) -> Any:
228        """
229        Get information about the underlying QUIC stream.
230        """
231        if name == "stream_id":
232            return self.stream_id
233
234    def write(self, data):
235        self.protocol._quic.send_stream_data(self.stream_id, data)
236        self.protocol._transmit_soon()
237
238    def write_eof(self):
239        self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
240        self.protocol._transmit_soon()
241