DS-Lab / src / main / java / dslab / rmi / channel / EncryptableSocketChannel.java
EncryptableSocketChannel.java
Raw
package dslab.rmi.channel;

import dslab.crypto.AES;
import dslab.crypto.ClientChallengeMessage;
import dslab.crypto.RSA;
import dslab.util.ComponentId;

import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.net.Socket;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.function.Function;
import java.util.logging.Logger;

import static dslab.rmi.serialize.dmap.DmapServerSerializer.DmapCommand.startsecure;
import static java.lang.String.join;
import static java.util.function.Function.identity;
import static java.util.logging.Level.SEVERE;

/**
 * Allows for encrypted communication between client and server. The channel starts out unencrypted.
 * If the client sends the 'startsecure' command, the channel intercepts it and negotiates the parameters
 * for AES encryption with the client, then continues to operate in encrypted mode until closed.
 */
public class EncryptableSocketChannel extends SocketChannel {

    private final ComponentId componentId;

    //no-op encryption and decryption as default values
    private Function<String,String> encrypt = identity();
    private Function<String,String> decrypt = identity();

    private static final Logger LOG = Logger.getLogger(EncryptableSocketChannel.class.getSimpleName());

    public EncryptableSocketChannel(Socket socket, ComponentId componentId) throws IOException {
        super(socket);
        this.componentId = componentId;
    }

    /**
     * Handles the setup of the encrypted channel if initiated by the client
     */
    protected void handleStartSecure() {
        try {
            LOG.info("Encrypting channel " + getId());

            write("ok " + componentId);

            var challengeMessage = deserialize(readLine());

            var aes = new AES(challengeMessage);

            encrypt = aes::encrypt;
            decrypt = aes::decrypt;

            write("ok " + aes.getBase64ClientChallenge());

            var clientOk = readLine();

            if (!clientOk.equals("ok"))
                throw new RuntimeException("Expected client to say 'ok', actual message was '" + clientOk + "'");

        } catch (ClosedChannelException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Initiate encryption of the channel
     */
    public void initiateStartSecure() {

        try {
            write("startsecure");

            var ok_componentId = readLine();
            var componentId = ok_componentId.split(" ")[1];
            var rsa = new RSA(new ComponentId(componentId));

            var aes = new AES();

            write(rsa.encrypt("ok " +
                    aes.getBase64ClientChallenge() + " " +
                    aes.getBase64EncodedKey() + " " +
                    aes.getBase64InitializationVector()));

            encrypt = aes::encrypt;
            decrypt = aes::decrypt;

            var solvedClientChallenge = readLine().split(" ")[1];

            if (!solvedClientChallenge.equals(aes.getBase64ClientChallenge()))
                throw new IllegalStateException("The server was unable to solve the client challenge");

            write("ok");


        } catch (ClosedChannelException e) {
            throw new RuntimeException("Initialization of secure connection failed", e);
        }


    }

    private ClientChallengeMessage deserialize(String rawChallengeMessage) {

        var rsa = new RSA(componentId);

        var decrypted = rsa.decrypt(rawChallengeMessage);
        var array = decrypted.split(" ");
        var secretKeyString = Base64.getDecoder().decode(array[2]);

        byte[] clientChallenge = Base64.getDecoder().decode(array[1]);
        var secretKey = new SecretKeySpec(secretKeyString, 0, secretKeyString.length, "AES");

        var iv = new IvParameterSpec(Base64.getDecoder().decode(array[3]));

        return new ClientChallengeMessage(clientChallenge, secretKey, iv);
    }

    @Override
    public String readLinesWhileReady() throws ClosedChannelException {

        var buffer = new ArrayList<String>();

        try {
            do buffer.add(this.readLineInternal());
            while (in.ready());
        } catch (IOException e) {
            LOG.log(SEVERE, "IOException during conversation with client " + getId() + ". ", e);
            close();
            throw new ClosedChannelException();
        }

        var result = join("\n",buffer);
        LOG.info(getId() + " said to me: " + result);

        return result;
    }

    @Override
    public String readLine() throws ClosedChannelException {
        var result = readLineInternal();
        LOG.info(getId() + " said to me: " + result);
        return result;
    }

    private String readLineInternal() throws ClosedChannelException {
        String result = decrypt.apply(super.readLine());

        if (startsecure.name().equals(result)) {
            handleStartSecure();
            return readLine();
        }

        return result;
    }

    @Override
    public void write(String string) throws ClosedChannelException {
        LOG.info("I said to " + socket.hashCode() + ": " + string);
        super.write(encrypt.apply(string));
    }
}