Skip to content

Commit

Permalink
Fix race condition in SqsWorker when acknowledgements are enabled
Browse files Browse the repository at this point in the history
Signed-off-by: Krishna Kondaka <[email protected]>
  • Loading branch information
Krishna Kondaka committed Jul 11, 2023
1 parent 45b6e55 commit b2f6a8a
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
import org.opensearch.dataprepper.plugins.source.configuration.OnErrorOption;
import org.opensearch.dataprepper.plugins.source.configuration.SqsOptions;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.acknowledgements.DefaultAcknowledgementSetManager;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.JacksonEvent;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import org.junit.jupiter.api.AfterEach;
Expand All @@ -28,6 +32,9 @@
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
Expand All @@ -36,21 +43,34 @@
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.core.StringStartsWith.startsWith;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.lenient;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.doAnswer;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
class SqsWorkerIT {
private SqsClient sqsClient;
@Mock
private S3Service s3Service;
private S3SourceConfig s3SourceConfig;
private PluginMetrics pluginMetrics;
private S3ObjectGenerator s3ObjectGenerator;
private String bucket;
private Backoff backoff;
private AcknowledgementSetManager acknowledgementSetManager;
private Double receivedCount = 0.0;
private Double deletedCount = 0.0;
private Double ackCallbackCount = 0.0;
private Event event;
private AtomicBoolean ready = new AtomicBoolean(false);

@BeforeEach
void setUp() {
Expand All @@ -76,8 +96,8 @@ void setUp() {
final DistributionSummary distributionSummary = mock(DistributionSummary.class);
final Timer sqsMessageDelayTimer = mock(Timer.class);

when(pluginMetrics.counter(anyString())).thenReturn(sharedCounter);
when(pluginMetrics.summary(anyString())).thenReturn(distributionSummary);
lenient().when(pluginMetrics.counter(anyString())).thenReturn(sharedCounter);
lenient().when(pluginMetrics.summary(anyString())).thenReturn(distributionSummary);
when(pluginMetrics.timer(anyString())).thenReturn(sqsMessageDelayTimer);

final SqsOptions sqsOptions = mock(SqsOptions.class);
Expand All @@ -86,7 +106,7 @@ void setUp() {
when(sqsOptions.getMaximumMessages()).thenReturn(10);
when(sqsOptions.getWaitTime()).thenReturn(Duration.ofSeconds(10));
when(s3SourceConfig.getSqsOptions()).thenReturn(sqsOptions);
when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES);
lenient().when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES);
when(s3SourceConfig.getNotificationSource()).thenReturn(NotificationSourceOption.S3);
}

Expand Down Expand Up @@ -127,6 +147,145 @@ void processSqsMessages_should_return_at_least_one_message(final int numberOfObj
assertThat(sqsMessagesProcessed, lessThanOrEqualTo(numberOfObjectsToWrite));
}

@ParameterizedTest
@ValueSource(ints = {1})
void processSqsMessages_should_return_at_least_one_message_with_acks_with_callback_invoked_after_processS3Object_finishes(final int numberOfObjectsToWrite) throws IOException {
writeToS3(numberOfObjectsToWrite);

when(s3SourceConfig.getAcknowledgements()).thenReturn(true);
final Counter receivedCounter = mock(Counter.class);
final Counter deletedCounter = mock(Counter.class);
final Counter ackCallbackCounter = mock(Counter.class);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(receivedCounter);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(deletedCounter);
when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(ackCallbackCounter);
lenient().doAnswer((val) -> {
receivedCount += (double)val.getArgument(0);
return null;
}).when(receivedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
if (val.getArgument(0) != null) {
deletedCount += (double)val.getArgument(0);
}
return null;
}).when(deletedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
ackCallbackCount += 1;
return null;
}).when(ackCallbackCounter).increment();

doAnswer((val) -> {
AcknowledgementSet ackSet = val.getArgument(1);
S3ObjectReference s3ObjectReference = val.getArgument(0);
assertThat(s3ObjectReference.getBucketName(), equalTo(bucket));
assertThat(s3ObjectReference.getKey(), startsWith("s3 source/sqs/"));
event = (Event)JacksonEvent.fromMessage(val.getArgument(0).toString());
ackSet.add(event);
return null;
}).when(s3Service).addS3Object(any(S3ObjectReference.class), any(AcknowledgementSet.class));
ExecutorService executor = Executors.newFixedThreadPool(2);
acknowledgementSetManager = new DefaultAcknowledgementSetManager(executor);
final SqsWorker objectUnderTest = createObjectUnderTest();
Thread sinkThread = new Thread(() -> {
try {
synchronized(this) {
while (!ready.get()) {
Thread.sleep(100);
this.wait();
}
if (event.getEventHandle() != null) {
event.getEventHandle().release(true);
}
}
} catch (Exception e){}
});
sinkThread.start();
final int sqsMessagesProcessed = objectUnderTest.processSqsMessages();
synchronized(this) {
ready.set(true);
this.notify();
}
try {
Thread.sleep(10000);
} catch (Exception e){}

