DS-Lab / src / test / java / dslab / rmi / EncryptableSocketChannelTest.java
EncryptableSocketChannelTest.java
Raw
package dslab.rmi;

import dslab.rmi.channel.EncryptableSocketChannel;
import dslab.util.ComponentId;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ErrorCollector;

import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.List;

import static dslab.util.Utils.asUnchecked;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class EncryptableSocketChannelTest {

    @Test
    public void doesNotEncryptBeforeStartSecure() throws Exception {

        var socket = mock(Socket.class);
        when(socket.getInputStream()).thenReturn(mock(InputStream.class));

        var out = new ByteArrayOutputStream();
        when(socket.getOutputStream()).thenReturn(out);

        var channel = new EncryptableSocketChannel(socket, new ComponentId("mailbox-univer-ze"));

        channel.write("hello");
        assertEquals(List.of("hello"), listOf(out));
    }

    @Rule
    public final ErrorCollector collector = new ErrorCollector();

    @Test
    public void decryptsEncryptedTransmissionsCorrectly() throws Exception {

        var serverSocket = new ServerSocket(1234);
        var clientSocket = new Socket((String) null, serverSocket.getLocalPort());

        var clientChannel = new EncryptableSocketChannel(serverSocket.accept(), new ComponentId("client-arthur"));
        var serverChannel = new EncryptableSocketChannel(clientSocket, new ComponentId("mailbox-earth-planet"));

        var clientThread = new Thread(() -> {
            clientChannel.initiateStartSecure();
            asUnchecked(() -> clientChannel.write("hello"));
        });

        clientThread.start();

        collector.checkThat("hello", equalTo(serverChannel.readLinesWhileReady()));

        //cleanup
        serverSocket.close();
        serverChannel.close();
        clientChannel.close();

    }

    private List<String> listOf(ByteArrayOutputStream out) {
        return out.toString().lines().collect(toList());
    }

}