Skip to content

Commit

Permalink
Fix race condition in SqsWorker when acknowledgements are enabled (#3001
Browse files Browse the repository at this point in the history
) (#3010)

* Fix race condition in SqsWorker when acknowledgements are enabled

Signed-off-by: Krishna Kondaka <[email protected]>

* Modified to do the synchronization in the acknowledgement set framework

Signed-off-by: Krishna Kondaka <[email protected]>

* Fixed failing tests

Signed-off-by: Krishna Kondaka <[email protected]>

* Removed unused variable

Signed-off-by: Krishna Kondaka <[email protected]>

* Addressed review comment and fixed failing tests

Signed-off-by: Krishna Kondaka <[email protected]>

* Addressed review comments

Signed-off-by: Krishna Kondaka <[email protected]>

* Fixed failing tests

Signed-off-by: Krishna Kondaka <[email protected]>

* Fixed checkStyle failure

Signed-off-by: Krishna Kondaka <[email protected]>

---------

Signed-off-by: Krishna Kondaka <[email protected]>
Co-authored-by: Krishna Kondaka <[email protected]>
(cherry picked from commit 515cf61)
  • Loading branch information
kkondaka committed Jul 12, 2023
1 parent 62933f6 commit bb44726
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,12 @@ public interface AcknowledgementSet {
* @since 2.2
*/
public boolean release(final EventHandle eventHandle, final boolean result);

/**
* Indicates that the addition of initial set of events to
* the acknowledgement set is completed.
* It is possible that more events are added to the set as the
* initial events are going through the pipeline line.
*/
public void complete();
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,18 @@ public void run() {
while (!isStopped) {
try {
final List<Record<Event>> records = inMemorySourceAccessor.read(testingKey);
if (records.size() == 0) {
Thread.sleep(1000);
continue;
}
AcknowledgementSet ackSet =
acknowledgementSetManager.create((result) ->
{
inMemorySourceAccessor.setAckReceived(result);
},
Duration.ofSeconds(15));
records.stream().forEach((record) -> { ackSet.add(record.getData()); });
ackSet.complete();
writeToBuffer(records);
} catch (final Exception ex) {
LOG.error("Error during source loop.", ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class DefaultAcknowledgementSet implements AcknowledgementSet {
private final Map<EventHandle, AtomicInteger> pendingAcknowledgments;
private Future<?> callbackFuture;
private final DefaultAcknowledgementSetMetrics metrics;
private boolean completed;

public DefaultAcknowledgementSet(final ExecutorService executor, final Consumer<Boolean> callback, final Duration expiryTime, final DefaultAcknowledgementSetMetrics metrics) {
this.callback = callback;
Expand All @@ -42,6 +43,7 @@ public DefaultAcknowledgementSet(final ExecutorService executor, final Consumer<
this.expiryTime = Instant.now().plusMillis(expiryTime.toMillis());
this.callbackFuture = null;
this.metrics = metrics;
this.completed = false;
pendingAcknowledgments = new HashMap<>();
lock = new ReentrantLock(true);
}
Expand Down Expand Up @@ -84,6 +86,8 @@ public boolean isDone() {
if (Instant.now().isAfter(expiryTime)) {
if (callbackFuture != null) {
callbackFuture.cancel(true);
callbackFuture = null;
LOG.warn("AcknowledgementSet expired");
}
metrics.increment(DefaultAcknowledgementSetMetrics.EXPIRED_METRIC_NAME);
return true;
Expand All @@ -98,6 +102,19 @@ public Instant getExpiryTime() {
return expiryTime;
}

@Override
public void complete() {
lock.lock();
try {
completed = true;
if (pendingAcknowledgments.size() == 0) {
callbackFuture = executor.submit(() -> callback.accept(this.result));
}
} finally {
lock.unlock();
}
}

@Override
public boolean release(final EventHandle eventHandle, final boolean result) {
lock.lock();
Expand All @@ -114,9 +131,11 @@ public boolean release(final EventHandle eventHandle, final boolean result) {
}
if (pendingAcknowledgments.get(eventHandle).decrementAndGet() == 0) {
pendingAcknowledgments.remove(eventHandle);
if (pendingAcknowledgments.size() == 0) {
if (completed && pendingAcknowledgments.size() == 0) {
callbackFuture = executor.submit(() -> callback.accept(this.result));
return true;
} else if (pendingAcknowledgments.size() == 0) {
LOG.warn("Acknowledgement set is not completed. Delaying callback until it is completed");
}
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void setup() {
AcknowledgementSet acknowledgementSet1 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT_MS);
acknowledgementSet1.add(event1);
acknowledgementSet1.add(event2);
acknowledgementSet1.complete();
}

DefaultAcknowledgementSetManager createObjectUnderTest() {
Expand Down Expand Up @@ -98,6 +99,7 @@ void testMultipleAcknowledgementSets() throws InterruptedException {

AcknowledgementSet acknowledgementSet2 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT_MS);
acknowledgementSet2.add(event3);
acknowledgementSet2.complete();

acknowledgementSetManager.releaseEventReference(eventHandle2, true);
acknowledgementSetManager.releaseEventReference(eventHandle3, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void setupEvent() {
@Test
void testDefaultAcknowledgementSetBasic() throws Exception {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true));
Expand All @@ -97,6 +98,7 @@ void testDefaultAcknowledgementSetBasic() throws Exception {
@Test
void testDefaultAcknowledgementSetMultipleAcquireAndRelease() throws Exception {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
defaultAcknowledgementSet.acquire(handle);
Expand All @@ -111,6 +113,7 @@ void testDefaultAcknowledgementSetMultipleAcquireAndRelease() throws Exception {
@Test
void testDefaultAcknowledgementInvalidAcquire() {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
DefaultAcknowledgementSet secondAcknowledgementSet = createObjectUnderTest();
DefaultEventHandle handle2 = new DefaultEventHandle(secondAcknowledgementSet);
defaultAcknowledgementSet.acquire(handle2);
Expand All @@ -120,6 +123,7 @@ void testDefaultAcknowledgementInvalidAcquire() {
@Test
void testDefaultAcknowledgementInvalidRelease() {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
DefaultAcknowledgementSet secondAcknowledgementSet = createObjectUnderTest();
DefaultEventHandle handle2 = new DefaultEventHandle(secondAcknowledgementSet);
assertThat(defaultAcknowledgementSet.release(handle2, true), equalTo(false));
Expand All @@ -129,6 +133,7 @@ void testDefaultAcknowledgementInvalidRelease() {
@Test
void testDefaultAcknowledgementDuplicateReleaseError() throws Exception {
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true));
Expand All @@ -144,6 +149,7 @@ void testDefaultAcknowledgementSetWithCustomCallback() throws Exception {
}
);
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true));
Expand All @@ -162,6 +168,7 @@ void testDefaultAcknowledgementSetNegativeAcknowledgements() throws Exception {
}
);
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
defaultAcknowledgementSet.acquire(handle);
Expand Down Expand Up @@ -190,6 +197,7 @@ void testDefaultAcknowledgementSetExpirations() throws Exception {
}
);
defaultAcknowledgementSet.add(event);
defaultAcknowledgementSet.complete();
assertThat(handle, not(equalTo(null)));
assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet));
assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
import org.opensearch.dataprepper.plugins.source.configuration.OnErrorOption;
import org.opensearch.dataprepper.plugins.source.configuration.SqsOptions;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.acknowledgements.DefaultAcknowledgementSetManager;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.JacksonEvent;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import org.junit.jupiter.api.AfterEach;
Expand All @@ -28,6 +32,9 @@
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
Expand All @@ -36,21 +43,34 @@
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.core.StringStartsWith.startsWith;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.lenient;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.doAnswer;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
class SqsWorkerIT {
private SqsClient sqsClient;
@Mock
private S3Service s3Service;
private S3SourceConfig s3SourceConfig;
private PluginMetrics pluginMetrics;
private S3ObjectGenerator s3ObjectGenerator;
private String bucket;
private Backoff backoff;
private AcknowledgementSetManager acknowledgementSetManager;
private Double receivedCount = 0.0;
private Double deletedCount = 0.0;
private Double ackCallbackCount = 0.0;
private Event event;
private AtomicBoolean ready = new AtomicBoolean(false);

@BeforeEach
void setUp() {
Expand All @@ -76,8 +96,8 @@ void setUp() {
final DistributionSummary distributionSummary = mock(DistributionSummary.class);
final Timer sqsMessageDelayTimer = mock(Timer.class);

when(pluginMetrics.counter(anyString())).thenReturn(sharedCounter);
when(pluginMetrics.summary(anyString())).thenReturn(distributionSummary);
lenient().when(pluginMetrics.counter(anyString())).thenReturn(sharedCounter);
lenient().when(pluginMetrics.summary(anyString())).thenReturn(distributionSummary);
when(pluginMetrics.timer(anyString())).thenReturn(sqsMessageDelayTimer);

final SqsOptions sqsOptions = mock(SqsOptions.class);
Expand All @@ -86,7 +106,7 @@ void setUp() {
when(sqsOptions.getMaximumMessages()).thenReturn(10);
when(sqsOptions.getWaitTime()).thenReturn(Duration.ofSeconds(10));
when(s3SourceConfig.getSqsOptions()).thenReturn(sqsOptions);
when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES);
lenient().when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES);
when(s3SourceConfig.getNotificationSource()).thenReturn(NotificationSourceOption.S3);
}

Expand Down Expand Up @@ -127,6 +147,140 @@ void processSqsMessages_should_return_at_least_one_message(final int numberOfObj
assertThat(sqsMessagesProcessed, lessThanOrEqualTo(numberOfObjectsToWrite));
}

@ParameterizedTest
@ValueSource(ints = {1})
void processSqsMessages_should_return_at_least_one_message_with_acks_with_callback_invoked_after_processS3Object_finishes(final int numberOfObjectsToWrite) throws IOException, InterruptedException {
writeToS3(numberOfObjectsToWrite);

when(s3SourceConfig.getAcknowledgements()).thenReturn(true);
final Counter receivedCounter = mock(Counter.class);
final Counter deletedCounter = mock(Counter.class);
final Counter ackCallbackCounter = mock(Counter.class);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(receivedCounter);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(deletedCounter);
when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(ackCallbackCounter);
lenient().doAnswer((val) -> {
receivedCount += (double)val.getArgument(0);
return null;
}).when(receivedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
if (val.getArgument(0) != null) {
deletedCount += (double)val.getArgument(0);
}
return null;
}).when(deletedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
ackCallbackCount += 1;
return null;
}).when(ackCallbackCounter).increment();

doAnswer((val) -> {
AcknowledgementSet ackSet = val.getArgument(1);
S3ObjectReference s3ObjectReference = val.getArgument(0);
assertThat(s3ObjectReference.getBucketName(), equalTo(bucket));
assertThat(s3ObjectReference.getKey(), startsWith("s3 source/sqs/"));
event = (Event)JacksonEvent.fromMessage(val.getArgument(0).toString());
ackSet.add(event);
return null;
}).when(s3Service).addS3Object(any(S3ObjectReference.class), any(AcknowledgementSet.class));
ExecutorService executor = Executors.newFixedThreadPool(2);
acknowledgementSetManager = new DefaultAcknowledgementSetManager(executor);
final SqsWorker objectUnderTest = createObjectUnderTest();
Thread sinkThread = new Thread(() -> {
try {
synchronized(this) {
while (!ready.get()) {
Thread.sleep(100);
this.wait();
}
if (event.getEventHandle() != null) {
event.getEventHandle().release(true);
}
}
} catch (Exception e){}
});
sinkThread.start();
final int sqsMessagesProcessed = objectUnderTest.processSqsMessages();
synchronized(this) {
ready.set(true);
this.notify();
}
Thread.sleep(10000);

assertThat(deletedCount, equalTo((double)1.0));
assertThat(ackCallbackCount, equalTo((double)1.0));
}

@ParameterizedTest
@ValueSource(ints = {1})
void processSqsMessages_should_return_at_least_one_message_with_acks_with_callback_invoked_before_processS3Object_finishes(final int numberOfObjectsToWrite) throws IOException, InterruptedException {
writeToS3(numberOfObjectsToWrite);

when(s3SourceConfig.getAcknowledgements()).thenReturn(true);
final Counter receivedCounter = mock(Counter.class);
final Counter deletedCounter = mock(Counter.class);
final Counter ackCallbackCounter = mock(Counter.class);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(receivedCounter);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(deletedCounter);
when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(ackCallbackCounter);
lenient().doAnswer((val) -> {
receivedCount += (double)val.getArgument(0);
return null;
}).when(receivedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
if (val.getArgument(0) != null) {
deletedCount += (double)val.getArgument(0);
}
return null;
}).when(deletedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
ackCallbackCount += 1;
return null;
}).when(ackCallbackCounter).increment();

doAnswer((val) -> {
AcknowledgementSet ackSet = val.getArgument(1);
S3ObjectReference s3ObjectReference = val.getArgument(0);
assertThat(s3ObjectReference.getBucketName(), equalTo(bucket));
assertThat(s3ObjectReference.getKey(), startsWith("s3 source/sqs/"));
event = (Event)JacksonEvent.fromMessage(val.getArgument(0).toString());

ackSet.add(event);
synchronized(this) {
ready.set(true);
this.notify();
}
try {
Thread.sleep(4000);
} catch (Exception e){}

return null;
}).when(s3Service).addS3Object(any(S3ObjectReference.class), any(AcknowledgementSet.class));
ExecutorService executor = Executors.newFixedThreadPool(2);
acknowledgementSetManager = new DefaultAcknowledgementSetManager(executor);
final SqsWorker objectUnderTest = createObjectUnderTest();
Thread sinkThread = new Thread(() -> {
try {
synchronized(this) {
while (!ready.get()) {
Thread.sleep(100);
this.wait();
}
if (event.getEventHandle() != null) {
event.getEventHandle().release(true);
}
}
} catch (Exception e){}
});
sinkThread.start();
final int sqsMessagesProcessed = objectUnderTest.processSqsMessages();

Thread.sleep(10000);

assertThat(deletedCount, equalTo((double)1.0));
assertThat(ackCallbackCount, equalTo((double)1.0));
}

/** The EventBridge test is disabled by default
* To run this test run only this one test with S3 bucket configured to use EventBridge to send notifications to SQS
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ && isEventBridgeEventTypeCreated(parsedMessage)) {
final Optional<DeleteMessageBatchRequestEntry> deleteMessageBatchRequestEntry = processS3Object(parsedMessage, s3ObjectReference, acknowledgementSet);
if (endToEndAcknowledgementsEnabled) {
deleteMessageBatchRequestEntry.ifPresent(waitingForAcknowledgements::add);
acknowledgementSet.complete();
} else {
deleteMessageBatchRequestEntry.ifPresent(deleteMessageBatchRequestEntryCollection::add);
}
Expand Down

0 comments on commit bb44726

Please sign in to comment.