Skip to content

Commit

Permalink
Migrate to Java: AsyncLoops
Browse files Browse the repository at this point in the history
  • Loading branch information
marianobarrios committed May 5, 2024
1 parent 519a6b9 commit 076b6ea
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 172 deletions.
20 changes: 11 additions & 9 deletions src/test/scala/tlschannel/async/AsyncTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import scala.Option;
import scala.collection.immutable.Seq;
import scala.jdk.javaapi.CollectionConverters;
import tlschannel.helpers.AsyncLoops;
import tlschannel.helpers.SocketGroups.AsyncSocketPair;
import tlschannel.helpers.SocketPairFactory;
Expand All @@ -21,13 +22,13 @@ public class AsyncTest implements AsyncTestBase {

// real engine - run tasks
@Test
public void testRunTasks() {
public void testRunTasks() throws Throwable {
System.out.println("testRunTasks():");
AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup();
int dataSize = 5 * 1024 * 1024;
System.out.printf("data size: %d\n", dataSize);
Seq<AsyncSocketPair> socketPairs =
factory.asyncN(Option.apply(null), channelGroup, socketPairCount, true, false);
List<AsyncSocketPair> socketPairs = CollectionConverters.asJava(
factory.asyncN(Option.apply(null), channelGroup, socketPairCount, true, false));
AsyncLoops.Report report = AsyncLoops.loop(socketPairs, dataSize);

shutdownChannelGroup(channelGroup);
Expand All @@ -41,13 +42,13 @@ public void testRunTasks() {

// real engine - do not run tasks
@Test
public void testNotRunTasks() {
public void testNotRunTasks() throws Throwable {
System.out.println("testNotRunTasks():");
AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup();
int dataSize = 2 * 1024 * 1024;
System.out.printf("data size: %d\n", dataSize);
Seq<AsyncSocketPair> socketPairs =
factory.asyncN(Option.apply(null), channelGroup, socketPairCount, false, false);
List<AsyncSocketPair> socketPairs = CollectionConverters.asJava(
factory.asyncN(Option.apply(null), channelGroup, socketPairCount, false, false));
AsyncLoops.Report report = AsyncLoops.loop(socketPairs, dataSize);

shutdownChannelGroup(channelGroup);
Expand All @@ -64,12 +65,13 @@ public void testNotRunTasks() {

// null engine
@Test
public void testNullEngine() {
public void testNullEngine() throws Throwable {
System.out.println("testNullEngine():");
AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup();
int dataSize = 12 * 1024 * 1024;
System.out.printf("data size: %d\n", dataSize);
Seq<AsyncSocketPair> socketPairs = factory.asyncN(null, channelGroup, socketPairCount, true, false);
List<AsyncSocketPair> socketPairs =
CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false));
AsyncLoops.Report report = AsyncLoops.loop(socketPairs, dataSize);

shutdownChannelGroup(channelGroup);
Expand Down
232 changes: 232 additions & 0 deletions src/test/scala/tlschannel/helpers/AsyncLoops.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package tlschannel.helpers;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.LongAdder;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import tlschannel.helpers.SocketGroups.AsyncSocketGroup;
import tlschannel.helpers.SocketGroups.AsyncSocketPair;

public class AsyncLoops {

private static final Logger logger = Logger.getLogger(AsyncLoops.class.getName());

private interface Endpoint {
int remaining();

Optional<Throwable> exception();
}

private static class WriterEndpoint implements Endpoint {
private final AsyncSocketGroup socketGroup;
private final SplittableRandom random = new SplittableRandom(Loops.seed());
private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize());
private int remaining;
private Optional<Throwable> exception = Optional.empty();

public WriterEndpoint(AsyncSocketGroup socketGroup, int remaining) {
this.socketGroup = socketGroup;
this.remaining = remaining;
buffer.flip();
}

@Override
public int remaining() {
return remaining;
}

public Optional<Throwable> exception() {
return exception;
}
}

private static class ReaderEndpoint implements Endpoint {
private final AsyncSocketGroup socketGroup;
private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize());
private final MessageDigest digest;
private int remaining;
private Optional<Throwable> exception = Optional.empty();

public ReaderEndpoint(AsyncSocketGroup socketGroup, int remaining) {
this.socketGroup = socketGroup;
this.remaining = remaining;
try {
digest = MessageDigest.getInstance(Loops.hashAlgorithm());
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}

@Override
public int remaining() {
return remaining;
}

public Optional<Throwable> exception() {
return exception;
}
}

public static class Report {
final long dequeueCycles;
final long completedReads;
final long failedReads;
final long completedWrites;
final long failedWrites;

public Report(
long dequeueCycles, long completedReads, long failedReads, long completedWrites, long failedWrites) {
this.dequeueCycles = dequeueCycles;
this.completedReads = completedReads;
this.failedReads = failedReads;
this.completedWrites = completedWrites;
this.failedWrites = failedWrites;
}

public void print() {
System.out.print("test loop:\n");
System.out.printf(" dequeue cycles: %8d\n", dequeueCycles);
System.out.printf(" completed reads: %8d\n", completedReads);
System.out.printf(" failed reads: %8d\n", failedReads);
System.out.printf(" completed writes: %8d\n", completedWrites);
System.out.printf(" failed writes: %8d\n", failedWrites);
}
}

public static Report loop(List<AsyncSocketPair> socketPairs, int dataSize) throws Throwable {
logger.fine(() -> "starting async loop - pair count: " + socketPairs.size());

int dequeueCycles = 0;
LongAdder completedReads = new LongAdder();
LongAdder failedReads = new LongAdder();
LongAdder completedWrites = new LongAdder();
LongAdder failedWrites = new LongAdder();

LinkedBlockingQueue<Endpoint> endpointQueue = new LinkedBlockingQueue<>();
byte[] dataHash = Loops.expectedBytesHash().apply(dataSize);

List<WriterEndpoint> clientEndpoints = socketPairs.stream()
.map(p -> new WriterEndpoint(p.client, dataSize))
.collect(Collectors.toList());

List<ReaderEndpoint> serverEndpoints = socketPairs.stream()
.map(p -> new ReaderEndpoint(p.server, dataSize))
.collect(Collectors.toList());

for (Endpoint endpoint : clientEndpoints) {
endpointQueue.put(endpoint);
}
for (Endpoint endpoint : serverEndpoints) {
endpointQueue.put(endpoint);
}

int endpointsFinished = 0;
int totalEndpoints = endpointQueue.size();
while (true) {
Endpoint endpoint = endpointQueue.take(); // blocks

dequeueCycles += 1;

if (endpoint.exception().isPresent()) {
throw endpoint.exception().get();
}

if (endpoint.remaining() == 0) {
endpointsFinished += 1;
if (endpointsFinished == totalEndpoints) {
break;
}
} else {

if (endpoint instanceof WriterEndpoint) {
WriterEndpoint writer = (WriterEndpoint) endpoint;

if (!writer.buffer.hasRemaining()) {
TestUtil.nextBytes(writer.random, writer.buffer.array());
writer.buffer.position(0);
writer.buffer.limit(Math.min(writer.buffer.capacity(), writer.remaining));
}
writer.socketGroup.external.write(
writer.buffer, 1, TimeUnit.DAYS, null, new CompletionHandler<Integer, Object>() {
@Override
public void completed(Integer c, Object attach) {
assertTrue(c > 0);
writer.remaining -= c;
try {
endpointQueue.put(writer);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
completedWrites.increment();
}

@Override
public void failed(Throwable e, Object attach) {
writer.exception = Optional.of(e);
try {
endpointQueue.put(writer);
} catch (InterruptedException ex) {
throw new RuntimeException(ex);
}
failedWrites.increment();
}
});
} else if (endpoint instanceof ReaderEndpoint) {
ReaderEndpoint reader = (ReaderEndpoint) endpoint;
reader.buffer.clear();
reader.socketGroup.external.read(
reader.buffer, 1, TimeUnit.DAYS, null, new CompletionHandler<Integer, Object>() {
@Override
public void completed(Integer c, Object attach) {
assertTrue(c > 0);
reader.digest.update(reader.buffer.array(), 0, c);
reader.remaining -= c;
try {
endpointQueue.put(reader);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
completedReads.increment();
}

@Override
public void failed(Throwable e, Object attach) {
reader.exception = Optional.of(e);
try {
endpointQueue.put(reader);
} catch (InterruptedException ex) {
throw new RuntimeException(ex);
}
failedReads.increment();
}
});
} else {
throw new IllegalStateException();
}
}
}
for (AsyncSocketPair socketPair : socketPairs) {
socketPair.client.external.close();
socketPair.server.external.close();
SocketPairFactory.checkDeallocation(socketPair);
}
for (ReaderEndpoint reader : serverEndpoints) {
assertArrayEquals(reader.digest.digest(), dataHash);
}
return new Report(
dequeueCycles,
completedReads.longValue(),
failedReads.longValue(),
completedWrites.longValue(),
failedWrites.longValue());
}
}
Loading

0 comments on commit 076b6ea

Please sign in to comment.