From 2be7d6ca36886e53434f94d628832d6ff270e667 Mon Sep 17 00:00:00 2001 From: Ryland Degnan Date: Tue, 11 Dec 2018 16:05:57 -0800 Subject: [PATCH] Replace for loops with forEach in RSocketClient/Server to fix non-deterministic behavior (#556) Signed-off-by: Ryland Degnan --- .../main/java/io/rsocket/RSocketClient.java | 457 +++++++++--------- .../main/java/io/rsocket/RSocketServer.java | 96 ++-- 2 files changed, 272 insertions(+), 281 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java index 7a4d54a4f..e781c4f3b 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java @@ -39,7 +39,7 @@ /** Client Side of a RSocket socket. Sends {@link Frame}s to a {@link RSocketServer} */ class RSocketClient implements RSocket { - + private final DuplexConnection connection; private final Function frameDecoder; private final Consumer errorConsumer; @@ -49,7 +49,7 @@ class RSocketClient implements RSocket { private final UnboundedProcessor sendProcessor; private final Lifecycle lifecycle = new Lifecycle(); private KeepAliveHandler keepAliveHandler; - + /*server requester*/ RSocketClient( DuplexConnection connection, @@ -59,7 +59,7 @@ class RSocketClient implements RSocket { this( connection, frameDecoder, errorConsumer, streamIdSupplier, Duration.ZERO, Duration.ZERO, 0); } - + /*client requester*/ RSocketClient( DuplexConnection connection, @@ -75,24 +75,24 @@ class RSocketClient implements RSocket { this.streamIdSupplier = streamIdSupplier; this.senders = Collections.synchronizedMap(new IntObjectHashMap<>()); this.receivers = Collections.synchronizedMap(new IntObjectHashMap<>()); - + // DO NOT Change the order here. The Send processor must be subscribed to before receiving this.sendProcessor = new UnboundedProcessor<>(); - + connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer); - + connection .send(sendProcessor) .doFinally(this::handleSendProcessorCancel) .subscribe(null, this::handleSendProcessorError); - + connection.receive().subscribe(this::handleIncomingFrames, errorConsumer); - + if (!Duration.ZERO.equals(tickPeriod)) { this.keepAliveHandler = KeepAliveHandler.ofClient( new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout, missedAcks)); - + keepAliveHandler .timeout() .subscribe( @@ -109,306 +109,297 @@ class RSocketClient implements RSocket { keepAliveHandler = null; } } - + private void handleSendProcessorError(Throwable t) { Throwable terminationError = lifecycle.getTerminationError(); Throwable err = terminationError != null ? terminationError : t; - for (Subscriber subscriber : receivers.values()) { + receivers.values().forEach(subscriber -> { try { subscriber.onError(err); } catch (Throwable e) { errorConsumer.accept(e); } - } - - for (LimitableRequestPublisher p : senders.values()) { - p.cancel(); - } + }); + + senders.values().forEach(LimitableRequestPublisher::cancel); } - + private void handleSendProcessorCancel(SignalType t) { if (SignalType.ON_ERROR == t) { return; } - - for (Subscriber subscriber : receivers.values()) { + + receivers.values().forEach(subscriber -> { try { subscriber.onError(new Throwable("closed connection")); } catch (Throwable e) { errorConsumer.accept(e); } - } - - for (LimitableRequestPublisher p : senders.values()) { - p.cancel(); - } + }); + + senders.values().forEach(LimitableRequestPublisher::cancel); } - + @Override public Mono fireAndForget(Payload payload) { return handleFireAndForget(payload); } - + @Override public Mono requestResponse(Payload payload) { return handleRequestResponse(payload); } - + @Override public Flux requestStream(Payload payload) { return handleRequestStream(payload); } - + @Override public Flux requestChannel(Publisher payloads) { return handleChannel(Flux.from(payloads)); } - + @Override public Mono metadataPush(Payload payload) { return handleMetadataPush(payload); } - + @Override public double availability() { return connection.availability(); } - + @Override public void dispose() { connection.dispose(); } - + @Override public boolean isDisposed() { return connection.isDisposed(); } - + @Override public Mono onClose() { return connection.onClose(); } - + private Mono handleFireAndForget(Payload payload) { return lifecycle - .active() - .then( - Mono.fromRunnable( - () -> { - final int streamId = streamIdSupplier.nextStreamId(); - final Frame requestFrame = - Frame.Request.from(streamId, FrameType.REQUEST_FNF, payload, 1); - payload.release(); - sendProcessor.onNext(requestFrame); - })); + .active() + .then( + Mono.fromRunnable( + () -> { + final int streamId = streamIdSupplier.nextStreamId(); + final Frame requestFrame = + Frame.Request.from(streamId, FrameType.REQUEST_FNF, payload, 1); + payload.release(); + sendProcessor.onNext(requestFrame); + })); } - + private Flux handleRequestStream(final Payload payload) { return lifecycle - .active() - .thenMany( - Flux.defer( - () -> { - int streamId = streamIdSupplier.nextStreamId(); - - UnicastProcessor receiver = UnicastProcessor.create(); - receivers.put(streamId, receiver); - - AtomicBoolean first = new AtomicBoolean(false); - - return receiver - .doOnRequest( - n -> { - if (first.compareAndSet(false, true) && !receiver.isDisposed()) { - final Frame requestFrame = - Frame.Request.from( - streamId, FrameType.REQUEST_STREAM, payload, n); - payload.release(); - sendProcessor.onNext(requestFrame); - } else if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.RequestN.from(streamId, n)); - } - sendProcessor.drain(); - }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.Error.from(streamId, t)); - } - }) - .doOnCancel( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.Cancel.from(streamId)); - } - }) - .doFinally( - s -> { - receivers.remove(streamId); - }); - })); + .active() + .thenMany( + Flux.defer( + () -> { + int streamId = streamIdSupplier.nextStreamId(); + + UnicastProcessor receiver = UnicastProcessor.create(); + receivers.put(streamId, receiver); + + AtomicBoolean first = new AtomicBoolean(false); + + return receiver + .doOnRequest( + n -> { + if (first.compareAndSet(false, true) && !receiver.isDisposed()) { + final Frame requestFrame = + Frame.Request.from( + streamId, FrameType.REQUEST_STREAM, payload, n); + payload.release(); + sendProcessor.onNext(requestFrame); + } else if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.RequestN.from(streamId, n)); + } + sendProcessor.drain(); + }) + .doOnError( + t -> { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.Error.from(streamId, t)); + } + }) + .doOnCancel( + () -> { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.Cancel.from(streamId)); + } + }) + .doFinally( + s -> { + receivers.remove(streamId); + }); + })); } - + private Mono handleRequestResponse(final Payload payload) { return lifecycle - .active() - .then( - Mono.defer( - () -> { - int streamId = streamIdSupplier.nextStreamId(); - final Frame requestFrame = - Frame.Request.from(streamId, FrameType.REQUEST_RESPONSE, payload, 1); - payload.release(); - - UnicastMonoProcessor receiver = UnicastMonoProcessor.create(); - receivers.put(streamId, receiver); - - sendProcessor.onNext(requestFrame); - - return receiver - .doOnError(t -> sendProcessor.onNext(Frame.Error.from(streamId, t))) - .doFinally( - s -> { - if (s == SignalType.CANCEL) { - sendProcessor.onNext(Frame.Cancel.from(streamId)); - } - - receivers.remove(streamId); - }); - })); + .active() + .then( + Mono.defer( + () -> { + int streamId = streamIdSupplier.nextStreamId(); + final Frame requestFrame = + Frame.Request.from(streamId, FrameType.REQUEST_RESPONSE, payload, 1); + payload.release(); + + UnicastMonoProcessor receiver = UnicastMonoProcessor.create(); + receivers.put(streamId, receiver); + + sendProcessor.onNext(requestFrame); + + return receiver + .doOnError(t -> sendProcessor.onNext(Frame.Error.from(streamId, t))) + .doFinally( + s -> { + if (s == SignalType.CANCEL) { + sendProcessor.onNext(Frame.Cancel.from(streamId)); + } + + receivers.remove(streamId); + }); + })); } - + private Flux handleChannel(Flux request) { return lifecycle - .active() - .thenMany( - Flux.defer( - () -> { - final UnicastProcessor receiver = UnicastProcessor.create(); - final int streamId = streamIdSupplier.nextStreamId(); - final AtomicBoolean firstRequest = new AtomicBoolean(true); - - return receiver - .doOnRequest( - n -> { - if (firstRequest.compareAndSet(true, false)) { - final AtomicBoolean firstPayload = new AtomicBoolean(true); - final Flux requestFrames = - request - .transform( - f -> { - LimitableRequestPublisher wrapped = - LimitableRequestPublisher.wrap(f); - // Need to set this to one for first the frame - wrapped.increaseRequestLimit(1); - senders.put(streamId, wrapped); - receivers.put(streamId, receiver); - - return wrapped; - }) - .map( - payload -> { - final Frame requestFrame; - if (firstPayload.compareAndSet(true, false)) { - requestFrame = - Frame.Request.from( - streamId, - FrameType.REQUEST_CHANNEL, - payload, - n); - } else { - requestFrame = - Frame.PayloadFrame.from( - streamId, FrameType.NEXT, payload); - } - payload.release(); - return requestFrame; - }) - .doOnComplete( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - Frame.PayloadFrame.from( - streamId, FrameType.COMPLETE)); - } - if (firstPayload.get()) { - receiver.onComplete(); - } - }); - - requestFrames.subscribe( - sendProcessor::onNext, - t -> { - errorConsumer.accept(t); - receiver.dispose(); - }); - } else { + .active() + .thenMany( + Flux.defer( + () -> { + final UnicastProcessor receiver = UnicastProcessor.create(); + final int streamId = streamIdSupplier.nextStreamId(); + final AtomicBoolean firstRequest = new AtomicBoolean(true); + + return receiver + .doOnRequest( + n -> { + if (firstRequest.compareAndSet(true, false)) { + final AtomicBoolean firstPayload = new AtomicBoolean(true); + final Flux requestFrames = + request + .transform( + f -> { + LimitableRequestPublisher wrapped = + LimitableRequestPublisher.wrap(f); + // Need to set this to one for first the frame + wrapped.increaseRequestLimit(1); + senders.put(streamId, wrapped); + receivers.put(streamId, receiver); + + return wrapped; + }) + .map( + payload -> { + final Frame requestFrame; + if (firstPayload.compareAndSet(true, false)) { + requestFrame = + Frame.Request.from( + streamId, + FrameType.REQUEST_CHANNEL, + payload, + n); + } else { + requestFrame = + Frame.PayloadFrame.from( + streamId, FrameType.NEXT, payload); + } + payload.release(); + return requestFrame; + }) + .doOnComplete( + () -> { if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.RequestN.from(streamId, n)); + sendProcessor.onNext( + Frame.PayloadFrame.from( + streamId, FrameType.COMPLETE)); + } + if (firstPayload.get()) { + receiver.onComplete(); } - } - }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.Error.from(streamId, t)); - } - }) - .doOnCancel( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.Cancel.from(streamId)); - } - }) - .doFinally( - s -> { - receivers.remove(streamId); - LimitableRequestPublisher sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - }); - })); + }); + + requestFrames.subscribe( + sendProcessor::onNext, + t -> { + errorConsumer.accept(t); + receiver.dispose(); + }); + } else { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.RequestN.from(streamId, n)); + } + } + }) + .doOnError( + t -> { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.Error.from(streamId, t)); + } + }) + .doOnCancel( + () -> { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.Cancel.from(streamId)); + } + }) + .doFinally( + s -> { + receivers.remove(streamId); + LimitableRequestPublisher sender = senders.remove(streamId); + if (sender != null) { + sender.cancel(); + } + }); + })); } - + private Mono handleMetadataPush(Payload payload) { return lifecycle - .active() - .then( - Mono.fromRunnable( - () -> { - final Frame requestFrame = - Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1); - payload.release(); - sendProcessor.onNext(requestFrame); - })); + .active() + .then( + Mono.fromRunnable( + () -> { + final Frame requestFrame = + Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1); + payload.release(); + sendProcessor.onNext(requestFrame); + })); } - + private boolean contains(int streamId) { return receivers.containsKey(streamId); } - + protected void terminate() { - lifecycle.setTerminationError(new ClosedChannelException()); - + if (keepAliveHandler != null) { keepAliveHandler.dispose(); } try { - for (Processor subscriber : receivers.values()) { - cleanUpSubscriber(subscriber); - } - for (LimitableRequestPublisher p : senders.values()) { - cleanUpLimitableRequestPublisher(p); - } + receivers.values().forEach(this::cleanUpSubscriber); + senders.values().forEach(this::cleanUpLimitableRequestPublisher); } finally { senders.clear(); receivers.clear(); sendProcessor.dispose(); } } - + private synchronized void cleanUpLimitableRequestPublisher( LimitableRequestPublisher limitableRequestPublisher) { try { @@ -417,7 +408,7 @@ private synchronized void cleanUpLimitableRequestPublisher( errorConsumer.accept(t); } } - + private synchronized void cleanUpSubscriber(Processor subscriber) { try { subscriber.onError(lifecycle.getTerminationError()); @@ -425,7 +416,7 @@ private synchronized void cleanUpSubscriber(Processor subscriber) { errorConsumer.accept(t); } } - + private void handleIncomingFrames(Frame frame) { try { int streamId = frame.getStreamId(); @@ -439,7 +430,7 @@ private void handleIncomingFrames(Frame frame) { frame.release(); } } - + private void handleStreamZero(FrameType type, Frame frame) { switch (type) { case ERROR: @@ -462,7 +453,7 @@ private void handleStreamZero(FrameType type, Frame frame) { "Client received supported frame on stream 0: " + frame.toString())); } } - + private void handleFrame(int streamId, FrameType type, Frame frame) { Subscriber receiver = receivers.get(streamId); if (receiver == null) { @@ -509,14 +500,14 @@ private void handleFrame(int streamId, FrameType type, Frame frame) { } } } - + private void handleMissingResponseProcessor(int streamId, FrameType type, Frame frame) { if (!streamIdSupplier.isBeforeOrCurrent(streamId)) { if (type == FrameType.ERROR) { // message for stream that has never existed, we have a problem with // the overall connection and must tear down String errorMessage = frame.getDataUtf8(); - + throw new IllegalStateException( "Client received error for non-existent stream: " + streamId @@ -533,14 +524,14 @@ private void handleMissingResponseProcessor(int streamId, FrameType type, Frame // receiving a frame after a given stream has been cancelled/completed, // so ignore (cancellation is async so there is a race condition) } - + private static class Lifecycle { - + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = AtomicReferenceFieldUpdater.newUpdater( Lifecycle.class, Throwable.class, "terminationError"); private volatile Throwable terminationError; - + public Mono active() { return Mono.create( sink -> { @@ -551,11 +542,11 @@ public Mono active() { } }); } - + public void setTerminationError(Throwable err) { TERMINATION_ERROR.compareAndSet(this, null, err); } - + public Throwable getTerminationError() { return terminationError; } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java index 3b47d59e3..e29d4d1f1 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java @@ -40,18 +40,18 @@ /** Server side RSocket. Receives {@link Frame}s from a {@link RSocketClient} */ class RSocketServer implements RSocket { - + private final DuplexConnection connection; private final RSocket requestHandler; private final Function frameDecoder; private final Consumer errorConsumer; - + private final Map sendingSubscriptions; private final Map> channelProcessors; - + private final UnboundedProcessor sendProcessor; private KeepAliveHandler keepAliveHandler; - + /*client responder*/ RSocketServer( DuplexConnection connection, @@ -60,7 +60,7 @@ class RSocketServer implements RSocket { Consumer errorConsumer) { this(connection, requestHandler, frameDecoder, errorConsumer, 0, 0); } - + /*server responder*/ RSocketServer( DuplexConnection connection, @@ -75,18 +75,18 @@ class RSocketServer implements RSocket { this.errorConsumer = errorConsumer; this.sendingSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>()); this.channelProcessors = Collections.synchronizedMap(new IntObjectHashMap<>()); - + // DO NOT Change the order here. The Send processor must be subscribed to before receiving // connections this.sendProcessor = new UnboundedProcessor<>(); - + connection .send(sendProcessor) .doFinally(this::handleSendProcessorCancel) .subscribe(null, this::handleSendProcessorError); - + Disposable receiveDisposable = connection.receive().subscribe(this::handleFrame, errorConsumer); - + this.connection .onClose() .doFinally( @@ -95,11 +95,11 @@ class RSocketServer implements RSocket { receiveDisposable.dispose(); }) .subscribe(null, errorConsumer); - + if (tickPeriod != 0) { keepAliveHandler = KeepAliveHandler.ofServer(new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout)); - + keepAliveHandler .timeout() .subscribe( @@ -114,47 +114,47 @@ class RSocketServer implements RSocket { keepAliveHandler = null; } } - + private void handleSendProcessorError(Throwable t) { - for (Subscription subscription : sendingSubscriptions.values()) { + sendingSubscriptions.values().forEach(subscription -> { try { subscription.cancel(); } catch (Throwable e) { errorConsumer.accept(e); } - } - - for (Processor subscription : channelProcessors.values()) { + }); + + channelProcessors.values().forEach(subscription -> { try { subscription.onError(t); } catch (Throwable e) { errorConsumer.accept(e); } - } + }); } - + private void handleSendProcessorCancel(SignalType t) { if (SignalType.ON_ERROR == t) { return; } - - for (Subscription subscription : sendingSubscriptions.values()) { + + sendingSubscriptions.values().forEach(subscription -> { try { subscription.cancel(); } catch (Throwable e) { errorConsumer.accept(e); } - } - - for (Processor subscription : channelProcessors.values()) { + }); + + channelProcessors.values().forEach(subscription -> { try { subscription.onComplete(); } catch (Throwable e) { errorConsumer.accept(e); } - } + }); } - + @Override public Mono fireAndForget(Payload payload) { try { @@ -163,7 +163,7 @@ public Mono fireAndForget(Payload payload) { return Mono.error(t); } } - + @Override public Mono requestResponse(Payload payload) { try { @@ -172,7 +172,7 @@ public Mono requestResponse(Payload payload) { return Mono.error(t); } } - + @Override public Flux requestStream(Payload payload) { try { @@ -181,7 +181,7 @@ public Flux requestStream(Payload payload) { return Flux.error(t); } } - + @Override public Flux requestChannel(Publisher payloads) { try { @@ -190,7 +190,7 @@ public Flux requestChannel(Publisher payloads) { return Flux.error(t); } } - + @Override public Mono metadataPush(Payload payload) { try { @@ -199,45 +199,45 @@ public Mono metadataPush(Payload payload) { return Mono.error(t); } } - + @Override public void dispose() { connection.dispose(); } - + @Override public boolean isDisposed() { return connection.isDisposed(); } - + @Override public Mono onClose() { return connection.onClose(); } - + private void cleanup() { if (keepAliveHandler != null) { keepAliveHandler.dispose(); } cleanUpSendingSubscriptions(); cleanUpChannelProcessors(); - + requestHandler.dispose(); sendProcessor.dispose(); } - + private synchronized void cleanUpSendingSubscriptions() { sendingSubscriptions.values().forEach(Subscription::cancel); sendingSubscriptions.clear(); } - + private synchronized void cleanUpChannelProcessors() { channelProcessors .values() .forEach(Processor::onComplete); channelProcessors.clear(); } - + private void handleFrame(Frame frame) { try { int streamId = frame.getStreamId(); @@ -313,14 +313,14 @@ private void handleFrame(Frame frame) { frame.release(); } } - + private void handleFireAndForget(int streamId, Mono result) { result .doOnSubscribe(subscription -> sendingSubscriptions.put(streamId, subscription)) .doFinally(signalType -> sendingSubscriptions.remove(streamId)) .subscribe(null, errorConsumer); } - + private void handleRequestResponse(int streamId, Mono response) { response .doOnSubscribe(subscription -> sendingSubscriptions.put(streamId, subscription)) @@ -340,7 +340,7 @@ private void handleRequestResponse(int streamId, Mono response) { .doFinally(signalType -> sendingSubscriptions.remove(streamId)) .subscribe(sendProcessor::onNext, t -> handleError(streamId, t)); } - + private void handleStream(int streamId, Flux response, int initialRequestN) { response .transform( @@ -364,44 +364,44 @@ private void handleStream(int streamId, Flux response, int initialReque sendProcessor.onNext(frame); }); } - + private void handleChannel(int streamId, Payload payload, int initialRequestN) { UnicastProcessor frames = UnicastProcessor.create(); channelProcessors.put(streamId, frames); - + Flux payloads = frames .doOnCancel(() -> sendProcessor.onNext(Frame.Cancel.from(streamId))) .doOnError(t -> sendProcessor.onNext(Frame.Error.from(streamId, t))) .doOnRequest(l -> sendProcessor.onNext(Frame.RequestN.from(streamId, l))) .doFinally(signalType -> channelProcessors.remove(streamId)); - + // not chained, as the payload should be enqueued in the Unicast processor before this method // returns // and any later payload can be processed frames.onNext(payload); - + handleStream(streamId, requestChannel(payloads), initialRequestN); } - + private void handleKeepAliveFrame(Frame frame) { if (keepAliveHandler != null) { keepAliveHandler.receive(frame); } } - + private void handleCancelFrame(int streamId) { Subscription subscription = sendingSubscriptions.remove(streamId); if (subscription != null) { subscription.cancel(); } } - + private void handleError(int streamId, Throwable t) { errorConsumer.accept(t); sendProcessor.onNext(Frame.Error.from(streamId, t)); } - + private void handleRequestN(int streamId, Frame frame) { final Subscription subscription = sendingSubscriptions.get(streamId); if (subscription != null) {