Skip to content

Commit

Permalink
Merge branch 'main' into dqs/add-query-status
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki MORITA <[email protected]>
  • Loading branch information
ykmr1224 authored Sep 6, 2024
2 parents 763dfbd + b76aa65 commit 4486a44
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ public DispatchQueryResponse submit(
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.BATCH)
.status(QueryState.WAITING)
.indexName(getIndexName(context))
.build();
}

private static String getIndexName(DispatchQueryContext context) {
return context.getIndexQueryDetails() != null
? context.getIndexQueryDetails().openSearchIndexName()
: null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,35 @@ public IndexQueryDetails build() {
}

public String openSearchIndexName() {
if (getIndexType() == null) {
return null;
}
FullyQualifiedTableName fullyQualifiedTableName = getFullyQualifiedTableName();
String indexName = StringUtils.EMPTY;
switch (getIndexType()) {
case COVERING:
indexName =
"flint_"
+ fullyQualifiedTableName.toFlintName()
+ "_"
+ strip(getIndexName(), STRIP_CHARS)
+ "_"
+ getIndexType().getSuffix();
if (getIndexName() != null) { // getIndexName will be null for SHOW INDEX query
indexName =
"flint_"
+ fullyQualifiedTableName.toFlintName()
+ "_"
+ strip(getIndexName(), STRIP_CHARS)
+ "_"
+ getIndexType().getSuffix();
} else {
return null;
}
break;
case SKIPPING:
indexName =
"flint_" + fullyQualifiedTableName.toFlintName() + "_" + getIndexType().getSuffix();
break;
case MATERIALIZED_VIEW:
indexName = "flint_" + new FullyQualifiedTableName(mvName).toFlintName();
if (mvName != null) { // mvName is not available for SHOW MATERIALIZED VIEW query
indexName = "flint_" + new FullyQualifiedTableName(mvName).toFlintName();
} else {
return null;
}
break;
}
return percentEncode(indexName).toLowerCase();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.spark.dispatcher;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
Expand Down Expand Up @@ -191,8 +192,8 @@ void testDispatchSelectQueryCreateNewSession() {

verifyNoInteractions(emrServerlessClient);
verify(sessionManager, never()).getSession(any(), any());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
}

@Test
Expand All @@ -218,8 +219,8 @@ void testDispatchSelectQueryReuseSession() {

verifyNoInteractions(emrServerlessClient);
verify(sessionManager, never()).createSession(any(), any());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
}

@Test
Expand Down Expand Up @@ -275,8 +276,8 @@ void testDispatchCreateAutoRefreshIndexQuery() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -320,8 +321,8 @@ void testDispatchWithPPLQuery() {
asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

Expand All @@ -346,7 +347,7 @@ void testDispatchWithSparkUDFQuery() {
sparkQueryDispatcher.dispatch(
getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(),
asyncQueryRequestContext));
Assertions.assertEquals(
assertEquals(
"Query is not allowed: Creating user-defined functions is not allowed",
illegalArgumentException.getMessage());
verifyNoInteractions(emrServerlessClient);
Expand Down Expand Up @@ -398,8 +399,8 @@ void testDispatchIndexQueryWithoutADatasourceName() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -436,8 +437,46 @@ void testDispatchMaterializedViewQuery() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

@Test
void testManualRefreshMaterializedViewQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query =
"CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = false)";
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
null,
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
sparkSubmitParameters,
tags,
false,
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals("flint_mv_1", dispatchQueryResponse.getIndexName());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -477,8 +516,8 @@ void testRefreshIndexQuery() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(JobType.REFRESH, dispatchQueryResponse.getJobType());
verifyNoInteractions(flintIndexMetadataService);
}
Expand Down Expand Up @@ -522,8 +561,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() {
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
assertEquals(expected, startJobRequestArgumentCaptor.getValue());
assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

Expand All @@ -533,7 +572,6 @@ void testDispatchAlterToManualRefreshIndexQuery() {
sparkQueryDispatcher =
new SparkQueryDispatcher(
dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider);

String query =
"ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH"
+ " (auto_refresh = false)";
Expand All @@ -550,6 +588,7 @@ void testDispatchAlterToManualRefreshIndexQuery() {
flintIndexOpFactory));

sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(queryHandlerFactory, times(1)).getIndexDMLHandler();
}

Expand All @@ -559,7 +598,6 @@ void testDispatchDropIndexQuery() {
sparkQueryDispatcher =
new SparkQueryDispatcher(
dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider);

String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
Expand All @@ -573,7 +611,9 @@ void testDispatchDropIndexQuery() {
indexDMLResultStorageService,
flintIndexOpFactory));

sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);
DispatchQueryResponse response =
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);

verify(queryHandlerFactory, times(1)).getIndexDMLHandler();
}

