package dslab.rmi; import dslab.rmi.channel.EncryptableSocketChannel; import dslab.util.ComponentId; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.net.ServerSocket; import java.net.Socket; import java.util.List; import static dslab.util.Utils.asUnchecked; import static java.util.stream.Collectors.toList; import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class EncryptableSocketChannelTest { @Test public void doesNotEncryptBeforeStartSecure() throws Exception { var socket = mock(Socket.class); when(socket.getInputStream()).thenReturn(mock(InputStream.class)); var out = new ByteArrayOutputStream(); when(socket.getOutputStream()).thenReturn(out); var channel = new EncryptableSocketChannel(socket, new ComponentId("mailbox-univer-ze")); channel.write("hello"); assertEquals(List.of("hello"), listOf(out)); } @Rule public final ErrorCollector collector = new ErrorCollector(); @Test public void decryptsEncryptedTransmissionsCorrectly() throws Exception { var serverSocket = new ServerSocket(1234); var clientSocket = new Socket((String) null, serverSocket.getLocalPort()); var clientChannel = new EncryptableSocketChannel(serverSocket.accept(), new ComponentId("client-arthur")); var serverChannel = new EncryptableSocketChannel(clientSocket, new ComponentId("mailbox-earth-planet")); var clientThread = new Thread(() -> { clientChannel.initiateStartSecure(); asUnchecked(() -> clientChannel.write("hello")); }); clientThread.start(); collector.checkThat("hello", equalTo(serverChannel.readLinesWhileReady())); //cleanup serverSocket.close(); serverChannel.close(); clientChannel.close(); } private List listOf(ByteArrayOutputStream out) { return out.toString().lines().collect(toList()); } }