package dslab.rmi.channel; import dslab.crypto.AES; import dslab.crypto.ClientChallengeMessage; import dslab.crypto.RSA; import dslab.util.ComponentId; import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.SecretKeySpec; import java.io.IOException; import java.net.Socket; import java.nio.channels.ClosedChannelException; import java.util.ArrayList; import java.util.Base64; import java.util.function.Function; import java.util.logging.Logger; import static dslab.rmi.serialize.dmap.DmapServerSerializer.DmapCommand.startsecure; import static java.lang.String.join; import static java.util.function.Function.identity; import static java.util.logging.Level.SEVERE; /** * Allows for encrypted communication between client and server. The channel starts out unencrypted. * If the client sends the 'startsecure' command, the channel intercepts it and negotiates the parameters * for AES encryption with the client, then continues to operate in encrypted mode until closed. */ public class EncryptableSocketChannel extends SocketChannel { private final ComponentId componentId; //no-op encryption and decryption as default values private Function<String,String> encrypt = identity(); private Function<String,String> decrypt = identity(); private static final Logger LOG = Logger.getLogger(EncryptableSocketChannel.class.getSimpleName()); public EncryptableSocketChannel(Socket socket, ComponentId componentId) throws IOException { super(socket); this.componentId = componentId; } /** * Handles the setup of the encrypted channel if initiated by the client */ protected void handleStartSecure() { try { LOG.info("Encrypting channel " + getId()); write("ok " + componentId); var challengeMessage = deserialize(readLine()); var aes = new AES(challengeMessage); encrypt = aes::encrypt; decrypt = aes::decrypt; write("ok " + aes.getBase64ClientChallenge()); var clientOk = readLine(); if (!clientOk.equals("ok")) throw new RuntimeException("Expected client to say 'ok', actual message was '" + clientOk + "'"); } catch (ClosedChannelException e) { throw new RuntimeException(e); } } /** * Initiate encryption of the channel */ public void initiateStartSecure() { try { write("startsecure"); var ok_componentId = readLine(); var componentId = ok_componentId.split(" ")[1]; var rsa = new RSA(new ComponentId(componentId)); var aes = new AES(); write(rsa.encrypt("ok " + aes.getBase64ClientChallenge() + " " + aes.getBase64EncodedKey() + " " + aes.getBase64InitializationVector())); encrypt = aes::encrypt; decrypt = aes::decrypt; var solvedClientChallenge = readLine().split(" ")[1]; if (!solvedClientChallenge.equals(aes.getBase64ClientChallenge())) throw new IllegalStateException("The server was unable to solve the client challenge"); write("ok"); } catch (ClosedChannelException e) { throw new RuntimeException("Initialization of secure connection failed", e); } } private ClientChallengeMessage deserialize(String rawChallengeMessage) { var rsa = new RSA(componentId); var decrypted = rsa.decrypt(rawChallengeMessage); var array = decrypted.split(" "); var secretKeyString = Base64.getDecoder().decode(array[2]); byte[] clientChallenge = Base64.getDecoder().decode(array[1]); var secretKey = new SecretKeySpec(secretKeyString, 0, secretKeyString.length, "AES"); var iv = new IvParameterSpec(Base64.getDecoder().decode(array[3])); return new ClientChallengeMessage(clientChallenge, secretKey, iv); } @Override public String readLinesWhileReady() throws ClosedChannelException { var buffer = new ArrayList<String>(); try { do buffer.add(this.readLineInternal()); while (in.ready()); } catch (IOException e) { LOG.log(SEVERE, "IOException during conversation with client " + getId() + ". ", e); close(); throw new ClosedChannelException(); } var result = join("\n",buffer); LOG.info(getId() + " said to me: " + result); return result; } @Override public String readLine() throws ClosedChannelException { var result = readLineInternal(); LOG.info(getId() + " said to me: " + result); return result; } private String readLineInternal() throws ClosedChannelException { String result = decrypt.apply(super.readLine()); if (startsecure.name().equals(result)) { handleStartSecure(); return readLine(); } return result; } @Override public void write(String string) throws ClosedChannelException { LOG.info("I said to " + socket.hashCode() + ": " + string); super.write(encrypt.apply(string)); } }