Reliable-Transport-Protocol / 3700recv
3700recv
Raw
#!/usr/bin/env -S python3 -u

import argparse, socket, time, json, select, struct, sys, math, hashlib

# Receiving class that receives data 
# and prints it out in-order
class Receiver:
    def __init__(self):
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.socket.bind(('0.0.0.0', 0))
        self.port = self.socket.getsockname()[1]
        self.log("Bound to port %d" % self.port)

        self.remote_host = None
        self.remote_port = None

        self.seen_seq_no = set()
        self.print_queue = []
        self.last_printed = -1

    # send a message
    def send(self, message):
        self.socket.sendto(json.dumps(message).encode('utf-8'), (self.remote_host, self.remote_port))

    # log some message to stderr
    def log(self, message):
        sys.stderr.write(message + "\n")
        sys.stderr.flush()

    # receive packets from the sender, verify checksums, maintain a print queue, and send back an ack
    def run(self):
        while True:
            socks = select.select([self.socket], [], [])[0]
            for conn in socks:
                data, addr = conn.recvfrom(65535)
                # grab the remote host/port if we don't already have it
                if self.remote_host is None:
                    self.remote_host = addr[0]
                    self.remote_port = addr[1]
                try:
                    msg = json.loads(data.decode('utf-8'))
                except: 
                    continue
                # verify correct checksum
                if self.check_checksum(msg):
                    continue
                if msg['seqno'] not in self.seen_seq_no:
                    self.print_in_order(msg)
                self.seen_seq_no.add(msg["seqno"])
                # send back an ack
                self.log("sending ack...")
                self.send({ "type": "ack", "seqno": msg["seqno"] + 1 })
        return
    
    # verify that the checksum of a received message is correct
    # and return bool, otherwise takes appropriate action by sending retransmit message
    def check_checksum(self,msg):
        recv_checksum = self.calculate_checksum(msg['data'])
        if recv_checksum != msg['checksum']:
            self.log("Mismatched checksums. Sender checksum: %s vs. Receiver checksum: %s" % (msg['checksum'],recv_checksum))
            try:
                self.send({ "type": "retransmit", "seqno": msg["seqno"] })
            except:
                pass
            return True
        return False
            
    # add packet to print queue and print if consecutive senqos received
    def print_in_order(self,msg):
        # add message to print queue, maintain sorted order
        if (msg['seqno'],msg['data']) not in self.print_queue:
            self.insert_sorted(msg)
       
        # if the queue is comprised of consecutive packets, print them
        len_q = len(self.print_queue)
        first = self.print_queue[0][0]
        if first == self.print_queue[len_q-1][0] - (len_q - 1) and self.last_printed == first - 1:
            for i in range(len_q):
                sys.stdout.write(self.print_queue[i][1])
            self.last_printed = self.print_queue[len_q-1][0]
            self.print_queue = []
            return
       
        # go through the queue and see if a subset is consecutive, then print that subset
        if self.last_printed == first - 1:
            for i in reversed(range(len(self.print_queue))):
                if first == self.print_queue[i][0] - (i - 1):
                    for j in range(i):
                        sys.stdout.write(self.print_queue[j][1])
                    self.last_printed = self.print_queue[i][0]
                    self.print_queue = self.print_queue[i:]

    # calculates and returns the checksum for a given data string
    def calculate_checksum(self,data):
        self.log("calculating expected checksum...")
        checksum = hashlib.md5(data.encode('utf-8')).hexdigest()
        return checksum

    # insert a packet msg into the sorted print queue
    def insert_sorted(self,msg):
        self.log("inserting message into print queue")
        # find index to place msg
        index = len(self.print_queue)
        for i in range(len(self.print_queue)):
            if self.print_queue[i][0] > msg['seqno']:
               index = i
               break
        # insert msg at appropriate index
        if index == len(self.print_queue):
            self.print_queue = self.print_queue[:index] + [(msg['seqno'],msg['data'])]
        else:
            self.print_queue = self.print_queue[:index] + [(msg['seqno'],msg['data'])] + self.print_queue[index:]

# runs receiver class when executed as script
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='receive data')
    args = parser.parse_args()
    sender = Receiver()
    sender.run()