Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.3] Fix bucket ownership validation in S3 source #3011

Merged
merged 1 commit into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import org.apache.parquet.io.SeekableInputStream;
import org.opensearch.dataprepper.model.io.InputFile;
import org.opensearch.dataprepper.plugins.source.ownership.BucketOwnerProvider;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
Expand All @@ -18,17 +19,20 @@ public class S3InputFile implements InputFile {

private final S3ObjectReference s3ObjectReference;

private final BucketOwnerProvider bucketOwnerProvider;
private final S3ObjectPluginMetrics s3ObjectPluginMetrics;

private HeadObjectResponse metadata;

public S3InputFile(
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final S3ObjectPluginMetrics s3ObjectPluginMetrics
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final BucketOwnerProvider bucketOwnerProvider,
final S3ObjectPluginMetrics s3ObjectPluginMetrics
) {
this.s3Client = s3Client;
this.s3ObjectReference = s3ObjectReference;
this.bucketOwnerProvider = bucketOwnerProvider;
this.s3ObjectPluginMetrics = s3ObjectPluginMetrics;
}

Expand All @@ -49,7 +53,7 @@ public long getLength() {
@Override
public SeekableInputStream newStream() {
return new S3InputStream(
s3Client, s3ObjectReference, getMetadata(), s3ObjectPluginMetrics, DEFAULT_RETRY_DELAY, DEFAULT_RETRIES);
s3Client, s3ObjectReference, bucketOwnerProvider, getMetadata(), s3ObjectPluginMetrics, DEFAULT_RETRY_DELAY, DEFAULT_RETRIES);
}

/**
Expand All @@ -58,9 +62,12 @@ public SeekableInputStream newStream() {
*/
private synchronized HeadObjectResponse getMetadata() {
if (metadata == null) {
final HeadObjectRequest request = HeadObjectRequest.builder()
final HeadObjectRequest.Builder headRequestBuilder = HeadObjectRequest.builder()
.bucket(s3ObjectReference.getBucketName())
.key(s3ObjectReference.getKey())
.key(s3ObjectReference.getKey());
bucketOwnerProvider.getBucketOwner(s3ObjectReference.getBucketName())
.ifPresent(headRequestBuilder::expectedBucketOwner);
final HeadObjectRequest request = headRequestBuilder
.build();
metadata = s3Client.headObject(request);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dev.failsafe.function.CheckedSupplier;
import org.apache.http.ConnectionClosedException;
import org.apache.parquet.io.SeekableInputStream;
import org.opensearch.dataprepper.plugins.source.ownership.BucketOwnerProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.sync.ResponseTransformer;
Expand Down Expand Up @@ -73,12 +74,13 @@ class S3InputStream extends SeekableInputStream {
private RetryPolicy<Integer> retryPolicyReturningInteger;

public S3InputStream(
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final HeadObjectResponse metadata,
final S3ObjectPluginMetrics s3ObjectPluginMetrics,
final Duration retryDelay,
final int retries
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final BucketOwnerProvider bucketOwnerProvider,
final HeadObjectResponse metadata,
final S3ObjectPluginMetrics s3ObjectPluginMetrics,
final Duration retryDelay,
final int retries
) {
this.s3Client = s3Client;
this.s3ObjectReference = s3ObjectReference;
Expand All @@ -90,6 +92,9 @@ public S3InputStream(
.bucket(this.s3ObjectReference.getBucketName())
.key(this.s3ObjectReference.getKey());

bucketOwnerProvider.getBucketOwner(this.s3ObjectReference.getBucketName())
.ifPresent(getObjectRequestBuilder::expectedBucketOwner);

this.retryPolicyReturningByteArray = RetryPolicy.<byte[]>builder()
.handle(RETRYABLE_EXCEPTIONS)
.withDelay(retryDelay)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ private void doParseObject(final AcknowledgementSet acknowledgementSet, final S3

LOG.info("Read S3 object: {}", s3ObjectReference);

final S3InputFile inputFile = new S3InputFile(s3Client, s3ObjectReference, s3ObjectPluginMetrics);
final S3InputFile inputFile = new S3InputFile(s3Client, s3ObjectReference, bucketOwnerProvider, s3ObjectPluginMetrics);

final CompressionOption fileCompressionOption = compressionOption != CompressionOption.AUTOMATIC ?
compressionOption : CompressionOption.fromFileName(s3ObjectReference.getKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
import org.apache.parquet.io.SeekableInputStream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.opensearch.dataprepper.plugins.source.ownership.BucketOwnerProvider;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;

import java.util.Optional;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand All @@ -20,15 +27,24 @@ public class S3InputFileTest {
private S3Client s3Client;
private S3ObjectReference s3ObjectReference;
private S3ObjectPluginMetrics s3ObjectPluginMetrics;
private S3InputFile s3InputFile;
private String bucketName;
private String key;
private BucketOwnerProvider bucketOwnerProvider;

@BeforeEach
public void setUp() {
s3Client = mock(S3Client.class);
s3ObjectReference = mock(S3ObjectReference.class);
s3ObjectPluginMetrics = mock(S3ObjectPluginMetrics.class);
bucketOwnerProvider = mock(BucketOwnerProvider.class);
bucketName = UUID.randomUUID().toString();
key = UUID.randomUUID().toString();
when(s3ObjectReference.getBucketName()).thenReturn(bucketName);
when(s3ObjectReference.getKey()).thenReturn(key);
}

s3InputFile = new S3InputFile(s3Client, s3ObjectReference, s3ObjectPluginMetrics);
private S3InputFile createObjectUnderTest() {
return new S3InputFile(s3Client, s3ObjectReference, bucketOwnerProvider, s3ObjectPluginMetrics);
}

@Test
Expand All @@ -37,18 +53,50 @@ public void testGetLength() {
when(s3Client.headObject(any(HeadObjectRequest.class))).thenReturn(headObjectResponse);
when(headObjectResponse.contentLength()).thenReturn(12345L);

long length = s3InputFile.getLength();
long length = createObjectUnderTest().getLength();

assertThat(length, equalTo(12345L));
final ArgumentCaptor<HeadObjectRequest> headObjectRequestArgumentCaptor = ArgumentCaptor.forClass(HeadObjectRequest.class);
verify(s3Client, times(1)).headObject(headObjectRequestArgumentCaptor.capture());

final HeadObjectRequest actualHeadObjectRequest = headObjectRequestArgumentCaptor.getValue();
assertAll(
() -> assertThat(actualHeadObjectRequest.bucket(), equalTo(bucketName)),
() -> assertThat(actualHeadObjectRequest.key(), equalTo(key)),
() -> assertThat(actualHeadObjectRequest.expectedBucketOwner(), nullValue())
);
}

@Test
public void getLength_requests_head_for_bucket_key_and_owner_when_bucket_has_owner() {
final HeadObjectResponse headObjectResponse = mock(HeadObjectResponse.class);
when(s3Client.headObject(any(HeadObjectRequest.class))).thenReturn(headObjectResponse);
when(headObjectResponse.contentLength()).thenReturn(12345L);

final String owner = UUID.randomUUID().toString();
when(bucketOwnerProvider.getBucketOwner(bucketName)).thenReturn(Optional.of(owner));

long length = createObjectUnderTest().getLength();

assertThat(length, equalTo(12345L));
verify(s3Client, times(1)).headObject(any(HeadObjectRequest.class));

final ArgumentCaptor<HeadObjectRequest> headObjectRequestArgumentCaptor = ArgumentCaptor.forClass(HeadObjectRequest.class);
verify(s3Client, times(1)).headObject(headObjectRequestArgumentCaptor.capture());

final HeadObjectRequest actualHeadObjectRequest = headObjectRequestArgumentCaptor.getValue();
assertAll(
() -> assertThat(actualHeadObjectRequest.bucket(), equalTo(bucketName)),
() -> assertThat(actualHeadObjectRequest.key(), equalTo(key)),
() -> assertThat(actualHeadObjectRequest.expectedBucketOwner(), equalTo(owner))
);
}

@Test
public void testNewStream() {
HeadObjectResponse headObjectResponse = mock(HeadObjectResponse.class);
when(s3Client.headObject(any(HeadObjectRequest.class))).thenReturn(headObjectResponse);

SeekableInputStream seekableInputStream = s3InputFile.newStream();
SeekableInputStream seekableInputStream = createObjectUnderTest().newStream();

assertThat(seekableInputStream.getClass(), equalTo(S3InputStream.class));
}
Expand Down
Loading
Loading