Skip to content

Commit

Permalink
Migrate to Java: SocketGroups
Browse files Browse the repository at this point in the history
  • Loading branch information
marianobarrios committed Apr 21, 2024
1 parent 2c49608 commit 45c4dfd
Show file tree
Hide file tree
Showing 18 changed files with 134 additions and 92 deletions.
2 changes: 1 addition & 1 deletion src/test/scala/tlschannel/AllocationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import java.lang.management.MemoryMXBean;
import scala.Option;
import tlschannel.helpers.Loops;
import tlschannel.helpers.SocketPair;
import tlschannel.helpers.SocketGroups.SocketPair;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SslContextFactory;

Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/tlschannel/BlockingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.junit.jupiter.api.TestInstance.Lifecycle;
import scala.Option;
import tlschannel.helpers.Loops;
import tlschannel.helpers.SocketPair;
import tlschannel.helpers.SocketGroups.SocketPair;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SocketPairFactory.ChuckSizes;
import tlschannel.helpers.SocketPairFactory.ChunkSizeConfig;
Expand Down
10 changes: 5 additions & 5 deletions src/test/scala/tlschannel/CipherTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import scala.Some;
import scala.jdk.CollectionConverters;
import tlschannel.helpers.Loops;
import tlschannel.helpers.SocketPair;
import tlschannel.helpers.SocketGroups.SocketPair;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SslContextFactory;

Expand Down Expand Up @@ -54,8 +54,8 @@ public Collection<DynamicTest> testHalfDuplexWithRenegotiation() {
Some.apply(cipher), Option.apply(null), true, false, Option.apply(null));
Loops.halfDuplex(socketPair, dataSize, protocol.compareTo("TLSv1.2") < 0, false);
String actualProtocol = socketPair
.client()
.tls()
.client
.tls
.getSslEngine()
.getSession()
.getProtocol();
Expand All @@ -82,8 +82,8 @@ public Collection<DynamicTest> testFullDuplex() {
Some.apply(cipher), Option.apply(null), true, false, Option.apply(null));
Loops.fullDuplex(socketPair, dataSize);
String actualProtocol = socketPair
.client()
.tls()
.client
.tls
.getSslEngine()
.getSession()
.getProtocol();
Expand Down
31 changes: 14 additions & 17 deletions src/test/scala/tlschannel/ConcurrentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static tlschannel.helpers.SocketGroups.*;

