import math from unittest import TestCase from aioquic import tls from aioquic.quic.packet import PACKET_TYPE_INITIAL, PACKET_TYPE_ONE_RTT from aioquic.quic.packet_builder import QuicSentPacket from aioquic.quic.rangeset import RangeSet from aioquic.quic.recovery import ( QuicPacketPacer, QuicPacketRecovery, QuicPacketSpace, QuicRttMonitor, ) def send_probe(): pass class QuicPacketPacerTest(TestCase): def setUp(self): self.pacer = QuicPacketPacer() def test_no_measurement(self): self.assertIsNone(self.pacer.next_send_time(now=0.0)) self.pacer.update_after_send(now=0.0) self.assertIsNone(self.pacer.next_send_time(now=0.0)) self.pacer.update_after_send(now=0.0) def test_with_measurement(self): self.assertIsNone(self.pacer.next_send_time(now=0.0)) self.pacer.update_after_send(now=0.0) self.pacer.update_rate(congestion_window=1280000, smoothed_rtt=0.05) self.assertEqual(self.pacer.bucket_max, 0.0008) self.assertEqual(self.pacer.bucket_time, 0.0) self.assertEqual(self.pacer.packet_time, 0.00005) # 16 packets for i in range(16): self.assertIsNone(self.pacer.next_send_time(now=1.0)) self.pacer.update_after_send(now=1.0) self.assertAlmostEqual(self.pacer.next_send_time(now=1.0), 1.00005) # 2 packets for i in range(2): self.assertIsNone(self.pacer.next_send_time(now=1.00005)) self.pacer.update_after_send(now=1.00005) self.assertAlmostEqual(self.pacer.next_send_time(now=1.00005), 1.0001) # 1 packet self.assertIsNone(self.pacer.next_send_time(now=1.0001)) self.pacer.update_after_send(now=1.0001) self.assertAlmostEqual(self.pacer.next_send_time(now=1.0001), 1.00015) # 2 packets for i in range(2): self.assertIsNone(self.pacer.next_send_time(now=1.00015)) self.pacer.update_after_send(now=1.00015) self.assertAlmostEqual(self.pacer.next_send_time(now=1.00015), 1.0002) class QuicPacketRecoveryTest(TestCase): def setUp(self): self.INITIAL_SPACE = QuicPacketSpace() self.HANDSHAKE_SPACE = QuicPacketSpace() self.ONE_RTT_SPACE = QuicPacketSpace() self.recovery = QuicPacketRecovery( initial_rtt=0.1, peer_completed_address_validation=True, send_probe=send_probe, ) self.recovery.spaces = [ self.INITIAL_SPACE, self.HANDSHAKE_SPACE, self.ONE_RTT_SPACE, ] def test_discard_space(self): self.recovery.discard_space(self.INITIAL_SPACE) def test_on_ack_received_ack_eliciting(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=PACKET_TYPE_ONE_RTT, sent_bytes=1280, sent_time=0.0, ) space = self.ONE_RTT_SPACE #  packet sent self.recovery.on_packet_sent(packet, space) self.assertEqual(self.recovery.bytes_in_flight, 1280) self.assertEqual(space.ack_eliciting_in_flight, 1) self.assertEqual(len(space.sent_packets), 1) # packet ack'd self.recovery.on_ack_received( space, ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0 ) self.assertEqual(self.recovery.bytes_in_flight, 0) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 0) # check RTT self.assertTrue(self.recovery._rtt_initialized) self.assertEqual(self.recovery._rtt_latest, 10.0) self.assertEqual(self.recovery._rtt_min, 10.0) self.assertEqual(self.recovery._rtt_smoothed, 10.0) def test_on_ack_received_non_ack_eliciting(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=False, is_crypto_packet=False, packet_number=0, packet_type=PACKET_TYPE_ONE_RTT, sent_bytes=1280, sent_time=123.45, ) space = self.ONE_RTT_SPACE #  packet sent self.recovery.on_packet_sent(packet, space) self.assertEqual(self.recovery.bytes_in_flight, 1280) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 1) # packet ack'd self.recovery.on_ack_received( space, ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0 ) self.assertEqual(self.recovery.bytes_in_flight, 0) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 0) # check RTT self.assertFalse(self.recovery._rtt_initialized) self.assertEqual(self.recovery._rtt_latest, 0.0) self.assertEqual(self.recovery._rtt_min, math.inf) self.assertEqual(self.recovery._rtt_smoothed, 0.0) def test_on_packet_lost_crypto(self): packet = QuicSentPacket( epoch=tls.Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=PACKET_TYPE_INITIAL, sent_bytes=1280, sent_time=0.0, ) space = self.INITIAL_SPACE self.recovery.on_packet_sent(packet, space) self.assertEqual(self.recovery.bytes_in_flight, 1280) self.assertEqual(space.ack_eliciting_in_flight, 1) self.assertEqual(len(space.sent_packets), 1) self.recovery._detect_loss(space, now=1.0) self.assertEqual(self.recovery.bytes_in_flight, 0) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 0) class QuicRttMonitorTest(TestCase): def test_monitor(self): monitor = QuicRttMonitor() self.assertFalse(monitor.is_rtt_increasing(rtt=10, now=1000)) self.assertEqual(monitor._samples, [10, 0.0, 0.0, 0.0, 0.0]) self.assertFalse(monitor._ready) # not taken into account self.assertFalse(monitor.is_rtt_increasing(rtt=11, now=1000)) self.assertEqual(monitor._samples, [10, 0.0, 0.0, 0.0, 0.0]) self.assertFalse(monitor._ready) self.assertFalse(monitor.is_rtt_increasing(rtt=11, now=1001)) self.assertEqual(monitor._samples, [10, 11, 0.0, 0.0, 0.0]) self.assertFalse(monitor._ready) self.assertFalse(monitor.is_rtt_increasing(rtt=12, now=1002)) self.assertEqual(monitor._samples, [10, 11, 12, 0.0, 0.0]) self.assertFalse(monitor._ready) self.assertFalse(monitor.is_rtt_increasing(rtt=13, now=1003)) self.assertEqual(monitor._samples, [10, 11, 12, 13, 0.0]) self.assertFalse(monitor._ready) # we now have enough samples self.assertFalse(monitor.is_rtt_increasing(rtt=14, now=1004)) self.assertEqual(monitor._samples, [10, 11, 12, 13, 14]) self.assertTrue(monitor._ready) self.assertFalse(monitor.is_rtt_increasing(rtt=20, now=1005)) self.assertEqual(monitor._increases, 0) self.assertFalse(monitor.is_rtt_increasing(rtt=30, now=1006)) self.assertEqual(monitor._increases, 0) self.assertFalse(monitor.is_rtt_increasing(rtt=40, now=1007)) self.assertEqual(monitor._increases, 0) self.assertFalse(monitor.is_rtt_increasing(rtt=50, now=1008)) self.assertEqual(monitor._increases, 0) self.assertFalse(monitor.is_rtt_increasing(rtt=60, now=1009)) self.assertEqual(monitor._increases, 1) self.assertFalse(monitor.is_rtt_increasing(rtt=70, now=1010)) self.assertEqual(monitor._increases, 2) self.assertFalse(monitor.is_rtt_increasing(rtt=80, now=1011)) self.assertEqual(monitor._increases, 3) self.assertFalse(monitor.is_rtt_increasing(rtt=90, now=1012)) self.assertEqual(monitor._increases, 4) self.assertTrue(monitor.is_rtt_increasing(rtt=100, now=1013)) self.assertEqual(monitor._increases, 5)