Reliable-Transport-Protocol / run
run
Raw
#!/usr/bin/env python3 

import atexit
import re
import sys
import os
import time
import json
import random
import select
import binascii
import signal
import traceback
import socket
import subprocess
import struct
from threading import Thread
from functools import reduce
from collections import defaultdict

#### PARAMETERS

SENDER_EXECUTABLE_NAME = "3700send"
RECEVIER_EXECUTABLE_NAME = "3700recv"
MAX_WRITE = 10000
DEFAULT_SLEEP = 1
LOG_LEVEL = 1

simulator = None

def die(msg):
  print("\nError: %s" % msg)
  simulator.stop()
  sys.exit(0)

start = time.time()

def log(caller, msg, level=0):
  if level <= LOG_LEVEL:
    print("[%07.4f  %12s]: %s" % (now(), caller, msg))


def get_config(config_file):
  # Load config file
  if not os.path.exists(config_file):
    die("Could not find config file '%s'" % config_file)

  try:
    with open(config_file) as f:
      config_data = f.read()
  except Exception as e:
    die("Unable to read data from config file '%s': %s" % (config_file, e))

  try:
    config = json.loads(config_data)
  except Exception as e:
    die("Unable to parse JSON in config file '%s': %s" % (config_file, e))

  return config

def get_executable(executable):
  if not os.path.exists(executable):
    die("Could not find program '%s'" % executable)

  if not os.access(executable, os.X_OK):
    die("Could not execute program '%s'" % executable)

#### WRAPPER CODE

class FDWrapper:
  def __init__(self, fd, parent):
    self.fd = fd
    self.parent = parent

  def fileno(self):
    return self.fd.fileno()

class Wrapper:
  def __init__(self, executable, simulator):
    self.executable = executable
    self.simulator = simulator

    self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    self.socket.bind(('localhost', 0))
    self.port = self.socket.getsockname()[1]
    self.started = False
    self.remote_port = None

    self.process = None

    self.received_data = bytearray()

    self.packets = 0
    self.bytes = 0

  def __str__(self):
    return self.executable

  def bytes_sent(self, length):
    self.packets += 1
    self.bytes += length
    
  def is_started(self):
    return self.process is not None

  def get_args(self):
    raise ValueError("Must be implemented by subclass")

  def start(self):
    args = "%s %s" % (os.path.join(".", self.executable), " ".join(self.get_args()))
    log("Simulator", "Starting %s with command '%s'" % (self.executable, args))
    self.process = subprocess.Popen(args,
                                    shell=True,
                                    stdin=subprocess.PIPE,
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.PIPE,
                                    preexec_fn=os.setsid)

    def make_non_blocking(fd):
      try:
        from fcntl import fcntl, F_GETFL, F_SETFL
        flags = fcntl(fd, F_GETFL) # get current p.stdout flags
        fcntl(fd, F_SETFL, flags | os.O_NONBLOCK)
      except ImportError:
        print("Warning:  Unable to load fcntl module; things may not work as expected.")

    make_non_blocking(self.process.stdout)
    make_non_blocking(self.process.stderr)

    atexit.register(self.stop)
    self.started = True

  def ready(self):
    return self.remote_port != None

  def stop(self):
    if self.process and self.process.poll() is None:
      os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
    self.process = None

  def get_read_fds(self):
    if self.process:
      return [FDWrapper(self.process.stdout, self), FDWrapper(self.process.stderr, self), FDWrapper(self.socket, self)]
    else:
      return []

  def read(self, fd):
    if fd.fd == self.socket:
      data, addr = fd.fd.recvfrom(65535)
      if not self.remote_port:
        self.remote_port = addr[1]
      
      simulator.packet_received(self, data)
    else:
      data = fd.fd.read(100000)

      if len(data) == 0 and self.executable == SENDER_EXECUTABLE_NAME:
        simulator.check_final()

      if self.process.returncode is not None or len(data) == 0:
        die("%s crashed; exiting" % self.executable)
        
      if self.executable != RECEVIER_EXECUTABLE_NAME or fd.fd != self.process.stdout:
        for line in data.decode('utf-8').strip().split("\n"):
          log(self.executable, (" " * 50 if self.executable == "3700recv" else "" ) + line)
          m = re.match(r'Bound to port ([0-9]+)', line)
          if m:
            self.remote_port = int(m.group(1))
      else:
        self.received_data += data

  def send(self, data):
    self.socket.sendto(data, ('localhost', self.remote_port))

