Skip to content

Commit

Permalink
Migrate to Java: AsyncShutdownTest
Browse files Browse the repository at this point in the history
  • Loading branch information
marianobarrios committed May 5, 2024
1 parent 076b6ea commit 29578e0
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 243 deletions.
100 changes: 100 additions & 0 deletions src/test/scala/tlschannel/async/AsyncShutdownTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package tlschannel.async;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import scala.jdk.javaapi.CollectionConverters;
import tlschannel.helpers.SocketGroups;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SslContextFactory;

@TestInstance(Lifecycle.PER_CLASS)
public class AsyncShutdownTest implements AsyncTestBase {

private final SslContextFactory sslContextFactory = new SslContextFactory();
private final SocketPairFactory factory = new SocketPairFactory(sslContextFactory.defaultContext());

int bufferSize = 10;

@Test
public void testImmediateShutdown() throws InterruptedException {
System.out.println("testImmediateShutdown():");
AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup();
int socketPairCount = 50;
List<SocketGroups.AsyncSocketPair> socketPairs =
CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false));
for (SocketGroups.AsyncSocketPair pair : socketPairs) {
ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize);
pair.client.external.write(writeBuffer);
ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize);
pair.server.external.read(readBuffer);
}

assertFalse(channelGroup.isTerminated());

channelGroup.shutdownNow();

// terminated even after a relatively short timeout
boolean terminated = channelGroup.awaitTermination(100, TimeUnit.MILLISECONDS);
assertTrue(terminated);
assertTrue(channelGroup.isTerminated());
assertChannelGroupConsistency(channelGroup);

printChannelGroupStatus(channelGroup);
}

@Test
public void testNonImmediateShutdown() throws InterruptedException, IOException {
System.out.println("testNonImmediateShutdown():");
AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup();
int socketPairCount = 50;
List<SocketGroups.AsyncSocketPair> socketPairs =
CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false));
for (SocketGroups.AsyncSocketPair pair : socketPairs) {
ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize);
pair.client.external.write(writeBuffer);
ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize);
pair.server.external.read(readBuffer);
}

assertFalse(channelGroup.isTerminated());

channelGroup.shutdown();

{
// not terminated even after a relatively long timeout
boolean terminated = channelGroup.awaitTermination(2000, TimeUnit.MILLISECONDS);
assertFalse(terminated);
assertFalse(channelGroup.isTerminated());
}

for (SocketGroups.AsyncSocketPair pair : socketPairs) {
pair.client.external.close();
pair.server.external.close();
}

{
// terminated even after a relatively short timeout
boolean terminated = channelGroup.awaitTermination(100, TimeUnit.MILLISECONDS);
assertTrue(terminated);
assertTrue(channelGroup.isTerminated());
}

assertChannelGroupConsistency(channelGroup);

assertEquals(0, channelGroup.getCancelledReadCount());
assertEquals(0, channelGroup.getCancelledWriteCount());
assertEquals(0, channelGroup.getFailedReadCount());
assertEquals(0, channelGroup.getFailedWriteCount());

printChannelGroupStatus(channelGroup);
}
}
92 changes: 0 additions & 92 deletions src/test/scala/tlschannel/async/AsyncShutdownTest.scala

This file was deleted.

159 changes: 159 additions & 0 deletions src/test/scala/tlschannel/async/AsyncTimeoutTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package tlschannel.async;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.LongAdder;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import scala.jdk.javaapi.CollectionConverters;
import tlschannel.helpers.SocketGroups;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SslContextFactory;

