import binascii from typing import Callable, Optional, Tuple from .._crypto import AEAD, CryptoError, HeaderProtection from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract from .packet import QuicProtocolVersion, decode_packet_number, is_long_header CIPHER_SUITES = { CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"), CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"), CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"), } INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256 INITIAL_SALT_DRAFT_23 = binascii.unhexlify("c3eef712c72ebb5a11a7d2432bb46365bef9f502") INITIAL_SALT_DRAFT_29 = binascii.unhexlify("afbfec289993d24c9e9786f19c6111e04390a899") SAMPLE_SIZE = 16 Callback = Callable[[str], None] def NoCallback(trigger: str) -> None: pass class KeyUnavailableError(CryptoError): pass def derive_key_iv_hp( cipher_suite: CipherSuite, secret: bytes ) -> Tuple[bytes, bytes, bytes]: algorithm = cipher_suite_hash(cipher_suite) if cipher_suite in [ CipherSuite.AES_256_GCM_SHA384, CipherSuite.CHACHA20_POLY1305_SHA256, ]: key_size = 32 else: key_size = 16 return ( hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size), hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12), hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size), ) class CryptoContext: def __init__( self, key_phase: int = 0, setup_cb: Callback = NoCallback, teardown_cb: Callback = NoCallback, ) -> None: self.aead: Optional[AEAD] = None self.cipher_suite: Optional[CipherSuite] = None self.hp: Optional[HeaderProtection] = None self.key_phase = key_phase self.secret: Optional[bytes] = None self.version: Optional[int] = None self._setup_cb = setup_cb self._teardown_cb = teardown_cb def decrypt_packet( self, packet: bytes, encrypted_offset: int, expected_packet_number: int ) -> Tuple[bytes, bytes, int, bool]: if self.aead is None: raise KeyUnavailableError("Decryption key is not available") # header protection plain_header, packet_number = self.hp.remove(packet, encrypted_offset) first_byte = plain_header[0] # packet number pn_length = (first_byte & 0x03) + 1 packet_number = decode_packet_number( packet_number, pn_length * 8, expected_packet_number ) # detect key phase change crypto = self if not is_long_header(first_byte): key_phase = (first_byte & 4) >> 2 if key_phase != self.key_phase: crypto = next_key_phase(self) # payload protection payload = crypto.aead.decrypt( packet[len(plain_header) :], plain_header, packet_number ) return plain_header, payload, packet_number, crypto != self def encrypt_packet( self, plain_header: bytes, plain_payload: bytes, packet_number: int ) -> bytes: assert self.is_valid(), "Encryption key is not available" # payload protection protected_payload = self.aead.encrypt( plain_payload, plain_header, packet_number ) # header protection return self.hp.apply(plain_header, protected_payload) def is_valid(self) -> bool: return self.aead is not None def setup(self, cipher_suite: CipherSuite, secret: bytes, version: int) -> None: hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite] key, iv, hp = derive_key_iv_hp(cipher_suite, secret) self.aead = AEAD(aead_cipher_name, key, iv) self.cipher_suite = cipher_suite self.hp = HeaderProtection(hp_cipher_name, hp) self.secret = secret self.version = version # trigger callback self._setup_cb("tls") def teardown(self) -> None: self.aead = None self.cipher_suite = None self.hp = None self.secret = None # trigger callback self._teardown_cb("tls") def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None: self.aead = crypto.aead self.key_phase = crypto.key_phase self.secret = crypto.secret # trigger callback self._setup_cb(trigger) def next_key_phase(self: CryptoContext) -> CryptoContext: algorithm = cipher_suite_hash(self.cipher_suite) crypto = CryptoContext(key_phase=int(not self.key_phase)) crypto.setup( cipher_suite=self.cipher_suite, secret=hkdf_expand_label( algorithm, self.secret, b"quic ku", b"", algorithm.digest_size ), version=self.version, ) return crypto class CryptoPair: def __init__( self, recv_setup_cb: Callback = NoCallback, recv_teardown_cb: Callback = NoCallback, send_setup_cb: Callback = NoCallback, send_teardown_cb: Callback = NoCallback, ) -> None: self.aead_tag_size = 16 self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb) self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb) self._update_key_requested = False def decrypt_packet( self, packet: bytes, encrypted_offset: int, expected_packet_number: int ) -> Tuple[bytes, bytes, int]: plain_header, payload, packet_number, update_key = self.recv.decrypt_packet( packet, encrypted_offset, expected_packet_number ) if update_key: self._update_key("remote_update") return plain_header, payload, packet_number def encrypt_packet( self, plain_header: bytes, plain_payload: bytes, packet_number: int ) -> bytes: if self._update_key_requested: self._update_key("local_update") return self.send.encrypt_packet(plain_header, plain_payload, packet_number) def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None: if is_client: recv_label, send_label = b"server in", b"client in" else: recv_label, send_label = b"client in", b"server in" if version < QuicProtocolVersion.DRAFT_29: initial_salt = INITIAL_SALT_DRAFT_23 else: initial_salt = INITIAL_SALT_DRAFT_29 algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE) initial_secret = hkdf_extract(algorithm, initial_salt, cid) self.recv.setup( cipher_suite=INITIAL_CIPHER_SUITE, secret=hkdf_expand_label( algorithm, initial_secret, recv_label, b"", algorithm.digest_size ), version=version, ) self.send.setup( cipher_suite=INITIAL_CIPHER_SUITE, secret=hkdf_expand_label( algorithm, initial_secret, send_label, b"", algorithm.digest_size ), version=version, ) def teardown(self) -> None: self.recv.teardown() self.send.teardown() def update_key(self) -> None: self._update_key_requested = True @property def key_phase(self) -> int: if self._update_key_requested: return int(not self.recv.key_phase) else: return self.recv.key_phase def _update_key(self, trigger: str) -> None: apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger) apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger) self._update_key_requested = False