import argparse import asyncio import logging import pickle import ssl from typing import Optional, cast from dnslib.dns import QTYPE, DNSQuestion, DNSRecord from quic_logger import QuicDirectoryLogger from aioquic.asyncio.client import connect from aioquic.asyncio.protocol import QuicConnectionProtocol from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import QuicEvent, StreamDataReceived logger = logging.getLogger("client") class DoQClient(QuicConnectionProtocol): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._ack_waiter: Optional[asyncio.Future[None]] = None async def query(self, query_type: str, dns_query: str) -> None: query = DNSRecord(q=DNSQuestion(dns_query, getattr(QTYPE, query_type))) stream_id = self._quic.get_next_available_stream_id() logger.debug(f"Stream ID: {stream_id}") end_stream = False self._quic.send_stream_data(stream_id, bytes(query.pack()), end_stream) waiter = self._loop.create_future() self._ack_waiter = waiter self.transmit() return await asyncio.shield(waiter) def quic_event_received(self, event: QuicEvent) -> None: if self._ack_waiter is not None: if isinstance(event, StreamDataReceived): answer = DNSRecord.parse(event.data) logger.info(answer) waiter = self._ack_waiter self._ack_waiter = None waiter.set_result(None) def save_session_ticket(ticket): """ Callback which is invoked by the TLS engine when a new session ticket is received. """ logger.info("New session ticket received") if args.session_ticket: with open(args.session_ticket, "wb") as fp: pickle.dump(ticket, fp) async def run( configuration: QuicConfiguration, host: str, port: int, query_type: str, dns_query: str, ) -> None: logger.debug(f"Connecting to {host}:{port}") async with connect( host, port, configuration=configuration, session_ticket_handler=save_session_ticket, create_protocol=DoQClient, ) as client: client = cast(DoQClient, client) logger.debug("Sending DNS query") await client.query(query_type, dns_query) if __name__ == "__main__": parser = argparse.ArgumentParser(description="DNS over QUIC client") parser.add_argument("-t", "--type", type=str, help="Type of record to ") parser.add_argument( "--host", type=str, default="localhost", help="The remote peer's host name or IP address", ) parser.add_argument( "--port", type=int, default=784, help="The remote peer's port number" ) parser.add_argument( "-k", "--insecure", action="store_true", help="do not validate server certificate", ) parser.add_argument( "--ca-certs", type=str, help="load CA certificates from the specified file" ) parser.add_argument("--dns_type", help="The DNS query type to send") parser.add_argument("--query", help="Domain to query") parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "-l", "--secrets-log", type=str, help="log secrets to a file, for use with Wireshark", ) parser.add_argument( "-s", "--session-ticket", type=str, help="read and write session ticket from the specified file", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) configuration = QuicConfiguration( alpn_protocols=["dq"], is_client=True, max_datagram_frame_size=65536 ) if args.ca_certs: configuration.load_verify_locations(args.ca_certs) if args.insecure: configuration.verify_mode = ssl.CERT_NONE if args.quic_log: configuration.quic_logger = QuicDirectoryLogger(args.quic_log) if args.secrets_log: configuration.secrets_log_file = open(args.secrets_log, "a") if args.session_ticket: try: with open(args.session_ticket, "rb") as fp: configuration.session_ticket = pickle.load(fp) except FileNotFoundError: logger.debug(f"Unable to read {args.session_ticket}") pass else: logger.debug("No session ticket defined...") loop = asyncio.get_event_loop() loop.run_until_complete( run( configuration=configuration, host=args.host, port=args.port, query_type=args.dns_type, dns_query=args.query, ) )