class Sender(Wrapper):
  def __init__(self, simulator, data):
    super().__init__(SENDER_EXECUTABLE_NAME, simulator)
    self.data = data
    self.thread = None

  def get_args(self):
    return ["127.0.0.1", str(self.port)]

  def start(self):
    super().start()
    def go():
      try:
        self.process.stdin.write(self.data)
        self.process.stdin.close()
      except BrokenPipeError:
        log("Simulator", "Pipe to %s broken" % self.executable)
        pass

    self.thread = Thread(target=go)
    self.thread.start() 

  def stop(self):
    super().stop()
    self.thread.join()

class Receiver(Wrapper):
  def __init__(self, simulator):
    super().__init__(RECEVIER_EXECUTABLE_NAME, simulator)

  def get_args(self):
    return []

#### SIMULATOR

class EnqueuedPacket:
  def __init__(self, data, ts):
    self.data = data
    self.ts = ts

class Buffer:
  def __init__(self, name, config):
    self.name = name
    self.config = config
    self.buffer = []
    self.bandwidth = self.config["network"]["bandwidth"]
    self.buffer_size = self.config["network"]["buffer"]
    self.busy_until = time.time()
    self.packet_sending = None

  def log(self, message, level=2):
    log("%s Buffer" % self.name, message, level)

  def enqueue(self, data):
    # drop packets beyond bandwidth delay product
    size = sum(map(lambda ep: len(ep.data), self.buffer))
    if size + len(data) > self.buffer_size:
      log("Simulator", "Dropping packet due to router queue full")
      return

    self.buffer.append(EnqueuedPacket(data, time.time()))

  def ready_to_deliver(self, start):
    result = []
    if self.busy_until > start:
      self.log("Still sending packet; nothing to deliver")
      return result

    if self.packet_sending is not None:
      result = [self.packet_sending]
      self.packet_sending = None
      self.log("Delivering packet %s" % result[0])

    if len(self.buffer) > 0:
      to_send = self.buffer.pop(0)
      self.packet_sending = to_send.data
      self.busy_until = start + len(to_send.data) * 1.0/self.bandwidth
      self.log("Starting to send packet %s" % self.packet_sending)
      self.log("Will be done in %.4f" % (self.busy_until - time.time()))

    return result

  def sleep_time(self):
    if self.packet_sending is None:
      self.log("No packet being sent, returning default sleep", 3)
      return DEFAULT_SLEEP

    diff = self.busy_until - time.time()
    self.log("Returning sleep time of %.4f" % diff)
    return diff if diff > 0 else 0

class Queue:
  def __init__(self, name, config):
    self.name = name
    self.config = config
    self.buffer = []
    self.delay = self.config["network"]["delay"]

  def log(self, message, level=2):
    log("%s Queue" % self.name, message, level)

  def enqueue(self, data, jitter):
    self.buffer.append(EnqueuedPacket(data, time.time() + self.delay + jitter))

  def ready_to_move_to_buffer(self, start):
    dequeued = list(filter(lambda ep: ep.ts <= start, self.buffer))
    self.buffer = list(filter(lambda ep: ep.ts > start, self.buffer))
    self.log("Dequeuing messages: %s" % dequeued, 3)

    return list(map(lambda ep: ep.data, dequeued))

  def sleep_time(self):
    if len(self.buffer) == 0:
      self.log("Empty buffer, returning default sleep", 3)
      return DEFAULT_SLEEP

    result = sorted(self.buffer, key=lambda ep: ep.ts)[0].ts - time.time()
    self.log("Returning sleep time of %.4f" % result)
    return result if result > 0 else 0

class Path:
  def __init__(self, name, config):
    self.name = name
    self.config = config
    self.queue = Queue(name, config)
    self.buffer = Buffer(name, config)

  def enqueue(self, data, jitter):
    self.queue.enqueue(data, jitter)

  def sleep_time(self):
    return min(self.queue.sleep_time(), self.buffer.sleep_time())

  def ready_to_deliver(self, start):
    for data in self.queue.ready_to_move_to_buffer(start):
      self.buffer.enqueue(data)

    return self.buffer.ready_to_deliver(start)

