import asyncio import ipaddress import socket import sys from typing import AsyncGenerator, Callable, Optional, cast from ..quic.configuration import QuicConfiguration from ..quic.connection import QuicConnection from ..tls import SessionTicketHandler from .compat import asynccontextmanager from .protocol import QuicConnectionProtocol, QuicStreamHandler __all__ = ["connect"] @asynccontextmanager async def connect( host: str, port: int, *, configuration: Optional[QuicConfiguration] = None, create_protocol: Optional[Callable] = QuicConnectionProtocol, session_ticket_handler: Optional[SessionTicketHandler] = None, stream_handler: Optional[QuicStreamHandler] = None, wait_connected: bool = True, local_port: int = 0, ) -> AsyncGenerator[QuicConnectionProtocol, None]: """ Connect to a QUIC server at the given `host` and `port`. :meth:`connect()` returns an awaitable. Awaiting it yields a :class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to create streams. :func:`connect` also accepts the following optional arguments: * ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration` configuration object. * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that manages the connection. It should be a callable or class accepting the same arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass. * ``session_ticket_handler`` is a callback which is invoked by the TLS engine when a new session ticket is received. * ``stream_handler`` is a callback which is invoked whenever a stream is created. It must accept two arguments: a :class:`asyncio.StreamReader` and a :class:`asyncio.StreamWriter`. * ``local_port`` is the UDP port number that this client wants to bind. """ loop = asyncio.get_event_loop() local_host = "::" # if host is not an IP address, pass it to enable SNI try: ipaddress.ip_address(host) server_name = None except ValueError: server_name = host # lookup remote address infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM) addr = infos[0][4] if len(addr) == 2: # determine behaviour for IPv4 if sys.platform == "win32": # on Windows, we must use an IPv4 socket to reach an IPv4 host local_host = "0.0.0.0" else: # other platforms support dual-stack sockets addr = ("::ffff:" + addr[0], addr[1], 0, 0) # prepare QUIC connection if configuration is None: configuration = QuicConfiguration(is_client=True) if server_name is not None: configuration.server_name = server_name connection = QuicConnection( configuration=configuration, session_ticket_handler=session_ticket_handler ) # connect _, protocol = await loop.create_datagram_endpoint( lambda: create_protocol(connection, stream_handler=stream_handler), local_addr=(local_host, local_port), ) protocol = cast(QuicConnectionProtocol, protocol) protocol.connect(addr) if wait_connected: await protocol.wait_connected() try: yield protocol finally: protocol.close() await protocol.wait_closed()