From 13da52495f10ea2d1ed9ac5ac1a91d838eeef09f Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Mon, 13 Nov 2023 10:34:18 -0800 Subject: [PATCH] Revert "Create new session if client provided session is invalid (#2368) (#2371)" This reverts commit 5ab7858cb9e7ec69555545a8e7a5675e2f73e9e4. --- .../dispatcher/SparkQueryDispatcher.java | 5 +++-- .../execution/statement/StatementModel.java | 2 +- ...AsyncQueryExecutorServiceImplSpecTest.java | 22 +++++++++---------- .../dispatcher/SparkQueryDispatcherTest.java | 20 +++++++++++++++++ 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 5e80259e09..8feeddcafc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -219,9 +219,10 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ // get session from request SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); Optional createdSession = sessionManager.getSession(sessionId); - if (createdSession.isPresent()) { - session = createdSession.get(); + if (createdSession.isEmpty()) { + throw new IllegalArgumentException("no session found. " + sessionId); } + session = createdSession.get(); } if (session == null || !session.isReady()) { // create session if not exist or session dead/fail diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index adc147c905..2a1043bf73 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -36,7 +36,7 @@ public class StatementModel extends StateModel { public static final String QUERY_ID = "queryId"; public static final String SUBMIT_TIME = "submitTime"; public static final String ERROR = "error"; - public static final String UNKNOWN = ""; + public static final String UNKNOWN = "unknown"; public static final String STATEMENT_DOC_TYPE = "statement"; private final String version; diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index cf638effc6..6bc40c009b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -45,7 +45,6 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; -import org.opensearch.core.common.Strings; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.plugins.Plugin; @@ -228,7 +227,6 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { // 2. fetch async query result. AsyncQueryExecutionResponse asyncQueryResults = asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); - assertTrue(Strings.isEmpty(asyncQueryResults.getError())); assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); // 3. cancel async query. @@ -462,7 +460,7 @@ public void recreateSessionIfNotReady() { } @Test - public void submitQueryInInvalidSessionWillCreateNewSession() { + public void submitQueryInInvalidSessionThrowException() { LocalEMRSClient emrsClient = new LocalEMRSClient(); AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrsClient); @@ -470,14 +468,16 @@ public void submitQueryInInvalidSessionWillCreateNewSession() { // enable session enableSession(true); - // 1. create async query with invalid sessionId - SessionId invalidSessionId = SessionId.newSessionId(DATASOURCE); - CreateAsyncQueryResponse asyncQuery = - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest( - "select 1", DATASOURCE, LangType.SQL, invalidSessionId.getSessionId())); - assertNotNull(asyncQuery.getSessionId()); - assertNotEquals(invalidSessionId.getSessionId(), asyncQuery.getSessionId()); + // 1. create async query. + SessionId sessionId = SessionId.newSessionId(DATASOURCE); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, sessionId.getSessionId()))); + assertEquals("no session found. " + sessionId, exception.getMessage()); } private DataSourceServiceImpl createDataSourceService() { diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 95b6033d12..743274d46c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -327,6 +327,26 @@ void testDispatchSelectQueryReuseSession() { Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); } + @Test + void testDispatchSelectQueryInvalidSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, "invalid"); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.empty()).when(sessionManager).getSession(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).createSession(any()); + Assertions.assertEquals( + "no session found. " + new SessionId("invalid"), exception.getMessage()); + } + @Test void testDispatchSelectQueryFailedCreateSession() { String query = "select * from my_glue.default.http_logs";