Skip to content

Commit

Permalink
fix(ss): incorrect udp relay behavior
Browse files Browse the repository at this point in the history
1. packet id filter
2. peer address mapping
  • Loading branch information
Zmax0 committed Sep 14, 2024
1 parent f0e7764 commit b4bdcd4
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
public class AttributeKeys {

public static final AttributeKey<ServerConfig> SERVER_CONFIG = AttributeKey.newInstance("SERVER_CONFIG");
public static final AttributeKey<Object> SERVER_UDP_RELAY_WORKER = AttributeKey.newInstance("SERVER_UDP_RELAY_WORKER");

private AttributeKeys() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.urbanspork.common.protocol.shadowsocks.aead.AEAD;
import com.urbanspork.common.protocol.socks.Address;
import com.urbanspork.common.transport.udp.RelayingPacket;
import com.urbanspork.common.util.Dice;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.bouncycastle.crypto.InvalidCipherTextException;
Expand All @@ -30,7 +31,7 @@ public AeadCipherCodecImpl(CipherMethod cipherMethod, Keys keys) {
@Override
public void encode(Context context, ByteBuf msg, ByteBuf out) throws InvalidCipherTextException {
InetSocketAddress address = context.address();
byte[] salt = context.control().salt();
byte[] salt = Dice.rollBytes(cipherMethod.keySize());
out.writeBytes(salt);
ByteBuf temp = Unpooled.buffer(Address.getLength(address));
Address.encode(address, temp);
Expand All @@ -40,7 +41,7 @@ public void encode(Context context, ByteBuf msg, ByteBuf out) throws InvalidCiph

@Override
public RelayingPacket<ByteBuf> decode(Context context, ByteBuf in) throws InvalidCipherTextException {
byte[] salt = context.control().salt();
byte[] salt = new byte[cipherMethod.keySize()];
in.readBytes(salt);
ByteBuf packet = AEAD.UDP.newPayloadDecoder(cipherMethod, keys.encKey(), salt).decodePacket(in);
InetSocketAddress address = Address.decode(packet);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,41 +1,43 @@
package com.urbanspork.common.codec.shadowsocks.udp;

import com.urbanspork.common.channel.AttributeKeys;
import com.urbanspork.common.codec.shadowsocks.Mode;
import com.urbanspork.common.config.ServerConfig;
import com.urbanspork.common.manage.shadowsocks.ServerUserManager;
import com.urbanspork.common.protocol.shadowsocks.Control;
import com.urbanspork.common.protocol.shadowsocks.replay.PacketWindowFilter;
import com.urbanspork.common.transport.udp.DatagramPacketWrapper;
import com.urbanspork.common.transport.udp.RelayingPacket;
import com.urbanspork.common.util.LruCache;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.DatagramPacket;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import io.netty.util.AttributeKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.List;

@ChannelHandler.Sharable
public class UdpRelayCodec extends MessageToMessageCodec<DatagramPacket, DatagramPacketWrapper> {
private static final Logger logger = LoggerFactory.getLogger(UdpRelayCodec.class);
private final ServerConfig config;
private static final AttributeKey<Control> CONTROL = AttributeKey.valueOf(UdpRelayCodec.class, Control.class.getSimpleName());
private final Mode mode;
private final ServerUserManager userManager;
private final AeadCipherCodec cipher;
private final LruCache<InetSocketAddress, Control> controlMap;
private final Control control0;
private final LruCache<Long, PacketWindowFilter> filters = new LruCache<>(1024, Duration.ofMinutes(5), (k, v) -> logger.trace("{} expire", k));

public UdpRelayCodec(ServerConfig config, Mode mode, ServerUserManager userManager) {
this.config = config;
this.mode = mode;
this.userManager = userManager;
this.cipher = AeadCipherCodecs.get(config);
this.controlMap = new LruCache<>(1024, Duration.ofMinutes(5), (k, v) -> logger.info("[udp]control map expire {}={}", k, v));
this.control0 = new Control(config.getCipher());
}

@Override
Expand All @@ -46,7 +48,10 @@ protected void encode(ChannelHandlerContext ctx, DatagramPacketWrapper msg, List
}
ByteBuf in = Unpooled.buffer();
DatagramPacket data = msg.packet();
Control control = getControl(proxy);
Control control = ctx.channel().attr(CONTROL).get();
if (control == null) {
control = new Control();
}
control.increasePacketId(1);
logger.trace("[udp][{}][encode]{}|{}", mode, proxy, control);
cipher.encode(new Context(mode, control, data.recipient(), userManager), data.content(), in);
Expand All @@ -55,27 +60,26 @@ protected void encode(ChannelHandlerContext ctx, DatagramPacketWrapper msg, List

@Override
protected void decode(ChannelHandlerContext ctx, DatagramPacket msg, List<Object> out) throws Exception {
Control control = getControl(msg.sender());
Control control = new Control(0, 0, 0);
Context context = new Context(mode, control, null, userManager);
RelayingPacket<ByteBuf> packet = cipher.decode(context, msg.content());
logger.trace("[udp][{}][decode]{}|{}", mode, msg.sender(), control);
if (cipher instanceof Aead2022CipherCodecImpl && !control.validatePacketId()) {
logger.error("[udp][{}→]{} packet_id {} out of window", mode, msg.sender(), control.getPacketId());
return;
Channel channel = ctx.channel();
channel.attr(CONTROL).set(control);
if (mode == Mode.Server) {
long clientSessionId = control.getClientSessionId();
channel.attr(AttributeKeys.SERVER_UDP_RELAY_WORKER).set(clientSessionId);
PacketWindowFilter filter = filters.computeIfAbsent(clientSessionId, k -> new PacketWindowFilter());
if (cipher instanceof Aead2022CipherCodecImpl && !filter.validatePacketId(control.getPacketId(), Long.MAX_VALUE)) {
logger.error("packet id out of window, {}→{}|{}", msg.sender(), packet.address(), control);
return;
}
}
out.add(new DatagramPacket(packet.content(), packet.address(), msg.sender()));
}

@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
controlMap.release();
}

private Control getControl(InetSocketAddress key) {
if (Mode.Client == mode) {
return control0;
} else {
return controlMap.computeIfAbsent(key, k -> new Control(config.getCipher()));
}
filters.release();
}
}
Original file line number Diff line number Diff line change
@@ -1,30 +1,23 @@
package com.urbanspork.common.protocol.shadowsocks;

import com.urbanspork.common.codec.CipherKind;
import com.urbanspork.common.manage.shadowsocks.ServerUser;
import com.urbanspork.common.protocol.shadowsocks.replay.PacketWindowFilter;
import com.urbanspork.common.util.Dice;

import java.util.concurrent.ThreadLocalRandom;

public class Control {
private final byte[] salt;
private long clientSessionId;
private long serverSessionId;
private long packetId;
private ServerUser user;
private final PacketWindowFilter packetWindowFilter;

public Control(CipherKind kind) {
this(Dice.rollBytes(kind.keySize()), ThreadLocalRandom.current().nextLong(), ThreadLocalRandom.current().nextLong(), 0);
public Control() {
this(ThreadLocalRandom.current().nextLong(), ThreadLocalRandom.current().nextLong(), 0);
}

Control(byte[] salt, long clientSessionId, long serverSessionId, long packetId) {
this.salt = salt;
public Control(long clientSessionId, long serverSessionId, long packetId) {
this.clientSessionId = clientSessionId;
this.serverSessionId = serverSessionId;
this.packetId = packetId;
this.packetWindowFilter = new PacketWindowFilter();
}

public void increasePacketId(long i) {
Expand All @@ -40,14 +33,6 @@ public void increasePacketId(long i) {
}
}

public boolean validatePacketId() {
return packetWindowFilter.validatePacketId(packetId, Long.MAX_VALUE);
}

public byte[] salt() {
return salt;
}

public long getPacketId() {
return packetId;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.urbanspork.server;

import com.urbanspork.common.channel.AttributeKeys;
import com.urbanspork.common.transport.udp.DatagramPacketWrapper;
import com.urbanspork.common.transport.udp.PacketEncoding;
import io.netty.bootstrap.Bootstrap;
Expand All @@ -19,14 +20,15 @@
import java.net.InetSocketAddress;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

public class ServerUdpRelayHandler extends SimpleChannelInboundHandler<DatagramPacket> {

private static final Logger logger = LoggerFactory.getLogger(ServerUdpRelayHandler.class);
private final Map<InetSocketAddress, Channel> workerChannels = new ConcurrentHashMap<>();
private final Map<Object, Channel> workerChannels = new ConcurrentHashMap<>();
private final EventLoopGroup workerGroup;
private final PacketEncoding packetEncoding;
private Channel packetWorkerChannel;

public ServerUdpRelayHandler(PacketEncoding packetEncoding, EventLoopGroup workerGroup) {
super(false);
Expand All @@ -39,7 +41,7 @@ public void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) {
Channel channel = ctx.channel();
InetSocketAddress sender = msg.sender();
InetSocketAddress recipient = msg.recipient();
Channel workerChannel = workerChannel(sender, channel);
Channel workerChannel = workerChannel(channel);
logger.info("[udp][relay]{}→{}~{}→{}", sender, recipient, channel.localAddress(), workerChannel.localAddress());
workerChannel.writeAndFlush(msg);
}
Expand All @@ -49,17 +51,36 @@ public void handlerRemoved(ChannelHandlerContext ctx) {
for (Map.Entry<?, Channel> entry : workerChannels.entrySet()) {
entry.getValue().close();
}
if (packetWorkerChannel != null) {
packetWorkerChannel.close();
}
}

Channel workerChannel(InetSocketAddress key, Channel inboundChannel) {
Channel workerChannel(Channel inboundChannel) {
if (PacketEncoding.Packet == packetEncoding) {
key = PacketEncoding.Packet.seqPacketMagicAddress();
if (packetWorkerChannel == null) {
packetWorkerChannel = newWorkerChannel0(inboundChannel);
}
return packetWorkerChannel;
}
Object key = Objects.requireNonNull(inboundChannel.attr(AttributeKeys.SERVER_UDP_RELAY_WORKER).get(), "require channel attribute: " + AttributeKeys.SERVER_UDP_RELAY_WORKER);
return workerChannels.computeIfAbsent(key, k -> newWorkerChannel(k, inboundChannel));
}

private Channel newWorkerChannel(InetSocketAddress key, Channel channel) {
Channel workerChannel = new Bootstrap().group(workerGroup).channel(NioDatagramChannel.class)
private Channel newWorkerChannel(Object key, Channel channel) {
Channel workerChannel = newWorkerChannel0(channel);
workerChannel.closeFuture().addListener(future -> {
Channel removed = workerChannels.remove(key);
logger.info("[udp][binding]{} != {}", key, removed);
});
logger.info("[udp][binding]{} == {}", key, workerChannel);
return workerChannel;
}

// callback->server->client
private Channel newWorkerChannel0(Channel channel) {
// automatically assigned port now, may have security implications
return new Bootstrap().group(workerGroup).channel(NioDatagramChannel.class)
.handler(new ChannelInitializer<>() {
@Override
protected void initChannel(Channel ch) {
Expand All @@ -68,16 +89,10 @@ protected void initChannel(Channel ch) {
new InboundHandler(channel)
);
}
})// callback->server->client
}) // callback->server->client
.bind(0) // automatically assigned port now, may have security implications
.syncUninterruptibly()
.channel();
workerChannel.closeFuture().addListener(future -> {
Channel removed = workerChannels.remove(key);
logger.info("[udp][binding]{} != {}", key, removed);
});
logger.info("[udp][binding]{} == {}", key, workerChannel);
return workerChannel;
}

private static class InboundHandler extends MessageToMessageCodec<DatagramPacket, DatagramPacket> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.urbanspork.server.trojan;

import com.urbanspork.common.channel.AttributeKeys;
import com.urbanspork.common.protocol.socks.Address;
import com.urbanspork.common.protocol.trojan.Trojan;
import com.urbanspork.common.transport.udp.DatagramPacketWrapper;
Expand Down Expand Up @@ -34,6 +35,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
short length = msg.readShort();
msg.skipBytes(Trojan.CRLF.length);
ByteBuf content = msg.readBytes(length);
ctx.channel().attr(AttributeKeys.SERVER_UDP_RELAY_WORKER).set(address);
out.add(new DatagramPacket(content, recipient, address));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.urbanspork.common.transport.udp.DatagramPacketWrapper;
import com.urbanspork.common.util.Dice;
import com.urbanspork.test.TestDice;
import com.urbanspork.test.template.TraceLevelLoggerTestTemplate;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
Expand All @@ -23,7 +24,7 @@
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;

class EmbeddedChannelTest {
class EmbeddedChannelTest extends TraceLevelLoggerTestTemplate {
@Test
void testTcpRelayChannel() {
int port = TestDice.rollPort();
Expand Down Expand Up @@ -109,4 +110,9 @@ void testAead2022UdpAntiReplay() {
client.close();
server.close();
}

@Override
protected Class<?> loggerClass() {
return UdpRelayCodec.class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void testIncorrectPassword() {
void testTooShortHeader() {
AeadCipherCodec codec = newAEADCipherCodec();
ByteBuf in = Unpooled.wrappedBuffer(Dice.rollBytes(3));
Context context = new Context(Mode.Client, new Control(TestDice.rollCipher()), null, ServerUserManager.empty());
Context context = new Context(Mode.Client, new Control(), null, ServerUserManager.empty());
Assertions.assertThrows(DecoderException.class, () -> codec.decode(context, in));
}

Expand All @@ -51,10 +51,9 @@ void testEmptyMsg(Mode from, Mode to) throws InvalidCipherTextException {
AeadCipherCodec codec = newAEADCipherCodec();
InetSocketAddress address = InetSocketAddress.createUnresolved(TestDice.rollHost(), TestDice.rollPort());
ByteBuf in = Unpooled.buffer();
CipherKind kind = TestDice.rollCipher();
codec.encode(new Context(from, new Control(kind), address, ServerUserManager.empty()), Unpooled.EMPTY_BUFFER, in);
codec.encode(new Context(from, new Control(), address, ServerUserManager.empty()), Unpooled.EMPTY_BUFFER, in);
Assertions.assertTrue(in.isReadable());
RelayingPacket<ByteBuf> pocket = codec.decode(new Context(to, new Control(kind), address, ServerUserManager.empty()), in);
RelayingPacket<ByteBuf> pocket = codec.decode(new Context(to, new Control(), address, ServerUserManager.empty()), in);
Assertions.assertFalse(in.isReadable());
Assertions.assertNotNull(pocket);
}
Expand All @@ -63,18 +62,18 @@ void testEmptyMsg(Mode from, Mode to) throws InvalidCipherTextException {
void testTooShortPacket() {
AeadCipherCodec codec = newAEADCipherCodec();
ByteBuf in = Unpooled.buffer();
Context c1 = new Context(Mode.Client, new Control(CipherKind.aead2022_blake3_aes_128_gcm), null, ServerUserManager.empty());
Context c1 = new Context(Mode.Client, new Control(), null, ServerUserManager.empty());
Assertions.assertThrows(DecoderException.class, () -> codec.decode(c1, in));
Context c2 = new Context(Mode.Server, new Control(CipherKind.aead2022_blake3_aes_128_gcm), null, ServerUserManager.empty());
Context c2 = new Context(Mode.Server, new Control(), null, ServerUserManager.empty());
Assertions.assertThrows(DecoderException.class, () -> codec.decode(c2, in));
}

@Test
void testInvalidSocketType() throws InvalidCipherTextException {
InetSocketAddress address = InetSocketAddress.createUnresolved(TestDice.rollHost(), TestDice.rollPort());
Context c1 = new Context(Mode.Client, new Control(CipherKind.aead2022_blake3_aes_128_gcm), address, ServerUserManager.empty());
Context c1 = new Context(Mode.Client, new Control(), address, ServerUserManager.empty());
testInvalidSocketType(c1);
Context c2 = new Context(Mode.Server, new Control(CipherKind.aead2022_blake3_aes_128_gcm), address, ServerUserManager.empty());
Context c2 = new Context(Mode.Server, new Control(), address, ServerUserManager.empty());
testInvalidSocketType(c2);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void testByKind(CipherKind kind) throws Exception {
int port = TestDice.rollPort();
String host = TestDice.rollHost();
InetSocketAddress address = InetSocketAddress.createUnresolved(host, port);
cipherTest(new Context(Mode.Client, new Control(kind), address, ServerUserManager.empty()), new Context(Mode.Server, new Control(kind), null, ServerUserManager.empty()));
cipherTest(new Context(Mode.Client, new Control(), address, ServerUserManager.empty()), new Context(Mode.Server, new Control(), null, ServerUserManager.empty()));
}

private void cipherTest(Context request, Context response) throws Exception {
Expand Down
Loading

0 comments on commit b4bdcd4

Please sign in to comment.