import argparse import asyncio import logging import pickle import sys import time from collections import deque from typing import Deque, Dict, cast from urllib.parse import urlparse from httpx import AsyncClient from httpx.config import Timeout from httpx.dispatch.base import AsyncDispatcher from httpx.models import Request, Response from quic_logger import QuicDirectoryLogger from aioquic.asyncio.client import connect from aioquic.asyncio.protocol import QuicConnectionProtocol from aioquic.h3.connection import H3_ALPN, H3Connection from aioquic.h3.events import DataReceived, H3Event, HeadersReceived from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import QuicEvent logger = logging.getLogger("client") class H3Dispatcher(QuicConnectionProtocol, AsyncDispatcher): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._http = H3Connection(self._quic) self._request_events: Dict[int, Deque[H3Event]] = {} self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {} async def send(self, request: Request, timeout: Timeout = None) -> Response: stream_id = self._quic.get_next_available_stream_id() # prepare request self._http.send_headers( stream_id=stream_id, headers=[ (b":method", request.method.encode()), (b":scheme", request.url.scheme.encode()), (b":authority", str(request.url.authority).encode()), (b":path", request.url.full_path.encode()), ] + [ (k.encode(), v.encode()) for (k, v) in request.headers.items() if k not in ("connection", "host") ], ) self._http.send_data(stream_id=stream_id, data=request.read(), end_stream=True) # transmit request waiter = self._loop.create_future() self._request_events[stream_id] = deque() self._request_waiter[stream_id] = waiter self.transmit() # process response events: Deque[H3Event] = await asyncio.shield(waiter) content = b"" headers = [] status_code = None for event in events: if isinstance(event, HeadersReceived): for header, value in event.headers: if header == b":status": status_code = int(value.decode()) elif header[0:1] != b":": headers.append((header.decode(), value.decode())) elif isinstance(event, DataReceived): content += event.data return Response( status_code=status_code, http_version="HTTP/3", headers=headers, content=content, request=request, ) def http_event_received(self, event: H3Event): if isinstance(event, (HeadersReceived, DataReceived)): stream_id = event.stream_id if stream_id in self._request_events: self._request_events[event.stream_id].append(event) if event.stream_ended: request_waiter = self._request_waiter.pop(stream_id) request_waiter.set_result(self._request_events.pop(stream_id)) def quic_event_received(self, event: QuicEvent): #  pass event to the HTTP layer if self._http is not None: for http_event in self._http.handle_event(event): self.http_event_received(http_event) def save_session_ticket(ticket): """ Callback which is invoked by the TLS engine when a new session ticket is received. """ logger.info("New session ticket received") if args.session_ticket: with open(args.session_ticket, "wb") as fp: pickle.dump(ticket, fp) async def run(configuration: QuicConfiguration, url: str, data: str) -> None: # parse URL parsed = urlparse(url) assert parsed.scheme == "https", "Only https:// URLs are supported." if ":" in parsed.netloc: host, port_str = parsed.netloc.split(":") port = int(port_str) else: host = parsed.netloc port = 443 async with connect( host, port, configuration=configuration, create_protocol=H3Dispatcher, session_ticket_handler=save_session_ticket, ) as dispatch: client = AsyncClient(dispatch=cast(AsyncDispatcher, dispatch)) # perform request start = time.time() if data is not None: response = await client.post( url, data=data.encode(), headers={"content-type": "application/x-www-form-urlencoded"}, ) else: response = await client.get(url) elapsed = time.time() - start # print speed octets = len(response.content) logger.info( "Received %d bytes in %.1f s (%.3f Mbps)" % (octets, elapsed, octets * 8 / elapsed / 1000000) ) # print response for header, value in response.headers.items(): sys.stderr.write(header + ": " + value + "\r\n") sys.stderr.write("\r\n") sys.stdout.buffer.write(response.content) sys.stdout.buffer.flush() if __name__ == "__main__": parser = argparse.ArgumentParser(description="HTTP/3 client") parser.add_argument("url", type=str, help="the URL to query (must be HTTPS)") parser.add_argument( "-d", "--data", type=str, help="send the specified data in a POST request" ) parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "-l", "--secrets-log", type=str, help="log secrets to a file, for use with Wireshark", ) parser.add_argument( "-s", "--session-ticket", type=str, help="read and write session ticket from the specified file", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) # prepare configuration configuration = QuicConfiguration(is_client=True, alpn_protocols=H3_ALPN) if args.quic_log: configuration.quic_logger = QuicDirectoryLogger(args.quic_log) if args.secrets_log: configuration.secrets_log_file = open(args.secrets_log, "a") if args.session_ticket: try: with open(args.session_ticket, "rb") as fp: configuration.session_ticket = pickle.load(fp) except FileNotFoundError: pass loop = asyncio.get_event_loop() loop.run_until_complete( run(configuration=configuration, url=args.url, data=args.data) )