Reliable-Transport-Protocol / 3700send
3700send
Raw
#!/usr/bin/env -S python3 -u

import argparse, socket, time, json, select, struct, sys, math, time, hashlib 

DATA_SIZE = 1375

# Sending class that accepts data 
# and sends it across the network
class Sender:
    def __init__(self, host, port):
        self.host = host
        self.remote_port = int(port)
        self.log("Sender starting up using port %s" % self.remote_port)
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.socket.bind(('0.0.0.0', 0))
        self.waiting = False

        self.seen_seq_no = set()
        self.lps = 0
        self.window_size = 1
        self.done_sending = False
        self.awaiting_ack = []
        self.retransmit_queue = []
        self.rtt_log = []
        self.thresh = 20
        self.j = 0
        self.initial_window = 0
        self.ls = 0
        self.full_window = False

    # log some message to stderr
    def log(self, message):
        sys.stderr.write(message + "\n")
        sys.stderr.flush()

    # send a message
    def send(self, message):
        self.socket.sendto(json.dumps(message).encode('utf-8'), (self.host, self.remote_port))

    # runs the sender. continually runs select on the available sockets 
    # to determine whether to send or receive packets
    def run(self):
        while True:
            sockets = [self.socket, sys.stdin] if not self.waiting else [self.socket]
            socks = select.select(sockets, [], [], 0.1)[0]
            for conn in socks:
                if conn == self.socket: # receiving packet(s)
                    count, increase_w = 0,0
                    initial_window = self.window_size
                    self.full_window = False
                    batch = set()
                    for i in range(self.window_size):
                        try: 
                            k, addr = conn.recvfrom(65535)
                        except: 
                            break # no more packets to receive
                        try: 
                            msg = json.loads(k.decode('utf-8'))
                        except: 
                            continue # received message was corrupted

                        batch, increase_w = self.handle_recv(msg['type'],msg['seqno'],batch,increase_w)
                        self.seen_seq_no.add(msg['seqno'])
                        self.check_complete()
                        self.waiting = False
                        count = i + 1
                    self.update_window(count,initial_window,increase_w)

                elif conn == sys.stdin: # sending packet(s)
                    # retransmit as many packets as possible
                    num_to_retransmit = min(self.window_size, len(self.retransmit_queue))
                    self.retransmit(num_to_retransmit)

                    # send as many packets as the remaining window allows
                    for i in range(self.window_size - num_to_retransmit):
                        data = sys.stdin.read(DATA_SIZE)
                        # when there is no data left to send, we are done sending and can be done when all acks are received
                        if len(data) == 0:
                            self.done_sending = True
                            self.check_complete()
                            break

                        self.seen_seq_no.add(self.lps)
                        self.send({"type":"msg","seqno":self.lps,"checksum":self.get_checksum(data),"data":data})
                        self.lps += 1
                        self.ls = time.time()
                        self.awaiting_ack.append((time.time(),self.lps-1,data)) 

                    self.waiting = True
            self.idle_dr(time.time())

    # handle receiving a message, and determine the correct actions to take depending if 
    # it is an ack or a retransmit request, returns batch and potential window incr num.
    def handle_recv(self,msg_type,seqno,batch,increase_w):
        if msg_type == 'ack':
            self.log("received ack message")
            self.update_timeout(seqno - 1)
            self.update_awaiting(seqno)
            if seqno in batch: 
                self.reset_timeouts()
            batch.add(seqno)
            if self.window_size + increase_w < self.thresh: 
                increase_w += 3
            return batch, increase_w
        else:
            self.log("received retransmit request")
            self.handle_retransmit_req(seqno)
            self.update_timeout(seqno)
            if seqno+1 in batch: 
                self.reset_timeouts()
            batch.add(seqno + 1)
            return batch, increase_w

    # when the sender is idle, periodically check for sent packets that have timed out
    # relatedly, retransmit packets from the queue if the sender is waiting
    def idle_dr(self,endtime):
        if self.rtt_log:
            if endtime - self.ls > (sum(self.rtt_log) / len(self.rtt_log)):
                self.log("detecting if packet was dropped...")
                self.detect_drop() 
        else:
            if endtime - self.ls > 1.25:
                self.log("detecting if packet was dropped...")
                self.detect_drop() 
        if self.waiting and self.retransmit_queue:
            self.log("retransmitting from queue...")
            self.retransmit(min(self.window_size, len(self.retransmit_queue)))

    # update the window size based on observations from incoming packets
    def update_window(self,count,initial_window,increase_w):
        if count == initial_window:
            self.full_window = True
        if self.window_size >= self.thresh:
            self.window_size += 1
        else:
            self.window_size += increase_w

    # resets the timeouts of all sent packets awaiting an ack
    def reset_timeouts(self):
        self.log("Resetting timeouts for awaiting acks")
        for i in range(len(self.awaiting_ack)):
            self.awaiting_ack[i] = ((time.time(),self.awaiting_ack[i][1],self.awaiting_ack[i][2]))

    # updates the socket timeout to the average rtt and returns
    def update_timeout(self,seqno):
        for p in self.awaiting_ack:
            if p[1] == seqno:
                d = time.time() - p[0]
                self.rtt_log.append(d + 0.10)
                t = (sum(self.rtt_log) / len(self.rtt_log))
                self.socket.settimeout(t)
                return

    # handles a retransmit message caused by a corrupt packet
    # adds corresponding packet to retransmit queue
    def handle_retransmit_req(self,seqno):
        for p in self.awaiting_ack:
            if p[1] == seqno:
                self.awaiting_ack.remove(p)
                self.retransmit_queue.append((p[1],p[2]))
                break

    # calculates and returns the checksum for a given data string
    def get_checksum(self,data):
        self.log("Calculating checksum...")
        # use hashlib to ensure we get distinct val each time
        checksum = hashlib.md5(data.encode('utf-8')).hexdigest()
        return checksum

    # retransmit as many packets from the queue as the window size allows
    def retransmit(self,num_to_retransmit):
        for p in range(num_to_retransmit):
            data = self.retransmit_queue[p][1]
            msg = { "type": "msg", "seqno": self.retransmit_queue[p][0], "checksum": self.get_checksum(data), "data":data }
            self.log("Retransmitting message '%s'" % msg)
            self.send(msg)
            self.awaiting_ack.append((time.time(),self.retransmit_queue[p][0],data))
        self.retransmit_queue = self.retransmit_queue[num_to_retransmit:]

    # determine whether the final ack has been received
    # if so: exit, otherwise returns False
    def check_complete(self):
        if self.done_sending and not self.awaiting_ack and self.lps == max(self.seen_seq_no):
            self.log("Final ack received, exiting")
            sys.exit(0)
        else:
            False

    # remove a given packet from the awaiting list once an ack has been received
    def update_awaiting(self,ack_no):
        temp_list = []
        endtime = time.time()
        for p in self.awaiting_ack:
            if p[1] != ack_no - 1:
                temp_list.append(p)
        self.awaiting_ack = temp_list

    # detect if a packet has been dropped from the network
    # add packet to retransmission queue if the ack was not received in time
    def detect_drop(self):
        endtime = time.time()
        temp_list = []
        for p in self.awaiting_ack:
            if not self.rtt_log:
                if endtime - p[0] > 1.25:
                    self.log("adding to retransmit queue...")
                    self.retransmit_queue.append((p[1],p[2]))
                    self.thresh = max(1,int(self.thresh * 0.8))
                    self.window_size = max(1,self.window_size - 1)
                    continue
            # ack not received in time - packet must be retransmitted
            elif endtime - p[0] > (sum(self.rtt_log) / len(self.rtt_log)):
                self.log("adding to retransmit queue...")
                self.retransmit_queue.append((p[1],p[2]))
                if not self.full_window:
                    self.thresh = max(1,int(self.thresh * 0.8))
                    self.window_size = max(1,self.window_size - 1)
                    continue
            # still awaiting ack
            print(p)
            self.log("still awaiting ack for seqno %d" %(p[1]))
            temp_list.append(p)
        self.awaiting_ack = temp_list

    # handle a drop request, updating the window and threshold if a full window wasn't received or if rtt_log is empty
    def handle_drop(self,seqno,data,empty):
        self.log("handling drop...")
        self.retransmit_queue.append((seqno,data))
        if empty or not self.full_window:
            self.thresh = max(1,int(self.thresh * 0.8))
            self.window_size = max(1,self.window_size - 1)

# runs sender class when executed as script
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='send data')
    parser.add_argument('host', type=str, help="Remote host to connect to")
    parser.add_argument('port', type=int, help="UDP port number to connect to")
    args = parser.parse_args()
    sender = Sender(args.host, args.port)
    sender.run()