import argparse import asyncio import logging from typing import Dict, Optional from dnslib.dns import DNSRecord from quic_logger import QuicDirectoryLogger from aioquic.asyncio import QuicConnectionProtocol, serve from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.connection import QuicConnection from aioquic.quic.events import ProtocolNegotiated, QuicEvent, StreamDataReceived from aioquic.tls import SessionTicket try: import uvloop except ImportError: uvloop = None class DnsConnection: def __init__(self, quic: QuicConnection): self._quic = quic def do_query(self, payload) -> bytes: q = DNSRecord.parse(payload) return q.send(self.resolver(), 53) def resolver(self) -> str: return args.resolver def handle_event(self, event: QuicEvent) -> None: if isinstance(event, StreamDataReceived): data = self.do_query(event.data) end_stream = False self._quic.send_stream_data(event.stream_id, data, end_stream) class DnsServerProtocol(QuicConnectionProtocol): # -00 specifies 'dq', 'doq', and 'doq-h00' (the latter obviously tying to # the version of the draft it matches). This is confusing, so we'll just # support them all, until future drafts define conflicting behaviour. SUPPORTED_ALPNS = ["dq", "doq", "doq-h00"] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._dns: Optional[DnsConnection] = None def quic_event_received(self, event: QuicEvent): if isinstance(event, ProtocolNegotiated): if event.alpn_protocol in DnsServerProtocol.SUPPORTED_ALPNS: self._dns = DnsConnection(self._quic) if self._dns is not None: self._dns.handle_event(event) class SessionTicketStore: """ Simple in-memory store for session tickets. """ def __init__(self) -> None: self.tickets: Dict[bytes, SessionTicket] = {} def add(self, ticket: SessionTicket) -> None: self.tickets[ticket.ticket] = ticket def pop(self, label: bytes) -> Optional[SessionTicket]: return self.tickets.pop(label, None) if __name__ == "__main__": parser = argparse.ArgumentParser(description="DNS over QUIC server") parser.add_argument( "--host", type=str, default="::", help="listen on the specified address (defaults to ::)", ) parser.add_argument( "--port", type=int, default=4784, help="listen on the specified port (defaults to 4784)", ) parser.add_argument( "-k", "--private-key", type=str, required=True, help="load the TLS private key from the specified file", ) parser.add_argument( "-c", "--certificate", type=str, required=True, help="load the TLS certificate from the specified file", ) parser.add_argument( "--resolver", type=str, default="8.8.8.8", help="Upstream Classic DNS resolver to use", ) parser.add_argument( "--retry", action="store_true", help="send a retry for new connections", ) parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) if args.quic_log: quic_logger = QuicDirectoryLogger(args.quic_log) else: quic_logger = None configuration = QuicConfiguration( alpn_protocols=["dq"], is_client=False, max_datagram_frame_size=65536, quic_logger=quic_logger, ) configuration.load_cert_chain(args.certificate, args.private_key) ticket_store = SessionTicketStore() if uvloop is not None: uvloop.install() loop = asyncio.get_event_loop() loop.run_until_complete( serve( args.host, args.port, configuration=configuration, create_protocol=DnsServerProtocol, session_ticket_fetcher=ticket_store.pop, session_ticket_handler=ticket_store.add, retry=args.retry, ) ) try: loop.run_forever() except KeyboardInterrupt: pass