Skip to content

Commit

Permalink
Elastic search connector enable pagination query data from get record…
Browse files Browse the repository at this point in the history
…. This will allow use to bypass max_result_window settigns
  • Loading branch information
chngpe authored and nauy2697 committed Sep 26, 2022
1 parent 34a5139 commit cfcc1bc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 31 deletions.
5 changes: 5 additions & 0 deletions athena-elasticsearch/athena-elasticsearch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ Parameters:
Description: "timeout period (in seconds) for Search queries used in the retrieval of documents from an index (default is 12 minutes)."
Default: 720
Type: Number
QueryScrollTimeout:
Description: "timeout period (in seconds) for scroll timeout used in the retrieval of documents (default is 60 seconds)."
Default: 60
Type: Number
IsVPCAccess:
AllowedValues:
- true
Expand Down Expand Up @@ -91,6 +95,7 @@ Resources:
domain_mapping: !Ref DomainMapping
query_timeout_cluster: !Ref QueryTimeoutCluster
query_timeout_search: !Ref QueryTimeoutSearch
query_scroll_timeout: !Ref QueryScrollTimeout
FunctionName: !Sub "${AthenaCatalogName}"
Handler: "com.amazonaws.athena.connectors.elasticsearch.ElasticsearchCompositeHandler"
CodeUri: "./target/athena-elasticsearch-2022.38.1.jar"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@
import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.types.pojo.Field;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.search.Scroll;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.slf4j.Logger;
Expand Down Expand Up @@ -68,7 +73,11 @@ public class ElasticsearchRecordHandler

// Env. variable that holds the query timeout period for the Search queries.
private static final String QUERY_TIMEOUT_SEARCH = "query_timeout_search";
// Env. variable that holds the scroll timeout for the Search queries.
private static final String SCROLL_TIMEOUT = "query_scroll_timeout";

private final long queryTimeout;
private final long scrollTimeout;

// Pagination batch size (100 documents).
private static final int QUERY_BATCH_SIZE = 100;
Expand All @@ -85,17 +94,19 @@ public ElasticsearchRecordHandler()
this.clientFactory = new AwsRestHighLevelClientFactory(getEnv(AUTO_DISCOVER_ENDPOINT)
.equalsIgnoreCase("true"));
this.queryTimeout = Long.parseLong(getEnv(QUERY_TIMEOUT_SEARCH));
this.scrollTimeout = Strings.isNullOrEmpty(getEnv(SCROLL_TIMEOUT)) ? 60L : Long.parseLong(getEnv(SCROLL_TIMEOUT));
}

@VisibleForTesting
protected ElasticsearchRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena amazonAthena,
AwsRestHighLevelClientFactory clientFactory, long queryTimeout)
AwsRestHighLevelClientFactory clientFactory, long queryTimeout, long scrollTimeout)
{
super(amazonS3, secretsManager, amazonAthena, SOURCE_TYPE);

this.typeUtils = new ElasticsearchTypeUtils();
this.clientFactory = clientFactory;
this.queryTimeout = queryTimeout;
this.scrollTimeout = scrollTimeout;
}