class Simulator:
  def __init__(self, config):
    self.config = config
    self.data = self.generate_data(config["data"])
    self.sender = Sender(self, self.data)
    self.receiver = Receiver(self)

    self.s_to_r = Path("S->R", self.config)
    self.r_to_s = Path("R->S", self.config)

  def generate_data(self, length):
    data = bytearray()
    i = 0
    while len(data) < length:
      blob = ("----- Block %07d -----" % i) + binascii.b2a_hex(os.urandom(675)).decode('utf-8')
      data += bytearray(blob.encode('utf-8'))
      i += 1

    return data

  def start(self):
    log("Simulator", "Beginning simulation")
    self.receiver.start()
    self.sender.start()

    while True:
      read_fds = self.receiver.get_read_fds() + self.sender.get_read_fds()

      sleep_time = min(self.r_to_s.sleep_time(), self.s_to_r.sleep_time())

      readable, _, _ = select.select(read_fds, [], [], sleep_time)
      start = time.time()

      for r in readable:
        r.parent.read(r)

      if not self.receiver.ready():
        log("Simulator", "Sleeping for 100ms to let receiver come up")
      elif not self.sender.started:
        self.sender.start()

      if now() > config["lifetime"]:
        die("Simulation time exceeded, and %s did not exit" % SENDER_EXECUTABLE_NAME)

      for data in self.r_to_s.ready_to_deliver(start):
        self.sender.send(data)

      for data in self.s_to_r.ready_to_deliver(start):
        self.receiver.send(data)

  def stop(self):
    self.receiver.stop()
    self.sender.stop()

  def packet_received(self, endpoint, data):
    endpoint.bytes_sent(len(data))

    def config_fraction(kind):
      return kind in config["network"] and random.uniform(0, 1) < config["network"][kind]

    def drop():
      return config_fraction("drop")

    def duplicate():
      return config_fraction("duplicate")

    def mangle():
      return config_fraction("mangle")

    def jitter():
      if "jitter" in config["network"]:
        return random.uniform(-1 * config["network"]["jitter"], config["network"]["jitter"])
      else:
        return 0

    if len(data) > 1500:
      log("Simulator", "Dropping too-big packet (%d) sent by %s" % (len(data), endpoint))
      return

    if drop():
      log("Simulator", "Dropping packet sent by %s" % endpoint)
      return

    if mangle():
      log("Simulator", "Mangling packet sent by %s" % endpoint)
      tmp = bytearray(data)
      for i in range(0, 5):
        tmp[random.randint(0, len(tmp)-1)] = 0x58
      data = bytes(tmp)

    if duplicate():
      log("Simulator", "Duplicating packet sent by %s" % endpoint)
      self.enqueue_packet(endpoint, data, jitter())

    self.enqueue_packet(endpoint, data, jitter())

  def enqueue_packet(self, endpoint, data, jitter):
    if endpoint == self.receiver:
      self.r_to_s.enqueue(data, jitter)
    else:
      self.s_to_r.enqueue(data, jitter)

  def check_final(self):
    if self.data == self.receiver.received_data:
      print("\nSuccess!  Data was transmitted correctly.")
    else:
      print("\nError -- data was not transmitted correctly.")
      print("Sent:\n%s\n\nReceived:\n%s\n" % (self.data, self.receiver.received_data))

    print("\nStats: %.4f total time, %d bytes/%d packets sent (%d/%d sender -> receiver, %d/%d receiver -> sender)" % (now(), self.sender.bytes + self.receiver.bytes, self.sender.packets + self.receiver.packets, self.sender.bytes, self.sender.packets, self.receiver.bytes, self.receiver.packets))

    sys.exit(0)

#### MAIN PROGRAM

if len(sys.argv) != 2:
  die("Usage: ./run config-file")

get_executable(SENDER_EXECUTABLE_NAME)
get_executable(RECEVIER_EXECUTABLE_NAME)
config = get_config(sys.argv[1])

if "seed" in config:
  random.seed(config["seed"])

# Set up the bridges, get LAN info
simulator = Simulator(config)

def now():
  return time.time() - start

try:
  simulator.start()
except Exception as e:
  traceback.print_exc()
  die("Got exception %s" % e)
except KeyboardInterrupt:
  die("Received keyboard interrupt")