Skip to content

Commit

Permalink
Fix bucket ownership validation. Resolves #3005 (#3009)
Browse files Browse the repository at this point in the history
Signed-off-by: David Venable <[email protected]>
  • Loading branch information
dlvenable committed Jul 12, 2023
1 parent 2b7b7da commit decccb9
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import org.apache.parquet.io.SeekableInputStream;
import org.opensearch.dataprepper.model.io.InputFile;
import org.opensearch.dataprepper.plugins.source.ownership.BucketOwnerProvider;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
Expand All @@ -18,17 +19,20 @@ public class S3InputFile implements InputFile {

private final S3ObjectReference s3ObjectReference;

private final BucketOwnerProvider bucketOwnerProvider;
private final S3ObjectPluginMetrics s3ObjectPluginMetrics;

private HeadObjectResponse metadata;

public S3InputFile(
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final S3ObjectPluginMetrics s3ObjectPluginMetrics
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final BucketOwnerProvider bucketOwnerProvider,
final S3ObjectPluginMetrics s3ObjectPluginMetrics
) {
this.s3Client = s3Client;
this.s3ObjectReference = s3ObjectReference;
this.bucketOwnerProvider = bucketOwnerProvider;
this.s3ObjectPluginMetrics = s3ObjectPluginMetrics;
}

Expand All @@ -49,7 +53,7 @@ public long getLength() {
@Override
public SeekableInputStream newStream() {
return new S3InputStream(
s3Client, s3ObjectReference, getMetadata(), s3ObjectPluginMetrics, DEFAULT_RETRY_DELAY, DEFAULT_RETRIES);
s3Client, s3ObjectReference, bucketOwnerProvider, getMetadata(), s3ObjectPluginMetrics, DEFAULT_RETRY_DELAY, DEFAULT_RETRIES);
}

/**
Expand All @@ -58,9 +62,12 @@ public SeekableInputStream newStream() {
*/
private synchronized HeadObjectResponse getMetadata() {
if (metadata == null) {
final HeadObjectRequest request = HeadObjectRequest.builder()
final HeadObjectRequest.Builder headRequestBuilder = HeadObjectRequest.builder()
.bucket(s3ObjectReference.getBucketName())
.key(s3ObjectReference.getKey())
.key(s3ObjectReference.getKey());
bucketOwnerProvider.getBucketOwner(s3ObjectReference.getBucketName())
.ifPresent(headRequestBuilder::expectedBucketOwner);
final HeadObjectRequest request = headRequestBuilder
.build();
metadata = s3Client.headObject(request);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dev.failsafe.function.CheckedSupplier;
import org.apache.http.ConnectionClosedException;
import org.apache.parquet.io.SeekableInputStream;
import org.opensearch.dataprepper.plugins.source.ownership.BucketOwnerProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.sync.ResponseTransformer;
Expand Down Expand Up @@ -73,12 +74,13 @@ class S3InputStream extends SeekableInputStream {
private RetryPolicy<Integer> retryPolicyReturningInteger;

public S3InputStream(
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final HeadObjectResponse metadata,
final S3ObjectPluginMetrics s3ObjectPluginMetrics,
final Duration retryDelay,
final int retries
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final BucketOwnerProvider bucketOwnerProvider,
final HeadObjectResponse metadata,
final S3ObjectPluginMetrics s3ObjectPluginMetrics,
final Duration retryDelay,
final int retries
) {
this.s3Client = s3Client;
this.s3ObjectReference = s3ObjectReference;
Expand All @@ -90,6 +92,9 @@ public S3InputStream(
.bucket(this.s3ObjectReference.getBucketName())
.key(this.s3ObjectReference.getKey());

bucketOwnerProvider.getBucketOwner(this.s3ObjectReference.getBucketName())
.ifPresent(getObjectRequestBuilder::expectedBucketOwner);

this.retryPolicyReturningByteArray = RetryPolicy.<byte[]>builder()
.handle(RETRYABLE_EXCEPTIONS)
.withDelay(retryDelay)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private void doParseObject(final AcknowledgementSet acknowledgementSet, final S3

LOG.info("Read S3 object: {}", s3ObjectReference);

final S3InputFile inputFile = new S3InputFile(s3Client, s3ObjectReference, s3ObjectPluginMetrics);
final S3InputFile inputFile = new S3InputFile(s3Client, s3ObjectReference, bucketOwnerProvider, s3ObjectPluginMetrics);

final CompressionOption fileCompressionOption = compressionOption != CompressionOption.AUTOMATIC ?
compressionOption : CompressionOption.fromFileName(s3ObjectReference.getKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
import org.apache.parquet.io.SeekableInputStream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.opensearch.dataprepper.plugins.source.ownership.BucketOwnerProvider;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;

import java.util.Optional;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand All @@ -20,15 +27,24 @@ public class S3InputFileTest {
private S3Client s3Client;
private S3ObjectReference s3ObjectReference;
private S3ObjectPluginMetrics s3ObjectPluginMetrics;
private S3InputFile s3InputFile;
private String bucketName;
private String key;
private BucketOwnerProvider bucketOwnerProvider;

@BeforeEach
public void setUp() {
s3Client = mock(S3Client.class);
s3ObjectReference = mock(S3ObjectReference.class);
s3ObjectPluginMetrics = mock(S3ObjectPluginMetrics.class);
bucketOwnerProvider = mock(BucketOwnerProvider.class);
bucketName = UUID.randomUUID().toString();
key = UUID.randomUUID().toString();
when(s3ObjectReference.getBucketName()).thenReturn(bucketName);
when(s3ObjectReference.getKey()).thenReturn(key);
}

s3InputFile = new S3InputFile(s3Client, s3ObjectReference, s3ObjectPluginMetrics);
private S3InputFile createObjectUnderTest() {
return new S3InputFile(s3Client, s3ObjectReference, bucketOwnerProvider, s3ObjectPluginMetrics);
}

@Test
Expand All @@ -37,18 +53,50 @@ public void testGetLength() {
when(s3Client.headObject(any(HeadObjectRequest.class))).thenReturn(headObjectResponse);
when(headObjectResponse.contentLength()).thenReturn(12345L);

long length = s3InputFile.getLength();
long length = createObjectUnderTest().getLength();

assertThat(length, equalTo(12345L));
final ArgumentCaptor<HeadObjectRequest> headObjectRequestArgumentCaptor = ArgumentCaptor.forClass(HeadObjectRequest.class);
verify(s3Client, times(1)).headObject(headObjectRequestArgumentCaptor.capture());

final HeadObjectRequest actualHeadObjectRequest = headObjectRequestArgumentCaptor.getValue();
assertAll(
() -> assertThat(actualHeadObjectRequest.bucket(), equalTo(bucketName)),
() -> assertThat(actualHeadObjectRequest.key(), equalTo(key)),
() -> assertThat(actualHeadObjectRequest.expectedBucketOwner(), nullValue())
);
}

@Test
public void getLength_requests_head_for_bucket_key_and_owner_when_bucket_has_owner() {
final HeadObjectResponse headObjectResponse = mock(HeadObjectResponse.class);
when(s3Client.headObject(any(HeadObjectRequest.class))).thenReturn(headObjectResponse);
when(headObjectResponse.contentLength()).thenReturn(12345L);

final String owner = UUID.randomUUID().toString();
when(bucketOwnerProvider.getBucketOwner(bucketName)).thenReturn(Optional.of(owner));

long length = createObjectUnderTest().getLength();

assertThat(length, equalTo(12345L));
verify(s3Client, times(1)).headObject(any(HeadObjectRequest.class));

final ArgumentCaptor<HeadObjectRequest> headObjectRequestArgumentCaptor = ArgumentCaptor.forClass(HeadObjectRequest.class);
verify(s3Client, times(1)).headObject(headObjectRequestArgumentCaptor.capture());

final HeadObjectRequest actualHeadObjectRequest = headObjectRequestArgumentCaptor.getValue();
assertAll(
() -> assertThat(actualHeadObjectRequest.bucket(), equalTo(bucketName)),
() -> assertThat(actualHeadObjectRequest.key(), equalTo(key)),
() -> assertThat(actualHeadObjectRequest.expectedBucketOwner(), equalTo(owner))
);
}

@Test
public void testNewStream() {
HeadObjectResponse headObjectResponse = mock(HeadObjectResponse.class);
when(s3Client.headObject(any(HeadObjectRequest.class))).thenReturn(headObjectResponse);

SeekableInputStream seekableInputStream = s3InputFile.newStream();
SeekableInputStream seekableInputStream = createObjectUnderTest().newStream();

assertThat(seekableInputStream.getClass(), equalTo(S3InputStream.class));
}
Expand Down
Loading

0 comments on commit decccb9

Please sign in to comment.