diff --git a/src/test/scala/tlschannel/async/AsyncTest.java b/src/test/scala/tlschannel/async/AsyncTest.java index 6f2d7b38..3e31dabc 100644 --- a/src/test/scala/tlschannel/async/AsyncTest.java +++ b/src/test/scala/tlschannel/async/AsyncTest.java @@ -2,11 +2,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; import scala.Option; -import scala.collection.immutable.Seq; +import scala.jdk.javaapi.CollectionConverters; import tlschannel.helpers.AsyncLoops; import tlschannel.helpers.SocketGroups.AsyncSocketPair; import tlschannel.helpers.SocketPairFactory; @@ -21,13 +22,13 @@ public class AsyncTest implements AsyncTestBase { // real engine - run tasks @Test - public void testRunTasks() { + public void testRunTasks() throws Throwable { System.out.println("testRunTasks():"); AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); int dataSize = 5 * 1024 * 1024; System.out.printf("data size: %d\n", dataSize); - Seq socketPairs = - factory.asyncN(Option.apply(null), channelGroup, socketPairCount, true, false); + List socketPairs = CollectionConverters.asJava( + factory.asyncN(Option.apply(null), channelGroup, socketPairCount, true, false)); AsyncLoops.Report report = AsyncLoops.loop(socketPairs, dataSize); shutdownChannelGroup(channelGroup); @@ -41,13 +42,13 @@ public void testRunTasks() { // real engine - do not run tasks @Test - public void testNotRunTasks() { + public void testNotRunTasks() throws Throwable { System.out.println("testNotRunTasks():"); AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); int dataSize = 2 * 1024 * 1024; System.out.printf("data size: %d\n", dataSize); - Seq socketPairs = - factory.asyncN(Option.apply(null), channelGroup, socketPairCount, false, false); + List socketPairs = CollectionConverters.asJava( + factory.asyncN(Option.apply(null), channelGroup, socketPairCount, false, false)); AsyncLoops.Report report = AsyncLoops.loop(socketPairs, dataSize); shutdownChannelGroup(channelGroup); @@ -64,12 +65,13 @@ public void testNotRunTasks() { // null engine @Test - public void testNullEngine() { + public void testNullEngine() throws Throwable { System.out.println("testNullEngine():"); AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup(); int dataSize = 12 * 1024 * 1024; System.out.printf("data size: %d\n", dataSize); - Seq socketPairs = factory.asyncN(null, channelGroup, socketPairCount, true, false); + List socketPairs = + CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false)); AsyncLoops.Report report = AsyncLoops.loop(socketPairs, dataSize); shutdownChannelGroup(channelGroup); diff --git a/src/test/scala/tlschannel/helpers/AsyncLoops.java b/src/test/scala/tlschannel/helpers/AsyncLoops.java new file mode 100644 index 00000000..cc9b1fc4 --- /dev/null +++ b/src/test/scala/tlschannel/helpers/AsyncLoops.java @@ -0,0 +1,232 @@ +package tlschannel.helpers; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.*; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.LongAdder; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import tlschannel.helpers.SocketGroups.AsyncSocketGroup; +import tlschannel.helpers.SocketGroups.AsyncSocketPair; + +public class AsyncLoops { + + private static final Logger logger = Logger.getLogger(AsyncLoops.class.getName()); + + private interface Endpoint { + int remaining(); + + Optional exception(); + } + + private static class WriterEndpoint implements Endpoint { + private final AsyncSocketGroup socketGroup; + private final SplittableRandom random = new SplittableRandom(Loops.seed()); + private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize()); + private int remaining; + private Optional exception = Optional.empty(); + + public WriterEndpoint(AsyncSocketGroup socketGroup, int remaining) { + this.socketGroup = socketGroup; + this.remaining = remaining; + buffer.flip(); + } + + @Override + public int remaining() { + return remaining; + } + + public Optional exception() { + return exception; + } + } + + private static class ReaderEndpoint implements Endpoint { + private final AsyncSocketGroup socketGroup; + private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize()); + private final MessageDigest digest; + private int remaining; + private Optional exception = Optional.empty(); + + public ReaderEndpoint(AsyncSocketGroup socketGroup, int remaining) { + this.socketGroup = socketGroup; + this.remaining = remaining; + try { + digest = MessageDigest.getInstance(Loops.hashAlgorithm()); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + + @Override + public int remaining() { + return remaining; + } + + public Optional exception() { + return exception; + } + } + + public static class Report { + final long dequeueCycles; + final long completedReads; + final long failedReads; + final long completedWrites; + final long failedWrites; + + public Report( + long dequeueCycles, long completedReads, long failedReads, long completedWrites, long failedWrites) { + this.dequeueCycles = dequeueCycles; + this.completedReads = completedReads; + this.failedReads = failedReads; + this.completedWrites = completedWrites; + this.failedWrites = failedWrites; + } + + public void print() { + System.out.print("test loop:\n"); + System.out.printf(" dequeue cycles: %8d\n", dequeueCycles); + System.out.printf(" completed reads: %8d\n", completedReads); + System.out.printf(" failed reads: %8d\n", failedReads); + System.out.printf(" completed writes: %8d\n", completedWrites); + System.out.printf(" failed writes: %8d\n", failedWrites); + } + } + + public static Report loop(List socketPairs, int dataSize) throws Throwable { + logger.fine(() -> "starting async loop - pair count: " + socketPairs.size()); + + int dequeueCycles = 0; + LongAdder completedReads = new LongAdder(); + LongAdder failedReads = new LongAdder(); + LongAdder completedWrites = new LongAdder(); + LongAdder failedWrites = new LongAdder(); + + LinkedBlockingQueue endpointQueue = new LinkedBlockingQueue<>(); + byte[] dataHash = Loops.expectedBytesHash().apply(dataSize); + + List clientEndpoints = socketPairs.stream() + .map(p -> new WriterEndpoint(p.client, dataSize)) + .collect(Collectors.toList()); + + List serverEndpoints = socketPairs.stream() + .map(p -> new ReaderEndpoint(p.server, dataSize)) + .collect(Collectors.toList()); + + for (Endpoint endpoint : clientEndpoints) { + endpointQueue.put(endpoint); + } + for (Endpoint endpoint : serverEndpoints) { + endpointQueue.put(endpoint); + } + + int endpointsFinished = 0; + int totalEndpoints = endpointQueue.size(); + while (true) { + Endpoint endpoint = endpointQueue.take(); // blocks + + dequeueCycles += 1; + + if (endpoint.exception().isPresent()) { + throw endpoint.exception().get(); + } + + if (endpoint.remaining() == 0) { + endpointsFinished += 1; + if (endpointsFinished == totalEndpoints) { + break; + } + } else { + + if (endpoint instanceof WriterEndpoint) { + WriterEndpoint writer = (WriterEndpoint) endpoint; + + if (!writer.buffer.hasRemaining()) { + TestUtil.nextBytes(writer.random, writer.buffer.array()); + writer.buffer.position(0); + writer.buffer.limit(Math.min(writer.buffer.capacity(), writer.remaining)); + } + writer.socketGroup.external.write( + writer.buffer, 1, TimeUnit.DAYS, null, new CompletionHandler() { + @Override + public void completed(Integer c, Object attach) { + assertTrue(c > 0); + writer.remaining -= c; + try { + endpointQueue.put(writer); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + completedWrites.increment(); + } + + @Override + public void failed(Throwable e, Object attach) { + writer.exception = Optional.of(e); + try { + endpointQueue.put(writer); + } catch (InterruptedException ex) { + throw new RuntimeException(ex); + } + failedWrites.increment(); + } + }); + } else if (endpoint instanceof ReaderEndpoint) { + ReaderEndpoint reader = (ReaderEndpoint) endpoint; + reader.buffer.clear(); + reader.socketGroup.external.read( + reader.buffer, 1, TimeUnit.DAYS, null, new CompletionHandler() { + @Override + public void completed(Integer c, Object attach) { + assertTrue(c > 0); + reader.digest.update(reader.buffer.array(), 0, c); + reader.remaining -= c; + try { + endpointQueue.put(reader); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + completedReads.increment(); + } + + @Override + public void failed(Throwable e, Object attach) { + reader.exception = Optional.of(e); + try { + endpointQueue.put(reader); + } catch (InterruptedException ex) { + throw new RuntimeException(ex); + } + failedReads.increment(); + } + }); + } else { + throw new IllegalStateException(); + } + } + } + for (AsyncSocketPair socketPair : socketPairs) { + socketPair.client.external.close(); + socketPair.server.external.close(); + SocketPairFactory.checkDeallocation(socketPair); + } + for (ReaderEndpoint reader : serverEndpoints) { + assertArrayEquals(reader.digest.digest(), dataHash); + } + return new Report( + dequeueCycles, + completedReads.longValue(), + failedReads.longValue(), + completedWrites.longValue(), + failedWrites.longValue()); + } +} diff --git a/src/test/scala/tlschannel/helpers/AsyncLoops.scala b/src/test/scala/tlschannel/helpers/AsyncLoops.scala deleted file mode 100644 index f4328c3a..00000000 --- a/src/test/scala/tlschannel/helpers/AsyncLoops.scala +++ /dev/null @@ -1,163 +0,0 @@ -package tlschannel.helpers - -import java.nio.ByteBuffer -import java.nio.channels.CompletionHandler -import java.security.MessageDigest -import java.util.SplittableRandom -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.LongAdder -import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertTrue} -import tlschannel.helpers.SocketGroups.{AsyncSocketGroup, AsyncSocketPair} - -import java.util.logging.Logger -import scala.util.control.Breaks - -object AsyncLoops { - - val logger = Logger.getLogger(AsyncLoops.getClass.getName) - - trait Endpoint { - def remaining: Int - def exception: Option[Throwable] - } - - case class WriterEndpoint(socketGroup: AsyncSocketGroup, var remaining: Int) extends Endpoint { - var exception: Option[Throwable] = None - val random = new SplittableRandom(Loops.seed) - val buffer = ByteBuffer.allocate(Loops.bufferSize) - buffer.flip() - } - - case class ReaderEndpoint(socketGroup: AsyncSocketGroup, var remaining: Int) extends Endpoint { - var exception: Option[Throwable] = None - val buffer = ByteBuffer.allocate(Loops.bufferSize) - val digest = MessageDigest.getInstance(Loops.hashAlgorithm) - } - - case class Report( - dequeueCycles: Long, - completedReads: Long, - failedReads: Long, - completedWrites: Long, - failedWrites: Long - ) { - def print(): Unit = { - println(f"test loop:") - println(f" dequeue cycles: $dequeueCycles%8d") - println(f" completed reads: $completedReads%8d") - println(f" failed reads: $failedReads%8d") - println(f" completed writes: $completedWrites%8d") - println(f" failed writes: $failedWrites%8d") - } - } - - def loop(socketPairs: Seq[AsyncSocketPair], dataSize: Int): Report = { - - logger.fine(() => s"starting async loop - pair count: ${socketPairs.size}") - - var dequeueCycles = 0 - val completedReads = new LongAdder - val failedReads = new LongAdder - val completedWrites = new LongAdder - val failedWrites = new LongAdder - - val endpointQueue = new LinkedBlockingQueue[Endpoint] - val dataHash = Loops.expectedBytesHash(dataSize) - val endpoints = for (pair <- socketPairs) yield { - val clientEndpoint = WriterEndpoint(pair.client, remaining = dataSize) - val serverEndpoint = ReaderEndpoint(pair.server, remaining = dataSize) - (clientEndpoint, serverEndpoint) - } - - val (writers, readers) = endpoints.unzip - val allEndpoints = writers ++ readers - - for (endpoint <- allEndpoints) { - endpointQueue.put(endpoint) - } - var endpointsFinished = 0 - val totalEndpoints = endpoints.length * 2 - Breaks.breakable { - while (true) { - val endpoint = endpointQueue.take() // blocks - - dequeueCycles += 1 - endpoint.exception.foreach(throw _) - if (endpoint.remaining == 0) { - endpointsFinished += 1 - if (endpointsFinished == totalEndpoints) { - Breaks.break() - } - } else { - endpoint match { - case writer: WriterEndpoint => - if (!writer.buffer.hasRemaining) { - TestUtil.nextBytes(writer.random, writer.buffer.array()) - writer.buffer.position(0) - writer.buffer.limit(math.min(writer.buffer.capacity, writer.remaining)) - } - writer.socketGroup.external.write( - writer.buffer, - 1, - TimeUnit.DAYS, - null, - new CompletionHandler[Integer, Null] { - override def completed(c: Integer, attach: Null) = { - assertTrue(c > 0) - writer.remaining -= c - endpointQueue.put(writer) - completedWrites.increment() - } - - override def failed(e: Throwable, attach: Null) = { - writer.exception = Some(e) - endpointQueue.put(writer) - failedWrites.increment() - } - } - ) - case reader: ReaderEndpoint => - reader.buffer.clear() - reader.socketGroup.external.read( - reader.buffer, - 1, - TimeUnit.DAYS, - null, - new CompletionHandler[Integer, Null] { - override def completed(c: Integer, attach: Null) = { - assertTrue(c > 0) - reader.digest.update(reader.buffer.array, 0, c) - reader.remaining -= c - endpointQueue.put(reader) - completedReads.increment() - } - override def failed(e: Throwable, attach: Null) = { - reader.exception = Some(e) - endpointQueue.put(reader) - failedReads.increment() - } - } - ) - } - } - } - } - for (socketPair <- socketPairs) { - socketPair.client.external.close() - socketPair.server.external.close() - SocketPairFactory.checkDeallocation(socketPair) - } - for (reader <- readers) { - assertArrayEquals(reader.digest.digest(), dataHash) - } - Report( - dequeueCycles, - completedReads.longValue(), - failedReads.longValue(), - completedWrites.longValue(), - failedWrites.longValue() - ) - } - -}