Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Nov 13, 2024
1 parent 95a19f4 commit 42836f4
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownExce
// If the delegate above was already terminated via onError or onComplete from another
// thread.
logger.warn("StreamObserver was previously cancelled.", e);
} catch (RuntimeException ignored) {
logger.warn("StreamObserver was unexpectedly cancelled.", e);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,14 @@ protected void onResponse(StreamingGetDataResponse chunk) {
onHeartbeatResponse(chunk.getComputationHeartbeatResponseList());

for (int i = 0; i < chunk.getRequestIdCount(); ++i) {
AppendableInputStream responseStream = pending.get(chunk.getRequestId(i));
synchronized (this) {
verify(responseStream != null || isShutdown, "No pending response stream");
@Nullable AppendableInputStream responseStream = pending.get(chunk.getRequestId(i));
if (responseStream == null) {
synchronized (this) {
// shutdown()/shutdownInternal() cleans up pending, else we expect a pending
// responseStream for every response.
verify(isShutdown, "No pending response stream");
}
continue;
}
responseStream.append(chunk.getSerializedResponse(i).newInput());
if (chunk.getRemainingBytesForResponse() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,28 @@
* becomes ready.
*/
@ThreadSafe
public final class DirectStreamObserver<T> implements TerminatingStreamObserver<T> {
final class DirectStreamObserver<T> implements TerminatingStreamObserver<T> {
private static final Logger LOG = LoggerFactory.getLogger(DirectStreamObserver.class);
private static final long OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS = 30;

private final Phaser isReadyNotifier;

private final long deadlineSeconds;
private final int messagesBetweenIsReadyChecks;
private final Object lock = new Object();

@GuardedBy("lock")
private final CallStreamObserver<T> outboundObserver;

private final long deadlineSeconds;
private final int messagesBetweenIsReadyChecks;

@GuardedBy("lock")
private boolean isClosed = false;

@GuardedBy("lock")
private boolean isUserClosed = false;

@GuardedBy("lock")
private int messagesSinceReady = 0;

public DirectStreamObserver(
DirectStreamObserver(
Phaser isReadyNotifier,
CallStreamObserver<T> outboundObserver,
long deadlineSeconds,
Expand Down Expand Up @@ -89,6 +90,9 @@ public void onNext(T value) throws StreamObserverCancelledException {
throw new StreamObserverCancelledException("StreamObserver was terminated.");
}

// We close under "lock", so this should never happen.
assert !isClosed;

// If we awaited previously and timed out, wait for the same phase. Otherwise we're
// careful to observe the phase before observing isReady.
if (awaitPhase < 0) {
Expand Down Expand Up @@ -131,6 +135,10 @@ public void onNext(T value) throws StreamObserverCancelledException {
if (currentPhase < 0) {
throw new StreamObserverCancelledException("StreamObserver was terminated.");
}

// We close under "lock", so this should never happen.
assert !isClosed;

messagesSinceReady = 0;
outboundObserver.onNext(value);
return;
Expand Down Expand Up @@ -162,24 +170,23 @@ public void onNext(T value) throws StreamObserverCancelledException {
public void onError(Throwable t) {
isReadyNotifier.forceTermination();
synchronized (lock) {
markClosedOrThrow();
outboundObserver.onError(t);
if (!isClosed) {
Preconditions.checkState(!isUserClosed);
outboundObserver.onError(t);
isClosed = true;
}
}
}

@Override
public void onCompleted() {
isReadyNotifier.forceTermination();
synchronized (lock) {
markClosedOrThrow();
outboundObserver.onCompleted();
}
}

private void markClosedOrThrow() {
synchronized (lock) {
Preconditions.checkState(!isClosed);
isClosed = true;
if (!isClosed) {
Preconditions.checkState(!isUserClosed);
outboundObserver.onCompleted();
isClosed = true;
}
}
}

Expand All @@ -188,8 +195,9 @@ public void terminate(Throwable terminationException) {
// Free the blocked threads in onNext().
isReadyNotifier.forceTermination();
synchronized (lock) {
if (!isClosed) {
if (!isUserClosed) {
onError(terminationException);
isUserClosed = true;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public void setMessageCompression(boolean b) {}
testStream.shutdown();

// Sleep a bit to give sendExecutor time to execute the send().
Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS);
Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);

sendBlocker.countDown();
assertThat(sendFuture.get()).isInstanceOf(WindmillStreamShutdownException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.io.IOException;
import java.util.HashSet;
Expand Down Expand Up @@ -53,12 +51,9 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.InOrder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@RunWith(JUnit4.class)
public class GrpcCommitWorkStreamTest {
private static final Logger LOG = LoggerFactory.getLogger(GrpcCommitWorkStreamTest.class);
private static final String FAKE_SERVER_NAME = "Fake server for GrpcCommitWorkStreamTest";
private static final Windmill.JobHeader TEST_JOB_HEADER =
Windmill.JobHeader.newBuilder()
Expand Down Expand Up @@ -126,6 +121,7 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException {
spy(new TestCommitWorkStreamRequestObserver());
CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver);
GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
InOrder requestObserverVerifier = inOrder(requestObserver);
try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) {
for (int i = 0; i < numCommits; i++) {
batcher.commitWorkItem(
Expand All @@ -140,21 +136,14 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException {
}

// Verify that we sent the commits above in a request + the initial header.
verify(requestObserver, times(2))
.onNext(
argThat(
request -> {
if (request.getHeader().equals(TEST_JOB_HEADER)) {
LOG.info("Header received.");
return true;
} else if (!request.getCommitChunkList().isEmpty()) {
LOG.info("Chunk received.");
return true;
} else {
LOG.error("Incorrect request.");
return false;
}
}));
requestObserverVerifier
.verify(requestObserver)
.onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER)));
requestObserverVerifier
.verify(requestObserver)
.onNext(argThat(request -> !request.getCommitChunkList().isEmpty()));
requestObserverVerifier.verifyNoMoreInteractions();

// We won't get responses so we will have some pending requests.
assertTrue(commitWorkStream.hasPendingRequests());
commitWorkStream.shutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void testQueuedBatch_notifyFailed_throwsWindmillStreamShutdownExceptionOn
WindmillStreamShutdownException.class,
queuedBatch::waitForSendOrFailNotification));
// Wait a few seconds for the above future to get scheduled and run.
Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS);
Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
queuedBatch.notifyFailed();
waitFuture.join();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void testRequestKeyedData() {
});

// Sleep a bit to allow future to run.
Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS);
Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);

Windmill.KeyedGetDataResponse response =
Windmill.KeyedGetDataResponse.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ private void flushResponse() {
done.countDown();
});
}
while (done.await(5, TimeUnit.SECONDS)) {}
done.await();
stream.halfClose();
assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS));
executor.shutdown();
Expand Down
Loading

0 comments on commit 42836f4

Please sign in to comment.