Skip to content

Commit

Permalink
Server Connection Closing on Exception (#442)
Browse files Browse the repository at this point in the history
* server connection was being closed when the server received an error frame
* added test to test handling a request after a server receives an exception
  • Loading branch information
robertroeser authored and yschimke committed Nov 4, 2017
1 parent 3d58344 commit 1849d08
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 82 deletions.
2 changes: 1 addition & 1 deletion rsocket-core/src/main/java/io/rsocket/RSocketClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ public Frame apply(Payload payload) {
.doOnError(
t -> {
errorConsumer.accept(t);
receiver.cancel();
receiver.dispose();
})
.subscribe();
} else {
Expand Down
9 changes: 8 additions & 1 deletion rsocket-core/src/main/java/io/rsocket/RSocketServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@ class RSocketServer implements RSocket {
this.receiveDisposable =
connection
.receive()
.flatMapSequential(this::handleFrame)
.flatMapSequential(
frame ->
handleFrame(frame)
.onErrorResume(
t -> {
errorConsumer.accept(t);
return Mono.empty();
}))
.doOnError(errorConsumer)
.then()
.subscribe();
Expand Down
124 changes: 62 additions & 62 deletions rsocket-core/src/main/java/io/rsocket/internal/SwitchTransform.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,71 +10,71 @@
import reactor.core.publisher.Operators;

public final class SwitchTransform<T, R> extends Flux<R> {

final Publisher<? extends T> source;
final BiFunction<T, Flux<T>, Publisher<? extends R>> transformer;

public SwitchTransform(
Publisher<? extends T> source, BiFunction<T, Flux<T>, Publisher<? extends R>> transformer) {
this.source = Objects.requireNonNull(source, "source");
this.transformer = Objects.requireNonNull(transformer, "transformer");
}

@Override
public void subscribe(CoreSubscriber<? super R> actual) {
Flux.from(source).subscribe(new SwitchTransformSubscriber<>(actual, transformer));
}

static final class SwitchTransformSubscriber<T, R> implements CoreSubscriber<T> {
@SuppressWarnings("rawtypes")
static final AtomicIntegerFieldUpdater<SwitchTransformSubscriber> ONCE =
AtomicIntegerFieldUpdater.newUpdater(SwitchTransformSubscriber.class, "once");

final CoreSubscriber<? super R> actual;

final Publisher<? extends T> source;
final BiFunction<T, Flux<T>, Publisher<? extends R>> transformer;
final UnboundedProcessor<T> processor = new UnboundedProcessor<>();
Subscription s;
volatile int once;

SwitchTransformSubscriber(
CoreSubscriber<? super R> actual,
BiFunction<T, Flux<T>, Publisher<? extends R>> transformer) {
this.actual = actual;
this.transformer = transformer;

public SwitchTransform(
Publisher<? extends T> source, BiFunction<T, Flux<T>, Publisher<? extends R>> transformer) {
this.source = Objects.requireNonNull(source, "source");
this.transformer = Objects.requireNonNull(transformer, "transformer");
}

@Override
public void onSubscribe(Subscription s) {
if (Operators.validate(this.s, s)) {
this.s = s;
processor.onSubscribe(s);
}
public void subscribe(CoreSubscriber<? super R> actual) {
Flux.from(source).subscribe(new SwitchTransformSubscriber<>(actual, transformer));
}

@Override
public void onNext(T t) {
if (once == 0 && ONCE.compareAndSet(this, 0, 1)) {
try {
Publisher<? extends R> result =
Objects.requireNonNull(
transformer.apply(t, processor), "The transformer returned a null value");
Flux.from(result).subscribe(actual);
} catch (Throwable e) {
onError(Operators.onOperatorError(s, e, t, actual.currentContext()));
return;

static final class SwitchTransformSubscriber<T, R> implements CoreSubscriber<T> {
@SuppressWarnings("rawtypes")
static final AtomicIntegerFieldUpdater<SwitchTransformSubscriber> ONCE =
AtomicIntegerFieldUpdater.newUpdater(SwitchTransformSubscriber.class, "once");

final CoreSubscriber<? super R> actual;
final BiFunction<T, Flux<T>, Publisher<? extends R>> transformer;
final UnboundedProcessor<T> processor = new UnboundedProcessor<>();
Subscription s;
volatile int once;

SwitchTransformSubscriber(
CoreSubscriber<? super R> actual,
BiFunction<T, Flux<T>, Publisher<? extends R>> transformer) {
this.actual = actual;
this.transformer = transformer;
}

@Override
public void onSubscribe(Subscription s) {
if (Operators.validate(this.s, s)) {
this.s = s;
processor.onSubscribe(s);
}
}

@Override
public void onNext(T t) {
if (once == 0 && ONCE.compareAndSet(this, 0, 1)) {
try {
Publisher<? extends R> result =
Objects.requireNonNull(
transformer.apply(t, processor), "The transformer returned a null value");
Flux.from(result).subscribe(actual);
} catch (Throwable e) {
onError(Operators.onOperatorError(s, e, t, actual.currentContext()));
return;
}
}
processor.onNext(t);
}

@Override
public void onError(Throwable t) {
processor.onError(t);
}

@Override
public void onComplete() {
processor.onComplete();
}
}
processor.onNext(t);
}

@Override
public void onError(Throwable t) {
processor.onError(t);
}

@Override
public void onComplete() {
processor.onComplete();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@

package io.rsocket.integration;

import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

import io.rsocket.AbstractRSocket;
import io.rsocket.Payload;
import io.rsocket.RSocket;
Expand All @@ -35,28 +28,33 @@
import io.rsocket.transport.netty.server.TcpServerTransport;
import io.rsocket.util.PayloadImpl;
import io.rsocket.util.RSocketProxy;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class IntegrationTest {
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

private NettyContextCloseable server;
private RSocket client;
private AtomicInteger requestCount;
private CountDownLatch disconnectionCounter;
public static volatile boolean calledClient = false;
public static volatile boolean calledServer = false;
public static volatile boolean calledFrame = false;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

public class IntegrationTest {

private static final RSocketInterceptor clientPlugin;
private static final RSocketInterceptor serverPlugin;
private static final DuplexConnectionInterceptor connectionPlugin;
public static volatile boolean calledClient = false;
public static volatile boolean calledServer = false;
public static volatile boolean calledFrame = false;

static {
clientPlugin =
Expand Down Expand Up @@ -86,8 +84,15 @@ public Mono<Payload> requestResponse(Payload payload) {
};
}

private NettyContextCloseable server;
private RSocket client;
private AtomicInteger requestCount;
private CountDownLatch disconnectionCounter;
private AtomicInteger errorCount;

@Before
public void startup() {
errorCount = new AtomicInteger();
requestCount = new AtomicInteger();
disconnectionCounter = new CountDownLatch(1);

Expand All @@ -97,6 +102,9 @@ public void startup() {
RSocketFactory.receive()
.addServerPlugin(serverPlugin)
.addConnectionPlugin(connectionPlugin)
.errorConsumer(t -> {
errorCount.incrementAndGet();
})
.acceptor(
(setup, sendingSocket) -> {
sendingSocket
Expand All @@ -116,6 +124,11 @@ public Mono<Payload> requestResponse(Payload payload) {
public Flux<Payload> requestStream(Payload payload) {
return Flux.range(1, 10_000).map(i -> new PayloadImpl("data -> " + i));
}

@Override
public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
return Flux.from(payloads);
}
});
})
.transport(serverTransport)
Expand Down Expand Up @@ -145,7 +158,7 @@ public void testRequest() {
assertTrue(calledFrame);
}

@Test
@Test(timeout = 5_000L)
public void testStream() {
Subscriber<Payload> subscriber = TestSubscriber.createCancelling();
client.requestStream(new PayloadImpl("start")).subscribe(subscriber);
Expand All @@ -159,4 +172,16 @@ public void testClose() throws InterruptedException {
client.close().block();
disconnectionCounter.await();
}

@Test // (timeout = 5_000L)
public void testCallRequestWithErrorAndThenRequest() {
try {
client.requestChannel(Mono.error(new Throwable())).blockLast();
} catch (Throwable t) {
}

Assert.assertEquals(1, errorCount.incrementAndGet());

testRequest();
}
}

0 comments on commit 1849d08

Please sign in to comment.