Skip to content

Commit

Permalink
Http2WebSocketHandler: propagate messages down the channel pipeline i…
Browse files Browse the repository at this point in the history
…f this handler is removed
  • Loading branch information
mostroverkhov committed Dec 16, 2023
1 parent 4af4539 commit 2679193
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ Http2FrameListener webSocketOrNext(int streamId) {
}
return webSocket;
}
return next;
return next();
}

void registerWebSocket(int streamId, Http2WebSocketChannel webSocket) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public abstract class Http2WebSocketHandler extends ChannelDuplexHandler
static final AsciiString HEADER_WEBSOCKET_ENDOFSTREAM_VALUE_FALSE = AsciiString.of("false");

Http2ConnectionHandler http2Handler;
Http2FrameListener next;
HandlerListener handlerListener;

Http2WebSocketHandler() {}

Expand All @@ -40,10 +40,24 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
Http2ConnectionHandler http2Handler =
this.http2Handler =
Preconditions.requireHandler(ctx.channel(), Http2ConnectionHandler.class);
Http2ConnectionDecoder decoder = http2Handler.decoder();
Http2FrameListener next = decoder.frameListener();
decoder.frameListener(this);
this.next = next;
HandlerListener listener = handlerListener;
if (listener == null) {
Http2ConnectionDecoder decoder = http2Handler.decoder();
Http2FrameListener next = decoder.frameListener();
listener = handlerListener = new HandlerListener().current(this).next(next);
decoder.frameListener(listener);
} else {
listener.current(this);
}
}

@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
HandlerListener listener = handlerListener;
if (listener != null) {
listener.current(null);
}
super.handlerRemoved(ctx);
}

@Override
Expand Down Expand Up @@ -150,7 +164,7 @@ public void onUnknownFrame(
}

final Http2FrameListener next() {
return next;
return handlerListener.next;
}

static AsciiString endOfStreamName() {
Expand All @@ -162,4 +176,134 @@ static AsciiString endOfStreamValue(boolean endOfStream) {
? HEADER_WEBSOCKET_ENDOFSTREAM_VALUE_TRUE
: HEADER_WEBSOCKET_ENDOFSTREAM_VALUE_FALSE;
}

static final class HandlerListener implements Http2FrameListener {
Http2FrameListener cur;
Http2FrameListener next;

public HandlerListener current(Http2FrameListener cur) {
this.cur = cur;
return this;
}

public HandlerListener next(Http2FrameListener next) {
this.next = next;
return this;
}

Http2FrameListener next() {
return next;
}

private Http2FrameListener listener() {
Http2FrameListener c = cur;
if (c != null) {
return c;
}
return next;
}

@Override
public int onDataRead(
ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream)
throws Http2Exception {
return listener().onDataRead(ctx, streamId, data, padding, endOfStream);
}

@Override
public void onHeadersRead(
ChannelHandlerContext ctx,
int streamId,
Http2Headers headers,
int padding,
boolean endOfStream)
throws Http2Exception {
listener().onHeadersRead(ctx, streamId, headers, padding, endOfStream);
}

@Override
public void onHeadersRead(
ChannelHandlerContext ctx,
int streamId,
Http2Headers headers,
int streamDependency,
short weight,
boolean exclusive,
int padding,
boolean endOfStream)
throws Http2Exception {
listener()
.onHeadersRead(
ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream);
}

@Override
public void onPriorityRead(
ChannelHandlerContext ctx,
int streamId,
int streamDependency,
short weight,
boolean exclusive)
throws Http2Exception {
listener().onPriorityRead(ctx, streamId, streamDependency, weight, exclusive);
}

@Override
public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode)
throws Http2Exception {
listener().onRstStreamRead(ctx, streamId, errorCode);
}

@Override
public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception {
listener().onSettingsAckRead(ctx);
}

@Override
public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings)
throws Http2Exception {
listener().onSettingsRead(ctx, settings);
}

@Override
public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception {
listener().onPingRead(ctx, data);
}

@Override
public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception {
listener().onPingAckRead(ctx, data);
}

@Override
public void onPushPromiseRead(
ChannelHandlerContext ctx,
int streamId,
int promisedStreamId,
Http2Headers headers,
int padding)
throws Http2Exception {
listener().onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding);
}

@Override
public void onGoAwayRead(
ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception {
listener().onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
}

@Override
public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement)
throws Http2Exception {
listener().onWindowUpdateRead(ctx, streamId, windowSizeIncrement);
}

@Override
public void onUnknownFrame(
ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload)
throws Http2Exception {
listener().onUnknownFrame(ctx, frameType, streamId, flags, payload);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http2.Http2FrameCodec;
import io.netty.handler.codec.http2.Http2FrameCodecBuilder;
import io.netty.handler.codec.http2.Http2HeadersFrame;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.ReferenceCountUtil;
import java.net.SocketAddress;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
Expand All @@ -32,6 +35,7 @@
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

Expand Down Expand Up @@ -108,6 +112,70 @@ void websocketFramesExchange(WebSocketsConfigurer webSocketsConfigurer) throws E
}
}

