Skip to content

Commit

Permalink
refactor: make minor tweaks to improve code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Zmax0 committed Jul 3, 2024
1 parent 290f11d commit 2bb41a9
Show file tree
Hide file tree
Showing 15 changed files with 112 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public static SslHandler buildSslHandler(Channel ch, ServerConfig config) throws
}

public static WebSocketClientProtocolHandler buildWebSocketHandler(ServerConfig config) throws URISyntaxException {
Optional<WebSocketSetting> ws = Optional.of(config.getWs());
Optional<WebSocketSetting> ws = Optional.ofNullable(config.getWs());
String path = ws.map(WebSocketSetting::getPath).orElseThrow(() -> new IllegalArgumentException("required path not present"));
WebSocketClientProtocolConfig.Builder builder = WebSocketClientProtocolConfig.newBuilder()
.webSocketUri(new URI("ws", null, config.getHost(), config.getPort(), path, null, null));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,7 @@ protected void initChannel(Channel ch) throws URISyntaxException {
new HttpClientCodec(),
new HttpObjectAggregator(0xffff),
ClientSocksInitializer.buildWebSocketHandler(config),
new MessageToMessageCodec<BinaryWebSocketFrame, ByteBuf>() {
private ChannelPromise promise;

@Override
protected void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
BinaryWebSocketFrame frame = new BinaryWebSocketFrame(msg.retain());
if (!promise.isDone()) {
promise.addListener(f -> ctx.writeAndFlush(frame));
} else {
out.add(frame);
}
}

@Override
protected void decode(ChannelHandlerContext ctx, BinaryWebSocketFrame msg, List<Object> out) {
out.add(msg.retain().content());
}

@Override
public void handlerAdded(ChannelHandlerContext ctx) {
promise = ctx.newPromise();
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
promise.setSuccess();
}
ctx.fireUserEventTriggered(evt);
}
}
new WebSocketCodec()
);
}
pipeline.addLast(new ClientAeadCodec(config.getCipher(), RequestCommand.UDP, key.recipient, config.getPassword()));
Expand Down Expand Up @@ -127,4 +97,36 @@ public void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) {
channel.writeAndFlush(new DatagramPacketWrapper(new DatagramPacket(msg, recipient), sender));
}
}

private static class WebSocketCodec extends MessageToMessageCodec<BinaryWebSocketFrame, ByteBuf> {
private ChannelPromise promise;

@Override
public void handlerAdded(ChannelHandlerContext ctx) {
promise = ctx.newPromise();
}

@Override
protected void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
BinaryWebSocketFrame frame = new BinaryWebSocketFrame(msg.retain());
if (!promise.isDone()) {
promise.addListener(f -> ctx.writeAndFlush(frame));
} else {
out.add(frame);
}
}

