1import asyncio 2import os 3from functools import partial 4from typing import Callable, Dict, Optional, Text, Union, cast 5 6from ..buffer import Buffer 7from ..quic.configuration import QuicConfiguration 8from ..quic.connection import NetworkAddress, QuicConnection 9from ..quic.packet import ( 10 PACKET_TYPE_INITIAL, 11 encode_quic_retry, 12 encode_quic_version_negotiation, 13 pull_quic_header, 14) 15from ..quic.retry import QuicRetryTokenHandler 16from ..tls import SessionTicketFetcher, SessionTicketHandler 17from .protocol import QuicConnectionProtocol, QuicStreamHandler 18 19__all__ = ["serve"] 20 21 22class QuicServer(asyncio.DatagramProtocol): 23 def __init__( 24 self, 25 *, 26 configuration: QuicConfiguration, 27 create_protocol: Callable = QuicConnectionProtocol, 28 session_ticket_fetcher: Optional[SessionTicketFetcher] = None, 29 session_ticket_handler: Optional[SessionTicketHandler] = None, 30 stateless_retry: bool = False, 31 stream_handler: Optional[QuicStreamHandler] = None, 32 ) -> None: 33 self._configuration = configuration 34 self._create_protocol = create_protocol 35 self._loop = asyncio.get_event_loop() 36 self._protocols: Dict[bytes, QuicConnectionProtocol] = {} 37 self._session_ticket_fetcher = session_ticket_fetcher 38 self._session_ticket_handler = session_ticket_handler 39 self._transport: Optional[asyncio.DatagramTransport] = None 40 41 self._stream_handler = stream_handler 42 43 if stateless_retry: 44 self._retry = QuicRetryTokenHandler() 45 else: 46 self._retry = None 47 48 def close(self): 49 for protocol in set(self._protocols.values()): 50 protocol.close() 51 self._protocols.clear() 52 self._transport.close() 53 54 def connection_made(self, transport: asyncio.BaseTransport) -> None: 55 self._transport = cast(asyncio.DatagramTransport, transport) 56 57 def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: 58 data = cast(bytes, data) 59 buf = Buffer(data=data) 60 61 try: 62 header = pull_quic_header( 63 buf, host_cid_length=self._configuration.connection_id_length 64 ) 65 except ValueError: 66 return 67 68 # version negotiation 69 if ( 70 header.version is not None 71 and header.version not in self._configuration.supported_versions 72 ): 73 self._transport.sendto( 74 encode_quic_version_negotiation( 75 source_cid=header.destination_cid, 76 destination_cid=header.source_cid, 77 supported_versions=self._configuration.supported_versions, 78 ), 79 addr, 80 ) 81 return 82 83 protocol = self._protocols.get(header.destination_cid, None) 84 original_connection_id: Optional[bytes] = None 85 if ( 86 protocol is None 87 and len(data) >= 1200 88 and header.packet_type == PACKET_TYPE_INITIAL 89 ): 90 # stateless retry 91 if self._retry is not None: 92 if not header.token: 93 # create a retry token 94 self._transport.sendto( 95 encode_quic_retry( 96 version=header.version, 97 source_cid=os.urandom(8), 98 destination_cid=header.source_cid, 99 original_destination_cid=header.destination_cid, 100 retry_token=self._retry.create_token( 101 addr, header.destination_cid 102 ), 103 ), 104 addr, 105 ) 106 return 107 else: 108 # validate retry token 109 try: 110 original_connection_id = self._retry.validate_token( 111 addr, header.token 112 ) 113 except ValueError: 114 return 115 116 # create new connection 117 connection = QuicConnection( 118 configuration=self._configuration, 119 logger_connection_id=original_connection_id or header.destination_cid, 120 original_connection_id=original_connection_id, 121 session_ticket_fetcher=self._session_ticket_fetcher, 122 session_ticket_handler=self._session_ticket_handler, 123 ) 124 protocol = self._create_protocol( 125 connection, stream_handler=self._stream_handler 126 ) 127 protocol.connection_made(self._transport) 128 129 # register callbacks 130 protocol._connection_id_issued_handler = partial( 131 self._connection_id_issued, protocol=protocol 132 ) 133 protocol._connection_id_retired_handler = partial( 134 self._connection_id_retired, protocol=protocol 135 ) 136 protocol._connection_terminated_handler = partial( 137 self._connection_terminated, protocol=protocol 138 ) 139 140 self._protocols[header.destination_cid] = protocol 141 self._protocols[connection.host_cid] = protocol 142 143 if protocol is not None: 144 protocol.datagram_received(data, addr) 145 146 def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol): 147 self._protocols[cid] = protocol 148 149 def _connection_id_retired( 150 self, cid: bytes, protocol: QuicConnectionProtocol 151 ) -> None: 152 assert self._protocols[cid] == protocol 153 del self._protocols[cid] 154 155 def _connection_terminated(self, protocol: QuicConnectionProtocol): 156 for cid, proto in list(self._protocols.items()): 157 if proto == protocol: 158 del self._protocols[cid] 159 160 161async def serve( 162 host: str, 163 port: int, 164 *, 165 configuration: QuicConfiguration, 166 create_protocol: Callable = QuicConnectionProtocol, 167 session_ticket_fetcher: Optional[SessionTicketFetcher] = None, 168 session_ticket_handler: Optional[SessionTicketHandler] = None, 169 stateless_retry: bool = False, 170 stream_handler: QuicStreamHandler = None, 171) -> QuicServer: 172 """ 173 Start a QUIC server at the given `host` and `port`. 174 175 :func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration` 176 containing TLS certificate and private key as the ``configuration`` argument. 177 178 :func:`serve` also accepts the following optional arguments: 179 180 * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that 181 manages the connection. It should be a callable or class accepting the same 182 arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning 183 an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass. 184 * ``session_ticket_fetcher`` is a callback which is invoked by the TLS 185 engine when a session ticket is presented by the peer. It should return 186 the session ticket with the specified ID or `None` if it is not found. 187 * ``session_ticket_handler`` is a callback which is invoked by the TLS 188 engine when a new session ticket is issued. It should store the session 189 ticket for future lookup. 190 * ``stateless_retry`` specifies whether a stateless retry should be 191 performed prior to handling new connections. 192 * ``stream_handler`` is a callback which is invoked whenever a stream is 193 created. It must accept two arguments: a :class:`asyncio.StreamReader` 194 and a :class:`asyncio.StreamWriter`. 195 """ 196 197 loop = asyncio.get_event_loop() 198 199 _, protocol = await loop.create_datagram_endpoint( 200 lambda: QuicServer( 201 configuration=configuration, 202 create_protocol=create_protocol, 203 session_ticket_fetcher=session_ticket_fetcher, 204 session_ticket_handler=session_ticket_handler, 205 stateless_retry=stateless_retry, 206 stream_handler=stream_handler, 207 ), 208 local_addr=(host, port), 209 ) 210 return cast(QuicServer, protocol) 211