From cfcc1bc602bf642843e59da38bb6e8079582e76a Mon Sep 17 00:00:00 2001 From: chngpe Date: Wed, 21 Sep 2022 20:02:03 -0400 Subject: [PATCH] Elastic search connector enable pagination query data from get record. This will allow use to bypass max_result_window settigns --- .../athena-elasticsearch.yaml | 5 ++ .../ElasticsearchRecordHandler.java | 68 ++++++++++++------- .../ElasticsearchRecordHandlerTest.java | 25 +++++-- 3 files changed, 67 insertions(+), 31 deletions(-) diff --git a/athena-elasticsearch/athena-elasticsearch.yaml b/athena-elasticsearch/athena-elasticsearch.yaml index e2c338eac7..fdb071c31f 100644 --- a/athena-elasticsearch/athena-elasticsearch.yaml +++ b/athena-elasticsearch/athena-elasticsearch.yaml @@ -59,6 +59,10 @@ Parameters: Description: "timeout period (in seconds) for Search queries used in the retrieval of documents from an index (default is 12 minutes)." Default: 720 Type: Number + QueryScrollTimeout: + Description: "timeout period (in seconds) for scroll timeout used in the retrieval of documents (default is 60 seconds)." + Default: 60 + Type: Number IsVPCAccess: AllowedValues: - true @@ -91,6 +95,7 @@ Resources: domain_mapping: !Ref DomainMapping query_timeout_cluster: !Ref QueryTimeoutCluster query_timeout_search: !Ref QueryTimeoutSearch + query_scroll_timeout: !Ref QueryScrollTimeout FunctionName: !Sub "${AthenaCatalogName}" Handler: "com.amazonaws.athena.connectors.elasticsearch.ElasticsearchCompositeHandler" CodeUri: "./target/athena-elasticsearch-2022.38.1.jar" diff --git a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java index fef9f02347..810cf0e255 100644 --- a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java +++ b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java @@ -34,9 +34,14 @@ import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; +import org.elasticsearch.action.search.ClearScrollRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchScrollRequest; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.search.Scroll; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.slf4j.Logger; @@ -68,7 +73,11 @@ public class ElasticsearchRecordHandler // Env. variable that holds the query timeout period for the Search queries. private static final String QUERY_TIMEOUT_SEARCH = "query_timeout_search"; + // Env. variable that holds the scroll timeout for the Search queries. + private static final String SCROLL_TIMEOUT = "query_scroll_timeout"; + private final long queryTimeout; + private final long scrollTimeout; // Pagination batch size (100 documents). private static final int QUERY_BATCH_SIZE = 100; @@ -85,17 +94,19 @@ public ElasticsearchRecordHandler() this.clientFactory = new AwsRestHighLevelClientFactory(getEnv(AUTO_DISCOVER_ENDPOINT) .equalsIgnoreCase("true")); this.queryTimeout = Long.parseLong(getEnv(QUERY_TIMEOUT_SEARCH)); + this.scrollTimeout = Strings.isNullOrEmpty(getEnv(SCROLL_TIMEOUT)) ? 60L : Long.parseLong(getEnv(SCROLL_TIMEOUT)); } @VisibleForTesting protected ElasticsearchRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena amazonAthena, - AwsRestHighLevelClientFactory clientFactory, long queryTimeout) + AwsRestHighLevelClientFactory clientFactory, long queryTimeout, long scrollTimeout) { super(amazonS3, secretsManager, amazonAthena, SOURCE_TYPE); this.typeUtils = new ElasticsearchTypeUtils(); this.clientFactory = clientFactory; this.queryTimeout = queryTimeout; + this.scrollTimeout = scrollTimeout; } /** @@ -147,38 +158,45 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor GeneratedRowWriter rowWriter = createFieldExtractors(recordsRequest); // Create a new search-source injected with the projection, predicate, and the pagination batch size. - SearchSourceBuilder searchSource = new SearchSourceBuilder().size(QUERY_BATCH_SIZE) + SearchSourceBuilder searchSource = new SearchSourceBuilder() + .size(QUERY_BATCH_SIZE) .timeout(new TimeValue(queryTimeout, TimeUnit.SECONDS)) .fetchSource(ElasticsearchQueryUtils.getProjection(recordsRequest.getSchema())) .query(ElasticsearchQueryUtils.getQuery(recordsRequest.getConstraints().getSummary())); + + //init scroll + Scroll scroll = new Scroll(TimeValue.timeValueSeconds(this.scrollTimeout)); // Create a new search-request for the specified index. - SearchRequest searchRequest = new SearchRequest(index).preference(shard); - int hitsNum; - int currPosition = 0; - do { - // Process the search request injecting the search-source, and setting the from position - // used for pagination of results. - SearchResponse searchResponse = client - .getDocuments(searchRequest.source(searchSource.from(currPosition))); - - // Throw on query timeout. + SearchRequest searchRequest = new SearchRequest(index) + .preference(shard) + .scroll(scroll) + .source(searchSource.from(0)); + + //Read the returned scroll id, which points to the search context that’s being kept alive and will be needed in the following search scroll call + SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); + + while (searchResponse.getHits() != null + && searchResponse.getHits().getHits() != null + && searchResponse.getHits().getHits().length > 0 + && queryStatusChecker.isQueryRunning()) { + Iterator finalIterator = searchResponse.getHits().iterator(); + while (finalIterator.hasNext() && queryStatusChecker.isQueryRunning()) { + ++numRows; + spiller.writeRows((Block block, int rowNum) -> + rowWriter.writeRow(block, rowNum, client.getDocument(finalIterator.next())) ? 1 : 0); + } + + //prep for next hits and keep track of scroll id. + SearchScrollRequest scrollRequest = new SearchScrollRequest(searchResponse.getScrollId()).scroll(scroll); + searchResponse = client.scroll(scrollRequest, RequestOptions.DEFAULT); if (searchResponse.isTimedOut()) { throw new RuntimeException("Request for index (" + index + ") " + shard + " timed out."); } + } - // Increment current position to next batch of results. - currPosition += QUERY_BATCH_SIZE; - // Process hits. - Iterator hitIterator = searchResponse.getHits().iterator(); - hitsNum = searchResponse.getHits().getHits().length; - - while (hitIterator.hasNext() && queryStatusChecker.isQueryRunning()) { - ++numRows; - spiller.writeRows((Block block, int rowNum) -> - rowWriter.writeRow(block, rowNum, client.getDocument(hitIterator.next())) ? 1 : 0); - } - // if hitsNum < QUERY_BATCH_SIZE, then this is the last batch of documents. - } while (hitsNum == QUERY_BATCH_SIZE && queryStatusChecker.isQueryRunning()); + ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); + clearScrollRequest.addScrollId(searchResponse.getScrollId()); + client.clearScroll(clearScrollRequest, RequestOptions.DEFAULT); } catch (IOException error) { throw new RuntimeException("Error sending search query: " + error.getMessage(), error); diff --git a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java index bfb844cfbd..7e0a773484 100644 --- a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java +++ b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java @@ -113,6 +113,9 @@ public class ElasticsearchRecordHandlerTest @Mock private SearchResponse mockResponse; + @Mock + private SearchResponse mockScrollResponse; + @Mock private AmazonS3 amazonS3; @@ -128,6 +131,9 @@ public class ElasticsearchRecordHandlerTest @Mock S3Object s3Object; + String[] expectedDocuments = {"[mytext : My favorite Sci-Fi movie is Interstellar.], [mykeyword : I love keywords.], [mylong : {11,12,13}], [myinteger : 666115], [myshort : 1972], [mybyte : 5], [mydouble : 47.5], [myscaled : 7], [myfloat : 5.6], [myhalf : 6.2], [mydatemilli : 2020-05-15T06:49:30], [mydatenano : {2020-05-15T06:50:01.457}], [myboolean : true], [mybinary : U29tZSBiaW5hcnkgYmxvYg==], [mynested : {[l1long : 357345987],[l1date : 2020-05-15T06:57:44.123],[l1nested : {[l2short : {1,2,3,4,5,6,7,8,9,10}],[l2binary : U29tZSBiaW5hcnkgYmxvYg==]}]}], [objlistouter : {}]" + ,"[mytext : My favorite TV comedy is Seinfeld.], [mykeyword : I hate key-values.], [mylong : {14,null,16}], [myinteger : 732765666], [myshort : 1971], [mybyte : 7], [mydouble : 27.6], [myscaled : 10], [myfloat : 7.8], [myhalf : 7.3], [mydatemilli : null], [mydatenano : {2020-05-15T06:49:30.001}], [myboolean : false], [mybinary : U29tZSBiaW5hcnkgYmxvYg==], [mynested : {[l1long : 7322775555],[l1date : 2020-05-15T01:57:44.777],[l1nested : {[l2short : {11,12,13,14,15,16,null,18,19,20}],[l2binary : U29tZSBiaW5hcnkgYmxvYg==]}]}], [objlistouter : {{[objlistinner : {{[title : somebook],[hi : hi]}}],[test2 : title]}}]"}; + @Before public void setUp() throws IOException @@ -305,10 +311,12 @@ public void setUp() .build(); when(clientFactory.getOrCreateClient(anyString())).thenReturn(mockClient); - when(mockClient.getDocuments(any())).thenReturn(mockResponse); when(mockClient.getDocument(any())).thenReturn(document1, document2); + when(mockClient.search(any(), any())).thenReturn(mockResponse); + when(mockScrollResponse.getHits()).thenReturn(null); + when(mockClient.scroll(any(), any())).thenReturn(mockScrollResponse); - handler = new ElasticsearchRecordHandler(amazonS3, awsSecretsManager, athena, clientFactory, 720); + handler = new ElasticsearchRecordHandler(amazonS3, awsSecretsManager, athena, clientFactory, 720, 60); logger.info("setUpBefore - exit"); } @@ -331,6 +339,7 @@ public void doReadRecordsNoSpill() SearchHits searchHits = new SearchHits(searchHit, new TotalHits(2, TotalHits.Relation.EQUAL_TO), 4); when(mockResponse.getHits()).thenReturn(searchHits); + when(mockResponse.getScrollId()).thenReturn("123"); Map constraintsMap = new HashMap<>(); constraintsMap.put("myshort", SortedRangeSet.copyOf(Types.MinorType.SMALLINT.getType(), @@ -357,7 +366,7 @@ public void doReadRecordsNoSpill() // Capture the SearchRequest object from the call to client.getDocuments(). // The former contains information such as the projection and predicate. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); - verify(mockClient).getDocuments(argumentCaptor.capture()); + verify(mockClient).search(argumentCaptor.capture(), any()); SearchRequest searchRequest = argumentCaptor.getValue(); // Get the actual projection and compare to the expected one. List actualProjection = ImmutableList.copyOf(searchRequest.source().fetchSource().includes()); @@ -374,6 +383,7 @@ public void doReadRecordsNoSpill() assertEquals(2, response.getRecords().getRowCount()); for (int i = 0; i < response.getRecords().getRowCount(); ++i) { logger.info("doReadRecordsNoSpill - Row: {}, {}", i, BlockUtils.rowToString(response.getRecords(), i)); + assertEquals(expectedDocuments[i], BlockUtils.rowToString(response.getRecords(), i)); } logger.info("doReadRecordsNoSpill: exit"); @@ -424,15 +434,18 @@ public void doReadRecordsSpill() try (RemoteReadRecordsResponse response = (RemoteReadRecordsResponse) rawResponse) { logger.info("doReadRecordsSpill: remoteBlocks[{}]", response.getRemoteBlocks().size()); - assertEquals(3, response.getNumberBlocks()); + assertEquals(1, response.getNumberBlocks()); int blockNum = 0; for (SpillLocation next : response.getRemoteBlocks()) { S3SpillLocation spillLocation = (S3SpillLocation) next; try (Block block = spillReader.read(spillLocation, response.getEncryptionKey(), response.getSchema())) { logger.info("doReadRecordsSpill: blockNum[{}] and recordCount[{}]", blockNum++, block.getRowCount()); - logger.info("doReadRecordsSpill: {}", BlockUtils.rowToString(block, 0)); - assertNotNull(BlockUtils.rowToString(block, 0)); + assertEquals(expectedDocuments.length, block.getRowCount()); + for (int rowCount = 0; rowCount < block.getRowCount(); rowCount++) { + logger.info("doReadRecordsSpill: {}", BlockUtils.rowToString(block, rowCount)); + assertEquals(expectedDocuments[rowCount], BlockUtils.rowToString(block, rowCount)); + } } } }