Skip to content

Commit

Permalink
Integration
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 committed Sep 19, 2024
1 parent 712c4e5 commit 1329e2f
Show file tree
Hide file tree
Showing 16 changed files with 257 additions and 220 deletions.
2 changes: 1 addition & 1 deletion async-query-core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ jacocoTestCoverageVerification {
}
limit {
counter = 'BRANCH'
minimum = 1.0
minimum = 0.9
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.sql.spark.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
Expand All @@ -20,8 +18,6 @@
import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream;
import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
Expand Down Expand Up @@ -84,25 +80,6 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
}
}

public static List<String> validateSparkSqlQuery(DataSource datasource, String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
try {
SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource);
StatementContext statement = sqlBaseParser.statement();
sqlParserBaseVisitor.visit(statement);
return sqlParserBaseVisitor.getValidationErrors();
} catch (SyntaxCheckException e) {
logger.error(
String.format(
"Failed to parse sql statement context while validating sql query %s", sqlQuery),
e);
return Collections.emptyList();
}
}

public static SqlBaseParser getBaseParser(String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
Expand All @@ -111,54 +88,6 @@ public static SqlBaseParser getBaseParser(String sqlQuery) {
return sqlBaseParser;
}

private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) {
if (datasource != null
&& datasource.getConnectorType() != null
&& datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) {
return new SparkSqlSecurityLakeValidatorVisitor();
} else {
return new SparkSqlValidatorVisitor();
}
}

/**
* A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class
* supports accumulating validation errors on visiting sql statement
*/
@Getter
private static class SqlBaseValidatorVisitor<T> extends SqlBaseParserBaseVisitor<T> {
private final List<String> validationErrors = new ArrayList<>();
}

/** A generic validator impl for Spark Sql Queries */
private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor<Void> {
@Override
public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) {
getValidationErrors().add("Creating user-defined functions is not allowed");
return super.visitCreateFunction(ctx);
}
}

/** A validator impl specific to Security Lake for Spark Sql Queries */
private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor<Void> {

public SparkSqlSecurityLakeValidatorVisitor() {
// only select statement allowed. hence we add the validation error to all types of statements
// by default
// and remove the validation error only for select statement.
getValidationErrors()
.add(
"Unsupported sql statement for security lake data source. Only select queries are"
+ " allowed");
}

@Override
public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) {
getValidationErrors().clear();
return super.visitStatementDefault(ctx);
}
}

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class CloudWatchLogsGrammarElementValidator extends DenyListGrammarElemen
MANAGE_RESOURCE,
ANALYZE_TABLE,
CACHE_TABLE,
CLEAR_CACHE,
DESCRIBE_NAMESPACE,
DESCRIBE_FUNCTION,
DESCRIBE_QUERY,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.validator;

