import asyncio
import time #CARLO
from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast
from ..quic import events
from ..quic.connection import NetworkAddress, QuicConnection
QuicConnectionIdHandler = Callable[[bytes], None]
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
class QuicConnectionProtocol(asyncio.DatagramProtocol):
def __init__(
self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None
):
loop = asyncio.get_event_loop()
self._closed = asyncio.Event()
self._connected = False
self._connected_waiter: Optional[asyncio.Future[None]] = None
self._loop = loop
self._ping_waiters: Dict[int, asyncio.Future[None]] = {}
self._quic = quic
self._stream_readers: Dict[int, asyncio.StreamReader] = {}
self._timer: Optional[asyncio.TimerHandle] = None
self._timer_at: Optional[float] = None
self._transmit_task: Optional[asyncio.Handle] = None
self._transport: Optional[asyncio.DatagramTransport] = None
# callbacks
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
self._connection_terminated_handler: Callable[[], None] = lambda: None
if stream_handler is not None:
self._stream_handler = stream_handler
else:
self._stream_handler = lambda r, w: None
def change_connection_id(self) -> None:
"""
Change the connection ID used to communicate with the peer.
The previous connection ID will be retired.
"""
self._quic.change_connection_id()
self.transmit()
def close(self) -> None:
"""
Close the connection.
"""
self._quic.close()
self.transmit()
def connect(self, addr: NetworkAddress) -> None:
"""
Initiate the TLS handshake.
This method can only be called for clients and a single time.
"""
self._quic.connect(addr, now=self._loop.time())
self.transmit()
async def create_stream(
self, is_unidirectional: bool = False
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""
Create a QUIC stream and return a pair of (reader, writer) objects.
The returned reader and writer objects are instances of :class:`asyncio.StreamReader`
and :class:`asyncio.StreamWriter` classes.
"""
stream_id = self._quic.get_next_available_stream_id(
is_unidirectional=is_unidirectional
)
return self._create_stream(stream_id)
def request_key_update(self) -> None:
"""
Request an update of the encryption keys.
"""
self._quic.request_key_update()
self.transmit()
async def ping(self) -> None:
"""
Ping the peer and wait for the response.
"""
waiter = self._loop.create_future()
uid = id(waiter)
self._ping_waiters[uid] = waiter
self._quic.send_ping(uid)
self.transmit()
await asyncio.shield(waiter)
def transmit(self, counter = 0, hmstrategy = 0, n_request_migration = 0, interval_migration = 0) -> None: #DEBUG2 TEST* DEBUG V2* PERF EV AUTOMATION* DEBUG V3*
"""
Send pending datagrams to the peer and arm the timer if needed.
"""
self._transmit_task = None
#now = time.time() #CARLO
#print("Inside protocol.transmit at: " + str(now)) #CARLO
# send datagrams
for data, addr in self._quic.datagrams_to_send(counter, hmstrategy, n_request_migration, interval_migration, now=self._loop.time()): #DEBUG2 TEST* DEBUG V2* PERF EV AUTOMATION* DEBUG V3*
self._transport.sendto(data, addr)
# re-arm timer
timer_at = self._quic.get_timer()
if self._timer is not None and self._timer_at != timer_at:
self._timer.cancel()
self._timer = None
if self._timer is None and timer_at is not None:
self._timer = self._loop.call_at(timer_at, self._handle_timer)
self._timer_at = timer_at
async def wait_closed(self) -> None:
"""
Wait for the connection to be closed.
"""
await self._closed.wait()
async def wait_connected(self) -> None:
"""
Wait for the TLS handshake to complete.
"""
assert self._connected_waiter is None, "already awaiting connected"
if not self._connected:
self._connected_waiter = self._loop.create_future()
await asyncio.shield(self._connected_waiter)
# asyncio.Transport
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
self._process_events()
self.transmit()
# overridable
def quic_event_received(self, event: events.QuicEvent) -> None:
"""
Called when a QUIC event is received.
Reimplement this in your subclass to handle the events.
"""
# FIXME: move this to a subclass
if isinstance(event, events.ConnectionTerminated):
for reader in self._stream_readers.values():
reader.feed_eof()
elif isinstance(event, events.StreamDataReceived):
reader = self._stream_readers.get(event.stream_id, None)
if reader is None:
reader, writer = self._create_stream(event.stream_id)
self._stream_handler(reader, writer)
reader.feed_data(event.data)
if event.end_stream:
reader.feed_eof()
# private
def _create_stream(
self, stream_id: int
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
adapter = QuicStreamAdapter(self, stream_id)
reader = asyncio.StreamReader()
writer = asyncio.StreamWriter(adapter, None, reader, self._loop)
self._stream_readers[stream_id] = reader
return reader, writer
def _handle_timer(self) -> None:
now = max(self._timer_at, self._loop.time())
self._timer = None
self._timer_at = None
self._quic.handle_timer(now=now)
self._process_events()
self.transmit()
def _process_events(self) -> None:
event = self._quic.next_event()
while event is not None:
if isinstance(event, events.ConnectionIdIssued):
self._connection_id_issued_handler(event.connection_id)
elif isinstance(event, events.ConnectionIdRetired):
self._connection_id_retired_handler(event.connection_id)
elif isinstance(event, events.ConnectionTerminated):
self._connection_terminated_handler()
# abort connection waiter
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected_waiter = None
waiter.set_exception(ConnectionError)
# abort ping waiters
for waiter in self._ping_waiters.values():
waiter.set_exception(ConnectionError)
self._ping_waiters.clear()
self._closed.set()
elif isinstance(event, events.HandshakeCompleted):
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected = True
self._connected_waiter = None
waiter.set_result(None)
elif isinstance(event, events.PingAcknowledged):
waiter = self._ping_waiters.pop(event.uid, None)
if waiter is not None:
waiter.set_result(None)
self.quic_event_received(event)
event = self._quic.next_event()
def _transmit_soon(self) -> None:
if self._transmit_task is None:
self._transmit_task = self._loop.call_soon(self.transmit)
class QuicStreamAdapter(asyncio.Transport):
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
self.protocol = protocol
self.stream_id = stream_id
def can_write_eof(self) -> bool:
return True
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""
Get information about the underlying QUIC stream.
"""
if name == "stream_id":
return self.stream_id
def write(self, data):
self.protocol._quic.send_stream_data(self.stream_id, data)
self.protocol._transmit_soon()
def write_eof(self):
self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
self.protocol._transmit_soon()