Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Java: SslContextFactory #223

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/test/scala/tlschannel/BlockingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class BlockingTest {
@TestFactory
public Collection<DynamicTest> testHalfDuplexWireRenegotiations() {
System.out.println("testHalfDuplexWireRenegotiations():");
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize() * 2, x -> x * 2)
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize * 2, x -> x * 2)
.collect(Collectors.toList());
List<Integer> reversedSizes = ListUtils.reversed(sizes);
List<DynamicTest> ret = new ArrayList<>();
Expand Down Expand Up @@ -55,7 +55,7 @@ public Collection<DynamicTest> testHalfDuplexWireRenegotiations() {
@TestFactory
public Collection<DynamicTest> testFullDuplex() {
System.out.println("testFullDuplex():");
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize() * 2, x -> x * 2)
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize * 2, x -> x * 2)
.collect(Collectors.toList());
List<Integer> reversedSizes = ListUtils.reversed(sizes);
List<DynamicTest> ret = new ArrayList<>();
Expand Down
7 changes: 2 additions & 5 deletions src/test/scala/tlschannel/CipherTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.junit.jupiter.api.TestInstance.Lifecycle;
import scala.Option;
import scala.Some;
import scala.jdk.CollectionConverters;
import tlschannel.helpers.Loops;
import tlschannel.helpers.SocketGroups.SocketPair;
import tlschannel.helpers.SocketPairFactory;
Expand Down Expand Up @@ -44,8 +43,7 @@ public Collection<DynamicTest> testHalfDuplexWithRenegotiation() {
List<DynamicTest> tests = new ArrayList<>();
for (String protocol : protocols) {
SslContextFactory ctxFactory = new SslContextFactory(protocol);
for (String cipher :
CollectionConverters.SeqHasAsJava(ctxFactory.allCiphers()).asJava()) {
for (String cipher : ctxFactory.getAllCiphers()) {
tests.add(DynamicTest.dynamicTest(
String.format("testHalfDuplexWithRenegotiation() - protocol: %s, cipher: %s", protocol, cipher),
() -> {
Expand Down Expand Up @@ -73,8 +71,7 @@ public Collection<DynamicTest> testFullDuplex() {
List<DynamicTest> tests = new ArrayList<>();
for (String protocol : protocols) {
SslContextFactory ctxFactory = new SslContextFactory(protocol);
for (String cipher :
CollectionConverters.SeqHasAsJava(ctxFactory.allCiphers()).asJava()) {
for (String cipher : ctxFactory.getAllCiphers()) {
tests.add(DynamicTest.dynamicTest(
String.format("testFullDuplex() - protocol: %s, cipher: %s", protocol, cipher), () -> {
SocketPairFactory socketFactory = new SocketPairFactory(ctxFactory.defaultContext());
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/tlschannel/InteroperabilityTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ public class InteroperabilityTest {

private final SslContextFactory sslContextFactory = new SslContextFactory();
private final SocketPairFactory factory =
new SocketPairFactory(sslContextFactory.defaultContext(), SslContextFactory.certificateCommonName());
new SocketPairFactory(sslContextFactory.defaultContext(), SslContextFactory.certificateCommonName);

private final Random random = new Random();

private final int dataSize = SslContextFactory.tlsMaxDataSize() * 10;
private final int dataSize = SslContextFactory.tlsMaxDataSize * 10;

private final byte[] data = new byte[dataSize];

Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/tlschannel/NonBlockingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class NonBlockingTest {
@TestFactory
public Collection<DynamicTest> testSelectorLoop() {
System.out.println("testSelectorLoop():");
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize() * 2, x -> x * 2)
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize * 2, x -> x * 2)
.collect(Collectors.toList());
List<Integer> reversedSizes = ListUtils.reversed(sizes);
List<DynamicTest> ret = new ArrayList<>();
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/tlschannel/NullEngineTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class NullEngineTest {
@TestFactory
public Collection<DynamicTest> testHalfDuplexHeapBuffers() {
System.out.println("testHalfDuplexHeapBuffers():");
List<Integer> sizes = StreamUtils.iterate(512, x -> x < SslContextFactory.tlsMaxDataSize() * 2, x -> x * 2)
List<Integer> sizes = StreamUtils.iterate(512, x -> x < SslContextFactory.tlsMaxDataSize * 2, x -> x * 2)
.collect(Collectors.toList());
List<DynamicTest> tests = new ArrayList<>();
for (int size1 : sizes) {
Expand All @@ -66,7 +66,7 @@ public Collection<DynamicTest> testHalfDuplexHeapBuffers() {
@TestFactory
public Collection<DynamicTest> testHalfDuplexDirectBuffers() {
System.out.println("testHalfDuplexDirectBuffers():");
List<Integer> sizes = StreamUtils.iterate(512, x -> x < SslContextFactory.tlsMaxDataSize() * 2, x -> x * 2)
List<Integer> sizes = StreamUtils.iterate(512, x -> x < SslContextFactory.tlsMaxDataSize * 2, x -> x * 2)
.collect(Collectors.toList());
List<DynamicTest> tests = new ArrayList<>();
for (int size1 : sizes) {
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/tlschannel/async/PseudoAsyncTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class PseudoAsyncTest {
// test a half-duplex interaction, with renegotiation before reversing the direction of the flow (as in HTTP)
@TestFactory
public Collection<DynamicTest> testHalfDuplex() {
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize() * 2, x -> x * 2)
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize * 2, x -> x * 2)
.collect(Collectors.toList());
List<Integer> reversedSizes = ListUtils.reversed(sizes);
List<DynamicTest> ret = new ArrayList<>();
Expand All @@ -58,7 +58,7 @@ public Collection<DynamicTest> testHalfDuplex() {
// test a full-duplex interaction, without any renegotiation
@TestFactory
public Collection<DynamicTest> testFullDuplex() {
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize() * 2, x -> x * 2)
List<Integer> sizes = StreamUtils.iterate(1, x -> x < SslContextFactory.tlsMaxDataSize * 2, x -> x * 2)
.collect(Collectors.toList());
List<Integer> reversedSizes = ListUtils.reversed(sizes);
List<DynamicTest> ret = new ArrayList<>();
Expand Down
87 changes: 87 additions & 0 deletions src/test/scala/tlschannel/helpers/SslContextFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package tlschannel.helpers;

import java.io.IOException;
import java.io.InputStream;
import java.security.*;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;

public class SslContextFactory {

public static final int tlsMaxDataSize = (int) Math.pow(2, 14);
public static final String certificateCommonName = "name"; // must match what's in the certificates

private final String protocol;
private final SSLContext defaultContext;
private final List<String> allCiphers;

public SslContextFactory(String protocol) {
this.protocol = protocol;
try {
SSLContext sslContext = SSLContext.getInstance(protocol);
KeyStore ks = KeyStore.getInstance("JKS");
try (InputStream keystoreFile = getClass().getClassLoader().getResourceAsStream("keystore.jks")) {
ks.load(keystoreFile, "password".toCharArray());
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(ks);
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(ks, "password".toCharArray());
sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
}
this.defaultContext = sslContext;
} catch (IOException | GeneralSecurityException e) {
throw new RuntimeException(e);
}
this.allCiphers = ciphers(defaultContext).stream().sorted().collect(Collectors.toList());
}

public SslContextFactory() {
this("TLSv1.2");
}

private List<String> ciphers(SSLContext ctx) {
return Arrays.stream(ctx.createSSLEngine().getSupportedCipherSuites())
// this is not a real cipher, but a hack actually
.filter(c -> !Objects.equals(c, "TLS_EMPTY_RENEGOTIATION_INFO_SCSV"))
// disable problematic ciphers
.filter(c -> !Arrays.asList(
"TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5",
"TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA",
"TLS_KRB5_EXPORT_WITH_RC4_40_MD5",
"TLS_KRB5_EXPORT_WITH_RC4_40_SHA",
"TLS_KRB5_WITH_3DES_EDE_CBC_MD5",
"TLS_KRB5_WITH_3DES_EDE_CBC_SHA",
"TLS_KRB5_WITH_DES_CBC_MD5",
"TLS_KRB5_WITH_DES_CBC_SHA",
"TLS_KRB5_WITH_RC4_128_MD5",
"TLS_KRB5_WITH_RC4_128_SHA",
"SSL_RSA_EXPORT_WITH_DES40_CBC_SHA",
"SSL_RSA_EXPORT_WITH_RC4_40_MD5")
.contains(c))
// No SHA-2 with TLS < 1.2
.filter(c -> Arrays.asList("TLSv1.2", "TLSv1.3").contains(protocol)
|| !c.endsWith("_SHA256") && !c.endsWith("_SHA384"))
// Disable cipher only supported in TLS >= 1.3
.filter(c -> protocol.compareTo("TLSv1.3") > 0
|| !Arrays.asList("TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384")
.contains(c))
// https://bugs.openjdk.java.net/browse/JDK-8224997
.filter(c -> !c.endsWith("_CHACHA20_POLY1305_SHA256"))
// Anonymous ciphers are problematic because they are disabled in some VMs
.filter(c -> !c.contains("_anon_"))
.collect(Collectors.toList());
}

public SSLContext defaultContext() {
return defaultContext;
}

public List<String> getAllCiphers() {
return allCiphers;
}
}
82 changes: 0 additions & 82 deletions src/test/scala/tlschannel/helpers/SslContextFactory.scala

This file was deleted.