/**
Expand Down Expand Up @@ -147,38 +158,45 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor
GeneratedRowWriter rowWriter = createFieldExtractors(recordsRequest);

// Create a new search-source injected with the projection, predicate, and the pagination batch size.
SearchSourceBuilder searchSource = new SearchSourceBuilder().size(QUERY_BATCH_SIZE)
SearchSourceBuilder searchSource = new SearchSourceBuilder()
.size(QUERY_BATCH_SIZE)
.timeout(new TimeValue(queryTimeout, TimeUnit.SECONDS))
.fetchSource(ElasticsearchQueryUtils.getProjection(recordsRequest.getSchema()))
.query(ElasticsearchQueryUtils.getQuery(recordsRequest.getConstraints().getSummary()));

//init scroll
Scroll scroll = new Scroll(TimeValue.timeValueSeconds(this.scrollTimeout));
// Create a new search-request for the specified index.
SearchRequest searchRequest = new SearchRequest(index).preference(shard);
int hitsNum;
int currPosition = 0;
do {
// Process the search request injecting the search-source, and setting the from position
// used for pagination of results.
SearchResponse searchResponse = client
.getDocuments(searchRequest.source(searchSource.from(currPosition)));

// Throw on query timeout.
SearchRequest searchRequest = new SearchRequest(index)
.preference(shard)
.scroll(scroll)
.source(searchSource.from(0));

//Read the returned scroll id, which points to the search context that’s being kept alive and will be needed in the following search scroll call
SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);

while (searchResponse.getHits() != null
&& searchResponse.getHits().getHits() != null
&& searchResponse.getHits().getHits().length > 0
&& queryStatusChecker.isQueryRunning()) {
Iterator<SearchHit> finalIterator = searchResponse.getHits().iterator();
while (finalIterator.hasNext() && queryStatusChecker.isQueryRunning()) {
++numRows;
spiller.writeRows((Block block, int rowNum) ->
rowWriter.writeRow(block, rowNum, client.getDocument(finalIterator.next())) ? 1 : 0);
}

//prep for next hits and keep track of scroll id.
SearchScrollRequest scrollRequest = new SearchScrollRequest(searchResponse.getScrollId()).scroll(scroll);
searchResponse = client.scroll(scrollRequest, RequestOptions.DEFAULT);
if (searchResponse.isTimedOut()) {
throw new RuntimeException("Request for index (" + index + ") " + shard + " timed out.");
}
}

// Increment current position to next batch of results.
currPosition += QUERY_BATCH_SIZE;
// Process hits.
Iterator<SearchHit> hitIterator = searchResponse.getHits().iterator();
hitsNum = searchResponse.getHits().getHits().length;

while (hitIterator.hasNext() && queryStatusChecker.isQueryRunning()) {
++numRows;
spiller.writeRows((Block block, int rowNum) ->
rowWriter.writeRow(block, rowNum, client.getDocument(hitIterator.next())) ? 1 : 0);
}
// if hitsNum < QUERY_BATCH_SIZE, then this is the last batch of documents.
} while (hitsNum == QUERY_BATCH_SIZE && queryStatusChecker.isQueryRunning());
ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
clearScrollRequest.addScrollId(searchResponse.getScrollId());
client.clearScroll(clearScrollRequest, RequestOptions.DEFAULT);
}
catch (IOException error) {
throw new RuntimeException("Error sending search query: " + error.getMessage(), error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ public class ElasticsearchRecordHandlerTest
@Mock
private SearchResponse mockResponse;

@Mock
private SearchResponse mockScrollResponse;

@Mock
private AmazonS3 amazonS3;

Expand All @@ -128,6 +131,9 @@ public class ElasticsearchRecordHandlerTest
@Mock
S3Object s3Object;

String[] expectedDocuments = {"[mytext : My favorite Sci-Fi movie is Interstellar.], [mykeyword : I love keywords.], [mylong : {11,12,13}], [myinteger : 666115], [myshort : 1972], [mybyte : 5], [mydouble : 47.5], [myscaled : 7], [myfloat : 5.6], [myhalf : 6.2], [mydatemilli : 2020-05-15T06:49:30], [mydatenano : {2020-05-15T06:50:01.457}], [myboolean : true], [mybinary : U29tZSBiaW5hcnkgYmxvYg==], [mynested : {[l1long : 357345987],[l1date : 2020-05-15T06:57:44.123],[l1nested : {[l2short : {1,2,3,4,5,6,7,8,9,10}],[l2binary : U29tZSBiaW5hcnkgYmxvYg==]}]}], [objlistouter : {}]"
,"[mytext : My favorite TV comedy is Seinfeld.], [mykeyword : I hate key-values.], [mylong : {14,null,16}], [myinteger : 732765666], [myshort : 1971], [mybyte : 7], [mydouble : 27.6], [myscaled : 10], [myfloat : 7.8], [myhalf : 7.3], [mydatemilli : null], [mydatenano : {2020-05-15T06:49:30.001}], [myboolean : false], [mybinary : U29tZSBiaW5hcnkgYmxvYg==], [mynested : {[l1long : 7322775555],[l1date : 2020-05-15T01:57:44.777],[l1nested : {[l2short : {11,12,13,14,15,16,null,18,19,20}],[l2binary : U29tZSBiaW5hcnkgYmxvYg==]}]}], [objlistouter : {{[objlistinner : {{[title : somebook],[hi : hi]}}],[test2 : title]}}]"};

@Before
public void setUp()
throws IOException
Expand Down Expand Up @@ -305,10 +311,12 @@ public void setUp()
.build();

when(clientFactory.getOrCreateClient(anyString())).thenReturn(mockClient);
when(mockClient.getDocuments(any())).thenReturn(mockResponse);
when(mockClient.getDocument(any())).thenReturn(document1, document2);
when(mockClient.search(any(), any())).thenReturn(mockResponse);
when(mockScrollResponse.getHits()).thenReturn(null);
when(mockClient.scroll(any(), any())).thenReturn(mockScrollResponse);

handler = new ElasticsearchRecordHandler(amazonS3, awsSecretsManager, athena, clientFactory, 720);
handler = new ElasticsearchRecordHandler(amazonS3, awsSecretsManager, athena, clientFactory, 720, 60);

logger.info("setUpBefore - exit");
}
Expand All @@ -331,6 +339,7 @@ public void doReadRecordsNoSpill()
SearchHits searchHits =
new SearchHits(searchHit, new TotalHits(2, TotalHits.Relation.EQUAL_TO), 4);
when(mockResponse.getHits()).thenReturn(searchHits);
when(mockResponse.getScrollId()).thenReturn("123");

Map<String, ValueSet> constraintsMap = new HashMap<>();
constraintsMap.put("myshort", SortedRangeSet.copyOf(Types.MinorType.SMALLINT.getType(),
Expand All @@ -357,7 +366,7 @@ public void doReadRecordsNoSpill()
// Capture the SearchRequest object from the call to client.getDocuments().
// The former contains information such as the projection and predicate.
ArgumentCaptor<SearchRequest> argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class);
verify(mockClient).getDocuments(argumentCaptor.capture());
verify(mockClient).search(argumentCaptor.capture(), any());
SearchRequest searchRequest = argumentCaptor.getValue();
// Get the actual projection and compare to the expected one.
List<String> actualProjection = ImmutableList.copyOf(searchRequest.source().fetchSource().includes());
Expand All @@ -374,6 +383,7 @@ public void doReadRecordsNoSpill()
assertEquals(2, response.getRecords().getRowCount());
for (int i = 0; i < response.getRecords().getRowCount(); ++i) {
logger.info("doReadRecordsNoSpill - Row: {}, {}", i, BlockUtils.rowToString(response.getRecords(), i));
assertEquals(expectedDocuments[i], BlockUtils.rowToString(response.getRecords(), i));
}

logger.info("doReadRecordsNoSpill: exit");
Expand Down Expand Up @@ -424,15 +434,18 @@ public void doReadRecordsSpill()
try (RemoteReadRecordsResponse response = (RemoteReadRecordsResponse) rawResponse) {
logger.info("doReadRecordsSpill: remoteBlocks[{}]", response.getRemoteBlocks().size());

assertEquals(3, response.getNumberBlocks());
assertEquals(1, response.getNumberBlocks());

int blockNum = 0;
for (SpillLocation next : response.getRemoteBlocks()) {
S3SpillLocation spillLocation = (S3SpillLocation) next;
try (Block block = spillReader.read(spillLocation, response.getEncryptionKey(), response.getSchema())) {
logger.info("doReadRecordsSpill: blockNum[{}] and recordCount[{}]", blockNum++, block.getRowCount());
logger.info("doReadRecordsSpill: {}", BlockUtils.rowToString(block, 0));
assertNotNull(BlockUtils.rowToString(block, 0));
assertEquals(expectedDocuments.length, block.getRowCount());
for (int rowCount = 0; rowCount < block.getRowCount(); rowCount++) {
logger.info("doReadRecordsSpill: {}", BlockUtils.rowToString(block, rowCount));
assertEquals(expectedDocuments[rowCount], BlockUtils.rowToString(block, rowCount));
}
}
}
}
Expand Down

0 comments on commit cfcc1bc

Please sign in to comment.