diff --git a/src/test/scala/tlschannel/NullEngineTest.java b/src/test/scala/tlschannel/NullEngineTest.java index e071956f..f4577d41 100644 --- a/src/test/scala/tlschannel/NullEngineTest.java +++ b/src/test/scala/tlschannel/NullEngineTest.java @@ -34,7 +34,7 @@ public class NullEngineTest { { // heat cache - Loops.expectedBytesHash().apply(dataSize); + Loops.expectedBytesHash.apply(dataSize); } // null engine - half duplex - heap buffers diff --git a/src/test/scala/tlschannel/helpers/AsyncLoops.java b/src/test/scala/tlschannel/helpers/AsyncLoops.java index cc9b1fc4..300e4d17 100644 --- a/src/test/scala/tlschannel/helpers/AsyncLoops.java +++ b/src/test/scala/tlschannel/helpers/AsyncLoops.java @@ -28,8 +28,8 @@ private interface Endpoint { private static class WriterEndpoint implements Endpoint { private final AsyncSocketGroup socketGroup; - private final SplittableRandom random = new SplittableRandom(Loops.seed()); - private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize()); + private final SplittableRandom random = new SplittableRandom(Loops.seed); + private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize); private int remaining; private Optional exception = Optional.empty(); @@ -51,7 +51,7 @@ public Optional exception() { private static class ReaderEndpoint implements Endpoint { private final AsyncSocketGroup socketGroup; - private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize()); + private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize); private final MessageDigest digest; private int remaining; private Optional exception = Optional.empty(); @@ -60,7 +60,7 @@ public ReaderEndpoint(AsyncSocketGroup socketGroup, int remaining) { this.socketGroup = socketGroup; this.remaining = remaining; try { - digest = MessageDigest.getInstance(Loops.hashAlgorithm()); + digest = MessageDigest.getInstance(Loops.hashAlgorithm); } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); } @@ -112,7 +112,7 @@ public static Report loop(List socketPairs, int dataSize) throw LongAdder failedWrites = new LongAdder(); LinkedBlockingQueue endpointQueue = new LinkedBlockingQueue<>(); - byte[] dataHash = Loops.expectedBytesHash().apply(dataSize); + byte[] dataHash = Loops.expectedBytesHash.apply(dataSize); List clientEndpoints = socketPairs.stream() .map(p -> new WriterEndpoint(p.client, dataSize)) diff --git a/src/test/scala/tlschannel/helpers/Loops.java b/src/test/scala/tlschannel/helpers/Loops.java new file mode 100644 index 00000000..4aa6f7a3 --- /dev/null +++ b/src/test/scala/tlschannel/helpers/Loops.java @@ -0,0 +1,191 @@ +package tlschannel.helpers; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.SplittableRandom; +import java.util.function.Function; +import java.util.logging.Logger; +import tlschannel.helpers.SocketGroups.*; + +public class Loops { + + private static final Logger logger = Logger.getLogger(Loops.class.getName()); + + public static final long seed = 143000953L; + + /* + * Note that it is necessary to use a multiple of 4 as buffer size for writing. + * This is because the bytesProduced to write are generated using Random.nextBytes, that + * always consumes full (4 byte) integers. A multiple of 4 then prevents "holes" + * in the random sequence. + */ + public static final int bufferSize = 4 * 5000; + + private static final int renegotiatePeriod = 10000; + public static final String hashAlgorithm = "MD5"; // for speed + + /** Test a half-duplex interaction, with (optional) renegotiation before reversing the direction of the flow (as in + * HTTP) + */ + public static void halfDuplex(SocketPair socketPair, int dataSize, boolean renegotiation, boolean scattering) { + Thread clientWriterThread = new Thread( + () -> Loops.writerLoop(dataSize, socketPair.client, renegotiation, scattering, false, false), + "client-writer"); + Thread serverReaderThread = new Thread( + () -> Loops.readerLoop(dataSize, socketPair.server, scattering, false, false), "server-reader"); + Thread serverWriterThread = new Thread( + () -> Loops.writerLoop(dataSize, socketPair.server, renegotiation, scattering, true, true), + "server-writer"); + Thread clientReaderThread = new Thread( + () -> Loops.readerLoop(dataSize, socketPair.client, scattering, true, true), "client-reader"); + + try { + serverReaderThread.start(); + clientWriterThread.start(); + + serverReaderThread.join(); + clientWriterThread.join(); + + clientReaderThread.start(); + serverWriterThread.start(); + + clientReaderThread.join(); + serverWriterThread.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + SocketPairFactory.checkDeallocation(socketPair); + } + + public static void fullDuplex(SocketPair socketPair, int dataSize) { + Thread clientWriterThread = new Thread( + () -> Loops.writerLoop(dataSize, socketPair.client, false, false, false, false), "client-writer"); + Thread serverWriterThread = new Thread( + () -> Loops.writerLoop(dataSize, socketPair.server, false, false, false, false), "server-write"); + Thread clientReaderThread = + new Thread(() -> Loops.readerLoop(dataSize, socketPair.client, false, false, false), "client-reader"); + Thread serverReaderThread = + new Thread(() -> Loops.readerLoop(dataSize, socketPair.server, false, false, false), "server-reader"); + + try { + serverReaderThread.start(); + clientWriterThread.start(); + clientReaderThread.start(); + serverWriterThread.start(); + + serverReaderThread.join(); + clientWriterThread.join(); + clientReaderThread.join(); + serverWriterThread.join(); + + socketPair.client.external.close(); + socketPair.server.external.close(); + + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + + SocketPairFactory.checkDeallocation(socketPair); + } + + public static void writerLoop( + int size, + SocketGroup socketGroup, + boolean renegotiate, + boolean scattering, + boolean shutdown, + boolean close) { + TestJavaUtil.cannotFail(() -> { + logger.fine(() -> String.format( + "Starting writer loop, size: %s, scattering: %s, renegotiate: %s", size, scattering, renegotiate)); + SplittableRandom random = new SplittableRandom(seed); + int bytesSinceRenegotiation = 0; + int bytesRemaining = size; + byte[] bufferArray = new byte[bufferSize]; + while (bytesRemaining > 0) { + ByteBuffer buffer = ByteBuffer.wrap(bufferArray, 0, Math.min(bufferSize, bytesRemaining)); + TestUtil.nextBytes(random, buffer.array()); + while (buffer.hasRemaining()) { + if (renegotiate && bytesSinceRenegotiation > renegotiatePeriod) { + socketGroup.tls.renegotiate(); + bytesSinceRenegotiation = 0; + } + int c; + if (scattering) { + c = (int) socketGroup.tls.write(multiWrap(buffer)); + } else { + c = socketGroup.external.write(buffer); + } + assertTrue(c > 0, "blocking write must return a positive number"); + bytesSinceRenegotiation += c; + bytesRemaining -= c; + assertTrue(bytesRemaining >= 0); + } + } + if (shutdown) socketGroup.tls.shutdown(); + if (close) socketGroup.external.close(); + logger.fine("Finalizing writer loop"); + }); + } + + public static void readerLoop( + int size, SocketGroup socketGroup, boolean gathering, boolean readEof, boolean close) { + + TestJavaUtil.cannotFail(() -> { + logger.fine(() -> String.format("Starting reader loop. Size: $size, gathering: %s", gathering)); + byte[] readArray = new byte[bufferSize]; + int bytesRemaining = size; + MessageDigest digest = MessageDigest.getInstance(hashAlgorithm); + while (bytesRemaining > 0) { + ByteBuffer readBuffer = ByteBuffer.wrap(readArray, 0, Math.min(bufferSize, bytesRemaining)); + int c; + if (gathering) { + c = (int) socketGroup.tls.read(multiWrap(readBuffer)); + } else { + c = socketGroup.external.read(readBuffer); + } + assertTrue(c > 0, "blocking read must return a positive number"); + digest.update(readBuffer.array(), 0, readBuffer.position()); + bytesRemaining -= c; + assertTrue(bytesRemaining >= 0); + } + if (readEof) assertEquals(-1, socketGroup.external.read(ByteBuffer.wrap(readArray))); + byte[] actual = digest.digest(); + assertArrayEquals(expectedBytesHash.apply(size), actual); + if (close) socketGroup.external.close(); + logger.fine("Finalizing reader loop"); + }); + } + + private static byte[] hash(int size) { + try { + MessageDigest digest = MessageDigest.getInstance(hashAlgorithm); + SplittableRandom random = new SplittableRandom(seed); + int generated = 0; + int bufferSize = 4 * 1024; + byte[] array = new byte[bufferSize]; + while (generated < size) { + TestUtil.nextBytes(random, array); + int pending = size - generated; + digest.update(array, 0, Math.min(bufferSize, pending)); + generated += bufferSize; + } + return digest.digest(); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + + public static final Function expectedBytesHash = new TestJavaUtil.Memo<>(Loops::hash)::apply; + + private static ByteBuffer[] multiWrap(ByteBuffer buffer) { + return new ByteBuffer[] {ByteBuffer.allocate(0), buffer, ByteBuffer.allocate(0)}; + } +} diff --git a/src/test/scala/tlschannel/helpers/Loops.scala b/src/test/scala/tlschannel/helpers/Loops.scala deleted file mode 100644 index e7402218..00000000 --- a/src/test/scala/tlschannel/helpers/Loops.scala +++ /dev/null @@ -1,157 +0,0 @@ -package tlschannel.helpers - -import java.nio.ByteBuffer -import java.security.MessageDigest -import java.util.SplittableRandom -import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertEquals, assertTrue} -import tlschannel.helpers.TestUtil.Memo -import tlschannel.helpers.SocketGroups._ -import java.util.logging.Logger - -object Loops { - - val logger = Logger.getLogger(Loops.getClass.getName) - - val seed = 143000953L - - /* - * Note that it is necessary to use a multiple of 4 as buffer size for writing. - * This is because the bytesProduced to write are generated using Random.nextBytes, that - * always consumes full (4 byte) integers. A multiple of 4 then prevents "holes" - * in the random sequence. - */ - val bufferSize = 4 * 5000 - - val renegotiatePeriod = 10000 - val hashAlgorithm = "MD5" // for speed - - /** Test a half-duplex interaction, with (optional) renegotiation before reversing the direction of the flow (as in - * HTTP) - */ - def halfDuplex(socketPair: SocketPair, dataSize: Int, renegotiation: Boolean = false, scattering: Boolean = false) = { - val clientWriterThread = - new Thread(() => Loops.writerLoop(dataSize, socketPair.client, renegotiation, scattering), "client-writer") - val serverReaderThread = - new Thread(() => Loops.readerLoop(dataSize, socketPair.server, scattering), "server-reader") - val serverWriterThread = new Thread( - () => Loops.writerLoop(dataSize, socketPair.server, renegotiation, scattering, shutdown = true, close = true), - "server-writer" - ) - val clientReaderThread = new Thread( - () => Loops.readerLoop(dataSize, socketPair.client, scattering, close = true, readEof = true), - "client-reader" - ) - Seq(serverReaderThread, clientWriterThread).foreach(_.start()) - Seq(serverReaderThread, clientWriterThread).foreach(_.join()) - Seq(clientReaderThread, serverWriterThread).foreach(_.start()) - Seq(clientReaderThread, serverWriterThread).foreach(_.join()) - SocketPairFactory.checkDeallocation(socketPair) - } - - def fullDuplex(socketPair: SocketPair, dataSize: Int) = { - val clientWriterThread = new Thread(() => Loops.writerLoop(dataSize, socketPair.client), "client-writer") - val serverWriterThread = new Thread(() => Loops.writerLoop(dataSize, socketPair.server), "server-write") - val clientReaderThread = new Thread(() => Loops.readerLoop(dataSize, socketPair.client), "client-reader") - val serverReaderThread = new Thread(() => Loops.readerLoop(dataSize, socketPair.server), "server-reader") - Seq(serverReaderThread, clientWriterThread, clientReaderThread, serverWriterThread).foreach(_.start()) - Seq(serverReaderThread, clientWriterThread, clientReaderThread, serverWriterThread).foreach(_.join()) - socketPair.client.external.close() - socketPair.server.external.close() - SocketPairFactory.checkDeallocation(socketPair) - } - - def writerLoop( - size: Int, - socketGroup: SocketGroup, - renegotiate: Boolean = false, - scattering: Boolean = false, - shutdown: Boolean = false, - close: Boolean = false - ): Unit = TestUtil.cannotFail { - - logger.fine(() => s"Starting writer loop, size: $size, scattering: $scattering, renegotiate:$renegotiate") - val random = new SplittableRandom(seed) - var bytesSinceRenegotiation = 0 - var bytesRemaining = size - val bufferArray = Array.ofDim[Byte](bufferSize) - while (bytesRemaining > 0) { - val buffer = ByteBuffer.wrap(bufferArray, 0, math.min(bufferSize, bytesRemaining)) - TestUtil.nextBytes(random, buffer.array) - while (buffer.hasRemaining) { - if (renegotiate && bytesSinceRenegotiation > renegotiatePeriod) { - socketGroup.tls.renegotiate() - bytesSinceRenegotiation = 0 - } - val c = - if (scattering) - socketGroup.tls.write(multiWrap(buffer)).toInt - else - socketGroup.external.write(buffer) - assertTrue(c > 0, "blocking write must return a positive number") - bytesSinceRenegotiation += c - bytesRemaining -= c.toInt - assertTrue(bytesRemaining >= 0) - } - } - if (shutdown) - socketGroup.tls.shutdown() - if (close) - socketGroup.external.close() - logger.fine("Finalizing writer loop") - } - - def readerLoop( - size: Int, - socketGroup: SocketGroup, - gathering: Boolean = false, - readEof: Boolean = false, - close: Boolean = false - ): Unit = TestUtil.cannotFail { - - logger.fine(() => s"Starting reader loop. Size: $size, gathering: $gathering") - val readArray = Array.ofDim[Byte](bufferSize) - var bytesRemaining = size - val digest = MessageDigest.getInstance(hashAlgorithm) - while (bytesRemaining > 0) { - val readBuffer = ByteBuffer.wrap(readArray, 0, math.min(bufferSize, bytesRemaining)) - val c = - if (gathering) - socketGroup.tls.read(multiWrap(readBuffer)).toInt - else - socketGroup.external.read(readBuffer) - assertTrue(c > 0, "blocking read must return a positive number") - digest.update(readBuffer.array(), 0, readBuffer.position()) - bytesRemaining -= c - assertTrue(bytesRemaining >= 0) - } - if (readEof) - assertEquals(-1, socketGroup.external.read(ByteBuffer.wrap(readArray))) - val actual = digest.digest() - assertArrayEquals(expectedBytesHash(size), actual) - if (close) - socketGroup.external.close() - logger.fine("Finalizing reader loop") - } - - private def hash(size: Int): Array[Byte] = { - val digest = MessageDigest.getInstance(hashAlgorithm) - val random = new SplittableRandom(seed) - var generated = 0 - val bufferSize = 4 * 1024 - val array = Array.ofDim[Byte](bufferSize) - while (generated < size) { - TestUtil.nextBytes(random, array) - val pending = size - generated - digest.update(array, 0, math.min(bufferSize, pending)) - generated += bufferSize - } - digest.digest() - } - - val expectedBytesHash: Int => Array[Byte] = new Memo(hash).apply - - private def multiWrap(buffer: ByteBuffer) = { - Array(ByteBuffer.allocate(0), buffer, ByteBuffer.allocate(0)) - } - -} diff --git a/src/test/scala/tlschannel/helpers/TestJavaUtil.java b/src/test/scala/tlschannel/helpers/TestJavaUtil.java index 710b66e3..16a3a284 100644 --- a/src/test/scala/tlschannel/helpers/TestJavaUtil.java +++ b/src/test/scala/tlschannel/helpers/TestJavaUtil.java @@ -1,5 +1,7 @@ package tlschannel.helpers; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; @@ -36,4 +38,17 @@ public static Runnable cannotFailRunnable(ExceptionalRunnable exceptionalRunnabl } }; } + + public static class Memo { + private final ConcurrentHashMap cache = new ConcurrentHashMap<>(); + private final Function f; + + public Memo(Function f) { + this.f = f; + } + + public O apply(I i) { + return cache.computeIfAbsent(i, f); + } + } } diff --git a/src/test/scala/tlschannel/helpers/TestUtil.scala b/src/test/scala/tlschannel/helpers/TestUtil.scala index 37785ef2..a615af8f 100644 --- a/src/test/scala/tlschannel/helpers/TestUtil.scala +++ b/src/test/scala/tlschannel/helpers/TestUtil.scala @@ -1,31 +1,12 @@ package tlschannel.helpers import java.util.SplittableRandom -import java.util.concurrent.ConcurrentHashMap -import java.util.logging.{Level, Logger} -import scala.jdk.CollectionConverters._ -import scala.util.control.ControlThrowable +import java.util.logging.Logger object TestUtil { val logger = Logger.getLogger(TestUtil.getClass.getName) - def cannotFail(thunk: => Unit): Unit = { - try thunk - catch { - case r: ControlThrowable => - // pass - case e: Throwable => - val lastMessage = - s"An essential thread (${Thread.currentThread().getName}) failed unexpectedly, terminating process" - logger.log(Level.SEVERE, lastMessage, e) - System.err.println(lastMessage) - e.printStackTrace() // we are committing suicide, assure the reason gets through - Thread.sleep(1000) // give the process some time for flushing logs - System.exit(1) - } - } - def removeAndCollect[A](iterator: java.util.Iterator[A]): Seq[A] = { val builder = Seq.newBuilder[A] while (iterator.hasNext) { @@ -52,16 +33,4 @@ object TestUtil { } } } - - /** @param f - * the function to memoize - * @tparam I - * input to f - * @tparam O - * output of f - */ - class Memo[I, O](f: I => O) extends (I => O) { - val cache = new ConcurrentHashMap[I, O] - override def apply(x: I) = cache.asScala.getOrElseUpdate(x, f(x)) - } }