Skip to content

Commit

Permalink
Migrate to Java: Loops
Browse files Browse the repository at this point in the history
  • Loading branch information
marianobarrios committed May 8, 2024
1 parent 7bc83ea commit f6f1a6a
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 195 deletions.
2 changes: 1 addition & 1 deletion src/test/scala/tlschannel/NullEngineTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class NullEngineTest {

{
// heat cache
Loops.expectedBytesHash().apply(dataSize);
Loops.expectedBytesHash.apply(dataSize);
}

// null engine - half duplex - heap buffers
Expand Down
10 changes: 5 additions & 5 deletions src/test/scala/tlschannel/helpers/AsyncLoops.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ private interface Endpoint {

private static class WriterEndpoint implements Endpoint {
private final AsyncSocketGroup socketGroup;
private final SplittableRandom random = new SplittableRandom(Loops.seed());
private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize());
private final SplittableRandom random = new SplittableRandom(Loops.seed);
private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize);
private int remaining;
private Optional<Throwable> exception = Optional.empty();

Expand All @@ -51,7 +51,7 @@ public Optional<Throwable> exception() {

private static class ReaderEndpoint implements Endpoint {
private final AsyncSocketGroup socketGroup;
private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize());
private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize);
private final MessageDigest digest;
private int remaining;
private Optional<Throwable> exception = Optional.empty();
Expand All @@ -60,7 +60,7 @@ public ReaderEndpoint(AsyncSocketGroup socketGroup, int remaining) {
this.socketGroup = socketGroup;
this.remaining = remaining;
try {
digest = MessageDigest.getInstance(Loops.hashAlgorithm());
digest = MessageDigest.getInstance(Loops.hashAlgorithm);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -112,7 +112,7 @@ public static Report loop(List<AsyncSocketPair> socketPairs, int dataSize) throw
LongAdder failedWrites = new LongAdder();

LinkedBlockingQueue<Endpoint> endpointQueue = new LinkedBlockingQueue<>();
byte[] dataHash = Loops.expectedBytesHash().apply(dataSize);
byte[] dataHash = Loops.expectedBytesHash.apply(dataSize);

List<WriterEndpoint> clientEndpoints = socketPairs.stream()
.map(p -> new WriterEndpoint(p.client, dataSize))
Expand Down
191 changes: 191 additions & 0 deletions src/test/scala/tlschannel/helpers/Loops.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package tlschannel.helpers;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.SplittableRandom;
import java.util.function.Function;
import java.util.logging.Logger;
import tlschannel.helpers.SocketGroups.*;

public class Loops {

private static final Logger logger = Logger.getLogger(Loops.class.getName());

public static final long seed = 143000953L;

/*
* Note that it is necessary to use a multiple of 4 as buffer size for writing.
* This is because the bytesProduced to write are generated using Random.nextBytes, that
* always consumes full (4 byte) integers. A multiple of 4 then prevents "holes"
* in the random sequence.
*/
public static final int bufferSize = 4 * 5000;

private static final int renegotiatePeriod = 10000;
public static final String hashAlgorithm = "MD5"; // for speed

/** Test a half-duplex interaction, with (optional) renegotiation before reversing the direction of the flow (as in
* HTTP)
*/
public static void halfDuplex(SocketPair socketPair, int dataSize, boolean renegotiation, boolean scattering) {
Thread clientWriterThread = new Thread(
() -> Loops.writerLoop(dataSize, socketPair.client, renegotiation, scattering, false, false),
"client-writer");
Thread serverReaderThread = new Thread(
() -> Loops.readerLoop(dataSize, socketPair.server, scattering, false, false), "server-reader");
Thread serverWriterThread = new Thread(
() -> Loops.writerLoop(dataSize, socketPair.server, renegotiation, scattering, true, true),
"server-writer");
Thread clientReaderThread = new Thread(
() -> Loops.readerLoop(dataSize, socketPair.client, scattering, true, true), "client-reader");

try {
serverReaderThread.start();
clientWriterThread.start();

serverReaderThread.join();
clientWriterThread.join();

clientReaderThread.start();
serverWriterThread.start();

clientReaderThread.join();
serverWriterThread.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}

SocketPairFactory.checkDeallocation(socketPair);
}

public static void fullDuplex(SocketPair socketPair, int dataSize) {
Thread clientWriterThread = new Thread(
() -> Loops.writerLoop(dataSize, socketPair.client, false, false, false, false), "client-writer");
Thread serverWriterThread = new Thread(
() -> Loops.writerLoop(dataSize, socketPair.server, false, false, false, false), "server-write");
Thread clientReaderThread =
new Thread(() -> Loops.readerLoop(dataSize, socketPair.client, false, false, false), "client-reader");
Thread serverReaderThread =
new Thread(() -> Loops.readerLoop(dataSize, socketPair.server, false, false, false), "server-reader");

try {
serverReaderThread.start();
clientWriterThread.start();
clientReaderThread.start();
serverWriterThread.start();

serverReaderThread.join();
clientWriterThread.join();
clientReaderThread.join();
serverWriterThread.join();

socketPair.client.external.close();
socketPair.server.external.close();

} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}

SocketPairFactory.checkDeallocation(socketPair);
}

public static void writerLoop(
int size,
SocketGroup socketGroup,
boolean renegotiate,
boolean scattering,
boolean shutdown,
boolean close) {
TestJavaUtil.cannotFail(() -> {
logger.fine(() -> String.format(
"Starting writer loop, size: %s, scattering: %s, renegotiate: %s", size, scattering, renegotiate));
SplittableRandom random = new SplittableRandom(seed);
int bytesSinceRenegotiation = 0;
int bytesRemaining = size;
byte[] bufferArray = new byte[bufferSize];
while (bytesRemaining > 0) {
ByteBuffer buffer = ByteBuffer.wrap(bufferArray, 0, Math.min(bufferSize, bytesRemaining));
TestUtil.nextBytes(random, buffer.array());
while (buffer.hasRemaining()) {
if (renegotiate && bytesSinceRenegotiation > renegotiatePeriod) {
socketGroup.tls.renegotiate();
bytesSinceRenegotiation = 0;
}
int c;
if (scattering) {
c = (int) socketGroup.tls.write(multiWrap(buffer));
} else {
c = socketGroup.external.write(buffer);
}
assertTrue(c > 0, "blocking write must return a positive number");
bytesSinceRenegotiation += c;
bytesRemaining -= c;
assertTrue(bytesRemaining >= 0);
}
}
if (shutdown) socketGroup.tls.shutdown();
if (close) socketGroup.external.close();
logger.fine("Finalizing writer loop");
});
}

public static void readerLoop(
int size, SocketGroup socketGroup, boolean gathering, boolean readEof, boolean close) {

TestJavaUtil.cannotFail(() -> {
logger.fine(() -> String.format("Starting reader loop. Size: $size, gathering: %s", gathering));
byte[] readArray = new byte[bufferSize];
int bytesRemaining = size;
MessageDigest digest = MessageDigest.getInstance(hashAlgorithm);
while (bytesRemaining > 0) {
ByteBuffer readBuffer = ByteBuffer.wrap(readArray, 0, Math.min(bufferSize, bytesRemaining));
int c;
if (gathering) {
c = (int) socketGroup.tls.read(multiWrap(readBuffer));
} else {
c = socketGroup.external.read(readBuffer);
}
assertTrue(c > 0, "blocking read must return a positive number");
digest.update(readBuffer.array(), 0, readBuffer.position());
bytesRemaining -= c;
assertTrue(bytesRemaining >= 0);
}
if (readEof) assertEquals(-1, socketGroup.external.read(ByteBuffer.wrap(readArray)));
byte[] actual = digest.digest();
assertArrayEquals(expectedBytesHash.apply(size), actual);
if (close) socketGroup.external.close();
logger.fine("Finalizing reader loop");
});
}

private static byte[] hash(int size) {
try {
MessageDigest digest = MessageDigest.getInstance(hashAlgorithm);
SplittableRandom random = new SplittableRandom(seed);
int generated = 0;
int bufferSize = 4 * 1024;
byte[] array = new byte[bufferSize];
while (generated < size) {
TestUtil.nextBytes(random, array);
int pending = size - generated;
digest.update(array, 0, Math.min(bufferSize, pending));
generated += bufferSize;
}
return digest.digest();
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}

public static final Function<Integer, byte[]> expectedBytesHash = new TestJavaUtil.Memo<>(Loops::hash)::apply;

private static ByteBuffer[] multiWrap(ByteBuffer buffer) {
return new ByteBuffer[] {ByteBuffer.allocate(0), buffer, ByteBuffer.allocate(0)};
}
}
157 changes: 0 additions & 157 deletions src/test/scala/tlschannel/helpers/Loops.scala

This file was deleted.

Loading

0 comments on commit f6f1a6a

Please sign in to comment.