1import asyncio 2from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast 3 4from ..quic import events 5from ..quic.connection import NetworkAddress, QuicConnection 6 7QuicConnectionIdHandler = Callable[[bytes], None] 8QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None] 9 10 11class QuicConnectionProtocol(asyncio.DatagramProtocol): 12 def __init__( 13 self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None 14 ): 15 loop = asyncio.get_event_loop() 16 17 self._closed = asyncio.Event() 18 self._connected = False 19 self._connected_waiter: Optional[asyncio.Future[None]] = None 20 self._loop = loop 21 self._ping_waiters: Dict[int, asyncio.Future[None]] = {} 22 self._quic = quic 23 self._stream_readers: Dict[int, asyncio.StreamReader] = {} 24 self._timer: Optional[asyncio.TimerHandle] = None 25 self._timer_at: Optional[float] = None 26 self._transmit_task: Optional[asyncio.Handle] = None 27 self._transport: Optional[asyncio.DatagramTransport] = None 28 29 # callbacks 30 self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None 31 self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None 32 self._connection_terminated_handler: Callable[[], None] = lambda: None 33 if stream_handler is not None: 34 self._stream_handler = stream_handler 35 else: 36 self._stream_handler = lambda r, w: None 37 38 def change_connection_id(self) -> None: 39 """ 40 Change the connection ID used to communicate with the peer. 41 42 The previous connection ID will be retired. 43 """ 44 self._quic.change_connection_id() 45 self.transmit() 46 47 def close(self) -> None: 48 """ 49 Close the connection. 50 """ 51 self._quic.close() 52 self.transmit() 53 54 def connect(self, addr: NetworkAddress) -> None: 55 """ 56 Initiate the TLS handshake. 57 58 This method can only be called for clients and a single time. 59 """ 60 self._quic.connect(addr, now=self._loop.time()) 61 self.transmit() 62 63 async def create_stream( 64 self, is_unidirectional: bool = False 65 ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: 66 """ 67 Create a QUIC stream and return a pair of (reader, writer) objects. 68 69 The returned reader and writer objects are instances of :class:`asyncio.StreamReader` 70 and :class:`asyncio.StreamWriter` classes. 71 """ 72 stream_id = self._quic.get_next_available_stream_id( 73 is_unidirectional=is_unidirectional 74 ) 75 return self._create_stream(stream_id) 76 77 def request_key_update(self) -> None: 78 """ 79 Request an update of the encryption keys. 80 """ 81 self._quic.request_key_update() 82 self.transmit() 83 84 async def ping(self) -> None: 85 """ 86 Ping the peer and wait for the response. 87 """ 88 waiter = self._loop.create_future() 89 uid = id(waiter) 90 self._ping_waiters[uid] = waiter 91 self._quic.send_ping(uid) 92 self.transmit() 93 await asyncio.shield(waiter) 94 95 def transmit(self) -> None: 96 """ 97 Send pending datagrams to the peer and arm the timer if needed. 98 """ 99 self._transmit_task = None 100 101 # send datagrams 102 for data, addr in self._quic.datagrams_to_send(now=self._loop.time()): 103 self._transport.sendto(data, addr) 104 105 # re-arm timer 106 timer_at = self._quic.get_timer() 107 if self._timer is not None and self._timer_at != timer_at: 108 self._timer.cancel() 109 self._timer = None 110 if self._timer is None and timer_at is not None: 111 self._timer = self._loop.call_at(timer_at, self._handle_timer) 112 self._timer_at = timer_at 113 114 async def wait_closed(self) -> None: 115 """ 116 Wait for the connection to be closed. 117 """ 118 await self._closed.wait() 119 120 async def wait_connected(self) -> None: 121 """ 122 Wait for the TLS handshake to complete. 123 """ 124 assert self._connected_waiter is None, "already awaiting connected" 125 if not self._connected: 126 self._connected_waiter = self._loop.create_future() 127 await asyncio.shield(self._connected_waiter) 128 129 # asyncio.Transport 130 131 def connection_made(self, transport: asyncio.BaseTransport) -> None: 132 self._transport = cast(asyncio.DatagramTransport, transport) 133 134 def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: 135 self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time()) 136 self._process_events() 137 self.transmit() 138 139 # overridable 140 141 def quic_event_received(self, event: events.QuicEvent) -> None: 142 """ 143 Called when a QUIC event is received. 144 145 Reimplement this in your subclass to handle the events. 146 """ 147 # FIXME: move this to a subclass 148 if isinstance(event, events.ConnectionTerminated): 149 for reader in self._stream_readers.values(): 150 reader.feed_eof() 151 elif isinstance(event, events.StreamDataReceived): 152 reader = self._stream_readers.get(event.stream_id, None) 153 if reader is None: 154 reader, writer = self._create_stream(event.stream_id) 155 self._stream_handler(reader, writer) 156 reader.feed_data(event.data) 157 if event.end_stream: 158 reader.feed_eof() 159 160 # private 161 162 def _create_stream( 163 self, stream_id: int 164 ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: 165 adapter = QuicStreamAdapter(self, stream_id) 166 reader = asyncio.StreamReader() 167 writer = asyncio.StreamWriter(adapter, None, reader, self._loop) 168 self._stream_readers[stream_id] = reader 169 return reader, writer 170 171 def _handle_timer(self) -> None: 172 now = max(self._timer_at, self._loop.time()) 173 self._timer = None 174 self._timer_at = None 175 self._quic.handle_timer(now=now) 176 self._process_events() 177 self.transmit() 178 179 def _process_events(self) -> None: 180 event = self._quic.next_event() 181 while event is not None: 182 if isinstance(event, events.ConnectionIdIssued): 183 self._connection_id_issued_handler(event.connection_id) 184 elif isinstance(event, events.ConnectionIdRetired): 185 self._connection_id_retired_handler(event.connection_id) 186 elif isinstance(event, events.ConnectionTerminated): 187 self._connection_terminated_handler() 188 189 # abort connection waiter 190 if self._connected_waiter is not None: 191 waiter = self._connected_waiter 192 self._connected_waiter = None 193 waiter.set_exception(ConnectionError) 194 195 # abort ping waiters 196 for waiter in self._ping_waiters.values(): 197 waiter.set_exception(ConnectionError) 198 self._ping_waiters.clear() 199 200 self._closed.set() 201 elif isinstance(event, events.HandshakeCompleted): 202 if self._connected_waiter is not None: 203 waiter = self._connected_waiter 204 self._connected = True 205 self._connected_waiter = None 206 waiter.set_result(None) 207 elif isinstance(event, events.PingAcknowledged): 208 waiter = self._ping_waiters.pop(event.uid, None) 209 if waiter is not None: 210 waiter.set_result(None) 211 self.quic_event_received(event) 212 event = self._quic.next_event() 213 214 def _transmit_soon(self) -> None: 215 if self._transmit_task is None: 216 self._transmit_task = self._loop.call_soon(self.transmit) 217 218 219class QuicStreamAdapter(asyncio.Transport): 220 def __init__(self, protocol: QuicConnectionProtocol, stream_id: int): 221 self.protocol = protocol 222 self.stream_id = stream_id 223 224 def can_write_eof(self) -> bool: 225 return True 226 227 def get_extra_info(self, name: str, default: Any = None) -> Any: 228 """ 229 Get information about the underlying QUIC stream. 230 """ 231 if name == "stream_id": 232 return self.stream_id 233 234 def write(self, data): 235 self.protocol._quic.send_stream_data(self.stream_id, data) 236 self.protocol._transmit_soon() 237 238 def write_eof(self): 239 self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True) 240 self.protocol._transmit_soon() 241