#!/usr/bin/env -S python3 -u
import argparse, socket, time, json, select, struct, sys, math, time, hashlib
DATA_SIZE = 1375
# Sending class that accepts data
# and sends it across the network
class Sender:
def __init__(self, host, port):
self.host = host
self.remote_port = int(port)
self.log("Sender starting up using port %s" % self.remote_port)
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.bind(('0.0.0.0', 0))
self.waiting = False
self.seen_seq_no = set()
self.lps = 0
self.window_size = 1
self.done_sending = False
self.awaiting_ack = []
self.retransmit_queue = []
self.rtt_log = []
self.thresh = 20
self.j = 0
self.initial_window = 0
self.ls = 0
self.full_window = False
# log some message to stderr
def log(self, message):
sys.stderr.write(message + "\n")
sys.stderr.flush()
# send a message
def send(self, message):
self.socket.sendto(json.dumps(message).encode('utf-8'), (self.host, self.remote_port))
# runs the sender. continually runs select on the available sockets
# to determine whether to send or receive packets
def run(self):
while True:
sockets = [self.socket, sys.stdin] if not self.waiting else [self.socket]
socks = select.select(sockets, [], [], 0.1)[0]
for conn in socks:
if conn == self.socket: # receiving packet(s)
count, increase_w = 0,0
initial_window = self.window_size
self.full_window = False
batch = set()
for i in range(self.window_size):
try:
k, addr = conn.recvfrom(65535)
except:
break # no more packets to receive
try:
msg = json.loads(k.decode('utf-8'))
except:
continue # received message was corrupted
batch, increase_w = self.handle_recv(msg['type'],msg['seqno'],batch,increase_w)
self.seen_seq_no.add(msg['seqno'])
self.check_complete()
self.waiting = False
count = i + 1
self.update_window(count,initial_window,increase_w)
elif conn == sys.stdin: # sending packet(s)
# retransmit as many packets as possible
num_to_retransmit = min(self.window_size, len(self.retransmit_queue))
self.retransmit(num_to_retransmit)
# send as many packets as the remaining window allows
for i in range(self.window_size - num_to_retransmit):
data = sys.stdin.read(DATA_SIZE)
# when there is no data left to send, we are done sending and can be done when all acks are received
if len(data) == 0:
self.done_sending = True
self.check_complete()
break
self.seen_seq_no.add(self.lps)
self.send({"type":"msg","seqno":self.lps,"checksum":self.get_checksum(data),"data":data})
self.lps += 1
self.ls = time.time()
self.awaiting_ack.append((time.time(),self.lps-1,data))
self.waiting = True
self.idle_dr(time.time())
# handle receiving a message, and determine the correct actions to take depending if
# it is an ack or a retransmit request, returns batch and potential window incr num.
def handle_recv(self,msg_type,seqno,batch,increase_w):
if msg_type == 'ack':
self.log("received ack message")
self.update_timeout(seqno - 1)
self.update_awaiting(seqno)
if seqno in batch:
self.reset_timeouts()
batch.add(seqno)
if self.window_size + increase_w < self.thresh:
increase_w += 3
return batch, increase_w
else:
self.log("received retransmit request")
self.handle_retransmit_req(seqno)
self.update_timeout(seqno)
if seqno+1 in batch:
self.reset_timeouts()
batch.add(seqno + 1)
return batch, increase_w
# when the sender is idle, periodically check for sent packets that have timed out
# relatedly, retransmit packets from the queue if the sender is waiting
def idle_dr(self,endtime):
if self.rtt_log:
if endtime - self.ls > (sum(self.rtt_log) / len(self.rtt_log)):
self.log("detecting if packet was dropped...")
self.detect_drop()
else:
if endtime - self.ls > 1.25:
self.log("detecting if packet was dropped...")
self.detect_drop()
if self.waiting and self.retransmit_queue:
self.log("retransmitting from queue...")
self.retransmit(min(self.window_size, len(self.retransmit_queue)))
# update the window size based on observations from incoming packets
def update_window(self,count,initial_window,increase_w):
if count == initial_window:
self.full_window = True
if self.window_size >= self.thresh:
self.window_size += 1
else:
self.window_size += increase_w
# resets the timeouts of all sent packets awaiting an ack
def reset_timeouts(self):
self.log("Resetting timeouts for awaiting acks")
for i in range(len(self.awaiting_ack)):
self.awaiting_ack[i] = ((time.time(),self.awaiting_ack[i][1],self.awaiting_ack[i][2]))
# updates the socket timeout to the average rtt and returns
def update_timeout(self,seqno):
for p in self.awaiting_ack:
if p[1] == seqno:
d = time.time() - p[0]
self.rtt_log.append(d + 0.10)
t = (sum(self.rtt_log) / len(self.rtt_log))
self.socket.settimeout(t)
return
# handles a retransmit message caused by a corrupt packet
# adds corresponding packet to retransmit queue
def handle_retransmit_req(self,seqno):
for p in self.awaiting_ack:
if p[1] == seqno:
self.awaiting_ack.remove(p)
self.retransmit_queue.append((p[1],p[2]))
break
# calculates and returns the checksum for a given data string
def get_checksum(self,data):
self.log("Calculating checksum...")
# use hashlib to ensure we get distinct val each time
checksum = hashlib.md5(data.encode('utf-8')).hexdigest()
return checksum
# retransmit as many packets from the queue as the window size allows
def retransmit(self,num_to_retransmit):
for p in range(num_to_retransmit):
data = self.retransmit_queue[p][1]
msg = { "type": "msg", "seqno": self.retransmit_queue[p][0], "checksum": self.get_checksum(data), "data":data }
self.log("Retransmitting message '%s'" % msg)
self.send(msg)
self.awaiting_ack.append((time.time(),self.retransmit_queue[p][0],data))
self.retransmit_queue = self.retransmit_queue[num_to_retransmit:]
# determine whether the final ack has been received
# if so: exit, otherwise returns False
def check_complete(self):
if self.done_sending and not self.awaiting_ack and self.lps == max(self.seen_seq_no):
self.log("Final ack received, exiting")
sys.exit(0)
else:
False
# remove a given packet from the awaiting list once an ack has been received
def update_awaiting(self,ack_no):
temp_list = []
endtime = time.time()
for p in self.awaiting_ack:
if p[1] != ack_no - 1:
temp_list.append(p)
self.awaiting_ack = temp_list
# detect if a packet has been dropped from the network
# add packet to retransmission queue if the ack was not received in time
def detect_drop(self):
endtime = time.time()
temp_list = []
for p in self.awaiting_ack:
if not self.rtt_log:
if endtime - p[0] > 1.25:
self.log("adding to retransmit queue...")
self.retransmit_queue.append((p[1],p[2]))
self.thresh = max(1,int(self.thresh * 0.8))
self.window_size = max(1,self.window_size - 1)
continue
# ack not received in time - packet must be retransmitted
elif endtime - p[0] > (sum(self.rtt_log) / len(self.rtt_log)):
self.log("adding to retransmit queue...")
self.retransmit_queue.append((p[1],p[2]))
if not self.full_window:
self.thresh = max(1,int(self.thresh * 0.8))
self.window_size = max(1,self.window_size - 1)
continue
# still awaiting ack
print(p)
self.log("still awaiting ack for seqno %d" %(p[1]))
temp_list.append(p)
self.awaiting_ack = temp_list
# handle a drop request, updating the window and threshold if a full window wasn't received or if rtt_log is empty
def handle_drop(self,seqno,data,empty):
self.log("handling drop...")
self.retransmit_queue.append((seqno,data))
if empty or not self.full_window:
self.thresh = max(1,int(self.thresh * 0.8))
self.window_size = max(1,self.window_size - 1)
# runs sender class when executed as script
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='send data')
parser.add_argument('host', type=str, help="Remote host to connect to")
parser.add_argument('port', type=int, help="UDP port number to connect to")
args = parser.parse_args()
sender = Sender(args.host, args.port)
sender.run()