import java.io.*; import java.net.*; import java.util.*; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; public class Controller { private ServerSocket serverSocket; private int cport, R, timeout, rebalance_period; private Index index; private Boolean rebalancing; private ArrayList<Object> currentStores = new ArrayList<>(); private ArrayList<Object> currentRemoves = new ArrayList<>(); public static void main(String[] args) { int cport = Integer.parseInt(args[0]); int R = Integer.parseInt(args[1]); int timeout = Integer.parseInt(args[2]); int rebalance_period = Integer.parseInt(args[3]); try { Controller c = new Controller(cport, R, timeout, rebalance_period); c.start(); } catch (Exception e) { System.out.println("ah fuck: " + e); } } Controller(int cport, int R, int timeout, int rebalance_period) throws IOException { this.cport = cport; this.timeout = timeout; this.R = R; this.rebalance_period = rebalance_period; this.index = new Index(); this.rebalancing = false; ControllerLogger.init(Logger.LoggingType.ON_FILE_AND_TERMINAL); } class Index { public HashMap<String, String> status = new HashMap<>(); public HashMap<Integer, Socket> dstores = new HashMap<>(); public HashMap<String, Integer> fileSizes = new HashMap<>(); public HashMap<String, ArrayList<Integer>> fileLocations = new HashMap<>(); // public HashMap<String, ArrayList<Integer>> fileAcks = new HashMap<>(); public HashMap<String, CountDownLatch> fileAcks = new HashMap<>(); public HashMap<Socket, ArrayList<Integer>> locationsTried = new HashMap<>(); // public HashMap<String, ArrayList<Integer>> fileRemoveAcks = new HashMap<>(); public HashMap<String, CountDownLatch> fileRemoveAcks = new HashMap<>(); public HashMap<Integer, String> listResponses = new HashMap<>(); // public ArrayList<Integer> rebalanceAcks = new ArrayList<>(); public CountDownLatch rebalanceAcks; public String getFileList() { synchronized (this) { ArrayList<String> toSend = new ArrayList<>(); for(String filename : fileLocations.keySet()) { if(status.get(filename).equals("store_complete")) { toSend.add(filename); } } return String.join(" ", toSend); } } public void removeDStore(Integer port) { Socket client = dstores.get(port); dstores.remove(port); ArrayList<String> toRemove = new ArrayList<>(); for(Map.Entry<String, ArrayList<Integer>> entries : fileLocations.entrySet()) { entries.getValue().remove(port); if(entries.getValue().size() <= 0) { toRemove.add(entries.getKey()); } } for(String entry : toRemove) { removeEntry(entry); } try { client.close(); } catch (IOException e) { e.printStackTrace(); } } public void removeEntry(String filename) { status.remove(filename); fileSizes.remove(filename); fileLocations.remove(filename); fileAcks.remove(filename); fileRemoveAcks.remove(filename); } public Integer getFile(String filename) { Random r = new Random(); return fileLocations.get(filename).get(r.nextInt(fileLocations.get(filename).size())); } } private void sendMessage(Socket socket, PrintWriter writer, String message) { ControllerLogger.getInstance().messageSent(socket, message); writer.println(message); } class Rebalances implements Runnable { int rebalance_period; Rebalances(int r) { rebalance_period = r; } @Override public void run() { while(true) { try { Thread.sleep(rebalance_period); while(rebalancing) { Thread.sleep(5); } synchronized (rebalancing) { rebalancing = true; rebalance(); rebalancing = false; } } catch (InterruptedException | IOException e) { e.printStackTrace(); } } } } class HandleConnection implements Runnable { Socket client; HandleConnection(Socket client) { this.client = client; } public void run() { try { while(true) { BufferedReader reader = new BufferedReader(new InputStreamReader(client.getInputStream())); OutputStream output = client.getOutputStream(); PrintWriter writer = new PrintWriter(output, true); String input = reader.readLine(); if(input == null) { client.close(); break; } while(rebalancing) { Thread.sleep(5); } ControllerLogger.getInstance().messageReceived(client, input); String[] args = input.split(" "); if(R > index.dstores.size() && !args[0].equals("JOIN")) { sendMessage(client, writer, Protocol.ERROR_NOT_ENOUGH_DSTORES_TOKEN); continue; } try{ switch (args[0]) { case Protocol.JOIN_TOKEN: int port = Integer.parseInt(args[1]); loadDStore(port, client); return; case Protocol.STORE_TOKEN: Object t = new Object(); synchronized (currentStores) { currentStores.add(t); } try { String filename = args[1]; int fileSize = Integer.parseInt(args[2]); store(client, filename, fileSize); } catch (NumberFormatException e) { System.err.println("Malformed message in store received"); } synchronized (currentStores) { currentStores.remove(t); } break; case Protocol.LOAD_TOKEN: if(index.fileLocations.containsKey(args[1])) { index.locationsTried.put(client, new ArrayList<>(index.fileLocations.get(args[1]))); } load(client, args[1]); break; case Protocol.REMOVE_TOKEN: Object r = new Object(); synchronized (currentRemoves) { currentRemoves.add(r); } remove(client, args[1]); synchronized (currentRemoves) { currentRemoves.remove(r); } break; case Protocol.LIST_TOKEN: sendMessage(client, writer, "LIST " + index.getFileList()); break; case Protocol.RELOAD_TOKEN: if(index.locationsTried.containsKey(client)) { load(client, args[1]); } else { System.err.println("Client attempting to reload before attempting to load"); } break; default: System.err.println("Message malformed, did not understand"); break; } } catch (Exception e) { System.err.println("Error processing message"); System.err.println(e); } } } catch (Exception e) { e.printStackTrace(); } } } public void start() throws IOException { serverSocket = new ServerSocket(cport); Rebalances r = new Rebalances(rebalance_period); new Thread(r).start(); while (true) { Socket client = serverSocket.accept(); HandleConnection con = new HandleConnection(client); new Thread(con).start(); } } private void remove(Socket client, String filename) throws IOException, InterruptedException { OutputStream output = client.getOutputStream(); PrintWriter writer = new PrintWriter(output, true); InputStream in = client.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(in)); if(!index.fileLocations.containsKey(filename) || index.status.get(filename).equals("store_in_progress") || index.status.get(filename).equals("removal_in_progress")) { sendMessage(client, writer, Protocol.ERROR_FILE_DOES_NOT_EXIST_TOKEN); return; } synchronized (index) { index.status.put(filename, "removal_in_progress"); } index.fileRemoveAcks.put(filename, new CountDownLatch(index.fileLocations.get(filename).size())); for(int port : index.fileLocations.get(filename)) { Socket dStore = index.dstores.get(port); OutputStream dOutput = dStore.getOutputStream(); PrintWriter dWriter = new PrintWriter(dOutput, true); sendMessage(dStore, dWriter, Protocol.REMOVE_TOKEN + " " + filename); } Boolean gotAllAcks = index.fileRemoveAcks.get(filename).await(timeout, TimeUnit.MILLISECONDS); // long start = System.currentTimeMillis(); // while(index.fileRemoveAcks.get(filename).size() < index.fileLocations.get(filename).size() && (System.currentTimeMillis() - start) < timeout) { // Thread.sleep(1); // } // // if(index.fileRemoveAcks.get(filename).size() != index.fileLocations.get(filename).size()) { // System.err.println("ERROR_REMOVE_ACK_NOT_RECEIVED"); // } if(!gotAllAcks) { System.err.println("Not all file remove acks received"); } synchronized (index) { index.status.put(filename, "removal_complete"); index.removeEntry(filename); } sendMessage(client, writer, Protocol.REMOVE_COMPLETE_TOKEN); } private void load(Socket client, String filename) throws IOException, InterruptedException { OutputStream output = client.getOutputStream(); PrintWriter writer = new PrintWriter(output, true); InputStream in = client.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(in)); if(index.locationsTried.size() == 0) { sendMessage(client, writer, Protocol.ERROR_LOAD_TOKEN); return; } if(!index.fileLocations.containsKey(filename) || index.status.get(filename).equals("store_in_progress") || index.status.get(filename).equals("removal_in_progress")) { sendMessage(client, writer, Protocol.ERROR_FILE_DOES_NOT_EXIST_TOKEN); return; } int port = index.getFile(filename); index.locationsTried.get(client).remove(Integer.valueOf(port)); sendMessage(client, writer, Protocol.LOAD_FROM_TOKEN + " " + port + " " + index.fileSizes.get(filename)); } private ArrayList<Integer> getRSmallestDStores() { HashMap<Integer, Integer> temp = new HashMap<>(); synchronized (index) { for (int port : index.dstores.keySet()) { temp.put(port, 0); } for (ArrayList<Integer> ports : index.fileLocations.values()) { for (int port : ports) { if (temp.containsKey(port)) { temp.put(port, temp.get(port) + 1); } else { temp.put(port, 1); } } } } ArrayList<Integer> ports = new ArrayList<>(); while (ports.size() < R) { int port = Integer.MAX_VALUE; int minValue = Integer.MAX_VALUE; for (Map.Entry<Integer, Integer> entries : temp.entrySet()) { if (entries.getValue() < minValue) { port = entries.getKey(); minValue = entries.getValue(); } } ports.add(port); temp.remove(port); } return ports; } private void store(Socket client, String filename, int fileSize) throws IOException, InterruptedException { OutputStream output = client.getOutputStream(); PrintWriter writer = new PrintWriter(output, true); InputStream in = client.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(in)); if(index.fileLocations.containsKey(filename) || (index.status.containsKey(filename) && index.status.get(filename).equals("store_in_progress"))) { sendMessage(client, writer, Protocol.ERROR_FILE_ALREADY_EXISTS_TOKEN); return; } ArrayList<Integer> ports; synchronized (index) { index.status.put(filename, "store_in_progress"); ports = getRSmallestDStores(); index.fileLocations.put(filename, ports); index.fileAcks.put(filename, new CountDownLatch(ports.size())); index.fileSizes.put(filename, fileSize); } sendMessage(client, writer, Protocol.STORE_TO_TOKEN + " " + ports.stream().map(Object::toString).collect(Collectors.joining(" "))); Boolean acksReceived = index.fileAcks.get(filename).await(timeout, TimeUnit.MILLISECONDS); if(acksReceived) { synchronized (index) { index.status.put(filename, "store_complete"); } sendMessage(client, writer, Protocol.STORE_COMPLETE_TOKEN); } else { System.err.println("Not all acks received"); synchronized (index) { index.removeEntry(filename); } // long start = System.currentTimeMillis(); // while(index.fileAcks.get(filename).size() < index.fileLocations.get(filename).size() && (System.currentTimeMillis() - start) < timeout) { // Thread.sleep(1); // } // if(index.fileAcks.get(filename).size() < index.fileLocations.get(filename).size()) { // System.err.println("Not all acks received"); // synchronized (index) { // index.removeEntry(filename); // } // } else { // synchronized (index) { // index.status.put(filename, "store_complete"); // } // sendMessage(client, writer, Protocol.STORE_COMPLETE_TOKEN); // } } } private void loadDStore(int port, Socket client) throws IOException, InterruptedException { synchronized (index) { index.dstores.put(port, client); } ControllerLogger.getInstance().dstoreJoined(client, port); OutputStream output = client.getOutputStream(); PrintWriter writer = new PrintWriter(output, true); InputStream in = client.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(in)); new Thread(() -> { try { while(rebalancing) { Thread.sleep(5); } synchronized (rebalancing) { rebalancing = true; rebalance(); rebalancing = false; } } catch (IOException | InterruptedException e) { e.printStackTrace(); } }).start(); while(true) { String ping = reader.readLine(); if(ping == null) { System.err.println("DStore " + port + " disconnected"); synchronized (index) { index.removeDStore(port); } break; } ControllerLogger.getInstance().messageReceived(client, ping); String[] args = ping.split(" "); switch (args[0]) { case Protocol.STORE_ACK_TOKEN: synchronized (index) { if (index.fileAcks.containsKey(args[1])) { // index.fileAcks.get(args[1]).add(port); index.fileAcks.get(args[1]).countDown(); } } break; case Protocol.REMOVE_ACK_TOKEN: synchronized (index) { if (index.fileRemoveAcks.containsKey(args[1])) { // index.fileRemoveAcks.get(args[1]).add(port); index.fileRemoveAcks.get(args[1]).countDown(); } } break; case Protocol.LIST_TOKEN: index.listResponses.put(port, String.join(" ", Arrays.copyOfRange(args, 1, args.length))); break; case Protocol.REBALANCE_COMPLETE_TOKEN: // index.rebalanceAcks.add(client.getPort()); if (index.rebalanceAcks != null) { index.rebalanceAcks.countDown(); } else { System.err.println("Received rebalance ack without ongoing rebalance"); } break; case Protocol.ERROR_FILE_DOES_NOT_EXIST_TOKEN: if (index.fileRemoveAcks.containsKey(args[1])) { // index.fileRemoveAcks.get(args[1]).add(port); index.fileRemoveAcks.get(args[1]).countDown(); } break; default: System.err.println("Malformed message received"); System.err.println(args[0]); break; } } } private void rebalance() throws IOException, InterruptedException { if(index.dstores.size() < R) { return; } while(currentStores.size() > 0 || currentRemoves.size() > 0) { Thread.sleep(5); } HashMap<String, ArrayList<Integer>> tempLocations = new HashMap<>(); HashMap<String, ArrayList<Integer>> oldLocations = new HashMap<>(index.fileLocations); HashMap<Integer, Integer> tempNoOfFiles = new HashMap<>(); ArrayList<Integer> connectedDstores = new ArrayList<>(); ArrayList<Integer> disconnectedDstores = new ArrayList<>(); synchronized (index.dstores) { index.listResponses = new HashMap<>(); for (Map.Entry<Integer, Socket> entry : index.dstores.entrySet()) { int port = entry.getKey(); Socket client = entry.getValue(); OutputStream output = client.getOutputStream(); PrintWriter writer = new PrintWriter(output, true); InputStream in = client.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(in)); sendMessage(client, writer, Protocol.LIST_TOKEN); int oldSize = index.listResponses.size(); long start = System.currentTimeMillis(); while (index.listResponses.size() == oldSize && (System.currentTimeMillis() - start) < timeout) { Thread.sleep(5); } if (index.listResponses.size() == oldSize) { System.err.println("DStore disconnected while waiting for LIST response in rebalance"); disconnectedDstores.add(port); continue; } else { connectedDstores.add(port); } String fileListS = index.listResponses.get(port); String[] fileList = fileListS.split(" "); tempNoOfFiles.put(port, fileList.length); if (!fileListS.equals("")) { for (String filename : fileList) { if (!tempLocations.containsKey(filename)) { tempLocations.put(filename, new ArrayList<>()); } tempLocations.get(filename).add(port); } } } } synchronized (index) { for (Integer port : disconnectedDstores) { index.removeDStore(port); } } if(index.dstores.size() < R) { System.err.println("No longer enough Dstores to continue rebalancing"); return; } HashMap<String, ArrayList<Integer>> toRemoveOld = new HashMap<>(); for(String filename : tempLocations.keySet()) { if(!oldLocations.containsKey(filename)) { toRemoveOld.put(filename, tempLocations.get(filename)); } } for(String filename : toRemoveOld.keySet()) { tempLocations.remove(filename); } for (String filename : tempLocations.keySet()) { Integer lD = R - tempLocations.get(filename).size(); if (lD > 0) { tempLocations.get(filename).addAll(getNSmallestDstoresWithoutFile(lD, filename, tempLocations, connectedDstores)); } else if (lD < 0) { tempLocations.get(filename).addAll(getNLargestDstoresWithFile(lD, filename, tempLocations, connectedDstores)); } } HashMap<Integer, ArrayList<String>> dstores = new HashMap<>(); for (Integer port : connectedDstores) { dstores.put(port, new ArrayList<>()); for (Map.Entry<String, ArrayList<Integer>> entries : tempLocations.entrySet()) { if (entries.getValue().contains(port)) { dstores.get(port).add(entries.getKey()); } } } double cap = R * tempLocations.size() / (double) dstores.size(); double maxSize = Math.ceil(cap); double minSize = Math.floor(cap); while (!dstoresWithinLimits(tempLocations, dstores)) { String toMove = null; Integer moveTo = null, moveFrom = null; for (Map.Entry<Integer, ArrayList<String>> entries : dstores.entrySet()) { if (entries.getValue().size() > maxSize) { for (String filename : entries.getValue()) { for (Map.Entry<Integer, ArrayList<String>> entries2 : dstores.entrySet()) { if (entries2.getValue().size() < minSize && !entries2.getValue().contains(filename)) { toMove = filename; moveTo = entries2.getKey(); moveFrom = entries.getKey(); break; } else if (entries2.getValue().size() < maxSize && !entries2.getValue().contains(filename)) { toMove = filename; moveTo = entries2.getKey(); moveFrom = entries.getKey(); break; } } } } if(entries.getValue().size() < minSize) { for (Map.Entry<Integer, ArrayList<String>> entries2 : dstores.entrySet()) { for (String filename : entries2.getValue()) { if(!entries.getValue().contains(filename) && entries2.getValue().size() - 1 >= minSize) { toMove = filename; moveTo = entries.getKey(); moveFrom = entries2.getKey(); break; } } } } } if (moveFrom != null && moveTo != null && toMove != null) { dstores.get(moveFrom).remove(toMove); dstores.get(moveTo).add(toMove); tempLocations.get(toMove).remove(moveFrom); tempLocations.get(toMove).add(moveTo); } } HashMap<String, ArrayList<Integer>> toSend = new HashMap<>(); HashMap<String, ArrayList<Integer>> toRemove = new HashMap<>(toRemoveOld); for (Map.Entry<String, ArrayList<Integer>> entries : tempLocations.entrySet()) { toSend.put(entries.getKey(), new ArrayList<>()); toRemove.put(entries.getKey(), new ArrayList<>()); for (Integer toPort : entries.getValue()) { if(oldLocations.containsKey(entries.getKey())) { if (!oldLocations.get(entries.getKey()).contains(toPort)) { toSend.get(entries.getKey()).add(toPort); } } else { if(tempLocations.get(entries.getKey()).contains(toPort)) { toRemove.get(entries.getKey()).add(toPort); } } } } for (Map.Entry<String, ArrayList<Integer>> entries : oldLocations.entrySet()) { for (Integer fromPort : entries.getValue()) { if (!tempLocations.get(entries.getKey()).contains(fromPort)) { toRemove.get(entries.getKey()).add(fromPort); } } } index.rebalanceAcks = new CountDownLatch(connectedDstores.size()); for (Integer port : connectedDstores) { HashMap<String, ArrayList<Integer>> sendFilesTo = new HashMap<>(); ArrayList<String> removeFilesFrom = new ArrayList<>(); for (Map.Entry<String, ArrayList<Integer>> sendFiles : toSend.entrySet()) { if (oldLocations.containsKey(sendFiles.getKey()) && oldLocations.get(sendFiles.getKey()).contains(port)) { sendFilesTo.put(sendFiles.getKey(), sendFiles.getValue()); } } for (String filename : sendFilesTo.keySet()) { toSend.remove(filename); } for (Map.Entry<String, ArrayList<Integer>> removeFiles : toRemove.entrySet()) { if (removeFiles.getValue().contains(port)) { removeFilesFrom.add(removeFiles.getKey()); } } ArrayList<String> sendString = new ArrayList<>(); for (Map.Entry<String, ArrayList<Integer>> entries : sendFilesTo.entrySet()) { if (entries.getValue().size() > 0) { String sendy = entries.getKey() + " " + entries.getValue().size(); sendy += " " + entries.getValue().stream().map(Object::toString).collect(Collectors.joining(" ")); sendString.add(sendy); } } String files_to_send = sendString.size() + ""; if (sendString.size() > 0) { files_to_send += " " + sendString.stream().map(Object::toString).collect(Collectors.joining(" ")); } String files_to_remove = removeFilesFrom.size() + ""; if (removeFilesFrom.size() > 0) { files_to_remove += " " + String.join(" ", removeFilesFrom); } String message = Protocol.REBALANCE_TOKEN + " " + files_to_send + " " + files_to_remove; if (sendString.size() <= 0 && removeFilesFrom.size() <= 0) { index.rebalanceAcks.countDown(); continue; } Socket client = index.dstores.get(port); OutputStream output = client.getOutputStream(); PrintWriter writer = new PrintWriter(output, true); sendMessage(client, writer, message); // long start = System.currentTimeMillis(); // int startSize = index.rebalanceAcks.size(); // while (index.rebalanceAcks.size() == startSize && (System.currentTimeMillis() - start) < timeout) { // Thread.sleep(5); // } // // if (index.rebalanceAcks.size() == startSize) { // System.err.println("Rebalance on dstore " + port + "unsuccessful"); // continue; // } } Boolean acksReceived = index.rebalanceAcks.await(timeout, TimeUnit.MILLISECONDS); if(acksReceived) { System.out.println("Rebalance successful"); } else { System.err.println("Rebalance unsuccessful, ack not received from " + index.rebalanceAcks.getCount() + " dstores"); } synchronized (index) { index.fileLocations = tempLocations; } } private boolean dstoresWithinLimits(HashMap<String, ArrayList<Integer>> tempLocations, HashMap<Integer, ArrayList<String>> dstores) { int f = tempLocations.size(); int N = dstores.size(); double cap = R * f / (double) N; double maxSize = Math.ceil(cap); double minSize = Math.floor(cap); for(ArrayList<String> stores : dstores.values()) { if(!(maxSize >= stores.size() && stores.size() >= minSize)) { return false; } } return true; } private ArrayList<Integer> getNLargestDstoresWithFile(int N, String filename, HashMap<String, ArrayList<Integer>> tempLocations, ArrayList<Integer> connectedDstores) { HashMap<Integer, Integer> dstoreSizes = new HashMap<>(); ArrayList<Integer> containsFile = new ArrayList<>(); for(Integer port : connectedDstores) { if(tempLocations.get(filename).contains(port)) { containsFile.add(port); } } for (ArrayList<Integer> ports : tempLocations.values()) { for (int port : ports) { if(containsFile.contains(port)) { if (dstoreSizes.containsKey(port)) { dstoreSizes.put(port, dstoreSizes.get(port) + 1); } else { dstoreSizes.put(port, 1); } } } } ArrayList<Integer> ports = new ArrayList<>(); while (ports.size() < N) { int port = Integer.MIN_VALUE; int maxValue = Integer.MIN_VALUE; for (Map.Entry<Integer, Integer> entries : dstoreSizes.entrySet()) { if (entries.getValue() > maxValue) { port = entries.getKey(); maxValue = entries.getValue(); } } ports.add(port); dstoreSizes.remove(port); } return ports; } private ArrayList<Integer> getNSmallestDstoresWithoutFile(int N, String filename, HashMap<String, ArrayList<Integer>> tempLocations, ArrayList<Integer> connectedDstores) { HashMap<Integer, Integer> dstoreSizes = new HashMap<>(); ArrayList<Integer> doesntContainFile = new ArrayList<>(); for(Integer port : connectedDstores) { if(!tempLocations.get(filename).contains(port)) { doesntContainFile.add(port); } } for (ArrayList<Integer> ports : tempLocations.values()) { for (Integer port : ports) { if(doesntContainFile.contains(port)) { if (dstoreSizes.containsKey(port)) { dstoreSizes.put(port, dstoreSizes.get(port) + 1); } else { dstoreSizes.put(port, 1); } } } } for(Integer port : doesntContainFile) { if(!dstoreSizes.containsKey(port)) { dstoreSizes.put(port, 0); } } ArrayList<Integer> ports = new ArrayList<>(); while (ports.size() < N) { int port = Integer.MAX_VALUE; int minValue = Integer.MAX_VALUE; for (Map.Entry<Integer, Integer> entries : dstoreSizes.entrySet()) { if (entries.getValue() < minValue) { port = entries.getKey(); minValue = entries.getValue(); } } ports.add(port); dstoreSizes.remove(port); } return ports; } }