@TestInstance(Lifecycle.PER_CLASS)
public class AsyncTimeoutTest implements AsyncTestBase {

SslContextFactory sslContextFactory = new SslContextFactory();
SocketPairFactory factory = new SocketPairFactory(sslContextFactory.defaultContext());

private static final int bufferSize = 10;

private static final int repetitions = 50;

// scheduled timeout
@Test
public void testScheduledTimeout() throws IOException {
System.out.println("testScheduledTimeout()");
AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup();
LongAdder successWrites = new LongAdder();
LongAdder successReads = new LongAdder();
for (int i = 1; i <= repetitions; i++) {
int socketPairCount = 50;
List<SocketGroups.AsyncSocketPair> socketPairs =
CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false));
CountDownLatch latch = new CountDownLatch(socketPairCount * 2);
for (SocketGroups.AsyncSocketPair pair : socketPairs) {
ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize);
AtomicBoolean clientDone = new AtomicBoolean();
pair.client.external.write(
writeBuffer, 50, TimeUnit.MILLISECONDS, null, new CompletionHandler<Integer, Object>() {
@Override
public void failed(Throwable exc, Object attachment) {
if (!clientDone.compareAndSet(false, true)) {
Assertions.fail();
}
latch.countDown();
}

@Override
public void completed(Integer result, Object attachment) {
if (!clientDone.compareAndSet(false, true)) {
Assertions.fail();
}
latch.countDown();
successWrites.increment();
}
});
ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize);
AtomicBoolean serverDone = new AtomicBoolean();
pair.server.external.read(
readBuffer, 100, TimeUnit.MILLISECONDS, null, new CompletionHandler<Integer, Object>() {
@Override
public void failed(Throwable exc, Object attachment) {
if (!serverDone.compareAndSet(false, true)) {
Assertions.fail();
}
latch.countDown();
}

@Override
public void completed(Integer result, Object attachment) {
if (!serverDone.compareAndSet(false, true)) {
Assertions.fail();
}
latch.countDown();
successReads.increment();
}
});
}
try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
for (SocketGroups.AsyncSocketPair pair : socketPairs) {
pair.client.external.close();
pair.server.external.close();
}
}

shutdownChannelGroup(channelGroup);
assertChannelGroupConsistency(channelGroup);

assertEquals(0, channelGroup.getFailedReadCount());
assertEquals(0, channelGroup.getFailedWriteCount());

assertEquals(channelGroup.getSuccessfulWriteCount(), successWrites.longValue());
assertEquals(channelGroup.getSuccessfulReadCount(), successReads.longValue());

System.out.printf("success writes: %8d\n", successWrites.longValue());
System.out.printf("success reads: %8d\n", successReads.longValue());
printChannelGroupStatus(channelGroup);
}

// triggered timeout
@Test
public void testTriggeredTimeout() throws IOException {
System.out.println("testScheduledTimeout()");
AsynchronousTlsChannelGroup channelGroup = new AsynchronousTlsChannelGroup();
int successfulWriteCancellations = 0;
int successfulReadCancellations = 0;
for (int i = 1; i <= repetitions; i++) {
int socketPairCount = 50;
List<SocketGroups.AsyncSocketPair> socketPairs =
CollectionConverters.asJava(factory.asyncN(null, channelGroup, socketPairCount, true, false));

for (SocketGroups.AsyncSocketPair pair : socketPairs) {
ByteBuffer writeBuffer = ByteBuffer.allocate(bufferSize);
Future<Integer> writeFuture = pair.client.external.write(writeBuffer);
if (writeFuture.cancel(true)) {
successfulWriteCancellations += 1;
}
}

for (SocketGroups.AsyncSocketPair pair : socketPairs) {
ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize);
Future<Integer> readFuture = pair.server.external.read(readBuffer);
if (readFuture.cancel(true)) {
successfulReadCancellations += 1;
}
}

for (SocketGroups.AsyncSocketPair pair : socketPairs) {
pair.client.external.close();
pair.server.external.close();
}
}
shutdownChannelGroup(channelGroup);
assertChannelGroupConsistency(channelGroup);

assertEquals(0, channelGroup.getFailedReadCount());
assertEquals(0, channelGroup.getFailedWriteCount());

assertEquals(channelGroup.getCancelledWriteCount(), successfulWriteCancellations);
assertEquals(channelGroup.getCancelledReadCount(), successfulReadCancellations);

System.out.printf("success writes: %8d\n", channelGroup.getSuccessfulWriteCount());
System.out.printf("success reads: %8d\n", channelGroup.getSuccessfulReadCount());
}
}
Loading

0 comments on commit 29578e0

Please sign in to comment.