From 45c4dfd4462dd9f0325257cbfd9a04d9741df53a Mon Sep 17 00:00:00 2001 From: Mariano Barrios Date: Sun, 21 Apr 2024 14:12:46 +0200 Subject: [PATCH] Migrate to Java: SocketGroups --- src/test/scala/tlschannel/AllocationTest.java | 2 +- src/test/scala/tlschannel/BlockingTest.java | 2 +- src/test/scala/tlschannel/CipherTest.java | 10 ++-- src/test/scala/tlschannel/ConcurrentTest.java | 31 +++++------ .../tlschannel/MultiNonBlockingTest.java | 6 +-- .../scala/tlschannel/NonBlockingTest.java | 2 +- src/test/scala/tlschannel/NullEngineTest.java | 5 +- .../tlschannel/NullMultiNonBlockingTest.java | 2 +- src/test/scala/tlschannel/ScatteringTest.java | 2 +- .../tlschannel/async/AsyncCloseTest.java | 14 ++--- .../tlschannel/async/AsyncShutdownTest.scala | 19 ++++--- .../scala/tlschannel/async/AsyncTest.java | 2 +- .../tlschannel/async/AsyncTimeoutTest.scala | 25 +++++---- .../scala/tlschannel/helpers/AsyncLoops.scala | 7 +-- src/test/scala/tlschannel/helpers/Loops.scala | 2 +- .../tlschannel/helpers/NonBlockingLoops.scala | 15 +++--- .../tlschannel/helpers/SocketGroups.java | 53 +++++++++++++++++++ .../helpers/SocketPairFactory.scala | 27 ++++------ 18 files changed, 134 insertions(+), 92 deletions(-) create mode 100644 src/test/scala/tlschannel/helpers/SocketGroups.java diff --git a/src/test/scala/tlschannel/AllocationTest.java b/src/test/scala/tlschannel/AllocationTest.java index efaf43cc..8ddaa2de 100644 --- a/src/test/scala/tlschannel/AllocationTest.java +++ b/src/test/scala/tlschannel/AllocationTest.java @@ -4,7 +4,7 @@ import java.lang.management.MemoryMXBean; import scala.Option; import tlschannel.helpers.Loops; -import tlschannel.helpers.SocketPair; +import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; diff --git a/src/test/scala/tlschannel/BlockingTest.java b/src/test/scala/tlschannel/BlockingTest.java index 8cc8836f..67114e53 100644 --- a/src/test/scala/tlschannel/BlockingTest.java +++ b/src/test/scala/tlschannel/BlockingTest.java @@ -8,7 +8,7 @@ import org.junit.jupiter.api.TestInstance.Lifecycle; import scala.Option; import tlschannel.helpers.Loops; -import tlschannel.helpers.SocketPair; +import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SocketPairFactory.ChuckSizes; import tlschannel.helpers.SocketPairFactory.ChunkSizeConfig; diff --git a/src/test/scala/tlschannel/CipherTest.java b/src/test/scala/tlschannel/CipherTest.java index 3a696206..ac77a944 100644 --- a/src/test/scala/tlschannel/CipherTest.java +++ b/src/test/scala/tlschannel/CipherTest.java @@ -15,7 +15,7 @@ import scala.Some; import scala.jdk.CollectionConverters; import tlschannel.helpers.Loops; -import tlschannel.helpers.SocketPair; +import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; @@ -54,8 +54,8 @@ public Collection testHalfDuplexWithRenegotiation() { Some.apply(cipher), Option.apply(null), true, false, Option.apply(null)); Loops.halfDuplex(socketPair, dataSize, protocol.compareTo("TLSv1.2") < 0, false); String actualProtocol = socketPair - .client() - .tls() + .client + .tls .getSslEngine() .getSession() .getProtocol(); @@ -82,8 +82,8 @@ public Collection testFullDuplex() { Some.apply(cipher), Option.apply(null), true, false, Option.apply(null)); Loops.fullDuplex(socketPair, dataSize); String actualProtocol = socketPair - .client() - .tls() + .client + .tls .getSslEngine() .getSession() .getProtocol(); diff --git a/src/test/scala/tlschannel/ConcurrentTest.java b/src/test/scala/tlschannel/ConcurrentTest.java index 393e7be1..7b7f85e9 100644 --- a/src/test/scala/tlschannel/ConcurrentTest.java +++ b/src/test/scala/tlschannel/ConcurrentTest.java @@ -2,6 +2,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static tlschannel.helpers.SocketGroups.*; import java.io.IOException; import java.nio.ByteBuffer; @@ -31,15 +32,11 @@ public class ConcurrentTest { @Test public void testWriteSide() throws IOException { SocketPair socketPair = factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); - Thread clientWriterThread1 = - new Thread(() -> writerLoop(dataSize, 'a', socketPair.client()), "client-writer-1"); - Thread clientWriterThread2 = - new Thread(() -> writerLoop(dataSize, 'b', socketPair.client()), "client-writer-2"); - Thread clientWriterThread3 = - new Thread(() -> writerLoop(dataSize, 'c', socketPair.client()), "client-writer-3"); - Thread clientWriterThread4 = - new Thread(() -> writerLoop(dataSize, 'd', socketPair.client()), "client-writer-4"); - Thread serverReaderThread = new Thread(() -> readerLoop(dataSize * 4, socketPair.server()), "server-reader"); + Thread clientWriterThread1 = new Thread(() -> writerLoop(dataSize, 'a', socketPair.client), "client-writer-1"); + Thread clientWriterThread2 = new Thread(() -> writerLoop(dataSize, 'b', socketPair.client), "client-writer-2"); + Thread clientWriterThread3 = new Thread(() -> writerLoop(dataSize, 'c', socketPair.client), "client-writer-3"); + Thread clientWriterThread4 = new Thread(() -> writerLoop(dataSize, 'd', socketPair.client), "client-writer-4"); + Thread serverReaderThread = new Thread(() -> readerLoop(dataSize * 4, socketPair.server), "server-reader"); Stream.of( serverReaderThread, clientWriterThread1, @@ -49,7 +46,7 @@ public void testWriteSide() throws IOException { .forEach(t -> t.start()); Stream.of(clientWriterThread1, clientWriterThread2, clientWriterThread3, clientWriterThread4) .forEach(t -> joinInterruptible(t)); - socketPair.client().external().close(); + socketPair.client.external.close(); joinInterruptible(serverReaderThread); SocketPairFactory.checkDeallocation(socketPair); } @@ -58,15 +55,15 @@ public void testWriteSide() throws IOException { @Test public void testReadSide() throws IOException { SocketPair socketPair = factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); - Thread clientWriterThread = new Thread(() -> writerLoop(dataSize, 'a', socketPair.client()), "client-writer"); + Thread clientWriterThread = new Thread(() -> writerLoop(dataSize, 'a', socketPair.client), "client-writer"); AtomicLong totalRead = new AtomicLong(); Thread serverReaderThread1 = - new Thread(() -> readerLoopUntilEof(socketPair.server(), totalRead), "server-reader-1"); + new Thread(() -> readerLoopUntilEof(socketPair.server, totalRead), "server-reader-1"); Thread serverReaderThread2 = - new Thread(() -> readerLoopUntilEof(socketPair.server(), totalRead), "server-reader-2"); + new Thread(() -> readerLoopUntilEof(socketPair.server, totalRead), "server-reader-2"); Stream.of(serverReaderThread1, serverReaderThread2, clientWriterThread).forEach(t -> t.start()); joinInterruptible(clientWriterThread); - socketPair.client().external().close(); + socketPair.client.external.close(); Stream.of(serverReaderThread1, serverReaderThread2).forEach(t -> joinInterruptible(t)); SocketPairFactory.checkDeallocation(socketPair); assertEquals(dataSize, totalRead.get()); @@ -82,7 +79,7 @@ private void writerLoop(int size, char ch, SocketGroup socketGroup) { while (bytesRemaining > 0) { ByteBuffer buffer = ByteBuffer.wrap(bufferArray, 0, Math.min(bufferSize, bytesRemaining)); while (buffer.hasRemaining()) { - int c = socketGroup.external().write(buffer); + int c = socketGroup.external.write(buffer); assertTrue(c > 0, "blocking write must return a positive number"); bytesRemaining -= c; assertTrue(bytesRemaining >= 0); @@ -104,7 +101,7 @@ private void readerLoop(int size, SocketGroup socketGroup) { int bytesRemaining = size; while (bytesRemaining > 0) { ByteBuffer readBuffer = ByteBuffer.wrap(readArray, 0, Math.min(bufferSize, bytesRemaining)); - int c = socketGroup.external().read(readBuffer); + int c = socketGroup.external.read(readBuffer); assertTrue(c > 0, "blocking read must return a positive number"); bytesRemaining -= c; assertTrue(bytesRemaining >= 0); @@ -124,7 +121,7 @@ private void readerLoopUntilEof(SocketGroup socketGroup, AtomicLong accumulator) byte[] readArray = new byte[bufferSize]; while (true) { ByteBuffer readBuffer = ByteBuffer.wrap(readArray, 0, bufferSize); - int c = socketGroup.external().read(readBuffer); + int c = socketGroup.external.read(readBuffer); if (c == -1) { logger.fine("Finalizing reader loop"); return null; diff --git a/src/test/scala/tlschannel/MultiNonBlockingTest.java b/src/test/scala/tlschannel/MultiNonBlockingTest.java index e2d77bdc..65a075c2 100644 --- a/src/test/scala/tlschannel/MultiNonBlockingTest.java +++ b/src/test/scala/tlschannel/MultiNonBlockingTest.java @@ -9,7 +9,7 @@ import scala.Option; import scala.collection.immutable.Seq; import tlschannel.helpers.NonBlockingLoops; -import tlschannel.helpers.SocketPair; +import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; @@ -46,7 +46,7 @@ public void testTasksInExecutor() { @Test public void testTasksInLoopWithRenegotiation() { System.out.println("testTasksInExecutorWithRenegotiation():"); - Seq pairs = factory.nioNioN( + Seq pairs = factory.nioNioN( Option.apply(null), totalConnections, Option.apply(null), true, false, Option.apply(null)); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, true); assertEquals(0, report.asyncTasksRun()); @@ -57,7 +57,7 @@ public void testTasksInLoopWithRenegotiation() { @Test public void testTasksInExecutorWithRenegotiation() { System.out.println("testTasksInExecutorWithRenegotiation():"); - Seq pairs = factory.nioNioN( + Seq pairs = factory.nioNioN( Option.apply(null), totalConnections, Option.apply(null), false, false, Option.apply(null)); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, true); report.print(); diff --git a/src/test/scala/tlschannel/NonBlockingTest.java b/src/test/scala/tlschannel/NonBlockingTest.java index 4c805ccd..5847e269 100644 --- a/src/test/scala/tlschannel/NonBlockingTest.java +++ b/src/test/scala/tlschannel/NonBlockingTest.java @@ -13,7 +13,7 @@ import scala.Some; import scala.jdk.javaapi.CollectionConverters; import tlschannel.helpers.NonBlockingLoops; -import tlschannel.helpers.SocketPair; +import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SocketPairFactory.ChuckSizes; import tlschannel.helpers.SocketPairFactory.ChunkSizeConfig; diff --git a/src/test/scala/tlschannel/NullEngineTest.java b/src/test/scala/tlschannel/NullEngineTest.java index ba7cf292..cf6e0c28 100644 --- a/src/test/scala/tlschannel/NullEngineTest.java +++ b/src/test/scala/tlschannel/NullEngineTest.java @@ -13,6 +13,7 @@ import scala.Option; import scala.Some; import tlschannel.helpers.Loops; +import tlschannel.helpers.SocketGroups; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SocketPairFactory.ChuckSizes; import tlschannel.helpers.SocketPairFactory.ChunkSizeConfig; @@ -44,7 +45,7 @@ public Collection testHalfDuplexHeapBuffers() { List tests = new ArrayList<>(); for (int size1 : sizes) { DynamicTest test = DynamicTest.dynamicTest(String.format("Testing sizes: size1=%s", size1), () -> { - tlschannel.helpers.SocketPair socketPair = factory.nioNio( + SocketGroups.SocketPair socketPair = factory.nioNio( null, Some.apply(new ChunkSizeConfig( new ChuckSizes(Some.apply(size1), Option.apply(null)), @@ -69,7 +70,7 @@ public Collection testHalfDuplexDirectBuffers() { List tests = new ArrayList<>(); for (int size1 : sizes) { DynamicTest test = DynamicTest.dynamicTest(String.format("Testing sizes: size1=%s", size1), () -> { - tlschannel.helpers.SocketPair socketPair = factory.nioNio( + SocketGroups.SocketPair socketPair = factory.nioNio( null, Some.apply(new ChunkSizeConfig( new ChuckSizes(Some.apply(size1), Option.apply(null)), diff --git a/src/test/scala/tlschannel/NullMultiNonBlockingTest.java b/src/test/scala/tlschannel/NullMultiNonBlockingTest.java index 30ae58c3..0db9349f 100644 --- a/src/test/scala/tlschannel/NullMultiNonBlockingTest.java +++ b/src/test/scala/tlschannel/NullMultiNonBlockingTest.java @@ -8,7 +8,7 @@ import scala.Option; import scala.collection.immutable.Seq; import tlschannel.helpers.NonBlockingLoops; -import tlschannel.helpers.SocketPair; +import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; diff --git a/src/test/scala/tlschannel/ScatteringTest.java b/src/test/scala/tlschannel/ScatteringTest.java index d7a2f84a..0f523512 100644 --- a/src/test/scala/tlschannel/ScatteringTest.java +++ b/src/test/scala/tlschannel/ScatteringTest.java @@ -5,7 +5,7 @@ import org.junit.jupiter.api.TestInstance.Lifecycle; import scala.Option; import tlschannel.helpers.Loops; -import tlschannel.helpers.SocketPair; +import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; diff --git a/src/test/scala/tlschannel/async/AsyncCloseTest.java b/src/test/scala/tlschannel/async/AsyncCloseTest.java index b8ff2cb3..3553e009 100644 --- a/src/test/scala/tlschannel/async/AsyncCloseTest.java +++ b/src/test/scala/tlschannel/async/AsyncCloseTest.java @@ -14,7 +14,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import tlschannel.helpers.AsyncSocketPair; +import tlschannel.helpers.SocketGroups.AsyncSocketPair; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; @@ -40,9 +40,9 @@ public void testClosingWhileReading() throws IOException, InterruptedException { AsyncSocketPair socketPair = factory.async(null, channelGroup, true, false); ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize); - Future readFuture = socketPair.server().external().read(readBuffer); + Future readFuture = socketPair.server.external.read(readBuffer); - socketPair.server().external().close(); + socketPair.server.external.close(); try { readFuture.get(1000, TimeUnit.MILLISECONDS); @@ -61,7 +61,7 @@ public void testClosingWhileReading() throws IOException, InterruptedException { Assertions.fail(e); } - socketPair.client().external().close(); + socketPair.client.external.close(); shutdownChannelGroup(channelGroup); assertChannelGroupConsistency(channelGroup); assertEquals(0, channelGroup.getSuccessfulReadCount()); @@ -78,10 +78,10 @@ public void testRawClosingWhileReading() throws IOException, InterruptedExceptio AsyncSocketPair socketPair = factory.async(null, channelGroup, true, false); ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize); - Future readFuture = socketPair.server().external().read(readBuffer); + Future readFuture = socketPair.server.external.read(readBuffer); // important: closing the raw socket - socketPair.server().plain().close(); + socketPair.server.plain.close(); try { readFuture.get(1000, TimeUnit.MILLISECONDS); @@ -100,7 +100,7 @@ public void testRawClosingWhileReading() throws IOException, InterruptedExceptio Assertions.fail(e); } - socketPair.client().external().close(); + socketPair.client.external.close(); shutdownChannelGroup(channelGroup); assertChannelGroupConsistency(channelGroup); assertEquals(0, channelGroup.getSuccessfulReadCount()); diff --git a/src/test/scala/tlschannel/async/AsyncShutdownTest.scala b/src/test/scala/tlschannel/async/AsyncShutdownTest.scala index 5284388d..ff575ffd 100644 --- a/src/test/scala/tlschannel/async/AsyncShutdownTest.scala +++ b/src/test/scala/tlschannel/async/AsyncShutdownTest.scala @@ -4,7 +4,6 @@ import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} import java.nio.ByteBuffer import java.util.concurrent.TimeUnit -import tlschannel.helpers.AsyncSocketPair import tlschannel.helpers.SocketPairFactory import tlschannel.helpers.SslContextFactory import org.junit.jupiter.api.{Test, TestInstance} @@ -24,11 +23,11 @@ class AsyncShutdownTest extends AsyncTestBase { val channelGroup = new AsynchronousTlsChannelGroup() val socketPairCount = 50 val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true) - for (AsyncSocketPair(client, server) <- socketPairs) { + for (pair <- socketPairs) { val writeBuffer = ByteBuffer.allocate(bufferSize) - client.external.write(writeBuffer) + pair.client.external.write(writeBuffer) val readBuffer = ByteBuffer.allocate(bufferSize) - server.external.read(readBuffer) + pair.server.external.read(readBuffer) } assertFalse(channelGroup.isTerminated) @@ -50,11 +49,11 @@ class AsyncShutdownTest extends AsyncTestBase { val channelGroup = new AsynchronousTlsChannelGroup() val socketPairCount = 50 val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true) - for (AsyncSocketPair(client, server) <- socketPairs) { + for (pair <- socketPairs) { val writeBuffer = ByteBuffer.allocate(bufferSize) - client.external.write(writeBuffer) + pair.client.external.write(writeBuffer) val readBuffer = ByteBuffer.allocate(bufferSize) - server.external.read(readBuffer) + pair.server.external.read(readBuffer) } assertFalse(channelGroup.isTerminated) @@ -68,9 +67,9 @@ class AsyncShutdownTest extends AsyncTestBase { assertFalse(channelGroup.isTerminated) } - for (AsyncSocketPair(client, server) <- socketPairs) { - client.external.close() - server.external.close() + for (pair <- socketPairs) { + pair.client.external.close() + pair.server.external.close() } { diff --git a/src/test/scala/tlschannel/async/AsyncTest.java b/src/test/scala/tlschannel/async/AsyncTest.java index 87c30875..6f2d7b38 100644 --- a/src/test/scala/tlschannel/async/AsyncTest.java +++ b/src/test/scala/tlschannel/async/AsyncTest.java @@ -8,7 +8,7 @@ import scala.Option; import scala.collection.immutable.Seq; import tlschannel.helpers.AsyncLoops; -import tlschannel.helpers.AsyncSocketPair; +import tlschannel.helpers.SocketGroups.AsyncSocketPair; import tlschannel.helpers.SocketPairFactory; import tlschannel.helpers.SslContextFactory; diff --git a/src/test/scala/tlschannel/async/AsyncTimeoutTest.scala b/src/test/scala/tlschannel/async/AsyncTimeoutTest.scala index ae971e8b..b0fe1f67 100644 --- a/src/test/scala/tlschannel/async/AsyncTimeoutTest.scala +++ b/src/test/scala/tlschannel/async/AsyncTimeoutTest.scala @@ -10,7 +10,6 @@ import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.LongAdder -import tlschannel.helpers.AsyncSocketPair import tlschannel.helpers.SocketPairFactory import tlschannel.helpers.SslContextFactory @@ -35,10 +34,10 @@ class AsyncTimeoutTest extends AsyncTestBase { val socketPairCount = 50 val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true) val latch = new CountDownLatch(socketPairCount * 2) - for (AsyncSocketPair(client, server) <- socketPairs) { + for (pair <- socketPairs) { val writeBuffer = ByteBuffer.allocate(bufferSize) val clientDone = new AtomicBoolean - client.external.write( + pair.client.external.write( writeBuffer, 50, TimeUnit.MILLISECONDS, @@ -62,7 +61,7 @@ class AsyncTimeoutTest extends AsyncTestBase { ) val readBuffer = ByteBuffer.allocate(bufferSize) val serverDone = new AtomicBoolean - server.external.read( + pair.server.external.read( readBuffer, 100, TimeUnit.MILLISECONDS, @@ -86,9 +85,9 @@ class AsyncTimeoutTest extends AsyncTestBase { ) } latch.await() - for (AsyncSocketPair(client, server) <- socketPairs) { - client.external.close() - server.external.close() + for (pair <- socketPairs) { + pair.client.external.close() + pair.server.external.close() } } @@ -116,11 +115,11 @@ class AsyncTimeoutTest extends AsyncTestBase { for (_ <- 1 to repetitions) { val socketPairCount = 50 val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true) - val futures = for (AsyncSocketPair(client, server) <- socketPairs) yield { + val futures = for (pair <- socketPairs) yield { val writeBuffer = ByteBuffer.allocate(bufferSize) - val writeFuture = client.external.write(writeBuffer) + val writeFuture = pair.client.external.write(writeBuffer) val readBuffer = ByteBuffer.allocate(bufferSize) - val readFuture = server.external.read(readBuffer) + val readFuture = pair.server.external.read(readBuffer) (writeFuture, readFuture) } @@ -132,9 +131,9 @@ class AsyncTimeoutTest extends AsyncTestBase { successfulReadCancellations += 1 } } - for (AsyncSocketPair(client, server) <- socketPairs) { - client.external.close() - server.external.close() + for (pair <- socketPairs) { + pair.client.external.close() + pair.server.external.close() } } shutdownChannelGroup(channelGroup) diff --git a/src/test/scala/tlschannel/helpers/AsyncLoops.scala b/src/test/scala/tlschannel/helpers/AsyncLoops.scala index c90a7602..f4328c3a 100644 --- a/src/test/scala/tlschannel/helpers/AsyncLoops.scala +++ b/src/test/scala/tlschannel/helpers/AsyncLoops.scala @@ -8,6 +8,7 @@ import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.LongAdder import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertTrue} +import tlschannel.helpers.SocketGroups.{AsyncSocketGroup, AsyncSocketPair} import java.util.logging.Logger import scala.util.control.Breaks @@ -63,9 +64,9 @@ object AsyncLoops { val endpointQueue = new LinkedBlockingQueue[Endpoint] val dataHash = Loops.expectedBytesHash(dataSize) - val endpoints = for (AsyncSocketPair(client, server) <- socketPairs) yield { - val clientEndpoint = WriterEndpoint(client, remaining = dataSize) - val serverEndpoint = ReaderEndpoint(server, remaining = dataSize) + val endpoints = for (pair <- socketPairs) yield { + val clientEndpoint = WriterEndpoint(pair.client, remaining = dataSize) + val serverEndpoint = ReaderEndpoint(pair.server, remaining = dataSize) (clientEndpoint, serverEndpoint) } diff --git a/src/test/scala/tlschannel/helpers/Loops.scala b/src/test/scala/tlschannel/helpers/Loops.scala index ee372df3..e7402218 100644 --- a/src/test/scala/tlschannel/helpers/Loops.scala +++ b/src/test/scala/tlschannel/helpers/Loops.scala @@ -5,7 +5,7 @@ import java.security.MessageDigest import java.util.SplittableRandom import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertEquals, assertTrue} import tlschannel.helpers.TestUtil.Memo - +import tlschannel.helpers.SocketGroups._ import java.util.logging.Logger object Loops { diff --git a/src/test/scala/tlschannel/helpers/NonBlockingLoops.scala b/src/test/scala/tlschannel/helpers/NonBlockingLoops.scala index eebaaea7..00610ab9 100644 --- a/src/test/scala/tlschannel/helpers/NonBlockingLoops.scala +++ b/src/test/scala/tlschannel/helpers/NonBlockingLoops.scala @@ -4,6 +4,7 @@ import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertTrue} import tlschannel.NeedsWriteException import tlschannel.NeedsReadException import tlschannel.NeedsTaskException +import tlschannel.helpers.SocketGroups.{SocketGroup, SocketPair} import java.util.concurrent.atomic.LongAdder import scala.util.Random @@ -61,15 +62,15 @@ object NonBlockingLoops { val readyTaskSockets = new ConcurrentLinkedQueue[Endpoint] - val endpoints = for (SocketPair(client, server) <- socketPairs) yield { - client.plain.configureBlocking(false) - server.plain.configureBlocking(false) + val endpoints = for (pair <- socketPairs) yield { + pair.client.plain.configureBlocking(false) + pair.server.plain.configureBlocking(false) - val clientEndpoint = WriterEndpoint(client, key = null, remaining = dataSize) - val serverEndpoint = ReaderEndpoint(server, key = null, remaining = dataSize) + val clientEndpoint = WriterEndpoint(pair.client, key = null, remaining = dataSize) + val serverEndpoint = ReaderEndpoint(pair.server, key = null, remaining = dataSize) - clientEndpoint.key = client.plain.register(selector, SelectionKey.OP_WRITE, clientEndpoint) - serverEndpoint.key = server.plain.register(selector, SelectionKey.OP_READ, serverEndpoint) + clientEndpoint.key = pair.client.plain.register(selector, SelectionKey.OP_WRITE, clientEndpoint) + serverEndpoint.key = pair.server.plain.register(selector, SelectionKey.OP_READ, serverEndpoint) (clientEndpoint, serverEndpoint) } diff --git a/src/test/scala/tlschannel/helpers/SocketGroups.java b/src/test/scala/tlschannel/helpers/SocketGroups.java new file mode 100644 index 00000000..4243c480 --- /dev/null +++ b/src/test/scala/tlschannel/helpers/SocketGroups.java @@ -0,0 +1,53 @@ +package tlschannel.helpers; + +import java.nio.channels.ByteChannel; +import java.nio.channels.SocketChannel; +import tlschannel.TlsChannel; +import tlschannel.async.ExtendedAsynchronousByteChannel; + +public class SocketGroups { + + public static class SocketPair { + public final SocketGroup client; + public final SocketGroup server; + + public SocketPair(SocketGroup client, SocketGroup server) { + this.client = client; + this.server = server; + } + } + + public static class AsyncSocketPair { + public final AsyncSocketGroup client; + public final AsyncSocketGroup server; + + public AsyncSocketPair(AsyncSocketGroup client, AsyncSocketGroup server) { + this.client = client; + this.server = server; + } + } + + public static class SocketGroup { + public final ByteChannel external; + public final TlsChannel tls; + public final SocketChannel plain; + + public SocketGroup(ByteChannel external, TlsChannel tls, SocketChannel plain) { + this.external = external; + this.tls = tls; + this.plain = plain; + } + } + + public static class AsyncSocketGroup { + public final ExtendedAsynchronousByteChannel external; + public final TlsChannel tls; + public final SocketChannel plain; + + public AsyncSocketGroup(ExtendedAsynchronousByteChannel external, TlsChannel tls, SocketChannel plain) { + this.external = external; + this.tls = tls; + this.plain = plain; + } + } +} diff --git a/src/test/scala/tlschannel/helpers/SocketPairFactory.scala b/src/test/scala/tlschannel/helpers/SocketPairFactory.scala index 5081a7fc..01ecb119 100644 --- a/src/test/scala/tlschannel/helpers/SocketPairFactory.scala +++ b/src/test/scala/tlschannel/helpers/SocketPairFactory.scala @@ -11,26 +11,17 @@ import javax.net.ssl.SSLEngine import org.junit.jupiter.api.Assertions.assertEquals import javax.crypto.Cipher -import java.nio.channels.ByteChannel import java.util.Optional import javax.net.ssl.SNIHostName import javax.net.ssl.SNIServerName import tlschannel._ import tlschannel.async.AsynchronousTlsChannel import tlschannel.async.AsynchronousTlsChannelGroup -import tlschannel.async.ExtendedAsynchronousByteChannel - import java.util.logging.Logger import scala.jdk.CollectionConverters._ import scala.util.Random -case class SocketPair(client: SocketGroup, server: SocketGroup) - -case class AsyncSocketPair(client: AsyncSocketGroup, server: AsyncSocketGroup) - -case class SocketGroup(external: ByteChannel, tls: TlsChannel, plain: SocketChannel) - -case class AsyncSocketGroup(external: ExtendedAsynchronousByteChannel, tls: TlsChannel, plain: SocketChannel) +import tlschannel.helpers.SocketGroups._ /** Create pairs of connected sockets (using the loopback interface). Additionally, all the raw (non-encrypted) socket * channel are wrapped with a chunking decorator that partitions the bytesProduced of any read or write operation. @@ -136,7 +127,7 @@ class SocketPairFactory(val sslContext: SSLContext, val serverName: String) { .newBuilder(rawServer, nameOpt => sslContextFactory(clientSniHostName, sslContext)(nameOpt)) .withEngineFactory(fixedCipherServerSslEngineFactory(cipher) _) .build() - (client, SocketGroup(server, server, rawServer)) + (client, new SocketGroup(server, server, rawServer)) } def nioOld(cipher: Option[String] = None): (SocketGroup, SSLSocket) = { @@ -149,7 +140,7 @@ class SocketPairFactory(val sslContext: SSLContext, val serverName: String) { val client = ClientTlsChannel .newBuilder(rawClient, createClientSslEngine(cipher, chosenPort)) .build() - (SocketGroup(client, client, rawClient), server) + (new SocketGroup(client, client, rawClient), server) } def nioNio( @@ -278,9 +269,9 @@ class SocketPairFactory(val sslContext: SSLContext, val serverName: String) { serverChannel } - val clientPair = SocketGroup(externalClient, clientChannel, rawClient) - val serverPair = SocketGroup(externalServer, serverChannel, rawServer) - SocketPair(clientPair, serverPair) + val clientPair = new SocketGroup(externalClient, clientChannel, rawClient) + val serverPair = new SocketGroup(externalServer, serverChannel, rawServer) + new SocketPair(clientPair, serverPair) } } finally { serverSocket.close() @@ -356,9 +347,9 @@ class SocketPairFactory(val sslContext: SSLContext, val serverName: String) { val clientAsyncChannel = new AsynchronousTlsChannel(channelGroup, clientChannel, rawClient) val serverAsyncChannel = new AsynchronousTlsChannel(channelGroup, serverChannel, rawServer) - val clientPair = AsyncSocketGroup(clientAsyncChannel, clientChannel, rawClient) - val serverPair = AsyncSocketGroup(serverAsyncChannel, serverChannel, rawServer) - AsyncSocketPair(clientPair, serverPair) + val clientPair = new AsyncSocketGroup(clientAsyncChannel, clientChannel, rawClient) + val serverPair = new AsyncSocketGroup(serverAsyncChannel, serverChannel, rawServer) + new AsyncSocketPair(clientPair, serverPair) } } finally { serverSocket.close()