1import asyncio
2import os
3from functools import partial
4from typing import Callable, Dict, Optional, Text, Union, cast
5
6from ..buffer import Buffer
7from ..quic.configuration import QuicConfiguration
8from ..quic.connection import NetworkAddress, QuicConnection
9from ..quic.packet import (
10    PACKET_TYPE_INITIAL,
11    encode_quic_retry,
12    encode_quic_version_negotiation,
13    pull_quic_header,
14)
15from ..quic.retry import QuicRetryTokenHandler
16from ..tls import SessionTicketFetcher, SessionTicketHandler
17from .protocol import QuicConnectionProtocol, QuicStreamHandler
18
19__all__ = ["serve"]
20
21
22class QuicServer(asyncio.DatagramProtocol):
23    def __init__(
24        self,
25        *,
26        configuration: QuicConfiguration,
27        create_protocol: Callable = QuicConnectionProtocol,
28        session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
29        session_ticket_handler: Optional[SessionTicketHandler] = None,
30        stateless_retry: bool = False,
31        stream_handler: Optional[QuicStreamHandler] = None,
32    ) -> None:
33        self._configuration = configuration
34        self._create_protocol = create_protocol
35        self._loop = asyncio.get_event_loop()
36        self._protocols: Dict[bytes, QuicConnectionProtocol] = {}
37        self._session_ticket_fetcher = session_ticket_fetcher
38        self._session_ticket_handler = session_ticket_handler
39        self._transport: Optional[asyncio.DatagramTransport] = None
40
41        self._stream_handler = stream_handler
42
43        if stateless_retry:
44            self._retry = QuicRetryTokenHandler()
45        else:
46            self._retry = None
47
48    def close(self):
49        for protocol in set(self._protocols.values()):
50            protocol.close()
51        self._protocols.clear()
52        self._transport.close()
53
54    def connection_made(self, transport: asyncio.BaseTransport) -> None:
55        self._transport = cast(asyncio.DatagramTransport, transport)
56
57    def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
58        data = cast(bytes, data)
59        buf = Buffer(data=data)
60
61        try:
62            header = pull_quic_header(
63                buf, host_cid_length=self._configuration.connection_id_length
64            )
65        except ValueError:
66            return
67
68        # version negotiation
69        if (
70            header.version is not None
71            and header.version not in self._configuration.supported_versions
72        ):
73            self._transport.sendto(
74                encode_quic_version_negotiation(
75                    source_cid=header.destination_cid,
76                    destination_cid=header.source_cid,
77                    supported_versions=self._configuration.supported_versions,
78                ),
79                addr,
80            )
81            return
82
83        protocol = self._protocols.get(header.destination_cid, None)
84        original_connection_id: Optional[bytes] = None
85        if (
86            protocol is None
87            and len(data) >= 1200
88            and header.packet_type == PACKET_TYPE_INITIAL
89        ):
90            # stateless retry
91            if self._retry is not None:
92                if not header.token:
93                    # create a retry token
94                    self._transport.sendto(
95                        encode_quic_retry(
96                            version=header.version,
97                            source_cid=os.urandom(8),
98                            destination_cid=header.source_cid,
99                            original_destination_cid=header.destination_cid,
100                            retry_token=self._retry.create_token(
101                                addr, header.destination_cid
102                            ),
103                        ),
104                        addr,
105                    )
106                    return
107                else:
108                    # validate retry token
109                    try:
110                        original_connection_id = self._retry.validate_token(
111                            addr, header.token
112                        )
113                    except ValueError:
114                        return
115
116            # create new connection
117            connection = QuicConnection(
118                configuration=self._configuration,
119                logger_connection_id=original_connection_id or header.destination_cid,
120                original_connection_id=original_connection_id,
121                session_ticket_fetcher=self._session_ticket_fetcher,
122                session_ticket_handler=self._session_ticket_handler,
123            )
124            protocol = self._create_protocol(
125                connection, stream_handler=self._stream_handler
126            )
127            protocol.connection_made(self._transport)
128
129            # register callbacks
130            protocol._connection_id_issued_handler = partial(
131                self._connection_id_issued, protocol=protocol
132            )
133            protocol._connection_id_retired_handler = partial(
134                self._connection_id_retired, protocol=protocol
135            )
136            protocol._connection_terminated_handler = partial(
137                self._connection_terminated, protocol=protocol
138            )
139
140            self._protocols[header.destination_cid] = protocol
141            self._protocols[connection.host_cid] = protocol
142
143        if protocol is not None:
144            protocol.datagram_received(data, addr)
145
146    def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
147        self._protocols[cid] = protocol
148
149    def _connection_id_retired(
150        self, cid: bytes, protocol: QuicConnectionProtocol
151    ) -> None:
152        assert self._protocols[cid] == protocol
153        del self._protocols[cid]
154
155    def _connection_terminated(self, protocol: QuicConnectionProtocol):
156        for cid, proto in list(self._protocols.items()):
157            if proto == protocol:
158                del self._protocols[cid]
159
160
161async def serve(
162    host: str,
163    port: int,
164    *,
165    configuration: QuicConfiguration,
166    create_protocol: Callable = QuicConnectionProtocol,
167    session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
168    session_ticket_handler: Optional[SessionTicketHandler] = None,
169    stateless_retry: bool = False,
170    stream_handler: QuicStreamHandler = None,
171) -> QuicServer:
172    """
173    Start a QUIC server at the given `host` and `port`.
174
175    :func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration`
176    containing TLS certificate and private key as the ``configuration`` argument.
177
178    :func:`serve` also accepts the following optional arguments:
179
180    * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
181      manages the connection. It should be a callable or class accepting the same
182      arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
183      an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
184    * ``session_ticket_fetcher`` is a callback which is invoked by the TLS
185      engine when a session ticket is presented by the peer. It should return
186      the session ticket with the specified ID or `None` if it is not found.
187    * ``session_ticket_handler`` is a callback which is invoked by the TLS
188      engine when a new session ticket is issued. It should store the session
189      ticket for future lookup.
190    * ``stateless_retry`` specifies whether a stateless retry should be
191      performed prior to handling new connections.
192    * ``stream_handler`` is a callback which is invoked whenever a stream is
193      created. It must accept two arguments: a :class:`asyncio.StreamReader`
194      and a :class:`asyncio.StreamWriter`.
195    """
196
197    loop = asyncio.get_event_loop()
198
199    _, protocol = await loop.create_datagram_endpoint(
200        lambda: QuicServer(
201            configuration=configuration,
202            create_protocol=create_protocol,
203            session_ticket_fetcher=session_ticket_fetcher,
204            session_ticket_handler=session_ticket_handler,
205            stateless_retry=stateless_retry,
206            stream_handler=stream_handler,
207        ),
208        local_addr=(host, port),
209    )
210    return cast(QuicServer, protocol)
211