Expand All @@ -597,7 +637,7 @@ void testDispatchWithUnSupportedDataSourceType() {
getBaseDispatchQueryRequestBuilder(query).datasource("my_prometheus").build(),
asyncQueryRequestContext));

Assertions.assertEquals(
assertEquals(
"UnSupported datasource type for async queries:: PROMETHEUS",
unsupportedOperationException.getMessage());
}
Expand All @@ -609,7 +649,7 @@ void testCancelJob() {
String queryId =
sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext);

Assertions.assertEquals(QUERY_ID, queryId);
assertEquals(QUERY_ID, queryId);
}

@Test
Expand All @@ -625,7 +665,7 @@ void testCancelQueryWithSession() {

verifyNoInteractions(emrServerlessClient);
verify(statement, times(1)).cancel();
Assertions.assertEquals(MOCK_STATEMENT_ID, queryId);
assertEquals(MOCK_STATEMENT_ID, queryId);
}

@Test
Expand All @@ -642,7 +682,7 @@ void testCancelQueryWithInvalidSession() {

verifyNoInteractions(emrServerlessClient);
verifyNoInteractions(session);
Assertions.assertEquals("no session found. invalid", exception.getMessage());
assertEquals("no session found. invalid", exception.getMessage());
}

@Test
Expand All @@ -659,8 +699,7 @@ void testCancelQueryWithInvalidStatementId() {

verifyNoInteractions(emrServerlessClient);
verifyNoInteractions(statement);
Assertions.assertEquals(
"no statement found. " + new StatementId("invalid"), exception.getMessage());
assertEquals("no statement found. " + new StatementId("invalid"), exception.getMessage());
}

@Test
Expand Down Expand Up @@ -705,7 +744,7 @@ void testGetQueryResponse() {
JSONObject result =
sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata(), asyncQueryRequestContext);

Assertions.assertEquals("PENDING", result.get("status"));
assertEquals("PENDING", result.get("status"));
}

@Test
Expand All @@ -724,7 +763,7 @@ void testGetQueryResponseWithSession() {
asyncQueryRequestContext);

verifyNoInteractions(emrServerlessClient);
Assertions.assertEquals("waiting", result.get("status"));
assertEquals("waiting", result.get("status"));
}

@Test
Expand All @@ -743,7 +782,7 @@ void testGetQueryResponseWithInvalidSession() {
asyncQueryRequestContext));

verifyNoInteractions(emrServerlessClient);
Assertions.assertEquals("no session found. " + MOCK_SESSION_ID, exception.getMessage());
assertEquals("no session found. " + MOCK_SESSION_ID, exception.getMessage());
}

@Test
Expand All @@ -763,7 +802,7 @@ void testGetQueryResponseWithStatementNotExist() {
asyncQueryRequestContext));

verifyNoInteractions(emrServerlessClient);
Assertions.assertEquals(
assertEquals(
"no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage());
}

Expand All @@ -780,7 +819,7 @@ void testGetQueryResponseWithSuccess() {
sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata(), asyncQueryRequestContext);

verify(jobExecutionResponseReader, times(1)).getResultWithJobId(EMR_JOB_ID, null);
Assertions.assertEquals(
assertEquals(
new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet());
JSONObject dataJson = new JSONObject();
dataJson.put(ERROR_FIELD, "");
Expand All @@ -791,7 +830,7 @@ void testGetQueryResponseWithSuccess() {
// the same order.
// We need similar.
Assertions.assertTrue(dataJson.similar(result.get(DATA_FIELD)));
Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD));
assertEquals("SUCCESS", result.get(STATUS_FIELD));
verifyNoInteractions(emrServerlessClient);
}

Expand Down
Loading

0 comments on commit 4486a44

Please sign in to comment.