import math
from typing import Callable, Dict, Iterable, List, Optional
from .logger import QuicLoggerTrace
from .packet_builder import QuicDeliveryState, QuicSentPacket
from .rangeset import RangeSet
# loss detection
K_PACKET_THRESHOLD = 3
K_GRANULARITY = 0.001 # seconds
K_TIME_THRESHOLD = 9 / 8
K_MICRO_SECOND = 0.000001
K_SECOND = 1.0
# congestion control
K_MAX_DATAGRAM_SIZE = 1280
K_INITIAL_WINDOW = 10 * K_MAX_DATAGRAM_SIZE
K_MINIMUM_WINDOW = 2 * K_MAX_DATAGRAM_SIZE
K_LOSS_REDUCTION_FACTOR = 0.5
class QuicPacketSpace:
def __init__(self) -> None:
self.ack_at: Optional[float] = None
self.ack_queue = RangeSet()
self.discarded = False
self.expected_packet_number = 0
self.largest_received_packet = -1
self.largest_received_time: Optional[float] = None
# sent packets and loss
self.ack_eliciting_in_flight = 0
self.largest_acked_packet = 0
self.loss_time: Optional[float] = None
self.sent_packets: Dict[int, QuicSentPacket] = {}
class QuicPacketPacer:
def __init__(self) -> None:
self.bucket_max: float = 0.0
self.bucket_time: float = 0.0
self.evaluation_time: float = 0.0
self.packet_time: Optional[float] = None
def next_send_time(self, now: float) -> float:
if self.packet_time is not None:
self.update_bucket(now=now)
if self.bucket_time <= 0:
return now + self.packet_time
return None
def update_after_send(self, now: float) -> None:
if self.packet_time is not None:
self.update_bucket(now=now)
if self.bucket_time < self.packet_time:
self.bucket_time = 0.0
else:
self.bucket_time -= self.packet_time
def update_bucket(self, now: float) -> None:
if now > self.evaluation_time:
self.bucket_time = min(
self.bucket_time + (now - self.evaluation_time), self.bucket_max
)
self.evaluation_time = now
def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None:
pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND)
self.packet_time = max(
K_MICRO_SECOND, min(K_MAX_DATAGRAM_SIZE / pacing_rate, K_SECOND)
)
self.bucket_max = (
max(
2 * K_MAX_DATAGRAM_SIZE,
min(congestion_window // 4, 16 * K_MAX_DATAGRAM_SIZE),
)
/ pacing_rate
)
if self.bucket_time > self.bucket_max:
self.bucket_time = self.bucket_max
class QuicCongestionControl:
"""
New Reno congestion control.
"""
def __init__(self) -> None:
self.bytes_in_flight = 0
self.congestion_window = K_INITIAL_WINDOW
self._congestion_recovery_start_time = 0.0
self._congestion_stash = 0
self._rtt_monitor = QuicRttMonitor()
self.ssthresh: Optional[int] = None
def on_packet_acked(self, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes
# don't increase window in congestion recovery
if packet.sent_time <= self._congestion_recovery_start_time:
return
if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
self._congestion_stash += packet.sent_bytes
count = self._congestion_stash // self.congestion_window
if count:
self._congestion_stash -= count * self.congestion_window
self.congestion_window += count * K_MAX_DATAGRAM_SIZE
def on_packet_sent(self, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes
def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time
# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now
self.congestion_window = max(
int(self.congestion_window * K_LOSS_REDUCTION_FACTOR), K_MINIMUM_WINDOW
)
self.ssthresh = self.congestion_window
# TODO : collapse congestion window if persistent congestion
def on_rtt_measurement(self, latest_rtt: float, now: float) -> None:
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
latest_rtt, now
):
self.ssthresh = self.congestion_window
class QuicPacketRecovery:
"""
Packet loss and congestion controller.
"""
def __init__(
self,
initial_rtt: float,
peer_completed_address_validation: bool,
send_probe: Callable[[], None],
quic_logger: Optional[QuicLoggerTrace] = None,
) -> None:
self.max_ack_delay = 0.025
self.peer_completed_address_validation = peer_completed_address_validation
self.spaces: List[QuicPacketSpace] = []
# callbacks
self._quic_logger = quic_logger
self._send_probe = send_probe
# loss detection
self._pto_count = 0
self._rtt_initial = initial_rtt
self._rtt_initialized = False
self._rtt_latest = 0.0
self._rtt_min = math.inf
self._rtt_smoothed = 0.0
self._rtt_variance = 0.0
self._time_of_last_sent_ack_eliciting_packet = 0.0
# congestion control
self._cc = QuicCongestionControl()
self._pacer = QuicPacketPacer()
@property
def bytes_in_flight(self) -> int:
return self._cc.bytes_in_flight
@property
def congestion_window(self) -> int:
return self._cc.congestion_window
def discard_space(self, space: QuicPacketSpace) -> None:
assert space in self.spaces
self._cc.on_packets_expired(
filter(lambda x: x.in_flight, space.sent_packets.values())
)
space.sent_packets.clear()
space.ack_at = None
space.ack_eliciting_in_flight = 0
space.loss_time = None
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated()
def get_loss_detection_time(self) -> float:
# loss timer
loss_space = self._get_loss_space()
if loss_space is not None:
return loss_space.loss_time
# packet timer
if (
not self.peer_completed_address_validation
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
):
timeout = self.get_probe_timeout() * (2 ** self._pto_count)
return self._time_of_last_sent_ack_eliciting_packet + timeout
return None
def get_probe_timeout(self) -> float:
if not self._rtt_initialized:
return 2 * self._rtt_initial
return (
self._rtt_smoothed
+ max(4 * self._rtt_variance, K_GRANULARITY)
+ self.max_ack_delay
)
def on_ack_received(
self,
space: QuicPacketSpace,
ack_rangeset: RangeSet,
ack_delay: float,
now: float,
) -> None:
"""
Update metrics as the result of an ACK being received.
"""
is_ack_eliciting = False
largest_acked = ack_rangeset.bounds().stop - 1
largest_newly_acked = None
largest_sent_time = None
if largest_acked > space.largest_acked_packet:
space.largest_acked_packet = largest_acked
for packet_number in sorted(space.sent_packets.keys()):
if packet_number > largest_acked:
break
if packet_number in ack_rangeset:
# remove packet and update counters
packet = space.sent_packets.pop(packet_number)
if packet.is_ack_eliciting:
is_ack_eliciting = True
space.ack_eliciting_in_flight -= 1
if packet.in_flight:
self._cc.on_packet_acked(packet)
largest_newly_acked = packet_number
largest_sent_time = packet.sent_time
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.ACKED, *args)
# nothing to do if there are no newly acked packets
if largest_newly_acked is None:
return
if largest_acked == largest_newly_acked and is_ack_eliciting:
latest_rtt = now - largest_sent_time
log_rtt = True
# limit ACK delay to max_ack_delay
ack_delay = min(ack_delay, self.max_ack_delay)
# update RTT estimate, which cannot be < 1 ms
self._rtt_latest = max(latest_rtt, 0.001)
if self._rtt_latest < self._rtt_min:
self._rtt_min = self._rtt_latest
if self._rtt_latest > self._rtt_min + ack_delay:
self._rtt_latest -= ack_delay
if not self._rtt_initialized:
self._rtt_initialized = True
self._rtt_variance = latest_rtt / 2
self._rtt_smoothed = latest_rtt
else:
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
self._rtt_min - self._rtt_latest
)
self._rtt_smoothed = (
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
)
# inform congestion controller
self._cc.on_rtt_measurement(latest_rtt, now=now)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
else:
log_rtt = False
self._detect_loss(space, now=now)
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated(log_rtt=log_rtt)
def on_loss_detection_timeout(self, now: float) -> None:
loss_space = self._get_loss_space()
if loss_space is not None:
self._detect_loss(loss_space, now=now)
else:
self._pto_count += 1
# reschedule some data
for space in self.spaces:
self._on_packets_lost(
tuple(
filter(
lambda i: i.is_crypto_packet, space.sent_packets.values()
)
),
space=space,
now=now,
)
self._send_probe()
def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
space.sent_packets[packet.packet_number] = packet
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight += 1
if packet.in_flight:
if packet.is_ack_eliciting:
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
# add packet to bytes in flight
self._cc.on_packet_sent(packet)
if self._quic_logger is not None:
self._log_metrics_updated()
def _detect_loss(self, space: QuicPacketSpace, now: float) -> None:
"""
Check whether any packets should be declared lost.
"""
loss_delay = K_TIME_THRESHOLD * (
max(self._rtt_latest, self._rtt_smoothed)
if self._rtt_initialized
else self._rtt_initial
)
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
time_threshold = now - loss_delay
lost_packets = []
space.loss_time = None
for packet_number, packet in space.sent_packets.items():
if packet_number > space.largest_acked_packet:
break
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
lost_packets.append(packet)
else:
packet_loss_time = packet.sent_time + loss_delay
if space.loss_time is None or space.loss_time > packet_loss_time:
space.loss_time = packet_loss_time
self._on_packets_lost(lost_packets, space=space, now=now)
def _get_loss_space(self) -> Optional[QuicPacketSpace]:
loss_space = None
for space in self.spaces:
if space.loss_time is not None and (
loss_space is None or space.loss_time < loss_space.loss_time
):
loss_space = space
return loss_space
def _log_metrics_updated(self, log_rtt=False) -> None:
data = {
"bytes_in_flight": self._cc.bytes_in_flight,
"cwnd": self._cc.congestion_window,
}
if self._cc.ssthresh is not None:
data["ssthresh"] = self._cc.ssthresh
if log_rtt:
data.update(
{
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
}
)
self._quic_logger.log_event(
category="recovery", event="metrics_updated", data=data
)
def _on_packets_lost(
self, packets: Iterable[QuicSentPacket], space: QuicPacketSpace, now: float
) -> None:
lost_packets_cc = []
for packet in packets:
del space.sent_packets[packet.packet_number]
if packet.in_flight:
lost_packets_cc.append(packet)
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight -= 1
if self._quic_logger is not None:
self._quic_logger.log_event(
category="recovery",
event="packet_lost",
data={
"type": self._quic_logger.packet_type(packet.packet_type),
"packet_number": str(packet.packet_number),
},
)
self._log_metrics_updated()
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.LOST, *args)
# inform congestion controller
if lost_packets_cc:
self._cc.on_packets_lost(lost_packets_cc, now=now)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
if self._quic_logger is not None:
self._log_metrics_updated()
class QuicRttMonitor:
"""
Roundtrip time monitor for HyStart.
"""
def __init__(self) -> None:
self._increases = 0
self._last_time = None
self._ready = False
self._size = 5
self._filtered_min: Optional[float] = None
self._sample_idx = 0
self._sample_max: Optional[float] = None
self._sample_min: Optional[float] = None
self._sample_time = 0.0
self._samples = [0.0 for i in range(self._size)]
def add_rtt(self, rtt: float) -> None:
self._samples[self._sample_idx] = rtt
self._sample_idx += 1
if self._sample_idx >= self._size:
self._sample_idx = 0
self._ready = True
if self._ready:
self._sample_max = self._samples[0]
self._sample_min = self._samples[0]
for sample in self._samples[1:]:
if sample < self._sample_min:
self._sample_min = sample
elif sample > self._sample_max:
self._sample_max = sample
def is_rtt_increasing(self, rtt: float, now: float) -> bool:
if now > self._sample_time + K_GRANULARITY:
self.add_rtt(rtt)
self._sample_time = now
if self._ready:
if self._filtered_min is None or self._filtered_min > self._sample_max:
self._filtered_min = self._sample_max
delta = self._sample_min - self._filtered_min
if delta * 4 >= self._filtered_min:
self._increases += 1
if self._increases >= self._size:
return True
elif delta > 0:
self._increases = 0
return False