assertThat(deletedCount, equalTo((double)1.0));
assertThat(ackCallbackCount, equalTo((double)1.0));
}

@ParameterizedTest
@ValueSource(ints = {1})
void processSqsMessages_should_return_at_least_one_message_with_acks_with_callback_invoked_before_processS3Object_finishes(final int numberOfObjectsToWrite) throws IOException {
writeToS3(numberOfObjectsToWrite);

when(s3SourceConfig.getAcknowledgements()).thenReturn(true);
final Counter receivedCounter = mock(Counter.class);
final Counter deletedCounter = mock(Counter.class);
final Counter ackCallbackCounter = mock(Counter.class);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(receivedCounter);
when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(deletedCounter);
when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(ackCallbackCounter);
lenient().doAnswer((val) -> {
receivedCount += (double)val.getArgument(0);
return null;
}).when(receivedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
if (val.getArgument(0) != null) {
deletedCount += (double)val.getArgument(0);
}
return null;
}).when(deletedCounter).increment(any(Double.class));
lenient().doAnswer((val) -> {
ackCallbackCount += 1;
return null;
}).when(ackCallbackCounter).increment();

doAnswer((val) -> {
AcknowledgementSet ackSet = val.getArgument(1);
S3ObjectReference s3ObjectReference = val.getArgument(0);
assertThat(s3ObjectReference.getBucketName(), equalTo(bucket));
assertThat(s3ObjectReference.getKey(), startsWith("s3 source/sqs/"));
event = (Event)JacksonEvent.fromMessage(val.getArgument(0).toString());

ackSet.add(event);
synchronized(this) {
ready.set(true);
this.notify();
}
try {
Thread.sleep(4000);
} catch (Exception e){}

return null;
}).when(s3Service).addS3Object(any(S3ObjectReference.class), any(AcknowledgementSet.class));
ExecutorService executor = Executors.newFixedThreadPool(2);
acknowledgementSetManager = new DefaultAcknowledgementSetManager(executor);
final SqsWorker objectUnderTest = createObjectUnderTest();
Thread sinkThread = new Thread(() -> {
try {
synchronized(this) {
while (!ready.get()) {
Thread.sleep(100);
this.wait();
}
if (event.getEventHandle() != null) {
event.getEventHandle().release(true);
}
}
} catch (Exception e){}
});
sinkThread.start();
final int sqsMessagesProcessed = objectUnderTest.processSqsMessages();

try {
Thread.sleep(10000);
} catch (Exception e){}


assertThat(deletedCount, equalTo((double)1.0));
assertThat(ackCallbackCount, equalTo((double)1.0));
}

/** The EventBridge test is disabled by default
* To run this test run only this one test with S3 bucket configured to use EventBridge to send notifications to SQS
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.concurrent.atomic.AtomicBoolean;

public class SqsWorker implements Runnable {
private static final Logger LOG = LoggerFactory.getLogger(SqsWorker.class);
Expand Down Expand Up @@ -226,21 +227,33 @@ && isEventBridgeEventTypeCreated(parsedMessage)) {
for (ParsedMessage parsedMessage : parsedMessagesToRead) {
List<DeleteMessageBatchRequestEntry> waitingForAcknowledgements = new ArrayList<>();
AcknowledgementSet acknowledgementSet = null;
AtomicBoolean acknowledgementSetReady = new AtomicBoolean(false);
if (endToEndAcknowledgementsEnabled) {
// Acknowledgement Set timeout is slightly smaller than the visibility timeout;
int timeout = (int) sqsOptions.getVisibilityTimeout().getSeconds() - 2;
acknowledgementSet = acknowledgementSetManager.create((result) -> {
acknowledgementSetCallbackCounter.increment();
// Delete only if this is positive acknowledgement
if (result == true) {
deleteSqsMessages(waitingForAcknowledgements);
synchronized (waitingForAcknowledgements) {
while (!acknowledgementSetReady.get()) {
try {
waitingForAcknowledgements.wait();
} catch (InterruptedException e){}
}
acknowledgementSetCallbackCounter.increment();
// Delete only if this is positive acknowledgement
if (result == true) {
deleteSqsMessages(waitingForAcknowledgements);
}
}
}, Duration.ofSeconds(timeout));
}
final S3ObjectReference s3ObjectReference = populateS3Reference(parsedMessage.getBucketName(), parsedMessage.getObjectKey());
final Optional<DeleteMessageBatchRequestEntry> deleteMessageBatchRequestEntry = processS3Object(parsedMessage, s3ObjectReference, acknowledgementSet);
if (endToEndAcknowledgementsEnabled) {
deleteMessageBatchRequestEntry.ifPresent(waitingForAcknowledgements::add);
synchronized (waitingForAcknowledgements) {
acknowledgementSetReady.set(true);
waitingForAcknowledgements.notify();
}
} else {
deleteMessageBatchRequestEntry.ifPresent(deleteMessageBatchRequestEntryCollection::add);
}
Expand Down

0 comments on commit b2f6a8a

Please sign in to comment.