@Test
void handlerRemove() throws InterruptedException {
Http2FramesHandler serverHttp2FramesHandler = new Http2FramesHandler();
server =
createServer(
ch -> {
SslHandler sslHandler = serverSslContext.newHandler(ch.alloc());
Http2FrameCodecBuilder http2FrameCodecBuilder =
Http2FrameCodecBuilder.forServer().validateHeaders(false);
Http2Settings settings = http2FrameCodecBuilder.initialSettings();
settings.put(Http2WebSocketProtocol.SETTINGS_ENABLE_CONNECT_PROTOCOL, (Long) 1L);
settings.initialWindowSize(INITIAL_WINDOW_SIZE);
Http2FrameCodec http2frameCodec = http2FrameCodecBuilder.build();

Http2WebSocketServerHandler http2webSocketHandler =
Http2WebSocketServerBuilder.create()
.acceptor(new PathAcceptor("/test", new RemoveHttp2WebSocketHandler()))
.build();
ch.pipeline()
.addLast(
sslHandler,
http2frameCodec,
http2webSocketHandler,
serverHttp2FramesHandler);
})
.sync()
.channel();

SocketAddress address = server.localAddress();

client =
createClient(
address,
ch -> {
SslHandler sslHandler = clientSslContext.newHandler(ch.alloc());
Http2FrameCodecBuilder http2FrameCodecBuilder =
Http2FrameCodecBuilder.forClient();
Http2Settings settings = http2FrameCodecBuilder.initialSettings();
settings.initialWindowSize(INITIAL_WINDOW_SIZE);
Http2FrameCodec http2FrameCodec = http2FrameCodecBuilder.build();
Http2WebSocketClientHandler http2WebSocketClientHandler =
Http2WebSocketClientBuilder.create().handshakeTimeoutMillis(5_000).build();
ch.pipeline().addLast(sslHandler, http2FrameCodec, http2WebSocketClientHandler);
})
.sync()
.channel();

NoopClientWebSocketHandler clientWebSocketHandler = new NoopClientWebSocketHandler(client);
Http2WebSocketClientHandshaker handshaker = Http2WebSocketClientHandshaker.create(client);
ChannelFuture clientHandshake = handshaker.handshake("/test", clientWebSocketHandler);

clientHandshake.await(6, TimeUnit.SECONDS);
Assertions.assertThat(clientHandshake.isSuccess()).isTrue();
ChannelPromise frameReceived = clientWebSocketHandler.frameReceived;
frameReceived.await(6, TimeUnit.SECONDS);
Assertions.assertThat(frameReceived.isSuccess()).isTrue();

NoopClientWebSocketHandler nextClientWebSocketHandler = new NoopClientWebSocketHandler(client);
ChannelFuture nextClientHandshake = handshaker.handshake("/test", nextClientWebSocketHandler);
nextClientHandshake.await(6, TimeUnit.SECONDS);
Assertions.assertThat(nextClientHandshake.isSuccess()).isFalse();
Assertions.assertThat(serverHttp2FramesHandler.requestsReceived).isEqualTo(1);
}

@BeforeEach
void setUp() throws Exception {
serverSslContext = serverSslContext();
Expand All @@ -128,6 +196,40 @@ void tearDown() throws Exception {
}
}

private static class NoopClientWebSocketHandler
extends SimpleChannelInboundHandler<TextWebSocketFrame> {

volatile ChannelPromise frameReceived;

NoopClientWebSocketHandler(Channel channel) {
this.frameReceived = channel.newPromise();
}

@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
ctx.writeAndFlush(new TextWebSocketFrame("test"));
}

@Override
protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) {
msg.release();
frameReceived.trySuccess();
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
frameReceived.tryFailure(new ClosedChannelException());
super.channelInactive(ctx);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
frameReceived.tryFailure(cause);
super.exceptionCaught(ctx, cause);
}
}

private static class ClientWebSocketHandler
extends SimpleChannelInboundHandler<TextWebSocketFrame> {
private volatile ChannelPromise allFramesReceived;
Expand Down Expand Up @@ -180,6 +282,39 @@ private TextWebSocketFrame nextWebSocketFrame() {
}
}

private static class Http2FramesHandler extends ChannelInboundHandlerAdapter {
volatile int requestsReceived;

@SuppressWarnings("NonAtomicOperationOnVolatileField")
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
if (msg instanceof Http2HeadersFrame) {
/*called on eventloop thread only*/
requestsReceived++;
ctx.close();
}
} finally {
ReferenceCountUtil.safeRelease(msg);
}
}
}

private static class RemoveHttp2WebSocketHandler
extends SimpleChannelInboundHandler<TextWebSocketFrame> {

@Override
protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame webSocketFrame) {
ctx.writeAndFlush(webSocketFrame.retain());
ctx.channel().parent().pipeline().remove(Http2WebSocketServerHandler.class);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
ctx.close();
}
}

private static class ServerWebSocketHandler
extends SimpleChannelInboundHandler<TextWebSocketFrame> {

Expand Down

0 comments on commit 2679193

Please sign in to comment.