public class DefaultGrammarElementValidator implements GrammarElementValidator {
@Override
public boolean isValid(GrammarElement element) {
return true;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.validator;

import java.util.Map;
import lombok.AllArgsConstructor;
import org.opensearch.sql.datasource.model.DataSourceType;

@AllArgsConstructor
public class GrammarElementValidatorProvider {

private final Map<DataSourceType, GrammarElementValidator> validatorMap;
private final GrammarElementValidator defaultValidator;

public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) {
return validatorMap.getOrDefault(dataSourceType, defaultValidator);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public Void visitAlterViewSchemaBinding(AlterViewSchemaBindingContext ctx) {
public Void visitRenameTable(RenameTableContext ctx) {
if (ctx.VIEW() != null) {
validateAllowed(GrammarElement.ALTER_VIEW);
} else if (ctx.TABLE() != null) {
} else {
validateAllowed(GrammarElement.ALTER_NAMESPACE);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

@AllArgsConstructor
public class SQLQueryValidator {
private final GrammarElementValidatorFactory grammarElementValidatorFactory;
private final GrammarElementValidatorProvider grammarElementValidatorProvider;

public void validate(String sqlQuery, DataSourceType datasourceType) {
GrammarElementValidator grammarElementValidator =
grammarElementValidatorFactory.getValidatorForDatasource(datasourceType);
grammarElementValidatorProvider.getValidatorForDatasource(datasourceType);
SQLQueryValidationVisitor visitor = new SQLQueryValidationVisitor(grammarElementValidator);
visitor.visit(SQLQueryUtils.getBaseParser(sqlQuery).singleStatement());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse;
import org.opensearch.sql.spark.rest.model.LangType;
import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler;
import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory;
import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator;
import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider;
import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator;
import org.opensearch.sql.spark.validator.SQLQueryValidator;

/**
Expand Down Expand Up @@ -178,7 +180,10 @@ public void setUp() {
metricsService,
new SparkSubmitParametersBuilderProvider(collection));
SQLQueryValidator sqlQueryValidator =
new SQLQueryValidator(new GrammarElementValidatorFactory());
new SQLQueryValidator(
new GrammarElementValidatorProvider(
ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()),
new DefaultGrammarElementValidator()));
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
dataSourceService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import com.amazonaws.services.emrserverless.model.JobRun;
import com.amazonaws.services.emrserverless.model.JobRunState;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand Down Expand Up @@ -88,7 +89,9 @@
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.model.LangType;
import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler;
import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory;
import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator;
import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider;
import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator;
import org.opensearch.sql.spark.validator.SQLQueryValidator;

@ExtendWith(MockitoExtension.class)
Expand All @@ -115,7 +118,10 @@ public class SparkQueryDispatcherTest {
@Mock private AsyncQueryScheduler asyncQueryScheduler;

private final SQLQueryValidator sqlQueryValidator =
new SQLQueryValidator(new GrammarElementValidatorFactory());
new SQLQueryValidator(
new GrammarElementValidatorProvider(
ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()),
new DefaultGrammarElementValidator()));

private DataSourceSparkParameterComposer dataSourceSparkParameterComposer =
(datasourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index;
import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv;
import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex;
Expand All @@ -22,7 +21,6 @@
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails;
Expand Down Expand Up @@ -444,106 +442,6 @@ void testRecoverIndex() {
assertEquals(IndexQueryActionType.RECOVER, indexDetails.getIndexQueryActionType());
}

@Test
void testValidateSparkSqlQuery_ValidQuery() {
List<String> errors =
validateSparkSqlQueryForDataSourceType(
"DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste'",
DataSourceType.PROMETHEUS);

assertTrue(errors.isEmpty(), "Valid query should not produce any errors");
}

@Test
void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake() {
List<String> errors =
validateSparkSqlQueryForDataSourceType(
"SELECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE);

assertTrue(errors.isEmpty(), "Valid query should not produce any errors ");
}

@Test
void testValidateSparkSqlQuery_SelectQuery_DataSourceTypeNull() {
List<String> errors =
validateSparkSqlQueryForDataSourceType("SELECT * FROM users WHERE age > 18", null);

assertTrue(errors.isEmpty(), "Valid query should not produce any errors ");
}

@Test
void testValidateSparkSqlQuery_InvalidQuery_SyntaxCheckFailureSkippedWithoutValidationError() {
List<String> errors =
validateSparkSqlQueryForDataSourceType(
"SEECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE);

assertTrue(errors.isEmpty(), "Valid query should not produce any errors ");
}

@Test
void testValidateSparkSqlQuery_nullDatasource() {
List<String> errors =
SQLQueryUtils.validateSparkSqlQuery(null, "SELECT * FROM users WHERE age > 18");
assertTrue(errors.isEmpty(), "Valid query should not produce any errors ");
}

private List<String> validateSparkSqlQueryForDataSourceType(
String query, DataSourceType dataSourceType) {
when(this.dataSource.getConnectorType()).thenReturn(dataSourceType);

return SQLQueryUtils.validateSparkSqlQuery(this.dataSource, query);
}

@Test
void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake_ValidationFails() {
List<String> errors =
validateSparkSqlQueryForDataSourceType(
"REFRESH INDEX cv1 ON mys3.default.http_logs", DataSourceType.SECURITY_LAKE);

assertFalse(
errors.isEmpty(),
"Invalid query as Security Lake datasource supports only flint queries and SELECT sql"
+ " queries. Given query was REFRESH sql query");
assertEquals(
errors.get(0),
"Unsupported sql statement for security lake data source. Only select queries are allowed");
}

@Test
void
testValidateSparkSqlQuery_NonSelectStatementContainingSelectClause_DataSourceSecurityLake_ValidationFails() {
String query =
"CREATE TABLE AccountSummaryOrWhatever AS "
+ "select taxid, address1, count(address1) from dbo.t "
+ "group by taxid, address1;";

List<String> errors =
validateSparkSqlQueryForDataSourceType(query, DataSourceType.SECURITY_LAKE);

assertFalse(
errors.isEmpty(),
"Invalid query as Security Lake datasource supports only flint queries and SELECT sql"
+ " queries. Given query was REFRESH sql query");
assertEquals(
errors.get(0),
"Unsupported sql statement for security lake data source. Only select queries are allowed");
}

@Test
void testValidateSparkSqlQuery_InvalidQuery() {
when(dataSource.getConnectorType()).thenReturn(DataSourceType.PROMETHEUS);
String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'";

List<String> errors = SQLQueryUtils.validateSparkSqlQuery(dataSource, invalidQuery);

assertFalse(errors.isEmpty(), "Invalid query should produce errors");
assertEquals(1, errors.size(), "Should have one error");
assertEquals(
"Creating user-defined functions is not allowed",
errors.get(0),
"Error message should match");
}

@Getter
protected static class IndexQuery {
private String query;
Expand Down
Loading

0 comments on commit 1329e2f

Please sign in to comment.