@Override
protected void decode(ChannelHandlerContext ctx, BinaryWebSocketFrame msg, List<Object> out) {
out.add(msg.retain().content());
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
promise.setSuccess();
}
ctx.fireUserEventTriggered(evt);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (Mode.Server == session.mode() && cause instanceof DecoderException && ctx.channel() instanceof SocketChannel socketChannel) {
String transLog = ExceptionHandler.transLog(ctx.channel());
if (Mode.Server == session.mode() && cause instanceof DecoderException) {
SocketChannel channel = (SocketChannel) ctx.channel();
String transLog = ExceptionHandler.transLog(channel);
logger.error("[tcp][{}] {}", transLog, cause.getMessage());
ctx.deregister();
socketChannel.config().setSoLinger(0);
socketChannel.shutdownOutput().addListener(future -> socketChannel.unsafe().beginRead());
channel.config().setSoLinger(0);
channel.shutdownOutput().addListener(future -> channel.unsafe().beginRead());
} else {
ctx.fireExceptionCaught(cause);
}
Expand Down
2 changes: 1 addition & 1 deletion urban-spork-server/resource/logback.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
</appender>

<logger name="io.netty" level="DEBUG"/>
<logger name="io.netty.handler.ssl" level="INFO"/>
<logger name="io.netty.handler" level="INFO"/>
<logger name="com.urbanspork" level="INFO"/>

<root level="WARN">
Expand Down
24 changes: 12 additions & 12 deletions urban-spork-server/src/com/urbanspork/server/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,18 @@ private static Instance startup(EventLoopGroup bossGroup, EventLoopGroup workerG
}
}
Context context = Context.newCheckReplayInstance();
ServerSocketChannel tcp;
try {
tcp = (ServerSocketChannel) new ServerBootstrap().group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ServerInitializer(config, context))
.childOption(ChannelOption.ALLOW_HALF_CLOSURE, true)
.bind(config.getPort()).sync().addListener(future -> logger.info("Startup tcp server => {}", config)).channel()
.closeFuture().addListener(future -> context.release()).channel();
} catch (Exception e) {
context.release();
throw e;
}
ServerSocketChannel tcp = (ServerSocketChannel) new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ServerInitializer(config, context))
.childOption(ChannelOption.ALLOW_HALF_CLOSURE, true)
.bind(config.getPort()).addListener(future -> {
if (!future.isSuccess()) {
context.release();
}
})
.sync().addListener(future -> logger.info("Startup tcp server => {}", config))
.channel().closeFuture().addListener(future -> context.release()).channel();
config.setPort(tcp.localAddress().getPort());
Optional<DatagramChannel> udp = startupUdp(bossGroup, workerGroup, config);
return new Instance(tcp, udp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ protected void initChannel(Channel c) throws SSLException {
case vmess -> c.pipeline().addLast(new ServerAeadCodec(config), new ExceptionHandler(config), new ServerRelayHandler(config));
case trojan -> {
String serverName = config.getHost();
SslSetting sslSetting = Optional.ofNullable(config.getSsl())
.orElseThrow(() -> new IllegalArgumentException("required security setting not present"));
SslSetting sslSetting = Optional.ofNullable(config.getSsl()).orElseThrow(() -> new IllegalArgumentException("required security setting not present"));
SslContext sslContext = SslContextBuilder.forServer(new File(sslSetting.getCertificateFile()), new File(sslSetting.getKeyFile()), sslSetting.getKeyPassword()).build();
if (sslSetting.getServerName() != null) {
serverName = sslSetting.getServerName();
Expand All @@ -60,7 +59,7 @@ protected void initChannel(Channel c) throws SSLException {
}

private void enableWebSocket(Channel channel) {
String path = Optional.of(config.getWs()).map(WebSocketSetting::getPath).orElseThrow(() -> new IllegalArgumentException("required path not present"));
String path = Optional.ofNullable(config.getWs()).map(WebSocketSetting::getPath).orElseThrow(() -> new IllegalArgumentException("required path not present"));
channel.pipeline().addLast(
new HttpServerCodec(),
new HttpObjectAggregator(0xffff),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
public class DelayedEchoTestServer {

public static final int PORT = 16801;
public static final int MAX_DELAYED_SECOND = 5;
public static final int MAX_DELAYED_SECOND = 2;

public static void main(String[] args) throws IOException {
launch(PORT, new CompletableFuture<>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ void testBuildSslHandler() {

@Test
void testBuildWebSocketHandler() {
ServerConfig config = ServerConfigTest.testConfig(0);
Assertions.assertThrows(NullPointerException.class, () -> ClientSocksInitializer.buildWebSocketHandler(config));
WebSocketSetting webSocket = new WebSocketSetting();
ServerConfig config = ServerConfigTest.testConfig(0);
config.setWs(webSocket);
Assertions.assertThrows(IllegalArgumentException.class, () -> ClientSocksInitializer.buildWebSocketHandler(config));
webSocket.setPath("/ws");
webSocket.setHeader(Map.of("Host", "localhost"));
config.setWs(webSocket);
Assertions.assertDoesNotThrow(() -> ClientSocksInitializer.buildWebSocketHandler(config));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.socket.DatagramPacket;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.socksx.v5.DefaultSocks5CommandRequest;
import io.netty.handler.codec.socksx.v5.Socks5AddressType;
Expand Down Expand Up @@ -112,32 +111,4 @@ void testAead2022UdpAntiReplay() {
client.close();
server.close();
}

@Test
void testAead2022TcpAntiReplay() {
EmbeddedChannel server1 = new EmbeddedChannel();
EmbeddedChannel server2 = new EmbeddedChannel();
EmbeddedChannel client = new EmbeddedChannel();
CipherKind kind = CipherKind.aead2022_blake3_aes_256_gcm;
ServerConfig config = new ServerConfig();
config.setPassword(TestDice.rollPassword(Protocol.shadowsocks, kind));
config.setCipher(kind);
Context context = Context.newCheckReplayInstance();
server1.pipeline().addLast(new TcpRelayCodec(context, config, Mode.Server));
server2.pipeline().addLast(new TcpRelayCodec(context, config, Mode.Server));
DefaultSocks5CommandRequest request = new DefaultSocks5CommandRequest(Socks5CommandType.CONNECT, Socks5AddressType.DOMAIN, "localhost", 16800);
client.pipeline().addLast(new TcpRelayCodec(context, config, request, Mode.Client));
client.writeOutbound(Unpooled.wrappedBuffer(Dice.rollBytes(10)));
ByteBuf msg1 = client.readOutbound();
ByteBuf msg2 = msg1.copy();
Assertions.assertTrue(msg1.isReadable());
Assertions.assertTrue(msg2.isReadable());
server1.writeInbound(msg1);
ByteBuf tooShortMsg = Unpooled.wrappedBuffer(Dice.rollBytes(33));
Assertions.assertThrows(DecoderException.class, () -> server2.writeInbound(tooShortMsg));
Assertions.assertThrows(DecoderException.class, () -> server2.writeInbound(msg2));
client.close();
server1.close();
server2.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import io.netty.handler.codec.socksx.v5.Socks5CommandType;
import org.bouncycastle.crypto.InvalidCipherTextException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Base64;
import java.util.List;

@DisplayName("tcp.AeadCipherCodecTest")
class AeadCipherCodecTest extends TraceLevelLoggerTestTemplate {

@Test
Expand Down Expand Up @@ -70,6 +73,32 @@ void testUnexpectedStreamType() throws InvalidCipherTextException {
Assertions.assertThrows(DecoderException.class, () -> codec.decode(session, temp, out));
}

@Test
void testAead2022TcpAntiReplay() {
CipherKind kind = CipherKind.aead2022_blake3_aes_256_gcm;
ServerConfig config = new ServerConfig();
config.setPassword(TestDice.rollPassword(Protocol.shadowsocks, kind));
config.setCipher(kind);
Context context = Context.newCheckReplayInstance();
DefaultSocks5CommandRequest request = new DefaultSocks5CommandRequest(Socks5CommandType.CONNECT, Socks5AddressType.DOMAIN, "localhost", 16800);
ByteBuf msg1 = Unpooled.buffer();
Identity identity = new Identity(kind);
Session clientSession = new Session(Mode.Client, identity, request, ServerUserManager.EMPTY, context);
AeadCipherCodec clientCodec = AeadCipherCodecs.get(config);
Assertions.assertDoesNotThrow(() -> clientCodec.encode(clientSession, Unpooled.wrappedBuffer(Dice.rollBytes(10)), msg1));
ByteBuf msg2 = msg1.copy();
Assertions.assertTrue(msg1.isReadable());
Assertions.assertTrue(msg2.isReadable());
ByteBuf tooShortMsg = Unpooled.wrappedBuffer(Dice.rollBytes(33));
Session serverSession = new Session(Mode.Server, identity, request, ServerUserManager.EMPTY, context);
List<Object> out = new ArrayList<>();
AeadCipherCodec serverCodec1 = AeadCipherCodecs.get(config);
Assertions.assertThrows(IndexOutOfBoundsException.class, () -> serverCodec1.decode(serverSession, tooShortMsg, out));
Assertions.assertDoesNotThrow(() -> serverCodec1.decode(serverSession, msg1, out));
AeadCipherCodec serverCodec2 = AeadCipherCodecs.get(config);
Assertions.assertThrows(DecoderException.class, () -> serverCodec2.decode(serverSession, msg2, out));
}

@Override
protected Class<?> loggerClass() {
return AeadCipherCodec.class;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -33,8 +32,8 @@
import java.util.concurrent.ThreadLocalRandom;

@TestInstance(Lifecycle.PER_CLASS)
@DisplayName("tcp.AeadCipherCodecsTest")
class AeadCipherCodecsTest {

private final byte[] in = Dice.rollBytes(0xffff * 10);
private CipherKind kind;
private String password;
Expand All @@ -46,17 +45,9 @@ void beforeAll() {
logger.setLevel(Level.TRACE);
}

@DisplayName("Single cipher")
@Test
void test() throws Exception {
parameterizedTest(CipherKind.chacha20_poly1305);
parameterizedTest(CipherKind.aead2022_blake3_aes_256_gcm);
}

@ParameterizedTest
@DisplayName("All supported cipher iterate")
@EnumSource(CipherKind.class)
void parameterizedTest(CipherKind kind) throws Exception {
void testByKind(CipherKind kind) throws Exception {
this.password = TestDice.rollPassword(Protocol.shadowsocks, kind);
this.kind = kind;
int port = TestDice.rollPort();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
package com.urbanspork.common.codec.shadowsocks.tcp;

import com.urbanspork.common.codec.shadowsocks.Mode;
import com.urbanspork.common.config.ServerConfig;
import com.urbanspork.common.config.ServerConfigTest;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

class TcpRelayCodecTest {
@Test
void testCaughtOtherException() {
ServerConfig config = ServerConfigTest.testConfig(0);
TcpRelayCodec codec = new TcpRelayCodec(new Context(), config, Mode.Server);
testCaughtOtherException(Mode.Server, new RuntimeException());
testCaughtOtherException(Mode.Client, new DecoderException());
}

void testCaughtOtherException(Mode mode, Throwable throwable) {
TcpRelayCodec codec = new TcpRelayCodec(new Context(), ServerConfigTest.testConfig(0), mode);
EmbeddedChannel channel = new EmbeddedChannel(codec);
codec.exceptionCaught(channel.pipeline().context(codec), new RuntimeException());
Assertions.assertThrows(RuntimeException.class, channel::checkException);
codec.exceptionCaught(channel.pipeline().context(codec), throwable);
Assertions.assertThrows(throwable.getClass(), channel::checkException);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import io.netty.handler.codec.DecoderException;
import org.bouncycastle.crypto.InvalidCipherTextException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import java.net.InetSocketAddress;
import java.util.Base64;

@DisplayName("udp.AeadCipherCodecTest")
class AeadCipherCodecTest extends TraceLevelLoggerTestTemplate {

@Test
Expand All @@ -41,13 +43,18 @@ void testTooShortHeader() {

@Test
void testEmptyMsg() throws InvalidCipherTextException {
testEmptyMsg(Mode.Client, Mode.Server);
testEmptyMsg(Mode.Server, Mode.Client);
}

void testEmptyMsg(Mode from, Mode to) throws InvalidCipherTextException {
AeadCipherCodec codec = newAEADCipherCodec();
InetSocketAddress address = InetSocketAddress.createUnresolved(TestDice.rollHost(), TestDice.rollPort());
ByteBuf in = Unpooled.buffer();
CipherKind kind = TestDice.rollCipher();
codec.encode(new Context(Mode.Client, new Control(kind), address, ServerUserManager.EMPTY), Unpooled.EMPTY_BUFFER, in);
codec.encode(new Context(from, new Control(kind), address, ServerUserManager.EMPTY), Unpooled.EMPTY_BUFFER, in);
Assertions.assertTrue(in.isReadable());
RelayingPacket<ByteBuf> pocket = codec.decode(new Context(Mode.Server, new Control(kind), address, ServerUserManager.EMPTY), in);
RelayingPacket<ByteBuf> pocket = codec.decode(new Context(to, new Control(kind), address, ServerUserManager.EMPTY), in);
Assertions.assertFalse(in.isReadable());
Assertions.assertNotNull(pocket);
}
Expand Down
Loading

0 comments on commit 2bb41a9

Please sign in to comment.