Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Populate indexName for BatchQuery #2999

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ public DispatchQueryResponse submit(
.resultIndex(dataSourceMetadata.getResultIndex())
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.BATCH)
.indexName(getIndexName(context))
.build();
}

private static String getIndexName(DispatchQueryContext context) {
return context.getIndexQueryDetails() != null
? context.getIndexQueryDetails().openSearchIndexName()
: null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,35 @@ public IndexQueryDetails build() {
}

public String openSearchIndexName() {
if (getIndexType() == null) {
return null;
}
FullyQualifiedTableName fullyQualifiedTableName = getFullyQualifiedTableName();
String indexName = StringUtils.EMPTY;
switch (getIndexType()) {
case COVERING:
indexName =
"flint_"
+ fullyQualifiedTableName.toFlintName()
+ "_"
+ strip(getIndexName(), STRIP_CHARS)
+ "_"
+ getIndexType().getSuffix();
if (getIndexName() != null) { // getIndexName will be null for SHOW INDEX query
indexName =
"flint_"
+ fullyQualifiedTableName.toFlintName()
+ "_"
+ strip(getIndexName(), STRIP_CHARS)
+ "_"
+ getIndexType().getSuffix();
} else {
return null;
}
break;
case SKIPPING:
indexName =
"flint_" + fullyQualifiedTableName.toFlintName() + "_" + getIndexType().getSuffix();
break;
case MATERIALIZED_VIEW:
indexName = "flint_" + new FullyQualifiedTableName(mvName).toFlintName();
if (mvName != null) { // mvName is not available for SHOW MATERIALIZED VIEW query
indexName = "flint_" + new FullyQualifiedTableName(mvName).toFlintName();
} else {
return null;
}
break;
}
return percentEncode(indexName).toLowerCase();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.spark.dispatcher;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
Expand Down Expand Up @@ -191,8 +192,8 @@ void testDispatchSelectQueryCreateNewSession() {

verifyNoInteractions(emrServerlessClient);
verify(sessionManager, never()).getSession(any(), any());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
}

@Test
Expand All @@ -218,8 +219,8 @@ void testDispatchSelectQueryReuseSession() {

verifyNoInteractions(emrServerlessClient);
verify(sessionManager, never()).createSession(any(), any());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
}

@Test
Expand Down Expand Up @@ -275,8 +276,8 @@ void testDispatchCreateAutoRefreshIndexQuery() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -320,8 +321,8 @@ void testDispatchWithPPLQuery() {
asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

Expand All @@ -346,7 +347,7 @@ void testDispatchWithSparkUDFQuery() {
sparkQueryDispatcher.dispatch(
getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(),
asyncQueryRequestContext));
Assertions.assertEquals(
assertEquals(
"Query is not allowed: Creating user-defined functions is not allowed",
illegalArgumentException.getMessage());
verifyNoInteractions(emrServerlessClient);
Expand Down Expand Up @@ -398,8 +399,8 @@ void testDispatchIndexQueryWithoutADatasourceName() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -436,8 +437,46 @@ void testDispatchMaterializedViewQuery() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

@Test
void testManualRefreshMaterializedViewQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query =
"CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = false)";
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
null,
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
sparkSubmitParameters,
tags,
false,
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals("flint_mv_1", dispatchQueryResponse.getIndexName());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -477,8 +516,8 @@ void testRefreshIndexQuery() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(JobType.REFRESH, dispatchQueryResponse.getJobType());
verifyNoInteractions(flintIndexMetadataService);
}
Expand Down Expand Up @@ -522,8 +561,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

Expand All @@ -533,7 +572,6 @@ void testDispatchAlterToManualRefreshIndexQuery() {
sparkQueryDispatcher =
new SparkQueryDispatcher(
dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider);

String query =
"ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH"
+ " (auto_refresh = false)";
Expand All @@ -550,6 +588,7 @@ void testDispatchAlterToManualRefreshIndexQuery() {
flintIndexOpFactory));

sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(queryHandlerFactory, times(1)).getIndexDMLHandler();
}

Expand All @@ -559,7 +598,6 @@ void testDispatchDropIndexQuery() {
sparkQueryDispatcher =
new SparkQueryDispatcher(
dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider);

String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
Expand All @@ -573,7 +611,9 @@ void testDispatchDropIndexQuery() {
indexDMLResultStorageService,
flintIndexOpFactory));

sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);
DispatchQueryResponse response =
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(queryHandlerFactory, times(1)).getIndexDMLHandler();
}

