diff --git a/src/test/scala/tlschannel/AllocationTest.java b/src/test/scala/tlschannel/AllocationTest.java index 8ddaa2de..5a3d7d24 100644 --- a/src/test/scala/tlschannel/AllocationTest.java +++ b/src/test/scala/tlschannel/AllocationTest.java @@ -2,7 +2,7 @@ import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; -import scala.Option; +import java.util.Optional; import tlschannel.helpers.Loops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -24,12 +24,9 @@ public static void main(String[] args) { MemoryMXBean memoryBean = ManagementFactory.getMemoryMXBean(); - SocketPair socketPair1 = - factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); - SocketPair socketPair2 = - factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); - SocketPair socketPair3 = - factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); + SocketPair socketPair1 = factory.nioNio(Optional.empty(), Optional.empty(), true, false, Optional.empty()); + SocketPair socketPair2 = factory.nioNio(Optional.empty(), Optional.empty(), true, false, Optional.empty()); + SocketPair socketPair3 = factory.nioNio(Optional.empty(), Optional.empty(), true, false, Optional.empty()); // do a "warm-up" loop, in order to not count anything statically allocated Loops.halfDuplex(socketPair1, 10000, false, false); diff --git a/src/test/scala/tlschannel/BlockingTest.java b/src/test/scala/tlschannel/BlockingTest.java index de0b2cd3..b3708e16 100644 --- a/src/test/scala/tlschannel/BlockingTest.java +++ b/src/test/scala/tlschannel/BlockingTest.java @@ -6,7 +6,6 @@ import org.junit.jupiter.api.TestFactory; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; import tlschannel.helpers.Loops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -37,13 +36,13 @@ public Collection testHalfDuplexWireRenegotiations() { ret.add(DynamicTest.dynamicTest( String.format("testHalfDuplexWireRenegotiations() - size1=%d, size2=%d", size1, size2), () -> { SocketPair socketPair = factory.nioNio( - Option.apply(null), - Option.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(Optional.of(size1), Optional.of(size2)), new ChuckSizes(Optional.of(size1), Optional.of(size2)))), true, false, - Option.apply(null)); + Optional.empty()); Loops.halfDuplex(socketPair, dataSize, true, false); System.out.printf("%5d -eng-> %5d -net-> %5d -eng-> %5d\n", size1, size2, size1, size2); })); @@ -65,13 +64,13 @@ public Collection testFullDuplex() { ret.add(DynamicTest.dynamicTest( String.format("testFullDuplex() - size1=%d, size2=%d", size1, size2), () -> { SocketPair socketPair = factory.nioNio( - Option.apply(null), - Option.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(Optional.of(size1), Optional.of(size2)), new ChuckSizes(Optional.of(size1), Optional.of(size2)))), true, false, - Option.apply(null)); + Optional.empty()); Loops.fullDuplex(socketPair, dataSize); System.out.printf("%5d -eng-> %5d -net-> %5d -eng-> %5d\n", size1, size2, size1, size2); })); diff --git a/src/test/scala/tlschannel/CipherTest.java b/src/test/scala/tlschannel/CipherTest.java index 55b2dbb8..0f7b8f7e 100644 --- a/src/test/scala/tlschannel/CipherTest.java +++ b/src/test/scala/tlschannel/CipherTest.java @@ -1,18 +1,13 @@ package tlschannel; import java.security.NoSuchAlgorithmException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; +import java.util.*; import java.util.stream.Collectors; import javax.net.ssl.SSLContext; import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.TestFactory; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; -import scala.Some; import tlschannel.helpers.Loops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -49,7 +44,7 @@ public Collection testHalfDuplexWithRenegotiation() { () -> { SocketPairFactory socketFactory = new SocketPairFactory(ctxFactory.defaultContext()); SocketPair socketPair = socketFactory.nioNio( - Some.apply(cipher), Option.apply(null), true, false, Option.apply(null)); + Optional.of(cipher), Optional.empty(), true, false, Optional.empty()); Loops.halfDuplex(socketPair, dataSize, protocol.compareTo("TLSv1.2") < 0, false); String actualProtocol = socketPair .client @@ -76,7 +71,7 @@ public Collection testFullDuplex() { String.format("testFullDuplex() - protocol: %s, cipher: %s", protocol, cipher), () -> { SocketPairFactory socketFactory = new SocketPairFactory(ctxFactory.defaultContext()); SocketPair socketPair = socketFactory.nioNio( - Some.apply(cipher), Option.apply(null), true, false, Option.apply(null)); + Optional.of(cipher), Optional.empty(), true, false, Optional.empty()); Loops.fullDuplex(socketPair, dataSize); String actualProtocol = socketPair .client diff --git a/src/test/scala/tlschannel/CloseTest.java b/src/test/scala/tlschannel/CloseTest.java index f07ef8c0..d4c2a952 100644 --- a/src/test/scala/tlschannel/CloseTest.java +++ b/src/test/scala/tlschannel/CloseTest.java @@ -14,8 +14,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; -import scala.Some; import tlschannel.helpers.*; import tlschannel.helpers.SocketPairFactory.ChuckSizes; import tlschannel.helpers.SocketPairFactory.ChunkSizeConfig; @@ -36,13 +34,13 @@ public class CloseTest { @Test void testTcpImmediateClose() throws InterruptedException, IOException { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Some.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(internalBufferSize, Optional.empty()), new ChuckSizes(internalBufferSize, Optional.empty()))), true, false, - Option.apply(null)); + Optional.empty()); SocketGroups.SocketGroup clientGroup = socketPair.client; SocketGroups.SocketGroup serverGroup = socketPair.server; ByteChannel client = clientGroup.external; @@ -75,13 +73,13 @@ void testTcpImmediateClose() throws InterruptedException, IOException { @Test void testTcpClose() throws InterruptedException, IOException { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Some.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(internalBufferSize, Optional.empty()), new ChuckSizes(internalBufferSize, Optional.empty()))), true, false, - Option.apply(null)); + Optional.empty()); SocketGroups.SocketGroup clientGroup = socketPair.client; SocketGroups.SocketGroup serverGroup = socketPair.server; ByteChannel client = clientGroup.external; @@ -119,13 +117,13 @@ void testTcpClose() throws InterruptedException, IOException { @Test void testClose() throws InterruptedException { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Some.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(internalBufferSize, Optional.empty()), new ChuckSizes(internalBufferSize, Optional.empty()))), true, false, - Option.apply(null)); + Optional.empty()); SocketGroups.SocketGroup clientGroup = socketPair.client; SocketGroups.SocketGroup serverGroup = socketPair.server; ByteChannel client = clientGroup.external; @@ -163,13 +161,13 @@ void testClose() throws InterruptedException { @Test void testCloseAndWait() throws InterruptedException { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Some.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(internalBufferSize, Optional.empty()), new ChuckSizes(internalBufferSize, Optional.empty()))), true, true, - Option.apply(null)); + Optional.empty()); SocketGroups.SocketGroup clientGroup = socketPair.client; SocketGroups.SocketGroup serverGroup = socketPair.server; ByteChannel client = clientGroup.external; @@ -207,13 +205,13 @@ void testCloseAndWait() throws InterruptedException { @Test void testCloseAndWaitForever() throws IOException, InterruptedException { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Some.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(internalBufferSize, Optional.empty()), new ChuckSizes(internalBufferSize, Optional.empty()))), true, true, - Option.apply(null)); + Optional.empty()); SocketGroups.SocketGroup clientGroup = socketPair.client; SocketGroups.SocketGroup serverGroup = socketPair.server; ByteChannel client = clientGroup.external; @@ -249,13 +247,13 @@ void testCloseAndWaitForever() throws IOException, InterruptedException { @Test void testShutdownAndForget() throws InterruptedException, IOException { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Some.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(internalBufferSize, Optional.empty()), new ChuckSizes(internalBufferSize, Optional.empty()))), true, false, - Option.apply(null)); + Optional.empty()); SocketGroups.SocketGroup clientGroup = socketPair.client; SocketGroups.SocketGroup serverGroup = socketPair.server; ByteChannel client = clientGroup.external; @@ -291,13 +289,13 @@ void testShutdownAndForget() throws InterruptedException, IOException { @Test void testShutdownAndWait() throws IOException, InterruptedException { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Some.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(internalBufferSize, Optional.empty()), new ChuckSizes(internalBufferSize, Optional.empty()))), true, false, - Option.apply(null)); + Optional.empty()); SocketGroups.SocketGroup clientGroup = socketPair.client; SocketGroups.SocketGroup serverGroup = socketPair.server; ByteChannel client = clientGroup.external; @@ -343,13 +341,13 @@ void testShutdownAndWait() throws IOException, InterruptedException { @Test void testShutdownAndWaitForever() throws InterruptedException, IOException { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Some.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(internalBufferSize, Optional.empty()), new ChuckSizes(internalBufferSize, Optional.empty()))), true, false, - Option.apply(null)); + Optional.empty()); SocketGroups.SocketGroup clientGroup = socketPair.client; SocketGroups.SocketGroup serverGroup = socketPair.server; ByteChannel client = clientGroup.external; diff --git a/src/test/scala/tlschannel/ConcurrentTest.java b/src/test/scala/tlschannel/ConcurrentTest.java index 32be69fa..060164bf 100644 --- a/src/test/scala/tlschannel/ConcurrentTest.java +++ b/src/test/scala/tlschannel/ConcurrentTest.java @@ -7,13 +7,13 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Logger; import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; import tlschannel.helpers.*; @TestInstance(Lifecycle.PER_CLASS) @@ -31,7 +31,7 @@ public class ConcurrentTest { // write-side thread safety @Test public void testWriteSide() throws IOException { - SocketPair socketPair = factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); + SocketPair socketPair = factory.nioNio(Optional.empty(), Optional.empty(), true, false, Optional.empty()); Thread clientWriterThread1 = new Thread(() -> writerLoop(dataSize, 'a', socketPair.client), "client-writer-1"); Thread clientWriterThread2 = new Thread(() -> writerLoop(dataSize, 'b', socketPair.client), "client-writer-2"); Thread clientWriterThread3 = new Thread(() -> writerLoop(dataSize, 'c', socketPair.client), "client-writer-3"); @@ -54,7 +54,7 @@ public void testWriteSide() throws IOException { // read-size thread-safety @Test public void testReadSide() throws IOException { - SocketPair socketPair = factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); + SocketPair socketPair = factory.nioNio(Optional.empty(), Optional.empty(), true, false, Optional.empty()); Thread clientWriterThread = new Thread(() -> writerLoop(dataSize, 'a', socketPair.client), "client-writer"); AtomicLong totalRead = new AtomicLong(); Thread serverReaderThread1 = diff --git a/src/test/scala/tlschannel/FailTest.java b/src/test/scala/tlschannel/FailTest.java index 16f73837..55f10359 100644 --- a/src/test/scala/tlschannel/FailTest.java +++ b/src/test/scala/tlschannel/FailTest.java @@ -8,11 +8,11 @@ import java.nio.ByteBuffer; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; +import java.util.Optional; import javax.net.ssl.SSLException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; import tlschannel.helpers.TestJavaUtil; @@ -26,17 +26,17 @@ public class FailTest { @Test public void testPlanToTls() throws IOException, InterruptedException { ServerSocketChannel serverSocket = ServerSocketChannel.open(); - serverSocket.bind(new InetSocketAddress(factory.localhost(), 0 /* find free port */)); + serverSocket.bind(new InetSocketAddress(factory.localhost, 0 /* find free port */)); int chosenPort = ((InetSocketAddress) serverSocket.getLocalAddress()).getPort(); - InetSocketAddress address = new InetSocketAddress(factory.localhost(), chosenPort); + InetSocketAddress address = new InetSocketAddress(factory.localhost, chosenPort); SocketChannel clientChannel = SocketChannel.open(address); SocketChannel rawServer = serverSocket.accept(); - factory.createClientSslEngine(Option.empty(), chosenPort); + factory.createClientSslEngine(Optional.empty(), chosenPort); ServerTlsChannel.Builder serverChannelBuilder = ServerTlsChannel.newBuilder( rawServer, - nameOpt -> - factory.sslContextFactory(factory.clientSniHostName(), factory.sslContext(), nameOpt)) - .withEngineFactory(sslContext -> factory.fixedCipherServerSslEngineFactory(Option.empty(), sslContext)); + nameOpt -> factory.sslContextFactory(factory.clientSniHostName, factory.sslContext, nameOpt)) + .withEngineFactory( + sslContext -> factory.fixedCipherServerSslEngineFactory(Optional.empty(), sslContext)); ServerTlsChannel serverChannel = serverChannelBuilder.build(); diff --git a/src/test/scala/tlschannel/InteroperabilityTest.java b/src/test/scala/tlschannel/InteroperabilityTest.java index 3d97ec32..0437fd0c 100644 --- a/src/test/scala/tlschannel/InteroperabilityTest.java +++ b/src/test/scala/tlschannel/InteroperabilityTest.java @@ -8,11 +8,11 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Optional; import java.util.Random; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; import tlschannel.helpers.*; @TestInstance(Lifecycle.PER_CLASS) @@ -112,7 +112,7 @@ private void fullDuplexStream(Writer serverWriter, Reader clientReader, Writer c // "old-io -> old-io (half duplex) @Test public void testOldToOldHalfDuplex() throws IOException, InterruptedException { - SocketGroups.OldOldSocketPair sockerPair = factory.oldOld(Option.apply(null)); + SocketGroups.OldOldSocketPair sockerPair = factory.oldOld(Optional.empty()); halfDuplexStream( new SSLSocketWriter(sockerPair.server), new SocketReader(sockerPair.client), @@ -123,7 +123,7 @@ public void testOldToOldHalfDuplex() throws IOException, InterruptedException { // old-io -> old-io (full duplex) @Test public void testOldToOldFullDuplex() throws IOException, InterruptedException { - SocketGroups.OldOldSocketPair sockerPair = factory.oldOld(Option.apply(null)); + SocketGroups.OldOldSocketPair sockerPair = factory.oldOld(Optional.empty()); fullDuplexStream( new SSLSocketWriter(sockerPair.server), new SocketReader(sockerPair.client), @@ -136,7 +136,7 @@ public void testOldToOldFullDuplex() throws IOException, InterruptedException { // nio -> old-io (half duplex) @Test public void testNioToOldHalfDuplex() throws IOException, InterruptedException { - SocketGroups.NioOldSocketPair socketPair = factory.nioOld(Option.apply(null)); + SocketGroups.NioOldSocketPair socketPair = factory.nioOld(Optional.empty()); halfDuplexStream( new SSLSocketWriter(socketPair.server), new ByteChannelReader(socketPair.client.tls), @@ -147,7 +147,7 @@ public void testNioToOldHalfDuplex() throws IOException, InterruptedException { // nio -> old-io (full duplex) @Test public void testNioToOldFullDuplex() throws IOException, InterruptedException { - SocketGroups.NioOldSocketPair socketPair = factory.nioOld(Option.apply(null)); + SocketGroups.NioOldSocketPair socketPair = factory.nioOld(Optional.empty()); fullDuplexStream( new SSLSocketWriter(socketPair.server), new ByteChannelReader(socketPair.client.tls), @@ -160,7 +160,7 @@ public void testNioToOldFullDuplex() throws IOException, InterruptedException { // old-io -> nio (half duplex) @Test public void testOldToNioHalfDuplex() throws IOException, InterruptedException { - SocketGroups.OldNioSocketPair socketPair = factory.oldNio(Option.apply(null)); + SocketGroups.OldNioSocketPair socketPair = factory.oldNio(Optional.empty()); halfDuplexStream( new TlsChannelWriter(socketPair.server.tls), new SocketReader(socketPair.client), @@ -171,7 +171,7 @@ public void testOldToNioHalfDuplex() throws IOException, InterruptedException { // old-io -> nio (full duplex) @Test public void testOldToNioFullDuplex() throws IOException, InterruptedException { - SocketGroups.OldNioSocketPair socketPair = factory.oldNio(Option.apply(null)); + SocketGroups.OldNioSocketPair socketPair = factory.oldNio(Optional.empty()); fullDuplexStream( new TlsChannelWriter(socketPair.server.tls), new SocketReader(socketPair.client), diff --git a/src/test/scala/tlschannel/MultiNonBlockingTest.java b/src/test/scala/tlschannel/MultiNonBlockingTest.java index 63efc9d5..b1f7d543 100644 --- a/src/test/scala/tlschannel/MultiNonBlockingTest.java +++ b/src/test/scala/tlschannel/MultiNonBlockingTest.java @@ -3,12 +3,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; -import scala.jdk.CollectionConverters; import tlschannel.helpers.NonBlockingLoops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -26,9 +25,8 @@ public class MultiNonBlockingTest { @Test public void testTaskLoop() { System.out.println("testTasksInExecutorWithRenegotiation():"); - List pairs = CollectionConverters.SeqHasAsJava(factory.nioNioN( - Option.apply(null), totalConnections, Option.apply(null), true, false, Option.apply(null))) - .asJava(); + List pairs = + factory.nioNioN(Optional.empty(), totalConnections, Optional.empty(), true, false, Optional.empty()); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, false); assertEquals(0, report.asyncTasksRun); report.print(); @@ -38,9 +36,8 @@ public void testTaskLoop() { @Test public void testTasksInExecutor() { System.out.println("testTasksInExecutorWithRenegotiation():"); - List pairs = CollectionConverters.SeqHasAsJava(factory.nioNioN( - Option.apply(null), totalConnections, Option.apply(null), false, false, Option.apply(null))) - .asJava(); + List pairs = + factory.nioNioN(Optional.empty(), totalConnections, Optional.empty(), false, false, Optional.empty()); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, false); report.print(); } @@ -49,9 +46,8 @@ public void testTasksInExecutor() { @Test public void testTasksInLoopWithRenegotiation() { System.out.println("testTasksInExecutorWithRenegotiation():"); - List pairs = CollectionConverters.SeqHasAsJava(factory.nioNioN( - Option.apply(null), totalConnections, Option.apply(null), true, false, Option.apply(null))) - .asJava(); + List pairs = + factory.nioNioN(Optional.empty(), totalConnections, Optional.empty(), true, false, Optional.empty()); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, true); assertEquals(0, report.asyncTasksRun); report.print(); @@ -61,9 +57,8 @@ public void testTasksInLoopWithRenegotiation() { @Test public void testTasksInExecutorWithRenegotiation() { System.out.println("testTasksInExecutorWithRenegotiation():"); - List pairs = CollectionConverters.SeqHasAsJava(factory.nioNioN( - Option.apply(null), totalConnections, Option.apply(null), false, false, Option.apply(null))) - .asJava(); + List pairs = + factory.nioNioN(Optional.empty(), totalConnections, Optional.empty(), false, false, Optional.empty()); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, true); report.print(); } diff --git a/src/test/scala/tlschannel/NonBlockingTest.java b/src/test/scala/tlschannel/NonBlockingTest.java index ebf95982..3c4c011e 100644 --- a/src/test/scala/tlschannel/NonBlockingTest.java +++ b/src/test/scala/tlschannel/NonBlockingTest.java @@ -6,8 +6,6 @@ import org.junit.jupiter.api.TestFactory; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; -import scala.Some; import tlschannel.helpers.NonBlockingLoops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -37,13 +35,13 @@ public Collection testSelectorLoop() { ret.add(DynamicTest.dynamicTest( String.format("testSelectorLoop() - size1=%d, size2=%d", size1, size2), () -> { SocketPair socketPair = factory.nioNio( - Option.apply(null), - Some.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(Optional.of(size1), Optional.of(size2)), new ChuckSizes(Optional.of(size1), Optional.of(size2)))), true, false, - Option.apply(null)); + Optional.empty()); NonBlockingLoops.Report report = NonBlockingLoops.loop(Collections.singletonList(socketPair), dataSize, true); diff --git a/src/test/scala/tlschannel/NullEngineTest.java b/src/test/scala/tlschannel/NullEngineTest.java index 1018c16f..30d5ce5d 100644 --- a/src/test/scala/tlschannel/NullEngineTest.java +++ b/src/test/scala/tlschannel/NullEngineTest.java @@ -11,8 +11,6 @@ import org.junit.jupiter.api.TestFactory; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; -import scala.Some; import tlschannel.helpers.Loops; import tlschannel.helpers.SocketGroups; import tlschannel.helpers.SocketPairFactory; @@ -48,12 +46,12 @@ public Collection testHalfDuplexHeapBuffers() { DynamicTest test = DynamicTest.dynamicTest(String.format("Testing sizes: size1=%s", size1), () -> { SocketGroups.SocketPair socketPair = factory.nioNio( null, - Some.apply(new ChunkSizeConfig( + Optional.of(new ChunkSizeConfig( new ChuckSizes(Optional.of(size1), Optional.empty()), new ChuckSizes(Optional.of(size1), Optional.empty()))), true, false, - Option.apply(null)); + Optional.empty()); Loops.halfDuplex(socketPair, dataSize, false, false); System.out.printf("-eng-> %5d -net-> %5d -eng->\n", size1, size1); }); @@ -73,12 +71,12 @@ public Collection testHalfDuplexDirectBuffers() { DynamicTest test = DynamicTest.dynamicTest(String.format("Testing sizes: size1=%s", size1), () -> { SocketGroups.SocketPair socketPair = factory.nioNio( null, - Some.apply(new ChunkSizeConfig( + Optional.of(new ChunkSizeConfig( new ChuckSizes(Optional.of(size1), Optional.empty()), new ChuckSizes(Optional.of(size1), Optional.empty()))), true, false, - Option.apply(null)); + Optional.empty()); Loops.halfDuplex(socketPair, dataSize, false, false); System.out.printf("-eng-> %5d -net-> %5d -eng->\n", size1, size1); }); diff --git a/src/test/scala/tlschannel/NullMultiNonBlockingTest.java b/src/test/scala/tlschannel/NullMultiNonBlockingTest.java index 1cef9716..7b590df0 100644 --- a/src/test/scala/tlschannel/NullMultiNonBlockingTest.java +++ b/src/test/scala/tlschannel/NullMultiNonBlockingTest.java @@ -1,13 +1,12 @@ package tlschannel; import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; -import scala.jdk.CollectionConverters; import tlschannel.helpers.NonBlockingLoops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -26,9 +25,8 @@ public class NullMultiNonBlockingTest { @Test public void testRunTasksInNonBlockingLoop() { - List pairs = CollectionConverters.SeqHasAsJava( - factory.nioNioN(null, totalConnections, Option.apply(null), true, false, Option.apply(null))) - .asJava(); + List pairs = + factory.nioNioN(null, totalConnections, Optional.empty(), true, false, Optional.empty()); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, false); Assertions.assertEquals(0, report.asyncTasksRun); } diff --git a/src/test/scala/tlschannel/ScatteringTest.java b/src/test/scala/tlschannel/ScatteringTest.java index 0f523512..c10c9cd3 100644 --- a/src/test/scala/tlschannel/ScatteringTest.java +++ b/src/test/scala/tlschannel/ScatteringTest.java @@ -1,9 +1,9 @@ package tlschannel; +import java.util.Optional; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; import tlschannel.helpers.Loops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -19,7 +19,7 @@ public class ScatteringTest { @Test public void testHalfDuplex() { - SocketPair socketPair = factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); + SocketPair socketPair = factory.nioNio(Optional.empty(), Optional.empty(), true, false, Optional.empty()); Loops.halfDuplex(socketPair, dataSize, true, false); } } diff --git a/src/test/scala/tlschannel/async/AsyncShutdownTest.java b/src/test/scala/tlschannel/async/AsyncShutdownTest.java index 0814f5df..9b8254de 100644 --- a/src/test/scala/tlschannel/async/AsyncShutdownTest.java +++ b/src/test/scala/tlschannel/async/AsyncShutdownTest.java @@ -11,7 +11,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.jdk.javaapi.CollectionConverters; import tlschannel.helpers.SocketGroups; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; @@ -30,7 +29,7 @@ public void testImmediateShutdown() throws InterruptedException { AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); int socketPairCount = 50; List socketPairs = - CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); + factory.asyncN(null, channelGroup, socketPairCount, true, false); for (SocketGroups.AsyncSocketPair pair : socketPairs) { ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize); pair.client.external.write(writeBuffer); @@ -57,7 +56,7 @@ public void testNonImmediateShutdown() throws InterruptedException, IOException AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); int socketPairCount = 50; List socketPairs = - CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); + factory.asyncN(null, channelGroup, socketPairCount, true, false); for (SocketGroups.AsyncSocketPair pair : socketPairs) { ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize); pair.client.external.write(writeBuffer); diff --git a/src/test/scala/tlschannel/async/AsyncTest.java b/src/test/scala/tlschannel/async/AsyncTest.java index 3e31dabc..e5753338 100644 --- a/src/test/scala/tlschannel/async/AsyncTest.java +++ b/src/test/scala/tlschannel/async/AsyncTest.java @@ -3,11 +3,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; -import scala.jdk.javaapi.CollectionConverters; import tlschannel.helpers.AsyncLoops; import tlschannel.helpers.SocketGroups.AsyncSocketPair; import tlschannel.helpers.SocketPairFactory; @@ -27,8 +26,8 @@ public void testRunTasks() throws Throwable { AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); int dataSize = 5 * 1024 * 1024; System.out.printf("data size: %d\n", dataSize); - List socketPairs = CollectionConverters.asJava( - factory.asyncN(Option.apply(null), channelGroup, socketPairCount, true, false)); + List socketPairs = + factory.asyncN(Optional.empty(), channelGroup, socketPairCount, true, false); AsyncLoops.Report report = AsyncLoops.loop(socketPairs, dataSize); shutdownChannelGroup(channelGroup); @@ -47,8 +46,8 @@ public void testNotRunTasks() throws Throwable { AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); int dataSize = 2 * 1024 * 1024; System.out.printf("data size: %d\n", dataSize); - List socketPairs = CollectionConverters.asJava( - factory.asyncN(Option.apply(null), channelGroup, socketPairCount, false, false)); + List socketPairs = + factory.asyncN(Optional.empty(), channelGroup, socketPairCount, false, false); AsyncLoops.Report report = AsyncLoops.loop(socketPairs, dataSize); shutdownChannelGroup(channelGroup); @@ -70,8 +69,7 @@ public void testNullEngine() throws Throwable { AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); int dataSize = 12 * 1024 * 1024; System.out.printf("data size: %d\n", dataSize); - List socketPairs = - CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); + List socketPairs = factory.asyncN(null, channelGroup, socketPairCount, true, false); AsyncLoops.Report report = AsyncLoops.loop(socketPairs, dataSize); shutdownChannelGroup(channelGroup); diff --git a/src/test/scala/tlschannel/async/AsyncTimeoutTest.java b/src/test/scala/tlschannel/async/AsyncTimeoutTest.java index 52ae2016..76fee420 100644 --- a/src/test/scala/tlschannel/async/AsyncTimeoutTest.java +++ b/src/test/scala/tlschannel/async/AsyncTimeoutTest.java @@ -15,7 +15,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.jdk.javaapi.CollectionConverters; import tlschannel.helpers.SocketGroups; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; @@ -40,7 +39,7 @@ public void testScheduledTimeout() throws IOException { for (int i = 1; i <= repetitions; i++) { int socketPairCount = 50; List socketPairs = - CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); + factory.asyncN(null, channelGroup, socketPairCount, true, false); CountDownLatch latch = new CountDownLatch(socketPairCount * 2); for (SocketGroups.AsyncSocketPair pair : socketPairs) { ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize); @@ -121,7 +120,7 @@ public void testTriggeredTimeout() throws IOException { for (int i = 1; i <= repetitions; i++) { int socketPairCount = 50; List socketPairs = - CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); + factory.asyncN(null, channelGroup, socketPairCount, true, false); for (SocketGroups.AsyncSocketPair pair : socketPairs) { ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize); diff --git a/src/test/scala/tlschannel/async/PseudoAsyncTest.java b/src/test/scala/tlschannel/async/PseudoAsyncTest.java index c8270106..971c52f1 100644 --- a/src/test/scala/tlschannel/async/PseudoAsyncTest.java +++ b/src/test/scala/tlschannel/async/PseudoAsyncTest.java @@ -9,7 +9,6 @@ import org.junit.jupiter.api.TestFactory; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import scala.Option; import tlschannel.helpers.Loops; import tlschannel.helpers.SocketGroups; import tlschannel.helpers.SocketPairFactory; @@ -42,13 +41,13 @@ public Collection testHalfDuplex() { ret.add(DynamicTest.dynamicTest( String.format("testHalfDuplex() - size1=%s, size2=%s", size1, size2), () -> { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Option.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(Optional.of(size1), Optional.of(size2)), new ChuckSizes(Optional.of(size1), Optional.of(size2)))), true, false, - Option.apply(channelGroup)); + Optional.of(channelGroup)); Loops.halfDuplex(socketPair, dataSize, false, false); })); } @@ -68,13 +67,13 @@ public Collection testFullDuplex() { ret.add(DynamicTest.dynamicTest( String.format("testFullDuplex() - size1=%s, size2=%s", size1, size2), () -> { SocketGroups.SocketPair socketPair = factory.nioNio( - Option.apply(null), - Option.apply(new ChunkSizeConfig( + Optional.empty(), + Optional.of(new ChunkSizeConfig( new ChuckSizes(Optional.of(size1), Optional.of(size2)), new ChuckSizes(Optional.of(size1), Optional.of(size2)))), true, false, - Option.apply(channelGroup)); + Optional.of(channelGroup)); Loops.fullDuplex(socketPair, dataSize); })); } diff --git a/src/test/scala/tlschannel/helpers/SocketPairFactory.java b/src/test/scala/tlschannel/helpers/SocketPairFactory.java new file mode 100644 index 00000000..441043bd --- /dev/null +++ b/src/test/scala/tlschannel/helpers/SocketPairFactory.java @@ -0,0 +1,460 @@ +package tlschannel.helpers; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.nio.channels.ByteChannel; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.security.NoSuchAlgorithmException; +import java.util.*; +import java.util.logging.Logger; +import javax.crypto.Cipher; +import javax.net.ssl.*; +import tlschannel.*; +import tlschannel.async.AsynchronousTlsChannel; +import tlschannel.async.AsynchronousTlsChannelGroup; +import tlschannel.helpers.SocketGroups.*; + +/** Create pairs of connected sockets (using the loopback interface). Additionally, all the raw (non-encrypted) socket + * channel are wrapped with a chunking decorator that partitions the bytesProduced of any read or write operation. + */ +public class SocketPairFactory { + + private static final Logger logger = Logger.getLogger(SocketPairFactory.class.getName()); + + private static final int maxAllowedKeyLength; + + static { + try { + maxAllowedKeyLength = Cipher.getMaxAllowedKeyLength("AES"); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + + public static class ChunkSizeConfig { + public final ChuckSizes clientChuckSize; + public final ChuckSizes serverChunkSize; + + public ChunkSizeConfig(ChuckSizes clientChuckSize, ChuckSizes serverChunkSize) { + this.clientChuckSize = clientChuckSize; + this.serverChunkSize = serverChunkSize; + } + } + + public static class ChuckSizes { + public final Optional internalSize; + public final Optional externalSize; + + public ChuckSizes(Optional internalSize, Optional externalSize) { + this.internalSize = internalSize; + this.externalSize = externalSize; + } + } + + public final SSLContext sslContext; + private final String serverName; + private final boolean releaseBuffers = true; + public final SNIHostName clientSniHostName; + private final SNIMatcher expectedSniHostName; + public final InetAddress localhost; + + private final SSLSocketFactory sslSocketFactory; + private final SSLServerSocketFactory sslServerSocketFactory; + + private final TrackingAllocator globalPlainTrackingAllocator = + new TrackingAllocator(TlsChannel.defaultPlainBufferAllocator); + private final TrackingAllocator globalEncryptedTrackingAllocator = + new TrackingAllocator(TlsChannel.defaultEncryptedBufferAllocator); + + public SocketPairFactory(SSLContext sslContext, String serverName) { + this.sslContext = sslContext; + this.serverName = serverName; + this.clientSniHostName = new SNIHostName(serverName); + this.expectedSniHostName = SNIHostName.createSNIMatcher(serverName /* regex! */); + try { + this.localhost = InetAddress.getByName(null); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + this.sslSocketFactory = sslContext.getSocketFactory(); + this.sslServerSocketFactory = sslContext.getServerSocketFactory(); + logger.info(() -> String.format("AES max key length: %s", maxAllowedKeyLength)); + } + + public SocketPairFactory(SSLContext sslContext) { + this(sslContext, SslContextFactory.certificateCommonName); + } + + public SSLEngine fixedCipherServerSslEngineFactory(Optional cipher, SSLContext sslContext) { + SSLEngine engine = sslContext.createSSLEngine(); + engine.setUseClientMode(false); + cipher.ifPresent(c -> engine.setEnabledCipherSuites(new String[] {c})); + return engine; + } + + public Optional sslContextFactory( + SNIServerName expectedName, SSLContext sslContext, Optional name) { + if (name.isPresent()) { + SNIServerName n = name.get(); + logger.warning(() -> "ContextFactory, requested name: " + n); + if (!expectedSniHostName.matches(n)) { + throw new IllegalArgumentException(String.format("Received SNI $n does not match %s", serverName)); + } + return Optional.of(sslContext); + } else { + throw new IllegalArgumentException("SNI expected"); + } + } + + public SSLEngine createClientSslEngine(Optional cipher, int peerPort) { + SSLEngine engine = sslContext.createSSLEngine(serverName, peerPort); + engine.setUseClientMode(true); + cipher.ifPresent(c -> engine.setEnabledCipherSuites(new String[] {c})); + SSLParameters sslParams = engine.getSSLParameters(); // returns a value object + sslParams.setEndpointIdentificationAlgorithm("HTTPS"); + sslParams.setServerNames(Collections.singletonList(clientSniHostName)); + engine.setSSLParameters(sslParams); + return engine; + } + + private SSLServerSocket createSslServerSocket(Optional cipher) { + try { + SSLServerSocket serverSocket = + (SSLServerSocket) sslServerSocketFactory.createServerSocket(0 /* find free port */); + cipher.ifPresent(c -> serverSocket.setEnabledCipherSuites(new String[] {c})); + return serverSocket; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private SSLSocket createSslSocket(Optional cipher, InetAddress host, int port, String requestedHost) { + try { + SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(host, port); + cipher.ifPresent(c -> socket.setEnabledCipherSuites(new String[] {c})); + return socket; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public OldOldSocketPair oldOld(Optional cipher) { + try { + SSLServerSocket serverSocket = createSslServerSocket(cipher); + int chosenPort = serverSocket.getLocalPort(); + SSLSocket client = createSslSocket(cipher, localhost, chosenPort, serverName); + SSLParameters sslParameters = client.getSSLParameters(); // returns a value object + sslParameters.setServerNames(Collections.singletonList(clientSniHostName)); + client.setSSLParameters(sslParameters); + SSLSocket server = (SSLSocket) serverSocket.accept(); + serverSocket.close(); + return new OldOldSocketPair(client, server); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public OldNioSocketPair oldNio(Optional cipher) { + try { + ServerSocketChannel serverSocket = ServerSocketChannel.open(); + serverSocket.bind(new InetSocketAddress(localhost, 0 /* find free port */)); + int chosenPort = ((InetSocketAddress) serverSocket.getLocalAddress()).getPort(); + SSLSocket client = createSslSocket(cipher, localhost, chosenPort, serverName); + SSLParameters sslParameters = client.getSSLParameters(); // returns a value object + sslParameters.setServerNames(Collections.singletonList(clientSniHostName)); + client.setSSLParameters(sslParameters); + SocketChannel rawServer = serverSocket.accept(); + serverSocket.close(); + ServerTlsChannel server = ServerTlsChannel.newBuilder( + rawServer, nameOpt -> sslContextFactory(clientSniHostName, sslContext, nameOpt)) + .withEngineFactory(x -> fixedCipherServerSslEngineFactory(cipher, x)) + .build(); + return new OldNioSocketPair(client, new SocketGroup(server, server, rawServer)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public NioOldSocketPair nioOld(Optional cipher) { + try { + SSLServerSocket serverSocket = createSslServerSocket(cipher); + int chosenPort = serverSocket.getLocalPort(); + InetSocketAddress address = new InetSocketAddress(localhost, chosenPort); + SocketChannel rawClient = SocketChannel.open(address); + SSLSocket server = (SSLSocket) serverSocket.accept(); + serverSocket.close(); + ClientTlsChannel client = ClientTlsChannel.newBuilder(rawClient, createClientSslEngine(cipher, chosenPort)) + .build(); + return new NioOldSocketPair(new SocketGroup(client, client, rawClient), server); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public SocketPair nioNio( + Optional cipher, + Optional chunkSizeConfig, + boolean runTasks, + boolean waitForCloseConfirmation, + Optional pseudoAsyncGroup) { + return nioNioN(cipher, 1, chunkSizeConfig, runTasks, waitForCloseConfirmation, pseudoAsyncGroup) + .get(0); + } + + public List nioNioN( + Optional cipher, + int qtty, + Optional chunkSizeConfig, + boolean runTasks, + boolean waitForCloseConfirmation, + Optional pseudoAsyncGroup) { + try (ServerSocketChannel serverSocket = ServerSocketChannel.open()) { + serverSocket.bind(new InetSocketAddress(localhost, 0 /* find free port */)); + int chosenPort = ((InetSocketAddress) serverSocket.getLocalAddress()).getPort(); + InetSocketAddress address = new InetSocketAddress(localhost, chosenPort); + List pairs = new ArrayList<>(); + for (int i = 0; i < qtty; i++) { + SocketChannel rawClient = SocketChannel.open(address); + SocketChannel rawServer = serverSocket.accept(); + + ByteChannel plainClient; + if (chunkSizeConfig.isPresent()) { + Optional internalSize = chunkSizeConfig.get().clientChuckSize.internalSize; + if (internalSize.isPresent()) { + plainClient = new ChunkingByteChannel(rawClient, internalSize.get()); + } else { + plainClient = rawClient; + } + } else { + plainClient = rawClient; + } + + ByteChannel plainServer; + if (chunkSizeConfig.isPresent()) { + Optional internalSize = chunkSizeConfig.get().serverChunkSize.internalSize; + if (internalSize.isPresent()) { + plainServer = new ChunkingByteChannel(rawServer, internalSize.get()); + } else { + plainServer = rawServer; + } + } else { + plainServer = rawServer; + } + + SSLEngine clientEngine; + if (cipher == null) { + clientEngine = new NullSslEngine(); + } else { + clientEngine = createClientSslEngine(cipher, chosenPort); + } + + ClientTlsChannel clientChannel = ClientTlsChannel.newBuilder(plainClient, clientEngine) + .withRunTasks(runTasks) + .withWaitForCloseConfirmation(waitForCloseConfirmation) + .withPlainBufferAllocator(globalPlainTrackingAllocator) + .withEncryptedBufferAllocator(globalEncryptedTrackingAllocator) + .withReleaseBuffers(releaseBuffers) + .build(); + + ServerTlsChannel.Builder serverChannelBuilder; + if (cipher == null) { + serverChannelBuilder = ServerTlsChannel.newBuilder(plainServer, new NullSslContext()); + } else { + serverChannelBuilder = ServerTlsChannel.newBuilder( + plainServer, nameOpt -> sslContextFactory(clientSniHostName, sslContext, nameOpt)) + .withEngineFactory(ctx -> fixedCipherServerSslEngineFactory(cipher, ctx)); + } + + ServerTlsChannel serverChannel = serverChannelBuilder + .withRunTasks(runTasks) + .withWaitForCloseConfirmation(waitForCloseConfirmation) + .withPlainBufferAllocator(globalPlainTrackingAllocator) + .withEncryptedBufferAllocator(globalEncryptedTrackingAllocator) + .withReleaseBuffers(releaseBuffers) + .build(); + + /* + * Handler executor can be null because BlockerByteChannel will only use Futures, never callbacks. + */ + + ByteChannel clientAsyncChannel; + if (pseudoAsyncGroup.isPresent()) { + rawClient.configureBlocking(false); + clientAsyncChannel = new BlockerByteChannel( + new AsynchronousTlsChannel(pseudoAsyncGroup.get(), clientChannel, rawClient)); + } else { + clientAsyncChannel = clientChannel; + } + + ByteChannel serverAsyncChannel; + if (pseudoAsyncGroup.isPresent()) { + rawServer.configureBlocking(false); + serverAsyncChannel = new BlockerByteChannel( + new AsynchronousTlsChannel(pseudoAsyncGroup.get(), serverChannel, rawServer)); + } else { + serverAsyncChannel = serverChannel; + } + + ByteChannel externalClient; + if (chunkSizeConfig.isPresent()) { + Optional size = chunkSizeConfig.get().clientChuckSize.externalSize; + if (size.isPresent()) { + externalClient = new ChunkingByteChannel(clientAsyncChannel, size.get()); + } else { + externalClient = clientChannel; + } + } else { + externalClient = clientChannel; + } + + ByteChannel externalServer; + if (chunkSizeConfig.isPresent()) { + Optional size = chunkSizeConfig.get().serverChunkSize.externalSize; + if (size.isPresent()) { + externalServer = new ChunkingByteChannel(serverAsyncChannel, size.get()); + } else { + externalServer = serverChannel; + } + } else { + externalServer = serverChannel; + } + + SocketGroup clientPair = new SocketGroup(externalClient, clientChannel, rawClient); + SocketGroup serverPair = new SocketGroup(externalServer, serverChannel, rawServer); + pairs.add(new SocketPair(clientPair, serverPair)); + } + return pairs; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public AsyncSocketPair async( + Optional cipher, + AsynchronousTlsChannelGroup channelGroup, + boolean runTasks, + boolean waitForCloseConfirmation) { + return asyncN(cipher, channelGroup, 1, runTasks, waitForCloseConfirmation) + .get(0); + } + + public List asyncN( + Optional cipher, + AsynchronousTlsChannelGroup channelGroup, + int qtty, + boolean runTasks, + boolean waitForCloseConfirmation) { + try (ServerSocketChannel serverSocket = ServerSocketChannel.open()) { + serverSocket.bind(new InetSocketAddress(localhost, 0 /* find free port */)); + int chosenPort = ((InetSocketAddress) serverSocket.getLocalAddress()).getPort(); + InetSocketAddress address = new InetSocketAddress(localhost, chosenPort); + + List pairs = new ArrayList<>(); + for (int i = 0; i < qtty; i++) { + SocketChannel rawClient = SocketChannel.open(address); + SocketChannel rawServer = serverSocket.accept(); + + rawClient.configureBlocking(false); + rawServer.configureBlocking(false); + + SSLEngine clientEngine; + if (cipher == null) { + clientEngine = new NullSslEngine(); + } else { + clientEngine = createClientSslEngine(cipher, chosenPort); + } + + ClientTlsChannel clientChannel = ClientTlsChannel.newBuilder( + new RandomChunkingByteChannel(rawClient, SocketPairFactory::getChunkingSize), + clientEngine) + .withWaitForCloseConfirmation(waitForCloseConfirmation) + .withPlainBufferAllocator(globalPlainTrackingAllocator) + .withRunTasks(runTasks) + .withEncryptedBufferAllocator(globalEncryptedTrackingAllocator) + .withReleaseBuffers(releaseBuffers) + .build(); + + ServerTlsChannel.Builder serverChannelBuilder; + if (cipher == null) { + serverChannelBuilder = ServerTlsChannel.newBuilder( + new RandomChunkingByteChannel(rawServer, SocketPairFactory::getChunkingSize), + new NullSslContext()); + } else { + serverChannelBuilder = ServerTlsChannel.newBuilder( + new RandomChunkingByteChannel(rawServer, SocketPairFactory::getChunkingSize), + nameOpt -> sslContextFactory(clientSniHostName, sslContext, nameOpt)) + .withEngineFactory(ctx -> fixedCipherServerSslEngineFactory(cipher, ctx)); + } + + ServerTlsChannel serverChannel = serverChannelBuilder + .withWaitForCloseConfirmation(waitForCloseConfirmation) + .withPlainBufferAllocator(globalPlainTrackingAllocator) + .withEncryptedBufferAllocator(globalEncryptedTrackingAllocator) + .withReleaseBuffers(releaseBuffers) + .build(); + + AsynchronousTlsChannel clientAsyncChannel = + new AsynchronousTlsChannel(channelGroup, clientChannel, rawClient); + AsynchronousTlsChannel serverAsyncChannel = + new AsynchronousTlsChannel(channelGroup, serverChannel, rawServer); + + AsyncSocketGroup clientPair = new AsyncSocketGroup(clientAsyncChannel, clientChannel, rawClient); + AsyncSocketGroup serverPair = new AsyncSocketGroup(serverAsyncChannel, serverChannel, rawServer); + pairs.add(new AsyncSocketPair(clientPair, serverPair)); + } + return pairs; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public String getGlobalAllocationReport() { + TrackingAllocator plainAlloc = globalPlainTrackingAllocator; + TrackingAllocator encryptedAlloc = globalEncryptedTrackingAllocator; + long maxPlain = plainAlloc.maxAllocation(); + long maxEncrypted = encryptedAlloc.maxAllocation(); + long totalPlain = plainAlloc.bytesAllocated(); + long totalEncrypted = encryptedAlloc.bytesAllocated(); + long buffersAllocatedPlain = plainAlloc.buffersAllocated(); + long buffersAllocatedEncrypted = encryptedAlloc.buffersAllocated(); + long buffersDeallocatedPlain = plainAlloc.buffersDeallocated(); + long buffersDeallocatedEncrypted = encryptedAlloc.buffersDeallocated(); + return "Allocation report:\n" + + String.format(" max allocation (bytes) - plain: %s - encrypted: %s\n", maxPlain, maxEncrypted) + + String.format(" total allocation (bytes) - plain: %s - encrypted: %s\n", totalPlain, totalEncrypted) + + String.format( + " buffers allocated - plain: %s - encrypted: %s\n", + buffersAllocatedPlain, buffersAllocatedEncrypted) + + String.format( + " buffers deallocated - plain: %s - encrypted: %s\n", + buffersDeallocatedPlain, buffersDeallocatedEncrypted); + } + + public static void checkDeallocation(SocketPair socketPair) { + checkBufferDeallocation(socketPair.client.tls.getPlainBufferAllocator()); + checkBufferDeallocation(socketPair.client.tls.getEncryptedBufferAllocator()); + } + + public static void checkDeallocation(AsyncSocketPair socketPair) { + checkBufferDeallocation(socketPair.client.tls.getPlainBufferAllocator()); + checkBufferDeallocation(socketPair.client.tls.getEncryptedBufferAllocator()); + } + + private static void checkBufferDeallocation(TrackingAllocator allocator) { + logger.fine(() -> String.format("allocator: %s; allocated: %s}", allocator, allocator.bytesAllocated())); + logger.fine(() -> String.format("allocator: %s; deallocated: %s", allocator, allocator.bytesDeallocated())); + assertEquals(allocator.bytesDeallocated(), allocator.bytesAllocated(), " - some buffers were not deallocated"); + } + + private static int getChunkingSize() { + double labmda = 1.0 / SslContextFactory.tlsMaxDataSize; + double uniform = new Random().nextDouble(); + double exp = Math.log(uniform) * (-1 / labmda); + return Math.max((int) exp, 1); + } +} diff --git a/src/test/scala/tlschannel/helpers/SocketPairFactory.scala b/src/test/scala/tlschannel/helpers/SocketPairFactory.scala deleted file mode 100644 index 2c6e11bf..00000000 --- a/src/test/scala/tlschannel/helpers/SocketPairFactory.scala +++ /dev/null @@ -1,420 +0,0 @@ -package tlschannel.helpers - -import java.nio.channels.ServerSocketChannel -import java.nio.channels.SocketChannel -import java.net.InetAddress -import java.net.InetSocketAddress -import javax.net.ssl.SSLContext -import javax.net.ssl.SSLServerSocket -import javax.net.ssl.SSLSocket -import javax.net.ssl.SSLEngine -import org.junit.jupiter.api.Assertions.assertEquals - -import javax.crypto.Cipher -import java.util.Optional -import javax.net.ssl.SNIHostName -import javax.net.ssl.SNIServerName -import tlschannel._ -import tlschannel.async.AsynchronousTlsChannel -import tlschannel.async.AsynchronousTlsChannelGroup - -import java.util.logging.Logger -import scala.jdk.CollectionConverters._ -import scala.util.Random -import tlschannel.helpers.SocketGroups._ - -/** Create pairs of connected sockets (using the loopback interface). Additionally, all the raw (non-encrypted) socket - * channel are wrapped with a chunking decorator that partitions the bytesProduced of any read or write operation. - */ -class SocketPairFactory(val sslContext: SSLContext, val serverName: String) { - - import SocketPairFactory._ - - def this(sslContext: SSLContext) = { - this(sslContext, SslContextFactory.certificateCommonName) - } - - val logger = Logger.getLogger(classOf[SocketPairFactory].getName) - - private val releaseBuffers = true - - val clientSniHostName = new SNIHostName(serverName) - private val expectedSniHostName = SNIHostName.createSNIMatcher(serverName /* regex! */ ) - - def fixedCipherServerSslEngineFactory(cipher: Option[String])(sslContext: SSLContext): SSLEngine = { - val engine = sslContext.createSSLEngine() - engine.setUseClientMode(false) - cipher.foreach(c => engine.setEnabledCipherSuites(Array(c))) - engine - } - - def sslContextFactory(expectedName: SNIServerName, sslContext: SSLContext)( - name: Optional[SNIServerName] - ): Optional[SSLContext] = { - if (name.isPresent) { - val n = name.get - logger.warning(() => "ContextFactory, requested name: " + n) - if (!expectedSniHostName.matches(n)) { - throw new IllegalArgumentException(s"Received SNI $n does not match $serverName") - } - Optional.of(sslContext) - } else { - throw new IllegalArgumentException("SNI expected") - } - } - - val localhost = InetAddress.getByName(null) - - logger.info(s"AES max key length: ${Cipher.getMaxAllowedKeyLength("AES")}") - - val sslSocketFactory = sslContext.getSocketFactory - val sslServerSocketFactory = sslContext.getServerSocketFactory - - val globalPlainTrackingAllocator = new TrackingAllocator(TlsChannel.defaultPlainBufferAllocator) - val globalEncryptedTrackingAllocator = new TrackingAllocator(TlsChannel.defaultEncryptedBufferAllocator) - - def createClientSslEngine(cipher: Option[String], peerPort: Integer): SSLEngine = { - val engine = sslContext.createSSLEngine(serverName, peerPort) - engine.setUseClientMode(true) - cipher.foreach(c => engine.setEnabledCipherSuites(Array(c))) - val sslParams = engine.getSSLParameters() // returns a value object - sslParams.setEndpointIdentificationAlgorithm("HTTPS") - sslParams.setServerNames(Seq[SNIServerName](clientSniHostName).asJava) - engine.setSSLParameters(sslParams) - engine - } - - private def createSslServerSocket(cipher: Option[String]): SSLServerSocket = { - val serverSocket = sslServerSocketFactory.createServerSocket(0 /* find free port */ ).asInstanceOf[SSLServerSocket] - cipher.foreach(c => serverSocket.setEnabledCipherSuites(Array(c))) - serverSocket - } - - private def createSslSocket( - cipher: Option[String], - host: InetAddress, - port: Int, - requestedHost: String - ): SSLSocket = { - val socket = sslSocketFactory.createSocket(host, port).asInstanceOf[SSLSocket] - cipher.foreach(c => socket.setEnabledCipherSuites(Array(c))) - socket - } - - def oldOld(cipher: Option[String] = None): OldOldSocketPair = { - val serverSocket = createSslServerSocket(cipher) - val chosenPort = serverSocket.getLocalPort - val client = createSslSocket(cipher, localhost, chosenPort, requestedHost = serverName) - val sslParameters = client.getSSLParameters // returns a value object - sslParameters.setServerNames(Seq[SNIServerName](clientSniHostName).asJava) - client.setSSLParameters(sslParameters) - val server = serverSocket.accept().asInstanceOf[SSLSocket] - serverSocket.close() - new OldOldSocketPair(client, server) - } - - def oldNio(cipher: Option[String] = None): OldNioSocketPair = { - val serverSocket = ServerSocketChannel.open() - serverSocket.bind(new InetSocketAddress(localhost, 0 /* find free port */ )) - val chosenPort = serverSocket.getLocalAddress.asInstanceOf[InetSocketAddress].getPort - val client = createSslSocket(cipher, localhost, chosenPort, requestedHost = serverName) - val sslParameters = client.getSSLParameters // returns a value object - sslParameters.setServerNames(Seq[SNIServerName](clientSniHostName).asJava) - client.setSSLParameters(sslParameters) - val rawServer = serverSocket.accept() - serverSocket.close() - val server = ServerTlsChannel - .newBuilder(rawServer, nameOpt => sslContextFactory(clientSniHostName, sslContext)(nameOpt)) - .withEngineFactory(fixedCipherServerSslEngineFactory(cipher) _) - .build() - new OldNioSocketPair(client, new SocketGroup(server, server, rawServer)) - } - - def nioOld(cipher: Option[String] = None): NioOldSocketPair = { - val serverSocket = createSslServerSocket(cipher) - val chosenPort = serverSocket.getLocalPort - val address = new InetSocketAddress(localhost, chosenPort) - val rawClient = SocketChannel.open(address) - val server = serverSocket.accept().asInstanceOf[SSLSocket] - serverSocket.close() - val client = ClientTlsChannel - .newBuilder(rawClient, createClientSslEngine(cipher, chosenPort)) - .build() - new NioOldSocketPair(new SocketGroup(client, client, rawClient), server) - } - - def nioNio( - cipher: Option[String] = None, - chunkSizeConfig: Option[ChunkSizeConfig] = None, - runTasks: Boolean = true, - waitForCloseConfirmation: Boolean = false, - pseudoAsyncGroup: Option[AsynchronousTlsChannelGroup] = None - ): SocketPair = { - nioNioN( - cipher, - 1, - chunkSizeConfig, - runTasks, - waitForCloseConfirmation, - pseudoAsyncGroup - ).head - } - - def nioNioN( - cipher: Option[String] = None, - qtty: Int, - chunkSizeConfig: Option[ChunkSizeConfig] = None, - runTasks: Boolean = true, - waitForCloseConfirmation: Boolean = false, - pseudoAsyncGroup: Option[AsynchronousTlsChannelGroup] = None - ): Seq[SocketPair] = { - val serverSocket = ServerSocketChannel.open() - try { - serverSocket.bind(new InetSocketAddress(localhost, 0 /* find free port */ )) - val chosenPort = serverSocket.getLocalAddress.asInstanceOf[InetSocketAddress].getPort - val address = new InetSocketAddress(localhost, chosenPort) - for (_ <- 0 until qtty) yield { - val rawClient = SocketChannel.open(address) - val rawServer = serverSocket.accept() - - val plainClient = chunkSizeConfig match { - case Some(config) => - val internalSize = config.clientChuckSize.internalSize - if (internalSize.isPresent) { - new ChunkingByteChannel(rawClient, internalSize.get) - } else { - rawClient - } - case None => - rawClient - } - - val plainServer = chunkSizeConfig match { - case Some(config) => - val internalSize = config.serverChunkSize.internalSize - if (internalSize.isPresent) { - new ChunkingByteChannel(rawServer, internalSize.get) - } else { - rawServer - } - case None => - rawServer - } - - val clientEngine = if (cipher == null) { - new NullSslEngine - } else { - createClientSslEngine(cipher, chosenPort) - } - - val clientChannel = ClientTlsChannel - .newBuilder(plainClient, clientEngine) - .withRunTasks(runTasks) - .withWaitForCloseConfirmation(waitForCloseConfirmation) - .withPlainBufferAllocator(globalPlainTrackingAllocator) - .withEncryptedBufferAllocator(globalEncryptedTrackingAllocator) - .withReleaseBuffers(releaseBuffers) - .build() - - val serverChannelBuilder = if (cipher == null) { - ServerTlsChannel - .newBuilder(plainServer, new NullSslContext) - } else { - ServerTlsChannel - .newBuilder(plainServer, nameOpt => sslContextFactory(clientSniHostName, sslContext)(nameOpt)) - .withEngineFactory(fixedCipherServerSslEngineFactory(cipher) _) - } - - val serverChannel = serverChannelBuilder - .withRunTasks(runTasks) - .withWaitForCloseConfirmation(waitForCloseConfirmation) - .withPlainBufferAllocator(globalPlainTrackingAllocator) - .withEncryptedBufferAllocator(globalEncryptedTrackingAllocator) - .withReleaseBuffers(releaseBuffers) - .build() - - /* - * Handler executor can be null because BlockerByteChannel will only use Futures, never callbacks. - */ - - val clientAsyncChannel = pseudoAsyncGroup match { - case Some(channelGroup) => - rawClient.configureBlocking(false) - new BlockerByteChannel(new AsynchronousTlsChannel(channelGroup, clientChannel, rawClient)) - case None => - clientChannel - } - - val serverAsyncChannel = pseudoAsyncGroup match { - case Some(channelGroup) => - rawServer.configureBlocking(false) - new BlockerByteChannel(new AsynchronousTlsChannel(channelGroup, serverChannel, rawServer)) - case None => - serverChannel - } - - val externalClient = chunkSizeConfig match { - case Some(config) => - val size = config.clientChuckSize.externalSize - if (size.isPresent) { - new ChunkingByteChannel(clientAsyncChannel, size.get) - } else { - clientChannel - } - case None => - clientChannel - } - - val externalServer = chunkSizeConfig match { - case Some(config) => - val size = config.serverChunkSize.externalSize - if (size.isPresent) { - new ChunkingByteChannel(serverAsyncChannel, size.get) - } else { - serverChannel - } - case None => - serverChannel - } - - val clientPair = new SocketGroup(externalClient, clientChannel, rawClient) - val serverPair = new SocketGroup(externalServer, serverChannel, rawServer) - new SocketPair(clientPair, serverPair) - } - } finally { - serverSocket.close() - } - } - - def async( - cipher: Option[String] = None, - channelGroup: AsynchronousTlsChannelGroup, - runTasks: Boolean, - waitForCloseConfirmation: Boolean = false - ): AsyncSocketPair = { - asyncN(cipher, channelGroup, 1, runTasks, waitForCloseConfirmation).head - } - - def asyncN( - cipher: Option[String] = None, - channelGroup: AsynchronousTlsChannelGroup, - qtty: Int, - runTasks: Boolean, - waitForCloseConfirmation: Boolean = false - ): Seq[AsyncSocketPair] = { - val serverSocket = ServerSocketChannel.open() - - try { - serverSocket.bind(new InetSocketAddress(localhost, 0 /* find free port */ )) - val chosenPort = serverSocket.getLocalAddress.asInstanceOf[InetSocketAddress].getPort - val address = new InetSocketAddress(localhost, chosenPort) - for (_ <- 0 until qtty) yield { - val rawClient = SocketChannel.open(address) - val rawServer = serverSocket.accept() - - rawClient.configureBlocking(false) - rawServer.configureBlocking(false) - - val clientEngine = if (cipher == null) { - new NullSslEngine - } else { - createClientSslEngine(cipher, chosenPort) - } - - val clientChannel = ClientTlsChannel - .newBuilder(new RandomChunkingByteChannel(rawClient, SocketPairFactory.getChunkingSize _), clientEngine) - .withWaitForCloseConfirmation(waitForCloseConfirmation) - .withPlainBufferAllocator(globalPlainTrackingAllocator) - .withRunTasks(runTasks) - .withEncryptedBufferAllocator(globalEncryptedTrackingAllocator) - .withReleaseBuffers(releaseBuffers) - .build() - - val serverChannelBuilder = if (cipher == null) { - ServerTlsChannel - .newBuilder( - new RandomChunkingByteChannel(rawServer, SocketPairFactory.getChunkingSize _), - new NullSslContext - ) - } else { - ServerTlsChannel - .newBuilder( - new RandomChunkingByteChannel(rawServer, SocketPairFactory.getChunkingSize _), - nameOpt => sslContextFactory(clientSniHostName, sslContext)(nameOpt) - ) - .withEngineFactory(fixedCipherServerSslEngineFactory(cipher)) - } - - val serverChannel = serverChannelBuilder - .withWaitForCloseConfirmation(waitForCloseConfirmation) - .withPlainBufferAllocator(globalPlainTrackingAllocator) - .withEncryptedBufferAllocator(globalEncryptedTrackingAllocator) - .withReleaseBuffers(releaseBuffers) - .build() - - val clientAsyncChannel = new AsynchronousTlsChannel(channelGroup, clientChannel, rawClient) - val serverAsyncChannel = new AsynchronousTlsChannel(channelGroup, serverChannel, rawServer) - - val clientPair = new AsyncSocketGroup(clientAsyncChannel, clientChannel, rawClient) - val serverPair = new AsyncSocketGroup(serverAsyncChannel, serverChannel, rawServer) - new AsyncSocketPair(clientPair, serverPair) - } - } finally { - serverSocket.close() - } - } - - def getGlobalAllocationReport(): String = { - val plainAlloc = globalPlainTrackingAllocator - val encryptedAlloc = globalEncryptedTrackingAllocator - val maxPlain = plainAlloc.maxAllocation() - val maxEncrypted = encryptedAlloc.maxAllocation() - val totalPlain = plainAlloc.bytesAllocated() - val totalEncrypted = encryptedAlloc.bytesAllocated() - val buffersAllocatedPlain = plainAlloc.buffersAllocated() - val buffersAllocatedEncrypted = encryptedAlloc.buffersAllocated() - val buffersDeallocatedPlain = plainAlloc.buffersDeallocated() - val buffersDeallocatedEncrypted = encryptedAlloc.buffersDeallocated() - val ret = new StringBuilder - ret ++= s"Allocation report:\n" - ret ++= s" max allocation (bytes) - plain: $maxPlain - encrypted: $maxEncrypted\n" - ret ++= s" total allocation (bytes) - plain: $totalPlain - encrypted: $totalEncrypted\n" - ret ++= s" buffers allocated - plain: $buffersAllocatedPlain - encrypted: $buffersAllocatedEncrypted\n" - ret ++= s" buffers deallocated - plain: $buffersDeallocatedPlain - encrypted: $buffersDeallocatedEncrypted\n" - ret.toString() - } - -} - -object SocketPairFactory { - - val logger = Logger.getLogger(SocketPairFactory.getClass.getName) - - def checkDeallocation(socketPair: SocketPair) = { - checkBufferDeallocation(socketPair.client.tls.getPlainBufferAllocator) - checkBufferDeallocation(socketPair.client.tls.getEncryptedBufferAllocator) - } - - def checkDeallocation(socketPair: AsyncSocketPair) = { - checkBufferDeallocation(socketPair.client.tls.getPlainBufferAllocator) - checkBufferDeallocation(socketPair.client.tls.getEncryptedBufferAllocator) - } - - private def checkBufferDeallocation(allocator: TrackingAllocator) = { - logger.fine(() => s"allocator: $allocator; allocated: ${allocator.bytesAllocated()}") - logger.fine(() => s"allocator: $allocator; deallocated: ${allocator.bytesDeallocated()}") - assertEquals(allocator.bytesDeallocated(), allocator.bytesAllocated(), " - some buffers were not deallocated") - } - - def getChunkingSize(): Int = { - val labmda = 1.0 / SslContextFactory.tlsMaxDataSize - val uniform = Random.nextDouble() - val exp = math.log(uniform) * (-1 / labmda) - math.max(exp.toInt, 1) - } - - case class ChunkSizeConfig(clientChuckSize: ChuckSizes, serverChunkSize: ChuckSizes) - - case class ChuckSizes(internalSize: Optional[Integer], externalSize: Optional[Integer]) - -}