From 213e8040802ffb1e428f22b1b50cb19e4113024d Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Mon, 28 Oct 2024 23:53:21 -0700 Subject: [PATCH] address PR comments --- .../client/AbstractWindmillStream.java | 185 +++++------------- .../client/ResettableStreamObserver.java | 97 +++++++++ .../windmill/client/StreamDebugMetrics.java | 135 +++++++++++++ .../client/grpc/GrpcGetDataStream.java | 2 +- .../client/ResettableStreamObserverTest.java | 90 +++++++++ 5 files changed, 367 insertions(+), 142 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index ca54ade05f0f..511cf4c1a79f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -25,24 +25,14 @@ import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.com.google.api.client.util.Sleeper; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; @@ -74,9 +64,9 @@ public abstract class AbstractWindmillStream implements Win // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce // per-chunk overhead, and small enough that we can still perform granular flow-control. protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20; + // Indicates that the logical stream has been half-closed and is waiting for clean server + // shutdown. private static final Status OK_STATUS = Status.fromCode(Status.Code.OK); - - protected final AtomicBoolean clientClosed; protected final Sleeper sleeper; /** @@ -87,30 +77,24 @@ public abstract class AbstractWindmillStream implements Win */ protected final Object shutdownLock = new Object(); - private final AtomicLong lastSendTimeMs; private final ExecutorService executor; private final BackOff backoff; - private final AtomicLong startTimeMs; - private final AtomicLong lastResponseTimeMs; - private final AtomicInteger restartCount; - private final AtomicInteger errorCount; - private final AtomicReference lastRestartReason; - private final AtomicReference lastRestartTime; - private final AtomicLong sleepUntil; private final CountDownLatch finishLatch; private final Set> streamRegistry; private final int logEveryNStreamFailures; private final String backendWorkerToken; - private final ResettableRequestObserver requestObserver; - private final AtomicReference shutdownTime; + private final ResettableStreamObserver requestObserver; + private final StreamDebugMonitor debugMonitor; + private final Logger logger; + protected volatile boolean clientClosed; /** - * Indicates if the current {@link ResettableRequestObserver} was closed by calling {@link - * #halfClose()}. + * Indicates if the current {@link ResettableStreamObserver} was closed by calling {@link + * #halfClose()}. Separate from {@link #clientClosed} as this is specific to the requestObserver + * and is initially false on retry. */ - private final AtomicBoolean streamClosed; + private volatile boolean streamClosed; - private final Logger logger; private volatile boolean isShutdown; private volatile boolean started; @@ -133,28 +117,20 @@ protected AbstractWindmillStream( this.backoff = backoff; this.streamRegistry = streamRegistry; this.logEveryNStreamFailures = logEveryNStreamFailures; - this.clientClosed = new AtomicBoolean(); + this.clientClosed = false; this.isShutdown = false; this.started = false; - this.streamClosed = new AtomicBoolean(false); - this.startTimeMs = new AtomicLong(); - this.lastSendTimeMs = new AtomicLong(); - this.lastResponseTimeMs = new AtomicLong(); - this.restartCount = new AtomicInteger(); - this.errorCount = new AtomicInteger(); - this.lastRestartReason = new AtomicReference<>(); - this.lastRestartTime = new AtomicReference<>(); - this.sleepUntil = new AtomicLong(); + this.streamClosed = false; this.finishLatch = new CountDownLatch(1); this.requestObserver = - new ResettableRequestObserver<>( + new ResettableStreamObserver<>( () -> streamObserverFactory.from( clientFactory, new AbstractWindmillStream.ResponseObserver())); this.sleeper = Sleeper.DEFAULT; this.logger = logger; - this.shutdownTime = new AtomicReference<>(); + this.debugMonitor = new StreamDebugMonitor(); } private static String createThreadName(String streamType, String backendWorkerToken) { @@ -163,10 +139,6 @@ private static String createThreadName(String streamType, String backendWorkerTo : String.format("%s-WindmillStream-thread", streamType); } - private static long debugDuration(long nowMs, long startMs) { - return startMs <= 0 ? -1 : Math.max(0, nowMs - startMs); - } - /** Called on each response from the server. */ protected abstract void onResponse(ResponseT response); @@ -195,13 +167,13 @@ protected final void send(RequestT request) { return; } - if (streamClosed.get()) { + if (streamClosed) { // TODO(m-trieu): throw a more specific exception here (i.e StreamClosedException) throw new IllegalStateException("Send called on a client closed stream."); } try { - lastSendTimeMs.set(Instant.now().getMillis()); + debugMonitor.recordSend(); requestObserver.onNext(request); } catch (StreamObserverCancelledException e) { if (isShutdown) { @@ -239,21 +211,22 @@ private void startStream() { if (isShutdown) { break; } - startTimeMs.set(Instant.now().getMillis()); - lastResponseTimeMs.set(0); - streamClosed.set(false); + debugMonitor.recordStart(); + streamClosed = false; requestObserver.reset(); onNewStream(); - if (clientClosed.get()) { + if (clientClosed) { halfClose(); } return; } + } catch (WindmillStreamShutdownException e) { + logger.debug("Stream was shutdown waiting to start.", e); } catch (Exception e) { logger.error("Failed to create new stream, retrying: ", e); try { long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); + debugMonitor.recordSleep(sleep); sleeper.sleep(sleep); } catch (InterruptedException ie) { Thread.currentThread().interrupt(); @@ -285,7 +258,7 @@ protected final void executeSafely(Runnable runnable) { } public final void maybeSendHealthCheck(Instant lastSendThreshold) { - if (!clientClosed.get() && lastSendTimeMs.get() < lastSendThreshold.getMillis()) { + if (!clientClosed && debugMonitor.lastSendTimeMs() < lastSendThreshold.getMillis()) { try { sendHealthCheck(); } catch (RuntimeException e) { @@ -303,28 +276,19 @@ public final void maybeSendHealthCheck(Instant lastSendThreshold) { */ public final void appendSummaryHtml(PrintWriter writer) { appendSpecificHtml(writer); - if (restartCount.get() > 0) { - writer.format( - ", %d restarts, last restart reason [ %s ] at [%s], %d errors", - restartCount.get(), lastRestartReason.get(), lastRestartTime.get(), errorCount.get()); - } - if (clientClosed.get()) { + debugMonitor.printRestartsHtml(writer); + if (clientClosed) { writer.write(", client closed"); } long nowMs = Instant.now().getMillis(); - long sleepLeft = sleepUntil.get() - nowMs; + long sleepLeft = debugMonitor.sleepLeft(); if (sleepLeft > 0) { writer.format(", %dms backoff remaining", sleepLeft); } + debugMonitor.printSummaryHtml(writer, nowMs); writer.format( - ", current stream is %dms old, last send %dms, last response %dms, closed: %s, " - + "isShutdown: %s, shutdown time: %s", - debugDuration(nowMs, startTimeMs.get()), - debugDuration(nowMs, lastSendTimeMs.get()), - debugDuration(nowMs, lastResponseTimeMs.get()), - streamClosed.get(), - isShutdown, - shutdownTime.get()); + ", closed: %s, " + "isShutdown: %s, shutdown time: %s", + streamClosed, isShutdown, debugMonitor.shutdownTime()); } /** @@ -336,9 +300,9 @@ public final void appendSummaryHtml(PrintWriter writer) { @Override public final synchronized void halfClose() { // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. - clientClosed.set(true); + clientClosed = true; requestObserver.onCompleted(); - streamClosed.set(true); + streamClosed = true; } @Override @@ -348,7 +312,7 @@ public final boolean awaitTermination(int time, TimeUnit unit) throws Interrupte @Override public final Instant startTime() { - return new Instant(startTimeMs.get()); + return new Instant(debugMonitor.startTimeMs()); } @Override @@ -363,71 +327,15 @@ public final void shutdown() { synchronized (shutdownLock) { if (!isShutdown) { isShutdown = true; - shutdownTime.set(DateTime.now()); - if (started) { - // requestObserver is not set until the first startStream() is called. If the stream was - // never started there is nothing to clean up internally. - requestObserver.onError( - new WindmillStreamShutdownException("Explicit call to shutdown stream.")); - shutdownInternal(); - } + debugMonitor.recordShutdown(); + requestObserver.poison(); + shutdownInternal(); } } } - private void recordRestartReason(String error) { - lastRestartReason.set(error); - lastRestartTime.set(DateTime.now()); - } - protected abstract void shutdownInternal(); - /** - * Request observer that allows resetting its internal delegate using the given {@link - * #requestObserverSupplier}. - * - * @implNote {@link StreamObserver}s generated by {@link * #requestObserverSupplier} are expected - * to be {@link ThreadSafe}. - */ - @ThreadSafe - private static class ResettableRequestObserver implements StreamObserver { - - private final Supplier> requestObserverSupplier; - - @GuardedBy("this") - private @Nullable StreamObserver delegateRequestObserver; - - private ResettableRequestObserver(Supplier> requestObserverSupplier) { - this.requestObserverSupplier = requestObserverSupplier; - this.delegateRequestObserver = null; - } - - private synchronized StreamObserver delegate() { - return Preconditions.checkNotNull( - delegateRequestObserver, - "requestObserver cannot be null. Missing a call to startStream() to initialize."); - } - - private synchronized void reset() { - delegateRequestObserver = requestObserverSupplier.get(); - } - - @Override - public void onNext(RequestT requestT) { - delegate().onNext(requestT); - } - - @Override - public void onError(Throwable throwable) { - delegate().onError(throwable); - } - - @Override - public void onCompleted() { - delegate().onCompleted(); - } - } - private class ResponseObserver implements StreamObserver { @Override @@ -437,7 +345,7 @@ public void onNext(ResponseT response) { } catch (IOException e) { // Ignore. } - lastResponseTimeMs.set(Instant.now().getMillis()); + debugMonitor.recordResponse(); onResponse(response); } @@ -451,7 +359,7 @@ public void onError(Throwable t) { try { long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); + debugMonitor.recordSleep(sleep); sleeper.sleep(sleep); } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -473,16 +381,16 @@ public void onCompleted() { } private void recordStreamStatus(Status status) { - int currentRestartCount = restartCount.incrementAndGet(); + int currentRestartCount = debugMonitor.incrementAndGetRestarts(); if (status.isOk()) { String restartReason = "Stream completed successfully but did not complete requested operations, " + "recreating"; logger.warn(restartReason); - recordRestartReason(restartReason); + debugMonitor.recordRestartReason(restartReason); } else { - int currentErrorCount = errorCount.incrementAndGet(); - recordRestartReason(status.toString()); + int currentErrorCount = debugMonitor.incrementAndGetErrors(); + debugMonitor.recordRestartReason(status.toString()); Throwable t = status.getCause(); if (t instanceof StreamObserverCancelledException) { logger.error( @@ -494,11 +402,6 @@ private void recordStreamStatus(Status status) { } else if (currentRestartCount % logEveryNStreamFailures == 0) { // Don't log every restart since it will get noisy, and many errors transient. long nowMillis = Instant.now().getMillis(); - String responseDebug = - lastResponseTimeMs.get() == 0 - ? "never received response" - : "received response " + (nowMillis - lastResponseTimeMs.get()) + "ms ago"; - logger.debug( "{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}" + " with status: {}. created {}ms ago; {}. This is normal with autoscaling.", @@ -507,8 +410,8 @@ private void recordStreamStatus(Status status) { currentErrorCount, t, status, - nowMillis - startTimeMs.get(), - responseDebug); + nowMillis - debugMonitor.startTimeMs(), + debugMonitor.responseDebugString(nowMillis)); } // If the stream was stopped due to a resource exhausted error then we are throttled. @@ -520,7 +423,7 @@ private void recordStreamStatus(Status status) { /** Returns true if the stream was torn down and should not be restarted internally. */ private synchronized boolean maybeTeardownStream() { - if (isShutdown || (clientClosed.get() && !hasPendingRequests())) { + if (isShutdown || (clientClosed && !hasPendingRequests())) { streamRegistry.remove(AbstractWindmillStream.this); finishLatch.countDown(); executor.shutdownNow(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java new file mode 100644 index 000000000000..d31f21117ff8 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; + +/** + * Request observer that allows resetting its internal delegate using the given {@link + * #streamObserverFactory}. + * + * @implNote {@link StreamObserver}s generated by {@link * #requestObserverSupplier} are expected to + * be {@link ThreadSafe}. + */ +@ThreadSafe +@Internal +final class ResettableStreamObserver implements StreamObserver { + private final Supplier> streamObserverFactory; + + @GuardedBy("this") + private @Nullable StreamObserver delegateStreamObserver; + + /** + * Indicates that the request observer should no longer be used. Attempts to perform operations on + * the request observer will throw an {@link WindmillStreamShutdownException}. + */ + @GuardedBy("this") + private boolean isPoisoned; + + ResettableStreamObserver(Supplier> streamObserverFactory) { + this.streamObserverFactory = streamObserverFactory; + this.delegateStreamObserver = null; + this.isPoisoned = false; + } + + private synchronized StreamObserver delegate() { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Explicit call to shutdown stream."); + } + return Preconditions.checkNotNull( + delegateStreamObserver, + "requestObserver cannot be null. Missing a call to startStream() to initialize."); + } + + synchronized void reset() { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Explicit call to shutdown stream."); + } + + delegateStreamObserver = streamObserverFactory.get(); + } + + synchronized void poison() { + if (!isPoisoned) { + isPoisoned = true; + if (delegateStreamObserver != null) { + delegateStreamObserver.onError( + new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + } + } + } + + @Override + public void onNext(T t) { + delegate().onNext(t); + } + + @Override + public void onError(Throwable throwable) { + delegate().onError(throwable); + } + + @Override + public void onCompleted() { + delegate().onCompleted(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java new file mode 100644 index 000000000000..85e2d2402fb3 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.joda.time.DateTime; +import org.joda.time.Instant; + +/** Records stream events for debugging. */ +@ThreadSafe +final class StreamDebugMetrics { + private final AtomicInteger restartCount = new AtomicInteger(); + private final AtomicInteger errorCount = new AtomicInteger(); + + @GuardedBy("this") + private long sleepUntil = 0; + + @GuardedBy("this") + private String lastRestartReason = ""; + + @GuardedBy("this") + private DateTime lastRestartTime = null; + + @GuardedBy("this") + private long lastResponseTimeMs = 0; + + @GuardedBy("this") + private long lastSendTimeMs = 0; + + @GuardedBy("this") + private long startTimeMs = 0; + + @GuardedBy("this") + private DateTime shutdownTime = null; + + private static long debugDuration(long nowMs, long startMs) { + return startMs <= 0 ? -1 : Math.max(0, nowMs - startMs); + } + + private static long nowMs() { + return Instant.now().getMillis(); + } + + synchronized void recordSend() { + lastSendTimeMs = nowMs(); + } + + synchronized void recordStart() { + startTimeMs = nowMs(); + lastResponseTimeMs = 0; + } + + synchronized void recordResponse() { + lastResponseTimeMs = nowMs(); + } + + synchronized void recordRestartReason(String error) { + lastRestartReason = error; + lastRestartTime = DateTime.now(); + } + + synchronized long startTimeMs() { + return startTimeMs; + } + + synchronized long lastSendTimeMs() { + return lastSendTimeMs; + } + + synchronized void recordSleep(long sleepMs) { + sleepUntil = nowMs() + sleepMs; + } + + synchronized long sleepLeft() { + return sleepUntil - nowMs(); + } + + int incrementAndGetRestarts() { + return restartCount.incrementAndGet(); + } + + int incrementAndGetErrors() { + return errorCount.incrementAndGet(); + } + + synchronized void recordShutdown() { + shutdownTime = DateTime.now(); + } + + synchronized String responseDebugString(long nowMillis) { + return lastResponseTimeMs == 0 + ? "never received response" + : "received response " + (nowMillis - lastResponseTimeMs) + "ms ago"; + } + + void printRestartsHtml(PrintWriter writer) { + if (restartCount.get() > 0) { + synchronized (this) { + writer.format( + ", %d restarts, last restart reason [ %s ] at [%s], %d errors", + restartCount.get(), lastRestartReason, lastRestartTime, errorCount.get()); + } + } + } + + synchronized DateTime shutdownTime() { + return shutdownTime; + } + + synchronized void printSummaryHtml(PrintWriter writer, long nowMs) { + writer.format( + ", current stream is %dms old, last send %dms, last response %dms", + debugDuration(nowMs, startTimeMs), + debugDuration(nowMs, lastSendTimeMs), + debugDuration(nowMs, lastResponseTimeMs)); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index a5d5b1882fd7..6a809712bd9f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -161,7 +161,7 @@ protected synchronized void onNewStream() { } send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); - if (clientClosed.get() && !isShutdown()) { + if (clientClosed && !isShutdown()) { // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java new file mode 100644 index 000000000000..538da9607f8b --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ResettableStreamObserverTest { + private final StreamObserver delegate = + spy( + new StreamObserver() { + @Override + public void onNext(Integer integer) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }); + + @Test + public void testPoison_beforeDelegateSet() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + verifyNoInteractions(delegate); + } + + @Test + public void testPoison_afterDelegateSet() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.reset(); + observer.poison(); + verify(delegate).onError(isA(WindmillStreamShutdownException.class)); + } + + @Test + public void testReset_afterPoisonedThrows() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, observer::reset); + } + + @Test + public void onNext_afterPoisonedThrows() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, () -> observer.onNext(1)); + } + + @Test + public void onError_afterPoisonedThrows() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + assertThrows( + WindmillStreamShutdownException.class, + () -> observer.onError(new RuntimeException("something bad happened."))); + } + + @Test + public void onCompleted_afterPoisonedThrows() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, observer::onCompleted); + } +}