Expand All @@ -597,7 +637,7 @@ void testDispatchWithUnSupportedDataSourceType() {
getBaseDispatchQueryRequestBuilder(query).datasource("my_prometheus").build(),
asyncQueryRequestContext));

Assertions.assertEquals(
assertEquals(
"UnSupported datasource type for async queries:: PROMETHEUS",
unsupportedOperationException.getMessage());
}
Expand All @@ -609,7 +649,7 @@ void testCancelJob() {
String queryId =
sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext);

Assertions.assertEquals(QUERY_ID, queryId);
assertEquals(QUERY_ID, queryId);
}

@Test
Expand All @@ -625,7 +665,7 @@ void testCancelQueryWithSession() {

verifyNoInteractions(emrServerlessClient);
verify(statement, times(1)).cancel();
Assertions.assertEquals(MOCK_STATEMENT_ID, queryId);
assertEquals(MOCK_STATEMENT_ID, queryId);
}

@Test
Expand All @@ -642,7 +682,7 @@ void testCancelQueryWithInvalidSession() {

verifyNoInteractions(emrServerlessClient);
verifyNoInteractions(session);
Assertions.assertEquals("no session found. invalid", exception.getMessage());
assertEquals("no session found. invalid", exception.getMessage());
}

@Test
Expand All @@ -659,8 +699,7 @@ void testCancelQueryWithInvalidStatementId() {

verifyNoInteractions(emrServerlessClient);
verifyNoInteractions(statement);
Assertions.assertEquals(
"no statement found. " + new StatementId("invalid"), exception.getMessage());
assertEquals("no statement found. " + new StatementId("invalid"), exception.getMessage());
}

@Test
Expand Down Expand Up @@ -705,7 +744,7 @@ void testGetQueryResponse() {
JSONObject result =
sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata(), asyncQueryRequestContext);

Assertions.assertEquals("PENDING", result.get("status"));
assertEquals("PENDING", result.get("status"));
}

@Test
Expand All @@ -724,7 +763,7 @@ void testGetQueryResponseWithSession() {
asyncQueryRequestContext);

verifyNoInteractions(emrServerlessClient);
Assertions.assertEquals("waiting", result.get("status"));
assertEquals("waiting", result.get("status"));
}

@Test
Expand All @@ -743,7 +782,7 @@ void testGetQueryResponseWithInvalidSession() {
asyncQueryRequestContext));

verifyNoInteractions(emrServerlessClient);
Assertions.assertEquals("no session found. " + MOCK_SESSION_ID, exception.getMessage());
assertEquals("no session found. " + MOCK_SESSION_ID, exception.getMessage());
}

@Test
Expand All @@ -763,7 +802,7 @@ void testGetQueryResponseWithStatementNotExist() {
asyncQueryRequestContext));

verifyNoInteractions(emrServerlessClient);
Assertions.assertEquals(
assertEquals(
"no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage());
}

Expand All @@ -780,7 +819,7 @@ void testGetQueryResponseWithSuccess() {
sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata(), asyncQueryRequestContext);

verify(jobExecutionResponseReader, times(1)).getResultWithJobId(EMR_JOB_ID, null);
Assertions.assertEquals(
assertEquals(
new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet());
JSONObject dataJson = new JSONObject();
dataJson.put(ERROR_FIELD, "");
Expand All @@ -791,7 +830,7 @@ void testGetQueryResponseWithSuccess() {
// the same order.
// We need similar.
Assertions.assertTrue(dataJson.similar(result.get(DATA_FIELD)));
Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD));
assertEquals("SUCCESS", result.get(STATUS_FIELD));
verifyNoInteractions(emrServerlessClient);
}

Expand Down
Loading
Loading