1import asyncio
2import os
3from functools import partial
4from typing import Callable, Dict, Optional, Text, Union, cast
6from ..buffer import Buffer
7from ..quic.configuration import QuicConfiguration
8from ..quic.connection import NetworkAddress, QuicConnection
9from ..quic.packet import (
11    encode_quic_retry,
12    encode_quic_version_negotiation,
13    pull_quic_header,
15from ..quic.retry import QuicRetryTokenHandler
16from ..tls import SessionTicketFetcher, SessionTicketHandler
17from .protocol import QuicConnectionProtocol, QuicStreamHandler
19__all__ = ["serve"]
22class QuicServer(asyncio.DatagramProtocol):
23    def __init__(
24        self,
25        *,
26        configuration: QuicConfiguration,
27        create_protocol: Callable = QuicConnectionProtocol,
28        session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
29        session_ticket_handler: Optional[SessionTicketHandler] = None,
30        stateless_retry: bool = False,
31        stream_handler: Optional[QuicStreamHandler] = None,
32    ) -> None:
33        self._configuration = configuration
34        self._create_protocol = create_protocol
35        self._loop = asyncio.get_event_loop()
36        self._protocols: Dict[bytes, QuicConnectionProtocol] = {}
37        self._session_ticket_fetcher = session_ticket_fetcher
38        self._session_ticket_handler = session_ticket_handler
39        self._transport: Optional[asyncio.DatagramTransport] = None
41        self._stream_handler = stream_handler
43        if stateless_retry:
44            self._retry = QuicRetryTokenHandler()
45        else:
46            self._retry = None
48    def close(self):
49        for protocol in set(self._protocols.values()):
50            protocol.close()
51        self._protocols.clear()
52        self._transport.close()
54    def connection_made(self, transport: asyncio.BaseTransport) -> None:
55        self._transport = cast(asyncio.DatagramTransport, transport)
57    def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
58        data = cast(bytes, data)
59        buf = Buffer(data=data)
61        try:
62            header = pull_quic_header(
63                buf, host_cid_length=self._configuration.connection_id_length
64            )
65        except ValueError:
66            return
68        # version negotiation
69        if (
70            header.version is not None
71            and header.version not in self._configuration.supported_versions
72        ):
73            self._transport.sendto(
74                encode_quic_version_negotiation(
75                    source_cid=header.destination_cid,
76                    destination_cid=header.source_cid,
77                    supported_versions=self._configuration.supported_versions,
78                ),
79                addr,
80            )
81            return
83        protocol = self._protocols.get(header.destination_cid, None)
84        original_connection_id: Optional[bytes] = None
85        if (
86            protocol is None
87            and len(data) >= 1200
88            and header.packet_type == PACKET_TYPE_INITIAL
89        ):
90            # stateless retry
91            if self._retry is not None:
92                if not header.token:
93                    # create a retry token
94                    self._transport.sendto(
95                        encode_quic_retry(
96                            version=header.version,
97                            source_cid=os.urandom(8),
98                            destination_cid=header.source_cid,
99                            original_destination_cid=header.destination_cid,
100                            retry_token=self._retry.create_token(
101                                addr, header.destination_cid
102                            ),
103                        ),
104                        addr,
105                    )
106                    return
107                else:
108                    # validate retry token
109                    try:
110                        original_connection_id = self._retry.validate_token(
111                            addr, header.token
112                        )
113                    except ValueError:
114                        return
116            # create new connection
117            connection = QuicConnection(
118                configuration=self._configuration,
119                logger_connection_id=original_connection_id or header.destination_cid,
120                original_connection_id=original_connection_id,
121                session_ticket_fetcher=self._session_ticket_fetcher,
122                session_ticket_handler=self._session_ticket_handler,
123            )
124            protocol = self._create_protocol(
125                connection, stream_handler=self._stream_handler
126            )
127            protocol.connection_made(self._transport)
129            # register callbacks
130            protocol._connection_id_issued_handler = partial(
131                self._connection_id_issued, protocol=protocol
132            )
133            protocol._connection_id_retired_handler = partial(
134                self._connection_id_retired, protocol=protocol
135            )
136            protocol._connection_terminated_handler = partial(
137                self._connection_terminated, protocol=protocol
138            )
140            self._protocols[header.destination_cid] = protocol
141            self._protocols[connection.host_cid] = protocol
143        if protocol is not None:
144            protocol.datagram_received(data, addr)
146    def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
147        self._protocols[cid] = protocol
149    def _connection_id_retired(
150        self, cid: bytes, protocol: QuicConnectionProtocol
151    ) -> None:
152        assert self._protocols[cid] == protocol
153        del self._protocols[cid]
155    def _connection_terminated(self, protocol: QuicConnectionProtocol):
156        for cid, proto in list(self._protocols.items()):
157            if proto == protocol:
158                del self._protocols[cid]
161async def serve(
162    host: str,
163    port: int,
164    *,
165    configuration: QuicConfiguration,
166    create_protocol: Callable = QuicConnectionProtocol,
167    session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
168    session_ticket_handler: Optional[SessionTicketHandler] = None,
169    stateless_retry: bool = False,
170    stream_handler: QuicStreamHandler = None,
171) -> QuicServer:
172    """
173    Start a QUIC server at the given `host` and `port`.
175    :func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration`
176    containing TLS certificate and private key as the ``configuration`` argument.
178    :func:`serve` also accepts the following optional arguments:
180    * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
181      manages the connection. It should be a callable or class accepting the same
182      arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
183      an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
184    * ``session_ticket_fetcher`` is a callback which is invoked by the TLS
185      engine when a session ticket is presented by the peer. It should return
186      the session ticket with the specified ID or `None` if it is not found.
187    * ``session_ticket_handler`` is a callback which is invoked by the TLS
188      engine when a new session ticket is issued. It should store the session
189      ticket for future lookup.
190    * ``stateless_retry`` specifies whether a stateless retry should be
191      performed prior to handling new connections.
192    * ``stream_handler`` is a callback which is invoked whenever a stream is
193      created. It must accept two arguments: a :class:`asyncio.StreamReader`
194      and a :class:`asyncio.StreamWriter`.
195    """
197    loop = asyncio.get_event_loop()
199    _, protocol = await loop.create_datagram_endpoint(
200        lambda: QuicServer(
201            configuration=configuration,
202            create_protocol=create_protocol,
203            session_ticket_fetcher=session_ticket_fetcher,
204            session_ticket_handler=session_ticket_handler,
205            stateless_retry=stateless_retry,
206            stream_handler=stream_handler,
207        ),
208        local_addr=(host, port),
209    )
210    return cast(QuicServer, protocol)