SDSM-for-SDI / examples / doq_server.py
doq_server.py
Raw
import argparse
import asyncio
import logging
from typing import Dict, Optional

from dnslib.dns import DNSRecord
from quic_logger import QuicDirectoryLogger

from aioquic.asyncio import QuicConnectionProtocol, serve
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import ProtocolNegotiated, QuicEvent, StreamDataReceived
from aioquic.tls import SessionTicket

try:
    import uvloop
except ImportError:
    uvloop = None


class DnsConnection:
    def __init__(self, quic: QuicConnection):
        self._quic = quic

    def do_query(self, payload) -> bytes:
        q = DNSRecord.parse(payload)
        return q.send(self.resolver(), 53)

    def resolver(self) -> str:
        return args.resolver

    def handle_event(self, event: QuicEvent) -> None:
        if isinstance(event, StreamDataReceived):
            data = self.do_query(event.data)
            end_stream = False
            self._quic.send_stream_data(event.stream_id, data, end_stream)


class DnsServerProtocol(QuicConnectionProtocol):

    # -00 specifies 'dq', 'doq', and 'doq-h00' (the latter obviously tying to
    # the version of the draft it matches). This is confusing, so we'll just
    # support them all, until future drafts define conflicting behaviour.
    SUPPORTED_ALPNS = ["dq", "doq", "doq-h00"]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._dns: Optional[DnsConnection] = None

    def quic_event_received(self, event: QuicEvent):
        if isinstance(event, ProtocolNegotiated):
            if event.alpn_protocol in DnsServerProtocol.SUPPORTED_ALPNS:
                self._dns = DnsConnection(self._quic)
        if self._dns is not None:
            self._dns.handle_event(event)


class SessionTicketStore:
    """
    Simple in-memory store for session tickets.
    """

    def __init__(self) -> None:
        self.tickets: Dict[bytes, SessionTicket] = {}

    def add(self, ticket: SessionTicket) -> None:
        self.tickets[ticket.ticket] = ticket

    def pop(self, label: bytes) -> Optional[SessionTicket]:
        return self.tickets.pop(label, None)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="DNS over QUIC server")
    parser.add_argument(
        "--host",
        type=str,
        default="::",
        help="listen on the specified address (defaults to ::)",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=4784,
        help="listen on the specified port (defaults to 4784)",
    )
    parser.add_argument(
        "-k",
        "--private-key",
        type=str,
        required=True,
        help="load the TLS private key from the specified file",
    )
    parser.add_argument(
        "-c",
        "--certificate",
        type=str,
        required=True,
        help="load the TLS certificate from the specified file",
    )
    parser.add_argument(
        "--resolver",
        type=str,
        default="8.8.8.8",
        help="Upstream Classic DNS resolver to use",
    )
    parser.add_argument(
        "--retry", action="store_true", help="send a retry for new connections",
    )
    parser.add_argument(
        "-q",
        "--quic-log",
        type=str,
        help="log QUIC events to QLOG files in the specified directory",
    )
    parser.add_argument(
        "-v", "--verbose", action="store_true", help="increase logging verbosity"
    )

    args = parser.parse_args()

    logging.basicConfig(
        format="%(asctime)s %(levelname)s %(name)s %(message)s",
        level=logging.DEBUG if args.verbose else logging.INFO,
    )

    if args.quic_log:
        quic_logger = QuicDirectoryLogger(args.quic_log)
    else:
        quic_logger = None

    configuration = QuicConfiguration(
        alpn_protocols=["dq"],
        is_client=False,
        max_datagram_frame_size=65536,
        quic_logger=quic_logger,
    )

    configuration.load_cert_chain(args.certificate, args.private_key)

    ticket_store = SessionTicketStore()

    if uvloop is not None:
        uvloop.install()
    loop = asyncio.get_event_loop()
    loop.run_until_complete(
        serve(
            args.host,
            args.port,
            configuration=configuration,
            create_protocol=DnsServerProtocol,
            session_ticket_fetcher=ticket_store.pop,
            session_ticket_handler=ticket_store.add,
            retry=args.retry,
        )
    )
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        pass