SDSM-for-SDI / src / aioquic / asyncio / protocol.py
protocol.py
Raw
import asyncio
import time #CARLO
from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast

from ..quic import events
from ..quic.connection import NetworkAddress, QuicConnection

QuicConnectionIdHandler = Callable[[bytes], None]
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]


class QuicConnectionProtocol(asyncio.DatagramProtocol):
    def __init__(
        self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None
    ):
        loop = asyncio.get_event_loop()

        self._closed = asyncio.Event()
        self._connected = False
        self._connected_waiter: Optional[asyncio.Future[None]] = None
        self._loop = loop
        self._ping_waiters: Dict[int, asyncio.Future[None]] = {}
        self._quic = quic
        self._stream_readers: Dict[int, asyncio.StreamReader] = {}
        self._timer: Optional[asyncio.TimerHandle] = None
        self._timer_at: Optional[float] = None
        self._transmit_task: Optional[asyncio.Handle] = None
        self._transport: Optional[asyncio.DatagramTransport] = None

        # callbacks
        self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
        self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
        self._connection_terminated_handler: Callable[[], None] = lambda: None
        if stream_handler is not None:
            self._stream_handler = stream_handler
        else:
            self._stream_handler = lambda r, w: None

    def change_connection_id(self) -> None:
        """
        Change the connection ID used to communicate with the peer.

        The previous connection ID will be retired.
        """
        self._quic.change_connection_id()
        self.transmit()

    def close(self) -> None:
        """
        Close the connection.
        """
        self._quic.close()
        self.transmit()

    def connect(self, addr: NetworkAddress) -> None:
        """
        Initiate the TLS handshake.

        This method can only be called for clients and a single time.
        """
        self._quic.connect(addr, now=self._loop.time())
        self.transmit()

    async def create_stream(
        self, is_unidirectional: bool = False
    ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
        """
        Create a QUIC stream and return a pair of (reader, writer) objects.

        The returned reader and writer objects are instances of :class:`asyncio.StreamReader`
        and :class:`asyncio.StreamWriter` classes.
        """
        stream_id = self._quic.get_next_available_stream_id(
            is_unidirectional=is_unidirectional
        )
        return self._create_stream(stream_id)

    def request_key_update(self) -> None:
        """
        Request an update of the encryption keys.
        """
        self._quic.request_key_update()
        self.transmit()

    async def ping(self) -> None:
        """
        Ping the peer and wait for the response.
        """
        waiter = self._loop.create_future()
        uid = id(waiter)
        self._ping_waiters[uid] = waiter
        self._quic.send_ping(uid)
        self.transmit()
        await asyncio.shield(waiter)

    def transmit(self, counter = 0, hmstrategy = 0, n_request_migration = 0, interval_migration = 0) -> None: #DEBUG2 TEST* DEBUG V2* PERF EV AUTOMATION* DEBUG V3*
        """
        Send pending datagrams to the peer and arm the timer if needed.
        """
        self._transmit_task = None

        #now = time.time() #CARLO
        #print("Inside protocol.transmit at: " + str(now)) #CARLO

        # send datagrams
        for data, addr in self._quic.datagrams_to_send(counter, hmstrategy, n_request_migration, interval_migration, now=self._loop.time()):   #DEBUG2 TEST* DEBUG V2* PERF EV AUTOMATION* DEBUG V3*
            self._transport.sendto(data, addr)

        # re-arm timer
        timer_at = self._quic.get_timer()
        if self._timer is not None and self._timer_at != timer_at:
            self._timer.cancel()
            self._timer = None
        if self._timer is None and timer_at is not None:
            self._timer = self._loop.call_at(timer_at, self._handle_timer)
        self._timer_at = timer_at

    async def wait_closed(self) -> None:
        """
        Wait for the connection to be closed.
        """
        await self._closed.wait()

    async def wait_connected(self) -> None:
        """
        Wait for the TLS handshake to complete.
        """
        assert self._connected_waiter is None, "already awaiting connected"
        if not self._connected:
            self._connected_waiter = self._loop.create_future()
            await asyncio.shield(self._connected_waiter)

    # asyncio.Transport

    def connection_made(self, transport: asyncio.BaseTransport) -> None:
        self._transport = cast(asyncio.DatagramTransport, transport)

    def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: 
        self._quic.receive_datagram(cast(bytes, data), addr,  now=self._loop.time()) 
        self._process_events()
        self.transmit()

    # overridable

    def quic_event_received(self, event: events.QuicEvent) -> None:
        """
        Called when a QUIC event is received.

        Reimplement this in your subclass to handle the events.
        """
        # FIXME: move this to a subclass
        if isinstance(event, events.ConnectionTerminated):
            for reader in self._stream_readers.values():
                reader.feed_eof()
        elif isinstance(event, events.StreamDataReceived):
            reader = self._stream_readers.get(event.stream_id, None)
            if reader is None:
                reader, writer = self._create_stream(event.stream_id)
                self._stream_handler(reader, writer)
            reader.feed_data(event.data)
            if event.end_stream:
                reader.feed_eof()

    # private

    def _create_stream(
        self, stream_id: int
    ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
        adapter = QuicStreamAdapter(self, stream_id)
        reader = asyncio.StreamReader()
        writer = asyncio.StreamWriter(adapter, None, reader, self._loop)
        self._stream_readers[stream_id] = reader
        return reader, writer

    def _handle_timer(self) -> None:
        now = max(self._timer_at, self._loop.time())
        self._timer = None
        self._timer_at = None
        self._quic.handle_timer(now=now)
        self._process_events()
        self.transmit()

    def _process_events(self) -> None:
        event = self._quic.next_event()
        while event is not None:
            if isinstance(event, events.ConnectionIdIssued):
                self._connection_id_issued_handler(event.connection_id)
            elif isinstance(event, events.ConnectionIdRetired):
                self._connection_id_retired_handler(event.connection_id)
            elif isinstance(event, events.ConnectionTerminated):
                self._connection_terminated_handler()

                # abort connection waiter
                if self._connected_waiter is not None:
                    waiter = self._connected_waiter
                    self._connected_waiter = None
                    waiter.set_exception(ConnectionError)

                # abort ping waiters
                for waiter in self._ping_waiters.values():
                    waiter.set_exception(ConnectionError)
                self._ping_waiters.clear()

                self._closed.set()
            elif isinstance(event, events.HandshakeCompleted):
                if self._connected_waiter is not None:
                    waiter = self._connected_waiter
                    self._connected = True
                    self._connected_waiter = None
                    waiter.set_result(None)
            elif isinstance(event, events.PingAcknowledged):
                waiter = self._ping_waiters.pop(event.uid, None)
                if waiter is not None:
                    waiter.set_result(None)
            self.quic_event_received(event)
            event = self._quic.next_event()

    def _transmit_soon(self) -> None:
        if self._transmit_task is None:
            self._transmit_task = self._loop.call_soon(self.transmit)


class QuicStreamAdapter(asyncio.Transport):
    def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
        self.protocol = protocol
        self.stream_id = stream_id

    def can_write_eof(self) -> bool:
        return True

    def get_extra_info(self, name: str, default: Any = None) -> Any:
        """
        Get information about the underlying QUIC stream.
        """
        if name == "stream_id":
            return self.stream_id

    def write(self, data):
        self.protocol._quic.send_stream_data(self.stream_id, data)
        self.protocol._transmit_soon()

    def write_eof(self):
        self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
        self.protocol._transmit_soon()