From 8210b50787968dc5372dddf13cb63e30be1df186 Mon Sep 17 00:00:00 2001 From: Mariano Barrios Date: Sun, 5 May 2024 22:17:53 +0200 Subject: [PATCH] Migrate to Java: AsyncShutdownTest --- .../tlschannel/async/AsyncShutdownTest.java | 100 +++++++++++ .../tlschannel/async/AsyncShutdownTest.scala | 92 ---------- .../tlschannel/async/AsyncTimeoutTest.java | 159 ++++++++++++++++++ .../tlschannel/async/AsyncTimeoutTest.scala | 151 ----------------- 4 files changed, 259 insertions(+), 243 deletions(-) create mode 100644 src/test/scala/tlschannel/async/AsyncShutdownTest.java delete mode 100644 src/test/scala/tlschannel/async/AsyncShutdownTest.scala create mode 100644 src/test/scala/tlschannel/async/AsyncTimeoutTest.java delete mode 100644 src/test/scala/tlschannel/async/AsyncTimeoutTest.scala diff --git a/src/test/scala/tlschannel/async/AsyncShutdownTest.java b/src/test/scala/tlschannel/async/AsyncShutdownTest.java new file mode 100644 index 00000000..0814f5df --- /dev/null +++ b/src/test/scala/tlschannel/async/AsyncShutdownTest.java @@ -0,0 +1,100 @@ +package tlschannel.async; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import scala.jdk.javaapi.CollectionConverters; +import tlschannel.helpers.SocketGroups; +import tlschannel.helpers.SocketPairFactory; +import tlschannel.helpers.SslContextFactory; + +@TestInstance(Lifecycle.PER_CLASS) +public class AsyncShutdownTest implements AsyncTestBase { + + private final SslContextFactory sslContextFactory = new SslContextFactory(); + private final SocketPairFactory factory = new SocketPairFactory(sslContextFactory.defaultContext()); + + int bufferSize = 10; + + @Test + public void testImmediateShutdown() throws InterruptedException { + System.out.println("testImmediateShutdown():"); + AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); + int socketPairCount = 50; + List socketPairs = + CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); + for (SocketGroups.AsyncSocketPair pair : socketPairs) { + ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize); + pair.client.external.write(writeBuffer); + ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize); + pair.server.external.read(readBuffer); + } + + assertFalse(channelGroup.isTerminated()); + + channelGroup.shutdownNow(); + + // terminated even after a relatively short timeout + boolean terminated = channelGroup.awaitTermination(100, TimeUnit.MILLISECONDS); + assertTrue(terminated); + assertTrue(channelGroup.isTerminated()); + assertChannelGroupConsistency(channelGroup); + + printChannelGroupStatus(channelGroup); + } + + @Test + public void testNonImmediateShutdown() throws InterruptedException, IOException { + System.out.println("testNonImmediateShutdown():"); + AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); + int socketPairCount = 50; + List socketPairs = + CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); + for (SocketGroups.AsyncSocketPair pair : socketPairs) { + ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize); + pair.client.external.write(writeBuffer); + ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize); + pair.server.external.read(readBuffer); + } + + assertFalse(channelGroup.isTerminated()); + + channelGroup.shutdown(); + + { + // not terminated even after a relatively long timeout + boolean terminated = channelGroup.awaitTermination(2000, TimeUnit.MILLISECONDS); + assertFalse(terminated); + assertFalse(channelGroup.isTerminated()); + } + + for (SocketGroups.AsyncSocketPair pair : socketPairs) { + pair.client.external.close(); + pair.server.external.close(); + } + + { + // terminated even after a relatively short timeout + boolean terminated = channelGroup.awaitTermination(100, TimeUnit.MILLISECONDS); + assertTrue(terminated); + assertTrue(channelGroup.isTerminated()); + } + + assertChannelGroupConsistency(channelGroup); + + assertEquals(0, channelGroup.getCancelledReadCount()); + assertEquals(0, channelGroup.getCancelledWriteCount()); + assertEquals(0, channelGroup.getFailedReadCount()); + assertEquals(0, channelGroup.getFailedWriteCount()); + + printChannelGroupStatus(channelGroup); + } +} diff --git a/src/test/scala/tlschannel/async/AsyncShutdownTest.scala b/src/test/scala/tlschannel/async/AsyncShutdownTest.scala deleted file mode 100644 index ff575ffd..00000000 --- a/src/test/scala/tlschannel/async/AsyncShutdownTest.scala +++ /dev/null @@ -1,92 +0,0 @@ -package tlschannel.async - -import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} - -import java.nio.ByteBuffer -import java.util.concurrent.TimeUnit -import tlschannel.helpers.SocketPairFactory -import tlschannel.helpers.SslContextFactory -import org.junit.jupiter.api.{Test, TestInstance} -import org.junit.jupiter.api.TestInstance.Lifecycle - -@TestInstance(Lifecycle.PER_CLASS) -class AsyncShutdownTest extends AsyncTestBase { - - val sslContextFactory = new SslContextFactory - val factory = new SocketPairFactory(sslContextFactory.defaultContext) - - val bufferSize = 10 - - @Test - def testImmediateShutdown(): Unit = { - println("testImmediateShutdown():") - val channelGroup = new AsynchronousTlsChannelGroup() - val socketPairCount = 50 - val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true) - for (pair <- socketPairs) { - val writeBuffer = ByteBuffer.allocate(bufferSize) - pair.client.external.write(writeBuffer) - val readBuffer = ByteBuffer.allocate(bufferSize) - pair.server.external.read(readBuffer) - } - - assertFalse(channelGroup.isTerminated) - - channelGroup.shutdownNow() - - // terminated even after a relatively short timeout - val terminated = channelGroup.awaitTermination(100, TimeUnit.MILLISECONDS) - assertTrue(terminated) - assertTrue(channelGroup.isTerminated) - assertChannelGroupConsistency(channelGroup) - - printChannelGroupStatus(channelGroup) - } - - @Test - def testNonImmediateShutdown(): Unit = { - println("testNonImmediateShutdown():") - val channelGroup = new AsynchronousTlsChannelGroup() - val socketPairCount = 50 - val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true) - for (pair <- socketPairs) { - val writeBuffer = ByteBuffer.allocate(bufferSize) - pair.client.external.write(writeBuffer) - val readBuffer = ByteBuffer.allocate(bufferSize) - pair.server.external.read(readBuffer) - } - - assertFalse(channelGroup.isTerminated) - - channelGroup.shutdown() - - { - // not terminated even after a relatively long timeout - val terminated = channelGroup.awaitTermination(2000, TimeUnit.MILLISECONDS) - assertFalse(terminated) - assertFalse(channelGroup.isTerminated) - } - - for (pair <- socketPairs) { - pair.client.external.close() - pair.server.external.close() - } - - { - // terminated even after a relatively short timeout - val terminated = channelGroup.awaitTermination(100, TimeUnit.MILLISECONDS) - assertTrue(terminated) - assertTrue(channelGroup.isTerminated) - } - - assertChannelGroupConsistency(channelGroup) - - assertEquals(0, channelGroup.getCancelledReadCount) - assertEquals(0, channelGroup.getCancelledWriteCount) - assertEquals(0, channelGroup.getFailedReadCount) - assertEquals(0, channelGroup.getFailedWriteCount) - - printChannelGroupStatus(channelGroup) - } - -} diff --git a/src/test/scala/tlschannel/async/AsyncTimeoutTest.java b/src/test/scala/tlschannel/async/AsyncTimeoutTest.java new file mode 100644 index 00000000..ad9e8d5a --- /dev/null +++ b/src/test/scala/tlschannel/async/AsyncTimeoutTest.java @@ -0,0 +1,159 @@ +package tlschannel.async; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.LongAdder; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import scala.jdk.javaapi.CollectionConverters; +import tlschannel.helpers.SocketGroups; +import tlschannel.helpers.SocketPairFactory; +import tlschannel.helpers.SslContextFactory; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@TestInstance(Lifecycle.PER_CLASS) +public class AsyncTimeoutTest implements AsyncTestBase { + + SslContextFactory sslContextFactory = new SslContextFactory(); + SocketPairFactory factory = new SocketPairFactory(sslContextFactory.defaultContext()); + + private static final int bufferSize = 10; + + private static final int repetitions = 50; + + // scheduled timeout + @Test + public void testScheduledTimeout() throws IOException { + System.out.println("testScheduledTimeout()"); + AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); + LongAdder successWrites = new LongAdder(); + LongAdder successReads = new LongAdder(); + for (int i = 1; i <= repetitions; i++) { + int socketPairCount = 50; + List socketPairs = + CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); + CountDownLatch latch = new CountDownLatch(socketPairCount * 2); + for (SocketGroups.AsyncSocketPair pair : socketPairs) { + ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize); + AtomicBoolean clientDone = new AtomicBoolean(); + pair.client.external.write( + writeBuffer, 50, TimeUnit.MILLISECONDS, null, new CompletionHandler() { + @Override + public void failed(Throwable exc, Object attachment) { + if (!clientDone.compareAndSet(false, true)) { + Assertions.fail(); + } + latch.countDown(); + } + + @Override + public void completed(Integer result, Object attachment) { + if (!clientDone.compareAndSet(false, true)) { + Assertions.fail(); + } + latch.countDown(); + successWrites.increment(); + } + }); + ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize); + AtomicBoolean serverDone = new AtomicBoolean(); + pair.server.external.read( + readBuffer, 100, TimeUnit.MILLISECONDS, null, new CompletionHandler() { + @Override + public void failed(Throwable exc, Object attachment) { + if (!serverDone.compareAndSet(false, true)) { + Assertions.fail(); + } + latch.countDown(); + } + + @Override + public void completed(Integer result, Object attachment) { + if (!serverDone.compareAndSet(false, true)) { + Assertions.fail(); + } + latch.countDown(); + successReads.increment(); + } + }); + } + try { + latch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + for (SocketGroups.AsyncSocketPair pair : socketPairs) { + pair.client.external.close(); + pair.server.external.close(); + } + } + + shutdownChannelGroup(channelGroup); + assertChannelGroupConsistency(channelGroup); + + assertEquals(0, channelGroup.getFailedReadCount()); + assertEquals(0, channelGroup.getFailedWriteCount()); + + assertEquals(channelGroup.getSuccessfulWriteCount(), successWrites.longValue()); + assertEquals(channelGroup.getSuccessfulReadCount(), successReads.longValue()); + + System.out.printf("success writes: %8d\n", successWrites.longValue()); + System.out.printf("success reads: %8d\n", successReads.longValue()); + printChannelGroupStatus(channelGroup); + } + + // triggered timeout + @Test + public void testTriggeredTimeout() throws IOException { + System.out.println("testScheduledTimeout()"); + AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); + int successfulWriteCancellations = 0; + int successfulReadCancellations = 0; + for (int i = 1; i <= repetitions; i++) { + int socketPairCount = 50; + List socketPairs = + CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); + + for (SocketGroups.AsyncSocketPair pair : socketPairs) { + ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize); + Future writeFuture = pair.client.external.write(writeBuffer); + if (writeFuture.cancel(true)) { + successfulWriteCancellations += 1; + } + } + + for (SocketGroups.AsyncSocketPair pair : socketPairs) { + ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize); + Future readFuture = pair.server.external.read(readBuffer); + if (readFuture.cancel(true)) { + successfulReadCancellations += 1; + } + } + + for (SocketGroups.AsyncSocketPair pair : socketPairs) { + pair.client.external.close(); + pair.server.external.close(); + } + } + shutdownChannelGroup(channelGroup); + assertChannelGroupConsistency(channelGroup); + + assertEquals(0, channelGroup.getFailedReadCount()); + assertEquals(0, channelGroup.getFailedWriteCount()); + + assertEquals(channelGroup.getCancelledWriteCount(), successfulWriteCancellations); + assertEquals(channelGroup.getCancelledReadCount(), successfulReadCancellations); + + System.out.printf("success writes: %8d\n", channelGroup.getSuccessfulWriteCount()); + System.out.printf("success reads: %8d\n", channelGroup.getSuccessfulReadCount()); + } +} diff --git a/src/test/scala/tlschannel/async/AsyncTimeoutTest.scala b/src/test/scala/tlschannel/async/AsyncTimeoutTest.scala deleted file mode 100644 index b0fe1f67..00000000 --- a/src/test/scala/tlschannel/async/AsyncTimeoutTest.scala +++ /dev/null @@ -1,151 +0,0 @@ -package tlschannel.async - -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.TestInstance.Lifecycle -import org.junit.jupiter.api.{Assertions, Test, TestInstance} - -import java.nio.ByteBuffer -import java.nio.channels.CompletionHandler -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.LongAdder -import tlschannel.helpers.SocketPairFactory -import tlschannel.helpers.SslContextFactory - -@TestInstance(Lifecycle.PER_CLASS) -class AsyncTimeoutTest extends AsyncTestBase { - - val sslContextFactory = new SslContextFactory - val factory = new SocketPairFactory(sslContextFactory.defaultContext) - - val bufferSize = 10 - - val repetitions = 50 - - // scheduled timeout - @Test - def testScheduledTimeout(): Unit = { - println("testScheduledTimeout()") - val channelGroup = new AsynchronousTlsChannelGroup() - val successWrites = new LongAdder - val successReads = new LongAdder - for (_ <- 1 to repetitions) { - val socketPairCount = 50 - val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true) - val latch = new CountDownLatch(socketPairCount * 2) - for (pair <- socketPairs) { - val writeBuffer = ByteBuffer.allocate(bufferSize) - val clientDone = new AtomicBoolean - pair.client.external.write( - writeBuffer, - 50, - TimeUnit.MILLISECONDS, - null, - new CompletionHandler[Integer, Null] { - override def failed(exc: Throwable, attachment: Null) = { - if (!clientDone.compareAndSet(false, true)) { - Assertions.fail() - } - latch.countDown() - } - - override def completed(result: Integer, attachment: Null) = { - if (!clientDone.compareAndSet(false, true)) { - Assertions.fail() - } - latch.countDown() - successWrites.increment() - } - } - ) - val readBuffer = ByteBuffer.allocate(bufferSize) - val serverDone = new AtomicBoolean - pair.server.external.read( - readBuffer, - 100, - TimeUnit.MILLISECONDS, - null, - new CompletionHandler[Integer, Null] { - override def failed(exc: Throwable, attachment: Null) = { - if (!serverDone.compareAndSet(false, true)) { - Assertions.fail() - } - latch.countDown() - } - - override def completed(result: Integer, attachment: Null) = { - if (!serverDone.compareAndSet(false, true)) { - Assertions.fail() - } - latch.countDown() - successReads.increment() - } - } - ) - } - latch.await() - for (pair <- socketPairs) { - pair.client.external.close() - pair.server.external.close() - } - } - - shutdownChannelGroup(channelGroup) - assertChannelGroupConsistency(channelGroup) - - assertEquals(0, channelGroup.getFailedReadCount) - assertEquals(0, channelGroup.getFailedWriteCount) - - assertEquals(channelGroup.getSuccessfulWriteCount, successWrites.longValue) - assertEquals(channelGroup.getSuccessfulReadCount, successReads.longValue) - - println(f"success writes: ${successWrites.longValue}%8d") - println(f"success reads: ${successReads.longValue}%8d") - printChannelGroupStatus(channelGroup) - } - - // triggered timeout - @Test - def testTriggeredTimeout(): Unit = { - println("testScheduledTimeout()") - val channelGroup = new AsynchronousTlsChannelGroup() - var successfulWriteCancellations = 0 - var successfulReadCancellations = 0 - for (_ <- 1 to repetitions) { - val socketPairCount = 50 - val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true) - val futures = for (pair <- socketPairs) yield { - val writeBuffer = ByteBuffer.allocate(bufferSize) - val writeFuture = pair.client.external.write(writeBuffer) - val readBuffer = ByteBuffer.allocate(bufferSize) - val readFuture = pair.server.external.read(readBuffer) - (writeFuture, readFuture) - } - - for ((writeFuture, readFuture) <- futures) { - if (writeFuture.cancel(true)) { - successfulWriteCancellations += 1 - } - if (readFuture.cancel(true)) { - successfulReadCancellations += 1 - } - } - for (pair <- socketPairs) { - pair.client.external.close() - pair.server.external.close() - } - } - shutdownChannelGroup(channelGroup) - assertChannelGroupConsistency(channelGroup) - - assertEquals(0, channelGroup.getFailedReadCount) - assertEquals(0, channelGroup.getFailedWriteCount) - - assertEquals(channelGroup.getCancelledWriteCount, successfulWriteCancellations) - assertEquals(channelGroup.getCancelledReadCount, successfulReadCancellations) - - println(f"success writes: ${channelGroup.getSuccessfulWriteCount}%8d") - println(f"success reads: ${channelGroup.getSuccessfulReadCount}%8d") - } -}