SDSM-for-SDI / src / aioquic / quic / retry.py
retry.py
Raw
import ipaddress
from typing import Tuple

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa

from ..buffer import Buffer
from ..tls import pull_opaque, push_opaque
from .connection import NetworkAddress


def encode_address(addr: NetworkAddress) -> bytes:
    return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])


class QuicRetryTokenHandler:
    def __init__(self) -> None:
        self._key = rsa.generate_private_key(
            public_exponent=65537, key_size=1024, backend=default_backend()
        )

    def create_token(
        self,
        addr: NetworkAddress,
        original_destination_connection_id: bytes,
        retry_source_connection_id: bytes,
    ) -> bytes:
        buf = Buffer(capacity=512)
        push_opaque(buf, 1, encode_address(addr))
        push_opaque(buf, 1, original_destination_connection_id)
        push_opaque(buf, 1, retry_source_connection_id)
        return self._key.public_key().encrypt(
            buf.data,
            padding.OAEP(
                mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None
            ),
        )

    def validate_token(self, addr: NetworkAddress, token: bytes) -> Tuple[bytes, bytes]:
        buf = Buffer(
            data=self._key.decrypt(
                token,
                padding.OAEP(
                    mgf=padding.MGF1(hashes.SHA256()),
                    algorithm=hashes.SHA256(),
                    label=None,
                ),
            )
        )
        encoded_addr = pull_opaque(buf, 1)
        original_destination_connection_id = pull_opaque(buf, 1)
        retry_source_connection_id = pull_opaque(buf, 1)
        if encoded_addr != encode_address(addr):
            raise ValueError("Remote address does not match.")
        return original_destination_connection_id, retry_source_connection_id