diff --git a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/acknowledgements/AcknowledgementSet.java b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/acknowledgements/AcknowledgementSet.java index edd36f4ee7..c95c2e5f88 100644 --- a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/acknowledgements/AcknowledgementSet.java +++ b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/acknowledgements/AcknowledgementSet.java @@ -50,4 +50,12 @@ public interface AcknowledgementSet { * @since 2.2 */ public boolean release(final EventHandle eventHandle, final boolean result); + + /** + * Indicates that the addition of initial set of events to + * the acknowledgement set is completed. + * It is possible that more events are added to the set as the + * initial events are going through the pipeline line. + */ + public void complete(); } diff --git a/data-prepper-core/src/integrationTest/java/org/opensearch/dataprepper/plugins/InMemorySource.java b/data-prepper-core/src/integrationTest/java/org/opensearch/dataprepper/plugins/InMemorySource.java index e28a9a3029..133340cc39 100644 --- a/data-prepper-core/src/integrationTest/java/org/opensearch/dataprepper/plugins/InMemorySource.java +++ b/data-prepper-core/src/integrationTest/java/org/opensearch/dataprepper/plugins/InMemorySource.java @@ -114,6 +114,10 @@ public void run() { while (!isStopped) { try { final List> records = inMemorySourceAccessor.read(testingKey); + if (records.size() == 0) { + Thread.sleep(1000); + continue; + } AcknowledgementSet ackSet = acknowledgementSetManager.create((result) -> { @@ -121,6 +125,7 @@ public void run() { }, Duration.ofSeconds(15)); records.stream().forEach((record) -> { ackSet.add(record.getData()); }); + ackSet.complete(); writeToBuffer(records); } catch (final Exception ex) { LOG.error("Error during source loop.", ex); diff --git a/data-prepper-core/src/main/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSet.java b/data-prepper-core/src/main/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSet.java index c5bd269105..3c8fe12159 100644 --- a/data-prepper-core/src/main/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSet.java +++ b/data-prepper-core/src/main/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSet.java @@ -34,6 +34,7 @@ public class DefaultAcknowledgementSet implements AcknowledgementSet { private final Map pendingAcknowledgments; private Future callbackFuture; private final DefaultAcknowledgementSetMetrics metrics; + private boolean completed; public DefaultAcknowledgementSet(final ExecutorService executor, final Consumer callback, final Duration expiryTime, final DefaultAcknowledgementSetMetrics metrics) { this.callback = callback; @@ -42,6 +43,7 @@ public DefaultAcknowledgementSet(final ExecutorService executor, final Consumer< this.expiryTime = Instant.now().plusMillis(expiryTime.toMillis()); this.callbackFuture = null; this.metrics = metrics; + this.completed = false; pendingAcknowledgments = new HashMap<>(); lock = new ReentrantLock(true); } @@ -84,6 +86,8 @@ public boolean isDone() { if (Instant.now().isAfter(expiryTime)) { if (callbackFuture != null) { callbackFuture.cancel(true); + callbackFuture = null; + LOG.warn("AcknowledgementSet expired"); } metrics.increment(DefaultAcknowledgementSetMetrics.EXPIRED_METRIC_NAME); return true; @@ -98,6 +102,19 @@ public Instant getExpiryTime() { return expiryTime; } + @Override + public void complete() { + lock.lock(); + try { + completed = true; + if (pendingAcknowledgments.size() == 0) { + callbackFuture = executor.submit(() -> callback.accept(this.result)); + } + } finally { + lock.unlock(); + } + } + @Override public boolean release(final EventHandle eventHandle, final boolean result) { lock.lock(); @@ -114,9 +131,11 @@ public boolean release(final EventHandle eventHandle, final boolean result) { } if (pendingAcknowledgments.get(eventHandle).decrementAndGet() == 0) { pendingAcknowledgments.remove(eventHandle); - if (pendingAcknowledgments.size() == 0) { + if (completed && pendingAcknowledgments.size() == 0) { callbackFuture = executor.submit(() -> callback.accept(this.result)); return true; + } else if (pendingAcknowledgments.size() == 0) { + LOG.warn("Acknowledgement set is not completed. Delaying callback until it is completed"); } } } finally { diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetManagerTests.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetManagerTests.java index 6538e5c89b..9b015aea72 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetManagerTests.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetManagerTests.java @@ -64,6 +64,7 @@ void setup() { AcknowledgementSet acknowledgementSet1 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT_MS); acknowledgementSet1.add(event1); acknowledgementSet1.add(event2); + acknowledgementSet1.complete(); } DefaultAcknowledgementSetManager createObjectUnderTest() { @@ -98,6 +99,7 @@ void testMultipleAcknowledgementSets() throws InterruptedException { AcknowledgementSet acknowledgementSet2 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT_MS); acknowledgementSet2.add(event3); + acknowledgementSet2.complete(); acknowledgementSetManager.releaseEventReference(eventHandle2, true); acknowledgementSetManager.releaseEventReference(eventHandle3, true); diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetTests.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetTests.java index afc5d5683e..c4403e4b2f 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetTests.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetTests.java @@ -89,6 +89,7 @@ void setupEvent() { @Test void testDefaultAcknowledgementSetBasic() throws Exception { defaultAcknowledgementSet.add(event); + defaultAcknowledgementSet.complete(); assertThat(handle, not(equalTo(null))); assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet)); assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true)); @@ -97,6 +98,7 @@ void testDefaultAcknowledgementSetBasic() throws Exception { @Test void testDefaultAcknowledgementSetMultipleAcquireAndRelease() throws Exception { defaultAcknowledgementSet.add(event); + defaultAcknowledgementSet.complete(); assertThat(handle, not(equalTo(null))); assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet)); defaultAcknowledgementSet.acquire(handle); @@ -111,6 +113,7 @@ void testDefaultAcknowledgementSetMultipleAcquireAndRelease() throws Exception { @Test void testDefaultAcknowledgementInvalidAcquire() { defaultAcknowledgementSet.add(event); + defaultAcknowledgementSet.complete(); DefaultAcknowledgementSet secondAcknowledgementSet = createObjectUnderTest(); DefaultEventHandle handle2 = new DefaultEventHandle(secondAcknowledgementSet); defaultAcknowledgementSet.acquire(handle2); @@ -120,6 +123,7 @@ void testDefaultAcknowledgementInvalidAcquire() { @Test void testDefaultAcknowledgementInvalidRelease() { defaultAcknowledgementSet.add(event); + defaultAcknowledgementSet.complete(); DefaultAcknowledgementSet secondAcknowledgementSet = createObjectUnderTest(); DefaultEventHandle handle2 = new DefaultEventHandle(secondAcknowledgementSet); assertThat(defaultAcknowledgementSet.release(handle2, true), equalTo(false)); @@ -129,6 +133,7 @@ void testDefaultAcknowledgementInvalidRelease() { @Test void testDefaultAcknowledgementDuplicateReleaseError() throws Exception { defaultAcknowledgementSet.add(event); + defaultAcknowledgementSet.complete(); assertThat(handle, not(equalTo(null))); assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet)); assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true)); @@ -144,6 +149,7 @@ void testDefaultAcknowledgementSetWithCustomCallback() throws Exception { } ); defaultAcknowledgementSet.add(event); + defaultAcknowledgementSet.complete(); assertThat(handle, not(equalTo(null))); assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet)); assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true)); @@ -162,6 +168,7 @@ void testDefaultAcknowledgementSetNegativeAcknowledgements() throws Exception { } ); defaultAcknowledgementSet.add(event); + defaultAcknowledgementSet.complete(); assertThat(handle, not(equalTo(null))); assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet)); defaultAcknowledgementSet.acquire(handle); @@ -190,6 +197,7 @@ void testDefaultAcknowledgementSetExpirations() throws Exception { } ); defaultAcknowledgementSet.add(event); + defaultAcknowledgementSet.complete(); assertThat(handle, not(equalTo(null))); assertThat(handle.getAcknowledgementSet(), equalTo(defaultAcknowledgementSet)); assertThat(defaultAcknowledgementSet.release(handle, true), equalTo(true)); diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaSourceCustomConsumer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaSourceCustomConsumer.java index 6bd35391b9..67505a67f4 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaSourceCustomConsumer.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaSourceCustomConsumer.java @@ -147,6 +147,8 @@ public void consumeRecords() throws Exception { if (!acknowledgementsEnabled) { offsets.forEach((partition, offsetRange) -> updateOffsetsToCommit(partition, new OffsetAndMetadata(offsetRange.getMaximum() + 1))); + } else { + acknowledgementSet.complete(); } } } diff --git a/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/SqsWorkerIT.java b/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/SqsWorkerIT.java index 144bb49b22..23c365ec45 100644 --- a/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/SqsWorkerIT.java +++ b/data-prepper-plugins/s3-source/src/integrationTest/java/org/opensearch/dataprepper/plugins/source/SqsWorkerIT.java @@ -13,6 +13,10 @@ import org.opensearch.dataprepper.plugins.source.configuration.OnErrorOption; import org.opensearch.dataprepper.plugins.source.configuration.SqsOptions; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.acknowledgements.DefaultAcknowledgementSetManager; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Timer; import org.junit.jupiter.api.AfterEach; @@ -28,6 +32,9 @@ import java.io.IOException; import java.time.Duration; import java.time.Instant; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.UUID; import static org.hamcrest.CoreMatchers.equalTo; @@ -36,14 +43,22 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.core.StringStartsWith.startsWith; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.doAnswer; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +@ExtendWith(MockitoExtension.class) class SqsWorkerIT { private SqsClient sqsClient; + @Mock private S3Service s3Service; private S3SourceConfig s3SourceConfig; private PluginMetrics pluginMetrics; @@ -51,6 +66,11 @@ class SqsWorkerIT { private String bucket; private Backoff backoff; private AcknowledgementSetManager acknowledgementSetManager; + private Double receivedCount = 0.0; + private Double deletedCount = 0.0; + private Double ackCallbackCount = 0.0; + private Event event; + private AtomicBoolean ready = new AtomicBoolean(false); @BeforeEach void setUp() { @@ -76,8 +96,8 @@ void setUp() { final DistributionSummary distributionSummary = mock(DistributionSummary.class); final Timer sqsMessageDelayTimer = mock(Timer.class); - when(pluginMetrics.counter(anyString())).thenReturn(sharedCounter); - when(pluginMetrics.summary(anyString())).thenReturn(distributionSummary); + lenient().when(pluginMetrics.counter(anyString())).thenReturn(sharedCounter); + lenient().when(pluginMetrics.summary(anyString())).thenReturn(distributionSummary); when(pluginMetrics.timer(anyString())).thenReturn(sqsMessageDelayTimer); final SqsOptions sqsOptions = mock(SqsOptions.class); @@ -86,7 +106,7 @@ void setUp() { when(sqsOptions.getMaximumMessages()).thenReturn(10); when(sqsOptions.getWaitTime()).thenReturn(Duration.ofSeconds(10)); when(s3SourceConfig.getSqsOptions()).thenReturn(sqsOptions); - when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES); + lenient().when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES); when(s3SourceConfig.getNotificationSource()).thenReturn(NotificationSourceOption.S3); } @@ -127,6 +147,140 @@ void processSqsMessages_should_return_at_least_one_message(final int numberOfObj assertThat(sqsMessagesProcessed, lessThanOrEqualTo(numberOfObjectsToWrite)); } + @ParameterizedTest + @ValueSource(ints = {1}) + void processSqsMessages_should_return_at_least_one_message_with_acks_with_callback_invoked_after_processS3Object_finishes(final int numberOfObjectsToWrite) throws IOException, InterruptedException { + writeToS3(numberOfObjectsToWrite); + + when(s3SourceConfig.getAcknowledgements()).thenReturn(true); + final Counter receivedCounter = mock(Counter.class); + final Counter deletedCounter = mock(Counter.class); + final Counter ackCallbackCounter = mock(Counter.class); + when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(receivedCounter); + when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(deletedCounter); + when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(ackCallbackCounter); + lenient().doAnswer((val) -> { + receivedCount += (double)val.getArgument(0); + return null; + }).when(receivedCounter).increment(any(Double.class)); + lenient().doAnswer((val) -> { + if (val.getArgument(0) != null) { + deletedCount += (double)val.getArgument(0); + } + return null; + }).when(deletedCounter).increment(any(Double.class)); + lenient().doAnswer((val) -> { + ackCallbackCount += 1; + return null; + }).when(ackCallbackCounter).increment(); + + doAnswer((val) -> { + AcknowledgementSet ackSet = val.getArgument(1); + S3ObjectReference s3ObjectReference = val.getArgument(0); + assertThat(s3ObjectReference.getBucketName(), equalTo(bucket)); + assertThat(s3ObjectReference.getKey(), startsWith("s3 source/sqs/")); + event = (Event)JacksonEvent.fromMessage(val.getArgument(0).toString()); + ackSet.add(event); + return null; + }).when(s3Service).addS3Object(any(S3ObjectReference.class), any(AcknowledgementSet.class)); + ExecutorService executor = Executors.newFixedThreadPool(2); + acknowledgementSetManager = new DefaultAcknowledgementSetManager(executor); + final SqsWorker objectUnderTest = createObjectUnderTest(); + Thread sinkThread = new Thread(() -> { + try { + synchronized(this) { + while (!ready.get()) { + Thread.sleep(100); + this.wait(); + } + if (event.getEventHandle() != null) { + event.getEventHandle().release(true); + } + } + } catch (Exception e){} + }); + sinkThread.start(); + final int sqsMessagesProcessed = objectUnderTest.processSqsMessages(); + synchronized(this) { + ready.set(true); + this.notify(); + } + Thread.sleep(10000); + + assertThat(deletedCount, equalTo((double)1.0)); + assertThat(ackCallbackCount, equalTo((double)1.0)); + } + + @ParameterizedTest + @ValueSource(ints = {1}) + void processSqsMessages_should_return_at_least_one_message_with_acks_with_callback_invoked_before_processS3Object_finishes(final int numberOfObjectsToWrite) throws IOException, InterruptedException { + writeToS3(numberOfObjectsToWrite); + + when(s3SourceConfig.getAcknowledgements()).thenReturn(true); + final Counter receivedCounter = mock(Counter.class); + final Counter deletedCounter = mock(Counter.class); + final Counter ackCallbackCounter = mock(Counter.class); + when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(receivedCounter); + when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(deletedCounter); + when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(ackCallbackCounter); + lenient().doAnswer((val) -> { + receivedCount += (double)val.getArgument(0); + return null; + }).when(receivedCounter).increment(any(Double.class)); + lenient().doAnswer((val) -> { + if (val.getArgument(0) != null) { + deletedCount += (double)val.getArgument(0); + } + return null; + }).when(deletedCounter).increment(any(Double.class)); + lenient().doAnswer((val) -> { + ackCallbackCount += 1; + return null; + }).when(ackCallbackCounter).increment(); + + doAnswer((val) -> { + AcknowledgementSet ackSet = val.getArgument(1); + S3ObjectReference s3ObjectReference = val.getArgument(0); + assertThat(s3ObjectReference.getBucketName(), equalTo(bucket)); + assertThat(s3ObjectReference.getKey(), startsWith("s3 source/sqs/")); + event = (Event)JacksonEvent.fromMessage(val.getArgument(0).toString()); + + ackSet.add(event); + synchronized(this) { + ready.set(true); + this.notify(); + } + try { + Thread.sleep(4000); + } catch (Exception e){} + + return null; + }).when(s3Service).addS3Object(any(S3ObjectReference.class), any(AcknowledgementSet.class)); + ExecutorService executor = Executors.newFixedThreadPool(2); + acknowledgementSetManager = new DefaultAcknowledgementSetManager(executor); + final SqsWorker objectUnderTest = createObjectUnderTest(); + Thread sinkThread = new Thread(() -> { + try { + synchronized(this) { + while (!ready.get()) { + Thread.sleep(100); + this.wait(); + } + if (event.getEventHandle() != null) { + event.getEventHandle().release(true); + } + } + } catch (Exception e){} + }); + sinkThread.start(); + final int sqsMessagesProcessed = objectUnderTest.processSqsMessages(); + + Thread.sleep(10000); + + assertThat(deletedCount, equalTo((double)1.0)); + assertThat(ackCallbackCount, equalTo((double)1.0)); + } + /** The EventBridge test is disabled by default * To run this test run only this one test with S3 bucket configured to use EventBridge to send notifications to SQS */ diff --git a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/SqsWorker.java b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/SqsWorker.java index 4a1beecf61..847a70c5bb 100644 --- a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/SqsWorker.java +++ b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/SqsWorker.java @@ -241,6 +241,7 @@ && isEventBridgeEventTypeCreated(parsedMessage)) { final Optional deleteMessageBatchRequestEntry = processS3Object(parsedMessage, s3ObjectReference, acknowledgementSet); if (endToEndAcknowledgementsEnabled) { deleteMessageBatchRequestEntry.ifPresent(waitingForAcknowledgements::add); + acknowledgementSet.complete(); } else { deleteMessageBatchRequestEntry.ifPresent(deleteMessageBatchRequestEntryCollection::add); }