From 571c83fdf99a5f3215d2d0143fb8887bc4e8e8ee Mon Sep 17 00:00:00 2001 From: AbdulRehman Date: Thu, 17 Oct 2024 15:07:45 -0400 Subject: [PATCH 1/7] Feature/postgresql disable predicate pushdown (#2337) Co-authored-by: AbdulRehman Faraj Co-authored-by: ejeffrli <144148373+ejeffrli@users.noreply.github.com> --- .../jdbc/manager/JdbcSplitQueryBuilder.java | 2 +- .../postgresql/PostGreSqlMetadataHandler.java | 12 +++ .../PostGreSqlQueryStringBuilder.java | 59 +----------- .../postgresql/PostGreSqlRecordHandler.java | 3 +- .../PostGreSqlRecordHandlerTest.java | 96 +------------------ .../redshift/RedshiftRecordHandler.java | 4 +- .../redshift/RedshiftRecordHandlerTest.java | 2 +- .../saphana/SaphanaQueryStringBuilder.java | 7 ++ 8 files changed, 30 insertions(+), 155 deletions(-) diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcSplitQueryBuilder.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcSplitQueryBuilder.java index 104f271017..a532d02aa2 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcSplitQueryBuilder.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcSplitQueryBuilder.java @@ -336,7 +336,7 @@ else if (singleValues.size() > 1) { return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; } - protected String toPredicate(String columnName, String operator, Object value, ArrowType type, + private String toPredicate(String columnName, String operator, Object value, ArrowType type, List accumulator) { accumulator.add(new TypeAndValue(type, value)); diff --git a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandler.java b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandler.java index 03d5d84109..a765aa4e0f 100644 --- a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandler.java +++ b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandler.java @@ -39,6 +39,7 @@ import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.ComplexExpressionPushdownSubType; import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.FilterPushdownSubType; +import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.HintsSubtype; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo; import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory; @@ -97,6 +98,9 @@ public class PostGreSqlMetadataHandler static final String LIST_PAGINATED_TABLES_QUERY = "SELECT a.\"TABLE_NAME\", a.\"TABLE_SCHEM\" FROM ((SELECT table_name as \"TABLE_NAME\", table_schema as \"TABLE_SCHEM\" FROM information_schema.tables WHERE table_schema = ?) UNION (SELECT matviewname as \"TABLE_NAME\", schemaname as \"TABLE_SCHEM\" from pg_catalog.pg_matviews mv where has_table_privilege(format('%I.%I', mv.schemaname, mv.matviewname), 'select') and schemaname = ?)) AS a ORDER BY a.\"TABLE_NAME\" LIMIT ? OFFSET ?"; + //Session Property Flag that hints to the engine that the data source is using none default collation + protected static final String NON_DEFAULT_COLLATE = "non_default_collate"; + /** * Instantiates handler to be used by Lambda function directly. * @@ -143,6 +147,14 @@ public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAlloca )); jdbcQueryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions); + + //Provide a hint to the engine that postgresql is using default collate settings + //Which doesn't match Athena's Engine Collation; this disabling Predicate pushdown + boolean nonDefaultCollate = Boolean.valueOf(this.configOptions.getOrDefault(NON_DEFAULT_COLLATE, "false")); + if (nonDefaultCollate) { + capabilities.put(DataSourceOptimizations.DATA_SOURCE_HINTS.withSupportedSubTypes(HintsSubtype.NON_DEFAULT_COLLATE)); + } + return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build()); } diff --git a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlQueryStringBuilder.java b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlQueryStringBuilder.java index 046d7049ca..65b6c772d8 100644 --- a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlQueryStringBuilder.java +++ b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlQueryStringBuilder.java @@ -1,7 +1,4 @@ -/*- - * #%L - * athena-postgresql - * %% +/*- #%L athena-postgresql %% * Copyright (C) 2019 Amazon Web Services * %% * Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,10 +20,7 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; import com.amazonaws.athena.connectors.jdbc.manager.FederationExpressionParser; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.athena.connectors.jdbc.manager.TypeAndValue; import com.google.common.base.Strings; -import org.apache.arrow.vector.types.Types; -import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Schema; import java.sql.Connection; @@ -37,8 +31,6 @@ import java.util.Objects; import java.util.stream.Collectors; -import static java.lang.String.format; - /** * Extends {@link JdbcSplitQueryBuilder} and implements PostGreSql specific SQL clauses for split. * @@ -47,13 +39,9 @@ public class PostGreSqlQueryStringBuilder extends JdbcSplitQueryBuilder { - private final java.util.Map configOptions; - private final String postgresqlCollateExperimentalFlag = "postgresql_collate_experimental_flag"; - - public PostGreSqlQueryStringBuilder(final String quoteCharacters, final FederationExpressionParser federationExpressionParser, final java.util.Map configOptions) + public PostGreSqlQueryStringBuilder(final String quoteCharacters, final FederationExpressionParser federationExpressionParser) { super(quoteCharacters, federationExpressionParser); - this.configOptions = configOptions; } @Override @@ -106,10 +94,10 @@ protected String getFromClauseWithSplit(String catalog, String schema, String ta if (PostGreSqlMetadataHandler.ALL_PARTITIONS.equals(partitionSchemaName) || PostGreSqlMetadataHandler.ALL_PARTITIONS.equals(partitionName)) { // No partitions - return format(" FROM %s ", tableName); + return String.format(" FROM %s ", tableName); } - return format(" FROM %s.%s ", quote(partitionSchemaName), quote(partitionName)); + return String.format(" FROM %s.%s ", quote(partitionSchemaName), quote(partitionName)); } @Override @@ -122,43 +110,4 @@ protected List getPartitionWhereClauses(final Split split) return Collections.emptyList(); } - - protected String toPredicate(String columnName, String operator, Object value, ArrowType type, - List accumulator) - { - if (isPostgresqlCollateExperimentalFlagEnabled()) { - Types.MinorType minorType = Types.getMinorTypeForArrowType(type); - //Only check for varchar; as it's the only collate-able type - //Only a range that is applicable - if (minorType.equals(Types.MinorType.VARCHAR) && isOperatorARange(operator)) { - accumulator.add(new TypeAndValue(type, value)); - return format("%s %s ? COLLATE \"C\"", quote(columnName), operator); - } - } - // Default to parent's behavior - return super.toPredicate(columnName, operator, value, type, accumulator); - } - - /** - * Flags to check if experimental flag to allow different collate for postgresql - * @return true if a flag is set; default otherwise to false; - */ - private boolean isPostgresqlCollateExperimentalFlagEnabled() - { - String flag = configOptions.getOrDefault(postgresqlCollateExperimentalFlag, "false"); - return flag.equalsIgnoreCase("true"); - } - - private boolean isOperatorARange(String operator) - { - switch (operator) { - case ">": - case "<": - case ">=": - case "<=": - return true; - default: - return false; - } - } } diff --git a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java index 9f5fb19f2f..877c450a05 100644 --- a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java +++ b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java @@ -72,8 +72,7 @@ public PostGreSqlRecordHandler(java.util.Map configOptions) public PostGreSqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), - new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(POSTGRESQL_DRIVER_CLASS, POSTGRESQL_DEFAULT_PORT)), - new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER), configOptions), configOptions); + new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(POSTGRESQL_DRIVER_CLASS, POSTGRESQL_DEFAULT_PORT)), new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting diff --git a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java index e95fb985e8..c9827a8464 100644 --- a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java +++ b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java @@ -56,7 +56,6 @@ import java.sql.PreparedStatement; import java.sql.SQLException; import java.util.Collections; -import java.util.Map; import java.util.concurrent.TimeUnit; import static com.amazonaws.athena.connectors.postgresql.PostGreSqlConstants.POSTGRES_NAME; @@ -76,7 +75,7 @@ public class PostGreSqlRecordHandlerTest extends TestBase private AWSSecretsManager secretsManager; private AmazonAthena athena; private MockedStatic mockedPostGreSqlMetadataHandler; - private DatabaseConnectionConfig databaseConnectionConfig; + @Before public void setup() throws Exception @@ -87,8 +86,8 @@ public void setup() this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - jdbcSplitQueryBuilder = new PostGreSqlQueryStringBuilder("\"", new PostgreSqlFederationExpressionParser("\""), Collections.emptyMap()); - databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", POSTGRES_NAME, + jdbcSplitQueryBuilder = new PostGreSqlQueryStringBuilder("\"", new PostgreSqlFederationExpressionParser("\"")); + final DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", POSTGRES_NAME, "postgres://jdbc:postgresql://hostname/user=A&password=B"); this.postGreSqlRecordHandler = new PostGreSqlRecordHandler(databaseConnectionConfig, amazonS3, secretsManager, athena, jdbcConnectionFactory, jdbcSplitQueryBuilder, com.google.common.collect.ImmutableMap.of()); @@ -230,95 +229,6 @@ public void buildSplitSqlForDateTest() logger.info("buildSplitSqlForDateTest - exit"); } - @Test - public void buildSplitSqlCollateAwareQuery() - throws SQLException - { - logger.info("buildSplitSqlCollateAwareQuery - enter"); - - TableName tableName = new TableName("testSchema", "testTable"); - - SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol1", Types.MinorType.INT.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol2", Types.MinorType.VARCHAR.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol3", Types.MinorType.BIGINT.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol4", Types.MinorType.FLOAT4.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol5", Types.MinorType.SMALLINT.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol6", Types.MinorType.TINYINT.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol7", Types.MinorType.FLOAT8.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol8", Types.MinorType.BIT.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol9", new ArrowType.Decimal(8, 2)).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("testCol10", new ArrowType.Utf8()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("partition_schema_name", Types.MinorType.VARCHAR.getType()).build()); - schemaBuilder.addField(FieldBuilder.newBuilder("partition_name", Types.MinorType.VARCHAR.getType()).build()); - Schema schema = schemaBuilder.build(); - - Split split = Mockito.mock(Split.class); - Mockito.when(split.getProperties()).thenReturn(ImmutableMap.of("partition_schema_name", "s0", "partition_name", "p0")); - Mockito.when(split.getProperty(Mockito.eq(com.amazonaws.athena.connectors.postgresql.PostGreSqlMetadataHandler.BLOCK_PARTITION_SCHEMA_COLUMN_NAME))).thenReturn("s0"); - Mockito.when(split.getProperty(Mockito.eq(com.amazonaws.athena.connectors.postgresql.PostGreSqlMetadataHandler.BLOCK_PARTITION_COLUMN_NAME))).thenReturn("p0"); - - Range range1a = Mockito.mock(Range.class, Mockito.RETURNS_DEEP_STUBS); - Mockito.when(range1a.isSingleValue()).thenReturn(true); - Mockito.when(range1a.getLow().getValue()).thenReturn(1); - Range range1b = Mockito.mock(Range.class, Mockito.RETURNS_DEEP_STUBS); - Mockito.when(range1b.isSingleValue()).thenReturn(true); - Mockito.when(range1b.getLow().getValue()).thenReturn(2); - ValueSet valueSet1 = Mockito.mock(SortedRangeSet.class, Mockito.RETURNS_DEEP_STUBS); - Mockito.when(valueSet1.getRanges().getOrderedRanges()).thenReturn(ImmutableList.of(range1a, range1b)); - - ValueSet valueSet2 = getRangeSet(Marker.Bound.EXACTLY, "1", Marker.Bound.BELOW, "10"); - ValueSet valueSet3 = getRangeSet(Marker.Bound.ABOVE, 2L, Marker.Bound.EXACTLY, 20L); - ValueSet valueSet4 = getSingleValueSet(1.1F); - ValueSet valueSet5 = getSingleValueSet(1); - ValueSet valueSet6 = getSingleValueSet(0); - ValueSet valueSet7 = getSingleValueSet(1.2d); - ValueSet valueSet8 = getSingleValueSet(true); - ValueSet valueSet9 = getSingleValueSet(BigDecimal.valueOf(12.34)); - ValueSet valueSet10 = getSingleValueSet("A"); - - Constraints constraints = Mockito.mock(Constraints.class); - Mockito.when(constraints.getSummary()).thenReturn(new ImmutableMap.Builder() - .put("testCol1", valueSet1) - .put("testCol2", valueSet2) - .put("testCol3", valueSet3) - .put("testCol4", valueSet4) - .put("testCol5", valueSet5) - .put("testCol6", valueSet6) - .put("testCol7", valueSet7) - .put("testCol8", valueSet8) - .put("testCol9", valueSet9) - .put("testCol10", valueSet10) - .build()); - - String expectedSql = "SELECT \"testCol1\", \"testCol2\", \"testCol3\", \"testCol4\", \"testCol5\", \"testCol6\", \"testCol7\", \"testCol8\", \"testCol9\", RTRIM(\"testCol10\") AS \"testCol10\" FROM \"s0\".\"p0\" WHERE (\"testCol1\" IN (?,?)) AND ((\"testCol2\" >= ? COLLATE \"C\" AND \"testCol2\" < ? COLLATE \"C\")) AND ((\"testCol3\" > ? AND \"testCol3\" <= ?)) AND (\"testCol4\" = ?) AND (\"testCol5\" = ?) AND (\"testCol6\" = ?) AND (\"testCol7\" = ?) AND (\"testCol8\" = ?) AND (\"testCol9\" = ?) AND (\"testCol10\" = ?)"; - PreparedStatement expectedPreparedStatement = Mockito.mock(PreparedStatement.class); - Mockito.when(this.connection.prepareStatement(Mockito.eq(expectedSql))).thenReturn(expectedPreparedStatement); - - //Setting Collate Aware query builder flag on - Map configOptions = ImmutableMap.of("postgresql_collate_experimental_flag", "true"); - PostGreSqlQueryStringBuilder localJdbcSplitQueryBuilder = new PostGreSqlQueryStringBuilder("\"", new PostgreSqlFederationExpressionParser("\""), configOptions); - PostGreSqlRecordHandler localPostgresqlRecordHandler = new PostGreSqlRecordHandler(databaseConnectionConfig, amazonS3, secretsManager, athena, jdbcConnectionFactory, localJdbcSplitQueryBuilder, configOptions); - PreparedStatement preparedStatement = localPostgresqlRecordHandler.buildSplitSql(this.connection, "testCatalogName", tableName, schema, constraints, split); - - Assert.assertEquals(expectedPreparedStatement, preparedStatement); - Mockito.verify(preparedStatement, Mockito.times(1)).setInt(1, 1); - Mockito.verify(preparedStatement, Mockito.times(1)).setInt(2, 2); - Mockito.verify(preparedStatement, Mockito.times(1)).setString(3, "1"); - Mockito.verify(preparedStatement, Mockito.times(1)).setString(4, "10"); - Mockito.verify(preparedStatement, Mockito.times(1)).setLong(5, 2L); - Mockito.verify(preparedStatement, Mockito.times(1)).setLong(6, 20L); - Mockito.verify(preparedStatement, Mockito.times(1)).setFloat(7, 1.1F); - Mockito.verify(preparedStatement, Mockito.times(1)).setShort(8, (short) 1); - Mockito.verify(preparedStatement, Mockito.times(1)).setByte(9, (byte) 0); - Mockito.verify(preparedStatement, Mockito.times(1)).setDouble(10, 1.2d); - Mockito.verify(preparedStatement, Mockito.times(1)).setBoolean(11, true); - Mockito.verify(preparedStatement, Mockito.times(1)).setBigDecimal(12, BigDecimal.valueOf(12.34)); - Mockito.verify(preparedStatement, Mockito.times(1)).setString(13, "A"); - - logger.info("buildSplitSqlCollateAwareQuery - exit"); - } - private ValueSet getSingleValueSet(Object value) { Range range = Mockito.mock(Range.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(range.isSingleValue()).thenReturn(true); diff --git a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java index b49b69c684..1546ea391b 100644 --- a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java +++ b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java @@ -63,9 +63,7 @@ public RedshiftRecordHandler(java.util.Map configOptions) public RedshiftRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), - new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, - new DatabaseConnectionInfo(REDSHIFT_DRIVER_CLASS, REDSHIFT_DEFAULT_PORT)), - new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER), configOptions), configOptions); + new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(REDSHIFT_DRIVER_CLASS, REDSHIFT_DEFAULT_PORT)), new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting diff --git a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java index bbb7fb6af5..8f459391b5 100644 --- a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java +++ b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java @@ -90,7 +90,7 @@ public void setup() this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - jdbcSplitQueryBuilder = new PostGreSqlQueryStringBuilder("\"", new PostgreSqlFederationExpressionParser("\""), Collections.emptyMap()); + jdbcSplitQueryBuilder = new PostGreSqlQueryStringBuilder("\"", new PostgreSqlFederationExpressionParser("\"")); final DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", REDSHIFT_NAME, "redshift://jdbc:redshift://hostname/user=A&password=B"); diff --git a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaQueryStringBuilder.java b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaQueryStringBuilder.java index ca65ce8efe..16ec948679 100644 --- a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaQueryStringBuilder.java +++ b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaQueryStringBuilder.java @@ -349,4 +349,11 @@ else if (singleValues.size() > 1) { return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; } + + private String toPredicate(String columnName, String operator, Object value, ArrowType type, + List accumulator) + { + accumulator.add(new TypeAndValue(type, value)); + return quote(columnName) + " " + operator + " ?"; + } } From fd26659c64e34f42d0e6bf259dd292dc7b6c0a54 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 21 Oct 2024 07:02:15 +0000 Subject: [PATCH 2/7] build(deps): bump software.amazon.awssdk:bom from 2.28.21 to 2.28.26 Bumps software.amazon.awssdk:bom from 2.28.21 to 2.28.26. --- updated-dependencies: - dependency-name: software.amazon.awssdk:bom dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- athena-dynamodb/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/athena-dynamodb/pom.xml b/athena-dynamodb/pom.xml index 045a804973..28d8f26fdb 100644 --- a/athena-dynamodb/pom.xml +++ b/athena-dynamodb/pom.xml @@ -13,7 +13,7 @@ software.amazon.awssdk bom - 2.28.21 + 2.28.26 pom import From 8fe875f262cb343132a8556c7985e9068de44de1 Mon Sep 17 00:00:00 2001 From: Aimery Methena <159072740+aimethed@users.noreply.github.com> Date: Tue, 22 Oct 2024 09:40:27 -0400 Subject: [PATCH 3/7] AWS SDK v2 migration (#2339) Signed-off-by: dependabot[bot] Co-authored-by: Trianz-Akshay <108925344+Trianz-Akshay@users.noreply.github.com> Co-authored-by: Jithendar Trianz <106380520+Jithendar12@users.noreply.github.com> Co-authored-by: VenkatasivareddyTR <110587813+VenkatasivareddyTR@users.noreply.github.com> Co-authored-by: ejeffrli <144148373+ejeffrli@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Jeffrey Lin Co-authored-by: Mario Rial Co-authored-by: AbdulRehman --- athena-aws-cmdb/Dockerfile | 9 + athena-aws-cmdb/athena-aws-cmdb.yaml | 5 +- athena-aws-cmdb/pom.xml | 24 +- .../aws/cmdb/AwsCmdbMetadataHandler.java | 8 +- .../aws/cmdb/AwsCmdbRecordHandler.java | 8 +- .../aws/cmdb/TableProviderFactory.java | 22 +- .../cmdb/tables/EmrClusterTableProvider.java | 76 ++-- .../aws/cmdb/tables/RdsTableProvider.java | 155 ++++---- .../aws/cmdb/tables/ec2/EbsTableProvider.java | 60 +-- .../aws/cmdb/tables/ec2/Ec2TableProvider.java | 126 +++---- .../cmdb/tables/ec2/ImagesTableProvider.java | 90 ++--- .../cmdb/tables/ec2/RouteTableProvider.java | 74 ++-- .../ec2/SecurityGroupsTableProvider.java | 62 ++-- .../cmdb/tables/ec2/SubnetTableProvider.java | 43 ++- .../aws/cmdb/tables/ec2/VpcTableProvider.java | 38 +- .../tables/s3/S3BucketsTableProvider.java | 26 +- .../tables/s3/S3ObjectsTableProvider.java | 54 +-- .../aws/cmdb/AwsCmdbMetadataHandlerTest.java | 12 +- .../aws/cmdb/AwsCmdbRecordHandlerTest.java | 12 +- .../aws/cmdb/TableProviderFactoryTest.java | 18 +- .../tables/AbstractTableProviderTest.java | 41 +- .../tables/EmrClusterTableProviderTest.java | 81 ++-- .../aws/cmdb/tables/RdsTableProviderTest.java | 155 ++++---- .../cmdb/tables/ec2/EbsTableProviderTest.java | 56 ++- .../cmdb/tables/ec2/Ec2TableProviderTest.java | 142 ++++--- .../tables/ec2/ImagesTableProviderTest.java | 81 ++-- .../tables/ec2/RouteTableProviderTest.java | 69 ++-- .../ec2/SecurityGroupsTableProviderTest.java | 54 ++- .../tables/ec2/SubnetTableProviderTest.java | 42 +-- .../cmdb/tables/ec2/VpcTableProviderTest.java | 37 +- .../tables/s3/S3BucketsTableProviderTest.java | 33 +- .../tables/s3/S3ObjectsTableProviderTest.java | 60 ++- athena-clickhouse/Dockerfile | 9 + athena-clickhouse/athena-clickhouse.yaml | 5 +- .../clickhouse/ClickHouseMetadataHandler.java | 8 +- .../ClickHouseMuxMetadataHandler.java | 6 +- .../ClickHouseMuxRecordHandler.java | 8 +- .../clickhouse/ClickHouseRecordHandler.java | 16 +- .../ClickHouseMetadataHandlerTest.java | 18 +- .../ClickHouseMuxJdbcMetadataHandlerTest.java | 12 +- .../ClickHouseMuxJdbcRecordHandlerTest.java | 18 +- athena-cloudera-hive/Dockerfile | 9 + .../athena-cloudera-hive.yaml | 5 +- athena-cloudera-hive/pom.xml | 14 +- .../cloudera/HiveMetadataHandler.java | 8 +- .../cloudera/HiveMuxMetadataHandler.java | 6 +- .../cloudera/HiveMuxRecordHandler.java | 8 +- .../cloudera/HiveRecordHandler.java | 13 +- .../cloudera/HiveMetadataHandlerTest.java | 19 +- .../cloudera/HiveMuxMetadataHandlerTest.java | 12 +- .../cloudera/HiveMuxRecordHandlerTest.java | 18 +- .../cloudera/HiveRecordHandlerTest.java | 25 +- athena-cloudera-impala/Dockerfile | 9 + .../athena-cloudera-impala.yaml | 5 +- athena-cloudera-impala/pom.xml | 14 +- .../ImpalaFederationExpressionParser.java | 2 +- .../cloudera/ImpalaMetadataHandler.java | 8 +- .../cloudera/ImpalaMuxMetadataHandler.java | 6 +- .../cloudera/ImpalaMuxRecordHandler.java | 8 +- .../cloudera/ImpalaRecordHandler.java | 14 +- .../cloudera/ImpalaMetadataHandlerTest.java | 18 +- .../ImpalaMuxMetadataHandlerTest.java | 12 +- .../cloudera/ImpalaMuxRecordHandlerTest.java | 18 +- .../ImpalaQueryStringBuilderTest.java | 1 - .../cloudera/ImpalaRecordHandlerTest.java | 25 +- athena-cloudwatch-metrics/Dockerfile | 9 + .../athena-cloudwatch-metrics.yaml | 5 +- athena-cloudwatch-metrics/pom.xml | 6 +- .../cloudwatch/metrics/MetricStatSerDe.java | 31 +- .../cloudwatch/metrics/MetricUtils.java | 62 ++-- .../metrics/MetricsExceptionFilter.java | 8 +- .../metrics/MetricsMetadataHandler.java | 60 +-- .../metrics/MetricsRecordHandler.java | 119 +++--- .../metrics/MetricStatSerDeTest.java | 30 +- .../cloudwatch/metrics/MetricUtilsTest.java | 97 +++-- .../metrics/MetricsMetadataHandlerTest.java | 38 +- .../metrics/MetricsRecordHandlerTest.java | 129 ++++--- athena-cloudwatch/Dockerfile | 9 + athena-cloudwatch/athena-cloudwatch.yaml | 5 +- athena-cloudwatch/pom.xml | 26 +- .../cloudwatch/CloudwatchExceptionFilter.java | 6 +- .../cloudwatch/CloudwatchMetadataHandler.java | 103 +++--- .../cloudwatch/CloudwatchRecordHandler.java | 72 ++-- .../cloudwatch/CloudwatchTableResolver.java | 68 ++-- .../cloudwatch/CloudwatchUtils.java | 37 +- .../CloudwatchMetadataHandlerTest.java | 122 +++--- .../CloudwatchRecordHandlerTest.java | 76 ++-- .../cloudwatch/integ/CloudwatchIntegTest.java | 35 +- athena-datalakegen2/Dockerfile | 9 + athena-datalakegen2/athena-datalakegen2.yaml | 5 +- athena-datalakegen2/pom.xml | 14 +- .../DataLakeGen2MetadataHandler.java | 8 +- .../DataLakeGen2MuxMetadataHandler.java | 6 +- .../DataLakeGen2MuxRecordHandler.java | 8 +- .../DataLakeGen2RecordHandler.java | 13 +- .../DataLakeGen2MetadataHandlerTest.java | 18 +- .../DataLakeGen2MuxMetadataHandlerTest.java | 12 +- .../DataLakeGen2MuxRecordHandlerTest.java | 18 +- .../DataLakeRecordHandlerTest.java | 18 +- athena-db2-as400/Dockerfile | 9 + athena-db2-as400/athena-db2-as400.yaml | 5 +- athena-db2-as400/pom.xml | 14 +- .../db2as400/Db2As400MetadataHandler.java | 8 +- .../db2as400/Db2As400MuxMetadataHandler.java | 6 +- .../db2as400/Db2As400MuxRecordHandler.java | 8 +- .../db2as400/Db2As400RecordHandler.java | 13 +- .../db2as400/Db2As400MetadataHandlerTest.java | 18 +- .../db2as400/Db2As400RecordHandlerTest.java | 18 +- athena-db2/Dockerfile | 9 + athena-db2/athena-db2.yaml | 5 +- athena-db2/pom.xml | 14 +- .../connectors/db2/Db2MetadataHandler.java | 8 +- .../connectors/db2/Db2MuxMetadataHandler.java | 6 +- .../connectors/db2/Db2MuxRecordHandler.java | 8 +- .../connectors/db2/Db2RecordHandler.java | 13 +- .../db2/Db2MetadataHandlerTest.java | 18 +- .../connectors/db2/Db2RecordHandlerTest.java | 18 +- athena-docdb/Dockerfile | 9 + athena-docdb/athena-docdb.yaml | 5 +- athena-docdb/pom.xml | 8 +- .../docdb/DocDBMetadataHandler.java | 22 +- .../connectors/docdb/DocDBRecordHandler.java | 17 +- .../docdb/DocDBMetadataHandlerTest.java | 12 +- .../docdb/DocDBRecordHandlerTest.java | 45 ++- .../docdb/integ/DocDbIntegTest.java | 46 +-- athena-dynamodb/Dockerfile | 9 + athena-dynamodb/athena-dynamodb.yaml | 5 +- athena-dynamodb/pom.xml | 25 +- .../dynamodb/DynamoDBMetadataHandler.java | 49 +-- .../dynamodb/DynamoDBRecordHandler.java | 28 +- .../dynamodb/qpt/DDBQueryPassthrough.java | 8 +- .../resolver/DynamoDBFieldResolver.java | 6 +- .../resolver/DynamoDBTableResolver.java | 6 +- .../dynamodb/util/DDBPredicateUtils.java | 8 +- .../dynamodb/util/DDBTableUtils.java | 6 +- .../dynamodb/util/DDBTypeUtils.java | 16 +- .../dynamodb/DynamoDBMetadataHandlerTest.java | 138 ++++--- .../dynamodb/DynamoDBRecordHandlerTest.java | 167 +++++---- .../dynamodb/DynamoDbIntegTest.java | 10 +- athena-elasticsearch/Dockerfile | 9 + .../athena-elasticsearch.yaml | 5 +- athena-elasticsearch/pom.xml | 62 +--- .../AWSRequestSigningApacheInterceptor.java | 106 +++--- .../AwsElasticsearchFactory.java | 11 +- .../elasticsearch/AwsRestHighLevelClient.java | 24 +- .../AwsRestHighLevelClientFactory.java | 4 +- .../ElasticsearchDomainMapProvider.java | 31 +- .../ElasticsearchMetadataHandler.java | 14 +- .../ElasticsearchRecordHandler.java | 19 +- .../ElasticsearchDomainMapProviderTest.java | 59 ++- .../ElasticsearchMetadataHandlerTest.java | 12 +- .../ElasticsearchRecordHandlerTest.java | 43 +-- .../integ/ElasticsearchIntegTest.java | 10 +- .../example/ExampleMetadataHandler.java | 8 +- .../example/ExampleRecordHandler.java | 29 +- .../example/ExampleMetadataHandlerTest.java | 8 +- .../example/ExampleRecordHandlerTest.java | 33 +- athena-federation-integ-test/pom.xml | 81 +--- .../connector/integ/IntegrationTestBase.java | 161 ++++---- .../integ/clients/CloudFormationClient.java | 69 ++-- .../SecretsManagerCredentialsProvider.java | 21 +- .../validation/FederationService.java | 30 -- .../validation/FederationServiceProvider.java | 87 ++--- .../validation/LambdaMetadataProvider.java | 12 +- .../validation/LambdaRecordProvider.java | 4 +- athena-federation-sdk/pom.xml | 123 +++--- .../CrossAccountCredentialsProvider.java | 38 +- .../connector/lambda/QueryStatusChecker.java | 16 +- .../connector/lambda/data/BlockUtils.java | 4 + .../lambda/data/DateTimeFormatterUtil.java | 6 +- .../lambda/data/S3BlockSpillReader.java | 40 +- .../connector/lambda/data/S3BlockSpiller.java | 54 ++- .../domain/spill/SpillLocationVerifier.java | 16 +- .../exceptions/AthenaConnectorException.java | 2 +- .../handlers/AthenaExceptionFilter.java | 2 +- .../lambda/handlers/GlueMetadataHandler.java | 143 ++++--- .../lambda/handlers/MetadataHandler.java | 26 +- .../lambda/handlers/RecordHandler.java | 21 +- .../lambda/metadata/MetadataService.java | 38 -- .../lambda/records/RecordService.java | 38 -- .../security/CachableSecretsManager.java | 17 +- .../lambda/security/KmsKeyFactory.java | 34 +- .../v2/LambdaFunctionExceptionSerDe.java | 30 +- .../serde/v2/ObjectMapperFactoryV2.java | 6 +- .../serde/v3/ObjectMapperFactoryV3.java | 6 +- .../serde/v4/ObjectMapperFactoryV4.java | 6 +- .../serde/v5/ObjectMapperFactoryV5.java | 6 +- .../lambda/QueryStatusCheckerTest.java | 45 +-- .../lambda/data/S3BlockSpillerTest.java | 46 +-- .../spill/SpillLocationVerifierTest.java | 38 +- .../handlers/GlueMetadataHandlerTest.java | 193 +++++----- .../security/CacheableSecretsManagerTest.java | 21 +- .../v2/LambdaFunctionExceptionSerDeTest.java | 21 +- athena-gcs/Dockerfile | 9 + athena-gcs/athena-gcs.yaml | 5 +- .../connectors/gcs/GcsMetadataHandler.java | 32 +- .../connectors/gcs/GcsRecordHandler.java | 17 +- .../gcs/GcsThrottlingExceptionFilter.java | 4 +- .../athena/connectors/gcs/GcsUtil.java | 24 +- .../connectors/gcs/common/PartitionUtil.java | 22 +- .../gcs/filter/FilterExpressionBuilder.java | 6 +- .../gcs/storage/StorageMetadata.java | 16 +- .../gcs/GcsCompositeHandlerTest.java | 28 +- .../gcs/GcsExceptionFilterTest.java | 9 +- .../gcs/GcsMetadataHandlerTest.java | 209 ++++++----- .../connectors/gcs/GcsRecordHandlerTest.java | 23 +- .../athena/connectors/gcs/GcsTestUtils.java | 9 +- .../athena/connectors/gcs/GenericGcsTest.java | 18 +- .../gcs/common/PartitionUtilTest.java | 126 ++++--- .../filter/FilterExpressionBuilderTest.java | 18 +- .../gcs/storage/StorageMetadataTest.java | 50 +-- athena-google-bigquery/Dockerfile | 9 + .../athena-google-bigquery.yaml | 5 +- athena-google-bigquery/pom.xml | 14 - .../bigquery/BigQueryExceptionFilter.java | 4 +- .../bigquery/BigQueryRecordHandler.java | 17 +- .../google/bigquery/BigQueryUtils.java | 18 +- .../BigQueryCompositeHandlerTest.java | 33 +- .../bigquery/BigQueryRecordHandlerTest.java | 14 +- .../bigquery/integ/BigQueryIntegTest.java | 1 - athena-hbase/Dockerfile | 9 + athena-hbase/athena-hbase.yaml | 5 +- athena-hbase/pom.xml | 165 +-------- .../connectors/hbase/HbaseKerberosUtils.java | 38 +- .../hbase/HbaseMetadataHandler.java | 18 +- .../connectors/hbase/HbaseRecordHandler.java | 19 +- .../hbase/HbaseMetadataHandlerTest.java | 12 +- .../hbase/HbaseRecordHandlerTest.java | 42 +-- .../hbase/integ/HbaseIntegTest.java | 111 +++--- athena-hortonworks-hive/Dockerfile | 9 + .../athena-hortonworks-hive.yaml | 5 +- athena-hortonworks-hive/pom.xml | 14 +- .../hortonworks/HiveMetadataHandler.java | 8 +- .../hortonworks/HiveMuxMetadataHandler.java | 6 +- .../hortonworks/HiveMuxRecordHandler.java | 8 +- .../hortonworks/HiveRecordHandler.java | 13 +- .../hortonworks/HiveMetadataHandlerTest.java | 18 +- .../HiveMuxMetadataHandlerTest.java | 12 +- .../hortonworks/HiveMuxRecordHandlerTest.java | 18 +- .../hortonworks/HiveRecordHandlerTest.java | 24 +- athena-jdbc/pom.xml | 74 +--- .../jdbc/MultiplexingJdbcMetadataHandler.java | 8 +- .../jdbc/MultiplexingJdbcRecordHandler.java | 12 +- .../jdbc/manager/JdbcMetadataHandler.java | 8 +- .../jdbc/manager/JdbcRecordHandler.java | 12 +- .../MultiplexingJdbcMetadataHandlerTest.java | 12 +- .../MultiplexingJdbcRecordHandlerTest.java | 18 +- .../jdbc/manager/JdbcMetadataHandlerTest.java | 18 +- .../jdbc/manager/JdbcRecordHandlerTest.java | 50 +-- athena-kafka/Dockerfile | 9 + athena-kafka/athena-kafka.yaml | 5 +- athena-kafka/pom.xml | 10 - .../connectors/kafka/GlueRegistryReader.java | 44 +-- .../kafka/KafkaMetadataHandler.java | 88 ++--- .../connectors/kafka/KafkaRecordHandler.java | 17 +- .../athena/connectors/kafka/KafkaUtils.java | 58 +-- .../kafka/KafkaCompositeHandlerTest.java | 12 +- .../kafka/KafkaMetadataHandlerTest.java | 124 ++++--- .../kafka/KafkaRecordHandlerTest.java | 124 ++++--- .../connectors/kafka/KafkaUtilsTest.java | 95 ++--- athena-msk/Dockerfile | 9 + athena-msk/athena-msk.yaml | 5 +- athena-msk/pom.xml | 21 +- .../msk/AmazonMskMetadataHandler.java | 93 ++--- .../msk/AmazonMskRecordHandler.java | 17 +- .../athena/connectors/msk/AmazonMskUtils.java | 58 +-- .../connectors/msk/GlueRegistryReader.java | 49 +-- .../msk/AmazonMskCompositeHandlerTest.java | 9 +- .../msk/AmazonMskMetadataHandlerTest.java | 123 +++--- .../msk/AmazonMskRecordHandlerTest.java | 120 +++--- .../connectors/msk/AmazonMskUtilsTest.java | 70 ++-- athena-mysql/Dockerfile | 9 + athena-mysql/athena-mysql.yaml | 5 +- athena-mysql/pom.xml | 14 +- .../mysql/MySqlMetadataHandler.java | 8 +- .../mysql/MySqlMuxMetadataHandler.java | 6 +- .../mysql/MySqlMuxRecordHandler.java | 8 +- .../connectors/mysql/MySqlRecordHandler.java | 15 +- .../mysql/MySqlMetadataHandlerTest.java | 18 +- .../MySqlMuxJdbcMetadataHandlerTest.java | 12 +- .../mysql/MySqlMuxJdbcRecordHandlerTest.java | 18 +- .../mysql/MySqlRecordHandlerTest.java | 18 +- .../mysql/integ/MySqlIntegTest.java | 31 +- athena-neptune/Dockerfile | 9 + athena-neptune/athena-neptune.yaml | 5 +- .../neptune/NeptuneMetadataHandler.java | 31 +- .../neptune/NeptuneRecordHandler.java | 21 +- .../neptune/NeptuneMetadataHandlerTest.java | 81 ++-- .../neptune/NeptuneRecordHandlerTest.java | 56 +-- athena-oracle/Dockerfile | 9 + athena-oracle/athena-oracle.yaml | 5 +- athena-oracle/pom.xml | 14 +- .../oracle/OracleMetadataHandler.java | 8 +- .../oracle/OracleMuxMetadataHandler.java | 6 +- .../oracle/OracleMuxRecordHandler.java | 8 +- .../oracle/OracleRecordHandler.java | 15 +- .../oracle/OracleMetadataHandlerTest.java | 18 +- .../OracleMuxJdbcMetadataHandlerTest.java | 14 +- .../OracleMuxJdbcRecordHandlerTest.java | 18 +- .../oracle/OracleRecordHandlerTest.java | 18 +- .../oracle/integ/OracleIntegTest.java | 46 +-- athena-postgresql/Dockerfile | 9 + athena-postgresql/athena-postgresql.yaml | 7 +- athena-postgresql/pom.xml | 14 +- .../postgresql/PostGreSqlMetadataHandler.java | 8 +- .../PostGreSqlMuxMetadataHandler.java | 6 +- .../PostGreSqlMuxRecordHandler.java | 8 +- .../postgresql/PostGreSqlRecordHandler.java | 18 +- .../PostGreSqlMetadataHandlerTest.java | 16 +- .../PostGreSqlMuxJdbcMetadataHandlerTest.java | 12 +- .../PostGreSqlMuxJdbcRecordHandlerTest.java | 18 +- .../PostGreSqlRecordHandlerTest.java | 18 +- .../postgresql/integ/PostGreSqlIntegTest.java | 32 +- athena-redis/Dockerfile | 9 + athena-redis/athena-redis.yaml | 5 +- athena-redis/pom.xml | 60 +-- .../redis/RedisMetadataHandler.java | 22 +- .../connectors/redis/RedisRecordHandler.java | 23 +- .../redis/RedisMetadataHandlerTest.java | 22 +- .../redis/RedisRecordHandlerTest.java | 48 ++- .../redis/integ/RedisIntegTest.java | 196 +++++----- athena-redshift/Dockerfile | 9 + athena-redshift/athena-redshift.yaml | 5 +- athena-redshift/pom.xml | 28 +- .../redshift/RedshiftMetadataHandler.java | 6 +- .../redshift/RedshiftMuxMetadataHandler.java | 6 +- .../redshift/RedshiftMuxRecordHandler.java | 8 +- .../redshift/RedshiftRecordHandler.java | 17 +- .../redshift/RedshiftMetadataHandlerTest.java | 16 +- .../RedshiftMuxJdbcMetadataHandlerTest.java | 12 +- .../RedshiftMuxJdbcRecordHandlerTest.java | 18 +- .../redshift/RedshiftRecordHandlerTest.java | 18 +- .../redshift/integ/RedshiftIntegTest.java | 33 +- athena-saphana/Dockerfile | 9 + athena-saphana/athena-saphana.yaml | 5 +- athena-saphana/pom.xml | 14 +- .../saphana/SaphanaMetadataHandler.java | 8 +- .../saphana/SaphanaMuxMetadataHandler.java | 6 +- .../saphana/SaphanaMuxRecordHandler.java | 8 +- .../saphana/SaphanaRecordHandler.java | 15 +- .../saphana/SaphanaMetadataHandlerTest.java | 18 +- .../SaphanaMuxJdbcMetadataHandlerTest.java | 12 +- .../SaphanaMuxJdbcRecordHandlerTest.java | 18 +- .../saphana/SaphanaRecordHandlerTest.java | 18 +- athena-snowflake/Dockerfile | 9 + athena-snowflake/athena-snowflake.yaml | 5 +- athena-snowflake/pom.xml | 14 +- .../snowflake/SnowflakeMetadataHandler.java | 8 +- .../SnowflakeMuxMetadataHandler.java | 6 +- .../snowflake/SnowflakeMuxRecordHandler.java | 8 +- .../snowflake/SnowflakeRecordHandler.java | 15 +- .../SnowflakeMetadataHandlerTest.java | 18 +- .../SnowflakeMuxJdbcMetadataHandlerTest.java | 12 +- .../SnowflakeMuxJdbcRecordHandlerTest.java | 18 +- .../snowflake/SnowflakeRecordHandlerTest.java | 18 +- athena-sqlserver/Dockerfile | 9 + athena-sqlserver/athena-sqlserver.yaml | 5 +- athena-sqlserver/pom.xml | 14 +- .../sqlserver/SqlServerMetadataHandler.java | 8 +- .../SqlServerMuxMetadataHandler.java | 6 +- .../sqlserver/SqlServerMuxRecordHandler.java | 8 +- .../sqlserver/SqlServerRecordHandler.java | 17 +- .../SqlServerMetadataHandlerTest.java | 18 +- .../SqlServerMuxMetadataHandlerTest.java | 12 +- .../SqlServerMuxRecordHandlerTest.java | 18 +- .../sqlserver/SqlServerRecordHandlerTest.java | 18 +- athena-synapse/Dockerfile | 9 + athena-synapse/athena-synapse.yaml | 5 +- athena-synapse/pom.xml | 14 +- .../synapse/SynapseMetadataHandler.java | 8 +- .../synapse/SynapseMuxMetadataHandler.java | 6 +- .../synapse/SynapseMuxRecordHandler.java | 8 +- .../synapse/SynapseRecordHandler.java | 17 +- .../synapse/SynapseMetadataHandlerTest.java | 19 +- .../SynapseMuxMetadataHandlerTest.java | 12 +- .../synapse/SynapseMuxRecordHandlerTest.java | 18 +- .../synapse/SynapseRecordHandlerTest.java | 18 +- athena-teradata/Dockerfile | 9 + athena-teradata/athena-teradata.yaml | 10 +- athena-teradata/pom.xml | 19 +- .../teradata/TeradataMetadataHandler.java | 8 +- .../teradata/TeradataMuxMetadataHandler.java | 6 +- .../teradata/TeradataMuxRecordHandler.java | 8 +- .../teradata/TeradataRecordHandler.java | 15 +- .../teradata/TeradataMetadataHandlerTest.java | 18 +- .../TeradataMuxJdbcMetadataHandlerTest.java | 12 +- .../TeradataMuxJdbcRecordHandlerTest.java | 18 +- .../teradata/TeradataRecordHandlerTest.java | 18 +- athena-timestream/Dockerfile | 9 + athena-timestream/athena-timestream.yaml | 5 +- athena-timestream/pom.xml | 18 +- .../timestream/TimestreamClientBuilder.java | 30 +- .../timestream/TimestreamMetadataHandler.java | 118 +++--- .../timestream/TimestreamRecordHandler.java | 63 ++-- .../connectors/timestream/TestUtils.java | 72 ++-- .../TimestreamClientBuilderTest.java | 7 +- .../TimestreamMetadataHandlerTest.java | 183 +++++---- .../TimestreamRecordHandlerTest.java | 78 ++-- .../timestream/integ/TimestreamIntegTest.java | 30 +- .../TimestreamWriteRecordRequestBuilder.java | 34 +- athena-tpcds/Dockerfile | 9 + athena-tpcds/athena-tpcds.yaml | 5 +- .../tpcds/TPCDSMetadataHandler.java | 8 +- .../connectors/tpcds/TPCDSRecordHandler.java | 13 +- .../tpcds/TPCDSMetadataHandlerTest.java | 8 +- .../tpcds/TPCDSRecordHandlerTest.java | 48 +-- athena-udfs/Dockerfile | 9 + athena-udfs/athena-udfs.yaml | 5 +- .../connectors/udfs/AthenaUDFHandler.java | 4 +- athena-vertica/Dockerfile | 9 + athena-vertica/athena-vertica.yaml | 5 +- athena-vertica/pom.xml | 10 + .../vertica/VerticaCompositeHandler.java | 12 +- .../connectors/vertica/VerticaConstants.java | 11 + .../vertica/VerticaMetadataHandler.java | 54 +-- .../vertica/VerticaRecordHandler.java | 120 +++--- .../vertica/VerticaSchemaUtils.java | 61 +++ .../vertica/VerticaMetadataHandlerTest.java | 60 ++- .../vertica/VerticaRecordHandlerTest.java | 349 ++++++++++++++++++ pom.xml | 2 +- .../bump_versions/bump_connectors_version.py | 4 + tools/bump_versions/common.py | 7 + 422 files changed, 6258 insertions(+), 6072 deletions(-) create mode 100644 athena-aws-cmdb/Dockerfile create mode 100644 athena-clickhouse/Dockerfile create mode 100644 athena-cloudera-hive/Dockerfile create mode 100644 athena-cloudera-impala/Dockerfile create mode 100644 athena-cloudwatch-metrics/Dockerfile create mode 100644 athena-cloudwatch/Dockerfile create mode 100644 athena-datalakegen2/Dockerfile create mode 100644 athena-db2-as400/Dockerfile create mode 100644 athena-db2/Dockerfile create mode 100644 athena-docdb/Dockerfile create mode 100644 athena-dynamodb/Dockerfile create mode 100644 athena-elasticsearch/Dockerfile delete mode 100644 athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/FederationService.java delete mode 100644 athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/metadata/MetadataService.java delete mode 100644 athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/records/RecordService.java create mode 100644 athena-gcs/Dockerfile create mode 100644 athena-google-bigquery/Dockerfile create mode 100644 athena-hbase/Dockerfile create mode 100644 athena-hortonworks-hive/Dockerfile create mode 100644 athena-kafka/Dockerfile create mode 100644 athena-msk/Dockerfile create mode 100644 athena-mysql/Dockerfile create mode 100644 athena-neptune/Dockerfile create mode 100644 athena-oracle/Dockerfile create mode 100644 athena-postgresql/Dockerfile create mode 100644 athena-redis/Dockerfile create mode 100644 athena-redshift/Dockerfile create mode 100644 athena-saphana/Dockerfile create mode 100644 athena-snowflake/Dockerfile create mode 100644 athena-sqlserver/Dockerfile create mode 100644 athena-synapse/Dockerfile create mode 100644 athena-teradata/Dockerfile create mode 100644 athena-timestream/Dockerfile create mode 100644 athena-tpcds/Dockerfile create mode 100644 athena-udfs/Dockerfile create mode 100644 athena-vertica/Dockerfile create mode 100644 athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandlerTest.java diff --git a/athena-aws-cmdb/Dockerfile b/athena-aws-cmdb/Dockerfile new file mode 100644 index 0000000000..a599a28963 --- /dev/null +++ b/athena-aws-cmdb/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-aws-cmdb-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-aws-cmdb-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.aws.cmdb.AwsCmdbCompositeHandler" ] \ No newline at end of file diff --git a/athena-aws-cmdb/athena-aws-cmdb.yaml b/athena-aws-cmdb/athena-aws-cmdb.yaml index b3265cd1eb..4365e6781d 100644 --- a/athena-aws-cmdb/athena-aws-cmdb.yaml +++ b/athena-aws-cmdb/athena-aws-cmdb.yaml @@ -52,10 +52,9 @@ Resources: spill_bucket: !Ref SpillBucket spill_prefix: !Ref SpillPrefix FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.aws.cmdb.AwsCmdbCompositeHandler" - CodeUri: "./target/athena-aws-cmdb-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-aws-cmdb:2022.47.1' Description: "Enables Amazon Athena to communicate with various AWS Services, making your resource inventories accessible via SQL." - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-aws-cmdb/pom.xml b/athena-aws-cmdb/pom.xml index 90c1ad7b59..6cc732de9c 100644 --- a/athena-aws-cmdb/pom.xml +++ b/athena-aws-cmdb/pom.xml @@ -16,9 +16,9 @@ withdep - com.amazonaws - aws-java-sdk-ec2 - ${aws-sdk.version} + software.amazon.awssdk + ec2 + ${aws-sdk-v2.version} @@ -28,14 +28,20 @@ - com.amazonaws - aws-java-sdk-emr - ${aws-sdk.version} + software.amazon.awssdk + emr + ${aws-sdk-v2.version} - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} + + + software.amazon.awssdk + netty-nio-client + + org.slf4j diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandler.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandler.java index 4a4b61f694..f2626625ba 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandler.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandler.java @@ -39,9 +39,9 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKey; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.List; import java.util.Map; @@ -77,8 +77,8 @@ public AwsCmdbMetadataHandler(java.util.Map configOptions) protected AwsCmdbMetadataHandler( TableProviderFactory tableProviderFactory, EncryptionKeyFactory keyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, java.util.Map configOptions) diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandler.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandler.java index 9dcfe3ffe6..dc530d3f90 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandler.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -56,7 +56,7 @@ public AwsCmdbRecordHandler(java.util.Map configOptions) } @VisibleForTesting - protected AwsCmdbRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, TableProviderFactory tableProviderFactory, java.util.Map configOptions) + protected AwsCmdbRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, TableProviderFactory tableProviderFactory, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); tableProviders = tableProviderFactory.getTableProviders(); diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactory.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactory.java index d5868d33db..cdd1743950 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactory.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactory.java @@ -32,15 +32,11 @@ import com.amazonaws.athena.connectors.aws.cmdb.tables.ec2.VpcTableProvider; import com.amazonaws.athena.connectors.aws.cmdb.tables.s3.S3BucketsTableProvider; import com.amazonaws.athena.connectors.aws.cmdb.tables.s3.S3ObjectsTableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.AmazonEC2ClientBuilder; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder; -import com.amazonaws.services.rds.AmazonRDS; -import com.amazonaws.services.rds.AmazonRDSClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.emr.EmrClient; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.s3.S3Client; import java.util.ArrayList; import java.util.HashMap; @@ -59,15 +55,15 @@ public class TableProviderFactory public TableProviderFactory(java.util.Map configOptions) { this( - AmazonEC2ClientBuilder.standard().build(), - AmazonElasticMapReduceClientBuilder.standard().build(), - AmazonRDSClientBuilder.standard().build(), - AmazonS3ClientBuilder.standard().build(), + Ec2Client.create(), + EmrClient.create(), + RdsClient.create(), + S3Client.create(), configOptions); } @VisibleForTesting - protected TableProviderFactory(AmazonEC2 ec2, AmazonElasticMapReduce emr, AmazonRDS rds, AmazonS3 amazonS3, java.util.Map configOptions) + protected TableProviderFactory(Ec2Client ec2, EmrClient emr, RdsClient rds, S3Client amazonS3, java.util.Map configOptions) { addProvider(new Ec2TableProvider(ec2)); addProvider(new EbsTableProvider(ec2)); diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/EmrClusterTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/EmrClusterTableProvider.java index ee3b15da91..c3d10c7233 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/EmrClusterTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/EmrClusterTableProvider.java @@ -29,15 +29,15 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableRequest; import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; -import com.amazonaws.services.elasticmapreduce.model.Cluster; -import com.amazonaws.services.elasticmapreduce.model.ClusterSummary; -import com.amazonaws.services.elasticmapreduce.model.DescribeClusterRequest; -import com.amazonaws.services.elasticmapreduce.model.DescribeClusterResult; -import com.amazonaws.services.elasticmapreduce.model.ListClustersRequest; -import com.amazonaws.services.elasticmapreduce.model.ListClustersResult; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.emr.EmrClient; +import software.amazon.awssdk.services.emr.model.Cluster; +import software.amazon.awssdk.services.emr.model.ClusterSummary; +import software.amazon.awssdk.services.emr.model.DescribeClusterRequest; +import software.amazon.awssdk.services.emr.model.DescribeClusterResponse; +import software.amazon.awssdk.services.emr.model.ListClustersRequest; +import software.amazon.awssdk.services.emr.model.ListClustersResponse; import java.util.List; import java.util.stream.Collectors; @@ -49,9 +49,9 @@ public class EmrClusterTableProvider implements TableProvider { private static final Schema SCHEMA; - private AmazonElasticMapReduce emr; + private EmrClient emr; - public EmrClusterTableProvider(AmazonElasticMapReduce emr) + public EmrClusterTableProvider(EmrClient emr) { this.emr = emr; } @@ -93,23 +93,23 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { boolean done = false; - ListClustersRequest request = new ListClustersRequest(); + ListClustersRequest request = ListClustersRequest.builder().build(); while (!done) { - ListClustersResult response = emr.listClusters(request); + ListClustersResponse response = emr.listClusters(request); - for (ClusterSummary next : response.getClusters()) { + for (ClusterSummary next : response.clusters()) { Cluster cluster = null; - if (!next.getStatus().getState().toLowerCase().contains("terminated")) { - DescribeClusterResult clusterResponse = emr.describeCluster(new DescribeClusterRequest().withClusterId(next.getId())); - cluster = clusterResponse.getCluster(); + if (!next.status().stateAsString().toLowerCase().contains("terminated")) { + DescribeClusterResponse clusterResponse = emr.describeCluster(DescribeClusterRequest.builder().clusterId(next.id()).build()); + cluster = clusterResponse.cluster(); } clusterToRow(next, cluster, spiller); } - request.setMarker(response.getMarker()); + request = request.toBuilder().marker(response.marker()).build(); - if (response.getMarker() == null || !queryStatusChecker.isQueryRunning()) { + if (response.marker() == null || !queryStatusChecker.isQueryRunning()) { done = true; } } @@ -131,31 +131,31 @@ private void clusterToRow(ClusterSummary clusterSummary, spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("id", row, clusterSummary.getId()); - matched &= block.offerValue("name", row, clusterSummary.getName()); - matched &= block.offerValue("instance_hours", row, clusterSummary.getNormalizedInstanceHours()); - matched &= block.offerValue("state", row, clusterSummary.getStatus().getState()); - matched &= block.offerValue("state_code", row, clusterSummary.getStatus().getStateChangeReason().getCode()); - matched &= block.offerValue("state_msg", row, clusterSummary.getStatus().getStateChangeReason().getMessage()); + matched &= block.offerValue("id", row, clusterSummary.id()); + matched &= block.offerValue("name", row, clusterSummary.name()); + matched &= block.offerValue("instance_hours", row, clusterSummary.normalizedInstanceHours()); + matched &= block.offerValue("state", row, clusterSummary.status().stateAsString()); + matched &= block.offerValue("state_code", row, clusterSummary.status().stateChangeReason().codeAsString()); + matched &= block.offerValue("state_msg", row, clusterSummary.status().stateChangeReason().message()); if (cluster != null) { - matched &= block.offerValue("autoscaling_role", row, cluster.getAutoScalingRole()); - matched &= block.offerValue("custom_ami", row, cluster.getCustomAmiId()); - matched &= block.offerValue("instance_collection_type", row, cluster.getInstanceCollectionType()); - matched &= block.offerValue("log_uri", row, cluster.getLogUri()); - matched &= block.offerValue("master_public_dns", row, cluster.getMasterPublicDnsName()); - matched &= block.offerValue("release_label", row, cluster.getReleaseLabel()); - matched &= block.offerValue("running_ami", row, cluster.getRunningAmiVersion()); - matched &= block.offerValue("scale_down_behavior", row, cluster.getScaleDownBehavior()); - matched &= block.offerValue("service_role", row, cluster.getServiceRole()); - matched &= block.offerValue("service_role", row, cluster.getServiceRole()); - - List applications = cluster.getApplications().stream() - .map(next -> next.getName() + ":" + next.getVersion()).collect(Collectors.toList()); + matched &= block.offerValue("autoscaling_role", row, cluster.autoScalingRole()); + matched &= block.offerValue("custom_ami", row, cluster.customAmiId()); + matched &= block.offerValue("instance_collection_type", row, cluster.instanceCollectionTypeAsString()); + matched &= block.offerValue("log_uri", row, cluster.logUri()); + matched &= block.offerValue("master_public_dns", row, cluster.masterPublicDnsName()); + matched &= block.offerValue("release_label", row, cluster.releaseLabel()); + matched &= block.offerValue("running_ami", row, cluster.runningAmiVersion()); + matched &= block.offerValue("scale_down_behavior", row, cluster.scaleDownBehaviorAsString()); + matched &= block.offerValue("service_role", row, cluster.serviceRole()); + matched &= block.offerValue("service_role", row, cluster.serviceRole()); + + List applications = cluster.applications().stream() + .map(next -> next.name() + ":" + next.version()).collect(Collectors.toList()); matched &= block.offerComplexValue("applications", row, FieldResolver.DEFAULT, applications); - List tags = cluster.getTags().stream() - .map(next -> next.getKey() + ":" + next.getValue()).collect(Collectors.toList()); + List tags = cluster.tags().stream() + .map(next -> next.key() + ":" + next.value()).collect(Collectors.toList()); matched &= block.offerComplexValue("tags", row, FieldResolver.DEFAULT, tags); } diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/RdsTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/RdsTableProvider.java index f3d9a18a8b..d424476646 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/RdsTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/RdsTableProvider.java @@ -30,22 +30,22 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableRequest; import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.rds.AmazonRDS; -import com.amazonaws.services.rds.model.DBInstance; -import com.amazonaws.services.rds.model.DBInstanceStatusInfo; -import com.amazonaws.services.rds.model.DBParameterGroupStatus; -import com.amazonaws.services.rds.model.DBSecurityGroupMembership; -import com.amazonaws.services.rds.model.DBSubnetGroup; -import com.amazonaws.services.rds.model.DescribeDBInstancesRequest; -import com.amazonaws.services.rds.model.DescribeDBInstancesResult; -import com.amazonaws.services.rds.model.DomainMembership; -import com.amazonaws.services.rds.model.Endpoint; -import com.amazonaws.services.rds.model.Subnet; -import com.amazonaws.services.rds.model.Tag; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.rds.model.DBInstance; +import software.amazon.awssdk.services.rds.model.DBInstanceStatusInfo; +import software.amazon.awssdk.services.rds.model.DBParameterGroupStatus; +import software.amazon.awssdk.services.rds.model.DBSecurityGroupMembership; +import software.amazon.awssdk.services.rds.model.DBSubnetGroup; +import software.amazon.awssdk.services.rds.model.DescribeDbInstancesRequest; +import software.amazon.awssdk.services.rds.model.DescribeDbInstancesResponse; +import software.amazon.awssdk.services.rds.model.DomainMembership; +import software.amazon.awssdk.services.rds.model.Endpoint; +import software.amazon.awssdk.services.rds.model.Subnet; +import software.amazon.awssdk.services.rds.model.Tag; import java.util.stream.Collectors; @@ -56,9 +56,9 @@ public class RdsTableProvider implements TableProvider { private static final Schema SCHEMA; - private AmazonRDS rds; + private RdsClient rds; - public RdsTableProvider(AmazonRDS rds) + public RdsTableProvider(RdsClient rds) { this.rds = rds; } @@ -99,27 +99,24 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest @Override public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { - boolean done = false; - DescribeDBInstancesRequest request = new DescribeDBInstancesRequest(); + DescribeDbInstancesRequest.Builder requestBuilder = DescribeDbInstancesRequest.builder(); ValueSet idConstraint = recordsRequest.getConstraints().getSummary().get("instance_id"); if (idConstraint != null && idConstraint.isSingleValue()) { - request.setDBInstanceIdentifier(idConstraint.getSingleValue().toString()); + requestBuilder.dbInstanceIdentifier(idConstraint.getSingleValue().toString()); } - while (!done) { - DescribeDBInstancesResult response = rds.describeDBInstances(request); + DescribeDbInstancesResponse response; + do { + response = rds.describeDBInstances(requestBuilder.build()); - for (DBInstance instance : response.getDBInstances()) { + for (DBInstance instance : response.dbInstances()) { instanceToRow(instance, spiller); } - request.setMarker(response.getMarker()); - - if (response.getMarker() == null || !queryStatusChecker.isQueryRunning()) { - done = true; - } + requestBuilder.marker(response.marker()); } + while (response.marker() != null && queryStatusChecker.isQueryRunning()); } /** @@ -136,145 +133,145 @@ private void instanceToRow(DBInstance instance, spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("instance_id", row, instance.getDBInstanceIdentifier()); - matched &= block.offerValue("primary_az", row, instance.getAvailabilityZone()); - matched &= block.offerValue("storage_gb", row, instance.getAllocatedStorage()); - matched &= block.offerValue("is_encrypted", row, instance.getStorageEncrypted()); - matched &= block.offerValue("storage_type", row, instance.getStorageType()); - matched &= block.offerValue("backup_retention_days", row, instance.getBackupRetentionPeriod()); - matched &= block.offerValue("auto_upgrade", row, instance.getAutoMinorVersionUpgrade()); - matched &= block.offerValue("instance_class", row, instance.getDBInstanceClass()); - matched &= block.offerValue("port", row, instance.getDbInstancePort()); - matched &= block.offerValue("status", row, instance.getDBInstanceStatus()); - matched &= block.offerValue("dbi_resource_id", row, instance.getDbiResourceId()); - matched &= block.offerValue("name", row, instance.getDBName()); - matched &= block.offerValue("engine", row, instance.getEngine()); - matched &= block.offerValue("engine_version", row, instance.getEngineVersion()); - matched &= block.offerValue("license_model", row, instance.getLicenseModel()); - matched &= block.offerValue("secondary_az", row, instance.getSecondaryAvailabilityZone()); - matched &= block.offerValue("backup_window", row, instance.getPreferredBackupWindow()); - matched &= block.offerValue("maint_window", row, instance.getPreferredMaintenanceWindow()); - matched &= block.offerValue("read_replica_source_id", row, instance.getReadReplicaSourceDBInstanceIdentifier()); - matched &= block.offerValue("create_time", row, instance.getInstanceCreateTime()); - matched &= block.offerValue("public_access", row, instance.getPubliclyAccessible()); - matched &= block.offerValue("iops", row, instance.getIops()); - matched &= block.offerValue("is_multi_az", row, instance.getMultiAZ()); + matched &= block.offerValue("instance_id", row, instance.dbInstanceIdentifier()); + matched &= block.offerValue("primary_az", row, instance.availabilityZone()); + matched &= block.offerValue("storage_gb", row, instance.allocatedStorage()); + matched &= block.offerValue("is_encrypted", row, instance.storageEncrypted()); + matched &= block.offerValue("storage_type", row, instance.storageType()); + matched &= block.offerValue("backup_retention_days", row, instance.backupRetentionPeriod()); + matched &= block.offerValue("auto_upgrade", row, instance.autoMinorVersionUpgrade()); + matched &= block.offerValue("instance_class", row, instance.dbInstanceClass()); + matched &= block.offerValue("port", row, instance.dbInstancePort()); + matched &= block.offerValue("status", row, instance.dbInstanceStatus()); + matched &= block.offerValue("dbi_resource_id", row, instance.dbiResourceId()); + matched &= block.offerValue("name", row, instance.dbName()); + matched &= block.offerValue("engine", row, instance.engine()); + matched &= block.offerValue("engine_version", row, instance.engineVersion()); + matched &= block.offerValue("license_model", row, instance.licenseModel()); + matched &= block.offerValue("secondary_az", row, instance.secondaryAvailabilityZone()); + matched &= block.offerValue("backup_window", row, instance.preferredBackupWindow()); + matched &= block.offerValue("maint_window", row, instance.preferredMaintenanceWindow()); + matched &= block.offerValue("read_replica_source_id", row, instance.readReplicaSourceDBInstanceIdentifier()); + matched &= block.offerValue("create_time", row, instance.instanceCreateTime()); + matched &= block.offerValue("public_access", row, instance.publiclyAccessible()); + matched &= block.offerValue("iops", row, instance.iops()); + matched &= block.offerValue("is_multi_az", row, instance.multiAZ()); matched &= block.offerComplexValue("domains", row, (Field field, Object val) -> { if (field.getName().equals("domain")) { - return ((DomainMembership) val).getDomain(); + return ((DomainMembership) val).domain(); } else if (field.getName().equals("fqdn")) { - return ((DomainMembership) val).getFQDN(); + return ((DomainMembership) val).fqdn(); } else if (field.getName().equals("iam_role")) { - return ((DomainMembership) val).getIAMRoleName(); + return ((DomainMembership) val).iamRoleName(); } else if (field.getName().equals("status")) { - return ((DomainMembership) val).getStatus(); + return ((DomainMembership) val).status(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - instance.getDomainMemberships()); + instance.domainMemberships()); matched &= block.offerComplexValue("param_groups", row, (Field field, Object val) -> { if (field.getName().equals("name")) { - return ((DBParameterGroupStatus) val).getDBParameterGroupName(); + return ((DBParameterGroupStatus) val).dbParameterGroupName(); } else if (field.getName().equals("status")) { - return ((DBParameterGroupStatus) val).getParameterApplyStatus(); + return ((DBParameterGroupStatus) val).parameterApplyStatus(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - instance.getDBParameterGroups()); + instance.dbParameterGroups()); matched &= block.offerComplexValue("db_security_groups", row, (Field field, Object val) -> { if (field.getName().equals("name")) { - return ((DBSecurityGroupMembership) val).getDBSecurityGroupName(); + return ((DBSecurityGroupMembership) val).dbSecurityGroupName(); } else if (field.getName().equals("status")) { - return ((DBSecurityGroupMembership) val).getStatus(); + return ((DBSecurityGroupMembership) val).status(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - instance.getDBSecurityGroups()); + instance.dbSecurityGroups()); matched &= block.offerComplexValue("subnet_group", row, (Field field, Object val) -> { if (field.getName().equals("description")) { - return ((DBSubnetGroup) val).getDBSubnetGroupDescription(); + return ((DBSubnetGroup) val).dbSubnetGroupDescription(); } else if (field.getName().equals("name")) { - return ((DBSubnetGroup) val).getDBSubnetGroupName(); + return ((DBSubnetGroup) val).dbSubnetGroupName(); } else if (field.getName().equals("status")) { - return ((DBSubnetGroup) val).getSubnetGroupStatus(); + return ((DBSubnetGroup) val).subnetGroupStatus(); } else if (field.getName().equals("vpc")) { - return ((DBSubnetGroup) val).getVpcId(); + return ((DBSubnetGroup) val).vpcId(); } else if (field.getName().equals("subnets")) { - return ((DBSubnetGroup) val).getSubnets().stream() - .map(next -> next.getSubnetIdentifier()).collect(Collectors.toList()); + return ((DBSubnetGroup) val).subnets().stream() + .map(next -> next.subnetIdentifier()).collect(Collectors.toList()); } else if (val instanceof Subnet) { - return ((Subnet) val).getSubnetIdentifier(); + return ((Subnet) val).subnetIdentifier(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - instance.getDBSubnetGroup()); + instance.dbSubnetGroup()); matched &= block.offerComplexValue("endpoint", row, (Field field, Object val) -> { if (field.getName().equals("address")) { - return ((Endpoint) val).getAddress(); + return ((Endpoint) val).address(); } else if (field.getName().equals("port")) { - return ((Endpoint) val).getPort(); + return ((Endpoint) val).port(); } else if (field.getName().equals("zone")) { - return ((Endpoint) val).getHostedZoneId(); + return ((Endpoint) val).hostedZoneId(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - instance.getEndpoint()); + instance.endpoint()); matched &= block.offerComplexValue("status_infos", row, (Field field, Object val) -> { if (field.getName().equals("message")) { - return ((DBInstanceStatusInfo) val).getMessage(); + return ((DBInstanceStatusInfo) val).message(); } else if (field.getName().equals("is_normal")) { - return ((DBInstanceStatusInfo) val).getNormal(); + return ((DBInstanceStatusInfo) val).normal(); } else if (field.getName().equals("status")) { - return ((DBInstanceStatusInfo) val).getStatus(); + return ((DBInstanceStatusInfo) val).status(); } else if (field.getName().equals("type")) { - return ((DBInstanceStatusInfo) val).getStatusType(); + return ((DBInstanceStatusInfo) val).statusType(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - instance.getStatusInfos()); + instance.statusInfos()); matched &= block.offerComplexValue("tags", row, (Field field, Object val) -> { if (field.getName().equals("key")) { - return ((Tag) val).getKey(); + return ((Tag) val).key(); } else if (field.getName().equals("value")) { - return ((Tag) val).getValue(); + return ((Tag) val).value(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - instance.getTagList()); + instance.tagList()); return matched ? 1 : 0; }); diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/EbsTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/EbsTableProvider.java index 48b6503757..7356a34ea7 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/EbsTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/EbsTableProvider.java @@ -31,14 +31,14 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeVolumesRequest; -import com.amazonaws.services.ec2.model.DescribeVolumesResult; -import com.amazonaws.services.ec2.model.Volume; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeVolumesRequest; +import software.amazon.awssdk.services.ec2.model.DescribeVolumesResponse; +import software.amazon.awssdk.services.ec2.model.Volume; import java.util.Collections; import java.util.List; @@ -52,9 +52,9 @@ public class EbsTableProvider { private static final Logger logger = LoggerFactory.getLogger(EbsTableProvider.class); private static final Schema SCHEMA; - private AmazonEC2 ec2; + private Ec2Client ec2; - public EbsTableProvider(AmazonEC2 ec2) + public EbsTableProvider(Ec2Client ec2) { this.ec2 = ec2; } @@ -96,24 +96,24 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { boolean done = false; - DescribeVolumesRequest request = new DescribeVolumesRequest(); + DescribeVolumesRequest.Builder request = DescribeVolumesRequest.builder(); ValueSet idConstraint = recordsRequest.getConstraints().getSummary().get("id"); if (idConstraint != null && idConstraint.isSingleValue()) { - request.setVolumeIds(Collections.singletonList(idConstraint.getSingleValue().toString())); + request.volumeIds(Collections.singletonList(idConstraint.getSingleValue().toString())); } while (!done) { - DescribeVolumesResult response = ec2.describeVolumes(request); + DescribeVolumesResponse response = ec2.describeVolumes(request.build()); - for (Volume volume : response.getVolumes()) { + for (Volume volume : response.volumes()) { logger.info("readWithConstraint: {}", response); instanceToRow(volume, spiller); } - request.setNextToken(response.getNextToken()); + request.nextToken(response.nextToken()); - if (response.getNextToken() == null || !queryStatusChecker.isQueryRunning()) { + if (response.nextToken() == null || !queryStatusChecker.isQueryRunning()) { done = true; } } @@ -133,26 +133,26 @@ private void instanceToRow(Volume volume, spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("id", row, volume.getVolumeId()); - matched &= block.offerValue("type", row, volume.getVolumeType()); - matched &= block.offerValue("availability_zone", row, volume.getAvailabilityZone()); - matched &= block.offerValue("created_time", row, volume.getCreateTime()); - matched &= block.offerValue("is_encrypted", row, volume.getEncrypted()); - matched &= block.offerValue("kms_key_id", row, volume.getKmsKeyId()); - matched &= block.offerValue("size", row, volume.getSize()); - matched &= block.offerValue("iops", row, volume.getIops()); - matched &= block.offerValue("snapshot_id", row, volume.getSnapshotId()); - matched &= block.offerValue("state", row, volume.getState()); - - if (volume.getAttachments().size() == 1) { - matched &= block.offerValue("target", row, volume.getAttachments().get(0).getInstanceId()); - matched &= block.offerValue("attached_device", row, volume.getAttachments().get(0).getDevice()); - matched &= block.offerValue("attachment_state", row, volume.getAttachments().get(0).getState()); - matched &= block.offerValue("attachment_time", row, volume.getAttachments().get(0).getAttachTime()); + matched &= block.offerValue("id", row, volume.volumeId()); + matched &= block.offerValue("type", row, volume.volumeTypeAsString()); + matched &= block.offerValue("availability_zone", row, volume.availabilityZone()); + matched &= block.offerValue("created_time", row, volume.createTime()); + matched &= block.offerValue("is_encrypted", row, volume.encrypted()); + matched &= block.offerValue("kms_key_id", row, volume.kmsKeyId()); + matched &= block.offerValue("size", row, volume.size()); + matched &= block.offerValue("iops", row, volume.iops()); + matched &= block.offerValue("snapshot_id", row, volume.snapshotId()); + matched &= block.offerValue("state", row, volume.stateAsString()); + + if (volume.attachments().size() == 1) { + matched &= block.offerValue("target", row, volume.attachments().get(0).instanceId()); + matched &= block.offerValue("attached_device", row, volume.attachments().get(0).device()); + matched &= block.offerValue("attachment_state", row, volume.attachments().get(0).stateAsString()); + matched &= block.offerValue("attachment_time", row, volume.attachments().get(0).attachTime()); } - List tags = volume.getTags().stream() - .map(next -> next.getKey() + ":" + next.getValue()).collect(Collectors.toList()); + List tags = volume.tags().stream() + .map(next -> next.key() + ":" + next.value()).collect(Collectors.toList()); matched &= block.offerComplexValue("tags", row, FieldResolver.DEFAULT, tags); return matched ? 1 : 0; diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/Ec2TableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/Ec2TableProvider.java index dfa8876284..6bf9dbb58d 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/Ec2TableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/Ec2TableProvider.java @@ -32,19 +32,19 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeInstancesRequest; -import com.amazonaws.services.ec2.model.DescribeInstancesResult; -import com.amazonaws.services.ec2.model.Instance; -import com.amazonaws.services.ec2.model.InstanceNetworkInterface; -import com.amazonaws.services.ec2.model.InstanceState; -import com.amazonaws.services.ec2.model.Reservation; -import com.amazonaws.services.ec2.model.StateReason; -import com.amazonaws.services.ec2.model.Tag; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeInstancesRequest; +import software.amazon.awssdk.services.ec2.model.DescribeInstancesResponse; +import software.amazon.awssdk.services.ec2.model.Instance; +import software.amazon.awssdk.services.ec2.model.InstanceNetworkInterface; +import software.amazon.awssdk.services.ec2.model.InstanceState; +import software.amazon.awssdk.services.ec2.model.Reservation; +import software.amazon.awssdk.services.ec2.model.StateReason; +import software.amazon.awssdk.services.ec2.model.Tag; import java.util.Collections; import java.util.List; @@ -57,9 +57,9 @@ public class Ec2TableProvider implements TableProvider { private static final Schema SCHEMA; - private AmazonEC2 ec2; + private Ec2Client ec2; - public Ec2TableProvider(AmazonEC2 ec2) + public Ec2TableProvider(Ec2Client ec2) { this.ec2 = ec2; } @@ -101,25 +101,25 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { boolean done = false; - DescribeInstancesRequest request = new DescribeInstancesRequest(); + DescribeInstancesRequest.Builder request = DescribeInstancesRequest.builder(); ValueSet idConstraint = recordsRequest.getConstraints().getSummary().get("instance_id"); if (idConstraint != null && idConstraint.isSingleValue()) { - request.setInstanceIds(Collections.singletonList(idConstraint.getSingleValue().toString())); + request.instanceIds(Collections.singletonList(idConstraint.getSingleValue().toString())); } while (!done) { - DescribeInstancesResult response = ec2.describeInstances(request); + DescribeInstancesResponse response = ec2.describeInstances(request.build()); - for (Reservation reservation : response.getReservations()) { - for (Instance instance : reservation.getInstances()) { + for (Reservation reservation : response.reservations()) { + for (Instance instance : reservation.instances()) { instanceToRow(instance, spiller); } } - request.setNextToken(response.getNextToken()); + request.nextToken(response.nextToken()); - if (response.getNextToken() == null || !queryStatusChecker.isQueryRunning()) { + if (response.nextToken() == null || !queryStatusChecker.isQueryRunning()) { done = true; } } @@ -139,106 +139,106 @@ private void instanceToRow(Instance instance, spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("instance_id", row, instance.getInstanceId()); - matched &= block.offerValue("image_id", row, instance.getImageId()); - matched &= block.offerValue("instance_type", row, instance.getInstanceType()); - matched &= block.offerValue("platform", row, instance.getPlatform()); - matched &= block.offerValue("private_dns_name", row, instance.getPrivateDnsName()); - matched &= block.offerValue("private_ip_address", row, instance.getPrivateIpAddress()); - matched &= block.offerValue("public_dns_name", row, instance.getPublicDnsName()); - matched &= block.offerValue("public_ip_address", row, instance.getPublicIpAddress()); - matched &= block.offerValue("subnet_id", row, instance.getSubnetId()); - matched &= block.offerValue("vpc_id", row, instance.getVpcId()); - matched &= block.offerValue("architecture", row, instance.getArchitecture()); - matched &= block.offerValue("instance_lifecycle", row, instance.getInstanceLifecycle()); - matched &= block.offerValue("root_device_name", row, instance.getRootDeviceName()); - matched &= block.offerValue("root_device_type", row, instance.getRootDeviceType()); - matched &= block.offerValue("spot_instance_request_id", row, instance.getSpotInstanceRequestId()); - matched &= block.offerValue("virtualization_type", row, instance.getVirtualizationType()); - matched &= block.offerValue("key_name", row, instance.getKeyName()); - matched &= block.offerValue("kernel_id", row, instance.getKernelId()); - matched &= block.offerValue("capacity_reservation_id", row, instance.getCapacityReservationId()); - matched &= block.offerValue("launch_time", row, instance.getLaunchTime()); + matched &= block.offerValue("instance_id", row, instance.instanceId()); + matched &= block.offerValue("image_id", row, instance.imageId()); + matched &= block.offerValue("instance_type", row, instance.instanceTypeAsString()); + matched &= block.offerValue("platform", row, instance.platformAsString()); + matched &= block.offerValue("private_dns_name", row, instance.privateDnsName()); + matched &= block.offerValue("private_ip_address", row, instance.privateIpAddress()); + matched &= block.offerValue("public_dns_name", row, instance.publicDnsName()); + matched &= block.offerValue("public_ip_address", row, instance.publicIpAddress()); + matched &= block.offerValue("subnet_id", row, instance.subnetId()); + matched &= block.offerValue("vpc_id", row, instance.vpcId()); + matched &= block.offerValue("architecture", row, instance.architectureAsString()); + matched &= block.offerValue("instance_lifecycle", row, instance.instanceLifecycleAsString()); + matched &= block.offerValue("root_device_name", row, instance.rootDeviceName()); + matched &= block.offerValue("root_device_type", row, instance.rootDeviceTypeAsString()); + matched &= block.offerValue("spot_instance_request_id", row, instance.spotInstanceRequestId()); + matched &= block.offerValue("virtualization_type", row, instance.virtualizationTypeAsString()); + matched &= block.offerValue("key_name", row, instance.keyName()); + matched &= block.offerValue("kernel_id", row, instance.kernelId()); + matched &= block.offerValue("capacity_reservation_id", row, instance.capacityReservationId()); + matched &= block.offerValue("launch_time", row, instance.launchTime()); matched &= block.offerComplexValue("state", row, (Field field, Object val) -> { if (field.getName().equals("name")) { - return ((InstanceState) val).getName(); + return ((InstanceState) val).nameAsString(); } else if (field.getName().equals("code")) { - return ((InstanceState) val).getCode(); + return ((InstanceState) val).code(); } throw new RuntimeException("Unknown field " + field.getName()); - }, instance.getState()); + }, instance.state()); matched &= block.offerComplexValue("network_interfaces", row, (Field field, Object val) -> { if (field.getName().equals("status")) { - return ((InstanceNetworkInterface) val).getStatus(); + return ((InstanceNetworkInterface) val).statusAsString(); } else if (field.getName().equals("subnet")) { - return ((InstanceNetworkInterface) val).getSubnetId(); + return ((InstanceNetworkInterface) val).subnetId(); } else if (field.getName().equals("vpc")) { - return ((InstanceNetworkInterface) val).getVpcId(); + return ((InstanceNetworkInterface) val).vpcId(); } else if (field.getName().equals("mac")) { - return ((InstanceNetworkInterface) val).getMacAddress(); + return ((InstanceNetworkInterface) val).macAddress(); } else if (field.getName().equals("private_dns")) { - return ((InstanceNetworkInterface) val).getPrivateDnsName(); + return ((InstanceNetworkInterface) val).privateDnsName(); } else if (field.getName().equals("private_ip")) { - return ((InstanceNetworkInterface) val).getPrivateIpAddress(); + return ((InstanceNetworkInterface) val).privateIpAddress(); } else if (field.getName().equals("security_groups")) { - return ((InstanceNetworkInterface) val).getGroups().stream().map(next -> next.getGroupName() + ":" + next.getGroupId()).collect(Collectors.toList()); + return ((InstanceNetworkInterface) val).groups().stream().map(next -> next.groupName() + ":" + next.groupId()).collect(Collectors.toList()); } else if (field.getName().equals("interface_id")) { - return ((InstanceNetworkInterface) val).getNetworkInterfaceId(); + return ((InstanceNetworkInterface) val).networkInterfaceId(); } throw new RuntimeException("Unknown field " + field.getName()); - }, instance.getNetworkInterfaces()); + }, instance.networkInterfaces()); matched &= block.offerComplexValue("state_reason", row, (Field field, Object val) -> { if (field.getName().equals("message")) { - return ((StateReason) val).getMessage(); + return ((StateReason) val).message(); } else if (field.getName().equals("code")) { - return ((StateReason) val).getCode(); + return ((StateReason) val).code(); } throw new RuntimeException("Unknown field " + field.getName()); - }, instance.getStateReason()); + }, instance.stateReason()); - matched &= block.offerValue("ebs_optimized", row, instance.getEbsOptimized()); + matched &= block.offerValue("ebs_optimized", row, instance.ebsOptimized()); - List securityGroups = instance.getSecurityGroups().stream() - .map(next -> next.getGroupId()).collect(Collectors.toList()); + List securityGroups = instance.securityGroups().stream() + .map(next -> next.groupId()).collect(Collectors.toList()); matched &= block.offerComplexValue("security_groups", row, FieldResolver.DEFAULT, securityGroups); - List securityGroupNames = instance.getSecurityGroups().stream() - .map(next -> next.getGroupName()).collect(Collectors.toList()); + List securityGroupNames = instance.securityGroups().stream() + .map(next -> next.groupName()).collect(Collectors.toList()); matched &= block.offerComplexValue("security_group_names", row, FieldResolver.DEFAULT, securityGroupNames); - List ebsVolumes = instance.getBlockDeviceMappings().stream() - .map(next -> next.getEbs().getVolumeId()).collect(Collectors.toList()); + List ebsVolumes = instance.blockDeviceMappings().stream() + .map(next -> next.ebs().volumeId()).collect(Collectors.toList()); matched &= block.offerComplexValue("ebs_volumes", row, FieldResolver.DEFAULT, ebsVolumes); matched &= block.offerComplexValue("tags", row, (Field field, Object val) -> { if (field.getName().equals("key")) { - return ((Tag) val).getKey(); + return ((Tag) val).key(); } else if (field.getName().equals("value")) { - return ((Tag) val).getValue(); + return ((Tag) val).value(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - instance.getTags()); + instance.tags()); return matched ? 1 : 0; }); diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/ImagesTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/ImagesTableProvider.java index 3858946948..a80ad779bf 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/ImagesTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/ImagesTableProvider.java @@ -31,17 +31,17 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.BlockDeviceMapping; -import com.amazonaws.services.ec2.model.DescribeImagesRequest; -import com.amazonaws.services.ec2.model.DescribeImagesResult; -import com.amazonaws.services.ec2.model.EbsBlockDevice; -import com.amazonaws.services.ec2.model.Image; -import com.amazonaws.services.ec2.model.Tag; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.BlockDeviceMapping; +import software.amazon.awssdk.services.ec2.model.DescribeImagesRequest; +import software.amazon.awssdk.services.ec2.model.DescribeImagesResponse; +import software.amazon.awssdk.services.ec2.model.EbsBlockDevice; +import software.amazon.awssdk.services.ec2.model.Image; +import software.amazon.awssdk.services.ec2.model.Tag; import java.util.Collections; import java.util.List; @@ -58,9 +58,9 @@ public class ImagesTableProvider //query for a specific owner. private final String defaultOwner; private static final Schema SCHEMA; - private AmazonEC2 ec2; + private Ec2Client ec2; - public ImagesTableProvider(AmazonEC2 ec2, java.util.Map configOptions) + public ImagesTableProvider(Ec2Client ec2, java.util.Map configOptions) { this.ec2 = ec2; this.defaultOwner = configOptions.get(DEFAULT_OWNER_ENV); @@ -104,28 +104,28 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest @Override public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { - DescribeImagesRequest request = new DescribeImagesRequest(); + DescribeImagesRequest.Builder request = DescribeImagesRequest.builder(); ValueSet idConstraint = recordsRequest.getConstraints().getSummary().get("id"); ValueSet ownerConstraint = recordsRequest.getConstraints().getSummary().get("owner"); if (idConstraint != null && idConstraint.isSingleValue()) { - request.setImageIds(Collections.singletonList(idConstraint.getSingleValue().toString())); + request.imageIds(Collections.singletonList(idConstraint.getSingleValue().toString())); } else if (ownerConstraint != null && ownerConstraint.isSingleValue()) { - request.setOwners(Collections.singletonList(ownerConstraint.getSingleValue().toString())); + request.owners(Collections.singletonList(ownerConstraint.getSingleValue().toString())); } else if (defaultOwner != null) { - request.setOwners(Collections.singletonList(defaultOwner)); + request.owners(Collections.singletonList(defaultOwner)); } else { throw new RuntimeException("A default owner account must be set or the query must have owner" + "in the where clause with exactly 1 value otherwise results may be too big."); } - DescribeImagesResult response = ec2.describeImages(request); + DescribeImagesResponse response = ec2.describeImages(request.build()); int count = 0; - for (Image next : response.getImages()) { + for (Image next : response.images()) { if (count++ > MAX_IMAGES) { throw new RuntimeException("Too many images returned, add an owner or id filter."); } @@ -147,34 +147,34 @@ private void instanceToRow(Image image, spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("id", row, image.getImageId()); - matched &= block.offerValue("architecture", row, image.getArchitecture()); - matched &= block.offerValue("created", row, image.getCreationDate()); - matched &= block.offerValue("description", row, image.getDescription()); - matched &= block.offerValue("hypervisor", row, image.getHypervisor()); - matched &= block.offerValue("location", row, image.getImageLocation()); - matched &= block.offerValue("type", row, image.getImageType()); - matched &= block.offerValue("kernel", row, image.getKernelId()); - matched &= block.offerValue("name", row, image.getName()); - matched &= block.offerValue("owner", row, image.getOwnerId()); - matched &= block.offerValue("platform", row, image.getPlatform()); - matched &= block.offerValue("ramdisk", row, image.getRamdiskId()); - matched &= block.offerValue("root_device", row, image.getRootDeviceName()); - matched &= block.offerValue("root_type", row, image.getRootDeviceType()); - matched &= block.offerValue("srvio_net", row, image.getSriovNetSupport()); - matched &= block.offerValue("state", row, image.getState()); - matched &= block.offerValue("virt_type", row, image.getVirtualizationType()); - matched &= block.offerValue("is_public", row, image.getPublic()); + matched &= block.offerValue("id", row, image.imageId()); + matched &= block.offerValue("architecture", row, image.architectureAsString()); + matched &= block.offerValue("created", row, image.creationDate()); + matched &= block.offerValue("description", row, image.description()); + matched &= block.offerValue("hypervisor", row, image.hypervisorAsString()); + matched &= block.offerValue("location", row, image.imageLocation()); + matched &= block.offerValue("type", row, image.imageTypeAsString()); + matched &= block.offerValue("kernel", row, image.kernelId()); + matched &= block.offerValue("name", row, image.name()); + matched &= block.offerValue("owner", row, image.ownerId()); + matched &= block.offerValue("platform", row, image.platformAsString()); + matched &= block.offerValue("ramdisk", row, image.ramdiskId()); + matched &= block.offerValue("root_device", row, image.rootDeviceName()); + matched &= block.offerValue("root_type", row, image.rootDeviceTypeAsString()); + matched &= block.offerValue("srvio_net", row, image.sriovNetSupport()); + matched &= block.offerValue("state", row, image.stateAsString()); + matched &= block.offerValue("virt_type", row, image.virtualizationTypeAsString()); + matched &= block.offerValue("is_public", row, image.publicLaunchPermissions()); - List tags = image.getTags(); + List tags = image.tags(); matched &= block.offerComplexValue("tags", row, (Field field, Object val) -> { if (field.getName().equals("key")) { - return ((Tag) val).getKey(); + return ((Tag) val).key(); } else if (field.getName().equals("value")) { - return ((Tag) val).getValue(); + return ((Tag) val).value(); } throw new RuntimeException("Unexpected field " + field.getName()); @@ -185,33 +185,33 @@ else if (field.getName().equals("value")) { row, (Field field, Object val) -> { if (field.getName().equals("dev_name")) { - return ((BlockDeviceMapping) val).getDeviceName(); + return ((BlockDeviceMapping) val).deviceName(); } else if (field.getName().equals("no_device")) { - return ((BlockDeviceMapping) val).getNoDevice(); + return ((BlockDeviceMapping) val).noDevice(); } else if (field.getName().equals("virt_name")) { - return ((BlockDeviceMapping) val).getVirtualName(); + return ((BlockDeviceMapping) val).virtualName(); } else if (field.getName().equals("ebs")) { - return ((BlockDeviceMapping) val).getEbs(); + return ((BlockDeviceMapping) val).ebs(); } else if (field.getName().equals("ebs_size")) { - return ((EbsBlockDevice) val).getVolumeSize(); + return ((EbsBlockDevice) val).volumeSize(); } else if (field.getName().equals("ebs_iops")) { - return ((EbsBlockDevice) val).getIops(); + return ((EbsBlockDevice) val).iops(); } else if (field.getName().equals("ebs_type")) { - return ((EbsBlockDevice) val).getVolumeType(); + return ((EbsBlockDevice) val).volumeTypeAsString(); } else if (field.getName().equals("ebs_kms_key")) { - return ((EbsBlockDevice) val).getKmsKeyId(); + return ((EbsBlockDevice) val).kmsKeyId(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - image.getBlockDeviceMappings()); + image.blockDeviceMappings()); return matched ? 1 : 0; }); diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/RouteTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/RouteTableProvider.java index 24583be45e..7c71183464 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/RouteTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/RouteTableProvider.java @@ -31,13 +31,13 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeRouteTablesRequest; -import com.amazonaws.services.ec2.model.DescribeRouteTablesResult; -import com.amazonaws.services.ec2.model.Route; -import com.amazonaws.services.ec2.model.RouteTable; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeRouteTablesRequest; +import software.amazon.awssdk.services.ec2.model.DescribeRouteTablesResponse; +import software.amazon.awssdk.services.ec2.model.Route; +import software.amazon.awssdk.services.ec2.model.RouteTable; import java.util.Collections; import java.util.List; @@ -50,9 +50,9 @@ public class RouteTableProvider implements TableProvider { private static final Schema SCHEMA; - private AmazonEC2 ec2; + private Ec2Client ec2; - public RouteTableProvider(AmazonEC2 ec2) + public RouteTableProvider(Ec2Client ec2) { this.ec2 = ec2; } @@ -94,25 +94,25 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { boolean done = false; - DescribeRouteTablesRequest request = new DescribeRouteTablesRequest(); + DescribeRouteTablesRequest.Builder request = DescribeRouteTablesRequest.builder(); ValueSet idConstraint = recordsRequest.getConstraints().getSummary().get("route_table_id"); if (idConstraint != null && idConstraint.isSingleValue()) { - request.setRouteTableIds(Collections.singletonList(idConstraint.getSingleValue().toString())); + request.routeTableIds(Collections.singletonList(idConstraint.getSingleValue().toString())); } while (!done) { - DescribeRouteTablesResult response = ec2.describeRouteTables(request); + DescribeRouteTablesResponse response = ec2.describeRouteTables(request.build()); - for (RouteTable nextRouteTable : response.getRouteTables()) { - for (Route route : nextRouteTable.getRoutes()) { + for (RouteTable nextRouteTable : response.routeTables()) { + for (Route route : nextRouteTable.routes()) { instanceToRow(nextRouteTable, route, spiller); } } - request.setNextToken(response.getNextToken()); + request.nextToken(response.nextToken()); - if (response.getNextToken() == null || !queryStatusChecker.isQueryRunning()) { + if (response.nextToken() == null || !queryStatusChecker.isQueryRunning()) { done = true; } } @@ -134,33 +134,33 @@ private void instanceToRow(RouteTable routeTable, spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("route_table_id", row, routeTable.getRouteTableId()); - matched &= block.offerValue("owner", row, routeTable.getOwnerId()); - matched &= block.offerValue("vpc", row, routeTable.getVpcId()); - matched &= block.offerValue("dst_cidr", row, route.getDestinationCidrBlock()); - matched &= block.offerValue("dst_cidr_v6", row, route.getDestinationIpv6CidrBlock()); - matched &= block.offerValue("dst_prefix_list", row, route.getDestinationPrefixListId()); - matched &= block.offerValue("egress_igw", row, route.getEgressOnlyInternetGatewayId()); - matched &= block.offerValue("gateway", row, route.getGatewayId()); - matched &= block.offerValue("instance_id", row, route.getInstanceId()); - matched &= block.offerValue("instance_owner", row, route.getInstanceOwnerId()); - matched &= block.offerValue("nat_gateway", row, route.getNatGatewayId()); - matched &= block.offerValue("interface", row, route.getNetworkInterfaceId()); - matched &= block.offerValue("origin", row, route.getOrigin()); - matched &= block.offerValue("state", row, route.getState()); - matched &= block.offerValue("transit_gateway", row, route.getTransitGatewayId()); - matched &= block.offerValue("vpc_peering_con", row, route.getVpcPeeringConnectionId()); - - List associations = routeTable.getAssociations().stream() - .map(next -> next.getSubnetId() + ":" + next.getRouteTableId()).collect(Collectors.toList()); + matched &= block.offerValue("route_table_id", row, routeTable.routeTableId()); + matched &= block.offerValue("owner", row, routeTable.ownerId()); + matched &= block.offerValue("vpc", row, routeTable.vpcId()); + matched &= block.offerValue("dst_cidr", row, route.destinationCidrBlock()); + matched &= block.offerValue("dst_cidr_v6", row, route.destinationIpv6CidrBlock()); + matched &= block.offerValue("dst_prefix_list", row, route.destinationPrefixListId()); + matched &= block.offerValue("egress_igw", row, route.egressOnlyInternetGatewayId()); + matched &= block.offerValue("gateway", row, route.gatewayId()); + matched &= block.offerValue("instance_id", row, route.instanceId()); + matched &= block.offerValue("instance_owner", row, route.instanceOwnerId()); + matched &= block.offerValue("nat_gateway", row, route.natGatewayId()); + matched &= block.offerValue("interface", row, route.networkInterfaceId()); + matched &= block.offerValue("origin", row, route.originAsString()); + matched &= block.offerValue("state", row, route.stateAsString()); + matched &= block.offerValue("transit_gateway", row, route.transitGatewayId()); + matched &= block.offerValue("vpc_peering_con", row, route.vpcPeeringConnectionId()); + + List associations = routeTable.associations().stream() + .map(next -> next.subnetId() + ":" + next.routeTableId()).collect(Collectors.toList()); matched &= block.offerComplexValue("associations", row, FieldResolver.DEFAULT, associations); - List tags = routeTable.getTags().stream() - .map(next -> next.getKey() + ":" + next.getValue()).collect(Collectors.toList()); + List tags = routeTable.tags().stream() + .map(next -> next.key() + ":" + next.value()).collect(Collectors.toList()); matched &= block.offerComplexValue("tags", row, FieldResolver.DEFAULT, tags); - List propagatingVgws = routeTable.getPropagatingVgws().stream() - .map(next -> next.getGatewayId()).collect(Collectors.toList()); + List propagatingVgws = routeTable.propagatingVgws().stream() + .map(next -> next.gatewayId()).collect(Collectors.toList()); matched &= block.offerComplexValue("propagating_vgws", row, FieldResolver.DEFAULT, propagatingVgws); return matched ? 1 : 0; diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SecurityGroupsTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SecurityGroupsTableProvider.java index 8f4f6dd3c3..94afbdf687 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SecurityGroupsTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SecurityGroupsTableProvider.java @@ -31,13 +31,13 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeSecurityGroupsRequest; -import com.amazonaws.services.ec2.model.DescribeSecurityGroupsResult; -import com.amazonaws.services.ec2.model.IpPermission; -import com.amazonaws.services.ec2.model.SecurityGroup; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeSecurityGroupsRequest; +import software.amazon.awssdk.services.ec2.model.DescribeSecurityGroupsResponse; +import software.amazon.awssdk.services.ec2.model.IpPermission; +import software.amazon.awssdk.services.ec2.model.SecurityGroup; import java.util.Collections; import java.util.List; @@ -53,9 +53,9 @@ public class SecurityGroupsTableProvider private static final String EGRESS = "egress"; private static final Schema SCHEMA; - private AmazonEC2 ec2; + private Ec2Client ec2; - public SecurityGroupsTableProvider(AmazonEC2 ec2) + public SecurityGroupsTableProvider(Ec2Client ec2) { this.ec2 = ec2; } @@ -97,34 +97,34 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { boolean done = false; - DescribeSecurityGroupsRequest request = new DescribeSecurityGroupsRequest(); + DescribeSecurityGroupsRequest.Builder request = DescribeSecurityGroupsRequest.builder(); ValueSet idConstraint = recordsRequest.getConstraints().getSummary().get("id"); if (idConstraint != null && idConstraint.isSingleValue()) { - request.setGroupIds(Collections.singletonList(idConstraint.getSingleValue().toString())); + request.groupIds(Collections.singletonList(idConstraint.getSingleValue().toString())); } ValueSet nameConstraint = recordsRequest.getConstraints().getSummary().get("name"); if (nameConstraint != null && nameConstraint.isSingleValue()) { - request.setGroupNames(Collections.singletonList(nameConstraint.getSingleValue().toString())); + request.groupNames(Collections.singletonList(nameConstraint.getSingleValue().toString())); } while (!done) { - DescribeSecurityGroupsResult response = ec2.describeSecurityGroups(request); + DescribeSecurityGroupsResponse response = ec2.describeSecurityGroups(request.build()); //Each rule is mapped to a row in the response. SGs have INGRESS and EGRESS rules. - for (SecurityGroup next : response.getSecurityGroups()) { - for (IpPermission nextPerm : next.getIpPermissions()) { + for (SecurityGroup next : response.securityGroups()) { + for (IpPermission nextPerm : next.ipPermissions()) { instanceToRow(next, nextPerm, INGRESS, spiller); } - for (IpPermission nextPerm : next.getIpPermissionsEgress()) { + for (IpPermission nextPerm : next.ipPermissionsEgress()) { instanceToRow(next, nextPerm, EGRESS, spiller); } } - request.setNextToken(response.getNextToken()); - if (response.getNextToken() == null || !queryStatusChecker.isQueryRunning()) { + request.nextToken(response.nextToken()); + if (response.nextToken() == null || !queryStatusChecker.isQueryRunning()) { done = true; } } @@ -148,28 +148,28 @@ private void instanceToRow(SecurityGroup securityGroup, spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("id", row, securityGroup.getGroupId()); - matched &= block.offerValue("name", row, securityGroup.getGroupName()); - matched &= block.offerValue("description", row, securityGroup.getDescription()); - matched &= block.offerValue("from_port", row, permission.getFromPort()); - matched &= block.offerValue("to_port", row, permission.getFromPort()); - matched &= block.offerValue("protocol", row, permission.getIpProtocol()); - matched &= block.offerValue("direction", row, permission.getIpProtocol()); + matched &= block.offerValue("id", row, securityGroup.groupId()); + matched &= block.offerValue("name", row, securityGroup.groupName()); + matched &= block.offerValue("description", row, securityGroup.description()); + matched &= block.offerValue("from_port", row, permission.fromPort()); + matched &= block.offerValue("to_port", row, permission.toPort()); + matched &= block.offerValue("protocol", row, permission.ipProtocol()); + matched &= block.offerValue("direction", row, direction); - List ipv4Ranges = permission.getIpv4Ranges().stream() - .map(next -> next.getCidrIp() + ":" + next.getDescription()).collect(Collectors.toList()); + List ipv4Ranges = permission.ipRanges().stream() + .map(next -> next.cidrIp() + ":" + next.description()).collect(Collectors.toList()); matched &= block.offerComplexValue("ipv4_ranges", row, FieldResolver.DEFAULT, ipv4Ranges); - List ipv6Ranges = permission.getIpv6Ranges().stream() - .map(next -> next.getCidrIpv6() + ":" + next.getDescription()).collect(Collectors.toList()); + List ipv6Ranges = permission.ipv6Ranges().stream() + .map(next -> next.cidrIpv6() + ":" + next.description()).collect(Collectors.toList()); matched &= block.offerComplexValue("ipv6_ranges", row, FieldResolver.DEFAULT, ipv6Ranges); - List prefixLists = permission.getPrefixListIds().stream() - .map(next -> next.getPrefixListId() + ":" + next.getDescription()).collect(Collectors.toList()); + List prefixLists = permission.prefixListIds().stream() + .map(next -> next.prefixListId() + ":" + next.description()).collect(Collectors.toList()); matched &= block.offerComplexValue("prefix_lists", row, FieldResolver.DEFAULT, prefixLists); - List userIdGroups = permission.getUserIdGroupPairs().stream() - .map(next -> next.getUserId() + ":" + next.getGroupId()) + List userIdGroups = permission.userIdGroupPairs().stream() + .map(next -> next.userId() + ":" + next.groupId()) .collect(Collectors.toList()); matched &= block.offerComplexValue("user_id_groups", row, FieldResolver.DEFAULT, userIdGroups); diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SubnetTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SubnetTableProvider.java index f64bb9bd26..444fd39510 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SubnetTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SubnetTableProvider.java @@ -31,12 +31,12 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeSubnetsRequest; -import com.amazonaws.services.ec2.model.DescribeSubnetsResult; -import com.amazonaws.services.ec2.model.Subnet; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeSubnetsRequest; +import software.amazon.awssdk.services.ec2.model.DescribeSubnetsResponse; +import software.amazon.awssdk.services.ec2.model.Subnet; import java.util.Collections; import java.util.List; @@ -49,9 +49,9 @@ public class SubnetTableProvider implements TableProvider { private static final Schema SCHEMA; - private AmazonEC2 ec2; + private Ec2Client ec2; - public SubnetTableProvider(AmazonEC2 ec2) + public SubnetTableProvider(Ec2Client ec2) { this.ec2 = ec2; } @@ -92,15 +92,15 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest @Override public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { - DescribeSubnetsRequest request = new DescribeSubnetsRequest(); + DescribeSubnetsRequest.Builder request = DescribeSubnetsRequest.builder(); ValueSet idConstraint = recordsRequest.getConstraints().getSummary().get("id"); if (idConstraint != null && idConstraint.isSingleValue()) { - request.setSubnetIds(Collections.singletonList(idConstraint.getSingleValue().toString())); + request.subnetIds(Collections.singletonList(idConstraint.getSingleValue().toString())); } - DescribeSubnetsResult response = ec2.describeSubnets(request); - for (Subnet subnet : response.getSubnets()) { + DescribeSubnetsResponse response = ec2.describeSubnets(request.build()); + for (Subnet subnet : response.subnets()) { instanceToRow(subnet, spiller); } } @@ -119,19 +119,18 @@ private void instanceToRow(Subnet subnet, spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("id", row, subnet.getSubnetId()); - matched &= block.offerValue("availability_zone", row, subnet.getAvailabilityZone()); - matched &= block.offerValue("available_ip_count", row, subnet.getAvailableIpAddressCount()); - matched &= block.offerValue("cidr_block", row, subnet.getCidrBlock()); - matched &= block.offerValue("default_for_az", row, subnet.getDefaultForAz()); - matched &= block.offerValue("map_public_ip", row, subnet.getMapPublicIpOnLaunch()); - matched &= block.offerValue("owner", row, subnet.getOwnerId()); - matched &= block.offerValue("state", row, subnet.getState()); - matched &= block.offerValue("vpc", row, subnet.getVpcId()); - matched &= block.offerValue("vpc", row, subnet.getVpcId()); + matched &= block.offerValue("id", row, subnet.subnetId()); + matched &= block.offerValue("availability_zone", row, subnet.availabilityZone()); + matched &= block.offerValue("available_ip_count", row, subnet.availableIpAddressCount()); + matched &= block.offerValue("cidr_block", row, subnet.cidrBlock()); + matched &= block.offerValue("default_for_az", row, subnet.defaultForAz()); + matched &= block.offerValue("map_public_ip", row, subnet.mapPublicIpOnLaunch()); + matched &= block.offerValue("owner", row, subnet.ownerId()); + matched &= block.offerValue("state", row, subnet.stateAsString()); + matched &= block.offerValue("vpc", row, subnet.vpcId()); - List tags = subnet.getTags().stream() - .map(next -> next.getKey() + ":" + next.getValue()).collect(Collectors.toList()); + List tags = subnet.tags().stream() + .map(next -> next.key() + ":" + next.value()).collect(Collectors.toList()); matched &= block.offerComplexValue("tags", row, FieldResolver.DEFAULT, tags); return matched ? 1 : 0; diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/VpcTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/VpcTableProvider.java index 18087ba5e5..44adc6a846 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/VpcTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/VpcTableProvider.java @@ -31,12 +31,12 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeVpcsRequest; -import com.amazonaws.services.ec2.model.DescribeVpcsResult; -import com.amazonaws.services.ec2.model.Vpc; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeVpcsRequest; +import software.amazon.awssdk.services.ec2.model.DescribeVpcsResponse; +import software.amazon.awssdk.services.ec2.model.Vpc; import java.util.Collections; import java.util.List; @@ -49,9 +49,9 @@ public class VpcTableProvider implements TableProvider { private static final Schema SCHEMA; - private AmazonEC2 ec2; + private Ec2Client ec2; - public VpcTableProvider(AmazonEC2 ec2) + public VpcTableProvider(Ec2Client ec2) { this.ec2 = ec2; } @@ -92,15 +92,15 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest @Override public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { - DescribeVpcsRequest request = new DescribeVpcsRequest(); + DescribeVpcsRequest.Builder request = DescribeVpcsRequest.builder(); ValueSet idConstraint = recordsRequest.getConstraints().getSummary().get("id"); if (idConstraint != null && idConstraint.isSingleValue()) { - request.setVpcIds(Collections.singletonList(idConstraint.getSingleValue().toString())); + request.vpcIds(Collections.singletonList(idConstraint.getSingleValue().toString())); } - DescribeVpcsResult response = ec2.describeVpcs(request); - for (Vpc vpc : response.getVpcs()) { + DescribeVpcsResponse response = ec2.describeVpcs(request.build()); + for (Vpc vpc : response.vpcs()) { instanceToRow(vpc, spiller); } } @@ -119,16 +119,16 @@ private void instanceToRow(Vpc vpc, spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("id", row, vpc.getVpcId()); - matched &= block.offerValue("cidr_block", row, vpc.getCidrBlock()); - matched &= block.offerValue("dhcp_opts", row, vpc.getDhcpOptionsId()); - matched &= block.offerValue("tenancy", row, vpc.getInstanceTenancy()); - matched &= block.offerValue("owner", row, vpc.getOwnerId()); - matched &= block.offerValue("state", row, vpc.getState()); - matched &= block.offerValue("is_default", row, vpc.getIsDefault()); + matched &= block.offerValue("id", row, vpc.vpcId()); + matched &= block.offerValue("cidr_block", row, vpc.cidrBlock()); + matched &= block.offerValue("dhcp_opts", row, vpc.dhcpOptionsId()); + matched &= block.offerValue("tenancy", row, vpc.instanceTenancyAsString()); + matched &= block.offerValue("owner", row, vpc.ownerId()); + matched &= block.offerValue("state", row, vpc.stateAsString()); + matched &= block.offerValue("is_default", row, vpc.isDefault()); - List tags = vpc.getTags().stream() - .map(next -> next.getKey() + ":" + next.getValue()).collect(Collectors.toList()); + List tags = vpc.tags().stream() + .map(next -> next.key() + ":" + next.value()).collect(Collectors.toList()); matched &= block.offerComplexValue("tags", row, FieldResolver.DEFAULT, tags); return matched ? 1 : 0; diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProvider.java index 0387ac6bf7..7ff28b61e5 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProvider.java @@ -29,10 +29,12 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.Bucket; -import com.amazonaws.services.s3.model.Owner; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.Bucket; +import software.amazon.awssdk.services.s3.model.GetBucketAclRequest; +import software.amazon.awssdk.services.s3.model.GetBucketAclResponse; +import software.amazon.awssdk.services.s3.model.Owner; /** * Maps your S3 Objects to a table. @@ -41,9 +43,9 @@ public class S3BucketsTableProvider implements TableProvider { private static final Schema SCHEMA; - private AmazonS3 amazonS3; + private S3Client amazonS3; - public S3BucketsTableProvider(AmazonS3 amazonS3) + public S3BucketsTableProvider(S3Client amazonS3) { this.amazonS3 = amazonS3; } @@ -84,7 +86,7 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest @Override public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { - for (Bucket next : amazonS3.listBuckets()) { + for (Bucket next : amazonS3.listBuckets().buckets()) { toRow(next, spiller); } } @@ -102,13 +104,15 @@ private void toRow(Bucket bucket, { spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("bucket_name", row, bucket.getName()); - matched &= block.offerValue("create_date", row, bucket.getCreationDate()); + matched &= block.offerValue("bucket_name", row, bucket.name()); + matched &= block.offerValue("create_date", row, bucket.creationDate()); - Owner owner = bucket.getOwner(); + GetBucketAclResponse response = amazonS3.getBucketAcl(GetBucketAclRequest.builder().bucket(bucket.name()).build()); + + Owner owner = response.owner(); if (owner != null) { - matched &= block.offerValue("owner_name", row, bucket.getOwner().getDisplayName()); - matched &= block.offerValue("owner_id", row, bucket.getOwner().getId()); + matched &= block.offerValue("owner_name", row, owner.displayName()); + matched &= block.offerValue("owner_id", row, owner.id()); } return matched ? 1 : 0; diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProvider.java index c58315f49e..88179b9382 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProvider.java @@ -30,12 +30,12 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.ListObjectsV2Request; -import com.amazonaws.services.s3.model.ListObjectsV2Result; -import com.amazonaws.services.s3.model.Owner; -import com.amazonaws.services.s3.model.S3ObjectSummary; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; +import software.amazon.awssdk.services.s3.model.Owner; +import software.amazon.awssdk.services.s3.model.S3Object; /** * Maps your S3 Objects to a table. @@ -45,9 +45,9 @@ public class S3ObjectsTableProvider { private static final int MAX_KEYS = 1000; private static final Schema SCHEMA; - private AmazonS3 amazonS3; + private S3Client amazonS3; - public S3ObjectsTableProvider(AmazonS3 amazonS3) + public S3ObjectsTableProvider(S3Client amazonS3) { this.amazonS3 = amazonS3; } @@ -98,42 +98,44 @@ public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsR "(e.g. where bucket_name='my_bucket'."); } - ListObjectsV2Request req = new ListObjectsV2Request().withBucketName(bucket).withMaxKeys(MAX_KEYS); - ListObjectsV2Result result; + ListObjectsV2Request req = ListObjectsV2Request.builder().bucket(bucket).maxKeys(MAX_KEYS).build(); + ListObjectsV2Response response; do { - result = amazonS3.listObjectsV2(req); - for (S3ObjectSummary objectSummary : result.getObjectSummaries()) { - toRow(objectSummary, spiller); + response = amazonS3.listObjectsV2(req); + for (S3Object s3Object : response.contents()) { + toRow(s3Object, spiller, bucket); } - req.setContinuationToken(result.getNextContinuationToken()); + req = req.toBuilder().continuationToken(response.nextContinuationToken()).build(); } - while (result.isTruncated() && queryStatusChecker.isQueryRunning()); + while (response.isTruncated() && queryStatusChecker.isQueryRunning()); } /** * Maps a DBInstance into a row in our Apache Arrow response block(s). * - * @param objectSummary The S3 ObjectSummary to map. + * @param s3Object The S3 object to map. * @param spiller The BlockSpiller to use when we want to write a matching row to the response. + * @param bucket The name of the S3 bucket * @note The current implementation is rather naive in how it maps fields. It leverages a static * list of fields that we'd like to provide and then explicitly filters and converts each field. */ - private void toRow(S3ObjectSummary objectSummary, - BlockSpiller spiller) + private void toRow(S3Object s3Object, + BlockSpiller spiller, + String bucket) { spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("bucket_name", row, objectSummary.getBucketName()); - matched &= block.offerValue("e_tag", row, objectSummary.getETag()); - matched &= block.offerValue("key", row, objectSummary.getKey()); - matched &= block.offerValue("bytes", row, objectSummary.getSize()); - matched &= block.offerValue("storage_class", row, objectSummary.getStorageClass()); - matched &= block.offerValue("last_modified", row, objectSummary.getLastModified()); + matched &= block.offerValue("bucket_name", row, bucket); + matched &= block.offerValue("e_tag", row, s3Object.eTag()); + matched &= block.offerValue("key", row, s3Object.key()); + matched &= block.offerValue("bytes", row, s3Object.size()); + matched &= block.offerValue("storage_class", row, s3Object.storageClassAsString()); + matched &= block.offerValue("last_modified", row, s3Object.lastModified()); - Owner owner = objectSummary.getOwner(); + Owner owner = s3Object.owner(); if (owner != null) { - matched &= block.offerValue("owner_name", row, owner.getDisplayName()); - matched &= block.offerValue("owner_id", row, owner.getId()); + matched &= block.offerValue("owner_name", row, owner.displayName()); + matched &= block.offerValue("owner_id", row, owner.id()); } return matched ? 1 : 0; diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandlerTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandlerTest.java index ba8f6f815e..6c755e65aa 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandlerTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandlerTest.java @@ -38,15 +38,15 @@ import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.Collections; @@ -75,7 +75,7 @@ public class AwsCmdbMetadataHandlerTest private FederatedIdentity identity = new FederatedIdentity("arn", "account", Collections.emptyMap(), Collections.emptyList()); @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock private TableProviderFactory mockTableProviderFactory; @@ -98,10 +98,10 @@ public class AwsCmdbMetadataHandlerTest private Block mockBlock; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; private AwsCmdbMetadataHandler handler; diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandlerTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandlerTest.java index 9c78bb1ab8..09000c9e60 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandlerTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandlerTest.java @@ -32,15 +32,15 @@ import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.UUID; @@ -62,7 +62,7 @@ public class AwsCmdbRecordHandlerTest private FederatedIdentity identity = new FederatedIdentity("arn", "account", Collections.emptyMap(), Collections.emptyList()); @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock private TableProviderFactory mockTableProviderFactory; @@ -77,10 +77,10 @@ public class AwsCmdbRecordHandlerTest private TableProvider mockTableProvider; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Mock private QueryStatusChecker queryStatusChecker; diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactoryTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactoryTest.java index 19a77878e4..83e2f72c3b 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactoryTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactoryTest.java @@ -21,19 +21,19 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; -import com.amazonaws.services.rds.AmazonRDS; -import com.amazonaws.services.s3.AmazonS3; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.emr.EmrClient; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.s3.S3Client; import java.util.List; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; @RunWith(MockitoJUnitRunner.class) public class TableProviderFactoryTest @@ -42,16 +42,16 @@ public class TableProviderFactoryTest private int expectedTables = 11; @Mock - private AmazonEC2 mockEc2; + private Ec2Client mockEc2; @Mock - private AmazonElasticMapReduce mockEmr; + private EmrClient mockEmr; @Mock - private AmazonRDS mockRds; + private RdsClient mockRds; @Mock - private AmazonS3 amazonS3; + private S3Client amazonS3; private TableProviderFactory factory = new TableProviderFactory(mockEc2, mockEmr, mockRds, amazonS3, com.google.common.collect.ImmutableMap.of()); diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/AbstractTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/AbstractTableProviderTest.java index f4d6ba505a..8ab8620921 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/AbstractTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/AbstractTableProviderTest.java @@ -43,11 +43,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -59,8 +54,16 @@ import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -74,8 +77,6 @@ import static com.amazonaws.athena.connector.lambda.domain.predicate.Constraints.DEFAULT_NO_LIMIT; import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -99,7 +100,7 @@ public abstract class AbstractTableProviderTest private final List mockS3Store = new ArrayList<>(); @Mock - private AmazonS3 amazonS3; + private S3Client amazonS3; @Mock private QueryStatusChecker queryStatusChecker; @@ -129,24 +130,24 @@ public void setUp() { allocator = new BlockAllocatorImpl(); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); mockS3Store.add(byteHolder); - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) - .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); - ByteHolder byteHolder = mockS3Store.get(0); - mockS3Store.remove(0); - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + when(amazonS3.getObject(any(GetObjectRequest.class))) + .thenAnswer(new Answer() + { + @Override + public Object answer(InvocationOnMock invocationOnMock) + throws Throwable + { + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(mockS3Store.get(0).getBytes())); + } }); blockSpillReader = new S3BlockSpillReader(amazonS3, allocator); diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/EmrClusterTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/EmrClusterTableProviderTest.java index c88fc6943b..b593b275a2 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/EmrClusterTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/EmrClusterTableProviderTest.java @@ -21,17 +21,6 @@ import com.amazonaws.athena.connector.lambda.data.Block; import com.amazonaws.athena.connector.lambda.data.BlockUtils; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; -import com.amazonaws.services.elasticmapreduce.model.Application; -import com.amazonaws.services.elasticmapreduce.model.Cluster; -import com.amazonaws.services.elasticmapreduce.model.ClusterStateChangeReason; -import com.amazonaws.services.elasticmapreduce.model.ClusterStatus; -import com.amazonaws.services.elasticmapreduce.model.ClusterSummary; -import com.amazonaws.services.elasticmapreduce.model.DescribeClusterRequest; -import com.amazonaws.services.elasticmapreduce.model.DescribeClusterResult; -import com.amazonaws.services.elasticmapreduce.model.ListClustersRequest; -import com.amazonaws.services.elasticmapreduce.model.ListClustersResult; -import com.amazonaws.services.elasticmapreduce.model.Tag; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -41,6 +30,17 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.emr.EmrClient; +import software.amazon.awssdk.services.emr.model.Application; +import software.amazon.awssdk.services.emr.model.Cluster; +import software.amazon.awssdk.services.emr.model.ClusterStateChangeReason; +import software.amazon.awssdk.services.emr.model.ClusterStatus; +import software.amazon.awssdk.services.emr.model.ClusterSummary; +import software.amazon.awssdk.services.emr.model.DescribeClusterRequest; +import software.amazon.awssdk.services.emr.model.DescribeClusterResponse; +import software.amazon.awssdk.services.emr.model.ListClustersRequest; +import software.amazon.awssdk.services.emr.model.ListClustersResponse; +import software.amazon.awssdk.services.emr.model.Tag; import java.util.ArrayList; import java.util.List; @@ -49,7 +49,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -59,7 +58,7 @@ public class EmrClusterTableProviderTest private static final Logger logger = LoggerFactory.getLogger(EmrClusterTableProviderTest.class); @Mock - private AmazonElasticMapReduce mockEmr; + private EmrClient mockEmr; protected String getIdField() { @@ -96,24 +95,18 @@ protected void setUpRead() { when(mockEmr.listClusters(nullable(ListClustersRequest.class))) .thenAnswer((InvocationOnMock invocation) -> { - ListClustersResult mockResult = mock(ListClustersResult.class); List values = new ArrayList<>(); values.add(makeClusterSummary(getIdValue())); values.add(makeClusterSummary(getIdValue())); values.add(makeClusterSummary("fake-id")); - when(mockResult.getClusters()).thenReturn(values); + ListClustersResponse mockResult = ListClustersResponse.builder().clusters(values).build(); return mockResult; }); when(mockEmr.describeCluster(nullable(DescribeClusterRequest.class))) .thenAnswer((InvocationOnMock invocation) -> { DescribeClusterRequest request = (DescribeClusterRequest) invocation.getArguments()[0]; - DescribeClusterResult mockResult = mock(DescribeClusterResult.class); - List values = new ArrayList<>(); - values.add(makeClusterSummary(getIdValue())); - values.add(makeClusterSummary(getIdValue())); - values.add(makeClusterSummary("fake-id")); - when(mockResult.getCluster()).thenReturn(makeCluster(request.getClusterId())); + DescribeClusterResponse mockResult = DescribeClusterResponse.builder().cluster(makeCluster(request.clusterId())).build(); return mockResult; }); } @@ -170,32 +163,32 @@ private void validate(FieldReader fieldReader) private ClusterSummary makeClusterSummary(String id) { - return new ClusterSummary() - .withName("name") - .withId(id) - .withStatus(new ClusterStatus() - .withState("state") - .withStateChangeReason(new ClusterStateChangeReason() - .withCode("state_code") - .withMessage("state_msg"))) - .withNormalizedInstanceHours(100); + return ClusterSummary.builder() + .name("name") + .id(id) + .status(ClusterStatus.builder().state("state") + .stateChangeReason(ClusterStateChangeReason.builder() + .code("state_code") + .message("state_msg").build()).build()) + .normalizedInstanceHours(100).build(); } private Cluster makeCluster(String id) { - return new Cluster() - .withId(id) - .withName("name") - .withAutoScalingRole("autoscaling_role") - .withCustomAmiId("custom_ami") - .withInstanceCollectionType("instance_collection_type") - .withLogUri("log_uri") - .withMasterPublicDnsName("master_public_dns") - .withReleaseLabel("release_label") - .withRunningAmiVersion("running_ami") - .withScaleDownBehavior("scale_down_behavior") - .withServiceRole("service_role") - .withApplications(new Application().withName("name").withVersion("version")) - .withTags(new Tag("key", "value")); + return Cluster.builder() + .id(id) + .name("name") + .autoScalingRole("autoscaling_role") + .customAmiId("custom_ami") + .instanceCollectionType("instance_collection_type") + .logUri("log_uri") + .masterPublicDnsName("master_public_dns") + .releaseLabel("release_label") + .runningAmiVersion("running_ami") + .scaleDownBehavior("scale_down_behavior") + .serviceRole("service_role") + .applications(Application.builder().name("name").version("version").build()) + .tags(Tag.builder().key("key").value("value").build()) + .build(); } } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/RdsTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/RdsTableProviderTest.java index 76b8c858ef..7f3e586387 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/RdsTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/RdsTableProviderTest.java @@ -21,30 +21,6 @@ import com.amazonaws.athena.connector.lambda.data.Block; import com.amazonaws.athena.connector.lambda.data.BlockUtils; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; -import com.amazonaws.services.elasticmapreduce.model.Application; -import com.amazonaws.services.elasticmapreduce.model.Cluster; -import com.amazonaws.services.elasticmapreduce.model.ClusterStateChangeReason; -import com.amazonaws.services.elasticmapreduce.model.ClusterStatus; -import com.amazonaws.services.elasticmapreduce.model.ClusterSummary; -import com.amazonaws.services.elasticmapreduce.model.DescribeClusterRequest; -import com.amazonaws.services.elasticmapreduce.model.DescribeClusterResult; -import com.amazonaws.services.elasticmapreduce.model.ListClustersRequest; -import com.amazonaws.services.elasticmapreduce.model.ListClustersResult; -import com.amazonaws.services.rds.AmazonRDS; -import com.amazonaws.services.rds.model.DBInstance; -import com.amazonaws.services.rds.model.DBInstanceStatusInfo; -import com.amazonaws.services.rds.model.DBParameterGroup; -import com.amazonaws.services.rds.model.DBParameterGroupStatus; -import com.amazonaws.services.rds.model.DBSecurityGroupMembership; -import com.amazonaws.services.rds.model.DBSubnetGroup; -import com.amazonaws.services.rds.model.DescribeDBInstancesRequest; -import com.amazonaws.services.rds.model.DescribeDBInstancesResult; -import com.amazonaws.services.rds.model.DomainMembership; -import com.amazonaws.services.rds.model.Endpoint; -import com.amazonaws.services.rds.model.Subnet; -import com.amazonaws.services.rds.model.Tag; - import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -54,6 +30,18 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.rds.model.DBInstance; +import software.amazon.awssdk.services.rds.model.DBInstanceStatusInfo; +import software.amazon.awssdk.services.rds.model.DBParameterGroupStatus; +import software.amazon.awssdk.services.rds.model.DBSecurityGroupMembership; +import software.amazon.awssdk.services.rds.model.DBSubnetGroup; +import software.amazon.awssdk.services.rds.model.DescribeDbInstancesRequest; +import software.amazon.awssdk.services.rds.model.DescribeDbInstancesResponse; +import software.amazon.awssdk.services.rds.model.DomainMembership; +import software.amazon.awssdk.services.rds.model.Endpoint; +import software.amazon.awssdk.services.rds.model.Subnet; +import software.amazon.awssdk.services.rds.model.Tag; import java.util.ArrayList; import java.util.Date; @@ -74,7 +62,7 @@ public class RdsTableProviderTest private static final Logger logger = LoggerFactory.getLogger(RdsTableProviderTest.class); @Mock - private AmazonRDS mockRds; + private RdsClient mockRds; protected String getIdField() { @@ -110,19 +98,19 @@ protected TableProvider setUpSource() protected void setUpRead() { final AtomicLong requestCount = new AtomicLong(0); - when(mockRds.describeDBInstances(nullable(DescribeDBInstancesRequest.class))) + when(mockRds.describeDBInstances(nullable(DescribeDbInstancesRequest.class))) .thenAnswer((InvocationOnMock invocation) -> { - DescribeDBInstancesResult mockResult = mock(DescribeDBInstancesResult.class); List values = new ArrayList<>(); values.add(makeValue(getIdValue())); values.add(makeValue(getIdValue())); values.add(makeValue("fake-id")); - when(mockResult.getDBInstances()).thenReturn(values); + DescribeDbInstancesResponse.Builder resultBuilder = DescribeDbInstancesResponse.builder(); + resultBuilder.dbInstances(values); if (requestCount.incrementAndGet() < 3) { - when(mockResult.getMarker()).thenReturn(String.valueOf(requestCount.get())); + resultBuilder.marker(String.valueOf(requestCount.get())); } - return mockResult; + return resultBuilder.build(); }); } @@ -184,56 +172,61 @@ private void validate(FieldReader fieldReader) private DBInstance makeValue(String id) { - return new DBInstance() - .withDBInstanceIdentifier(id) - .withAvailabilityZone("primary_az") - .withAllocatedStorage(100) - .withStorageEncrypted(true) - .withBackupRetentionPeriod(100) - .withAutoMinorVersionUpgrade(true) - .withDBInstanceClass("instance_class") - .withDbInstancePort(100) - .withDBInstanceStatus("status") - .withStorageType("storage_type") - .withDbiResourceId("dbi_resource_id") - .withDBName("name") - .withDomainMemberships(new DomainMembership() - .withDomain("domain") - .withFQDN("fqdn") - .withIAMRoleName("iam_role") - .withStatus("status")) - .withEngine("engine") - .withEngineVersion("engine_version") - .withLicenseModel("license_model") - .withSecondaryAvailabilityZone("secondary_az") - .withPreferredBackupWindow("backup_window") - .withPreferredMaintenanceWindow("maint_window") - .withReadReplicaSourceDBInstanceIdentifier("read_replica_source_id") - .withDBParameterGroups(new DBParameterGroupStatus() - .withDBParameterGroupName("name") - .withParameterApplyStatus("status")) - .withDBSecurityGroups(new DBSecurityGroupMembership() - .withDBSecurityGroupName("name") - .withStatus("status")) - .withDBSubnetGroup(new DBSubnetGroup() - .withDBSubnetGroupName("name") - .withSubnetGroupStatus("status") - .withVpcId("vpc") - .withSubnets(new Subnet() - .withSubnetIdentifier("subnet"))) - .withStatusInfos(new DBInstanceStatusInfo() - .withStatus("status") - .withMessage("message") - .withNormal(true) - .withStatusType("type")) - .withEndpoint(new Endpoint() - .withAddress("address") - .withPort(100) - .withHostedZoneId("zone")) - .withInstanceCreateTime(new Date(100000)) - .withIops(100) - .withMultiAZ(true) - .withPubliclyAccessible(true) - .withTagList(new Tag().withKey("key").withValue("value")); + return DBInstance.builder() + .dbInstanceIdentifier(id) + .availabilityZone("primary_az") + .allocatedStorage(100) + .storageEncrypted(true) + .backupRetentionPeriod(100) + .autoMinorVersionUpgrade(true) + .dbInstanceClass("instance_class") + .dbInstancePort(100) + .dbInstanceStatus("status") + .storageType("storage_type") + .dbiResourceId("dbi_resource_id") + .dbName("name") + .domainMemberships(DomainMembership.builder() + .domain("domain") + .fqdn("fqdn") + .iamRoleName("iam_role") + .status("status") + .build()) + .engine("engine") + .engineVersion("engine_version") + .licenseModel("license_model") + .secondaryAvailabilityZone("secondary_az") + .preferredBackupWindow("backup_window") + .preferredMaintenanceWindow("maint_window") + .readReplicaSourceDBInstanceIdentifier("read_replica_source_id") + .dbParameterGroups(DBParameterGroupStatus.builder() + .dbParameterGroupName("name") + .parameterApplyStatus("status") + .build()) + .dbSecurityGroups(DBSecurityGroupMembership.builder() + .dbSecurityGroupName("name") + .status("status").build()) + .dbSubnetGroup(DBSubnetGroup.builder() + .dbSubnetGroupName("name") + .subnetGroupStatus("status") + .vpcId("vpc") + .subnets(Subnet.builder().subnetIdentifier("subnet").build()) + .build()) + .statusInfos(DBInstanceStatusInfo.builder() + .status("status") + .message("message") + .normal(true) + .statusType("type") + .build()) + .endpoint(Endpoint.builder() + .address("address") + .port(100) + .hostedZoneId("zone") + .build()) + .instanceCreateTime(new Date(100000).toInstant()) + .iops(100) + .multiAZ(true) + .publiclyAccessible(true) + .tagList(Tag.builder().key("key").value("value").build()) + .build(); } } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/EbsTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/EbsTableProviderTest.java index 35ebc15812..2fdc295b4f 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/EbsTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/EbsTableProviderTest.java @@ -23,12 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeVolumesRequest; -import com.amazonaws.services.ec2.model.DescribeVolumesResult; -import com.amazonaws.services.ec2.model.Tag; -import com.amazonaws.services.ec2.model.Volume; -import com.amazonaws.services.ec2.model.VolumeAttachment; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -38,6 +32,12 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeVolumesRequest; +import software.amazon.awssdk.services.ec2.model.DescribeVolumesResponse; +import software.amazon.awssdk.services.ec2.model.Tag; +import software.amazon.awssdk.services.ec2.model.Volume; +import software.amazon.awssdk.services.ec2.model.VolumeAttachment; import java.util.ArrayList; import java.util.Date; @@ -47,7 +47,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -57,7 +56,7 @@ public class EbsTableProviderTest private static final Logger logger = LoggerFactory.getLogger(EbsTableProviderTest.class); @Mock - private AmazonEC2 mockEc2; + private Ec2Client mockEc2; protected String getIdField() { @@ -95,14 +94,13 @@ protected void setUpRead() when(mockEc2.describeVolumes(nullable(DescribeVolumesRequest.class))).thenAnswer((InvocationOnMock invocation) -> { DescribeVolumesRequest request = (DescribeVolumesRequest) invocation.getArguments()[0]; - assertEquals(getIdValue(), request.getVolumeIds().get(0)); - DescribeVolumesResult mockResult = mock(DescribeVolumesResult.class); + assertEquals(getIdValue(), request.volumeIds().get(0)); + List values = new ArrayList<>(); values.add(makeVolume(getIdValue())); values.add(makeVolume(getIdValue())); values.add(makeVolume("fake-id")); - when(mockResult.getVolumes()).thenReturn(values); - return mockResult; + return DescribeVolumesResponse.builder().volumes(values).build(); }); } @@ -158,23 +156,23 @@ private void validate(FieldReader fieldReader) private Volume makeVolume(String id) { - Volume volume = new Volume(); - volume.withVolumeId(id) - .withVolumeType("type") - .withAttachments(new VolumeAttachment() - .withInstanceId("target") - .withDevice("attached_device") - .withState("attachment_state") - .withAttachTime(new Date(100_000))) - .withAvailabilityZone("availability_zone") - .withCreateTime(new Date(100_000)) - .withEncrypted(true) - .withKmsKeyId("kms_key_id") - .withSize(100) - .withIops(100) - .withSnapshotId("snapshot_id") - .withState("state") - .withTags(new Tag("key", "value")); + Volume volume = Volume.builder() + .volumeId(id) + .volumeType("type") + .attachments(VolumeAttachment.builder() + .instanceId("target") + .device("attached_device") + .state("attachment_state") + .attachTime(new Date(100_000).toInstant()).build()) + .availabilityZone("availability_zone") + .createTime(new Date(100_000).toInstant()) + .encrypted(true) + .kmsKeyId("kms_key_id") + .size(100) + .iops(100) + .snapshotId("snapshot_id") + .state("state") + .tags(Tag.builder().key("key").value("value").build()).build(); return volume; } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/Ec2TableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/Ec2TableProviderTest.java index 9712796cf6..2dd9bcfe2a 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/Ec2TableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/Ec2TableProviderTest.java @@ -23,19 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeInstancesRequest; -import com.amazonaws.services.ec2.model.DescribeInstancesResult; -import com.amazonaws.services.ec2.model.EbsInstanceBlockDevice; -import com.amazonaws.services.ec2.model.GroupIdentifier; -import com.amazonaws.services.ec2.model.Instance; -import com.amazonaws.services.ec2.model.InstanceBlockDeviceMapping; -import com.amazonaws.services.ec2.model.InstanceNetworkInterface; -import com.amazonaws.services.ec2.model.InstanceState; -import com.amazonaws.services.ec2.model.Reservation; -import com.amazonaws.services.ec2.model.StateReason; -import com.amazonaws.services.ec2.model.Tag; - import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -45,14 +32,27 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeInstancesRequest; +import software.amazon.awssdk.services.ec2.model.DescribeInstancesResponse; +import software.amazon.awssdk.services.ec2.model.EbsInstanceBlockDevice; +import software.amazon.awssdk.services.ec2.model.GroupIdentifier; +import software.amazon.awssdk.services.ec2.model.Instance; +import software.amazon.awssdk.services.ec2.model.InstanceBlockDeviceMapping; +import software.amazon.awssdk.services.ec2.model.InstanceNetworkInterface; +import software.amazon.awssdk.services.ec2.model.InstanceState; +import software.amazon.awssdk.services.ec2.model.Reservation; +import software.amazon.awssdk.services.ec2.model.StateReason; +import software.amazon.awssdk.services.ec2.model.Tag; import java.util.ArrayList; import java.util.Date; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -62,7 +62,7 @@ public class Ec2TableProviderTest private static final Logger logger = LoggerFactory.getLogger(Ec2TableProviderTest.class); @Mock - private AmazonEC2 mockEc2; + private Ec2Client mockEc2; protected String getIdField() { @@ -100,13 +100,11 @@ protected void setUpRead() when(mockEc2.describeInstances(nullable(DescribeInstancesRequest.class))).thenAnswer((InvocationOnMock invocation) -> { DescribeInstancesRequest request = (DescribeInstancesRequest) invocation.getArguments()[0]; - assertEquals(getIdValue(), request.getInstanceIds().get(0)); - DescribeInstancesResult mockResult = mock(DescribeInstancesResult.class); + assertEquals(getIdValue(), request.instanceIds().get(0)); List reservations = new ArrayList<>(); reservations.add(makeReservation()); reservations.add(makeReservation()); - when(mockResult.getReservations()).thenReturn(reservations); - return mockResult; + return DescribeInstancesResponse.builder().reservations(reservations).build(); }); } @@ -162,68 +160,66 @@ private void validate(FieldReader fieldReader) private Reservation makeReservation() { - Reservation reservation = mock(Reservation.class); List instances = new ArrayList<>(); instances.add(makeInstance(getIdValue())); instances.add(makeInstance(getIdValue())); instances.add(makeInstance("non-matching-id")); - when(reservation.getInstances()).thenReturn(instances); - return reservation; + return Reservation.builder().instances(instances).build(); } private Instance makeInstance(String id) { - Instance instance = new Instance(); - instance.withInstanceId(id) - .withImageId("image_id") - .withInstanceType("instance_type") - .withPlatform("platform") - .withPrivateDnsName("private_dns_name") - .withPrivateIpAddress("private_ip_address") - .withPublicDnsName("public_dns_name") - .withPublicIpAddress("public_ip_address") - .withSubnetId("subnet_id") - .withVpcId("vpc_id") - .withArchitecture("architecture") - .withInstanceLifecycle("instance_lifecycle") - .withRootDeviceName("root_device_name") - .withRootDeviceType("root_device_type") - .withSpotInstanceRequestId("spot_instance_request_id") - .withVirtualizationType("virtualization_type") - .withKeyName("key_name") - .withKernelId("kernel_id") - .withCapacityReservationId("capacity_reservation_id") - .withLaunchTime(new Date(100_000)) - .withState(new InstanceState().withCode(100).withName("name")) - .withStateReason(new StateReason().withCode("code").withMessage("message")) - .withEbsOptimized(true) - .withTags(new Tag("key","value")); + Instance.Builder instance = Instance.builder() + .instanceId(id) + .imageId("image_id") + .instanceType("instance_type") + .platform("platform") + .privateDnsName("private_dns_name") + .privateIpAddress("private_ip_address") + .publicDnsName("public_dns_name") + .publicIpAddress("public_ip_address") + .subnetId("subnet_id") + .vpcId("vpc_id") + .architecture("architecture") + .instanceLifecycle("instance_lifecycle") + .rootDeviceName("root_device_name") + .rootDeviceType("root_device_type") + .spotInstanceRequestId("spot_instance_request_id") + .virtualizationType("virtualization_type") + .keyName("key_name") + .kernelId("kernel_id") + .capacityReservationId("capacity_reservation_id") + .launchTime(new Date(100_000).toInstant()) + .state(InstanceState.builder().code(100).name("name").build()) + .stateReason(StateReason.builder().code("code").message("message").build()) + .ebsOptimized(true) + .tags(Tag.builder().key("key").value("value").build()); List interfaces = new ArrayList<>(); - interfaces.add(new InstanceNetworkInterface() - .withStatus("status") - .withSubnetId("subnet") - .withVpcId("vpc") - .withMacAddress("mac_address") - .withPrivateDnsName("private_dns") - .withPrivateIpAddress("private_ip") - .withNetworkInterfaceId("interface_id") - .withGroups(new GroupIdentifier().withGroupId("group_id").withGroupName("group_name"))); - - interfaces.add(new InstanceNetworkInterface() - .withStatus("status") - .withSubnetId("subnet") - .withVpcId("vpc") - .withMacAddress("mac") - .withPrivateDnsName("private_dns") - .withPrivateIpAddress("private_ip") - .withNetworkInterfaceId("interface_id") - .withGroups(new GroupIdentifier().withGroupId("group_id").withGroupName("group_name"))); - - instance.withNetworkInterfaces(interfaces) - .withSecurityGroups(new GroupIdentifier().withGroupId("group_id").withGroupName("group_name")) - .withBlockDeviceMappings(new InstanceBlockDeviceMapping().withDeviceName("device_name").withEbs(new EbsInstanceBlockDevice().withVolumeId("volume_id"))); - - return instance; + interfaces.add(InstanceNetworkInterface.builder() + .status("status") + .subnetId("subnet") + .vpcId("vpc") + .macAddress("mac_address") + .privateDnsName("private_dns") + .privateIpAddress("private_ip") + .networkInterfaceId("interface_id") + .groups(GroupIdentifier.builder().groupId("group_id").groupName("group_name").build()).build()); + + interfaces.add(InstanceNetworkInterface.builder() + .status("status") + .subnetId("subnet") + .vpcId("vpc") + .macAddress("mac") + .privateDnsName("private_dns") + .privateIpAddress("private_ip") + .networkInterfaceId("interface_id") + .groups(GroupIdentifier.builder().groupId("group_id").groupName("group_name").build()).build()); + + instance.networkInterfaces(interfaces) + .securityGroups(GroupIdentifier.builder().groupId("group_id").groupName("group_name").build()) + .blockDeviceMappings(InstanceBlockDeviceMapping.builder().deviceName("device_name").ebs(EbsInstanceBlockDevice.builder().volumeId("volume_id").build()).build()); + + return instance.build(); } } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/ImagesTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/ImagesTableProviderTest.java index c1ab238c86..d981c618be 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/ImagesTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/ImagesTableProviderTest.java @@ -23,13 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.BlockDeviceMapping; -import com.amazonaws.services.ec2.model.DescribeImagesRequest; -import com.amazonaws.services.ec2.model.DescribeImagesResult; -import com.amazonaws.services.ec2.model.EbsBlockDevice; -import com.amazonaws.services.ec2.model.Image; -import com.amazonaws.services.ec2.model.Tag; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -39,6 +32,13 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.BlockDeviceMapping; +import software.amazon.awssdk.services.ec2.model.DescribeImagesRequest; +import software.amazon.awssdk.services.ec2.model.DescribeImagesResponse; +import software.amazon.awssdk.services.ec2.model.EbsBlockDevice; +import software.amazon.awssdk.services.ec2.model.Image; +import software.amazon.awssdk.services.ec2.model.Tag; import java.util.ArrayList; import java.util.List; @@ -47,7 +47,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -57,7 +56,7 @@ public class ImagesTableProviderTest private static final Logger logger = LoggerFactory.getLogger(ImagesTableProviderTest.class); @Mock - private AmazonEC2 mockEc2; + private Ec2Client mockEc2; protected String getIdField() { @@ -95,14 +94,12 @@ protected void setUpRead() when(mockEc2.describeImages(nullable(DescribeImagesRequest.class))).thenAnswer((InvocationOnMock invocation) -> { DescribeImagesRequest request = (DescribeImagesRequest) invocation.getArguments()[0]; - assertEquals(getIdValue(), request.getImageIds().get(0)); - DescribeImagesResult mockResult = mock(DescribeImagesResult.class); + assertEquals(getIdValue(), request.imageIds().get(0)); List values = new ArrayList<>(); values.add(makeImage(getIdValue())); values.add(makeImage(getIdValue())); values.add(makeImage("fake-id")); - when(mockResult.getImages()).thenReturn(values); - return mockResult; + return DescribeImagesResponse.builder().images(values).build(); }); } @@ -158,35 +155,35 @@ private void validate(FieldReader fieldReader) private Image makeImage(String id) { - Image image = new Image(); - image.withImageId(id) - .withArchitecture("architecture") - .withCreationDate("created") - .withDescription("description") - .withHypervisor("hypervisor") - .withImageLocation("location") - .withImageType("type") - .withKernelId("kernel") - .withName("name") - .withOwnerId("owner") - .withPlatform("platform") - .withRamdiskId("ramdisk") - .withRootDeviceName("root_device") - .withRootDeviceType("root_type") - .withSriovNetSupport("srvio_net") - .withState("state") - .withVirtualizationType("virt_type") - .withPublic(true) - .withTags(new Tag("key", "value")) - .withBlockDeviceMappings(new BlockDeviceMapping() - .withDeviceName("dev_name") - .withNoDevice("no_device") - .withVirtualName("virt_name") - .withEbs(new EbsBlockDevice() - .withIops(100) - .withKmsKeyId("ebs_kms_key") - .withVolumeType("ebs_type") - .withVolumeSize(100))); + Image image = Image.builder() + .imageId(id) + .architecture("architecture") + .creationDate("created") + .description("description") + .hypervisor("hypervisor") + .imageLocation("location") + .imageType("type") + .kernelId("kernel") + .name("name") + .ownerId("owner") + .platform("platform") + .ramdiskId("ramdisk") + .rootDeviceName("root_device") + .rootDeviceType("root_type") + .sriovNetSupport("srvio_net") + .state("state") + .virtualizationType("virt_type") + .publicLaunchPermissions(true) + .tags(Tag.builder().key("key").value("value").build()) + .blockDeviceMappings(BlockDeviceMapping.builder() + .deviceName("dev_name") + .noDevice("no_device") + .virtualName("virt_name") + .ebs(EbsBlockDevice.builder() + .iops(100) + .kmsKeyId("ebs_kms_key") + .volumeType("ebs_type") + .volumeSize(100).build()).build()).build(); return image; } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/RouteTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/RouteTableProviderTest.java index cf293d33e2..6151bea84d 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/RouteTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/RouteTableProviderTest.java @@ -23,14 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeRouteTablesRequest; -import com.amazonaws.services.ec2.model.DescribeRouteTablesResult; -import com.amazonaws.services.ec2.model.PropagatingVgw; -import com.amazonaws.services.ec2.model.Route; -import com.amazonaws.services.ec2.model.RouteTable; -import com.amazonaws.services.ec2.model.RouteTableAssociation; -import com.amazonaws.services.ec2.model.Tag; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -40,6 +32,14 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeRouteTablesRequest; +import software.amazon.awssdk.services.ec2.model.DescribeRouteTablesResponse; +import software.amazon.awssdk.services.ec2.model.PropagatingVgw; +import software.amazon.awssdk.services.ec2.model.Route; +import software.amazon.awssdk.services.ec2.model.RouteTable; +import software.amazon.awssdk.services.ec2.model.RouteTableAssociation; +import software.amazon.awssdk.services.ec2.model.Tag; import java.util.ArrayList; import java.util.List; @@ -48,7 +48,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -58,7 +57,7 @@ public class RouteTableProviderTest private static final Logger logger = LoggerFactory.getLogger(RouteTableProviderTest.class); @Mock - private AmazonEC2 mockEc2; + private Ec2Client mockEc2; protected String getIdField() { @@ -96,14 +95,12 @@ protected void setUpRead() when(mockEc2.describeRouteTables(nullable(DescribeRouteTablesRequest.class))).thenAnswer((InvocationOnMock invocation) -> { DescribeRouteTablesRequest request = (DescribeRouteTablesRequest) invocation.getArguments()[0]; - assertEquals(getIdValue(), request.getRouteTableIds().get(0)); - DescribeRouteTablesResult mockResult = mock(DescribeRouteTablesResult.class); + assertEquals(getIdValue(), request.routeTableIds().get(0)); List values = new ArrayList<>(); values.add(makeRouteTable(getIdValue())); values.add(makeRouteTable(getIdValue())); values.add(makeRouteTable("fake-id")); - when(mockResult.getRouteTables()).thenReturn(values); - return mockResult; + return DescribeRouteTablesResponse.builder().routeTables(values).build(); }); } @@ -159,28 +156,28 @@ private void validate(FieldReader fieldReader) private RouteTable makeRouteTable(String id) { - RouteTable routeTable = new RouteTable(); - routeTable.withRouteTableId(id) - .withOwnerId("owner") - .withVpcId("vpc") - .withAssociations(new RouteTableAssociation().withSubnetId("subnet").withRouteTableId("route_table_id")) - .withTags(new Tag("key", "value")) - .withPropagatingVgws(new PropagatingVgw().withGatewayId("gateway_id")) - .withRoutes(new Route() - .withDestinationCidrBlock("dst_cidr") - .withDestinationIpv6CidrBlock("dst_cidr_v6") - .withDestinationPrefixListId("dst_prefix_list") - .withEgressOnlyInternetGatewayId("egress_igw") - .withGatewayId("gateway") - .withInstanceId("instance_id") - .withInstanceOwnerId("instance_owner") - .withNatGatewayId("nat_gateway") - .withNetworkInterfaceId("interface") - .withOrigin("origin") - .withState("state") - .withTransitGatewayId("transit_gateway") - .withVpcPeeringConnectionId("vpc_peering_con") - ); + RouteTable routeTable = RouteTable.builder() + .routeTableId(id) + .ownerId("owner") + .vpcId("vpc") + .associations(RouteTableAssociation.builder().subnetId("subnet").routeTableId("route_table_id").build()) + .tags(Tag.builder().key("key").value("value").build()) + .propagatingVgws(PropagatingVgw.builder().gatewayId("gateway_id").build()) + .routes(Route.builder() + .destinationCidrBlock("dst_cidr") + .destinationIpv6CidrBlock("dst_cidr_v6") + .destinationPrefixListId("dst_prefix_list") + .egressOnlyInternetGatewayId("egress_igw") + .gatewayId("gateway") + .instanceId("instance_id") + .instanceOwnerId("instance_owner") + .natGatewayId("nat_gateway") + .networkInterfaceId("interface") + .origin("origin") + .state("state") + .transitGatewayId("transit_gateway") + .vpcPeeringConnectionId("vpc_peering_con").build() + ).build(); return routeTable; } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SecurityGroupsTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SecurityGroupsTableProviderTest.java index 471a54af25..ea89933190 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SecurityGroupsTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SecurityGroupsTableProviderTest.java @@ -23,15 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeSecurityGroupsRequest; -import com.amazonaws.services.ec2.model.DescribeSecurityGroupsResult; -import com.amazonaws.services.ec2.model.IpPermission; -import com.amazonaws.services.ec2.model.IpRange; -import com.amazonaws.services.ec2.model.Ipv6Range; -import com.amazonaws.services.ec2.model.PrefixListId; -import com.amazonaws.services.ec2.model.SecurityGroup; -import com.amazonaws.services.ec2.model.UserIdGroupPair; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -41,6 +32,15 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeSecurityGroupsRequest; +import software.amazon.awssdk.services.ec2.model.DescribeSecurityGroupsResponse; +import software.amazon.awssdk.services.ec2.model.IpPermission; +import software.amazon.awssdk.services.ec2.model.IpRange; +import software.amazon.awssdk.services.ec2.model.Ipv6Range; +import software.amazon.awssdk.services.ec2.model.PrefixListId; +import software.amazon.awssdk.services.ec2.model.SecurityGroup; +import software.amazon.awssdk.services.ec2.model.UserIdGroupPair; import java.util.ArrayList; import java.util.List; @@ -49,7 +49,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -59,7 +58,7 @@ public class SecurityGroupsTableProviderTest private static final Logger logger = LoggerFactory.getLogger(SecurityGroupsTableProviderTest.class); @Mock - private AmazonEC2 mockEc2; + private Ec2Client mockEc2; protected String getIdField() { @@ -98,14 +97,12 @@ protected void setUpRead() .thenAnswer((InvocationOnMock invocation) -> { DescribeSecurityGroupsRequest request = (DescribeSecurityGroupsRequest) invocation.getArguments()[0]; - assertEquals(getIdValue(), request.getGroupIds().get(0)); - DescribeSecurityGroupsResult mockResult = mock(DescribeSecurityGroupsResult.class); + assertEquals(getIdValue(), request.groupIds().get(0)); List values = new ArrayList<>(); values.add(makeSecurityGroup(getIdValue())); values.add(makeSecurityGroup(getIdValue())); values.add(makeSecurityGroup("fake-id")); - when(mockResult.getSecurityGroups()).thenReturn(values); - return mockResult; + return DescribeSecurityGroupsResponse.builder().securityGroups(values).build(); }); } @@ -161,19 +158,18 @@ private void validate(FieldReader fieldReader) private SecurityGroup makeSecurityGroup(String id) { - return new SecurityGroup() - .withGroupId(id) - .withGroupName("name") - .withDescription("description") - .withIpPermissions(new IpPermission() - .withIpProtocol("protocol") - .withFromPort(100) - .withToPort(100) - .withIpv4Ranges(new IpRange().withCidrIp("cidr").withDescription("description")) - - .withIpv6Ranges(new Ipv6Range().withCidrIpv6("cidr").withDescription("description")) - .withPrefixListIds(new PrefixListId().withPrefixListId("prefix").withDescription("description")) - .withUserIdGroupPairs(new UserIdGroupPair().withGroupId("group_id").withUserId("user_id")) - ); + return SecurityGroup.builder() + .groupId(id) + .groupName("name") + .description("description") + .ipPermissions(IpPermission.builder() + .ipProtocol("protocol") + .fromPort(100) + .toPort(100) + .ipRanges(IpRange.builder().cidrIp("cidr").description("description").build()) + .ipv6Ranges(Ipv6Range.builder().cidrIpv6("cidr").description("description").build()) + .prefixListIds(PrefixListId.builder().prefixListId("prefix").description("description").build()) + .userIdGroupPairs(UserIdGroupPair.builder().groupId("group_id").userId("user_id").build()).build() + ).build(); } } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SubnetTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SubnetTableProviderTest.java index 04437e13f9..4afd3e4e5e 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SubnetTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/SubnetTableProviderTest.java @@ -23,11 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeSubnetsRequest; -import com.amazonaws.services.ec2.model.DescribeSubnetsResult; -import com.amazonaws.services.ec2.model.Subnet; -import com.amazonaws.services.ec2.model.Tag; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -37,6 +32,11 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeSubnetsRequest; +import software.amazon.awssdk.services.ec2.model.DescribeSubnetsResponse; +import software.amazon.awssdk.services.ec2.model.Subnet; +import software.amazon.awssdk.services.ec2.model.Tag; import java.util.ArrayList; import java.util.List; @@ -45,7 +45,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -55,7 +54,7 @@ public class SubnetTableProviderTest private static final Logger logger = LoggerFactory.getLogger(SubnetTableProviderTest.class); @Mock - private AmazonEC2 mockEc2; + private Ec2Client mockEc2; protected String getIdField() { @@ -93,15 +92,12 @@ protected void setUpRead() when(mockEc2.describeSubnets(nullable(DescribeSubnetsRequest.class))).thenAnswer((InvocationOnMock invocation) -> { DescribeSubnetsRequest request = (DescribeSubnetsRequest) invocation.getArguments()[0]; - assertEquals(getIdValue(), request.getSubnetIds().get(0)); - DescribeSubnetsResult mockResult = mock(DescribeSubnetsResult.class); + assertEquals(getIdValue(), request.subnetIds().get(0)); List values = new ArrayList<>(); values.add(makeSubnet(getIdValue())); values.add(makeSubnet(getIdValue())); values.add(makeSubnet("fake-id")); - when(mockResult.getSubnets()).thenReturn(values); - - return mockResult; + return DescribeSubnetsResponse.builder().subnets(values).build(); }); } @@ -157,16 +153,16 @@ private void validate(FieldReader fieldReader) private Subnet makeSubnet(String id) { - return new Subnet() - .withSubnetId(id) - .withAvailabilityZone("availability_zone") - .withCidrBlock("cidr_block") - .withAvailableIpAddressCount(100) - .withDefaultForAz(true) - .withMapPublicIpOnLaunch(true) - .withOwnerId("owner") - .withState("state") - .withTags(new Tag().withKey("key").withValue("value")) - .withVpcId("vpc"); + return Subnet.builder() + .subnetId(id) + .availabilityZone("availability_zone") + .cidrBlock("cidr_block") + .availableIpAddressCount(100) + .defaultForAz(true) + .mapPublicIpOnLaunch(true) + .ownerId("owner") + .state("state") + .tags(Tag.builder().key("key").value("value").build()) + .vpcId("vpc").build(); } } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/VpcTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/VpcTableProviderTest.java index 4abb29ccd1..900fdf67ca 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/VpcTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/ec2/VpcTableProviderTest.java @@ -23,11 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.DescribeVpcsRequest; -import com.amazonaws.services.ec2.model.DescribeVpcsResult; -import com.amazonaws.services.ec2.model.Tag; -import com.amazonaws.services.ec2.model.Vpc; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -37,6 +32,11 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.ec2.Ec2Client; +import software.amazon.awssdk.services.ec2.model.DescribeVpcsRequest; +import software.amazon.awssdk.services.ec2.model.DescribeVpcsResponse; +import software.amazon.awssdk.services.ec2.model.Tag; +import software.amazon.awssdk.services.ec2.model.Vpc; import java.util.ArrayList; import java.util.List; @@ -45,7 +45,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -55,7 +54,7 @@ public class VpcTableProviderTest private static final Logger logger = LoggerFactory.getLogger(VpcTableProviderTest.class); @Mock - private AmazonEC2 mockEc2; + private Ec2Client mockEc2; protected String getIdField() { @@ -93,14 +92,12 @@ protected void setUpRead() when(mockEc2.describeVpcs(nullable(DescribeVpcsRequest.class))).thenAnswer((InvocationOnMock invocation) -> { DescribeVpcsRequest request = (DescribeVpcsRequest) invocation.getArguments()[0]; - assertEquals(getIdValue(), request.getVpcIds().get(0)); - DescribeVpcsResult mockResult = mock(DescribeVpcsResult.class); + assertEquals(getIdValue(), request.vpcIds().get(0)); List values = new ArrayList<>(); values.add(makeVpc(getIdValue())); values.add(makeVpc(getIdValue())); values.add(makeVpc("fake-id")); - when(mockResult.getVpcs()).thenReturn(values); - return mockResult; + return DescribeVpcsResponse.builder().vpcs(values).build(); }); } @@ -156,15 +153,15 @@ private void validate(FieldReader fieldReader) private Vpc makeVpc(String id) { - Vpc vpc = new Vpc(); - vpc.withVpcId(id) - .withCidrBlock("cidr_block") - .withDhcpOptionsId("dhcp_opts") - .withInstanceTenancy("tenancy") - .withOwnerId("owner") - .withState("state") - .withIsDefault(true) - .withTags(new Tag("key", "valye")); + Vpc vpc = Vpc.builder() + .vpcId(id) + .cidrBlock("cidr_block") + .dhcpOptionsId("dhcp_opts") + .instanceTenancy("tenancy") + .ownerId("owner") + .state("state") + .isDefault(true) + .tags(Tag.builder().key("key").value("valye").build()).build(); return vpc; } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProviderTest.java index cb1372a917..348a077164 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProviderTest.java @@ -23,9 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.Bucket; -import com.amazonaws.services.s3.model.Owner; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -33,12 +30,19 @@ import org.mockito.invocation.InvocationOnMock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.Bucket; +import software.amazon.awssdk.services.s3.model.GetBucketAclRequest; +import software.amazon.awssdk.services.s3.model.GetBucketAclResponse; +import software.amazon.awssdk.services.s3.model.ListBucketsResponse; +import software.amazon.awssdk.services.s3.model.Owner; import java.util.ArrayList; import java.util.Date; import java.util.List; import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; public class S3BucketsTableProviderTest @@ -47,7 +51,7 @@ public class S3BucketsTableProviderTest private static final Logger logger = LoggerFactory.getLogger(S3BucketsTableProviderTest.class); @Mock - private AmazonS3 mockS3; + private S3Client mockS3; protected String getIdField() { @@ -87,7 +91,15 @@ protected void setUpRead() values.add(makeBucket(getIdValue())); values.add(makeBucket(getIdValue())); values.add(makeBucket("fake-id")); - return values; + return ListBucketsResponse.builder().buckets(values).build(); + }); + when(mockS3.getBucketAcl(any(GetBucketAclRequest.class))).thenAnswer((InvocationOnMock invocation) -> { + return GetBucketAclResponse.builder() + .owner(Owner.builder() + .displayName("owner_name") + .id("owner_id") + .build()) + .build(); }); } @@ -143,13 +155,10 @@ private void validate(FieldReader fieldReader) private Bucket makeBucket(String id) { - Bucket bucket = new Bucket(); - bucket.setName(id); - Owner owner = new Owner(); - owner.setDisplayName("owner_name"); - owner.setId("owner_id"); - bucket.setOwner(owner); - bucket.setCreationDate(new Date(100_000)); + Bucket bucket = Bucket.builder() + .name(id) + .creationDate(new Date(100_000).toInstant()) + .build(); return bucket; } } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProviderTest.java index ec77efc11a..761730ee08 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProviderTest.java @@ -23,11 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.ListObjectsV2Request; -import com.amazonaws.services.s3.model.ListObjectsV2Result; -import com.amazonaws.services.s3.model.Owner; -import com.amazonaws.services.s3.model.S3ObjectSummary; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -35,6 +30,11 @@ import org.mockito.invocation.InvocationOnMock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; +import software.amazon.awssdk.services.s3.model.Owner; +import software.amazon.awssdk.services.s3.model.S3Object; import java.util.ArrayList; import java.util.Date; @@ -45,7 +45,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class S3ObjectsTableProviderTest @@ -54,7 +53,7 @@ public class S3ObjectsTableProviderTest private static final Logger logger = LoggerFactory.getLogger(S3ObjectsTableProviderTest.class); @Mock - private AmazonS3 mockS3; + private S3Client mockS3; protected String getIdField() { @@ -92,25 +91,26 @@ protected void setUpRead() AtomicLong count = new AtomicLong(0); when(mockS3.listObjectsV2(nullable(ListObjectsV2Request.class))).thenAnswer((InvocationOnMock invocation) -> { ListObjectsV2Request request = (ListObjectsV2Request) invocation.getArguments()[0]; - assertEquals(getIdValue(), request.getBucketName()); + assertEquals(getIdValue(), request.bucket()); - ListObjectsV2Result mockResult = mock(ListObjectsV2Result.class); - List values = new ArrayList<>(); - values.add(makeObjectSummary(getIdValue())); - values.add(makeObjectSummary(getIdValue())); - values.add(makeObjectSummary("fake-id")); - when(mockResult.getObjectSummaries()).thenReturn(values); + List values = new ArrayList<>(); + values.add(makeS3Object()); + values.add(makeS3Object()); + ListObjectsV2Response.Builder responseBuilder = ListObjectsV2Response.builder().contents(values); if (count.get() > 0) { - assertNotNull(request.getContinuationToken()); + assertNotNull(request.continuationToken()); } if (count.incrementAndGet() < 2) { - when(mockResult.isTruncated()).thenReturn(true); - when(mockResult.getNextContinuationToken()).thenReturn("token"); + responseBuilder.isTruncated(true); + responseBuilder.nextContinuationToken("token"); + } + else { + responseBuilder.isTruncated(false); } - return mockResult; + return responseBuilder.build(); }); } @@ -167,19 +167,17 @@ private void validate(FieldReader fieldReader) } } - private S3ObjectSummary makeObjectSummary(String id) + private S3Object makeS3Object() { - S3ObjectSummary summary = new S3ObjectSummary(); - Owner owner = new Owner(); - owner.setId("owner_id"); - owner.setDisplayName("owner_name"); - summary.setOwner(owner); - summary.setBucketName(id); - summary.setETag("e_tag"); - summary.setKey("key"); - summary.setSize(100); - summary.setLastModified(new Date(100_000)); - summary.setStorageClass("storage_class"); - return summary; + Owner owner = Owner.builder().id("owner_id").displayName("owner_name").build(); + S3Object s3Object = S3Object.builder() + .owner(owner) + .eTag("e_tag") + .key("key") + .size((long)100) + .lastModified(new Date(100_000).toInstant()) + .storageClass("storage_class") + .build(); + return s3Object; } } diff --git a/athena-clickhouse/Dockerfile b/athena-clickhouse/Dockerfile new file mode 100644 index 0000000000..a092ba28cb --- /dev/null +++ b/athena-clickhouse/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-clickhouse-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-clickhouse-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.clickhouse.ClickHouseMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-clickhouse/athena-clickhouse.yaml b/athena-clickhouse/athena-clickhouse.yaml index ca6d171dcb..259aae7198 100644 --- a/athena-clickhouse/athena-clickhouse.yaml +++ b/athena-clickhouse/athena-clickhouse.yaml @@ -70,10 +70,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.clickhouse.ClickHouseMuxCompositeHandler" - CodeUri: "./target/athena-clickhouse-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-clickhouse:2022.47.1' Description: "Enables Amazon Athena to communicate with ClickHouse using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMetadataHandler.java b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMetadataHandler.java index 54174bce6a..fd6ecdb379 100644 --- a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMetadataHandler.java +++ b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMetadataHandler.java @@ -40,8 +40,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.mysql.MySqlMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableSet; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -49,6 +47,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -101,8 +101,8 @@ public ClickHouseMetadataHandler(DatabaseConnectionConfig databaseConnectionConf @VisibleForTesting protected ClickHouseMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxMetadataHandler.java b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxMetadataHandler.java index f062b46e13..d1ac2226f7 100644 --- a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxMetadataHandler.java +++ b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -59,7 +59,7 @@ public ClickHouseMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected ClickHouseMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected ClickHouseMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxRecordHandler.java b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxRecordHandler.java index 047864557a..ccb54ee88a 100644 --- a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxRecordHandler.java +++ b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -58,7 +58,7 @@ public ClickHouseMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - ClickHouseMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + ClickHouseMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseRecordHandler.java b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseRecordHandler.java index d5e485f503..6728a6c4e1 100644 --- a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseRecordHandler.java +++ b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseRecordHandler.java @@ -30,18 +30,16 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; import com.amazonaws.athena.connectors.mysql.MySqlFederationExpressionParser; +import com.amazonaws.athena.connectors.mysql.MySqlMuxCompositeHandler; import com.amazonaws.athena.connectors.mysql.MySqlQueryStringBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -77,13 +75,13 @@ public ClickHouseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig public ClickHouseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new MySqlQueryStringBuilder(MYSQL_QUOTE_CHARACTER, new MySqlFederationExpressionParser(MYSQL_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - ClickHouseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final AWSSecretsManager secretsManager, - final AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + ClickHouseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, + final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMetadataHandlerTest.java b/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMetadataHandlerTest.java index 783650c302..901201ca78 100644 --- a/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMetadataHandlerTest.java +++ b/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMetadataHandlerTest.java @@ -40,10 +40,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import net.jqwik.api.Table; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -51,6 +47,10 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.PreparedStatement; @@ -79,8 +79,8 @@ public class ClickHouseMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private BlockAllocator blockAllocator; @Before @@ -90,9 +90,9 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.metadataHandler = new ClickHouseMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.blockAllocator = Mockito.mock(BlockAllocator.class); diff --git a/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcMetadataHandlerTest.java b/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcMetadataHandlerTest.java index afc3e2a1ae..aebc2c0deb 100644 --- a/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcMetadataHandlerTest.java +++ b/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcMetadataHandlerTest.java @@ -32,11 +32,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -49,8 +49,8 @@ public class ClickHouseMuxJdbcMetadataHandlerTest private ClickHouseMetadataHandler metadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -62,8 +62,8 @@ public void setup() //Mockito.when(this.allocator.createBlock(nullable(Schema.class))).thenReturn(Mockito.mock(Block.class)); this.metadataHandler = Mockito.mock(ClickHouseMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.metadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcRecordHandlerTest.java b/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcRecordHandlerTest.java index e7ff91da7b..9adc8e9096 100644 --- a/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcRecordHandlerTest.java +++ b/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class ClickHouseMuxJdbcRecordHandlerTest private Map recordHandlerMap; private ClickHouseRecordHandler recordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.recordHandler = Mockito.mock(ClickHouseRecordHandler.class); this.recordHandlerMap = Collections.singletonMap(ClickHouseConstants.NAME, this.recordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", ClickHouseConstants.NAME, diff --git a/athena-cloudera-hive/Dockerfile b/athena-cloudera-hive/Dockerfile new file mode 100644 index 0000000000..a56019f693 --- /dev/null +++ b/athena-cloudera-hive/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-cloudera-hive-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-cloudera-hive-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.cloudera.HiveMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-cloudera-hive/athena-cloudera-hive.yaml b/athena-cloudera-hive/athena-cloudera-hive.yaml index 0984f3e01d..70f2775b1b 100644 --- a/athena-cloudera-hive/athena-cloudera-hive.yaml +++ b/athena-cloudera-hive/athena-cloudera-hive.yaml @@ -65,10 +65,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.cloudera.HiveMuxCompositeHandler" - CodeUri: "./target/athena-cloudera-hive-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudera-hive:2022.47.1' Description: "Enables Amazon Athena to communicate with Coludera Hive using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-cloudera-hive/pom.xml b/athena-cloudera-hive/pom.xml index 7e4dfdbcc3..cd0a11a82d 100644 --- a/athena-cloudera-hive/pom.xml +++ b/athena-cloudera-hive/pom.xml @@ -52,12 +52,18 @@ ${mockito.version} test - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMetadataHandler.java b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMetadataHandler.java index 6f6b38b4bc..a307a63383 100644 --- a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMetadataHandler.java +++ b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMetadataHandler.java @@ -49,8 +49,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -61,6 +59,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -96,8 +96,8 @@ public HiveMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, ja @VisibleForTesting protected HiveMetadataHandler( DatabaseConnectionConfig databaseConnectionConfiguration, - AWSSecretsManager secretManager, - AmazonAthena athena, + SecretsManagerClient secretManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxMetadataHandler.java b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxMetadataHandler.java index 6110d8cc25..b99cd881e6 100644 --- a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxMetadataHandler.java +++ b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -56,7 +56,7 @@ public HiveMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected HiveMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected HiveMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandler.java b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandler.java index f87ee06bef..3dd28acccc 100644 --- a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandler.java +++ b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public HiveMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - HiveMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + HiveMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandler.java b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandler.java index ed5af5284d..95ff9f6a3e 100644 --- a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandler.java +++ b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandler.java @@ -28,15 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -62,11 +59,11 @@ public HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java } public HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new HiveQueryStringBuilder(HIVE_QUOTE_CHARACTER, new HiveFederationExpressionParser(HIVE_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMetadataHandlerTest.java b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMetadataHandlerTest.java index b1520d4a3d..abc43fc0b1 100644 --- a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMetadataHandlerTest.java +++ b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMetadataHandlerTest.java @@ -28,10 +28,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; @@ -39,7 +35,10 @@ import org.junit.BeforeClass; import org.junit.Test; import org.mockito.Mockito; - +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.*; import java.util.*; @@ -58,8 +57,8 @@ public class HiveMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private BlockAllocator blockAllocator; @BeforeClass @@ -75,9 +74,9 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.hiveMetadataHandler = new HiveMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); diff --git a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxMetadataHandlerTest.java b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxMetadataHandlerTest.java index 8f0f47fc63..344b1e4915 100644 --- a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxMetadataHandlerTest.java +++ b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxMetadataHandlerTest.java @@ -43,8 +43,8 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableRequest; import com.amazonaws.athena.connector.lambda.metadata.ListSchemasRequest; import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import static org.mockito.ArgumentMatchers.nullable; @@ -54,8 +54,8 @@ public class HiveMuxMetadataHandlerTest private HiveMetadataHandler hiveMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @BeforeClass @@ -68,8 +68,8 @@ public void setup() this.allocator = new BlockAllocatorImpl(); this.hiveMetadataHandler = Mockito.mock(HiveMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("metaHive", this.hiveMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", HiveConstants.HIVE_NAME, diff --git a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandlerTest.java b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandlerTest.java index d3fb2d0ee3..31035ae1a8 100644 --- a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandlerTest.java +++ b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandlerTest.java @@ -29,15 +29,15 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.mockito.Mockito; import org.testng.Assert; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -50,9 +50,9 @@ public class HiveMuxRecordHandlerTest private Map recordHandlerMap; private HiveRecordHandler hiveRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @BeforeClass @@ -64,9 +64,9 @@ public void setup() { this.hiveRecordHandler = Mockito.mock(HiveRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("recordHive", this.hiveRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", HiveConstants.HIVE_NAME, diff --git a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandlerTest.java b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandlerTest.java index 8cfce879a6..108474f096 100644 --- a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandlerTest.java +++ b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandlerTest.java @@ -32,11 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -46,6 +41,12 @@ import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; + import java.sql.Connection; import java.sql.Date; import java.sql.PreparedStatement; @@ -62,18 +63,18 @@ public class HiveRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-cloudera-impala/Dockerfile b/athena-cloudera-impala/Dockerfile new file mode 100644 index 0000000000..2ed43aeaa9 --- /dev/null +++ b/athena-cloudera-impala/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-cloudera-impala-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-cloudera-impala-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.cloudera.ImpalaMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-cloudera-impala/athena-cloudera-impala.yaml b/athena-cloudera-impala/athena-cloudera-impala.yaml index ee292cea3a..60dc37ed9e 100644 --- a/athena-cloudera-impala/athena-cloudera-impala.yaml +++ b/athena-cloudera-impala/athena-cloudera-impala.yaml @@ -70,10 +70,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.cloudera.ImpalaMuxCompositeHandler" - CodeUri: "./target/athena-cloudera-impala-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudera-impala:2022.47.1' Description: "Enables Amazon Athena to communicate with Cloudera Impala using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-cloudera-impala/pom.xml b/athena-cloudera-impala/pom.xml index d3b2a73d3d..cfdb74e7b3 100644 --- a/athena-cloudera-impala/pom.xml +++ b/athena-cloudera-impala/pom.xml @@ -48,12 +48,18 @@ ${mockito.version} test - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaFederationExpressionParser.java b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaFederationExpressionParser.java index 74469e70fb..3abc39655e 100644 --- a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaFederationExpressionParser.java +++ b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaFederationExpressionParser.java @@ -17,7 +17,7 @@ * limitations under the License. * #L% */ -package com.amazonaws.athena.connectors.hortonworks; +package com.amazonaws.athena.connectors.cloudera; import com.amazonaws.athena.connectors.jdbc.manager.JdbcFederationExpressionParser; import com.google.common.base.Joiner; diff --git a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMetadataHandler.java b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMetadataHandler.java index 5d75bff3cd..609a424199 100644 --- a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMetadataHandler.java +++ b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMetadataHandler.java @@ -49,8 +49,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -60,6 +58,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -92,8 +92,8 @@ public ImpalaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, @VisibleForTesting protected ImpalaMetadataHandler( DatabaseConnectionConfig databaseConnectionConfiguration, - AWSSecretsManager secretManager, - AmazonAthena athena, + SecretsManagerClient secretManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxMetadataHandler.java b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxMetadataHandler.java index dbe810912f..ec55031198 100644 --- a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxMetadataHandler.java +++ b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -56,7 +56,7 @@ public ImpalaMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected ImpalaMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected ImpalaMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandler.java b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandler.java index d1461b523e..8dbac1f9e3 100644 --- a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandler.java +++ b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public ImpalaMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - ImpalaMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + ImpalaMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandler.java b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandler.java index 8a336a0b5f..59912af693 100644 --- a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandler.java +++ b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandler.java @@ -22,22 +22,18 @@ import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; -import com.amazonaws.athena.connectors.hortonworks.ImpalaFederationExpressionParser; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -63,11 +59,11 @@ public ImpalaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, ja } public ImpalaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new ImpalaQueryStringBuilder(IMPALA_QUOTE_CHARACTER, new ImpalaFederationExpressionParser(IMPALA_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - ImpalaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + ImpalaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMetadataHandlerTest.java b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMetadataHandlerTest.java index 09746df6da..d87f00e757 100644 --- a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMetadataHandlerTest.java +++ b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMetadataHandlerTest.java @@ -28,10 +28,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; @@ -39,6 +35,10 @@ import org.junit.BeforeClass; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.*; import java.util.*; @@ -58,8 +58,8 @@ public class ImpalaMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private BlockAllocator blockAllocator; @BeforeClass public static void dataSetUP() { @@ -73,9 +73,9 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.impalaMetadataHandler = new ImpalaMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); diff --git a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxMetadataHandlerTest.java b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxMetadataHandlerTest.java index 8fe338fcb8..60f6a8af9e 100644 --- a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxMetadataHandlerTest.java +++ b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxMetadataHandlerTest.java @@ -43,8 +43,8 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableRequest; import com.amazonaws.athena.connector.lambda.metadata.ListSchemasRequest; import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import static org.mockito.ArgumentMatchers.nullable; @@ -54,8 +54,8 @@ public class ImpalaMuxMetadataHandlerTest private ImpalaMetadataHandler impalaMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @BeforeClass @@ -68,8 +68,8 @@ public void setup() this.allocator = new BlockAllocatorImpl(); this.impalaMetadataHandler = Mockito.mock(ImpalaMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("metaImpala", this.impalaMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", ImpalaConstants.IMPALA_NAME, diff --git a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandlerTest.java b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandlerTest.java index ec84d0ed0c..cff80beebb 100644 --- a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandlerTest.java +++ b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandlerTest.java @@ -29,15 +29,15 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.mockito.Mockito; import org.testng.Assert; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -50,9 +50,9 @@ public class ImpalaMuxRecordHandlerTest private Map recordHandlerMap; private ImpalaRecordHandler impalaRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @BeforeClass @@ -64,9 +64,9 @@ public void setup() { this.impalaRecordHandler = Mockito.mock(ImpalaRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("recordImpala", this.impalaRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", ImpalaConstants.IMPALA_NAME, diff --git a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaQueryStringBuilderTest.java b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaQueryStringBuilderTest.java index d87cc871c6..0b08ecfe45 100644 --- a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaQueryStringBuilderTest.java +++ b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaQueryStringBuilderTest.java @@ -20,7 +20,6 @@ package com.amazonaws.athena.connectors.cloudera; import com.amazonaws.athena.connector.lambda.domain.Split; -import com.amazonaws.athena.connectors.hortonworks.ImpalaFederationExpressionParser; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandlerTest.java b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandlerTest.java index 5e222fd508..bd0909b48f 100644 --- a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandlerTest.java +++ b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandlerTest.java @@ -28,16 +28,15 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.Range; import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; -import com.amazonaws.athena.connectors.hortonworks.ImpalaFederationExpressionParser; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -64,18 +63,18 @@ public class ImpalaRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-cloudwatch-metrics/Dockerfile b/athena-cloudwatch-metrics/Dockerfile new file mode 100644 index 0000000000..b3eafc1e38 --- /dev/null +++ b/athena-cloudwatch-metrics/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-cloudwatch-metrics-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-cloudwatch-metrics-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.cloudwatch.metrics.MetricsCompositeHandler" ] \ No newline at end of file diff --git a/athena-cloudwatch-metrics/athena-cloudwatch-metrics.yaml b/athena-cloudwatch-metrics/athena-cloudwatch-metrics.yaml index d1d815063c..974b979e37 100644 --- a/athena-cloudwatch-metrics/athena-cloudwatch-metrics.yaml +++ b/athena-cloudwatch-metrics/athena-cloudwatch-metrics.yaml @@ -52,10 +52,9 @@ Resources: spill_bucket: !Ref SpillBucket spill_prefix: !Ref SpillPrefix FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.cloudwatch.metrics.MetricsCompositeHandler" - CodeUri: "./target/athena-cloudwatch-metrics-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudwatch-metrics:2022.47.1' Description: "Enables Amazon Athena to communicate with Cloudwatch Metrics, making your metrics data accessible via SQL" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-cloudwatch-metrics/pom.xml b/athena-cloudwatch-metrics/pom.xml index 6c8bff216e..b249525238 100644 --- a/athena-cloudwatch-metrics/pom.xml +++ b/athena-cloudwatch-metrics/pom.xml @@ -16,9 +16,9 @@ withdep - com.amazonaws - aws-java-sdk-cloudwatch - ${aws-sdk.version} + software.amazon.awssdk + cloudwatch + ${aws-sdk-v2.version} diff --git a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricStatSerDe.java b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricStatSerDe.java index 44bfcef8e0..e44c66e7f1 100644 --- a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricStatSerDe.java +++ b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricStatSerDe.java @@ -19,14 +19,14 @@ */ package com.amazonaws.athena.connectors.cloudwatch.metrics; -import com.amazonaws.services.cloudwatch.model.MetricStat; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.type.CollectionType; +import software.amazon.awssdk.services.cloudwatch.model.MetricStat; import java.io.IOException; import java.util.List; +import java.util.stream.Collectors; /** * Used to serialize and deserialize Cloudwatch Metrics MetricStat objects. This is used @@ -48,7 +48,7 @@ private MetricStatSerDe() {} public static String serialize(List metricStats) { try { - return mapper.writeValueAsString(new MetricStatHolder(metricStats)); + return mapper.writeValueAsString(metricStats.stream().map(stat -> stat.toBuilder()).collect(Collectors.toList())); } catch (JsonProcessingException ex) { throw new RuntimeException(ex); @@ -64,30 +64,11 @@ public static String serialize(List metricStats) public static List deserialize(String serializedMetricStats) { try { - return mapper.readValue(serializedMetricStats, MetricStatHolder.class).getMetricStats(); + CollectionType metricStatBuilderCollection = mapper.getTypeFactory().constructCollectionType(List.class, MetricStat.serializableBuilderClass()); + return ((List) mapper.readValue(serializedMetricStats, metricStatBuilderCollection)).stream().map(stat -> stat.build()).collect(Collectors.toList()); } catch (IOException ex) { throw new RuntimeException(ex); } } - - /** - * Helper which allows us to use Jackson's Object Mapper to serialize a List of MetricStats. - */ - private static class MetricStatHolder - { - private final List metricStats; - - @JsonCreator - public MetricStatHolder(@JsonProperty("metricStats") List metricStats) - { - this.metricStats = metricStats; - } - - @JsonProperty - public List getMetricStats() - { - return metricStats; - } - } } diff --git a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricUtils.java b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricUtils.java index 40ebeacaeb..7c8b97aa90 100644 --- a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricUtils.java +++ b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricUtils.java @@ -26,15 +26,15 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.cloudwatch.model.Dimension; -import com.amazonaws.services.cloudwatch.model.DimensionFilter; -import com.amazonaws.services.cloudwatch.model.GetMetricDataRequest; -import com.amazonaws.services.cloudwatch.model.ListMetricsRequest; -import com.amazonaws.services.cloudwatch.model.Metric; -import com.amazonaws.services.cloudwatch.model.MetricDataQuery; -import com.amazonaws.services.cloudwatch.model.MetricStat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.cloudwatch.model.Dimension; +import software.amazon.awssdk.services.cloudwatch.model.DimensionFilter; +import software.amazon.awssdk.services.cloudwatch.model.GetMetricDataRequest; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsRequest; +import software.amazon.awssdk.services.cloudwatch.model.Metric; +import software.amazon.awssdk.services.cloudwatch.model.MetricDataQuery; +import software.amazon.awssdk.services.cloudwatch.model.MetricStat; import java.util.ArrayList; import java.util.Collections; @@ -70,11 +70,11 @@ private MetricUtils() {} */ protected static boolean applyMetricConstraints(ConstraintEvaluator evaluator, Metric metric, String statistic) { - if (!evaluator.apply(NAMESPACE_FIELD, metric.getNamespace())) { + if (!evaluator.apply(NAMESPACE_FIELD, metric.namespace())) { return false; } - if (!evaluator.apply(METRIC_NAME_FIELD, metric.getMetricName())) { + if (!evaluator.apply(METRIC_NAME_FIELD, metric.metricName())) { return false; } @@ -82,13 +82,13 @@ protected static boolean applyMetricConstraints(ConstraintEvaluator evaluator, M return false; } - for (Dimension next : metric.getDimensions()) { - if (evaluator.apply(DIMENSION_NAME_FIELD, next.getName()) && evaluator.apply(DIMENSION_VALUE_FIELD, next.getValue())) { + for (Dimension next : metric.dimensions()) { + if (evaluator.apply(DIMENSION_NAME_FIELD, next.name()) && evaluator.apply(DIMENSION_VALUE_FIELD, next.value())) { return true; } } - if (metric.getDimensions().isEmpty() && + if (metric.dimensions().isEmpty() && evaluator.apply(DIMENSION_NAME_FIELD, null) && evaluator.apply(DIMENSION_VALUE_FIELD, null)) { return true; @@ -100,28 +100,29 @@ protected static boolean applyMetricConstraints(ConstraintEvaluator evaluator, M /** * Attempts to push the supplied predicate constraints onto the Cloudwatch Metrics request. */ - protected static void pushDownPredicate(Constraints constraints, ListMetricsRequest listMetricsRequest) + protected static void pushDownPredicate(Constraints constraints, ListMetricsRequest.Builder listMetricsRequest) { Map summary = constraints.getSummary(); ValueSet namespaceConstraint = summary.get(NAMESPACE_FIELD); if (namespaceConstraint != null && namespaceConstraint.isSingleValue()) { - listMetricsRequest.setNamespace(namespaceConstraint.getSingleValue().toString()); + listMetricsRequest.namespace(namespaceConstraint.getSingleValue().toString()); } ValueSet metricConstraint = summary.get(METRIC_NAME_FIELD); if (metricConstraint != null && metricConstraint.isSingleValue()) { - listMetricsRequest.setMetricName(metricConstraint.getSingleValue().toString()); + listMetricsRequest.metricName(metricConstraint.getSingleValue().toString()); } ValueSet dimensionNameConstraint = summary.get(DIMENSION_NAME_FIELD); ValueSet dimensionValueConstraint = summary.get(DIMENSION_VALUE_FIELD); if (dimensionNameConstraint != null && dimensionNameConstraint.isSingleValue() && dimensionValueConstraint != null && dimensionValueConstraint.isSingleValue()) { - DimensionFilter filter = new DimensionFilter() - .withName(dimensionNameConstraint.getSingleValue().toString()) - .withValue(dimensionValueConstraint.getSingleValue().toString()); - listMetricsRequest.setDimensions(Collections.singletonList(filter)); + DimensionFilter filter = DimensionFilter.builder() + .name(dimensionNameConstraint.getSingleValue().toString()) + .value(dimensionValueConstraint.getSingleValue().toString()) + .build(); + listMetricsRequest.dimensions(Collections.singletonList(filter)); } } @@ -136,18 +137,15 @@ protected static GetMetricDataRequest makeGetMetricDataRequest(ReadRecordsReques Split split = readRecordsRequest.getSplit(); String serializedMetricStats = split.getProperty(MetricStatSerDe.SERIALIZED_METRIC_STATS_FIELD_NAME); List metricStats = MetricStatSerDe.deserialize(serializedMetricStats); - GetMetricDataRequest dataRequest = new GetMetricDataRequest(); - com.amazonaws.services.cloudwatch.model.Metric metric = new com.amazonaws.services.cloudwatch.model.Metric(); - metric.setNamespace(split.getProperty(NAMESPACE_FIELD)); - metric.setMetricName(split.getProperty(METRIC_NAME_FIELD)); + GetMetricDataRequest.Builder dataRequestBuilder = GetMetricDataRequest.builder(); List metricDataQueries = new ArrayList<>(); int metricId = 1; for (MetricStat nextMetricStat : metricStats) { - metricDataQueries.add(new MetricDataQuery().withMetricStat(nextMetricStat).withId("m" + metricId++)); + metricDataQueries.add(MetricDataQuery.builder().metricStat(nextMetricStat).id("m" + metricId++).build()); } - dataRequest.withMetricDataQueries(metricDataQueries); + dataRequestBuilder.metricDataQueries(metricDataQueries); ValueSet timeConstraint = readRecordsRequest.getConstraints().getSummary().get(TIMESTAMP_FIELD); if (timeConstraint instanceof SortedRangeSet && !timeConstraint.isNullAllowed()) { @@ -162,30 +160,30 @@ protected static GetMetricDataRequest makeGetMetricDataRequest(ReadRecordsReques Long lowerBound = (Long) basicPredicate.getLow().getValue(); //TODO: confirm timezone handling logger.info("makeGetMetricsRequest: with startTime " + (lowerBound * 1000) + " " + new Date(lowerBound * 1000)); - dataRequest.withStartTime(new Date(lowerBound * 1000)); + dataRequestBuilder.startTime(new Date(lowerBound * 1000).toInstant()); } else { //TODO: confirm timezone handling - dataRequest.withStartTime(new Date(0)); + dataRequestBuilder.startTime(new Date(0).toInstant()); } if (!basicPredicate.getHigh().isNullValue()) { Long upperBound = (Long) basicPredicate.getHigh().getValue(); //TODO: confirm timezone handling logger.info("makeGetMetricsRequest: with endTime " + (upperBound * 1000) + " " + new Date(upperBound * 1000)); - dataRequest.withEndTime(new Date(upperBound * 1000)); + dataRequestBuilder.endTime(new Date(upperBound * 1000).toInstant()); } else { //TODO: confirm timezone handling - dataRequest.withEndTime(new Date(System.currentTimeMillis())); + dataRequestBuilder.endTime(new Date(System.currentTimeMillis()).toInstant()); } } else { //TODO: confirm timezone handling - dataRequest.withStartTime(new Date(0)); - dataRequest.withEndTime(new Date(System.currentTimeMillis())); + dataRequestBuilder.startTime(new Date(0).toInstant()); + dataRequestBuilder.endTime(new Date(System.currentTimeMillis()).toInstant()); } - return dataRequest; + return dataRequestBuilder.build(); } } diff --git a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsExceptionFilter.java b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsExceptionFilter.java index 4810c6a017..1efb757f46 100644 --- a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsExceptionFilter.java +++ b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsExceptionFilter.java @@ -20,8 +20,8 @@ package com.amazonaws.athena.connectors.cloudwatch.metrics; import com.amazonaws.athena.connector.lambda.ThrottlingInvoker; -import com.amazonaws.services.cloudwatch.model.AmazonCloudWatchException; -import com.amazonaws.services.cloudwatch.model.LimitExceededException; +import software.amazon.awssdk.services.cloudwatch.model.CloudWatchException; +import software.amazon.awssdk.services.cloudwatch.model.LimitExceededException; /** * Used to identify Exceptions that are related to Cloudwatch Metrics throttling events. @@ -36,11 +36,11 @@ private MetricsExceptionFilter() {} @Override public boolean isMatch(Exception ex) { - if (ex instanceof AmazonCloudWatchException && ex.getMessage().startsWith("Rate exceeded")) { + if (ex instanceof CloudWatchException && ex.getMessage().startsWith("Rate exceeded")) { return true; } - if (ex instanceof AmazonCloudWatchException && ex.getMessage().startsWith("Request has been throttled")) { + if (ex instanceof CloudWatchException && ex.getMessage().startsWith("Request has been throttled")) { return true; } diff --git a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsMetadataHandler.java b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsMetadataHandler.java index 866b465162..2b64e7c129 100644 --- a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsMetadataHandler.java +++ b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsMetadataHandler.java @@ -42,19 +42,18 @@ import com.amazonaws.athena.connectors.cloudwatch.metrics.tables.MetricSamplesTable; import com.amazonaws.athena.connectors.cloudwatch.metrics.tables.MetricsTable; import com.amazonaws.athena.connectors.cloudwatch.metrics.tables.Table; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder; -import com.amazonaws.services.cloudwatch.model.ListMetricsRequest; -import com.amazonaws.services.cloudwatch.model.ListMetricsResult; -import com.amazonaws.services.cloudwatch.model.Metric; -import com.amazonaws.services.cloudwatch.model.MetricStat; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.util.CollectionUtils; import com.google.common.collect.Lists; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.cloudwatch.CloudWatchClient; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsRequest; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsResponse; +import software.amazon.awssdk.services.cloudwatch.model.Metric; +import software.amazon.awssdk.services.cloudwatch.model.MetricStat; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.utils.CollectionUtils; import java.util.ArrayList; import java.util.Collections; @@ -107,7 +106,7 @@ public class MetricsMetadataHandler //Used to handle throttling events by applying AIMD congestion control private final ThrottlingInvoker invoker; - private final AmazonCloudWatch metrics; + private final CloudWatchClient metrics; static { //The statistics supported by Cloudwatch Metrics by default @@ -133,16 +132,16 @@ public class MetricsMetadataHandler public MetricsMetadataHandler(java.util.Map configOptions) { super(SOURCE_TYPE, configOptions); - this.metrics = AmazonCloudWatchClientBuilder.standard().build(); + this.metrics = CloudWatchClient.create(); this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build(); } @VisibleForTesting protected MetricsMetadataHandler( - AmazonCloudWatch metrics, + CloudWatchClient metrics, EncryptionKeyFactory keyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, java.util.Map configOptions) @@ -235,33 +234,36 @@ public GetSplitsResponse doGetSplits(BlockAllocator blockAllocator, GetSplitsReq try (ConstraintEvaluator constraintEvaluator = new ConstraintEvaluator(blockAllocator, METRIC_DATA_TABLE.getSchema(), getSplitsRequest.getConstraints())) { - ListMetricsRequest listMetricsRequest = new ListMetricsRequest(); - MetricUtils.pushDownPredicate(getSplitsRequest.getConstraints(), listMetricsRequest); - listMetricsRequest.setNextToken(getSplitsRequest.getContinuationToken()); + ListMetricsRequest.Builder listMetricsRequestBuilder = ListMetricsRequest.builder(); + MetricUtils.pushDownPredicate(getSplitsRequest.getConstraints(), listMetricsRequestBuilder); + listMetricsRequestBuilder.nextToken(getSplitsRequest.getContinuationToken()); String period = getPeriodFromConstraint(getSplitsRequest.getConstraints()); Set splits = new HashSet<>(); - ListMetricsResult result = invoker.invoke(() -> metrics.listMetrics(listMetricsRequest)); + ListMetricsRequest listMetricsRequest = listMetricsRequestBuilder.build(); + ListMetricsResponse result = invoker.invoke(() -> metrics.listMetrics(listMetricsRequest)); List metricStats = new ArrayList<>(100); - for (Metric nextMetric : result.getMetrics()) { + for (Metric nextMetric : result.metrics()) { for (String nextStatistic : STATISTICS) { if (MetricUtils.applyMetricConstraints(constraintEvaluator, nextMetric, nextStatistic)) { - metricStats.add(new MetricStat() - .withMetric(new Metric() - .withNamespace(nextMetric.getNamespace()) - .withMetricName(nextMetric.getMetricName()) - .withDimensions(nextMetric.getDimensions())) - .withPeriod(Integer.valueOf(period)) - .withStat(nextStatistic)); + metricStats.add(MetricStat.builder() + .metric(Metric.builder() + .namespace(nextMetric.namespace()) + .metricName(nextMetric.metricName()) + .dimensions(nextMetric.dimensions()) + .build()) + .period(Integer.valueOf(period)) + .stat(nextStatistic) + .build()); } } } String continuationToken = null; - if (result.getNextToken() != null && - !result.getNextToken().equalsIgnoreCase(listMetricsRequest.getNextToken())) { - continuationToken = result.getNextToken(); + if (result.nextToken() != null && + !result.nextToken().equalsIgnoreCase(listMetricsRequest.nextToken())) { + continuationToken = result.nextToken(); } if (CollectionUtils.isNullOrEmpty(metricStats)) { diff --git a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandler.java b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandler.java index 93b18c62d3..3ca9219f96 100644 --- a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandler.java +++ b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandler.java @@ -29,29 +29,25 @@ import com.amazonaws.athena.connectors.cloudwatch.metrics.tables.MetricSamplesTable; import com.amazonaws.athena.connectors.cloudwatch.metrics.tables.MetricsTable; import com.amazonaws.athena.connectors.cloudwatch.metrics.tables.Table; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder; -import com.amazonaws.services.cloudwatch.model.Dimension; -import com.amazonaws.services.cloudwatch.model.GetMetricDataRequest; -import com.amazonaws.services.cloudwatch.model.GetMetricDataResult; -import com.amazonaws.services.cloudwatch.model.ListMetricsRequest; -import com.amazonaws.services.cloudwatch.model.ListMetricsResult; -import com.amazonaws.services.cloudwatch.model.Metric; -import com.amazonaws.services.cloudwatch.model.MetricDataQuery; -import com.amazonaws.services.cloudwatch.model.MetricDataResult; -import com.amazonaws.services.cloudwatch.model.MetricStat; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.cloudwatch.CloudWatchClient; +import software.amazon.awssdk.services.cloudwatch.model.Dimension; +import software.amazon.awssdk.services.cloudwatch.model.GetMetricDataRequest; +import software.amazon.awssdk.services.cloudwatch.model.GetMetricDataResponse; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsRequest; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsResponse; +import software.amazon.awssdk.services.cloudwatch.model.Metric; +import software.amazon.awssdk.services.cloudwatch.model.MetricDataQuery; +import software.amazon.awssdk.services.cloudwatch.model.MetricDataResult; +import software.amazon.awssdk.services.cloudwatch.model.MetricStat; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; -import java.util.Date; +import java.time.Instant; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -99,23 +95,23 @@ public class MetricsRecordHandler //Used to handle throttling events by applying AIMD congestion control private final ThrottlingInvoker invoker; - private final AmazonS3 amazonS3; - private final AmazonCloudWatch metrics; + private final S3Client amazonS3; + private final CloudWatchClient cloudwatchClient; public MetricsRecordHandler(java.util.Map configOptions) { - this(AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), - AmazonCloudWatchClientBuilder.standard().build(), configOptions); + this(S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), + CloudWatchClient.create(), configOptions); } @VisibleForTesting - protected MetricsRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, AmazonCloudWatch metrics, java.util.Map configOptions) + protected MetricsRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, CloudWatchClient metrics, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.amazonS3 = amazonS3; - this.metrics = metrics; + this.cloudwatchClient = metrics; this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions) .withInitialDelayMs(THROTTLING_INITIAL_DELAY) .withIncrease(THROTTLING_INCREMENTAL_INCREASE) @@ -146,37 +142,39 @@ else if (readRecordsRequest.getTableName().getTableName().equalsIgnoreCase(METRI private void readMetricsWithConstraint(BlockSpiller blockSpiller, ReadRecordsRequest request, QueryStatusChecker queryStatusChecker) throws TimeoutException { - ListMetricsRequest listMetricsRequest = new ListMetricsRequest(); - MetricUtils.pushDownPredicate(request.getConstraints(), listMetricsRequest); + ListMetricsRequest.Builder listMetricsRequestBuilder = ListMetricsRequest.builder(); + MetricUtils.pushDownPredicate(request.getConstraints(), listMetricsRequestBuilder); String prevToken; + String nextToken; Set requiredFields = new HashSet<>(); request.getSchema().getFields().stream().forEach(next -> requiredFields.add(next.getName())); ValueSet dimensionNameConstraint = request.getConstraints().getSummary().get(DIMENSION_NAME_FIELD); ValueSet dimensionValueConstraint = request.getConstraints().getSummary().get(DIMENSION_VALUE_FIELD); do { - prevToken = listMetricsRequest.getNextToken(); - ListMetricsResult result = invoker.invoke(() -> metrics.listMetrics(listMetricsRequest)); - for (Metric nextMetric : result.getMetrics()) { + ListMetricsRequest listMetricsRequest = listMetricsRequestBuilder.build(); + prevToken = listMetricsRequest.nextToken(); + ListMetricsResponse result = invoker.invoke(() -> cloudwatchClient.listMetrics(listMetricsRequest)); + for (Metric nextMetric : result.metrics()) { blockSpiller.writeRows((Block block, int row) -> { boolean matches = MetricUtils.applyMetricConstraints(blockSpiller.getConstraintEvaluator(), nextMetric, null); if (matches) { - matches &= block.offerValue(METRIC_NAME_FIELD, row, nextMetric.getMetricName()); - matches &= block.offerValue(NAMESPACE_FIELD, row, nextMetric.getNamespace()); + matches &= block.offerValue(METRIC_NAME_FIELD, row, nextMetric.metricName()); + matches &= block.offerValue(NAMESPACE_FIELD, row, nextMetric.namespace()); matches &= block.offerComplexValue(STATISTIC_FIELD, row, DEFAULT, STATISTICS); matches &= block.offerComplexValue(DIMENSIONS_FIELD, row, (Field field, Object val) -> { if (field.getName().equals(DIMENSION_NAME_FIELD)) { - return ((Dimension) val).getName(); + return ((Dimension) val).name(); } else if (field.getName().equals(DIMENSION_VALUE_FIELD)) { - return ((Dimension) val).getValue(); + return ((Dimension) val).value(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - nextMetric.getDimensions()); + nextMetric.dimensions()); //This field is 'faked' in that we just use it as a convenient way to filter single dimensions. As such //we always populate it with the value of the filter if the constraint passed and the filter was singleValue @@ -193,9 +191,10 @@ else if (field.getName().equals(DIMENSION_VALUE_FIELD)) { return matches ? 1 : 0; }); } - listMetricsRequest.setNextToken(result.getNextToken()); + nextToken = result.nextToken(); + listMetricsRequestBuilder.nextToken(nextToken); } - while (listMetricsRequest.getNextToken() != null && !listMetricsRequest.getNextToken().equalsIgnoreCase(prevToken) && queryStatusChecker.isQueryRunning()); + while (nextToken != null && !nextToken.equalsIgnoreCase(prevToken) && queryStatusChecker.isQueryRunning()); } /** @@ -204,46 +203,49 @@ else if (field.getName().equals(DIMENSION_VALUE_FIELD)) { private void readMetricSamplesWithConstraint(BlockSpiller blockSpiller, ReadRecordsRequest request, QueryStatusChecker queryStatusChecker) throws TimeoutException { - GetMetricDataRequest dataRequest = MetricUtils.makeGetMetricDataRequest(request); + GetMetricDataRequest originalDataRequest = MetricUtils.makeGetMetricDataRequest(request); Map queries = new HashMap<>(); - for (MetricDataQuery query : dataRequest.getMetricDataQueries()) { - queries.put(query.getId(), query); + for (MetricDataQuery query : originalDataRequest.metricDataQueries()) { + queries.put(query.id(), query); } + GetMetricDataRequest.Builder dataRequestBuilder = originalDataRequest.toBuilder(); String prevToken; + String nextToken; ValueSet dimensionNameConstraint = request.getConstraints().getSummary().get(DIMENSION_NAME_FIELD); ValueSet dimensionValueConstraint = request.getConstraints().getSummary().get(DIMENSION_VALUE_FIELD); do { - prevToken = dataRequest.getNextToken(); - GetMetricDataResult result = invoker.invoke(() -> metrics.getMetricData(dataRequest)); - for (MetricDataResult nextMetric : result.getMetricDataResults()) { - MetricStat metricStat = queries.get(nextMetric.getId()).getMetricStat(); - List timestamps = nextMetric.getTimestamps(); - List values = nextMetric.getValues(); - for (int i = 0; i < nextMetric.getValues().size(); i++) { + GetMetricDataRequest dataRequest = dataRequestBuilder.build(); + prevToken = dataRequest.nextToken(); + GetMetricDataResponse result = invoker.invoke(() -> cloudwatchClient.getMetricData(dataRequest)); + for (MetricDataResult nextMetric : result.metricDataResults()) { + MetricStat metricStat = queries.get(nextMetric.id()).metricStat(); + List timestamps = nextMetric.timestamps(); + List values = nextMetric.values(); + for (int i = 0; i < nextMetric.values().size(); i++) { int sampleNum = i; blockSpiller.writeRows((Block block, int row) -> { /** * Most constraints were already applied at split generation so we only need to apply * a subset. */ - block.offerValue(METRIC_NAME_FIELD, row, metricStat.getMetric().getMetricName()); - block.offerValue(NAMESPACE_FIELD, row, metricStat.getMetric().getNamespace()); - block.offerValue(STATISTIC_FIELD, row, metricStat.getStat()); + block.offerValue(METRIC_NAME_FIELD, row, metricStat.metric().metricName()); + block.offerValue(NAMESPACE_FIELD, row, metricStat.metric().namespace()); + block.offerValue(STATISTIC_FIELD, row, metricStat.stat()); block.offerComplexValue(DIMENSIONS_FIELD, row, (Field field, Object val) -> { if (field.getName().equals(DIMENSION_NAME_FIELD)) { - return ((Dimension) val).getName(); + return ((Dimension) val).name(); } else if (field.getName().equals(DIMENSION_VALUE_FIELD)) { - return ((Dimension) val).getValue(); + return ((Dimension) val).value(); } throw new RuntimeException("Unexpected field " + field.getName()); }, - metricStat.getMetric().getDimensions()); + metricStat.metric().dimensions()); //This field is 'faked' in that we just use it as a convenient way to filter single dimensions. As such //we always populate it with the value of the filter if the constraint passed and the filter was singleValue @@ -257,19 +259,20 @@ else if (field.getName().equals(DIMENSION_VALUE_FIELD)) { ? null : dimensionValueConstraint.getSingleValue().toString(); block.offerValue(DIMENSION_VALUE_FIELD, row, dimVal); - block.offerValue(PERIOD_FIELD, row, metricStat.getPeriod()); + block.offerValue(PERIOD_FIELD, row, metricStat.period()); boolean matches = true; block.offerValue(VALUE_FIELD, row, values.get(sampleNum)); - long timestamp = timestamps.get(sampleNum).getTime() / 1000; + long timestamp = timestamps.get(sampleNum).getEpochSecond() / 1000; block.offerValue(TIMESTAMP_FIELD, row, timestamp); return matches ? 1 : 0; }); } } - dataRequest.setNextToken(result.getNextToken()); + nextToken = result.nextToken(); + dataRequestBuilder.nextToken(result.nextToken()); } - while (dataRequest.getNextToken() != null && !dataRequest.getNextToken().equalsIgnoreCase(prevToken) && queryStatusChecker.isQueryRunning()); + while (nextToken != null && !nextToken.equalsIgnoreCase(prevToken) && queryStatusChecker.isQueryRunning()); } } diff --git a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricStatSerDeTest.java b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricStatSerDeTest.java index 63d15023bc..bfde6ac296 100644 --- a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricStatSerDeTest.java +++ b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricStatSerDeTest.java @@ -19,12 +19,12 @@ */ package com.amazonaws.athena.connectors.cloudwatch.metrics; -import com.amazonaws.services.cloudwatch.model.Dimension; -import com.amazonaws.services.cloudwatch.model.Metric; -import com.amazonaws.services.cloudwatch.model.MetricStat; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.cloudwatch.model.Dimension; +import software.amazon.awssdk.services.cloudwatch.model.Metric; +import software.amazon.awssdk.services.cloudwatch.model.MetricStat; import java.util.ArrayList; import java.util.List; @@ -34,8 +34,8 @@ public class MetricStatSerDeTest { private static final Logger logger = LoggerFactory.getLogger(MetricStatSerDeTest.class); - private static final String EXPECTED_SERIALIZATION = "{\"metricStats\":[{\"metric\":{\"namespace\":\"namespace\",\"metricName\":\"metricName\",\"dimensions\":[" + - "{\"name\":\"dim_name1\",\"value\":\"dim_value1\"},{\"name\":\"dim_name2\",\"value\":\"dim_value2\"}]},\"period\":60,\"stat\":\"p90\",\"unit\":null}]}"; + private static final String EXPECTED_SERIALIZATION = "[{\"metric\":{\"namespace\":\"namespace\",\"metricName\":\"metricName\",\"dimensions\":[" + + "{\"name\":\"dim_name1\",\"value\":\"dim_value1\"},{\"name\":\"dim_name2\",\"value\":\"dim_value2\"}]},\"period\":60,\"stat\":\"p90\",\"unit\":null}]"; @Test public void serializeTest() @@ -48,17 +48,19 @@ public void serializeTest() String namespace = "namespace"; List dimensions = new ArrayList<>(); - dimensions.add(new Dimension().withName("dim_name1").withValue("dim_value1")); - dimensions.add(new Dimension().withName("dim_name2").withValue("dim_value2")); + dimensions.add(Dimension.builder().name("dim_name1").value("dim_value1").build()); + dimensions.add(Dimension.builder().name("dim_name2").value("dim_value2").build()); List metricStats = new ArrayList<>(); - metricStats.add(new MetricStat() - .withMetric(new Metric() - .withNamespace(namespace) - .withMetricName(metricName) - .withDimensions(dimensions)) - .withPeriod(60) - .withStat(statistic)); + metricStats.add(MetricStat.builder() + .metric(Metric.builder() + .namespace(namespace) + .metricName(metricName) + .dimensions(dimensions) + .build()) + .period(60) + .stat(statistic) + .build()); String actualSerialization = MetricStatSerDe.serialize(metricStats); logger.info("serializeTest: {}", actualSerialization); List actual = MetricStatSerDe.deserialize(actualSerialization); diff --git a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricUtilsTest.java b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricUtilsTest.java index 7929635f31..c32cd6cd5c 100644 --- a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricUtilsTest.java +++ b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricUtilsTest.java @@ -31,18 +31,18 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; -import com.amazonaws.services.cloudwatch.model.Dimension; -import com.amazonaws.services.cloudwatch.model.DimensionFilter; -import com.amazonaws.services.cloudwatch.model.GetMetricDataRequest; -import com.amazonaws.services.cloudwatch.model.ListMetricsRequest; -import com.amazonaws.services.cloudwatch.model.Metric; -import com.amazonaws.services.cloudwatch.model.MetricStat; import org.apache.arrow.vector.types.pojo.Schema; import com.google.common.collect.ImmutableList; import org.apache.arrow.vector.types.Types; import org.junit.After; import org.junit.Before; import org.junit.Test; +import software.amazon.awssdk.services.cloudwatch.model.Dimension; +import software.amazon.awssdk.services.cloudwatch.model.DimensionFilter; +import software.amazon.awssdk.services.cloudwatch.model.GetMetricDataRequest; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsRequest; +import software.amazon.awssdk.services.cloudwatch.model.Metric; +import software.amazon.awssdk.services.cloudwatch.model.MetricStat; import java.util.ArrayList; import java.util.Collections; @@ -100,33 +100,21 @@ public void applyMetricConstraints() ConstraintEvaluator constraintEvaluator = new ConstraintEvaluator(allocator, schema, new Constraints(constraintsMap, Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT)); - Metric metric = new Metric() - .withNamespace("match1") - .withMetricName("match2") - .withDimensions(new Dimension().withName("match4").withValue("match5")); + Metric metric = Metric.builder() + .namespace("match1") + .metricName("match2") + .dimensions(Dimension.builder().name("match4").value("match5").build()) + .build(); String statistic = "match3"; assertTrue(MetricUtils.applyMetricConstraints(constraintEvaluator, metric, statistic)); - assertFalse(MetricUtils.applyMetricConstraints(constraintEvaluator, copyMetric(metric).withNamespace("no_match"), statistic)); - assertFalse(MetricUtils.applyMetricConstraints(constraintEvaluator, copyMetric(metric).withMetricName("no_match"), statistic)); + assertFalse(MetricUtils.applyMetricConstraints(constraintEvaluator, metric.toBuilder().namespace("no_match").build(), statistic)); + assertFalse(MetricUtils.applyMetricConstraints(constraintEvaluator, metric.toBuilder().metricName("no_match").build(), statistic)); assertFalse(MetricUtils.applyMetricConstraints(constraintEvaluator, - copyMetric(metric).withDimensions(Collections.singletonList(new Dimension().withName("no_match").withValue("match5"))), statistic)); + metric.toBuilder().dimensions(Collections.singletonList(Dimension.builder().name("no_match").value("match5").build())).build(), statistic)); assertFalse(MetricUtils.applyMetricConstraints(constraintEvaluator, - copyMetric(metric).withDimensions(Collections.singletonList(new Dimension().withName("match4").withValue("no_match"))), statistic)); - assertFalse(MetricUtils.applyMetricConstraints(constraintEvaluator, copyMetric(metric), "no_match")); - } - - private Metric copyMetric(Metric metric) - { - Metric newMetric = new Metric() - .withNamespace(metric.getNamespace()) - .withMetricName(metric.getMetricName()); - - List dims = new ArrayList<>(); - for (Dimension next : metric.getDimensions()) { - dims.add(new Dimension().withName(next.getName()).withValue(next.getValue())); - } - return newMetric.withDimensions(dims); + metric.toBuilder().dimensions(Collections.singletonList(Dimension.builder().name("match4").value("no_match").build())).build(), statistic)); + assertFalse(MetricUtils.applyMetricConstraints(constraintEvaluator, metric, "no_match")); } @Test @@ -139,13 +127,14 @@ public void pushDownPredicate() constraintsMap.put(DIMENSION_NAME_FIELD, makeStringEquals(allocator, "match4")); constraintsMap.put(DIMENSION_VALUE_FIELD, makeStringEquals(allocator, "match5")); - ListMetricsRequest request = new ListMetricsRequest(); - MetricUtils.pushDownPredicate(new Constraints(constraintsMap, Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT), request); + ListMetricsRequest.Builder requestBuilder = ListMetricsRequest.builder(); + MetricUtils.pushDownPredicate(new Constraints(constraintsMap, Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT), requestBuilder); + ListMetricsRequest request = requestBuilder.build(); - assertEquals("match1", request.getNamespace()); - assertEquals("match2", request.getMetricName()); - assertEquals(1, request.getDimensions().size()); - assertEquals(new DimensionFilter().withName("match4").withValue("match5"), request.getDimensions().get(0)); + assertEquals("match1", request.namespace()); + assertEquals("match2", request.metricName()); + assertEquals(1, request.dimensions().size()); + assertEquals(DimensionFilter.builder().name("match4").value("match5").build(), request.dimensions().get(0)); } @Test @@ -159,17 +148,19 @@ public void makeGetMetricDataRequest() String namespace = "namespace"; List dimensions = new ArrayList<>(); - dimensions.add(new Dimension().withName("dim_name1").withValue("dim_value1")); - dimensions.add(new Dimension().withName("dim_name2").withValue("dim_value2")); + dimensions.add(Dimension.builder().name("dim_name1").value("dim_value1").build()); + dimensions.add(Dimension.builder().name("dim_name2").value("dim_value2").build()); List metricStats = new ArrayList<>(); - metricStats.add(new MetricStat() - .withMetric(new Metric() - .withNamespace(namespace) - .withMetricName(metricName) - .withDimensions(dimensions)) - .withPeriod(60) - .withStat(statistic)); + metricStats.add(MetricStat.builder() + .metric(Metric.builder() + .namespace(namespace) + .metricName(metricName) + .dimensions(dimensions) + .build()) + .period(60) + .stat(statistic) + .build()); Split split = Split.newBuilder(null, null) .add(NAMESPACE_FIELD, namespace) @@ -198,16 +189,16 @@ public void makeGetMetricDataRequest() ); GetMetricDataRequest actual = MetricUtils.makeGetMetricDataRequest(request); - assertEquals(1, actual.getMetricDataQueries().size()); - assertNotNull(actual.getMetricDataQueries().get(0).getId()); - MetricStat metricStat = actual.getMetricDataQueries().get(0).getMetricStat(); + assertEquals(1, actual.metricDataQueries().size()); + assertNotNull(actual.metricDataQueries().get(0).id()); + MetricStat metricStat = actual.metricDataQueries().get(0).metricStat(); assertNotNull(metricStat); - assertEquals(metricName, metricStat.getMetric().getMetricName()); - assertEquals(namespace, metricStat.getMetric().getNamespace()); - assertEquals(statistic, metricStat.getStat()); - assertEquals(period, metricStat.getPeriod()); - assertEquals(2, metricStat.getMetric().getDimensions().size()); - assertEquals(1000L, actual.getStartTime().getTime()); - assertTrue(actual.getStartTime().getTime() <= System.currentTimeMillis() + 1_000); + assertEquals(metricName, metricStat.metric().metricName()); + assertEquals(namespace, metricStat.metric().namespace()); + assertEquals(statistic, metricStat.stat()); + assertEquals(period, metricStat.period()); + assertEquals(2, metricStat.metric().dimensions().size()); + assertEquals(1000L, actual.startTime().toEpochMilli()); + assertTrue(actual.startTime().toEpochMilli() <= System.currentTimeMillis() + 1_000); } } diff --git a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsMetadataHandlerTest.java b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsMetadataHandlerTest.java index 0dcf33d5c1..a194c74185 100644 --- a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsMetadataHandlerTest.java +++ b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsMetadataHandlerTest.java @@ -43,12 +43,6 @@ import com.amazonaws.athena.connector.lambda.metadata.MetadataResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.model.ListMetricsRequest; -import com.amazonaws.services.cloudwatch.model.ListMetricsResult; -import com.amazonaws.services.cloudwatch.model.Metric; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.After; @@ -60,6 +54,12 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.cloudwatch.CloudWatchClient; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsRequest; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsResponse; +import software.amazon.awssdk.services.cloudwatch.model.Metric; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.Collections; @@ -92,13 +92,13 @@ public class MetricsMetadataHandlerTest private BlockAllocator allocator; @Mock - private AmazonCloudWatch mockMetrics; + private CloudWatchClient mockMetrics; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Before public void setUp() @@ -273,17 +273,20 @@ public void doGetMetricSamplesSplits() ListMetricsRequest request = invocation.getArgument(0, ListMetricsRequest.class); //assert that the namespace filter was indeed pushed down - assertEquals(namespaceFilter, request.getNamespace()); - String nextToken = (request.getNextToken() == null) ? "valid" : null; + assertEquals(namespaceFilter, request.namespace()); + String nextToken = (request.nextToken() == null) ? "valid" : null; List metrics = new ArrayList<>(); for (int i = 0; i < numMetrics; i++) { //first page does not match constraints, but second page should - String mockNamespace = (request.getNextToken() == null) ? "NotMyNameSpace" : namespaceFilter; - metrics.add(new Metric().withNamespace(mockNamespace).withMetricName("metric-" + i)); + String mockNamespace = (request.nextToken() == null) ? "NotMyNameSpace" : namespaceFilter; + metrics.add(Metric.builder() + .namespace(mockNamespace) + .metricName("metric-" + i) + .build()); } - return new ListMetricsResult().withNextToken(nextToken).withMetrics(metrics); + return ListMetricsResponse.builder().nextToken(nextToken).metrics(metrics).build(); }); Schema schema = SchemaBuilder.newBuilder().addIntField("partitionId").build(); @@ -356,9 +359,12 @@ public void doGetMetricSamplesSplitsEmptyMetrics() when(mockMetrics.listMetrics(nullable(ListMetricsRequest.class))).thenAnswer((InvocationOnMock invocation) -> { List metrics = new ArrayList<>(); for (int i = 0; i < numMetrics; i++) { - metrics.add(new Metric().withNamespace(namespace).withMetricName("metric-" + i)); + metrics.add(Metric.builder() + .namespace(namespace) + .metricName("metric-" + i) + .build()); } - return new ListMetricsResult().withNextToken(null).withMetrics(metrics); + return ListMetricsResponse.builder().nextToken(null).metrics(metrics).build(); }); Schema schema = SchemaBuilder.newBuilder().addIntField("partitionId").build(); diff --git a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandlerTest.java b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandlerTest.java index bf90e3134a..8b50b97881 100644 --- a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandlerTest.java +++ b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandlerTest.java @@ -37,23 +37,6 @@ import com.amazonaws.athena.connectors.cloudwatch.metrics.tables.MetricSamplesTable; import com.amazonaws.athena.connectors.cloudwatch.metrics.tables.MetricsTable; import com.amazonaws.athena.connectors.cloudwatch.metrics.tables.Table; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.model.Dimension; -import com.amazonaws.services.cloudwatch.model.GetMetricDataRequest; -import com.amazonaws.services.cloudwatch.model.GetMetricDataResult; -import com.amazonaws.services.cloudwatch.model.ListMetricsRequest; -import com.amazonaws.services.cloudwatch.model.ListMetricsResult; -import com.amazonaws.services.cloudwatch.model.Metric; -import com.amazonaws.services.cloudwatch.model.MetricDataQuery; -import com.amazonaws.services.cloudwatch.model.MetricDataResult; -import com.amazonaws.services.cloudwatch.model.MetricStat; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.io.ByteStreams; import org.junit.After; import org.junit.Before; @@ -65,9 +48,29 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.cloudwatch.CloudWatchClient; +import software.amazon.awssdk.services.cloudwatch.model.Dimension; +import software.amazon.awssdk.services.cloudwatch.model.GetMetricDataRequest; +import software.amazon.awssdk.services.cloudwatch.model.GetMetricDataResponse; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsRequest; +import software.amazon.awssdk.services.cloudwatch.model.ListMetricsResponse; +import software.amazon.awssdk.services.cloudwatch.model.Metric; +import software.amazon.awssdk.services.cloudwatch.model.MetricDataQuery; +import software.amazon.awssdk.services.cloudwatch.model.MetricDataResult; +import software.amazon.awssdk.services.cloudwatch.model.MetricStat; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; import java.io.InputStream; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.Date; @@ -112,16 +115,16 @@ public class MetricsRecordHandlerTest private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @Mock - private AmazonCloudWatch mockMetrics; + private CloudWatchClient mockMetrics; @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Before public void setUp() @@ -132,31 +135,27 @@ public void setUp() handler = new MetricsRecordHandler(mockS3, mockSecretsManager, mockAthena, mockMetrics, com.google.common.collect.ImmutableMap.of()); spillReader = new S3BlockSpillReader(mockS3, allocator); - Mockito.lenient().when(mockS3.putObject(any())) + Mockito.lenient().when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - Mockito.lenient().when(mockS3.getObject(nullable(String.class), nullable(String.class))) + Mockito.lenient().when(mockS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); } @@ -183,17 +182,23 @@ public void readMetricsWithConstraint() ListMetricsRequest request = invocation.getArgument(0, ListMetricsRequest.class); numCalls.incrementAndGet(); //assert that the namespace filter was indeed pushed down - assertEquals(namespace, request.getNamespace()); - String nextToken = (request.getNextToken() == null) ? "valid" : null; + assertEquals(namespace, request.namespace()); + String nextToken = (request.nextToken() == null) ? "valid" : null; List metrics = new ArrayList<>(); for (int i = 0; i < numMetrics; i++) { - metrics.add(new Metric().withNamespace(namespace).withMetricName("metric-" + i) - .withDimensions(new Dimension().withName(dimName).withValue(dimValue))); - metrics.add(new Metric().withNamespace(namespace + i).withMetricName("metric-" + i)); + metrics.add(Metric.builder() + .namespace(namespace) + .metricName("metric-" + i) + .dimensions(Dimension.builder() + .name(dimName) + .value(dimValue) + .build()) + .build()); + metrics.add(Metric.builder().namespace(namespace + i).metricName("metric-" + i).build()); } - return new ListMetricsResult().withNextToken(nextToken).withMetrics(metrics); + return ListMetricsResponse.builder().nextToken(nextToken).metrics(metrics).build(); }); Map constraintsMap = new HashMap<>(); @@ -246,7 +251,7 @@ public void readMetricSamplesWithConstraint() String period = "60"; String dimName = "dimName"; String dimValue = "dimValue"; - List dimensions = Collections.singletonList(new Dimension().withName(dimName).withValue(dimValue)); + List dimensions = Collections.singletonList(Dimension.builder().name(dimName).value(dimValue).build()); int numMetrics = 10; int numSamples = 10; @@ -270,13 +275,15 @@ public void readMetricSamplesWithConstraint() .build(); List metricStats = new ArrayList<>(); - metricStats.add(new MetricStat() - .withMetric(new Metric() - .withNamespace(namespace) - .withMetricName(metricName) - .withDimensions(dimensions)) - .withPeriod(60) - .withStat(statistic)); + metricStats.add(MetricStat.builder() + .metric(Metric.builder() + .namespace(namespace) + .metricName(metricName) + .dimensions(dimensions) + .build()) + .period(60) + .stat(statistic) + .build()); Split split = Split.newBuilder(spillLocation, keyFactory.create()) .add(MetricStatSerDe.SERIALIZED_METRIC_STATS_FIELD_NAME, MetricStatSerDe.serialize(metricStats)) @@ -310,40 +317,40 @@ public void readMetricSamplesWithConstraint() logger.info("readMetricSamplesWithConstraint: exit"); } - private GetMetricDataResult mockMetricData(InvocationOnMock invocation, int numMetrics, int numSamples) + private GetMetricDataResponse mockMetricData(InvocationOnMock invocation, int numMetrics, int numSamples) { GetMetricDataRequest request = invocation.getArgument(0, GetMetricDataRequest.class); /** * Confirm that all available criteria were pushed down into Cloudwatch Metrics */ - List queries = request.getMetricDataQueries(); + List queries = request.metricDataQueries(); assertEquals(1, queries.size()); MetricDataQuery query = queries.get(0); - MetricStat stat = query.getMetricStat(); - assertEquals("m1", query.getId()); - assertNotNull(stat.getPeriod()); - assertNotNull(stat.getMetric()); - assertNotNull(stat.getStat()); - assertNotNull(stat.getMetric().getMetricName()); - assertNotNull(stat.getMetric().getNamespace()); - assertNotNull(stat.getMetric().getDimensions()); - assertEquals(1, stat.getMetric().getDimensions().size()); - - String nextToken = (request.getNextToken() == null) ? "valid" : null; + MetricStat stat = query.metricStat(); + assertEquals("m1", query.id()); + assertNotNull(stat.period()); + assertNotNull(stat.metric()); + assertNotNull(stat.stat()); + assertNotNull(stat.metric().metricName()); + assertNotNull(stat.metric().namespace()); + assertNotNull(stat.metric().dimensions()); + assertEquals(1, stat.metric().dimensions().size()); + + String nextToken = (request.nextToken() == null) ? "valid" : null; List samples = new ArrayList<>(); for (int i = 0; i < numMetrics; i++) { List values = new ArrayList<>(); - List timestamps = new ArrayList<>(); + List timestamps = new ArrayList<>(); for (double j = 0; j < numSamples; j++) { values.add(j); - timestamps.add(new Date(System.currentTimeMillis() + (int) j)); + timestamps.add(new Date(System.currentTimeMillis() + (int) j).toInstant()); } - samples.add(new MetricDataResult().withValues(values).withTimestamps(timestamps).withId("m1")); + samples.add(MetricDataResult.builder().values(values).timestamps(timestamps).id("m1").build()); } - return new GetMetricDataResult().withNextToken(nextToken).withMetricDataResults(samples); + return GetMetricDataResponse.builder().nextToken(nextToken).metricDataResults(samples).build(); } private class ByteHolder diff --git a/athena-cloudwatch/Dockerfile b/athena-cloudwatch/Dockerfile new file mode 100644 index 0000000000..9859ff8b4c --- /dev/null +++ b/athena-cloudwatch/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-cloudwatch-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-cloudwatch-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.cloudwatch.CloudwatchCompositeHandler" ] \ No newline at end of file diff --git a/athena-cloudwatch/athena-cloudwatch.yaml b/athena-cloudwatch/athena-cloudwatch.yaml index a5d69fbb2e..2e301dc882 100644 --- a/athena-cloudwatch/athena-cloudwatch.yaml +++ b/athena-cloudwatch/athena-cloudwatch.yaml @@ -66,10 +66,9 @@ Resources: spill_prefix: !Ref SpillPrefix kms_key_id: !If [HasKMSKeyId, !Ref KMSKeyId, !Ref "AWS::NoValue"] FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.cloudwatch.CloudwatchCompositeHandler" - CodeUri: "./target/athena-cloudwatch-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudwatch:2022.47.1' Description: "Enables Amazon Athena to communicate with Cloudwatch, making your log accessible via SQL" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory Role: !If [NotHasLambdaRole, !GetAtt FunctionRole.Arn, !Ref LambdaRole] diff --git a/athena-cloudwatch/pom.xml b/athena-cloudwatch/pom.xml index bd2dad00d8..95a34d1d37 100644 --- a/athena-cloudwatch/pom.xml +++ b/athena-cloudwatch/pom.xml @@ -29,15 +29,35 @@ test - com.amazonaws - aws-java-sdk-logs - ${aws-sdk.version} + software.amazon.awssdk + cloudwatchlogs + 2.28.2 + + + + commons-logging + commons-logging + + + software.amazon.awssdk + netty-nio-client + + + + + software.amazon.awssdk + cloudwatch + ${aws-sdk-v2.version} commons-logging commons-logging + + software.amazon.awssdk + netty-nio-client + diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchExceptionFilter.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchExceptionFilter.java index c71db552cf..093aeedd7e 100644 --- a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchExceptionFilter.java +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchExceptionFilter.java @@ -20,8 +20,8 @@ package com.amazonaws.athena.connectors.cloudwatch; import com.amazonaws.athena.connector.lambda.ThrottlingInvoker; -import com.amazonaws.services.logs.model.AWSLogsException; -import com.amazonaws.services.logs.model.LimitExceededException; +import software.amazon.awssdk.services.cloudwatch.model.LimitExceededException; +import software.amazon.awssdk.services.cloudwatchlogs.model.CloudWatchLogsException; /** * Used to identify Exceptions that are related to Cloudwatch Logs throttling events. @@ -36,7 +36,7 @@ private CloudwatchExceptionFilter() {} @Override public boolean isMatch(Exception ex) { - if (ex instanceof AWSLogsException && ex.getMessage().startsWith("Rate exceeded")) { + if (ex instanceof CloudWatchLogsException && ex.getMessage().startsWith("Rate exceeded")) { return true; } diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandler.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandler.java index cd52e12683..e62ca50477 100644 --- a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandler.java +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandler.java @@ -43,17 +43,6 @@ import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connectors.cloudwatch.qpt.CloudwatchQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.logs.AWSLogs; -import com.amazonaws.services.logs.AWSLogsClientBuilder; -import com.amazonaws.services.logs.model.DescribeLogGroupsRequest; -import com.amazonaws.services.logs.model.DescribeLogGroupsResult; -import com.amazonaws.services.logs.model.DescribeLogStreamsRequest; -import com.amazonaws.services.logs.model.DescribeLogStreamsResult; -import com.amazonaws.services.logs.model.GetQueryResultsResult; -import com.amazonaws.services.logs.model.LogStream; -import com.amazonaws.services.logs.model.ResultField; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -62,6 +51,16 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogGroupsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogGroupsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogStreamsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogStreamsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetQueryResultsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.LogStream; +import software.amazon.awssdk.services.cloudwatchlogs.model.ResultField; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.Collections; @@ -123,7 +122,7 @@ public class CloudwatchMetadataHandler .build(); } - private final AWSLogs awsLogs; + private final CloudWatchLogsClient awsLogs; private final ThrottlingInvoker invoker; private final CloudwatchTableResolver tableResolver; private final CloudwatchQueryPassthrough queryPassthrough = new CloudwatchQueryPassthrough(); @@ -131,17 +130,17 @@ public class CloudwatchMetadataHandler public CloudwatchMetadataHandler(java.util.Map configOptions) { super(SOURCE_TYPE, configOptions); - this.awsLogs = AWSLogsClientBuilder.standard().build(); + this.awsLogs = CloudWatchLogsClient.create(); this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build(); - this.tableResolver = new CloudwatchTableResolver(this.invoker, awsLogs, MAX_RESULTS, MAX_RESULTS); + this.tableResolver = new CloudwatchTableResolver(this.invoker, awsLogs, MAX_RESULTS, MAX_RESULTS); } @VisibleForTesting protected CloudwatchMetadataHandler( - AWSLogs awsLogs, + CloudWatchLogsClient awsLogs, EncryptionKeyFactory keyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, java.util.Map configOptions) @@ -161,19 +160,19 @@ protected CloudwatchMetadataHandler( public ListSchemasResponse doListSchemaNames(BlockAllocator blockAllocator, ListSchemasRequest listSchemasRequest) throws TimeoutException { - DescribeLogGroupsRequest request = new DescribeLogGroupsRequest(); - DescribeLogGroupsResult result; + DescribeLogGroupsRequest.Builder requestBuilder = DescribeLogGroupsRequest.builder(); + DescribeLogGroupsResponse response; List schemas = new ArrayList<>(); do { if (schemas.size() > MAX_RESULTS) { throw new RuntimeException("Too many log groups, exceeded max metadata results for schema count."); } - result = invoker.invoke(() -> awsLogs.describeLogGroups(request)); - result.getLogGroups().forEach(next -> schemas.add(next.getLogGroupName())); - request.setNextToken(result.getNextToken()); - logger.info("doListSchemaNames: Listing log groups {} {}", result.getNextToken(), schemas.size()); + response = invoker.invoke(() -> awsLogs.describeLogGroups(requestBuilder.build())); + response.logGroups().forEach(next -> schemas.add(next.logGroupName())); + requestBuilder.nextToken(response.nextToken()); + logger.info("doListSchemaNames: Listing log groups {} {}", response.nextToken(), schemas.size()); } - while (result.getNextToken() != null); + while (response.nextToken() != null); return new ListSchemasResponse(listSchemasRequest.getCatalogName(), schemas); } @@ -189,28 +188,28 @@ public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTables { String nextToken = null; String logGroupName = tableResolver.validateSchema(listTablesRequest.getSchemaName()); - DescribeLogStreamsRequest request = new DescribeLogStreamsRequest(logGroupName); - DescribeLogStreamsResult result; + DescribeLogStreamsRequest.Builder requestBuilder = DescribeLogStreamsRequest.builder().logGroupName(logGroupName); + DescribeLogStreamsResponse response; List tables = new ArrayList<>(); if (listTablesRequest.getPageSize() == UNLIMITED_PAGE_SIZE_VALUE) { do { if (tables.size() > MAX_RESULTS) { throw new RuntimeException("Too many log streams, exceeded max metadata results for table count."); } - result = invoker.invoke(() -> awsLogs.describeLogStreams(request)); - result.getLogStreams().forEach(next -> tables.add(toTableName(listTablesRequest, next))); - request.setNextToken(result.getNextToken()); - logger.info("doListTables: Listing log streams with token {} and size {}", result.getNextToken(), tables.size()); + response = invoker.invoke(() -> awsLogs.describeLogStreams(requestBuilder.build())); + response.logStreams().forEach(next -> tables.add(toTableName(listTablesRequest, next))); + requestBuilder.nextToken(response.nextToken()); + logger.info("doListTables: Listing log streams with token {} and size {}", response.nextToken(), tables.size()); } - while (result.getNextToken() != null); + while (response.nextToken() != null); } else { - request.setNextToken(listTablesRequest.getNextToken()); - request.setLimit(listTablesRequest.getPageSize()); - result = invoker.invoke(() -> awsLogs.describeLogStreams(request)); - result.getLogStreams().forEach(next -> tables.add(toTableName(listTablesRequest, next))); - nextToken = result.getNextToken(); - logger.info("doListTables: Listing log streams with token {} and size {}", result.getNextToken(), tables.size()); + requestBuilder.nextToken(listTablesRequest.getNextToken()); + requestBuilder.limit(listTablesRequest.getPageSize()); + response = invoker.invoke(() -> awsLogs.describeLogStreams(requestBuilder.build())); + response.logStreams().forEach(next -> tables.add(toTableName(listTablesRequest, next))); + nextToken = response.nextToken(); + logger.info("doListTables: Listing log streams with token {} and size {}", response.nextToken(), tables.size()); } // Don't add the ALL_LOG_STREAMS_TABLE unless we're at the end of listing out all the tables. @@ -276,26 +275,26 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request CloudwatchTableName cwTableName = tableResolver.validateTable(request.getTableName()); - DescribeLogStreamsRequest cwRequest = new DescribeLogStreamsRequest(cwTableName.getLogGroupName()); + DescribeLogStreamsRequest.Builder cwRequestBuilder = DescribeLogStreamsRequest.builder().logGroupName(cwTableName.getLogGroupName()); if (!ALL_LOG_STREAMS_TABLE.equals(cwTableName.getLogStreamName())) { - cwRequest.setLogStreamNamePrefix(cwTableName.getLogStreamName()); + cwRequestBuilder.logStreamNamePrefix(cwTableName.getLogStreamName()); } - DescribeLogStreamsResult result; + DescribeLogStreamsResponse response; do { - result = invoker.invoke(() -> awsLogs.describeLogStreams(cwRequest)); - for (LogStream next : result.getLogStreams()) { + response = invoker.invoke(() -> awsLogs.describeLogStreams(cwRequestBuilder.build())); + for (LogStream next : response.logStreams()) { //Each log stream that matches any possible partition pruning should be added to the partition list. blockWriter.writeRows((Block block, int rowNum) -> { - boolean matched = block.setValue(LOG_GROUP_FIELD, rowNum, cwRequest.getLogGroupName()); - matched &= block.setValue(LOG_STREAM_FIELD, rowNum, next.getLogStreamName()); - matched &= block.setValue(LOG_STREAM_SIZE_FIELD, rowNum, next.getStoredBytes()); + boolean matched = block.setValue(LOG_GROUP_FIELD, rowNum, cwRequestBuilder.build().logGroupName()); + matched &= block.setValue(LOG_STREAM_FIELD, rowNum, next.logStreamName()); + matched &= block.setValue(LOG_STREAM_SIZE_FIELD, rowNum, next.storedBytes()); return matched ? 1 : 0; }); } - cwRequest.setNextToken(result.getNextToken()); + cwRequestBuilder.nextToken(response.nextToken()); } - while (result.getNextToken() != null && queryStatusChecker.isQueryRunning()); + while (response.nextToken() != null && queryStatusChecker.isQueryRunning()); } /** @@ -367,11 +366,11 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge throw new IllegalArgumentException("No Query passed through [{}]" + request); } // to get column names with limit 1 - GetQueryResultsResult getQueryResultsResult = getResult(invoker, awsLogs, request.getQueryPassthroughArguments(), 1); + GetQueryResultsResponse getQueryResultsResponse = getResult(invoker, awsLogs, request.getQueryPassthroughArguments(), 1); SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); - if (!getQueryResultsResult.getResults().isEmpty()) { - for (ResultField field : getQueryResultsResult.getResults().get(0)) { - schemaBuilder.addField(field.getField(), Types.MinorType.VARCHAR.getType()); + if (!getQueryResultsResponse.results().isEmpty()) { + for (ResultField field : getQueryResultsResponse.results().get(0)) { + schemaBuilder.addField(field.field(), Types.MinorType.VARCHAR.getType()); } } @@ -415,6 +414,6 @@ private String encodeContinuationToken(int partition) */ private TableName toTableName(ListTablesRequest request, LogStream logStream) { - return new TableName(request.getSchemaName(), logStream.getLogStreamName()); + return new TableName(request.getSchemaName(), logStream.logStreamName()); } } diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java index a5d29f0f9b..912b94d218 100644 --- a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java @@ -32,22 +32,18 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.cloudwatch.qpt.CloudwatchQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.logs.AWSLogs; -import com.amazonaws.services.logs.AWSLogsClientBuilder; -import com.amazonaws.services.logs.model.GetLogEventsRequest; -import com.amazonaws.services.logs.model.GetLogEventsResult; -import com.amazonaws.services.logs.model.GetQueryResultsResult; -import com.amazonaws.services.logs.model.OutputLogEvent; -import com.amazonaws.services.logs.model.ResultField; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetLogEventsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetLogEventsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetQueryResultsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.OutputLogEvent; +import software.amazon.awssdk.services.cloudwatchlogs.model.ResultField; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.List; import java.util.Map; @@ -78,21 +74,21 @@ public class CloudwatchRecordHandler //Used to handle Throttling events and apply AIMD congestion control private final ThrottlingInvoker invoker; private final AtomicLong count = new AtomicLong(0); - private final AWSLogs awsLogs; + private final CloudWatchLogsClient awsLogs; private final CloudwatchQueryPassthrough queryPassthrough = new CloudwatchQueryPassthrough(); public CloudwatchRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), - AWSLogsClientBuilder.defaultClient(), + S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), + CloudWatchLogsClient.create(), configOptions); } @VisibleForTesting - protected CloudwatchRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, AWSLogs awsLogs, java.util.Map configOptions) + protected CloudwatchRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, CloudWatchLogsClient awsLogs, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.awsLogs = awsLogs; @@ -118,37 +114,38 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor invoker.setBlockSpiller(spiller); do { final String actualContinuationToken = continuationToken; - GetLogEventsResult logEventsResult = invoker.invoke(() -> awsLogs.getLogEvents( + GetLogEventsResponse logEventsResponse = invoker.invoke(() -> awsLogs.getLogEvents( pushDownConstraints(recordsRequest.getConstraints(), - new GetLogEventsRequest() - .withLogGroupName(split.getProperty(LOG_GROUP_FIELD)) + GetLogEventsRequest.builder() + .logGroupName(split.getProperty(LOG_GROUP_FIELD)) //We use the property instead of the table name because of the special all_streams table - .withLogStreamName(split.getProperty(LOG_STREAM_FIELD)) - .withNextToken(actualContinuationToken) + .logStreamName(split.getProperty(LOG_STREAM_FIELD)) + .nextToken(actualContinuationToken) // must be set to use nextToken correctly - .withStartFromHead(true) + .startFromHead(true) + .build() ))); - if (continuationToken == null || !continuationToken.equals(logEventsResult.getNextForwardToken())) { - continuationToken = logEventsResult.getNextForwardToken(); + if (continuationToken == null || !continuationToken.equals(logEventsResponse.nextForwardToken())) { + continuationToken = logEventsResponse.nextForwardToken(); } else { continuationToken = null; } - for (OutputLogEvent ole : logEventsResult.getEvents()) { + for (OutputLogEvent ole : logEventsResponse.events()) { spiller.writeRows((Block block, int rowNum) -> { boolean matched = true; matched &= block.offerValue(LOG_STREAM_FIELD, rowNum, split.getProperty(LOG_STREAM_FIELD)); - matched &= block.offerValue(LOG_TIME_FIELD, rowNum, ole.getTimestamp()); - matched &= block.offerValue(LOG_MSG_FIELD, rowNum, ole.getMessage()); + matched &= block.offerValue(LOG_TIME_FIELD, rowNum, ole.timestamp()); + matched &= block.offerValue(LOG_MSG_FIELD, rowNum, ole.message()); return matched ? 1 : 0; }); } logger.info("readWithConstraint: LogGroup[{}] LogStream[{}] Continuation[{}] rows[{}]", tableName.getSchemaName(), tableName.getTableName(), continuationToken, - logEventsResult.getEvents().size()); + logEventsResponse.events().size()); } while (continuationToken != null && queryStatusChecker.isQueryRunning()); } @@ -158,13 +155,13 @@ private void getQueryPassthreoughResults(BlockSpiller spiller, ReadRecordsReques { Map qptArguments = recordsRequest.getConstraints().getQueryPassthroughArguments(); queryPassthrough.verify(qptArguments); - GetQueryResultsResult getQueryResultsResult = getResult(invoker, awsLogs, qptArguments, Integer.parseInt(qptArguments.get(CloudwatchQueryPassthrough.LIMIT))); + GetQueryResultsResponse getQueryResultsResponse = getResult(invoker, awsLogs, qptArguments, Integer.parseInt(qptArguments.get(CloudwatchQueryPassthrough.LIMIT))); - for (List resultList : getQueryResultsResult.getResults()) { + for (List resultList : getQueryResultsResponse.results()) { spiller.writeRows((Block block, int rowNum) -> { for (ResultField resultField : resultList) { boolean matched = true; - matched &= block.offerValue(resultField.getField(), rowNum, resultField.getValue()); + matched &= block.offerValue(resultField.field(), rowNum, resultField.value()); if (!matched) { return 0; } @@ -184,6 +181,7 @@ private void getQueryPassthreoughResults(BlockSpiller spiller, ReadRecordsReques */ private GetLogEventsRequest pushDownConstraints(Constraints constraints, GetLogEventsRequest request) { + GetLogEventsRequest.Builder requestBuilder = request.toBuilder(); ValueSet timeConstraint = constraints.getSummary().get(LOG_TIME_FIELD); if (timeConstraint instanceof SortedRangeSet && !timeConstraint.isNullAllowed()) { //SortedRangeSet is how >, <, between is represented which are easiest and most common when @@ -195,15 +193,15 @@ private GetLogEventsRequest pushDownConstraints(Constraints constraints, GetLogE if (!basicPredicate.getLow().isNullValue()) { Long lowerBound = (Long) basicPredicate.getLow().getValue(); - request.setStartTime(lowerBound); + requestBuilder.startTime(lowerBound); } if (!basicPredicate.getHigh().isNullValue()) { Long upperBound = (Long) basicPredicate.getHigh().getValue(); - request.setEndTime(upperBound); + requestBuilder.endTime(upperBound); } } - return request; + return requestBuilder.build(); } } diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchTableResolver.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchTableResolver.java index 4c7f25ec7e..d4059b0438 100644 --- a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchTableResolver.java +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchTableResolver.java @@ -21,18 +21,18 @@ import com.amazonaws.athena.connector.lambda.ThrottlingInvoker; import com.amazonaws.athena.connector.lambda.domain.TableName; -import com.amazonaws.services.logs.AWSLogs; -import com.amazonaws.services.logs.model.DescribeLogGroupsRequest; -import com.amazonaws.services.logs.model.DescribeLogGroupsResult; -import com.amazonaws.services.logs.model.DescribeLogStreamsRequest; -import com.amazonaws.services.logs.model.DescribeLogStreamsResult; -import com.amazonaws.services.logs.model.LogGroup; -import com.amazonaws.services.logs.model.LogStream; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogGroupsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogGroupsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogStreamsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogStreamsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.LogGroup; +import software.amazon.awssdk.services.cloudwatchlogs.model.LogStream; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; @@ -51,7 +51,7 @@ public class CloudwatchTableResolver { private static final Logger logger = LoggerFactory.getLogger(CloudwatchTableResolver.class); - private AWSLogs awsLogs; + private CloudWatchLogsClient logsClient; //Used to handle Throttling events using an AIMD strategy for congestion control. private ThrottlingInvoker invoker; //The LogStream pattern that is capitalized by LAMBDA @@ -67,14 +67,14 @@ public class CloudwatchTableResolver * Constructs an instance of the table resolver. * * @param invoker The ThrottlingInvoker to use to handle throttling events. - * @param awsLogs The AWSLogs client to use for cache misses. + * @param logsClient The AWSLogs client to use for cache misses. * @param maxSchemaCacheSize The max number of schemas to cache. * @param maxTableCacheSize The max tables to cache. */ - public CloudwatchTableResolver(ThrottlingInvoker invoker, AWSLogs awsLogs, long maxSchemaCacheSize, long maxTableCacheSize) + public CloudwatchTableResolver(ThrottlingInvoker invoker, CloudWatchLogsClient logsClient, long maxSchemaCacheSize, long maxTableCacheSize) { this.invoker = invoker; - this.awsLogs = awsLogs; + this.logsClient = logsClient; this.tableCache = CacheBuilder.newBuilder() .maximumSize(maxTableCacheSize) .build( @@ -119,12 +119,12 @@ private CloudwatchTableName loadLogStreams(String logGroup, String logStream) logger.info("loadLogStreams: Did not find a match for the table, falling back to LogGroup scan for {}:{}", logGroup, logStream); - DescribeLogStreamsRequest validateTableRequest = new DescribeLogStreamsRequest(logGroup); - DescribeLogStreamsResult validateTableResult; + DescribeLogStreamsRequest.Builder validateTableRequestBuilder = DescribeLogStreamsRequest.builder().logGroupName(logGroup); + DescribeLogStreamsResponse validateTableResponse; do { - validateTableResult = invoker.invoke(() -> awsLogs.describeLogStreams(validateTableRequest)); - for (LogStream nextStream : validateTableResult.getLogStreams()) { - String logStreamName = nextStream.getLogStreamName(); + validateTableResponse = invoker.invoke(() -> logsClient.describeLogStreams(validateTableRequestBuilder.build())); + for (LogStream nextStream : validateTableResponse.logStreams()) { + String logStreamName = nextStream.logStreamName(); CloudwatchTableName nextCloudwatch = new CloudwatchTableName(logGroup, logStreamName); tableCache.put(nextCloudwatch.toTableName(), nextCloudwatch); if (nextCloudwatch.getLogStreamName().equalsIgnoreCase(logStream)) { @@ -134,9 +134,9 @@ private CloudwatchTableName loadLogStreams(String logGroup, String logStream) return nextCloudwatch; } } - validateTableRequest.setNextToken(validateTableResult.getNextToken()); + validateTableRequestBuilder.nextToken(validateTableResponse.nextToken()); } - while (validateTableResult.getNextToken() != null); + while (validateTableResponse.nextToken() != null); //We could not find a match throw new IllegalArgumentException("No such table " + logGroup + " " + logStream); @@ -163,11 +163,11 @@ private CloudwatchTableName loadLogStream(String logGroup, String logStream) LAMBDA_PATTERN, effectiveTableName); effectiveTableName = effectiveTableName.replace(LAMBDA_PATTERN, LAMBDA_ACTUAL_PATTERN); } - DescribeLogStreamsRequest request = new DescribeLogStreamsRequest(logGroup) - .withLogStreamNamePrefix(effectiveTableName); - DescribeLogStreamsResult result = invoker.invoke(() -> awsLogs.describeLogStreams(request)); - for (LogStream nextStream : result.getLogStreams()) { - String logStreamName = nextStream.getLogStreamName(); + DescribeLogStreamsRequest request = DescribeLogStreamsRequest.builder().logGroupName(logGroup) + .logStreamNamePrefix(effectiveTableName).build(); + DescribeLogStreamsResponse response = invoker.invoke(() -> logsClient.describeLogStreams(request)); + for (LogStream nextStream : response.logStreams()) { + String logStreamName = nextStream.logStreamName(); CloudwatchTableName nextCloudwatch = new CloudwatchTableName(logGroup, logStreamName); if (nextCloudwatch.getLogStreamName().equalsIgnoreCase(logStream)) { logger.info("loadLogStream: Matched {} for {}:{}", nextCloudwatch, logGroup, logStream); @@ -195,21 +195,21 @@ private String loadLogGroups(String schemaName) } logger.info("loadLogGroups: Did not find a match for the schema, falling back to LogGroup scan for {}", schemaName); - DescribeLogGroupsRequest validateSchemaRequest = new DescribeLogGroupsRequest(); - DescribeLogGroupsResult validateSchemaResult; + DescribeLogGroupsRequest.Builder validateSchemaRequestBuilder = DescribeLogGroupsRequest.builder(); + DescribeLogGroupsResponse validateSchemaResponse; do { - validateSchemaResult = invoker.invoke(() -> awsLogs.describeLogGroups(validateSchemaRequest)); - for (LogGroup next : validateSchemaResult.getLogGroups()) { - String nextLogGroupName = next.getLogGroupName(); + validateSchemaResponse = invoker.invoke(() -> logsClient.describeLogGroups(validateSchemaRequestBuilder.build())); + for (LogGroup next : validateSchemaResponse.logGroups()) { + String nextLogGroupName = next.logGroupName(); schemaCache.put(schemaName, nextLogGroupName); if (nextLogGroupName.equalsIgnoreCase(schemaName)) { logger.info("loadLogGroups: Matched {} for {}", nextLogGroupName, schemaName); return nextLogGroupName; } } - validateSchemaRequest.setNextToken(validateSchemaResult.getNextToken()); + validateSchemaRequestBuilder.nextToken(validateSchemaResponse.nextToken()); } - while (validateSchemaResult.getNextToken() != null); + while (validateSchemaResponse.nextToken() != null); //We could not find a match throw new IllegalArgumentException("No such schema " + schemaName); @@ -224,10 +224,10 @@ private String loadLogGroups(String schemaName) private String loadLogGroup(String schemaName) throws TimeoutException { - DescribeLogGroupsRequest request = new DescribeLogGroupsRequest().withLogGroupNamePrefix(schemaName); - DescribeLogGroupsResult result = invoker.invoke(() -> awsLogs.describeLogGroups(request)); - for (LogGroup next : result.getLogGroups()) { - String nextLogGroupName = next.getLogGroupName(); + DescribeLogGroupsRequest request = DescribeLogGroupsRequest.builder().logGroupNamePrefix(schemaName).build(); + DescribeLogGroupsResponse response = invoker.invoke(() -> logsClient.describeLogGroups(request)); + for (LogGroup next : response.logGroups()) { + String nextLogGroupName = next.logGroupName(); if (nextLogGroupName.equalsIgnoreCase(schemaName)) { logger.info("loadLogGroup: Matched {} for {}", nextLogGroupName, schemaName); return nextLogGroupName; diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchUtils.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchUtils.java index 5c19ec17ee..bb8a209d47 100644 --- a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchUtils.java +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchUtils.java @@ -21,13 +21,14 @@ import com.amazonaws.athena.connector.lambda.ThrottlingInvoker; import com.amazonaws.athena.connectors.cloudwatch.qpt.CloudwatchQueryPassthrough; -import com.amazonaws.services.logs.AWSLogs; -import com.amazonaws.services.logs.model.GetQueryResultsRequest; -import com.amazonaws.services.logs.model.GetQueryResultsResult; -import com.amazonaws.services.logs.model.StartQueryRequest; -import com.amazonaws.services.logs.model.StartQueryResult; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetQueryResultsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetQueryResultsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.QueryStatus; +import software.amazon.awssdk.services.cloudwatchlogs.model.StartQueryRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.StartQueryResponse; import java.time.Instant; import java.time.temporal.ChronoUnit; @@ -41,8 +42,8 @@ public final class CloudwatchUtils private CloudwatchUtils() {} public static StartQueryRequest startQueryRequest(Map qptArguments) { - return new StartQueryRequest().withEndTime(Long.valueOf(qptArguments.get(CloudwatchQueryPassthrough.ENDTIME))).withStartTime(Long.valueOf(qptArguments.get(CloudwatchQueryPassthrough.STARTTIME))) - .withQueryString(qptArguments.get(CloudwatchQueryPassthrough.QUERYSTRING)).withLogGroupNames(getLogGroupNames(qptArguments)); + return StartQueryRequest.builder().endTime(Long.valueOf(qptArguments.get(CloudwatchQueryPassthrough.ENDTIME))).startTime(Long.valueOf(qptArguments.get(CloudwatchQueryPassthrough.STARTTIME))) + .queryString(qptArguments.get(CloudwatchQueryPassthrough.QUERYSTRING)).logGroupNames(getLogGroupNames(qptArguments)).build(); } private static String[] getLogGroupNames(Map qptArguments) @@ -55,25 +56,25 @@ private static String[] getLogGroupNames(Map qptArguments) return logGroupNames; } - public static StartQueryResult getQueryResult(AWSLogs awsLogs, StartQueryRequest startQueryRequest) + public static StartQueryResponse getQueryResult(CloudWatchLogsClient awsLogs, StartQueryRequest startQueryRequest) { return awsLogs.startQuery(startQueryRequest); } - public static GetQueryResultsResult getQueryResults(AWSLogs awsLogs, StartQueryResult startQueryResult) + public static GetQueryResultsResponse getQueryResults(CloudWatchLogsClient awsLogs, StartQueryResponse startQueryResponse) { - return awsLogs.getQueryResults(new GetQueryResultsRequest().withQueryId(startQueryResult.getQueryId())); + return awsLogs.getQueryResults(GetQueryResultsRequest.builder().queryId(startQueryResponse.queryId()).build()); } - public static GetQueryResultsResult getResult(ThrottlingInvoker invoker, AWSLogs awsLogs, Map qptArguments, int limit) throws TimeoutException, InterruptedException + public static GetQueryResultsResponse getResult(ThrottlingInvoker invoker, CloudWatchLogsClient awsLogs, Map qptArguments, int limit) throws TimeoutException, InterruptedException { - StartQueryResult startQueryResult = invoker.invoke(() -> getQueryResult(awsLogs, startQueryRequest(qptArguments).withLimit(limit))); - String status = null; - GetQueryResultsResult getQueryResultsResult; + StartQueryResponse startQueryResponse = invoker.invoke(() -> getQueryResult(awsLogs, startQueryRequest(qptArguments).toBuilder().limit(limit).build())); + QueryStatus status = null; + GetQueryResultsResponse getQueryResultsResponse; Instant startTime = Instant.now(); // Record the start time do { - getQueryResultsResult = invoker.invoke(() -> getQueryResults(awsLogs, startQueryResult)); - status = getQueryResultsResult.getStatus(); + getQueryResultsResponse = invoker.invoke(() -> getQueryResults(awsLogs, startQueryResponse)); + status = getQueryResultsResponse.status(); Thread.sleep(1000); // Check if 10 minutes have passed @@ -82,8 +83,8 @@ public static GetQueryResultsResult getResult(ThrottlingInvoker invoker, AWSLogs if (elapsedMinutes >= RESULT_TIMEOUT) { throw new RuntimeException("Query execution timeout exceeded."); } - } while (!status.equalsIgnoreCase("Complete")); + } while (!status.equals(QueryStatus.COMPLETE)); - return getQueryResultsResult; + return getQueryResultsResponse; } } diff --git a/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandlerTest.java b/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandlerTest.java index 22a876dbae..f615b3c7b1 100644 --- a/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandlerTest.java +++ b/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandlerTest.java @@ -43,15 +43,6 @@ import com.amazonaws.athena.connector.lambda.metadata.MetadataResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.logs.AWSLogs; -import com.amazonaws.services.logs.model.DescribeLogGroupsRequest; -import com.amazonaws.services.logs.model.DescribeLogGroupsResult; -import com.amazonaws.services.logs.model.DescribeLogStreamsRequest; -import com.amazonaws.services.logs.model.DescribeLogStreamsResult; -import com.amazonaws.services.logs.model.LogGroup; -import com.amazonaws.services.logs.model.LogStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Schema; @@ -65,6 +56,15 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogGroupsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogGroupsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogStreamsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.DescribeLogStreamsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.LogGroup; +import software.amazon.awssdk.services.cloudwatchlogs.model.LogStream; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.Collections; @@ -92,26 +92,32 @@ public class CloudwatchMetadataHandlerTest private BlockAllocator allocator; @Mock - private AWSLogs mockAwsLogs; + private CloudWatchLogsClient mockAwsLogs; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Before public void setUp() throws Exception { Mockito.lenient().when(mockAwsLogs.describeLogStreams(nullable(DescribeLogStreamsRequest.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { - return new DescribeLogStreamsResult().withLogStreams(new LogStream().withLogStreamName("table-9"), - new LogStream().withLogStreamName("table-10")); + return DescribeLogStreamsResponse.builder() + .logStreams( + LogStream.builder().logStreamName("table-9").build(), + LogStream.builder().logStreamName("table-10").build()) + .build(); }); when(mockAwsLogs.describeLogGroups(nullable(DescribeLogGroupsRequest.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { - return new DescribeLogGroupsResult().withLogGroups(new LogGroup().withLogGroupName("schema-1"), - new LogGroup().withLogGroupName("schema-20")); + return DescribeLogGroupsResponse.builder() + .logGroups( + LogGroup.builder().logGroupName("schema-1").build(), + LogGroup.builder().logGroupName("schema-20").build()) + .build(); }); handler = new CloudwatchMetadataHandler(mockAwsLogs, new LocalKeyFactory(), mockSecretsManager, mockAthena, "spillBucket", "spillPrefix", com.google.common.collect.ImmutableMap.of()); allocator = new BlockAllocatorImpl(); @@ -133,34 +139,33 @@ public void doListSchemaNames() when(mockAwsLogs.describeLogGroups(nullable(DescribeLogGroupsRequest.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { DescribeLogGroupsRequest request = (DescribeLogGroupsRequest) invocationOnMock.getArguments()[0]; - DescribeLogGroupsResult result = new DescribeLogGroupsResult(); + DescribeLogGroupsResponse.Builder responseBuilder = DescribeLogGroupsResponse.builder(); Integer nextToken; - if (request.getNextToken() == null) { + if (request.nextToken() == null) { nextToken = 1; } - else if (Integer.valueOf(request.getNextToken()) < 3) { - nextToken = Integer.valueOf(request.getNextToken()) + 1; + else if (Integer.valueOf(request.nextToken()) < 3) { + nextToken = Integer.valueOf(request.nextToken()) + 1; } else { nextToken = null; } List logGroups = new ArrayList<>(); - if (request.getNextToken() == null || Integer.valueOf(request.getNextToken()) < 3) { + if (request.nextToken() == null || Integer.valueOf(request.nextToken()) < 3) { for (int i = 0; i < 10; i++) { - LogGroup nextLogGroup = new LogGroup(); - nextLogGroup.setLogGroupName("schema-" + String.valueOf(i)); + LogGroup nextLogGroup = LogGroup.builder().logGroupName("schema-" + String.valueOf(i)).build(); logGroups.add(nextLogGroup); } } - result.withLogGroups(logGroups); + responseBuilder.logGroups(logGroups); if (nextToken != null) { - result.setNextToken(String.valueOf(nextToken)); + responseBuilder.nextToken(String.valueOf(nextToken)); } - return result; + return responseBuilder.build(); }); ListSchemasRequest req = new ListSchemasRequest(identity, "queryId", "default"); @@ -183,34 +188,33 @@ public void doListTables() when(mockAwsLogs.describeLogStreams(nullable(DescribeLogStreamsRequest.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { DescribeLogStreamsRequest request = (DescribeLogStreamsRequest) invocationOnMock.getArguments()[0]; - DescribeLogStreamsResult result = new DescribeLogStreamsResult(); + DescribeLogStreamsResponse.Builder responseBuilder = DescribeLogStreamsResponse.builder(); Integer nextToken; - if (request.getNextToken() == null) { + if (request.nextToken() == null) { nextToken = 1; } - else if (Integer.valueOf(request.getNextToken()) < 3) { - nextToken = Integer.valueOf(request.getNextToken()) + 1; + else if (Integer.valueOf(request.nextToken()) < 3) { + nextToken = Integer.valueOf(request.nextToken()) + 1; } else { nextToken = null; } List logStreams = new ArrayList<>(); - if (request.getNextToken() == null || Integer.valueOf(request.getNextToken()) < 3) { + if (request.nextToken() == null || Integer.valueOf(request.nextToken()) < 3) { for (int i = 0; i < 10; i++) { - LogStream nextLogStream = new LogStream(); - nextLogStream.setLogStreamName("table-" + String.valueOf(i)); + LogStream nextLogStream = LogStream.builder().logStreamName("table-" + String.valueOf(i)).build(); logStreams.add(nextLogStream); } } - result.withLogStreams(logStreams); + responseBuilder.logStreams(logStreams); if (nextToken != null) { - result.setNextToken(String.valueOf(nextToken)); + responseBuilder.nextToken(String.valueOf(nextToken)); } - return result; + return responseBuilder.build(); }); ListTablesRequest req = new ListTablesRequest(identity, "queryId", "default", @@ -238,35 +242,34 @@ public void doGetTable() when(mockAwsLogs.describeLogStreams(nullable(DescribeLogStreamsRequest.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { DescribeLogStreamsRequest request = (DescribeLogStreamsRequest) invocationOnMock.getArguments()[0]; - assertTrue(request.getLogGroupName().equals(expectedSchema)); - DescribeLogStreamsResult result = new DescribeLogStreamsResult(); + assertTrue(request.logGroupName().equals(expectedSchema)); + DescribeLogStreamsResponse.Builder responseBuilder = DescribeLogStreamsResponse.builder(); Integer nextToken; - if (request.getNextToken() == null) { + if (request.nextToken() == null) { nextToken = 1; } - else if (Integer.valueOf(request.getNextToken()) < 3) { - nextToken = Integer.valueOf(request.getNextToken()) + 1; + else if (Integer.valueOf(request.nextToken()) < 3) { + nextToken = Integer.valueOf(request.nextToken()) + 1; } else { nextToken = null; } List logStreams = new ArrayList<>(); - if (request.getNextToken() == null || Integer.valueOf(request.getNextToken()) < 3) { + if (request.nextToken() == null || Integer.valueOf(request.nextToken()) < 3) { for (int i = 0; i < 10; i++) { - LogStream nextLogStream = new LogStream(); - nextLogStream.setLogStreamName("table-" + String.valueOf(i)); + LogStream nextLogStream = LogStream.builder().logStreamName("table-" + String.valueOf(i)).build(); logStreams.add(nextLogStream); } } - result.withLogStreams(logStreams); + responseBuilder.logStreams(logStreams); if (nextToken != null) { - result.setNextToken(String.valueOf(nextToken)); + responseBuilder.nextToken(String.valueOf(nextToken)); } - return result; + return responseBuilder.build(); }); GetTableRequest req = new GetTableRequest(identity, "queryId", "default", new TableName(expectedSchema, "table-9"), Collections.emptyMap()); @@ -290,36 +293,37 @@ public void doGetTableLayout() when(mockAwsLogs.describeLogStreams(nullable(DescribeLogStreamsRequest.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { DescribeLogStreamsRequest request = (DescribeLogStreamsRequest) invocationOnMock.getArguments()[0]; - DescribeLogStreamsResult result = new DescribeLogStreamsResult(); + DescribeLogStreamsResponse.Builder responseBuilder = DescribeLogStreamsResponse.builder(); Integer nextToken; - if (request.getNextToken() == null) { + if (request.nextToken() == null) { nextToken = 1; } - else if (Integer.valueOf(request.getNextToken()) < 3) { - nextToken = Integer.valueOf(request.getNextToken()) + 1; + else if (Integer.valueOf(request.nextToken()) < 3) { + nextToken = Integer.valueOf(request.nextToken()) + 1; } else { nextToken = null; } List logStreams = new ArrayList<>(); - if (request.getNextToken() == null || Integer.valueOf(request.getNextToken()) < 3) { - int continuation = request.getNextToken() == null ? 0 : Integer.valueOf(request.getNextToken()); + if (request.nextToken() == null || Integer.valueOf(request.nextToken()) < 3) { + int continuation = request.nextToken() == null ? 0 : Integer.valueOf(request.nextToken()); for (int i = 0 + continuation * 100; i < 300; i++) { - LogStream nextLogStream = new LogStream(); - nextLogStream.setLogStreamName("table-" + String.valueOf(i)); - nextLogStream.setStoredBytes(i * 1000L); + LogStream nextLogStream = LogStream.builder() + .logStreamName("table-" + String.valueOf(i)) + .storedBytes(i * 1000L) + .build(); logStreams.add(nextLogStream); } } - result.withLogStreams(logStreams); + responseBuilder.logStreams(logStreams); if (nextToken != null) { - result.setNextToken(String.valueOf(nextToken)); + responseBuilder.nextToken(String.valueOf(nextToken)); } - return result; + return responseBuilder.build(); }); Map constraintsMap = new HashMap<>(); diff --git a/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandlerTest.java b/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandlerTest.java index 6e3ec73623..f8b95fdafc 100644 --- a/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandlerTest.java +++ b/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandlerTest.java @@ -39,17 +39,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.logs.AWSLogs; -import com.amazonaws.services.logs.model.GetLogEventsRequest; -import com.amazonaws.services.logs.model.GetLogEventsResult; -import com.amazonaws.services.logs.model.OutputLogEvent; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.Types; @@ -63,6 +52,19 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetLogEventsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetLogEventsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.OutputLogEvent; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -77,7 +79,6 @@ import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -94,16 +95,16 @@ public class CloudwatchRecordHandlerTest private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @Mock - private AWSLogs mockAwsLogs; + private CloudWatchLogsClient mockAwsLogs; @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Before public void setUp() @@ -116,70 +117,67 @@ public void setUp() handler = new CloudwatchRecordHandler(mockS3, mockSecretsManager, mockAthena, mockAwsLogs, com.google.common.collect.ImmutableMap.of()); spillReader = new S3BlockSpillReader(mockS3, allocator); - when(mockS3.putObject(any())) + when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(mockS3.getObject(nullable(String.class), nullable(String.class))) + when(mockS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); when(mockAwsLogs.getLogEvents(nullable(GetLogEventsRequest.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { GetLogEventsRequest request = (GetLogEventsRequest) invocationOnMock.getArguments()[0]; //Check that predicate pushdown was propagated to cloudwatch - assertNotNull(request.getStartTime()); - assertNotNull(request.getEndTime()); + assertNotNull(request.startTime()); + assertNotNull(request.endTime()); - GetLogEventsResult result = new GetLogEventsResult(); + GetLogEventsResponse.Builder responseBuilder = GetLogEventsResponse.builder(); Integer nextToken; - if (request.getNextToken() == null) { + if (request.nextToken() == null) { nextToken = 1; } - else if (Integer.valueOf(request.getNextToken()) < 3) { - nextToken = Integer.valueOf(request.getNextToken()) + 1; + else if (Integer.valueOf(request.nextToken()) < 3) { + nextToken = Integer.valueOf(request.nextToken()) + 1; } else { nextToken = null; } List logEvents = new ArrayList<>(); - if (request.getNextToken() == null || Integer.valueOf(request.getNextToken()) < 3) { - long continuation = request.getNextToken() == null ? 0 : Integer.valueOf(request.getNextToken()); + if (request.nextToken() == null || Integer.valueOf(request.nextToken()) < 3) { + long continuation = request.nextToken() == null ? 0 : Integer.valueOf(request.nextToken()); for (int i = 0; i < 100_000; i++) { - OutputLogEvent outputLogEvent = new OutputLogEvent(); - outputLogEvent.setMessage("message-" + (continuation * i)); - outputLogEvent.setTimestamp(i * 100L); + OutputLogEvent outputLogEvent = OutputLogEvent.builder() + .message("message-" + (continuation * i)) + .timestamp(i * 100L) + .build(); logEvents.add(outputLogEvent); } } - result.withEvents(logEvents); + responseBuilder.events(logEvents); if (nextToken != null) { - result.setNextForwardToken(String.valueOf(nextToken)); + responseBuilder.nextForwardToken(String.valueOf(nextToken)); } - return result; + return responseBuilder.build(); }); } diff --git a/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/integ/CloudwatchIntegTest.java b/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/integ/CloudwatchIntegTest.java index 4f38711800..c9d1dd9f73 100644 --- a/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/integ/CloudwatchIntegTest.java +++ b/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/integ/CloudwatchIntegTest.java @@ -20,12 +20,6 @@ package com.amazonaws.athena.connectors.cloudwatch.integ; import com.amazonaws.athena.connector.integ.IntegrationTestBase; -import com.amazonaws.services.athena.model.Row; -import com.amazonaws.services.logs.AWSLogs; -import com.amazonaws.services.logs.AWSLogsClientBuilder; -import com.amazonaws.services.logs.model.DeleteLogGroupRequest; -import com.amazonaws.services.logs.model.InputLogEvent; -import com.amazonaws.services.logs.model.PutLogEventsRequest; import com.google.common.collect.ImmutableList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,6 +32,10 @@ import software.amazon.awscdk.services.iam.PolicyStatement; import software.amazon.awscdk.services.logs.LogGroup; import software.amazon.awscdk.services.logs.LogStream; +import software.amazon.awssdk.services.athena.model.Row; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import software.amazon.awssdk.services.cloudwatchlogs.model.InputLogEvent; +import software.amazon.awssdk.services.cloudwatchlogs.model.PutLogEventsRequest; import java.util.ArrayList; import java.util.List; @@ -134,20 +132,21 @@ protected void setUpTableData() logger.info("Setting up Log Group: {}, Log Stream: {}", logGroupName, logStreamName); logger.info("----------------------------------------------------"); - AWSLogs logsClient = AWSLogsClientBuilder.defaultClient(); + CloudWatchLogsClient logsClient = CloudWatchLogsClient.create(); try { - logsClient.putLogEvents(new PutLogEventsRequest() - .withLogGroupName(logGroupName) - .withLogStreamName(logStreamName) - .withLogEvents( - new InputLogEvent().withTimestamp(currentTimeMillis).withMessage("Space, the final frontier."), - new InputLogEvent().withTimestamp(fromTimeMillis).withMessage(logMessage), - new InputLogEvent().withTimestamp(toTimeMillis + 5000) - .withMessage("To boldly go where no man has gone before!"))); + logsClient.putLogEvents(PutLogEventsRequest.builder() + .logGroupName(logGroupName) + .logStreamName(logStreamName) + .logEvents( + InputLogEvent.builder().timestamp(currentTimeMillis).message("Space, the final frontier.").build(), + InputLogEvent.builder().timestamp(fromTimeMillis).message(logMessage).build(), + InputLogEvent.builder().timestamp(toTimeMillis + 5000) + .message("To boldly go where no man has gone before!").build()) + .build()); } finally { - logsClient.shutdown(); + logsClient.close(); } } @@ -268,13 +267,13 @@ public void selectColumnWithPredicateIntegTest() String query = String.format("select message from %s.\"%s\".\"%s\" where time between %d and %d;", lambdaFunctionName, logGroupName, logStreamName, fromTimeMillis, toTimeMillis); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List messages = new ArrayList<>(); - rows.forEach(row -> messages.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> messages.add(row.data().get(0).varCharValue())); logger.info("Messages: {}", messages); assertEquals("Wrong number of log messages found.", 1, messages.size()); assertTrue("Expecting log message: " + logMessage, messages.contains(logMessage)); diff --git a/athena-datalakegen2/Dockerfile b/athena-datalakegen2/Dockerfile new file mode 100644 index 0000000000..4e1929f607 --- /dev/null +++ b/athena-datalakegen2/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-datalakegen2-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-datalakegen2-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2MuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-datalakegen2/athena-datalakegen2.yaml b/athena-datalakegen2/athena-datalakegen2.yaml index 32da145587..5890402513 100644 --- a/athena-datalakegen2/athena-datalakegen2.yaml +++ b/athena-datalakegen2/athena-datalakegen2.yaml @@ -71,10 +71,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2MuxCompositeHandler" - CodeUri: "./target/athena-datalakegen2-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-datalakegen2:2022.47.1' Description: "Enables Amazon Athena to communicate with DataLake Gen2 using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-datalakegen2/pom.xml b/athena-datalakegen2/pom.xml index c72c6c4813..670a4396b9 100644 --- a/athena-datalakegen2/pom.xml +++ b/athena-datalakegen2/pom.xml @@ -32,12 +32,18 @@ mssql-jdbc ${mssql.jdbc.version} - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandler.java b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandler.java index 53fd9386fe..14d27bba34 100644 --- a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandler.java +++ b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandler.java @@ -47,8 +47,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -58,6 +56,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -111,8 +111,8 @@ public DataLakeGen2MetadataHandler( @VisibleForTesting protected DataLakeGen2MetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxMetadataHandler.java b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxMetadataHandler.java index 0132af948d..577a193ec7 100644 --- a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxMetadataHandler.java +++ b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxMetadataHandler.java @@ -24,9 +24,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public DataLakeGen2MuxMetadataHandler(java.util.Map configOption } @VisibleForTesting - protected DataLakeGen2MuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected DataLakeGen2MuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandler.java b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandler.java index f637195150..dd7c643f82 100644 --- a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandler.java +++ b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandler.java @@ -24,10 +24,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -54,7 +54,7 @@ public DataLakeGen2MuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - DataLakeGen2MuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + DataLakeGen2MuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java index 16b3e5b584..f80e8bd0c0 100644 --- a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java +++ b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java @@ -28,15 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -54,12 +51,12 @@ public DataLakeGen2RecordHandler(java.util.Map configOptions) } public DataLakeGen2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new GenericJdbcConnectionFactory(databaseConnectionConfig, DataLakeGen2MetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(DataLakeGen2Constants.DRIVER_CLASS, DataLakeGen2Constants.DEFAULT_PORT)), new DataLakeGen2QueryStringBuilder(QUOTE_CHARACTER, new DataLakeGen2FederationExpressionParser(QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - DataLakeGen2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + DataLakeGen2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandlerTest.java b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandlerTest.java index c37359bab8..c69dcf613a 100644 --- a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandlerTest.java +++ b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandlerTest.java @@ -38,10 +38,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; @@ -50,6 +46,10 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.ResultSet; @@ -77,8 +77,8 @@ public class DataLakeGen2MetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() @@ -89,9 +89,9 @@ public void setup() this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); logger.info(" this.connection.."+ this.connection); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}").build()); this.dataLakeGen2MetadataHandler = new DataLakeGen2MetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); } diff --git a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxMetadataHandlerTest.java b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxMetadataHandlerTest.java index 0608abdec3..a2ffc02ec4 100644 --- a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxMetadataHandlerTest.java +++ b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxMetadataHandlerTest.java @@ -32,11 +32,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -49,8 +49,8 @@ public class DataLakeGen2MuxMetadataHandlerTest private DataLakeGen2MetadataHandler dataLakeGen2MetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -60,8 +60,8 @@ public void setup() this.allocator = new BlockAllocatorImpl(); this.dataLakeGen2MetadataHandler = Mockito.mock(DataLakeGen2MetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.dataLakeGen2MetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandlerTest.java b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandlerTest.java index 6b7f491bd0..dc2fa02473 100644 --- a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandlerTest.java +++ b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class DataLakeGen2MuxRecordHandlerTest private Map recordHandlerMap; private DataLakeGen2RecordHandler dataLakeGen2RecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.dataLakeGen2RecordHandler = Mockito.mock(DataLakeGen2RecordHandler.class); this.recordHandlerMap = Collections.singletonMap(DataLakeGen2Constants.NAME, this.dataLakeGen2RecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", DataLakeGen2Constants.NAME, diff --git a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeRecordHandlerTest.java b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeRecordHandlerTest.java index 1dd198ae89..912d328fa3 100644 --- a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeRecordHandlerTest.java +++ b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeRecordHandlerTest.java @@ -31,9 +31,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -41,6 +38,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -56,18 +56,18 @@ public class DataLakeRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { System.setProperty("aws.region", "us-east-1"); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-db2-as400/Dockerfile b/athena-db2-as400/Dockerfile new file mode 100644 index 0000000000..affd37e7bb --- /dev/null +++ b/athena-db2-as400/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-db2-as400-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-db2-as400-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.db2as400.Db2As400MuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-db2-as400/athena-db2-as400.yaml b/athena-db2-as400/athena-db2-as400.yaml index c84dac623e..ea0a331051 100644 --- a/athena-db2-as400/athena-db2-as400.yaml +++ b/athena-db2-as400/athena-db2-as400.yaml @@ -72,10 +72,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.db2as400.Db2As400MuxCompositeHandler" - CodeUri: "./target/athena-db2-as400-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-db2-as400:2022.47.1' Description: "Enables Amazon Athena to communicate with DB2 on iSeries (AS400) using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-db2-as400/pom.xml b/athena-db2-as400/pom.xml index 7c458b8caf..2165ff5019 100644 --- a/athena-db2-as400/pom.xml +++ b/athena-db2-as400/pom.xml @@ -33,12 +33,18 @@ jt400 20.0.7 - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MetadataHandler.java b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MetadataHandler.java index b083ceecb5..a589bbc33d 100644 --- a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MetadataHandler.java +++ b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MetadataHandler.java @@ -49,8 +49,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -60,6 +58,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -121,8 +121,8 @@ public Db2As400MetadataHandler( @VisibleForTesting protected Db2As400MetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxMetadataHandler.java b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxMetadataHandler.java index 490a72696b..705fe5e6ff 100644 --- a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxMetadataHandler.java +++ b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxMetadataHandler.java @@ -24,9 +24,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public Db2As400MuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected Db2As400MuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected Db2As400MuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxRecordHandler.java b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxRecordHandler.java index c2c19cc5d5..3d4706a208 100644 --- a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxRecordHandler.java +++ b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxRecordHandler.java @@ -24,10 +24,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -54,7 +54,7 @@ public Db2As400MuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - Db2As400MuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + Db2As400MuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandler.java b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandler.java index e78ae1964b..69d0711852 100644 --- a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandler.java +++ b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandler.java @@ -29,15 +29,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -60,13 +57,13 @@ public Db2As400RecordHandler(java.util.Map configOptions) */ public Db2As400RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new GenericJdbcConnectionFactory(databaseConnectionConfig, null, new DatabaseConnectionInfo(Db2As400Constants.DRIVER_CLASS, Db2As400Constants.DEFAULT_PORT)), new Db2As400QueryStringBuilder(QUOTE_CHARACTER), configOptions); } @VisibleForTesting - Db2As400RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + Db2As400RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400MetadataHandlerTest.java b/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400MetadataHandlerTest.java index 5f16236d1a..ce35bab8e8 100644 --- a/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400MetadataHandlerTest.java +++ b/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400MetadataHandlerTest.java @@ -41,10 +41,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; @@ -53,6 +49,10 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.PreparedStatement; @@ -80,9 +80,9 @@ public class Db2As400MetadataHandlerTest extends TestBase { private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; private BlockAllocator blockAllocator; - private AmazonAthena athena; + private AthenaClient athena; @Before public void setup() throws Exception { @@ -91,9 +91,9 @@ public void setup() throws Exception { this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); logger.info(" this.connection.."+ this.connection); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}").build()); this.db2As400MetadataHandler = new Db2As400MetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.blockAllocator = new BlockAllocatorImpl(); diff --git a/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandlerTest.java b/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandlerTest.java index fa2314b253..4ca5b947a8 100644 --- a/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandlerTest.java +++ b/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandlerTest.java @@ -31,9 +31,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -41,6 +38,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -54,16 +54,16 @@ public class Db2As400RecordHandlerTest { private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { System.setProperty("aws.region", "us-east-1"); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-db2/Dockerfile b/athena-db2/Dockerfile new file mode 100644 index 0000000000..0d8231fa29 --- /dev/null +++ b/athena-db2/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-db2-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-db2-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.db2.Db2MuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-db2/athena-db2.yaml b/athena-db2/athena-db2.yaml index cbaaa93af9..7508f16712 100644 --- a/athena-db2/athena-db2.yaml +++ b/athena-db2/athena-db2.yaml @@ -72,10 +72,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.db2.Db2MuxCompositeHandler" - CodeUri: "./target/athena-db2-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-db2:2022.47.1' Description: "Enables Amazon Athena to communicate with DB2 using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-db2/pom.xml b/athena-db2/pom.xml index fbe105f1b7..e018349754 100644 --- a/athena-db2/pom.xml +++ b/athena-db2/pom.xml @@ -33,12 +33,18 @@ jcc 11.5.9.0 - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandler.java b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandler.java index d5dec08242..965197ff0a 100644 --- a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandler.java +++ b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandler.java @@ -55,8 +55,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -67,6 +65,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -130,8 +130,8 @@ public Db2MetadataHandler( @VisibleForTesting protected Db2MetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxMetadataHandler.java b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxMetadataHandler.java index ab596649ab..2fd0df2842 100644 --- a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxMetadataHandler.java +++ b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxMetadataHandler.java @@ -24,9 +24,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public Db2MuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected Db2MuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected Db2MuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxRecordHandler.java b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxRecordHandler.java index 1919316e39..94fbe8c395 100644 --- a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxRecordHandler.java +++ b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxRecordHandler.java @@ -24,10 +24,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -54,7 +54,7 @@ public Db2MuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - Db2MuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + Db2MuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2RecordHandler.java b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2RecordHandler.java index 442d19fee3..8e9941f220 100644 --- a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2RecordHandler.java +++ b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2RecordHandler.java @@ -29,15 +29,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -61,13 +58,13 @@ public Db2RecordHandler(java.util.Map configOptions) */ public Db2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new GenericJdbcConnectionFactory(databaseConnectionConfig, null, new DatabaseConnectionInfo(Db2Constants.DRIVER_CLASS, Db2Constants.DEFAULT_PORT)), new Db2QueryStringBuilder(QUOTE_CHARACTER, new Db2FederationExpressionParser(QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - Db2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + Db2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandlerTest.java b/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandlerTest.java index 02ff20fa93..81a1ebb474 100644 --- a/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandlerTest.java +++ b/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2MetadataHandlerTest.java @@ -41,10 +41,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; @@ -53,6 +49,10 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.PreparedStatement; @@ -81,9 +81,9 @@ public class Db2MetadataHandlerTest extends TestBase { private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; private BlockAllocator blockAllocator; - private AmazonAthena athena; + private AthenaClient athena; @Before public void setup() throws Exception { @@ -92,9 +92,9 @@ public void setup() throws Exception { this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); logger.info(" this.connection.."+ this.connection); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}").build()); this.db2MetadataHandler = new Db2MetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.blockAllocator = new BlockAllocatorImpl(); diff --git a/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2RecordHandlerTest.java b/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2RecordHandlerTest.java index 801db06233..b7de058f8d 100644 --- a/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2RecordHandlerTest.java +++ b/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2RecordHandlerTest.java @@ -31,9 +31,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -41,6 +38,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -55,16 +55,16 @@ public class Db2RecordHandlerTest { private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { System.setProperty("aws.region", "us-east-1"); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-docdb/Dockerfile b/athena-docdb/Dockerfile new file mode 100644 index 0000000000..06e8a5c907 --- /dev/null +++ b/athena-docdb/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-docdb-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-docdb-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.docdb.DocDBCompositeHandler" ] \ No newline at end of file diff --git a/athena-docdb/athena-docdb.yaml b/athena-docdb/athena-docdb.yaml index efb7fc0b2e..588b05f52e 100644 --- a/athena-docdb/athena-docdb.yaml +++ b/athena-docdb/athena-docdb.yaml @@ -66,10 +66,9 @@ Resources: spill_prefix: !Ref SpillPrefix default_docdb: !Ref DocDBConnectionString FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.docdb.DocDBCompositeHandler" - CodeUri: "./target/athena-docdb-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-docdb:2022.47.1' Description: "Enables Amazon Athena to communicate with DocumentDB, making your DocumentDB data accessible via SQL." - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-docdb/pom.xml b/athena-docdb/pom.xml index 5dd645c20a..8982ee0159 100644 --- a/athena-docdb/pom.xml +++ b/athena-docdb/pom.xml @@ -28,11 +28,11 @@ 2022.47.1 test - + - com.amazonaws - aws-java-sdk-docdb - ${aws-sdk.version} + software.amazon.awssdk + docdb + ${aws-sdk-v2.version} test diff --git a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandler.java b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandler.java index 5a25b6f50c..191269fbd6 100644 --- a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandler.java +++ b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandler.java @@ -42,11 +42,6 @@ import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connectors.docdb.qpt.DocDBQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Database; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; import com.mongodb.client.MongoClient; @@ -58,6 +53,11 @@ import org.bson.Document; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Database; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.LinkedHashSet; @@ -95,13 +95,13 @@ public class DocDBMetadataHandler //is indeed enabled for use by this connector. private static final String DOCDB_METADATA_FLAG = "docdb-metadata-flag"; //Used to filter out Glue tables which lack a docdb metadata flag. - private static final TableFilter TABLE_FILTER = (Table table) -> table.getParameters().containsKey(DOCDB_METADATA_FLAG); + private static final TableFilter TABLE_FILTER = (Table table) -> table.parameters().containsKey(DOCDB_METADATA_FLAG); //The number of documents to scan when attempting to infer schema from an DocDB collection. private static final int SCHEMA_INFERRENCE_NUM_DOCS = 10; // used to filter out Glue databases which lack the docdb-metadata-flag in the URI. - private static final DatabaseFilter DB_FILTER = (Database database) -> (database.getLocationUri() != null && database.getLocationUri().contains(DOCDB_METADATA_FLAG)); + private static final DatabaseFilter DB_FILTER = (Database database) -> (database.locationUri() != null && database.locationUri().contains(DOCDB_METADATA_FLAG)); - private final AWSGlue glue; + private final GlueClient glue; private final DocDBConnectionFactory connectionFactory; private final DocDBQueryPassthrough queryPassthrough = new DocDBQueryPassthrough(); @@ -114,11 +114,11 @@ public DocDBMetadataHandler(java.util.Map configOptions) @VisibleForTesting protected DocDBMetadataHandler( - AWSGlue glue, + GlueClient glue, DocDBConnectionFactory connectionFactory, EncryptionKeyFactory keyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, java.util.Map configOptions) diff --git a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandler.java b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandler.java index ecba05bc18..4b0459f57e 100644 --- a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandler.java +++ b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandler.java @@ -28,12 +28,6 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.docdb.qpt.DocDBQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoCursor; @@ -44,6 +38,9 @@ import org.bson.Document; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; import java.util.TreeMap; @@ -81,15 +78,15 @@ public class DocDBRecordHandler public DocDBRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), + S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), new DocDBConnectionFactory(), configOptions); } @VisibleForTesting - protected DocDBRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, DocDBConnectionFactory connectionFactory, java.util.Map configOptions) + protected DocDBRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, DocDBConnectionFactory connectionFactory, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.connectionFactory = connectionFactory; diff --git a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandlerTest.java b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandlerTest.java index 866ecf164b..a69d0f4d31 100644 --- a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandlerTest.java +++ b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandlerTest.java @@ -39,9 +39,6 @@ import com.amazonaws.athena.connector.lambda.metadata.MetadataRequestType; import com.amazonaws.athena.connector.lambda.metadata.MetadataResponse; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.mongodb.client.FindIterable; import com.mongodb.client.MongoClient; @@ -63,6 +60,9 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.Arrays; @@ -100,13 +100,13 @@ public class DocDBMetadataHandlerTest private MongoClient mockClient; @Mock - private AWSGlue awsGlue; + private GlueClient awsGlue; @Mock - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Before public void setUp() diff --git a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java index 18a1947c79..866bc1ac41 100644 --- a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java +++ b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java @@ -40,14 +40,6 @@ import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import com.mongodb.client.FindIterable; @@ -71,6 +63,17 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -100,7 +103,7 @@ public class DocDBRecordHandlerTest private DocDBRecordHandler handler; private BlockAllocator allocator; private List mockS3Storage = new ArrayList<>(); - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpillReader spillReader; private Schema schemaForRead; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -116,16 +119,16 @@ public class DocDBRecordHandlerTest private MongoClient mockClient; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Mock - private AWSGlue awsGlue; + private GlueClient awsGlue; @Mock - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; @Mock MongoDatabase mockDatabase; @@ -171,7 +174,7 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); mockDatabase = mock(MongoDatabase.class); mockCollection = mock(MongoCollection.class); mockIterable = mock(FindIterable.class); @@ -179,31 +182,27 @@ public void setUp() when(mockClient.getDatabase(eq(DEFAULT_SCHEMA))).thenReturn(mockDatabase); when(mockDatabase.getCollection(eq(TEST_TABLE))).thenReturn(mockCollection); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); handler = new DocDBRecordHandler(amazonS3, mockSecretsManager, mockAthena, connectionFactory, com.google.common.collect.ImmutableMap.of()); diff --git a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/integ/DocDbIntegTest.java b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/integ/DocDbIntegTest.java index f20a65ceeb..bf0a314e8a 100644 --- a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/integ/DocDbIntegTest.java +++ b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/integ/DocDbIntegTest.java @@ -27,16 +27,6 @@ import com.amazonaws.athena.connector.integ.data.ConnectorVpcAttributes; import com.amazonaws.athena.connector.integ.data.SecretsManagerCredentials; import com.amazonaws.athena.connector.integ.providers.ConnectorPackagingAttributesProvider; -import com.amazonaws.services.athena.model.Row; -import com.amazonaws.services.docdb.AmazonDocDB; -import com.amazonaws.services.docdb.AmazonDocDBClientBuilder; -import com.amazonaws.services.docdb.model.DBCluster; -import com.amazonaws.services.docdb.model.DescribeDBClustersRequest; -import com.amazonaws.services.docdb.model.DescribeDBClustersResult; -import com.amazonaws.services.lambda.AWSLambda; -import com.amazonaws.services.lambda.AWSLambdaClientBuilder; -import com.amazonaws.services.lambda.model.InvocationType; -import com.amazonaws.services.lambda.model.InvokeRequest; import com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,6 +47,14 @@ import software.amazon.awscdk.services.ec2.Vpc; import software.amazon.awscdk.services.ec2.VpcAttributes; import software.amazon.awscdk.services.iam.PolicyDocument; +import software.amazon.awssdk.services.athena.model.Row; +import software.amazon.awssdk.services.docdb.DocDbClient; +import software.amazon.awssdk.services.docdb.model.DBCluster; +import software.amazon.awssdk.services.docdb.model.DescribeDbClustersRequest; +import software.amazon.awssdk.services.docdb.model.DescribeDbClustersResponse; +import software.amazon.awssdk.services.lambda.LambdaClient; +import software.amazon.awssdk.services.lambda.model.InvocationType; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; import java.util.ArrayList; import java.util.HashMap; @@ -192,15 +190,16 @@ private Stack getDocDbStack() { * Lambda. All exceptions thrown here will be caught in the calling function. */ private Endpoint getClusterData() { - AmazonDocDB docDbClient = AmazonDocDBClientBuilder.defaultClient(); + DocDbClient docDbClient = DocDbClient.create(); try { - DescribeDBClustersResult dbClustersResult = docDbClient.describeDBClusters(new DescribeDBClustersRequest() - .withDBClusterIdentifier(dbClusterName)); - DBCluster cluster = dbClustersResult.getDBClusters().get(0); - return new Endpoint(cluster.getEndpoint(), cluster.getPort()); + DescribeDbClustersResponse dbClustersResponse = docDbClient.describeDBClusters(DescribeDbClustersRequest.builder() + .dbClusterIdentifier(dbClusterName) + .build()); + DBCluster cluster = dbClustersResponse.dbClusters().get(0); + return new Endpoint(cluster.endpoint(), cluster.port()); } finally { - docDbClient.shutdown(); + docDbClient.close(); } } @@ -263,20 +262,21 @@ protected void setUpTableData() logger.info("----------------------------------------------------"); String mongoLambdaName = "integ-mongodb-" + UUID.randomUUID(); - AWSLambda lambdaClient = AWSLambdaClientBuilder.defaultClient(); + LambdaClient lambdaClient = LambdaClient.create(); CloudFormationClient cloudFormationMongoClient = new CloudFormationClient(getMongoLambdaStack(mongoLambdaName)); try { // Create the Lambda function. cloudFormationMongoClient.createStack(); // Invoke the Lambda function. - lambdaClient.invoke(new InvokeRequest() - .withFunctionName(mongoLambdaName) - .withInvocationType(InvocationType.RequestResponse)); + lambdaClient.invoke(InvokeRequest.builder() + .functionName(mongoLambdaName) + .invocationType(InvocationType.REQUEST_RESPONSE) + .build()); } finally { // Delete the Lambda function. cloudFormationMongoClient.deleteStack(); - lambdaClient.shutdown(); + lambdaClient.close(); } } @@ -371,13 +371,13 @@ public void selectColumnWithPredicateIntegTest() String query = String.format("select title from %s.%s.%s where year > 2012;", lambdaFunctionName, docdbDbName, docdbTableMovies); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List titles = new ArrayList<>(); - rows.forEach(row -> titles.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> titles.add(row.data().get(0).varCharValue())); logger.info("Titles: {}", titles); assertEquals("Wrong number of DB records found.", 1, titles.size()); assertTrue("Movie title not found: Interstellar.", titles.contains("Interstellar")); diff --git a/athena-dynamodb/Dockerfile b/athena-dynamodb/Dockerfile new file mode 100644 index 0000000000..868346d735 --- /dev/null +++ b/athena-dynamodb/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-dynamodb-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-dynamodb-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.dynamodb.DynamoDBCompositeHandler" ] \ No newline at end of file diff --git a/athena-dynamodb/athena-dynamodb.yaml b/athena-dynamodb/athena-dynamodb.yaml index ae3e023f58..f44ac89665 100644 --- a/athena-dynamodb/athena-dynamodb.yaml +++ b/athena-dynamodb/athena-dynamodb.yaml @@ -66,10 +66,9 @@ Resources: spill_prefix: !Ref SpillPrefix kms_key_id: !If [HasKMSKeyId, !Ref KMSKeyId, !Ref "AWS::NoValue"] FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.dynamodb.DynamoDBCompositeHandler" - CodeUri: "./target/athena-dynamodb-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-dynamodb:2022.47.1' Description: "Enables Amazon Athena to communicate with DynamoDB, making your tables accessible via SQL" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory Role: !If [NotHasLambdaRole, !GetAtt FunctionRole.Arn, !Ref LambdaRole] diff --git a/athena-dynamodb/pom.xml b/athena-dynamodb/pom.xml index 28d8f26fdb..81b5e92c61 100644 --- a/athena-dynamodb/pom.xml +++ b/athena-dynamodb/pom.xml @@ -8,17 +8,6 @@ 4.0.0 athena-dynamodb 2022.47.1 - - - - software.amazon.awssdk - bom - 2.28.26 - pom - import - - - com.amazonaws @@ -31,20 +20,16 @@ athena-federation-integ-test 2022.47.1 test - - - com.amazonaws - aws-java-sdk-sts - - software.amazon.awssdk dynamodb + ${aws-sdk-v2.version} software.amazon.awssdk dynamodb-enhanced + ${aws-sdk-v2.version} com.amazonaws @@ -55,6 +40,7 @@ software.amazon.awssdk url-connection-client + ${aws-sdk-v2.version} test @@ -111,13 +97,10 @@ test-jar test - - software.amazon.awssdk - sdk-core - software.amazon.awssdk sts + ${aws-sdk-v2.version} diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandler.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandler.java index f6be93e6b1..d472551f13 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandler.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandler.java @@ -57,14 +57,7 @@ import com.amazonaws.athena.connectors.dynamodb.util.DDBTableUtils; import com.amazonaws.athena.connectors.dynamodb.util.DDBTypeUtils; import com.amazonaws.athena.connectors.dynamodb.util.IncrementingValueNameProducer; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Database; -import com.amazonaws.services.glue.model.ErrorDetails; -import com.amazonaws.services.glue.model.FederationSourceErrorCode; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.util.json.Jackson; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -74,10 +67,17 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.enhanced.dynamodb.document.EnhancedDocument; +import software.amazon.awssdk.services.athena.AthenaClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementRequest; import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementResponse; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Database; +import software.amazon.awssdk.services.glue.model.ErrorDetails; +import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.Collections; @@ -134,15 +134,15 @@ public class DynamoDBMetadataHandler // defines the value that should be present in the Glue Database URI to enable the DB for DynamoDB. static final String DYNAMO_DB_FLAG = "dynamo-db-flag"; // used to filter out Glue tables which lack indications of being used for DDB. - private static final TableFilter TABLE_FILTER = (Table table) -> table.getStorageDescriptor().getLocation().contains(DYNAMODB) - || (table.getParameters() != null && DYNAMODB.equals(table.getParameters().get("classification"))) - || (table.getStorageDescriptor().getParameters() != null && DYNAMODB.equals(table.getStorageDescriptor().getParameters().get("classification"))); + private static final TableFilter TABLE_FILTER = (Table table) -> table.storageDescriptor().location().contains(DYNAMODB) + || (table.parameters() != null && DYNAMODB.equals(table.parameters().get("classification"))) + || (table.storageDescriptor().parameters() != null && DYNAMODB.equals(table.storageDescriptor().parameters().get("classification"))); // used to filter out Glue databases which lack the DYNAMO_DB_FLAG in the URI. - private static final DatabaseFilter DB_FILTER = (Database database) -> (database.getLocationUri() != null && database.getLocationUri().contains(DYNAMO_DB_FLAG)); + private static final DatabaseFilter DB_FILTER = (Database database) -> (database.locationUri() != null && database.locationUri().contains(DYNAMO_DB_FLAG)); private final ThrottlingInvoker invoker; private final DynamoDbClient ddbClient; - private final AWSGlue glueClient; + private final GlueClient glueClient; private final DynamoDBTableResolver tableResolver; private final DDBQueryPassthrough queryPassthrough; @@ -162,12 +162,12 @@ public DynamoDBMetadataHandler(java.util.Map configOptions) @VisibleForTesting DynamoDBMetadataHandler( EncryptionKeyFactory keyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, DynamoDbClient ddbClient, - AWSGlue glueClient, + GlueClient glueClient, java.util.Map configOptions) { super(glueClient, keyFactory, secretsManager, athena, SOURCE_TYPE, spillBucket, spillPrefix, configOptions); @@ -258,7 +258,7 @@ public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesReque public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception { if (!request.isQueryPassthrough()) { - throw new AthenaConnectorException("No Query passed through [{}]" + request, new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString()).withErrorMessage("No Query passed through [{}]" + request)); + throw new AthenaConnectorException("No Query passed through [{}]" + request, ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).errorMessage("No Query passed through [{}]" + request).build()); } queryPassthrough.verify(request.getQueryPassthroughArguments()); @@ -327,7 +327,7 @@ public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTabl table = tableResolver.getTableMetadata(tableName); } catch (TimeoutException e) { - throw new AthenaConnectorException(e.getMessage(), new ErrorDetails().withErrorCode(FederationSourceErrorCode.OperationTimeoutException.toString()).withErrorMessage(e.getMessage())); + throw new AthenaConnectorException(e.getMessage(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.OPERATION_TIMEOUT_EXCEPTION.toString()).errorMessage(e.getMessage()).build()); } // add table name so we don't have to do case insensitive resolution again partitionSchemaBuilder.addMetadata(TABLE_METADATA, table.getName()); @@ -449,7 +449,14 @@ private void precomputeAdditionalMetadata(Set columnsToIgnore, Map partitionMetadata = partitions.getSchema().getCustomMetadata(); String partitionType = partitionMetadata.get(PARTITION_TYPE_METADATA); if (partitionType == null) { - throw new AthenaConnectorException(String.format("No metadata %s defined in Schema %s", PARTITION_TYPE_METADATA, partitions.getSchema()), new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException(String.format("No metadata %s defined in Schema %s", PARTITION_TYPE_METADATA, partitions.getSchema()), ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } if (QUERY_PARTITION_TYPE.equals(partitionType)) { String hashKeyName = partitionMetadata.get(HASH_KEY_NAME_METADATA); @@ -530,7 +537,7 @@ else if (SCAN_PARTITION_TYPE.equals(partitionType)) { return new GetSplitsResponse(request.getCatalogName(), splits, null); } else { - throw new AthenaConnectorException("Unexpected partition type " + partitionType, new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("Unexpected partition type " + partitionType, ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } } diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java index 4ecf630889..8215b578ce 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java @@ -36,13 +36,8 @@ import com.amazonaws.athena.connectors.dynamodb.util.DDBPredicateUtils; import com.amazonaws.athena.connectors.dynamodb.util.DDBRecordMetadata; import com.amazonaws.athena.connectors.dynamodb.util.DDBTypeUtils; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.model.ErrorDetails; -import com.amazonaws.services.glue.model.FederationSourceErrorCode; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.util.json.Jackson; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; @@ -52,6 +47,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.enhanced.dynamodb.document.EnhancedDocument; +import software.amazon.awssdk.services.athena.AthenaClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementRequest; @@ -60,6 +56,10 @@ import software.amazon.awssdk.services.dynamodb.model.QueryResponse; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.glue.model.ErrorDetails; +import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.util.ArrayList; @@ -131,7 +131,7 @@ public ThrottlingInvoker load(String tableName) } @VisibleForTesting - DynamoDBRecordHandler(DynamoDbClient ddbClient, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, String sourceType, java.util.Map configOptions) + DynamoDBRecordHandler(DynamoDbClient ddbClient, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, String sourceType, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, sourceType, configOptions); this.ddbClient = ddbClient; @@ -209,7 +209,7 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor private void handleQueryPassthroughPartiQLQuery(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { if (!recordsRequest.getConstraints().isQueryPassThrough()) { - throw new AthenaConnectorException("Attempting to readConstraints with Query Passthrough without PartiQL Query", new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("Attempting to readConstraints with Query Passthrough without PartiQL Query", ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } queryPassthrough.verify(recordsRequest.getConstraints().getQueryPassthroughArguments()); @@ -324,11 +324,12 @@ private QueryRequest buildQueryRequest(Split split, String tableName, Schema sch Map expressionAttributeValues = new HashMap<>(); if (rangeKeyFilter != null || nonKeyFilter != null) { try { - expressionAttributeNames.putAll(Jackson.getObjectMapper().readValue(split.getProperty(EXPRESSION_NAMES_METADATA), STRING_MAP_TYPE_REFERENCE)); + ObjectMapper objectMapper = new ObjectMapper(); + expressionAttributeNames.putAll(objectMapper.readValue(split.getProperty(EXPRESSION_NAMES_METADATA), STRING_MAP_TYPE_REFERENCE)); expressionAttributeValues.putAll(EnhancedDocument.fromJson(split.getProperty(EXPRESSION_VALUES_METADATA)).toMap()); } catch (IOException e) { - throw new AthenaConnectorException(e.getMessage(), new ErrorDetails().withErrorCode(FederationSourceErrorCode.InternalServiceException.toString())); + throw new AthenaConnectorException(e.getMessage(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.INTERNAL_SERVICE_EXCEPTION.toString()).build()); } } @@ -391,11 +392,12 @@ private ScanRequest buildScanRequest(Split split, String tableName, Schema schem Map expressionAttributeValues = new HashMap<>(); if (rangeKeyFilter != null || nonKeyFilter != null) { try { - expressionAttributeNames.putAll(Jackson.getObjectMapper().readValue(split.getProperty(EXPRESSION_NAMES_METADATA), STRING_MAP_TYPE_REFERENCE)); + ObjectMapper objectMapper = new ObjectMapper(); + expressionAttributeNames.putAll(objectMapper.readValue(split.getProperty(EXPRESSION_NAMES_METADATA), STRING_MAP_TYPE_REFERENCE)); expressionAttributeValues.putAll(EnhancedDocument.fromJson(split.getProperty(EXPRESSION_VALUES_METADATA)).toMap()); } catch (IOException e) { - throw new AthenaConnectorException(e.getMessage(), new ErrorDetails().withErrorCode(FederationSourceErrorCode.InternalServiceException.toString())); + throw new AthenaConnectorException(e.getMessage(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.INTERNAL_SERVICE_EXCEPTION.toString()).build()); } } @@ -468,7 +470,7 @@ public Map next() } } catch (TimeoutException | ExecutionException e) { - throw new AthenaConnectorException(e.getMessage(), new ErrorDetails().withErrorCode(FederationSourceErrorCode.OperationTimeoutException.toString())); + throw new AthenaConnectorException(e.getMessage(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.OPERATION_TIMEOUT_EXCEPTION.toString()).build()); } currentPageIterator.set(iterator); if (iterator.hasNext()) { diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/qpt/DDBQueryPassthrough.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/qpt/DDBQueryPassthrough.java index 09a250a1fd..68a6d70403 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/qpt/DDBQueryPassthrough.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/qpt/DDBQueryPassthrough.java @@ -21,11 +21,11 @@ import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException; import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature; -import com.amazonaws.services.glue.model.ErrorDetails; -import com.amazonaws.services.glue.model.FederationSourceErrorCode; import com.google.common.collect.ImmutableSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.glue.model.ErrorDetails; +import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode; import java.util.Arrays; import java.util.List; @@ -80,7 +80,7 @@ public void customConnectorVerifications(Map engineQptArguments) // Immediately check if the statement starts with "SELECT" if (!upperCaseStatement.startsWith("SELECT")) { - throw new AthenaConnectorException("Statement does not start with SELECT.", new ErrorDetails().withErrorCode(FederationSourceErrorCode.OperationNotSupportedException.toString())); + throw new AthenaConnectorException("Statement does not start with SELECT.", ErrorDetails.builder().errorCode(FederationSourceErrorCode.OPERATION_NOT_SUPPORTED_EXCEPTION.toString()).build()); } // List of disallowed keywords @@ -89,7 +89,7 @@ public void customConnectorVerifications(Map engineQptArguments) // Check if the statement contains any disallowed keywords for (String keyword : disallowedKeywords) { if (upperCaseStatement.contains(keyword)) { - throw new AthenaConnectorException("Unaccepted operation; only SELECT statements are allowed. Found: " + keyword, new ErrorDetails().withErrorCode(FederationSourceErrorCode.OperationNotSupportedException.toString())); + throw new AthenaConnectorException("Unaccepted operation; only SELECT statements are allowed. Found: " + keyword, ErrorDetails.builder().errorCode(FederationSourceErrorCode.OPERATION_NOT_SUPPORTED_EXCEPTION.toString()).build()); } } } diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/resolver/DynamoDBFieldResolver.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/resolver/DynamoDBFieldResolver.java index cebb175715..0a186d7763 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/resolver/DynamoDBFieldResolver.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/resolver/DynamoDBFieldResolver.java @@ -23,12 +23,12 @@ import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException; import com.amazonaws.athena.connectors.dynamodb.util.DDBRecordMetadata; import com.amazonaws.athena.connectors.dynamodb.util.DDBTypeUtils; -import com.amazonaws.services.glue.model.ErrorDetails; -import com.amazonaws.services.glue.model.FederationSourceErrorCode; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.glue.model.ErrorDetails; +import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode; import java.util.Map; @@ -90,7 +90,7 @@ public Object getFieldValue(Field field, Object originalValue) } throw new AthenaConnectorException("Invalid field value encountered in DB record for field: " + field + - ",value: " + fieldValue, new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + ",value: " + fieldValue, ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } // Return the field value of a map key diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/resolver/DynamoDBTableResolver.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/resolver/DynamoDBTableResolver.java index 290359507b..7ae1fd436e 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/resolver/DynamoDBTableResolver.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/resolver/DynamoDBTableResolver.java @@ -24,8 +24,6 @@ import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBPaginatedTables; import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBTable; import com.amazonaws.athena.connectors.dynamodb.util.DDBTableUtils; -import com.amazonaws.services.glue.model.ErrorDetails; -import com.amazonaws.services.glue.model.FederationSourceErrorCode; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.Multimap; import org.apache.arrow.vector.types.pojo.Schema; @@ -35,6 +33,8 @@ import software.amazon.awssdk.services.dynamodb.model.ListTablesRequest; import software.amazon.awssdk.services.dynamodb.model.ListTablesResponse; import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException; +import software.amazon.awssdk.services.glue.model.ErrorDetails; +import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode; import java.util.ArrayList; import java.util.Collection; @@ -121,7 +121,7 @@ public Schema getTableSchema(String tableName) return DDBTableUtils.peekTableForSchema(caseInsensitiveMatch.get(), invoker, ddbClient); } else { - throw new AthenaConnectorException(e.getMessage(), new ErrorDetails().withErrorCode(FederationSourceErrorCode.EntityNotFoundException.toString())); + throw new AthenaConnectorException(e.getMessage(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.ENTITY_NOT_FOUND_EXCEPTION.toString()).build()); } } } diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBPredicateUtils.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBPredicateUtils.java index 3c38e4dec7..bf7aa0854f 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBPredicateUtils.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBPredicateUtils.java @@ -27,12 +27,12 @@ import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException; import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBIndex; import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBTable; -import com.amazonaws.services.glue.model.ErrorDetails; -import com.amazonaws.services.glue.model.FederationSourceErrorCode; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ProjectionType; +import software.amazon.awssdk.services.glue.model.ErrorDetails; +import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode; import java.util.ArrayList; import java.util.HashSet; @@ -192,7 +192,7 @@ private static void validateColumnRange(Range range) case EXACTLY: break; case BELOW: - throw new AthenaConnectorException("Low marker should never use BELOW bound", new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("Low marker should never use BELOW bound", ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); default: throw new AssertionError("Unhandled lower bound: " + range.getLow().getBound()); } @@ -200,7 +200,7 @@ private static void validateColumnRange(Range range) if (!range.getHigh().isUpperUnbounded()) { switch (range.getHigh().getBound()) { case ABOVE: - throw new AthenaConnectorException("High marker should never use ABOVE bound", new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("High marker should never use ABOVE bound", ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); case EXACTLY: break; case BELOW: diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTableUtils.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTableUtils.java index 923d03ec48..98332f78c1 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTableUtils.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTableUtils.java @@ -24,8 +24,6 @@ import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException; import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBIndex; import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBTable; -import com.amazonaws.services.glue.model.ErrorDetails; -import com.amazonaws.services.glue.model.FederationSourceErrorCode; import com.google.common.collect.ImmutableList; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; @@ -44,6 +42,8 @@ import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.ScanResponse; import software.amazon.awssdk.services.dynamodb.model.TableDescription; +import software.amazon.awssdk.services.glue.model.ErrorDetails; +import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode; import java.util.List; import java.util.Map; @@ -170,7 +170,7 @@ public static Schema peekTableForSchema(String tableName, ThrottlingInvoker invo logger.warn("Failed to retrieve table schema due to KMS issue, empty schema for table: {}. Error Message: {}", tableName, runtimeException.getMessage()); } else { - throw new AthenaConnectorException(runtimeException.getMessage(), new ErrorDetails().withErrorCode(FederationSourceErrorCode.OperationTimeoutException.toString())); + throw new AthenaConnectorException(runtimeException.getMessage(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.OPERATION_TIMEOUT_EXCEPTION.toString()).build()); } } return schemaBuilder.build(); diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTypeUtils.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTypeUtils.java index b5f27a434a..d1abcdefaa 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTypeUtils.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTypeUtils.java @@ -32,8 +32,6 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.ConstraintProjector; import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException; import com.amazonaws.athena.connectors.dynamodb.resolver.DynamoDBFieldResolver; -import com.amazonaws.services.glue.model.ErrorDetails; -import com.amazonaws.services.glue.model.FederationSourceErrorCode; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.holders.NullableBitHolder; import org.apache.arrow.vector.types.Types; @@ -52,6 +50,8 @@ import software.amazon.awssdk.enhanced.dynamodb.internal.converter.attribute.EnhancedAttributeValue; import software.amazon.awssdk.enhanced.dynamodb.internal.converter.attribute.StringAttributeConverter; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.glue.model.ErrorDetails; +import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode; import software.amazon.awssdk.utils.ImmutableMap; import java.math.BigDecimal; @@ -191,7 +191,7 @@ else if (enhancedAttributeValue.isMap()) { } String attributeTypeName = (value == null || value.getClass() == null) ? "null" : enhancedAttributeValue.type().name(); - throw new AthenaConnectorException("Unknown Attribute Value Type[" + attributeTypeName + "] for field[" + key + "]", new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("Unknown Attribute Value Type[" + attributeTypeName + "] for field[" + key + "]", ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } /** @@ -265,7 +265,7 @@ public static Field getArrowFieldFromDDBType(String attributeName, String attrib case MAP: return new Field(attributeName, FieldType.nullable(Types.MinorType.STRUCT.getType()), null); default: - throw new AthenaConnectorException("Unknown type[" + attributeType + "] for field[" + attributeName + "]", new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("Unknown type[" + attributeType + "] for field[" + attributeName + "]", ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } } @@ -385,7 +385,7 @@ public static List coerceListToExpectedType(Object value, Field field, D if (!(value instanceof Collection)) { if (value instanceof Map) { - throw new AthenaConnectorException("Unexpected type (Map) encountered for: " + childField.getName(), new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("Unexpected type (Map) encountered for: " + childField.getName(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } return Collections.singletonList(coerceValueToExpectedType(value, childField, fieldType, recordMetadata)); } @@ -621,7 +621,7 @@ else if (value instanceof Map) { return handleMapType((Map) value); } else { - throw new AthenaConnectorException("Unsupported value type: " + value.getClass(), new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("Unsupported value type: " + value.getClass(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } } @@ -635,7 +635,7 @@ public static AttributeValue jsonToAttributeValue(String jsonString, String key) { EnhancedDocument enhancedDocument = EnhancedDocument.fromJson(jsonString); if (!enhancedDocument.isPresent(key)) { - throw new AthenaConnectorException("Unknown attribute Key", new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("Unknown attribute Key", ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } return enhancedDocument.toMap().get(key); } @@ -658,7 +658,7 @@ else if (firstElement instanceof Number) { } // Add other types if needed // Fallback for unsupported set types - throw new AthenaConnectorException("Unsupported Set element type: " + firstElement.getClass(), new ErrorDetails().withErrorCode(FederationSourceErrorCode.InvalidInputException.toString())); + throw new AthenaConnectorException("Unsupported Set element type: " + firstElement.getClass(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build()); } private static AttributeValue handleListType(List value) diff --git a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandlerTest.java b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandlerTest.java index bff76405e8..76d1e504ed 100644 --- a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandlerTest.java +++ b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandlerTest.java @@ -44,18 +44,7 @@ import com.amazonaws.athena.connector.lambda.metadata.MetadataResponse; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.dynamodb.util.DDBTypeUtils; -import com.amazonaws.services.athena.AmazonAthena; - import com.amazonaws.services.dynamodbv2.document.ItemUtils; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.Database; -import com.amazonaws.services.glue.model.GetDatabasesResult; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.GetTablesResult; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.amazonaws.util.json.Jackson; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -74,7 +63,19 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.enhanced.dynamodb.document.EnhancedDocument; +import software.amazon.awssdk.services.athena.AthenaClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.Database; +import software.amazon.awssdk.services.glue.model.GetDatabasesRequest; +import software.amazon.awssdk.services.glue.model.GetDatabasesResponse; +import software.amazon.awssdk.services.glue.model.GetTablesRequest; +import software.amazon.awssdk.services.glue.model.GetTablesResponse; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.glue.paginators.GetDatabasesIterable; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.time.Instant; import java.time.LocalDateTime; @@ -114,6 +115,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.*; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; /** @@ -129,13 +131,13 @@ public class DynamoDBMetadataHandlerTest public TestName testName = new TestName(); @Mock - private AWSGlue glueClient; + private GlueClient glueClient; @Mock - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; @Mock - private AmazonAthena athena; + private AthenaClient athena; private DynamoDBMetadataHandler handler; @@ -162,7 +164,7 @@ public void tearDown() public void doListSchemaNamesGlueError() throws Exception { - when(glueClient.getDatabases(any())).thenThrow(new AmazonServiceException("")); + when(glueClient.getDatabasesPaginator(any(GetDatabasesRequest.class))).thenThrow(new AmazonServiceException("")); ListSchemasRequest req = new ListSchemasRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME); ListSchemasResponse res = handler.doListSchemaNames(allocator, req); @@ -176,12 +178,16 @@ public void doListSchemaNamesGlueError() public void doListSchemaNamesGlue() throws Exception { - GetDatabasesResult result = new GetDatabasesResult().withDatabaseList( - new Database().withName(DEFAULT_SCHEMA), - new Database().withName("ddb").withLocationUri(DYNAMO_DB_FLAG), - new Database().withName("s3").withLocationUri("blah")); + GetDatabasesResponse response = GetDatabasesResponse.builder() + .databaseList( + Database.builder().name(DEFAULT_SCHEMA).build(), + Database.builder().name("ddb").locationUri(DYNAMO_DB_FLAG).build(), + Database.builder().name("s3").locationUri("blah").build()) + .build(); - when(glueClient.getDatabases(any())).thenReturn(result); + GetDatabasesIterable mockIterable = mock(GetDatabasesIterable.class); + when(mockIterable.stream()).thenReturn(Collections.singletonList(response).stream()); + when(glueClient.getDatabasesPaginator(any(GetDatabasesRequest.class))).thenReturn(mockIterable); ListSchemasRequest req = new ListSchemasRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME); ListSchemasResponse res = handler.doListSchemaNames(allocator, req); @@ -202,25 +208,37 @@ public void doListTablesGlueAndDynamo() tableNames.add("table2"); tableNames.add("table3"); - GetTablesResult mockResult = new GetTablesResult(); List tableList = new ArrayList<>(); - tableList.add(new Table().withName("table1") - .withParameters(ImmutableMap.of("classification", "dynamodb")) - .withStorageDescriptor(new StorageDescriptor() - .withLocation("some.location"))); - tableList.add(new Table().withName("table2") - .withParameters(ImmutableMap.of()) - .withStorageDescriptor(new StorageDescriptor() - .withLocation("some.location") - .withParameters(ImmutableMap.of("classification", "dynamodb")))); - tableList.add(new Table().withName("table3") - .withParameters(ImmutableMap.of()) - .withStorageDescriptor(new StorageDescriptor() - .withLocation("arn:aws:dynamodb:us-east-1:012345678910:table/table3"))); - tableList.add(new Table().withName("notADynamoTable").withParameters(ImmutableMap.of()).withStorageDescriptor( - new StorageDescriptor().withParameters(ImmutableMap.of()).withLocation("some_location"))); - mockResult.setTableList(tableList); - when(glueClient.getTables(any())).thenReturn(mockResult); + tableList.add(Table.builder().name("table1") + .parameters(ImmutableMap.of("classification", "dynamodb")) + .storageDescriptor(StorageDescriptor.builder() + .location("some.location") + .build()) + .build()); + tableList.add(Table.builder().name("table2") + .parameters(ImmutableMap.of()) + .storageDescriptor(StorageDescriptor.builder() + .location("some.location") + .parameters(ImmutableMap.of("classification", "dynamodb")) + .build()) + .build()); + tableList.add(Table.builder().name("table3") + .parameters(ImmutableMap.of()) + .storageDescriptor(StorageDescriptor.builder() + .location("arn:aws:dynamodb:us-east-1:012345678910:table/table3") + .build()) + .build()); + tableList.add(Table.builder().name("notADynamoTable") + .parameters(ImmutableMap.of()) + .storageDescriptor(StorageDescriptor.builder() + .location("some_location") + .parameters(ImmutableMap.of()) + .build()) + .build()); + GetTablesResponse mockResponse = GetTablesResponse.builder() + .tableList(tableList) + .build(); + when(glueClient.getTables(any(GetTablesRequest.class))).thenReturn(mockResponse); ListTablesRequest req = new ListTablesRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, DEFAULT_SCHEMA, null, UNLIMITED_PAGE_SIZE_VALUE); @@ -257,7 +275,7 @@ public void doListPaginatedTables() public void doGetTable() throws Exception { - when(glueClient.getTable(any())).thenThrow(new AmazonServiceException("")); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenThrow(new AmazonServiceException("")); GetTableRequest req = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, TEST_TABLE_NAME, Collections.emptyMap()); GetTableResponse res = handler.doGetTable(allocator, req); @@ -273,7 +291,7 @@ public void doGetTable() public void doGetEmptyTable() throws Exception { - when(glueClient.getTable(any())).thenThrow(new AmazonServiceException("")); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenThrow(new AmazonServiceException("")); GetTableRequest req = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, TEST_TABLE_2_NAME, Collections.emptyMap()); GetTableResponse res = handler.doGetTable(allocator, req); @@ -288,7 +306,7 @@ public void doGetEmptyTable() public void testCaseInsensitiveResolve() throws Exception { - when(glueClient.getTable(any())).thenThrow(new AmazonServiceException("")); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenThrow(new AmazonServiceException("")); GetTableRequest req = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, TEST_TABLE_2_NAME, Collections.emptyMap()); GetTableResponse res = handler.doGetTable(allocator, req); @@ -594,20 +612,21 @@ public void validateSourceTableNamePropagation() throws Exception { List columns = new ArrayList<>(); - columns.add(new Column().withName("col1").withType("int")); - columns.add(new Column().withName("col2").withType("bigint")); - columns.add(new Column().withName("col3").withType("string")); + columns.add(Column.builder().name("col1").type("int").build()); + columns.add(Column.builder().name("col2").type("bigint").build()); + columns.add(Column.builder().name("col3").type("string").build()); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE, COLUMN_NAME_MAPPING_PROPERTY, "col1=Col1 , col2=Col2 ,col3=Col3", DATETIME_FORMAT_MAPPING_PROPERTY, "col1=datetime1,col3=datetime3 "); - Table table = new Table() - .withParameters(param) - .withPartitionKeys() - .withStorageDescriptor(new StorageDescriptor().withColumns(columns)); - GetTableResult mockResult = new GetTableResult().withTable(table); - when(glueClient.getTable(any())).thenReturn(mockResult); + Table table = Table.builder() + .parameters(param) + .storageDescriptor(StorageDescriptor.builder().columns(columns).build()) + .partitionKeys(Collections.EMPTY_SET) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse tableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(tableResponse); TableName tableName = new TableName(DEFAULT_SCHEMA, "glueTableForTestTable"); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); @@ -635,20 +654,21 @@ public void doGetTableLayoutScanWithTypeOverride() throws Exception { List columns = new ArrayList<>(); - columns.add(new Column().withName("col1").withType("int")); - columns.add(new Column().withName("col2").withType("timestamptz")); - columns.add(new Column().withName("col3").withType("string")); + columns.add(Column.builder().name("col1").type("int").build()); + columns.add(Column.builder().name("col2").type("timestamptz").build()); + columns.add(Column.builder().name("col3").type("string").build()); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE, COLUMN_NAME_MAPPING_PROPERTY, "col1=Col1", DATETIME_FORMAT_MAPPING_PROPERTY, "col1=datetime1,col3=datetime3 "); - Table table = new Table() - .withParameters(param) - .withPartitionKeys() - .withStorageDescriptor(new StorageDescriptor().withColumns(columns)); - GetTableResult mockResult = new GetTableResult().withTable(table); - when(glueClient.getTable(any())).thenReturn(mockResult); + Table table = Table.builder() + .parameters(param) + .partitionKeys(Collections.EMPTY_SET) + .storageDescriptor(StorageDescriptor.builder().columns(columns).build()) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse tableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(tableResponse); TableName tableName = new TableName(DEFAULT_SCHEMA, "glueTableForTestTable"); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); diff --git a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java index d9f3f421b1..9972e3fc0f 100644 --- a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java +++ b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java @@ -38,15 +38,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.dynamodb.util.DDBTypeUtils; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.EntityNotFoundException; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.complex.impl.UnionListReader; @@ -60,7 +51,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; - import org.junit.rules.TestName; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -68,7 +58,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.enhanced.dynamodb.document.EnhancedDocument; +import software.amazon.awssdk.services.athena.AthenaClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.EntityNotFoundException; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.time.LocalDate; import java.time.LocalDateTime; @@ -121,13 +119,13 @@ public class DynamoDBRecordHandlerTest private DynamoDBMetadataHandler metadataHandler; @Mock - private AWSGlue glueClient; + private GlueClient glueClient; @Mock - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; @Mock - private AmazonAthena athena; + private AthenaClient athena; @Rule public TestName testName = new TestName(); @@ -138,7 +136,7 @@ public void setup() logger.info("{}: enter", testName.getMethodName()); allocator = new BlockAllocatorImpl(); - handler = new DynamoDBRecordHandler(ddbClient, mock(AmazonS3.class), mock(AWSSecretsManager.class), mock(AmazonAthena.class), "source_type", com.google.common.collect.ImmutableMap.of()); + handler = new DynamoDBRecordHandler(ddbClient, mock(S3Client.class), mock(SecretsManagerClient.class), mock(AthenaClient.class), "source_type", com.google.common.collect.ImmutableMap.of()); metadataHandler = new DynamoDBMetadataHandler(new LocalKeyFactory(), secretsManager, athena, "spillBucket", "spillPrefix", ddbClient, glueClient, com.google.common.collect.ImmutableMap.of()); } @@ -398,25 +396,26 @@ public void testDateTimeSupportFromGlueTable() throws Exception TimeZone.setDefault(TimeZone.getTimeZone("UTC")); List columns = new ArrayList<>(); - columns.add(new Column().withName("col0").withType("string")); - columns.add(new Column().withName("col1").withType("timestamp")); - columns.add(new Column().withName("col2").withType("timestamp")); - columns.add(new Column().withName("col3").withType("date")); - columns.add(new Column().withName("col4").withType("date")); - columns.add(new Column().withName("col5").withType("timestamptz")); - columns.add(new Column().withName("col6").withType("timestamptz")); - columns.add(new Column().withName("col7").withType("timestamptz")); + columns.add(Column.builder().name("col0").type("string").build()); + columns.add(Column.builder().name("col1").type("timestamp").build()); + columns.add(Column.builder().name("col2").type("timestamp").build()); + columns.add(Column.builder().name("col3").type("date").build()); + columns.add(Column.builder().name("col4").type("date").build()); + columns.add(Column.builder().name("col5").type("timestamptz").build()); + columns.add(Column.builder().name("col6").type("timestamptz").build()); + columns.add(Column.builder().name("col7").type("timestamptz").build()); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE3, COLUMN_NAME_MAPPING_PROPERTY, "col1=Col1 , col2=Col2 ,col3=Col3, col4=Col4,col5=Col5,col6=Col6,col7=Col7", DATETIME_FORMAT_MAPPING_PROPERTY, "col1=yyyyMMdd'S'HHmmss,col3=dd/MM/yyyy "); - Table table = new Table() - .withParameters(param) - .withPartitionKeys() - .withStorageDescriptor(new StorageDescriptor().withColumns(columns)); - GetTableResult mockResult = new GetTableResult().withTable(table); - when(glueClient.getTable(any())).thenReturn(mockResult); + Table table = Table.builder() + .parameters(param) + .partitionKeys(Collections.EMPTY_SET) + .storageDescriptor(StorageDescriptor.builder().columns(columns).build()) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse tableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(tableResponse); TableName tableName = new TableName(DEFAULT_SCHEMA, TEST_TABLE3); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); @@ -463,17 +462,18 @@ public void testDateTimeSupportFromGlueTable() throws Exception public void testStructWithNullFromGlueTable() throws Exception { List columns = new ArrayList<>(); - columns.add(new Column().withName("col0").withType("string")); - columns.add(new Column().withName("col1").withType("struct")); + columns.add(Column.builder().name("col0").type("string").build()); + columns.add(Column.builder().name("col1").type("struct").build()); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE4, COLUMN_NAME_MAPPING_PROPERTY, "col0=Col0,col1=Col1,col2=Col2"); - Table table = new Table() - .withParameters(param) - .withPartitionKeys() - .withStorageDescriptor(new StorageDescriptor().withColumns(columns)); - GetTableResult mockResult = new GetTableResult().withTable(table); - when(glueClient.getTable(any())).thenReturn(mockResult); + Table table = Table.builder() + .parameters(param) + .partitionKeys(Collections.EMPTY_SET) + .storageDescriptor(StorageDescriptor.builder().columns(columns).build()) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse tableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(tableResponse); TableName tableName = new TableName(DEFAULT_SCHEMA, TEST_TABLE4); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); @@ -520,7 +520,7 @@ public void testStructWithNullFromGlueTable() throws Exception @Test public void testStructWithNullFromDdbTable() throws Exception { - when(glueClient.getTable(any())).thenThrow(new EntityNotFoundException("")); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenThrow(EntityNotFoundException.builder().message("").build()); TableName tableName = new TableName(DEFAULT_SCHEMA, TEST_TABLE4); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); @@ -571,19 +571,20 @@ public void testMapWithSchemaFromGlueTable() throws Exception } List columns = new ArrayList<>(); - columns.add(new Column().withName("col0").withType("string")); - columns.add(new Column().withName("outermap").withType("MAP>")); - columns.add(new Column().withName("structcol").withType("MAP>")); + columns.add(Column.builder().name("col0").type("string").build()); + columns.add(Column.builder().name("outermap").type("MAP>").build()); + columns.add(Column.builder().name("structcol").type("MAP>").build()); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE5, COLUMN_NAME_MAPPING_PROPERTY, "col0=Col0,col1=Col1,col2=Col2"); - Table table = new Table() - .withParameters(param) - .withPartitionKeys() - .withStorageDescriptor(new StorageDescriptor().withColumns(columns)); - GetTableResult mockResult = new GetTableResult().withTable(table); - when(glueClient.getTable(any())).thenReturn(mockResult); + Table table = Table.builder() + .parameters(param) + .partitionKeys(Collections.EMPTY_SET) + .storageDescriptor(StorageDescriptor.builder().columns(columns).build()) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse tableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(tableResponse); TableName tableName = new TableName(DEFAULT_SCHEMA, TEST_TABLE5); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); @@ -624,19 +625,20 @@ public void testMapWithSchemaFromGlueTable() throws Exception public void testStructWithSchemaFromGlueTable() throws Exception { List columns = new ArrayList<>(); - columns.add(new Column().withName("col0").withType("string")); - columns.add(new Column().withName("outermap").withType("struct>")); - columns.add(new Column().withName("structcol").withType("struct>")); + columns.add(Column.builder().name("col0").type("string").build()); + columns.add(Column.builder().name("outermap").type("struct>").build()); + columns.add(Column.builder().name("structcol").type("struct>").build()); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE6, COLUMN_NAME_MAPPING_PROPERTY, "col0=Col0,col1=Col1,col2=Col2"); - Table table = new Table() - .withParameters(param) - .withPartitionKeys() - .withStorageDescriptor(new StorageDescriptor().withColumns(columns)); - GetTableResult mockResult = new GetTableResult().withTable(table); - when(glueClient.getTable(any())).thenReturn(mockResult); + Table table = Table.builder() + .parameters(param) + .partitionKeys(Collections.EMPTY_SET) + .storageDescriptor(StorageDescriptor.builder().columns(columns).build()) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse tableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(tableResponse); TableName tableName = new TableName(DEFAULT_SCHEMA, TEST_TABLE6); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); @@ -678,20 +680,21 @@ public void testStructWithSchemaFromGlueTable() throws Exception public void testListWithSchemaFromGlueTable() throws Exception { List columns = new ArrayList<>(); - columns.add(new Column().withName("col0").withType("string")); - columns.add(new Column().withName("stringList").withType("ARRAY ")); - columns.add(new Column().withName("intList").withType("ARRAY ")); - columns.add(new Column().withName("listStructCol").withType("array>")); + columns.add(Column.builder().name("col0").type("string").build()); + columns.add(Column.builder().name("stringList").type("ARRAY ").build()); + columns.add(Column.builder().name("intList").type("ARRAY ").build()); + columns.add(Column.builder().name("listStructCol").type("array>").build()); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE7, COLUMN_NAME_MAPPING_PROPERTY, "col0=Col0,col1=Col1,col2=Col2"); - Table table = new Table() - .withParameters(param) - .withPartitionKeys() - .withStorageDescriptor(new StorageDescriptor().withColumns(columns)); - GetTableResult mockResult = new GetTableResult().withTable(table); - when(glueClient.getTable(any())).thenReturn(mockResult); + Table table = Table.builder() + .parameters(param) + .partitionKeys(Collections.EMPTY_SET) + .storageDescriptor(StorageDescriptor.builder().columns(columns).build()) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse tableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(tableResponse); TableName tableName = new TableName(DEFAULT_SCHEMA, TEST_TABLE7); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); @@ -760,18 +763,19 @@ public void testNumMapWithSchemaFromGlueTable() throws Exception } List columns = new ArrayList<>(); - columns.add(new Column().withName("col0").withType("string")); - columns.add(new Column().withName("nummap").withType("map")); + columns.add(Column.builder().name("col0").type("string").build()); + columns.add(Column.builder().name("nummap").type("map").build()); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE8, COLUMN_NAME_MAPPING_PROPERTY, "col0=Col0,col1=Col1,col2=Col2"); - Table table = new Table() - .withParameters(param) - .withPartitionKeys() - .withStorageDescriptor(new StorageDescriptor().withColumns(columns)); - GetTableResult mockResult = new GetTableResult().withTable(table); - when(glueClient.getTable(any())).thenReturn(mockResult); + Table table = Table.builder() + .parameters(param) + .partitionKeys(Collections.EMPTY_SET) + .storageDescriptor(StorageDescriptor.builder().columns(columns).build()) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse mockResult = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(mockResult); TableName tableName = new TableName(DEFAULT_SCHEMA, TEST_TABLE8); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); @@ -824,18 +828,19 @@ public void testNumMapWithSchemaFromGlueTable() throws Exception public void testNumStructWithSchemaFromGlueTable() throws Exception { List columns = new ArrayList<>(); - columns.add(new Column().withName("col0").withType("string")); - columns.add(new Column().withName("nummap").withType("struct")); + columns.add(Column.builder().name("col0").type("string").build()); + columns.add(Column.builder().name("nummap").type("struct").build()); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE8, COLUMN_NAME_MAPPING_PROPERTY, "col0=Col0,col1=Col1,col2=Col2"); - Table table = new Table() - .withParameters(param) - .withPartitionKeys() - .withStorageDescriptor(new StorageDescriptor().withColumns(columns)); - GetTableResult mockResult = new GetTableResult().withTable(table); - when(glueClient.getTable(any())).thenReturn(mockResult); + Table table = Table.builder() + .parameters(param) + .partitionKeys(Collections.EMPTY_SET) + .storageDescriptor(StorageDescriptor.builder().columns(columns).build()) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse mockResult = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); + when(glueClient.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(mockResult); TableName tableName = new TableName(DEFAULT_SCHEMA, TEST_TABLE8); GetTableRequest getTableRequest = new GetTableRequest(TEST_IDENTITY, TEST_QUERY_ID, TEST_CATALOG_NAME, tableName, Collections.emptyMap()); diff --git a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDbIntegTest.java b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDbIntegTest.java index 821c9e6b87..4e23966c90 100644 --- a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDbIntegTest.java +++ b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDbIntegTest.java @@ -20,7 +20,6 @@ package com.amazonaws.athena.connectors.dynamodb; import com.amazonaws.athena.connector.integ.IntegrationTestBase; -import com.amazonaws.services.athena.model.Row; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.slf4j.Logger; @@ -32,6 +31,7 @@ import software.amazon.awscdk.services.iam.PolicyDocument; import software.amazon.awscdk.services.iam.PolicyStatement; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.athena.model.Row; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import java.nio.ByteBuffer; @@ -239,13 +239,13 @@ public void selectColumnWithPredicateIntegTest() String query = String.format("select title from %s.%s.%s where year > 2000;", lambdaFunctionName, dynamodbDbName, movieTableName); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List titles = new ArrayList<>(); - rows.forEach(row -> titles.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> titles.add(row.data().get(0).varCharValue())); logger.info("Titles: {}", titles); assertEquals("Wrong number of DB records found.", 1, titles.size()); assertTrue("Movie title not found: Interstellar.", titles.contains("Interstellar")); @@ -265,13 +265,13 @@ public void selectFloat8TypeTest() String query = String.format("select float8_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Double.valueOf(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(Double.valueOf(row.data().get(0).varCharValue()))); AssertJUnit.assertEquals("Wrong number of DB records found.", 1, values.size()); AssertJUnit.assertTrue("Float8 not found: " + 1E-130, values.contains(1E-130)); } diff --git a/athena-elasticsearch/Dockerfile b/athena-elasticsearch/Dockerfile new file mode 100644 index 0000000000..d153e67a95 --- /dev/null +++ b/athena-elasticsearch/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-elasticsearch-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-elasticsearch-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.elasticsearch.ElasticsearchCompositeHandler" ] \ No newline at end of file diff --git a/athena-elasticsearch/athena-elasticsearch.yaml b/athena-elasticsearch/athena-elasticsearch.yaml index ff12e031ad..c2e6603bf9 100644 --- a/athena-elasticsearch/athena-elasticsearch.yaml +++ b/athena-elasticsearch/athena-elasticsearch.yaml @@ -102,10 +102,9 @@ Resources: query_timeout_search: !Ref QueryTimeoutSearch query_scroll_timeout: !Ref QueryScrollTimeout FunctionName: !Sub "${AthenaCatalogName}" - Handler: "com.amazonaws.athena.connectors.elasticsearch.ElasticsearchCompositeHandler" - CodeUri: "./target/athena-elasticsearch-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-elasticsearch:2022.47.1' Description: "The Elasticsearch Lambda Connector provides Athena users the ability to query data stored on Elasticsearch clusters." - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-elasticsearch/pom.xml b/athena-elasticsearch/pom.xml index 13f08500c8..316cf40b7f 100644 --- a/athena-elasticsearch/pom.xml +++ b/athena-elasticsearch/pom.xml @@ -62,60 +62,6 @@ ${log4j2Version} runtime - - com.amazonaws - jmespath-java - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - - - com.amazonaws - aws-java-sdk-core - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - software.amazon.awscdk @@ -135,11 +81,11 @@ elasticsearch-rest-high-level-client 7.10.2 - + - com.amazonaws - aws-java-sdk-elasticsearch - ${aws-sdk.version} + software.amazon.awssdk + elasticsearch + ${aws-sdk-v2.version} diff --git a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AWSRequestSigningApacheInterceptor.java b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AWSRequestSigningApacheInterceptor.java index 4108d663b8..2c3f58e215 100644 --- a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AWSRequestSigningApacheInterceptor.java +++ b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AWSRequestSigningApacheInterceptor.java @@ -19,10 +19,6 @@ */ package com.amazonaws.athena.connectors.elasticsearch; -import com.amazonaws.DefaultRequest; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.Signer; -import com.amazonaws.http.HttpMethodName; import org.apache.http.Header; import org.apache.http.HttpEntityEnclosingRequest; import org.apache.http.HttpException; @@ -34,9 +30,15 @@ import org.apache.http.entity.BasicHttpEntity; import org.apache.http.message.BasicHeader; import org.apache.http.protocol.HttpContext; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.http.auth.spi.signer.SignedRequest; import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; @@ -47,8 +49,8 @@ import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; /** - * An {@link HttpRequestInterceptor} that signs requests using any AWS {@link Signer} - * and {@link AWSCredentialsProvider}. + * An {@link HttpRequestInterceptor} that signs requests using any AWS {@link AwsV4HttpSigner} + * and {@link AwsCredentialsProvider}. */ public class AWSRequestSigningApacheInterceptor implements HttpRequestInterceptor { @@ -61,34 +63,35 @@ public class AWSRequestSigningApacheInterceptor implements HttpRequestIntercepto /** * The particular signer implementation. */ - private final Signer signer; + private final AwsV4HttpSigner signer; /** * The source of AWS credentials for signing. */ - private final AWSCredentialsProvider awsCredentialsProvider; + private final AwsCredentialsProvider awsCredentialsProvider; + private final String region; /** - * - * @param service service that we're connecting to - * @param signer particular signer implementation + * @param service service that we're connecting to + * @param signer particular signer implementation * @param awsCredentialsProvider source of AWS credentials for signing */ public AWSRequestSigningApacheInterceptor(final String service, - final Signer signer, - final AWSCredentialsProvider awsCredentialsProvider) + final AwsV4HttpSigner signer, + final AwsCredentialsProvider awsCredentialsProvider, + final String region) { this.service = service; this.signer = signer; this.awsCredentialsProvider = awsCredentialsProvider; + this.region = region; } /** * {@inheritDoc} */ @Override - public void process(final HttpRequest request, final HttpContext context) - throws HttpException, IOException + public void process(final HttpRequest request, final HttpContext context) throws HttpException, IOException { URIBuilder uriBuilder; try { @@ -98,55 +101,61 @@ public void process(final HttpRequest request, final HttpContext context) throw new IOException("Invalid URI", e); } - // Copy Apache HttpRequest to AWS DefaultRequest - DefaultRequest signableRequest = new DefaultRequest<>(service); - - HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST); - if (host != null) { - signableRequest.setEndpoint(URI.create(host.toURI())); - } - final HttpMethodName httpMethod = - HttpMethodName.fromValue(request.getRequestLine().getMethod()); - signableRequest.setHttpMethod(httpMethod); + // Build the SdkHttpFullRequest + SdkHttpFullRequest.Builder signableRequest = null; try { - signableRequest.setResourcePath(uriBuilder.build().getRawPath()); + signableRequest = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.fromValue(request.getRequestLine().getMethod())) // Set HTTP Method + .encodedPath(uriBuilder.build().getRawPath()) // Set Resource Path + .rawQueryParameters(nvpToMapParams(uriBuilder.getQueryParams())) // Set Query Parameters + .headers(headerArrayToMap(request.getAllHeaders())); } catch (URISyntaxException e) { throw new IOException("Invalid URI", e); } + // Set the endpoint (host) if present in the context + HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST); + if (host != null) { + signableRequest.uri(URI.create(host.toURI())); // Set the base endpoint URL + } + + // Handle content/body if it's an HttpEntityEnclosingRequest if (request instanceof HttpEntityEnclosingRequest) { - HttpEntityEnclosingRequest httpEntityEnclosingRequest = - (HttpEntityEnclosingRequest) request; + HttpEntityEnclosingRequest httpEntityEnclosingRequest = (HttpEntityEnclosingRequest) request; if (httpEntityEnclosingRequest.getEntity() != null) { - signableRequest.setContent(httpEntityEnclosingRequest.getEntity().getContent()); + InputStream contentStream = httpEntityEnclosingRequest.getEntity().getContent(); + signableRequest.contentStreamProvider(() -> contentStream); // Set content provider } else { - // This is a workaround from here: https://github.com/aws/aws-sdk-java/issues/2078 - signableRequest.setContent(new ByteArrayInputStream(new byte[0])); + // Workaround: provide an empty stream if no entity is present + signableRequest.contentStreamProvider(() -> new ByteArrayInputStream(new byte[0])); } } - signableRequest.setParameters(nvpToMapParams(uriBuilder.getQueryParams())); - signableRequest.setHeaders(headerArrayToMap(request.getAllHeaders())); - - // Sign it - signer.sign(signableRequest, awsCredentialsProvider.getCredentials()); - // Now copy everything back - request.setHeaders(mapToHeaderArray(signableRequest.getHeaders())); + // Sign the request + SdkHttpFullRequest.Builder finalSignableRequest = signableRequest; + SignedRequest signedRequest = + signer.sign(r -> r.identity(awsCredentialsProvider.resolveCredentials()) + .request(finalSignableRequest.build()) + .payload(finalSignableRequest.contentStreamProvider()) + .putProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, service) + .putProperty(AwsV4HttpSigner.REGION_NAME, region)); // Required for S3 only + // Now copy everything back to the original request (including signed headers) + request.setHeaders(mapToHeaderArray(signedRequest.request().headers())); + + // If the request has an entity (body), copy it back to the original request if (request instanceof HttpEntityEnclosingRequest) { - HttpEntityEnclosingRequest httpEntityEnclosingRequest = - (HttpEntityEnclosingRequest) request; + HttpEntityEnclosingRequest httpEntityEnclosingRequest = (HttpEntityEnclosingRequest) request; if (httpEntityEnclosingRequest.getEntity() != null) { BasicHttpEntity basicHttpEntity = new BasicHttpEntity(); - basicHttpEntity.setContent(signableRequest.getContent()); + basicHttpEntity.setContent(signableRequest.contentStreamProvider().newStream()); httpEntityEnclosingRequest.setEntity(basicHttpEntity); } } } /** - * * @param params list of HTTP query params as NameValuePairs * @return a multimap of HTTP query params */ @@ -165,12 +174,13 @@ private static Map> nvpToMapParams(final List headerArrayToMap(final Header[] headers) + private static Map> headerArrayToMap(final Header[] headers) { - Map headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + Map> headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); for (Header header : headers) { if (!skipHeader(header)) { - headersMap.put(header.getName(), header.getValue()); + // If the header name already exists, add the new value to the list + headersMap.computeIfAbsent(header.getName(), k -> new ArrayList<>()).add(header.getValue()); } } return headersMap; @@ -191,12 +201,12 @@ private static boolean skipHeader(final Header header) * @param mapHeaders Map of header entries * @return modeled Header objects */ - private static Header[] mapToHeaderArray(final Map mapHeaders) + private static Header[] mapToHeaderArray(final Map> mapHeaders) { Header[] headers = new Header[mapHeaders.size()]; int i = 0; - for (Map.Entry headerEntry : mapHeaders.entrySet()) { - headers[i++] = new BasicHeader(headerEntry.getKey(), headerEntry.getValue()); + for (Map.Entry> headerEntry : mapHeaders.entrySet()) { + headers[i++] = new BasicHeader(headerEntry.getKey(), headerEntry.getValue().get(0)); } return headers; } diff --git a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsElasticsearchFactory.java b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsElasticsearchFactory.java index 1d1dd9eaa4..0cf4c0005b 100644 --- a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsElasticsearchFactory.java +++ b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsElasticsearchFactory.java @@ -7,9 +7,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,8 +19,7 @@ */ package com.amazonaws.athena.connectors.elasticsearch; -import com.amazonaws.services.elasticsearch.AWSElasticsearch; -import com.amazonaws.services.elasticsearch.AWSElasticsearchClientBuilder; +import software.amazon.awssdk.services.elasticsearch.ElasticsearchClient; /** * This factory class provides an AWS ES Client. @@ -31,8 +30,8 @@ public class AwsElasticsearchFactory * Gets a default AWS ES client. * @return default AWS ES client. */ - public AWSElasticsearch getClient() + public ElasticsearchClient getClient() { - return AWSElasticsearchClientBuilder.defaultClient(); + return ElasticsearchClient.create(); } } diff --git a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsRestHighLevelClient.java b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsRestHighLevelClient.java index 3142dfda33..39c4ddd258 100644 --- a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsRestHighLevelClient.java +++ b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsRestHighLevelClient.java @@ -19,8 +19,6 @@ */ package com.amazonaws.athena.connectors.elasticsearch; -import com.amazonaws.auth.AWS4Signer; -import com.amazonaws.auth.AWSCredentialsProvider; import com.google.common.base.Splitter; import org.apache.http.HttpHost; import org.apache.http.HttpRequestInterceptor; @@ -46,6 +44,9 @@ import org.elasticsearch.search.SearchHit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.services.elasticsearch.ElasticsearchClient; import java.io.IOException; import java.util.LinkedHashMap; @@ -196,7 +197,7 @@ public static class Builder { private final String endpoint; private final RestClientBuilder clientBuilder; - private final AWS4Signer signer; + private final AwsV4HttpSigner signer; private final Splitter domainSplitter; /** @@ -207,7 +208,7 @@ public Builder(String endpoint) { this.endpoint = endpoint; this.clientBuilder = RestClient.builder(HttpHost.create(this.endpoint)); - this.signer = new AWS4Signer(); + this.signer = AwsV4HttpSigner.create(); this.domainSplitter = Splitter.on("."); } @@ -216,7 +217,7 @@ public Builder(String endpoint) * @param credentialsProvider is the AWS credentials provider. * @return self. */ - public Builder withCredentials(AWSCredentialsProvider credentialsProvider) + public Builder withCredentials(AwsCredentialsProvider credentialsProvider) { /** * endpoint: @@ -231,16 +232,13 @@ public Builder withCredentials(AWSCredentialsProvider credentialsProvider) */ List domainSplits = domainSplitter.splitToList(endpoint); + HttpRequestInterceptor interceptor; if (domainSplits.size() > 1) { - signer.setRegionName(domainSplits.get(1)); - signer.setServiceName("es"); - } - - HttpRequestInterceptor interceptor = - new AWSRequestSigningApacheInterceptor(signer.getServiceName(), signer, credentialsProvider); + interceptor = new AWSRequestSigningApacheInterceptor(ElasticsearchClient.SERVICE_NAME, signer, credentialsProvider, domainSplits.get(1)); - clientBuilder.setHttpClientConfigCallback(httpClientBuilder -> httpClientBuilder - .addInterceptorLast(interceptor)); + clientBuilder.setHttpClientConfigCallback(httpClientBuilder -> httpClientBuilder + .addInterceptorLast(interceptor)); + } return this; } diff --git a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsRestHighLevelClientFactory.java b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsRestHighLevelClientFactory.java index 422c3884dc..6286d64eda 100644 --- a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsRestHighLevelClientFactory.java +++ b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/AwsRestHighLevelClientFactory.java @@ -19,9 +19,9 @@ */ package com.amazonaws.athena.connectors.elasticsearch; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -100,7 +100,7 @@ private AwsRestHighLevelClient createClient(String endpoint) { if (useAwsCredentials) { return new AwsRestHighLevelClient.Builder(endpoint) - .withCredentials(new DefaultAWSCredentialsProviderChain()).build(); + .withCredentials(DefaultCredentialsProvider.create()).build(); } else { Matcher credentials = credentialsPattern.matcher(endpoint); diff --git a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchDomainMapProvider.java b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchDomainMapProvider.java index 63df6af55c..b3051f842d 100644 --- a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchDomainMapProvider.java +++ b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchDomainMapProvider.java @@ -19,15 +19,14 @@ */ package com.amazonaws.athena.connectors.elasticsearch; -import com.amazonaws.services.elasticsearch.AWSElasticsearch; -import com.amazonaws.services.elasticsearch.model.DescribeElasticsearchDomainsRequest; -import com.amazonaws.services.elasticsearch.model.DescribeElasticsearchDomainsResult; -import com.amazonaws.services.elasticsearch.model.ListDomainNamesRequest; -import com.amazonaws.services.elasticsearch.model.ListDomainNamesResult; import com.google.common.base.Splitter; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.elasticsearch.ElasticsearchClient; +import software.amazon.awssdk.services.elasticsearch.model.DescribeElasticsearchDomainsRequest; +import software.amazon.awssdk.services.elasticsearch.model.DescribeElasticsearchDomainsResponse; +import software.amazon.awssdk.services.elasticsearch.model.ListDomainNamesResponse; import java.util.ArrayList; import java.util.HashMap; @@ -100,14 +99,14 @@ public Map getDomainMap(String domainMapping) private Map getDomainMapFromAmazonElasticsearch() throws RuntimeException { - final AWSElasticsearch awsEsClient = awsElasticsearchFactory.getClient(); + final ElasticsearchClient awsEsClient = awsElasticsearchFactory.getClient(); final Map domainMap = new HashMap<>(); try { - ListDomainNamesResult listDomainNamesResult = awsEsClient.listDomainNames(new ListDomainNamesRequest()); + ListDomainNamesResponse listDomainNamesResponse = awsEsClient.listDomainNames(); List domainNames = new ArrayList<>(); - listDomainNamesResult.getDomainNames().forEach(domainInfo -> - domainNames.add(domainInfo.getDomainName())); + listDomainNamesResponse.domainNames().forEach(domainInfo -> + domainNames.add(domainInfo.domainName())); int startDomainNameIndex = 0; int endDomainNameIndex; @@ -117,13 +116,13 @@ private Map getDomainMapFromAmazonElasticsearch() // DescribeElasticsearchDomains - Describes the domain configuration for up to five specified Amazon // ES domains. Create multiple requests when list of Domain Names > 5. endDomainNameIndex = Math.min(startDomainNameIndex + 5, maxDomainNames); - DescribeElasticsearchDomainsRequest describeDomainsRequest = new DescribeElasticsearchDomainsRequest() - .withDomainNames(domainNames.subList(startDomainNameIndex, endDomainNameIndex)); - DescribeElasticsearchDomainsResult describeDomainsResult = + DescribeElasticsearchDomainsRequest describeDomainsRequest = DescribeElasticsearchDomainsRequest + .builder().domainNames(domainNames.subList(startDomainNameIndex, endDomainNameIndex)).build(); + DescribeElasticsearchDomainsResponse describeDomainsResult = awsEsClient.describeElasticsearchDomains(describeDomainsRequest); - describeDomainsResult.getDomainStatusList().forEach(domainStatus -> { - String domainEndpoint = (domainStatus.getEndpoint() == null) ? domainStatus.getEndpoints().get("vpc") : domainStatus.getEndpoint(); - domainMap.put(domainStatus.getDomainName(), endpointPrefix + domainEndpoint); + describeDomainsResult.domainStatusList().forEach(domainStatus -> { + String domainEndpoint = (domainStatus.endpoint() == null) ? domainStatus.endpoints().get("vpc") : domainStatus.endpoint(); + domainMap.put(domainStatus.domainName(), endpointPrefix + domainEndpoint); }); startDomainNameIndex = endDomainNameIndex; } @@ -138,7 +137,7 @@ private Map getDomainMapFromAmazonElasticsearch() throw new RuntimeException("Unable to create domain map: " + error.getMessage(), error); } finally { - awsEsClient.shutdown(); + awsEsClient.close(); } } diff --git a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchMetadataHandler.java b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchMetadataHandler.java index c51e32c100..836f0635b4 100644 --- a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchMetadataHandler.java +++ b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchMetadataHandler.java @@ -41,9 +41,6 @@ import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connectors.elasticsearch.qpt.ElasticsearchQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; @@ -54,6 +51,9 @@ import org.elasticsearch.client.indices.GetIndexResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.util.Arrays; @@ -109,7 +109,7 @@ public class ElasticsearchMetadataHandler protected static final String INDEX_KEY = "index"; - private final AWSGlue awsGlue; + private final GlueClient awsGlue; private final AwsRestHighLevelClientFactory clientFactory; private final ElasticsearchDomainMapProvider domainMapProvider; @@ -130,10 +130,10 @@ public ElasticsearchMetadataHandler(Map configOptions) @VisibleForTesting protected ElasticsearchMetadataHandler( - AWSGlue awsGlue, + GlueClient awsGlue, EncryptionKeyFactory keyFactory, - AWSSecretsManager awsSecretsManager, - AmazonAthena athena, + SecretsManagerClient awsSecretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, ElasticsearchDomainMapProvider domainMapProvider, diff --git a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java index 2307ddcd46..1d90956ad1 100644 --- a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java +++ b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java @@ -27,12 +27,6 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.elasticsearch.qpt.ElasticsearchQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; import org.elasticsearch.action.search.ClearScrollRequest; @@ -48,6 +42,9 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.util.Iterator; @@ -91,8 +88,8 @@ public class ElasticsearchRecordHandler public ElasticsearchRecordHandler(Map configOptions) { - super(AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), SOURCE_TYPE, configOptions); + super(S3Client.create(), SecretsManagerClient.create(), + AthenaClient.create(), SOURCE_TYPE, configOptions); this.typeUtils = new ElasticsearchTypeUtils(); this.clientFactory = new AwsRestHighLevelClientFactory(configOptions.getOrDefault(AUTO_DISCOVER_ENDPOINT, "").equalsIgnoreCase("true")); @@ -102,9 +99,9 @@ public ElasticsearchRecordHandler(Map configOptions) @VisibleForTesting protected ElasticsearchRecordHandler( - AmazonS3 amazonS3, - AWSSecretsManager secretsManager, - AmazonAthena amazonAthena, + S3Client amazonS3, + SecretsManagerClient secretsManager, + AthenaClient amazonAthena, AwsRestHighLevelClientFactory clientFactory, long queryTimeout, long scrollTimeout, diff --git a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchDomainMapProviderTest.java b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchDomainMapProviderTest.java index 2f2895d34a..7dfd58be3d 100644 --- a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchDomainMapProviderTest.java +++ b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchDomainMapProviderTest.java @@ -7,9 +7,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,25 +19,24 @@ */ package com.amazonaws.athena.connectors.elasticsearch; -import com.amazonaws.services.elasticsearch.AWSElasticsearch; -import com.amazonaws.services.elasticsearch.model.DomainInfo; -import com.amazonaws.services.elasticsearch.model.DescribeElasticsearchDomainsRequest; -import com.amazonaws.services.elasticsearch.model.DescribeElasticsearchDomainsResult; -import com.amazonaws.services.elasticsearch.model.ElasticsearchDomainStatus; -import com.amazonaws.services.elasticsearch.model.ListDomainNamesResult; import com.google.common.collect.ImmutableList; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.elasticsearch.ElasticsearchClient; +import software.amazon.awssdk.services.elasticsearch.model.DescribeElasticsearchDomainsRequest; +import software.amazon.awssdk.services.elasticsearch.model.DescribeElasticsearchDomainsResponse; +import software.amazon.awssdk.services.elasticsearch.model.DomainInfo; +import software.amazon.awssdk.services.elasticsearch.model.ElasticsearchDomainStatus; +import software.amazon.awssdk.services.elasticsearch.model.ListDomainNamesResponse; import java.util.ArrayList; import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -83,40 +82,40 @@ public void getDomainMapFromAwsElasticsearchTest() AwsElasticsearchFactory mockElasticsearchFactory = mock(AwsElasticsearchFactory.class); ElasticsearchDomainMapProvider domainProvider = new ElasticsearchDomainMapProvider(true, mockElasticsearchFactory); - AWSElasticsearch mockClient = mock(AWSElasticsearch.class); - ListDomainNamesResult mockDomainInfo = mock(ListDomainNamesResult.class); + ElasticsearchClient mockClient = mock(ElasticsearchClient.class); + ListDomainNamesResponse mockDomainInfo = mock(ListDomainNamesResponse.class); List domainNames = ImmutableList.of("domain1", "domain2", "domain3", "domain4", "domain5", "domain6"); List domainInfo = new ArrayList<>(); List domainStatus = new ArrayList<>(); domainNames.forEach(domainName -> { - domainInfo.add(new DomainInfo().withDomainName(domainName)); - domainStatus.add(new ElasticsearchDomainStatus() - .withDomainName(domainName) - .withEndpoint("www.domain." + domainName)); + domainInfo.add(DomainInfo.builder().domainName(domainName).build()); + domainStatus.add(ElasticsearchDomainStatus.builder() + .domainName(domainName) + .endpoint("www.domain." + domainName).build()); }); when(mockElasticsearchFactory.getClient()).thenReturn(mockClient); - when(mockClient.listDomainNames(any())).thenReturn(mockDomainInfo); - when(mockDomainInfo.getDomainNames()).thenReturn(domainInfo); - - when(mockClient.describeElasticsearchDomains(new DescribeElasticsearchDomainsRequest() - .withDomainNames(domainNames.subList(0, 5)))) - .thenReturn(new DescribeElasticsearchDomainsResult() - .withDomainStatusList(domainStatus.subList(0, 5))); - when(mockClient.describeElasticsearchDomains(new DescribeElasticsearchDomainsRequest() - .withDomainNames(domainNames.subList(5, 6)))) - .thenReturn(new DescribeElasticsearchDomainsResult() - .withDomainStatusList(domainStatus.subList(5, 6))); + when(mockClient.listDomainNames()).thenReturn(mockDomainInfo); + when(mockDomainInfo.domainNames()).thenReturn(domainInfo); + + when(mockClient.describeElasticsearchDomains(DescribeElasticsearchDomainsRequest.builder(). + domainNames(domainNames.subList(0, 5)).build())) + .thenReturn(DescribeElasticsearchDomainsResponse.builder() + .domainStatusList(domainStatus.subList(0, 5)).build()); + when(mockClient.describeElasticsearchDomains(DescribeElasticsearchDomainsRequest.builder() + .domainNames(domainNames.subList(5, 6)).build())) + .thenReturn(DescribeElasticsearchDomainsResponse.builder() + .domainStatusList(domainStatus.subList(5, 6)).build()); Map domainMap = domainProvider.getDomainMap(null); logger.info("Domain Map: {}", domainMap); - verify(mockClient).describeElasticsearchDomains(new DescribeElasticsearchDomainsRequest() - .withDomainNames(domainNames.subList(0, 5))); - verify(mockClient).describeElasticsearchDomains(new DescribeElasticsearchDomainsRequest() - .withDomainNames(domainNames.subList(5, 6))); + verify(mockClient).describeElasticsearchDomains(DescribeElasticsearchDomainsRequest.builder() + .domainNames(domainNames.subList(0, 5)).build()); + verify(mockClient).describeElasticsearchDomains(DescribeElasticsearchDomainsRequest.builder() + .domainNames(domainNames.subList(5, 6)).build()); assertEquals("Invalid number of domains.", domainNames.size(), domainMap.size()); domainNames.forEach(domainName -> { diff --git a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchMetadataHandlerTest.java b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchMetadataHandlerTest.java index ce3c50ce6f..cddfce30a4 100644 --- a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchMetadataHandlerTest.java +++ b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchMetadataHandlerTest.java @@ -28,9 +28,6 @@ import com.amazonaws.athena.connector.lambda.metadata.*; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -54,6 +51,9 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.util.ArrayList; @@ -90,13 +90,13 @@ public class ElasticsearchMetadataHandlerTest private BlockAllocatorImpl allocator; @Mock - private AWSGlue awsGlue; + private GlueClient awsGlue; @Mock - private AWSSecretsManager awsSecretsManager; + private SecretsManagerClient awsSecretsManager; @Mock - private AmazonAthena amazonAthena; + private AthenaClient amazonAthena; @Mock private AwsRestHighLevelClient mockClient; diff --git a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java index f7db1f124c..1336badd71 100644 --- a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java +++ b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java @@ -37,13 +37,6 @@ import com.amazonaws.athena.connector.lambda.records.RecordResponse; import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -67,6 +60,16 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -116,19 +119,13 @@ public class ElasticsearchRecordHandlerTest private SearchResponse mockScrollResponse; @Mock - private AmazonS3 amazonS3; - - @Mock - private AWSSecretsManager awsSecretsManager; - - @Mock - private AmazonAthena athena; + private S3Client amazonS3; @Mock - PutObjectResult putObjectResult; + private SecretsManagerClient awsSecretsManager; @Mock - S3Object s3Object; + private AthenaClient athena; String[] expectedDocuments = {"[mytext : My favorite Sci-Fi movie is Interstellar.], [mykeyword : I love keywords.], [mylong : {11,12,13}], [myinteger : 666115], [myshort : 1972], [mybyte : 5], [mydouble : 47.5], [myscaled : 7], [myfloat : 5.6], [myhalf : 6.2], [mydatemilli : 2020-05-15T06:49:30], [mydatenano : {2020-05-15T06:50:01.457}], [myboolean : true], [mybinary : U29tZSBiaW5hcnkgYmxvYg==], [mynested : {[l1long : 357345987],[l1date : 2020-05-15T06:57:44.123],[l1nested : {[l2short : {1,2,3,4,5,6,7,8,9,10}],[l2binary : U29tZSBiaW5hcnkgYmxvYg==]}]}], [objlistouter : {}]" ,"[mytext : My favorite TV comedy is Seinfeld.], [mykeyword : I hate key-values.], [mylong : {14,null,16}], [myinteger : 732765666], [myshort : 1971], [mybyte : 7], [mydouble : 27.6], [myscaled : 10], [myfloat : 7.8], [myhalf : 7.3], [mydatemilli : null], [mydatenano : {2020-05-15T06:49:30.001}], [myboolean : false], [mybinary : U29tZSBiaW5hcnkgYmxvYg==], [mynested : {[l1long : 7322775555],[l1date : 2020-05-15T01:57:44.777],[l1nested : {[l2short : {11,12,13,14,15,16,null,18,19,20}],[l2binary : U29tZSBiaW5hcnkgYmxvYg==]}]}], [objlistouter : {{[objlistinner : {{[title : somebook],[hi : hi]}}],[test2 : title]}}]"}; @@ -276,31 +273,27 @@ public void setUp() allocator = new BlockAllocatorImpl(); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); spillReader = new S3BlockSpillReader(amazonS3, allocator); diff --git a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/integ/ElasticsearchIntegTest.java b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/integ/ElasticsearchIntegTest.java index bb5fef0426..d4d6a28885 100644 --- a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/integ/ElasticsearchIntegTest.java +++ b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/integ/ElasticsearchIntegTest.java @@ -21,7 +21,6 @@ import com.amazonaws.athena.connector.integ.IntegrationTestBase; import com.amazonaws.athena.connector.integ.clients.CloudFormationClient; -import com.amazonaws.services.athena.model.Row; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.slf4j.Logger; @@ -38,6 +37,7 @@ import software.amazon.awscdk.services.iam.Effect; import software.amazon.awscdk.services.iam.PolicyDocument; import software.amazon.awscdk.services.iam.PolicyStatement; +import software.amazon.awssdk.services.athena.model.Row; import java.util.ArrayList; import java.util.List; @@ -267,13 +267,13 @@ public void selectColumnWithPredicateIntegTest() String query = String.format("select title from %s.%s.%s where year > 2000;", lambdaFunctionName, domainName, index); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List titles = new ArrayList<>(); - rows.forEach(row -> titles.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> titles.add(row.data().get(0).varCharValue())); logger.info("Titles: {}", titles); assertEquals("Wrong number of DB records found.", 1, titles.size()); assertTrue("Movie title not found: Interstellar.", titles.contains("Interstellar")); @@ -288,13 +288,13 @@ public void selectFloat8TypeTest() String query = String.format("select float8_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Double.valueOf(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(Double.valueOf(row.data().get(0).varCharValue()))); AssertJUnit.assertEquals("Wrong number of DB records found.", 1, values.size()); AssertJUnit.assertTrue("Float8 not found: " + 1e-32, values.contains(1e-32)); } diff --git a/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleMetadataHandler.java b/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleMetadataHandler.java index 21d2594da7..58f165e7e3 100644 --- a/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleMetadataHandler.java +++ b/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleMetadataHandler.java @@ -40,8 +40,6 @@ import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse; import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.complex.reader.FieldReader; //DO NOT REMOVE - this will not be _unused_ when customers go through the tutorial and uncomment @@ -49,6 +47,8 @@ import org.apache.arrow.vector.types.Types; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.Comparator; @@ -92,8 +92,8 @@ public ExampleMetadataHandler(java.util.Map configOptions) @VisibleForTesting protected ExampleMetadataHandler( EncryptionKeyFactory keyFactory, - AWSSecretsManager awsSecretsManager, - AmazonAthena athena, + SecretsManagerClient awsSecretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, java.util.Map configOptions) diff --git a/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleRecordHandler.java b/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleRecordHandler.java index cc895cade2..402420e0bb 100644 --- a/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleRecordHandler.java +++ b/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleRecordHandler.java @@ -34,18 +34,18 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.ConstraintProjector; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.holders.NullableIntHolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.NoSuchKeyException; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.BufferedReader; import java.io.IOException; @@ -77,15 +77,15 @@ public class ExampleRecordHandler */ private static final String SOURCE_TYPE = "example"; - private AmazonS3 amazonS3; + private S3Client amazonS3; public ExampleRecordHandler(java.util.Map configOptions) { - this(AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), configOptions); + this(S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), configOptions); } @VisibleForTesting - protected ExampleRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena amazonAthena, java.util.Map configOptions) + protected ExampleRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, java.util.Map configOptions) { super(amazonS3, secretsManager, amazonAthena, SOURCE_TYPE, configOptions); this.amazonS3 = amazonS3; @@ -230,10 +230,13 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor private BufferedReader openS3File(String bucket, String key) { logger.info("openS3File: opening file " + bucket + ":" + key); - if (amazonS3.doesObjectExist(bucket, key)) { - S3Object obj = amazonS3.getObject(bucket, key); + try { + ResponseInputStream responseStream = amazonS3.getObject(GetObjectRequest.builder().bucket(bucket).key(key).build()); logger.info("openS3File: opened file " + bucket + ":" + key); - return new BufferedReader(new InputStreamReader(obj.getObjectContent())); + return new BufferedReader(new InputStreamReader(responseStream)); + } + catch (NoSuchKeyException e) { + logger.error("openS3File: failed to open file " + bucket + ":" + key, e); } return null; } diff --git a/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleMetadataHandlerTest.java b/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleMetadataHandlerTest.java index 5f8ff32501..301ba9f167 100644 --- a/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleMetadataHandlerTest.java +++ b/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleMetadataHandlerTest.java @@ -43,8 +43,6 @@ import com.amazonaws.athena.connector.lambda.metadata.MetadataResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -53,6 +51,8 @@ import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.Collections; @@ -72,8 +72,8 @@ public class ExampleMetadataHandlerTest private static final Logger logger = LoggerFactory.getLogger(ExampleMetadataHandlerTest.class); private ExampleMetadataHandler handler = new ExampleMetadataHandler(new LocalKeyFactory(), - mock(AWSSecretsManager.class), - mock(AmazonAthena.class), + mock(SecretsManagerClient.class), + mock(AthenaClient.class), "spill-bucket", "spill-prefix", com.google.common.collect.ImmutableMap.of()); diff --git a/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleRecordHandlerTest.java b/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleRecordHandlerTest.java index 2d597c4632..de2b30524b 100644 --- a/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleRecordHandlerTest.java +++ b/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleRecordHandlerTest.java @@ -33,11 +33,6 @@ import com.amazonaws.athena.connector.lambda.records.ReadRecordsResponse; import com.amazonaws.athena.connector.lambda.records.RecordResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.After; @@ -49,6 +44,12 @@ import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; import java.io.UnsupportedEncodingException; @@ -59,6 +60,7 @@ import static com.amazonaws.athena.connector.lambda.domain.predicate.Constraints.DEFAULT_NO_LIMIT; import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -72,9 +74,9 @@ public class ExampleRecordHandlerTest System.getenv("publishing").equalsIgnoreCase("true"); private BlockAllocatorImpl allocator; private Schema schemaForRead; - private AmazonS3 amazonS3; - private AWSSecretsManager awsSecretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient awsSecretsManager; + private AthenaClient athena; private S3BlockSpillReader spillReader; @Rule @@ -105,23 +107,18 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); - awsSecretsManager = mock(AWSSecretsManager.class); - athena = mock(AmazonAthena.class); + amazonS3 = mock(S3Client.class); + awsSecretsManager = mock(SecretsManagerClient.class); + athena = mock(AthenaClient.class); - when(amazonS3.doesObjectExist(nullable(String.class), nullable(String.class))).thenReturn(true); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - S3Object mockObject = mock(S3Object.class); - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(getFakeObject()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(getFakeObject())); } }); diff --git a/athena-federation-integ-test/pom.xml b/athena-federation-integ-test/pom.xml index f06fc32822..bee6d6be9d 100644 --- a/athena-federation-integ-test/pom.xml +++ b/athena-federation-integ-test/pom.xml @@ -11,60 +11,6 @@ jar Amazon Athena Query Federation Integ Test - - com.amazonaws - jmespath-java - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - - - com.amazonaws - aws-java-sdk-core - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - commons-cli commons-cli @@ -109,29 +55,28 @@ guava ${guava.version} + + software.amazon.awssdk + secretsmanager + ${aws-sdk-v2.version} + org.testng testng ${testng.version} - - - com.amazonaws - aws-java-sdk-athena - ${aws-sdk.version} - - + - com.amazonaws - aws-java-sdk-cloudformation - ${aws-sdk.version} + software.amazon.awssdk + athena + ${aws-sdk-v2.version} - + - com.amazonaws - aws-java-sdk-secretsmanager - ${aws-sdk.version} + software.amazon.awssdk + cloudformation + ${aws-sdk-v2.version} diff --git a/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/IntegrationTestBase.java b/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/IntegrationTestBase.java index bb2b7c5822..880debe5bc 100644 --- a/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/IntegrationTestBase.java +++ b/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/IntegrationTestBase.java @@ -25,18 +25,6 @@ import com.amazonaws.athena.connector.integ.data.TestConfig; import com.amazonaws.athena.connector.integ.providers.ConnectorVpcAttributesProvider; import com.amazonaws.athena.connector.integ.providers.SecretsManagerCredentialsProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.athena.model.Datum; -import com.amazonaws.services.athena.model.GetQueryExecutionRequest; -import com.amazonaws.services.athena.model.GetQueryExecutionResult; -import com.amazonaws.services.athena.model.GetQueryResultsRequest; -import com.amazonaws.services.athena.model.GetQueryResultsResult; -import com.amazonaws.services.athena.model.ListDatabasesRequest; -import com.amazonaws.services.athena.model.ListDatabasesResult; -import com.amazonaws.services.athena.model.ResultConfiguration; -import com.amazonaws.services.athena.model.Row; -import com.amazonaws.services.athena.model.StartQueryExecutionRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.annotations.AfterClass; @@ -44,6 +32,17 @@ import org.testng.annotations.Test; import software.amazon.awscdk.core.Stack; import software.amazon.awscdk.services.iam.PolicyDocument; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.athena.model.Datum; +import software.amazon.awssdk.services.athena.model.GetQueryExecutionRequest; +import software.amazon.awssdk.services.athena.model.GetQueryExecutionResponse; +import software.amazon.awssdk.services.athena.model.GetQueryResultsRequest; +import software.amazon.awssdk.services.athena.model.GetQueryResultsResponse; +import software.amazon.awssdk.services.athena.model.ListDatabasesRequest; +import software.amazon.awssdk.services.athena.model.ListDatabasesResponse; +import software.amazon.awssdk.services.athena.model.ResultConfiguration; +import software.amazon.awssdk.services.athena.model.Row; +import software.amazon.awssdk.services.athena.model.StartQueryExecutionRequest; import java.time.LocalDate; import java.time.LocalDateTime; @@ -94,7 +93,7 @@ public abstract class IntegrationTestBase private final ConnectorStackProvider connectorStackProvider; private final String lambdaFunctionName; - private final AmazonAthena athenaClient; + private final AthenaClient athenaClient; private final TestConfig testConfig; private final Optional vpcAttributes; private final Optional secretCredentials; @@ -128,7 +127,7 @@ protected void setSpecificResource(final Stack stack) }; lambdaFunctionName = connectorStackProvider.getLambdaFunctionName(); - athenaClient = AmazonAthenaClientBuilder.defaultClient(); + athenaClient = AthenaClient.create(); athenaWorkgroup = getAthenaWorkgroup(); athenaResultLocation = getAthenaResultLocation(); } @@ -262,14 +261,15 @@ protected void cleanUp() public List listDatabases() { logger.info("listDatabases({})", lambdaFunctionName); - ListDatabasesRequest listDatabasesRequest = new ListDatabasesRequest() - .withCatalogName(lambdaFunctionName); + ListDatabasesRequest listDatabasesRequest = ListDatabasesRequest.builder() + .catalogName(lambdaFunctionName) + .build(); - ListDatabasesResult listDatabasesResult = athenaClient.listDatabases(listDatabasesRequest); - logger.info("Results: [{}]", listDatabasesResult); + ListDatabasesResponse listDatabasesResponse = athenaClient.listDatabases(listDatabasesRequest); + logger.info("Results: [{}]", listDatabasesResponse); List dbNames = new ArrayList<>(); - listDatabasesResult.getDatabaseList().forEach(db -> dbNames.add(db.getName())); + listDatabasesResponse.databaseList().forEach(db -> dbNames.add(db.name())); return dbNames; } @@ -285,8 +285,8 @@ public List listTables(String databaseName) { String query = String.format("show tables in `%s`.`%s`;", lambdaFunctionName, databaseName); List tableNames = new ArrayList<>(); - startQueryExecution(query).getResultSet().getRows() - .forEach(row -> tableNames.add(row.getData().get(0).getVarCharValue())); + startQueryExecution(query).resultSet().rows() + .forEach(row -> tableNames.add(row.data().get(0).varCharValue())); return tableNames; } @@ -303,9 +303,9 @@ public Map describeTable(String databaseName, String tableName) { String query = String.format("describe `%s`.`%s`.`%s`;", lambdaFunctionName, databaseName, tableName); Map schema = new HashMap<>(); - startQueryExecution(query).getResultSet().getRows() + startQueryExecution(query).resultSet().rows() .forEach(row -> { - String property = row.getData().get(0).getVarCharValue(); + String property = row.data().get(0).varCharValue(); String[] columnProperties = property.split("\t"); if (columnProperties.length == 2) { schema.put(columnProperties[0], columnProperties[1]); @@ -321,21 +321,22 @@ public Map describeTable(String databaseName, String tableName) * @return The query results object containing the metadata and row information. * @throws RuntimeException The Query is cancelled or has failed. */ - public GetQueryResultsResult startQueryExecution(String query) + public GetQueryResultsResponse startQueryExecution(String query) throws RuntimeException { - StartQueryExecutionRequest startQueryExecutionRequest = new StartQueryExecutionRequest() - .withWorkGroup(athenaWorkgroup) - .withQueryString(query) - .withResultConfiguration(new ResultConfiguration().withOutputLocation(athenaResultLocation)); + StartQueryExecutionRequest startQueryExecutionRequest = StartQueryExecutionRequest.builder() + .workGroup(athenaWorkgroup) + .queryString(query) + .resultConfiguration(ResultConfiguration.builder().outputLocation(athenaResultLocation).build()) + .build(); String queryExecutionId = sendAthenaQuery(startQueryExecutionRequest); logger.info("Query: [{}], Query Id: [{}]", query, queryExecutionId); waitForAthenaQueryResults(queryExecutionId); - GetQueryResultsResult getQueryResultsResult = getAthenaQueryResults(queryExecutionId); + GetQueryResultsResponse getQueryResultsResponse = getAthenaQueryResults(queryExecutionId); //logger.info("Results: [{}]", getQueryResultsResult.toString()); - return getQueryResultsResult; + return getQueryResultsResponse; } /** @@ -345,7 +346,7 @@ public GetQueryResultsResult startQueryExecution(String query) */ private String sendAthenaQuery(StartQueryExecutionRequest startQueryExecutionRequest) { - return athenaClient.startQueryExecution(startQueryExecutionRequest).getQueryExecutionId(); + return athenaClient.startQueryExecution(startQueryExecutionRequest).queryExecutionId(); } /** @@ -357,12 +358,13 @@ private void waitForAthenaQueryResults(String queryExecutionId) throws RuntimeException { // Poll the state of the query request while it is queued or running - GetQueryExecutionRequest getQueryExecutionRequest = new GetQueryExecutionRequest() - .withQueryExecutionId(queryExecutionId); - GetQueryExecutionResult getQueryExecutionResult; + GetQueryExecutionRequest getQueryExecutionRequest = GetQueryExecutionRequest.builder() + .queryExecutionId(queryExecutionId) + .build(); + GetQueryExecutionResponse getQueryExecutionResponse; while (true) { - getQueryExecutionResult = athenaClient.getQueryExecution(getQueryExecutionRequest); - String queryState = getQueryExecutionResult.getQueryExecution().getStatus().getState(); + getQueryExecutionResponse = athenaClient.getQueryExecution(getQueryExecutionRequest); + String queryState = getQueryExecutionResponse.queryExecution().status().state().toString(); logger.info("Query State: {}", queryState); if (queryState.equals(ATHENA_QUERY_QUEUED_STATE) || queryState.equals(ATHENA_QUERY_RUNNING_STATE)) { try { @@ -374,8 +376,8 @@ private void waitForAthenaQueryResults(String queryExecutionId) } } else if (queryState.equals(ATHENA_QUERY_FAILED_STATE) || queryState.equals(ATHENA_QUERY_CANCELLED_STATE)) { - throw new RuntimeException(getQueryExecutionResult - .getQueryExecution().getStatus().getStateChangeReason()); + throw new RuntimeException(getQueryExecutionResponse + .queryExecution().status().stateChangeReason()); } break; } @@ -386,11 +388,12 @@ else if (queryState.equals(ATHENA_QUERY_FAILED_STATE) || queryState.equals(ATHEN * @param queryExecutionId The query's Id. * @return The query results object containing the metadata and row information. */ - private GetQueryResultsResult getAthenaQueryResults(String queryExecutionId) + private GetQueryResultsResponse getAthenaQueryResults(String queryExecutionId) { // Get query results - GetQueryResultsRequest getQueryResultsRequest = new GetQueryResultsRequest() - .withQueryExecutionId(queryExecutionId); + GetQueryResultsRequest getQueryResultsRequest = GetQueryResultsRequest.builder() + .queryExecutionId(queryExecutionId) + .build(); return athenaClient.getQueryResults(getQueryResultsRequest); } @@ -472,8 +475,8 @@ public float calculateThroughput(String lambdaFnName, String schemaName, String public List processQuery(String query) { List firstColValues = new ArrayList<>(); - skipColumnHeaderRow(startQueryExecution(query).getResultSet().getRows()) - .forEach(row -> firstColValues.add(row.getData().get(0).getVarCharValue())); + skipColumnHeaderRow(startQueryExecution(query).resultSet().rows()) + .forEach(row -> firstColValues.add(row.data().get(0).varCharValue())); return firstColValues; } public List skipColumnHeaderRow(List rows) @@ -493,13 +496,13 @@ public void selectIntegerTypeTest() String query = String.format("select int_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Integer.parseInt(row.getData().get(0).getVarCharValue().split("\\.")[0]))); + rows.forEach(row -> values.add(Integer.parseInt(row.data().get(0).varCharValue().split("\\.")[0]))); logger.info("Titles: {}", values); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Integer not found: " + TEST_DATATYPES_INT_VALUE, values.contains(TEST_DATATYPES_INT_VALUE)); @@ -514,13 +517,13 @@ public void selectVarcharTypeTest() String query = String.format("select varchar_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> values.add(row.data().get(0).varCharValue())); logger.info("Titles: {}", values); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Varchar not found: " + TEST_DATATYPES_VARCHAR_VALUE, values.contains(TEST_DATATYPES_VARCHAR_VALUE)); @@ -535,13 +538,13 @@ public void selectBooleanTypeTest() String query = String.format("select boolean_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Boolean.valueOf(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(Boolean.valueOf(row.data().get(0).varCharValue()))); logger.info("Titles: {}", values); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Boolean not found: " + TEST_DATATYPES_BOOLEAN_VALUE, values.contains(TEST_DATATYPES_BOOLEAN_VALUE)); @@ -556,13 +559,13 @@ public void selectSmallintTypeTest() String query = String.format("select smallint_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Short.valueOf(row.getData().get(0).getVarCharValue().split("\\.")[0]))); + rows.forEach(row -> values.add(Short.valueOf(row.data().get(0).varCharValue().split("\\.")[0]))); logger.info("Titles: {}", values); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Smallint not found: " + TEST_DATATYPES_SHORT_VALUE, values.contains(TEST_DATATYPES_SHORT_VALUE)); @@ -577,13 +580,13 @@ public void selectBigintTypeTest() String query = String.format("select bigint_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Long.valueOf(row.getData().get(0).getVarCharValue().split("\\.")[0]))); + rows.forEach(row -> values.add(Long.valueOf(row.data().get(0).varCharValue().split("\\.")[0]))); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Bigint not found: " + TEST_DATATYPES_LONG_VALUE, values.contains(TEST_DATATYPES_LONG_VALUE)); } @@ -597,13 +600,13 @@ public void selectFloat4TypeTest() String query = String.format("select float4_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Float.valueOf(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(Float.valueOf(row.data().get(0).varCharValue()))); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Float4 not found: " + TEST_DATATYPES_SINGLE_PRECISION_VALUE, values.contains(TEST_DATATYPES_SINGLE_PRECISION_VALUE)); } @@ -617,13 +620,13 @@ public void selectFloat8TypeTest() String query = String.format("select float8_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Double.valueOf(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(Double.valueOf(row.data().get(0).varCharValue()))); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Float8 not found: " + TEST_DATATYPES_DOUBLE_PRECISION_VALUE, values.contains(TEST_DATATYPES_DOUBLE_PRECISION_VALUE)); } @@ -637,13 +640,13 @@ public void selectDateTypeTest() String query = String.format("select date_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(LocalDate.parse(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(LocalDate.parse(row.data().get(0).varCharValue()))); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Date not found: " + TEST_DATATYPES_DATE_VALUE, values.contains(LocalDate.parse(TEST_DATATYPES_DATE_VALUE))); } @@ -657,15 +660,15 @@ public void selectTimestampTypeTest() String query = String.format("select timestamp_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); // for some reason, timestamps lose their 'T'. - rows.forEach(row -> values.add(LocalDateTime.parse(row.getData().get(0).getVarCharValue().replace(' ', 'T')))); - logger.info(rows.get(0).getData().get(0).getVarCharValue()); + rows.forEach(row -> values.add(LocalDateTime.parse(row.data().get(0).varCharValue().replace(' ', 'T')))); + logger.info(rows.get(0).data().get(0).varCharValue()); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Date not found: " + TEST_DATATYPES_TIMESTAMP_VALUE, values.contains(LocalDateTime.parse(TEST_DATATYPES_TIMESTAMP_VALUE))); } @@ -679,20 +682,19 @@ public void selectByteArrayTypeTest() String query = String.format("select byte_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(row.getData().get(0).getVarCharValue())); - Datum actual = rows.get(0).getData().get(0); - Datum expected = new Datum(); - expected.setVarCharValue("deadbeef"); - logger.info(rows.get(0).getData().get(0).getVarCharValue()); + rows.forEach(row -> values.add(row.data().get(0).varCharValue())); + Datum actual = rows.get(0).data().get(0); + Datum expected = Datum.builder().varCharValue("deadbeef").build(); + logger.info(rows.get(0).data().get(0).varCharValue()); assertEquals("Wrong number of DB records found.", 1, values.size()); - String bytestring = actual.getVarCharValue().replace(" ", ""); - assertEquals("Byte[] not found: " + Arrays.toString(TEST_DATATYPES_BYTE_ARRAY_VALUE), expected.getVarCharValue(), bytestring); + String bytestring = actual.varCharValue().replace(" ", ""); + assertEquals("Byte[] not found: " + Arrays.toString(TEST_DATATYPES_BYTE_ARRAY_VALUE), expected.varCharValue(), bytestring); } @Test @@ -704,17 +706,16 @@ public void selectVarcharListTypeTest() String query = String.format("select textarray_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(row.getData().get(0).getVarCharValue())); - Datum actual = rows.get(0).getData().get(0); - Datum expected = new Datum(); - expected.setVarCharValue(TEST_DATATYPES_VARCHAR_ARRAY_VALUE); - logger.info(rows.get(0).getData().get(0).getVarCharValue()); + rows.forEach(row -> values.add(row.data().get(0).varCharValue())); + Datum actual = rows.get(0).data().get(0); + Datum expected = Datum.builder().varCharValue(TEST_DATATYPES_VARCHAR_ARRAY_VALUE).build(); + logger.info(rows.get(0).data().get(0).varCharValue()); assertEquals("Wrong number of DB records found.", 1, values.size()); assertEquals("List not found: " + TEST_DATATYPES_VARCHAR_ARRAY_VALUE, expected, actual); } @@ -728,13 +729,13 @@ public void selectNullValueTest() String query = String.format("select int_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_NULL_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } - Datum actual = rows.get(0).getData().get(0); - assertNull("Value not 'null'. Received: " + actual.getVarCharValue(), actual.getVarCharValue()); + Datum actual = rows.get(0).data().get(0); + assertNull("Value not 'null'. Received: " + actual.varCharValue(), actual.varCharValue()); } @Test @@ -746,7 +747,7 @@ public void selectEmptyTableTest() String query = String.format("select int_type from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_EMPTY_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); diff --git a/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/clients/CloudFormationClient.java b/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/clients/CloudFormationClient.java index af5a92f9b7..37b290f0ad 100644 --- a/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/clients/CloudFormationClient.java +++ b/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/clients/CloudFormationClient.java @@ -19,15 +19,6 @@ */ package com.amazonaws.athena.connector.integ.clients; -import com.amazonaws.services.cloudformation.AmazonCloudFormation; -import com.amazonaws.services.cloudformation.AmazonCloudFormationClientBuilder; -import com.amazonaws.services.cloudformation.model.Capability; -import com.amazonaws.services.cloudformation.model.CreateStackRequest; -import com.amazonaws.services.cloudformation.model.CreateStackResult; -import com.amazonaws.services.cloudformation.model.DeleteStackRequest; -import com.amazonaws.services.cloudformation.model.DescribeStackEventsRequest; -import com.amazonaws.services.cloudformation.model.DescribeStackEventsResult; -import com.amazonaws.services.cloudformation.model.StackEvent; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; import org.slf4j.Logger; @@ -35,6 +26,14 @@ import org.testng.internal.collections.Pair; import software.amazon.awscdk.core.App; import software.amazon.awscdk.core.Stack; +import software.amazon.awssdk.services.cloudformation.model.Capability; +import software.amazon.awssdk.services.cloudformation.model.CreateStackRequest; +import software.amazon.awssdk.services.cloudformation.model.CreateStackResponse; +import software.amazon.awssdk.services.cloudformation.model.DeleteStackRequest; +import software.amazon.awssdk.services.cloudformation.model.DescribeStackEventsRequest; +import software.amazon.awssdk.services.cloudformation.model.DescribeStackEventsResponse; +import software.amazon.awssdk.services.cloudformation.model.ResourceStatus; +import software.amazon.awssdk.services.cloudformation.model.StackEvent; import java.util.List; @@ -46,13 +45,11 @@ public class CloudFormationClient { private static final Logger logger = LoggerFactory.getLogger(CloudFormationClient.class); - private static final String CF_CREATE_RESOURCE_IN_PROGRESS_STATUS = "CREATE_IN_PROGRESS"; - private static final String CF_CREATE_RESOURCE_FAILED_STATUS = "CREATE_FAILED"; private static final long sleepTimeMillis = 5000L; private final String stackName; private final String stackTemplate; - private final AmazonCloudFormation cloudFormationClient; + private final software.amazon.awssdk.services.cloudformation.CloudFormationClient cloudFormationClient; public CloudFormationClient(Pair stackPair) { @@ -66,7 +63,7 @@ public CloudFormationClient(App theApp, Stack theStack) stackTemplate = objectMapper .valueToTree(theApp.synth().getStackArtifact(theStack.getArtifactId()).getTemplate()) .toPrettyString(); - this.cloudFormationClient = AmazonCloudFormationClientBuilder.defaultClient(); + this.cloudFormationClient = software.amazon.awssdk.services.cloudformation.CloudFormationClient.create(); } /** @@ -81,11 +78,12 @@ public void createStack() logger.info("------------------------------------------------------"); // logger.info(stackTemplate); - CreateStackRequest createStackRequest = new CreateStackRequest() - .withStackName(stackName) - .withTemplateBody(stackTemplate) - .withDisableRollback(true) - .withCapabilities(Capability.CAPABILITY_NAMED_IAM); + CreateStackRequest createStackRequest = CreateStackRequest.builder() + .stackName(stackName) + .templateBody(stackTemplate) + .disableRollback(true) + .capabilities(Capability.CAPABILITY_NAMED_IAM) + .build(); processCreateStackRequest(createStackRequest); } @@ -98,22 +96,23 @@ private void processCreateStackRequest(CreateStackRequest createStackRequest) throws RuntimeException { // Create CloudFormation stack. - CreateStackResult result = cloudFormationClient.createStack(createStackRequest); - logger.info("Stack ID: {}", result.getStackId()); + CreateStackResponse response = cloudFormationClient.createStack(createStackRequest); + logger.info("Stack ID: {}", response.stackId()); - DescribeStackEventsRequest describeStackEventsRequest = new DescribeStackEventsRequest() - .withStackName(createStackRequest.getStackName()); - DescribeStackEventsResult describeStackEventsResult; + DescribeStackEventsRequest describeStackEventsRequest = DescribeStackEventsRequest.builder() + .stackName(createStackRequest.stackName()) + .build(); + DescribeStackEventsResponse describeStackEventsResponse; // Poll status of stack until stack has been created or creation has failed while (true) { - describeStackEventsResult = cloudFormationClient.describeStackEvents(describeStackEventsRequest); - StackEvent event = describeStackEventsResult.getStackEvents().get(0); - String resourceId = event.getLogicalResourceId(); - String resourceStatus = event.getResourceStatus(); + describeStackEventsResponse = cloudFormationClient.describeStackEvents(describeStackEventsRequest); + StackEvent event = describeStackEventsResponse.stackEvents().get(0); + String resourceId = event.logicalResourceId(); + ResourceStatus resourceStatus = event.resourceStatus(); logger.info("Resource Id: {}, Resource status: {}", resourceId, resourceStatus); - if (!resourceId.equals(event.getStackName()) || - resourceStatus.equals(CF_CREATE_RESOURCE_IN_PROGRESS_STATUS)) { + if (!resourceId.equals(event.stackName()) || + resourceStatus.equals(ResourceStatus.CREATE_IN_PROGRESS)) { try { Thread.sleep(sleepTimeMillis); continue; @@ -122,8 +121,8 @@ private void processCreateStackRequest(CreateStackRequest createStackRequest) throw new RuntimeException("Thread.sleep interrupted: " + e.getMessage(), e); } } - else if (resourceStatus.equals(CF_CREATE_RESOURCE_FAILED_STATUS)) { - throw new RuntimeException(getCloudFormationErrorReasons(describeStackEventsResult.getStackEvents())); + else if (resourceStatus.equals(ResourceStatus.CREATE_FAILED)) { + throw new RuntimeException(getCloudFormationErrorReasons(describeStackEventsResponse.stackEvents())); } break; } @@ -140,9 +139,9 @@ private String getCloudFormationErrorReasons(List stackEvents) new StringBuilder("CloudFormation stack creation failed due to the following reason(s):\n"); stackEvents.forEach(stackEvent -> { - if (stackEvent.getResourceStatus().equals(CF_CREATE_RESOURCE_FAILED_STATUS)) { + if (stackEvent.resourceStatus().equals(ResourceStatus.CREATE_FAILED)) { String errorMessage = String.format("Resource: %s, Reason: %s\n", - stackEvent.getLogicalResourceId(), stackEvent.getResourceStatusReason()); + stackEvent.logicalResourceId(), stackEvent.resourceStatusReason()); errorMessageBuilder.append(errorMessage); } }); @@ -160,14 +159,14 @@ public void deleteStack() logger.info("------------------------------------------------------"); try { - DeleteStackRequest request = new DeleteStackRequest().withStackName(stackName); + DeleteStackRequest request = DeleteStackRequest.builder().stackName(stackName).build(); cloudFormationClient.deleteStack(request); } catch (Exception e) { logger.error("Something went wrong... Manual resource cleanup may be needed!!!", e); } finally { - cloudFormationClient.shutdown(); + cloudFormationClient.close(); } } } diff --git a/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/providers/SecretsManagerCredentialsProvider.java b/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/providers/SecretsManagerCredentialsProvider.java index cabc1afa66..c2b5f4af13 100644 --- a/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/providers/SecretsManagerCredentialsProvider.java +++ b/athena-federation-integ-test/src/main/java/com/amazonaws/athena/connector/integ/providers/SecretsManagerCredentialsProvider.java @@ -21,11 +21,10 @@ import com.amazonaws.athena.connector.integ.data.SecretsManagerCredentials; import com.amazonaws.athena.connector.integ.data.TestConfig; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.io.IOException; import java.util.HashMap; @@ -55,23 +54,21 @@ public static Optional getCredentials(TestConfig test if (secretsManagerSecret.isPresent()) { String secret = secretsManagerSecret.get(); - AWSSecretsManager secretsManager = AWSSecretsManagerClientBuilder.defaultClient(); + SecretsManagerClient secretsManager = SecretsManagerClient.create(); try { - GetSecretValueResult secretValueResult = secretsManager.getSecretValue(new GetSecretValueRequest() - .withSecretId(secret)); + GetSecretValueResponse secretValueResult = secretsManager.getSecretValue(GetSecretValueRequest.builder() + .secretId(secret) + .build()); ObjectMapper objectMapper = new ObjectMapper(); - Map credentials = objectMapper.readValue(secretValueResult.getSecretString(), + Map credentials = objectMapper.readValue(secretValueResult.secretString(), HashMap.class); return Optional.of(new SecretsManagerCredentials(secret, credentials.get("username"), - credentials.get("password"), secretValueResult.getARN())); + credentials.get("password"), secretValueResult.arn())); } catch (IOException e) { throw new RuntimeException(String.format("Unable to parse SecretsManager secret (%s): %s", secret, e.getMessage()), e); } - finally { - secretsManager.shutdown(); - } } return Optional.empty(); diff --git a/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/FederationService.java b/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/FederationService.java deleted file mode 100644 index d15468e50b..0000000000 --- a/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/FederationService.java +++ /dev/null @@ -1,30 +0,0 @@ -/*- - * #%L - * Amazon Athena Query Federation SDK Tools - * %% - * Copyright (C) 2019 - 2020 Amazon Web Services - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package com.amazonaws.athena.connector.validation; - -import com.amazonaws.athena.connector.lambda.request.FederationRequest; -import com.amazonaws.athena.connector.lambda.request.FederationResponse; -import com.amazonaws.services.lambda.invoke.LambdaFunction; - -public interface FederationService -{ - @LambdaFunction - FederationResponse call(final FederationRequest request); -} diff --git a/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/FederationServiceProvider.java b/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/FederationServiceProvider.java index 4306468cc1..4f3628a68f 100644 --- a/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/FederationServiceProvider.java +++ b/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/FederationServiceProvider.java @@ -19,19 +19,22 @@ */ package com.amazonaws.athena.connector.validation; +import com.amazonaws.athena.connector.lambda.request.FederationRequest; +import com.amazonaws.athena.connector.lambda.request.FederationResponse; import com.amazonaws.athena.connector.lambda.request.PingRequest; import com.amazonaws.athena.connector.lambda.request.PingResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory; -import com.amazonaws.services.lambda.AWSLambdaClientBuilder; -import com.amazonaws.services.lambda.invoke.LambdaFunction; -import com.amazonaws.services.lambda.invoke.LambdaFunctionNameResolver; -import com.amazonaws.services.lambda.invoke.LambdaInvokerFactory; -import com.amazonaws.services.lambda.invoke.LambdaInvokerFactoryConfig; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.lambda.LambdaClient; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.lang.reflect.Method; +import java.io.IOException; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -45,60 +48,58 @@ public class FederationServiceProvider private static final String VALIDATION_SUFFIX = "_validation"; - private static final Map serviceCache = new ConcurrentHashMap<>(); + private static final Map serdeVersionCache = new ConcurrentHashMap<>(); + + private static final LambdaClient lambdaClient = LambdaClient.create(); private FederationServiceProvider() { // Intentionally left blank. } - public static FederationService getService(String lambdaFunction, FederatedIdentity identity, String catalog) + private static R invokeFunction(String lambdaFunction, T request, Class responseClass, ObjectMapper objectMapper) { - FederationService service = serviceCache.get(lambdaFunction); - if (service != null) { - return service; + String payload; + try { + payload = objectMapper.writeValueAsString(request); + } + catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize request object", e); } - service = LambdaInvokerFactory.builder() - .lambdaClient(AWSLambdaClientBuilder.defaultClient()) - .objectMapper(VersionedObjectMapperFactory.create(BLOCK_ALLOCATOR)) - .lambdaFunctionNameResolver(new Mapper(lambdaFunction)) - .build(FederationService.class); - - PingRequest pingRequest = new PingRequest(identity, catalog, generateQueryId()); - PingResponse pingResponse = (PingResponse) service.call(pingRequest); + InvokeRequest invokeRequest = InvokeRequest.builder() + .functionName(lambdaFunction) + .payload(SdkBytes.fromUtf8String(payload)) + .build(); - int actualSerDeVersion = pingResponse.getSerDeVersion(); - log.info("SerDe version for function {}, catalog {} is {}", lambdaFunction, catalog, actualSerDeVersion); + InvokeResponse invokeResponse = lambdaClient.invoke(invokeRequest); - if (actualSerDeVersion != SERDE_VERSION) { - service = LambdaInvokerFactory.builder() - .lambdaClient(AWSLambdaClientBuilder.defaultClient()) - .objectMapper(VersionedObjectMapperFactory.create(BLOCK_ALLOCATOR, actualSerDeVersion)) - .lambdaFunctionNameResolver(new Mapper(lambdaFunction)) - .build(FederationService.class); + String response = invokeResponse.payload().asUtf8String(); + try { + return objectMapper.readValue(response, responseClass); + } + catch (IOException e) { + throw new RuntimeException("Failed to deserialize response payload", e); } - - serviceCache.put(lambdaFunction, service); - return service; } - public static final class Mapper - implements LambdaFunctionNameResolver + public static FederationResponse callService(String lambdaFunction, FederatedIdentity identity, String catalog, FederationRequest request) { - private final String function; - - private Mapper(String function) - { - this.function = function; + int serDeVersion = SERDE_VERSION; + if (serdeVersionCache.containsKey(lambdaFunction)) { + serDeVersion = serdeVersionCache.get(lambdaFunction); } - - @Override - public String getFunctionName(Method method, LambdaFunction lambdaFunction, - LambdaInvokerFactoryConfig lambdaInvokerFactoryConfig) - { - return function; + else { + ObjectMapper objectMapper = VersionedObjectMapperFactory.create(BLOCK_ALLOCATOR); + PingRequest pingRequest = new PingRequest(identity, catalog, generateQueryId()); + PingResponse pingResponse = invokeFunction(lambdaFunction, pingRequest, PingResponse.class, objectMapper); + + int actualSerDeVersion = pingResponse.getSerDeVersion(); + log.info("SerDe version for function {}, catalog {} is {}", lambdaFunction, catalog, actualSerDeVersion); + serdeVersionCache.put(lambdaFunction, actualSerDeVersion); } + + return invokeFunction(lambdaFunction, request, FederationResponse.class, VersionedObjectMapperFactory.create(BLOCK_ALLOCATOR, serDeVersion)); } public static String generateQueryId() diff --git a/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/LambdaMetadataProvider.java b/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/LambdaMetadataProvider.java index d3b2a93614..c603ddb463 100644 --- a/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/LambdaMetadataProvider.java +++ b/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/LambdaMetadataProvider.java @@ -42,8 +42,8 @@ import java.util.Set; import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE; +import static com.amazonaws.athena.connector.validation.FederationServiceProvider.callService; import static com.amazonaws.athena.connector.validation.FederationServiceProvider.generateQueryId; -import static com.amazonaws.athena.connector.validation.FederationServiceProvider.getService; /** * This class offers multiple convenience methods to retrieve metadata from a deployed Lambda. @@ -75,7 +75,7 @@ public static ListSchemasResponse listSchemas(String catalog, try (ListSchemasRequest request = new ListSchemasRequest(identity, queryId, catalog)) { log.info("Submitting request: {}", request); - ListSchemasResponse response = (ListSchemasResponse) getService(metadataFunction, identity, catalog).call(request); + ListSchemasResponse response = (ListSchemasResponse) callService(metadataFunction, identity, catalog, request); log.info("Received response: {}", response); return response; } @@ -107,7 +107,7 @@ public static ListTablesResponse listTables(String catalog, try (ListTablesRequest request = new ListTablesRequest(identity, queryId, catalog, schema, null, UNLIMITED_PAGE_SIZE_VALUE)) { log.info("Submitting request: {}", request); - ListTablesResponse response = (ListTablesResponse) getService(metadataFunction, identity, catalog).call(request); + ListTablesResponse response = (ListTablesResponse) callService(metadataFunction, identity, catalog, request); log.info("Received response: {}", response); return response; } @@ -136,7 +136,7 @@ public static GetTableResponse getTable(String catalog, try (GetTableRequest request = new GetTableRequest(identity, queryId, catalog, tableName, Collections.emptyMap())) { log.info("Submitting request: {}", request); - GetTableResponse response = (GetTableResponse) getService(metadataFunction, identity, catalog).call(request); + GetTableResponse response = (GetTableResponse) callService(metadataFunction, identity, catalog, request); log.info("Received response: {}", response); return response; } @@ -171,7 +171,7 @@ public static GetTableLayoutResponse getTableLayout(String catalog, try (GetTableLayoutRequest request = new GetTableLayoutRequest(identity, queryId, catalog, tableName, constraints, schema, partitionCols)) { log.info("Submitting request: {}", request); - GetTableLayoutResponse response = (GetTableLayoutResponse) getService(metadataFunction, identity, catalog).call(request); + GetTableLayoutResponse response = (GetTableLayoutResponse) callService(metadataFunction, identity, catalog, request); log.info("Received response: {}", response); return response; } @@ -208,7 +208,7 @@ public static GetSplitsResponse getSplits(String catalog, try (GetSplitsRequest request = new GetSplitsRequest(identity, queryId, catalog, tableName, partitions, partitionCols, constraints, contToken)) { log.info("Submitting request: {}", request); - GetSplitsResponse response = (GetSplitsResponse) getService(metadataFunction, identity, catalog).call(request); + GetSplitsResponse response = (GetSplitsResponse) callService(metadataFunction, identity, catalog, request); log.info("Received response: {}", response); return response; } diff --git a/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/LambdaRecordProvider.java b/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/LambdaRecordProvider.java index 87d301931e..296de3a854 100644 --- a/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/LambdaRecordProvider.java +++ b/athena-federation-sdk-tools/src/main/java/com/amazonaws/athena/connector/validation/LambdaRecordProvider.java @@ -29,8 +29,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static com.amazonaws.athena.connector.validation.FederationServiceProvider.callService; import static com.amazonaws.athena.connector.validation.FederationServiceProvider.generateQueryId; -import static com.amazonaws.athena.connector.validation.FederationServiceProvider.getService; /** * This class offers a convenience method to retrieve records from a deployed Lambda. @@ -81,7 +81,7 @@ public static ReadRecordsResponse readRecords(String catalog, MAX_BLOCK_SIZE, MAX_INLINE_BLOCK_SIZE)) { log.info("Submitting request: {}", request); - ReadRecordsResponse response = (ReadRecordsResponse) getService(recordFunction, identity, catalog).call(request); + ReadRecordsResponse response = (ReadRecordsResponse) callService(recordFunction, identity, catalog, request); log.info("Received response: {}", response); return response; } diff --git a/athena-federation-sdk/pom.xml b/athena-federation-sdk/pom.xml index e20494e405..75269ede6b 100644 --- a/athena-federation-sdk/pom.xml +++ b/athena-federation-sdk/pom.xml @@ -27,75 +27,85 @@ - com.amazonaws - jmespath-java - ${aws-sdk.version} + software.amazon.awssdk + apache-client + ${aws-sdk-v2.version} + + + software.amazon.awssdk + athena + ${aws-sdk-v2.version} - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations + software.amazon.awssdk + netty-nio-client - com.amazonaws - aws-java-sdk-core - ${aws-sdk.version} + software.amazon.awssdk + glue + ${aws-sdk-v2.version} - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor + software.amazon.awssdk + netty-nio-client + + + + software.amazon.awssdk + kms + ${aws-sdk-v2.version} + - com.fasterxml.jackson.core - jackson-core + software.amazon.awssdk + netty-nio-client + + + + software.amazon.awssdk + lambda + ${aws-sdk-v2.version} + - com.fasterxml.jackson.core - jackson-databind + software.amazon.awssdk + netty-nio-client + + + + software.amazon.awssdk + s3 + ${aws-sdk-v2.version} + - com.fasterxml.jackson.core - jackson-annotations + software.amazon.awssdk + netty-nio-client - com.amazonaws - aws-java-sdk-secretsmanager - ${aws-sdk.version} + software.amazon.awssdk + secretsmanager + ${aws-sdk-v2.version} commons-logging commons-logging + + software.amazon.awssdk + netty-nio-client + - com.amazonaws - aws-java-sdk-sts - ${aws-sdk.version} + software.amazon.awssdk + sts + ${aws-sdk-v2.version} com.fasterxml.jackson.datatype @@ -117,18 +127,12 @@ com.fasterxml.jackson.core jackson-annotations + + software.amazon.awssdk + netty-nio-client + - - com.amazonaws - aws-java-sdk-glue - ${aws-sdk.version} - - - com.amazonaws - aws-java-sdk-athena - ${aws-sdk.version} - org.apache.arrow arrow-vector @@ -171,21 +175,6 @@ aws-lambda-java-core 1.2.3 - - com.amazonaws - aws-java-sdk-lambda - ${aws-sdk.version} - - - com.amazonaws - aws-java-sdk-s3 - ${aws-sdk.version} - - - com.amazonaws - aws-java-sdk-kms - ${aws-sdk.version} - com.google.guava guava diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/credentials/CrossAccountCredentialsProvider.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/credentials/CrossAccountCredentialsProvider.java index 5ed2890fb3..309082d1d5 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/credentials/CrossAccountCredentialsProvider.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/credentials/CrossAccountCredentialsProvider.java @@ -19,17 +19,16 @@ */ package com.amazonaws.athena.connector.credentials; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicSessionCredentials; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.securitytoken.AWSSecurityTokenService; -import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceAsyncClientBuilder; -import com.amazonaws.services.securitytoken.model.AssumeRoleRequest; -import com.amazonaws.services.securitytoken.model.AssumeRoleResult; -import com.amazonaws.services.securitytoken.model.Credentials; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; +import software.amazon.awssdk.services.sts.model.AssumeRoleResponse; +import software.amazon.awssdk.services.sts.model.Credentials; import java.util.Map; @@ -40,19 +39,20 @@ public class CrossAccountCredentialsProvider private CrossAccountCredentialsProvider() {} - public static AWSCredentialsProvider getCrossAccountCredentialsIfPresent(Map configOptions, String roleSessionName) + public static AwsCredentialsProvider getCrossAccountCredentialsIfPresent(Map configOptions, String roleSessionName) { if (configOptions.containsKey(CROSS_ACCOUNT_ROLE_ARN_CONFIG)) { logger.debug("Found cross-account role arn to assume."); - AWSSecurityTokenService stsClient = AWSSecurityTokenServiceAsyncClientBuilder.standard().build(); - AssumeRoleRequest assumeRoleRequest = new AssumeRoleRequest() - .withRoleArn(configOptions.get(CROSS_ACCOUNT_ROLE_ARN_CONFIG)) - .withRoleSessionName(roleSessionName); - AssumeRoleResult assumeRoleResult = stsClient.assumeRole(assumeRoleRequest); - Credentials credentials = assumeRoleResult.getCredentials(); - BasicSessionCredentials basicSessionCredentials = new BasicSessionCredentials(credentials.getAccessKeyId(), credentials.getSecretAccessKey(), credentials.getSessionToken()); - return new AWSStaticCredentialsProvider(basicSessionCredentials); + StsClient stsClient = StsClient.create(); + AssumeRoleRequest assumeRoleRequest = AssumeRoleRequest.builder() + .roleArn(configOptions.get(CROSS_ACCOUNT_ROLE_ARN_CONFIG)) + .roleSessionName(roleSessionName) + .build(); + AssumeRoleResponse assumeRoleResponse = stsClient.assumeRole(assumeRoleRequest); + Credentials credentials = assumeRoleResponse.credentials(); + AwsSessionCredentials awsSessionCredentials = AwsSessionCredentials.builder().accessKeyId(credentials.accessKeyId()).secretAccessKey(credentials.secretAccessKey()).sessionToken(credentials.sessionToken()).build(); + return StaticCredentialsProvider.create(awsSessionCredentials); } - return DefaultAWSCredentialsProviderChain.getInstance(); + return DefaultCredentialsProvider.create(); } } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/QueryStatusChecker.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/QueryStatusChecker.java index 399424d441..491c46191f 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/QueryStatusChecker.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/QueryStatusChecker.java @@ -19,13 +19,13 @@ */ package com.amazonaws.athena.connector.lambda; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.model.GetQueryExecutionRequest; -import com.amazonaws.services.athena.model.GetQueryExecutionResult; -import com.amazonaws.services.athena.model.InvalidRequestException; import com.google.common.collect.ImmutableSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.athena.model.GetQueryExecutionRequest; +import software.amazon.awssdk.services.athena.model.GetQueryExecutionResponse; +import software.amazon.awssdk.services.athena.model.InvalidRequestException; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; @@ -50,12 +50,12 @@ public class QueryStatusChecker private boolean wasStarted = false; private final AtomicBoolean isRunning = new AtomicBoolean(true); - private final AmazonAthena athena; + private final AthenaClient athena; private final ThrottlingInvoker athenaInvoker; private final String queryId; private final Thread checkerThread; - public QueryStatusChecker(AmazonAthena athena, ThrottlingInvoker athenaInvoker, String queryId) + public QueryStatusChecker(AthenaClient athena, ThrottlingInvoker athenaInvoker, String queryId) { this.athena = athena; this.athenaInvoker = athenaInvoker; @@ -114,8 +114,8 @@ private void checkStatus(String queryId, int attempt) { logger.debug(format("Background thread checking status of Athena query %s, attempt %d", queryId, attempt)); try { - GetQueryExecutionResult queryExecution = athenaInvoker.invoke(() -> athena.getQueryExecution(new GetQueryExecutionRequest().withQueryExecutionId(queryId))); - String state = queryExecution.getQueryExecution().getStatus().getState(); + GetQueryExecutionResponse queryExecution = athenaInvoker.invoke(() -> athena.getQueryExecution(GetQueryExecutionRequest.builder().queryExecutionId(queryId).build())); + String state = queryExecution.queryExecution().status().state().toString(); if (TERMINAL_STATES.contains(state)) { logger.debug("Query {} has terminated with state {}", queryId, state); isRunning.set(false); diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/BlockUtils.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/BlockUtils.java index dfac5b00d6..268dad7ddb 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/BlockUtils.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/BlockUtils.java @@ -77,6 +77,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZoneId; @@ -273,6 +274,9 @@ else if (value instanceof LocalDateTime) { pos, ((LocalDateTime) value).atZone(UTC_ZONE_ID).toInstant().toEpochMilli()); } + else if (value instanceof Instant) { + ((DateMilliVector) vector).setSafe(pos, ((Instant) value).toEpochMilli()); + } else { ((DateMilliVector) vector).setSafe(pos, (long) value); } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/DateTimeFormatterUtil.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/DateTimeFormatterUtil.java index 13f8052ee7..462177d5b9 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/DateTimeFormatterUtil.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/DateTimeFormatterUtil.java @@ -19,13 +19,13 @@ */ package com.amazonaws.athena.connector.lambda.data; -import com.amazonaws.util.StringUtils; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.commons.lang3.time.DateFormatUtils; import org.apache.commons.lang3.time.DateUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.utils.StringUtils; import java.math.BigDecimal; import java.text.ParseException; @@ -98,7 +98,7 @@ private DateTimeFormatterUtil() public static LocalDate stringToLocalDate(String value, String dateFormat, ZoneId defaultTimeZone) { - if (StringUtils.isNullOrEmpty(dateFormat)) { + if (StringUtils.isEmpty(dateFormat)) { logger.info("Unable to parse {} as Date type due to invalid dateformat", value); return null; } @@ -143,7 +143,7 @@ public static Object stringToZonedDateTime(String value, String dateFormat, Zone */ public static Object stringToDateTime(String value, String dateFormat, ZoneId defaultTimeZone) { - if (StringUtils.isNullOrEmpty(dateFormat)) { + if (StringUtils.isEmpty(dateFormat)) { logger.warn("Unable to parse {} as DateTime type due to invalid date format", value); return null; } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillReader.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillReader.java index 48806b99dc..6415484b41 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillReader.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillReader.java @@ -25,12 +25,14 @@ import com.amazonaws.athena.connector.lambda.security.BlockCrypto; import com.amazonaws.athena.connector.lambda.security.EncryptionKey; import com.amazonaws.athena.connector.lambda.security.NoOpBlockCrypto; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.S3Object; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import java.io.IOException; @@ -40,10 +42,10 @@ public class S3BlockSpillReader { private static final Logger logger = LoggerFactory.getLogger(S3BlockSpillReader.class); - private final AmazonS3 amazonS3; + private final S3Client amazonS3; private final BlockAllocator allocator; - public S3BlockSpillReader(AmazonS3 amazonS3, BlockAllocator allocator) + public S3BlockSpillReader(S3Client amazonS3, BlockAllocator allocator) { this.amazonS3 = requireNonNull(amazonS3, "amazonS3 was null"); this.allocator = requireNonNull(allocator, "allocator was null"); @@ -59,13 +61,16 @@ public S3BlockSpillReader(AmazonS3 amazonS3, BlockAllocator allocator) */ public Block read(S3SpillLocation spillLocation, EncryptionKey key, Schema schema) { - S3Object fullObject = null; + ResponseInputStream responseStream = null; try { logger.debug("read: Started reading block from S3"); - fullObject = amazonS3.getObject(spillLocation.getBucket(), spillLocation.getKey()); + responseStream = amazonS3.getObject(GetObjectRequest.builder() + .bucket(spillLocation.getBucket()) + .key(spillLocation.getKey()) + .build()); logger.debug("read: Completed reading block from S3"); BlockCrypto blockCrypto = (key != null) ? new AesGcmBlockCrypto(allocator) : new NoOpBlockCrypto(allocator); - Block block = blockCrypto.decrypt(key, ByteStreams.toByteArray(fullObject.getObjectContent()), schema); + Block block = blockCrypto.decrypt(key, ByteStreams.toByteArray(responseStream), schema); logger.debug("read: Completed decrypting block of size."); return block; } @@ -73,12 +78,12 @@ public Block read(S3SpillLocation spillLocation, EncryptionKey key, Schema schem throw new RuntimeException(ex); } finally { - if (fullObject != null) { + if (responseStream != null) { try { - fullObject.close(); + responseStream.close(); } catch (IOException ex) { - logger.warn("read: Exception while closing S3 object", ex); + logger.warn("read: Exception while closing S3 response stream", ex); } } } @@ -93,24 +98,27 @@ public Block read(S3SpillLocation spillLocation, EncryptionKey key, Schema schem */ public byte[] read(S3SpillLocation spillLocation, EncryptionKey key) { - S3Object fullObject = null; + ResponseInputStream responseStream = null; try { logger.debug("read: Started reading block from S3"); - fullObject = amazonS3.getObject(spillLocation.getBucket(), spillLocation.getKey()); + responseStream = amazonS3.getObject(GetObjectRequest.builder() + .bucket(spillLocation.getBucket()) + .key(spillLocation.getKey()) + .build()); logger.debug("read: Completed reading block from S3"); BlockCrypto blockCrypto = (key != null) ? new AesGcmBlockCrypto(allocator) : new NoOpBlockCrypto(allocator); - return blockCrypto.decrypt(key, ByteStreams.toByteArray(fullObject.getObjectContent())); + return blockCrypto.decrypt(key, ByteStreams.toByteArray(responseStream)); } catch (IOException ex) { throw new RuntimeException(ex); } finally { - if (fullObject != null) { + if (responseStream != null) { try { - fullObject.close(); + responseStream.close(); } catch (IOException ex) { - logger.warn("read: Exception while closing S3 object", ex); + logger.warn("read: Exception while closing S3 response stream", ex); } } } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java index 2604b5e228..de879feafd 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java @@ -27,10 +27,6 @@ import com.amazonaws.athena.connector.lambda.security.BlockCrypto; import com.amazonaws.athena.connector.lambda.security.EncryptionKey; import com.amazonaws.athena.connector.lambda.security.NoOpBlockCrypto; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.ObjectMetadata; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.S3Object; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.ByteStreams; @@ -38,10 +34,16 @@ import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -76,7 +78,7 @@ public class S3BlockSpiller private static final String SPILL_PUT_REQUEST_HEADERS_ENV = "spill_put_request_headers"; //Used to write to S3 - private final AmazonS3 amazonS3; + private final S3Client amazonS3; //Used to optionally encrypt Blocks. private final BlockCrypto blockCrypto; //Used to create new blocks. @@ -125,7 +127,7 @@ public class S3BlockSpiller * @param constraintEvaluator The ConstraintEvaluator that should be used to constrain writes. */ public S3BlockSpiller( - AmazonS3 amazonS3, + S3Client amazonS3, SpillConfig spillConfig, BlockAllocator allocator, Schema schema, @@ -146,7 +148,7 @@ public S3BlockSpiller( * @param maxRowsPerCall The max number of rows to allow callers to write in one call. */ public S3BlockSpiller( - AmazonS3 amazonS3, + S3Client amazonS3, SpillConfig spillConfig, BlockAllocator allocator, Schema schema, @@ -318,29 +320,24 @@ public void close() /** * Grabs the request headers from env and sets them on the request */ - private void setRequestHeadersFromEnv(PutObjectRequest request) + private Map getRequestHeadersFromEnv() { String headersFromEnvStr = configOptions.get(SPILL_PUT_REQUEST_HEADERS_ENV); if (headersFromEnvStr == null || headersFromEnvStr.isEmpty()) { - return; + return Collections.emptyMap(); } try { ObjectMapper mapper = new ObjectMapper(); TypeReference> typeRef = new TypeReference>() {}; Map headers = mapper.readValue(headersFromEnvStr, typeRef); - for (Map.Entry entry : headers.entrySet()) { - String oldValue = request.putCustomRequestHeader(entry.getKey(), entry.getValue()); - if (oldValue != null) { - logger.warn("Key: %s has been overwritten with: %s. Old value: %s", - entry.getKey(), entry.getValue(), oldValue); - } - } + return headers; } catch (com.fasterxml.jackson.core.JsonProcessingException e) { String message = String.format("Invalid value for environment variable: %s : %s", SPILL_PUT_REQUEST_HEADERS_ENV, headersFromEnvStr); logger.error(message, e); } + return Collections.emptyMap(); } /** @@ -361,15 +358,13 @@ protected SpillLocation write(Block block) // Set the contentLength otherwise the s3 client will buffer again since it // only sees the InputStream wrapper. - ObjectMetadata objMeta = new ObjectMetadata(); - objMeta.setContentLength(bytes.length); - PutObjectRequest request = new PutObjectRequest( - spillLocation.getBucket(), - spillLocation.getKey(), - new ByteArrayInputStream(bytes), - objMeta); - setRequestHeadersFromEnv(request); - amazonS3.putObject(request); + PutObjectRequest request = PutObjectRequest.builder() + .bucket(spillLocation.getBucket()) + .key(spillLocation.getKey()) + .contentLength((long) bytes.length) + .metadata(getRequestHeadersFromEnv()) + .build(); + amazonS3.putObject(request, RequestBody.fromBytes(bytes)); logger.info("write: Completed spilling block of size {} bytes", bytes.length); return spillLocation; @@ -393,9 +388,12 @@ protected Block read(S3SpillLocation spillLocation, EncryptionKey key, Schema sc { try { logger.debug("write: Started reading block from S3"); - S3Object fullObject = amazonS3.getObject(spillLocation.getBucket(), spillLocation.getKey()); + ResponseInputStream responseStream = amazonS3.getObject(GetObjectRequest.builder() + .bucket(spillLocation.getBucket()) + .key(spillLocation.getKey()) + .build()); logger.debug("write: Completed reading block from S3"); - Block block = blockCrypto.decrypt(key, ByteStreams.toByteArray(fullObject.getObjectContent()), schema); + Block block = blockCrypto.decrypt(key, ByteStreams.toByteArray(responseStream), schema); logger.debug("write: Completed decrypting block of size."); return block; } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifier.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifier.java index 35dea4ab31..4e06af8044 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifier.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifier.java @@ -20,12 +20,12 @@ package com.amazonaws.athena.connector.lambda.domain.spill; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.HeadBucketRequest; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.HeadBucketRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; /** * This class is used to track the bucket and its state, and check its validity @@ -37,14 +37,14 @@ public class SpillLocationVerifier private enum BucketState {UNCHECKED, VALID, INVALID} - private final AmazonS3 amazons3; + private final S3Client amazons3; private String bucket; private BucketState state; /** * @param amazons3 The S3 object for the account. */ - public SpillLocationVerifier(AmazonS3 amazons3) + public SpillLocationVerifier(S3Client amazons3) { this.amazons3 = amazons3; this.bucket = null; @@ -83,11 +83,11 @@ public void checkBucketAuthZ(String spillBucket) void updateBucketState() { try { - amazons3.headBucket(new HeadBucketRequest(bucket)); + amazons3.headBucket(HeadBucketRequest.builder().bucket(bucket).build()); state = BucketState.VALID; } - catch (AmazonServiceException ex) { - int statusCode = ex.getStatusCode(); + catch (S3Exception ex) { + int statusCode = ex.statusCode(); // returns 404 if bucket was not found, 403 if bucket access is forbidden if (statusCode == 404 || statusCode == 403) { state = BucketState.INVALID; diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/exceptions/AthenaConnectorException.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/exceptions/AthenaConnectorException.java index 59c5f3c75e..3743c552eb 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/exceptions/AthenaConnectorException.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/exceptions/AthenaConnectorException.java @@ -19,7 +19,7 @@ */ package com.amazonaws.athena.connector.lambda.exceptions; -import com.amazonaws.services.glue.model.ErrorDetails; +import software.amazon.awssdk.services.glue.model.ErrorDetails; import javax.annotation.Nonnull; diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/AthenaExceptionFilter.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/AthenaExceptionFilter.java index 9a9d7e2c46..6109038723 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/AthenaExceptionFilter.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/AthenaExceptionFilter.java @@ -20,7 +20,7 @@ package com.amazonaws.athena.connector.lambda.handlers; import com.amazonaws.athena.connector.lambda.ThrottlingInvoker; -import com.amazonaws.services.athena.model.TooManyRequestsException; +import software.amazon.awssdk.services.athena.model.TooManyRequestsException; public class AthenaExceptionFilter implements ThrottlingInvoker.ExceptionFilter diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/GlueMetadataHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/GlueMetadataHandler.java index e33a639ea7..fcd80f1773 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/GlueMetadataHandler.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/GlueMetadataHandler.java @@ -20,7 +20,6 @@ * #L% */ -import com.amazonaws.ClientConfiguration; import com.amazonaws.athena.connector.lambda.data.BlockAllocator; import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; import com.amazonaws.athena.connector.lambda.domain.TableName; @@ -33,18 +32,6 @@ import com.amazonaws.athena.connector.lambda.metadata.MetadataRequest; import com.amazonaws.athena.connector.lambda.metadata.glue.GlueFieldLexer; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.AWSGlueClientBuilder; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.Database; -import com.amazonaws.services.glue.model.GetDatabasesRequest; -import com.amazonaws.services.glue.model.GetDatabasesResult; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.GetTablesRequest; -import com.amazonaws.services.glue.model.GetTablesResult; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.base.Splitter; import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; @@ -54,7 +41,19 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.Database; +import software.amazon.awssdk.services.glue.model.GetDatabasesRequest; +import software.amazon.awssdk.services.glue.model.GetTablesRequest; +import software.amazon.awssdk.services.glue.model.GetTablesResponse; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.glue.paginators.GetDatabasesIterable; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; + +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -137,7 +136,7 @@ public abstract class GlueMetadataHandler // emulate behavior from prior versions. public static final String GLUE_TABLE_CONTAINS_PREVIOUSLY_UNSUPPORTED_TYPE = "glueTableContainsPreviouslyUnsupportedType"; - private final AWSGlue awsGlue; + private final GlueClient awsGlue; /** * Basic constructor which is recommended when extending this class. @@ -156,8 +155,10 @@ public GlueMetadataHandler(String sourceType, java.util.Map conf boolean disabled = configOptions.get(DISABLE_GLUE) != null && !"false".equalsIgnoreCase(configOptions.get(DISABLE_GLUE)); // null if the current instance does not want to leverage Glue for metadata - awsGlue = disabled ? null : (AWSGlueClientBuilder.standard() - .withClientConfiguration(new ClientConfiguration().withConnectionTimeout(CONNECT_TIMEOUT)) + awsGlue = disabled ? null : (GlueClient.builder() + .httpClientBuilder(ApacheHttpClient + .builder() + .connectionTimeout(Duration.ofMillis(CONNECT_TIMEOUT))) .build()); } @@ -168,7 +169,7 @@ public GlueMetadataHandler(String sourceType, java.util.Map conf * @param sourceType The source type, used in diagnostic logging. * @param configOptions The configOptions for this MetadataHandler. */ - public GlueMetadataHandler(AWSGlue awsGlue, String sourceType, java.util.Map configOptions) + public GlueMetadataHandler(GlueClient awsGlue, String sourceType, java.util.Map configOptions) { super(sourceType, configOptions); this.awsGlue = awsGlue; @@ -179,7 +180,7 @@ public GlueMetadataHandler(AWSGlue awsGlue, String sourceType, java.util.Map schemas = new ArrayList<>(); - String nextToken = null; - do { - getDatabasesRequest.setNextToken(nextToken); - GetDatabasesResult result = awsGlue.getDatabases(getDatabasesRequest); - - for (Database next : result.getDatabaseList()) { - if (filter == null || filter.filter(next)) { - schemas.add(next.getName()); - } - } + GetDatabasesIterable responses = awsGlue.getDatabasesPaginator(getDatabasesRequest); - nextToken = result.getNextToken(); - } - while (nextToken != null); + responses.stream().forEach(response -> response.databaseList() + .forEach(database -> { + if (filter == null || filter.filter(database)) { + schemas.add(database.name()); + } + })); return new ListSchemasResponse(request.getCatalogName(), schemas); } @@ -309,31 +305,30 @@ public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTables protected ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTablesRequest request, TableFilter filter) throws Exception { - GetTablesRequest getTablesRequest = new GetTablesRequest(); - getTablesRequest.setCatalogId(getCatalog(request)); - getTablesRequest.setDatabaseName(request.getSchemaName()); - Set tables = new HashSet<>(); String nextToken = request.getNextToken(); int pageSize = request.getPageSize(); do { - getTablesRequest.setNextToken(nextToken); + GetTablesRequest.Builder getTablesRequest = GetTablesRequest.builder() + .catalogId(getCatalog(request)) + .databaseName(request.getSchemaName()) + .nextToken(nextToken); if (pageSize != UNLIMITED_PAGE_SIZE_VALUE) { // Paginated requests will include the maxResults argument determined by the minimum value between the // pageSize and the maximum results supported by Glue (as defined in the Glue API docs). int maxResults = Math.min(pageSize, GET_TABLES_REQUEST_MAX_RESULTS); - getTablesRequest.setMaxResults(maxResults); + getTablesRequest.maxResults(maxResults); pageSize -= maxResults; } - GetTablesResult result = awsGlue.getTables(getTablesRequest); + GetTablesResponse response = awsGlue.getTables(getTablesRequest.build()); - for (Table next : result.getTableList()) { + for (Table next : response.tableList()) { if (filter == null || filter.filter(next)) { - tables.add(new TableName(request.getSchemaName(), next.getName())); + tables.add(new TableName(request.getSchemaName(), next.name())); } } - nextToken = result.getNextToken(); + nextToken = response.nextToken(); } while (nextToken != null && (pageSize == UNLIMITED_PAGE_SIZE_VALUE || pageSize > 0)); @@ -387,21 +382,23 @@ protected GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReq throws Exception { TableName tableName = request.getTableName(); - com.amazonaws.services.glue.model.GetTableRequest getTableRequest = new com.amazonaws.services.glue.model.GetTableRequest(); - getTableRequest.setCatalogId(getCatalog(request)); - getTableRequest.setDatabaseName(tableName.getSchemaName()); - getTableRequest.setName(tableName.getTableName()); + //Full class name required due to name overlap with athena + software.amazon.awssdk.services.glue.model.GetTableRequest getTableRequest = software.amazon.awssdk.services.glue.model.GetTableRequest.builder() + .catalogId(getCatalog(request)) + .databaseName(tableName.getSchemaName()) + .name(tableName.getTableName()) + .build(); - GetTableResult result = awsGlue.getTable(getTableRequest); - Table table = result.getTable(); + software.amazon.awssdk.services.glue.model.GetTableResponse response = awsGlue.getTable(getTableRequest); + Table table = response.table(); if (filter != null && !filter.filter(table)) { throw new RuntimeException("No matching table found " + request.getTableName()); } SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); - if (table.getParameters() != null) { - table.getParameters() + if (table.parameters() != null) { + table.parameters() .entrySet() .forEach(next -> schemaBuilder.addMetadata(next.getKey(), next.getValue())); } @@ -412,35 +409,37 @@ protected GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReq Map datetimeFormatMappingWithColumnName = new HashMap<>(); Set partitionCols = new HashSet<>(); - if (table.getPartitionKeys() != null) { - partitionCols = table.getPartitionKeys() - .stream().map(next -> columnNameMapping.getOrDefault(next.getName(), next.getName())).collect(Collectors.toSet()); + if (table.partitionKeys() != null) { + partitionCols = table.partitionKeys() + .stream() + .map(next -> columnNameMapping.getOrDefault(next.name(), next.name())) + .collect(Collectors.toSet()); } // partition columns should be added to the schema if they exist - List allColumns = Stream.of(table.getStorageDescriptor().getColumns(), table.getPartitionKeys() == null ? new ArrayList() : table.getPartitionKeys()) + List allColumns = Stream.of(table.storageDescriptor().columns(), table.partitionKeys() == null ? new ArrayList() : table.partitionKeys()) .flatMap(x -> x.stream()) .collect(Collectors.toList()); boolean glueTableContainsPreviouslyUnsupportedType = false; for (Column next : allColumns) { - String rawColumnName = next.getName(); + String rawColumnName = next.name(); String mappedColumnName = columnNameMapping.getOrDefault(rawColumnName, rawColumnName); // apply any type override provided in typeOverrideMapping from metadata // this is currently only used for timestamp with timezone support - logger.info("Column {} with registered type {}", rawColumnName, next.getType()); - Field arrowField = convertField(mappedColumnName, next.getType()); + logger.info("Column {} with registered type {}", rawColumnName, next.type()); + Field arrowField = convertField(mappedColumnName, next.type()); schemaBuilder.addField(arrowField); // Add non-null non-empty comments to metadata - if (next.getComment() != null && !next.getComment().trim().isEmpty()) { - schemaBuilder.addMetadata(mappedColumnName, next.getComment()); + if (next.comment() != null && !next.comment().trim().isEmpty()) { + schemaBuilder.addMetadata(mappedColumnName, next.comment()); } if (dateTimeFormatMapping.containsKey(rawColumnName)) { datetimeFormatMappingWithColumnName.put(mappedColumnName, dateTimeFormatMapping.get(rawColumnName)); } // Indicate that we found a `set` or `decimal` type so that we can set this metadata on the schemaBuilder later on - if (glueTableContainsPreviouslyUnsupportedType == false && isPreviouslyUnsupported(next.getType(), arrowField)) { + if (glueTableContainsPreviouslyUnsupportedType == false && isPreviouslyUnsupported(next.type(), arrowField)) { glueTableContainsPreviouslyUnsupportedType = true; } } @@ -449,8 +448,8 @@ protected GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReq populateSourceTableNameIfAvailable(table, schemaBuilder); - if (table.getViewOriginalText() != null && !table.getViewOriginalText().isEmpty()) { - schemaBuilder.addMetadata(VIEW_METADATA_FIELD, table.getViewOriginalText()); + if (table.viewOriginalText() != null && !table.viewOriginalText().isEmpty()) { + schemaBuilder.addMetadata(VIEW_METADATA_FIELD, table.viewOriginalText()); } schemaBuilder.addMetadata(GLUE_TABLE_CONTAINS_PREVIOUSLY_UNSUPPORTED_TYPE, String.valueOf(glueTableContainsPreviouslyUnsupportedType)); @@ -515,12 +514,12 @@ public interface DatabaseFilter */ protected static void populateSourceTableNameIfAvailable(Table table, SchemaBuilder schemaBuilder) { - String sourceTableProperty = table.getParameters().get(SOURCE_TABLE_PROPERTY); + String sourceTableProperty = table.parameters().get(SOURCE_TABLE_PROPERTY); if (sourceTableProperty != null) { // table property exists so nothing to do (assumes all table properties were already copied) return; } - String location = table.getStorageDescriptor().getLocation(); + String location = table.storageDescriptor().location(); if (location != null) { Matcher matcher = TABLE_ARN_REGEX.matcher(location); if (matcher.matches()) { @@ -550,7 +549,7 @@ protected static String getSourceTableName(Schema schema) */ protected static Map getColumnNameMapping(Table table) { - String columnNameMappingParam = table.getParameters().get(COLUMN_NAME_MAPPING_PROPERTY); + String columnNameMappingParam = table.parameters().get(COLUMN_NAME_MAPPING_PROPERTY); if (!Strings.isNullOrEmpty(columnNameMappingParam)) { return MAP_SPLITTER.split(columnNameMappingParam); } @@ -567,7 +566,7 @@ protected static Map getColumnNameMapping(Table table) */ private Map getDateTimeFormatMapping(Table table) { - String datetimeFormatMappingParam = table.getParameters().get(DATETIME_FORMAT_MAPPING_PROPERTY); + String datetimeFormatMappingParam = table.parameters().get(DATETIME_FORMAT_MAPPING_PROPERTY); if (!Strings.isNullOrEmpty(datetimeFormatMappingParam)) { return MAP_SPLITTER.split(datetimeFormatMappingParam); } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java index 467bb61442..0810ba64b1 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java @@ -58,20 +58,18 @@ import com.amazonaws.athena.connector.lambda.security.KmsKeyFactory; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.kms.AWSKMSClientBuilder; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.RequestStreamHandler; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.io.InputStream; @@ -119,7 +117,7 @@ public abstract class MetadataHandler protected static final String KMS_KEY_ID_ENV = "kms_key_id"; protected static final String DISABLE_SPILL_ENCRYPTION = "disable_spill_encryption"; private final CachableSecretsManager secretsManager; - private final AmazonAthena athena; + private final AthenaClient athena; private final ThrottlingInvoker athenaInvoker; private final EncryptionKeyFactory encryptionKeyFactory; private final String spillBucket; @@ -146,14 +144,14 @@ public MetadataHandler(String sourceType, java.util.Map configOp } else { this.encryptionKeyFactory = (this.configOptions.get(KMS_KEY_ID_ENV) != null) ? - new KmsKeyFactory(AWSKMSClientBuilder.standard().build(), this.configOptions.get(KMS_KEY_ID_ENV)) : + new KmsKeyFactory(KmsClient.create(), this.configOptions.get(KMS_KEY_ID_ENV)) : new LocalKeyFactory(); logger.debug("ENABLE_SPILL_ENCRYPTION with encryption factory: " + encryptionKeyFactory.getClass().getSimpleName()); } - this.secretsManager = new CachableSecretsManager(AWSSecretsManagerClientBuilder.defaultClient()); - this.athena = AmazonAthenaClientBuilder.defaultClient(); - this.verifier = new SpillLocationVerifier(AmazonS3ClientBuilder.standard().build()); + this.secretsManager = new CachableSecretsManager(SecretsManagerClient.create()); + this.athena = AthenaClient.create(); + this.verifier = new SpillLocationVerifier(S3Client.create()); this.athenaInvoker = ThrottlingInvoker.newDefaultBuilder(ATHENA_EXCEPTION_FILTER, configOptions).build(); } @@ -162,8 +160,8 @@ public MetadataHandler(String sourceType, java.util.Map configOp */ public MetadataHandler( EncryptionKeyFactory encryptionKeyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, String sourceType, String spillBucket, String spillPrefix, @@ -176,7 +174,7 @@ public MetadataHandler( this.sourceType = sourceType; this.spillBucket = spillBucket; this.spillPrefix = spillPrefix; - this.verifier = new SpillLocationVerifier(AmazonS3ClientBuilder.standard().build()); + this.verifier = new SpillLocationVerifier(S3Client.create()); this.athenaInvoker = ThrottlingInvoker.newDefaultBuilder(ATHENA_EXCEPTION_FILTER, configOptions).build(); } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java index f3b47a1a45..ac3e563005 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java @@ -40,17 +40,14 @@ import com.amazonaws.athena.connector.lambda.request.PingResponse; import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager; import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.RequestStreamHandler; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.io.InputStream; @@ -71,10 +68,10 @@ public abstract class RecordHandler private static final String MAX_BLOCK_SIZE_BYTES = "MAX_BLOCK_SIZE_BYTES"; private static final int NUM_SPILL_THREADS = 2; protected final java.util.Map configOptions; - private final AmazonS3 amazonS3; + private final S3Client amazonS3; private final String sourceType; private final CachableSecretsManager secretsManager; - private final AmazonAthena athena; + private final AthenaClient athena; private final ThrottlingInvoker athenaInvoker; /** @@ -83,9 +80,9 @@ public abstract class RecordHandler public RecordHandler(String sourceType, java.util.Map configOptions) { this.sourceType = sourceType; - this.amazonS3 = AmazonS3ClientBuilder.defaultClient(); - this.secretsManager = new CachableSecretsManager(AWSSecretsManagerClientBuilder.defaultClient()); - this.athena = AmazonAthenaClientBuilder.defaultClient(); + this.amazonS3 = S3Client.create(); + this.secretsManager = new CachableSecretsManager(SecretsManagerClient.create()); + this.athena = AthenaClient.create(); this.configOptions = configOptions; this.athenaInvoker = ThrottlingInvoker.newDefaultBuilder(ATHENA_EXCEPTION_FILTER, configOptions).build(); } @@ -93,7 +90,7 @@ public RecordHandler(String sourceType, java.util.Map configOpti /** * @param sourceType Used to aid in logging diagnostic info when raising a support case. */ - public RecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, String sourceType, java.util.Map configOptions) + public RecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, String sourceType, java.util.Map configOptions) { this.sourceType = sourceType; this.amazonS3 = amazonS3; diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/metadata/MetadataService.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/metadata/MetadataService.java deleted file mode 100644 index f5fc0b99fc..0000000000 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/metadata/MetadataService.java +++ /dev/null @@ -1,38 +0,0 @@ -package com.amazonaws.athena.connector.lambda.metadata; - -/*- - * #%L - * Amazon Athena Query Federation SDK - * %% - * Copyright (C) 2019 Amazon Web Services - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ - -import com.amazonaws.services.lambda.invoke.LambdaFunction; - -/** - * Lambda functions intended for Metadata operations associate with this interface. - */ -public interface MetadataService -{ - /** - * Returns metadata corresponding to the request type. - * - * @param request The metadata request. - * @return The metadata. - */ - @LambdaFunction(functionName = "metadata") - MetadataResponse getMetadata(final MetadataRequest request); -} diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/records/RecordService.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/records/RecordService.java deleted file mode 100644 index 3d70c240ca..0000000000 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/records/RecordService.java +++ /dev/null @@ -1,38 +0,0 @@ -package com.amazonaws.athena.connector.lambda.records; - -/*- - * #%L - * Amazon Athena Query Federation SDK - * %% - * Copyright (C) 2019 Amazon Web Services - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ - -import com.amazonaws.services.lambda.invoke.LambdaFunction; - -/** - * Lambda functions intended for Record operations associate with this interface. - */ -public interface RecordService -{ - /** - * Returns data/records corresponding to the request type. - * - * @param request The data/records request. - * @return The data/records. - */ - @LambdaFunction(functionName = "record") - RecordResponse readRecords(final RecordRequest request); -} diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java index e557210bff..00ccf90900 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java @@ -20,12 +20,12 @@ * #L% */ -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.util.Iterator; import java.util.LinkedHashMap; @@ -52,9 +52,9 @@ public class CachableSecretsManager private static final Pattern NAME_PATTERN = Pattern.compile(SECRET_NAME_PATTERN); private final LinkedHashMap cache = new LinkedHashMap<>(); - private final AWSSecretsManager secretsManager; + private final SecretsManagerClient secretsManager; - public CachableSecretsManager(AWSSecretsManager secretsManager) + public CachableSecretsManager(SecretsManagerClient secretsManager) { this.secretsManager = secretsManager; } @@ -97,9 +97,10 @@ public String getSecret(String secretName) if (cacheEntry == null || cacheEntry.getAge() > MAX_CACHE_AGE_MS) { logger.info("getSecret: Resolving secret[{}].", secretName); - GetSecretValueResult secretValueResult = secretsManager.getSecretValue(new GetSecretValueRequest() - .withSecretId(secretName)); - cacheEntry = new CacheEntry(secretName, secretValueResult.getSecretString()); + GetSecretValueResponse secretValueResult = secretsManager.getSecretValue(GetSecretValueRequest.builder() + .secretId(secretName) + .build()); + cacheEntry = new CacheEntry(secretName, secretValueResult.secretString()); evictCache(cache.size() >= MAX_CACHE_SIZE); cache.put(secretName, cacheEntry); } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/KmsKeyFactory.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/KmsKeyFactory.java index 478ed6d7e0..c9d0589bfd 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/KmsKeyFactory.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/KmsKeyFactory.java @@ -20,12 +20,12 @@ * #L% */ -import com.amazonaws.services.kms.AWSKMS; -import com.amazonaws.services.kms.model.DataKeySpec; -import com.amazonaws.services.kms.model.GenerateDataKeyRequest; -import com.amazonaws.services.kms.model.GenerateDataKeyResult; -import com.amazonaws.services.kms.model.GenerateRandomRequest; -import com.amazonaws.services.kms.model.GenerateRandomResult; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.kms.model.DataKeySpec; +import software.amazon.awssdk.services.kms.model.GenerateDataKeyRequest; +import software.amazon.awssdk.services.kms.model.GenerateDataKeyResponse; +import software.amazon.awssdk.services.kms.model.GenerateRandomRequest; +import software.amazon.awssdk.services.kms.model.GenerateRandomResponse; /** * An EncryptionKeyFactory that is backed by AWS KMS. @@ -35,10 +35,10 @@ public class KmsKeyFactory implements EncryptionKeyFactory { - private final AWSKMS kmsClient; + private final KmsClient kmsClient; private final String masterKeyId; - public KmsKeyFactory(AWSKMS kmsClient, String masterKeyId) + public KmsKeyFactory(KmsClient kmsClient, String masterKeyId) { this.kmsClient = kmsClient; this.masterKeyId = masterKeyId; @@ -49,16 +49,18 @@ public KmsKeyFactory(AWSKMS kmsClient, String masterKeyId) */ public EncryptionKey create() { - GenerateDataKeyResult dataKeyResult = + GenerateDataKeyResponse dataKeyResponse = kmsClient.generateDataKey( - new GenerateDataKeyRequest() - .withKeyId(masterKeyId) - .withKeySpec(DataKeySpec.AES_256)); + GenerateDataKeyRequest.builder() + .keyId(masterKeyId) + .keySpec(DataKeySpec.AES_128) + .build()); - GenerateRandomRequest randomRequest = new GenerateRandomRequest() - .withNumberOfBytes(AesGcmBlockCrypto.NONCE_BYTES); - GenerateRandomResult randomResult = kmsClient.generateRandom(randomRequest); + GenerateRandomRequest randomRequest = GenerateRandomRequest.builder() + .numberOfBytes(AesGcmBlockCrypto.NONCE_BYTES) + .build(); + GenerateRandomResponse randomResponse = kmsClient.generateRandom(randomRequest); - return new EncryptionKey(dataKeyResult.getPlaintext().array(), randomResult.getPlaintext().array()); + return new EncryptionKey(dataKeyResponse.plaintext().asByteArray(), randomResponse.plaintext().asByteArray()); } } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v2/LambdaFunctionExceptionSerDe.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v2/LambdaFunctionExceptionSerDe.java index 595133c0da..637ef13771 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v2/LambdaFunctionExceptionSerDe.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v2/LambdaFunctionExceptionSerDe.java @@ -20,14 +20,13 @@ package com.amazonaws.athena.connector.lambda.serde.v2; import com.amazonaws.athena.connector.lambda.serde.BaseDeserializer; -import com.amazonaws.services.lambda.invoke.LambdaFunctionException; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonNode; import com.google.common.base.Joiner; +import software.amazon.awssdk.services.lambda.model.LambdaException; import java.io.IOException; -import java.lang.reflect.Constructor; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -46,15 +45,15 @@ public class LambdaFunctionExceptionSerDe private LambdaFunctionExceptionSerDe() {} - public static final class Deserializer extends BaseDeserializer + public static final class Deserializer extends BaseDeserializer { public Deserializer() { - super(LambdaFunctionException.class); + super(LambdaException.class); } @Override - public LambdaFunctionException deserialize(JsonParser jparser, DeserializationContext ctxt) + public LambdaException deserialize(JsonParser jparser, DeserializationContext ctxt) throws IOException { validateObjectStart(jparser.getCurrentToken()); @@ -63,18 +62,18 @@ public LambdaFunctionException deserialize(JsonParser jparser, DeserializationCo } @Override - public LambdaFunctionException doDeserialize(JsonParser jparser, DeserializationContext ctxt) + public LambdaException doDeserialize(JsonParser jparser, DeserializationContext ctxt) throws IOException { JsonNode root = jparser.getCodec().readTree(jparser); return recursiveParse(root); } - private LambdaFunctionException recursiveParse(JsonNode root) + private LambdaException recursiveParse(JsonNode root) { String errorType = getNullableStringValue(root, ERROR_TYPE_FIELD); String errorMessage = getNullableStringValue(root, ERROR_MESSAGE_FIELD); - LambdaFunctionException cause = null; + LambdaException cause = null; JsonNode causeNode = root.get(CAUSE_FIELD); if (causeNode != null) { cause = recursiveParse(causeNode); @@ -102,20 +101,7 @@ private LambdaFunctionException recursiveParse(JsonNode root) } } - // HACK: LambdaFunctionException is only intended to be instantiated by Lambda server-side, so its constructors - // are package-private or deprecated. Thus the need for reflection here. If the signature of the preferred - // constructor does change, we fall back to the deprecated constructor (which requires us to append the stackTrace - // to the errorMessage to not lose it). If the deprecated constructor is removed then this will not compile - // and the appropriate adjustment can be made. - try { - Constructor constructor = LambdaFunctionException.class.getDeclaredConstructor( - String.class, String.class, LambdaFunctionException.class, List.class); - constructor.setAccessible(true); - return constructor.newInstance(errorType, errorMessage, cause, stackTraces); - } - catch (ReflectiveOperationException e) { - return new LambdaFunctionException(appendStackTrace(errorMessage, stackTraces), false, errorType); - } + return (LambdaException) LambdaException.builder().cause(cause).message(appendStackTrace(errorMessage, stackTraces) + "\nErrorType: " + errorType).build(); } private String getNullableStringValue(JsonNode parent, String field) diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v2/ObjectMapperFactoryV2.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v2/ObjectMapperFactoryV2.java index 0a82530652..8e4799b185 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v2/ObjectMapperFactoryV2.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v2/ObjectMapperFactoryV2.java @@ -27,7 +27,6 @@ import com.amazonaws.athena.connector.lambda.serde.PingRequestSerDe; import com.amazonaws.athena.connector.lambda.serde.PingResponseSerDe; import com.amazonaws.athena.connector.lambda.serde.VersionedSerDe; -import com.amazonaws.services.lambda.invoke.LambdaFunctionException; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.databind.BeanDescription; import com.fasterxml.jackson.databind.DeserializationContext; @@ -52,12 +51,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.lambda.model.LambdaException; @Deprecated public class ObjectMapperFactoryV2 { private static final JsonFactory JSON_FACTORY = new JsonFactory(); - private static final String LAMDA_EXCEPTION_CLASS_NAME = LambdaFunctionException.class.getName(); + private static final String LAMDA_EXCEPTION_CLASS_NAME = LambdaException.class.getName(); private static final SerializerFactory SERIALIZER_FACTORY; @@ -153,7 +153,7 @@ private StrictObjectMapper(BlockAllocator allocator) ImmutableMap, JsonDeserializer> desers = ImmutableMap.of( FederationRequest.class, createRequestDeserializer(allocator), FederationResponse.class, createResponseDeserializer(allocator), - LambdaFunctionException.class, new LambdaFunctionExceptionSerDe.Deserializer()); + LambdaException.class, new LambdaFunctionExceptionSerDe.Deserializer()); SimpleDeserializers deserializers = new SimpleDeserializers(desers); DeserializerFactoryConfig dConfig = new DeserializerFactoryConfig().withAdditionalDeserializers(deserializers); _deserializationContext = new DefaultDeserializationContext.Impl(new StrictDeserializerFactory(dConfig)); diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v3/ObjectMapperFactoryV3.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v3/ObjectMapperFactoryV3.java index 412555b4ec..9a305cf9b1 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v3/ObjectMapperFactoryV3.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v3/ObjectMapperFactoryV3.java @@ -58,7 +58,6 @@ import com.amazonaws.athena.connector.lambda.serde.v2.UserDefinedFunctionRequestSerDe; import com.amazonaws.athena.connector.lambda.serde.v2.UserDefinedFunctionResponseSerDe; import com.amazonaws.athena.connector.lambda.serde.v2.ValueSetSerDe; -import com.amazonaws.services.lambda.invoke.LambdaFunctionException; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.databind.BeanDescription; import com.fasterxml.jackson.databind.DeserializationContext; @@ -83,11 +82,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.lambda.model.LambdaException; public class ObjectMapperFactoryV3 { private static final JsonFactory JSON_FACTORY = new JsonFactory(); - private static final String LAMDA_EXCEPTION_CLASS_NAME = LambdaFunctionException.class.getName(); + private static final String LAMDA_EXCEPTION_CLASS_NAME = LambdaException.class.getName(); private static final SerializerFactory SERIALIZER_FACTORY; @@ -183,7 +183,7 @@ private StrictObjectMapper(BlockAllocator allocator) ImmutableMap, JsonDeserializer> desers = ImmutableMap.of( FederationRequest.class, createRequestDeserializer(allocator), FederationResponse.class, createResponseDeserializer(allocator), - LambdaFunctionException.class, new LambdaFunctionExceptionSerDe.Deserializer()); + LambdaException.class, new LambdaFunctionExceptionSerDe.Deserializer()); SimpleDeserializers deserializers = new SimpleDeserializers(desers); DeserializerFactoryConfig dConfig = new DeserializerFactoryConfig().withAdditionalDeserializers(deserializers); _deserializationContext = new DefaultDeserializationContext.Impl(new StrictDeserializerFactory(dConfig)); diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v4/ObjectMapperFactoryV4.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v4/ObjectMapperFactoryV4.java index bbb38c56ce..3f8f7d0f00 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v4/ObjectMapperFactoryV4.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v4/ObjectMapperFactoryV4.java @@ -60,7 +60,6 @@ import com.amazonaws.athena.connector.lambda.serde.v2.UserDefinedFunctionRequestSerDe; import com.amazonaws.athena.connector.lambda.serde.v2.UserDefinedFunctionResponseSerDe; import com.amazonaws.athena.connector.lambda.serde.v2.ValueSetSerDe; -import com.amazonaws.services.lambda.invoke.LambdaFunctionException; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.databind.BeanDescription; import com.fasterxml.jackson.databind.DeserializationContext; @@ -85,11 +84,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.lambda.model.LambdaException; public class ObjectMapperFactoryV4 { private static final JsonFactory JSON_FACTORY = new JsonFactory(); - private static final String LAMDA_EXCEPTION_CLASS_NAME = LambdaFunctionException.class.getName(); + private static final String LAMDA_EXCEPTION_CLASS_NAME = LambdaException.class.getName(); private static final SerializerFactory SERIALIZER_FACTORY; @@ -185,7 +185,7 @@ private StrictObjectMapper(BlockAllocator allocator) ImmutableMap, JsonDeserializer> desers = ImmutableMap.of( FederationRequest.class, createRequestDeserializer(allocator), FederationResponse.class, createResponseDeserializer(allocator), - LambdaFunctionException.class, new LambdaFunctionExceptionSerDe.Deserializer()); + LambdaException.class, new LambdaFunctionExceptionSerDe.Deserializer()); SimpleDeserializers deserializers = new SimpleDeserializers(desers); DeserializerFactoryConfig dConfig = new DeserializerFactoryConfig().withAdditionalDeserializers(deserializers); _deserializationContext = new DefaultDeserializationContext.Impl(new StrictDeserializerFactory(dConfig)); diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v5/ObjectMapperFactoryV5.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v5/ObjectMapperFactoryV5.java index d47dc71c23..6b94467c9f 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v5/ObjectMapperFactoryV5.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/serde/v5/ObjectMapperFactoryV5.java @@ -71,7 +71,6 @@ import com.amazonaws.athena.connector.lambda.serde.v4.OrderByFieldSerDeV4; import com.amazonaws.athena.connector.lambda.serde.v4.SchemaSerDeV4; import com.amazonaws.athena.connector.lambda.serde.v4.VariableExpressionSerDeV4; -import com.amazonaws.services.lambda.invoke.LambdaFunctionException; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.databind.BeanDescription; import com.fasterxml.jackson.databind.DeserializationContext; @@ -96,11 +95,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.lambda.model.LambdaException; public class ObjectMapperFactoryV5 { private static final JsonFactory JSON_FACTORY = new JsonFactory(); - private static final String LAMDA_EXCEPTION_CLASS_NAME = LambdaFunctionException.class.getName(); + private static final String LAMDA_EXCEPTION_CLASS_NAME = LambdaException.class.getName(); private static final SerializerFactory SERIALIZER_FACTORY; @@ -196,7 +196,7 @@ private StrictObjectMapper(BlockAllocator allocator) ImmutableMap, JsonDeserializer> desers = ImmutableMap.of( FederationRequest.class, createRequestDeserializer(allocator), FederationResponse.class, createResponseDeserializer(allocator), - LambdaFunctionException.class, new LambdaFunctionExceptionSerDe.Deserializer()); + LambdaException.class, new LambdaFunctionExceptionSerDe.Deserializer()); SimpleDeserializers deserializers = new SimpleDeserializers(desers); DeserializerFactoryConfig dConfig = new DeserializerFactoryConfig().withAdditionalDeserializers(deserializers); _deserializationContext = new DefaultDeserializationContext.Impl(new StrictDeserializerFactory(dConfig)); diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/QueryStatusCheckerTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/QueryStatusCheckerTest.java index ba0a641783..a0b5b7bb41 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/QueryStatusCheckerTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/QueryStatusCheckerTest.java @@ -19,14 +19,16 @@ */ package com.amazonaws.athena.connector.lambda; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.model.GetQueryExecutionRequest; -import com.amazonaws.services.athena.model.GetQueryExecutionResult; -import com.amazonaws.services.athena.model.InvalidRequestException; -import com.amazonaws.services.athena.model.QueryExecution; -import com.amazonaws.services.athena.model.QueryExecutionStatus; import com.google.common.collect.ImmutableList; + +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.athena.model.GetQueryExecutionRequest; +import software.amazon.awssdk.services.athena.model.GetQueryExecutionResponse; +import software.amazon.awssdk.services.athena.model.InvalidRequestException; +import software.amazon.awssdk.services.athena.model.QueryExecution; +import software.amazon.awssdk.services.athena.model.QueryExecutionStatus; + import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; @@ -52,20 +54,21 @@ public class QueryStatusCheckerTest private final ThrottlingInvoker athenaInvoker = ThrottlingInvoker.newDefaultBuilder(ATHENA_EXCEPTION_FILTER, com.google.common.collect.ImmutableMap.of()).build(); @Mock - private AmazonAthena athena; + private AthenaClient athena; @Test public void testFastTermination() throws InterruptedException { String queryId = "query0"; - GetQueryExecutionRequest request = new GetQueryExecutionRequest().withQueryExecutionId(queryId); - when(athena.getQueryExecution(request)).thenReturn(new GetQueryExecutionResult().withQueryExecution(new QueryExecution().withStatus(new QueryExecutionStatus().withState("FAILED")))); + GetQueryExecutionRequest request = GetQueryExecutionRequest.builder().queryExecutionId(queryId).build(); + when(athena.getQueryExecution(request)).thenReturn(GetQueryExecutionResponse.builder().queryExecution(QueryExecution.builder().status(QueryExecutionStatus.builder().state("FAILED").build()).build()).build()); QueryStatusChecker queryStatusChecker = new QueryStatusChecker(athena, athenaInvoker, queryId); assertTrue(queryStatusChecker.isQueryRunning()); Thread.sleep(2000); assertFalse(queryStatusChecker.isQueryRunning()); - verify(athena, times(1)).getQueryExecution(any()); + verify(athena, times(1)).getQueryExecution(any(GetQueryExecutionRequest.class)); + queryStatusChecker.close(); } @Test @@ -73,9 +76,9 @@ public void testSlowTermination() throws InterruptedException { String queryId = "query1"; - GetQueryExecutionRequest request = new GetQueryExecutionRequest().withQueryExecutionId(queryId); - GetQueryExecutionResult result1and2 = new GetQueryExecutionResult().withQueryExecution(new QueryExecution().withStatus(new QueryExecutionStatus().withState("RUNNING"))); - GetQueryExecutionResult result3 = new GetQueryExecutionResult().withQueryExecution(new QueryExecution().withStatus(new QueryExecutionStatus().withState("SUCCEEDED"))); + GetQueryExecutionRequest request = GetQueryExecutionRequest.builder().queryExecutionId(queryId).build(); + GetQueryExecutionResponse result1and2 = GetQueryExecutionResponse.builder().queryExecution(QueryExecution.builder().status(QueryExecutionStatus.builder().state("RUNNING").build()).build()).build(); + GetQueryExecutionResponse result3 = GetQueryExecutionResponse.builder().queryExecution(QueryExecution.builder().status(QueryExecutionStatus.builder().state("SUCCEEDED").build()).build()).build(); when(athena.getQueryExecution(request)).thenReturn(result1and2).thenReturn(result1and2).thenReturn(result3); try (QueryStatusChecker queryStatusChecker = new QueryStatusChecker(athena, athenaInvoker, queryId)) { assertTrue(queryStatusChecker.isQueryRunning()); @@ -83,7 +86,7 @@ public void testSlowTermination() assertTrue(queryStatusChecker.isQueryRunning()); Thread.sleep(3000); assertFalse(queryStatusChecker.isQueryRunning()); - verify(athena, times(3)).getQueryExecution(any()); + verify(athena, times(3)).getQueryExecution(any(GetQueryExecutionRequest.class)); } } @@ -92,13 +95,13 @@ public void testNotFound() throws InterruptedException { String queryId = "query2"; - GetQueryExecutionRequest request = new GetQueryExecutionRequest().withQueryExecutionId(queryId); - when(athena.getQueryExecution(request)).thenThrow(new InvalidRequestException("")); + GetQueryExecutionRequest request = GetQueryExecutionRequest.builder().queryExecutionId(queryId).build(); + when(athena.getQueryExecution(request)).thenThrow(InvalidRequestException.builder().message("").build()); try (QueryStatusChecker queryStatusChecker = new QueryStatusChecker(athena, athenaInvoker, queryId)) { assertTrue(queryStatusChecker.isQueryRunning()); Thread.sleep(2000); assertTrue(queryStatusChecker.isQueryRunning()); - verify(athena, times(1)).getQueryExecution(any()); + verify(athena, times(1)).getQueryExecution(any(GetQueryExecutionRequest.class)); } } @@ -107,13 +110,13 @@ public void testOtherError() throws InterruptedException { String queryId = "query3"; - GetQueryExecutionRequest request = new GetQueryExecutionRequest().withQueryExecutionId(queryId); - when(athena.getQueryExecution(request)).thenThrow(new AmazonServiceException("")); + GetQueryExecutionRequest request = GetQueryExecutionRequest.builder().queryExecutionId(queryId).build(); + when(athena.getQueryExecution(request)).thenThrow(AwsServiceException.builder().message("").build()); try (QueryStatusChecker queryStatusChecker = new QueryStatusChecker(athena, athenaInvoker, queryId)) { assertTrue(queryStatusChecker.isQueryRunning()); Thread.sleep(3000); assertTrue(queryStatusChecker.isQueryRunning()); - verify(athena, times(2)).getQueryExecution(any()); + verify(athena, times(2)).getQueryExecution(any(GetQueryExecutionRequest.class)); } } } diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java index 0c9de56318..0abc45c3ec 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java @@ -25,11 +25,6 @@ import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Schema; @@ -44,6 +39,13 @@ import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -51,8 +53,6 @@ import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; @@ -71,7 +71,7 @@ public class S3BlockSpillerTest private String splitId = "splitId"; @Mock - private AmazonS3 mockS3; + private S3Client mockS3; private S3BlockSpiller blockWriter; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -130,18 +130,20 @@ public void spillTest() final ByteHolder byteHolder = new ByteHolder(); - ArgumentCaptor argument = ArgumentCaptor.forClass(PutObjectRequest.class); + ArgumentCaptor requestArgument = ArgumentCaptor.forClass(PutObjectRequest.class); + ArgumentCaptor bodyArgument = ArgumentCaptor.forClass(RequestBody.class); - when(mockS3.putObject(any())) + when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + PutObjectResponse response = PutObjectResponse.builder().build(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); - return mock(PutObjectResult.class); + return response; } }); @@ -151,9 +153,9 @@ public Object answer(InvocationOnMock invocationOnMock) assertEquals(bucket, ((S3SpillLocation) blockLocation).getBucket()); assertEquals(prefix + "/" + requestId + "/" + splitId + ".0", ((S3SpillLocation) blockLocation).getKey()); } - verify(mockS3, times(1)).putObject(argument.capture()); - assertEquals(argument.getValue().getBucketName(), bucket); - assertEquals(argument.getValue().getKey(), prefix + "/" + requestId + "/" + splitId + ".0"); + verify(mockS3, times(1)).putObject(requestArgument.capture(), bodyArgument.capture()); + assertEquals(requestArgument.getValue().bucket(), bucket); + assertEquals(requestArgument.getValue().key(), prefix + "/" + requestId + "/" + splitId + ".0"); SpillLocation blockLocation2 = blockWriter.write(expected); @@ -162,25 +164,23 @@ public Object answer(InvocationOnMock invocationOnMock) assertEquals(prefix + "/" + requestId + "/" + splitId + ".1", ((S3SpillLocation) blockLocation2).getKey()); } - verify(mockS3, times(2)).putObject(argument.capture()); - assertEquals(argument.getValue().getBucketName(), bucket); - assertEquals(argument.getValue().getKey(), prefix + "/" + requestId + "/" + splitId + ".1"); + verify(mockS3, times(2)).putObject(requestArgument.capture(), bodyArgument.capture()); + assertEquals(requestArgument.getValue().bucket(), bucket); + assertEquals(requestArgument.getValue().key(), prefix + "/" + requestId + "/" + splitId + ".1"); verifyNoMoreInteractions(mockS3); reset(mockS3); logger.info("spillTest: Starting read test."); - when(mockS3.getObject(eq(bucket), eq(prefix + "/" + requestId + "/" + splitId + ".1"))) + when(mockS3.getObject(any(GetObjectRequest.class))) .thenAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - S3Object mockObject = mock(S3Object.class); - when(mockObject.getObjectContent()).thenReturn(new S3ObjectInputStream(new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); } }); @@ -189,7 +189,7 @@ public Object answer(InvocationOnMock invocationOnMock) assertEquals(expected, block); verify(mockS3, times(1)) - .getObject(eq(bucket), eq(prefix + "/" + requestId + "/" + splitId + ".1")); + .getObject(any(GetObjectRequest.class)); verifyNoMoreInteractions(mockS3); diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifierTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifierTest.java index a5077b8b14..b78c58c4eb 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifierTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifierTest.java @@ -1,7 +1,3 @@ -package com.amazonaws.athena.connector.lambda.domain.spill; - -import com.amazonaws.AmazonServiceException; - /*- * #%L * Amazon Athena Query Federation SDK @@ -22,19 +18,22 @@ * #L% */ -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.Bucket; -import com.amazonaws.services.s3.model.HeadBucketRequest; -import com.amazonaws.services.s3.model.HeadBucketResult; +package com.amazonaws.athena.connector.lambda.domain.spill; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.Spy; -import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.Bucket; +import software.amazon.awssdk.services.s3.model.HeadBucketRequest; +import software.amazon.awssdk.services.s3.model.HeadBucketResponse; +import software.amazon.awssdk.services.s3.model.ListBucketsResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; import java.util.ArrayList; import java.util.Arrays; @@ -43,9 +42,6 @@ import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -66,7 +62,7 @@ public void setup() logger.info("setUpBefore - enter"); bucketNames = Arrays.asList("bucket1", "bucket2", "bucket3"); - AmazonS3 mockS3 = createMockS3(bucketNames); + S3Client mockS3 = createMockS3(bucketNames); spyVerifier = spy(new SpillLocationVerifier(mockS3)); logger.info("setUpBefore - exit"); @@ -146,23 +142,21 @@ public void checkBucketAuthZFail() logger.info("checkBucketAuthZFail - exit"); } - private AmazonS3 createMockS3(List buckets) + private S3Client createMockS3(List buckets) { - AmazonS3 s3mock = mock(AmazonS3.class); + S3Client s3mock = mock(S3Client.class); when(s3mock.headBucket(any(HeadBucketRequest.class))) - .thenAnswer((Answer) invocationOnMock -> { - String bucketName = ((HeadBucketRequest) invocationOnMock.getArguments()[0]).getBucketName(); + .thenAnswer((Answer) invocationOnMock -> { + String bucketName = ((HeadBucketRequest) invocationOnMock.getArguments()[0]).bucket(); if (buckets.contains(bucketName)) { return null; } - AmazonServiceException exception; + AwsServiceException exception; if (bucketName.equals("forbidden")) { - exception = new AmazonServiceException("Forbidden"); - exception.setStatusCode(403); + exception = S3Exception.builder().statusCode(403).message("Forbidden").build(); } else { - exception = new AmazonServiceException("Not Found"); - exception.setStatusCode(404); + exception = S3Exception.builder().statusCode(404).message("Not Found").build(); } throw exception; }); diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/handlers/GlueMetadataHandlerTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/handlers/GlueMetadataHandlerTest.java index 4c8409877c..010cd57745 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/handlers/GlueMetadataHandlerTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/handlers/GlueMetadataHandlerTest.java @@ -41,20 +41,22 @@ import com.amazonaws.athena.connector.lambda.metadata.MetadataRequest; import com.amazonaws.athena.connector.lambda.security.IdentityUtil; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.Database; -import com.amazonaws.services.glue.model.GetDatabasesRequest; -import com.amazonaws.services.glue.model.GetDatabasesResult; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.GetTablesRequest; -import com.amazonaws.services.glue.model.GetTablesResult; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.Table; import com.amazonaws.services.lambda.runtime.Context; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.Database; +import software.amazon.awssdk.services.glue.model.GetDatabaseResponse; +import software.amazon.awssdk.services.glue.model.GetDatabasesRequest; +import software.amazon.awssdk.services.glue.model.GetDatabasesResponse; +import software.amazon.awssdk.services.glue.model.GetTablesRequest; +import software.amazon.awssdk.services.glue.model.GetTablesResponse; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.glue.paginators.GetDatabasesIterable; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; + import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.After; @@ -114,11 +116,11 @@ public class GlueMetadataHandlerTest // consider the order of the tables deterministic (i.e. pagination will work irrespective of the order that the // tables are returned from the source). private final List
unPaginatedTables = new ImmutableList.Builder
() - .add(new Table().withName("table3")) - .add(new Table().withName("table2")) - .add(new Table().withName("table5")) - .add(new Table().withName("table4")) - .add(new Table().withName("table1")) + .add(Table.builder().name("table3").build()) + .add(Table.builder().name("table2").build()) + .add(Table.builder().name("table5").build()) + .add(Table.builder().name("table4").build()) + .add(Table.builder().name("table1").build()) .build(); // The following response is expected be returned from doListTables when the pagination pageSize is greater than @@ -136,7 +138,7 @@ public class GlueMetadataHandlerTest public TestName testName = new TestName(); @Mock - private AWSGlue mockGlue; + private GlueClient mockGlue; @Mock private Context mockContext; @@ -149,8 +151,8 @@ public void setUp() handler = new GlueMetadataHandler(mockGlue, new LocalKeyFactory(), - mock(AWSSecretsManager.class), - mock(AmazonAthena.class), + mock(SecretsManagerClient.class), + mock(AthenaClient.class), "glue-test", "spill-bucket", "spill-prefix", @@ -187,33 +189,39 @@ public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAlloca .thenAnswer((InvocationOnMock invocationOnMock) -> { GetTablesRequest request = (GetTablesRequest) invocationOnMock.getArguments()[0]; - String nextToken = request.getNextToken(); - int pageSize = request.getMaxResults() == null ? UNLIMITED_PAGE_SIZE_VALUE : request.getMaxResults(); - assertEquals(accountId, request.getCatalogId()); - assertEquals(schema, request.getDatabaseName()); - GetTablesResult mockResult = mock(GetTablesResult.class); + String nextToken = request.nextToken(); + int pageSize = request.maxResults() == null ? UNLIMITED_PAGE_SIZE_VALUE : request.maxResults(); + assertEquals(accountId, request.catalogId()); + assertEquals(schema, request.databaseName()); + GetTablesResponse response; if (pageSize == UNLIMITED_PAGE_SIZE_VALUE) { // Simulate full list of tables returned from Glue. - when(mockResult.getTableList()).thenReturn(unPaginatedTables); - when(mockResult.getNextToken()).thenReturn(null); + response = GetTablesResponse.builder() + .tableList(unPaginatedTables) + .nextToken(null) + .build(); } else { // Simulate paginated list of tables returned from Glue. List
paginatedTables = unPaginatedTables.stream() - .sorted(Comparator.comparing(Table::getName)) - .filter(table -> nextToken == null || table.getName().compareTo(nextToken) >= 0) + .sorted(Comparator.comparing(Table::name)) + .filter(table -> nextToken == null || table.name().compareTo(nextToken) >= 0) .limit(pageSize + 1) .collect(Collectors.toList()); if (paginatedTables.size() > pageSize) { - when(mockResult.getNextToken()).thenReturn(paginatedTables.get(pageSize).getName()); - when(mockResult.getTableList()).thenReturn(paginatedTables.subList(0, pageSize)); + response = GetTablesResponse.builder() + .tableList(paginatedTables.subList(0, pageSize)) + .nextToken(paginatedTables.get(pageSize).name()) + .build(); } else { - when(mockResult.getNextToken()).thenReturn(null); - when(mockResult.getTableList()).thenReturn(paginatedTables); + response = GetTablesResponse.builder() + .tableList(paginatedTables) + .nextToken(null) + .build(); } } - return mockResult; + return response; }); } @@ -222,6 +230,7 @@ public void tearDown() throws Exception { allocator.close(); + mockGlue.close(); logger.info("{}: exit ", testName.getMethodName()); } @@ -230,25 +239,18 @@ public void doListSchemaNames() throws Exception { List databases = new ArrayList<>(); - databases.add(new Database().withName("db1")); - databases.add(new Database().withName("db2")); + databases.add(Database.builder().name("db1").build()); + databases.add(Database.builder().name("db2").build()); - when(mockGlue.getDatabases(nullable(GetDatabasesRequest.class))) + when(mockGlue.getDatabasesPaginator(nullable(GetDatabasesRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { GetDatabasesRequest request = (GetDatabasesRequest) invocationOnMock.getArguments()[0]; - assertEquals(accountId, request.getCatalogId()); - GetDatabasesResult mockResult = mock(GetDatabasesResult.class); - if (request.getNextToken() == null) { - when(mockResult.getDatabaseList()).thenReturn(databases); - when(mockResult.getNextToken()).thenReturn("next"); - } - else { - //only return real info on 1st call - when(mockResult.getDatabaseList()).thenReturn(new ArrayList<>()); - when(mockResult.getNextToken()).thenReturn(null); - } - return mockResult; + assertEquals(accountId, request.catalogId()); + GetDatabasesIterable mockIterable = mock(GetDatabasesIterable.class); + GetDatabasesResponse response = GetDatabasesResponse.builder().databaseList(databases).build(); + when(mockIterable.stream()).thenReturn(Collections.singletonList(response).stream()); + return mockIterable; }); ListSchemasRequest req = new ListSchemasRequest(IdentityUtil.fakeIdentity(), queryId, catalog); @@ -256,10 +258,8 @@ public void doListSchemaNames() logger.info("doListSchemas - {}", res.getSchemas()); - assertEquals(databases.stream().map(next -> next.getName()).collect(Collectors.toList()), + assertEquals(databases.stream().map(next -> next.name()).collect(Collectors.toList()), new ArrayList<>(res.getSchemas())); - - verify(mockGlue, times(2)).getDatabases(nullable(GetDatabasesRequest.class)); } @Test @@ -330,38 +330,32 @@ public void doGetTable() expectedParams.put(DATETIME_FORMAT_MAPPING_PROPERTY, "col2=someformat2, col1=someformat1 "); List columns = new ArrayList<>(); - columns.add(new Column().withName("col1").withType("int").withComment("comment")); - columns.add(new Column().withName("col2").withType("bigint").withComment("comment")); - columns.add(new Column().withName("col3").withType("string").withComment("comment")); - columns.add(new Column().withName("col4").withType("timestamp").withComment("comment")); - columns.add(new Column().withName("col5").withType("date").withComment("comment")); - columns.add(new Column().withName("col6").withType("timestamptz").withComment("comment")); - columns.add(new Column().withName("col7").withType("timestamptz").withComment("comment")); - - List partitionKeys = new ArrayList<>(); - columns.add(new Column().withName("partition_col1").withType("int").withComment("comment")); - - Table mockTable = mock(Table.class); - StorageDescriptor mockSd = mock(StorageDescriptor.class); - - Mockito.lenient().when(mockTable.getName()).thenReturn(table); - when(mockTable.getStorageDescriptor()).thenReturn(mockSd); - when(mockTable.getParameters()).thenReturn(expectedParams); - when(mockSd.getColumns()).thenReturn(columns); - - when(mockGlue.getTable(nullable(com.amazonaws.services.glue.model.GetTableRequest.class))) + columns.add(Column.builder().name("col1").type("int").comment("comment").build()); + columns.add(Column.builder().name("col2").type("bigint").comment("comment").build()); + columns.add(Column.builder().name("col3").type("string").comment("comment").build()); + columns.add(Column.builder().name("col4").type("timestamp").comment("comment").build()); + columns.add(Column.builder().name("col5").type("date").comment("comment").build()); + columns.add(Column.builder().name("col6").type("timestamptz").comment("comment").build()); + columns.add(Column.builder().name("col7").type("timestamptz").comment("comment").build()); + columns.add(Column.builder().name("partition_col1").type("int").comment("comment").build()); + + StorageDescriptor sd = StorageDescriptor.builder().columns(columns).build(); + Table returnTable = Table.builder().storageDescriptor(sd).name(table).parameters(expectedParams).build(); + + when(mockGlue.getTable(nullable(software.amazon.awssdk.services.glue.model.GetTableRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - com.amazonaws.services.glue.model.GetTableRequest request = - (com.amazonaws.services.glue.model.GetTableRequest) invocationOnMock.getArguments()[0]; + software.amazon.awssdk.services.glue.model.GetTableRequest request = + (software.amazon.awssdk.services.glue.model.GetTableRequest) invocationOnMock.getArguments()[0]; - assertEquals(accountId, request.getCatalogId()); - assertEquals(schema, request.getDatabaseName()); - assertEquals(table, request.getName()); + assertEquals(accountId, request.catalogId()); + assertEquals(schema, request.databaseName()); + assertEquals(table, request.name()); - GetTableResult mockResult = mock(GetTableResult.class); - when(mockResult.getTable()).thenReturn(mockTable); - return mockResult; + software.amazon.awssdk.services.glue.model.GetTableResponse tableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder() + .table(returnTable) + .build(); + return tableResponse; }); GetTableRequest req = new GetTableRequest(IdentityUtil.fakeIdentity(), queryId, catalog, new TableName(schema, table), Collections.emptyMap()); @@ -401,8 +395,13 @@ public void populateSourceTableFromLocation() { Map params = new HashMap<>(); List partitions = Arrays.asList("aws", "aws-cn", "aws-us-gov"); for (String partition : partitions) { - StorageDescriptor storageDescriptor = new StorageDescriptor().withLocation(String.format("arn:%s:dynamodb:us-east-1:012345678910:table/My-Table", partition)); - Table table = new Table().withParameters(params).withStorageDescriptor(storageDescriptor); + StorageDescriptor storageDescriptor = StorageDescriptor.builder() + .location(String.format("arn:%s:dynamodb:us-east-1:012345678910:table/My-Table", partition)) + .build(); + Table table = Table.builder() + .parameters(params) + .storageDescriptor(storageDescriptor) + .build(); SchemaBuilder schemaBuilder = new SchemaBuilder(); populateSourceTableNameIfAvailable(table, schemaBuilder); Schema schema = schemaBuilder.build(); @@ -422,29 +421,25 @@ public void doGetTableEmptyComment() expectedParams.put("col1", "col1"); List columns = new ArrayList<>(); - columns.add(new Column().withName("col1").withType("int").withComment(" ")); - - Table mockTable = mock(Table.class); - StorageDescriptor mockSd = mock(StorageDescriptor.class); + columns.add(Column.builder().name("col1").type("int").comment(" ").build()); - Mockito.lenient().when(mockTable.getName()).thenReturn(table); - when(mockTable.getStorageDescriptor()).thenReturn(mockSd); - when(mockTable.getParameters()).thenReturn(expectedParams); - when(mockSd.getColumns()).thenReturn(columns); + StorageDescriptor sd = StorageDescriptor.builder().columns(columns).build(); + Table resultTable = Table.builder().storageDescriptor(sd).parameters(expectedParams).build(); - when(mockGlue.getTable(nullable(com.amazonaws.services.glue.model.GetTableRequest.class))) + when(mockGlue.getTable(nullable(software.amazon.awssdk.services.glue.model.GetTableRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - com.amazonaws.services.glue.model.GetTableRequest request = - (com.amazonaws.services.glue.model.GetTableRequest) invocationOnMock.getArguments()[0]; + software.amazon.awssdk.services.glue.model.GetTableRequest request = + (software.amazon.awssdk.services.glue.model.GetTableRequest) invocationOnMock.getArguments()[0]; - assertEquals(accountId, request.getCatalogId()); - assertEquals(schema, request.getDatabaseName()); - assertEquals(table, request.getName()); + assertEquals(accountId, request.catalogId()); + assertEquals(schema, request.databaseName()); + assertEquals(table, request.name()); - GetTableResult mockResult = mock(GetTableResult.class); - when(mockResult.getTable()).thenReturn(mockTable); - return mockResult; + software.amazon.awssdk.services.glue.model.GetTableResponse response = software.amazon.awssdk.services.glue.model.GetTableResponse.builder() + .table(resultTable) + .build(); + return response; }); GetTableRequest req = new GetTableRequest(IdentityUtil.fakeIdentity(), queryId, catalog, new TableName(schema, table), Collections.emptyMap()); diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/security/CacheableSecretsManagerTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/security/CacheableSecretsManagerTest.java index 181945f55f..3749e6edd2 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/security/CacheableSecretsManagerTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/security/CacheableSecretsManagerTest.java @@ -20,9 +20,10 @@ * #L% */ -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -39,14 +40,14 @@ public class CacheableSecretsManagerTest { - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; private CachableSecretsManager cachableSecretsManager; @Before public void setup() { - mockSecretsManager = mock(AWSSecretsManager.class); + mockSecretsManager = mock(SecretsManagerClient.class); cachableSecretsManager = new CachableSecretsManager(mockSecretsManager); } @@ -67,8 +68,8 @@ public void expirationTest() when(mockSecretsManager.getSecretValue(nullable(GetSecretValueRequest.class))) .thenAnswer((InvocationOnMock invocation) -> { GetSecretValueRequest request = invocation.getArgument(0, GetSecretValueRequest.class); - if (request.getSecretId().equalsIgnoreCase("test")) { - return new GetSecretValueResult().withSecretString("value2"); + if (request.secretId().equalsIgnoreCase("test")) { + return GetSecretValueResponse.builder().secretString("value2").build(); } throw new RuntimeException(); }); @@ -86,7 +87,7 @@ public void evictionTest() when(mockSecretsManager.getSecretValue(nullable(GetSecretValueRequest.class))) .thenAnswer((InvocationOnMock invocation) -> { GetSecretValueRequest request = invocation.getArgument(0, GetSecretValueRequest.class); - return new GetSecretValueResult().withSecretString(request.getSecretId() + "_value"); + return GetSecretValueResponse.builder().secretString(request.secretId() + "_value").build(); }); assertEquals("test_value", cachableSecretsManager.getSecret("test")); @@ -101,11 +102,11 @@ public void resolveSecrets() when(mockSecretsManager.getSecretValue(nullable(GetSecretValueRequest.class))) .thenAnswer((InvocationOnMock invocation) -> { GetSecretValueRequest request = invocation.getArgument(0, GetSecretValueRequest.class); - String result = request.getSecretId(); + String result = request.secretId(); if (result.equalsIgnoreCase("unknown")) { throw new RuntimeException("Unknown secret!"); } - return new GetSecretValueResult().withSecretString(result); + return GetSecretValueResponse.builder().secretString(result).build(); }); String oneSecret = "${OneSecret}"; diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/serde/v2/LambdaFunctionExceptionSerDeTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/serde/v2/LambdaFunctionExceptionSerDeTest.java index bb5c75043d..a9972e9ca1 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/serde/v2/LambdaFunctionExceptionSerDeTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/serde/v2/LambdaFunctionExceptionSerDeTest.java @@ -20,23 +20,21 @@ package com.amazonaws.athena.connector.lambda.serde.v2; import com.amazonaws.athena.connector.lambda.serde.TypedSerDeTest; -import com.amazonaws.services.lambda.invoke.LambdaFunctionException; import com.google.common.collect.ImmutableList; import org.junit.Before; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.lambda.model.LambdaException; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.Constructor; -import java.util.List; import static org.junit.Assert.assertEquals; public class LambdaFunctionExceptionSerDeTest - extends TypedSerDeTest + extends TypedSerDeTest { private static final Logger logger = LoggerFactory.getLogger(LambdaFunctionExceptionSerDeTest.class); @@ -47,14 +45,10 @@ public void beforeTest() String errorType = "com.amazonaws.services.dynamodbv2.model.ResourceNotFoundException"; String errorMessage = "Requested resource not found (Service: AmazonDynamoDBv2; Status Code: 400; Error Code: ResourceNotFoundException; Request ID: RIB6NOH4BNMAK6KQG88R5VE583VV4KQNSO5AEMVJF66Q9ASUAAJG)"; - ImmutableList stackTrace = ImmutableList.of( - "com.amazonaws.http.AmazonHttpClient$RequestExecutor.handleErrorResponse(AmazonHttpClient.java:1701)", - "com.amazonaws.http.AmazonHttpClient$RequestExecutor.executeOneRequest(AmazonHttpClient.java:1356)"); - Constructor constructor = LambdaFunctionException.class.getDeclaredConstructor( - String.class, String.class, LambdaFunctionException.class, List.class); - constructor.setAccessible(true); - expected = constructor.newInstance(errorType, errorMessage, null, stackTrace); - + ImmutableList> stackTrace = ImmutableList.of( + ImmutableList.of("com.amazonaws.http.AmazonHttpClient$RequestExecutor.handleErrorResponse(AmazonHttpClient.java:1701)"), + ImmutableList.of("com.amazonaws.http.AmazonHttpClient$RequestExecutor.executeOneRequest(AmazonHttpClient.java:1356)")); + expected = (LambdaException) LambdaException.builder().message(errorMessage + ". Stack trace: " + stackTrace + "\nErrorType: " + errorType).build(); String expectedSerDeFile = utils.getResourceOrFail("serde/v2", "LambdaFunctionException.json"); expectedSerDeText = utils.readAllAsString(expectedSerDeFile).trim(); } @@ -73,11 +67,10 @@ public void deserialize() logger.info("deserialize: enter"); InputStream input = new ByteArrayInputStream(expectedSerDeText.getBytes()); - LambdaFunctionException actual = mapper.readValue(input, LambdaFunctionException.class); + LambdaException actual = mapper.readValue(input, LambdaException.class); logger.info("deserialize: deserialized[{}]", actual.toString()); - assertEquals(expected.getType(), actual.getType()); assertEquals(expected.getMessage(), actual.getMessage()); assertEquals(expected.getCause(), actual.getCause()); expected.fillInStackTrace(); diff --git a/athena-gcs/Dockerfile b/athena-gcs/Dockerfile new file mode 100644 index 0000000000..04614fe43b --- /dev/null +++ b/athena-gcs/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-gcs.zip ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-gcs.zip + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.gcs.GcsCompositeHandler" ] \ No newline at end of file diff --git a/athena-gcs/athena-gcs.yaml b/athena-gcs/athena-gcs.yaml index 3c7e3b3f1b..fa97bc7f86 100644 --- a/athena-gcs/athena-gcs.yaml +++ b/athena-gcs/athena-gcs.yaml @@ -59,10 +59,9 @@ Resources: spill_prefix: !Ref SpillPrefix secret_manager_gcp_creds_name: !Ref GCSSecretName FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.gcs.GcsCompositeHandler" - CodeUri: "./target/athena-gcs.zip" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-gcs:2022.47.1' Description: "Amazon Athena GCS Connector" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsMetadataHandler.java b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsMetadataHandler.java index 231b5763a3..950c2c5745 100644 --- a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsMetadataHandler.java +++ b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsMetadataHandler.java @@ -39,12 +39,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connectors.gcs.common.PartitionUtil; import com.amazonaws.athena.connectors.gcs.storage.StorageMetadata; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.Database; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.VisibleForTesting; @@ -52,6 +46,12 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.Database; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.net.URI; @@ -80,11 +80,11 @@ public class GcsMetadataHandler */ private static final String SOURCE_TYPE = "gcs"; private static final CharSequence GCS_FLAG = "google-cloud-storage-flag"; - private static final DatabaseFilter DB_FILTER = (Database database) -> (database.getLocationUri() != null && database.getLocationUri().contains(GCS_FLAG)); + private static final DatabaseFilter DB_FILTER = (Database database) -> (database.locationUri() != null && database.locationUri().contains(GCS_FLAG)); // used to filter out Glue tables which lack indications of being used for GCS. - private static final TableFilter TABLE_FILTER = (Table table) -> table.getStorageDescriptor().getLocation().startsWith(GCS_LOCATION_PREFIX); + private static final TableFilter TABLE_FILTER = (Table table) -> table.storageDescriptor().location().startsWith(GCS_LOCATION_PREFIX); private final StorageMetadata datasource; - private final AWSGlue glueClient; + private final GlueClient glueClient; private final BufferAllocator allocator; public GcsMetadataHandler(BufferAllocator allocator, java.util.Map configOptions) throws IOException @@ -100,11 +100,11 @@ public GcsMetadataHandler(BufferAllocator allocator, java.util.Map configOptions) throws IOException { super(glueClient, keyFactory, awsSecretsManager, athena, SOURCE_TYPE, spillBucket, spillPrefix, configOptions); @@ -174,9 +174,9 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques //fetch schema from dataset api Schema schema = datasource.buildTableSchema(table, allocator); Map columnNameMapping = getColumnNameMapping(table); - List partitionKeys = table.getPartitionKeys() == null ? com.google.common.collect.ImmutableList.of() : table.getPartitionKeys(); + List partitionKeys = table.partitionKeys() == null ? com.google.common.collect.ImmutableList.of() : table.partitionKeys(); Set partitionCols = partitionKeys.stream() - .map(next -> columnNameMapping.getOrDefault(next.getName(), next.getName())).collect(Collectors.toSet()); + .map(next -> columnNameMapping.getOrDefault(next.name(), next.name())).collect(Collectors.toSet()); return new GetTableResponse(request.getCatalogName(), request.getTableName(), schema, partitionCols); } } @@ -246,14 +246,14 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest //getting storage file list List fileList = datasource.getStorageSplits(locationUri); SpillLocation spillLocation = makeSpillLocation(request); - LOGGER.info("Split list for {}.{} is \n{}", table.getDatabaseName(), table.getName(), fileList); + LOGGER.info("Split list for {}.{} is \n{}", table.databaseName(), table.name(), fileList); //creating splits based folder String storageSplitJson = new ObjectMapper().writeValueAsString(fileList); LOGGER.info("MetadataHandler=GcsMetadataHandler|Method=doGetSplits|Message=StorageSplit JSON\n{}", storageSplitJson); Split.Builder splitBuilder = Split.newBuilder(spillLocation, makeEncryptionKey()) - .add(FILE_FORMAT, table.getParameters().get(CLASSIFICATION_GLUE_TABLE_PARAM)) + .add(FILE_FORMAT, table.parameters().get(CLASSIFICATION_GLUE_TABLE_PARAM)) .add(STORAGE_SPLIT_JSON, storageSplitJson); // set partition column name and value in split diff --git a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandler.java b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandler.java index 2d59e6c7ec..4c86fe7e13 100644 --- a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandler.java +++ b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandler.java @@ -28,12 +28,6 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.arrow.dataset.file.FileFormat; @@ -53,6 +47,9 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.nio.charset.StandardCharsets; import java.util.List; @@ -77,9 +74,9 @@ public class GcsRecordHandler public GcsRecordHandler(BufferAllocator allocator, java.util.Map configOptions) { - this(AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), configOptions); + this(S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), configOptions); this.allocator = allocator; } @@ -91,7 +88,7 @@ public GcsRecordHandler(BufferAllocator allocator, java.util.Map * @param amazonAthena An instance of AmazonAthena */ @VisibleForTesting - protected GcsRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena amazonAthena, java.util.Map configOptions) + protected GcsRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, java.util.Map configOptions) { super(amazonS3, secretsManager, amazonAthena, SOURCE_TYPE, configOptions); this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build(); diff --git a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsThrottlingExceptionFilter.java b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsThrottlingExceptionFilter.java index cc71cc64f2..b26c36677d 100644 --- a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsThrottlingExceptionFilter.java +++ b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsThrottlingExceptionFilter.java @@ -20,7 +20,7 @@ package com.amazonaws.athena.connectors.gcs; import com.amazonaws.athena.connector.lambda.ThrottlingInvoker; -import com.amazonaws.services.athena.model.AmazonAthenaException; +import software.amazon.awssdk.services.athena.model.AthenaException; public class GcsThrottlingExceptionFilter implements ThrottlingInvoker.ExceptionFilter { @@ -29,7 +29,7 @@ public class GcsThrottlingExceptionFilter implements ThrottlingInvoker.Exception @Override public boolean isMatch(Exception ex) { - return (ex instanceof AmazonAthenaException && ex.getMessage().contains("Rate exceeded")) + return (ex instanceof AthenaException && ex.getMessage().contains("Rate exceeded")) || ex.getMessage().contains("Too Many Requests"); } } diff --git a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsUtil.java b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsUtil.java index 7a273585e2..aea6586626 100644 --- a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsUtil.java +++ b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsUtil.java @@ -22,16 +22,15 @@ import com.amazonaws.athena.connector.lambda.data.DateTimeFormatterUtil; import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.GetTableRequest; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.sun.jna.platform.unix.LibC; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.types.pojo.ArrowType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.GetTableRequest; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; @@ -100,7 +99,7 @@ public static void installCaCertificate() throws IOException, NoSuchAlgorithmExc */ public static void installGoogleCredentialsJsonFile(java.util.Map configOptions) throws IOException { - CachableSecretsManager secretsManager = new CachableSecretsManager(AWSSecretsManagerClientBuilder.defaultClient()); + CachableSecretsManager secretsManager = new CachableSecretsManager(SecretsManagerClient.create()); String gcsCredentialsJsonString = secretsManager.getSecret(configOptions.get(GCS_SECRET_KEY_ENV_VAR)); File destination = new File(GOOGLE_SERVICE_ACCOUNT_JSON_TEMP_FILE_LOCATION_VALUE); boolean destinationDirExists = new File(destination.getParent()).mkdirs(); @@ -143,14 +142,15 @@ public static String createUri(String path) * @param awsGlue AWS Glue client * @return Table object */ - public static Table getGlueTable(TableName tableName, AWSGlue awsGlue) + public static Table getGlueTable(TableName tableName, GlueClient awsGlue) { - GetTableRequest getTableRequest = new GetTableRequest(); - getTableRequest.setDatabaseName(tableName.getSchemaName()); - getTableRequest.setName(tableName.getTableName()); + GetTableRequest getTableRequest = GetTableRequest.builder() + .databaseName(tableName.getSchemaName()) + .name(tableName.getTableName()) + .build(); - GetTableResult result = awsGlue.getTable(getTableRequest); - return result.getTable(); + software.amazon.awssdk.services.glue.model.GetTableResponse response = awsGlue.getTable(getTableRequest); + return response.table(); } // The value returned here is going to block.offerValue, which eventually invokes BlockUtils.setValue() diff --git a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/common/PartitionUtil.java b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/common/PartitionUtil.java index efd0bcf9dd..6de7deffc5 100644 --- a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/common/PartitionUtil.java +++ b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/common/PartitionUtil.java @@ -19,12 +19,12 @@ */ package com.amazonaws.athena.connectors.gcs.common; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.Table; import org.apache.arrow.vector.FieldVector; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.Table; import java.net.URI; import java.net.URISyntaxException; @@ -72,9 +72,9 @@ private PartitionUtil() */ public static Map getPartitionColumnData(Table table, String partitionFolder) { - List partitionKeys = table.getPartitionKeys() == null ? com.google.common.collect.ImmutableList.of() : table.getPartitionKeys(); + List partitionKeys = table.partitionKeys() == null ? com.google.common.collect.ImmutableList.of() : table.partitionKeys(); return getRegExExpression(table).map(folderNameRegEx -> - getPartitionColumnData(table.getParameters().get(PARTITION_PATTERN_KEY), partitionFolder, folderNameRegEx, partitionKeys)) + getPartitionColumnData(table.parameters().get(PARTITION_PATTERN_KEY), partitionFolder, folderNameRegEx, partitionKeys)) .orElse(com.google.common.collect.ImmutableMap.of()); } @@ -93,7 +93,7 @@ protected static Map getPartitionColumnData(String partitionPatt Matcher partitionPatternMatcher = PARTITION_PATTERN.matcher(partitionPattern); Matcher partitionFolderMatcher = Pattern.compile(folderNameRegEx).matcher(partitionFolder); java.util.TreeSet partitionColumnsSet = partitionColumns.stream() - .map(c -> c.getName()) + .map(c -> c.name()) .collect(Collectors.toCollection(() -> new java.util.TreeSet<>(String.CASE_INSENSITIVE_ORDER))); while (partitionFolderMatcher.find()) { for (int j = 1; j <= partitionFolderMatcher.groupCount() && partitionPatternMatcher.find(); j++) { @@ -117,8 +117,8 @@ protected static Map getPartitionColumnData(String partitionPatt private static void validatePartitionColumnTypes(List columns) { for (Column column : columns) { - String columnType = column.getType().toLowerCase(); - LOGGER.info("validatePartitionColumnTypes - Field type of {} is {}", column.getName(), columnType); + String columnType = column.type().toLowerCase(); + LOGGER.info("validatePartitionColumnTypes - Field type of {} is {}", column.name(), columnType); switch (columnType) { case "string": case "varchar": @@ -140,9 +140,9 @@ private static void validatePartitionColumnTypes(List columns) */ protected static Optional getRegExExpression(Table table) { - List partitionColumns = table.getPartitionKeys() == null ? com.google.common.collect.ImmutableList.of() : table.getPartitionKeys(); + List partitionColumns = table.partitionKeys() == null ? com.google.common.collect.ImmutableList.of() : table.partitionKeys(); validatePartitionColumnTypes(partitionColumns); - String partitionPattern = table.getParameters().get(PARTITION_PATTERN_KEY); + String partitionPattern = table.parameters().get(PARTITION_PATTERN_KEY); // Check to see if there is a partition pattern configured for the Table by the user // if not, it returns empty value if (partitionPattern == null || StringUtils.isBlank(partitionPattern)) { @@ -170,8 +170,8 @@ protected static Optional getRegExExpression(Table table) public static URI getPartitionsFolderLocationUri(Table table, List fieldVectors, int readerPosition) throws URISyntaxException { String locationUri; - String tableLocation = table.getStorageDescriptor().getLocation(); - String partitionPattern = table.getParameters().get(PARTITION_PATTERN_KEY); + String tableLocation = table.storageDescriptor().location(); + String partitionPattern = table.parameters().get(PARTITION_PATTERN_KEY); if (null != partitionPattern) { for (FieldVector fieldVector : fieldVectors) { fieldVector.getReader().setPosition(readerPosition); diff --git a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/filter/FilterExpressionBuilder.java b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/filter/FilterExpressionBuilder.java index 9218895c35..13b7780ff6 100644 --- a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/filter/FilterExpressionBuilder.java +++ b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/filter/FilterExpressionBuilder.java @@ -21,9 +21,9 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; -import com.amazonaws.services.glue.model.Column; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.glue.model.Column; import java.util.List; import java.util.Map; @@ -50,8 +50,8 @@ public static Map>> getConstraintsForPartitionedCol { LOGGER.info("Constraint summaries: \n{}", constraints.getSummary()); return partitionColumns.stream().collect(Collectors.toMap( - column -> column.getName(), - column -> singleValuesStringSetFromValueSet(constraints.getSummary().get(column.getName())), + column -> column.name(), + column -> singleValuesStringSetFromValueSet(constraints.getSummary().get(column.name())), // Also we are forced to use Optional here because Collectors.toMap() doesn't allow null values to // be passed into the merge function (it asserts that the values are not null) // We shouldn't have duplicates but just merge the sets if we do. diff --git a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/storage/StorageMetadata.java b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/storage/StorageMetadata.java index dd48cba8d7..426654f2ca 100644 --- a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/storage/StorageMetadata.java +++ b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/storage/StorageMetadata.java @@ -26,9 +26,6 @@ import com.amazonaws.athena.connectors.gcs.GcsUtil; import com.amazonaws.athena.connectors.gcs.common.PartitionUtil; import com.amazonaws.athena.connectors.gcs.filter.FilterExpressionBuilder; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.Table; import com.google.api.gax.paging.Page; import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.storage.Blob; @@ -47,6 +44,9 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.Table; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -132,17 +132,17 @@ public List getStorageSplits(URI locationUri) * @return A list of {@link Map} instances * @throws URISyntaxException Throws if any occurs during parsing Uri */ - public List> getPartitionFolders(Schema schema, TableName tableInfo, Constraints constraints, AWSGlue awsGlue) + public List> getPartitionFolders(Schema schema, TableName tableInfo, Constraints constraints, GlueClient awsGlue) throws URISyntaxException { LOGGER.info("Getting partition folder(s) for table {}.{}", tableInfo.getSchemaName(), tableInfo.getTableName()); Table table = GcsUtil.getGlueTable(tableInfo, awsGlue); // Build expression only based on partition keys - List partitionColumns = table.getPartitionKeys() == null ? com.google.common.collect.ImmutableList.of() : table.getPartitionKeys(); + List partitionColumns = table.partitionKeys() == null ? com.google.common.collect.ImmutableList.of() : table.partitionKeys(); // getConstraintsForPartitionedColumns gives us a case insensitive mapping of column names to their value set Map>> columnValueConstraintMap = FilterExpressionBuilder.getConstraintsForPartitionedColumns(partitionColumns, constraints); LOGGER.info("columnValueConstraintMap for the request of {}.{} is \n{}", tableInfo.getSchemaName(), tableInfo.getTableName(), columnValueConstraintMap); - URI storageLocation = new URI(table.getStorageDescriptor().getLocation()); + URI storageLocation = new URI(table.storageDescriptor().location()); LOGGER.info("Listing object in location {} under the bucket {}", storageLocation.getAuthority(), storageLocation.getPath()); // Trim leading / String path = storageLocation.getPath().replaceFirst("^/", ""); @@ -226,9 +226,9 @@ private boolean partitionConstraintsSatisfied(Map partitionMap, public Schema buildTableSchema(Table table, BufferAllocator allocator) throws URISyntaxException { SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); - String locationUri = table.getStorageDescriptor().getLocation(); + String locationUri = table.storageDescriptor().location(); URI storageLocation = new URI(locationUri); - List fieldList = getFields(storageLocation.getAuthority(), storageLocation.getPath(), table.getParameters().get(CLASSIFICATION_GLUE_TABLE_PARAM), allocator); + List fieldList = getFields(storageLocation.getAuthority(), storageLocation.getPath(), table.parameters().get(CLASSIFICATION_GLUE_TABLE_PARAM), allocator); LOGGER.debug("Schema Fields\n{}", fieldList); for (Field field : fieldList) { diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsCompositeHandlerTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsCompositeHandlerTest.java index 1fa7d68b61..bc3d54c7dc 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsCompositeHandlerTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsCompositeHandlerTest.java @@ -19,11 +19,6 @@ */ package com.amazonaws.athena.connectors.gcs; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.ServiceAccountCredentials; import org.junit.jupiter.api.AfterAll; @@ -31,6 +26,10 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.mockito.Mockito; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.io.IOException; import java.security.KeyStoreException; @@ -46,23 +45,21 @@ @TestInstance(PER_CLASS) public class GcsCompositeHandlerTest extends GenericGcsTest { - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; private ServiceAccountCredentials serviceAccountCredentials; private GoogleCredentials credentials; @BeforeAll public void init() { super.initCommonMockedStatic(); - secretsManager = Mockito.mock(AWSSecretsManager.class); - mockedSecretManagerBuilder.when(AWSSecretsManagerClientBuilder::defaultClient).thenReturn(secretsManager); + secretsManager = Mockito.mock(SecretsManagerClient.class); + mockedSecretManagerBuilder.when(SecretsManagerClient::create).thenReturn(secretsManager); serviceAccountCredentials = Mockito.mock(ServiceAccountCredentials.class); mockedServiceAccountCredentials.when(() -> ServiceAccountCredentials.fromStream(Mockito.any())).thenReturn(serviceAccountCredentials); credentials = Mockito.mock(GoogleCredentials.class); mockedGoogleCredentials.when(() -> GoogleCredentials.fromStream(Mockito.any())).thenReturn(credentials); - AmazonS3ClientBuilder mockedAmazonS3Builder = Mockito.mock(AmazonS3ClientBuilder.class); - AmazonS3 mockedAmazonS3 = Mockito.mock(AmazonS3.class); - when(mockedAmazonS3Builder.build()).thenReturn(mockedAmazonS3); - mockedS3Builder.when(AmazonS3ClientBuilder::standard).thenReturn(mockedAmazonS3Builder); + S3Client mockedAmazonS3 = Mockito.mock(S3Client.class); + when(S3Client.create()).thenReturn(mockedAmazonS3); } @AfterAll @@ -74,8 +71,11 @@ public void cleanUp() { @Test public void testGcsCompositeHandler() throws IOException, CertificateEncodingException, NoSuchAlgorithmException, KeyStoreException { - GetSecretValueResult getSecretValueResult = new GetSecretValueResult().withVersionStages(com.google.common.collect.ImmutableList.of("v1")).withSecretString("{\"gcs_credential_keys\": \"test\"}"); - when(secretsManager.getSecretValue(Mockito.any())).thenReturn(getSecretValueResult); + GetSecretValueResponse getSecretValueResponse = GetSecretValueResponse.builder() + .versionStages(com.google.common.collect.ImmutableList.of("v1")) + .secretString("{\"gcs_credential_keys\": \"test\"}") + .build(); + when(secretsManager.getSecretValue(Mockito.isA(GetSecretValueRequest.class))).thenReturn(getSecretValueResponse); when(ServiceAccountCredentials.fromStream(Mockito.any())).thenReturn(serviceAccountCredentials); when(credentials.createScoped((Collection) any())).thenReturn(credentials); GcsCompositeHandler gcsCompositeHandler = new GcsCompositeHandler(); diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsExceptionFilterTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsExceptionFilterTest.java index 9bbb349420..2dbde1f7c8 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsExceptionFilterTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsExceptionFilterTest.java @@ -19,11 +19,10 @@ */ package com.amazonaws.athena.connectors.gcs; -import com.amazonaws.services.athena.model.AmazonAthenaException; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; - +import software.amazon.awssdk.services.athena.model.AthenaException; import static com.amazonaws.athena.connectors.gcs.GcsThrottlingExceptionFilter.EXCEPTION_FILTER; import static org.junit.Assert.assertEquals; @@ -36,11 +35,11 @@ public class GcsExceptionFilterTest @Test public void testIsMatch() { - boolean match = EXCEPTION_FILTER.isMatch(new AmazonAthenaException("Rate exceeded")); + boolean match = EXCEPTION_FILTER.isMatch(AthenaException.builder().message("Rate exceeded").build()); assertTrue(match); - boolean match1 = EXCEPTION_FILTER.isMatch(new AmazonAthenaException("Too Many Requests")); + boolean match1 = EXCEPTION_FILTER.isMatch(AthenaException.builder().message("Too Many Requests").build()); assertTrue(match1); - boolean match3 = EXCEPTION_FILTER.isMatch(new AmazonAthenaException("other")); + boolean match3 = EXCEPTION_FILTER.isMatch(AthenaException.builder().message("other").build()); assertFalse(match3); } } diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsMetadataHandlerTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsMetadataHandlerTest.java index 84005f8d02..1db09d5f77 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsMetadataHandlerTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsMetadataHandlerTest.java @@ -40,19 +40,6 @@ import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.gcs.storage.StorageMetadata; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.AWSGlueClientBuilder; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.Database; -import com.amazonaws.services.glue.model.GetDatabasesResult; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.GetTablesResult; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.api.gax.paging.Page; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.ServiceAccountCredentials; @@ -77,6 +64,20 @@ import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.Database; +import software.amazon.awssdk.services.glue.model.GetDatabasesRequest; +import software.amazon.awssdk.services.glue.model.GetDatabasesResponse; +import software.amazon.awssdk.services.glue.model.GetTablesRequest; +import software.amazon.awssdk.services.glue.model.GetTablesResponse; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.glue.paginators.GetDatabasesIterable; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.util.ArrayList; import java.util.Collection; @@ -128,19 +129,18 @@ public class GcsMetadataHandlerTest private BlockAllocator blockAllocator; private FederatedIdentity federatedIdentity; @Mock - private AWSGlue awsGlue; + private GlueClient awsGlue; @Mock - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; @Mock private ServiceAccountCredentials serviceAccountCredentials; @Mock - private AmazonAthena athena; + private AthenaClient athena; private MockedStatic mockedStorageOptions; private MockedStatic mockedServiceAccountCredentials; private MockedStatic mockedServiceGoogleCredentials; - private MockedStatic mockedAWSSecretsManagerClientBuilder; - private MockedStatic mockedAWSGlueClientBuilder; + private MockedStatic mockedAWSSecretsManagerClientBuilder; @Before public void setUp() throws Exception @@ -149,8 +149,7 @@ public void setUp() throws Exception mockedStorageOptions = mockStatic(StorageOptions.class); mockedServiceAccountCredentials = mockStatic(ServiceAccountCredentials.class); mockedServiceGoogleCredentials = mockStatic(GoogleCredentials.class); - mockedAWSSecretsManagerClientBuilder = mockStatic(AWSSecretsManagerClientBuilder.class); - mockedAWSGlueClientBuilder = mockStatic(AWSGlueClientBuilder.class); + mockedAWSSecretsManagerClientBuilder = mockStatic(SecretsManagerClient.class); Storage storage = mock(Storage.class); Blob blob = mock(Blob.class); @@ -170,10 +169,12 @@ public void setUp() throws Exception Mockito.when(GoogleCredentials.fromStream(Mockito.any())).thenReturn(credentials); Mockito.when(credentials.createScoped((Collection) any())).thenReturn(credentials); - Mockito.when(AWSSecretsManagerClientBuilder.defaultClient()).thenReturn(secretsManager); - GetSecretValueResult getSecretValueResult = new GetSecretValueResult().withVersionStages(ImmutableList.of("v1")).withSecretString("{\"gcs_credential_keys\": \"test\"}"); - Mockito.when(secretsManager.getSecretValue(Mockito.any())).thenReturn(getSecretValueResult); - Mockito.when(AWSGlueClientBuilder.defaultClient()).thenReturn(awsGlue); + Mockito.when(SecretsManagerClient.create()).thenReturn(secretsManager); + GetSecretValueResponse getSecretValueResponse = GetSecretValueResponse.builder() + .versionStages(ImmutableList.of("v1")) + .secretString("{\"gcs_credential_keys\": \"test\"}") + .build(); + Mockito.when(secretsManager.getSecretValue(Mockito.isA(GetSecretValueRequest.class))).thenReturn(getSecretValueResponse); gcsMetadataHandler = new GcsMetadataHandler(new LocalKeyFactory(), secretsManager, athena, "spillBucket", "spillPrefix", awsGlue, allocator, ImmutableMap.of()); blockAllocator = new BlockAllocatorImpl(); federatedIdentity = Mockito.mock(FederatedIdentity.class); @@ -186,18 +187,20 @@ public void tearDown() mockedServiceAccountCredentials.close(); mockedServiceGoogleCredentials.close(); mockedAWSSecretsManagerClientBuilder.close(); - mockedAWSGlueClientBuilder.close(); } @Test public void testDoListSchemaNames() throws Exception { - GetDatabasesResult result = new GetDatabasesResult().withDatabaseList( - new Database().withName(DATABASE_NAME).withLocationUri(S3_GOOGLE_CLOUD_STORAGE_FLAG), - new Database().withName(DATABASE_NAME1).withLocationUri(S3_GOOGLE_CLOUD_STORAGE_FLAG)); + GetDatabasesResponse response = GetDatabasesResponse.builder().databaseList( + Database.builder().name(DATABASE_NAME).locationUri(S3_GOOGLE_CLOUD_STORAGE_FLAG).build(), + Database.builder().name(DATABASE_NAME1).locationUri(S3_GOOGLE_CLOUD_STORAGE_FLAG).build() + ).build(); ListSchemasRequest listSchemasRequest = new ListSchemasRequest(federatedIdentity, QUERY_ID, CATALOG); - Mockito.when(awsGlue.getDatabases(any())).thenReturn(result); + GetDatabasesIterable mockIterable = mock(GetDatabasesIterable.class); + when(mockIterable.stream()).thenReturn(Collections.singletonList(response).stream()); + when(awsGlue.getDatabasesPaginator(any(GetDatabasesRequest.class))).thenReturn(mockIterable); ListSchemasResponse schemaNamesResponse = gcsMetadataHandler.doListSchemaNames(blockAllocator, listSchemasRequest); List expectedSchemaNames = new ArrayList<>(); expectedSchemaNames.add(DATABASE_NAME); @@ -216,19 +219,22 @@ public void testDoListSchemaNamesThrowsException() throws Exception @Test public void testDoListTables() throws Exception { - GetTablesResult getTablesResult = new GetTablesResult(); List
tableList = new ArrayList<>(); - tableList.add(new Table().withName(TABLE_1) - .withParameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)) - .withStorageDescriptor(new StorageDescriptor() - .withLocation(LOCATION))); - tableList.add(new Table().withName(TABLE_2) - .withParameters(ImmutableMap.of()) - .withStorageDescriptor(new StorageDescriptor() - .withLocation(LOCATION) - .withParameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)))); - getTablesResult.setTableList(tableList); - Mockito.when(awsGlue.getTables(any())).thenReturn(getTablesResult); + tableList.add(Table.builder().name(TABLE_1) + .parameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)) + .storageDescriptor(StorageDescriptor.builder() + .location(LOCATION) + .build()) + .build()); + tableList.add(Table.builder().name(TABLE_2) + .parameters(ImmutableMap.of()) + .storageDescriptor(StorageDescriptor.builder() + .location(LOCATION) + .parameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)) + .build()) + .build()); + GetTablesResponse getTablesResponse = GetTablesResponse.builder().tableList(tableList).build(); + Mockito.when(awsGlue.getTables(any(GetTablesRequest.class))).thenReturn(getTablesResponse); ListTablesRequest listTablesRequest = new ListTablesRequest(federatedIdentity, QUERY_ID, CATALOG, SCHEMA_NAME, TEST_TOKEN, 50); ListTablesResponse tableNamesResponse = gcsMetadataHandler.doListTables(blockAllocator, listTablesRequest); assertEquals(2, tableNamesResponse.getTables().size()); @@ -252,20 +258,24 @@ public void doGetTable() metadataSchema.put("dataFormat", PARQUET); Schema schema = new Schema(asList(field), metadataSchema); GetTableRequest getTableRequest = new GetTableRequest(federatedIdentity, QUERY_ID, "gcs", new TableName(SCHEMA_NAME, "testtable"), Collections.emptyMap()); - Table table = new Table(); - table.setName(TABLE_1); - table.setDatabaseName(DATABASE_NAME); - table.setParameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)); - table.setStorageDescriptor(new StorageDescriptor() - .withLocation(LOCATION).withColumns(new Column().withName("name").withType("String"))); - table.setCatalogId(CATALOG); List columns = ImmutableList.of( createColumn("name", "String") ); - table.setPartitionKeys(columns); - GetTableResult getTableResult = new GetTableResult(); - getTableResult.setTable(table); - Mockito.when(awsGlue.getTable(any())).thenReturn(getTableResult); + Table table = Table.builder() + .name(TABLE_1) + .databaseName(DATABASE_NAME) + .parameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)) + .storageDescriptor(StorageDescriptor.builder() + .location(LOCATION) + .columns(Column.builder().name("name").type("String").build()) + .build()) + .catalogId(CATALOG) + .partitionKeys(columns) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse getTableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder() + .table(table) + .build(); + Mockito.when(awsGlue.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(getTableResponse); StorageMetadata storageMetadata = mock(StorageMetadata.class); FieldUtils.writeField(gcsMetadataHandler, "datasource", storageMetadata, true); Mockito.when(storageMetadata.buildTableSchema(any(), any())).thenReturn(schema); @@ -281,24 +291,28 @@ public void testGetPartitions() throws Exception .addField("year", new ArrowType.Utf8()) .addField("month", new ArrowType.Utf8()) .addField("day", new ArrowType.Utf8()).build(); - Table table = new Table(); - table.setName(TABLE_1); - table.setDatabaseName(DATABASE_NAME); - table.setParameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET, - PARTITION_PATTERN_KEY, "year=${year}/birth_month${month}/${day}") - ); - table.setStorageDescriptor(new StorageDescriptor() - .withLocation(LOCATION).withColumns(new Column())); - table.setCatalogId(CATALOG); List columns = ImmutableList.of( createColumn("year", "varchar"), createColumn("month", "varchar"), createColumn("day", "varchar") ); - table.setPartitionKeys(columns); - GetTableResult getTableResult = new GetTableResult(); - getTableResult.setTable(table); - Mockito.when(awsGlue.getTable(any())).thenReturn(getTableResult); + Table table = Table.builder() + .name(TABLE_1) + .databaseName(DATABASE_NAME) + .parameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET, + PARTITION_PATTERN_KEY, "year=${year}/birth_month${month}/${day}") + ) + .storageDescriptor(StorageDescriptor.builder() + .location(LOCATION) + .columns(Column.builder().build()) + .build()) + .catalogId(CATALOG) + .partitionKeys(columns) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse getTableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder() + .table(table) + .build(); + Mockito.when(awsGlue.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(getTableResponse); GetTableLayoutRequest getTableLayoutRequest = Mockito.mock(GetTableLayoutRequest.class); Mockito.when(getTableLayoutRequest.getTableName()).thenReturn(new TableName(DATABASE_NAME, TABLE_1)); Mockito.when(getTableLayoutRequest.getSchema()).thenReturn(schema); @@ -318,17 +332,15 @@ public void testDoGetSplits() throws Exception QUERY_ID, CATALOG, TABLE_NAME, partitions, ImmutableList.of("year"), new Constraints(new HashMap<>(), Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT, Collections.emptyMap()), null); QueryStatusChecker queryStatusChecker = mock(QueryStatusChecker.class); - GetTableResult getTableResult = mock(GetTableResult.class); - StorageDescriptor storageDescriptor = mock(StorageDescriptor.class); - when(storageDescriptor.getLocation()).thenReturn(LOCATION); - Table table = mock(Table.class); - when(table.getStorageDescriptor()).thenReturn(storageDescriptor); - when(table.getParameters()).thenReturn(ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${year}/", CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)); - when(awsGlue.getTable(any())).thenReturn(getTableResult); - when(getTableResult.getTable()).thenReturn(table); - List columns = ImmutableList.of( - createColumn("year", "varchar") - ); + StorageDescriptor storageDescriptor = StorageDescriptor.builder().location(LOCATION).build(); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .parameters(ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${year}/", CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse getTableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder() + .table(table) + .build(); + when(awsGlue.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(getTableResponse); GetSplitsResponse response = gcsMetadataHandler.doGetSplits(blockAllocator, request); assertEquals(2, response.getSplits().size()); @@ -355,18 +367,17 @@ public void testDoGetSplitsProperty() throws Exception QUERY_ID, CATALOG, TABLE_NAME, partitions, ImmutableList.of("yearCol", "monthCol"), new Constraints(Collections.emptyMap(), Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT, Collections.emptyMap()), null); QueryStatusChecker queryStatusChecker = mock(QueryStatusChecker.class); - GetTableResult getTableResult = mock(GetTableResult.class); - StorageDescriptor storageDescriptor = mock(StorageDescriptor.class); - when(storageDescriptor.getLocation()).thenReturn(LOCATION); - Table table = mock(Table.class); - when(table.getStorageDescriptor()).thenReturn(storageDescriptor); - when(table.getParameters()).thenReturn(ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${yearCol}/month${monthCol}/", CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)); - when(awsGlue.getTable(any())).thenReturn(getTableResult); - when(getTableResult.getTable()).thenReturn(table); - List columns = ImmutableList.of( - createColumn("yearCol", "varchar"), - createColumn("monthCol", "varchar") - ); + StorageDescriptor storageDescriptor = StorageDescriptor.builder() + .location(LOCATION) + .build(); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .parameters(ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${yearCol}/month${monthCol}/", CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse getTableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder() + .table(table) + .build(); + when(awsGlue.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(getTableResponse); GetSplitsResponse response = gcsMetadataHandler.doGetSplits(blockAllocator, request); assertEquals(4, response.getSplits().size()); assertEquals(ImmutableList.of("2016", "2017", "2018", "2019"), response.getSplits().stream().map(split -> split.getProperties().get("yearCol")).sorted().collect(Collectors.toList())); @@ -381,17 +392,17 @@ public void testDoGetSplitsException() throws Exception QUERY_ID, CATALOG, TABLE_NAME, partitions, ImmutableList.of("gcs_file_format"), new Constraints(Collections.emptyMap(), Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT, Collections.emptyMap()), null); QueryStatusChecker queryStatusChecker = mock(QueryStatusChecker.class); - GetTableResult getTableResult = mock(GetTableResult.class); - StorageDescriptor storageDescriptor = mock(StorageDescriptor.class); - when(storageDescriptor.getLocation()).thenReturn(LOCATION); - Table table = mock(Table.class); - when(table.getStorageDescriptor()).thenReturn(storageDescriptor); - when(table.getParameters()).thenReturn(ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${gcs_file_format}/", CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)); - when(awsGlue.getTable(any())).thenReturn(getTableResult); - when(getTableResult.getTable()).thenReturn(table); - List columns = ImmutableList.of( - createColumn("gcs_file_format", "varchar") - ); + StorageDescriptor storageDescriptor = StorageDescriptor.builder() + .location(LOCATION) + .build(); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .parameters(ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${gcs_file_format}/", CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET)) + .build(); + software.amazon.awssdk.services.glue.model.GetTableResponse getTableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder() + .table(table) + .build(); + when(awsGlue.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(getTableResponse); gcsMetadataHandler.doGetSplits(blockAllocator, request); } } diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandlerTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandlerTest.java index 60274ad4f7..3e340142b7 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandlerTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandlerTest.java @@ -34,12 +34,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.auth.oauth2.GoogleCredentials; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -52,6 +46,9 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.File; import java.util.Collections; @@ -73,10 +70,10 @@ public class GcsRecordHandlerTest extends GenericGcsTest private static final Logger LOGGER = LoggerFactory.getLogger(GcsRecordHandlerTest.class); @Mock - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; @Mock - private AmazonAthena athena; + private AthenaClient athena; @Mock GoogleCredentials credentials; @@ -107,7 +104,7 @@ public void initCommonMockedStatic() LOGGER.info("Starting init."); federatedIdentity = Mockito.mock(FederatedIdentity.class); BlockAllocator allocator = new BlockAllocatorImpl(); - AmazonS3 amazonS3 = mock(AmazonS3.class); + S3Client amazonS3 = mock(S3Client.class); // Create Spill config // This will be enough for a single block @@ -125,11 +122,11 @@ public void initCommonMockedStatic() .withSpillLocation(s3SpillLocation) .build(); // To mock AmazonS3 via AmazonS3ClientBuilder - mockedS3Builder.when(AmazonS3ClientBuilder::defaultClient).thenReturn(amazonS3); - // To mock AWSSecretsManager via AWSSecretsManagerClientBuilder - mockedSecretManagerBuilder.when(AWSSecretsManagerClientBuilder::defaultClient).thenReturn(secretsManager); + mockedS3Builder.when(S3Client::create).thenReturn(amazonS3); + // To mock SecretsManagerClient via SecretsManagerClient + mockedSecretManagerBuilder.when(SecretsManagerClient::create).thenReturn(secretsManager); // To mock AmazonAthena via AmazonAthenaClientBuilder - mockedAthenaClientBuilder.when(AmazonAthenaClientBuilder::defaultClient).thenReturn(athena); + mockedAthenaClientBuilder.when(AthenaClient::create).thenReturn(athena); mockedGoogleCredentials.when(() -> GoogleCredentials.fromStream(any())).thenReturn(credentials); Schema schemaForRead = new Schema(GcsTestUtils.getTestSchemaFieldsArrow()); spillWriter = new S3BlockSpiller(amazonS3, spillConfig, allocator, schemaForRead, ConstraintEvaluator.emptyEvaluator(), com.google.common.collect.ImmutableMap.of()); diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsTestUtils.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsTestUtils.java index c99e26807d..50cae17b44 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsTestUtils.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsTestUtils.java @@ -20,8 +20,8 @@ package com.amazonaws.athena.connectors.gcs; import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; -import com.amazonaws.services.glue.model.Column; import com.google.common.collect.ImmutableMap; +import software.amazon.awssdk.services.glue.model.Column; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BitVector; @@ -163,9 +163,10 @@ public static VectorSchemaRoot getVectorSchemaRoot() public static Column createColumn(String name, String type) { - Column column = new Column(); - column.setName(name); - column.setType(type); + Column column = Column.builder() + .name(name) + .type(type) + .build(); return column; } public static Map createSummaryWithLValueRangeEqual(String fieldName, ArrowType fieldType, Object fieldValue) diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GenericGcsTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GenericGcsTest.java index bd8853cb23..7d6fbef4f4 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GenericGcsTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GenericGcsTest.java @@ -19,21 +19,21 @@ */ package com.amazonaws.athena.connectors.gcs; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.ServiceAccountCredentials; import org.mockito.MockedStatic; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.lang.reflect.Field; public class GenericGcsTest { - protected MockedStatic mockedS3Builder; - protected MockedStatic mockedSecretManagerBuilder; - protected MockedStatic mockedAthenaClientBuilder; + protected MockedStatic mockedS3Builder; + protected MockedStatic mockedSecretManagerBuilder; + protected MockedStatic mockedAthenaClientBuilder; protected MockedStatic mockedGoogleCredentials; protected MockedStatic mockedGcsUtil; @@ -41,9 +41,9 @@ public class GenericGcsTest protected void initCommonMockedStatic() { - mockedS3Builder = Mockito.mockStatic(AmazonS3ClientBuilder.class); - mockedSecretManagerBuilder = Mockito.mockStatic(AWSSecretsManagerClientBuilder.class); - mockedAthenaClientBuilder = Mockito.mockStatic(AmazonAthenaClientBuilder.class); + mockedS3Builder = Mockito.mockStatic(S3Client.class); + mockedSecretManagerBuilder = Mockito.mockStatic(SecretsManagerClient.class); + mockedAthenaClientBuilder = Mockito.mockStatic(AthenaClient.class); mockedGoogleCredentials = Mockito.mockStatic(GoogleCredentials.class); mockedGcsUtil = Mockito.mockStatic(GcsUtil.class); mockedServiceAccountCredentials = Mockito.mockStatic(ServiceAccountCredentials.class); diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/common/PartitionUtilTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/common/PartitionUtilTest.java index 87e845716f..2af0a10c03 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/common/PartitionUtilTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/common/PartitionUtilTest.java @@ -19,13 +19,13 @@ */ package com.amazonaws.athena.connectors.gcs.common; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.Table; import org.junit.Before; import org.junit.Test; -import java.util.AbstractMap; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.Table; + import java.util.List; import java.util.Map; import java.util.Optional; @@ -36,32 +36,32 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class PartitionUtilTest { - private Table table; + private StorageDescriptor storageDescriptor; + private List defaultColumns; @Before public void setup() { - StorageDescriptor storageDescriptor = mock(StorageDescriptor.class); - when(storageDescriptor.getLocation()).thenReturn("gs://mydatalake1test/birthday/"); - table = mock(Table.class); - when(table.getStorageDescriptor()).thenReturn(storageDescriptor); - - List columns = com.google.common.collect.ImmutableList.of( + storageDescriptor = StorageDescriptor.builder() + .location("gs://mydatalake1test/birthday/") + .build(); + defaultColumns = com.google.common.collect.ImmutableList.of( createColumn("year", "bigint"), createColumn("month", "int") ); - when(table.getPartitionKeys()).thenReturn(columns); } @Test(expected = IllegalArgumentException.class) public void testFolderNameRegExPatterExpectException() { - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${year}/birth_month${month}/${day}")); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .partitionKeys(defaultColumns) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${year}/birth_month${month}/${day}")) + .build(); Optional optionalRegEx = PartitionUtil.getRegExExpression(table); assertTrue(optionalRegEx.isPresent()); } @@ -69,7 +69,11 @@ public void testFolderNameRegExPatterExpectException() @Test(expected = IllegalArgumentException.class) public void testFolderNameRegExPatter() { - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${year}/birth_month${month}/")); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .partitionKeys(defaultColumns) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${year}/birth_month${month}/")) + .build(); Optional optionalRegEx = PartitionUtil.getRegExExpression(table); assertTrue(optionalRegEx.isPresent()); assertFalse("Expression shouldn't contain a '{' character", optionalRegEx.get().contains("{")); @@ -87,7 +91,11 @@ public void dynamicFolderExpressionWithDigits() "year=2001/birth_month01/", "month01/" ); - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${year}/birth_month${month}/")); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .partitionKeys(defaultColumns) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "year=${year}/birth_month${month}/")) + .build(); Optional optionalRegEx = PartitionUtil.getRegExExpression(table); assertTrue(optionalRegEx.isPresent()); Pattern folderMatchPattern = Pattern.compile(optionalRegEx.get()); @@ -112,12 +120,14 @@ public void dynamicFolderExpressionWithDefaultsDates() "creation_dt=2012-01-01/", "month01/" ); - // mock - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "creation_dt=${creation_dt}/")); List columns = com.google.common.collect.ImmutableList.of( createColumn("creation_dt", "date") ); - when(table.getPartitionKeys()).thenReturn(columns); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "creation_dt=${creation_dt}/")) + .partitionKeys(columns) + .build(); // build regex Optional optionalRegEx = PartitionUtil.getRegExExpression(table); assertTrue(optionalRegEx.isPresent()); @@ -144,12 +154,14 @@ public void dynamicFolderExpressionWithQuotedVarchar() // failed "state='UP'/", "month01/" ); - // mock - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "state='${stateName}'/")); List columns = com.google.common.collect.ImmutableList.of( createColumn("stateName", "string") ); - when(table.getPartitionKeys()).thenReturn(columns); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "state='${stateName}'/")) + .partitionKeys(columns) + .build(); // build regex Optional optionalRegEx = PartitionUtil.getRegExExpression(table); assertTrue(optionalRegEx.isPresent()); @@ -175,12 +187,13 @@ public void dynamicFolderExpressionWithUnquotedVarchar() "state=UP/", "month01/" ); - // mock - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "state=${stateName}/")); List columns = com.google.common.collect.ImmutableList.of( createColumn("stateName", "string") ); - when(table.getPartitionKeys()).thenReturn(columns); + Table table = Table.builder() + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, "state=${stateName}/")) + .partitionKeys(columns) + .build(); // build regex Optional optionalRegEx = PartitionUtil.getRegExExpression(table); assertTrue(optionalRegEx.isPresent()); @@ -199,11 +212,14 @@ public void dynamicFolderExpressionWithUnquotedVarchar() public void testGetHivePartitions() { String partitionPatten = "year=${year}/birth_month${month}/"; - // mock - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPatten )); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .partitionKeys(defaultColumns) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPatten)) + .build(); Optional optionalRegEx = PartitionUtil.getRegExExpression(table); assertTrue(optionalRegEx.isPresent()); - Map partitions = PartitionUtil.getPartitionColumnData(partitionPatten, "year=2000/birth_month09/", optionalRegEx.get(), table.getPartitionKeys()); + Map partitions = PartitionUtil.getPartitionColumnData(partitionPatten, "year=2000/birth_month09/", optionalRegEx.get(), table.partitionKeys()); assertFalse("Map of partition values is empty", partitions.isEmpty()); assertEquals("Partitions map size is more than 2", 2, partitions.size()); // Assert partition 1 @@ -216,20 +232,21 @@ public void testGetHivePartitions() @Test(expected = IllegalArgumentException.class) public void testGetHiveNonHivePartitions() { - // mock List columns = com.google.common.collect.ImmutableList.of( createColumn("year", "bigint"), createColumn("month", "int"), createColumn("day", "int") ); - // mock - when(table.getPartitionKeys()).thenReturn(columns); - String partitionPatten = "year=${year}/birth_month${month}/${day}"; - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPatten + "/")); + String partitionPattern = "year=${year}/birth_month${month}/${day}"; + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPattern + "/")) + .partitionKeys(columns) + .build(); Optional optionalRegEx = PartitionUtil.getRegExExpression(table); assertTrue(optionalRegEx.isPresent()); - Map partitions = PartitionUtil.getPartitionColumnData(partitionPatten, "year=2000/birth_month09/12/", - optionalRegEx.get(), table.getPartitionKeys()); + Map partitions = PartitionUtil.getPartitionColumnData(partitionPattern, "year=2000/birth_month09/12/", + optionalRegEx.get(), table.partitionKeys()); assertFalse("List of column prefix is empty", partitions.isEmpty()); assertEquals("Partition size is more than 3", 3, partitions.size()); // Assert partition 1 @@ -251,8 +268,11 @@ public void testGetPartitionFolders() createColumn("month", "int"), createColumn("day", "int") ); - when(table.getPartitionKeys()).thenReturn(columns); String partitionPattern = "year=${year}/birth_month${month}/${day}"; + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .partitionKeys(columns) + .build(); // list of folders in a bucket List bucketFolders = com.google.common.collect.ImmutableList.of( @@ -271,7 +291,7 @@ public void testGetPartitionFolders() Pattern folderMatchingPattern = Pattern.compile(optionalRegEx.get()); for (String folder : bucketFolders) { if (folderMatchingPattern.matcher(folder).matches()) { - Map partitions = PartitionUtil.getPartitionColumnData(partitionPattern, folder, optionalRegEx.get(), table.getPartitionKeys()); + Map partitions = PartitionUtil.getPartitionColumnData(partitionPattern, folder, optionalRegEx.get(), table.partitionKeys()); assertFalse("List of storage partitions is empty", partitions.isEmpty()); assertEquals("Partition size is more than 3", 3, partitions.size()); } @@ -286,10 +306,12 @@ public void testHivePartition() createColumn("statename", "string"), createColumn("zipcode", "varchar") ); - when(table.getPartitionKeys()).thenReturn(columns); String partitionPattern = "StateName=${statename}/ZipCode=${zipcode}"; - // mock - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPattern + "/")); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPattern + "/")) + .partitionKeys(columns) + .build(); // list of folders in a bucket List bucketFolders = com.google.common.collect.ImmutableList.of( "StateName=WB/ZipCode=700099/", @@ -311,7 +333,7 @@ public void testHivePartition() folder = folder.substring(1); } if (folderMatchingPattern.matcher(folder).matches()) { - Map partitions = PartitionUtil.getPartitionColumnData(partitionPattern, folder, optionalRegEx.get(), table.getPartitionKeys()); + Map partitions = PartitionUtil.getPartitionColumnData(partitionPattern, folder, optionalRegEx.get(), table.partitionKeys()); assertFalse("List of storage partitions is empty", partitions.isEmpty()); assertEquals("Partition size is more than 2", 2, partitions.size()); matchCount++; @@ -329,10 +351,12 @@ public void testNonHivePartition() createColumn("district", "varchar"), createColumn("zipcode", "string") ); - when(table.getPartitionKeys()).thenReturn(columns); String partitionPattern = "${statename}/${district}/${zipcode}"; - // mock - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPattern + "/")); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPattern + "/")) + .partitionKeys(columns) + .build(); // list of folders in a bucket List bucketFolders = com.google.common.collect.ImmutableList.of( "WB/Kolkata/700099/", @@ -351,7 +375,7 @@ public void testNonHivePartition() int matchCount = 0; for (String folder : bucketFolders) { if (folderMatchingPattern.matcher(folder).matches()) { - Map partitions = PartitionUtil.getPartitionColumnData(partitionPattern, folder, optionalRegEx.get(), table.getPartitionKeys()); + Map partitions = PartitionUtil.getPartitionColumnData(partitionPattern, folder, optionalRegEx.get(), table.partitionKeys()); assertFalse("List of storage partitions is empty", partitions.isEmpty()); assertEquals("Partition size is more than 3", 3, partitions.size()); matchCount++; @@ -369,10 +393,12 @@ public void testMixedLayoutStringOnlyPartition() createColumn("district", "varchar"), createColumn("zipcode", "string") ); - when(table.getPartitionKeys()).thenReturn(columns); String partitionPattern = "StateName=${statename}/District${district}/${zipcode}"; - // mock - when(table.getParameters()).thenReturn(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPattern + "/")); + Table table = Table.builder() + .storageDescriptor(storageDescriptor) + .parameters(com.google.common.collect.ImmutableMap.of(PARTITION_PATTERN_KEY, partitionPattern + "/")) + .partitionKeys(columns) + .build(); // list of folders in a bucket List bucketFolders = com.google.common.collect.ImmutableList.of( "StateName=WB/DistrictKolkata/700099/", @@ -391,7 +417,7 @@ public void testMixedLayoutStringOnlyPartition() int matchCount = 0; for (String folder : bucketFolders) { if (folderMatchingPattern.matcher(folder).matches()) { - Map partitions = PartitionUtil.getPartitionColumnData(partitionPattern, folder, optionalRegEx.get(), table.getPartitionKeys()); + Map partitions = PartitionUtil.getPartitionColumnData(partitionPattern, folder, optionalRegEx.get(), table.partitionKeys()); assertFalse("List of storage partitions is empty", partitions.isEmpty()); assertEquals("Partition size is more than 3", 3, partitions.size()); matchCount++; diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/filter/FilterExpressionBuilderTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/filter/FilterExpressionBuilderTest.java index b114faa5b0..610aa37ec9 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/filter/FilterExpressionBuilderTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/filter/FilterExpressionBuilderTest.java @@ -19,33 +19,19 @@ */ package com.amazonaws.athena.connectors.gcs.filter; -import com.amazonaws.athena.connector.lambda.data.Block; -import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl; -import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; -import com.amazonaws.athena.connector.lambda.domain.predicate.Marker; -import com.amazonaws.athena.connector.lambda.domain.predicate.Range; -import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; -import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; import com.amazonaws.athena.connectors.gcs.GcsTestUtils; -import com.amazonaws.services.glue.model.Column; -import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; - +import software.amazon.awssdk.services.glue.model.Column; import java.util.Collections; -import java.util.List; import java.util.Map; import static com.amazonaws.athena.connector.lambda.domain.predicate.Constraints.DEFAULT_NO_LIMIT; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.anyString; @RunWith(MockitoJUnitRunner.class) public class FilterExpressionBuilderTest @@ -54,7 +40,7 @@ public class FilterExpressionBuilderTest public void testGetExpressions() { Map>> result = FilterExpressionBuilder.getConstraintsForPartitionedColumns( - com.google.common.collect.ImmutableList.of(new Column().withName("year")), + com.google.common.collect.ImmutableList.of(Column.builder().name("year").build()), new Constraints(GcsTestUtils.createSummaryWithLValueRangeEqual("year", new ArrowType.Utf8(), "1"), Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT, Collections.emptyMap())); assertEquals(result.size(), 1); diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/storage/StorageMetadataTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/storage/StorageMetadataTest.java index 55a8139b01..b49ffec104 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/storage/StorageMetadataTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/storage/StorageMetadataTest.java @@ -23,12 +23,6 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; import com.amazonaws.athena.connectors.gcs.GenericGcsTest; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.AWSGlueClient; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.Table; import com.google.api.gax.paging.Page; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.ServiceAccountCredentials; @@ -53,6 +47,10 @@ import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.Table; import java.util.ArrayList; import java.util.Collection; @@ -141,13 +139,15 @@ private void storageMock() throws Exception @Test public void testBuildTableSchema() throws Exception { - Table table = new Table(); - table.setName("birthday"); - table.setDatabaseName("default"); - table.setParameters(ImmutableMap.of("classification", "parquet")); - table.setStorageDescriptor(new StorageDescriptor() - .withLocation("gs://mydatalake1test/birthday/")); - table.setCatalogId("catalog"); + Table table = Table.builder() + .name("birthday") + .databaseName("default") + .parameters(ImmutableMap.of("classification", "parquet")) + .storageDescriptor(StorageDescriptor.builder() + .location("gs://mydatalake1test/birthday/") + .build()) + .catalogId("catalog") + .build(); storageMetadata = mock(StorageMetadata.class); storageMock(); when(storageMetadata.buildTableSchema(any(), any())).thenCallRealMethod(); @@ -166,7 +166,7 @@ public void testGetPartitionFolders() throws Exception { //single partition getStorageList(ImmutableList.of("year=2000/birthday.parquet", "year=2000/", "year=2000/birthday1.parquet")); - AWSGlue glue = Mockito.mock(AWSGlueClient.class); + GlueClient glue = Mockito.mock(GlueClient.class); List fieldList = ImmutableList.of(new Field("year", FieldType.nullable(new ArrowType.Int(64, true)), null)); List partKeys = ImmutableList.of(createColumn("year", "varchar")); Schema schema = getSchema(glue, fieldList, partKeys, "year=${year}/"); @@ -235,19 +235,23 @@ public void testGetPartitionFolders() throws Exception } @NotNull - private Schema getSchema(AWSGlue glue, List fieldList, List partKeys, String partitionPattern) + private Schema getSchema(GlueClient glue, List fieldList, List partKeys, String partitionPattern) { Map metadataSchema = new HashMap<>(); metadataSchema.put("dataFormat", "parquet"); Schema schema = new Schema(fieldList, metadataSchema); - GetTableResult getTablesResult = new GetTableResult(); - getTablesResult.setTable(new Table().withName(TABLE_1) - .withParameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET, - PARTITION_PATTERN_KEY, partitionPattern)) - .withPartitionKeys(partKeys) - .withStorageDescriptor(new StorageDescriptor() - .withLocation(LOCATION))); - Mockito.when(glue.getTable(any())).thenReturn(getTablesResult); + software.amazon.awssdk.services.glue.model.GetTableResponse getTableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder() + .table(Table.builder() + .name(TABLE_1) + .parameters(ImmutableMap.of(CLASSIFICATION_GLUE_TABLE_PARAM, PARQUET, + PARTITION_PATTERN_KEY, partitionPattern)) + .partitionKeys(partKeys) + .storageDescriptor(StorageDescriptor.builder() + .location(LOCATION) + .build()) + .build()) + .build(); + Mockito.when(glue.getTable(any(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(getTableResponse); return schema; } diff --git a/athena-google-bigquery/Dockerfile b/athena-google-bigquery/Dockerfile new file mode 100644 index 0000000000..b1dbf5ef11 --- /dev/null +++ b/athena-google-bigquery/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-google-bigquery-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-google-bigquery-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.google.bigquery.BigQueryCompositeHandler" ] \ No newline at end of file diff --git a/athena-google-bigquery/athena-google-bigquery.yaml b/athena-google-bigquery/athena-google-bigquery.yaml index e5dd95f9b2..6cdf0cb299 100644 --- a/athena-google-bigquery/athena-google-bigquery.yaml +++ b/athena-google-bigquery/athena-google-bigquery.yaml @@ -79,10 +79,9 @@ Resources: big_query_endpoint: !Ref BigQueryEndpoint GOOGLE_APPLICATION_CREDENTIALS: '/tmp/service-account.json' FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.google.bigquery.BigQueryCompositeHandler" - CodeUri: "./target/athena-google-bigquery-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-google-bigquery:2022.47.1' Description: "Enables Amazon Athena to communicate with BigQuery using Google SDK" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-google-bigquery/pom.xml b/athena-google-bigquery/pom.xml index 08c904e38f..7eb4561080 100644 --- a/athena-google-bigquery/pom.xml +++ b/athena-google-bigquery/pom.xml @@ -25,20 +25,6 @@ jna-platform 5.15.0 - - com.amazonaws - athena-jdbc - 2022.47.1 - test-jar - test - - - - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} - test - software.amazon.awscdk diff --git a/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryExceptionFilter.java b/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryExceptionFilter.java index 866c7593ca..315e96dfc8 100644 --- a/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryExceptionFilter.java +++ b/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryExceptionFilter.java @@ -21,8 +21,8 @@ package com.amazonaws.athena.connectors.google.bigquery; import com.amazonaws.athena.connector.lambda.ThrottlingInvoker; -import com.amazonaws.services.athena.model.AmazonAthenaException; import com.google.cloud.bigquery.BigQueryException; +import software.amazon.awssdk.services.athena.model.AthenaException; public class BigQueryExceptionFilter implements ThrottlingInvoker.ExceptionFilter { public static final ThrottlingInvoker.ExceptionFilter EXCEPTION_FILTER = new BigQueryExceptionFilter(); @@ -30,7 +30,7 @@ public class BigQueryExceptionFilter implements ThrottlingInvoker.ExceptionFilte @Override public boolean isMatch(Exception ex) { - if (ex instanceof AmazonAthenaException && ex.getMessage().contains("Rate exceeded")) { + if (ex instanceof AthenaException && ex.getMessage().contains("Rate exceeded")) { return true; } if (ex instanceof BigQueryException && ex.getMessage().contains("Exceeded rate limits")) { diff --git a/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandler.java b/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandler.java index 531f629d94..cf17a98c1a 100644 --- a/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandler.java +++ b/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandler.java @@ -29,12 +29,6 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.google.bigquery.qpt.BigQueryQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.api.gax.rpc.ServerStream; import com.google.cloud.bigquery.BigQuery; import com.google.cloud.bigquery.BigQueryException; @@ -62,6 +56,9 @@ import org.apache.arrow.vector.types.pojo.Field; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.util.ArrayList; @@ -93,13 +90,13 @@ public class BigQueryRecordHandler BigQueryRecordHandler(java.util.Map configOptions, BufferAllocator allocator) { - this(AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), configOptions, allocator); + this(S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), configOptions, allocator); } @VisibleForTesting - public BigQueryRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, java.util.Map configOptions, BufferAllocator allocator) + public BigQueryRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions, BufferAllocator allocator) { super(amazonS3, secretsManager, athena, BigQueryConstants.SOURCE_TYPE, configOptions); this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build(); diff --git a/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryUtils.java b/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryUtils.java index 06745274ae..7ac14fb2ed 100644 --- a/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryUtils.java +++ b/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryUtils.java @@ -22,10 +22,6 @@ import com.amazonaws.athena.connector.lambda.data.DateTimeFormatterUtil; import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.api.gax.paging.Page; import com.google.auth.Credentials; import com.google.auth.oauth2.ServiceAccountCredentials; @@ -48,6 +44,9 @@ import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.io.ByteArrayInputStream; import java.io.File; @@ -81,11 +80,10 @@ private BigQueryUtils() public static Credentials getCredentialsFromSecretsManager(java.util.Map configOptions) throws IOException { - AWSSecretsManager secretsManager = AWSSecretsManagerClientBuilder.defaultClient(); - GetSecretValueRequest getSecretValueRequest = new GetSecretValueRequest(); - getSecretValueRequest.setSecretId(getEnvBigQueryCredsSmId(configOptions)); - GetSecretValueResult response = secretsManager.getSecretValue(getSecretValueRequest); - return ServiceAccountCredentials.fromStream(new ByteArrayInputStream(response.getSecretString().getBytes())).createScoped( + SecretsManagerClient secretsManager = SecretsManagerClient.create(); + GetSecretValueRequest getSecretValueRequest = GetSecretValueRequest.builder().secretId(getEnvBigQueryCredsSmId(configOptions)).build(); + GetSecretValueResponse response = secretsManager.getSecretValue(getSecretValueRequest); + return ServiceAccountCredentials.fromStream(new ByteArrayInputStream(response.secretString().getBytes())).createScoped( ImmutableSet.of( "https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/drive")); @@ -213,7 +211,7 @@ else if (subField.getType().getStandardType().name().equalsIgnoreCase("Struct")) */ public static void installGoogleCredentialsJsonFile(java.util.Map configOptions) throws IOException { - CachableSecretsManager secretsManager = new CachableSecretsManager(AWSSecretsManagerClientBuilder.defaultClient()); + CachableSecretsManager secretsManager = new CachableSecretsManager(SecretsManagerClient.create()); String gcsCredentialsJsonString = secretsManager.getSecret(configOptions.get(BigQueryConstants.ENV_BIG_QUERY_CREDS_SM_ID)); File destination = new File(TMP_SERVICE_ACCOUNT_JSON); boolean destinationDirExists = new File(destination.getParent()).mkdirs(); diff --git a/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryCompositeHandlerTest.java b/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryCompositeHandlerTest.java index 266598fe38..047c4118bc 100644 --- a/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryCompositeHandlerTest.java +++ b/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryCompositeHandlerTest.java @@ -19,9 +19,6 @@ */ package com.amazonaws.athena.connectors.google.bigquery; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.auth.oauth2.ServiceAccountCredentials; import org.junit.After; import org.junit.Before; @@ -31,6 +28,9 @@ import org.mockito.MockedStatic; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.io.IOException; import java.util.Arrays; @@ -47,12 +47,12 @@ public class BigQueryCompositeHandlerTest System.setProperty("aws.region", "us-east-1"); } - MockedStatic awsSecretManagerClient; + MockedStatic awsSecretManagerClient; MockedStatic serviceAccountCredentialsStatic; MockedStatic bigQueryUtils; private BigQueryCompositeHandler bigQueryCompositeHandler; @Mock - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; @Mock private ServiceAccountCredentials serviceAccountCredentials; @@ -61,7 +61,7 @@ public void setUp() { bigQueryUtils = mockStatic(BigQueryUtils.class); serviceAccountCredentialsStatic = mockStatic(ServiceAccountCredentials.class); - awsSecretManagerClient = mockStatic(AWSSecretsManagerClientBuilder.class); + awsSecretManagerClient = mockStatic(SecretsManagerClient.class); } @After @@ -77,15 +77,18 @@ public void bigQueryCompositeHandlerTest() throws IOException { Exception ex = null; - Mockito.when(AWSSecretsManagerClientBuilder.defaultClient()).thenReturn(secretsManager); - GetSecretValueResult getSecretValueResult = new GetSecretValueResult().withVersionStages(Arrays.asList("v1")).withSecretString("{\n" + - " \"type\": \"service_account\",\n" + - " \"project_id\": \"mockProjectId\",\n" + - " \"private_key_id\": \"mockPrivateKeyId\",\n" + - " \"private_key\": \"-----BEGIN PRIVATE KEY-----\\nmockPrivateKeydsfhdskfhjdfjkdhgfdjkghfdngvfkvfnjvfdjkg\\n-----END PRIVATE KEY-----\\n\",\n" + - " \"client_email\": \"mockabc@mockprojectid.iam.gserviceaccount.com\",\n" + - " \"client_id\": \"000000000000000000000\"\n" + - "}"); + Mockito.when(SecretsManagerClient.create()).thenReturn(secretsManager); + GetSecretValueResponse getSecretValueResponse = GetSecretValueResponse.builder() + .versionStages(Arrays.asList("v1")) + .secretString("{\n" + + " \"type\": \"service_account\",\n" + + " \"project_id\": \"mockProjectId\",\n" + + " \"private_key_id\": \"mockPrivateKeyId\",\n" + + " \"private_key\": \"-----BEGIN PRIVATE KEY-----\\nmockPrivateKeydsfhdskfhjdfjkdhgfdjkghfdngvfkvfnjvfdjkg\\n-----END PRIVATE KEY-----\\n\",\n" + + " \"client_email\": \"mockabc@mockprojectid.iam.gserviceaccount.com\",\n" + + " \"client_id\": \"000000000000000000000\"\n" + + "}") + .build(); Mockito.when(ServiceAccountCredentials.fromStream(any())).thenReturn(serviceAccountCredentials); bigQueryCompositeHandler = new BigQueryCompositeHandler(); diff --git a/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandlerTest.java b/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandlerTest.java index 74f97587bd..371b939508 100644 --- a/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandlerTest.java +++ b/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandlerTest.java @@ -35,9 +35,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.ServerStreamingCallable; import com.google.cloud.bigquery.BigQuery; @@ -79,6 +76,9 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -107,12 +107,12 @@ public class BigQueryRecordHandlerTest BigQuery bigQuery; @Mock - AWSSecretsManager awsSecretsManager; + SecretsManagerClient awsSecretsManager; private String bucket = "bucket"; private String prefix = "prefix"; @Mock - private AmazonAthena athena; + private AthenaClient athena; @Mock private BigQueryReadClient bigQueryReadClient; @Mock @@ -120,7 +120,7 @@ public class BigQueryRecordHandlerTest @Mock private ArrowSchema arrowSchema; private BigQueryRecordHandler bigQueryRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpiller spillWriter; private S3BlockSpillReader spillReader; private Schema schemaForRead; @@ -200,7 +200,7 @@ public void init() mockedStatic.when(() -> BigQueryUtils.getBigQueryClient(any(Map.class))).thenReturn(bigQuery); federatedIdentity = Mockito.mock(FederatedIdentity.class); allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); //Create Spill config spillConfig = SpillConfig.newBuilder() diff --git a/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/integ/BigQueryIntegTest.java b/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/integ/BigQueryIntegTest.java index 51cb38fe07..ca41c3751c 100644 --- a/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/integ/BigQueryIntegTest.java +++ b/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/integ/BigQueryIntegTest.java @@ -21,7 +21,6 @@ import com.amazonaws.athena.connector.integ.IntegrationTestBase; import com.amazonaws.athena.connector.integ.data.TestConfig; -import com.amazonaws.services.athena.model.Row; import com.google.common.collect.ImmutableList; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; diff --git a/athena-hbase/Dockerfile b/athena-hbase/Dockerfile new file mode 100644 index 0000000000..6772c2c793 --- /dev/null +++ b/athena-hbase/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-hbase-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-hbase-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.hbase.HbaseCompositeHandler" ] \ No newline at end of file diff --git a/athena-hbase/athena-hbase.yaml b/athena-hbase/athena-hbase.yaml index 447f96048f..c9d70a24e0 100644 --- a/athena-hbase/athena-hbase.yaml +++ b/athena-hbase/athena-hbase.yaml @@ -85,10 +85,9 @@ Resources: principal_name: !Ref PrincipalName hbase_rpc_protection: !Ref HbaseRpcProtection FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.hbase.HbaseCompositeHandler" - CodeUri: "./target/athena-hbase-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-hbase:2022.47.1' Description: "Enables Amazon Athena to communicate with HBase, making your HBase data accessible via SQL" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-hbase/pom.xml b/athena-hbase/pom.xml index f6482a7232..cda70c9c2f 100644 --- a/athena-hbase/pom.xml +++ b/athena-hbase/pom.xml @@ -13,11 +13,6 @@ 2.6.0-hadoop3 - - org.slf4j - slf4j-simple - ${slf4j-log4j.version} - com.amazonaws aws-athena-federation-sdk @@ -45,141 +40,13 @@ ${aws-cdk.version} test - + - com.amazonaws - aws-java-sdk-emr - ${aws-sdk.version} + software.amazon.awssdk + emr + ${aws-sdk-v2.version} test - - org.apache.directory.server - apacheds-kerberos-codec - 2.0.0.AM27 - - - org.apache.directory.api - api-ldap-model - - - - - org.apache.avro - avro - 1.12.0 - - - com.fasterxml.jackson.module - jackson-module-jaxb-annotations - ${fasterxml.jackson.version} - - - org.codehaus.jettison - jettison - 1.5.4 - - - com.google.protobuf - protobuf-java - ${protobuf3.version} - - - org.apache.directory.api - api-ldap-model - 2.1.7 - - - org.eclipse.jetty - jetty-server - ${jetty.version} - - - org.eclipse.jetty - jetty-xml - ${jetty.version} - - - org.eclipse.jetty - jetty-webapp - ${jetty.version} - - - org.eclipse.jetty - jetty-servlet - ${jetty.version} - - - org.eclipse.jetty - jetty-io - ${jetty.version} - - - org.apache.zookeeper - zookeeper - 3.9.2 - - - org.apache.hadoop - hadoop-common - 3.4.0 - - - org.eclipse.jetty - jetty-webapp - - - org.eclipse.jetty - jetty-server - - - org.eclipse.jetty - jetty-xml - - - org.eclipse.jetty - jetty-servlet - - - org.apache.avro - avro - - - org.codehaus.jackson - jackson-mapper-asl - - - org.codehaus.jackson - jackson-xc - - - org.codehaus.jettison - jettison - - - org.eclipse.jetty - jetty-io - - - log4j - log4j - - - com.google.protobuf - protobuf-java - - - - - org.apache.hbase - hbase-common - ${hbase.version} - - - org.apache.hadoop - hadoop-common - - - org.apache.hbase hbase-client @@ -200,30 +67,6 @@ - - - org.apache.httpcomponents - httpclient - ${apache.httpclient.version} - - - - commons-logging - commons-logging - - - - - org.slf4j - slf4j-api - ${slf4j-log4j.version} - - - org.slf4j - jcl-over-slf4j - ${slf4j-log4j.version} - runtime - com.amazonaws aws-lambda-java-log4j2 diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseKerberosUtils.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseKerberosUtils.java index f8e3282ebe..56f0ddcfa8 100644 --- a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseKerberosUtils.java +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseKerberosUtils.java @@ -19,17 +19,15 @@ */ package com.amazonaws.athena.connectors.hbase; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.GetObjectRequest; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectSummary; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; import java.io.BufferedInputStream; import java.io.File; @@ -68,20 +66,24 @@ public static Path copyConfigFilesFromS3ToTempFolder(java.util.Map responseStream = s3Client.getObject(GetObjectRequest.builder() + .bucket(s3Bucket[0]) + .key(s3Object.key()) + .build()); + InputStream inputStream = new BufferedInputStream(responseStream); + String key = s3Object.key(); String fName = key.substring(key.indexOf('/') + 1); if (!fName.isEmpty()) { File file = new File(tempDir + File.separator + fName); diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandler.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandler.java index 6be56ee456..9898a93cec 100644 --- a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandler.java +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandler.java @@ -44,10 +44,6 @@ import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection; import com.amazonaws.athena.connectors.hbase.connection.HbaseConnectionFactory; import com.amazonaws.athena.connectors.hbase.qpt.HbaseQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.Types; @@ -58,6 +54,10 @@ import org.apache.hadoop.hbase.TableName; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.util.ArrayList; @@ -101,12 +101,12 @@ public class HbaseMetadataHandler //is indeed enabled for use by this connector. private static final String HBASE_METADATA_FLAG = "hbase-metadata-flag"; //Used to filter out Glue tables which lack HBase metadata flag. - private static final TableFilter TABLE_FILTER = (Table table) -> table.getParameters().containsKey(HBASE_METADATA_FLAG); + private static final TableFilter TABLE_FILTER = (Table table) -> table.parameters().containsKey(HBASE_METADATA_FLAG); //Used to denote the 'type' of this connector for diagnostic purposes. private static final String SOURCE_TYPE = "hbase"; //The number of rows to scan when attempting to infer schema from an HBase table. private static final int NUM_ROWS_TO_SCAN = 10; - private final AWSGlue awsGlue; + private final GlueClient awsGlue; private final HbaseConnectionFactory connectionFactory; private final HbaseQueryPassthrough queryPassthrough = new HbaseQueryPassthrough(); @@ -120,10 +120,10 @@ public HbaseMetadataHandler(java.util.Map configOptions) @VisibleForTesting protected HbaseMetadataHandler( - AWSGlue awsGlue, + GlueClient awsGlue, EncryptionKeyFactory keyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, HbaseConnectionFactory connectionFactory, String spillBucket, String spillPrefix, diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java index 9e10b503df..ad51a3fc35 100644 --- a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java @@ -31,12 +31,6 @@ import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection; import com.amazonaws.athena.connectors.hbase.connection.HbaseConnectionFactory; import com.amazonaws.athena.connectors.hbase.qpt.HbaseQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -55,6 +49,9 @@ import org.apache.hadoop.hbase.util.Bytes; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.nio.charset.CharacterCodingException; @@ -83,7 +80,7 @@ public class HbaseRecordHandler //Used to denote the 'type' of this connector for diagnostic purposes. private static final String SOURCE_TYPE = "hbase"; - private final AmazonS3 amazonS3; + private final S3Client amazonS3; private final HbaseConnectionFactory connectionFactory; private final HbaseQueryPassthrough queryPassthrough = new HbaseQueryPassthrough(); @@ -91,15 +88,15 @@ public class HbaseRecordHandler public HbaseRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), + S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), new HbaseConnectionFactory(), configOptions); } @VisibleForTesting - protected HbaseRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, HbaseConnectionFactory connectionFactory, java.util.Map configOptions) + protected HbaseRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, HbaseConnectionFactory connectionFactory, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.amazonS3 = amazonS3; diff --git a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandlerTest.java b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandlerTest.java index eb43f48841..5447a8eec1 100644 --- a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandlerTest.java +++ b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandlerTest.java @@ -42,9 +42,6 @@ import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection; import com.amazonaws.athena.connectors.hbase.connection.HbaseConnectionFactory; import com.amazonaws.athena.connectors.hbase.connection.ResultProcessor; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.hadoop.hbase.HRegionInfo; @@ -63,6 +60,9 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.util.ArrayList; @@ -99,13 +99,13 @@ public class HbaseMetadataHandlerTest private HbaseConnectionFactory mockConnFactory; @Mock - private AWSGlue awsGlue; + private GlueClient awsGlue; @Mock - private AWSSecretsManager secretsManager; + private SecretsManagerClient secretsManager; @Mock - private AmazonAthena athena; + private AthenaClient athena; @Before public void setUp() diff --git a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandlerTest.java b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandlerTest.java index 9933043e5a..017608c74d 100644 --- a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandlerTest.java +++ b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandlerTest.java @@ -43,17 +43,6 @@ import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection; import com.amazonaws.athena.connectors.hbase.connection.HbaseConnectionFactory; import com.amazonaws.athena.connectors.hbase.connection.ResultProcessor; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.Types; @@ -74,6 +63,15 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -106,7 +104,7 @@ public class HbaseRecordHandlerTest private HbaseRecordHandler handler; private BlockAllocator allocator; private List mockS3Storage = new ArrayList<>(); - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpillReader spillReader; private Schema schemaForRead; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -124,10 +122,10 @@ public class HbaseRecordHandlerTest private HbaseConnectionFactory mockConnFactory; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Before public void setUp() @@ -139,33 +137,29 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); schemaForRead = TestUtils.makeSchema().addStringField(HbaseSchemaUtils.ROW_COLUMN_NAME).build(); diff --git a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/integ/HbaseIntegTest.java b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/integ/HbaseIntegTest.java index ee7204f1d3..1435e09b70 100644 --- a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/integ/HbaseIntegTest.java +++ b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/integ/HbaseIntegTest.java @@ -26,19 +26,6 @@ import com.amazonaws.athena.connector.integ.data.ConnectorStackAttributes; import com.amazonaws.athena.connector.integ.data.ConnectorVpcAttributes; import com.amazonaws.athena.connector.integ.providers.ConnectorPackagingAttributesProvider; -import com.amazonaws.services.athena.model.Row; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder; -import com.amazonaws.services.elasticmapreduce.model.Application; -import com.amazonaws.services.elasticmapreduce.model.ClusterSummary; -import com.amazonaws.services.elasticmapreduce.model.DescribeClusterRequest; -import com.amazonaws.services.elasticmapreduce.model.DescribeClusterResult; -import com.amazonaws.services.elasticmapreduce.model.ListClustersRequest; -import com.amazonaws.services.elasticmapreduce.model.ListClustersResult; -import com.amazonaws.services.lambda.AWSLambda; -import com.amazonaws.services.lambda.AWSLambdaClientBuilder; -import com.amazonaws.services.lambda.model.InvocationType; -import com.amazonaws.services.lambda.model.InvokeRequest; import com.google.common.collect.ImmutableList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,6 +37,17 @@ import software.amazon.awscdk.core.Stack; import software.amazon.awscdk.services.emr.CfnCluster; import software.amazon.awscdk.services.iam.PolicyDocument; +import software.amazon.awssdk.services.athena.model.Row; +import software.amazon.awssdk.services.emr.EmrClient; +import software.amazon.awssdk.services.emr.model.Application; +import software.amazon.awssdk.services.emr.model.ClusterSummary; +import software.amazon.awssdk.services.emr.model.DescribeClusterRequest; +import software.amazon.awssdk.services.emr.model.DescribeClusterResponse; +import software.amazon.awssdk.services.emr.model.ListClustersRequest; +import software.amazon.awssdk.services.emr.model.ListClustersResponse; +import software.amazon.awssdk.services.lambda.LambdaClient; +import software.amazon.awssdk.services.lambda.model.InvocationType; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; import java.time.LocalDate; import java.time.LocalDateTime; @@ -146,10 +144,10 @@ private Pair getHbaseStack() { .name(dbClusterName) .visibleToAllUsers(Boolean.TRUE) .applications(ImmutableList.of( - new Application().withName("HBase"), - new Application().withName("Hive"), - new Application().withName("Hue"), - new Application().withName("Phoenix"))) + Application.builder().name("HBase").build(), + Application.builder().name("Hive").build(), + Application.builder().name("Hue").build(), + Application.builder().name("Phoenix").build())) .instances(CfnCluster.JobFlowInstancesConfigProperty.builder() .emrManagedMasterSecurityGroup(vpcAttributes.getSecurityGroupId()) .emrManagedSlaveSecurityGroup(vpcAttributes.getSecurityGroupId()) @@ -180,27 +178,27 @@ private Pair getHbaseStack() { */ private String getClusterData() { - AmazonElasticMapReduce emrClient = AmazonElasticMapReduceClientBuilder.defaultClient(); + EmrClient emrClient = EmrClient.create(); try { - ListClustersResult listClustersResult; + ListClustersResponse listClustersResult; String marker = null; Optional dbClusterId; do { // While cluster Id has not yet been found and there are more paginated results. // Get paginated list of EMR clusters. - listClustersResult = emrClient.listClusters(new ListClustersRequest().withMarker(marker)); + listClustersResult = emrClient.listClusters(ListClustersRequest.builder().marker(marker).build()); // Get the cluster id. dbClusterId = getClusterId(listClustersResult); // Get the marker for the next paginated request. - marker = listClustersResult.getMarker(); + marker = listClustersResult.marker(); } while (!dbClusterId.isPresent() && marker != null); // Get the cluster description using the cluster id. - DescribeClusterResult clusterResult = emrClient.describeCluster(new DescribeClusterRequest() - .withClusterId(dbClusterId.orElseThrow(() -> - new RuntimeException("Unable to get cluster description for: " + dbClusterName)))); - return clusterResult.getCluster().getMasterPublicDnsName(); + DescribeClusterResponse clusterResult = emrClient.describeCluster(DescribeClusterRequest.builder() + .clusterId(dbClusterId.orElseThrow(() -> + new RuntimeException("Unable to get cluster description for: " + dbClusterName))).build()); + return clusterResult.cluster().masterPublicDnsName(); } finally { - emrClient.shutdown(); + emrClient.close(); } } @@ -210,12 +208,12 @@ private String getClusterData() * @return Optional String containing the cluster Id that matches the cluster name, or Optional.empty() if match * was not found. */ - private Optional getClusterId(ListClustersResult listClustersResult) + private Optional getClusterId(ListClustersResponse listClustersResult) { - for (ClusterSummary clusterSummary : listClustersResult.getClusters()) { - if (clusterSummary.getName().equals(dbClusterName)) { + for (ClusterSummary clusterSummary : listClustersResult.clusters()) { + if (clusterSummary.name().equals(dbClusterName)) { // Found match for cluster name - return cluster id. - String clusterId = clusterSummary.getId(); + String clusterId = clusterSummary.id(); logger.info("Found Cluster Id for {}: {}", dbClusterName, clusterId); return Optional.of(clusterId); } @@ -279,20 +277,21 @@ protected void setUpTableData() logger.info("----------------------------------------------------"); String hbaseLambdaName = "integ-hbase-" + UUID.randomUUID(); - AWSLambda lambdaClient = AWSLambdaClientBuilder.defaultClient(); + LambdaClient lambdaClient = LambdaClient.create(); CloudFormationClient cloudFormationHbaseClient = new CloudFormationClient(getHbaseLambdaStack(hbaseLambdaName)); try { // Create the Lambda function. cloudFormationHbaseClient.createStack(); // Invoke the Lambda function. - lambdaClient.invoke(new InvokeRequest() - .withFunctionName(hbaseLambdaName) - .withInvocationType(InvocationType.RequestResponse)); + lambdaClient.invoke(InvokeRequest.builder() + .functionName(hbaseLambdaName) + .invocationType(InvocationType.REQUEST_RESPONSE) + .build()); } finally { // Delete the Lambda function. cloudFormationHbaseClient.deleteStack(); - lambdaClient.shutdown(); + lambdaClient.close(); } } @@ -376,13 +375,13 @@ public void selectColumnWithPredicateIntegTest() String query = String .format("select \"info:lead_actor\" from %s.%s.%s where \"movie:title\" = 'Aliens';", lambdaFunctionName, hbaseDbName, hbaseTableName); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List actors = new ArrayList<>(); - rows.forEach(row -> actors.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> actors.add(row.data().get(0).varCharValue())); logger.info("Actors: {}", actors); assertEquals("Wrong number of DB records found.", 1, actors.size()); assertTrue("Actor not found: Sigourney Weaver.", actors.contains("Sigourney Weaver")); @@ -397,13 +396,13 @@ public void selectIntegerTypeTest() String query = String.format("select \"datatype:int_type\" from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Integer.parseInt(row.getData().get(0).getVarCharValue().split("\\.")[0]))); + rows.forEach(row -> values.add(Integer.parseInt(row.data().get(0).varCharValue().split("\\.")[0]))); logger.info("Titles: {}", values); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Integer not found: " + TEST_DATATYPES_INT_VALUE, values.contains(TEST_DATATYPES_INT_VALUE)); @@ -418,13 +417,13 @@ public void selectVarcharTypeTest() String query = String.format("select \"datatype:varchar_type\" from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> values.add(row.data().get(0).varCharValue())); logger.info("Titles: {}", values); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Varchar not found: " + TEST_DATATYPES_VARCHAR_VALUE, values.contains(TEST_DATATYPES_VARCHAR_VALUE)); @@ -439,13 +438,13 @@ public void selectBooleanTypeTest() String query = String.format("select \"datatype:boolean_type\" from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Boolean.valueOf(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(Boolean.valueOf(row.data().get(0).varCharValue()))); logger.info("Titles: {}", values); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Boolean not found: " + TEST_DATATYPES_BOOLEAN_VALUE, values.contains(TEST_DATATYPES_BOOLEAN_VALUE)); @@ -460,13 +459,13 @@ public void selectSmallintTypeTest() String query = String.format("select \"datatype:smallint_type\" from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Short.valueOf(row.getData().get(0).getVarCharValue().split("\\.")[0]))); + rows.forEach(row -> values.add(Short.valueOf(row.data().get(0).varCharValue().split("\\.")[0]))); logger.info("Titles: {}", values); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Smallint not found: " + TEST_DATATYPES_SHORT_VALUE, values.contains(TEST_DATATYPES_SHORT_VALUE)); @@ -481,13 +480,13 @@ public void selectBigintTypeTest() String query = String.format("select \"datatype:bigint_type\" from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Long.valueOf(row.getData().get(0).getVarCharValue().split("\\.")[0]))); + rows.forEach(row -> values.add(Long.valueOf(row.data().get(0).varCharValue().split("\\.")[0]))); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Bigint not found: " + TEST_DATATYPES_LONG_VALUE, values.contains(TEST_DATATYPES_LONG_VALUE)); } @@ -501,13 +500,13 @@ public void selectFloat4TypeTest() String query = String.format("select \"datatype:float4_type\" from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Float.valueOf(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(Float.valueOf(row.data().get(0).varCharValue()))); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Float4 not found: " + TEST_DATATYPES_SINGLE_PRECISION_VALUE, values.contains(TEST_DATATYPES_SINGLE_PRECISION_VALUE)); } @@ -521,13 +520,13 @@ public void selectFloat8TypeTest() String query = String.format("select \"datatype:float8_type\" from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(Double.valueOf(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(Double.valueOf(row.data().get(0).varCharValue()))); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Float8 not found: " + TEST_DATATYPES_DOUBLE_PRECISION_VALUE, values.contains(TEST_DATATYPES_DOUBLE_PRECISION_VALUE)); } @@ -541,13 +540,13 @@ public void selectDateTypeTest() String query = String.format("select \"datatype:date_type\" from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); - rows.forEach(row -> values.add(LocalDate.parse(row.getData().get(0).getVarCharValue()))); + rows.forEach(row -> values.add(LocalDate.parse(row.data().get(0).varCharValue()))); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Date not found: " + TEST_DATATYPES_DATE_VALUE, values.contains(LocalDate.parse(TEST_DATATYPES_DATE_VALUE))); } @@ -561,15 +560,15 @@ public void selectTimestampTypeTest() String query = String.format("select \"datatype:timestamp_type\" from %s.%s.%s;", lambdaFunctionName, INTEG_TEST_DATABASE_NAME, TEST_DATATYPES_TABLE_NAME); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List values = new ArrayList<>(); // for some reason, timestamps lose their 'T'. - rows.forEach(row -> values.add(LocalDateTime.parse(row.getData().get(0).getVarCharValue().replace(' ', 'T')))); - logger.info(rows.get(0).getData().get(0).getVarCharValue()); + rows.forEach(row -> values.add(LocalDateTime.parse(row.data().get(0).varCharValue().replace(' ', 'T')))); + logger.info(rows.get(0).data().get(0).varCharValue()); assertEquals("Wrong number of DB records found.", 1, values.size()); assertTrue("Date not found: " + TEST_DATATYPES_TIMESTAMP_VALUE, values.contains(LocalDateTime.parse(TEST_DATATYPES_TIMESTAMP_VALUE))); } diff --git a/athena-hortonworks-hive/Dockerfile b/athena-hortonworks-hive/Dockerfile new file mode 100644 index 0000000000..3a68e6d997 --- /dev/null +++ b/athena-hortonworks-hive/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-hortonworks-hive-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-hortonworks-hive-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.hortonworks.HiveMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-hortonworks-hive/athena-hortonworks-hive.yaml b/athena-hortonworks-hive/athena-hortonworks-hive.yaml index 5ea4c07bce..8f941be498 100644 --- a/athena-hortonworks-hive/athena-hortonworks-hive.yaml +++ b/athena-hortonworks-hive/athena-hortonworks-hive.yaml @@ -69,10 +69,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.hortonworks.HiveMuxCompositeHandler" - CodeUri: "./target/athena-hortonworks-hive-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-hortonworks-hive:2022.47.1' Description: "Enables Amazon Athena to communicate with Hortonworks Hive using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-hortonworks-hive/pom.xml b/athena-hortonworks-hive/pom.xml index 1b67ad4b8c..dc11525b4c 100644 --- a/athena-hortonworks-hive/pom.xml +++ b/athena-hortonworks-hive/pom.xml @@ -48,12 +48,18 @@ test - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMetadataHandler.java b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMetadataHandler.java index 4a42dca536..e3b49989ad 100644 --- a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMetadataHandler.java +++ b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMetadataHandler.java @@ -48,8 +48,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -60,6 +58,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -96,8 +96,8 @@ public HiveMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, ja @VisibleForTesting protected HiveMetadataHandler( DatabaseConnectionConfig databaseConnectionConfiguration, - AWSSecretsManager secretManager, - AmazonAthena athena, + SecretsManagerClient secretManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxMetadataHandler.java b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxMetadataHandler.java index fa13a931c4..81dfff872b 100644 --- a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxMetadataHandler.java +++ b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -56,7 +56,7 @@ public HiveMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected HiveMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected HiveMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandler.java b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandler.java index 6bc47f1687..aa676b7b99 100644 --- a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandler.java +++ b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public HiveMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - HiveMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + HiveMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandler.java b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandler.java index 47a7b235fd..1450634a1e 100644 --- a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandler.java +++ b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandler.java @@ -28,15 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -62,11 +59,11 @@ public HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java } public HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new HiveQueryStringBuilder(HIVE_QUOTE_CHARACTER, new HiveFederationExpressionParser(HIVE_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMetadataHandlerTest.java b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMetadataHandlerTest.java index ff9065debe..c7acb9f360 100644 --- a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMetadataHandlerTest.java +++ b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMetadataHandlerTest.java @@ -28,10 +28,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; @@ -39,6 +35,10 @@ import org.junit.BeforeClass; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.*; import java.util.*; @@ -57,8 +57,8 @@ public class HiveMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private BlockAllocator blockAllocator; @BeforeClass @@ -74,9 +74,9 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.hiveMetadataHandler = new HiveMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); diff --git a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxMetadataHandlerTest.java b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxMetadataHandlerTest.java index 5d93ab7427..eb04665d33 100644 --- a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxMetadataHandlerTest.java +++ b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxMetadataHandlerTest.java @@ -43,8 +43,8 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableRequest; import com.amazonaws.athena.connector.lambda.metadata.ListSchemasRequest; import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import static org.mockito.ArgumentMatchers.nullable; @@ -54,8 +54,8 @@ public class HiveMuxMetadataHandlerTest private HiveMetadataHandler hiveMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @BeforeClass @@ -68,8 +68,8 @@ public void setup() this.allocator = new BlockAllocatorImpl(); this.hiveMetadataHandler = Mockito.mock(HiveMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("metaHive", this.hiveMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", HiveConstants.HIVE_NAME, diff --git a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandlerTest.java b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandlerTest.java index 9d78ee80e1..32dba90175 100644 --- a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandlerTest.java +++ b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandlerTest.java @@ -29,15 +29,15 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.mockito.Mockito; import org.testng.Assert; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -50,9 +50,9 @@ public class HiveMuxRecordHandlerTest private Map recordHandlerMap; private HiveRecordHandler hiveRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @BeforeClass @@ -64,9 +64,9 @@ public void setup() { this.hiveRecordHandler = Mockito.mock(HiveRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("recordHive", this.hiveRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", HiveConstants.HIVE_NAME, diff --git a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandlerTest.java b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandlerTest.java index 679bd11228..c45cfaf7c4 100644 --- a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandlerTest.java +++ b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandlerTest.java @@ -46,13 +46,13 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.Range; import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import static com.amazonaws.athena.connectors.hortonworks.HiveConstants.HIVE_QUOTE_CHARACTER; import static org.mockito.ArgumentMatchers.any; @@ -64,18 +64,18 @@ public class HiveRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-jdbc/pom.xml b/athena-jdbc/pom.xml index b46fb4d096..06d59b115c 100644 --- a/athena-jdbc/pom.xml +++ b/athena-jdbc/pom.xml @@ -9,60 +9,6 @@ athena-jdbc 2022.47.1 - - com.amazonaws - jmespath-java - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - - - com.amazonaws - aws-java-sdk-core - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - com.amazonaws aws-athena-federation-sdk @@ -155,9 +101,9 @@ - com.amazonaws - aws-java-sdk-redshift - ${aws-sdk.version} + software.amazon.awssdk + redshift + ${aws-sdk-v2.version} test @@ -167,12 +113,18 @@ ${aws-cdk.version} test - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcMetadataHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcMetadataHandler.java index c20fa505da..992207a345 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcMetadataHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcMetadataHandler.java @@ -39,10 +39,10 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -64,8 +64,8 @@ public class MultiplexingJdbcMetadataHandler * @param metadataHandlerMap catalog -> JdbcMetadataHandler */ protected MultiplexingJdbcMetadataHandler( - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandler.java index e458b7cbcd..e2cb4f227c 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandler.java @@ -30,12 +30,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -61,9 +61,9 @@ public MultiplexingJdbcRecordHandler(JdbcRecordHandlerFactory jdbcRecordHandlerF @VisibleForTesting protected MultiplexingJdbcRecordHandler( - AmazonS3 amazonS3, - AWSSecretsManager secretsManager, - AmazonAthena athena, + S3Client amazonS3, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java index bb4154e6dd..f901f8bbfa 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java @@ -45,8 +45,6 @@ import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough; import com.amazonaws.athena.connectors.jdbc.splits.Splitter; import com.amazonaws.athena.connectors.jdbc.splits.SplitterFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -58,6 +56,8 @@ import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -115,8 +115,8 @@ protected JdbcMetadataHandler( @VisibleForTesting protected JdbcMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java index d1224b68f1..b58528d635 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java @@ -54,9 +54,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.connection.RdsSecretsCredentialProvider; import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.holders.NullableBigIntHolder; @@ -76,6 +73,9 @@ import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Array; import java.sql.Connection; @@ -111,9 +111,9 @@ protected JdbcRecordHandler(String sourceType, java.util.Map con } protected JdbcRecordHandler( - AmazonS3 amazonS3, - AWSSecretsManager secretsManager, - AmazonAthena athena, + S3Client amazonS3, + SecretsManagerClient secretsManager, + AthenaClient athena, DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) diff --git a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcMetadataHandlerTest.java b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcMetadataHandlerTest.java index 1c4c2cea54..484ffa93b4 100644 --- a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcMetadataHandlerTest.java +++ b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcMetadataHandlerTest.java @@ -32,11 +32,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -49,8 +49,8 @@ public class MultiplexingJdbcMetadataHandlerTest private JdbcMetadataHandler fakeDatabaseHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -62,8 +62,8 @@ public void setup() //Mockito.when(this.allocator.createBlock(nullable(Schema.class))).thenReturn(Mockito.mock(Block.class)); this.fakeDatabaseHandler = Mockito.mock(JdbcMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.fakeDatabaseHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandlerTest.java b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandlerTest.java index 46eac3ba57..60e229d6f7 100644 --- a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandlerTest.java +++ b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class MultiplexingJdbcRecordHandlerTest private Map recordHandlerMap; private JdbcRecordHandler fakeJdbcRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.fakeJdbcRecordHandler = Mockito.mock(JdbcRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("fakedatabase", this.fakeJdbcRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandlerTest.java b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandlerTest.java index 35bb083715..5d2d5f1291 100644 --- a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandlerTest.java +++ b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandlerTest.java @@ -39,15 +39,15 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.ResultSet; @@ -70,8 +70,8 @@ public class JdbcMetadataHandlerTest private FederatedIdentity federatedIdentity; private Connection connection; private BlockAllocator blockAllocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private ResultSet resultSetName; @Before @@ -82,9 +82,9 @@ public void setup() this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(connection.getCatalog()).thenReturn("testCatalog"); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", "fakedatabase://jdbc:fakedatabase://hostname/${testSecret}", "testSecret"); this.jdbcMetadataHandler = new JdbcMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()) diff --git a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandlerTest.java b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandlerTest.java index df6d6f1e42..cfbeba5602 100644 --- a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandlerTest.java +++ b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandlerTest.java @@ -39,20 +39,21 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.holders.NullableFloat8Holder; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; -import org.mockito.stubbing.Answer; +import org.mockito.invocation.InvocationOnMock; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.io.ByteArrayInputStream; import java.nio.charset.StandardCharsets; @@ -75,9 +76,9 @@ public class JdbcRecordHandlerTest private JdbcRecordHandler jdbcRecordHandler; private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private FederatedIdentity federatedIdentity; private PreparedStatement preparedStatement; @@ -89,11 +90,11 @@ public void setup() this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.preparedStatement = Mockito.mock(PreparedStatement.class); Mockito.when(this.connection.prepareStatement("someSql")).thenReturn(this.preparedStatement); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", @@ -143,15 +144,16 @@ public void readWithConstraint() BlockSpiller s3Spiller = new S3BlockSpiller(this.amazonS3, spillConfig, allocator, fieldSchema, constraintEvaluator, com.google.common.collect.ImmutableMap.of()); ReadRecordsRequest readRecordsRequest = new ReadRecordsRequest(this.federatedIdentity, "testCatalog", "testQueryId", inputTableName, fieldSchema, splitBuilder.build(), constraints, 1024, 1024); - Mockito.when(amazonS3.putObject(any())).thenAnswer((Answer) invocation -> { - ByteArrayInputStream byteArrayInputStream = (ByteArrayInputStream) ((PutObjectRequest) invocation.getArguments()[0]).getInputStream(); - int n = byteArrayInputStream.available(); - byte[] bytes = new byte[n]; - byteArrayInputStream.read(bytes, 0, n); - String data = new String(bytes, StandardCharsets.UTF_8); - Assert.assertTrue(data.contains("testVal1") || data.contains("testVal2") || data.contains("testPartitionValue")); - return new PutObjectResult(); - }); + Mockito.when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + ByteArrayInputStream inputStream = (ByteArrayInputStream) ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); + int n = inputStream.available(); + byte[] bytes = new byte[n]; + inputStream.read(bytes, 0, n); + String data = new String(bytes, StandardCharsets.UTF_8); + Assert.assertTrue(data.contains("testVal1") || data.contains("testVal2") || data.contains("testPartitionValue")); + return PutObjectResponse.builder().build(); + }); this.jdbcRecordHandler.readWithConstraint(s3Spiller, readRecordsRequest, queryStatusChecker); } diff --git a/athena-kafka/Dockerfile b/athena-kafka/Dockerfile new file mode 100644 index 0000000000..fbab927e79 --- /dev/null +++ b/athena-kafka/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-kafka-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-kafka-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.kafka.KafkaCompositeHandler" ] \ No newline at end of file diff --git a/athena-kafka/athena-kafka.yaml b/athena-kafka/athena-kafka.yaml index 0d6949ee55..27fd31b9c2 100644 --- a/athena-kafka/athena-kafka.yaml +++ b/athena-kafka/athena-kafka.yaml @@ -101,10 +101,9 @@ Resources: schema_registry_url: !Ref SchemaRegistryUrl auth_type: !Ref AuthType FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.kafka.KafkaCompositeHandler" - CodeUri: "./target/athena-kafka-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-kafka:2022.47.1' Description: "Enables Amazon Athena to communicate with Kafka clusters" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory Role: !If [NotHasLambdaRole, !GetAtt FunctionRole.Arn, !Ref LambdaRoleARN] diff --git a/athena-kafka/pom.xml b/athena-kafka/pom.xml index d9b4c7a2d1..7618263568 100644 --- a/athena-kafka/pom.xml +++ b/athena-kafka/pom.xml @@ -94,11 +94,6 @@ guava ${guava.version} - - com.amazonaws - aws-java-sdk-sts - ${aws-sdk.version} - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandler.java b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandler.java index d569f8dfe6..169fd1be81 100644 --- a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandler.java +++ b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandler.java @@ -46,8 +46,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -55,6 +53,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -114,8 +114,8 @@ public MySqlMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, J @VisibleForTesting protected MySqlMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxMetadataHandler.java b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxMetadataHandler.java index cbeaf1d92e..c31ce81f43 100644 --- a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxMetadataHandler.java +++ b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -58,7 +58,7 @@ public MySqlMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected MySqlMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected MySqlMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxRecordHandler.java b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxRecordHandler.java index f3bf4ab940..a159921eed 100644 --- a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxRecordHandler.java +++ b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public MySqlMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - MySqlMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + MySqlMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java index 57a60f24e8..8acf177306 100644 --- a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java +++ b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java @@ -29,17 +29,14 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -79,13 +76,13 @@ public MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, jav public MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new MySqlQueryStringBuilder(MYSQL_QUOTE_CHARACTER, new MySqlFederationExpressionParser(MYSQL_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final AWSSecretsManager secretsManager, - final AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, + final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandlerTest.java b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandlerTest.java index 1f44feac4f..129067e4b8 100644 --- a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandlerTest.java +++ b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandlerTest.java @@ -40,16 +40,16 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.PreparedStatement; @@ -78,8 +78,8 @@ public class MySqlMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private BlockAllocator blockAllocator; @Before @@ -89,9 +89,9 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.mySqlMetadataHandler = new MySqlMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.blockAllocator = Mockito.mock(BlockAllocator.class); diff --git a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcMetadataHandlerTest.java b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcMetadataHandlerTest.java index e09a78f2f0..d74d150cb4 100644 --- a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcMetadataHandlerTest.java +++ b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcMetadataHandlerTest.java @@ -32,11 +32,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -49,8 +49,8 @@ public class MySqlMuxJdbcMetadataHandlerTest private MySqlMetadataHandler mySqlMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -62,8 +62,8 @@ public void setup() //Mockito.when(this.allocator.createBlock(nullable(Schema.class))).thenReturn(Mockito.mock(Block.class)); this.mySqlMetadataHandler = Mockito.mock(MySqlMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.mySqlMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcRecordHandlerTest.java b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcRecordHandlerTest.java index 1af05bbddb..ea9c543c0b 100644 --- a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcRecordHandlerTest.java +++ b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class MySqlMuxJdbcRecordHandlerTest private Map recordHandlerMap; private MySqlRecordHandler mySqlRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.mySqlRecordHandler = Mockito.mock(MySqlRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("mysql", this.mySqlRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "mysql", diff --git a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandlerTest.java b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandlerTest.java index 06248de46b..157c08228b 100644 --- a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandlerTest.java +++ b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandlerTest.java @@ -36,9 +36,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -48,6 +45,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -66,17 +66,17 @@ public class MySqlRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/integ/MySqlIntegTest.java b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/integ/MySqlIntegTest.java index 7e2dc082e0..c5f3cb7bbc 100644 --- a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/integ/MySqlIntegTest.java +++ b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/integ/MySqlIntegTest.java @@ -26,12 +26,6 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo; import com.amazonaws.athena.connectors.jdbc.integ.JdbcTableUtils; -import com.amazonaws.services.athena.model.Row; -import com.amazonaws.services.rds.AmazonRDS; -import com.amazonaws.services.rds.AmazonRDSClientBuilder; -import com.amazonaws.services.rds.model.DescribeDBInstancesRequest; -import com.amazonaws.services.rds.model.DescribeDBInstancesResult; -import com.amazonaws.services.rds.model.Endpoint; import com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -53,6 +47,11 @@ import software.amazon.awscdk.services.rds.MysqlEngineVersion; import software.amazon.awscdk.services.rds.StorageType; import software.amazon.awscdk.services.secretsmanager.Secret; +import software.amazon.awssdk.services.athena.model.Row; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.rds.model.DescribeDbInstancesRequest; +import software.amazon.awssdk.services.rds.model.DescribeDbInstancesResponse; +import software.amazon.awssdk.services.rds.model.Endpoint; import java.util.ArrayList; import java.util.Collections; @@ -195,14 +194,14 @@ private Stack getMySqlStack() */ private Endpoint getInstanceData() { - AmazonRDS rdsClient = AmazonRDSClientBuilder.defaultClient(); + RdsClient rdsClient = RdsClient.create(); try { - DescribeDBInstancesResult instancesResult = rdsClient.describeDBInstances(new DescribeDBInstancesRequest() - .withDBInstanceIdentifier(dbInstanceName)); - return instancesResult.getDBInstances().get(0).getEndpoint(); + DescribeDbInstancesResponse instancesResponse = rdsClient.describeDBInstances(DescribeDbInstancesRequest.builder() + .dbInstanceIdentifier(dbInstanceName).build()); + return instancesResponse.dbInstances().get(0).endpoint(); } finally { - rdsClient.shutdown(); + rdsClient.close(); } } @@ -213,7 +212,7 @@ private Endpoint getInstanceData() private void setEnvironmentVars(Endpoint endpoint) { String connectionString = String.format("mysql://jdbc:mysql://%s:%s/mysql?user=%s&password=%s", - endpoint.getAddress(), endpoint.getPort(), username, password); + endpoint.address(), endpoint.port(), username, password); String connectionStringTag = lambdaFunctionName + "_connection_string"; environmentVars.put("default", connectionString); environmentVars.put(connectionStringTag, connectionString); @@ -436,13 +435,13 @@ public void selectColumnWithPredicateIntegTest() String query = String.format("select title from %s.%s.%s where year > 2010;", lambdaFunctionName, mysqlDbName, mysqlTableMovies); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List titles = new ArrayList<>(); - rows.forEach(row -> titles.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> titles.add(row.data().get(0).varCharValue())); logger.info("Titles: {}", titles); assertEquals("Wrong number of DB records found.", 1, titles.size()); assertTrue("Movie title not found: Interstellar.", titles.contains("Interstellar")); @@ -459,13 +458,13 @@ public void selectColumnBetweenDatesIntegTest() String query = String.format( "select first_name from %s.%s.%s where birthday between date('2005-10-01') and date('2005-10-31');", lambdaFunctionName, mysqlDbName, mysqlTableBday); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List names = new ArrayList<>(); - rows.forEach(row -> names.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> names.add(row.data().get(0).varCharValue())); logger.info("Names: {}", names); assertEquals("Wrong number of DB records found.", 1, names.size()); assertTrue("Name not found: Jane.", names.contains("Jane")); diff --git a/athena-neptune/Dockerfile b/athena-neptune/Dockerfile new file mode 100644 index 0000000000..c8573d87c1 --- /dev/null +++ b/athena-neptune/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-neptune-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-neptune-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.neptune.NeptuneCompositeHandler" ] \ No newline at end of file diff --git a/athena-neptune/athena-neptune.yaml b/athena-neptune/athena-neptune.yaml index 54ebf62cbb..114314291a 100644 --- a/athena-neptune/athena-neptune.yaml +++ b/athena-neptune/athena-neptune.yaml @@ -96,10 +96,9 @@ Resources: SERVICE_REGION: !Ref AWS::Region enable_caseinsensitivematch: !Ref EnableCaseInsensitiveMatch FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.neptune.NeptuneCompositeHandler" - CodeUri: "./target/athena-neptune-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-neptune:2022.47.1' Description: "Enables Amazon Athena to communicate with Neptune, making your Neptune graph data accessible via SQL." - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandler.java b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandler.java index ddb3759696..37c0c89c2e 100644 --- a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandler.java +++ b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandler.java @@ -44,12 +44,6 @@ import com.amazonaws.athena.connectors.neptune.propertygraph.PropertyGraphHandler; import com.amazonaws.athena.connectors.neptune.qpt.NeptuneQueryPassthrough; import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.GetTablesRequest; -import com.amazonaws.services.glue.model.GetTablesResult; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; @@ -59,6 +53,12 @@ import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.GetTablesRequest; +import software.amazon.awssdk.services.glue.model.GetTablesResponse; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; import java.util.HashSet; @@ -90,7 +90,7 @@ public class NeptuneMetadataHandler extends GlueMetadataHandler private final Logger logger = LoggerFactory.getLogger(NeptuneMetadataHandler.class); private static final String SOURCE_TYPE = "neptune"; // Used to denote the 'type' of this connector for diagnostic // purposes. - private final AWSGlue glue; + private final GlueClient glue; private final String glueDBName; private NeptuneConnection neptuneConnection = null; @@ -109,11 +109,11 @@ public NeptuneMetadataHandler(java.util.Map configOptions) @VisibleForTesting protected NeptuneMetadataHandler( - AWSGlue glue, + GlueClient glue, NeptuneConnection neptuneConnection, EncryptionKeyFactory keyFactory, - AWSSecretsManager awsSecretsManager, - AmazonAthena athena, + SecretsManagerClient awsSecretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, java.util.Map configOptions) @@ -174,14 +174,15 @@ public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesReque logger.info("doListTables: enter - " + request); List tables = new ArrayList<>(); - GetTablesRequest getTablesRequest = new GetTablesRequest(); - getTablesRequest.setDatabaseName(request.getSchemaName()); + GetTablesRequest getTablesRequest = GetTablesRequest.builder() + .databaseName(request.getSchemaName()) + .build(); - GetTablesResult getTablesResult = glue.getTables(getTablesRequest); - List
glueTableList = getTablesResult.getTableList(); + GetTablesResponse getTablesResponse = glue.getTables(getTablesRequest); + List
glueTableList = getTablesResponse.tableList(); String schemaName = request.getSchemaName(); glueTableList.forEach(e -> { - tables.add(new TableName(schemaName, e.getName())); + tables.add(new TableName(schemaName, e.name())); }); return new ListTablesResponse(request.getCatalogName(), tables, null); diff --git a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandler.java b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandler.java index e6de3b070d..2456b3aa27 100644 --- a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandler.java +++ b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandler.java @@ -26,15 +26,12 @@ import com.amazonaws.athena.connectors.neptune.Enums.GraphType; import com.amazonaws.athena.connectors.neptune.propertygraph.PropertyGraphHandler; import com.amazonaws.athena.connectors.neptune.rdf.RDFHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; /** * This class is part of an tutorial that will walk you through how to build a @@ -65,18 +62,18 @@ public class NeptuneRecordHandler extends RecordHandler public NeptuneRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), + S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), NeptuneConnection.createConnection(configOptions), configOptions); } @VisibleForTesting protected NeptuneRecordHandler( - AmazonS3 amazonS3, - AWSSecretsManager secretsManager, - AmazonAthena amazonAthena, + S3Client amazonS3, + SecretsManagerClient secretsManager, + AthenaClient amazonAthena, NeptuneConnection neptuneConnection, java.util.Map configOptions) { diff --git a/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandlerTest.java b/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandlerTest.java index a6c31c88ae..f125037311 100644 --- a/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandlerTest.java +++ b/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneMetadataHandlerTest.java @@ -28,15 +28,6 @@ import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.GetTablesRequest; -import com.amazonaws.services.glue.model.GetTablesResult; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.After; import org.junit.Before; @@ -46,6 +37,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.GetTablesRequest; +import software.amazon.awssdk.services.glue.model.GetTablesResponse; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; + import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -66,10 +66,7 @@ public class NeptuneMetadataHandlerTest extends TestBase { private static final Logger logger = LoggerFactory.getLogger(NeptuneMetadataHandlerTest.class); @Mock - private AWSGlue glue; - - @Mock - private GetTablesRequest glueReq = null; + private GlueClient glue; private NeptuneMetadataHandler handler = null; @@ -86,7 +83,7 @@ public void setUp() throws Exception { logger.info("setUpBefore - enter"); allocator = new BlockAllocatorImpl(); handler = new NeptuneMetadataHandler(glue,neptuneConnection, - new LocalKeyFactory(), mock(AWSSecretsManager.class), mock(AmazonAthena.class), "spill-bucket", + new LocalKeyFactory(), mock(SecretsManagerClient.class), mock(AthenaClient.class), "spill-bucket", "spill-prefix", com.google.common.collect.ImmutableMap.of()); logger.info("setUpBefore - exit"); } @@ -113,23 +110,19 @@ public void doListTables() { logger.info("doListTables - enter"); List
tables = new ArrayList
(); - Table table1 = new Table(); - table1.setName("table1"); - Table table2 = new Table(); - table2.setName("table2"); - Table table3 = new Table(); - table3.setName("table3"); + Table table1 = Table.builder().name("table1").build(); + Table table2 = Table.builder().name("table2").build(); + Table table3 = Table.builder().name("table3").build(); tables.add(table1); tables.add(table2); tables.add(table3); - GetTablesResult tableResult = new GetTablesResult(); - tableResult.setTableList(tables); + GetTablesResponse tableResponse = GetTablesResponse.builder().tableList(tables).build(); ListTablesRequest req = new ListTablesRequest(IDENTITY, "queryId", "default", "default", null, UNLIMITED_PAGE_SIZE_VALUE); - when(glue.getTables(nullable(GetTablesRequest.class))).thenReturn(tableResult); + when(glue.getTables(nullable(GetTablesRequest.class))).thenReturn(tableResponse); ListTablesResponse res = handler.doListTables(allocator, req); @@ -143,35 +136,33 @@ public void doGetTable() throws Exception { logger.info("doGetTable - enter"); - Table table = new Table(); - table.setName("table1"); - Map expectedParams = new HashMap<>(); - expectedParams.put("sourceTable", table.getName()); - expectedParams.put("columnMapping", "col2=Col2,col3=Col3, col4=Col4"); - expectedParams.put("datetimeFormatMapping", "col2=someformat2, col1=someformat1 "); - - table.setParameters(expectedParams); List columns = new ArrayList<>(); - columns.add(new Column().withName("col1").withType("int").withComment("comment")); - columns.add(new Column().withName("col2").withType("bigint").withComment("comment")); - columns.add(new Column().withName("col3").withType("string").withComment("comment")); - columns.add(new Column().withName("col4").withType("timestamp").withComment("comment")); - columns.add(new Column().withName("col5").withType("date").withComment("comment")); - columns.add(new Column().withName("col6").withType("timestamptz").withComment("comment")); - columns.add(new Column().withName("col7").withType("timestamptz").withComment("comment")); - - StorageDescriptor storageDescriptor = new StorageDescriptor(); - storageDescriptor.setColumns(columns); - table.setStorageDescriptor(storageDescriptor); + columns.add(Column.builder().name("col1").type("int").comment("comment").build()); + columns.add(Column.builder().name("col2").type("bigint").comment("comment").build()); + columns.add(Column.builder().name("col3").type("string").comment("comment").build()); + columns.add(Column.builder().name("col4").type("timestamp").comment("comm.build()ent").build()); + columns.add(Column.builder().name("col5").type("date").comment("comment").build()); + columns.add(Column.builder().name("col6").type("timestamptz").comment("comment").build()); + columns.add(Column.builder().name("col7").type("timestamptz").comment("comment").build()); + + StorageDescriptor storageDescriptor = StorageDescriptor.builder().columns(columns).build(); + Table table = Table.builder() + .name("table1") + .parameters(expectedParams) + .storageDescriptor(storageDescriptor) + .build(); + + expectedParams.put("sourceTable", table.name()); + expectedParams.put("columnMapping", "col2=Col2,col3=Col3, col4=Col4"); + expectedParams.put("datetimeFormatMapping", "col2=someformat2, col1=someformat1 "); GetTableRequest req = new GetTableRequest(IDENTITY, "queryId", "default", new TableName("schema1", "table1"), Collections.emptyMap()); - GetTableResult getTableResult = new GetTableResult(); - getTableResult.setTable(table); + software.amazon.awssdk.services.glue.model.GetTableResponse getTableResponse = software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); - when(glue.getTable(nullable(com.amazonaws.services.glue.model.GetTableRequest.class))).thenReturn(getTableResult); + when(glue.getTable(nullable(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenReturn(getTableResponse); GetTableResponse res = handler.doGetTable(allocator, req); diff --git a/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandlerTest.java b/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandlerTest.java index 5f4ed5ea24..bde646b1d3 100644 --- a/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandlerTest.java +++ b/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandlerTest.java @@ -46,13 +46,6 @@ import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -77,6 +70,15 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -97,9 +99,9 @@ public class NeptuneRecordHandlerTest extends TestBase { private Schema schemaPGVertexForRead; private Schema schemaPGEdgeForRead; private Schema schemaPGQueryForRead; - private AmazonS3 amazonS3; - private AWSSecretsManager awsSecretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient awsSecretsManager; + private AthenaClient athena; private S3BlockSpillReader spillReader; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); private List mockS3Storage = new ArrayList<>(); @@ -164,34 +166,32 @@ public void setUp() { .build(); allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); - awsSecretsManager = mock(AWSSecretsManager.class); - athena = mock(AmazonAthena.class); + amazonS3 = mock(S3Client.class); + awsSecretsManager = mock(SecretsManagerClient.class); + athena = mock(AthenaClient.class); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); - ByteHolder byteHolder; - synchronized (mockS3Storage) { - byteHolder = mockS3Storage.get(0); - mockS3Storage.remove(0); - logger.info("getObject: total size " + mockS3Storage.size()); - } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream(new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; - }); + when(amazonS3.getObject(any(GetObjectRequest.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + ByteHolder byteHolder; + synchronized (mockS3Storage) { + byteHolder = mockS3Storage.get(0); + mockS3Storage.remove(0); + logger.info("getObject: total size " + mockS3Storage.size()); + } + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); + }); handler = new NeptuneRecordHandler(amazonS3, awsSecretsManager, athena, neptuneConnection, com.google.common.collect.ImmutableMap.of()); spillReader = new S3BlockSpillReader(amazonS3, allocator); diff --git a/athena-oracle/Dockerfile b/athena-oracle/Dockerfile new file mode 100644 index 0000000000..e85f8c566e --- /dev/null +++ b/athena-oracle/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-oracle-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-oracle-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.oracle.OracleMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-oracle/athena-oracle.yaml b/athena-oracle/athena-oracle.yaml index dd19543fff..e086cf82cb 100644 --- a/athena-oracle/athena-oracle.yaml +++ b/athena-oracle/athena-oracle.yaml @@ -82,10 +82,9 @@ Resources: default: !Ref DefaultConnectionString is_FIPS_Enabled: !Ref IsFIPSEnabled FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.oracle.OracleMuxCompositeHandler" - CodeUri: "./target/athena-oracle-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-oracle:2022.47.1' Description: "Enables Amazon Athena to communicate with ORACLE using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-oracle/pom.xml b/athena-oracle/pom.xml index 5d5b2b6897..e727fd5e45 100644 --- a/athena-oracle/pom.xml +++ b/athena-oracle/pom.xml @@ -32,12 +32,18 @@ ojdbc8 23.5.0.24.07 - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandler.java b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandler.java index 8570bc1df3..3932645fe7 100644 --- a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandler.java +++ b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandler.java @@ -52,8 +52,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; @@ -65,6 +63,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -125,8 +125,8 @@ public OracleMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, @VisibleForTesting protected OracleMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxMetadataHandler.java b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxMetadataHandler.java index 6bf4fac5b0..6b854ccc64 100644 --- a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxMetadataHandler.java +++ b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -56,7 +56,7 @@ public OracleMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected OracleMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected OracleMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxRecordHandler.java b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxRecordHandler.java index 02b873ec7c..9a1f0a09ae 100644 --- a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxRecordHandler.java +++ b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public OracleMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - OracleMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + OracleMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandler.java b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandler.java index 83c2d66654..c312b87f5b 100644 --- a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandler.java +++ b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandler.java @@ -28,17 +28,14 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -77,13 +74,13 @@ public OracleRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, ja public OracleRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new OracleQueryStringBuilder(ORACLE_QUOTE_CHARACTER, new OracleFederationExpressionParser(ORACLE_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - OracleRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final AWSSecretsManager secretsManager, - final AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + OracleRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, + final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandlerTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandlerTest.java index 63a4e6b7bd..c84dc15538 100644 --- a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandlerTest.java +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandlerTest.java @@ -33,16 +33,16 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.PreparedStatement; @@ -72,8 +72,8 @@ public class OracleMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() @@ -82,9 +82,9 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.oracleMetadataHandler = new OracleMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); } diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcMetadataHandlerTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcMetadataHandlerTest.java index 3520219b56..537cc0c969 100644 --- a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcMetadataHandlerTest.java +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcMetadataHandlerTest.java @@ -32,13 +32,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.athena.connectors.oracle.OracleMetadataHandler; -import com.amazonaws.athena.connectors.oracle.OracleMuxMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -51,8 +49,8 @@ public class OracleMuxJdbcMetadataHandlerTest private OracleMetadataHandler oracleMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -64,8 +62,8 @@ public void setup() //Mockito.when(this.allocator.createBlock(nullable(Schema.class))).thenReturn(Mockito.mock(Block.class)); this.oracleMetadataHandler = Mockito.mock(OracleMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.oracleMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcRecordHandlerTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcRecordHandlerTest.java index 84c01f1eaf..1ec10050b3 100644 --- a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcRecordHandlerTest.java +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcRecordHandlerTest.java @@ -30,13 +30,13 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.oracle.OracleMuxRecordHandler; import com.amazonaws.athena.connectors.oracle.OracleRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -48,9 +48,9 @@ public class OracleMuxJdbcRecordHandlerTest private Map recordHandlerMap; private OracleRecordHandler oracleRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -59,9 +59,9 @@ public void setup() { this.oracleRecordHandler = Mockito.mock(OracleRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("oracle", this.oracleRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "oracle", diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandlerTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandlerTest.java index 7dfceed728..c95cf18e96 100644 --- a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandlerTest.java +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandlerTest.java @@ -32,9 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -43,6 +40,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.Date; @@ -60,9 +60,9 @@ public class OracleRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private static final String ORACLE_QUOTE_CHARACTER = "\""; @@ -71,9 +71,9 @@ public class OracleRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/integ/OracleIntegTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/integ/OracleIntegTest.java index a197daa374..26905abfc9 100644 --- a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/integ/OracleIntegTest.java +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/integ/OracleIntegTest.java @@ -20,7 +20,6 @@ package com.amazonaws.athena.connectors.oracle.integ; import com.amazonaws.athena.connector.integ.data.TestConfig; -import com.amazonaws.services.athena.model.Row; import com.google.common.collect.ImmutableList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +34,7 @@ import software.amazon.awscdk.services.iam.Effect; import software.amazon.awscdk.services.iam.PolicyDocument; import software.amazon.awscdk.services.iam.PolicyStatement; +import software.amazon.awssdk.services.athena.model.Row; import static org.junit.Assert.assertEquals; @@ -223,13 +223,13 @@ public void fetchRangePartitionDataTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\";", lambdaFunctionName, oracleDBName ,rangePartitionTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List rangePartitonData = new ArrayList<>(); - rows.forEach(row -> rangePartitonData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> rangePartitonData.add(row.data().get(0).varCharValue())); logger.info("rangePartitonData: {}", rangePartitonData); } @@ -241,13 +241,13 @@ public void fetchAllDataTypeDataTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\";", lambdaFunctionName, oracleDBName ,allDataTypeTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List allDataTypeData = new ArrayList<>(); - rows.forEach(row -> allDataTypeData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> allDataTypeData.add(row.data().get(0).varCharValue())); logger.info("allDataTypeData: {}", allDataTypeData); } @@ -259,13 +259,13 @@ public void fetchListPartitionDataTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\";", lambdaFunctionName, oracleDBName ,listPartitionTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List listPartitonData = new ArrayList<>(); - rows.forEach(row -> listPartitonData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> listPartitonData.add(row.data().get(0).varCharValue())); logger.info("listPartitonData: {}", listPartitonData); } @@ -277,13 +277,13 @@ public void numberDataTypeWhereClauseTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\" where NUMBER_TYPE=320;", lambdaFunctionName, oracleDBName ,allDataTypeTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List numberDataTypeData = new ArrayList<>(); - rows.forEach(row -> numberDataTypeData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> numberDataTypeData.add(row.data().get(0).varCharValue())); logger.info("numberDataTypeData: {}", numberDataTypeData); } @@ -295,13 +295,13 @@ public void charDataTypeWhereClauseTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\" where CHAR_TYPE='A';", lambdaFunctionName, oracleDBName ,allDataTypeTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List charDataTypeData = new ArrayList<>(); - rows.forEach(row -> charDataTypeData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> charDataTypeData.add(row.data().get(0).varCharValue())); logger.info("charDataTypeData: {}", charDataTypeData); } @@ -313,13 +313,13 @@ public void dateDataTypeWhereClauseTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\" where DATE_COL = date('2021-03-18');", lambdaFunctionName, oracleDBName ,allDataTypeTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List dateDataTypeData = new ArrayList<>(); - rows.forEach(row -> dateDataTypeData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> dateDataTypeData.add(row.data().get(0).varCharValue())); logger.info("dateDataTypeData: {}", dateDataTypeData); } @@ -331,13 +331,13 @@ public void timestampDataTypeWhereClauseTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\" where TIMESTAMP_WITH_3_FRAC_SEC_COL >= CAST('2021-03-18 09:00:00.123' AS TIMESTAMP);", lambdaFunctionName, oracleDBName ,allDataTypeTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List timestampDataTypeData = new ArrayList<>(); - rows.forEach(row -> timestampDataTypeData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> timestampDataTypeData.add(row.data().get(0).varCharValue())); logger.info("timestampDataTypeData: {}", timestampDataTypeData); } @@ -349,13 +349,13 @@ public void varcharDataTypeWhereClauseTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\" where VARCHAR_10_COL ='ORACLEXPR';", lambdaFunctionName, oracleDBName ,allDataTypeTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List varcharDataTypeData = new ArrayList<>(); - rows.forEach(row -> varcharDataTypeData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> varcharDataTypeData.add(row.data().get(0).varCharValue())); logger.info("varcharDataTypeData: {}", varcharDataTypeData); } @@ -367,13 +367,13 @@ public void decimalDataTypeWhereClauseTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\" where NUMBER_3_SF_2_DP = 5.82;", lambdaFunctionName, oracleDBName ,allDataTypeTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List decimalDataTypeData = new ArrayList<>(); - rows.forEach(row -> decimalDataTypeData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> decimalDataTypeData.add(row.data().get(0).varCharValue())); logger.info("decimalDataTypeData: {}", decimalDataTypeData); } @@ -385,13 +385,13 @@ public void multiDataTypefilterClauseTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\" where DATE_COL= date('2021-03-18') and NUMBER_3_SF_2_DP = 5.82;", lambdaFunctionName, oracleDBName ,allDataTypeTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List multiDataTypeFilterData = new ArrayList<>(); - rows.forEach(row -> multiDataTypeFilterData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> multiDataTypeFilterData.add(row.data().get(0).varCharValue())); logger.info("multiDataTypeFilterData: {}", multiDataTypeFilterData); } @@ -404,13 +404,13 @@ public void floatDataTypeWhereClauseTest() logger.info("--------------------------------------------------"); String query = String.format("select * from \"lambda:%s\".\"%s\".\"%s\" where float_col = 39840.0;", lambdaFunctionName, oracleDBName ,allDataTypeTable); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List floatDataTypeData = new ArrayList<>(); - rows.forEach(row -> floatDataTypeData.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> floatDataTypeData.add(row.data().get(0).varCharValue())); logger.info("floatDataTypeData: {}", floatDataTypeData); } diff --git a/athena-postgresql/Dockerfile b/athena-postgresql/Dockerfile new file mode 100644 index 0000000000..3376a994dc --- /dev/null +++ b/athena-postgresql/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-postgresql-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-postgresql-2022.47.1.jar + +# Command can be overwritten by providing a different command in the template directly. +# No need to specify here (already defined in athena-postgresql.yaml because has two different handlers) \ No newline at end of file diff --git a/athena-postgresql/athena-postgresql.yaml b/athena-postgresql/athena-postgresql.yaml index ab68f4b22c..30553623ec 100644 --- a/athena-postgresql/athena-postgresql.yaml +++ b/athena-postgresql/athena-postgresql.yaml @@ -81,10 +81,11 @@ Resources: default: !Ref DefaultConnectionString default_scale: !Ref DefaultScale FunctionName: !Ref LambdaFunctionName - Handler: !Sub "com.amazonaws.athena.connectors.postgresql.${CompositeHandler}" - CodeUri: "./target/athena-postgresql-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-postgresql:2022.47.1' + ImageConfig: + Command: [ !Sub "com.amazonaws.athena.connectors.postgresql.${CompositeHandler}" ] Description: "Enables Amazon Athena to communicate with PostgreSQL using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-postgresql/pom.xml b/athena-postgresql/pom.xml index 847b089a93..729080bbd9 100644 --- a/athena-postgresql/pom.xml +++ b/athena-postgresql/pom.xml @@ -39,12 +39,18 @@ ${mockito.version} test - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandler.java b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandler.java index a765aa4e0f..64e3159440 100644 --- a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandler.java +++ b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandler.java @@ -47,8 +47,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -58,6 +56,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -119,8 +119,8 @@ public PostGreSqlMetadataHandler(DatabaseConnectionConfig databaseConnectionConf @VisibleForTesting protected PostGreSqlMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxMetadataHandler.java b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxMetadataHandler.java index 3f6ae65f6e..f6667acfbf 100644 --- a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxMetadataHandler.java +++ b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -58,7 +58,7 @@ public PostGreSqlMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected PostGreSqlMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected PostGreSqlMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxRecordHandler.java b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxRecordHandler.java index 61198d3fec..8b98b0813f 100644 --- a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxRecordHandler.java +++ b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public PostGreSqlMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - PostGreSqlMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + PostGreSqlMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java index 877c450a05..fe4837c95a 100644 --- a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java +++ b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java @@ -29,17 +29,14 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -71,13 +68,14 @@ public PostGreSqlRecordHandler(java.util.Map configOptions) public PostGreSqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), - new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(POSTGRESQL_DRIVER_CLASS, POSTGRESQL_DEFAULT_PORT)), new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER)), configOptions); + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), + new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(POSTGRESQL_DRIVER_CLASS, POSTGRESQL_DEFAULT_PORT)), + new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - protected PostGreSqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, AWSSecretsManager secretsManager, - AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + protected PostGreSqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandlerTest.java b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandlerTest.java index 557619ad7f..97f72e4f8e 100644 --- a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandlerTest.java +++ b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMetadataHandlerTest.java @@ -40,10 +40,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; @@ -56,6 +52,10 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.PreparedStatement; @@ -86,8 +86,8 @@ public class PostGreSqlMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() @@ -96,8 +96,8 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.postGreSqlMetadataHandler = new PostGreSqlMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); } diff --git a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcMetadataHandlerTest.java b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcMetadataHandlerTest.java index f21d694c19..109706c13b 100644 --- a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcMetadataHandlerTest.java +++ b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcMetadataHandlerTest.java @@ -32,11 +32,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -49,8 +49,8 @@ public class PostGreSqlMuxJdbcMetadataHandlerTest private PostGreSqlMetadataHandler postGreSqlMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -62,8 +62,8 @@ public void setup() //Mockito.when(this.allocator.createBlock(nullable(Schema.class))).thenReturn(Mockito.mock(Block.class)); this.postGreSqlMetadataHandler = Mockito.mock(PostGreSqlMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("postgres", this.postGreSqlMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "postgres", diff --git a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcRecordHandlerTest.java b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcRecordHandlerTest.java index ba498d8f97..eadd042db0 100644 --- a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcRecordHandlerTest.java +++ b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class PostGreSqlMuxJdbcRecordHandlerTest private Map recordHandlerMap; private PostGreSqlRecordHandler postGreSqlRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.postGreSqlRecordHandler = Mockito.mock(PostGreSqlRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("postgres", this.postGreSqlRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "postgres", diff --git a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java index c9827a8464..e54c337d8e 100644 --- a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java +++ b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java @@ -33,9 +33,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -49,6 +46,9 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.math.BigDecimal; import java.sql.Connection; @@ -71,18 +71,18 @@ public class PostGreSqlRecordHandlerTest extends TestBase private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private MockedStatic mockedPostGreSqlMetadataHandler; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/integ/PostGreSqlIntegTest.java b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/integ/PostGreSqlIntegTest.java index 3fd840c1a3..68f3913340 100644 --- a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/integ/PostGreSqlIntegTest.java +++ b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/integ/PostGreSqlIntegTest.java @@ -26,12 +26,6 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo; import com.amazonaws.athena.connectors.jdbc.integ.JdbcTableUtils; -import com.amazonaws.services.athena.model.Row; -import com.amazonaws.services.rds.AmazonRDS; -import com.amazonaws.services.rds.AmazonRDSClientBuilder; -import com.amazonaws.services.rds.model.DescribeDBInstancesRequest; -import com.amazonaws.services.rds.model.DescribeDBInstancesResult; -import com.amazonaws.services.rds.model.Endpoint; import com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -53,6 +47,11 @@ import software.amazon.awscdk.services.rds.PostgresInstanceEngineProps; import software.amazon.awscdk.services.rds.StorageType; import software.amazon.awscdk.services.secretsmanager.Secret; +import software.amazon.awssdk.services.athena.model.Row; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.rds.model.DescribeDbInstancesRequest; +import software.amazon.awssdk.services.rds.model.DescribeDbInstancesResponse; +import software.amazon.awssdk.services.rds.model.Endpoint; import java.util.ArrayList; import java.util.Collections; @@ -195,14 +194,15 @@ private Stack getPostGreSqlStack() */ private Endpoint getInstanceData() { - AmazonRDS rdsClient = AmazonRDSClientBuilder.defaultClient(); + RdsClient rdsClient = RdsClient.create(); try { - DescribeDBInstancesResult instancesResult = rdsClient.describeDBInstances(new DescribeDBInstancesRequest() - .withDBInstanceIdentifier(dbInstanceName)); - return instancesResult.getDBInstances().get(0).getEndpoint(); + DescribeDbInstancesResponse instancesResponse = rdsClient.describeDBInstances(DescribeDbInstancesRequest.builder() + .dbInstanceIdentifier(dbInstanceName) + .build()); + return instancesResponse.dbInstances().get(0).endpoint(); } finally { - rdsClient.shutdown(); + rdsClient.close(); } } @@ -213,7 +213,7 @@ private Endpoint getInstanceData() private void setEnvironmentVars(Endpoint endpoint) { String connectionString = String.format("postgres://jdbc:postgresql://%s:%s/postgres?user=%s&password=%s", - endpoint.getAddress(), endpoint.getPort(), username, password); + endpoint.address(), endpoint.port(), username, password); String connectionStringTag = lambdaFunctionName + "_connection_string"; environmentVars.put("default", connectionString); environmentVars.put(connectionStringTag, connectionString); @@ -439,13 +439,13 @@ public void selectColumnWithPredicateIntegTest() String query = String.format("select title from %s.%s.%s where year > 2010;", lambdaFunctionName, postgresDbName, postgresTableMovies); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List titles = new ArrayList<>(); - rows.forEach(row -> titles.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> titles.add(row.data().get(0).varCharValue())); logger.info("Titles: {}", titles); assertEquals("Wrong number of DB records found.", 1, titles.size()); assertTrue("Movie title not found: Interstellar.", titles.contains("Interstellar")); @@ -462,13 +462,13 @@ public void selectColumnBetweenDatesIntegTest() String query = String.format( "select first_name from %s.%s.%s where birthday between date('2005-10-01') and date('2005-10-31');", lambdaFunctionName, postgresDbName, postgresTableBday); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List names = new ArrayList<>(); - rows.forEach(row -> names.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> names.add(row.data().get(0).varCharValue())); logger.info("Names: {}", names); assertEquals("Wrong number of DB records found.", 1, names.size()); assertTrue("Name not found: Jane.", names.contains("Jane")); diff --git a/athena-redis/Dockerfile b/athena-redis/Dockerfile new file mode 100644 index 0000000000..3e9e9888f7 --- /dev/null +++ b/athena-redis/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-redis-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-redis-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.redis.RedisCompositeHandler" ] \ No newline at end of file diff --git a/athena-redis/athena-redis.yaml b/athena-redis/athena-redis.yaml index 85ffd47a79..c3bc541752 100644 --- a/athena-redis/athena-redis.yaml +++ b/athena-redis/athena-redis.yaml @@ -81,10 +81,9 @@ Resources: qpt_cluster: !Ref QPTConnectionCluster qpt_db_number: !Ref QPTConnectionDBNumber FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.redis.RedisCompositeHandler" - CodeUri: "./target/athena-redis-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-redis:2022.47.1' Description: "Enables Amazon Athena to communicate with Redis, making your Redis data accessible via SQL" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-redis/pom.xml b/athena-redis/pom.xml index b585bea651..20dace645e 100644 --- a/athena-redis/pom.xml +++ b/athena-redis/pom.xml @@ -9,60 +9,6 @@ athena-redis 2022.47.1 - - com.amazonaws - jmespath-java - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - - - com.amazonaws - aws-java-sdk-core - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - com.amazonaws aws-athena-federation-sdk @@ -92,9 +38,9 @@ ${slf4j-log4j.version} - com.amazonaws - aws-java-sdk-elasticache - ${aws-sdk.version} + software.amazon.awssdk + elasticache + ${aws-sdk-v2.version} test diff --git a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandler.java b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandler.java index b0d6be7d06..5ddabdece9 100644 --- a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandler.java +++ b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandler.java @@ -47,11 +47,6 @@ import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionFactory; import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper; import com.amazonaws.athena.connectors.redis.qpt.RedisQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Database; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import io.lettuce.core.KeyScanCursor; import io.lettuce.core.Range; @@ -64,6 +59,11 @@ import org.apache.arrow.vector.util.Text; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Database; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Arrays; import java.util.HashSet; @@ -132,11 +132,11 @@ public class RedisMetadataHandler public static final String DEFAULT_REDIS_DB_NUMBER = "0"; //Used to filter out Glue tables which lack a redis endpoint. - private static final TableFilter TABLE_FILTER = (Table table) -> table.getParameters().containsKey(REDIS_ENDPOINT_PROP); + private static final TableFilter TABLE_FILTER = (Table table) -> table.parameters().containsKey(REDIS_ENDPOINT_PROP); //Used to filter out Glue databases which lack the REDIS_DB_FLAG in the URI. - private static final DatabaseFilter DB_FILTER = (Database database) -> (database.getLocationUri() != null && database.getLocationUri().contains(REDIS_DB_FLAG)); + private static final DatabaseFilter DB_FILTER = (Database database) -> (database.locationUri() != null && database.locationUri().contains(REDIS_DB_FLAG)); - private final AWSGlue awsGlue; + private final GlueClient awsGlue; private final RedisConnectionFactory redisConnectionFactory; private final RedisQueryPassthrough queryPassthrough = new RedisQueryPassthrough(); @@ -151,10 +151,10 @@ public RedisMetadataHandler(java.util.Map configOptions) @VisibleForTesting protected RedisMetadataHandler( - AWSGlue awsGlue, + GlueClient awsGlue, EncryptionKeyFactory keyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, RedisConnectionFactory redisConnectionFactory, String spillBucket, String spillPrefix, diff --git a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java index dfb490a3a6..5981aface2 100644 --- a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java +++ b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java @@ -29,12 +29,6 @@ import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionFactory; import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper; import com.amazonaws.athena.connectors.redis.qpt.RedisQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import io.lettuce.core.KeyScanCursor; import io.lettuce.core.ScanArgs; import io.lettuce.core.ScanCursor; @@ -45,6 +39,9 @@ import org.apache.arrow.vector.types.pojo.Field; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.HashMap; import java.util.HashSet; @@ -88,24 +85,24 @@ public class RedisRecordHandler private static final int SCAN_COUNT_SIZE = 100; private final RedisConnectionFactory redisConnectionFactory; - private final AmazonS3 amazonS3; + private final S3Client amazonS3; private final RedisQueryPassthrough queryPassthrough = new RedisQueryPassthrough(); public RedisRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.standard().build(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), + S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), new RedisConnectionFactory(), configOptions); } @VisibleForTesting - protected RedisRecordHandler(AmazonS3 amazonS3, - AWSSecretsManager secretsManager, - AmazonAthena athena, + protected RedisRecordHandler(S3Client amazonS3, + SecretsManagerClient secretsManager, + AthenaClient athena, RedisConnectionFactory redisConnectionFactory, java.util.Map configOptions) { diff --git a/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandlerTest.java b/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandlerTest.java index 2767e82eb1..2909de1d82 100644 --- a/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandlerTest.java +++ b/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisMetadataHandlerTest.java @@ -36,11 +36,6 @@ import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionFactory; import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper; import com.amazonaws.athena.connectors.redis.util.MockKeyScanCursor; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import io.lettuce.core.Range; import io.lettuce.core.ScanArgs; import io.lettuce.core.ScanCursor; @@ -57,6 +52,11 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.util.ArrayList; import java.util.Collections; @@ -103,13 +103,13 @@ public class RedisMetadataHandlerTest private RedisCommandsWrapper mockSyncCommands; @Mock - private AWSGlue mockGlue; + private GlueClient mockGlue; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Mock private RedisConnectionFactory mockFactory; @@ -129,10 +129,10 @@ public void setUp() when(mockSecretsManager.getSecretValue(nullable(GetSecretValueRequest.class))) .thenAnswer((InvocationOnMock invocation) -> { GetSecretValueRequest request = invocation.getArgument(0, GetSecretValueRequest.class); - if ("endpoint".equalsIgnoreCase(request.getSecretId())) { - return new GetSecretValueResult().withSecretString(decodedEndpoint); + if ("endpoint".equalsIgnoreCase(request.secretId())) { + return GetSecretValueResponse.builder().secretString(decodedEndpoint).build(); } - throw new RuntimeException("Unknown secret " + request.getSecretId()); + throw new RuntimeException("Unknown secret " + request.secretId()); }); } diff --git a/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisRecordHandlerTest.java b/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisRecordHandlerTest.java index e846d40ff2..d330af3ca3 100644 --- a/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisRecordHandlerTest.java +++ b/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisRecordHandlerTest.java @@ -40,15 +40,6 @@ import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper; import com.amazonaws.athena.connectors.redis.util.MockKeyScanCursor; import com.amazonaws.athena.connectors.redis.util.MockScoredValueScanCursor; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import io.lettuce.core.ScanArgs; @@ -69,6 +60,17 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -105,7 +107,7 @@ public class RedisRecordHandlerTest private RedisRecordHandler handler; private BlockAllocator allocator; private List mockS3Storage = new ArrayList<>(); - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpillReader spillReader; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -119,13 +121,13 @@ public class RedisRecordHandlerTest private RedisCommandsWrapper mockSyncCommands; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock private RedisConnectionFactory mockFactory; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Before public void setUp() @@ -137,42 +139,38 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); - Mockito.lenient().when(amazonS3.putObject(any())) + Mockito.lenient().when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - Mockito.lenient().when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + Mockito.lenient().when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); when(mockSecretsManager.getSecretValue(nullable(GetSecretValueRequest.class))) .thenAnswer((InvocationOnMock invocation) -> { GetSecretValueRequest request = invocation.getArgument(0, GetSecretValueRequest.class); - if ("endpoint".equalsIgnoreCase(request.getSecretId())) { - return new GetSecretValueResult().withSecretString(decodedEndpoint); + if ("endpoint".equalsIgnoreCase(request.secretId())) { + return GetSecretValueResponse.builder().secretString(decodedEndpoint).build(); } - throw new RuntimeException("Unknown secret " + request.getSecretId()); + throw new RuntimeException("Unknown secret " + request.secretId()); }); handler = new RedisRecordHandler(amazonS3, mockSecretsManager, mockAthena, mockFactory, com.google.common.collect.ImmutableMap.of()); diff --git a/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/integ/RedisIntegTest.java b/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/integ/RedisIntegTest.java index 64776c3708..98c81ab204 100644 --- a/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/integ/RedisIntegTest.java +++ b/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/integ/RedisIntegTest.java @@ -19,7 +19,6 @@ */ package com.amazonaws.athena.connectors.redis.integ; -import com.amazonaws.ClientConfiguration; import com.amazonaws.athena.connector.integ.ConnectorStackFactory; import com.amazonaws.athena.connector.integ.IntegrationTestBase; import com.amazonaws.athena.connector.integ.clients.CloudFormationClient; @@ -28,25 +27,6 @@ import com.amazonaws.athena.connector.integ.data.ConnectorVpcAttributes; import com.amazonaws.athena.connector.integ.data.SecretsManagerCredentials; import com.amazonaws.athena.connector.integ.providers.ConnectorPackagingAttributesProvider; -import com.amazonaws.services.athena.model.Row; -import com.amazonaws.services.elasticache.AmazonElastiCache; -import com.amazonaws.services.elasticache.AmazonElastiCacheClientBuilder; -import com.amazonaws.services.elasticache.model.DescribeCacheClustersRequest; -import com.amazonaws.services.elasticache.model.DescribeCacheClustersResult; -import com.amazonaws.services.elasticache.model.DescribeReplicationGroupsRequest; -import com.amazonaws.services.elasticache.model.DescribeReplicationGroupsResult; -import com.amazonaws.services.elasticache.model.Endpoint; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.AWSGlueClientBuilder; -import com.amazonaws.services.glue.model.EntityNotFoundException; -import com.amazonaws.services.glue.model.GetTableRequest; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.TableInput; -import com.amazonaws.services.glue.model.UpdateTableRequest; -import com.amazonaws.services.lambda.AWSLambda; -import com.amazonaws.services.lambda.AWSLambdaClientBuilder; -import com.amazonaws.services.lambda.model.InvocationType; -import com.amazonaws.services.lambda.model.InvokeRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.annotations.AfterClass; @@ -66,7 +46,23 @@ import software.amazon.awscdk.services.iam.PolicyDocument; import software.amazon.awscdk.services.s3.Bucket; import software.amazon.awscdk.services.s3.IBucket; - +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.services.athena.model.Row; +import software.amazon.awssdk.services.elasticache.ElastiCacheClient; +import software.amazon.awssdk.services.elasticache.model.DescribeCacheClustersRequest; +import software.amazon.awssdk.services.elasticache.model.DescribeCacheClustersResponse; +import software.amazon.awssdk.services.elasticache.model.DescribeReplicationGroupsRequest; +import software.amazon.awssdk.services.elasticache.model.DescribeReplicationGroupsResponse; +import software.amazon.awssdk.services.elasticache.model.Endpoint; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.EntityNotFoundException; +import software.amazon.awssdk.services.glue.model.TableInput; +import software.amazon.awssdk.services.glue.model.UpdateTableRequest; +import software.amazon.awssdk.services.lambda.LambdaClient; +import software.amazon.awssdk.services.lambda.model.InvocationType; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; + +import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -98,7 +94,7 @@ public class RedisIntegTest extends IntegrationTestBase private final String redisDbName; private final String redisTableNamePrefix; private final String lambdaFunctionName; - private final AWSGlue glue; + private final GlueClient glue; private final String redisStackName; private final Map environmentVars; @@ -120,8 +116,9 @@ public RedisIntegTest() redisDbName = (String) userSettings.get("redis_db_name"); redisTableNamePrefix = (String) userSettings.get("redis_table_name_prefix"); lambdaFunctionName = getLambdaFunctionName(); - glue = AWSGlueClientBuilder.standard() - .withClientConfiguration(new ClientConfiguration().withConnectionTimeout(GLUE_TIMEOUT)) + glue = GlueClient.builder() + .httpClientBuilder(ApacheHttpClient.builder() + .connectionTimeout(Duration.ofMillis(GLUE_TIMEOUT))) .build(); redisStackName = "integ-redis-instance-" + UUID.randomUUID(); environmentVars = new HashMap<>(); @@ -144,12 +141,12 @@ protected void setUp() Endpoint standaloneEndpoint = getRedisInstanceData(redisStandaloneName, false); logger.info("Got Endpoint: " + standaloneEndpoint.toString()); redisEndpoints.put(STANDALONE_KEY, String.format("%s:%s", - standaloneEndpoint.getAddress(), standaloneEndpoint.getPort())); + standaloneEndpoint.address(), standaloneEndpoint.port())); Endpoint clusterEndpoint = getRedisInstanceData(redisClusterName, true); logger.info("Got Endpoint: " + clusterEndpoint.toString()); redisEndpoints.put(CLUSTER_KEY, String.format("%s:%s:%s", - clusterEndpoint.getAddress(), clusterEndpoint.getPort(), redisPassword)); + clusterEndpoint.address(), clusterEndpoint.port(), redisPassword)); // Get endpoint information and set the connection string environment var for Lambda. environmentVars.put("standalone_connection", redisEndpoints.get(STANDALONE_KEY)); @@ -177,7 +174,7 @@ protected void cleanUp() // Delete the CloudFormation stack for Redis. cloudFormationClient.deleteStack(); // close glue client - glue.shutdown(); + glue.close(); } /** @@ -191,20 +188,21 @@ protected void setUpTableData() logger.info("----------------------------------------------------"); String redisLambdaName = "integ-redis-helper-" + UUID.randomUUID(); - AWSLambda lambdaClient = AWSLambdaClientBuilder.defaultClient(); + LambdaClient lambdaClient = LambdaClient.create(); CloudFormationClient cloudFormationRedisClient = new CloudFormationClient(getRedisLambdaStack(redisLambdaName)); try { // Create the Lambda function. cloudFormationRedisClient.createStack(); // Invoke the Lambda function. - lambdaClient.invoke(new InvokeRequest() - .withFunctionName(redisLambdaName) - .withInvocationType(InvocationType.RequestResponse)); + lambdaClient.invoke(InvokeRequest.builder() + .functionName(redisLambdaName) + .invocationType(InvocationType.REQUEST_RESPONSE) + .build()); } finally { // Delete the Lambda function. cloudFormationRedisClient.deleteStack(); - lambdaClient.shutdown(); + lambdaClient.close(); } } @@ -346,21 +344,26 @@ private Stack getRedisStack() */ private Endpoint getRedisInstanceData(String redisName, boolean isCluster) { - AmazonElastiCache elastiCacheClient = AmazonElastiCacheClientBuilder.defaultClient(); + ElastiCacheClient elastiCacheClient = ElastiCacheClient.create(); try { if (isCluster) { - DescribeReplicationGroupsResult describeResult = elastiCacheClient.describeReplicationGroups(new DescribeReplicationGroupsRequest() - .withReplicationGroupId(redisName)); - return describeResult.getReplicationGroups().get(0).getConfigurationEndpoint(); + DescribeReplicationGroupsRequest describeRequest = DescribeReplicationGroupsRequest.builder() + .replicationGroupId(redisName) + .build(); + DescribeReplicationGroupsResponse describeResponse = elastiCacheClient.describeReplicationGroups(describeRequest); + return describeResponse.replicationGroups().get(0).configurationEndpoint(); } else { - DescribeCacheClustersResult describeResult = elastiCacheClient.describeCacheClusters(new DescribeCacheClustersRequest() - .withCacheClusterId(redisName).withShowCacheNodeInfo(true)); - return describeResult.getCacheClusters().get(0).getCacheNodes().get(0).getEndpoint(); + DescribeCacheClustersRequest describeRequest = DescribeCacheClustersRequest.builder() + .cacheClusterId(redisName) + .showCacheNodeInfo(true) + .build(); + DescribeCacheClustersResponse describeResponse = elastiCacheClient.describeCacheClusters(describeRequest); + return describeResponse.cacheClusters().get(0).cacheNodes().get(0).endpoint(); } } finally { - elastiCacheClient.shutdown(); + elastiCacheClient.close(); } } @@ -371,15 +374,16 @@ private Endpoint getRedisInstanceData(String redisName, boolean isCluster) * @param tableName * @return Table */ - private com.amazonaws.services.glue.model.Table getGlueTable(String databaseName, String tableName) + private software.amazon.awssdk.services.glue.model.Table getGlueTable(String databaseName, String tableName) { - com.amazonaws.services.glue.model.Table table; - GetTableRequest getTableRequest = new GetTableRequest(); - getTableRequest.setDatabaseName(databaseName); - getTableRequest.setName(tableName); + software.amazon.awssdk.services.glue.model.Table table; + software.amazon.awssdk.services.glue.model.GetTableRequest getTableRequest = software.amazon.awssdk.services.glue.model.GetTableRequest.builder() + .databaseName(databaseName) + .name(tableName) + .build(); try { - GetTableResult tableResult = glue.getTable(getTableRequest); - table = tableResult.getTable(); + software.amazon.awssdk.services.glue.model.GetTableResponse tableResponse = glue.getTable(getTableRequest); + table = tableResponse.table(); } catch (EntityNotFoundException e) { throw e; } @@ -392,39 +396,39 @@ private com.amazonaws.services.glue.model.Table getGlueTable(String databaseName * @param table * @return TableInput */ - private TableInput createTableInput(com.amazonaws.services.glue.model.Table table) { - TableInput tableInput = new TableInput(); - tableInput.setDescription(table.getDescription()); - tableInput.setLastAccessTime(table.getLastAccessTime()); - tableInput.setOwner(table.getOwner()); - tableInput.setName(table.getName()); - if (Optional.ofNullable(table.getStorageDescriptor()).isPresent()) { - tableInput.setStorageDescriptor(table.getStorageDescriptor()); - if (Optional.ofNullable(table.getStorageDescriptor().getParameters()).isPresent()) - tableInput.setParameters(table.getStorageDescriptor().getParameters()); + private TableInput createTableInput(software.amazon.awssdk.services.glue.model.Table table) { + TableInput.Builder tableInput = TableInput.builder() + .description(table.description()) + .lastAccessTime(table.lastAccessTime()) + .owner(table.owner()) + .name(table.name()); + if (Optional.ofNullable(table.storageDescriptor()).isPresent()) { + tableInput.storageDescriptor(table.storageDescriptor()); + if (Optional.ofNullable(table.storageDescriptor().parameters()).isPresent()) + tableInput.parameters(table.storageDescriptor().parameters()); } - tableInput.setPartitionKeys(table.getPartitionKeys()); - tableInput.setTableType(table.getTableType()); - tableInput.setViewExpandedText(table.getViewExpandedText()); - tableInput.setViewOriginalText(table.getViewOriginalText()); - tableInput.setParameters(table.getParameters()); - return tableInput; + tableInput.partitionKeys(table.partitionKeys()); + tableInput.tableType(table.tableType()); + tableInput.viewExpandedText(table.viewExpandedText()); + tableInput.viewOriginalText(table.viewOriginalText()); + tableInput.parameters(table.parameters()); + return tableInput.build(); } private void selectHashValue() { String query = String.format("select * from \"%s\".\"%s\".\"%s\";", lambdaFunctionName, redisDbName, redisTableNamePrefix + "_1"); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List names = new ArrayList<>(); rows.forEach(row -> { - names.add(row.getData().get(1).getVarCharValue()); + names.add(row.data().get(1).varCharValue()); // redis key is added as an extra col by the connector. so expected #cols is #glue cols + 1 - assertEquals("Wrong number of columns found", 4, row.getData().size()); + assertEquals("Wrong number of columns found", 4, row.data().size()); }); logger.info("names: {}", names); assertEquals("Wrong number of DB records found.", 3, names.size()); @@ -437,15 +441,15 @@ private void selectZsetValue() { String query = String.format("select * from \"%s\".\"%s\".\"%s\";", lambdaFunctionName, redisDbName, redisTableNamePrefix + "_2"); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List names = new ArrayList<>(); rows.forEach(row -> { - names.add(row.getData().get(0).getVarCharValue()); - assertEquals("Wrong number of columns found", 2, row.getData().size()); + names.add(row.data().get(0).varCharValue()); + assertEquals("Wrong number of columns found", 2, row.data().size()); }); logger.info("names: {}", names); assertEquals("Wrong number of DB records found.", 3, names.size()); @@ -458,15 +462,15 @@ private void selectLiteralValue() { String query = String.format("select * from \"%s\".\"%s\".\"%s\";", lambdaFunctionName, redisDbName, redisTableNamePrefix + "_2"); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List names = new ArrayList<>(); rows.forEach(row -> { - names.add(row.getData().get(0).getVarCharValue()); - assertEquals("Wrong number of columns found", 2, row.getData().size()); + names.add(row.data().get(0).varCharValue()); + assertEquals("Wrong number of columns found", 2, row.data().size()); }); logger.info("names: {}", names); assertEquals("Wrong number of DB records found.", 3, names.size()); @@ -541,8 +545,8 @@ public void standaloneSelectPrefixWithHashValue() tableParams.put("redis-cluster-flag", "false"); tableParams.put("redis-ssl-flag", "false"); tableParams.put("redis-db-number", STANDALONE_REDIS_DB_NUMBER); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_1")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_1")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectHashValue(); } @@ -562,8 +566,8 @@ public void standaloneSelectZsetWithHashValue() tableParams.put("redis-cluster-flag", "false"); tableParams.put("redis-ssl-flag", "false"); tableParams.put("redis-db-number", STANDALONE_REDIS_DB_NUMBER); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_1")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_1")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectHashValue(); } @@ -582,8 +586,8 @@ public void clusterSelectPrefixWithHashValue() tableParams.put("redis-value-type", "hash"); // hash tableParams.put("redis-cluster-flag", "true"); tableParams.put("redis-ssl-flag", "true"); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_1")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_1")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectHashValue(); } @@ -602,8 +606,8 @@ public void clusterSelectZsetWithHashValue() tableParams.put("redis-value-type", "hash"); // hash tableParams.put("redis-cluster-flag", "true"); tableParams.put("redis-ssl-flag", "true"); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_1")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_1")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectHashValue(); } @@ -623,8 +627,8 @@ public void standaloneSelectPrefixWithZsetValue() tableParams.put("redis-cluster-flag", "false"); tableParams.put("redis-ssl-flag", "false"); tableParams.put("redis-db-number", STANDALONE_REDIS_DB_NUMBER); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectZsetValue(); } @@ -644,8 +648,8 @@ public void standaloneSelectZsetWithZsetValue() tableParams.put("redis-cluster-flag", "false"); tableParams.put("redis-ssl-flag", "false"); tableParams.put("redis-db-number", STANDALONE_REDIS_DB_NUMBER); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectZsetValue(); } @@ -664,8 +668,8 @@ public void clusterSelectPrefixWithZsetValue() tableParams.put("redis-value-type", "zset"); // zset tableParams.put("redis-cluster-flag", "true"); tableParams.put("redis-ssl-flag", "true"); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectZsetValue(); } @@ -684,8 +688,8 @@ public void clusterSelectZsetWithZsetValue() tableParams.put("redis-value-type", "zset"); // zset tableParams.put("redis-cluster-flag", "true"); tableParams.put("redis-ssl-flag", "true"); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectZsetValue(); } @@ -705,8 +709,8 @@ public void standaloneSelectPrefixWithLiteralValue() tableParams.put("redis-cluster-flag", "false"); tableParams.put("redis-ssl-flag", "false"); tableParams.put("redis-db-number", STANDALONE_REDIS_DB_NUMBER); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectLiteralValue(); } @@ -726,8 +730,8 @@ public void standaloneSelectZsetWithLiteralValue() tableParams.put("redis-cluster-flag", "false"); tableParams.put("redis-ssl-flag", "false"); tableParams.put("redis-db-number", STANDALONE_REDIS_DB_NUMBER); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectLiteralValue(); } @@ -746,8 +750,8 @@ public void clusterSelectPrefixWithLiteralValue() tableParams.put("redis-value-type", "literal"); // literal tableParams.put("redis-cluster-flag", "true"); tableParams.put("redis-ssl-flag", "true"); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectLiteralValue(); } @@ -766,8 +770,8 @@ public void clusterSelectZsetWithLiteralValue() tableParams.put("redis-value-type", "literal"); // literal tableParams.put("redis-cluster-flag", "true"); tableParams.put("redis-ssl-flag", "true"); - TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).withParameters(tableParams); - glue.updateTable(new UpdateTableRequest().withDatabaseName(redisDbName).withTableInput(tableInput)); + TableInput tableInput = createTableInput(getGlueTable(redisDbName, redisTableNamePrefix + "_2")).toBuilder().parameters(tableParams).build(); + glue.updateTable(UpdateTableRequest.builder().databaseName(redisDbName).tableInput(tableInput).build()); selectLiteralValue(); } diff --git a/athena-redshift/Dockerfile b/athena-redshift/Dockerfile new file mode 100644 index 0000000000..0e7d808823 --- /dev/null +++ b/athena-redshift/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-redshift-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-redshift-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.redshift.RedshiftMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-redshift/athena-redshift.yaml b/athena-redshift/athena-redshift.yaml index 11ad42864d..47cd238f89 100644 --- a/athena-redshift/athena-redshift.yaml +++ b/athena-redshift/athena-redshift.yaml @@ -79,10 +79,9 @@ Resources: default: !Ref DefaultConnectionString kms_key_id: !If [HasKMSKeyId, !Ref KMSKeyId, !Ref "AWS::NoValue"] FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.redshift.RedshiftMuxCompositeHandler" - CodeUri: "./target/athena-redshift-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-redshift:2022.47.1' Description: "Enables Amazon Athena to communicate with Redshift using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory Role: !If [NotHasLambdaRole, !GetAtt FunctionRole.Arn, !Ref LambdaRole] diff --git a/athena-redshift/pom.xml b/athena-redshift/pom.xml index 7119660c3e..5ec2b541d0 100644 --- a/athena-redshift/pom.xml +++ b/athena-redshift/pom.xml @@ -38,16 +38,16 @@ test-jar test - + - com.amazonaws - aws-java-sdk-redshift - ${aws-sdk.version} + software.amazon.awssdk + redshift + ${aws-sdk-v2.version} - com.amazonaws - aws-java-sdk-redshiftserverless - ${aws-sdk.version} + software.amazon.awssdk + redshiftserverless + ${aws-sdk-v2.version} @@ -56,12 +56,18 @@ ${aws-cdk.version} test - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMetadataHandler.java b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMetadataHandler.java index ab23c7b7b4..9a4dc18fea 100644 --- a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMetadataHandler.java +++ b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMetadataHandler.java @@ -36,12 +36,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.postgresql.PostGreSqlMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -80,7 +80,7 @@ public RedshiftMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig } @VisibleForTesting - RedshiftMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) + RedshiftMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { super(databaseConnectionConfig, secretsManager, athena, jdbcConnectionFactory, configOptions); } diff --git a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxMetadataHandler.java b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxMetadataHandler.java index e0ff392750..6293dd99fd 100644 --- a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxMetadataHandler.java +++ b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -58,7 +58,7 @@ public RedshiftMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected RedshiftMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected RedshiftMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxRecordHandler.java b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxRecordHandler.java index 72f1ea1381..2fe7b8fa3e 100644 --- a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxRecordHandler.java +++ b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -58,7 +58,7 @@ public RedshiftMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - RedshiftMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + RedshiftMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java index 1546ea391b..0b294de160 100644 --- a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java +++ b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java @@ -30,15 +30,12 @@ import com.amazonaws.athena.connectors.postgresql.PostGreSqlQueryStringBuilder; import com.amazonaws.athena.connectors.postgresql.PostGreSqlRecordHandler; import com.amazonaws.athena.connectors.postgresql.PostgreSqlFederationExpressionParser; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import static com.amazonaws.athena.connectors.postgresql.PostGreSqlConstants.POSTGRES_QUOTE_CHARACTER; import static com.amazonaws.athena.connectors.redshift.RedshiftConstants.REDSHIFT_DEFAULT_PORT; @@ -62,12 +59,14 @@ public RedshiftRecordHandler(java.util.Map configOptions) public RedshiftRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - super(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), - new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(REDSHIFT_DRIVER_CLASS, REDSHIFT_DEFAULT_PORT)), new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER)), configOptions); + super(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), + new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, + new DatabaseConnectionInfo(REDSHIFT_DRIVER_CLASS, REDSHIFT_DEFAULT_PORT)), + new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - RedshiftRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + RedshiftRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(databaseConnectionConfig, amazonS3, secretsManager, athena, jdbcConnectionFactory, jdbcSplitQueryBuilder, configOptions); } diff --git a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMetadataHandlerTest.java b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMetadataHandlerTest.java index 027c716876..2882d1a0d2 100644 --- a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMetadataHandlerTest.java +++ b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMetadataHandlerTest.java @@ -41,10 +41,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.postgresql.PostGreSqlMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; @@ -57,6 +53,10 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.PreparedStatement; @@ -88,8 +88,8 @@ public class RedshiftMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() @@ -98,8 +98,8 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.redshiftMetadataHandler = new RedshiftMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); } diff --git a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcMetadataHandlerTest.java b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcMetadataHandlerTest.java index b0c5a7d86d..7db8b60795 100644 --- a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcMetadataHandlerTest.java +++ b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcMetadataHandlerTest.java @@ -32,11 +32,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -49,8 +49,8 @@ public class RedshiftMuxJdbcMetadataHandlerTest private RedshiftMetadataHandler redshiftMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -62,8 +62,8 @@ public void setup() //Mockito.when(this.allocator.createBlock(nullable(Schema.class))).thenReturn(Mockito.mock(Block.class)); this.redshiftMetadataHandler = Mockito.mock(RedshiftMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("redshift", this.redshiftMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "redshift", diff --git a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcRecordHandlerTest.java b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcRecordHandlerTest.java index 686e322ccc..2a4abf05bc 100644 --- a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcRecordHandlerTest.java +++ b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class RedshiftMuxJdbcRecordHandlerTest private Map recordHandlerMap; private RedshiftRecordHandler redshiftRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.redshiftRecordHandler = Mockito.mock(RedshiftRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("redshift", this.redshiftRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "redshift", diff --git a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java index 8f459391b5..c3e73e0e1f 100644 --- a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java +++ b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java @@ -36,9 +36,6 @@ import com.amazonaws.athena.connectors.postgresql.PostGreSqlMetadataHandler; import com.amazonaws.athena.connectors.postgresql.PostGreSqlQueryStringBuilder; import com.amazonaws.athena.connectors.postgresql.PostgreSqlFederationExpressionParser; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -52,6 +49,9 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.math.BigDecimal; import java.sql.Connection; @@ -75,18 +75,18 @@ public class RedshiftRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private MockedStatic mockedPostGreSqlMetadataHandler; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/integ/RedshiftIntegTest.java b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/integ/RedshiftIntegTest.java index 392a0f5611..d103901df4 100644 --- a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/integ/RedshiftIntegTest.java +++ b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/integ/RedshiftIntegTest.java @@ -26,16 +26,8 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo; import com.amazonaws.athena.connectors.jdbc.integ.JdbcTableUtils; -import com.amazonaws.services.athena.model.Datum; -import com.amazonaws.services.athena.model.Row; -import com.amazonaws.services.redshift.AmazonRedshift; -import com.amazonaws.services.redshift.AmazonRedshiftClientBuilder; -import com.amazonaws.services.redshift.model.DescribeClustersRequest; -import com.amazonaws.services.redshift.model.DescribeClustersResult; -import com.amazonaws.services.redshift.model.Endpoint; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.testng.AssertJUnit; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -51,6 +43,11 @@ import software.amazon.awscdk.services.redshift.ClusterType; import software.amazon.awscdk.services.redshift.Login; import software.amazon.awscdk.services.redshift.NodeType; +import software.amazon.awssdk.services.athena.model.Row; +import software.amazon.awssdk.services.redshift.RedshiftClient; +import software.amazon.awssdk.services.redshift.model.DescribeClustersRequest; +import software.amazon.awssdk.services.redshift.model.DescribeClustersResponse; +import software.amazon.awssdk.services.redshift.model.Endpoint; import java.util.ArrayList; import java.util.Collections; @@ -190,14 +187,14 @@ private Stack getRedshiftStack() */ private Endpoint getClusterData() { - AmazonRedshift redshiftClient = AmazonRedshiftClientBuilder.defaultClient(); + RedshiftClient redshiftClient = RedshiftClient.create(); try { - DescribeClustersResult clustersResult = redshiftClient.describeClusters(new DescribeClustersRequest() - .withClusterIdentifier(clusterName)); - return clustersResult.getClusters().get(0).getEndpoint(); + DescribeClustersResponse clustersResult = redshiftClient.describeClusters(DescribeClustersRequest.builder() + .clusterIdentifier(clusterName).build()); + return clustersResult.clusters().get(0).endpoint(); } finally { - redshiftClient.shutdown(); + redshiftClient.close(); } } @@ -208,7 +205,7 @@ private Endpoint getClusterData() private void setEnvironmentVars(Endpoint endpoint) { String connectionString = String.format("redshift://jdbc:redshift://%s:%s/public?user=%s&password=%s", - endpoint.getAddress(), endpoint.getPort(), username, password); + endpoint.address(), endpoint.port(), username, password); String connectionStringTag = lambdaFunctionName + "_connection_string"; environmentVars.put("default", connectionString); environmentVars.put(connectionStringTag, connectionString); @@ -442,13 +439,13 @@ public void selectColumnWithPredicateIntegTest() String query = String.format("select title from %s.%s.%s where year > 2000;", lambdaFunctionName, redshiftDbName, redshiftTableMovies); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List titles = new ArrayList<>(); - rows.forEach(row -> titles.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> titles.add(row.data().get(0).varCharValue())); logger.info("Titles: {}", titles); assertEquals("Wrong number of DB records found.", 1, titles.size()); assertTrue("Movie title not found: Interstellar.", titles.contains("Interstellar")); @@ -465,13 +462,13 @@ public void selectColumnBetweenDatesIntegTest() String query = String.format( "select first_name from %s.%s.%s where birthday between date('2003-1-1') and date('2005-12-31');", lambdaFunctionName, redshiftDbName, redshiftTableBday); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List names = new ArrayList<>(); - rows.forEach(row -> names.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> names.add(row.data().get(0).varCharValue())); logger.info("Names: {}", names); assertEquals("Wrong number of DB records found.", 1, names.size()); assertTrue("Name not found: Jane.", names.contains("Jane")); diff --git a/athena-saphana/Dockerfile b/athena-saphana/Dockerfile new file mode 100644 index 0000000000..5e55d28a12 --- /dev/null +++ b/athena-saphana/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-saphana.zip ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-saphana.zip + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.saphana.SaphanaMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-saphana/athena-saphana.yaml b/athena-saphana/athena-saphana.yaml index 13b945c42f..5a1d895933 100644 --- a/athena-saphana/athena-saphana.yaml +++ b/athena-saphana/athena-saphana.yaml @@ -69,10 +69,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.saphana.SaphanaMuxCompositeHandler" - CodeUri: "./target/athena-saphana.zip" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-saphana:2022.47.1' Description: "Enables Amazon Athena to communicate with Teradata using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-saphana/pom.xml b/athena-saphana/pom.xml index e4a6199950..7677fac010 100644 --- a/athena-saphana/pom.xml +++ b/athena-saphana/pom.xml @@ -27,12 +27,18 @@ test-jar test - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMetadataHandler.java b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMetadataHandler.java index b0c9aca0a3..552c6751a4 100644 --- a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMetadataHandler.java +++ b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMetadataHandler.java @@ -52,8 +52,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -64,6 +62,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -102,8 +102,8 @@ SaphanaConstants.JDBC_PROPERTIES, new DatabaseConnectionInfo(SaphanaConstants.SA @VisibleForTesting protected SaphanaMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxMetadataHandler.java b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxMetadataHandler.java index 0a8d019de0..238de7146d 100644 --- a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxMetadataHandler.java +++ b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -56,7 +56,7 @@ public SaphanaMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected SaphanaMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected SaphanaMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxRecordHandler.java b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxRecordHandler.java index 050b8ba5d2..2414854794 100644 --- a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxRecordHandler.java +++ b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public SaphanaMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - SaphanaMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + SaphanaMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandler.java b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandler.java index c65656c45f..67f7a93e6a 100644 --- a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandler.java +++ b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandler.java @@ -31,18 +31,15 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -68,7 +65,7 @@ public SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, j SaphanaConstants.SAPHANA_DEFAULT_PORT)), configOptions); } @VisibleForTesting - SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); @@ -76,8 +73,8 @@ public SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, j public SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, GenericJdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), jdbcConnectionFactory, new SaphanaQueryStringBuilder(SAPHANA_QUOTE_CHARACTER, new SaphanaFederationExpressionParser(SAPHANA_QUOTE_CHARACTER)), configOptions); + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), + AthenaClient.create(), jdbcConnectionFactory, new SaphanaQueryStringBuilder(SAPHANA_QUOTE_CHARACTER, new SaphanaFederationExpressionParser(SAPHANA_QUOTE_CHARACTER)), configOptions); } @Override diff --git a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMetadataHandlerTest.java b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMetadataHandlerTest.java index bf8b86c007..59d9973b8e 100644 --- a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMetadataHandlerTest.java +++ b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMetadataHandlerTest.java @@ -48,10 +48,10 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.nullable; @@ -65,8 +65,8 @@ public class SaphanaMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private BlockAllocator blockAllocator; private static final Schema PARTITION_SCHEMA = SchemaBuilder.newBuilder().addField("PART_ID", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build(); @@ -78,9 +78,9 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.saphanaMetadataHandler = new SaphanaMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.blockAllocator = Mockito.mock(BlockAllocator.class); diff --git a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcMetadataHandlerTest.java b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcMetadataHandlerTest.java index 1e4dc8fffd..bab1b3d9c2 100644 --- a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcMetadataHandlerTest.java +++ b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcMetadataHandlerTest.java @@ -28,11 +28,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -45,8 +45,8 @@ public class SaphanaMuxJdbcMetadataHandlerTest { private SaphanaMetadataHandler saphanaMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -58,8 +58,8 @@ public void setup() //Mockito.when(this.allocator.createBlock(nullable(Schema.class))).thenReturn(Mockito.mock(Block.class)); this.saphanaMetadataHandler = Mockito.mock(SaphanaMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.saphanaMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcRecordHandlerTest.java b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcRecordHandlerTest.java index 1336af4d7c..5f17964b44 100644 --- a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcRecordHandlerTest.java +++ b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class SaphanaMuxJdbcRecordHandlerTest private Map recordHandlerMap; private SaphanaRecordHandler saphanaRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.saphanaRecordHandler = Mockito.mock(SaphanaRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("saphana", this.saphanaRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "saphana", diff --git a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandlerTest.java b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandlerTest.java index 833dd0a0d3..c48ced9e6c 100644 --- a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandlerTest.java +++ b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandlerTest.java @@ -32,9 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -43,6 +40,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -58,17 +58,17 @@ public class SaphanaRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-snowflake/Dockerfile b/athena-snowflake/Dockerfile new file mode 100644 index 0000000000..8d4d9081a6 --- /dev/null +++ b/athena-snowflake/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-snowflake.zip ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-snowflake.zip + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.snowflake.SnowflakeMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-snowflake/athena-snowflake.yaml b/athena-snowflake/athena-snowflake.yaml index 24edcaf9f8..67bac6e7aa 100644 --- a/athena-snowflake/athena-snowflake.yaml +++ b/athena-snowflake/athena-snowflake.yaml @@ -69,10 +69,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.snowflake.SnowflakeMuxCompositeHandler" - CodeUri: "./target/athena-snowflake.zip" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-snowflake:2022.47.1' Description: "Enables Amazon Athena to communicate with Snowflake using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-snowflake/pom.xml b/athena-snowflake/pom.xml index fcb730044b..aec0e7f807 100644 --- a/athena-snowflake/pom.xml +++ b/athena-snowflake/pom.xml @@ -32,12 +32,18 @@ snowflake-jdbc 3.19.0 - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java index 234718760a..d9c9b9d3f1 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java @@ -55,8 +55,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; @@ -68,6 +66,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -139,8 +139,8 @@ JDBC_PROPERTIES, new DatabaseConnectionInfo(SnowflakeConstants.SNOWFLAKE_DRIVER_ @VisibleForTesting protected SnowflakeMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxMetadataHandler.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxMetadataHandler.java index 5754598bd8..17c85a6189 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxMetadataHandler.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -56,7 +56,7 @@ public SnowflakeMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected SnowflakeMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected SnowflakeMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxRecordHandler.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxRecordHandler.java index 07694a20a6..2fb0812375 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxRecordHandler.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public SnowflakeMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - SnowflakeMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + SnowflakeMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java index 1f4b8ad1b2..28ac13ff21 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java @@ -30,15 +30,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -68,12 +65,12 @@ public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, } public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, GenericJdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new SnowflakeQueryStringBuilder(SNOWFLAKE_QUOTE_CHARACTER, new SnowflakeFederationExpressionParser(SNOWFLAKE_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final AWSSecretsManager secretsManager, - final AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, + final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java index b18fde1cda..6a219a3b1f 100644 --- a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java @@ -29,16 +29,16 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.*; import java.util.*; @@ -58,8 +58,8 @@ public class SnowflakeMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private BlockAllocator blockAllocator; private static final Schema PARTITION_SCHEMA = SchemaBuilder.newBuilder().addField("partition", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build(); @@ -71,9 +71,9 @@ public void setup() this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class , Mockito.RETURNS_DEEP_STUBS); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.snowflakeMetadataHandler = new SnowflakeMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.blockAllocator = Mockito.mock(BlockAllocator.class); diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcMetadataHandlerTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcMetadataHandlerTest.java index c566df626d..b6279fa517 100644 --- a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcMetadataHandlerTest.java +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcMetadataHandlerTest.java @@ -28,11 +28,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -45,8 +45,8 @@ public class SnowflakeMuxJdbcMetadataHandlerTest private SnowflakeMetadataHandler snowflakeMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -56,8 +56,8 @@ public void setup() this.allocator = new BlockAllocatorImpl(); this.snowflakeMetadataHandler = Mockito.mock(SnowflakeMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.snowflakeMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcRecordHandlerTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcRecordHandlerTest.java index 5dee99bfb7..367fde0afc 100644 --- a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcRecordHandlerTest.java +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcRecordHandlerTest.java @@ -30,13 +30,13 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.snowflake.SnowflakeMuxRecordHandler; import com.amazonaws.athena.connectors.snowflake.SnowflakeRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -50,9 +50,9 @@ public class SnowflakeMuxJdbcRecordHandlerTest private Map recordHandlerMap; private SnowflakeRecordHandler snowflakeRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -61,9 +61,9 @@ public void setup() { this.snowflakeRecordHandler = Mockito.mock(SnowflakeRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("snowflake", this.snowflakeRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "snowflake", diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandlerTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandlerTest.java index 006cbd95e5..56531dcdac 100644 --- a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandlerTest.java +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandlerTest.java @@ -33,9 +33,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -44,6 +41,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -59,17 +59,17 @@ public class SnowflakeRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-sqlserver/Dockerfile b/athena-sqlserver/Dockerfile new file mode 100644 index 0000000000..e602b9fc50 --- /dev/null +++ b/athena-sqlserver/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-sqlserver-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-sqlserver-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.sqlserver.SqlServerMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-sqlserver/athena-sqlserver.yaml b/athena-sqlserver/athena-sqlserver.yaml index 22c9e1b89c..9f09edfc88 100644 --- a/athena-sqlserver/athena-sqlserver.yaml +++ b/athena-sqlserver/athena-sqlserver.yaml @@ -76,10 +76,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.sqlserver.SqlServerMuxCompositeHandler" - CodeUri: "./target/athena-sqlserver-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-sqlserver:2022.47.1' Description: "Enables Amazon Athena to communicate with SQLSERVER using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-sqlserver/pom.xml b/athena-sqlserver/pom.xml index 6bfab70343..3b723f0a87 100644 --- a/athena-sqlserver/pom.xml +++ b/athena-sqlserver/pom.xml @@ -32,12 +32,18 @@ mssql-jdbc ${mssql.jdbc.version} - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMetadataHandler.java b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMetadataHandler.java index 604796ca3b..6321f91f38 100644 --- a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMetadataHandler.java +++ b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMetadataHandler.java @@ -54,8 +54,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -67,6 +65,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -150,8 +150,8 @@ public SqlServerMetadataHandler(DatabaseConnectionConfig databaseConnectionConfi @VisibleForTesting protected SqlServerMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxMetadataHandler.java b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxMetadataHandler.java index 4320ac8afe..7ca913e30e 100644 --- a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxMetadataHandler.java +++ b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxMetadataHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -56,7 +56,7 @@ public SqlServerMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected SqlServerMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected SqlServerMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandler.java b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandler.java index 5a195fc0f3..e9d5009639 100644 --- a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandler.java +++ b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public SqlServerMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - SqlServerMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + SqlServerMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandler.java b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandler.java index e1b64e79f5..073f5ad946 100644 --- a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandler.java +++ b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandler.java @@ -29,15 +29,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -63,13 +60,13 @@ public SqlServerRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, public SqlServerRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), jdbcConnectionFactory, new SqlServerQueryStringBuilder(SQLSERVER_QUOTE_CHARACTER, new SqlServerFederationExpressionParser(SQLSERVER_QUOTE_CHARACTER)), configOptions); + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), + AthenaClient.create(), jdbcConnectionFactory, new SqlServerQueryStringBuilder(SQLSERVER_QUOTE_CHARACTER, new SqlServerFederationExpressionParser(SQLSERVER_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - SqlServerRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final AWSSecretsManager secretsManager, - final AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + SqlServerRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, + final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMetadataHandlerTest.java b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMetadataHandlerTest.java index 27704cfdb1..e8e46b63fc 100644 --- a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMetadataHandlerTest.java +++ b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMetadataHandlerTest.java @@ -42,10 +42,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; @@ -54,6 +50,10 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.PreparedStatement; @@ -84,8 +84,8 @@ public class SqlServerMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private BlockAllocator allocator; @Before @@ -97,9 +97,9 @@ public void setup() this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); logger.info(" this.connection.."+ this.connection); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}").build()); this.sqlServerMetadataHandler = new SqlServerMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.allocator = new BlockAllocatorImpl(); diff --git a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxMetadataHandlerTest.java b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxMetadataHandlerTest.java index 96f8b3e5e2..cf8ec137a1 100644 --- a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxMetadataHandlerTest.java +++ b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxMetadataHandlerTest.java @@ -32,11 +32,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -49,8 +49,8 @@ public class SqlServerMuxMetadataHandlerTest private SqlServerMetadataHandler sqlServerMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -60,8 +60,8 @@ public void setup() this.allocator = new BlockAllocatorImpl(); this.sqlServerMetadataHandler = Mockito.mock(SqlServerMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.sqlServerMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandlerTest.java b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandlerTest.java index 7b337185fd..e6faa255d9 100644 --- a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandlerTest.java +++ b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class SqlServerMuxRecordHandlerTest private Map recordHandlerMap; private SqlServerRecordHandler sqlServerRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.sqlServerRecordHandler = Mockito.mock(SqlServerRecordHandler.class); this.recordHandlerMap = Collections.singletonMap(SqlServerConstants.NAME, this.sqlServerRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", SqlServerConstants.NAME, diff --git a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandlerTest.java b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandlerTest.java index 8ca6ebf791..c6f8f659dd 100644 --- a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandlerTest.java +++ b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandlerTest.java @@ -32,9 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -42,6 +39,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -57,18 +57,18 @@ public class SqlServerRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { System.setProperty("aws.region", "us-east-1"); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-synapse/Dockerfile b/athena-synapse/Dockerfile new file mode 100644 index 0000000000..2a7a05ec98 --- /dev/null +++ b/athena-synapse/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-synapse-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-synapse-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.synapse.SynapseMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-synapse/athena-synapse.yaml b/athena-synapse/athena-synapse.yaml index 3414752a64..1aa43f00be 100644 --- a/athena-synapse/athena-synapse.yaml +++ b/athena-synapse/athena-synapse.yaml @@ -78,10 +78,9 @@ Resources: spill_prefix: !Ref SpillPrefix default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.synapse.SynapseMuxCompositeHandler" - CodeUri: "./target/athena-synapse-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-synapse:2022.47.1' Description: "Enables Amazon Athena to communicate with SYNPASE using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory Role: !If [NotHasLambdaRole, !GetAtt FunctionRole.Arn, !Ref LambdaRoleARN] diff --git a/athena-synapse/pom.xml b/athena-synapse/pom.xml index 6f1440edd9..bb40a5cad7 100644 --- a/athena-synapse/pom.xml +++ b/athena-synapse/pom.xml @@ -59,12 +59,18 @@ - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + diff --git a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandler.java b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandler.java index 1ad714e9c5..f35bbe34a2 100644 --- a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandler.java +++ b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandler.java @@ -48,8 +48,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -63,6 +61,8 @@ import org.stringtemplate.v4.ST; import org.stringtemplate.v4.STGroup; import org.stringtemplate.v4.STGroupDir; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -111,8 +111,8 @@ public SynapseMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, @VisibleForTesting protected SynapseMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxMetadataHandler.java b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxMetadataHandler.java index daa1b25609..8947d626b7 100644 --- a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxMetadataHandler.java +++ b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxMetadataHandler.java @@ -24,9 +24,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public SynapseMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected SynapseMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected SynapseMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandler.java b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandler.java index 62b0f315c3..6fabc20bf7 100644 --- a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandler.java +++ b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandler.java @@ -24,10 +24,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -54,7 +54,7 @@ public SynapseMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - SynapseMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + SynapseMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandler.java b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandler.java index f7a6fa70cd..a7a6aed815 100644 --- a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandler.java +++ b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandler.java @@ -33,12 +33,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -46,6 +40,9 @@ import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -66,15 +63,15 @@ public SynapseRecordHandler(java.util.Map configOptions) } public SynapseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), new SynapseJdbcConnectionFactory(databaseConnectionConfig, + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), + AthenaClient.create(), new SynapseJdbcConnectionFactory(databaseConnectionConfig, SynapseMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(SynapseConstants.DRIVER_CLASS, SynapseConstants.DEFAULT_PORT)), new SynapseQueryStringBuilder(QUOTE_CHARACTER, new SynapseFederationExpressionParser(QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - SynapseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final AWSSecretsManager secretsManager, - final AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + SynapseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, + final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java index 037b0540e5..7e90ae436f 100644 --- a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java @@ -38,10 +38,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; @@ -50,7 +46,10 @@ import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -80,8 +79,8 @@ public class SynapseMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() @@ -92,9 +91,9 @@ public void setup() this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); logger.info(" this.connection.."+ this.connection); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"user\": \"testUser\", \"password\": \"testPassword\"}").build()); this.synapseMetadataHandler = new SynapseMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of()); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); } diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxMetadataHandlerTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxMetadataHandlerTest.java index 9b233d4201..8c593cb948 100644 --- a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxMetadataHandlerTest.java +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxMetadataHandlerTest.java @@ -32,11 +32,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -49,8 +49,8 @@ public class SynapseMuxMetadataHandlerTest private SynapseMetadataHandler synapseMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -60,8 +60,8 @@ public void setup() this.allocator = new BlockAllocatorImpl(); this.synapseMetadataHandler = Mockito.mock(SynapseMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.synapseMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandlerTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandlerTest.java index 0b3daed367..3ed375cae4 100644 --- a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandlerTest.java +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class SynapseMuxRecordHandlerTest private Map recordHandlerMap; private SynapseRecordHandler synapseRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.synapseRecordHandler = Mockito.mock(SynapseRecordHandler.class); this.recordHandlerMap = Collections.singletonMap(SynapseConstants.NAME, this.synapseRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", SynapseConstants.NAME, diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java index c8f27f7887..b0108974cc 100644 --- a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java @@ -31,9 +31,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -41,6 +38,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -58,17 +58,17 @@ public class SynapseRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-teradata/Dockerfile b/athena-teradata/Dockerfile new file mode 100644 index 0000000000..8f58411065 --- /dev/null +++ b/athena-teradata/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-teradata-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-teradata-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.teradata.TeradataMuxCompositeHandler" ] \ No newline at end of file diff --git a/athena-teradata/athena-teradata.yaml b/athena-teradata/athena-teradata.yaml index 3e1431c6c0..01ee09c1db 100644 --- a/athena-teradata/athena-teradata.yaml +++ b/athena-teradata/athena-teradata.yaml @@ -30,9 +30,6 @@ Parameters: Description: 'The prefix within SpillBucket where this function can spill data.' Type: String Default: athena-spill - LambdaJDBCLayername: - Description: 'Lambda JDBC layer Name. Must be ARN of layer' - Type: String LambdaTimeout: Description: 'Maximum Lambda invocation runtime in seconds. (min 1 - 900 max)' Default: 900 @@ -77,12 +74,9 @@ Resources: default: !Ref DefaultConnectionString partitioncount: !Ref PartitionCount FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.teradata.TeradataMuxCompositeHandler" - Layers: - - !Ref LambdaJDBCLayername - CodeUri: "./target/athena-teradata-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-teradata:2022.47.1' Description: "Enables Amazon Athena to communicate with Teradata using JDBC" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-teradata/pom.xml b/athena-teradata/pom.xml index f643b8709d..c44fd9da86 100644 --- a/athena-teradata/pom.xml +++ b/athena-teradata/pom.xml @@ -27,12 +27,18 @@ test-jar test - + - com.amazonaws - aws-java-sdk-rds - ${aws-sdk.version} + software.amazon.awssdk + rds + ${aws-sdk-v2.version} test + + + software.amazon.awssdk + netty-nio-client + + @@ -53,6 +59,11 @@ ${mockito.version} test + + com.teradata.jdbc + terajdbc + 20.00.00.34 + diff --git a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMetadataHandler.java b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMetadataHandler.java index b7d7ffe201..d230e8dcbf 100644 --- a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMetadataHandler.java +++ b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMetadataHandler.java @@ -50,8 +50,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -65,6 +63,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -129,8 +129,8 @@ public TeradataMetadataHandler( @VisibleForTesting protected TeradataMetadataHandler( DatabaseConnectionConfig databaseConnectionConfig, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { diff --git a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxMetadataHandler.java b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxMetadataHandler.java index c043e4c9e4..a54c73fe56 100644 --- a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxMetadataHandler.java +++ b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxMetadataHandler.java @@ -24,9 +24,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public TeradataMuxMetadataHandler(java.util.Map configOptions) } @VisibleForTesting - protected TeradataMuxMetadataHandler(AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + protected TeradataMuxMetadataHandler(SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, Map metadataHandlerMap, DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { super(secretsManager, athena, jdbcConnectionFactory, metadataHandlerMap, databaseConnectionConfig, configOptions); diff --git a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxRecordHandler.java b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxRecordHandler.java index 8ec63445a3..5667ddae62 100644 --- a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxRecordHandler.java +++ b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxRecordHandler.java @@ -25,10 +25,10 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public TeradataMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - TeradataMuxRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, + TeradataMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandler.java b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandler.java index 4a9f820581..52382322a6 100644 --- a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandler.java +++ b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandler.java @@ -29,15 +29,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -60,13 +57,13 @@ public TeradataRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, public TeradataRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new TeradataQueryStringBuilder(TERADATA_QUOTE_CHARACTER, new TeradataFederationExpressionParser(TERADATA_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - TeradataRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final AWSSecretsManager secretsManager, - final AmazonAthena athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + TeradataRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, + final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMetadataHandlerTest.java b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMetadataHandlerTest.java index d7ad931f00..137b2db9d2 100644 --- a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMetadataHandlerTest.java +++ b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMetadataHandlerTest.java @@ -29,16 +29,16 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.*; import java.util.*; @@ -58,8 +58,8 @@ public class TeradataMetadataHandlerTest private JdbcConnectionFactory jdbcConnectionFactory; private Connection connection; private FederatedIdentity federatedIdentity; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private BlockAllocator blockAllocator; @Before @@ -67,9 +67,9 @@ public void setup() throws Exception { this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); - Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); + Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); this.teradataMetadataHandler = new TeradataMetadataHandler(databaseConnectionConfig, this.secretsManager, this.athena, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of("partitioncount", "1000")); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.blockAllocator = Mockito.mock(BlockAllocator.class); diff --git a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcMetadataHandlerTest.java b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcMetadataHandlerTest.java index a0373109bc..5f2e376f82 100644 --- a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcMetadataHandlerTest.java +++ b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcMetadataHandlerTest.java @@ -28,11 +28,11 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.Map; @@ -45,8 +45,8 @@ public class TeradataMuxJdbcMetadataHandlerTest { private TeradataMetadataHandler teradataMetadataHandler; private JdbcMetadataHandler jdbcMetadataHandler; private BlockAllocator allocator; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -56,8 +56,8 @@ public void setup() this.allocator = new BlockAllocatorImpl(); this.teradataMetadataHandler = Mockito.mock(TeradataMetadataHandler.class); this.metadataHandlerMap = Collections.singletonMap("fakedatabase", this.teradataMetadataHandler); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "fakedatabase", diff --git a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcRecordHandlerTest.java b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcRecordHandlerTest.java index 600b450c3e..0c768ba3db 100644 --- a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcRecordHandlerTest.java +++ b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcRecordHandlerTest.java @@ -28,13 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.SQLException; @@ -46,9 +46,9 @@ public class TeradataMuxJdbcRecordHandlerTest private Map recordHandlerMap; private TeradataRecordHandler teradataRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; private QueryStatusChecker queryStatusChecker; private JdbcConnectionFactory jdbcConnectionFactory; @@ -57,9 +57,9 @@ public void setup() { this.teradataRecordHandler = Mockito.mock(TeradataRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("teradata", this.teradataRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", "teradata", diff --git a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandlerTest.java b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandlerTest.java index 9118bc3492..4a306592df 100644 --- a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandlerTest.java +++ b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandlerTest.java @@ -32,9 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -43,6 +40,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; import java.sql.PreparedStatement; @@ -58,17 +58,17 @@ public class TeradataRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; + private S3Client amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; @Before public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.amazonS3 = Mockito.mock(S3Client.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); diff --git a/athena-timestream/Dockerfile b/athena-timestream/Dockerfile new file mode 100644 index 0000000000..8a0be2c8f9 --- /dev/null +++ b/athena-timestream/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-timestream-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-timestream-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.timestream.TimestreamCompositeHandler" ] \ No newline at end of file diff --git a/athena-timestream/athena-timestream.yaml b/athena-timestream/athena-timestream.yaml index e0e3de7840..1850ffbecc 100644 --- a/athena-timestream/athena-timestream.yaml +++ b/athena-timestream/athena-timestream.yaml @@ -52,10 +52,9 @@ Resources: spill_bucket: !Ref SpillBucket spill_prefix: !Ref SpillPrefix FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.timestream.TimestreamCompositeHandler" - CodeUri: "./target/athena-timestream-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-timestream:2022.47.1' Description: "Enables Amazon Athena to communicate with Amazon Timestream, making your time series data accessible from Athena." - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-timestream/pom.xml b/athena-timestream/pom.xml index a58b2c13c0..f00279ec9d 100644 --- a/athena-timestream/pom.xml +++ b/athena-timestream/pom.xml @@ -47,14 +47,14 @@ ${slf4j-log4j.version} - com.amazonaws - aws-java-sdk-timestreamwrite - ${aws-sdk.version} + software.amazon.awssdk + timestreamwrite + ${aws-sdk-v2.version} - com.amazonaws - aws-java-sdk-timestreamquery - ${aws-sdk.version} + software.amazon.awssdk + timestreamquery + ${aws-sdk-v2.version} org.slf4j @@ -85,6 +85,12 @@ ${log4j2Version} runtime + + org.mockito + mockito-inline + ${mockito.version} + test + diff --git a/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamClientBuilder.java b/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamClientBuilder.java index 5f0a228f73..473ea85932 100644 --- a/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamClientBuilder.java +++ b/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamClientBuilder.java @@ -19,38 +19,42 @@ */ package com.amazonaws.athena.connectors.timestream; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.services.timestreamquery.AmazonTimestreamQuery; -import com.amazonaws.services.timestreamquery.AmazonTimestreamQueryClientBuilder; -import com.amazonaws.services.timestreamwrite.AmazonTimestreamWrite; -import com.amazonaws.services.timestreamwrite.AmazonTimestreamWriteClientBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.client.config.SdkAdvancedClientOption; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; +import software.amazon.awssdk.services.timestreamquery.TimestreamQueryClient; +import software.amazon.awssdk.services.timestreamwrite.TimestreamWriteClient; public class TimestreamClientBuilder { private static final Logger logger = LoggerFactory.getLogger(TimestreamClientBuilder.class); - + static Region defaultRegion = DefaultAwsRegionProviderChain.builder().build().getRegion(); private TimestreamClientBuilder() { // prevent instantiation with private constructor } - public static AmazonTimestreamQuery buildQueryClient(String sourceType) + public static TimestreamQueryClient buildQueryClient(String sourceType) { - return AmazonTimestreamQueryClientBuilder.standard().withClientConfiguration(buildClientConfiguration(sourceType)).build(); + return TimestreamQueryClient.builder().region(defaultRegion).credentialsProvider(DefaultCredentialsProvider.create()) + .overrideConfiguration(buildClientConfiguration(sourceType)).build(); } - public static AmazonTimestreamWrite buildWriteClient(String sourceType) + public static TimestreamWriteClient buildWriteClient(String sourceType) { - return AmazonTimestreamWriteClientBuilder.standard().withClientConfiguration(buildClientConfiguration(sourceType)).build(); + return TimestreamWriteClient.builder().region(defaultRegion).credentialsProvider(DefaultCredentialsProvider.create()) + .overrideConfiguration(buildClientConfiguration(sourceType)).build(); } - static ClientConfiguration buildClientConfiguration(String sourceType) + static ClientOverrideConfiguration buildClientConfiguration(String sourceType) { String userAgent = "aws-athena-" + sourceType + "-connector"; - ClientConfiguration clientConfiguration = new ClientConfiguration().withUserAgentPrefix(userAgent); - logger.info("Created client configuration with user agent {} for Timestream SDK", clientConfiguration.getUserAgentPrefix()); + ClientOverrideConfiguration clientConfiguration = ClientOverrideConfiguration.builder().putAdvancedOption(SdkAdvancedClientOption.USER_AGENT_PREFIX, userAgent).build(); + logger.info("Created client configuration with user agent {} for Timestream SDK is present", clientConfiguration.advancedOption(SdkAdvancedClientOption.USER_AGENT_PREFIX).isPresent()); return clientConfiguration; } } diff --git a/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamMetadataHandler.java b/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamMetadataHandler.java index 990673f527..84cdb9fb76 100644 --- a/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamMetadataHandler.java +++ b/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamMetadataHandler.java @@ -42,26 +42,26 @@ import com.amazonaws.athena.connector.util.PaginatedRequestIterator; import com.amazonaws.athena.connectors.timestream.qpt.TimestreamQueryPassthrough; import com.amazonaws.athena.connectors.timestream.query.QueryFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Table; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.timestreamquery.AmazonTimestreamQuery; -import com.amazonaws.services.timestreamquery.model.ColumnInfo; -import com.amazonaws.services.timestreamquery.model.Datum; -import com.amazonaws.services.timestreamquery.model.QueryRequest; -import com.amazonaws.services.timestreamquery.model.QueryResult; -import com.amazonaws.services.timestreamquery.model.Row; -import com.amazonaws.services.timestreamwrite.AmazonTimestreamWrite; -import com.amazonaws.services.timestreamwrite.model.ListDatabasesRequest; -import com.amazonaws.services.timestreamwrite.model.ListDatabasesResult; -import com.amazonaws.services.timestreamwrite.model.ListTablesResult; import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.timestreamquery.TimestreamQueryClient; +import software.amazon.awssdk.services.timestreamquery.model.ColumnInfo; +import software.amazon.awssdk.services.timestreamquery.model.Datum; +import software.amazon.awssdk.services.timestreamquery.model.QueryRequest; +import software.amazon.awssdk.services.timestreamquery.model.QueryResponse; +import software.amazon.awssdk.services.timestreamquery.model.Row; +import software.amazon.awssdk.services.timestreamwrite.TimestreamWriteClient; +import software.amazon.awssdk.services.timestreamwrite.model.Database; +import software.amazon.awssdk.services.timestreamwrite.model.ListDatabasesRequest; +import software.amazon.awssdk.services.timestreamwrite.model.ListDatabasesResponse; import java.util.Collections; import java.util.List; @@ -82,16 +82,16 @@ public class TimestreamMetadataHandler //is indeed enabled for use by this connector. private static final String METADATA_FLAG = "timestream-metadata-flag"; //Used to filter out Glue tables which lack a timestream metadata flag. - private static final TableFilter TABLE_FILTER = (Table table) -> table.getParameters().containsKey(METADATA_FLAG); + private static final TableFilter TABLE_FILTER = (Table table) -> table.parameters().containsKey(METADATA_FLAG); private static final long MAX_RESULTS = 100_000; //Used to generate TimeStream queries using templates query patterns. private final QueryFactory queryFactory = new QueryFactory(); - private final AWSGlue glue; - private final AmazonTimestreamQuery tsQuery; - private final AmazonTimestreamWrite tsMeta; + private final GlueClient glue; + private final TimestreamQueryClient tsQuery; + private final TimestreamWriteClient tsMeta; private final TimestreamQueryPassthrough queryPassthrough; @@ -106,12 +106,12 @@ public TimestreamMetadataHandler(java.util.Map configOptions) @VisibleForTesting protected TimestreamMetadataHandler( - AmazonTimestreamQuery tsQuery, - AmazonTimestreamWrite tsMeta, - AWSGlue glue, + TimestreamQueryClient tsQuery, + TimestreamWriteClient tsMeta, + GlueClient glue, EncryptionKeyFactory keyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, java.util.Map configOptions) @@ -136,9 +136,9 @@ public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAlloca public ListSchemasResponse doListSchemaNames(BlockAllocator blockAllocator, ListSchemasRequest request) throws Exception { - List schemas = PaginatedRequestIterator.stream(this::doListSchemaNamesOnePage, ListDatabasesResult::getNextToken) - .flatMap(result -> result.getDatabases().stream()) - .map(db -> db.getDatabaseName()) + List schemas = PaginatedRequestIterator.stream(this::doListSchemaNamesOnePage, ListDatabasesResponse::nextToken) + .flatMap(result -> result.databases().stream()) + .map(Database::databaseName) .collect(Collectors.toList()); return new ListSchemasResponse( @@ -146,9 +146,9 @@ public ListSchemasResponse doListSchemaNames(BlockAllocator blockAllocator, List schemas); } - private ListDatabasesResult doListSchemaNamesOnePage(String nextToken) + private ListDatabasesResponse doListSchemaNamesOnePage(String nextToken) { - return tsMeta.listDatabases(new ListDatabasesRequest().withNextToken(nextToken)); + return tsMeta.listDatabases(ListDatabasesRequest.builder().nextToken(nextToken).build()); } @Override @@ -159,7 +159,7 @@ public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTables try { return doListTablesInternal(blockAllocator, request); } - catch (com.amazonaws.services.timestreamwrite.model.ResourceNotFoundException ex) { + catch (software.amazon.awssdk.services.timestreamwrite.model.ResourceNotFoundException ex) { // If it fails then we will retry after resolving the schema name by ignoring the casing String resolvedSchemaName = findSchemaNameIgnoringCase(request.getSchemaName()); request = new ListTablesRequest(request.getIdentity(), request.getQueryId(), request.getCatalogName(), resolvedSchemaName, request.getNextToken(), request.getPageSize()); @@ -191,43 +191,43 @@ private ListTablesResponse doListTablesInternal(BlockAllocator blockAllocator, L } // Otherwise don't retrieve all pages, just pass through the page token. - ListTablesResult timestreamResults = doListTablesOnePage(request.getSchemaName(), request.getNextToken()); - List tableNames = timestreamResults.getTables() + software.amazon.awssdk.services.timestreamwrite.model.ListTablesResponse timestreamResults = doListTablesOnePage(request.getSchemaName(), request.getNextToken()); + List tableNames = timestreamResults.tables() .stream() - .map(table -> new TableName(request.getSchemaName(), table.getTableName())) + .map(table -> new TableName(request.getSchemaName(), table.tableName())) .collect(Collectors.toList()); // Pass through whatever token we got from Glue to the user ListTablesResponse result = new ListTablesResponse( request.getCatalogName(), tableNames, - timestreamResults.getNextToken()); + timestreamResults.nextToken()); logger.debug("doListTables [paginated] result: {}", result); return result; } - private ListTablesResult doListTablesOnePage(String schemaName, String nextToken) + private software.amazon.awssdk.services.timestreamwrite.model.ListTablesResponse doListTablesOnePage(String schemaName, String nextToken) { // TODO: We should pass through the pageSize as withMaxResults(pageSize) - com.amazonaws.services.timestreamwrite.model.ListTablesRequest listTablesRequest = - new com.amazonaws.services.timestreamwrite.model.ListTablesRequest() - .withDatabaseName(schemaName) - .withNextToken(nextToken); + software.amazon.awssdk.services.timestreamwrite.model.ListTablesRequest listTablesRequest = software.amazon.awssdk.services.timestreamwrite.model.ListTablesRequest.builder() + .databaseName(schemaName) + .nextToken(nextToken) + .build(); return tsMeta.listTables(listTablesRequest); } private Stream getTableNamesInSchema(String schemaName) { - return PaginatedRequestIterator.stream((pageToken) -> doListTablesOnePage(schemaName, pageToken), ListTablesResult::getNextToken) - .flatMap(currResult -> currResult.getTables().stream()) - .map(table -> new TableName(schemaName, table.getTableName())); + return PaginatedRequestIterator.stream((pageToken) -> doListTablesOnePage(schemaName, pageToken), software.amazon.awssdk.services.timestreamwrite.model.ListTablesResponse::nextToken) + .flatMap(currResult -> currResult.tables().stream()) + .map(table -> new TableName(schemaName, table.tableName())); } private String findSchemaNameIgnoringCase(String schemaNameInsensitive) { - return PaginatedRequestIterator.stream(this::doListSchemaNamesOnePage, ListDatabasesResult::getNextToken) - .flatMap(result -> result.getDatabases().stream()) - .map(db -> db.getDatabaseName()) + return PaginatedRequestIterator.stream(this::doListSchemaNamesOnePage, ListDatabasesResponse::nextToken) + .flatMap(result -> result.databases().stream()) + .map(Database::databaseName) .filter(name -> name.equalsIgnoreCase(schemaNameInsensitive)) .findAny() .orElseThrow(() -> new RuntimeException(String.format("Could not find a case-insensitive match for schema name %s", schemaNameInsensitive))); @@ -238,9 +238,9 @@ private TableName findTableNameIgnoringCase(BlockAllocator blockAllocator, GetTa String caseInsenstiveSchemaNameMatch = findSchemaNameIgnoringCase(getTableRequest.getTableName().getSchemaName()); // based on AmazonMskMetadataHandler::findGlueRegistryNameIgnoringCasing - return PaginatedRequestIterator.stream((pageToken) -> doListTablesOnePage(caseInsenstiveSchemaNameMatch, pageToken), ListTablesResult::getNextToken) - .flatMap(result -> result.getTables().stream()) - .map(tbl -> new TableName(caseInsenstiveSchemaNameMatch, tbl.getTableName())) + return PaginatedRequestIterator.stream((pageToken) -> doListTablesOnePage(caseInsenstiveSchemaNameMatch, pageToken), software.amazon.awssdk.services.timestreamwrite.model.ListTablesResponse::nextToken) + .flatMap(result -> result.tables().stream()) + .map(tbl -> new TableName(caseInsenstiveSchemaNameMatch, tbl.tableName())) .filter(tbl -> tbl.getTableName().equalsIgnoreCase(getTableRequest.getTableName().getTableName())) .findAny() .orElseThrow(() -> new RuntimeException(String.format("Could not find a case-insensitive match for table name %s", getTableRequest.getTableName().getTableName()))); @@ -256,24 +256,24 @@ private Schema inferSchemaForTable(TableName tableName) logger.info("doGetTable: Retrieving schema for table[{}] from TimeStream using describeQuery[{}].", tableName, describeQuery); - QueryRequest queryRequest = new QueryRequest().withQueryString(describeQuery); + QueryRequest queryRequest = QueryRequest.builder().queryString(describeQuery).build(); SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); do { - QueryResult queryResult = tsQuery.query(queryRequest); - for (Row next : queryResult.getRows()) { - List datum = next.getData(); + QueryResponse queryResult = tsQuery.query(queryRequest); + for (Row next : queryResult.rows()) { + List datum = next.data(); if (datum.size() != 3) { throw new RuntimeException("Unexpected datum size " + datum.size() + " while getting schema from datum[" + datum.toString() + "]"); } - Field nextField = TimestreamSchemaUtils.makeField(datum.get(0).getScalarValue(), datum.get(1).getScalarValue()); + Field nextField = TimestreamSchemaUtils.makeField(datum.get(0).scalarValue(), datum.get(1).scalarValue()); schemaBuilder.addField(nextField); } - queryRequest = new QueryRequest().withNextToken(queryResult.getNextToken()); + queryRequest = QueryRequest.builder().nextToken(queryResult.nextToken()).build(); } - while (queryRequest.getNextToken() != null); + while (queryRequest.nextToken() != null); return schemaBuilder.build(); } @@ -300,7 +300,7 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques Schema schema = inferSchemaForTable(request.getTableName()); return new GetTableResponse(request.getCatalogName(), request.getTableName(), schema); } - catch (com.amazonaws.services.timestreamquery.model.ValidationException ex) { + catch (software.amazon.awssdk.services.timestreamquery.model.ValidationException ex) { logger.debug("Could not find table name matching {} in database {}. Falling back to case-insensitive lookup.", request.getTableName().getTableName(), request.getTableName().getSchemaName()); TableName resolvedTableName = findTableNameIgnoringCase(blockAllocator, request); @@ -319,13 +319,13 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge queryPassthrough.verify(request.getQueryPassthroughArguments()); String customerPassedQuery = request.getQueryPassthroughArguments().get(TimestreamQueryPassthrough.QUERY); - QueryRequest queryRequest = new QueryRequest().withQueryString(customerPassedQuery).withMaxRows(1); + QueryRequest queryRequest = QueryRequest.builder().queryString(customerPassedQuery).maxRows(1).build(); // Timestream Query does not provide a way to conduct a dry run or retrieve metadata results without execution. Therefore, we need to "seek" at least once before obtaining metadata. - QueryResult queryResult = tsQuery.query(queryRequest); - List columnInfo = queryResult.getColumnInfo(); + QueryResponse queryResult = tsQuery.query(queryRequest); + List columnInfo = queryResult.columnInfo(); SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); for (ColumnInfo column : columnInfo) { - Field nextField = TimestreamSchemaUtils.makeField(column.getName(), column.getType().getScalarType().toLowerCase()); + Field nextField = TimestreamSchemaUtils.makeField(column.name(), column.type().scalarTypeAsString().toLowerCase()); schemaBuilder.addField(nextField); } diff --git a/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandler.java b/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandler.java index 7c70089abc..f25b7d7b41 100644 --- a/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandler.java +++ b/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandler.java @@ -40,18 +40,6 @@ import com.amazonaws.athena.connectors.timestream.qpt.TimestreamQueryPassthrough; import com.amazonaws.athena.connectors.timestream.query.QueryFactory; import com.amazonaws.athena.connectors.timestream.query.SelectQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; -import com.amazonaws.services.timestreamquery.AmazonTimestreamQuery; -import com.amazonaws.services.timestreamquery.model.Datum; -import com.amazonaws.services.timestreamquery.model.QueryRequest; -import com.amazonaws.services.timestreamquery.model.QueryResult; -import com.amazonaws.services.timestreamquery.model.Row; -import com.amazonaws.services.timestreamquery.model.TimeSeriesDataPoint; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.holders.NullableBigIntHolder; @@ -62,6 +50,15 @@ import org.apache.arrow.vector.types.pojo.Field; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.timestreamquery.TimestreamQueryClient; +import software.amazon.awssdk.services.timestreamquery.model.Datum; +import software.amazon.awssdk.services.timestreamquery.model.QueryRequest; +import software.amazon.awssdk.services.timestreamquery.model.QueryResponse; +import software.amazon.awssdk.services.timestreamquery.model.Row; +import software.amazon.awssdk.services.timestreamquery.model.TimeSeriesDataPoint; import java.time.Instant; import java.time.ZoneId; @@ -88,21 +85,21 @@ public class TimestreamRecordHandler private static final String SOURCE_TYPE = "timestream"; private final QueryFactory queryFactory = new QueryFactory(); - private final AmazonTimestreamQuery tsQuery; + private final TimestreamQueryClient tsQuery; private final TimestreamQueryPassthrough queryPassthrough = new TimestreamQueryPassthrough(); public TimestreamRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), + S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), TimestreamClientBuilder.buildQueryClient(SOURCE_TYPE), configOptions); } @VisibleForTesting - protected TimestreamRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, AmazonTimestreamQuery tsQuery, java.util.Map configOptions) + protected TimestreamRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, TimestreamQueryClient tsQuery, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.tsQuery = tsQuery; @@ -138,15 +135,15 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor long numRows = 0; do { - QueryResult queryResult = tsQuery.query(new QueryRequest().withQueryString(query).withNextToken(nextToken)); - List data = queryResult.getRows(); + QueryResponse queryResult = tsQuery.query(QueryRequest.builder().queryString(query).nextToken(nextToken).build()); + List data = queryResult.rows(); if (data != null) { numRows += data.size(); for (Row nextRow : data) { spiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, nextRow) ? 1 : 0); } } - nextToken = queryResult.getNextToken(); + nextToken = queryResult.nextToken(); logger.info("readWithConstraint: numRows[{}]", numRows); } while (nextToken != null && !nextToken.isEmpty()); } @@ -161,7 +158,7 @@ private GeneratedRowWriter buildRowWriter(ReadRecordsRequest request) switch (Types.getMinorTypeForArrowType(nextField.getType())) { case VARCHAR: builder.withExtractor(nextField.getName(), (VarCharExtractor) (Object context, NullableVarCharHolder value) -> { - String stringValue = ((Row) context).getData().get(curFieldNum).getScalarValue(); + String stringValue = ((Row) context).data().get(curFieldNum).scalarValue(); if (stringValue != null) { value.isSet = 1; value.value = stringValue; @@ -173,7 +170,7 @@ private GeneratedRowWriter buildRowWriter(ReadRecordsRequest request) break; case FLOAT8: builder.withExtractor(nextField.getName(), (Float8Extractor) (Object context, NullableFloat8Holder value) -> { - String doubleValue = ((Row) context).getData().get(curFieldNum).getScalarValue(); + String doubleValue = ((Row) context).data().get(curFieldNum).scalarValue(); if (doubleValue != null) { value.isSet = 1; value.value = Double.valueOf(doubleValue); @@ -186,12 +183,12 @@ private GeneratedRowWriter buildRowWriter(ReadRecordsRequest request) case BIT: builder.withExtractor(nextField.getName(), (BitExtractor) (Object context, NullableBitHolder value) -> { value.isSet = 1; - value.value = Boolean.valueOf(((Row) context).getData().get(curFieldNum).getScalarValue()) == false ? 0 : 1; + value.value = Boolean.valueOf(((Row) context).data().get(curFieldNum).scalarValue()) == false ? 0 : 1; }); break; case BIGINT: builder.withExtractor(nextField.getName(), (BigIntExtractor) (Object context, NullableBigIntHolder value) -> { - String longValue = ((Row) context).getData().get(curFieldNum).getScalarValue(); + String longValue = ((Row) context).data().get(curFieldNum).scalarValue(); if (longValue != null) { value.isSet = 1; value.value = Long.valueOf(longValue); @@ -203,7 +200,7 @@ private GeneratedRowWriter buildRowWriter(ReadRecordsRequest request) break; case DATEMILLI: builder.withExtractor(nextField.getName(), (DateMilliExtractor) (Object context, NullableDateMilliHolder value) -> { - String dateMilliValue = ((Row) context).getData().get(curFieldNum).getScalarValue(); + String dateMilliValue = ((Row) context).data().get(curFieldNum).scalarValue(); if (dateMilliValue != null) { value.isSet = 1; value.value = Instant.from(TIMESTAMP_FORMATTER.parse(dateMilliValue)).toEpochMilli(); @@ -233,30 +230,30 @@ private void buildTimeSeriesExtractor(GeneratedRowWriter.RowWriterBuilder builde (FieldVector vector, Extractor extractor, ConstraintProjector constraint) -> (Object context, int rowNum) -> { Row row = (Row) context; - Datum datum = row.getData().get(curFieldNum); + Datum datum = row.data().get(curFieldNum); Field timeField = field.getChildren().get(0).getChildren().get(0); Field valueField = field.getChildren().get(0).getChildren().get(1); - if (datum.getTimeSeriesValue() != null) { + if (datum.timeSeriesValue() != null) { List> values = new ArrayList<>(); - for (TimeSeriesDataPoint nextDatum : datum.getTimeSeriesValue()) { + for (TimeSeriesDataPoint nextDatum : datum.timeSeriesValue()) { Map eventMap = new HashMap<>(); - eventMap.put(timeField.getName(), Instant.from(TIMESTAMP_FORMATTER.parse(nextDatum.getTime())).toEpochMilli()); + eventMap.put(timeField.getName(), Instant.from(TIMESTAMP_FORMATTER.parse(nextDatum.time())).toEpochMilli()); switch (Types.getMinorTypeForArrowType(valueField.getType())) { case FLOAT8: - eventMap.put(valueField.getName(), Double.valueOf(nextDatum.getValue().getScalarValue())); + eventMap.put(valueField.getName(), Double.valueOf(nextDatum.value().scalarValue())); break; case BIGINT: - eventMap.put(valueField.getName(), Long.valueOf(nextDatum.getValue().getScalarValue())); + eventMap.put(valueField.getName(), Long.valueOf(nextDatum.value().scalarValue())); break; case INT: - eventMap.put(valueField.getName(), Integer.valueOf(nextDatum.getValue().getScalarValue())); + eventMap.put(valueField.getName(), Integer.valueOf(nextDatum.value().scalarValue())); break; case BIT: eventMap.put(valueField.getName(), - Boolean.valueOf(((Row) context).getData().get(curFieldNum).getScalarValue()) == false ? 0 : 1); + Boolean.valueOf(((Row) context).data().get(curFieldNum).scalarValue()) == false ? 0 : 1); break; } values.add(eventMap); diff --git a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TestUtils.java b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TestUtils.java index b09fccbbbf..5656dccb58 100644 --- a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TestUtils.java +++ b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TestUtils.java @@ -19,32 +19,21 @@ */ package com.amazonaws.athena.connectors.timestream; -import com.amazonaws.athena.connector.lambda.data.BlockUtils; -import com.amazonaws.athena.connector.lambda.data.FieldResolver; -import com.amazonaws.athena.connector.lambda.data.writers.GeneratedRowWriter; -import com.amazonaws.athena.connector.lambda.data.writers.extractors.Extractor; -import com.amazonaws.athena.connector.lambda.domain.predicate.ConstraintProjector; -import com.amazonaws.services.timestreamquery.model.Datum; -import com.amazonaws.services.timestreamquery.model.QueryResult; -import com.amazonaws.services.timestreamquery.model.Row; -import com.amazonaws.services.timestreamquery.model.TimeSeriesDataPoint; -import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.mockito.stubbing.Answer; +import software.amazon.awssdk.services.timestreamquery.model.Datum; +import software.amazon.awssdk.services.timestreamquery.model.QueryResponse; +import software.amazon.awssdk.services.timestreamquery.model.Row; +import software.amazon.awssdk.services.timestreamquery.model.TimeSeriesDataPoint; -import java.text.SimpleDateFormat; import java.time.LocalDateTime; import java.util.ArrayList; -import java.util.Date; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Random; import java.util.concurrent.atomic.AtomicLong; -import static org.apache.arrow.vector.types.Types.MinorType.FLOAT8; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -59,17 +48,17 @@ private TestUtils() {} private static final String[] AZS = {"us-east-1a", "us-east-1b", "us-east-1c", "us-east-1d"}; - public static QueryResult makeMockQueryResult(Schema schemaForRead, int numRows) + public static QueryResponse makeMockQueryResult(Schema schemaForRead, int numRows) { return makeMockQueryResult(schemaForRead, numRows, 100, true); } - public static QueryResult makeMockQueryResult(Schema schemaForRead, int numRows, int maxDataGenerationRow, boolean isRandomAZ) + public static QueryResponse makeMockQueryResult(Schema schemaForRead, int numRows, int maxDataGenerationRow, boolean isRandomAZ) { - QueryResult mockResult = mock(QueryResult.class); + QueryResponse mockResult = mock(QueryResponse.class); final AtomicLong nextToken = new AtomicLong(0); - when(mockResult.getRows()).thenAnswer((Answer>) invocationOnMock -> { + when(mockResult.rows()).thenAnswer((Answer>) invocationOnMock -> { List rows = new ArrayList<>(); for (int i = 0; i < maxDataGenerationRow; i++) { nextToken.incrementAndGet(); @@ -78,15 +67,14 @@ public static QueryResult makeMockQueryResult(Schema schemaForRead, int numRows, columnData.add(makeValue(nextField, i, isRandomAZ)); } - Row row = new Row(); - row.setData(columnData); + Row row = Row.builder().data(columnData).build(); rows.add(row); } return rows; } ); - when(mockResult.getNextToken()).thenAnswer((Answer) invocationOnMock -> { + when(mockResult.nextToken()).thenAnswer((Answer) invocationOnMock -> { if (nextToken.get() < numRows) { return String.valueOf(nextToken.get()); } @@ -99,30 +87,30 @@ public static QueryResult makeMockQueryResult(Schema schemaForRead, int numRows, public static Datum makeValue(Field field, int num, boolean isRandomAZ) { - Datum datum = new Datum(); + Datum.Builder datum = Datum.builder(); switch (Types.getMinorTypeForArrowType(field.getType())) { case VARCHAR: if (field.getName().equals("az")) { - datum.setScalarValue(isRandomAZ ? AZS[RAND.nextInt(4)] : "us-east-1a"); + datum.scalarValue(isRandomAZ ? AZS[RAND.nextInt(4)] : "us-east-1a"); } else { - datum.setScalarValue(field.getName() + "_" + RAND.nextInt(10_000_000)); + datum.scalarValue(field.getName() + "_" + RAND.nextInt(10_000_000)); } break; case FLOAT8: - datum.setScalarValue(String.valueOf(RAND.nextFloat())); + datum.scalarValue(String.valueOf(RAND.nextFloat())); break; case INT: - datum.setScalarValue(String.valueOf(RAND.nextInt())); + datum.scalarValue(String.valueOf(RAND.nextInt())); break; case BIT: - datum.setScalarValue(String.valueOf(RAND.nextBoolean())); + datum.scalarValue(String.valueOf(RAND.nextBoolean())); break; case BIGINT: - datum.setScalarValue(String.valueOf(RAND.nextLong())); + datum.scalarValue(String.valueOf(RAND.nextLong())); break; case DATEMILLI: - datum.setScalarValue(startDate.plusDays(num).toString().replace('T', ' ')); + datum.scalarValue(startDate.plusDays(num).toString().replace('T', ' ')); break; case LIST: buildTimeSeries(field, datum, num); @@ -131,17 +119,17 @@ public static Datum makeValue(Field field, int num, boolean isRandomAZ) throw new RuntimeException("Unsupported field type[" + field.getType() + "] for field[" + field.getName() + "]"); } - return datum; + return datum.build(); } - private static void buildTimeSeries(Field field, Datum datum, int num) + private static void buildTimeSeries(Field field, Datum.Builder datum, int num) { List dataPoints = new ArrayList<>(); for (int i = 0; i < 10; i++) { - TimeSeriesDataPoint dataPoint = new TimeSeriesDataPoint(); - Datum dataPointValue = new Datum(); + TimeSeriesDataPoint.Builder dataPoint = TimeSeriesDataPoint.builder(); + Datum.Builder dataPointValue = Datum.builder(); - dataPoint.setTime(startDate.plusDays(num).toString().replace('T', ' ')); + dataPoint.time(startDate.plusDays(num).toString().replace('T', ' ')); /** * Presently we only support TimeSeries as LIST> @@ -152,22 +140,22 @@ private static void buildTimeSeries(Field field, Datum datum, int num) switch (Types.getMinorTypeForArrowType(baseSeriesType.getType())) { case FLOAT8: - dataPointValue.setScalarValue(String.valueOf(RAND.nextFloat())); + dataPointValue.scalarValue(String.valueOf(RAND.nextFloat())); break; case BIT: - dataPointValue.setScalarValue(String.valueOf(RAND.nextBoolean())); + dataPointValue.scalarValue(String.valueOf(RAND.nextBoolean())); break; case INT: - dataPointValue.setScalarValue(String.valueOf(RAND.nextInt())); + dataPointValue.scalarValue(String.valueOf(RAND.nextInt())); break; case BIGINT: - dataPointValue.setScalarValue(String.valueOf(RAND.nextLong())); + dataPointValue.scalarValue(String.valueOf(RAND.nextLong())); break; } - dataPoint.setValue(dataPointValue); - dataPoints.add(dataPoint); + dataPoint.value(dataPointValue.build()); + dataPoints.add(dataPoint.build()); } - datum.setTimeSeriesValue(dataPoints); + datum.timeSeriesValue(dataPoints); } } diff --git a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamClientBuilderTest.java b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamClientBuilderTest.java index c3c4d4a486..de4f83b0bb 100644 --- a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamClientBuilderTest.java +++ b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamClientBuilderTest.java @@ -19,8 +19,9 @@ */ package com.amazonaws.athena.connectors.timestream; -import com.amazonaws.ClientConfiguration; import org.junit.Test; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.client.config.SdkAdvancedClientOption; import static org.junit.Assert.assertEquals; @@ -29,7 +30,7 @@ public class TimestreamClientBuilderTest { @Test public void testUserAgentField() { - ClientConfiguration clientConfiguration = TimestreamClientBuilder.buildClientConfiguration("timestream"); - assertEquals("aws-athena-timestream-connector", clientConfiguration.getUserAgentPrefix()); + ClientOverrideConfiguration clientConfiguration = TimestreamClientBuilder.buildClientConfiguration("timestream"); + assertEquals("aws-athena-timestream-connector", clientConfiguration.advancedOption(SdkAdvancedClientOption.USER_AGENT_PREFIX).get()); } } diff --git a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamMetadataHandlerTest.java b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamMetadataHandlerTest.java index 45478f3b84..9262c45f68 100644 --- a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamMetadataHandlerTest.java +++ b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamMetadataHandlerTest.java @@ -40,23 +40,6 @@ import com.amazonaws.athena.connector.lambda.metadata.MetadataResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.timestreamquery.AmazonTimestreamQuery; -import com.amazonaws.services.timestreamquery.model.Datum; -import com.amazonaws.services.timestreamquery.model.QueryRequest; -import com.amazonaws.services.timestreamquery.model.QueryResult; -import com.amazonaws.services.timestreamquery.model.Row; -import com.amazonaws.services.timestreamwrite.AmazonTimestreamWrite; -import com.amazonaws.services.timestreamwrite.model.Database; -import com.amazonaws.services.timestreamwrite.model.ListDatabasesRequest; -import com.amazonaws.services.timestreamwrite.model.ListDatabasesResult; -import com.amazonaws.services.timestreamwrite.model.ListTablesResult; -import com.amazonaws.services.timestreamwrite.model.Table; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -69,6 +52,21 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.glue.GlueClient; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.timestreamquery.TimestreamQueryClient; +import software.amazon.awssdk.services.timestreamquery.model.Datum; +import software.amazon.awssdk.services.timestreamquery.model.QueryRequest; +import software.amazon.awssdk.services.timestreamquery.model.QueryResponse; +import software.amazon.awssdk.services.timestreamquery.model.Row; +import software.amazon.awssdk.services.timestreamwrite.TimestreamWriteClient; +import software.amazon.awssdk.services.timestreamwrite.model.Database; +import software.amazon.awssdk.services.timestreamwrite.model.ListDatabasesRequest; +import software.amazon.awssdk.services.timestreamwrite.model.ListDatabasesResponse; +import software.amazon.awssdk.services.timestreamwrite.model.Table; import java.util.ArrayList; import java.util.Collections; @@ -78,10 +76,9 @@ import static com.amazonaws.athena.connector.lambda.domain.predicate.Constraints.DEFAULT_NO_LIMIT; import static com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler.VIEW_METADATA_FIELD; import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -97,15 +94,15 @@ public class TimestreamMetadataHandlerTest private BlockAllocator allocator; @Mock - protected AWSSecretsManager mockSecretsManager; + protected SecretsManagerClient mockSecretsManager; @Mock - protected AmazonAthena mockAthena; + protected AthenaClient mockAthena; @Mock - protected AmazonTimestreamQuery mockTsQuery; + protected TimestreamQueryClient mockTsQuery; @Mock - protected AmazonTimestreamWrite mockTsMeta; + protected TimestreamWriteClient mockTsMeta; @Mock - protected AWSGlue mockGlue; + protected GlueClient mockGlue; @Before public void setUp() @@ -142,26 +139,26 @@ public void doListSchemaNames() String newNextToken = null; List databases = new ArrayList<>(); - if (request.getNextToken() == null) { + if (request.nextToken() == null) { for (int i = 0; i < 10; i++) { - databases.add(new Database().withDatabaseName("database_" + i)); + databases.add(Database.builder().databaseName("database_" + i).build()); } newNextToken = "1"; } - else if (request.getNextToken().equals("1")) { + else if (request.nextToken().equals("1")) { for (int i = 10; i < 100; i++) { - databases.add(new Database().withDatabaseName("database_" + i)); + databases.add(Database.builder().databaseName("database_" + i).build()); } newNextToken = "2"; } - else if (request.getNextToken().equals("2")) { + else if (request.nextToken().equals("2")) { for (int i = 100; i < 1000; i++) { - databases.add(new Database().withDatabaseName("database_" + i)); + databases.add(Database.builder().databaseName("database_" + i).build()); } newNextToken = null; } - return new ListDatabasesResult().withDatabases(databases).withNextToken(newNextToken); + return ListDatabasesResponse.builder().databases(databases).nextToken(newNextToken).build(); }); ListSchemasRequest req = new ListSchemasRequest(identity, "queryId", "default"); @@ -184,33 +181,33 @@ public void doListTables() { logger.info("doListTables - enter"); - when(mockTsMeta.listTables(nullable(com.amazonaws.services.timestreamwrite.model.ListTablesRequest.class))) + when(mockTsMeta.listTables(nullable(software.amazon.awssdk.services.timestreamwrite.model.ListTablesRequest.class))) .thenAnswer((InvocationOnMock invocation) -> { - com.amazonaws.services.timestreamwrite.model.ListTablesRequest request = - invocation.getArgument(0, com.amazonaws.services.timestreamwrite.model.ListTablesRequest.class); + software.amazon.awssdk.services.timestreamwrite.model.ListTablesRequest request = + invocation.getArgument(0, software.amazon.awssdk.services.timestreamwrite.model.ListTablesRequest.class); String newNextToken = null; List
tables = new ArrayList<>(); - if (request.getNextToken() == null) { + if (request.nextToken() == null) { for (int i = 0; i < 10; i++) { - tables.add(new Table().withDatabaseName(request.getDatabaseName()).withTableName("table_" + i)); + tables.add(Table.builder().databaseName(request.databaseName()).tableName("table_" + i).build()); } newNextToken = "1"; } - else if (request.getNextToken().equals("1")) { + else if (request.nextToken().equals("1")) { for (int i = 10; i < 100; i++) { - tables.add(new Table().withDatabaseName(request.getDatabaseName()).withTableName("table_" + i)); + tables.add(Table.builder().databaseName(request.databaseName()).tableName("table_" + i).build()); } newNextToken = "2"; } - else if (request.getNextToken().equals("2")) { + else if (request.nextToken().equals("2")) { for (int i = 100; i < 1000; i++) { - tables.add(new Table().withDatabaseName(request.getDatabaseName()).withTableName("table_" + i)); + tables.add(Table.builder().databaseName(request.databaseName()).tableName("table_" + i).build()); } newNextToken = null; } - return new ListTablesResult().withTables(tables).withNextToken(newNextToken); + return software.amazon.awssdk.services.timestreamwrite.model.ListTablesResponse.builder().tables(tables).nextToken(newNextToken).build(); }); ListTablesRequest req = new ListTablesRequest(identity, "queryId", "default", defaultSchema, @@ -220,7 +217,7 @@ else if (request.getNextToken().equals("2")) { assertEquals(1000, res.getTables().size()); verify(mockTsMeta, times(3)) - .listTables(nullable(com.amazonaws.services.timestreamwrite.model.ListTablesRequest.class)); + .listTables(nullable(software.amazon.awssdk.services.timestreamwrite.model.ListTablesRequest.class)); Iterator schemaItr = res.getTables().iterator(); for (int i = 0; i < 1000; i++) { @@ -238,29 +235,29 @@ public void doGetTable() { logger.info("doGetTable - enter"); - when(mockGlue.getTable(nullable(com.amazonaws.services.glue.model.GetTableRequest.class))) - .thenReturn(mock(GetTableResult.class)); + when(mockGlue.getTable(nullable(software.amazon.awssdk.services.glue.model.GetTableRequest.class))) + .thenReturn(software.amazon.awssdk.services.glue.model.GetTableResponse.builder().build()); when(mockTsQuery.query(nullable(QueryRequest.class))).thenAnswer((InvocationOnMock invocation) -> { QueryRequest request = invocation.getArgument(0, QueryRequest.class); - assertEquals("DESCRIBE \"default\".\"table1\"", request.getQueryString()); + assertEquals("DESCRIBE \"default\".\"table1\"", request.queryString()); List rows = new ArrayList<>(); //TODO: Add types here - rows.add(new Row().withData(new Datum().withScalarValue("availability_zone"), - new Datum().withScalarValue("varchar"), - new Datum().withScalarValue("dimension"))); - rows.add(new Row().withData(new Datum().withScalarValue("measure_value"), - new Datum().withScalarValue("double"), - new Datum().withScalarValue("measure_value"))); - rows.add(new Row().withData(new Datum().withScalarValue("measure_name"), - new Datum().withScalarValue("varchar"), - new Datum().withScalarValue("measure_name"))); - rows.add(new Row().withData(new Datum().withScalarValue("time"), - new Datum().withScalarValue("timestamp"), - new Datum().withScalarValue("timestamp"))); - - return new QueryResult().withRows(rows); + rows.add(Row.builder().data(Datum.builder().scalarValue("availability_zone").build(), + Datum.builder().scalarValue("varchar").build(), + Datum.builder().scalarValue("dimension").build()).build()); + rows.add(Row.builder().data(Datum.builder().scalarValue("measure_value").build(), + Datum.builder().scalarValue("double").build(), + Datum.builder().scalarValue("measure_value").build()).build()); + rows.add(Row.builder().data(Datum.builder().scalarValue("measure_name").build(), + Datum.builder().scalarValue("varchar").build(), + Datum.builder().scalarValue("measure_name").build()).build()); + rows.add(Row.builder().data(Datum.builder().scalarValue("time").build(), + Datum.builder().scalarValue("timestamp").build(), + Datum.builder().scalarValue("timestamp").build()).build()); + + return QueryResponse.builder().rows(rows).build(); }); GetTableRequest req = new GetTableRequest(identity, @@ -294,23 +291,25 @@ public void doGetTableGlue() { logger.info("doGetTable - enter"); - when(mockGlue.getTable(nullable(com.amazonaws.services.glue.model.GetTableRequest.class))).thenAnswer((InvocationOnMock invocation) -> { - com.amazonaws.services.glue.model.GetTableRequest request = invocation.getArgument(0, - com.amazonaws.services.glue.model.GetTableRequest.class); + when(mockGlue.getTable(nullable(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenAnswer((InvocationOnMock invocation) -> { + software.amazon.awssdk.services.glue.model.GetTableRequest request = invocation.getArgument(0, + software.amazon.awssdk.services.glue.model.GetTableRequest.class); List columns = new ArrayList<>(); - columns.add(new Column().withName("col1").withType("varchar")); - columns.add(new Column().withName("col2").withType("double")); - com.amazonaws.services.glue.model.Table table = new com.amazonaws.services.glue.model.Table(); - table.setName(request.getName()); - table.setDatabaseName(request.getDatabaseName()); - StorageDescriptor storageDescriptor = new StorageDescriptor(); - storageDescriptor.setColumns(columns); - table.setStorageDescriptor(storageDescriptor); - table.setViewOriginalText("view text"); - table.setParameters(Collections.singletonMap("timestream-metadata-flag", "timestream-metadata-flag")); - - return new GetTableResult().withTable(table); + columns.add(Column.builder().name("col1").type("varchar").build()); + columns.add(Column.builder().name("col2").type("double").build()); + StorageDescriptor storageDescriptor = StorageDescriptor.builder() + .columns(columns) + .build(); + software.amazon.awssdk.services.glue.model.Table table = software.amazon.awssdk.services.glue.model.Table.builder() + .name(request.name()) + .databaseName(request.databaseName()) + .storageDescriptor(storageDescriptor) + .viewOriginalText("view text") + .parameters(Collections.singletonMap("timestream-metadata-flag", "timestream-metadata-flag")) + .build(); + + return software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); }); GetTableRequest req = new GetTableRequest(identity, @@ -340,25 +339,25 @@ public void doGetTimeSeriesTableGlue() { logger.info("doGetTimeSeriesTableGlue - enter"); - when(mockGlue.getTable(nullable(com.amazonaws.services.glue.model.GetTableRequest.class))).thenAnswer((InvocationOnMock invocation) -> { - com.amazonaws.services.glue.model.GetTableRequest request = invocation.getArgument(0, - com.amazonaws.services.glue.model.GetTableRequest.class); + when(mockGlue.getTable(nullable(software.amazon.awssdk.services.glue.model.GetTableRequest.class))).thenAnswer((InvocationOnMock invocation) -> { + software.amazon.awssdk.services.glue.model.GetTableRequest request = invocation.getArgument(0, + software.amazon.awssdk.services.glue.model.GetTableRequest.class); List columns = new ArrayList<>(); - columns.add(new Column().withName("az").withType("varchar")); - columns.add(new Column().withName("hostname").withType("varchar")); - columns.add(new Column().withName("region").withType("varchar")); - columns.add(new Column().withName("cpu_utilization").withType("ARRAY>")); - com.amazonaws.services.glue.model.Table table = new com.amazonaws.services.glue.model.Table(); - table.setName(request.getName()); - table.setDatabaseName(request.getDatabaseName()); - StorageDescriptor storageDescriptor = new StorageDescriptor(); - storageDescriptor.setColumns(columns); - table.setStorageDescriptor(storageDescriptor); - table.setViewOriginalText("SELECT az, hostname, region, cpu_utilization FROM TIMESERIES(metrics_table,'cpu_utilization')"); - table.setParameters(Collections.singletonMap("timestream-metadata-flag", "timestream-metadata-flag")); - - return new GetTableResult().withTable(table); + columns.add(Column.builder().name("az").type("varchar").build()); + columns.add(Column.builder().name("hostname").type("varchar").build()); + columns.add(Column.builder().name("region").type("varchar").build()); + columns.add(Column.builder().name("cpu_utilization").type("ARRAY>").build()); + StorageDescriptor storageDescriptor = StorageDescriptor.builder().columns(columns).build(); + software.amazon.awssdk.services.glue.model.Table table = software.amazon.awssdk.services.glue.model.Table.builder() + .name(request.name()) + .databaseName(request.databaseName()) + .storageDescriptor(storageDescriptor) + .viewOriginalText("SELECT az, hostname, region, cpu_utilization FROM TIMESERIES(metrics_table,'cpu_utilization')") + .parameters(Collections.singletonMap("timestream-metadata-flag", "timestream-metadata-flag")) + .build(); + + return software.amazon.awssdk.services.glue.model.GetTableResponse.builder().table(table).build(); }); GetTableRequest req = new GetTableRequest(identity, diff --git a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandlerTest.java b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandlerTest.java index 1682b4ef53..f3daeaff80 100644 --- a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandlerTest.java +++ b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandlerTest.java @@ -40,16 +40,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.timestreamquery.AmazonTimestreamQuery; -import com.amazonaws.services.timestreamquery.model.QueryRequest; -import com.amazonaws.services.timestreamquery.model.QueryResult; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; @@ -66,6 +56,18 @@ import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.timestreamquery.TimestreamQueryClient; +import software.amazon.awssdk.services.timestreamquery.model.QueryRequest; +import software.amazon.awssdk.services.timestreamquery.model.QueryResponse; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -100,7 +102,7 @@ public class TimestreamRecordHandlerTest private TimestreamRecordHandler handler; private BlockAllocator allocator; private List mockS3Storage = new ArrayList<>(); - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpillReader spillReader; private Schema schemaForRead; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -113,13 +115,13 @@ public class TimestreamRecordHandlerTest public TestName testName = new TestName(); @Mock - private AmazonTimestreamQuery mockClient; + private TimestreamQueryClient mockClient; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; private class ByteHolder { @@ -144,31 +146,29 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); + logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); + logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); schemaForRead = SchemaBuilder.newBuilder() @@ -198,11 +198,11 @@ public void doReadRecordsNoSpill() int numRowsGenerated = 1_000; String expectedQuery = "SELECT measure_name, measure_value::double, az, time, hostname, region FROM \"my_schema\".\"my_table\" WHERE (\"az\" IN ('us-east-1a','us-east-1b'))"; - QueryResult mockResult = makeMockQueryResult(schemaForRead, numRowsGenerated); + QueryResponse mockResult = makeMockQueryResult(schemaForRead, numRowsGenerated); when(mockClient.query(nullable(QueryRequest.class))) - .thenAnswer((Answer) invocationOnMock -> { + .thenAnswer((Answer) invocationOnMock -> { QueryRequest request = (QueryRequest) invocationOnMock.getArguments()[0]; - assertEquals(expectedQuery, request.getQueryString().replace("\n", "")); + assertEquals(expectedQuery, request.queryString().replace("\n", "")); return mockResult; } ); @@ -253,11 +253,11 @@ public void doReadRecordsSpill() { String expectedQuery = "SELECT measure_name, measure_value::double, az, time, hostname, region FROM \"my_schema\".\"my_table\" WHERE (\"az\" IN ('us-east-1a','us-east-1b'))"; - QueryResult mockResult = makeMockQueryResult(schemaForRead, 100_000); + QueryResponse mockResult = makeMockQueryResult(schemaForRead, 100_000); when(mockClient.query(nullable(QueryRequest.class))) - .thenAnswer((Answer) invocationOnMock -> { + .thenAnswer((Answer) invocationOnMock -> { QueryRequest request = (QueryRequest) invocationOnMock.getArguments()[0]; - assertEquals(expectedQuery, request.getQueryString().replace("\n", "")); + assertEquals(expectedQuery, request.queryString().replace("\n", "")); return mockResult; } ); @@ -327,11 +327,11 @@ public void readRecordsView() String expectedQuery = "WITH t1 AS ( select measure_name, az,sum(\"measure_value::double\") as value, count(*) as num_samples from \"my_schema\".\"my_table\" group by measure_name, az ) SELECT measure_name, az, value, num_samples FROM t1 WHERE (\"az\" IN ('us-east-1a','us-east-1b'))"; - QueryResult mockResult = makeMockQueryResult(schemaForReadView, 1_000); + QueryResponse mockResult = makeMockQueryResult(schemaForReadView, 1_000); when(mockClient.query(nullable(QueryRequest.class))) - .thenAnswer((Answer) invocationOnMock -> { + .thenAnswer((Answer) invocationOnMock -> { QueryRequest request = (QueryRequest) invocationOnMock.getArguments()[0]; - assertEquals(expectedQuery, request.getQueryString().replace("\n", "")); + assertEquals(expectedQuery, request.queryString().replace("\n", "")); return mockResult; } ); @@ -394,11 +394,11 @@ public void readRecordsTimeSeriesView() String expectedQuery = "WITH t1 AS ( select az, hostname, region, CREATE_TIME_SERIES(time, measure_value::double) as cpu_utilization from \"my_schema\".\"my_table\" WHERE measure_name = 'cpu_utilization' GROUP BY measure_name, az, hostname, region ) SELECT region, az, hostname, cpu_utilization FROM t1 WHERE (\"az\" IN ('us-east-1a','us-east-1b'))"; - QueryResult mockResult = makeMockQueryResult(schemaForReadView, 1_000); + QueryResponse mockResult = makeMockQueryResult(schemaForReadView, 1_000); when(mockClient.query(nullable(QueryRequest.class))) - .thenAnswer((Answer) invocationOnMock -> { + .thenAnswer((Answer) invocationOnMock -> { QueryRequest request = (QueryRequest) invocationOnMock.getArguments()[0]; - assertEquals("actual: " + request.getQueryString(), expectedQuery, request.getQueryString().replace("\n", "")); + assertEquals("actual: " + request.queryString(), expectedQuery, request.queryString().replace("\n", "")); return mockResult; } ); @@ -449,11 +449,11 @@ public void doReadRecordsNoSpillValidateTimeStamp() int numRows = 10; String expectedQuery = "SELECT measure_name, measure_value::double, az, time, hostname, region FROM \"my_schema\".\"my_table\" WHERE (\"az\" IN ('us-east-1a'))"; - QueryResult mockResult = makeMockQueryResult(schemaForRead, numRows, numRows, false); + QueryResponse mockResult = makeMockQueryResult(schemaForRead, numRows, numRows, false); when(mockClient.query(nullable(QueryRequest.class))) - .thenAnswer((Answer) invocationOnMock -> { + .thenAnswer((Answer) invocationOnMock -> { QueryRequest request = (QueryRequest) invocationOnMock.getArguments()[0]; - assertEquals(expectedQuery, request.getQueryString().replace("\n", "")); + assertEquals(expectedQuery, request.queryString().replace("\n", "")); return mockResult; } ); diff --git a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/integ/TimestreamIntegTest.java b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/integ/TimestreamIntegTest.java index 0fce4624bc..10f6575220 100644 --- a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/integ/TimestreamIntegTest.java +++ b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/integ/TimestreamIntegTest.java @@ -21,11 +21,6 @@ import com.amazonaws.athena.connector.integ.IntegrationTestBase; import com.amazonaws.athena.connectors.timestream.TimestreamClientBuilder; -import com.amazonaws.services.athena.model.Row; -import com.amazonaws.services.timestreamwrite.AmazonTimestreamWrite; -import com.amazonaws.services.timestreamwrite.model.CreateTableRequest; -import com.amazonaws.services.timestreamwrite.model.DeleteTableRequest; -import com.amazonaws.services.timestreamwrite.model.MeasureValueType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.slf4j.Logger; @@ -38,6 +33,11 @@ import software.amazon.awscdk.services.iam.PolicyDocument; import software.amazon.awscdk.services.iam.PolicyStatement; import software.amazon.awscdk.services.timestream.CfnDatabase; +import software.amazon.awssdk.services.athena.model.Row; +import software.amazon.awssdk.services.timestreamwrite.TimestreamWriteClient; +import software.amazon.awssdk.services.timestreamwrite.model.CreateTableRequest; +import software.amazon.awssdk.services.timestreamwrite.model.DeleteTableRequest; +import software.amazon.awssdk.services.timestreamwrite.model.MeasureValueType; import java.util.ArrayList; import java.util.List; @@ -58,7 +58,7 @@ public class TimestreamIntegTest extends IntegrationTestBase private final String jokePunchline; private final String lambdaFunctionName; private final long[] timeStream; - private final AmazonTimestreamWrite timestreamWriteClient; + private final TimestreamWriteClient timestreamWriteClient; public TimestreamIntegTest() { @@ -123,9 +123,9 @@ private void createTimestreamTable() logger.info("Creating the Timestream table: {}", timestreamTableName); logger.info("----------------------------------------------------"); - timestreamWriteClient.createTable(new CreateTableRequest() - .withDatabaseName(timestreamDbName) - .withTableName(timestreamTableName)); + timestreamWriteClient.createTable(CreateTableRequest.builder() + .databaseName(timestreamDbName) + .tableName(timestreamTableName).build()); } /** @@ -138,16 +138,16 @@ private void deleteTimstreamTable() logger.info("----------------------------------------------------"); try { - timestreamWriteClient.deleteTable(new DeleteTableRequest() - .withDatabaseName(timestreamDbName) - .withTableName(timestreamTableName)); + timestreamWriteClient.deleteTable(DeleteTableRequest.builder() + .databaseName(timestreamDbName) + .tableName(timestreamTableName).build()); } catch (Exception e) { // Do not rethrow here. logger.error("Unable to delete Timestream table: " + e.getMessage(), e); } finally { - timestreamWriteClient.shutdown(); + timestreamWriteClient.close(); } } @@ -295,13 +295,13 @@ public void selectColumnWithPredicateIntegTest() String query = String.format("select conversation from \"%s\".\"%s\".\"%s\" where subject = '%s' order by time desc limit 1;", lambdaFunctionName, timestreamDbName, timestreamTableName, jokeProtagonist); - List rows = startQueryExecution(query).getResultSet().getRows(); + List rows = startQueryExecution(query).resultSet().rows(); if (!rows.isEmpty()) { // Remove the column-header row rows.remove(0); } List conversation = new ArrayList<>(); - rows.forEach(row -> conversation.add(row.getData().get(0).getVarCharValue())); + rows.forEach(row -> conversation.add(row.data().get(0).varCharValue())); logger.info("conversation: {}", conversation); assertEquals("Wrong number of DB records found.", 1, conversation.size()); assertTrue("Did not find correct conversation: " + jokePunchline, conversation.contains(jokePunchline)); diff --git a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/integ/TimestreamWriteRecordRequestBuilder.java b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/integ/TimestreamWriteRecordRequestBuilder.java index 6c9b8acddc..73a9e63bd5 100644 --- a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/integ/TimestreamWriteRecordRequestBuilder.java +++ b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/integ/TimestreamWriteRecordRequestBuilder.java @@ -19,11 +19,11 @@ */ package com.amazonaws.athena.connectors.timestream.integ; -import com.amazonaws.services.timestreamwrite.model.Dimension; -import com.amazonaws.services.timestreamwrite.model.MeasureValueType; -import com.amazonaws.services.timestreamwrite.model.Record; -import com.amazonaws.services.timestreamwrite.model.TimeUnit; -import com.amazonaws.services.timestreamwrite.model.WriteRecordsRequest; +import software.amazon.awssdk.services.timestreamwrite.model.Dimension; +import software.amazon.awssdk.services.timestreamwrite.model.MeasureValueType; +import software.amazon.awssdk.services.timestreamwrite.model.Record; +import software.amazon.awssdk.services.timestreamwrite.model.TimeUnit; +import software.amazon.awssdk.services.timestreamwrite.model.WriteRecordsRequest; import java.util.ArrayList; import java.util.List; @@ -104,14 +104,14 @@ public TimestreamWriteRecordRequestBuilder withRecord(Map column long timeMillis) { List dimensions = new ArrayList<>(); - columns.forEach((k, v) -> dimensions.add(new Dimension().withName(k).withValue(v))); - records.add(new Record() - .withDimensions(dimensions) - .withMeasureName(measureName) - .withMeasureValue(measureValue) - .withMeasureValueType(measureValueType) - .withTime(String.valueOf(timeMillis)) - .withTimeUnit(TimeUnit.MILLISECONDS)); + columns.forEach((k, v) -> dimensions.add(Dimension.builder().name(k).value(v).build())); + records.add(Record.builder() + .dimensions(dimensions) + .measureName(measureName) + .measureValue(measureValue) + .measureValueType(measureValueType) + .time(String.valueOf(timeMillis)) + .timeUnit(TimeUnit.MILLISECONDS).build()); return this; } @@ -121,9 +121,9 @@ public TimestreamWriteRecordRequestBuilder withRecord(Map column */ public WriteRecordsRequest build() { - return new WriteRecordsRequest() - .withDatabaseName(databaseName) - .withTableName(tableName) - .withRecords(records); + return WriteRecordsRequest.builder() + .databaseName(databaseName) + .tableName(tableName) + .records(records).build(); } } diff --git a/athena-tpcds/Dockerfile b/athena-tpcds/Dockerfile new file mode 100644 index 0000000000..7c4d31ffa1 --- /dev/null +++ b/athena-tpcds/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-tpcds-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-tpcds-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.tpcds.TPCDSCompositeHandler" ] \ No newline at end of file diff --git a/athena-tpcds/athena-tpcds.yaml b/athena-tpcds/athena-tpcds.yaml index 219e0a7dd2..4086565f7e 100644 --- a/athena-tpcds/athena-tpcds.yaml +++ b/athena-tpcds/athena-tpcds.yaml @@ -52,10 +52,9 @@ Resources: spill_bucket: !Ref SpillBucket spill_prefix: !Ref SpillPrefix FunctionName: !Ref AthenaCatalogName - Handler: "com.amazonaws.athena.connectors.tpcds.TPCDSCompositeHandler" - CodeUri: "./target/athena-tpcds-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-tpcds:2022.47.1' Description: "This connector enables Amazon Athena to communicate with a randomly generated TPC-DS data source." - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSMetadataHandler.java b/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSMetadataHandler.java index 42f714b262..878b1266e2 100644 --- a/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSMetadataHandler.java +++ b/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSMetadataHandler.java @@ -41,8 +41,6 @@ import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connectors.tpcds.qpt.TPCDSQueryPassthrough; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.teradata.tpcds.Table; @@ -50,6 +48,8 @@ import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.HashSet; @@ -98,8 +98,8 @@ public TPCDSMetadataHandler(java.util.Map configOptions) @VisibleForTesting protected TPCDSMetadataHandler( EncryptionKeyFactory keyFactory, - AWSSecretsManager secretsManager, - AmazonAthena athena, + SecretsManagerClient secretsManager, + AthenaClient athena, String spillBucket, String spillPrefix, java.util.Map configOptions) diff --git a/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandler.java b/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandler.java index 9a3cf97e5f..6260165190 100644 --- a/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandler.java +++ b/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandler.java @@ -25,12 +25,6 @@ import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; import com.teradata.tpcds.Results; import com.teradata.tpcds.Session; import com.teradata.tpcds.Table; @@ -41,6 +35,9 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; import java.math.BigDecimal; @@ -76,11 +73,11 @@ public class TPCDSRecordHandler public TPCDSRecordHandler(java.util.Map configOptions) { - super(AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(), AmazonAthenaClientBuilder.defaultClient(), SOURCE_TYPE, configOptions); + super(S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), SOURCE_TYPE, configOptions); } @VisibleForTesting - protected TPCDSRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena, java.util.Map configOptions) + protected TPCDSRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); } diff --git a/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSMetadataHandlerTest.java b/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSMetadataHandlerTest.java index 7e39835921..e88a55b6e2 100644 --- a/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSMetadataHandlerTest.java +++ b/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSMetadataHandlerTest.java @@ -42,8 +42,6 @@ import com.amazonaws.athena.connector.lambda.metadata.MetadataResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.After; @@ -54,6 +52,8 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; import java.util.HashMap; @@ -76,10 +76,10 @@ public class TPCDSMetadataHandlerTest private BlockAllocator allocator; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Before public void setUp() diff --git a/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandlerTest.java b/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandlerTest.java index 2fc214cee8..a13b453c55 100644 --- a/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandlerTest.java +++ b/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandlerTest.java @@ -41,13 +41,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableMap; import com.google.common.io.ByteStreams; import com.teradata.tpcds.Table; @@ -63,6 +56,15 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -98,13 +100,13 @@ public class TPCDSRecordHandlerTest private Schema schemaForRead; @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock - private AWSSecretsManager mockSecretsManager; + private SecretsManagerClient mockSecretsManager; @Mock - private AmazonAthena mockAthena; + private AthenaClient mockAthena; @Before public void setUp() @@ -127,30 +129,28 @@ public void setUp() handler = new TPCDSRecordHandler(mockS3, mockSecretsManager, mockAthena, com.google.common.collect.ImmutableMap.of()); spillReader = new S3BlockSpillReader(mockS3, allocator); - when(mockS3.putObject(any())) + when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); + ByteHolder byteHolder = new ByteHolder(); + byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); - ByteHolder byteHolder = new ByteHolder(); - byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); mockS3Storage.add(byteHolder); - return mock(PutObjectResult.class); + logger.info("puObject: total size " + mockS3Storage.size()); } + return PutObjectResponse.builder().build(); }); - when(mockS3.getObject(nullable(String.class), nullable(String.class))) - .thenAnswer((InvocationOnMock invocationOnMock) -> - { + when(mockS3.getObject(any(GetObjectRequest.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + ByteHolder byteHolder; synchronized (mockS3Storage) { - S3Object mockObject = mock(S3Object.class); - ByteHolder byteHolder = mockS3Storage.get(0); + byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + logger.info("getObject: total size " + mockS3Storage.size()); } + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); } diff --git a/athena-udfs/Dockerfile b/athena-udfs/Dockerfile new file mode 100644 index 0000000000..d18b85ae78 --- /dev/null +++ b/athena-udfs/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-udfs-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-udfs-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.udfs.AthenaUDFHandler" ] \ No newline at end of file diff --git a/athena-udfs/athena-udfs.yaml b/athena-udfs/athena-udfs.yaml index 64fd2f54ef..a6f04ad182 100644 --- a/athena-udfs/athena-udfs.yaml +++ b/athena-udfs/athena-udfs.yaml @@ -39,10 +39,9 @@ Resources: Type: 'AWS::Serverless::Function' Properties: FunctionName: !Ref LambdaFunctionName - Handler: "com.amazonaws.athena.connectors.udfs.AthenaUDFHandler" - CodeUri: "./target/athena-udfs-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-udfs:2022.47.1' Description: "This connector enables Amazon Athena to leverage common UDFs made available via Lambda." - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-udfs/src/main/java/com/amazonaws/athena/connectors/udfs/AthenaUDFHandler.java b/athena-udfs/src/main/java/com/amazonaws/athena/connectors/udfs/AthenaUDFHandler.java index 3c90b60e99..3d9be11baa 100644 --- a/athena-udfs/src/main/java/com/amazonaws/athena/connectors/udfs/AthenaUDFHandler.java +++ b/athena-udfs/src/main/java/com/amazonaws/athena/connectors/udfs/AthenaUDFHandler.java @@ -21,8 +21,8 @@ import com.amazonaws.athena.connector.lambda.handlers.UserDefinedFunctionHandler; import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClient; import com.google.common.annotations.VisibleForTesting; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import javax.crypto.BadPaddingException; import javax.crypto.Cipher; @@ -55,7 +55,7 @@ public class AthenaUDFHandler public AthenaUDFHandler() { - this(new CachableSecretsManager(AWSSecretsManagerClient.builder().build())); + this(new CachableSecretsManager(SecretsManagerClient.create())); } @VisibleForTesting diff --git a/athena-vertica/Dockerfile b/athena-vertica/Dockerfile new file mode 100644 index 0000000000..c06ed8b9c2 --- /dev/null +++ b/athena-vertica/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/java:11 + +# Copy function code and runtime dependencies from Maven layout +COPY target/athena-vertica-2022.47.1.jar ${LAMBDA_TASK_ROOT} +# Unpack the jar +RUN jar xf athena-vertica-2022.47.1.jar + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "com.amazonaws.athena.connectors.vertica.VerticaCompositeHandler" ] \ No newline at end of file diff --git a/athena-vertica/athena-vertica.yaml b/athena-vertica/athena-vertica.yaml index 39fcb916cd..a0fed75e01 100644 --- a/athena-vertica/athena-vertica.yaml +++ b/athena-vertica/athena-vertica.yaml @@ -82,10 +82,9 @@ Resources: default: !Ref VerticaConnectionString FunctionName: !Sub "${AthenaCatalogName}" - Handler: "com.amazonaws.athena.connectors.vertica.VerticaCompositeHandler" - CodeUri: "./target/athena-vertica-2022.47.1.jar" + PackageType: "Image" + ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-vertica:2022.47.1' Description: "Amazon Athena Vertica Connector" - Runtime: java11 Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory PermissionsBoundary: !If [ HasPermissionsBoundary, !Ref PermissionsBoundaryARN, !Ref "AWS::NoValue" ] diff --git a/athena-vertica/pom.xml b/athena-vertica/pom.xml index b5b086e449..4e8543e54c 100644 --- a/athena-vertica/pom.xml +++ b/athena-vertica/pom.xml @@ -22,6 +22,11 @@ + + net.java.dev.jna + jna-platform + 5.14.0 + org.slf4j slf4j-api @@ -32,6 +37,11 @@ jcl-over-slf4j ${slf4j-log4j.version} + + org.apache.arrow + arrow-dataset + ${apache.arrow.version} + org.apache.logging.log4j log4j-slf4j2-impl diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaCompositeHandler.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaCompositeHandler.java index 7ccbb0f34d..07467897b2 100644 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaCompositeHandler.java +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaCompositeHandler.java @@ -21,6 +21,14 @@ import com.amazonaws.athena.connector.lambda.handlers.CompositeHandler; +import java.io.IOException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateEncodingException; + +import static com.amazonaws.athena.connectors.vertica.VerticaSchemaUtils.installCaCertificate; +import static com.amazonaws.athena.connectors.vertica.VerticaSchemaUtils.setupNativeEnvironmentVariables; + /** * Boilerplate composite handler that allows us to use a single Lambda function for both * Metadata and Data. @@ -28,8 +36,10 @@ public class VerticaCompositeHandler extends CompositeHandler { - public VerticaCompositeHandler() + public VerticaCompositeHandler() throws CertificateEncodingException, IOException, NoSuchAlgorithmException, KeyStoreException { super(new VerticaMetadataHandler(System.getenv()), new VerticaRecordHandler(System.getenv())); + installCaCertificate(); + setupNativeEnvironmentVariables(); } } diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java index dbcf85e8a5..a12d790501 100644 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java @@ -24,6 +24,17 @@ public final class VerticaConstants public static final String VERTICA_NAME = "vertica"; public static final String VERTICA_DRIVER_CLASS = "com.vertica.jdbc.Driver"; public static final int VERTICA_DEFAULT_PORT = 5433; + public static final String VERTICA_SPLIT_QUERY_ID = "query_id"; + public static final String VERTICA_SPLIT_EXPORT_BUCKET = "exportBucket"; + public static final String VERTICA_SPLIT_OBJECT_KEY = "s3ObjectKey"; + + /** + * A ssl file location constant to store the SSL certificate + * The file location is fixed at /tmp directory + * to retrieve ssl certificate location + */ + public static final String SSL_CERT_FILE_LOCATION = "SSL_CERT_FILE"; + public static final String SSL_CERT_FILE_LOCATION_VALUE = "/tmp/cacert.pem"; private VerticaConstants() {} } diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java index ee40632659..da52327eaf 100644 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java @@ -20,7 +20,6 @@ package com.amazonaws.athena.connectors.vertica; -import com.amazonaws.SdkClientException; import com.amazonaws.athena.connector.lambda.QueryStatusChecker; import com.amazonaws.athena.connector.lambda.data.Block; import com.amazonaws.athena.connector.lambda.data.BlockAllocator; @@ -48,11 +47,6 @@ import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough; import com.amazonaws.athena.connectors.vertica.query.QueryFactory; import com.amazonaws.athena.connectors.vertica.query.VerticaExportQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.ListObjectsRequest; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3ObjectSummary; import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -62,6 +56,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.stringtemplate.v4.ST; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -80,6 +80,9 @@ import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_DEFAULT_PORT; import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_DRIVER_CLASS; import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_NAME; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_EXPORT_BUCKET; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_OBJECT_KEY; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_QUERY_ID; import static com.amazonaws.athena.connectors.vertica.VerticaSchemaUtils.convertToArrowType; @@ -100,7 +103,7 @@ public class VerticaMetadataHandler private static final String[] TABLE_TYPES = {"TABLE"}; private final QueryFactory queryFactory = new QueryFactory(); private final VerticaSchemaUtils verticaSchemaUtils; - private AmazonS3 amazonS3; + private S3Client amazonS3; private final JdbcQueryPassthrough queryPassthrough = new JdbcQueryPassthrough(); @@ -117,11 +120,11 @@ public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, Map configOptions) { super(databaseConnectionConfig, jdbcConnectionFactory, configOptions); - amazonS3 = AmazonS3ClientBuilder.defaultClient(); + amazonS3 = S3Client.create(); verticaSchemaUtils = new VerticaSchemaUtils(); } @VisibleForTesting - public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, Map configOptions, AmazonS3 amazonS3, VerticaSchemaUtils verticaSchemaUtils) + public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, Map configOptions, S3Client amazonS3, VerticaSchemaUtils verticaSchemaUtils) { super(databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.amazonS3 = amazonS3; @@ -298,8 +301,8 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request } logger.info("Vertica Export Statement: {}", preparedSQLStmt); - // Build the Set AWS Region SQL - String awsRegionSql = queryBuilder.buildSetAwsRegionSql(amazonS3.getRegion().toString()); + // Build the Set AWS Region SQL - Assumes using the default region provider chain + String awsRegionSql = queryBuilder.buildSetAwsRegionSql(DefaultAwsRegionProviderChain.builder().build().getRegion().toString()); // write the prepared SQL statement to the partition column created in enhancePartitionSchema blockWriter.writeRows((Block block, int rowNum) ->{ @@ -374,16 +377,16 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest * For each generated S3 object, create a split and add data to the split. */ Split split; - List s3ObjectSummaries = getlistExportedObjects(exportBucket, queryId); + List s3ObjectsList = getlistExportedObjects(exportBucket, queryId); - if(!s3ObjectSummaries.isEmpty()) + if(!s3ObjectsList.isEmpty()) { - for (S3ObjectSummary objectSummary : s3ObjectSummaries) + for (S3Object s3Object : s3ObjectsList) { split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()) - .add("query_id", queryID) - .add("exportBucket", exportBucket) - .add("s3ObjectKey", objectSummary.getKey()) + .add(VERTICA_SPLIT_QUERY_ID, queryID) + .add(VERTICA_SPLIT_EXPORT_BUCKET, exportBucket) + .add(VERTICA_SPLIT_OBJECT_KEY, s3Object.key()) .build(); splits.add(split); @@ -395,9 +398,9 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest //No records were exported by Vertica for the issued query, creating a "empty" split logger.info("No records were exported by Vertica"); split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()) - .add("query_id", queryID) - .add("exportBucket", exportBucket) - .add("s3ObjectKey", EMPTY_STRING) + .add(VERTICA_SPLIT_QUERY_ID, queryID) + .add(VERTICA_SPLIT_EXPORT_BUCKET, exportBucket) + .add(VERTICA_SPLIT_OBJECT_KEY, EMPTY_STRING) .build(); splits.add(split); return new GetSplitsResponse(catalogName,split); @@ -428,17 +431,20 @@ private void executeQueriesOnVertica(Connection connection, String sqlStatement, /* * Get the list of all the exported S3 objects */ - private List getlistExportedObjects(String s3ExportBucket, String queryId){ - ObjectListing objectListing; + private List getlistExportedObjects(String s3ExportBucket, String queryId){ + ListObjectsResponse listObjectsResponse; try { - objectListing = amazonS3.listObjects(new ListObjectsRequest().withBucketName(s3ExportBucket).withPrefix(queryId)); + listObjectsResponse = amazonS3.listObjects(ListObjectsRequest.builder() + .bucket(s3ExportBucket) + .prefix(queryId) + .build()); } catch (SdkClientException e) { throw new RuntimeException("Exception listing the exported objects : " + e.getMessage(), e); } - return objectListing.getObjectSummaries(); + return listObjectsResponse.contents(); } private void testAccess(Connection conn, TableName table) { diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandler.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandler.java index b315d0e454..29bec641d6 100644 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandler.java +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandler.java @@ -32,28 +32,34 @@ import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.*; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; -import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.arrow.dataset.file.FileFormat; +import org.apache.arrow.dataset.file.FileSystemDatasetFactory; +import org.apache.arrow.dataset.jni.NativeMemoryPool; +import org.apache.arrow.dataset.scanner.ScanOptions; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.dataset.source.Dataset; +import org.apache.arrow.dataset.source.DatasetFactory; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.VisibleForTesting; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.holders.*; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; + +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_EXPORT_BUCKET; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_OBJECT_KEY; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_QUERY_ID; -import java.io.BufferedReader; import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; import java.math.BigDecimal; -import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.time.LocalDateTime; import java.util.HashMap; @@ -63,22 +69,18 @@ public class VerticaRecordHandler extends RecordHandler { private static final Logger logger = LoggerFactory.getLogger(VerticaRecordHandler.class); private static final String SOURCE_TYPE = "vertica"; - private static final String VERTICA_QUOTE_CHARACTER = "\""; - private static final String QUERY = "select * from S3Object s"; - private AmazonS3 amazonS3; public VerticaRecordHandler(java.util.Map configOptions) { - this(AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), configOptions); + this(S3Client.create(), + SecretsManagerClient.create(), + AthenaClient.create(), configOptions); } @VisibleForTesting - protected VerticaRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena amazonAthena, java.util.Map configOptions) + protected VerticaRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, java.util.Map configOptions) { super(amazonS3, secretsManager, amazonAthena, SOURCE_TYPE, configOptions); - this.amazonS3 = amazonS3; } /** @@ -102,9 +104,9 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor Schema schemaName = recordsRequest.getSchema(); Split split = recordsRequest.getSplit(); - String id = split.getProperty("query_id"); - String exportBucket = split.getProperty("exportBucket"); - String s3ObjectKey = split.getProperty("s3ObjectKey"); + String id = split.getProperty(VERTICA_SPLIT_QUERY_ID); + String exportBucket = split.getProperty(VERTICA_SPLIT_EXPORT_BUCKET); + String s3ObjectKey = split.getProperty(VERTICA_SPLIT_OBJECT_KEY); if(!s3ObjectKey.isEmpty()) { //get column name and type from the Schema @@ -129,25 +131,25 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor } GeneratedRowWriter rowWriter = builder.build(); - /* - Using S3 Select to read the S3 Parquet file generated in the split - */ - //Creating the read Request - SelectObjectContentRequest request = generateBaseParquetRequest(exportBucket, s3ObjectKey); - try (SelectObjectContentResult result = amazonS3.selectObjectContent(request)) { - InputStream resultInputStream = result.getPayload().getRecordsInputStream(); - BufferedReader streamReader = new BufferedReader(new InputStreamReader(resultInputStream, StandardCharsets.UTF_8)); - String inputStr; - while ((inputStr = streamReader.readLine()) != null) { - HashMap map = new HashMap<>(); - //we are reading the parquet files, but serializing the output it as JSON as SDK provides a Parquet InputSerialization, but only a JSON or CSV OutputSerializatio - ObjectMapper objectMapper = new ObjectMapper(); - map = objectMapper.readValue(inputStr, HashMap.class); - rowContext.setNameValue(map); - - //Passing the RowContext to BlockWriter; - spiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, rowContext) ? 1 : 0); + /* + Using Arrow Dataset to read the S3 Parquet file generated in the split + */ + try (ArrowReader reader = constructArrowReader(constructS3Uri(exportBucket, s3ObjectKey))) + { + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + for (int row = 0; row < root.getRowCount(); row++) { + HashMap map = new HashMap<>(); + for (Field field : root.getSchema().getFields()) { + map.put(field.getName(), root.getVector(field).getObject(row)); + } + rowContext.setNameValue(map); + + //Passing the RowContext to BlockWriter; + spiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, rowContext) ? 1 : 0); + } } + reader.close(); } catch (Exception e) { throw new RuntimeException("Error in connecting to S3 and selecting the object content for object : " + s3ObjectKey, e); } @@ -331,28 +333,24 @@ public HashMap getNameValue() { } } - - /* - Method to create the Parquet read request - */ - private static SelectObjectContentRequest generateBaseParquetRequest(String bucket, String key) + @VisibleForTesting + protected ArrowReader constructArrowReader(String uri) { - SelectObjectContentRequest request = new SelectObjectContentRequest(); - request.setBucketName(bucket); - request.setKey(key); - request.setExpression(VerticaRecordHandler.QUERY); - request.setExpressionType(ExpressionType.SQL); - - InputSerialization inputSerialization = new InputSerialization(); - inputSerialization.setParquet(new ParquetInput()); - inputSerialization.setCompressionType(CompressionType.NONE); - request.setInputSerialization(inputSerialization); - - OutputSerialization outputSerialization = new OutputSerialization(); - outputSerialization.setJson(new JSONOutput()); - request.setOutputSerialization(outputSerialization); + BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = new FileSystemDatasetFactory( + allocator, + NativeMemoryPool.getDefault(), + FileFormat.PARQUET, + uri); + Dataset dataset = datasetFactory.finish(); + ScanOptions options = new ScanOptions(/*batchSize*/ 32768); + Scanner scanner = dataset.newScan(options); + return scanner.scanBatches(); + } - return request; + private static String constructS3Uri(String bucket, String key) + { + return "s3://" + bucket + "/" + key; } } diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaSchemaUtils.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaSchemaUtils.java index 6939f8d2bb..547a01f26a 100644 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaSchemaUtils.java +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaSchemaUtils.java @@ -21,19 +21,39 @@ import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.sun.jna.platform.unix.LibC; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import java.io.FileWriter; +import java.io.IOException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.Base64; + +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.SSL_CERT_FILE_LOCATION; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.SSL_CERT_FILE_LOCATION_VALUE; public class VerticaSchemaUtils { private static final Logger logger = LoggerFactory.getLogger(VerticaSchemaUtils.class); + private static final String BEGIN_CERT = "-----BEGIN CERTIFICATE-----"; + private static final String END_CERT = "-----END CERTIFICATE-----"; + private static final String LINE_SEPARATOR = System.getProperty("line.separator"); + //Builds the table schema protected Schema buildTableSchema(Connection connection, TableName name) { @@ -125,4 +145,45 @@ public static void convertToArrowType(SchemaBuilder tableSchemaBuilder, String c tableSchemaBuilder.addStringField(colName); } } + + /** + * Write out the cacerts that we trust from the default java truststore. + * + */ + public static void installCaCertificate() throws IOException, NoSuchAlgorithmException, KeyStoreException, CertificateEncodingException + { + FileWriter caBundleWriter = new FileWriter(SSL_CERT_FILE_LOCATION_VALUE); + try { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + for (TrustManager trustManager : trustManagerFactory.getTrustManagers()) { + X509TrustManager x509TrustManager = (X509TrustManager) trustManager; + for (X509Certificate x509Certificate : x509TrustManager.getAcceptedIssuers()) { + caBundleWriter.write(formatCrtFileContents(x509Certificate)); + caBundleWriter.write(LINE_SEPARATOR); + } + } + } + finally { + caBundleWriter.close(); + } + } + + private static String formatCrtFileContents(Certificate certificate) throws CertificateEncodingException + { + Base64.Encoder encoder = Base64.getMimeEncoder(64, LINE_SEPARATOR.getBytes()); + byte[] rawCrtText = certificate.getEncoded(); + String encodedCertText = new String(encoder.encode(rawCrtText)); + String prettifiedCert = BEGIN_CERT + LINE_SEPARATOR + encodedCertText + LINE_SEPARATOR + END_CERT; + return prettifiedCert; + } + + public static void setupNativeEnvironmentVariables() + { + LibC.INSTANCE.setenv(SSL_CERT_FILE_LOCATION, SSL_CERT_FILE_LOCATION_VALUE, 1); + if (logger.isDebugEnabled()) { + logger.debug("Set native environment variables: {}: {}", + SSL_CERT_FILE_LOCATION, LibC.INSTANCE.getenv(SSL_CERT_FILE_LOCATION)); + } + } } diff --git a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java index 2a5adad8a4..48091b59e0 100644 --- a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java +++ b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java @@ -47,15 +47,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.vertica.query.QueryFactory; import com.amazonaws.athena.connectors.vertica.query.VerticaExportQueryBuilder; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.ListObjectsRequest; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.Region; -import com.amazonaws.services.s3.model.S3ObjectSummary; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; import com.google.common.collect.ImmutableList; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.After; @@ -69,6 +60,14 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.stringtemplate.v4.ST; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -103,9 +102,9 @@ public class VerticaMetadataHandlerTest extends TestBase private VerticaExportQueryBuilder verticaExportQueryBuilder; private VerticaSchemaUtils verticaSchemaUtils; private Connection connection; - private AWSSecretsManager secretsManager; - private AmazonAthena athena; - private AmazonS3 amazonS3; + private SecretsManagerClient secretsManager; + private AthenaClient athena; + private S3Client amazonS3; private FederatedIdentity federatedIdentity; private BlockAllocatorImpl allocator; private DatabaseMetaData databaseMetaData; @@ -117,11 +116,7 @@ public class VerticaMetadataHandlerTest extends TestBase private QueryStatusChecker queryStatusChecker; private VerticaMetadataHandler verticaMetadataHandlerMocked; @Mock - private AmazonS3 s3clientMock; - @Mock - private ListObjectsRequest listObjectsRequest; - @Mock - private ObjectListing objectListing; + private S3Client s3clientMock; private DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", VERTICA_NAME, "vertica://jdbc:vertica:thin:username/password@//127.0.0.1:1521/vrt"); @@ -134,8 +129,8 @@ public void setUp() throws Exception this.queryFactory = Mockito.mock(QueryFactory.class); this.verticaExportQueryBuilder = Mockito.mock(VerticaExportQueryBuilder.class); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.databaseMetaData = Mockito.mock(DatabaseMetaData.class); this.tableName = Mockito.mock(TableName.class); @@ -144,17 +139,16 @@ public void setUp() throws Exception this.schemaBuilder = Mockito.mock(SchemaBuilder.class); this.blockWriter = Mockito.mock(BlockWriter.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); - Mockito.lenient().when(this.secretsManager.getSecretValue(Mockito.eq(new GetSecretValueRequest().withSecretId("testSecret")))).thenReturn(new GetSecretValueResult().withSecretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}")); + Mockito.lenient().when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); Mockito.when(connection.getMetaData()).thenReturn(databaseMetaData); - Mockito.when(amazonS3.getRegion()).thenReturn(Region.US_West_2); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.secretsManager = Mockito.mock(AWSSecretsManager.class); - this.athena = Mockito.mock(AmazonAthena.class); + this.secretsManager = Mockito.mock(SecretsManagerClient.class); + this.athena = Mockito.mock(AthenaClient.class); this.verticaMetadataHandler = new VerticaMetadataHandler(databaseConnectionConfig, this.jdbcConnectionFactory, com.google.common.collect.ImmutableMap.of(), amazonS3, verticaSchemaUtils); this.federatedIdentity = Mockito.mock(FederatedIdentity.class); this.allocator = new BlockAllocatorImpl(); @@ -344,21 +338,13 @@ public void doGetSplits() throws Exception BlockUtils.setValue(partitions.getFieldVector("awsRegionSql"), i, "us-west-2"); } - List s3ObjectSummariesList = new ArrayList<>(); - S3ObjectSummary s3ObjectSummary = new S3ObjectSummary(); - s3ObjectSummary.setBucketName("s3ExportBucket"); - s3ObjectSummary.setKey("testKey"); - s3ObjectSummariesList.add(s3ObjectSummary); - ListObjectsRequest listObjectsRequestObj = new ListObjectsRequest(); - listObjectsRequestObj.setBucketName("s3ExportBucket"); - listObjectsRequestObj.setPrefix("queryId"); - + List objectList = new ArrayList<>(); + S3Object obj = S3Object.builder().key("testKey").build(); + objectList.add(obj); + ListObjectsResponse listObjectsResponse = ListObjectsResponse.builder().contents(objectList).build(); Mockito.when(verticaMetadataHandlerMocked.getS3ExportBucket()).thenReturn("testS3Bucket"); - Mockito.lenient().when(listObjectsRequest.withBucketName(nullable(String.class))).thenReturn(listObjectsRequestObj); - Mockito.lenient().when(listObjectsRequest.withPrefix(nullable(String.class))).thenReturn(listObjectsRequestObj); - Mockito.when(amazonS3.listObjects(nullable(ListObjectsRequest.class))).thenReturn(objectListing); - Mockito.when(objectListing.getObjectSummaries()).thenReturn(s3ObjectSummariesList); + Mockito.when(amazonS3.listObjects(nullable(ListObjectsRequest.class))).thenReturn(listObjectsResponse); GetSplitsRequest originalReq = new GetSplitsRequest(this.federatedIdentity, "queryId", "catalog_name", new TableName("schema", "table_name"), diff --git a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandlerTest.java b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandlerTest.java new file mode 100644 index 0000000000..b6ec304ad3 --- /dev/null +++ b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandlerTest.java @@ -0,0 +1,349 @@ +/*- + * #%L + * athena-gcs + * %% + * Copyright (C) 2019 - 2022 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.vertica; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnitRunner; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.amazonaws.athena.connector.lambda.data.Block; +import com.amazonaws.athena.connector.lambda.data.BlockAllocator; +import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl; +import com.amazonaws.athena.connector.lambda.data.BlockUtils; +import com.amazonaws.athena.connector.lambda.data.S3BlockSpillReader; +import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; +import com.amazonaws.athena.connector.lambda.domain.Split; +import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; +import com.amazonaws.athena.connector.lambda.domain.predicate.Range; +import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; +import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; +import com.amazonaws.athena.connector.lambda.domain.spill.S3SpillLocation; +import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation; +import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; +import com.amazonaws.athena.connector.lambda.records.ReadRecordsResponse; +import com.amazonaws.athena.connector.lambda.records.RecordResponse; +import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse; +import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; +import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; +import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.io.ByteStreams; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; + +import static com.amazonaws.athena.connector.lambda.domain.predicate.Constraints.DEFAULT_NO_LIMIT; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_EXPORT_BUCKET; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_OBJECT_KEY; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_QUERY_ID; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) + +public class VerticaRecordHandlerTest + extends TestBase +{ + private static final Logger logger = LoggerFactory.getLogger(VerticaRecordHandlerTest.class); + + private VerticaRecordHandler handler; + private BlockAllocator allocator; + private List mockS3Storage = new ArrayList<>(); + private S3BlockSpillReader spillReader; + private FederatedIdentity identity = new FederatedIdentity("arn", "account", Collections.emptyMap(), Collections.emptyList()); + private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); + + private static final BufferAllocator bufferAllocator = new RootAllocator(); + + @Rule + public TestName testName = new TestName(); + + @Mock + private S3Client mockS3; + + @Mock + private SecretsManagerClient mockSecretsManager; + + @Mock + private AthenaClient mockAthena; + + @Before + public void setup() + { + logger.info("{}: enter", testName.getMethodName()); + + allocator = new BlockAllocatorImpl(); + handler = new VerticaRecordHandler(mockS3, mockSecretsManager, mockAthena, com.google.common.collect.ImmutableMap.of()); + spillReader = new S3BlockSpillReader(mockS3, allocator); + + Mockito.lenient().when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); + ByteHolder byteHolder = new ByteHolder(); + byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); + synchronized (mockS3Storage) { + mockS3Storage.add(byteHolder); + logger.info("puObject: total size " + mockS3Storage.size()); + } + return PutObjectResponse.builder().build(); + }); + + Mockito.lenient().when(mockS3.getObject(any(GetObjectRequest.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + ByteHolder byteHolder; + synchronized (mockS3Storage) { + byteHolder = mockS3Storage.get(0); + mockS3Storage.remove(0); + logger.info("getObject: total size " + mockS3Storage.size()); + } + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); + }); + } + + @After + public void after() + { + allocator.close(); + logger.info("{}: exit ", testName.getMethodName()); + } + + @Test + public void doReadRecordsNoSpill() + throws Exception + { + logger.info("doReadRecordsNoSpill: enter"); + + VectorSchemaRoot schemaRoot = createRoot(); + ArrowReader mockReader = mock(ArrowReader.class); + when(mockReader.loadNextBatch()).thenReturn(true, false); + when(mockReader.getVectorSchemaRoot()).thenReturn(schemaRoot); + VerticaRecordHandler handlerSpy = spy(handler); + doReturn(mockReader).when(handlerSpy).constructArrowReader(any()); + + Map constraintsMap = new HashMap<>(); + constraintsMap.put("time", SortedRangeSet.copyOf(Types.MinorType.BIGINT.getType(), + ImmutableList.of(Range.equal(allocator, Types.MinorType.BIGINT.getType(), 100L)), false)); + + S3SpillLocation splitLoc = S3SpillLocation.newBuilder() + .withBucket(UUID.randomUUID().toString()) + .withSplitId(UUID.randomUUID().toString()) + .withQueryId(UUID.randomUUID().toString()) + .withIsDirectory(true) + .build(); + + Split.Builder splitBuilder = Split.newBuilder(splitLoc, keyFactory.create()) + .add(VERTICA_SPLIT_QUERY_ID, "query_id") + .add(VERTICA_SPLIT_EXPORT_BUCKET, "export_bucket") + .add(VERTICA_SPLIT_OBJECT_KEY, "s3_object_key"); + + ReadRecordsRequest request = new ReadRecordsRequest(identity, + DEFAULT_CATALOG, + QUERY_ID, + TABLE_NAME, + schemaRoot.getSchema(), + splitBuilder.build(), + new Constraints(constraintsMap, Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT), + 100_000_000_000L, + 100_000_000_000L//100GB don't expect this to spill + ); + RecordResponse rawResponse = handlerSpy.doReadRecords(allocator, request); + + assertTrue(rawResponse instanceof ReadRecordsResponse); + + ReadRecordsResponse response = (ReadRecordsResponse) rawResponse; + logger.info("doReadRecordsNoSpill: rows[{}]", response.getRecordCount()); + + assertTrue(response.getRecords().getRowCount() == 2); + logger.info("doReadRecordsNoSpill: {}", BlockUtils.rowToString(response.getRecords(), 0)); + logger.info("doReadRecordsNoSpill: {}", BlockUtils.rowToString(response.getRecords(), 1)); + + for (Field field : schemaRoot.getSchema().getFields()) { + assertTrue(response.getRecords().getFieldVector(field.getName()).getObject(0).equals(schemaRoot.getVector(field).getObject(0))); + assertTrue(response.getRecords().getFieldVector(field.getName()).getObject(1).equals(schemaRoot.getVector(field).getObject(1))); + } + + logger.info("doReadRecordsNoSpill: exit"); + } + + @Test + public void doReadRecordsSpill() + throws Exception + { + logger.info("doReadRecordsSpill: enter"); + + VectorSchemaRoot schemaRoot = createRoot(); + ArrowReader mockReader = mock(ArrowReader.class); + when(mockReader.loadNextBatch()).thenReturn(true, false); + when(mockReader.getVectorSchemaRoot()).thenReturn(schemaRoot); + VerticaRecordHandler handlerSpy = spy(handler); + doReturn(mockReader).when(handlerSpy).constructArrowReader(any()); + + Map constraintsMap = new HashMap<>(); + constraintsMap.put("time", SortedRangeSet.copyOf(Types.MinorType.BIGINT.getType(), + ImmutableList.of(Range.equal(allocator, Types.MinorType.BIGINT.getType(), 100L)), false)); + + S3SpillLocation splitLoc = S3SpillLocation.newBuilder() + .withBucket(UUID.randomUUID().toString()) + .withSplitId(UUID.randomUUID().toString()) + .withQueryId(UUID.randomUUID().toString()) + .withIsDirectory(true) + .build(); + + Split.Builder splitBuilder = Split.newBuilder(splitLoc, keyFactory.create()) + .add(VERTICA_SPLIT_QUERY_ID, "query_id") + .add(VERTICA_SPLIT_EXPORT_BUCKET, "export_bucket") + .add(VERTICA_SPLIT_OBJECT_KEY, "s3_object_key"); + + ReadRecordsRequest request = new ReadRecordsRequest(identity, + DEFAULT_CATALOG, + QUERY_ID, + TABLE_NAME, + schemaRoot.getSchema(), + splitBuilder.build(), + new Constraints(constraintsMap, Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT), + 1_500_000L, //~1.5MB so we should see some spill + 0L + ); + RecordResponse rawResponse = handlerSpy.doReadRecords(allocator, request); + + assertTrue(rawResponse instanceof RemoteReadRecordsResponse); + + try (RemoteReadRecordsResponse response = (RemoteReadRecordsResponse) rawResponse) { + logger.info("doReadRecordsSpill: remoteBlocks[{}]", response.getRemoteBlocks().size()); + + //assertTrue(response.getNumberBlocks() > 1); + + int blockNum = 0; + for (SpillLocation next : response.getRemoteBlocks()) { + S3SpillLocation spillLocation = (S3SpillLocation) next; + try (Block block = spillReader.read(spillLocation, response.getEncryptionKey(), response.getSchema())) { + + logger.info("doReadRecordsSpill: blockNum[{}] and recordCount[{}]", blockNum++, block.getRowCount()); + // assertTrue(++blockNum < response.getRemoteBlocks().size() && block.getRowCount() > 10_000); + + logger.info("doReadRecordsSpill: {}", BlockUtils.rowToString(block, 0)); + assertNotNull(BlockUtils.rowToString(block, 0)); + } + } + } + + logger.info("doReadRecordsSpill: exit"); + } + + private class ByteHolder + { + private byte[] bytes; + + public void setBytes(byte[] bytes) + { + this.bytes = bytes; + } + + public byte[] getBytes() + { + return bytes; + } + } + + private VectorSchemaRoot createRoot() + { + Schema schema = SchemaBuilder.newBuilder() + .addBigIntField("day") + .addBigIntField("month") + .addBigIntField("year") + .addStringField("preparedStmt") + .addStringField("queryId") + .addStringField("awsRegionSql") + .build(); + VectorSchemaRoot schemaRoot = VectorSchemaRoot.create(schema, bufferAllocator); + BigIntVector dayVector = (BigIntVector) schemaRoot.getVector("day"); + dayVector.allocateNew(2); + dayVector.set(0, 0); + dayVector.set(1, 1); + dayVector.setValueCount(2); + BigIntVector monthVector = (BigIntVector) schemaRoot.getVector("month"); + monthVector.allocateNew(2); + monthVector.set(0, 0); + monthVector.set(1, 1); + monthVector.setValueCount(2); + BigIntVector yearVector = (BigIntVector) schemaRoot.getVector("year"); + yearVector.allocateNew(2); + yearVector.set(0, 2000); + yearVector.set(1, 2001); + yearVector.setValueCount(2); + VarCharVector stmtVector = (VarCharVector) schemaRoot.getVector("preparedStmt"); + stmtVector.allocateNew(2); + stmtVector.set(0, new Text("test1")); + stmtVector.set(1, new Text("test2")); + stmtVector.setValueCount(2); + VarCharVector idVector = (VarCharVector) schemaRoot.getVector("queryId"); + idVector.allocateNew(2); + idVector.set(0, new Text("queryID1")); + idVector.set(1, new Text("queryID2")); + idVector.setValueCount(2); + VarCharVector regionVector = (VarCharVector) schemaRoot.getVector("awsRegionSql"); + regionVector.allocateNew(2); + regionVector.set(0, new Text("region1")); + regionVector.set(1, new Text("region2")); + regionVector.setValueCount(2); + schemaRoot.setRowCount(2); + return schemaRoot; + } +} diff --git a/pom.xml b/pom.xml index 025b0bcc0d..210c00139d 100644 --- a/pom.xml +++ b/pom.xml @@ -14,7 +14,7 @@ 11 3.13.0 - 1.12.773 + 2.28.9 1.2.2 1.6.0 1.204.0 diff --git a/tools/bump_versions/bump_connectors_version.py b/tools/bump_versions/bump_connectors_version.py index d478fd1c3e..d6cd78edd4 100755 --- a/tools/bump_versions/bump_connectors_version.py +++ b/tools/bump_versions/bump_connectors_version.py @@ -49,3 +49,7 @@ # Bump the versions in the yaml files yaml_files = glob.glob(f"{connector}/*.yaml") + glob.glob(f"{connector}/*.yml") common.update_yaml(yaml_files, new_version) + + # Bump the versions in the Dockerfiles + dockerfiles = glob.glob("Dockerfile") + common.update_dockerfile(dockerfiles, new_version) diff --git a/tools/bump_versions/common.py b/tools/bump_versions/common.py index 40ba70be79..bec31d3038 100755 --- a/tools/bump_versions/common.py +++ b/tools/bump_versions/common.py @@ -36,6 +36,13 @@ def update_yaml(yaml_files, new_version): for yml in yaml_files: subprocess.run(["sed", "-i", f"s/\(SemanticVersion:\s*\).*/\\1{new_version}/", yml]) subprocess.run(["sed", "-i", f"s/\(CodeUri:.*-\)[0-9]*\.[0-9]*\.[0-9]*\(-\?.*\.jar\)/\\1{new_version}\\2/", yml]) + subprocess.run(["sed", "-i", f"s/\(ImageUri:.*:\)[0-9]*\.[0-9]*\.[0-9]*\(\'\)/\\1{new_version}\\2/", yml]) + + +def update_dockerfile(dockerfiles, new_version): + for file in dockerfiles: + subprocess.run(["sed", "-i", f"s/\(COPY\s.*-\)[0-9]*\.[0-9]*\.[0-9]*\(\.jar.*\)/\\1{new_version}\\2/", file]) + subprocess.run(["sed", "-i", f"s/\(RUN\sjar\sxf.*-\)[0-9]*\.[0-9]*\.[0-9]*\(\.jar\)/\\1{new_version}\\2/", file]) def update_project_version(soup, new_version): From 0f433107efcc48249e6ccd7bd849a7819b96ce26 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 22 Oct 2024 13:41:13 +0000 Subject: [PATCH 4/7] build(deps): bump com.mysql:mysql-connector-j from 9.0.0 to 9.1.0 Bumps [com.mysql:mysql-connector-j](https://github.com/mysql/mysql-connector-j) from 9.0.0 to 9.1.0. - [Changelog](https://github.com/mysql/mysql-connector-j/blob/release/9.x/CHANGES) - [Commits](https://github.com/mysql/mysql-connector-j/compare/9.0.0...9.1.0) --- updated-dependencies: - dependency-name: com.mysql:mysql-connector-j dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- athena-mysql/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/athena-mysql/pom.xml b/athena-mysql/pom.xml index b281bdbd80..a8c9ac7b4e 100644 --- a/athena-mysql/pom.xml +++ b/athena-mysql/pom.xml @@ -35,7 +35,7 @@ com.mysql mysql-connector-j - 9.0.0 + 9.1.0 com.google.protobuf From 90b8c453493e51712b4bed422282daad692a05bd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 22 Oct 2024 13:48:32 +0000 Subject: [PATCH 5/7] build(deps): bump org.elasticsearch.client:elasticsearch-rest-client Bumps [org.elasticsearch.client:elasticsearch-rest-client](https://github.com/elastic/elasticsearch) from 8.15.2 to 8.15.3. - [Release notes](https://github.com/elastic/elasticsearch/releases) - [Changelog](https://github.com/elastic/elasticsearch/blob/main/CHANGELOG.md) - [Commits](https://github.com/elastic/elasticsearch/compare/v8.15.2...v8.15.3) --- updated-dependencies: - dependency-name: org.elasticsearch.client:elasticsearch-rest-client dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- athena-elasticsearch/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/athena-elasticsearch/pom.xml b/athena-elasticsearch/pom.xml index 316cf40b7f..0d040efdbc 100644 --- a/athena-elasticsearch/pom.xml +++ b/athena-elasticsearch/pom.xml @@ -73,7 +73,7 @@ org.elasticsearch.client elasticsearch-rest-client - 8.15.2 + 8.15.3 From 851c2b655dbf1d7d8054ecba007409f73922bc67 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 22 Oct 2024 13:54:35 +0000 Subject: [PATCH 6/7] build(deps): bump software.amazon.glue:schema-registry-serde Bumps [software.amazon.glue:schema-registry-serde](https://github.com/awslabs/aws-glue-schema-registry) from 1.1.20 to 1.1.21. - [Release notes](https://github.com/awslabs/aws-glue-schema-registry/releases) - [Changelog](https://github.com/awslabs/aws-glue-schema-registry/blob/master/CHANGELOG.md) - [Commits](https://github.com/awslabs/aws-glue-schema-registry/compare/v1.1.20...v1.1.21) --- updated-dependencies: - dependency-name: software.amazon.glue:schema-registry-serde dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- athena-msk/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/athena-msk/pom.xml b/athena-msk/pom.xml index c7a37a59a0..e12aad9167 100644 --- a/athena-msk/pom.xml +++ b/athena-msk/pom.xml @@ -98,7 +98,7 @@ software.amazon.glue schema-registry-serde - 1.1.20 + 1.1.21 io.confluent From d31f217dc179923a398c2cfb33d437cc4a6db6e6 Mon Sep 17 00:00:00 2001 From: Aimery Methena <159072740+aimethed@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:32:51 -0400 Subject: [PATCH 7/7] Opt in imageuri (#2340) --- athena-aws-cmdb/athena-aws-cmdb.yaml | 6 +++++- athena-clickhouse/athena-clickhouse.yaml | 6 +++++- athena-cloudera-hive/athena-cloudera-hive.yaml | 6 +++++- athena-cloudera-impala/athena-cloudera-impala.yaml | 6 +++++- athena-cloudwatch-metrics/athena-cloudwatch-metrics.yaml | 6 +++++- athena-cloudwatch/athena-cloudwatch.yaml | 7 +++++-- athena-datalakegen2/athena-datalakegen2.yaml | 6 +++++- athena-db2-as400/athena-db2-as400.yaml | 6 +++++- athena-db2/athena-db2.yaml | 6 +++++- athena-docdb/athena-docdb.yaml | 6 +++++- athena-dynamodb/athena-dynamodb.yaml | 7 +++++-- athena-elasticsearch/athena-elasticsearch.yaml | 7 +++++-- athena-example/athena-example.yaml | 1 - athena-gcs/athena-gcs.yaml | 7 +++++-- athena-google-bigquery/athena-google-bigquery.yaml | 6 +++++- athena-hbase/athena-hbase.yaml | 6 +++++- athena-hortonworks-hive/athena-hortonworks-hive.yaml | 6 +++++- athena-kafka/athena-kafka.yaml | 7 +++++-- athena-msk/athena-msk.yaml | 7 +++++-- athena-mysql/athena-mysql.yaml | 6 +++++- athena-neptune/athena-neptune.yaml | 7 +++++-- athena-oracle/athena-oracle.yaml | 6 +++++- athena-postgresql/athena-postgresql.yaml | 6 +++++- athena-redis/athena-redis.yaml | 6 +++++- athena-redshift/athena-redshift.yaml | 6 +++++- athena-saphana/athena-saphana.yaml | 6 +++++- athena-snowflake/athena-snowflake.yaml | 6 +++++- athena-sqlserver/athena-sqlserver.yaml | 6 +++++- athena-synapse/athena-synapse.yaml | 7 +++++-- athena-teradata/athena-teradata.yaml | 6 +++++- athena-timestream/athena-timestream.yaml | 6 +++++- athena-tpcds/athena-tpcds.yaml | 6 +++++- athena-udfs/athena-udfs.yaml | 6 +++++- athena-vertica/athena-vertica.yaml | 7 +++++-- 34 files changed, 165 insertions(+), 43 deletions(-) diff --git a/athena-aws-cmdb/athena-aws-cmdb.yaml b/athena-aws-cmdb/athena-aws-cmdb.yaml index 4365e6781d..8ead0e1db0 100644 --- a/athena-aws-cmdb/athena-aws-cmdb.yaml +++ b/athena-aws-cmdb/athena-aws-cmdb.yaml @@ -42,6 +42,8 @@ Parameters: Type: String Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -53,7 +55,9 @@ Resources: spill_prefix: !Ref SpillPrefix FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-aws-cmdb:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-aws-cmdb:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with various AWS Services, making your resource inventories accessible via SQL." Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-clickhouse/athena-clickhouse.yaml b/athena-clickhouse/athena-clickhouse.yaml index 259aae7198..8f65ef1018 100644 --- a/athena-clickhouse/athena-clickhouse.yaml +++ b/athena-clickhouse/athena-clickhouse.yaml @@ -59,6 +59,8 @@ Parameters: Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] NotHasLambdaRole: !Equals [!Ref LambdaRoleARN, ""] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -71,7 +73,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-clickhouse:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-clickhouse:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with ClickHouse using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-cloudera-hive/athena-cloudera-hive.yaml b/athena-cloudera-hive/athena-cloudera-hive.yaml index 70f2775b1b..bc9da142a9 100644 --- a/athena-cloudera-hive/athena-cloudera-hive.yaml +++ b/athena-cloudera-hive/athena-cloudera-hive.yaml @@ -54,6 +54,8 @@ Parameters: Type: String Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -66,7 +68,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudera-hive:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudera-hive:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Coludera Hive using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-cloudera-impala/athena-cloudera-impala.yaml b/athena-cloudera-impala/athena-cloudera-impala.yaml index 60dc37ed9e..adf96ece5b 100644 --- a/athena-cloudera-impala/athena-cloudera-impala.yaml +++ b/athena-cloudera-impala/athena-cloudera-impala.yaml @@ -59,6 +59,8 @@ Parameters: Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasLambdaEncryptionKmsKeyARN: !Not [ !Equals [ !Ref LambdaEncryptionKmsKeyARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -71,7 +73,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudera-impala:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudera-impala:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Cloudera Impala using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-cloudwatch-metrics/athena-cloudwatch-metrics.yaml b/athena-cloudwatch-metrics/athena-cloudwatch-metrics.yaml index 974b979e37..5f4ce7585c 100644 --- a/athena-cloudwatch-metrics/athena-cloudwatch-metrics.yaml +++ b/athena-cloudwatch-metrics/athena-cloudwatch-metrics.yaml @@ -42,6 +42,8 @@ Parameters: Type: String Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -53,7 +55,9 @@ Resources: spill_prefix: !Ref SpillPrefix FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudwatch-metrics:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudwatch-metrics:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Cloudwatch Metrics, making your metrics data accessible via SQL" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-cloudwatch/athena-cloudwatch.yaml b/athena-cloudwatch/athena-cloudwatch.yaml index 2e301dc882..67e286740f 100644 --- a/athena-cloudwatch/athena-cloudwatch.yaml +++ b/athena-cloudwatch/athena-cloudwatch.yaml @@ -54,7 +54,8 @@ Conditions: NotHasLambdaRole: !Equals [!Ref LambdaRole, ""] HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] CreateKMSPolicy: !And [ !Condition HasKMSKeyId, !Condition NotHasLambdaRole ] - + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -67,7 +68,9 @@ Resources: kms_key_id: !If [HasKMSKeyId, !Ref KMSKeyId, !Ref "AWS::NoValue"] FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudwatch:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-cloudwatch:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Cloudwatch, making your log accessible via SQL" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-datalakegen2/athena-datalakegen2.yaml b/athena-datalakegen2/athena-datalakegen2.yaml index 5890402513..96750e019c 100644 --- a/athena-datalakegen2/athena-datalakegen2.yaml +++ b/athena-datalakegen2/athena-datalakegen2.yaml @@ -60,6 +60,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -72,7 +74,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-datalakegen2:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-datalakegen2:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with DataLake Gen2 using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-db2-as400/athena-db2-as400.yaml b/athena-db2-as400/athena-db2-as400.yaml index ea0a331051..1b6cf39bf1 100644 --- a/athena-db2-as400/athena-db2-as400.yaml +++ b/athena-db2-as400/athena-db2-as400.yaml @@ -61,6 +61,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -73,7 +75,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-db2-as400:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-db2-as400:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with DB2 on iSeries (AS400) using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-db2/athena-db2.yaml b/athena-db2/athena-db2.yaml index 7508f16712..d82d9585f4 100644 --- a/athena-db2/athena-db2.yaml +++ b/athena-db2/athena-db2.yaml @@ -61,6 +61,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -73,7 +75,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-db2:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-db2:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with DB2 using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-docdb/athena-docdb.yaml b/athena-docdb/athena-docdb.yaml index 588b05f52e..5b8a91261a 100644 --- a/athena-docdb/athena-docdb.yaml +++ b/athena-docdb/athena-docdb.yaml @@ -55,6 +55,8 @@ Parameters: Type: String Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -67,7 +69,9 @@ Resources: default_docdb: !Ref DocDBConnectionString FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-docdb:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-docdb:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with DocumentDB, making your DocumentDB data accessible via SQL." Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-dynamodb/athena-dynamodb.yaml b/athena-dynamodb/athena-dynamodb.yaml index f44ac89665..7793c452bd 100644 --- a/athena-dynamodb/athena-dynamodb.yaml +++ b/athena-dynamodb/athena-dynamodb.yaml @@ -54,7 +54,8 @@ Conditions: NotHasLambdaRole: !Equals [!Ref LambdaRole, ""] HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] CreateKMSPolicy: !And [!Condition HasKMSKeyId, !Condition NotHasLambdaRole] - + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -67,7 +68,9 @@ Resources: kms_key_id: !If [HasKMSKeyId, !Ref KMSKeyId, !Ref "AWS::NoValue"] FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-dynamodb:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-dynamodb:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with DynamoDB, making your tables accessible via SQL" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-elasticsearch/athena-elasticsearch.yaml b/athena-elasticsearch/athena-elasticsearch.yaml index c2e6603bf9..7a57683d0c 100644 --- a/athena-elasticsearch/athena-elasticsearch.yaml +++ b/athena-elasticsearch/athena-elasticsearch.yaml @@ -86,7 +86,8 @@ Parameters: Conditions: IsVPCAccessSelected: !Equals [!Ref IsVPCAccess, true] HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] - + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -103,7 +104,9 @@ Resources: query_scroll_timeout: !Ref QueryScrollTimeout FunctionName: !Sub "${AthenaCatalogName}" PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-elasticsearch:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-elasticsearch:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "The Elasticsearch Lambda Connector provides Athena users the ability to query data stored on Elasticsearch clusters." Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-example/athena-example.yaml b/athena-example/athena-example.yaml index 8f019191ea..19230b92a6 100644 --- a/athena-example/athena-example.yaml +++ b/athena-example/athena-example.yaml @@ -42,7 +42,6 @@ Parameters: Description: "WARNING: If set to 'true' encryption for spilled data is disabled." Default: "false" Type: String - Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' diff --git a/athena-gcs/athena-gcs.yaml b/athena-gcs/athena-gcs.yaml index fa97bc7f86..80ce5a6cbe 100644 --- a/athena-gcs/athena-gcs.yaml +++ b/athena-gcs/athena-gcs.yaml @@ -47,7 +47,8 @@ Parameters: Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] - + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: AthenaGCSConnector: Type: 'AWS::Serverless::Function' @@ -60,7 +61,9 @@ Resources: secret_manager_gcp_creds_name: !Ref GCSSecretName FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-gcs:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-gcs:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Amazon Athena GCS Connector" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-google-bigquery/athena-google-bigquery.yaml b/athena-google-bigquery/athena-google-bigquery.yaml index 6cdf0cb299..42acf445f5 100644 --- a/athena-google-bigquery/athena-google-bigquery.yaml +++ b/athena-google-bigquery/athena-google-bigquery.yaml @@ -65,6 +65,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: AthenaBigQueryConnector: Type: 'AWS::Serverless::Function' @@ -80,7 +82,9 @@ Resources: GOOGLE_APPLICATION_CREDENTIALS: '/tmp/service-account.json' FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-google-bigquery:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-google-bigquery:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with BigQuery using Google SDK" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-hbase/athena-hbase.yaml b/athena-hbase/athena-hbase.yaml index c9d70a24e0..ba51dfd827 100644 --- a/athena-hbase/athena-hbase.yaml +++ b/athena-hbase/athena-hbase.yaml @@ -70,6 +70,8 @@ Parameters: Type: String Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -86,7 +88,9 @@ Resources: hbase_rpc_protection: !Ref HbaseRpcProtection FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-hbase:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-hbase:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with HBase, making your HBase data accessible via SQL" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-hortonworks-hive/athena-hortonworks-hive.yaml b/athena-hortonworks-hive/athena-hortonworks-hive.yaml index 8f941be498..8c1ac3a176 100644 --- a/athena-hortonworks-hive/athena-hortonworks-hive.yaml +++ b/athena-hortonworks-hive/athena-hortonworks-hive.yaml @@ -58,6 +58,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -70,7 +72,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-hortonworks-hive:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-hortonworks-hive:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Hortonworks Hive using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-kafka/athena-kafka.yaml b/athena-kafka/athena-kafka.yaml index 27fd31b9c2..9606527e64 100644 --- a/athena-kafka/athena-kafka.yaml +++ b/athena-kafka/athena-kafka.yaml @@ -85,7 +85,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] - + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: AthenaKafkaConnector: Type: 'AWS::Serverless::Function' @@ -102,7 +103,9 @@ Resources: auth_type: !Ref AuthType FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-kafka:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-kafka:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Kafka clusters" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-msk/athena-msk.yaml b/athena-msk/athena-msk.yaml index 4921ad4d07..ab06956c09 100644 --- a/athena-msk/athena-msk.yaml +++ b/athena-msk/athena-msk.yaml @@ -81,7 +81,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] - + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: AthenaMSKConnector: Type: 'AWS::Serverless::Function' @@ -97,7 +98,9 @@ Resources: auth_type: !Ref AuthType FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-msk:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-msk:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with MSK clusters" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-mysql/athena-mysql.yaml b/athena-mysql/athena-mysql.yaml index 145745b880..2b0d305c3e 100644 --- a/athena-mysql/athena-mysql.yaml +++ b/athena-mysql/athena-mysql.yaml @@ -59,6 +59,8 @@ Parameters: Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] NotHasLambdaRole: !Equals [!Ref LambdaRoleARN, ""] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -71,7 +73,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-mysql:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-mysql:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with MySQL using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-neptune/athena-neptune.yaml b/athena-neptune/athena-neptune.yaml index 114314291a..a468eeffb3 100644 --- a/athena-neptune/athena-neptune.yaml +++ b/athena-neptune/athena-neptune.yaml @@ -77,7 +77,8 @@ Parameters: Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] - + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -97,7 +98,9 @@ Resources: enable_caseinsensitivematch: !Ref EnableCaseInsensitiveMatch FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-neptune:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-neptune:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Neptune, making your Neptune graph data accessible via SQL." Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-oracle/athena-oracle.yaml b/athena-oracle/athena-oracle.yaml index e086cf82cb..58d8329dc6 100644 --- a/athena-oracle/athena-oracle.yaml +++ b/athena-oracle/athena-oracle.yaml @@ -70,6 +70,8 @@ Conditions: NotHasLambdaRole: !Equals [!Ref LambdaRoleARN, ""] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -83,7 +85,9 @@ Resources: is_FIPS_Enabled: !Ref IsFIPSEnabled FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-oracle:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-oracle:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with ORACLE using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-postgresql/athena-postgresql.yaml b/athena-postgresql/athena-postgresql.yaml index 30553623ec..1752d485e9 100644 --- a/athena-postgresql/athena-postgresql.yaml +++ b/athena-postgresql/athena-postgresql.yaml @@ -69,6 +69,8 @@ Parameters: Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] NotHasLambdaRole: !Equals [!Ref LambdaRoleARN, ""] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -82,7 +84,9 @@ Resources: default_scale: !Ref DefaultScale FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-postgresql:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-postgresql:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] ImageConfig: Command: [ !Sub "com.amazonaws.athena.connectors.postgresql.${CompositeHandler}" ] Description: "Enables Amazon Athena to communicate with PostgreSQL using JDBC" diff --git a/athena-redis/athena-redis.yaml b/athena-redis/athena-redis.yaml index c3bc541752..5d18c7ced1 100644 --- a/athena-redis/athena-redis.yaml +++ b/athena-redis/athena-redis.yaml @@ -67,6 +67,8 @@ Parameters: Type: Number Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -82,7 +84,9 @@ Resources: qpt_db_number: !Ref QPTConnectionDBNumber FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-redis:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-redis:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Redis, making your Redis data accessible via SQL" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-redshift/athena-redshift.yaml b/athena-redshift/athena-redshift.yaml index 47cd238f89..58686808d4 100644 --- a/athena-redshift/athena-redshift.yaml +++ b/athena-redshift/athena-redshift.yaml @@ -67,6 +67,8 @@ Conditions: - !Condition NotHasLambdaRole - !Condition HasKMSKeyId HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -80,7 +82,9 @@ Resources: kms_key_id: !If [HasKMSKeyId, !Ref KMSKeyId, !Ref "AWS::NoValue"] FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-redshift:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-redshift:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Redshift using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-saphana/athena-saphana.yaml b/athena-saphana/athena-saphana.yaml index 5a1d895933..201816368e 100644 --- a/athena-saphana/athena-saphana.yaml +++ b/athena-saphana/athena-saphana.yaml @@ -58,6 +58,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -70,7 +72,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-saphana:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-saphana:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Teradata using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-snowflake/athena-snowflake.yaml b/athena-snowflake/athena-snowflake.yaml index 67bac6e7aa..0646c8a5e6 100644 --- a/athena-snowflake/athena-snowflake.yaml +++ b/athena-snowflake/athena-snowflake.yaml @@ -58,6 +58,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -70,7 +72,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-snowflake:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-snowflake:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Snowflake using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-sqlserver/athena-sqlserver.yaml b/athena-sqlserver/athena-sqlserver.yaml index 9f09edfc88..8edbf4e082 100644 --- a/athena-sqlserver/athena-sqlserver.yaml +++ b/athena-sqlserver/athena-sqlserver.yaml @@ -65,6 +65,8 @@ Conditions: NotHasLambdaRole: !Equals [!Ref LambdaRoleARN, ""] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -77,7 +79,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-sqlserver:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-sqlserver:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with SQLSERVER using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-synapse/athena-synapse.yaml b/athena-synapse/athena-synapse.yaml index 1aa43f00be..50f1092c7c 100644 --- a/athena-synapse/athena-synapse.yaml +++ b/athena-synapse/athena-synapse.yaml @@ -66,7 +66,8 @@ Conditions: HasPermissionsBoundary: !Not [!Equals [!Ref PermissionsBoundaryARN, ""]] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] - + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -79,7 +80,9 @@ Resources: default: !Ref DefaultConnectionString FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-synapse:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-synapse:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with SYNPASE using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-teradata/athena-teradata.yaml b/athena-teradata/athena-teradata.yaml index 01ee09c1db..14177fa635 100644 --- a/athena-teradata/athena-teradata.yaml +++ b/athena-teradata/athena-teradata.yaml @@ -62,6 +62,8 @@ Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] HasSecurityGroups: !Not [ !Equals [ !Join ["", !Ref SecurityGroupIds], "" ] ] HasSubnets: !Not [ !Equals [ !Join ["", !Ref SubnetIds], "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: JdbcConnectorConfig: Type: 'AWS::Serverless::Function' @@ -75,7 +77,9 @@ Resources: partitioncount: !Ref PartitionCount FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-teradata:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-teradata:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Teradata using JDBC" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-timestream/athena-timestream.yaml b/athena-timestream/athena-timestream.yaml index 1850ffbecc..2e0387ce57 100644 --- a/athena-timestream/athena-timestream.yaml +++ b/athena-timestream/athena-timestream.yaml @@ -42,6 +42,8 @@ Parameters: Type: String Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -53,7 +55,9 @@ Resources: spill_prefix: !Ref SpillPrefix FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-timestream:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-timestream:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Enables Amazon Athena to communicate with Amazon Timestream, making your time series data accessible from Athena." Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-tpcds/athena-tpcds.yaml b/athena-tpcds/athena-tpcds.yaml index 4086565f7e..ed88425b11 100644 --- a/athena-tpcds/athena-tpcds.yaml +++ b/athena-tpcds/athena-tpcds.yaml @@ -42,6 +42,8 @@ Parameters: Type: String Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' @@ -53,7 +55,9 @@ Resources: spill_prefix: !Ref SpillPrefix FunctionName: !Ref AthenaCatalogName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-tpcds:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-tpcds:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "This connector enables Amazon Athena to communicate with a randomly generated TPC-DS data source." Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-udfs/athena-udfs.yaml b/athena-udfs/athena-udfs.yaml index a6f04ad182..facaba9bcb 100644 --- a/athena-udfs/athena-udfs.yaml +++ b/athena-udfs/athena-udfs.yaml @@ -34,13 +34,17 @@ Parameters: Type: String Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: ConnectorConfig: Type: 'AWS::Serverless::Function' Properties: FunctionName: !Ref LambdaFunctionName PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-udfs:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-udfs:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "This connector enables Amazon Athena to leverage common UDFs made available via Lambda." Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory diff --git a/athena-vertica/athena-vertica.yaml b/athena-vertica/athena-vertica.yaml index a0fed75e01..60775db042 100644 --- a/athena-vertica/athena-vertica.yaml +++ b/athena-vertica/athena-vertica.yaml @@ -63,7 +63,8 @@ Parameters: Conditions: HasPermissionsBoundary: !Not [ !Equals [ !Ref PermissionsBoundaryARN, "" ] ] - + IsRegionBAH: !Equals [!Ref "AWS::Region", "me-south-1"] + IsRegionHKG: !Equals [!Ref "AWS::Region", "ap-east-1"] Resources: LambdaSecurityGroup: Type: 'AWS::EC2::SecurityGroup' @@ -83,7 +84,9 @@ Resources: FunctionName: !Sub "${AthenaCatalogName}" PackageType: "Image" - ImageUri: !Sub '292517598671.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-vertica:2022.47.1' + ImageUri: !Sub + - '${Account}.dkr.ecr.${AWS::Region}.amazonaws.com/athena-federation-repository-vertica:2022.47.1' + - Account: !If [IsRegionBAH, 084828588479, !If [IsRegionHKG, 183295418215, 292517598671]] Description: "Amazon Athena Vertica Connector" Timeout: !Ref LambdaTimeout MemorySize: !Ref LambdaMemory