1import asyncio
2import ipaddress
3import socket
4import sys
5from typing import AsyncGenerator, Callable, Optional, cast
7from ..quic.configuration import QuicConfiguration
8from ..quic.connection import QuicConnection
9from ..tls import SessionTicketHandler
10from .compat import asynccontextmanager
11from .protocol import QuicConnectionProtocol, QuicStreamHandler
13__all__ = ["connect"]
17async def connect(
18    host: str,
19    port: int,
20    *,
21    configuration: Optional[QuicConfiguration] = None,
22    create_protocol: Optional[Callable] = QuicConnectionProtocol,
23    session_ticket_handler: Optional[SessionTicketHandler] = None,
24    stream_handler: Optional[QuicStreamHandler] = None,
25    wait_connected: bool = True,
26    local_port: int = 0,
27) -> AsyncGenerator[QuicConnectionProtocol, None]:
28    """
29    Connect to a QUIC server at the given `host` and `port`.
31    :meth:`connect()` returns an awaitable. Awaiting it yields a
32    :class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to
33    create streams.
35    :func:`connect` also accepts the following optional arguments:
37    * ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration`
38      configuration object.
39    * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
40      manages the connection. It should be a callable or class accepting the same
41      arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
42      an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
43    * ``session_ticket_handler`` is a callback which is invoked by the TLS
44      engine when a new session ticket is received.
45    * ``stream_handler`` is a callback which is invoked whenever a stream is
46      created. It must accept two arguments: a :class:`asyncio.StreamReader`
47      and a :class:`asyncio.StreamWriter`.
48    * ``local_port`` is the UDP port number that this client wants to bind.
49    """
50    loop = asyncio.get_event_loop()
51    local_host = "::"
53    # if host is not an IP address, pass it to enable SNI
54    try:
55        ipaddress.ip_address(host)
56        server_name = None
57    except ValueError:
58        server_name = host
60    # lookup remote address
61    infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
62    addr = infos[0][4]
63    if len(addr) == 2:
64        # determine behaviour for IPv4
65        if sys.platform == "win32":
66            # on Windows, we must use an IPv4 socket to reach an IPv4 host
67            local_host = ""
68        else:
69            # other platforms support dual-stack sockets
70            addr = ("::ffff:" + addr[0], addr[1], 0, 0)
72    # prepare QUIC connection
73    if configuration is None:
74        configuration = QuicConfiguration(is_client=True)
75    if server_name is not None:
76        configuration.server_name = server_name
77    connection = QuicConnection(
78        configuration=configuration, session_ticket_handler=session_ticket_handler
79    )
81    # connect
82    _, protocol = await loop.create_datagram_endpoint(
83        lambda: create_protocol(connection, stream_handler=stream_handler),
84        local_addr=(local_host, local_port),
85    )
86    protocol = cast(QuicConnectionProtocol, protocol)
87    protocol.connect(addr)
88    if wait_connected:
89        await protocol.wait_connected()
90    try:
91        yield protocol
92    finally:
93        protocol.close()
94    await protocol.wait_closed()