Skip to content

Commit

Permalink
Retry s3 reads on socket exceptions.
Browse files Browse the repository at this point in the history
S3 will reset the conenction on their end frequently. To not lose data,
data prepper should retry all socket exceptions by attempting to re-open
the stream.

Signed-off-by: Adi Suresh <[email protected]>
  • Loading branch information
asuresh8 committed Jul 7, 2023
1 parent 45b6e55 commit 74878a1
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 30 deletions.
1 change: 1 addition & 0 deletions data-prepper-plugins/s3-source/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies {
implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.15.2'
implementation 'org.xerial.snappy:snappy-java:1.1.10.1'
implementation 'org.apache.parquet:parquet-common:1.12.3'
implementation 'dev.failsafe:failsafe:3.3.2'
testImplementation 'org.apache.commons:commons-lang3:3.12.0'
testImplementation 'com.github.tomakehurst:wiremock:3.0.0-beta-8'
testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;

import java.time.Duration;

public class S3InputFile implements InputFile {

private static final Duration DEFAULT_RETRY_DELAY = Duration.ofSeconds(6);

private static final int DEFAULT_RETRIES = 10;

private final S3Client s3Client;

private final S3ObjectReference s3ObjectReference;
Expand Down Expand Up @@ -42,8 +48,8 @@ public long getLength() {
*/
@Override
public SeekableInputStream newStream() {

return new S3InputStream(s3Client, s3ObjectReference, getMetadata(), s3ObjectPluginMetrics);
return new S3InputStream(
s3Client, s3ObjectReference, getMetadata(), s3ObjectPluginMetrics, DEFAULT_RETRY_DELAY, DEFAULT_RETRIES);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package org.opensearch.dataprepper.plugins.source;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.LongAdder;

import com.google.common.base.Preconditions;
import com.google.common.io.ByteStreams;
import dev.failsafe.Failsafe;
import dev.failsafe.FailsafeException;
import dev.failsafe.RetryPolicy;
import dev.failsafe.function.CheckedSupplier;
import org.apache.parquet.io.SeekableInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -19,6 +17,16 @@
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
import software.amazon.awssdk.services.s3.model.S3Exception;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.atomic.LongAdder;

class S3InputStream extends SeekableInputStream {

private static final int COPY_BUFFER_SIZE = 8192;
Expand All @@ -27,6 +35,12 @@ class S3InputStream extends SeekableInputStream {

private static final int SKIP_SIZE = 1024 * 1024;

private static final List<Class<? extends Throwable>> RETRYABLE_EXCEPTIONS = List.of(
EOFException.class,
SocketException.class,
SocketTimeoutException.class
);

private final S3Client s3Client;

private final S3ObjectReference s3ObjectReference;
Expand All @@ -52,11 +66,17 @@ class S3InputStream extends SeekableInputStream {

private boolean closed = false;

private RetryPolicy<byte[]> retryPolicyReturningByteArray;

private RetryPolicy<Integer> retryPolicyReturningInteger;

public S3InputStream(
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final HeadObjectResponse metadata,
final S3ObjectPluginMetrics s3ObjectPluginMetrics
final S3Client s3Client,
final S3ObjectReference s3ObjectReference,
final HeadObjectResponse metadata,
final S3ObjectPluginMetrics s3ObjectPluginMetrics,
final Duration retryDelay,
final int retries
) {
this.s3Client = s3Client;
this.s3ObjectReference = s3ObjectReference;
Expand All @@ -65,10 +85,23 @@ public S3InputStream(
this.bytesCounter = new LongAdder();

this.getObjectRequestBuilder = GetObjectRequest.builder()
.bucket(this.s3ObjectReference.getBucketName())
.key(this.s3ObjectReference.getKey());
.bucket(this.s3ObjectReference.getBucketName())
.key(this.s3ObjectReference.getKey());

this.retryPolicyReturningByteArray = RetryPolicy.<byte[]>builder()
.handle(RETRYABLE_EXCEPTIONS)
.withDelay(retryDelay)
.withMaxRetries(retries)
.build();

this.retryPolicyReturningInteger = RetryPolicy.<Integer>builder()
.handle(RETRYABLE_EXCEPTIONS)
.withDelay(retryDelay)
.withMaxRetries(retries)
.build();
}


// Implement all InputStream methods first:

/**
Expand Down Expand Up @@ -129,7 +162,7 @@ public int read() throws IOException {
Preconditions.checkState(!closed, "Cannot read: already closed");
positionStream();

final int byteRead = stream.read();
final int byteRead = executeWithRetriesAndReturnInt(() -> stream.read());

if (byteRead != -1) {
pos += 1;
Expand Down Expand Up @@ -165,7 +198,7 @@ public int read(byte[] b, int off, int len) throws IOException {
Preconditions.checkState(!closed, "Cannot read: already closed");
positionStream();

final int bytesRead = stream.read(b, off, len);
final int bytesRead = executeWithRetriesAndReturnInt(() -> stream.read(b, off, len));

if (bytesRead > 0) {
pos += bytesRead;
Expand All @@ -186,7 +219,7 @@ public byte[] readAllBytes() throws IOException {
Preconditions.checkState(!closed, "Cannot read: already closed");
positionStream();

final byte[] bytesRead = stream.readAllBytes();
final byte[] bytesRead = executeWithRetriesAndReturnByteArray(() -> stream.readAllBytes());

pos += bytesRead.length;
next += bytesRead.length;
Expand All @@ -208,7 +241,7 @@ public int readNBytes(byte[] b, int off, int len) throws IOException {
Preconditions.checkState(!closed, "Cannot read: already closed");
positionStream();

final int bytesRead = stream.readNBytes(b, off, len);
final int bytesRead = executeWithRetriesAndReturnInt(() -> stream.readNBytes(b, off, len));

if (bytesRead > 0) {
pos += bytesRead;
Expand All @@ -229,7 +262,7 @@ public byte[] readNBytes(int len) throws IOException {
Preconditions.checkState(!closed, "Cannot read: already closed");
positionStream();

final byte[] bytesRead = stream.readNBytes(len);
final byte[] bytesRead = executeWithRetriesAndReturnByteArray(() -> stream.readNBytes(len));

pos += bytesRead.length;
next += bytesRead.length;
Expand Down Expand Up @@ -332,7 +365,7 @@ public void readFully(byte[] bytes, int start, int len) throws IOException {
Preconditions.checkState(!closed, "Cannot read: already closed");
positionStream();

int bytesRead = readFully(stream, bytes, start, len);
final int bytesRead = executeWithRetriesAndReturnInt(() -> readFully(stream, bytes, start, len));

if (bytesRead > 0) {
this.pos += bytesRead;
Expand Down Expand Up @@ -360,9 +393,9 @@ public int read(ByteBuffer buf) throws IOException {

int bytesRead = 0;
if (buf.hasArray()) {
bytesRead = readHeapBuffer(stream, buf);
bytesRead = executeWithRetriesAndReturnInt(() -> readHeapBuffer(stream, buf));
} else {
bytesRead = readDirectBuffer(stream, buf, temp);
bytesRead = executeWithRetriesAndReturnInt(() -> readDirectBuffer(stream, buf, temp));
}

if (bytesRead > 0) {
Expand Down Expand Up @@ -393,9 +426,9 @@ public void readFully(ByteBuffer buf) throws IOException {

int bytesRead = 0;
if (buf.hasArray()) {
bytesRead = readFullyHeapBuffer(stream, buf);
bytesRead = executeWithRetriesAndReturnInt(() -> readFullyHeapBuffer(stream, buf));
} else {
bytesRead = readFullyDirectBuffer(stream, buf, temp);
bytesRead = executeWithRetriesAndReturnInt(() -> readFullyDirectBuffer(stream, buf, temp));
}

if (bytesRead > 0) {
Expand Down Expand Up @@ -612,9 +645,7 @@ static int readFullyDirectBuffer(InputStream f, ByteBuffer buf, byte[] temp) thr
while (nextReadLength > 0 && (bytesRead = f.read(temp, 0, nextReadLength)) >= 0) {
buf.put(temp, 0, bytesRead);
nextReadLength = Math.min(buf.remaining(), temp.length);
if (bytesRead >= 0) {
totalBytesRead += bytesRead;
}
totalBytesRead += bytesRead;
}

if (bytesRead < 0 && buf.remaining() > 0) {
Expand All @@ -632,4 +663,32 @@ private void recordS3Exception(final S3Exception ex) {
s3ObjectPluginMetrics.getS3ObjectsFailedAccessDeniedCounter().increment();
}
}

private int executeWithRetriesAndReturnInt(CheckedSupplier<Integer> supplier) throws IOException {
return executeWithRetries(retryPolicyReturningInteger, supplier);
}

private byte[] executeWithRetriesAndReturnByteArray(CheckedSupplier<byte[]> supplier) throws IOException {
return executeWithRetries(retryPolicyReturningByteArray, supplier);
}


private <T> T executeWithRetries(RetryPolicy<T> retryPolicy, CheckedSupplier<T> supplier) throws IOException {
try {
return Failsafe.with(retryPolicy).get(() -> {
try {
return supplier.get();
} catch (SocketException e) {
LOG.warn("Resetting stream due to underlying socket exception", e);
openStream();
throw e;
}
});
} catch (FailsafeException e) {
LOG.error("Failed to read with Retries", e);
throw new IOException(e.getCause());
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,28 @@
import software.amazon.awssdk.services.s3.model.S3Exception;

import java.io.ByteArrayInputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.time.Duration;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class S3InputStreamTest {

private static final Duration RETRY_DELAY = Duration.ofMillis(10);

private static final int RETRIES = 3;

@Mock(lenient = true)
private S3Client s3Client;
@Mock(lenient = true)
Expand Down Expand Up @@ -57,7 +64,8 @@ void setUp() {
when(s3ObjectPluginMetrics.getS3ObjectsFailedNotFoundCounter()).thenReturn(s3ObjectsFailedNotFoundCounter);
when(s3ObjectPluginMetrics.getS3ObjectsFailedAccessDeniedCounter()).thenReturn(s3ObjectsFailedAccessDeniedCounter);

s3InputStream = new S3InputStream(s3Client, s3ObjectReference, metadata, s3ObjectPluginMetrics);
s3InputStream = new S3InputStream(
s3Client, s3ObjectReference, metadata, s3ObjectPluginMetrics, RETRY_DELAY, RETRIES);
}

@Test
Expand Down Expand Up @@ -117,6 +125,47 @@ void testReadEndOfFile() throws IOException {
verify(s3ObjectSizeProcessedSummary).record(0.0);
}

@Test
void testReadSucceedsAfterRetries() throws IOException {
InputStream mockInputStream = mock(InputStream.class);
when(s3Client.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class)))
.thenReturn(mockInputStream);

when(mockInputStream.read())
.thenThrow(SocketException.class)
.thenReturn(1);

int firstByte = s3InputStream.read();
assertEquals(1, firstByte);

s3InputStream.close();

verify(s3ObjectSizeProcessedSummary).record(1.0);

verify(mockInputStream, times(2)).read();
verify(mockInputStream, times(2)).close();
verify(s3Client, times(2))
.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class));
}

@Test
void testReadFailsAfterRetries() throws IOException {
InputStream mockInputStream = mock(InputStream.class);
when(s3Client.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class)))
.thenReturn(mockInputStream);

when(mockInputStream.read()).thenThrow(SocketException.class);

assertThrows(IOException.class, () -> s3InputStream.read());

s3InputStream.close();

verify(mockInputStream, times(RETRIES + 1)).read();
verify(mockInputStream, times(RETRIES + 2)).close();
verify(s3Client, times(RETRIES + 2))
.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class));
}

@Test
void testReadByteArray() throws IOException {
InputStream inputStream = new ByteArrayInputStream("Test data".getBytes());
Expand Down Expand Up @@ -148,6 +197,47 @@ void testReadAllBytes() throws IOException {
verify(s3ObjectSizeProcessedSummary).record(9.0);
}

@Test
void testReadAllBytesSucceedsAfterRetries() throws IOException {
InputStream mockInputStream = mock(InputStream.class);
when(s3Client.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class)))
.thenReturn(mockInputStream);

when(mockInputStream.readAllBytes())
.thenThrow(SocketException.class)
.thenReturn("Test data".getBytes());

final byte[] buffer = s3InputStream.readAllBytes();

assertArrayEquals("Test data".getBytes(), buffer);

s3InputStream.close();

verify(s3ObjectSizeProcessedSummary).record(9.0);
verify(mockInputStream, times(2)).readAllBytes();
verify(mockInputStream, times(2)).close();
verify(s3Client, times(2))
.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class));
}

@Test
void testReadAllBytesFailsAfterRetries() throws IOException {
InputStream mockInputStream = mock(InputStream.class);
when(s3Client.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class)))
.thenReturn(mockInputStream);

when(mockInputStream.readAllBytes()).thenThrow(SocketException.class);

assertThrows(IOException.class, () -> s3InputStream.readAllBytes());

s3InputStream.close();

verify(mockInputStream, times(RETRIES + 1)).readAllBytes();
verify(mockInputStream, times(RETRIES + 2)).close();
verify(s3Client, times(RETRIES + 2))
.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class));
}

@Test
void testReadNBytes_intoArray() throws Exception {
InputStream inputStream = new ByteArrayInputStream("Test data".getBytes());
Expand Down Expand Up @@ -277,7 +367,7 @@ void testReadFullyByteBuffer_endOfFile() throws IOException {
s3InputStream.seek(0); // Force opening the stream

ByteBuffer buffer = ByteBuffer.allocate(4);
assertThrows(EOFException.class, () -> s3InputStream.readFully(buffer));
assertThrows(IOException.class, () -> s3InputStream.readFully(buffer));

s3InputStream.close();
verify(s3ObjectSizeProcessedSummary).record(0.0);
Expand Down

0 comments on commit 74878a1

Please sign in to comment.