import java.io.IOException;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -31,15 +32,11 @@ public class ConcurrentTest {
@Test
public void testWriteSide() throws IOException {
SocketPair socketPair = factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null));
Thread clientWriterThread1 =
new Thread(() -> writerLoop(dataSize, 'a', socketPair.client()), "client-writer-1");
Thread clientWriterThread2 =
new Thread(() -> writerLoop(dataSize, 'b', socketPair.client()), "client-writer-2");
Thread clientWriterThread3 =
new Thread(() -> writerLoop(dataSize, 'c', socketPair.client()), "client-writer-3");
Thread clientWriterThread4 =
new Thread(() -> writerLoop(dataSize, 'd', socketPair.client()), "client-writer-4");
Thread serverReaderThread = new Thread(() -> readerLoop(dataSize * 4, socketPair.server()), "server-reader");
Thread clientWriterThread1 = new Thread(() -> writerLoop(dataSize, 'a', socketPair.client), "client-writer-1");
Thread clientWriterThread2 = new Thread(() -> writerLoop(dataSize, 'b', socketPair.client), "client-writer-2");
Thread clientWriterThread3 = new Thread(() -> writerLoop(dataSize, 'c', socketPair.client), "client-writer-3");
Thread clientWriterThread4 = new Thread(() -> writerLoop(dataSize, 'd', socketPair.client), "client-writer-4");
Thread serverReaderThread = new Thread(() -> readerLoop(dataSize * 4, socketPair.server), "server-reader");
Stream.of(
serverReaderThread,
clientWriterThread1,
Expand All @@ -49,7 +46,7 @@ public void testWriteSide() throws IOException {
.forEach(t -> t.start());
Stream.of(clientWriterThread1, clientWriterThread2, clientWriterThread3, clientWriterThread4)
.forEach(t -> joinInterruptible(t));
socketPair.client().external().close();
socketPair.client.external.close();
joinInterruptible(serverReaderThread);
SocketPairFactory.checkDeallocation(socketPair);
}
Expand All @@ -58,15 +55,15 @@ public void testWriteSide() throws IOException {
@Test
public void testReadSide() throws IOException {
SocketPair socketPair = factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null));
Thread clientWriterThread = new Thread(() -> writerLoop(dataSize, 'a', socketPair.client()), "client-writer");
Thread clientWriterThread = new Thread(() -> writerLoop(dataSize, 'a', socketPair.client), "client-writer");
AtomicLong totalRead = new AtomicLong();
Thread serverReaderThread1 =
new Thread(() -> readerLoopUntilEof(socketPair.server(), totalRead), "server-reader-1");
new Thread(() -> readerLoopUntilEof(socketPair.server, totalRead), "server-reader-1");
Thread serverReaderThread2 =
new Thread(() -> readerLoopUntilEof(socketPair.server(), totalRead), "server-reader-2");
new Thread(() -> readerLoopUntilEof(socketPair.server, totalRead), "server-reader-2");
Stream.of(serverReaderThread1, serverReaderThread2, clientWriterThread).forEach(t -> t.start());
joinInterruptible(clientWriterThread);
socketPair.client().external().close();
socketPair.client.external.close();
Stream.of(serverReaderThread1, serverReaderThread2).forEach(t -> joinInterruptible(t));
SocketPairFactory.checkDeallocation(socketPair);
assertEquals(dataSize, totalRead.get());
Expand All @@ -82,7 +79,7 @@ private void writerLoop(int size, char ch, SocketGroup socketGroup) {
while (bytesRemaining > 0) {
ByteBuffer buffer = ByteBuffer.wrap(bufferArray, 0, Math.min(bufferSize, bytesRemaining));
while (buffer.hasRemaining()) {
int c = socketGroup.external().write(buffer);
int c = socketGroup.external.write(buffer);
assertTrue(c > 0, "blocking write must return a positive number");
bytesRemaining -= c;
assertTrue(bytesRemaining >= 0);
Expand All @@ -104,7 +101,7 @@ private void readerLoop(int size, SocketGroup socketGroup) {
int bytesRemaining = size;
while (bytesRemaining > 0) {
ByteBuffer readBuffer = ByteBuffer.wrap(readArray, 0, Math.min(bufferSize, bytesRemaining));
int c = socketGroup.external().read(readBuffer);
int c = socketGroup.external.read(readBuffer);
assertTrue(c > 0, "blocking read must return a positive number");
bytesRemaining -= c;
assertTrue(bytesRemaining >= 0);
Expand All @@ -124,7 +121,7 @@ private void readerLoopUntilEof(SocketGroup socketGroup, AtomicLong accumulator)
byte[] readArray = new byte[bufferSize];
while (true) {
ByteBuffer readBuffer = ByteBuffer.wrap(readArray, 0, bufferSize);
int c = socketGroup.external().read(readBuffer);
int c = socketGroup.external.read(readBuffer);
if (c == -1) {
logger.fine("Finalizing reader loop");
return null;
Expand Down
6 changes: 3 additions & 3 deletions src/test/scala/tlschannel/MultiNonBlockingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import scala.Option;
import scala.collection.immutable.Seq;
import tlschannel.helpers.NonBlockingLoops;
import tlschannel.helpers.SocketPair;
import tlschannel.helpers.SocketGroups.SocketPair;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SslContextFactory;

Expand Down Expand Up @@ -46,7 +46,7 @@ public void testTasksInExecutor() {
@Test
public void testTasksInLoopWithRenegotiation() {
System.out.println("testTasksInExecutorWithRenegotiation():");
Seq<tlschannel.helpers.SocketPair> pairs = factory.nioNioN(
Seq<SocketPair> pairs = factory.nioNioN(
Option.apply(null), totalConnections, Option.apply(null), true, false, Option.apply(null));
NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, true);
assertEquals(0, report.asyncTasksRun());
Expand All @@ -57,7 +57,7 @@ public void testTasksInLoopWithRenegotiation() {
@Test
public void testTasksInExecutorWithRenegotiation() {
System.out.println("testTasksInExecutorWithRenegotiation():");
Seq<tlschannel.helpers.SocketPair> pairs = factory.nioNioN(
Seq<SocketPair> pairs = factory.nioNioN(
Option.apply(null), totalConnections, Option.apply(null), false, false, Option.apply(null));
NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, true);
report.print();
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/tlschannel/NonBlockingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import scala.Some;
import scala.jdk.javaapi.CollectionConverters;
import tlschannel.helpers.NonBlockingLoops;
import tlschannel.helpers.SocketPair;
import tlschannel.helpers.SocketGroups.SocketPair;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SocketPairFactory.ChuckSizes;
import tlschannel.helpers.SocketPairFactory.ChunkSizeConfig;
Expand Down
5 changes: 3 additions & 2 deletions src/test/scala/tlschannel/NullEngineTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import scala.Option;
import scala.Some;
import tlschannel.helpers.Loops;
import tlschannel.helpers.SocketGroups;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SocketPairFactory.ChuckSizes;
import tlschannel.helpers.SocketPairFactory.ChunkSizeConfig;
Expand Down Expand Up @@ -44,7 +45,7 @@ public Collection<DynamicTest> testHalfDuplexHeapBuffers() {
List<DynamicTest> tests = new ArrayList<>();
for (int size1 : sizes) {
DynamicTest test = DynamicTest.dynamicTest(String.format("Testing sizes: size1=%s", size1), () -> {
tlschannel.helpers.SocketPair socketPair = factory.nioNio(
SocketGroups.SocketPair socketPair = factory.nioNio(
null,
Some.apply(new ChunkSizeConfig(
new ChuckSizes(Some.apply(size1), Option.apply(null)),
Expand All @@ -69,7 +70,7 @@ public Collection<DynamicTest> testHalfDuplexDirectBuffers() {
List<DynamicTest> tests = new ArrayList<>();
for (int size1 : sizes) {
DynamicTest test = DynamicTest.dynamicTest(String.format("Testing sizes: size1=%s", size1), () -> {
tlschannel.helpers.SocketPair socketPair = factory.nioNio(
SocketGroups.SocketPair socketPair = factory.nioNio(
null,
Some.apply(new ChunkSizeConfig(
new ChuckSizes(Some.apply(size1), Option.apply(null)),
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/tlschannel/NullMultiNonBlockingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import scala.Option;
import scala.collection.immutable.Seq;
import tlschannel.helpers.NonBlockingLoops;
import tlschannel.helpers.SocketPair;
import tlschannel.helpers.SocketGroups.SocketPair;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SslContextFactory;

Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/tlschannel/ScatteringTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import org.junit.jupiter.api.TestInstance.Lifecycle;
import scala.Option;
import tlschannel.helpers.Loops;
import tlschannel.helpers.SocketPair;
import tlschannel.helpers.SocketGroups.SocketPair;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SslContextFactory;

Expand Down
14 changes: 7 additions & 7 deletions src/test/scala/tlschannel/async/AsyncCloseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import tlschannel.helpers.AsyncSocketPair;
import tlschannel.helpers.SocketGroups.AsyncSocketPair;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SslContextFactory;

Expand All @@ -40,9 +40,9 @@ public void testClosingWhileReading() throws IOException, InterruptedException {
AsyncSocketPair socketPair = factory.async(null, channelGroup, true, false);

ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize);
Future<Integer> readFuture = socketPair.server().external().read(readBuffer);
Future<Integer> readFuture = socketPair.server.external.read(readBuffer);

socketPair.server().external().close();
socketPair.server.external.close();

try {
readFuture.get(1000, TimeUnit.MILLISECONDS);
Expand All @@ -61,7 +61,7 @@ public void testClosingWhileReading() throws IOException, InterruptedException {
Assertions.fail(e);
}

socketPair.client().external().close();
socketPair.client.external.close();
shutdownChannelGroup(channelGroup);
assertChannelGroupConsistency(channelGroup);
assertEquals(0, channelGroup.getSuccessfulReadCount());
Expand All @@ -78,10 +78,10 @@ public void testRawClosingWhileReading() throws IOException, InterruptedExceptio
AsyncSocketPair socketPair = factory.async(null, channelGroup, true, false);

ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize);
Future<Integer> readFuture = socketPair.server().external().read(readBuffer);
Future<Integer> readFuture = socketPair.server.external.read(readBuffer);

// important: closing the raw socket
socketPair.server().plain().close();
socketPair.server.plain.close();

try {
readFuture.get(1000, TimeUnit.MILLISECONDS);
Expand All @@ -100,7 +100,7 @@ public void testRawClosingWhileReading() throws IOException, InterruptedExceptio
Assertions.fail(e);
}

socketPair.client().external().close();
socketPair.client.external.close();
shutdownChannelGroup(channelGroup);
assertChannelGroupConsistency(channelGroup);
assertEquals(0, channelGroup.getSuccessfulReadCount());
Expand Down
19 changes: 9 additions & 10 deletions src/test/scala/tlschannel/async/AsyncShutdownTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}

import java.nio.ByteBuffer
import java.util.concurrent.TimeUnit
import tlschannel.helpers.AsyncSocketPair
import tlschannel.helpers.SocketPairFactory
import tlschannel.helpers.SslContextFactory
import org.junit.jupiter.api.{Test, TestInstance}
Expand All @@ -24,11 +23,11 @@ class AsyncShutdownTest extends AsyncTestBase {
val channelGroup = new AsynchronousTlsChannelGroup()
val socketPairCount = 50
val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true)
for (AsyncSocketPair(client, server) <- socketPairs) {
for (pair <- socketPairs) {
val writeBuffer = ByteBuffer.allocate(bufferSize)
client.external.write(writeBuffer)
pair.client.external.write(writeBuffer)
val readBuffer = ByteBuffer.allocate(bufferSize)
server.external.read(readBuffer)
pair.server.external.read(readBuffer)
}

assertFalse(channelGroup.isTerminated)
Expand All @@ -50,11 +49,11 @@ class AsyncShutdownTest extends AsyncTestBase {
val channelGroup = new AsynchronousTlsChannelGroup()
val socketPairCount = 50
val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true)
for (AsyncSocketPair(client, server) <- socketPairs) {
for (pair <- socketPairs) {
val writeBuffer = ByteBuffer.allocate(bufferSize)
client.external.write(writeBuffer)
pair.client.external.write(writeBuffer)
val readBuffer = ByteBuffer.allocate(bufferSize)
server.external.read(readBuffer)
pair.server.external.read(readBuffer)
}

assertFalse(channelGroup.isTerminated)
Expand All @@ -68,9 +67,9 @@ class AsyncShutdownTest extends AsyncTestBase {
assertFalse(channelGroup.isTerminated)
}

for (AsyncSocketPair(client, server) <- socketPairs) {
client.external.close()
server.external.close()
for (pair <- socketPairs) {
pair.client.external.close()
pair.server.external.close()
}

{
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/tlschannel/async/AsyncTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import scala.Option;
import scala.collection.immutable.Seq;
import tlschannel.helpers.AsyncLoops;
import tlschannel.helpers.AsyncSocketPair;
import tlschannel.helpers.SocketGroups.AsyncSocketPair;
import tlschannel.helpers.SocketPairFactory;
import tlschannel.helpers.SslContextFactory;

Expand Down
25 changes: 12 additions & 13 deletions src/test/scala/tlschannel/async/AsyncTimeoutTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.LongAdder
import tlschannel.helpers.AsyncSocketPair
import tlschannel.helpers.SocketPairFactory
import tlschannel.helpers.SslContextFactory

Expand All @@ -35,10 +34,10 @@ class AsyncTimeoutTest extends AsyncTestBase {
val socketPairCount = 50
val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true)
val latch = new CountDownLatch(socketPairCount * 2)
for (AsyncSocketPair(client, server) <- socketPairs) {
for (pair <- socketPairs) {
val writeBuffer = ByteBuffer.allocate(bufferSize)
val clientDone = new AtomicBoolean
client.external.write(
pair.client.external.write(
writeBuffer,
50,
TimeUnit.MILLISECONDS,
Expand All @@ -62,7 +61,7 @@ class AsyncTimeoutTest extends AsyncTestBase {
)
val readBuffer = ByteBuffer.allocate(bufferSize)
val serverDone = new AtomicBoolean
server.external.read(
pair.server.external.read(
readBuffer,
100,
TimeUnit.MILLISECONDS,
Expand All @@ -86,9 +85,9 @@ class AsyncTimeoutTest extends AsyncTestBase {
)
}
latch.await()
for (AsyncSocketPair(client, server) <- socketPairs) {
client.external.close()
server.external.close()
for (pair <- socketPairs) {
pair.client.external.close()
pair.server.external.close()
}
}

Expand Down Expand Up @@ -116,11 +115,11 @@ class AsyncTimeoutTest extends AsyncTestBase {
for (_ <- 1 to repetitions) {
val socketPairCount = 50
val socketPairs = factory.asyncN(null, channelGroup, socketPairCount, runTasks = true)
val futures = for (AsyncSocketPair(client, server) <- socketPairs) yield {
val futures = for (pair <- socketPairs) yield {
val writeBuffer = ByteBuffer.allocate(bufferSize)
val writeFuture = client.external.write(writeBuffer)
val writeFuture = pair.client.external.write(writeBuffer)
val readBuffer = ByteBuffer.allocate(bufferSize)
val readFuture = server.external.read(readBuffer)
val readFuture = pair.server.external.read(readBuffer)
(writeFuture, readFuture)
}

Expand All @@ -132,9 +131,9 @@ class AsyncTimeoutTest extends AsyncTestBase {
successfulReadCancellations += 1
}
}
for (AsyncSocketPair(client, server) <- socketPairs) {
client.external.close()
server.external.close()
for (pair <- socketPairs) {
pair.client.external.close()
pair.server.external.close()
}
}
shutdownChannelGroup(channelGroup)
Expand Down
Loading

0 comments on commit 45c4dfd

Please sign in to comment.