1import asyncio
2import ipaddress
3import socket
4import sys
5from typing import AsyncGenerator, Callable, Optional, cast
6
7from ..quic.configuration import QuicConfiguration
8from ..quic.connection import QuicConnection
9from ..tls import SessionTicketHandler
10from .compat import asynccontextmanager
11from .protocol import QuicConnectionProtocol, QuicStreamHandler
12
13__all__ = ["connect"]
14
15
16@asynccontextmanager
17async def connect(
18    host: str,
19    port: int,
20    *,
21    configuration: Optional[QuicConfiguration] = None,
22    create_protocol: Optional[Callable] = QuicConnectionProtocol,
23    session_ticket_handler: Optional[SessionTicketHandler] = None,
24    stream_handler: Optional[QuicStreamHandler] = None,
25    wait_connected: bool = True,
26    local_port: int = 0,
27) -> AsyncGenerator[QuicConnectionProtocol, None]:
28    """
29    Connect to a QUIC server at the given `host` and `port`.
30
31    :meth:`connect()` returns an awaitable. Awaiting it yields a
32    :class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to
33    create streams.
34
35    :func:`connect` also accepts the following optional arguments:
36
37    * ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration`
38      configuration object.
39    * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
40      manages the connection. It should be a callable or class accepting the same
41      arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
42      an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
43    * ``session_ticket_handler`` is a callback which is invoked by the TLS
44      engine when a new session ticket is received.
45    * ``stream_handler`` is a callback which is invoked whenever a stream is
46      created. It must accept two arguments: a :class:`asyncio.StreamReader`
47      and a :class:`asyncio.StreamWriter`.
48    * ``local_port`` is the UDP port number that this client wants to bind.
49    """
50    loop = asyncio.get_event_loop()
51    local_host = "::"
52
53    # if host is not an IP address, pass it to enable SNI
54    try:
55        ipaddress.ip_address(host)
56        server_name = None
57    except ValueError:
58        server_name = host
59
60    # lookup remote address
61    infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
62    addr = infos[0][4]
63    if len(addr) == 2:
64        # determine behaviour for IPv4
65        if sys.platform == "win32":
66            # on Windows, we must use an IPv4 socket to reach an IPv4 host
67            local_host = "0.0.0.0"
68        else:
69            # other platforms support dual-stack sockets
70            addr = ("::ffff:" + addr[0], addr[1], 0, 0)
71
72    # prepare QUIC connection
73    if configuration is None:
74        configuration = QuicConfiguration(is_client=True)
75    if server_name is not None:
76        configuration.server_name = server_name
77    connection = QuicConnection(
78        configuration=configuration, session_ticket_handler=session_ticket_handler
79    )
80
81    # connect
82    _, protocol = await loop.create_datagram_endpoint(
83        lambda: create_protocol(connection, stream_handler=stream_handler),
84        local_addr=(local_host, local_port),
85    )
86    protocol = cast(QuicConnectionProtocol, protocol)
87    protocol.connect(addr)
88    if wait_connected:
89        await protocol.wait_connected()
90    try:
91        yield protocol
92    finally:
93        protocol.close()
94    await protocol.wait_closed()
95