Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in SqsWorker when acknowledgements are enabled #3001

Merged
merged 8 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,12 @@ public interface AcknowledgementSet {
* @since 2.2
*/
public boolean release(final EventHandle eventHandle, final boolean result);

/**
* Indicates that the addition of initial set of events to
* the acknowledgement set is completed.
* It is possible that more events are added to the set as the
* initial events are going through the pipeline line.
*/
public void complete();
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,18 @@ public void run() {
while (!isStopped) {
try {
final List<Record<Event>> records = inMemorySourceAccessor.read(testingKey);
if (records.size() == 0) {
Thread.sleep(1000);
continue;
}
AcknowledgementSet ackSet =
acknowledgementSetManager.create((result) ->
{
inMemorySourceAccessor.setAckReceived(result);
},
Duration.ofSeconds(15));
records.stream().forEach((record) -> { ackSet.add(record.getData()); });
ackSet.complete();
writeToBuffer(records);
} catch (final Exception ex) {
LOG.error("Error during source loop.", ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class DefaultAcknowledgementSet implements AcknowledgementSet {
private final Map<EventHandle, AtomicInteger> pendingAcknowledgments;
private Future<?> callbackFuture;
private final DefaultAcknowledgementSetMetrics metrics;
private boolean completed;

public DefaultAcknowledgementSet(final ExecutorService executor, final Consumer<Boolean> callback, final Duration expiryTime, final DefaultAcknowledgementSetMetrics metrics) {
this.callback = callback;
Expand All @@ -42,6 +43,7 @@ public DefaultAcknowledgementSet(final ExecutorService executor, final Consumer<
this.expiryTime = Instant.now().plusMillis(expiryTime.toMillis());
this.callbackFuture = null;
this.metrics = metrics;
this.completed = false;
pendingAcknowledgments = new HashMap<>();
lock = new ReentrantLock(true);
}
Expand Down Expand Up @@ -84,6 +86,8 @@ public boolean isDone() {
if (Instant.now().isAfter(expiryTime)) {
if (callbackFuture != null) {
callbackFuture.cancel(true);
callbackFuture = null;
LOG.warn("AcknowledgementSet expired");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any other data we can provide to make this useful if it happens? I'd imagine there may be a lot of these and they could be hard to understand.

Perhaps include the expiry time? Could we include the overall time?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am expecting there shouldn't be a lot of these. Since we only support one source, the acknowledgement set and expiry time should be obvious. We could pass a name to AcknowledgementSetManager.create when we support multiple sources.

}
metrics.increment(DefaultAcknowledgementSetMetrics.EXPIRED_METRIC_NAME);
return true;
Expand All @@ -98,6 +102,19 @@ public Instant getExpiryTime() {
return expiryTime;
}

@Override
public void complete() {
lock.lock();
try {
completed = true;
if (pendingAcknowledgments.size() == 0) {
callbackFuture = executor.submit(() -> callback.accept(this.result));
}
} finally {
lock.unlock();
}
}

@Override
public boolean release(final EventHandle eventHandle, final boolean result) {
lock.lock();
Expand All @@ -114,9 +131,11 @@ public boolean release(final EventHandle eventHandle, final boolean result) {
}
if (pendingAcknowledgments.get(eventHandle).decrementAndGet() == 0) {
pendingAcknowledgments.remove(eventHandle);
if (pendingAcknowledgments.size() == 0) {
if (completed && pendingAcknowledgments.size() == 0) {
callbackFuture = executor.submit(() -> callback.accept(this.result));
return true;
} else if (pendingAcknowledgments.size() == 0) {
LOG.warn("Acknowledgement set is not completed. Delaying callback until it is completed");
}
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void setup() {
AcknowledgementSet acknowledgementSet1 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT_MS);
acknowledgementSet1.add(event1);
acknowledgementSet1.add(event2);
acknowledgementSet1.complete();
}

DefaultAcknowledgementSetManager createObjectUnderTest() {
Expand Down Expand Up @@ -98,6 +99,7 @@ void testMultipleAcknowledgementSets() throws InterruptedException {

AcknowledgementSet acknowledgementSet2 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT_MS);
acknowledgementSet2.add(event3);
acknowledgementSet2.complete();

acknowledgementSetManager.releaseEventReference(eventHandle2, true);
acknowledgementSetManager.releaseEventReference(eventHandle3, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void setupEvent() {
@Test
void testDefaultAcknowledgementSetBasic() throws Exception {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true));
Expand All @@ -97,6 +98,7 @@ void testDefaultAcknowledgementSetBasic() throws Exception {
@Test
void testDefaultAcknowledgementSetMultipleAcquireAndRelease() throws Exception {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
defaultAcknowledgementSet.acquire(handle);
Expand All @@ -111,6 +113,7 @@ void testDefaultAcknowledgementSetMultipleAcquireAndRelease() throws Exception {
@Test
void testDefaultAcknowledgementInvalidAcquire() {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
DefaultAcknowledgementSet secondAcknowledgementSet = createObjectUnderTest();
DefaultEventHandle handle2 = new DefaultEventHandle(secondAcknowledgementSet);
defaultAcknowledgementSet.acquire(handle2);
Expand All @@ -120,6 +123,7 @@ void testDefaultAcknowledgementInvalidAcquire() {
@Test
void testDefaultAcknowledgementInvalidRelease() {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
DefaultAcknowledgementSet secondAcknowledgementSet = createObjectUnderTest();
DefaultEventHandle handle2 = new DefaultEventHandle(secondAcknowledgementSet);
assertThat(defaultAcknowledgementSet.release(handle2, true), equalTo(false));
Expand All @@ -129,6 +133,7 @@ void testDefaultAcknowledgementInvalidRelease() {
@Test
void testDefaultAcknowledgementDuplicateReleaseError() throws Exception {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true));
Expand All @@ -144,6 +149,7 @@ void testDefaultAcknowledgementSetWithCustomCallback() throws Exception {
}
);
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true));
Expand All @@ -162,6 +168,7 @@ void testDefaultAcknowledgementSetNegativeAcknowledgements() throws Exception {
}
);
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
defaultAcknowledgementSet.acquire(handle);
Expand Down Expand Up @@ -190,6 +197,7 @@ void testDefaultAcknowledgementSetExpirations() throws Exception {
}
);
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ public <T> void consumeRecords() throws Exception {
if (!acknowledgementsEnabled) {
offsets.forEach((partition, offsetRange) ->
updateOffsetsToCommit(partition, new OffsetAndMetadata(offsetRange.getMaximum() + 1)));
} else {
acknowledgementSet.complete();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
import org.opensearch.dataprepper.plugins.source.configuration.OnErrorOption;
import org.opensearch.dataprepper.plugins.source.configuration.SqsOptions;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.acknowledgements.DefaultAcknowledgementSetManager;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.JacksonEvent;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import org.junit.jupiter.api.AfterEach;
Expand All @@ -28,6 +32,9 @@
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
Expand All @@ -36,21 +43,34 @@
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.core.StringStartsWith.startsWith;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.lenient;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.doAnswer;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
class SqsWorkerIT {
private SqsClient sqsClient;
@Mock
private S3Service s3Service;
private S3SourceConfig s3SourceConfig;
private PluginMetrics pluginMetrics;
private S3ObjectGenerator s3ObjectGenerator;
private String bucket;
private Backoff backoff;
private AcknowledgementSetManager acknowledgementSetManager;
private Double receivedCount = 0.0;
private Double deletedCount = 0.0;
private Double ackCallbackCount = 0.0;
private Event event;
private AtomicBoolean ready = new AtomicBoolean(false);

@BeforeEach
void setUp() {
Expand All @@ -76,8 +96,8 @@ void setUp() {
final DistributionSummary distributionSummary = mock(DistributionSummary.class);
final Timer sqsMessageDelayTimer = mock(Timer.class);

when(pluginMetrics.counter(anyString())).thenReturn(sharedCounter);
when(pluginMetrics.summary(anyString())).thenReturn(distributionSummary);
lenient().when(pluginMetrics.counter(anyString())).thenReturn(sharedCounter);
lenient().when(pluginMetrics.summary(anyString())).thenReturn(distributionSummary);
when(pluginMetrics.timer(anyString())).thenReturn(sqsMessageDelayTimer);

final SqsOptions sqsOptions = mock(SqsOptions.class);
Expand All @@ -86,7 +106,7 @@ void setUp() {
when(sqsOptions.getMaximumMessages()).thenReturn(10);
when(sqsOptions.getWaitTime()).thenReturn(Duration.ofSeconds(10));
when(s3SourceConfig.getSqsOptions()).thenReturn(sqsOptions);
when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES);
lenient().when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES);
when(s3SourceConfig.getNotificationSource()).thenReturn(NotificationSourceOption.S3);
}

Expand Down Expand Up @@ -127,6 +147,140 @@ void processSqsMessages_should_return_at_least_one_message(final int numberOfObj
assertThat(sqsMessagesProcessed, lessThanOrEqualTo(numberOfObjectsToWrite));
}

@ParameterizedTest
@ValueSource(ints = {1})
void processSqsMessages_should_return_at_least_one_message_with_acks_with_callback_invoked_after_processS3Object_finishes(final int numberOfObjectsToWrite) throws IOException, InterruptedException {
writeToS3(numberOfObjectsToWrite);

when(s3SourceConfig.getAcknowledgements()).thenReturn(true);
final Counter receivedCounter = mock(Counter.class);
final Counter deletedCounter = mock(Counter.class);
final Counter ackCallbackCounter = mock(Counter.class);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(receivedCounter);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(deletedCounter);
when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(ackCallbackCounter);
lenient().doAnswer((val) -> {
receivedCount += (double)val.getArgument(0);
return null;
}).when(receivedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
if (val.getArgument(0) != null) {
deletedCount += (double)val.getArgument(0);
}
return null;
}).when(deletedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
ackCallbackCount += 1;
return null;
}).when(ackCallbackCounter).increment();

doAnswer((val) -> {
AcknowledgementSet ackSet = val.getArgument(1);
S3ObjectReference s3ObjectReference = val.getArgument(0);
assertThat(s3ObjectReference.getBucketName(), equalTo(bucket));
assertThat(s3ObjectReference.getKey(), startsWith("s3 source/sqs/"));
event = (Event)JacksonEvent.fromMessage(val.getArgument(0).toString());
ackSet.add(event);
return null;
}).when(s3Service).addS3Object(any(S3ObjectReference.class), any(AcknowledgementSet.class));
ExecutorService executor = Executors.newFixedThreadPool(2);
acknowledgementSetManager = new DefaultAcknowledgementSetManager(executor);
final SqsWorker objectUnderTest = createObjectUnderTest();
Thread sinkThread = new Thread(() -> {
try {
synchronized(this) {
while (!ready.get()) {
Thread.sleep(100);
this.wait();
}
if (event.getEventHandle() != null) {
event.getEventHandle().release(true);
}
}
} catch (Exception e){}
});
sinkThread.start();
final int sqsMessagesProcessed = objectUnderTest.processSqsMessages();
synchronized(this) {
ready.set(true);
this.notify();
}
Thread.sleep(10000);

assertThat(deletedCount, equalTo((double)1.0));
assertThat(ackCallbackCount, equalTo((double)1.0));
}

@ParameterizedTest
@ValueSource(ints = {1})
void processSqsMessages_should_return_at_least_one_message_with_acks_with_callback_invoked_before_processS3Object_finishes(final int numberOfObjectsToWrite) throws IOException, InterruptedException {
writeToS3(numberOfObjectsToWrite);

when(s3SourceConfig.getAcknowledgements()).thenReturn(true);
final Counter receivedCounter = mock(Counter.class);
final Counter deletedCounter = mock(Counter.class);
final Counter ackCallbackCounter = mock(Counter.class);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(receivedCounter);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(deletedCounter);
when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(ackCallbackCounter);
lenient().doAnswer((val) -> {
receivedCount += (double)val.getArgument(0);
return null;
}).when(receivedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
if (val.getArgument(0) != null) {
deletedCount += (double)val.getArgument(0);
}
return null;
}).when(deletedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
ackCallbackCount += 1;
return null;
}).when(ackCallbackCounter).increment();

doAnswer((val) -> {
AcknowledgementSet ackSet = val.getArgument(1);
S3ObjectReference s3ObjectReference = val.getArgument(0);
assertThat(s3ObjectReference.getBucketName(), equalTo(bucket));
assertThat(s3ObjectReference.getKey(), startsWith("s3 source/sqs/"));
event = (Event)JacksonEvent.fromMessage(val.getArgument(0).toString());

ackSet.add(event);
synchronized(this) {
ready.set(true);
this.notify();
}
try {
Thread.sleep(4000);
} catch (Exception e){}

return null;
}).when(s3Service).addS3Object(any(S3ObjectReference.class), any(AcknowledgementSet.class));
ExecutorService executor = Executors.newFixedThreadPool(2);
acknowledgementSetManager = new DefaultAcknowledgementSetManager(executor);
final SqsWorker objectUnderTest = createObjectUnderTest();
Thread sinkThread = new Thread(() -> {
try {
synchronized(this) {
while (!ready.get()) {
Thread.sleep(100);
this.wait();
}
if (event.getEventHandle() != null) {
event.getEventHandle().release(true);
}
}
} catch (Exception e){}
});
sinkThread.start();
final int sqsMessagesProcessed = objectUnderTest.processSqsMessages();

Thread.sleep(10000);

assertThat(deletedCount, equalTo((double)1.0));
assertThat(ackCallbackCount, equalTo((double)1.0));
}

/** The EventBridge test is disabled by default
* To run this test run only this one test with S3 bucket configured to use EventBridge to send notifications to SQS
*/
Expand Down
Loading
Loading