From 2948da66dfedbcdc9f11e24b9e9a533b30b3cb0e Mon Sep 17 00:00:00 2001 From: Mariano Barrios Date: Sun, 5 May 2024 13:32:23 +0200 Subject: [PATCH] Migrate to Java: InteroperabilityTest --- .../tlschannel/InteroperabilityTest.java | 181 ++++++++++++++++++ .../tlschannel/InteroperabilityTest.scala | 162 ---------------- .../tlschannel/helpers/SocketGroups.java | 31 +++ .../helpers/SocketPairFactory.scala | 14 +- 4 files changed, 219 insertions(+), 169 deletions(-) create mode 100644 src/test/scala/tlschannel/InteroperabilityTest.java delete mode 100644 src/test/scala/tlschannel/InteroperabilityTest.scala diff --git a/src/test/scala/tlschannel/InteroperabilityTest.java b/src/test/scala/tlschannel/InteroperabilityTest.java new file mode 100644 index 00000000..33def534 --- /dev/null +++ b/src/test/scala/tlschannel/InteroperabilityTest.java @@ -0,0 +1,181 @@ +package tlschannel; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static tlschannel.util.InteroperabilityUtils.*; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Random; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import scala.Option; +import tlschannel.helpers.*; + +@TestInstance(Lifecycle.PER_CLASS) +public class InteroperabilityTest { + + private final SslContextFactory sslContextFactory = new SslContextFactory(); + private final SocketPairFactory factory = + new SocketPairFactory(sslContextFactory.defaultContext(), SslContextFactory.certificateCommonName()); + + private final Random random = new Random(); + + private final int dataSize = SslContextFactory.tlsMaxDataSize() * 10; + + private final byte[] data = new byte[dataSize]; + + { + random.nextBytes(data); + } + + private final int margin = random.nextInt(100); + + private void writerLoop(Writer writer, boolean renegotiate) { + TestJavaUtil.cannotFail(() -> { + int remaining = dataSize; + while (remaining > 0) { + if (renegotiate) writer.renegotiate(); + int chunkSize = random.nextInt(remaining) + 1; // 1 <= chunkSize <= remaining + writer.write(data, dataSize - remaining, chunkSize); + remaining -= chunkSize; + } + }); + } + + private void readerLoop(Reader reader) { + TestJavaUtil.cannotFail(() -> { + byte[] receivedData = new byte[dataSize + margin]; + int remaining = dataSize; + while (remaining > 0) { + int chunkSize = random.nextInt(remaining + margin) + 1; // 1 <= chunkSize <= remaining + margin + int c = reader.read(receivedData, dataSize - remaining, chunkSize); + assertNotEquals(-1, c, "read must not return -1 when there were bytesProduced remaining"); + assertTrue(c <= remaining); + assertTrue(c > 0, "blocking read must return a positive number"); + remaining -= c; + } + assertEquals(0, remaining); + assertArrayEquals(data, Arrays.copyOfRange(receivedData, 0, dataSize)); + }); + } + + /** Test a half-duplex interaction, with renegotiation before reversing the direction of the flow (as in HTTP) + */ + private void halfDuplexStream(Writer serverWriter, Reader clientReader, Writer clientWriter, Reader serverReader) + throws IOException, InterruptedException { + Thread clientWriterThread = new Thread(() -> writerLoop(clientWriter, true), "client-writer"); + Thread serverWriterThread = new Thread(() -> writerLoop(serverWriter, true), "server-writer"); + Thread clientReaderThread = new Thread(() -> readerLoop(clientReader), "client-reader"); + Thread serverReaderThread = new Thread(() -> readerLoop(serverReader), "server-reader"); + serverReaderThread.start(); + clientWriterThread.start(); + serverReaderThread.join(); + clientWriterThread.join(); + clientReaderThread.start(); + // renegotiate three times, to test idempotency + for (int i = 0; i < 3; i++) { + serverWriter.renegotiate(); + } + serverWriterThread.start(); + clientReaderThread.join(); + serverWriterThread.join(); + serverWriter.close(); + clientWriter.close(); + } + + /** Test a full-duplex interaction, without any renegotiation + */ + private void fullDuplexStream(Writer serverWriter, Reader clientReader, Writer clientWriter, Reader serverReader) + throws IOException, InterruptedException { + Thread clientWriterThread = new Thread(() -> writerLoop(clientWriter, false), "client-writer"); + Thread serverWriterThread = new Thread(() -> writerLoop(serverWriter, false), "server-writer"); + Thread clientReaderThread = new Thread(() -> readerLoop(clientReader), "client-reader"); + Thread serverReaderThread = new Thread(() -> readerLoop(serverReader), "server-reader"); + serverReaderThread.start(); + clientWriterThread.start(); + clientReaderThread.start(); + serverWriterThread.start(); + serverReaderThread.join(); + clientWriterThread.join(); + clientReaderThread.join(); + serverWriterThread.join(); + clientWriter.close(); + serverWriter.close(); + } + + // OLD IO -> OLD IO + + // "old-io -> old-io (half duplex) + @Test + public void testOldToOldHalfDuplex() throws IOException, InterruptedException { + SocketGroups.OldOldSocketPair sockerPair = factory.oldOld(Option.apply(null)); + halfDuplexStream( + new SSLSocketWriter(sockerPair.server), + new SocketReader(sockerPair.client), + new SSLSocketWriter(sockerPair.client), + new SocketReader(sockerPair.server)); + } + + // old-io -> old-io (full duplex) + @Test + public void testOldToOldFullDuplex() throws IOException, InterruptedException { + SocketGroups.OldOldSocketPair sockerPair = factory.oldOld(Option.apply(null)); + fullDuplexStream( + new SSLSocketWriter(sockerPair.server), + new SocketReader(sockerPair.client), + new SSLSocketWriter(sockerPair.client), + new SocketReader(sockerPair.server)); + } + + // NIO -> OLD IO + + // nio -> old-io (half duplex) + @Test + public void testNioToOldHalfDuplex() throws IOException, InterruptedException { + SocketGroups.NioOldSocketPair socketPair = factory.nioOld(Option.apply(null)); + halfDuplexStream( + new SSLSocketWriter(socketPair.server), + new ByteChannelReader(socketPair.client.tls), + new TlsChannelWriter(socketPair.client.tls), + new SocketReader(socketPair.server)); + } + + // nio -> old-io (full duplex) + @Test + public void testNioToOldFullDuplex() throws IOException, InterruptedException { + SocketGroups.NioOldSocketPair socketPair = factory.nioOld(Option.apply(null)); + fullDuplexStream( + new SSLSocketWriter(socketPair.server), + new ByteChannelReader(socketPair.client.tls), + new TlsChannelWriter(socketPair.client.tls), + new SocketReader(socketPair.server)); + } + + // OLD IO -> NIO + + // old-io -> nio (half duplex) + @Test + public void testOldToNioHalfDuplex() throws IOException, InterruptedException { + SocketGroups.OldNioSocketPair socketPair = factory.oldNio(Option.apply(null)); + halfDuplexStream( + new TlsChannelWriter(socketPair.server.tls), + new SocketReader(socketPair.client), + new SSLSocketWriter(socketPair.client), + new ByteChannelReader(socketPair.server.tls)); + } + + // old-io -> nio (full duplex) + @Test + public void testOldToNioFullDuplex() throws IOException, InterruptedException { + SocketGroups.OldNioSocketPair socketPair = factory.oldNio(Option.apply(null)); + fullDuplexStream( + new TlsChannelWriter(socketPair.server.tls), + new SocketReader(socketPair.client), + new SSLSocketWriter(socketPair.client), + new ByteChannelReader(socketPair.server.tls)); + } +} diff --git a/src/test/scala/tlschannel/InteroperabilityTest.scala b/src/test/scala/tlschannel/InteroperabilityTest.scala deleted file mode 100644 index eb84c12a..00000000 --- a/src/test/scala/tlschannel/InteroperabilityTest.scala +++ /dev/null @@ -1,162 +0,0 @@ -package tlschannel - -import scala.util.Random -import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertEquals, assertNotEquals, assertTrue} -import org.junit.jupiter.api.{Test, TestInstance} -import org.junit.jupiter.api.TestInstance.Lifecycle -import tlschannel.helpers.TestUtil -import tlschannel.helpers.SslContextFactory -import tlschannel.helpers.SocketPairFactory -import tlschannel.util.InteroperabilityUtils._ - -@TestInstance(Lifecycle.PER_CLASS) -class InteroperabilityTest { - - val sslContextFactory = new SslContextFactory - val factory = new SocketPairFactory(sslContextFactory.defaultContext, SslContextFactory.certificateCommonName) - - def oldNio() = { - val (client, server) = factory.oldNio(None) - val clientPair = (new SSLSocketWriter(client), new SocketReader(client)) - val serverPair = (new TlsChannelWriter(server.tls), new ByteChannelReader(server.tls)) - (clientPair, serverPair) - } - - def nioOld() = { - val (client, server) = factory.nioOld() - val clientPair = (new TlsChannelWriter(client.tls), new ByteChannelReader(client.tls)) - val serverPair = (new SSLSocketWriter(server), new SocketReader(server)) - (clientPair, serverPair) - } - - def oldOld() = { - val (client, server) = factory.oldOld(None) - val clientPair = (new SSLSocketWriter(client), new SocketReader(client)) - val serverPair = (new SSLSocketWriter(server), new SocketReader(server)) - (clientPair, serverPair) - } - - val dataSize = SslContextFactory.tlsMaxDataSize * 10 - val data = Array.ofDim[Byte](dataSize) - Random.nextBytes(data) - - val margin = Random.nextInt(100) - - def writerLoop(writer: Writer, renegotiate: Boolean = false) = TestUtil.cannotFail { - var remaining = dataSize - while (remaining > 0) { - if (renegotiate) - writer.renegotiate() - val chunkSize = Random.nextInt(remaining) + 1 // 1 <= chunkSize <= remaining - writer.write(data, dataSize - remaining, chunkSize) - remaining -= chunkSize - } - } - - def readerLoop(reader: Reader, idx: Int = 0) = TestUtil.cannotFail { - val receivedData = Array.ofDim[Byte](dataSize + margin) - var remaining = dataSize - while (remaining > 0) { - val chunkSize = Random.nextInt(remaining + margin) + 1 // 1 <= chunkSize <= remaining + margin - val c = reader.read(receivedData, dataSize - remaining, chunkSize) - assertNotEquals(-1, c, "read must not return -1 when there were bytesProduced remaining") - assertTrue(c <= remaining) - assertTrue(c > 0, "blocking read must return a positive number") - remaining -= c - } - assertEquals(0, remaining) - assertArrayEquals(data, receivedData.slice(0, dataSize)) - } - - /** Test a half-duplex interaction, with renegotiation before reversing the direction of the flow (as in HTTP) - */ - def halfDuplexStream( - serverWriter: Writer, - clientReader: Reader, - clientWriter: Writer, - serverReader: Reader - ): Unit = { - val clientWriterThread = new Thread(() => writerLoop(clientWriter, renegotiate = true), "client-writer") - val serverWriterThread = new Thread(() => writerLoop(serverWriter, renegotiate = true), "server-writer") - val clientReaderThread = new Thread(() => readerLoop(clientReader), "client-reader") - val serverReaderThread = new Thread(() => readerLoop(serverReader), "server-reader") - Seq(serverReaderThread, clientWriterThread).foreach(_.start()) - Seq(serverReaderThread, clientWriterThread).foreach(_.join()) - clientReaderThread.start() - // renegotiate three times, to test idempotency - for (_ <- 1 to 3) { - serverWriter.renegotiate() - } - serverWriterThread.start() - Seq(clientReaderThread, serverWriterThread).foreach(_.join()) - serverWriter.close() - clientWriter.close() - } - - /** Test a full-duplex interaction, without any renegotiation - */ - def fullDuplexStream( - serverWriter: Writer, - clientReader: Reader, - clientWriter: Writer, - serverReader: Reader - ): Unit = { - val clientWriterThread = new Thread(() => writerLoop(clientWriter), "client-writer") - val serverWriterThread = new Thread(() => writerLoop(serverWriter), "server-writer") - val clientReaderThread = new Thread(() => readerLoop(clientReader), "client-reader") - val serverReaderThread = new Thread(() => readerLoop(serverReader), "server-reader") - Seq(serverReaderThread, clientWriterThread, clientReaderThread, serverWriterThread).foreach(_.start()) - Seq(serverReaderThread, clientWriterThread, clientReaderThread, serverWriterThread).foreach(_.join()) - clientWriter.close() - serverWriter.close() - } - - // OLD IO -> OLD IO - - // "old-io -> old-io (half duplex) - @Test - def testOldToOldHalfDuplex(): Unit = { - val ((clientWriter, clientReader), (serverWriter, serverReader)) = oldOld() - halfDuplexStream(serverWriter, clientReader, clientWriter, serverReader) - } - - // old-io -> old-io (full duplex) - @Test - def testOldToOldFullDuplex(): Unit = { - val ((clientWriter, clientReader), (serverWriter, serverReader)) = oldOld() - fullDuplexStream(serverWriter, clientReader, clientWriter, serverReader) - } - - // NIO -> OLD IO - - // nio -> old-io (half duplex) - @Test - def testNioToOldHalfDuplex(): Unit = { - val ((clientWriter, clientReader), (serverWriter, serverReader)) = nioOld() - halfDuplexStream(serverWriter, clientReader, clientWriter, serverReader) - } - - // nio -> old-io (full duplex) - @Test - def testNioToOldFullDuplex(): Unit = { - val ((clientWriter, clientReader), (serverWriter, serverReader)) = nioOld() - fullDuplexStream(serverWriter, clientReader, clientWriter, serverReader) - } - - // OLD IO -> NIO - - // old-io -> nio (half duplex) - @Test - def testOldToNioHalfDuplex(): Unit = { - val ((clientWriter, clientReader), (serverWriter, serverReader)) = oldNio() - halfDuplexStream(serverWriter, clientReader, clientWriter, serverReader) - } - - // old-io -> nio (full duplex) - @Test - def testOldToNioFullDuplex(): Unit = { - val ((clientWriter, clientReader), (serverWriter, serverReader)) = oldNio() - fullDuplexStream(serverWriter, clientReader, clientWriter, serverReader) - } - -} diff --git a/src/test/scala/tlschannel/helpers/SocketGroups.java b/src/test/scala/tlschannel/helpers/SocketGroups.java index 4243c480..778a2365 100644 --- a/src/test/scala/tlschannel/helpers/SocketGroups.java +++ b/src/test/scala/tlschannel/helpers/SocketGroups.java @@ -2,11 +2,42 @@ import java.nio.channels.ByteChannel; import java.nio.channels.SocketChannel; +import javax.net.ssl.SSLSocket; import tlschannel.TlsChannel; import tlschannel.async.ExtendedAsynchronousByteChannel; public class SocketGroups { + public static class OldOldSocketPair { + public final SSLSocket client; + public final SSLSocket server; + + public OldOldSocketPair(SSLSocket client, SSLSocket server) { + this.client = client; + this.server = server; + } + } + + public static class OldNioSocketPair { + public final SSLSocket client; + public final SocketGroup server; + + public OldNioSocketPair(SSLSocket client, SocketGroup server) { + this.client = client; + this.server = server; + } + } + + public static class NioOldSocketPair { + public final SocketGroup client; + public final SSLSocket server; + + public NioOldSocketPair(SocketGroup client, SSLSocket server) { + this.client = client; + this.server = server; + } + } + public static class SocketPair { public final SocketGroup client; public final SocketGroup server; diff --git a/src/test/scala/tlschannel/helpers/SocketPairFactory.scala b/src/test/scala/tlschannel/helpers/SocketPairFactory.scala index dd286f3b..2c6e11bf 100644 --- a/src/test/scala/tlschannel/helpers/SocketPairFactory.scala +++ b/src/test/scala/tlschannel/helpers/SocketPairFactory.scala @@ -17,10 +17,10 @@ import javax.net.ssl.SNIServerName import tlschannel._ import tlschannel.async.AsynchronousTlsChannel import tlschannel.async.AsynchronousTlsChannelGroup + import java.util.logging.Logger import scala.jdk.CollectionConverters._ import scala.util.Random - import tlschannel.helpers.SocketGroups._ /** Create pairs of connected sockets (using the loopback interface). Additionally, all the raw (non-encrypted) socket @@ -101,7 +101,7 @@ class SocketPairFactory(val sslContext: SSLContext, val serverName: String) { socket } - def oldOld(cipher: Option[String] = None): (SSLSocket, SSLSocket) = { + def oldOld(cipher: Option[String] = None): OldOldSocketPair = { val serverSocket = createSslServerSocket(cipher) val chosenPort = serverSocket.getLocalPort val client = createSslSocket(cipher, localhost, chosenPort, requestedHost = serverName) @@ -110,10 +110,10 @@ class SocketPairFactory(val sslContext: SSLContext, val serverName: String) { client.setSSLParameters(sslParameters) val server = serverSocket.accept().asInstanceOf[SSLSocket] serverSocket.close() - (client, server) + new OldOldSocketPair(client, server) } - def oldNio(cipher: Option[String] = None): (SSLSocket, SocketGroup) = { + def oldNio(cipher: Option[String] = None): OldNioSocketPair = { val serverSocket = ServerSocketChannel.open() serverSocket.bind(new InetSocketAddress(localhost, 0 /* find free port */ )) val chosenPort = serverSocket.getLocalAddress.asInstanceOf[InetSocketAddress].getPort @@ -127,10 +127,10 @@ class SocketPairFactory(val sslContext: SSLContext, val serverName: String) { .newBuilder(rawServer, nameOpt => sslContextFactory(clientSniHostName, sslContext)(nameOpt)) .withEngineFactory(fixedCipherServerSslEngineFactory(cipher) _) .build() - (client, new SocketGroup(server, server, rawServer)) + new OldNioSocketPair(client, new SocketGroup(server, server, rawServer)) } - def nioOld(cipher: Option[String] = None): (SocketGroup, SSLSocket) = { + def nioOld(cipher: Option[String] = None): NioOldSocketPair = { val serverSocket = createSslServerSocket(cipher) val chosenPort = serverSocket.getLocalPort val address = new InetSocketAddress(localhost, chosenPort) @@ -140,7 +140,7 @@ class SocketPairFactory(val sslContext: SSLContext, val serverName: String) { val client = ClientTlsChannel .newBuilder(rawClient, createClientSslEngine(cipher, chosenPort)) .build() - (new SocketGroup(client, client, rawClient), server) + new NioOldSocketPair(new SocketGroup(client, client, rawClient), server) } def nioNio(